feat(api): RequireAuth+RequireProjectAccess middleware, IDOR-scope check/apply по projectID

This commit is contained in:
2026-07-03 20:47:40 +07:00
parent 35ffe73ae3
commit 4533b0ca25
16 changed files with 498 additions and 143 deletions
+14 -4
View File
@@ -18,8 +18,8 @@ import (
// CheckApplier is the service surface the API depends on. // CheckApplier is the service surface the API depends on.
type CheckApplier interface { type CheckApplier interface {
Check(ctx context.Context, domainID uuid.UUID) (diff.Changeset, error) Check(ctx context.Context, projectID, domainID uuid.UUID) (diff.Changeset, error)
Apply(ctx context.Context, domainID uuid.UUID, req service.ApplyRequest) (diff.Changeset, error) Apply(ctx context.Context, projectID, domainID uuid.UUID, req service.ApplyRequest) (diff.Changeset, error)
} }
// TenantStore is the narrow persistence surface the CRUD handlers depend on. // TenantStore is the narrow persistence surface the CRUD handlers depend on.
@@ -63,6 +63,10 @@ type AuthStore interface {
GetUserByEmail(ctx context.Context, email string) (store.User, error) GetUserByEmail(ctx context.Context, email string) (store.User, error)
GetUserByID(ctx context.Context, userID uuid.UUID) (store.User, error) GetUserByID(ctx context.Context, userID uuid.UUID) (store.User, error)
GetUserProject(ctx context.Context, userID uuid.UUID) (store.Project, error) GetUserProject(ctx context.Context, userID uuid.UUID) (store.Project, error)
// GetProjectOwned looks up projectID and returns it only if it's owned by
// userID — RequireProjectAccess uses this to reject foreign/nonexistent
// projects with 404 before any handler runs.
GetProjectOwned(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error)
} }
// SessionManager creates/validates/destroys login sessions. *auth.Sessions // SessionManager creates/validates/destroys login sessions. *auth.Sessions
@@ -91,11 +95,17 @@ func NewRouter(a *API) http.Handler {
r.Route("/api/v1/auth", func(r chi.Router) { r.Route("/api/v1/auth", func(r chi.Router) {
r.Post("/register", a.handleRegister) r.Post("/register", a.handleRegister)
r.Post("/login", a.handleLogin) r.Post("/login", a.handleLogin)
r.Post("/logout", a.handleLogout) // защитится RequireAuth в Task 4 r.Group(func(r chi.Router) {
r.Get("/me", a.handleMe) // защитится RequireAuth в Task 4 r.Use(a.RequireAuth)
r.Post("/logout", a.handleLogout)
r.Get("/me", a.handleMe)
})
}) })
r.Route("/api/v1/projects/{pid}", func(r chi.Router) { r.Route("/api/v1/projects/{pid}", func(r chi.Router) {
r.Use(a.RequireAuth)
r.Use(a.RequireProjectAccess)
r.Route("/domains", func(r chi.Router) { r.Route("/domains", func(r chi.Router) {
r.Post("/", a.handleCreateDomain) r.Post("/", a.handleCreateDomain)
r.Get("/", a.handleListDomains) r.Get("/", a.handleListDomains)
+13 -8
View File
@@ -20,18 +20,23 @@ type mockCheckApplier struct {
lastReq service.ApplyRequest lastReq service.ApplyRequest
} }
func (m *mockCheckApplier) Check(context.Context, uuid.UUID) (diff.Changeset, error) { func (m *mockCheckApplier) Check(context.Context, uuid.UUID, uuid.UUID) (diff.Changeset, error) {
d := model.Record{Type: model.A, Name: "a.example.com.", TTL: 300, Values: []string{"1.1.1.1"}} d := model.Record{Type: model.A, Name: "a.example.com.", TTL: 300, Values: []string{"1.1.1.1"}}
return diff.Changeset{Diffs: []diff.RecordDiff{{Kind: diff.Add, Type: d.Type, Name: d.Name, Desired: &d}}}, nil return diff.Changeset{Diffs: []diff.RecordDiff{{Kind: diff.Add, Type: d.Type, Name: d.Name, Desired: &d}}}, nil
} }
func (m *mockCheckApplier) Apply(_ context.Context, _ uuid.UUID, req service.ApplyRequest) (diff.Changeset, error) { func (m *mockCheckApplier) Apply(_ context.Context, _, _ uuid.UUID, req service.ApplyRequest) (diff.Changeset, error) {
m.lastReq = req m.lastReq = req
return diff.Changeset{}, nil return diff.Changeset{}, nil
} }
// newTestAPI wires a fixed authenticated user who owns whatever project id
// is requested (via alwaysOwnedAuthStore/alwaysValidSessions in
// middleware_test.go) — these tests exercise check/apply behavior past the
// RequireAuth/RequireProjectAccess boundary, which is covered separately by
// middleware_test.go's own tests and the IDOR regression.
func newTestAPI() (*API, *mockCheckApplier) { func newTestAPI() (*API, *mockCheckApplier) {
m := &mockCheckApplier{} m := &mockCheckApplier{}
return &API{Svc: m}, m // остальные зависимости (store/cipher) nil — CRUD-тесты добавит реализатор return &API{Svc: m, Auth: alwaysOwnedAuthStore(), Sessions: alwaysValidSessions(uuid.New())}, m
} }
func TestCheckEndpoint(t *testing.T) { func TestCheckEndpoint(t *testing.T) {
@@ -39,7 +44,7 @@ func TestCheckEndpoint(t *testing.T) {
router := NewRouter(a) router := NewRouter(a)
did := uuid.New().String() did := uuid.New().String()
req := httptest.NewRequest(http.MethodGet, req := requestWithSessionCookie(http.MethodGet,
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/check", nil) "/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/check", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -62,7 +67,7 @@ func TestApplyDefaultsPruneFalse(t *testing.T) {
did := uuid.New().String() did := uuid.New().String()
body := `{"applyUpdates":true}` // applyPrunes отсутствует → false body := `{"applyUpdates":true}` // applyPrunes отсутствует → false
req := httptest.NewRequest(http.MethodPost, req := requestWithSessionCookie(http.MethodPost,
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply", "/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply",
strings.NewReader(body)) strings.NewReader(body))
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -81,7 +86,7 @@ func TestApplyEmptyBodyOK(t *testing.T) {
router := NewRouter(a) router := NewRouter(a)
did := uuid.New().String() did := uuid.New().String()
req := httptest.NewRequest(http.MethodPost, req := requestWithSessionCookie(http.MethodPost,
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply", nil) "/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -100,7 +105,7 @@ func TestApplyMalformedBody(t *testing.T) {
did := uuid.New().String() did := uuid.New().String()
body := `{"applyUpdates":` body := `{"applyUpdates":`
req := httptest.NewRequest(http.MethodPost, req := requestWithSessionCookie(http.MethodPost,
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply", "/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply",
strings.NewReader(body)) strings.NewReader(body))
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -114,7 +119,7 @@ func TestApplyMalformedBody(t *testing.T) {
func TestApplyBadUUID(t *testing.T) { func TestApplyBadUUID(t *testing.T) {
a, _ := newTestAPI() a, _ := newTestAPI()
router := NewRouter(a) router := NewRouter(a)
req := httptest.NewRequest(http.MethodPost, req := requestWithSessionCookie(http.MethodPost,
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/not-a-uuid/apply", "/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/not-a-uuid/apply",
bytes.NewReader([]byte(`{}`))) bytes.NewReader([]byte(`{}`)))
w := httptest.NewRecorder() w := httptest.NewRecorder()
+3 -20
View File
@@ -1,15 +1,12 @@
package api package api
import ( import (
"context"
"errors" "errors"
"log" "log"
"net/http" "net/http"
"strings" "strings"
"time" "time"
"github.com/google/uuid"
"github.com/vasyakrg/dns-autoresolver/internal/auth" "github.com/vasyakrg/dns-autoresolver/internal/auth"
"github.com/vasyakrg/dns-autoresolver/internal/store" "github.com/vasyakrg/dns-autoresolver/internal/store"
) )
@@ -39,19 +36,6 @@ func normalizeEmail(email string) string {
return strings.ToLower(strings.TrimSpace(email)) return strings.ToLower(strings.TrimSpace(email))
} }
// ctxKeyUserID is a private context key carrying the authenticated user's ID.
// Task 4's RequireAuth middleware sets it after validating the session
// cookie; handleMe reads it back.
type ctxKeyUserID struct{}
// userIDFromContext extracts the authenticated user ID set by RequireAuth
// (Task 4). Until that middleware is wired in, tests set it directly via
// context.WithValue.
func userIDFromContext(ctx context.Context) (uuid.UUID, bool) {
id, ok := ctx.Value(ctxKeyUserID{}).(uuid.UUID)
return id, ok
}
func setSessionCookie(w http.ResponseWriter, token string, exp time.Time) { func setSessionCookie(w http.ResponseWriter, token string, exp time.Time) {
http.SetCookie(w, &http.Cookie{ http.SetCookie(w, &http.Cookie{
Name: sessionCookieName, Value: token, Path: "/", Name: sessionCookieName, Value: token, Path: "/",
@@ -165,11 +149,10 @@ func (a *API) handleLogout(w http.ResponseWriter, r *http.Request) {
} }
// handleMe returns the authenticated caller's identity + default project. // handleMe returns the authenticated caller's identity + default project.
// The user ID comes from the request context, set by Task 4's RequireAuth // The user ID comes from the request context, set by RequireAuth after
// middleware after validating the session cookie (tests set it directly via // validating the session cookie.
// context.WithValue in the interim).
func (a *API) handleMe(w http.ResponseWriter, r *http.Request) { func (a *API) handleMe(w http.ResponseWriter, r *http.Request) {
userID, ok := userIDFromContext(r.Context()) userID, ok := userIDFrom(r.Context())
if !ok { if !ok {
writeErr(w, http.StatusUnauthorized, "authentication required") writeErr(w, http.StatusUnauthorized, "authentication required")
return return
+17 -3
View File
@@ -22,6 +22,7 @@ type mockAuthStore struct {
getUserByEmailFn func(ctx context.Context, email string) (store.User, error) getUserByEmailFn func(ctx context.Context, email string) (store.User, error)
getUserByIDFn func(ctx context.Context, userID uuid.UUID) (store.User, error) getUserByIDFn func(ctx context.Context, userID uuid.UUID) (store.User, error)
getUserProjectFn func(ctx context.Context, userID uuid.UUID) (store.Project, error) getUserProjectFn func(ctx context.Context, userID uuid.UUID) (store.Project, error)
getProjectOwnedFn func(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error)
} }
func (m *mockAuthStore) RegisterUser(ctx context.Context, email, passwordHash string) (store.User, store.Project, error) { func (m *mockAuthStore) RegisterUser(ctx context.Context, email, passwordHash string) (store.User, store.Project, error) {
@@ -40,8 +41,13 @@ func (m *mockAuthStore) GetUserProject(ctx context.Context, userID uuid.UUID) (s
return m.getUserProjectFn(ctx, userID) return m.getUserProjectFn(ctx, userID)
} }
func (m *mockAuthStore) GetProjectOwned(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error) {
return m.getProjectOwnedFn(ctx, projectID, userID)
}
type mockSessionManager struct { type mockSessionManager struct {
createFn func(ctx context.Context, userID uuid.UUID) (string, time.Time, error) createFn func(ctx context.Context, userID uuid.UUID) (string, time.Time, error)
validateFn func(ctx context.Context, token string) (uuid.UUID, error)
destroyCalled bool destroyCalled bool
destroyToken string destroyToken string
@@ -52,7 +58,10 @@ func (m *mockSessionManager) Create(ctx context.Context, userID uuid.UUID) (stri
return m.createFn(ctx, userID) return m.createFn(ctx, userID)
} }
func (m *mockSessionManager) Validate(context.Context, string) (uuid.UUID, error) { func (m *mockSessionManager) Validate(ctx context.Context, token string) (uuid.UUID, error) {
if m.validateFn != nil {
return m.validateFn(ctx, token)
}
return uuid.Nil, nil return uuid.Nil, nil
} }
@@ -338,7 +347,7 @@ func TestAuthLogout_ClearsSessionAndDestroys(t *testing.T) {
// now resolves the authenticated user via GetUserByID and returns their real // now resolves the authenticated user via GetUserByID and returns their real
// email, instead of leaving it blank. // email, instead of leaving it blank.
func TestAuthMe_ReturnsRealEmail(t *testing.T) { func TestAuthMe_ReturnsRealEmail(t *testing.T) {
a, authStore, _ := newTestAuthAPI() a, authStore, sessions := newTestAuthAPI()
userID := uuid.New() userID := uuid.New()
projectID := uuid.New() projectID := uuid.New()
authStore.getUserByIDFn = func(_ context.Context, id uuid.UUID) (store.User, error) { authStore.getUserByIDFn = func(_ context.Context, id uuid.UUID) (store.User, error) {
@@ -350,10 +359,15 @@ func TestAuthMe_ReturnsRealEmail(t *testing.T) {
authStore.getUserProjectFn = func(_ context.Context, uid uuid.UUID) (store.Project, error) { authStore.getUserProjectFn = func(_ context.Context, uid uuid.UUID) (store.Project, error) {
return store.Project{ID: projectID, UserID: uid, Name: "default"}, nil return store.Project{ID: projectID, UserID: uid, Name: "default"}, nil
} }
// /me is now behind RequireAuth: the session cookie must resolve to
// userID via Sessions.Validate rather than being injected directly.
sessions.validateFn = func(context.Context, string) (uuid.UUID, error) {
return userID, nil
}
router := NewRouter(a) router := NewRouter(a)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/me", nil) req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/me", nil)
req = req.WithContext(context.WithValue(req.Context(), ctxKeyUserID{}, userID)) req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: "some-valid-token"})
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
+8 -2
View File
@@ -24,12 +24,15 @@ func writeErr(w http.ResponseWriter, status int, msg string) {
} }
func (a *API) handleCheck(w http.ResponseWriter, r *http.Request) { func (a *API) handleCheck(w http.ResponseWriter, r *http.Request) {
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
did, err := uuid.Parse(chi.URLParam(r, "did")) did, err := uuid.Parse(chi.URLParam(r, "did"))
if err != nil { if err != nil {
writeErr(w, http.StatusBadRequest, "invalid domain id") writeErr(w, http.StatusBadRequest, "invalid domain id")
return return
} }
cs, err := a.Svc.Check(r.Context(), did) cs, err := a.Svc.Check(r.Context(), pid, did)
if err != nil { if err != nil {
log.Printf("api: check failed: %v", err) log.Printf("api: check failed: %v", err)
writeErr(w, http.StatusInternalServerError, "internal error") writeErr(w, http.StatusInternalServerError, "internal error")
@@ -39,6 +42,9 @@ func (a *API) handleCheck(w http.ResponseWriter, r *http.Request) {
} }
func (a *API) handleApply(w http.ResponseWriter, r *http.Request) { func (a *API) handleApply(w http.ResponseWriter, r *http.Request) {
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
did, err := uuid.Parse(chi.URLParam(r, "did")) did, err := uuid.Parse(chi.URLParam(r, "did"))
if err != nil { if err != nil {
writeErr(w, http.StatusBadRequest, "invalid domain id") writeErr(w, http.StatusBadRequest, "invalid domain id")
@@ -54,7 +60,7 @@ func (a *API) handleApply(w http.ResponseWriter, r *http.Request) {
return return
} }
} }
cs, err := a.Svc.Apply(r.Context(), did, service.ApplyRequest{ cs, err := a.Svc.Apply(r.Context(), pid, did, service.ApplyRequest{
ApplyUpdates: req.ApplyUpdates, ApplyPrunes: req.ApplyPrunes, ApplyUpdates: req.ApplyUpdates, ApplyPrunes: req.ApplyPrunes,
}) })
if err != nil { if err != nil {
+75
View File
@@ -0,0 +1,75 @@
package api
import (
"context"
"net/http"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
)
type ctxKey string
const (
ctxUserID ctxKey = "userID"
ctxProjectID ctxKey = "projectID"
)
// RequireAuth validates the session cookie and, on success, stores the
// authenticated user's ID in the request context (see userIDFrom). Any
// failure — missing cookie or an invalid/expired token — is rejected with
// 401 before reaching the wrapped handler.
func (a *API) RequireAuth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := r.Cookie(sessionCookieName)
if err != nil {
writeErr(w, http.StatusUnauthorized, "unauthorized")
return
}
uid, err := a.Sessions.Validate(r.Context(), c.Value)
if err != nil {
writeErr(w, http.StatusUnauthorized, "unauthorized")
return
}
ctx := context.WithValue(r.Context(), ctxUserID, uid)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// RequireProjectAccess verifies that the {pid} URL segment names a project
// owned by the authenticated user (set by RequireAuth, which must run
// first) and, on success, stores the project ID in the request context (see
// projectIDFrom). A project that doesn't exist or isn't owned by the caller
// is rejected with 404 — not 403 — so a caller can't distinguish "not yours"
// from "doesn't exist" (closes IDOR by never confirming another tenant's
// project exists).
func (a *API) RequireProjectAccess(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
uid, ok := userIDFrom(r.Context())
if !ok {
writeErr(w, http.StatusUnauthorized, "unauthorized")
return
}
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid project id")
return
}
if _, err := a.Auth.GetProjectOwned(r.Context(), pid, uid); err != nil {
writeErr(w, http.StatusNotFound, "not found")
return
}
ctx := context.WithValue(r.Context(), ctxProjectID, pid)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func userIDFrom(ctx context.Context) (uuid.UUID, bool) {
v, ok := ctx.Value(ctxUserID).(uuid.UUID)
return v, ok
}
func projectIDFrom(ctx context.Context) (uuid.UUID, bool) {
v, ok := ctx.Value(ctxProjectID).(uuid.UUID)
return v, ok
}
+265
View File
@@ -0,0 +1,265 @@
package api
import (
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/vasyakrg/dns-autoresolver/internal/diff"
"github.com/vasyakrg/dns-autoresolver/internal/service"
"github.com/vasyakrg/dns-autoresolver/internal/store"
)
// --- shared test doubles (also used by api_test.go / tenant_test.go) ---
// stubSessions is a configurable SessionManager test double.
type stubSessions struct {
validateFn func(ctx context.Context, token string) (uuid.UUID, error)
}
func (s stubSessions) Create(context.Context, uuid.UUID) (string, time.Time, error) {
return "stub-token", time.Now().Add(time.Hour), nil
}
func (s stubSessions) Validate(ctx context.Context, token string) (uuid.UUID, error) {
return s.validateFn(ctx, token)
}
func (s stubSessions) Destroy(context.Context, string) error { return nil }
// alwaysValidSessions builds a stubSessions whose Validate always succeeds
// with the given fixed user ID — for tests that only care about behavior
// past the RequireAuth boundary.
func alwaysValidSessions(uid uuid.UUID) stubSessions {
return stubSessions{validateFn: func(context.Context, string) (uuid.UUID, error) { return uid, nil }}
}
// stubAuthStore is a configurable AuthStore test double. Methods besides
// GetProjectOwned return zero values — tests exercising register/login/me
// use their own dedicated mock (mockAuthStore in auth_test.go).
type stubAuthStore struct {
getProjectOwnedFn func(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error)
}
func (s stubAuthStore) RegisterUser(context.Context, string, string) (store.User, store.Project, error) {
return store.User{}, store.Project{}, nil
}
func (s stubAuthStore) GetUserByEmail(context.Context, string) (store.User, error) {
return store.User{}, nil
}
func (s stubAuthStore) GetUserByID(context.Context, uuid.UUID) (store.User, error) {
return store.User{}, nil
}
func (s stubAuthStore) GetUserProject(context.Context, uuid.UUID) (store.Project, error) {
return store.Project{}, nil
}
func (s stubAuthStore) GetProjectOwned(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error) {
return s.getProjectOwnedFn(ctx, projectID, userID)
}
// alwaysOwnedAuthStore builds a stubAuthStore whose GetProjectOwned always
// succeeds, treating the caller as the owner of whatever project id is
// requested — for tests that only care about behavior past the
// RequireProjectAccess boundary (CRUD/check/apply tests).
func alwaysOwnedAuthStore() stubAuthStore {
return stubAuthStore{getProjectOwnedFn: func(_ context.Context, pid, uid uuid.UUID) (store.Project, error) {
return store.Project{ID: pid, UserID: uid}, nil
}}
}
// requestWithSessionCookie builds an httptest request carrying a session
// cookie so it clears RequireAuth in tests using stubSessions/alwaysValidSessions
// (which ignore the token value and validate unconditionally).
func requestWithSessionCookie(method, url string, body io.Reader) *http.Request {
req := httptest.NewRequest(method, url, body)
req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: "test-session-token"})
return req
}
// recordingCheckApplier is a CheckApplier test double that records whether
// Check/Apply were invoked — used by the IDOR regression test to assert the
// service layer is never reached for a project the caller doesn't own.
type recordingCheckApplier struct {
checkCalled bool
applyCalled bool
}
func (r *recordingCheckApplier) Check(context.Context, uuid.UUID, uuid.UUID) (diff.Changeset, error) {
r.checkCalled = true
return diff.Changeset{}, nil
}
func (r *recordingCheckApplier) Apply(context.Context, uuid.UUID, uuid.UUID, service.ApplyRequest) (diff.Changeset, error) {
r.applyCalled = true
return diff.Changeset{}, nil
}
// --- RequireAuth ---
func TestRequireAuth_NoCookie_Returns401(t *testing.T) {
a := &API{Sessions: stubSessions{validateFn: func(context.Context, string) (uuid.UUID, error) {
t.Fatal("Validate must not be called when there is no cookie")
return uuid.Nil, nil
}}}
nextCalled := false
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { nextCalled = true })
req := httptest.NewRequest(http.MethodGet, "/whatever", nil)
w := httptest.NewRecorder()
a.RequireAuth(next).ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", w.Code)
}
if nextCalled {
t.Fatal("next must not be called without a session cookie")
}
}
func TestRequireAuth_InvalidToken_Returns401(t *testing.T) {
a := &API{Sessions: stubSessions{validateFn: func(context.Context, string) (uuid.UUID, error) {
return uuid.Nil, errors.New("invalid or expired session")
}}}
nextCalled := false
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { nextCalled = true })
req := requestWithSessionCookie(http.MethodGet, "/whatever", nil)
w := httptest.NewRecorder()
a.RequireAuth(next).ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", w.Code)
}
if nextCalled {
t.Fatal("next must not be called when Sessions.Validate fails")
}
}
func TestRequireAuth_Success_CallsNextWithUserID(t *testing.T) {
uid := uuid.New()
a := &API{Sessions: alwaysValidSessions(uid)}
var gotUID uuid.UUID
var gotOK bool
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotUID, gotOK = userIDFrom(r.Context())
w.WriteHeader(http.StatusOK)
})
req := requestWithSessionCookie(http.MethodGet, "/whatever", nil)
w := httptest.NewRecorder()
a.RequireAuth(next).ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200 (next called), got %d", w.Code)
}
if !gotOK || gotUID != uid {
t.Fatalf("expected userIDFrom to yield %s, got %s (ok=%v)", uid, gotUID, gotOK)
}
}
// --- RequireProjectAccess ---
func TestRequireProjectAccess_NotOwned_Returns404(t *testing.T) {
a := &API{Auth: stubAuthStore{getProjectOwnedFn: func(context.Context, uuid.UUID, uuid.UUID) (store.Project, error) {
return store.Project{}, errors.New("no rows")
}}}
nextCalled := false
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { nextCalled = true })
router := chi.NewRouter()
router.With(a.RequireProjectAccess).Get("/projects/{pid}", next.ServeHTTP)
req := httptest.NewRequest(http.MethodGet, "/projects/"+uuid.New().String(), nil)
req = req.WithContext(context.WithValue(req.Context(), ctxUserID, uuid.New()))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404 for an unowned/foreign project, got %d body %s", w.Code, w.Body.String())
}
if nextCalled {
t.Fatal("next must not be called when the project isn't owned by the caller")
}
}
func TestRequireProjectAccess_Success_CallsNextWithProjectID(t *testing.T) {
a := &API{Auth: alwaysOwnedAuthStore()}
pid := uuid.New()
uid := uuid.New()
var gotPID uuid.UUID
var gotOK bool
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPID, gotOK = projectIDFrom(r.Context())
w.WriteHeader(http.StatusOK)
})
router := chi.NewRouter()
router.With(a.RequireProjectAccess).Get("/projects/{pid}", next.ServeHTTP)
req := httptest.NewRequest(http.MethodGet, "/projects/"+pid.String(), nil)
req = req.WithContext(context.WithValue(req.Context(), ctxUserID, uid))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200 (next called), got %d body %s", w.Code, w.Body.String())
}
if !gotOK || gotPID != pid {
t.Fatalf("expected projectIDFrom to yield %s, got %s (ok=%v)", pid, gotPID, gotOK)
}
}
// --- IDOR regression ---
// TestIDOR_CheckForeignProject_Returns404AndServiceNotCalled is the IDOR
// regression required by Task 4: a request for project B's domain, made
// with user A's session (who owns project A, not B), must be rejected by
// RequireProjectAccess with 404 before service.Check is ever invoked — a
// caller must never be able to check/apply another tenant's domain by
// guessing/leaking its ids.
func TestIDOR_CheckForeignProject_Returns404AndServiceNotCalled(t *testing.T) {
userA := uuid.New()
pidA := uuid.New()
pidB := uuid.New() // owned by a different user; not A
svc := &recordingCheckApplier{}
auth := stubAuthStore{getProjectOwnedFn: func(_ context.Context, pid, uid uuid.UUID) (store.Project, error) {
if pid == pidA && uid == userA {
return store.Project{ID: pidA, UserID: userA}, nil
}
return store.Project{}, errors.New("not found")
}}
a := &API{Svc: svc, Auth: auth, Sessions: alwaysValidSessions(userA)}
router := NewRouter(a)
did := uuid.New().String()
req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+pidB.String()+"/domains/"+did+"/check", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404 for user A requesting project B's domain, got %d body %s", w.Code, w.Body.String())
}
if svc.checkCalled {
t.Fatal("service.Check must not be called when the caller doesn't own the project")
}
// Sanity check: the same user against their own project succeeds and
// does reach the service — proving the 404 above is really about
// project ownership, not e.g. a broken route.
req2 := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+pidA.String()+"/domains/"+did+"/check", nil)
w2 := httptest.NewRecorder()
router.ServeHTTP(w2, req2)
if w2.Code != http.StatusOK {
t.Fatalf("expected 200 for user A requesting their own project, got %d body %s", w2.Code, w2.Body.String())
}
if !svc.checkCalled {
t.Fatal("expected service.Check to be called for the caller's own project")
}
}
+36 -60
View File
@@ -24,11 +24,9 @@ func decodeBody(w http.ResponseWriter, r *http.Request, v any) bool {
// --- accounts --- // --- accounts ---
func (a *API) handleCreateAccount(w http.ResponseWriter, r *http.Request) { func (a *API) handleCreateAccount(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid")) // pid is guaranteed present and owned by the caller — RequireProjectAccess
if err != nil { // validated it before this handler ever runs.
writeErr(w, http.StatusBadRequest, "invalid project id") pid, _ := projectIDFrom(r.Context())
return
}
var req accountRequest var req accountRequest
if !decodeBody(w, r, &req) { if !decodeBody(w, r, &req) {
return return
@@ -53,11 +51,9 @@ func (a *API) handleCreateAccount(w http.ResponseWriter, r *http.Request) {
} }
func (a *API) handleListAccounts(w http.ResponseWriter, r *http.Request) { func (a *API) handleListAccounts(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid")) // pid is guaranteed present and owned by the caller — RequireProjectAccess
if err != nil { // validated it before this handler ever runs.
writeErr(w, http.StatusBadRequest, "invalid project id") pid, _ := projectIDFrom(r.Context())
return
}
accs, err := a.Store.ListAccounts(r.Context(), pid) accs, err := a.Store.ListAccounts(r.Context(), pid)
if err != nil { if err != nil {
log.Printf("api: list accounts failed: %v", err) log.Printf("api: list accounts failed: %v", err)
@@ -72,11 +68,9 @@ func (a *API) handleListAccounts(w http.ResponseWriter, r *http.Request) {
} }
func (a *API) handleDeleteAccount(w http.ResponseWriter, r *http.Request) { func (a *API) handleDeleteAccount(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid")) // pid is guaranteed present and owned by the caller — RequireProjectAccess
if err != nil { // validated it before this handler ever runs.
writeErr(w, http.StatusBadRequest, "invalid project id") pid, _ := projectIDFrom(r.Context())
return
}
aid, err := uuid.Parse(chi.URLParam(r, "aid")) aid, err := uuid.Parse(chi.URLParam(r, "aid"))
if err != nil { if err != nil {
writeErr(w, http.StatusBadRequest, "invalid account id") writeErr(w, http.StatusBadRequest, "invalid account id")
@@ -93,11 +87,9 @@ func (a *API) handleDeleteAccount(w http.ResponseWriter, r *http.Request) {
// handleImportZones lists zones from the provider for the given account and // handleImportZones lists zones from the provider for the given account and
// creates one domain per zone (template_id left unset). // creates one domain per zone (template_id left unset).
func (a *API) handleImportZones(w http.ResponseWriter, r *http.Request) { func (a *API) handleImportZones(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid")) // pid is guaranteed present and owned by the caller — RequireProjectAccess
if err != nil { // validated it before this handler ever runs.
writeErr(w, http.StatusBadRequest, "invalid project id") pid, _ := projectIDFrom(r.Context())
return
}
aid, err := uuid.Parse(chi.URLParam(r, "aid")) aid, err := uuid.Parse(chi.URLParam(r, "aid"))
if err != nil { if err != nil {
writeErr(w, http.StatusBadRequest, "invalid account id") writeErr(w, http.StatusBadRequest, "invalid account id")
@@ -145,11 +137,9 @@ func (a *API) handleImportZones(w http.ResponseWriter, r *http.Request) {
// --- templates --- // --- templates ---
func (a *API) handleCreateTemplate(w http.ResponseWriter, r *http.Request) { func (a *API) handleCreateTemplate(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid")) // pid is guaranteed present and owned by the caller — RequireProjectAccess
if err != nil { // validated it before this handler ever runs.
writeErr(w, http.StatusBadRequest, "invalid project id") pid, _ := projectIDFrom(r.Context())
return
}
var req templateRequest var req templateRequest
if !decodeBody(w, r, &req) { if !decodeBody(w, r, &req) {
return return
@@ -169,11 +159,9 @@ func (a *API) handleCreateTemplate(w http.ResponseWriter, r *http.Request) {
} }
func (a *API) handleListTemplates(w http.ResponseWriter, r *http.Request) { func (a *API) handleListTemplates(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid")) // pid is guaranteed present and owned by the caller — RequireProjectAccess
if err != nil { // validated it before this handler ever runs.
writeErr(w, http.StatusBadRequest, "invalid project id") pid, _ := projectIDFrom(r.Context())
return
}
tpls, err := a.Store.ListTemplates(r.Context(), pid) tpls, err := a.Store.ListTemplates(r.Context(), pid)
if err != nil { if err != nil {
log.Printf("api: list templates failed: %v", err) log.Printf("api: list templates failed: %v", err)
@@ -188,11 +176,9 @@ func (a *API) handleListTemplates(w http.ResponseWriter, r *http.Request) {
} }
func (a *API) handleUpdateTemplate(w http.ResponseWriter, r *http.Request) { func (a *API) handleUpdateTemplate(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid")) // pid is guaranteed present and owned by the caller — RequireProjectAccess
if err != nil { // validated it before this handler ever runs.
writeErr(w, http.StatusBadRequest, "invalid project id") pid, _ := projectIDFrom(r.Context())
return
}
tid, err := uuid.Parse(chi.URLParam(r, "tid")) tid, err := uuid.Parse(chi.URLParam(r, "tid"))
if err != nil { if err != nil {
writeErr(w, http.StatusBadRequest, "invalid template id") writeErr(w, http.StatusBadRequest, "invalid template id")
@@ -217,11 +203,9 @@ func (a *API) handleUpdateTemplate(w http.ResponseWriter, r *http.Request) {
} }
func (a *API) handleDeleteTemplate(w http.ResponseWriter, r *http.Request) { func (a *API) handleDeleteTemplate(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid")) // pid is guaranteed present and owned by the caller — RequireProjectAccess
if err != nil { // validated it before this handler ever runs.
writeErr(w, http.StatusBadRequest, "invalid project id") pid, _ := projectIDFrom(r.Context())
return
}
tid, err := uuid.Parse(chi.URLParam(r, "tid")) tid, err := uuid.Parse(chi.URLParam(r, "tid"))
if err != nil { if err != nil {
writeErr(w, http.StatusBadRequest, "invalid template id") writeErr(w, http.StatusBadRequest, "invalid template id")
@@ -238,11 +222,9 @@ func (a *API) handleDeleteTemplate(w http.ResponseWriter, r *http.Request) {
// --- domains --- // --- domains ---
func (a *API) handleCreateDomain(w http.ResponseWriter, r *http.Request) { func (a *API) handleCreateDomain(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid")) // pid is guaranteed present and owned by the caller — RequireProjectAccess
if err != nil { // validated it before this handler ever runs.
writeErr(w, http.StatusBadRequest, "invalid project id") pid, _ := projectIDFrom(r.Context())
return
}
var req domainRequest var req domainRequest
if !decodeBody(w, r, &req) { if !decodeBody(w, r, &req) {
return return
@@ -286,11 +268,9 @@ func (a *API) handleCreateDomain(w http.ResponseWriter, r *http.Request) {
} }
func (a *API) handleListDomains(w http.ResponseWriter, r *http.Request) { func (a *API) handleListDomains(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid")) // pid is guaranteed present and owned by the caller — RequireProjectAccess
if err != nil { // validated it before this handler ever runs.
writeErr(w, http.StatusBadRequest, "invalid project id") pid, _ := projectIDFrom(r.Context())
return
}
doms, err := a.Store.ListDomains(r.Context(), pid) doms, err := a.Store.ListDomains(r.Context(), pid)
if err != nil { if err != nil {
log.Printf("api: list domains failed: %v", err) log.Printf("api: list domains failed: %v", err)
@@ -308,11 +288,9 @@ func (a *API) handleListDomains(w http.ResponseWriter, r *http.Request) {
// check/apply a domain — this is what makes an imported domain (which // check/apply a domain — this is what makes an imported domain (which
// starts with template_id=NULL) checkable, closing the import→check loop. // starts with template_id=NULL) checkable, closing the import→check loop.
func (a *API) handleSetDomainTemplate(w http.ResponseWriter, r *http.Request) { func (a *API) handleSetDomainTemplate(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid")) // pid is guaranteed present and owned by the caller — RequireProjectAccess
if err != nil { // validated it before this handler ever runs.
writeErr(w, http.StatusBadRequest, "invalid project id") pid, _ := projectIDFrom(r.Context())
return
}
did, err := uuid.Parse(chi.URLParam(r, "did")) did, err := uuid.Parse(chi.URLParam(r, "did"))
if err != nil { if err != nil {
writeErr(w, http.StatusBadRequest, "invalid domain id") writeErr(w, http.StatusBadRequest, "invalid domain id")
@@ -339,11 +317,9 @@ func (a *API) handleSetDomainTemplate(w http.ResponseWriter, r *http.Request) {
} }
func (a *API) handleDeleteDomain(w http.ResponseWriter, r *http.Request) { func (a *API) handleDeleteDomain(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid")) // pid is guaranteed present and owned by the caller — RequireProjectAccess
if err != nil { // validated it before this handler ever runs.
writeErr(w, http.StatusBadRequest, "invalid project id") pid, _ := projectIDFrom(r.Context())
return
}
did, err := uuid.Parse(chi.URLParam(r, "did")) did, err := uuid.Parse(chi.URLParam(r, "did"))
if err != nil { if err != nil {
writeErr(w, http.StatusBadRequest, "invalid domain id") writeErr(w, http.StatusBadRequest, "invalid domain id")
+29 -19
View File
@@ -132,7 +132,9 @@ func (m *mockTenantStore) ImportDomains(_ context.Context, projectID, accountID
type mockCipher struct{} type mockCipher struct{}
func (mockCipher) Encrypt(plaintext []byte) (string, error) { return "ENC(" + string(plaintext) + ")", nil } func (mockCipher) Encrypt(plaintext []byte) (string, error) {
return "ENC(" + string(plaintext) + ")", nil
}
func (mockCipher) Decrypt(enc string) ([]byte, error) { func (mockCipher) Decrypt(enc string) ([]byte, error) {
return []byte(strings.TrimSuffix(strings.TrimPrefix(enc, "ENC("), ")")), nil return []byte(strings.TrimSuffix(strings.TrimPrefix(enc, "ENC("), ")")), nil
} }
@@ -160,9 +162,17 @@ func (mockProvider) ApplyChanges(context.Context, provider.Credentials, string,
return nil return nil
} }
// newTenantTestAPI wires a fixed authenticated user who owns whatever
// project id is requested (alwaysOwnedAuthStore/alwaysValidSessions, see
// middleware_test.go) — these tests exercise CRUD behavior past the
// RequireAuth/RequireProjectAccess boundary, which has its own dedicated
// coverage in middleware_test.go.
func newTenantTestAPI() (*API, *mockTenantStore) { func newTenantTestAPI() (*API, *mockTenantStore) {
ts := &mockTenantStore{} ts := &mockTenantStore{}
a := &API{Store: ts, Cipher: mockCipher{}, Reg: &mockRegistry{}} a := &API{
Store: ts, Cipher: mockCipher{}, Reg: &mockRegistry{},
Auth: alwaysOwnedAuthStore(), Sessions: alwaysValidSessions(uuid.New()),
}
return a, ts return a, ts
} }
@@ -173,7 +183,7 @@ func TestCreateAccount_SecretEncryptedAndNotInResponse(t *testing.T) {
router := NewRouter(a) router := NewRouter(a)
body := `{"provider":"selectel","secret":"super-secret-token","comment":"prod"}` body := `{"provider":"selectel","secret":"super-secret-token","comment":"prod"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts", strings.NewReader(body)) req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts", strings.NewReader(body))
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -211,7 +221,7 @@ func TestListAccounts_NoSecretsInResponse(t *testing.T) {
} }
router := NewRouter(a) router := NewRouter(a)
req := httptest.NewRequest(http.MethodGet, "/api/v1/projects/"+testPID+"/accounts", nil) req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+testPID+"/accounts", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -234,7 +244,7 @@ func TestDeleteAccount_BadUUID(t *testing.T) {
a, _ := newTenantTestAPI() a, _ := newTenantTestAPI()
router := NewRouter(a) router := NewRouter(a)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/projects/"+testPID+"/accounts/not-a-uuid", nil) req := requestWithSessionCookie(http.MethodDelete, "/api/v1/projects/"+testPID+"/accounts/not-a-uuid", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -250,7 +260,7 @@ func TestCreateTemplate_SavesRecords(t *testing.T) {
router := NewRouter(a) router := NewRouter(a)
body := `{"name":"base","records":[{"type":"A","name":"@","ttl":300,"values":["1.2.3.4"]}]}` body := `{"name":"base","records":[{"type":"A","name":"@","ttl":300,"values":["1.2.3.4"]}]}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/templates", strings.NewReader(body)) req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/templates", strings.NewReader(body))
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -278,7 +288,7 @@ func TestUpdateTemplate_BadUUID(t *testing.T) {
router := NewRouter(a) router := NewRouter(a)
body := `{"name":"x","records":[]}` body := `{"name":"x","records":[]}`
req := httptest.NewRequest(http.MethodPut, "/api/v1/projects/"+testPID+"/templates/not-a-uuid", strings.NewReader(body)) req := requestWithSessionCookie(http.MethodPut, "/api/v1/projects/"+testPID+"/templates/not-a-uuid", strings.NewReader(body))
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -299,7 +309,7 @@ func TestImportZones_CreatesDomainPerZone(t *testing.T) {
}} }}
router := NewRouter(a) router := NewRouter(a)
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/"+accID.String()+"/import", nil) req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/"+accID.String()+"/import", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -337,7 +347,7 @@ func TestImportZones_AtomicRollbackOnError(t *testing.T) {
}} }}
router := NewRouter(a) router := NewRouter(a)
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/"+accID.String()+"/import", nil) req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/"+accID.String()+"/import", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -356,7 +366,7 @@ func TestImportZones_BadAccountUUID(t *testing.T) {
a, _ := newTenantTestAPI() a, _ := newTenantTestAPI()
router := NewRouter(a) router := NewRouter(a)
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/not-a-uuid/import", nil) req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/not-a-uuid/import", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -370,7 +380,7 @@ func TestCreateDomain_BadProjectUUID(t *testing.T) {
router := NewRouter(a) router := NewRouter(a)
body := `{"providerAccountId":"` + uuid.New().String() + `","zoneName":"example.com","zoneId":"z1"}` body := `{"providerAccountId":"` + uuid.New().String() + `","zoneName":"example.com","zoneId":"z1"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/not-a-uuid/domains", strings.NewReader(body)) req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/not-a-uuid/domains", strings.NewReader(body))
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -390,7 +400,7 @@ func TestCreateDomain_AccountNotFoundInProject(t *testing.T) {
// ts.accounts is empty, so GetAccount will not find this id. // ts.accounts is empty, so GetAccount will not find this id.
foreignAccID := uuid.New() foreignAccID := uuid.New()
body := `{"providerAccountId":"` + foreignAccID.String() + `","zoneName":"example.com","zoneId":"z1"}` body := `{"providerAccountId":"` + foreignAccID.String() + `","zoneName":"example.com","zoneId":"z1"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body)) req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -413,7 +423,7 @@ func TestCreateDomain_TemplateNotFoundInProject(t *testing.T) {
foreignTplID := uuid.New() foreignTplID := uuid.New()
body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1","templateId":"` + foreignTplID.String() + `"}` body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1","templateId":"` + foreignTplID.String() + `"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body)) req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -434,7 +444,7 @@ func TestCreateDomain_HappyPath(t *testing.T) {
router := NewRouter(a) router := NewRouter(a)
body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1"}` body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body)) req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -464,7 +474,7 @@ func TestCreateDomain_ValidTemplateInProject(t *testing.T) {
router := NewRouter(a) router := NewRouter(a)
body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1","templateId":"` + tplID.String() + `"}` body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1","templateId":"` + tplID.String() + `"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body)) req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -490,7 +500,7 @@ func TestSetDomainTemplate_ValidTemplateId(t *testing.T) {
router := NewRouter(a) router := NewRouter(a)
body := `{"templateId":"` + tplID.String() + `"}` body := `{"templateId":"` + tplID.String() + `"}`
req := httptest.NewRequest(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body)) req := requestWithSessionCookie(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body))
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -513,7 +523,7 @@ func TestSetDomainTemplate_BadTemplateUUID(t *testing.T) {
router := NewRouter(a) router := NewRouter(a)
body := `{"templateId":"not-a-uuid"}` body := `{"templateId":"not-a-uuid"}`
req := httptest.NewRequest(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body)) req := requestWithSessionCookie(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body))
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -530,7 +540,7 @@ func TestSetDomainTemplate_TemplateNotFound(t *testing.T) {
router := NewRouter(a) router := NewRouter(a)
body := `{"templateId":"` + uuid.New().String() + `"}` body := `{"templateId":"` + uuid.New().String() + `"}`
req := httptest.NewRequest(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body)) req := requestWithSessionCookie(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body))
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -543,7 +553,7 @@ func TestDeleteDomain_BadUUID(t *testing.T) {
a, _ := newTenantTestAPI() a, _ := newTenantTestAPI()
router := NewRouter(a) router := NewRouter(a)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/projects/"+testPID+"/domains/not-a-uuid", nil) req := requestWithSessionCookie(http.MethodDelete, "/api/v1/projects/"+testPID+"/domains/not-a-uuid", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
+9 -7
View File
@@ -21,7 +21,7 @@ type DomainRef struct {
} }
type Loader interface { type Loader interface {
LoadDomain(ctx context.Context, domainID uuid.UUID) (DomainRef, error) LoadDomain(ctx context.Context, projectID, domainID uuid.UUID) (DomainRef, error)
} }
type Recorder interface { type Recorder interface {
@@ -45,8 +45,10 @@ func New(loader Loader, rec Recorder, reg *registry.Registry, cipher *crypto.Cip
} }
// resolve loads the domain, its provider and decrypted credentials, and computes the diff. // resolve loads the domain, its provider and decrypted credentials, and computes the diff.
func (s *DomainService) resolve(ctx context.Context, domainID uuid.UUID) (provider.Provider, provider.Credentials, DomainRef, diff.Changeset, error) { // projectID scopes the lookup so a domainID belonging to another tenant's
ref, err := s.loader.LoadDomain(ctx, domainID) // project can never be resolved here (closes IDOR).
func (s *DomainService) resolve(ctx context.Context, projectID, domainID uuid.UUID) (provider.Provider, provider.Credentials, DomainRef, diff.Changeset, error) {
ref, err := s.loader.LoadDomain(ctx, projectID, domainID)
if err != nil { if err != nil {
return nil, provider.Credentials{}, ref, diff.Changeset{}, err return nil, provider.Credentials{}, ref, diff.Changeset{}, err
} }
@@ -68,8 +70,8 @@ func (s *DomainService) resolve(ctx context.Context, domainID uuid.UUID) (provid
} }
// Check computes and records the diff between template and zone. // Check computes and records the diff between template and zone.
func (s *DomainService) Check(ctx context.Context, domainID uuid.UUID) (diff.Changeset, error) { func (s *DomainService) Check(ctx context.Context, projectID, domainID uuid.UUID) (diff.Changeset, error) {
_, _, _, cs, err := s.resolve(ctx, domainID) _, _, _, cs, err := s.resolve(ctx, projectID, domainID)
if err != nil { if err != nil {
return diff.Changeset{}, err return diff.Changeset{}, err
} }
@@ -80,8 +82,8 @@ func (s *DomainService) Check(ctx context.Context, domainID uuid.UUID) (diff.Cha
} }
// Apply applies updates always (when ApplyUpdates) and prunes only when ApplyPrunes. // Apply applies updates always (when ApplyUpdates) and prunes only when ApplyPrunes.
func (s *DomainService) Apply(ctx context.Context, domainID uuid.UUID, req ApplyRequest) (diff.Changeset, error) { func (s *DomainService) Apply(ctx context.Context, projectID, domainID uuid.UUID, req ApplyRequest) (diff.Changeset, error) {
p, creds, ref, cs, err := s.resolve(ctx, domainID) p, creds, ref, cs, err := s.resolve(ctx, projectID, domainID)
if err != nil { if err != nil {
return diff.Changeset{}, err return diff.Changeset{}, err
} }
+6 -4
View File
@@ -44,7 +44,9 @@ func (f *fakeProvider) ApplyChanges(_ context.Context, _ provider.Credentials, _
type fakeLoader struct{ ref DomainRef } type fakeLoader struct{ ref DomainRef }
func (l fakeLoader) LoadDomain(context.Context, uuid.UUID) (DomainRef, error) { return l.ref, nil } func (l fakeLoader) LoadDomain(context.Context, uuid.UUID, uuid.UUID) (DomainRef, error) {
return l.ref, nil
}
type nopRecorder struct{} type nopRecorder struct{}
@@ -66,7 +68,7 @@ func TestCheckProducesDiff(t *testing.T) {
{Type: "A", Name: "a.example.com.", TTL: 300, Values: []string{"1.1.1.1"}}, // update {Type: "A", Name: "a.example.com.", TTL: 300, Values: []string{"1.1.1.1"}}, // update
}} }}
svc, _ := setup(t, actual, tmpl) svc, _ := setup(t, actual, tmpl)
cs, err := svc.Check(context.Background(), uuid.New()) cs, err := svc.Check(context.Background(), uuid.New(), uuid.New())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -87,7 +89,7 @@ func TestApplyRespectsPruneGuard(t *testing.T) {
// applyPrunes=false → удаление b НЕ применяется // applyPrunes=false → удаление b НЕ применяется
svc, fp := setup(t, actual, tmpl) svc, fp := setup(t, actual, tmpl)
if _, err := svc.Apply(context.Background(), uuid.New(), ApplyRequest{ApplyUpdates: true, ApplyPrunes: false}); err != nil { if _, err := svc.Apply(context.Background(), uuid.New(), uuid.New(), ApplyRequest{ApplyUpdates: true, ApplyPrunes: false}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
for _, d := range fp.applied.Diffs { for _, d := range fp.applied.Diffs {
@@ -98,7 +100,7 @@ func TestApplyRespectsPruneGuard(t *testing.T) {
// applyPrunes=true → удаление b применяется // applyPrunes=true → удаление b применяется
svc2, fp2 := setup(t, actual, tmpl) svc2, fp2 := setup(t, actual, tmpl)
if _, err := svc2.Apply(context.Background(), uuid.New(), ApplyRequest{ApplyUpdates: true, ApplyPrunes: true}); err != nil { if _, err := svc2.Apply(context.Background(), uuid.New(), uuid.New(), ApplyRequest{ApplyUpdates: true, ApplyPrunes: true}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
var sawDelete bool var sawDelete bool
+8 -3
View File
@@ -162,9 +162,14 @@ SELECT d.zone_id, a.provider, a.secret_enc, t.doc
FROM domains d FROM domains d
JOIN provider_accounts a ON a.id = d.provider_account_id JOIN provider_accounts a ON a.id = d.provider_account_id
LEFT JOIN templates t ON t.id = d.template_id LEFT JOIN templates t ON t.id = d.template_id
WHERE d.id = $1 WHERE d.id = $1 AND d.project_id = $2
` `
type LoadDomainFullParams struct {
ID uuid.UUID `json:"id"`
ProjectID uuid.UUID `json:"project_id"`
}
type LoadDomainFullRow struct { type LoadDomainFullRow struct {
ZoneID string `json:"zone_id"` ZoneID string `json:"zone_id"`
Provider string `json:"provider"` Provider string `json:"provider"`
@@ -172,8 +177,8 @@ type LoadDomainFullRow struct {
Doc *dto.TemplateDoc `json:"doc"` Doc *dto.TemplateDoc `json:"doc"`
} }
func (q *Queries) LoadDomainFull(ctx context.Context, id uuid.UUID) (LoadDomainFullRow, error) { func (q *Queries) LoadDomainFull(ctx context.Context, arg LoadDomainFullParams) (LoadDomainFullRow, error) {
row := q.db.QueryRow(ctx, loadDomainFull, id) row := q.db.QueryRow(ctx, loadDomainFull, arg.ID, arg.ProjectID)
var i LoadDomainFullRow var i LoadDomainFullRow
err := row.Scan( err := row.Scan(
&i.ZoneID, &i.ZoneID,
+5 -3
View File
@@ -13,9 +13,11 @@ import (
) )
// LoadDomain joins domains+provider_accounts+templates to build the // LoadDomain joins domains+provider_accounts+templates to build the
// service.DomainRef needed to check/apply a domain's DNS records. // service.DomainRef needed to check/apply a domain's DNS records. Scoped by
func (s *Store) LoadDomain(ctx context.Context, domainID uuid.UUID) (service.DomainRef, error) { // projectID so a domain belonging to another tenant's project can never be
row, err := s.q.LoadDomainFull(ctx, domainID) // loaded, even if its domainID is guessed/leaked (closes IDOR).
func (s *Store) LoadDomain(ctx context.Context, projectID, domainID uuid.UUID) (service.DomainRef, error) {
row, err := s.q.LoadDomainFull(ctx, db.LoadDomainFullParams{ID: domainID, ProjectID: projectID})
if err != nil { if err != nil {
return service.DomainRef{}, err return service.DomainRef{}, err
} }
+2 -2
View File
@@ -40,7 +40,7 @@ func TestLoadDomainAndSaveCheckRun(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
ref, err := s.LoadDomain(ctx, domain.ID) ref, err := s.LoadDomain(ctx, defaultProject, domain.ID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -87,7 +87,7 @@ func TestLoadDomainNoTemplate(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if _, err := s.LoadDomain(ctx, domain.ID); err == nil { if _, err := s.LoadDomain(ctx, defaultProject, domain.ID); err == nil {
t.Fatal("expected error for domain without template, got nil") t.Fatal("expected error for domain without template, got nil")
} }
} }
+1 -1
View File
@@ -27,4 +27,4 @@ SELECT d.zone_id, a.provider, a.secret_enc, t.doc
FROM domains d FROM domains d
JOIN provider_accounts a ON a.id = d.provider_account_id JOIN provider_accounts a ON a.id = d.provider_account_id
LEFT JOIN templates t ON t.id = d.template_id LEFT JOIN templates t ON t.id = d.template_id
WHERE d.id = $1; WHERE d.id = $1 AND d.project_id = $2;
+2 -2
View File
@@ -214,7 +214,7 @@ func TestSetDomainTemplate_ClosesImportCheckLoop(t *testing.T) {
dom := doms[0] dom := doms[0]
// Before binding, the domain is not checkable. // Before binding, the domain is not checkable.
if _, err := s.LoadDomain(ctx, dom.ID); err == nil { if _, err := s.LoadDomain(ctx, defaultProject, dom.ID); err == nil {
t.Fatal("expected LoadDomain to fail before a template is bound") t.Fatal("expected LoadDomain to fail before a template is bound")
} }
@@ -234,7 +234,7 @@ func TestSetDomainTemplate_ClosesImportCheckLoop(t *testing.T) {
t.Fatalf("expected domain.TemplateID=%s, got %+v", tpl.ID, updated.TemplateID) t.Fatalf("expected domain.TemplateID=%s, got %+v", tpl.ID, updated.TemplateID)
} }
ref, err := s.LoadDomain(ctx, dom.ID) ref, err := s.LoadDomain(ctx, defaultProject, dom.ID)
if err != nil { if err != nil {
t.Fatalf("expected LoadDomain to succeed after binding template, got error: %v", err) t.Fatalf("expected LoadDomain to succeed after binding template, got error: %v", err)
} }