feat(api): RequireAuth+RequireProjectAccess middleware, IDOR-scope check/apply по projectID
This commit is contained in:
+14
-4
@@ -18,8 +18,8 @@ import (
|
||||
|
||||
// CheckApplier is the service surface the API depends on.
|
||||
type CheckApplier interface {
|
||||
Check(ctx context.Context, domainID uuid.UUID) (diff.Changeset, error)
|
||||
Apply(ctx context.Context, domainID uuid.UUID, req service.ApplyRequest) (diff.Changeset, error)
|
||||
Check(ctx context.Context, projectID, domainID uuid.UUID) (diff.Changeset, error)
|
||||
Apply(ctx context.Context, projectID, domainID uuid.UUID, req service.ApplyRequest) (diff.Changeset, error)
|
||||
}
|
||||
|
||||
// TenantStore is the narrow persistence surface the CRUD handlers depend on.
|
||||
@@ -63,6 +63,10 @@ type AuthStore interface {
|
||||
GetUserByEmail(ctx context.Context, email string) (store.User, error)
|
||||
GetUserByID(ctx context.Context, userID uuid.UUID) (store.User, error)
|
||||
GetUserProject(ctx context.Context, userID uuid.UUID) (store.Project, error)
|
||||
// GetProjectOwned looks up projectID and returns it only if it's owned by
|
||||
// userID — RequireProjectAccess uses this to reject foreign/nonexistent
|
||||
// projects with 404 before any handler runs.
|
||||
GetProjectOwned(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error)
|
||||
}
|
||||
|
||||
// SessionManager creates/validates/destroys login sessions. *auth.Sessions
|
||||
@@ -91,11 +95,17 @@ func NewRouter(a *API) http.Handler {
|
||||
r.Route("/api/v1/auth", func(r chi.Router) {
|
||||
r.Post("/register", a.handleRegister)
|
||||
r.Post("/login", a.handleLogin)
|
||||
r.Post("/logout", a.handleLogout) // защитится RequireAuth в Task 4
|
||||
r.Get("/me", a.handleMe) // защитится RequireAuth в Task 4
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(a.RequireAuth)
|
||||
r.Post("/logout", a.handleLogout)
|
||||
r.Get("/me", a.handleMe)
|
||||
})
|
||||
})
|
||||
|
||||
r.Route("/api/v1/projects/{pid}", func(r chi.Router) {
|
||||
r.Use(a.RequireAuth)
|
||||
r.Use(a.RequireProjectAccess)
|
||||
|
||||
r.Route("/domains", func(r chi.Router) {
|
||||
r.Post("/", a.handleCreateDomain)
|
||||
r.Get("/", a.handleListDomains)
|
||||
|
||||
@@ -20,18 +20,23 @@ type mockCheckApplier struct {
|
||||
lastReq service.ApplyRequest
|
||||
}
|
||||
|
||||
func (m *mockCheckApplier) Check(context.Context, uuid.UUID) (diff.Changeset, error) {
|
||||
func (m *mockCheckApplier) Check(context.Context, uuid.UUID, uuid.UUID) (diff.Changeset, error) {
|
||||
d := model.Record{Type: model.A, Name: "a.example.com.", TTL: 300, Values: []string{"1.1.1.1"}}
|
||||
return diff.Changeset{Diffs: []diff.RecordDiff{{Kind: diff.Add, Type: d.Type, Name: d.Name, Desired: &d}}}, nil
|
||||
}
|
||||
func (m *mockCheckApplier) Apply(_ context.Context, _ uuid.UUID, req service.ApplyRequest) (diff.Changeset, error) {
|
||||
func (m *mockCheckApplier) Apply(_ context.Context, _, _ uuid.UUID, req service.ApplyRequest) (diff.Changeset, error) {
|
||||
m.lastReq = req
|
||||
return diff.Changeset{}, nil
|
||||
}
|
||||
|
||||
// newTestAPI wires a fixed authenticated user who owns whatever project id
|
||||
// is requested (via alwaysOwnedAuthStore/alwaysValidSessions in
|
||||
// middleware_test.go) — these tests exercise check/apply behavior past the
|
||||
// RequireAuth/RequireProjectAccess boundary, which is covered separately by
|
||||
// middleware_test.go's own tests and the IDOR regression.
|
||||
func newTestAPI() (*API, *mockCheckApplier) {
|
||||
m := &mockCheckApplier{}
|
||||
return &API{Svc: m}, m // остальные зависимости (store/cipher) nil — CRUD-тесты добавит реализатор
|
||||
return &API{Svc: m, Auth: alwaysOwnedAuthStore(), Sessions: alwaysValidSessions(uuid.New())}, m
|
||||
}
|
||||
|
||||
func TestCheckEndpoint(t *testing.T) {
|
||||
@@ -39,7 +44,7 @@ func TestCheckEndpoint(t *testing.T) {
|
||||
router := NewRouter(a)
|
||||
|
||||
did := uuid.New().String()
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
req := requestWithSessionCookie(http.MethodGet,
|
||||
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/check", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
@@ -62,7 +67,7 @@ func TestApplyDefaultsPruneFalse(t *testing.T) {
|
||||
|
||||
did := uuid.New().String()
|
||||
body := `{"applyUpdates":true}` // applyPrunes отсутствует → false
|
||||
req := httptest.NewRequest(http.MethodPost,
|
||||
req := requestWithSessionCookie(http.MethodPost,
|
||||
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply",
|
||||
strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
@@ -81,7 +86,7 @@ func TestApplyEmptyBodyOK(t *testing.T) {
|
||||
router := NewRouter(a)
|
||||
|
||||
did := uuid.New().String()
|
||||
req := httptest.NewRequest(http.MethodPost,
|
||||
req := requestWithSessionCookie(http.MethodPost,
|
||||
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
@@ -100,7 +105,7 @@ func TestApplyMalformedBody(t *testing.T) {
|
||||
|
||||
did := uuid.New().String()
|
||||
body := `{"applyUpdates":`
|
||||
req := httptest.NewRequest(http.MethodPost,
|
||||
req := requestWithSessionCookie(http.MethodPost,
|
||||
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply",
|
||||
strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
@@ -114,7 +119,7 @@ func TestApplyMalformedBody(t *testing.T) {
|
||||
func TestApplyBadUUID(t *testing.T) {
|
||||
a, _ := newTestAPI()
|
||||
router := NewRouter(a)
|
||||
req := httptest.NewRequest(http.MethodPost,
|
||||
req := requestWithSessionCookie(http.MethodPost,
|
||||
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/not-a-uuid/apply",
|
||||
bytes.NewReader([]byte(`{}`)))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/vasyakrg/dns-autoresolver/internal/auth"
|
||||
"github.com/vasyakrg/dns-autoresolver/internal/store"
|
||||
)
|
||||
@@ -39,19 +36,6 @@ func normalizeEmail(email string) string {
|
||||
return strings.ToLower(strings.TrimSpace(email))
|
||||
}
|
||||
|
||||
// ctxKeyUserID is a private context key carrying the authenticated user's ID.
|
||||
// Task 4's RequireAuth middleware sets it after validating the session
|
||||
// cookie; handleMe reads it back.
|
||||
type ctxKeyUserID struct{}
|
||||
|
||||
// userIDFromContext extracts the authenticated user ID set by RequireAuth
|
||||
// (Task 4). Until that middleware is wired in, tests set it directly via
|
||||
// context.WithValue.
|
||||
func userIDFromContext(ctx context.Context) (uuid.UUID, bool) {
|
||||
id, ok := ctx.Value(ctxKeyUserID{}).(uuid.UUID)
|
||||
return id, ok
|
||||
}
|
||||
|
||||
func setSessionCookie(w http.ResponseWriter, token string, exp time.Time) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: sessionCookieName, Value: token, Path: "/",
|
||||
@@ -165,11 +149,10 @@ func (a *API) handleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// handleMe returns the authenticated caller's identity + default project.
|
||||
// The user ID comes from the request context, set by Task 4's RequireAuth
|
||||
// middleware after validating the session cookie (tests set it directly via
|
||||
// context.WithValue in the interim).
|
||||
// The user ID comes from the request context, set by RequireAuth after
|
||||
// validating the session cookie.
|
||||
func (a *API) handleMe(w http.ResponseWriter, r *http.Request) {
|
||||
userID, ok := userIDFromContext(r.Context())
|
||||
userID, ok := userIDFrom(r.Context())
|
||||
if !ok {
|
||||
writeErr(w, http.StatusUnauthorized, "authentication required")
|
||||
return
|
||||
|
||||
@@ -18,10 +18,11 @@ import (
|
||||
// --- mocks ---
|
||||
|
||||
type mockAuthStore struct {
|
||||
registerUserFn func(ctx context.Context, email, passwordHash string) (store.User, store.Project, error)
|
||||
getUserByEmailFn func(ctx context.Context, email string) (store.User, error)
|
||||
getUserByIDFn func(ctx context.Context, userID uuid.UUID) (store.User, error)
|
||||
getUserProjectFn func(ctx context.Context, userID uuid.UUID) (store.Project, error)
|
||||
registerUserFn func(ctx context.Context, email, passwordHash string) (store.User, store.Project, error)
|
||||
getUserByEmailFn func(ctx context.Context, email string) (store.User, error)
|
||||
getUserByIDFn func(ctx context.Context, userID uuid.UUID) (store.User, error)
|
||||
getUserProjectFn func(ctx context.Context, userID uuid.UUID) (store.Project, error)
|
||||
getProjectOwnedFn func(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error)
|
||||
}
|
||||
|
||||
func (m *mockAuthStore) RegisterUser(ctx context.Context, email, passwordHash string) (store.User, store.Project, error) {
|
||||
@@ -40,8 +41,13 @@ func (m *mockAuthStore) GetUserProject(ctx context.Context, userID uuid.UUID) (s
|
||||
return m.getUserProjectFn(ctx, userID)
|
||||
}
|
||||
|
||||
func (m *mockAuthStore) GetProjectOwned(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error) {
|
||||
return m.getProjectOwnedFn(ctx, projectID, userID)
|
||||
}
|
||||
|
||||
type mockSessionManager struct {
|
||||
createFn func(ctx context.Context, userID uuid.UUID) (string, time.Time, error)
|
||||
createFn func(ctx context.Context, userID uuid.UUID) (string, time.Time, error)
|
||||
validateFn func(ctx context.Context, token string) (uuid.UUID, error)
|
||||
|
||||
destroyCalled bool
|
||||
destroyToken string
|
||||
@@ -52,7 +58,10 @@ func (m *mockSessionManager) Create(ctx context.Context, userID uuid.UUID) (stri
|
||||
return m.createFn(ctx, userID)
|
||||
}
|
||||
|
||||
func (m *mockSessionManager) Validate(context.Context, string) (uuid.UUID, error) {
|
||||
func (m *mockSessionManager) Validate(ctx context.Context, token string) (uuid.UUID, error) {
|
||||
if m.validateFn != nil {
|
||||
return m.validateFn(ctx, token)
|
||||
}
|
||||
return uuid.Nil, nil
|
||||
}
|
||||
|
||||
@@ -338,7 +347,7 @@ func TestAuthLogout_ClearsSessionAndDestroys(t *testing.T) {
|
||||
// now resolves the authenticated user via GetUserByID and returns their real
|
||||
// email, instead of leaving it blank.
|
||||
func TestAuthMe_ReturnsRealEmail(t *testing.T) {
|
||||
a, authStore, _ := newTestAuthAPI()
|
||||
a, authStore, sessions := newTestAuthAPI()
|
||||
userID := uuid.New()
|
||||
projectID := uuid.New()
|
||||
authStore.getUserByIDFn = func(_ context.Context, id uuid.UUID) (store.User, error) {
|
||||
@@ -350,10 +359,15 @@ func TestAuthMe_ReturnsRealEmail(t *testing.T) {
|
||||
authStore.getUserProjectFn = func(_ context.Context, uid uuid.UUID) (store.Project, error) {
|
||||
return store.Project{ID: projectID, UserID: uid, Name: "default"}, nil
|
||||
}
|
||||
// /me is now behind RequireAuth: the session cookie must resolve to
|
||||
// userID via Sessions.Validate rather than being injected directly.
|
||||
sessions.validateFn = func(context.Context, string) (uuid.UUID, error) {
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
router := NewRouter(a)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/me", nil)
|
||||
req = req.WithContext(context.WithValue(req.Context(), ctxKeyUserID{}, userID))
|
||||
req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: "some-valid-token"})
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
|
||||
@@ -24,12 +24,15 @@ func writeErr(w http.ResponseWriter, status int, msg string) {
|
||||
}
|
||||
|
||||
func (a *API) handleCheck(w http.ResponseWriter, r *http.Request) {
|
||||
// pid is guaranteed present and owned by the caller — RequireProjectAccess
|
||||
// validated it before this handler ever runs.
|
||||
pid, _ := projectIDFrom(r.Context())
|
||||
did, err := uuid.Parse(chi.URLParam(r, "did"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid domain id")
|
||||
return
|
||||
}
|
||||
cs, err := a.Svc.Check(r.Context(), did)
|
||||
cs, err := a.Svc.Check(r.Context(), pid, did)
|
||||
if err != nil {
|
||||
log.Printf("api: check failed: %v", err)
|
||||
writeErr(w, http.StatusInternalServerError, "internal error")
|
||||
@@ -39,6 +42,9 @@ func (a *API) handleCheck(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (a *API) handleApply(w http.ResponseWriter, r *http.Request) {
|
||||
// pid is guaranteed present and owned by the caller — RequireProjectAccess
|
||||
// validated it before this handler ever runs.
|
||||
pid, _ := projectIDFrom(r.Context())
|
||||
did, err := uuid.Parse(chi.URLParam(r, "did"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid domain id")
|
||||
@@ -54,7 +60,7 @@ func (a *API) handleApply(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
}
|
||||
cs, err := a.Svc.Apply(r.Context(), did, service.ApplyRequest{
|
||||
cs, err := a.Svc.Apply(r.Context(), pid, did, service.ApplyRequest{
|
||||
ApplyUpdates: req.ApplyUpdates, ApplyPrunes: req.ApplyPrunes,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type ctxKey string
|
||||
|
||||
const (
|
||||
ctxUserID ctxKey = "userID"
|
||||
ctxProjectID ctxKey = "projectID"
|
||||
)
|
||||
|
||||
// RequireAuth validates the session cookie and, on success, stores the
|
||||
// authenticated user's ID in the request context (see userIDFrom). Any
|
||||
// failure — missing cookie or an invalid/expired token — is rejected with
|
||||
// 401 before reaching the wrapped handler.
|
||||
func (a *API) RequireAuth(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := r.Cookie(sessionCookieName)
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusUnauthorized, "unauthorized")
|
||||
return
|
||||
}
|
||||
uid, err := a.Sessions.Validate(r.Context(), c.Value)
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusUnauthorized, "unauthorized")
|
||||
return
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), ctxUserID, uid)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// RequireProjectAccess verifies that the {pid} URL segment names a project
|
||||
// owned by the authenticated user (set by RequireAuth, which must run
|
||||
// first) and, on success, stores the project ID in the request context (see
|
||||
// projectIDFrom). A project that doesn't exist or isn't owned by the caller
|
||||
// is rejected with 404 — not 403 — so a caller can't distinguish "not yours"
|
||||
// from "doesn't exist" (closes IDOR by never confirming another tenant's
|
||||
// project exists).
|
||||
func (a *API) RequireProjectAccess(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
uid, ok := userIDFrom(r.Context())
|
||||
if !ok {
|
||||
writeErr(w, http.StatusUnauthorized, "unauthorized")
|
||||
return
|
||||
}
|
||||
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid project id")
|
||||
return
|
||||
}
|
||||
if _, err := a.Auth.GetProjectOwned(r.Context(), pid, uid); err != nil {
|
||||
writeErr(w, http.StatusNotFound, "not found")
|
||||
return
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), ctxProjectID, pid)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func userIDFrom(ctx context.Context) (uuid.UUID, bool) {
|
||||
v, ok := ctx.Value(ctxUserID).(uuid.UUID)
|
||||
return v, ok
|
||||
}
|
||||
|
||||
func projectIDFrom(ctx context.Context) (uuid.UUID, bool) {
|
||||
v, ok := ctx.Value(ctxProjectID).(uuid.UUID)
|
||||
return v, ok
|
||||
}
|
||||
@@ -0,0 +1,265 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/vasyakrg/dns-autoresolver/internal/diff"
|
||||
"github.com/vasyakrg/dns-autoresolver/internal/service"
|
||||
"github.com/vasyakrg/dns-autoresolver/internal/store"
|
||||
)
|
||||
|
||||
// --- shared test doubles (also used by api_test.go / tenant_test.go) ---
|
||||
|
||||
// stubSessions is a configurable SessionManager test double.
|
||||
type stubSessions struct {
|
||||
validateFn func(ctx context.Context, token string) (uuid.UUID, error)
|
||||
}
|
||||
|
||||
func (s stubSessions) Create(context.Context, uuid.UUID) (string, time.Time, error) {
|
||||
return "stub-token", time.Now().Add(time.Hour), nil
|
||||
}
|
||||
func (s stubSessions) Validate(ctx context.Context, token string) (uuid.UUID, error) {
|
||||
return s.validateFn(ctx, token)
|
||||
}
|
||||
func (s stubSessions) Destroy(context.Context, string) error { return nil }
|
||||
|
||||
// alwaysValidSessions builds a stubSessions whose Validate always succeeds
|
||||
// with the given fixed user ID — for tests that only care about behavior
|
||||
// past the RequireAuth boundary.
|
||||
func alwaysValidSessions(uid uuid.UUID) stubSessions {
|
||||
return stubSessions{validateFn: func(context.Context, string) (uuid.UUID, error) { return uid, nil }}
|
||||
}
|
||||
|
||||
// stubAuthStore is a configurable AuthStore test double. Methods besides
|
||||
// GetProjectOwned return zero values — tests exercising register/login/me
|
||||
// use their own dedicated mock (mockAuthStore in auth_test.go).
|
||||
type stubAuthStore struct {
|
||||
getProjectOwnedFn func(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error)
|
||||
}
|
||||
|
||||
func (s stubAuthStore) RegisterUser(context.Context, string, string) (store.User, store.Project, error) {
|
||||
return store.User{}, store.Project{}, nil
|
||||
}
|
||||
func (s stubAuthStore) GetUserByEmail(context.Context, string) (store.User, error) {
|
||||
return store.User{}, nil
|
||||
}
|
||||
func (s stubAuthStore) GetUserByID(context.Context, uuid.UUID) (store.User, error) {
|
||||
return store.User{}, nil
|
||||
}
|
||||
func (s stubAuthStore) GetUserProject(context.Context, uuid.UUID) (store.Project, error) {
|
||||
return store.Project{}, nil
|
||||
}
|
||||
func (s stubAuthStore) GetProjectOwned(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error) {
|
||||
return s.getProjectOwnedFn(ctx, projectID, userID)
|
||||
}
|
||||
|
||||
// alwaysOwnedAuthStore builds a stubAuthStore whose GetProjectOwned always
|
||||
// succeeds, treating the caller as the owner of whatever project id is
|
||||
// requested — for tests that only care about behavior past the
|
||||
// RequireProjectAccess boundary (CRUD/check/apply tests).
|
||||
func alwaysOwnedAuthStore() stubAuthStore {
|
||||
return stubAuthStore{getProjectOwnedFn: func(_ context.Context, pid, uid uuid.UUID) (store.Project, error) {
|
||||
return store.Project{ID: pid, UserID: uid}, nil
|
||||
}}
|
||||
}
|
||||
|
||||
// requestWithSessionCookie builds an httptest request carrying a session
|
||||
// cookie so it clears RequireAuth in tests using stubSessions/alwaysValidSessions
|
||||
// (which ignore the token value and validate unconditionally).
|
||||
func requestWithSessionCookie(method, url string, body io.Reader) *http.Request {
|
||||
req := httptest.NewRequest(method, url, body)
|
||||
req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: "test-session-token"})
|
||||
return req
|
||||
}
|
||||
|
||||
// recordingCheckApplier is a CheckApplier test double that records whether
|
||||
// Check/Apply were invoked — used by the IDOR regression test to assert the
|
||||
// service layer is never reached for a project the caller doesn't own.
|
||||
type recordingCheckApplier struct {
|
||||
checkCalled bool
|
||||
applyCalled bool
|
||||
}
|
||||
|
||||
func (r *recordingCheckApplier) Check(context.Context, uuid.UUID, uuid.UUID) (diff.Changeset, error) {
|
||||
r.checkCalled = true
|
||||
return diff.Changeset{}, nil
|
||||
}
|
||||
func (r *recordingCheckApplier) Apply(context.Context, uuid.UUID, uuid.UUID, service.ApplyRequest) (diff.Changeset, error) {
|
||||
r.applyCalled = true
|
||||
return diff.Changeset{}, nil
|
||||
}
|
||||
|
||||
// --- RequireAuth ---
|
||||
|
||||
func TestRequireAuth_NoCookie_Returns401(t *testing.T) {
|
||||
a := &API{Sessions: stubSessions{validateFn: func(context.Context, string) (uuid.UUID, error) {
|
||||
t.Fatal("Validate must not be called when there is no cookie")
|
||||
return uuid.Nil, nil
|
||||
}}}
|
||||
nextCalled := false
|
||||
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { nextCalled = true })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/whatever", nil)
|
||||
w := httptest.NewRecorder()
|
||||
a.RequireAuth(next).ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401, got %d", w.Code)
|
||||
}
|
||||
if nextCalled {
|
||||
t.Fatal("next must not be called without a session cookie")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireAuth_InvalidToken_Returns401(t *testing.T) {
|
||||
a := &API{Sessions: stubSessions{validateFn: func(context.Context, string) (uuid.UUID, error) {
|
||||
return uuid.Nil, errors.New("invalid or expired session")
|
||||
}}}
|
||||
nextCalled := false
|
||||
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { nextCalled = true })
|
||||
|
||||
req := requestWithSessionCookie(http.MethodGet, "/whatever", nil)
|
||||
w := httptest.NewRecorder()
|
||||
a.RequireAuth(next).ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401, got %d", w.Code)
|
||||
}
|
||||
if nextCalled {
|
||||
t.Fatal("next must not be called when Sessions.Validate fails")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireAuth_Success_CallsNextWithUserID(t *testing.T) {
|
||||
uid := uuid.New()
|
||||
a := &API{Sessions: alwaysValidSessions(uid)}
|
||||
|
||||
var gotUID uuid.UUID
|
||||
var gotOK bool
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotUID, gotOK = userIDFrom(r.Context())
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
req := requestWithSessionCookie(http.MethodGet, "/whatever", nil)
|
||||
w := httptest.NewRecorder()
|
||||
a.RequireAuth(next).ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 (next called), got %d", w.Code)
|
||||
}
|
||||
if !gotOK || gotUID != uid {
|
||||
t.Fatalf("expected userIDFrom to yield %s, got %s (ok=%v)", uid, gotUID, gotOK)
|
||||
}
|
||||
}
|
||||
|
||||
// --- RequireProjectAccess ---
|
||||
|
||||
func TestRequireProjectAccess_NotOwned_Returns404(t *testing.T) {
|
||||
a := &API{Auth: stubAuthStore{getProjectOwnedFn: func(context.Context, uuid.UUID, uuid.UUID) (store.Project, error) {
|
||||
return store.Project{}, errors.New("no rows")
|
||||
}}}
|
||||
nextCalled := false
|
||||
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { nextCalled = true })
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.With(a.RequireProjectAccess).Get("/projects/{pid}", next.ServeHTTP)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/projects/"+uuid.New().String(), nil)
|
||||
req = req.WithContext(context.WithValue(req.Context(), ctxUserID, uuid.New()))
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Fatalf("expected 404 for an unowned/foreign project, got %d body %s", w.Code, w.Body.String())
|
||||
}
|
||||
if nextCalled {
|
||||
t.Fatal("next must not be called when the project isn't owned by the caller")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireProjectAccess_Success_CallsNextWithProjectID(t *testing.T) {
|
||||
a := &API{Auth: alwaysOwnedAuthStore()}
|
||||
pid := uuid.New()
|
||||
uid := uuid.New()
|
||||
|
||||
var gotPID uuid.UUID
|
||||
var gotOK bool
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPID, gotOK = projectIDFrom(r.Context())
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.With(a.RequireProjectAccess).Get("/projects/{pid}", next.ServeHTTP)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/projects/"+pid.String(), nil)
|
||||
req = req.WithContext(context.WithValue(req.Context(), ctxUserID, uid))
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 (next called), got %d body %s", w.Code, w.Body.String())
|
||||
}
|
||||
if !gotOK || gotPID != pid {
|
||||
t.Fatalf("expected projectIDFrom to yield %s, got %s (ok=%v)", pid, gotPID, gotOK)
|
||||
}
|
||||
}
|
||||
|
||||
// --- IDOR regression ---
|
||||
|
||||
// TestIDOR_CheckForeignProject_Returns404AndServiceNotCalled is the IDOR
|
||||
// regression required by Task 4: a request for project B's domain, made
|
||||
// with user A's session (who owns project A, not B), must be rejected by
|
||||
// RequireProjectAccess with 404 before service.Check is ever invoked — a
|
||||
// caller must never be able to check/apply another tenant's domain by
|
||||
// guessing/leaking its ids.
|
||||
func TestIDOR_CheckForeignProject_Returns404AndServiceNotCalled(t *testing.T) {
|
||||
userA := uuid.New()
|
||||
pidA := uuid.New()
|
||||
pidB := uuid.New() // owned by a different user; not A
|
||||
|
||||
svc := &recordingCheckApplier{}
|
||||
auth := stubAuthStore{getProjectOwnedFn: func(_ context.Context, pid, uid uuid.UUID) (store.Project, error) {
|
||||
if pid == pidA && uid == userA {
|
||||
return store.Project{ID: pidA, UserID: userA}, nil
|
||||
}
|
||||
return store.Project{}, errors.New("not found")
|
||||
}}
|
||||
a := &API{Svc: svc, Auth: auth, Sessions: alwaysValidSessions(userA)}
|
||||
router := NewRouter(a)
|
||||
|
||||
did := uuid.New().String()
|
||||
req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+pidB.String()+"/domains/"+did+"/check", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Fatalf("expected 404 for user A requesting project B's domain, got %d body %s", w.Code, w.Body.String())
|
||||
}
|
||||
if svc.checkCalled {
|
||||
t.Fatal("service.Check must not be called when the caller doesn't own the project")
|
||||
}
|
||||
|
||||
// Sanity check: the same user against their own project succeeds and
|
||||
// does reach the service — proving the 404 above is really about
|
||||
// project ownership, not e.g. a broken route.
|
||||
req2 := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+pidA.String()+"/domains/"+did+"/check", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(w2, req2)
|
||||
if w2.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 for user A requesting their own project, got %d body %s", w2.Code, w2.Body.String())
|
||||
}
|
||||
if !svc.checkCalled {
|
||||
t.Fatal("expected service.Check to be called for the caller's own project")
|
||||
}
|
||||
}
|
||||
@@ -24,11 +24,9 @@ func decodeBody(w http.ResponseWriter, r *http.Request, v any) bool {
|
||||
// --- accounts ---
|
||||
|
||||
func (a *API) handleCreateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid project id")
|
||||
return
|
||||
}
|
||||
// pid is guaranteed present and owned by the caller — RequireProjectAccess
|
||||
// validated it before this handler ever runs.
|
||||
pid, _ := projectIDFrom(r.Context())
|
||||
var req accountRequest
|
||||
if !decodeBody(w, r, &req) {
|
||||
return
|
||||
@@ -53,11 +51,9 @@ func (a *API) handleCreateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (a *API) handleListAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid project id")
|
||||
return
|
||||
}
|
||||
// pid is guaranteed present and owned by the caller — RequireProjectAccess
|
||||
// validated it before this handler ever runs.
|
||||
pid, _ := projectIDFrom(r.Context())
|
||||
accs, err := a.Store.ListAccounts(r.Context(), pid)
|
||||
if err != nil {
|
||||
log.Printf("api: list accounts failed: %v", err)
|
||||
@@ -72,11 +68,9 @@ func (a *API) handleListAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (a *API) handleDeleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid project id")
|
||||
return
|
||||
}
|
||||
// pid is guaranteed present and owned by the caller — RequireProjectAccess
|
||||
// validated it before this handler ever runs.
|
||||
pid, _ := projectIDFrom(r.Context())
|
||||
aid, err := uuid.Parse(chi.URLParam(r, "aid"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid account id")
|
||||
@@ -93,11 +87,9 @@ func (a *API) handleDeleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||
// handleImportZones lists zones from the provider for the given account and
|
||||
// creates one domain per zone (template_id left unset).
|
||||
func (a *API) handleImportZones(w http.ResponseWriter, r *http.Request) {
|
||||
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid project id")
|
||||
return
|
||||
}
|
||||
// pid is guaranteed present and owned by the caller — RequireProjectAccess
|
||||
// validated it before this handler ever runs.
|
||||
pid, _ := projectIDFrom(r.Context())
|
||||
aid, err := uuid.Parse(chi.URLParam(r, "aid"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid account id")
|
||||
@@ -145,11 +137,9 @@ func (a *API) handleImportZones(w http.ResponseWriter, r *http.Request) {
|
||||
// --- templates ---
|
||||
|
||||
func (a *API) handleCreateTemplate(w http.ResponseWriter, r *http.Request) {
|
||||
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid project id")
|
||||
return
|
||||
}
|
||||
// pid is guaranteed present and owned by the caller — RequireProjectAccess
|
||||
// validated it before this handler ever runs.
|
||||
pid, _ := projectIDFrom(r.Context())
|
||||
var req templateRequest
|
||||
if !decodeBody(w, r, &req) {
|
||||
return
|
||||
@@ -169,11 +159,9 @@ func (a *API) handleCreateTemplate(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (a *API) handleListTemplates(w http.ResponseWriter, r *http.Request) {
|
||||
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid project id")
|
||||
return
|
||||
}
|
||||
// pid is guaranteed present and owned by the caller — RequireProjectAccess
|
||||
// validated it before this handler ever runs.
|
||||
pid, _ := projectIDFrom(r.Context())
|
||||
tpls, err := a.Store.ListTemplates(r.Context(), pid)
|
||||
if err != nil {
|
||||
log.Printf("api: list templates failed: %v", err)
|
||||
@@ -188,11 +176,9 @@ func (a *API) handleListTemplates(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (a *API) handleUpdateTemplate(w http.ResponseWriter, r *http.Request) {
|
||||
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid project id")
|
||||
return
|
||||
}
|
||||
// pid is guaranteed present and owned by the caller — RequireProjectAccess
|
||||
// validated it before this handler ever runs.
|
||||
pid, _ := projectIDFrom(r.Context())
|
||||
tid, err := uuid.Parse(chi.URLParam(r, "tid"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid template id")
|
||||
@@ -217,11 +203,9 @@ func (a *API) handleUpdateTemplate(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (a *API) handleDeleteTemplate(w http.ResponseWriter, r *http.Request) {
|
||||
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid project id")
|
||||
return
|
||||
}
|
||||
// pid is guaranteed present and owned by the caller — RequireProjectAccess
|
||||
// validated it before this handler ever runs.
|
||||
pid, _ := projectIDFrom(r.Context())
|
||||
tid, err := uuid.Parse(chi.URLParam(r, "tid"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid template id")
|
||||
@@ -238,11 +222,9 @@ func (a *API) handleDeleteTemplate(w http.ResponseWriter, r *http.Request) {
|
||||
// --- domains ---
|
||||
|
||||
func (a *API) handleCreateDomain(w http.ResponseWriter, r *http.Request) {
|
||||
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid project id")
|
||||
return
|
||||
}
|
||||
// pid is guaranteed present and owned by the caller — RequireProjectAccess
|
||||
// validated it before this handler ever runs.
|
||||
pid, _ := projectIDFrom(r.Context())
|
||||
var req domainRequest
|
||||
if !decodeBody(w, r, &req) {
|
||||
return
|
||||
@@ -286,11 +268,9 @@ func (a *API) handleCreateDomain(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (a *API) handleListDomains(w http.ResponseWriter, r *http.Request) {
|
||||
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid project id")
|
||||
return
|
||||
}
|
||||
// pid is guaranteed present and owned by the caller — RequireProjectAccess
|
||||
// validated it before this handler ever runs.
|
||||
pid, _ := projectIDFrom(r.Context())
|
||||
doms, err := a.Store.ListDomains(r.Context(), pid)
|
||||
if err != nil {
|
||||
log.Printf("api: list domains failed: %v", err)
|
||||
@@ -308,11 +288,9 @@ func (a *API) handleListDomains(w http.ResponseWriter, r *http.Request) {
|
||||
// check/apply a domain — this is what makes an imported domain (which
|
||||
// starts with template_id=NULL) checkable, closing the import→check loop.
|
||||
func (a *API) handleSetDomainTemplate(w http.ResponseWriter, r *http.Request) {
|
||||
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid project id")
|
||||
return
|
||||
}
|
||||
// pid is guaranteed present and owned by the caller — RequireProjectAccess
|
||||
// validated it before this handler ever runs.
|
||||
pid, _ := projectIDFrom(r.Context())
|
||||
did, err := uuid.Parse(chi.URLParam(r, "did"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid domain id")
|
||||
@@ -339,11 +317,9 @@ func (a *API) handleSetDomainTemplate(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (a *API) handleDeleteDomain(w http.ResponseWriter, r *http.Request) {
|
||||
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid project id")
|
||||
return
|
||||
}
|
||||
// pid is guaranteed present and owned by the caller — RequireProjectAccess
|
||||
// validated it before this handler ever runs.
|
||||
pid, _ := projectIDFrom(r.Context())
|
||||
did, err := uuid.Parse(chi.URLParam(r, "did"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid domain id")
|
||||
|
||||
+29
-19
@@ -132,7 +132,9 @@ func (m *mockTenantStore) ImportDomains(_ context.Context, projectID, accountID
|
||||
|
||||
type mockCipher struct{}
|
||||
|
||||
func (mockCipher) Encrypt(plaintext []byte) (string, error) { return "ENC(" + string(plaintext) + ")", nil }
|
||||
func (mockCipher) Encrypt(plaintext []byte) (string, error) {
|
||||
return "ENC(" + string(plaintext) + ")", nil
|
||||
}
|
||||
func (mockCipher) Decrypt(enc string) ([]byte, error) {
|
||||
return []byte(strings.TrimSuffix(strings.TrimPrefix(enc, "ENC("), ")")), nil
|
||||
}
|
||||
@@ -160,9 +162,17 @@ func (mockProvider) ApplyChanges(context.Context, provider.Credentials, string,
|
||||
return nil
|
||||
}
|
||||
|
||||
// newTenantTestAPI wires a fixed authenticated user who owns whatever
|
||||
// project id is requested (alwaysOwnedAuthStore/alwaysValidSessions, see
|
||||
// middleware_test.go) — these tests exercise CRUD behavior past the
|
||||
// RequireAuth/RequireProjectAccess boundary, which has its own dedicated
|
||||
// coverage in middleware_test.go.
|
||||
func newTenantTestAPI() (*API, *mockTenantStore) {
|
||||
ts := &mockTenantStore{}
|
||||
a := &API{Store: ts, Cipher: mockCipher{}, Reg: &mockRegistry{}}
|
||||
a := &API{
|
||||
Store: ts, Cipher: mockCipher{}, Reg: &mockRegistry{},
|
||||
Auth: alwaysOwnedAuthStore(), Sessions: alwaysValidSessions(uuid.New()),
|
||||
}
|
||||
return a, ts
|
||||
}
|
||||
|
||||
@@ -173,7 +183,7 @@ func TestCreateAccount_SecretEncryptedAndNotInResponse(t *testing.T) {
|
||||
router := NewRouter(a)
|
||||
|
||||
body := `{"provider":"selectel","secret":"super-secret-token","comment":"prod"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts", strings.NewReader(body))
|
||||
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
@@ -211,7 +221,7 @@ func TestListAccounts_NoSecretsInResponse(t *testing.T) {
|
||||
}
|
||||
router := NewRouter(a)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/projects/"+testPID+"/accounts", nil)
|
||||
req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+testPID+"/accounts", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
@@ -234,7 +244,7 @@ func TestDeleteAccount_BadUUID(t *testing.T) {
|
||||
a, _ := newTenantTestAPI()
|
||||
router := NewRouter(a)
|
||||
|
||||
req := httptest.NewRequest(http.MethodDelete, "/api/v1/projects/"+testPID+"/accounts/not-a-uuid", nil)
|
||||
req := requestWithSessionCookie(http.MethodDelete, "/api/v1/projects/"+testPID+"/accounts/not-a-uuid", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
@@ -250,7 +260,7 @@ func TestCreateTemplate_SavesRecords(t *testing.T) {
|
||||
router := NewRouter(a)
|
||||
|
||||
body := `{"name":"base","records":[{"type":"A","name":"@","ttl":300,"values":["1.2.3.4"]}]}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/templates", strings.NewReader(body))
|
||||
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/templates", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
@@ -278,7 +288,7 @@ func TestUpdateTemplate_BadUUID(t *testing.T) {
|
||||
router := NewRouter(a)
|
||||
|
||||
body := `{"name":"x","records":[]}`
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/v1/projects/"+testPID+"/templates/not-a-uuid", strings.NewReader(body))
|
||||
req := requestWithSessionCookie(http.MethodPut, "/api/v1/projects/"+testPID+"/templates/not-a-uuid", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
@@ -299,7 +309,7 @@ func TestImportZones_CreatesDomainPerZone(t *testing.T) {
|
||||
}}
|
||||
router := NewRouter(a)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/"+accID.String()+"/import", nil)
|
||||
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/"+accID.String()+"/import", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
@@ -337,7 +347,7 @@ func TestImportZones_AtomicRollbackOnError(t *testing.T) {
|
||||
}}
|
||||
router := NewRouter(a)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/"+accID.String()+"/import", nil)
|
||||
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/"+accID.String()+"/import", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
@@ -356,7 +366,7 @@ func TestImportZones_BadAccountUUID(t *testing.T) {
|
||||
a, _ := newTenantTestAPI()
|
||||
router := NewRouter(a)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/not-a-uuid/import", nil)
|
||||
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/not-a-uuid/import", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
@@ -370,7 +380,7 @@ func TestCreateDomain_BadProjectUUID(t *testing.T) {
|
||||
router := NewRouter(a)
|
||||
|
||||
body := `{"providerAccountId":"` + uuid.New().String() + `","zoneName":"example.com","zoneId":"z1"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/not-a-uuid/domains", strings.NewReader(body))
|
||||
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/not-a-uuid/domains", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
@@ -390,7 +400,7 @@ func TestCreateDomain_AccountNotFoundInProject(t *testing.T) {
|
||||
// ts.accounts is empty, so GetAccount will not find this id.
|
||||
foreignAccID := uuid.New()
|
||||
body := `{"providerAccountId":"` + foreignAccID.String() + `","zoneName":"example.com","zoneId":"z1"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
|
||||
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
@@ -413,7 +423,7 @@ func TestCreateDomain_TemplateNotFoundInProject(t *testing.T) {
|
||||
|
||||
foreignTplID := uuid.New()
|
||||
body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1","templateId":"` + foreignTplID.String() + `"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
|
||||
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
@@ -434,7 +444,7 @@ func TestCreateDomain_HappyPath(t *testing.T) {
|
||||
router := NewRouter(a)
|
||||
|
||||
body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
|
||||
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
@@ -464,7 +474,7 @@ func TestCreateDomain_ValidTemplateInProject(t *testing.T) {
|
||||
router := NewRouter(a)
|
||||
|
||||
body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1","templateId":"` + tplID.String() + `"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
|
||||
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
@@ -490,7 +500,7 @@ func TestSetDomainTemplate_ValidTemplateId(t *testing.T) {
|
||||
router := NewRouter(a)
|
||||
|
||||
body := `{"templateId":"` + tplID.String() + `"}`
|
||||
req := httptest.NewRequest(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body))
|
||||
req := requestWithSessionCookie(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
@@ -513,7 +523,7 @@ func TestSetDomainTemplate_BadTemplateUUID(t *testing.T) {
|
||||
router := NewRouter(a)
|
||||
|
||||
body := `{"templateId":"not-a-uuid"}`
|
||||
req := httptest.NewRequest(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body))
|
||||
req := requestWithSessionCookie(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
@@ -530,7 +540,7 @@ func TestSetDomainTemplate_TemplateNotFound(t *testing.T) {
|
||||
router := NewRouter(a)
|
||||
|
||||
body := `{"templateId":"` + uuid.New().String() + `"}`
|
||||
req := httptest.NewRequest(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body))
|
||||
req := requestWithSessionCookie(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
@@ -543,7 +553,7 @@ func TestDeleteDomain_BadUUID(t *testing.T) {
|
||||
a, _ := newTenantTestAPI()
|
||||
router := NewRouter(a)
|
||||
|
||||
req := httptest.NewRequest(http.MethodDelete, "/api/v1/projects/"+testPID+"/domains/not-a-uuid", nil)
|
||||
req := requestWithSessionCookie(http.MethodDelete, "/api/v1/projects/"+testPID+"/domains/not-a-uuid", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ type DomainRef struct {
|
||||
}
|
||||
|
||||
type Loader interface {
|
||||
LoadDomain(ctx context.Context, domainID uuid.UUID) (DomainRef, error)
|
||||
LoadDomain(ctx context.Context, projectID, domainID uuid.UUID) (DomainRef, error)
|
||||
}
|
||||
|
||||
type Recorder interface {
|
||||
@@ -45,8 +45,10 @@ func New(loader Loader, rec Recorder, reg *registry.Registry, cipher *crypto.Cip
|
||||
}
|
||||
|
||||
// resolve loads the domain, its provider and decrypted credentials, and computes the diff.
|
||||
func (s *DomainService) resolve(ctx context.Context, domainID uuid.UUID) (provider.Provider, provider.Credentials, DomainRef, diff.Changeset, error) {
|
||||
ref, err := s.loader.LoadDomain(ctx, domainID)
|
||||
// projectID scopes the lookup so a domainID belonging to another tenant's
|
||||
// project can never be resolved here (closes IDOR).
|
||||
func (s *DomainService) resolve(ctx context.Context, projectID, domainID uuid.UUID) (provider.Provider, provider.Credentials, DomainRef, diff.Changeset, error) {
|
||||
ref, err := s.loader.LoadDomain(ctx, projectID, domainID)
|
||||
if err != nil {
|
||||
return nil, provider.Credentials{}, ref, diff.Changeset{}, err
|
||||
}
|
||||
@@ -68,8 +70,8 @@ func (s *DomainService) resolve(ctx context.Context, domainID uuid.UUID) (provid
|
||||
}
|
||||
|
||||
// Check computes and records the diff between template and zone.
|
||||
func (s *DomainService) Check(ctx context.Context, domainID uuid.UUID) (diff.Changeset, error) {
|
||||
_, _, _, cs, err := s.resolve(ctx, domainID)
|
||||
func (s *DomainService) Check(ctx context.Context, projectID, domainID uuid.UUID) (diff.Changeset, error) {
|
||||
_, _, _, cs, err := s.resolve(ctx, projectID, domainID)
|
||||
if err != nil {
|
||||
return diff.Changeset{}, err
|
||||
}
|
||||
@@ -80,8 +82,8 @@ func (s *DomainService) Check(ctx context.Context, domainID uuid.UUID) (diff.Cha
|
||||
}
|
||||
|
||||
// Apply applies updates always (when ApplyUpdates) and prunes only when ApplyPrunes.
|
||||
func (s *DomainService) Apply(ctx context.Context, domainID uuid.UUID, req ApplyRequest) (diff.Changeset, error) {
|
||||
p, creds, ref, cs, err := s.resolve(ctx, domainID)
|
||||
func (s *DomainService) Apply(ctx context.Context, projectID, domainID uuid.UUID, req ApplyRequest) (diff.Changeset, error) {
|
||||
p, creds, ref, cs, err := s.resolve(ctx, projectID, domainID)
|
||||
if err != nil {
|
||||
return diff.Changeset{}, err
|
||||
}
|
||||
|
||||
@@ -44,7 +44,9 @@ func (f *fakeProvider) ApplyChanges(_ context.Context, _ provider.Credentials, _
|
||||
|
||||
type fakeLoader struct{ ref DomainRef }
|
||||
|
||||
func (l fakeLoader) LoadDomain(context.Context, uuid.UUID) (DomainRef, error) { return l.ref, nil }
|
||||
func (l fakeLoader) LoadDomain(context.Context, uuid.UUID, uuid.UUID) (DomainRef, error) {
|
||||
return l.ref, nil
|
||||
}
|
||||
|
||||
type nopRecorder struct{}
|
||||
|
||||
@@ -66,7 +68,7 @@ func TestCheckProducesDiff(t *testing.T) {
|
||||
{Type: "A", Name: "a.example.com.", TTL: 300, Values: []string{"1.1.1.1"}}, // update
|
||||
}}
|
||||
svc, _ := setup(t, actual, tmpl)
|
||||
cs, err := svc.Check(context.Background(), uuid.New())
|
||||
cs, err := svc.Check(context.Background(), uuid.New(), uuid.New())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -87,7 +89,7 @@ func TestApplyRespectsPruneGuard(t *testing.T) {
|
||||
|
||||
// applyPrunes=false → удаление b НЕ применяется
|
||||
svc, fp := setup(t, actual, tmpl)
|
||||
if _, err := svc.Apply(context.Background(), uuid.New(), ApplyRequest{ApplyUpdates: true, ApplyPrunes: false}); err != nil {
|
||||
if _, err := svc.Apply(context.Background(), uuid.New(), uuid.New(), ApplyRequest{ApplyUpdates: true, ApplyPrunes: false}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for _, d := range fp.applied.Diffs {
|
||||
@@ -98,7 +100,7 @@ func TestApplyRespectsPruneGuard(t *testing.T) {
|
||||
|
||||
// applyPrunes=true → удаление b применяется
|
||||
svc2, fp2 := setup(t, actual, tmpl)
|
||||
if _, err := svc2.Apply(context.Background(), uuid.New(), ApplyRequest{ApplyUpdates: true, ApplyPrunes: true}); err != nil {
|
||||
if _, err := svc2.Apply(context.Background(), uuid.New(), uuid.New(), ApplyRequest{ApplyUpdates: true, ApplyPrunes: true}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var sawDelete bool
|
||||
|
||||
@@ -162,9 +162,14 @@ SELECT d.zone_id, a.provider, a.secret_enc, t.doc
|
||||
FROM domains d
|
||||
JOIN provider_accounts a ON a.id = d.provider_account_id
|
||||
LEFT JOIN templates t ON t.id = d.template_id
|
||||
WHERE d.id = $1
|
||||
WHERE d.id = $1 AND d.project_id = $2
|
||||
`
|
||||
|
||||
type LoadDomainFullParams struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
ProjectID uuid.UUID `json:"project_id"`
|
||||
}
|
||||
|
||||
type LoadDomainFullRow struct {
|
||||
ZoneID string `json:"zone_id"`
|
||||
Provider string `json:"provider"`
|
||||
@@ -172,8 +177,8 @@ type LoadDomainFullRow struct {
|
||||
Doc *dto.TemplateDoc `json:"doc"`
|
||||
}
|
||||
|
||||
func (q *Queries) LoadDomainFull(ctx context.Context, id uuid.UUID) (LoadDomainFullRow, error) {
|
||||
row := q.db.QueryRow(ctx, loadDomainFull, id)
|
||||
func (q *Queries) LoadDomainFull(ctx context.Context, arg LoadDomainFullParams) (LoadDomainFullRow, error) {
|
||||
row := q.db.QueryRow(ctx, loadDomainFull, arg.ID, arg.ProjectID)
|
||||
var i LoadDomainFullRow
|
||||
err := row.Scan(
|
||||
&i.ZoneID,
|
||||
|
||||
@@ -13,9 +13,11 @@ import (
|
||||
)
|
||||
|
||||
// LoadDomain joins domains+provider_accounts+templates to build the
|
||||
// service.DomainRef needed to check/apply a domain's DNS records.
|
||||
func (s *Store) LoadDomain(ctx context.Context, domainID uuid.UUID) (service.DomainRef, error) {
|
||||
row, err := s.q.LoadDomainFull(ctx, domainID)
|
||||
// service.DomainRef needed to check/apply a domain's DNS records. Scoped by
|
||||
// projectID so a domain belonging to another tenant's project can never be
|
||||
// loaded, even if its domainID is guessed/leaked (closes IDOR).
|
||||
func (s *Store) LoadDomain(ctx context.Context, projectID, domainID uuid.UUID) (service.DomainRef, error) {
|
||||
row, err := s.q.LoadDomainFull(ctx, db.LoadDomainFullParams{ID: domainID, ProjectID: projectID})
|
||||
if err != nil {
|
||||
return service.DomainRef{}, err
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ func TestLoadDomainAndSaveCheckRun(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ref, err := s.LoadDomain(ctx, domain.ID)
|
||||
ref, err := s.LoadDomain(ctx, defaultProject, domain.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -87,7 +87,7 @@ func TestLoadDomainNoTemplate(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := s.LoadDomain(ctx, domain.ID); err == nil {
|
||||
if _, err := s.LoadDomain(ctx, defaultProject, domain.ID); err == nil {
|
||||
t.Fatal("expected error for domain without template, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,4 +27,4 @@ SELECT d.zone_id, a.provider, a.secret_enc, t.doc
|
||||
FROM domains d
|
||||
JOIN provider_accounts a ON a.id = d.provider_account_id
|
||||
LEFT JOIN templates t ON t.id = d.template_id
|
||||
WHERE d.id = $1;
|
||||
WHERE d.id = $1 AND d.project_id = $2;
|
||||
|
||||
@@ -214,7 +214,7 @@ func TestSetDomainTemplate_ClosesImportCheckLoop(t *testing.T) {
|
||||
dom := doms[0]
|
||||
|
||||
// Before binding, the domain is not checkable.
|
||||
if _, err := s.LoadDomain(ctx, dom.ID); err == nil {
|
||||
if _, err := s.LoadDomain(ctx, defaultProject, dom.ID); err == nil {
|
||||
t.Fatal("expected LoadDomain to fail before a template is bound")
|
||||
}
|
||||
|
||||
@@ -234,7 +234,7 @@ func TestSetDomainTemplate_ClosesImportCheckLoop(t *testing.T) {
|
||||
t.Fatalf("expected domain.TemplateID=%s, got %+v", tpl.ID, updated.TemplateID)
|
||||
}
|
||||
|
||||
ref, err := s.LoadDomain(ctx, dom.ID)
|
||||
ref, err := s.LoadDomain(ctx, defaultProject, dom.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("expected LoadDomain to succeed after binding template, got error: %v", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user