diff --git a/internal/api/api.go b/internal/api/api.go index 16710cd..cd01cf4 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -18,8 +18,8 @@ import ( // CheckApplier is the service surface the API depends on. type CheckApplier interface { - Check(ctx context.Context, domainID uuid.UUID) (diff.Changeset, error) - Apply(ctx context.Context, domainID uuid.UUID, req service.ApplyRequest) (diff.Changeset, error) + Check(ctx context.Context, projectID, domainID uuid.UUID) (diff.Changeset, error) + Apply(ctx context.Context, projectID, domainID uuid.UUID, req service.ApplyRequest) (diff.Changeset, error) } // TenantStore is the narrow persistence surface the CRUD handlers depend on. @@ -63,6 +63,10 @@ type AuthStore interface { GetUserByEmail(ctx context.Context, email string) (store.User, error) GetUserByID(ctx context.Context, userID uuid.UUID) (store.User, error) GetUserProject(ctx context.Context, userID uuid.UUID) (store.Project, error) + // GetProjectOwned looks up projectID and returns it only if it's owned by + // userID — RequireProjectAccess uses this to reject foreign/nonexistent + // projects with 404 before any handler runs. + GetProjectOwned(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error) } // SessionManager creates/validates/destroys login sessions. *auth.Sessions @@ -91,11 +95,17 @@ func NewRouter(a *API) http.Handler { r.Route("/api/v1/auth", func(r chi.Router) { r.Post("/register", a.handleRegister) r.Post("/login", a.handleLogin) - r.Post("/logout", a.handleLogout) // защитится RequireAuth в Task 4 - r.Get("/me", a.handleMe) // защитится RequireAuth в Task 4 + r.Group(func(r chi.Router) { + r.Use(a.RequireAuth) + r.Post("/logout", a.handleLogout) + r.Get("/me", a.handleMe) + }) }) r.Route("/api/v1/projects/{pid}", func(r chi.Router) { + r.Use(a.RequireAuth) + r.Use(a.RequireProjectAccess) + r.Route("/domains", func(r chi.Router) { r.Post("/", a.handleCreateDomain) r.Get("/", a.handleListDomains) diff --git a/internal/api/api_test.go b/internal/api/api_test.go index 0a48ee0..876d06b 100644 --- a/internal/api/api_test.go +++ b/internal/api/api_test.go @@ -20,18 +20,23 @@ type mockCheckApplier struct { lastReq service.ApplyRequest } -func (m *mockCheckApplier) Check(context.Context, uuid.UUID) (diff.Changeset, error) { +func (m *mockCheckApplier) Check(context.Context, uuid.UUID, uuid.UUID) (diff.Changeset, error) { d := model.Record{Type: model.A, Name: "a.example.com.", TTL: 300, Values: []string{"1.1.1.1"}} return diff.Changeset{Diffs: []diff.RecordDiff{{Kind: diff.Add, Type: d.Type, Name: d.Name, Desired: &d}}}, nil } -func (m *mockCheckApplier) Apply(_ context.Context, _ uuid.UUID, req service.ApplyRequest) (diff.Changeset, error) { +func (m *mockCheckApplier) Apply(_ context.Context, _, _ uuid.UUID, req service.ApplyRequest) (diff.Changeset, error) { m.lastReq = req return diff.Changeset{}, nil } +// newTestAPI wires a fixed authenticated user who owns whatever project id +// is requested (via alwaysOwnedAuthStore/alwaysValidSessions in +// middleware_test.go) — these tests exercise check/apply behavior past the +// RequireAuth/RequireProjectAccess boundary, which is covered separately by +// middleware_test.go's own tests and the IDOR regression. func newTestAPI() (*API, *mockCheckApplier) { m := &mockCheckApplier{} - return &API{Svc: m}, m // остальные зависимости (store/cipher) nil — CRUD-тесты добавит реализатор + return &API{Svc: m, Auth: alwaysOwnedAuthStore(), Sessions: alwaysValidSessions(uuid.New())}, m } func TestCheckEndpoint(t *testing.T) { @@ -39,7 +44,7 @@ func TestCheckEndpoint(t *testing.T) { router := NewRouter(a) did := uuid.New().String() - req := httptest.NewRequest(http.MethodGet, + req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/check", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -62,7 +67,7 @@ func TestApplyDefaultsPruneFalse(t *testing.T) { did := uuid.New().String() body := `{"applyUpdates":true}` // applyPrunes отсутствует → false - req := httptest.NewRequest(http.MethodPost, + req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply", strings.NewReader(body)) w := httptest.NewRecorder() @@ -81,7 +86,7 @@ func TestApplyEmptyBodyOK(t *testing.T) { router := NewRouter(a) did := uuid.New().String() - req := httptest.NewRequest(http.MethodPost, + req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -100,7 +105,7 @@ func TestApplyMalformedBody(t *testing.T) { did := uuid.New().String() body := `{"applyUpdates":` - req := httptest.NewRequest(http.MethodPost, + req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply", strings.NewReader(body)) w := httptest.NewRecorder() @@ -114,7 +119,7 @@ func TestApplyMalformedBody(t *testing.T) { func TestApplyBadUUID(t *testing.T) { a, _ := newTestAPI() router := NewRouter(a) - req := httptest.NewRequest(http.MethodPost, + req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/not-a-uuid/apply", bytes.NewReader([]byte(`{}`))) w := httptest.NewRecorder() diff --git a/internal/api/auth_handlers.go b/internal/api/auth_handlers.go index 76d3c38..60ff2d3 100644 --- a/internal/api/auth_handlers.go +++ b/internal/api/auth_handlers.go @@ -1,15 +1,12 @@ package api import ( - "context" "errors" "log" "net/http" "strings" "time" - "github.com/google/uuid" - "github.com/vasyakrg/dns-autoresolver/internal/auth" "github.com/vasyakrg/dns-autoresolver/internal/store" ) @@ -39,19 +36,6 @@ func normalizeEmail(email string) string { return strings.ToLower(strings.TrimSpace(email)) } -// ctxKeyUserID is a private context key carrying the authenticated user's ID. -// Task 4's RequireAuth middleware sets it after validating the session -// cookie; handleMe reads it back. -type ctxKeyUserID struct{} - -// userIDFromContext extracts the authenticated user ID set by RequireAuth -// (Task 4). Until that middleware is wired in, tests set it directly via -// context.WithValue. -func userIDFromContext(ctx context.Context) (uuid.UUID, bool) { - id, ok := ctx.Value(ctxKeyUserID{}).(uuid.UUID) - return id, ok -} - func setSessionCookie(w http.ResponseWriter, token string, exp time.Time) { http.SetCookie(w, &http.Cookie{ Name: sessionCookieName, Value: token, Path: "/", @@ -165,11 +149,10 @@ func (a *API) handleLogout(w http.ResponseWriter, r *http.Request) { } // handleMe returns the authenticated caller's identity + default project. -// The user ID comes from the request context, set by Task 4's RequireAuth -// middleware after validating the session cookie (tests set it directly via -// context.WithValue in the interim). +// The user ID comes from the request context, set by RequireAuth after +// validating the session cookie. func (a *API) handleMe(w http.ResponseWriter, r *http.Request) { - userID, ok := userIDFromContext(r.Context()) + userID, ok := userIDFrom(r.Context()) if !ok { writeErr(w, http.StatusUnauthorized, "authentication required") return diff --git a/internal/api/auth_test.go b/internal/api/auth_test.go index 58f9402..5525fc7 100644 --- a/internal/api/auth_test.go +++ b/internal/api/auth_test.go @@ -18,10 +18,11 @@ import ( // --- mocks --- type mockAuthStore struct { - registerUserFn func(ctx context.Context, email, passwordHash string) (store.User, store.Project, error) - getUserByEmailFn func(ctx context.Context, email string) (store.User, error) - getUserByIDFn func(ctx context.Context, userID uuid.UUID) (store.User, error) - getUserProjectFn func(ctx context.Context, userID uuid.UUID) (store.Project, error) + registerUserFn func(ctx context.Context, email, passwordHash string) (store.User, store.Project, error) + getUserByEmailFn func(ctx context.Context, email string) (store.User, error) + getUserByIDFn func(ctx context.Context, userID uuid.UUID) (store.User, error) + getUserProjectFn func(ctx context.Context, userID uuid.UUID) (store.Project, error) + getProjectOwnedFn func(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error) } func (m *mockAuthStore) RegisterUser(ctx context.Context, email, passwordHash string) (store.User, store.Project, error) { @@ -40,8 +41,13 @@ func (m *mockAuthStore) GetUserProject(ctx context.Context, userID uuid.UUID) (s return m.getUserProjectFn(ctx, userID) } +func (m *mockAuthStore) GetProjectOwned(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error) { + return m.getProjectOwnedFn(ctx, projectID, userID) +} + type mockSessionManager struct { - createFn func(ctx context.Context, userID uuid.UUID) (string, time.Time, error) + createFn func(ctx context.Context, userID uuid.UUID) (string, time.Time, error) + validateFn func(ctx context.Context, token string) (uuid.UUID, error) destroyCalled bool destroyToken string @@ -52,7 +58,10 @@ func (m *mockSessionManager) Create(ctx context.Context, userID uuid.UUID) (stri return m.createFn(ctx, userID) } -func (m *mockSessionManager) Validate(context.Context, string) (uuid.UUID, error) { +func (m *mockSessionManager) Validate(ctx context.Context, token string) (uuid.UUID, error) { + if m.validateFn != nil { + return m.validateFn(ctx, token) + } return uuid.Nil, nil } @@ -338,7 +347,7 @@ func TestAuthLogout_ClearsSessionAndDestroys(t *testing.T) { // now resolves the authenticated user via GetUserByID and returns their real // email, instead of leaving it blank. func TestAuthMe_ReturnsRealEmail(t *testing.T) { - a, authStore, _ := newTestAuthAPI() + a, authStore, sessions := newTestAuthAPI() userID := uuid.New() projectID := uuid.New() authStore.getUserByIDFn = func(_ context.Context, id uuid.UUID) (store.User, error) { @@ -350,10 +359,15 @@ func TestAuthMe_ReturnsRealEmail(t *testing.T) { authStore.getUserProjectFn = func(_ context.Context, uid uuid.UUID) (store.Project, error) { return store.Project{ID: projectID, UserID: uid, Name: "default"}, nil } + // /me is now behind RequireAuth: the session cookie must resolve to + // userID via Sessions.Validate rather than being injected directly. + sessions.validateFn = func(context.Context, string) (uuid.UUID, error) { + return userID, nil + } router := NewRouter(a) req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/me", nil) - req = req.WithContext(context.WithValue(req.Context(), ctxKeyUserID{}, userID)) + req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: "some-valid-token"}) w := httptest.NewRecorder() router.ServeHTTP(w, req) diff --git a/internal/api/handlers.go b/internal/api/handlers.go index 2815dec..3553bfa 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -24,12 +24,15 @@ func writeErr(w http.ResponseWriter, status int, msg string) { } func (a *API) handleCheck(w http.ResponseWriter, r *http.Request) { + // pid is guaranteed present and owned by the caller — RequireProjectAccess + // validated it before this handler ever runs. + pid, _ := projectIDFrom(r.Context()) did, err := uuid.Parse(chi.URLParam(r, "did")) if err != nil { writeErr(w, http.StatusBadRequest, "invalid domain id") return } - cs, err := a.Svc.Check(r.Context(), did) + cs, err := a.Svc.Check(r.Context(), pid, did) if err != nil { log.Printf("api: check failed: %v", err) writeErr(w, http.StatusInternalServerError, "internal error") @@ -39,6 +42,9 @@ func (a *API) handleCheck(w http.ResponseWriter, r *http.Request) { } func (a *API) handleApply(w http.ResponseWriter, r *http.Request) { + // pid is guaranteed present and owned by the caller — RequireProjectAccess + // validated it before this handler ever runs. + pid, _ := projectIDFrom(r.Context()) did, err := uuid.Parse(chi.URLParam(r, "did")) if err != nil { writeErr(w, http.StatusBadRequest, "invalid domain id") @@ -54,7 +60,7 @@ func (a *API) handleApply(w http.ResponseWriter, r *http.Request) { return } } - cs, err := a.Svc.Apply(r.Context(), did, service.ApplyRequest{ + cs, err := a.Svc.Apply(r.Context(), pid, did, service.ApplyRequest{ ApplyUpdates: req.ApplyUpdates, ApplyPrunes: req.ApplyPrunes, }) if err != nil { diff --git a/internal/api/middleware.go b/internal/api/middleware.go new file mode 100644 index 0000000..339e171 --- /dev/null +++ b/internal/api/middleware.go @@ -0,0 +1,75 @@ +package api + +import ( + "context" + "net/http" + + "github.com/go-chi/chi/v5" + "github.com/google/uuid" +) + +type ctxKey string + +const ( + ctxUserID ctxKey = "userID" + ctxProjectID ctxKey = "projectID" +) + +// RequireAuth validates the session cookie and, on success, stores the +// authenticated user's ID in the request context (see userIDFrom). Any +// failure — missing cookie or an invalid/expired token — is rejected with +// 401 before reaching the wrapped handler. +func (a *API) RequireAuth(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := r.Cookie(sessionCookieName) + if err != nil { + writeErr(w, http.StatusUnauthorized, "unauthorized") + return + } + uid, err := a.Sessions.Validate(r.Context(), c.Value) + if err != nil { + writeErr(w, http.StatusUnauthorized, "unauthorized") + return + } + ctx := context.WithValue(r.Context(), ctxUserID, uid) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// RequireProjectAccess verifies that the {pid} URL segment names a project +// owned by the authenticated user (set by RequireAuth, which must run +// first) and, on success, stores the project ID in the request context (see +// projectIDFrom). A project that doesn't exist or isn't owned by the caller +// is rejected with 404 — not 403 — so a caller can't distinguish "not yours" +// from "doesn't exist" (closes IDOR by never confirming another tenant's +// project exists). +func (a *API) RequireProjectAccess(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + uid, ok := userIDFrom(r.Context()) + if !ok { + writeErr(w, http.StatusUnauthorized, "unauthorized") + return + } + pid, err := uuid.Parse(chi.URLParam(r, "pid")) + if err != nil { + writeErr(w, http.StatusBadRequest, "invalid project id") + return + } + if _, err := a.Auth.GetProjectOwned(r.Context(), pid, uid); err != nil { + writeErr(w, http.StatusNotFound, "not found") + return + } + ctx := context.WithValue(r.Context(), ctxProjectID, pid) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +func userIDFrom(ctx context.Context) (uuid.UUID, bool) { + v, ok := ctx.Value(ctxUserID).(uuid.UUID) + return v, ok +} + +func projectIDFrom(ctx context.Context) (uuid.UUID, bool) { + v, ok := ctx.Value(ctxProjectID).(uuid.UUID) + return v, ok +} diff --git a/internal/api/middleware_test.go b/internal/api/middleware_test.go new file mode 100644 index 0000000..b24ecbd --- /dev/null +++ b/internal/api/middleware_test.go @@ -0,0 +1,265 @@ +package api + +import ( + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-chi/chi/v5" + "github.com/google/uuid" + + "github.com/vasyakrg/dns-autoresolver/internal/diff" + "github.com/vasyakrg/dns-autoresolver/internal/service" + "github.com/vasyakrg/dns-autoresolver/internal/store" +) + +// --- shared test doubles (also used by api_test.go / tenant_test.go) --- + +// stubSessions is a configurable SessionManager test double. +type stubSessions struct { + validateFn func(ctx context.Context, token string) (uuid.UUID, error) +} + +func (s stubSessions) Create(context.Context, uuid.UUID) (string, time.Time, error) { + return "stub-token", time.Now().Add(time.Hour), nil +} +func (s stubSessions) Validate(ctx context.Context, token string) (uuid.UUID, error) { + return s.validateFn(ctx, token) +} +func (s stubSessions) Destroy(context.Context, string) error { return nil } + +// alwaysValidSessions builds a stubSessions whose Validate always succeeds +// with the given fixed user ID — for tests that only care about behavior +// past the RequireAuth boundary. +func alwaysValidSessions(uid uuid.UUID) stubSessions { + return stubSessions{validateFn: func(context.Context, string) (uuid.UUID, error) { return uid, nil }} +} + +// stubAuthStore is a configurable AuthStore test double. Methods besides +// GetProjectOwned return zero values — tests exercising register/login/me +// use their own dedicated mock (mockAuthStore in auth_test.go). +type stubAuthStore struct { + getProjectOwnedFn func(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error) +} + +func (s stubAuthStore) RegisterUser(context.Context, string, string) (store.User, store.Project, error) { + return store.User{}, store.Project{}, nil +} +func (s stubAuthStore) GetUserByEmail(context.Context, string) (store.User, error) { + return store.User{}, nil +} +func (s stubAuthStore) GetUserByID(context.Context, uuid.UUID) (store.User, error) { + return store.User{}, nil +} +func (s stubAuthStore) GetUserProject(context.Context, uuid.UUID) (store.Project, error) { + return store.Project{}, nil +} +func (s stubAuthStore) GetProjectOwned(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error) { + return s.getProjectOwnedFn(ctx, projectID, userID) +} + +// alwaysOwnedAuthStore builds a stubAuthStore whose GetProjectOwned always +// succeeds, treating the caller as the owner of whatever project id is +// requested — for tests that only care about behavior past the +// RequireProjectAccess boundary (CRUD/check/apply tests). +func alwaysOwnedAuthStore() stubAuthStore { + return stubAuthStore{getProjectOwnedFn: func(_ context.Context, pid, uid uuid.UUID) (store.Project, error) { + return store.Project{ID: pid, UserID: uid}, nil + }} +} + +// requestWithSessionCookie builds an httptest request carrying a session +// cookie so it clears RequireAuth in tests using stubSessions/alwaysValidSessions +// (which ignore the token value and validate unconditionally). +func requestWithSessionCookie(method, url string, body io.Reader) *http.Request { + req := httptest.NewRequest(method, url, body) + req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: "test-session-token"}) + return req +} + +// recordingCheckApplier is a CheckApplier test double that records whether +// Check/Apply were invoked — used by the IDOR regression test to assert the +// service layer is never reached for a project the caller doesn't own. +type recordingCheckApplier struct { + checkCalled bool + applyCalled bool +} + +func (r *recordingCheckApplier) Check(context.Context, uuid.UUID, uuid.UUID) (diff.Changeset, error) { + r.checkCalled = true + return diff.Changeset{}, nil +} +func (r *recordingCheckApplier) Apply(context.Context, uuid.UUID, uuid.UUID, service.ApplyRequest) (diff.Changeset, error) { + r.applyCalled = true + return diff.Changeset{}, nil +} + +// --- RequireAuth --- + +func TestRequireAuth_NoCookie_Returns401(t *testing.T) { + a := &API{Sessions: stubSessions{validateFn: func(context.Context, string) (uuid.UUID, error) { + t.Fatal("Validate must not be called when there is no cookie") + return uuid.Nil, nil + }}} + nextCalled := false + next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { nextCalled = true }) + + req := httptest.NewRequest(http.MethodGet, "/whatever", nil) + w := httptest.NewRecorder() + a.RequireAuth(next).ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", w.Code) + } + if nextCalled { + t.Fatal("next must not be called without a session cookie") + } +} + +func TestRequireAuth_InvalidToken_Returns401(t *testing.T) { + a := &API{Sessions: stubSessions{validateFn: func(context.Context, string) (uuid.UUID, error) { + return uuid.Nil, errors.New("invalid or expired session") + }}} + nextCalled := false + next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { nextCalled = true }) + + req := requestWithSessionCookie(http.MethodGet, "/whatever", nil) + w := httptest.NewRecorder() + a.RequireAuth(next).ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", w.Code) + } + if nextCalled { + t.Fatal("next must not be called when Sessions.Validate fails") + } +} + +func TestRequireAuth_Success_CallsNextWithUserID(t *testing.T) { + uid := uuid.New() + a := &API{Sessions: alwaysValidSessions(uid)} + + var gotUID uuid.UUID + var gotOK bool + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotUID, gotOK = userIDFrom(r.Context()) + w.WriteHeader(http.StatusOK) + }) + + req := requestWithSessionCookie(http.MethodGet, "/whatever", nil) + w := httptest.NewRecorder() + a.RequireAuth(next).ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200 (next called), got %d", w.Code) + } + if !gotOK || gotUID != uid { + t.Fatalf("expected userIDFrom to yield %s, got %s (ok=%v)", uid, gotUID, gotOK) + } +} + +// --- RequireProjectAccess --- + +func TestRequireProjectAccess_NotOwned_Returns404(t *testing.T) { + a := &API{Auth: stubAuthStore{getProjectOwnedFn: func(context.Context, uuid.UUID, uuid.UUID) (store.Project, error) { + return store.Project{}, errors.New("no rows") + }}} + nextCalled := false + next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { nextCalled = true }) + + router := chi.NewRouter() + router.With(a.RequireProjectAccess).Get("/projects/{pid}", next.ServeHTTP) + + req := httptest.NewRequest(http.MethodGet, "/projects/"+uuid.New().String(), nil) + req = req.WithContext(context.WithValue(req.Context(), ctxUserID, uuid.New())) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Fatalf("expected 404 for an unowned/foreign project, got %d body %s", w.Code, w.Body.String()) + } + if nextCalled { + t.Fatal("next must not be called when the project isn't owned by the caller") + } +} + +func TestRequireProjectAccess_Success_CallsNextWithProjectID(t *testing.T) { + a := &API{Auth: alwaysOwnedAuthStore()} + pid := uuid.New() + uid := uuid.New() + + var gotPID uuid.UUID + var gotOK bool + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPID, gotOK = projectIDFrom(r.Context()) + w.WriteHeader(http.StatusOK) + }) + + router := chi.NewRouter() + router.With(a.RequireProjectAccess).Get("/projects/{pid}", next.ServeHTTP) + + req := httptest.NewRequest(http.MethodGet, "/projects/"+pid.String(), nil) + req = req.WithContext(context.WithValue(req.Context(), ctxUserID, uid)) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200 (next called), got %d body %s", w.Code, w.Body.String()) + } + if !gotOK || gotPID != pid { + t.Fatalf("expected projectIDFrom to yield %s, got %s (ok=%v)", pid, gotPID, gotOK) + } +} + +// --- IDOR regression --- + +// TestIDOR_CheckForeignProject_Returns404AndServiceNotCalled is the IDOR +// regression required by Task 4: a request for project B's domain, made +// with user A's session (who owns project A, not B), must be rejected by +// RequireProjectAccess with 404 before service.Check is ever invoked — a +// caller must never be able to check/apply another tenant's domain by +// guessing/leaking its ids. +func TestIDOR_CheckForeignProject_Returns404AndServiceNotCalled(t *testing.T) { + userA := uuid.New() + pidA := uuid.New() + pidB := uuid.New() // owned by a different user; not A + + svc := &recordingCheckApplier{} + auth := stubAuthStore{getProjectOwnedFn: func(_ context.Context, pid, uid uuid.UUID) (store.Project, error) { + if pid == pidA && uid == userA { + return store.Project{ID: pidA, UserID: userA}, nil + } + return store.Project{}, errors.New("not found") + }} + a := &API{Svc: svc, Auth: auth, Sessions: alwaysValidSessions(userA)} + router := NewRouter(a) + + did := uuid.New().String() + req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+pidB.String()+"/domains/"+did+"/check", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Fatalf("expected 404 for user A requesting project B's domain, got %d body %s", w.Code, w.Body.String()) + } + if svc.checkCalled { + t.Fatal("service.Check must not be called when the caller doesn't own the project") + } + + // Sanity check: the same user against their own project succeeds and + // does reach the service — proving the 404 above is really about + // project ownership, not e.g. a broken route. + req2 := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+pidA.String()+"/domains/"+did+"/check", nil) + w2 := httptest.NewRecorder() + router.ServeHTTP(w2, req2) + if w2.Code != http.StatusOK { + t.Fatalf("expected 200 for user A requesting their own project, got %d body %s", w2.Code, w2.Body.String()) + } + if !svc.checkCalled { + t.Fatal("expected service.Check to be called for the caller's own project") + } +} diff --git a/internal/api/tenant_handlers.go b/internal/api/tenant_handlers.go index 02f3dcc..13711a4 100644 --- a/internal/api/tenant_handlers.go +++ b/internal/api/tenant_handlers.go @@ -24,11 +24,9 @@ func decodeBody(w http.ResponseWriter, r *http.Request, v any) bool { // --- accounts --- func (a *API) handleCreateAccount(w http.ResponseWriter, r *http.Request) { - pid, err := uuid.Parse(chi.URLParam(r, "pid")) - if err != nil { - writeErr(w, http.StatusBadRequest, "invalid project id") - return - } + // pid is guaranteed present and owned by the caller — RequireProjectAccess + // validated it before this handler ever runs. + pid, _ := projectIDFrom(r.Context()) var req accountRequest if !decodeBody(w, r, &req) { return @@ -53,11 +51,9 @@ func (a *API) handleCreateAccount(w http.ResponseWriter, r *http.Request) { } func (a *API) handleListAccounts(w http.ResponseWriter, r *http.Request) { - pid, err := uuid.Parse(chi.URLParam(r, "pid")) - if err != nil { - writeErr(w, http.StatusBadRequest, "invalid project id") - return - } + // pid is guaranteed present and owned by the caller — RequireProjectAccess + // validated it before this handler ever runs. + pid, _ := projectIDFrom(r.Context()) accs, err := a.Store.ListAccounts(r.Context(), pid) if err != nil { log.Printf("api: list accounts failed: %v", err) @@ -72,11 +68,9 @@ func (a *API) handleListAccounts(w http.ResponseWriter, r *http.Request) { } func (a *API) handleDeleteAccount(w http.ResponseWriter, r *http.Request) { - pid, err := uuid.Parse(chi.URLParam(r, "pid")) - if err != nil { - writeErr(w, http.StatusBadRequest, "invalid project id") - return - } + // pid is guaranteed present and owned by the caller — RequireProjectAccess + // validated it before this handler ever runs. + pid, _ := projectIDFrom(r.Context()) aid, err := uuid.Parse(chi.URLParam(r, "aid")) if err != nil { writeErr(w, http.StatusBadRequest, "invalid account id") @@ -93,11 +87,9 @@ func (a *API) handleDeleteAccount(w http.ResponseWriter, r *http.Request) { // handleImportZones lists zones from the provider for the given account and // creates one domain per zone (template_id left unset). func (a *API) handleImportZones(w http.ResponseWriter, r *http.Request) { - pid, err := uuid.Parse(chi.URLParam(r, "pid")) - if err != nil { - writeErr(w, http.StatusBadRequest, "invalid project id") - return - } + // pid is guaranteed present and owned by the caller — RequireProjectAccess + // validated it before this handler ever runs. + pid, _ := projectIDFrom(r.Context()) aid, err := uuid.Parse(chi.URLParam(r, "aid")) if err != nil { writeErr(w, http.StatusBadRequest, "invalid account id") @@ -145,11 +137,9 @@ func (a *API) handleImportZones(w http.ResponseWriter, r *http.Request) { // --- templates --- func (a *API) handleCreateTemplate(w http.ResponseWriter, r *http.Request) { - pid, err := uuid.Parse(chi.URLParam(r, "pid")) - if err != nil { - writeErr(w, http.StatusBadRequest, "invalid project id") - return - } + // pid is guaranteed present and owned by the caller — RequireProjectAccess + // validated it before this handler ever runs. + pid, _ := projectIDFrom(r.Context()) var req templateRequest if !decodeBody(w, r, &req) { return @@ -169,11 +159,9 @@ func (a *API) handleCreateTemplate(w http.ResponseWriter, r *http.Request) { } func (a *API) handleListTemplates(w http.ResponseWriter, r *http.Request) { - pid, err := uuid.Parse(chi.URLParam(r, "pid")) - if err != nil { - writeErr(w, http.StatusBadRequest, "invalid project id") - return - } + // pid is guaranteed present and owned by the caller — RequireProjectAccess + // validated it before this handler ever runs. + pid, _ := projectIDFrom(r.Context()) tpls, err := a.Store.ListTemplates(r.Context(), pid) if err != nil { log.Printf("api: list templates failed: %v", err) @@ -188,11 +176,9 @@ func (a *API) handleListTemplates(w http.ResponseWriter, r *http.Request) { } func (a *API) handleUpdateTemplate(w http.ResponseWriter, r *http.Request) { - pid, err := uuid.Parse(chi.URLParam(r, "pid")) - if err != nil { - writeErr(w, http.StatusBadRequest, "invalid project id") - return - } + // pid is guaranteed present and owned by the caller — RequireProjectAccess + // validated it before this handler ever runs. + pid, _ := projectIDFrom(r.Context()) tid, err := uuid.Parse(chi.URLParam(r, "tid")) if err != nil { writeErr(w, http.StatusBadRequest, "invalid template id") @@ -217,11 +203,9 @@ func (a *API) handleUpdateTemplate(w http.ResponseWriter, r *http.Request) { } func (a *API) handleDeleteTemplate(w http.ResponseWriter, r *http.Request) { - pid, err := uuid.Parse(chi.URLParam(r, "pid")) - if err != nil { - writeErr(w, http.StatusBadRequest, "invalid project id") - return - } + // pid is guaranteed present and owned by the caller — RequireProjectAccess + // validated it before this handler ever runs. + pid, _ := projectIDFrom(r.Context()) tid, err := uuid.Parse(chi.URLParam(r, "tid")) if err != nil { writeErr(w, http.StatusBadRequest, "invalid template id") @@ -238,11 +222,9 @@ func (a *API) handleDeleteTemplate(w http.ResponseWriter, r *http.Request) { // --- domains --- func (a *API) handleCreateDomain(w http.ResponseWriter, r *http.Request) { - pid, err := uuid.Parse(chi.URLParam(r, "pid")) - if err != nil { - writeErr(w, http.StatusBadRequest, "invalid project id") - return - } + // pid is guaranteed present and owned by the caller — RequireProjectAccess + // validated it before this handler ever runs. + pid, _ := projectIDFrom(r.Context()) var req domainRequest if !decodeBody(w, r, &req) { return @@ -286,11 +268,9 @@ func (a *API) handleCreateDomain(w http.ResponseWriter, r *http.Request) { } func (a *API) handleListDomains(w http.ResponseWriter, r *http.Request) { - pid, err := uuid.Parse(chi.URLParam(r, "pid")) - if err != nil { - writeErr(w, http.StatusBadRequest, "invalid project id") - return - } + // pid is guaranteed present and owned by the caller — RequireProjectAccess + // validated it before this handler ever runs. + pid, _ := projectIDFrom(r.Context()) doms, err := a.Store.ListDomains(r.Context(), pid) if err != nil { log.Printf("api: list domains failed: %v", err) @@ -308,11 +288,9 @@ func (a *API) handleListDomains(w http.ResponseWriter, r *http.Request) { // check/apply a domain — this is what makes an imported domain (which // starts with template_id=NULL) checkable, closing the import→check loop. func (a *API) handleSetDomainTemplate(w http.ResponseWriter, r *http.Request) { - pid, err := uuid.Parse(chi.URLParam(r, "pid")) - if err != nil { - writeErr(w, http.StatusBadRequest, "invalid project id") - return - } + // pid is guaranteed present and owned by the caller — RequireProjectAccess + // validated it before this handler ever runs. + pid, _ := projectIDFrom(r.Context()) did, err := uuid.Parse(chi.URLParam(r, "did")) if err != nil { writeErr(w, http.StatusBadRequest, "invalid domain id") @@ -339,11 +317,9 @@ func (a *API) handleSetDomainTemplate(w http.ResponseWriter, r *http.Request) { } func (a *API) handleDeleteDomain(w http.ResponseWriter, r *http.Request) { - pid, err := uuid.Parse(chi.URLParam(r, "pid")) - if err != nil { - writeErr(w, http.StatusBadRequest, "invalid project id") - return - } + // pid is guaranteed present and owned by the caller — RequireProjectAccess + // validated it before this handler ever runs. + pid, _ := projectIDFrom(r.Context()) did, err := uuid.Parse(chi.URLParam(r, "did")) if err != nil { writeErr(w, http.StatusBadRequest, "invalid domain id") diff --git a/internal/api/tenant_test.go b/internal/api/tenant_test.go index 2b807a8..73c9f89 100644 --- a/internal/api/tenant_test.go +++ b/internal/api/tenant_test.go @@ -132,7 +132,9 @@ func (m *mockTenantStore) ImportDomains(_ context.Context, projectID, accountID type mockCipher struct{} -func (mockCipher) Encrypt(plaintext []byte) (string, error) { return "ENC(" + string(plaintext) + ")", nil } +func (mockCipher) Encrypt(plaintext []byte) (string, error) { + return "ENC(" + string(plaintext) + ")", nil +} func (mockCipher) Decrypt(enc string) ([]byte, error) { return []byte(strings.TrimSuffix(strings.TrimPrefix(enc, "ENC("), ")")), nil } @@ -160,9 +162,17 @@ func (mockProvider) ApplyChanges(context.Context, provider.Credentials, string, return nil } +// newTenantTestAPI wires a fixed authenticated user who owns whatever +// project id is requested (alwaysOwnedAuthStore/alwaysValidSessions, see +// middleware_test.go) — these tests exercise CRUD behavior past the +// RequireAuth/RequireProjectAccess boundary, which has its own dedicated +// coverage in middleware_test.go. func newTenantTestAPI() (*API, *mockTenantStore) { ts := &mockTenantStore{} - a := &API{Store: ts, Cipher: mockCipher{}, Reg: &mockRegistry{}} + a := &API{ + Store: ts, Cipher: mockCipher{}, Reg: &mockRegistry{}, + Auth: alwaysOwnedAuthStore(), Sessions: alwaysValidSessions(uuid.New()), + } return a, ts } @@ -173,7 +183,7 @@ func TestCreateAccount_SecretEncryptedAndNotInResponse(t *testing.T) { router := NewRouter(a) body := `{"provider":"selectel","secret":"super-secret-token","comment":"prod"}` - req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts", strings.NewReader(body)) + req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts", strings.NewReader(body)) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -211,7 +221,7 @@ func TestListAccounts_NoSecretsInResponse(t *testing.T) { } router := NewRouter(a) - req := httptest.NewRequest(http.MethodGet, "/api/v1/projects/"+testPID+"/accounts", nil) + req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+testPID+"/accounts", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -234,7 +244,7 @@ func TestDeleteAccount_BadUUID(t *testing.T) { a, _ := newTenantTestAPI() router := NewRouter(a) - req := httptest.NewRequest(http.MethodDelete, "/api/v1/projects/"+testPID+"/accounts/not-a-uuid", nil) + req := requestWithSessionCookie(http.MethodDelete, "/api/v1/projects/"+testPID+"/accounts/not-a-uuid", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -250,7 +260,7 @@ func TestCreateTemplate_SavesRecords(t *testing.T) { router := NewRouter(a) body := `{"name":"base","records":[{"type":"A","name":"@","ttl":300,"values":["1.2.3.4"]}]}` - req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/templates", strings.NewReader(body)) + req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/templates", strings.NewReader(body)) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -278,7 +288,7 @@ func TestUpdateTemplate_BadUUID(t *testing.T) { router := NewRouter(a) body := `{"name":"x","records":[]}` - req := httptest.NewRequest(http.MethodPut, "/api/v1/projects/"+testPID+"/templates/not-a-uuid", strings.NewReader(body)) + req := requestWithSessionCookie(http.MethodPut, "/api/v1/projects/"+testPID+"/templates/not-a-uuid", strings.NewReader(body)) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -299,7 +309,7 @@ func TestImportZones_CreatesDomainPerZone(t *testing.T) { }} router := NewRouter(a) - req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/"+accID.String()+"/import", nil) + req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/"+accID.String()+"/import", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -337,7 +347,7 @@ func TestImportZones_AtomicRollbackOnError(t *testing.T) { }} router := NewRouter(a) - req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/"+accID.String()+"/import", nil) + req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/"+accID.String()+"/import", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -356,7 +366,7 @@ func TestImportZones_BadAccountUUID(t *testing.T) { a, _ := newTenantTestAPI() router := NewRouter(a) - req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/not-a-uuid/import", nil) + req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/not-a-uuid/import", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -370,7 +380,7 @@ func TestCreateDomain_BadProjectUUID(t *testing.T) { router := NewRouter(a) body := `{"providerAccountId":"` + uuid.New().String() + `","zoneName":"example.com","zoneId":"z1"}` - req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/not-a-uuid/domains", strings.NewReader(body)) + req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/not-a-uuid/domains", strings.NewReader(body)) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -390,7 +400,7 @@ func TestCreateDomain_AccountNotFoundInProject(t *testing.T) { // ts.accounts is empty, so GetAccount will not find this id. foreignAccID := uuid.New() body := `{"providerAccountId":"` + foreignAccID.String() + `","zoneName":"example.com","zoneId":"z1"}` - req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body)) + req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body)) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -413,7 +423,7 @@ func TestCreateDomain_TemplateNotFoundInProject(t *testing.T) { foreignTplID := uuid.New() body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1","templateId":"` + foreignTplID.String() + `"}` - req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body)) + req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body)) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -434,7 +444,7 @@ func TestCreateDomain_HappyPath(t *testing.T) { router := NewRouter(a) body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1"}` - req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body)) + req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body)) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -464,7 +474,7 @@ func TestCreateDomain_ValidTemplateInProject(t *testing.T) { router := NewRouter(a) body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1","templateId":"` + tplID.String() + `"}` - req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body)) + req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body)) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -490,7 +500,7 @@ func TestSetDomainTemplate_ValidTemplateId(t *testing.T) { router := NewRouter(a) body := `{"templateId":"` + tplID.String() + `"}` - req := httptest.NewRequest(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body)) + req := requestWithSessionCookie(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body)) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -513,7 +523,7 @@ func TestSetDomainTemplate_BadTemplateUUID(t *testing.T) { router := NewRouter(a) body := `{"templateId":"not-a-uuid"}` - req := httptest.NewRequest(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body)) + req := requestWithSessionCookie(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body)) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -530,7 +540,7 @@ func TestSetDomainTemplate_TemplateNotFound(t *testing.T) { router := NewRouter(a) body := `{"templateId":"` + uuid.New().String() + `"}` - req := httptest.NewRequest(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body)) + req := requestWithSessionCookie(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body)) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -543,7 +553,7 @@ func TestDeleteDomain_BadUUID(t *testing.T) { a, _ := newTenantTestAPI() router := NewRouter(a) - req := httptest.NewRequest(http.MethodDelete, "/api/v1/projects/"+testPID+"/domains/not-a-uuid", nil) + req := requestWithSessionCookie(http.MethodDelete, "/api/v1/projects/"+testPID+"/domains/not-a-uuid", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) diff --git a/internal/service/service.go b/internal/service/service.go index 2dec925..ea8c46d 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -21,7 +21,7 @@ type DomainRef struct { } type Loader interface { - LoadDomain(ctx context.Context, domainID uuid.UUID) (DomainRef, error) + LoadDomain(ctx context.Context, projectID, domainID uuid.UUID) (DomainRef, error) } type Recorder interface { @@ -45,8 +45,10 @@ func New(loader Loader, rec Recorder, reg *registry.Registry, cipher *crypto.Cip } // resolve loads the domain, its provider and decrypted credentials, and computes the diff. -func (s *DomainService) resolve(ctx context.Context, domainID uuid.UUID) (provider.Provider, provider.Credentials, DomainRef, diff.Changeset, error) { - ref, err := s.loader.LoadDomain(ctx, domainID) +// projectID scopes the lookup so a domainID belonging to another tenant's +// project can never be resolved here (closes IDOR). +func (s *DomainService) resolve(ctx context.Context, projectID, domainID uuid.UUID) (provider.Provider, provider.Credentials, DomainRef, diff.Changeset, error) { + ref, err := s.loader.LoadDomain(ctx, projectID, domainID) if err != nil { return nil, provider.Credentials{}, ref, diff.Changeset{}, err } @@ -68,8 +70,8 @@ func (s *DomainService) resolve(ctx context.Context, domainID uuid.UUID) (provid } // Check computes and records the diff between template and zone. -func (s *DomainService) Check(ctx context.Context, domainID uuid.UUID) (diff.Changeset, error) { - _, _, _, cs, err := s.resolve(ctx, domainID) +func (s *DomainService) Check(ctx context.Context, projectID, domainID uuid.UUID) (diff.Changeset, error) { + _, _, _, cs, err := s.resolve(ctx, projectID, domainID) if err != nil { return diff.Changeset{}, err } @@ -80,8 +82,8 @@ func (s *DomainService) Check(ctx context.Context, domainID uuid.UUID) (diff.Cha } // Apply applies updates always (when ApplyUpdates) and prunes only when ApplyPrunes. -func (s *DomainService) Apply(ctx context.Context, domainID uuid.UUID, req ApplyRequest) (diff.Changeset, error) { - p, creds, ref, cs, err := s.resolve(ctx, domainID) +func (s *DomainService) Apply(ctx context.Context, projectID, domainID uuid.UUID, req ApplyRequest) (diff.Changeset, error) { + p, creds, ref, cs, err := s.resolve(ctx, projectID, domainID) if err != nil { return diff.Changeset{}, err } diff --git a/internal/service/service_test.go b/internal/service/service_test.go index e799306..d3e7cf7 100644 --- a/internal/service/service_test.go +++ b/internal/service/service_test.go @@ -44,7 +44,9 @@ func (f *fakeProvider) ApplyChanges(_ context.Context, _ provider.Credentials, _ type fakeLoader struct{ ref DomainRef } -func (l fakeLoader) LoadDomain(context.Context, uuid.UUID) (DomainRef, error) { return l.ref, nil } +func (l fakeLoader) LoadDomain(context.Context, uuid.UUID, uuid.UUID) (DomainRef, error) { + return l.ref, nil +} type nopRecorder struct{} @@ -66,7 +68,7 @@ func TestCheckProducesDiff(t *testing.T) { {Type: "A", Name: "a.example.com.", TTL: 300, Values: []string{"1.1.1.1"}}, // update }} svc, _ := setup(t, actual, tmpl) - cs, err := svc.Check(context.Background(), uuid.New()) + cs, err := svc.Check(context.Background(), uuid.New(), uuid.New()) if err != nil { t.Fatal(err) } @@ -87,7 +89,7 @@ func TestApplyRespectsPruneGuard(t *testing.T) { // applyPrunes=false → удаление b НЕ применяется svc, fp := setup(t, actual, tmpl) - if _, err := svc.Apply(context.Background(), uuid.New(), ApplyRequest{ApplyUpdates: true, ApplyPrunes: false}); err != nil { + if _, err := svc.Apply(context.Background(), uuid.New(), uuid.New(), ApplyRequest{ApplyUpdates: true, ApplyPrunes: false}); err != nil { t.Fatal(err) } for _, d := range fp.applied.Diffs { @@ -98,7 +100,7 @@ func TestApplyRespectsPruneGuard(t *testing.T) { // applyPrunes=true → удаление b применяется svc2, fp2 := setup(t, actual, tmpl) - if _, err := svc2.Apply(context.Background(), uuid.New(), ApplyRequest{ApplyUpdates: true, ApplyPrunes: true}); err != nil { + if _, err := svc2.Apply(context.Background(), uuid.New(), uuid.New(), ApplyRequest{ApplyUpdates: true, ApplyPrunes: true}); err != nil { t.Fatal(err) } var sawDelete bool diff --git a/internal/store/db/domains.sql.go b/internal/store/db/domains.sql.go index 3b5aa32..527bd66 100644 --- a/internal/store/db/domains.sql.go +++ b/internal/store/db/domains.sql.go @@ -162,9 +162,14 @@ SELECT d.zone_id, a.provider, a.secret_enc, t.doc FROM domains d JOIN provider_accounts a ON a.id = d.provider_account_id LEFT JOIN templates t ON t.id = d.template_id -WHERE d.id = $1 +WHERE d.id = $1 AND d.project_id = $2 ` +type LoadDomainFullParams struct { + ID uuid.UUID `json:"id"` + ProjectID uuid.UUID `json:"project_id"` +} + type LoadDomainFullRow struct { ZoneID string `json:"zone_id"` Provider string `json:"provider"` @@ -172,8 +177,8 @@ type LoadDomainFullRow struct { Doc *dto.TemplateDoc `json:"doc"` } -func (q *Queries) LoadDomainFull(ctx context.Context, id uuid.UUID) (LoadDomainFullRow, error) { - row := q.db.QueryRow(ctx, loadDomainFull, id) +func (q *Queries) LoadDomainFull(ctx context.Context, arg LoadDomainFullParams) (LoadDomainFullRow, error) { + row := q.db.QueryRow(ctx, loadDomainFull, arg.ID, arg.ProjectID) var i LoadDomainFullRow err := row.Scan( &i.ZoneID, diff --git a/internal/store/loader.go b/internal/store/loader.go index 984c85c..620db85 100644 --- a/internal/store/loader.go +++ b/internal/store/loader.go @@ -13,9 +13,11 @@ import ( ) // LoadDomain joins domains+provider_accounts+templates to build the -// service.DomainRef needed to check/apply a domain's DNS records. -func (s *Store) LoadDomain(ctx context.Context, domainID uuid.UUID) (service.DomainRef, error) { - row, err := s.q.LoadDomainFull(ctx, domainID) +// service.DomainRef needed to check/apply a domain's DNS records. Scoped by +// projectID so a domain belonging to another tenant's project can never be +// loaded, even if its domainID is guessed/leaked (closes IDOR). +func (s *Store) LoadDomain(ctx context.Context, projectID, domainID uuid.UUID) (service.DomainRef, error) { + row, err := s.q.LoadDomainFull(ctx, db.LoadDomainFullParams{ID: domainID, ProjectID: projectID}) if err != nil { return service.DomainRef{}, err } diff --git a/internal/store/loader_test.go b/internal/store/loader_test.go index 33f210f..c53a1c7 100644 --- a/internal/store/loader_test.go +++ b/internal/store/loader_test.go @@ -40,7 +40,7 @@ func TestLoadDomainAndSaveCheckRun(t *testing.T) { t.Fatal(err) } - ref, err := s.LoadDomain(ctx, domain.ID) + ref, err := s.LoadDomain(ctx, defaultProject, domain.ID) if err != nil { t.Fatal(err) } @@ -87,7 +87,7 @@ func TestLoadDomainNoTemplate(t *testing.T) { t.Fatal(err) } - if _, err := s.LoadDomain(ctx, domain.ID); err == nil { + if _, err := s.LoadDomain(ctx, defaultProject, domain.ID); err == nil { t.Fatal("expected error for domain without template, got nil") } } diff --git a/internal/store/queries/domains.sql b/internal/store/queries/domains.sql index 1f05fda..9e16759 100644 --- a/internal/store/queries/domains.sql +++ b/internal/store/queries/domains.sql @@ -27,4 +27,4 @@ SELECT d.zone_id, a.provider, a.secret_enc, t.doc FROM domains d JOIN provider_accounts a ON a.id = d.provider_account_id LEFT JOIN templates t ON t.id = d.template_id -WHERE d.id = $1; +WHERE d.id = $1 AND d.project_id = $2; diff --git a/internal/store/store_test.go b/internal/store/store_test.go index bcbdc12..9572617 100644 --- a/internal/store/store_test.go +++ b/internal/store/store_test.go @@ -214,7 +214,7 @@ func TestSetDomainTemplate_ClosesImportCheckLoop(t *testing.T) { dom := doms[0] // Before binding, the domain is not checkable. - if _, err := s.LoadDomain(ctx, dom.ID); err == nil { + if _, err := s.LoadDomain(ctx, defaultProject, dom.ID); err == nil { t.Fatal("expected LoadDomain to fail before a template is bound") } @@ -234,7 +234,7 @@ func TestSetDomainTemplate_ClosesImportCheckLoop(t *testing.T) { t.Fatalf("expected domain.TemplateID=%s, got %+v", tpl.ID, updated.TemplateID) } - ref, err := s.LoadDomain(ctx, dom.ID) + ref, err := s.LoadDomain(ctx, defaultProject, dom.ID) if err != nil { t.Fatalf("expected LoadDomain to succeed after binding template, got error: %v", err) }