package api import ( "context" "errors" "io" "net/http" "net/http/httptest" "testing" "time" "github.com/go-chi/chi/v5" "github.com/google/uuid" "github.com/vasyakrg/dns-autoresolver/internal/diff" "github.com/vasyakrg/dns-autoresolver/internal/service" "github.com/vasyakrg/dns-autoresolver/internal/store" ) // --- shared test doubles (also used by api_test.go / tenant_test.go) --- // stubSessions is a configurable SessionManager test double. type stubSessions struct { validateFn func(ctx context.Context, token string) (uuid.UUID, error) } func (s stubSessions) Create(context.Context, uuid.UUID) (string, time.Time, error) { return "stub-token", time.Now().Add(time.Hour), nil } func (s stubSessions) Validate(ctx context.Context, token string) (uuid.UUID, error) { return s.validateFn(ctx, token) } func (s stubSessions) Destroy(context.Context, string) error { return nil } // alwaysValidSessions builds a stubSessions whose Validate always succeeds // with the given fixed user ID — for tests that only care about behavior // past the RequireAuth boundary. func alwaysValidSessions(uid uuid.UUID) stubSessions { return stubSessions{validateFn: func(context.Context, string) (uuid.UUID, error) { return uid, nil }} } // stubAuthStore is a configurable AuthStore test double. Methods besides // GetProjectOwned return zero values — tests exercising register/login/me // use their own dedicated mock (mockAuthStore in auth_test.go). type stubAuthStore struct { getProjectOwnedFn func(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error) } func (s stubAuthStore) RegisterUser(context.Context, string, string) (store.User, store.Project, error) { return store.User{}, store.Project{}, nil } func (s stubAuthStore) GetUserByEmail(context.Context, string) (store.User, error) { return store.User{}, nil } func (s stubAuthStore) GetUserByID(context.Context, uuid.UUID) (store.User, error) { return store.User{}, nil } func (s stubAuthStore) GetUserProject(context.Context, uuid.UUID) (store.Project, error) { return store.Project{}, nil } func (s stubAuthStore) GetProjectOwned(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error) { return s.getProjectOwnedFn(ctx, projectID, userID) } // alwaysOwnedAuthStore builds a stubAuthStore whose GetProjectOwned always // succeeds, treating the caller as the owner of whatever project id is // requested — for tests that only care about behavior past the // RequireProjectAccess boundary (CRUD/check/apply tests). func alwaysOwnedAuthStore() stubAuthStore { return stubAuthStore{getProjectOwnedFn: func(_ context.Context, pid, uid uuid.UUID) (store.Project, error) { return store.Project{ID: pid, UserID: uid}, nil }} } // requestWithSessionCookie builds an httptest request carrying a session // cookie so it clears RequireAuth in tests using stubSessions/alwaysValidSessions // (which ignore the token value and validate unconditionally). func requestWithSessionCookie(method, url string, body io.Reader) *http.Request { req := httptest.NewRequest(method, url, body) req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: "test-session-token"}) return req } // recordingCheckApplier is a CheckApplier test double that records whether // Check/Apply were invoked — used by the IDOR regression test to assert the // service layer is never reached for a project the caller doesn't own. type recordingCheckApplier struct { checkCalled bool applyCalled bool } func (r *recordingCheckApplier) Check(context.Context, uuid.UUID, uuid.UUID) (diff.Changeset, error) { r.checkCalled = true return diff.Changeset{}, nil } func (r *recordingCheckApplier) Apply(context.Context, uuid.UUID, uuid.UUID, service.ApplyRequest) (diff.Changeset, error) { r.applyCalled = true return diff.Changeset{}, nil } // --- RequireAuth --- func TestRequireAuth_NoCookie_Returns401(t *testing.T) { a := &API{Sessions: stubSessions{validateFn: func(context.Context, string) (uuid.UUID, error) { t.Fatal("Validate must not be called when there is no cookie") return uuid.Nil, nil }}} nextCalled := false next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { nextCalled = true }) req := httptest.NewRequest(http.MethodGet, "/whatever", nil) w := httptest.NewRecorder() a.RequireAuth(next).ServeHTTP(w, req) if w.Code != http.StatusUnauthorized { t.Fatalf("expected 401, got %d", w.Code) } if nextCalled { t.Fatal("next must not be called without a session cookie") } } func TestRequireAuth_InvalidToken_Returns401(t *testing.T) { a := &API{Sessions: stubSessions{validateFn: func(context.Context, string) (uuid.UUID, error) { return uuid.Nil, errors.New("invalid or expired session") }}} nextCalled := false next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { nextCalled = true }) req := requestWithSessionCookie(http.MethodGet, "/whatever", nil) w := httptest.NewRecorder() a.RequireAuth(next).ServeHTTP(w, req) if w.Code != http.StatusUnauthorized { t.Fatalf("expected 401, got %d", w.Code) } if nextCalled { t.Fatal("next must not be called when Sessions.Validate fails") } } func TestRequireAuth_Success_CallsNextWithUserID(t *testing.T) { uid := uuid.New() a := &API{Sessions: alwaysValidSessions(uid)} var gotUID uuid.UUID var gotOK bool next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotUID, gotOK = userIDFrom(r.Context()) w.WriteHeader(http.StatusOK) }) req := requestWithSessionCookie(http.MethodGet, "/whatever", nil) w := httptest.NewRecorder() a.RequireAuth(next).ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("expected 200 (next called), got %d", w.Code) } if !gotOK || gotUID != uid { t.Fatalf("expected userIDFrom to yield %s, got %s (ok=%v)", uid, gotUID, gotOK) } } // --- RequireProjectAccess --- func TestRequireProjectAccess_NotOwned_Returns404(t *testing.T) { a := &API{Auth: stubAuthStore{getProjectOwnedFn: func(context.Context, uuid.UUID, uuid.UUID) (store.Project, error) { return store.Project{}, errors.New("no rows") }}} nextCalled := false next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { nextCalled = true }) router := chi.NewRouter() router.With(a.RequireProjectAccess).Get("/projects/{pid}", next.ServeHTTP) req := httptest.NewRequest(http.MethodGet, "/projects/"+uuid.New().String(), nil) req = req.WithContext(context.WithValue(req.Context(), ctxUserID, uuid.New())) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusNotFound { t.Fatalf("expected 404 for an unowned/foreign project, got %d body %s", w.Code, w.Body.String()) } if nextCalled { t.Fatal("next must not be called when the project isn't owned by the caller") } } func TestRequireProjectAccess_Success_CallsNextWithProjectID(t *testing.T) { a := &API{Auth: alwaysOwnedAuthStore()} pid := uuid.New() uid := uuid.New() var gotPID uuid.UUID var gotOK bool next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotPID, gotOK = projectIDFrom(r.Context()) w.WriteHeader(http.StatusOK) }) router := chi.NewRouter() router.With(a.RequireProjectAccess).Get("/projects/{pid}", next.ServeHTTP) req := httptest.NewRequest(http.MethodGet, "/projects/"+pid.String(), nil) req = req.WithContext(context.WithValue(req.Context(), ctxUserID, uid)) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("expected 200 (next called), got %d body %s", w.Code, w.Body.String()) } if !gotOK || gotPID != pid { t.Fatalf("expected projectIDFrom to yield %s, got %s (ok=%v)", pid, gotPID, gotOK) } } // --- IDOR regression --- // TestIDOR_CheckForeignProject_Returns404AndServiceNotCalled is the IDOR // regression required by Task 4: a request for project B's domain, made // with user A's session (who owns project A, not B), must be rejected by // RequireProjectAccess with 404 before service.Check is ever invoked — a // caller must never be able to check/apply another tenant's domain by // guessing/leaking its ids. func TestIDOR_CheckForeignProject_Returns404AndServiceNotCalled(t *testing.T) { userA := uuid.New() pidA := uuid.New() pidB := uuid.New() // owned by a different user; not A svc := &recordingCheckApplier{} auth := stubAuthStore{getProjectOwnedFn: func(_ context.Context, pid, uid uuid.UUID) (store.Project, error) { if pid == pidA && uid == userA { return store.Project{ID: pidA, UserID: userA}, nil } return store.Project{}, errors.New("not found") }} a := &API{Svc: svc, Auth: auth, Sessions: alwaysValidSessions(userA)} router := NewRouter(a) did := uuid.New().String() req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+pidB.String()+"/domains/"+did+"/check", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusNotFound { t.Fatalf("expected 404 for user A requesting project B's domain, got %d body %s", w.Code, w.Body.String()) } if svc.checkCalled { t.Fatal("service.Check must not be called when the caller doesn't own the project") } // Sanity check: the same user against their own project succeeds and // does reach the service — proving the 404 above is really about // project ownership, not e.g. a broken route. req2 := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+pidA.String()+"/domains/"+did+"/check", nil) w2 := httptest.NewRecorder() router.ServeHTTP(w2, req2) if w2.Code != http.StatusOK { t.Fatalf("expected 200 for user A requesting their own project, got %d body %s", w2.Code, w2.Body.String()) } if !svc.checkCalled { t.Fatal("expected service.Check to be called for the caller's own project") } }