266 lines
9.5 KiB
Go
266 lines
9.5 KiB
Go
package api
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/google/uuid"
|
|
|
|
"github.com/vasyakrg/dns-autoresolver/internal/diff"
|
|
"github.com/vasyakrg/dns-autoresolver/internal/service"
|
|
"github.com/vasyakrg/dns-autoresolver/internal/store"
|
|
)
|
|
|
|
// --- shared test doubles (also used by api_test.go / tenant_test.go) ---
|
|
|
|
// stubSessions is a configurable SessionManager test double.
|
|
type stubSessions struct {
|
|
validateFn func(ctx context.Context, token string) (uuid.UUID, error)
|
|
}
|
|
|
|
func (s stubSessions) Create(context.Context, uuid.UUID) (string, time.Time, error) {
|
|
return "stub-token", time.Now().Add(time.Hour), nil
|
|
}
|
|
func (s stubSessions) Validate(ctx context.Context, token string) (uuid.UUID, error) {
|
|
return s.validateFn(ctx, token)
|
|
}
|
|
func (s stubSessions) Destroy(context.Context, string) error { return nil }
|
|
|
|
// alwaysValidSessions builds a stubSessions whose Validate always succeeds
|
|
// with the given fixed user ID — for tests that only care about behavior
|
|
// past the RequireAuth boundary.
|
|
func alwaysValidSessions(uid uuid.UUID) stubSessions {
|
|
return stubSessions{validateFn: func(context.Context, string) (uuid.UUID, error) { return uid, nil }}
|
|
}
|
|
|
|
// stubAuthStore is a configurable AuthStore test double. Methods besides
|
|
// GetProjectOwned return zero values — tests exercising register/login/me
|
|
// use their own dedicated mock (mockAuthStore in auth_test.go).
|
|
type stubAuthStore struct {
|
|
getProjectOwnedFn func(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error)
|
|
}
|
|
|
|
func (s stubAuthStore) RegisterUser(context.Context, string, string) (store.User, store.Project, error) {
|
|
return store.User{}, store.Project{}, nil
|
|
}
|
|
func (s stubAuthStore) GetUserByEmail(context.Context, string) (store.User, error) {
|
|
return store.User{}, nil
|
|
}
|
|
func (s stubAuthStore) GetUserByID(context.Context, uuid.UUID) (store.User, error) {
|
|
return store.User{}, nil
|
|
}
|
|
func (s stubAuthStore) GetUserProject(context.Context, uuid.UUID) (store.Project, error) {
|
|
return store.Project{}, nil
|
|
}
|
|
func (s stubAuthStore) GetProjectOwned(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error) {
|
|
return s.getProjectOwnedFn(ctx, projectID, userID)
|
|
}
|
|
|
|
// alwaysOwnedAuthStore builds a stubAuthStore whose GetProjectOwned always
|
|
// succeeds, treating the caller as the owner of whatever project id is
|
|
// requested — for tests that only care about behavior past the
|
|
// RequireProjectAccess boundary (CRUD/check/apply tests).
|
|
func alwaysOwnedAuthStore() stubAuthStore {
|
|
return stubAuthStore{getProjectOwnedFn: func(_ context.Context, pid, uid uuid.UUID) (store.Project, error) {
|
|
return store.Project{ID: pid, UserID: uid}, nil
|
|
}}
|
|
}
|
|
|
|
// requestWithSessionCookie builds an httptest request carrying a session
|
|
// cookie so it clears RequireAuth in tests using stubSessions/alwaysValidSessions
|
|
// (which ignore the token value and validate unconditionally).
|
|
func requestWithSessionCookie(method, url string, body io.Reader) *http.Request {
|
|
req := httptest.NewRequest(method, url, body)
|
|
req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: "test-session-token"})
|
|
return req
|
|
}
|
|
|
|
// recordingCheckApplier is a CheckApplier test double that records whether
|
|
// Check/Apply were invoked — used by the IDOR regression test to assert the
|
|
// service layer is never reached for a project the caller doesn't own.
|
|
type recordingCheckApplier struct {
|
|
checkCalled bool
|
|
applyCalled bool
|
|
}
|
|
|
|
func (r *recordingCheckApplier) Check(context.Context, uuid.UUID, uuid.UUID) (diff.Changeset, error) {
|
|
r.checkCalled = true
|
|
return diff.Changeset{}, nil
|
|
}
|
|
func (r *recordingCheckApplier) Apply(context.Context, uuid.UUID, uuid.UUID, service.ApplyRequest) (diff.Changeset, error) {
|
|
r.applyCalled = true
|
|
return diff.Changeset{}, nil
|
|
}
|
|
|
|
// --- RequireAuth ---
|
|
|
|
func TestRequireAuth_NoCookie_Returns401(t *testing.T) {
|
|
a := &API{Sessions: stubSessions{validateFn: func(context.Context, string) (uuid.UUID, error) {
|
|
t.Fatal("Validate must not be called when there is no cookie")
|
|
return uuid.Nil, nil
|
|
}}}
|
|
nextCalled := false
|
|
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { nextCalled = true })
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/whatever", nil)
|
|
w := httptest.NewRecorder()
|
|
a.RequireAuth(next).ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusUnauthorized {
|
|
t.Fatalf("expected 401, got %d", w.Code)
|
|
}
|
|
if nextCalled {
|
|
t.Fatal("next must not be called without a session cookie")
|
|
}
|
|
}
|
|
|
|
func TestRequireAuth_InvalidToken_Returns401(t *testing.T) {
|
|
a := &API{Sessions: stubSessions{validateFn: func(context.Context, string) (uuid.UUID, error) {
|
|
return uuid.Nil, errors.New("invalid or expired session")
|
|
}}}
|
|
nextCalled := false
|
|
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { nextCalled = true })
|
|
|
|
req := requestWithSessionCookie(http.MethodGet, "/whatever", nil)
|
|
w := httptest.NewRecorder()
|
|
a.RequireAuth(next).ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusUnauthorized {
|
|
t.Fatalf("expected 401, got %d", w.Code)
|
|
}
|
|
if nextCalled {
|
|
t.Fatal("next must not be called when Sessions.Validate fails")
|
|
}
|
|
}
|
|
|
|
func TestRequireAuth_Success_CallsNextWithUserID(t *testing.T) {
|
|
uid := uuid.New()
|
|
a := &API{Sessions: alwaysValidSessions(uid)}
|
|
|
|
var gotUID uuid.UUID
|
|
var gotOK bool
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
gotUID, gotOK = userIDFrom(r.Context())
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
req := requestWithSessionCookie(http.MethodGet, "/whatever", nil)
|
|
w := httptest.NewRecorder()
|
|
a.RequireAuth(next).ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("expected 200 (next called), got %d", w.Code)
|
|
}
|
|
if !gotOK || gotUID != uid {
|
|
t.Fatalf("expected userIDFrom to yield %s, got %s (ok=%v)", uid, gotUID, gotOK)
|
|
}
|
|
}
|
|
|
|
// --- RequireProjectAccess ---
|
|
|
|
func TestRequireProjectAccess_NotOwned_Returns404(t *testing.T) {
|
|
a := &API{Auth: stubAuthStore{getProjectOwnedFn: func(context.Context, uuid.UUID, uuid.UUID) (store.Project, error) {
|
|
return store.Project{}, errors.New("no rows")
|
|
}}}
|
|
nextCalled := false
|
|
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { nextCalled = true })
|
|
|
|
router := chi.NewRouter()
|
|
router.With(a.RequireProjectAccess).Get("/projects/{pid}", next.ServeHTTP)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/projects/"+uuid.New().String(), nil)
|
|
req = req.WithContext(context.WithValue(req.Context(), ctxUserID, uuid.New()))
|
|
w := httptest.NewRecorder()
|
|
router.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusNotFound {
|
|
t.Fatalf("expected 404 for an unowned/foreign project, got %d body %s", w.Code, w.Body.String())
|
|
}
|
|
if nextCalled {
|
|
t.Fatal("next must not be called when the project isn't owned by the caller")
|
|
}
|
|
}
|
|
|
|
func TestRequireProjectAccess_Success_CallsNextWithProjectID(t *testing.T) {
|
|
a := &API{Auth: alwaysOwnedAuthStore()}
|
|
pid := uuid.New()
|
|
uid := uuid.New()
|
|
|
|
var gotPID uuid.UUID
|
|
var gotOK bool
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
gotPID, gotOK = projectIDFrom(r.Context())
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
router := chi.NewRouter()
|
|
router.With(a.RequireProjectAccess).Get("/projects/{pid}", next.ServeHTTP)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/projects/"+pid.String(), nil)
|
|
req = req.WithContext(context.WithValue(req.Context(), ctxUserID, uid))
|
|
w := httptest.NewRecorder()
|
|
router.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("expected 200 (next called), got %d body %s", w.Code, w.Body.String())
|
|
}
|
|
if !gotOK || gotPID != pid {
|
|
t.Fatalf("expected projectIDFrom to yield %s, got %s (ok=%v)", pid, gotPID, gotOK)
|
|
}
|
|
}
|
|
|
|
// --- IDOR regression ---
|
|
|
|
// TestIDOR_CheckForeignProject_Returns404AndServiceNotCalled is the IDOR
|
|
// regression required by Task 4: a request for project B's domain, made
|
|
// with user A's session (who owns project A, not B), must be rejected by
|
|
// RequireProjectAccess with 404 before service.Check is ever invoked — a
|
|
// caller must never be able to check/apply another tenant's domain by
|
|
// guessing/leaking its ids.
|
|
func TestIDOR_CheckForeignProject_Returns404AndServiceNotCalled(t *testing.T) {
|
|
userA := uuid.New()
|
|
pidA := uuid.New()
|
|
pidB := uuid.New() // owned by a different user; not A
|
|
|
|
svc := &recordingCheckApplier{}
|
|
auth := stubAuthStore{getProjectOwnedFn: func(_ context.Context, pid, uid uuid.UUID) (store.Project, error) {
|
|
if pid == pidA && uid == userA {
|
|
return store.Project{ID: pidA, UserID: userA}, nil
|
|
}
|
|
return store.Project{}, errors.New("not found")
|
|
}}
|
|
a := &API{Svc: svc, Auth: auth, Sessions: alwaysValidSessions(userA)}
|
|
router := NewRouter(a)
|
|
|
|
did := uuid.New().String()
|
|
req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+pidB.String()+"/domains/"+did+"/check", nil)
|
|
w := httptest.NewRecorder()
|
|
router.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusNotFound {
|
|
t.Fatalf("expected 404 for user A requesting project B's domain, got %d body %s", w.Code, w.Body.String())
|
|
}
|
|
if svc.checkCalled {
|
|
t.Fatal("service.Check must not be called when the caller doesn't own the project")
|
|
}
|
|
|
|
// Sanity check: the same user against their own project succeeds and
|
|
// does reach the service — proving the 404 above is really about
|
|
// project ownership, not e.g. a broken route.
|
|
req2 := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+pidA.String()+"/domains/"+did+"/check", nil)
|
|
w2 := httptest.NewRecorder()
|
|
router.ServeHTTP(w2, req2)
|
|
if w2.Code != http.StatusOK {
|
|
t.Fatalf("expected 200 for user A requesting their own project, got %d body %s", w2.Code, w2.Body.String())
|
|
}
|
|
if !svc.checkCalled {
|
|
t.Fatal("expected service.Check to be called for the caller's own project")
|
|
}
|
|
}
|