feat(api): RequireAuth+RequireProjectAccess middleware, IDOR-scope check/apply по projectID

This commit is contained in:
2026-07-03 20:47:40 +07:00
parent 35ffe73ae3
commit 4533b0ca25
16 changed files with 498 additions and 143 deletions
+14 -4
View File
@@ -18,8 +18,8 @@ import (
// CheckApplier is the service surface the API depends on.
type CheckApplier interface {
Check(ctx context.Context, domainID uuid.UUID) (diff.Changeset, error)
Apply(ctx context.Context, domainID uuid.UUID, req service.ApplyRequest) (diff.Changeset, error)
Check(ctx context.Context, projectID, domainID uuid.UUID) (diff.Changeset, error)
Apply(ctx context.Context, projectID, domainID uuid.UUID, req service.ApplyRequest) (diff.Changeset, error)
}
// TenantStore is the narrow persistence surface the CRUD handlers depend on.
@@ -63,6 +63,10 @@ type AuthStore interface {
GetUserByEmail(ctx context.Context, email string) (store.User, error)
GetUserByID(ctx context.Context, userID uuid.UUID) (store.User, error)
GetUserProject(ctx context.Context, userID uuid.UUID) (store.Project, error)
// GetProjectOwned looks up projectID and returns it only if it's owned by
// userID — RequireProjectAccess uses this to reject foreign/nonexistent
// projects with 404 before any handler runs.
GetProjectOwned(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error)
}
// SessionManager creates/validates/destroys login sessions. *auth.Sessions
@@ -91,11 +95,17 @@ func NewRouter(a *API) http.Handler {
r.Route("/api/v1/auth", func(r chi.Router) {
r.Post("/register", a.handleRegister)
r.Post("/login", a.handleLogin)
r.Post("/logout", a.handleLogout) // защитится RequireAuth в Task 4
r.Get("/me", a.handleMe) // защитится RequireAuth в Task 4
r.Group(func(r chi.Router) {
r.Use(a.RequireAuth)
r.Post("/logout", a.handleLogout)
r.Get("/me", a.handleMe)
})
})
r.Route("/api/v1/projects/{pid}", func(r chi.Router) {
r.Use(a.RequireAuth)
r.Use(a.RequireProjectAccess)
r.Route("/domains", func(r chi.Router) {
r.Post("/", a.handleCreateDomain)
r.Get("/", a.handleListDomains)
+13 -8
View File
@@ -20,18 +20,23 @@ type mockCheckApplier struct {
lastReq service.ApplyRequest
}
func (m *mockCheckApplier) Check(context.Context, uuid.UUID) (diff.Changeset, error) {
func (m *mockCheckApplier) Check(context.Context, uuid.UUID, uuid.UUID) (diff.Changeset, error) {
d := model.Record{Type: model.A, Name: "a.example.com.", TTL: 300, Values: []string{"1.1.1.1"}}
return diff.Changeset{Diffs: []diff.RecordDiff{{Kind: diff.Add, Type: d.Type, Name: d.Name, Desired: &d}}}, nil
}
func (m *mockCheckApplier) Apply(_ context.Context, _ uuid.UUID, req service.ApplyRequest) (diff.Changeset, error) {
func (m *mockCheckApplier) Apply(_ context.Context, _, _ uuid.UUID, req service.ApplyRequest) (diff.Changeset, error) {
m.lastReq = req
return diff.Changeset{}, nil
}
// newTestAPI wires a fixed authenticated user who owns whatever project id
// is requested (via alwaysOwnedAuthStore/alwaysValidSessions in
// middleware_test.go) — these tests exercise check/apply behavior past the
// RequireAuth/RequireProjectAccess boundary, which is covered separately by
// middleware_test.go's own tests and the IDOR regression.
func newTestAPI() (*API, *mockCheckApplier) {
m := &mockCheckApplier{}
return &API{Svc: m}, m // остальные зависимости (store/cipher) nil — CRUD-тесты добавит реализатор
return &API{Svc: m, Auth: alwaysOwnedAuthStore(), Sessions: alwaysValidSessions(uuid.New())}, m
}
func TestCheckEndpoint(t *testing.T) {
@@ -39,7 +44,7 @@ func TestCheckEndpoint(t *testing.T) {
router := NewRouter(a)
did := uuid.New().String()
req := httptest.NewRequest(http.MethodGet,
req := requestWithSessionCookie(http.MethodGet,
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/check", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -62,7 +67,7 @@ func TestApplyDefaultsPruneFalse(t *testing.T) {
did := uuid.New().String()
body := `{"applyUpdates":true}` // applyPrunes отсутствует → false
req := httptest.NewRequest(http.MethodPost,
req := requestWithSessionCookie(http.MethodPost,
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply",
strings.NewReader(body))
w := httptest.NewRecorder()
@@ -81,7 +86,7 @@ func TestApplyEmptyBodyOK(t *testing.T) {
router := NewRouter(a)
did := uuid.New().String()
req := httptest.NewRequest(http.MethodPost,
req := requestWithSessionCookie(http.MethodPost,
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -100,7 +105,7 @@ func TestApplyMalformedBody(t *testing.T) {
did := uuid.New().String()
body := `{"applyUpdates":`
req := httptest.NewRequest(http.MethodPost,
req := requestWithSessionCookie(http.MethodPost,
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply",
strings.NewReader(body))
w := httptest.NewRecorder()
@@ -114,7 +119,7 @@ func TestApplyMalformedBody(t *testing.T) {
func TestApplyBadUUID(t *testing.T) {
a, _ := newTestAPI()
router := NewRouter(a)
req := httptest.NewRequest(http.MethodPost,
req := requestWithSessionCookie(http.MethodPost,
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/not-a-uuid/apply",
bytes.NewReader([]byte(`{}`)))
w := httptest.NewRecorder()
+3 -20
View File
@@ -1,15 +1,12 @@
package api
import (
"context"
"errors"
"log"
"net/http"
"strings"
"time"
"github.com/google/uuid"
"github.com/vasyakrg/dns-autoresolver/internal/auth"
"github.com/vasyakrg/dns-autoresolver/internal/store"
)
@@ -39,19 +36,6 @@ func normalizeEmail(email string) string {
return strings.ToLower(strings.TrimSpace(email))
}
// ctxKeyUserID is a private context key carrying the authenticated user's ID.
// Task 4's RequireAuth middleware sets it after validating the session
// cookie; handleMe reads it back.
type ctxKeyUserID struct{}
// userIDFromContext extracts the authenticated user ID set by RequireAuth
// (Task 4). Until that middleware is wired in, tests set it directly via
// context.WithValue.
func userIDFromContext(ctx context.Context) (uuid.UUID, bool) {
id, ok := ctx.Value(ctxKeyUserID{}).(uuid.UUID)
return id, ok
}
func setSessionCookie(w http.ResponseWriter, token string, exp time.Time) {
http.SetCookie(w, &http.Cookie{
Name: sessionCookieName, Value: token, Path: "/",
@@ -165,11 +149,10 @@ func (a *API) handleLogout(w http.ResponseWriter, r *http.Request) {
}
// handleMe returns the authenticated caller's identity + default project.
// The user ID comes from the request context, set by Task 4's RequireAuth
// middleware after validating the session cookie (tests set it directly via
// context.WithValue in the interim).
// The user ID comes from the request context, set by RequireAuth after
// validating the session cookie.
func (a *API) handleMe(w http.ResponseWriter, r *http.Request) {
userID, ok := userIDFromContext(r.Context())
userID, ok := userIDFrom(r.Context())
if !ok {
writeErr(w, http.StatusUnauthorized, "authentication required")
return
+22 -8
View File
@@ -18,10 +18,11 @@ import (
// --- mocks ---
type mockAuthStore struct {
registerUserFn func(ctx context.Context, email, passwordHash string) (store.User, store.Project, error)
getUserByEmailFn func(ctx context.Context, email string) (store.User, error)
getUserByIDFn func(ctx context.Context, userID uuid.UUID) (store.User, error)
getUserProjectFn func(ctx context.Context, userID uuid.UUID) (store.Project, error)
registerUserFn func(ctx context.Context, email, passwordHash string) (store.User, store.Project, error)
getUserByEmailFn func(ctx context.Context, email string) (store.User, error)
getUserByIDFn func(ctx context.Context, userID uuid.UUID) (store.User, error)
getUserProjectFn func(ctx context.Context, userID uuid.UUID) (store.Project, error)
getProjectOwnedFn func(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error)
}
func (m *mockAuthStore) RegisterUser(ctx context.Context, email, passwordHash string) (store.User, store.Project, error) {
@@ -40,8 +41,13 @@ func (m *mockAuthStore) GetUserProject(ctx context.Context, userID uuid.UUID) (s
return m.getUserProjectFn(ctx, userID)
}
func (m *mockAuthStore) GetProjectOwned(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error) {
return m.getProjectOwnedFn(ctx, projectID, userID)
}
type mockSessionManager struct {
createFn func(ctx context.Context, userID uuid.UUID) (string, time.Time, error)
createFn func(ctx context.Context, userID uuid.UUID) (string, time.Time, error)
validateFn func(ctx context.Context, token string) (uuid.UUID, error)
destroyCalled bool
destroyToken string
@@ -52,7 +58,10 @@ func (m *mockSessionManager) Create(ctx context.Context, userID uuid.UUID) (stri
return m.createFn(ctx, userID)
}
func (m *mockSessionManager) Validate(context.Context, string) (uuid.UUID, error) {
func (m *mockSessionManager) Validate(ctx context.Context, token string) (uuid.UUID, error) {
if m.validateFn != nil {
return m.validateFn(ctx, token)
}
return uuid.Nil, nil
}
@@ -338,7 +347,7 @@ func TestAuthLogout_ClearsSessionAndDestroys(t *testing.T) {
// now resolves the authenticated user via GetUserByID and returns their real
// email, instead of leaving it blank.
func TestAuthMe_ReturnsRealEmail(t *testing.T) {
a, authStore, _ := newTestAuthAPI()
a, authStore, sessions := newTestAuthAPI()
userID := uuid.New()
projectID := uuid.New()
authStore.getUserByIDFn = func(_ context.Context, id uuid.UUID) (store.User, error) {
@@ -350,10 +359,15 @@ func TestAuthMe_ReturnsRealEmail(t *testing.T) {
authStore.getUserProjectFn = func(_ context.Context, uid uuid.UUID) (store.Project, error) {
return store.Project{ID: projectID, UserID: uid, Name: "default"}, nil
}
// /me is now behind RequireAuth: the session cookie must resolve to
// userID via Sessions.Validate rather than being injected directly.
sessions.validateFn = func(context.Context, string) (uuid.UUID, error) {
return userID, nil
}
router := NewRouter(a)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/me", nil)
req = req.WithContext(context.WithValue(req.Context(), ctxKeyUserID{}, userID))
req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: "some-valid-token"})
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
+8 -2
View File
@@ -24,12 +24,15 @@ func writeErr(w http.ResponseWriter, status int, msg string) {
}
func (a *API) handleCheck(w http.ResponseWriter, r *http.Request) {
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
did, err := uuid.Parse(chi.URLParam(r, "did"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid domain id")
return
}
cs, err := a.Svc.Check(r.Context(), did)
cs, err := a.Svc.Check(r.Context(), pid, did)
if err != nil {
log.Printf("api: check failed: %v", err)
writeErr(w, http.StatusInternalServerError, "internal error")
@@ -39,6 +42,9 @@ func (a *API) handleCheck(w http.ResponseWriter, r *http.Request) {
}
func (a *API) handleApply(w http.ResponseWriter, r *http.Request) {
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
did, err := uuid.Parse(chi.URLParam(r, "did"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid domain id")
@@ -54,7 +60,7 @@ func (a *API) handleApply(w http.ResponseWriter, r *http.Request) {
return
}
}
cs, err := a.Svc.Apply(r.Context(), did, service.ApplyRequest{
cs, err := a.Svc.Apply(r.Context(), pid, did, service.ApplyRequest{
ApplyUpdates: req.ApplyUpdates, ApplyPrunes: req.ApplyPrunes,
})
if err != nil {
+75
View File
@@ -0,0 +1,75 @@
package api
import (
"context"
"net/http"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
)
type ctxKey string
const (
ctxUserID ctxKey = "userID"
ctxProjectID ctxKey = "projectID"
)
// RequireAuth validates the session cookie and, on success, stores the
// authenticated user's ID in the request context (see userIDFrom). Any
// failure — missing cookie or an invalid/expired token — is rejected with
// 401 before reaching the wrapped handler.
func (a *API) RequireAuth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := r.Cookie(sessionCookieName)
if err != nil {
writeErr(w, http.StatusUnauthorized, "unauthorized")
return
}
uid, err := a.Sessions.Validate(r.Context(), c.Value)
if err != nil {
writeErr(w, http.StatusUnauthorized, "unauthorized")
return
}
ctx := context.WithValue(r.Context(), ctxUserID, uid)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// RequireProjectAccess verifies that the {pid} URL segment names a project
// owned by the authenticated user (set by RequireAuth, which must run
// first) and, on success, stores the project ID in the request context (see
// projectIDFrom). A project that doesn't exist or isn't owned by the caller
// is rejected with 404 — not 403 — so a caller can't distinguish "not yours"
// from "doesn't exist" (closes IDOR by never confirming another tenant's
// project exists).
func (a *API) RequireProjectAccess(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
uid, ok := userIDFrom(r.Context())
if !ok {
writeErr(w, http.StatusUnauthorized, "unauthorized")
return
}
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid project id")
return
}
if _, err := a.Auth.GetProjectOwned(r.Context(), pid, uid); err != nil {
writeErr(w, http.StatusNotFound, "not found")
return
}
ctx := context.WithValue(r.Context(), ctxProjectID, pid)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func userIDFrom(ctx context.Context) (uuid.UUID, bool) {
v, ok := ctx.Value(ctxUserID).(uuid.UUID)
return v, ok
}
func projectIDFrom(ctx context.Context) (uuid.UUID, bool) {
v, ok := ctx.Value(ctxProjectID).(uuid.UUID)
return v, ok
}
+265
View File
@@ -0,0 +1,265 @@
package api
import (
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/vasyakrg/dns-autoresolver/internal/diff"
"github.com/vasyakrg/dns-autoresolver/internal/service"
"github.com/vasyakrg/dns-autoresolver/internal/store"
)
// --- shared test doubles (also used by api_test.go / tenant_test.go) ---
// stubSessions is a configurable SessionManager test double.
type stubSessions struct {
validateFn func(ctx context.Context, token string) (uuid.UUID, error)
}
func (s stubSessions) Create(context.Context, uuid.UUID) (string, time.Time, error) {
return "stub-token", time.Now().Add(time.Hour), nil
}
func (s stubSessions) Validate(ctx context.Context, token string) (uuid.UUID, error) {
return s.validateFn(ctx, token)
}
func (s stubSessions) Destroy(context.Context, string) error { return nil }
// alwaysValidSessions builds a stubSessions whose Validate always succeeds
// with the given fixed user ID — for tests that only care about behavior
// past the RequireAuth boundary.
func alwaysValidSessions(uid uuid.UUID) stubSessions {
return stubSessions{validateFn: func(context.Context, string) (uuid.UUID, error) { return uid, nil }}
}
// stubAuthStore is a configurable AuthStore test double. Methods besides
// GetProjectOwned return zero values — tests exercising register/login/me
// use their own dedicated mock (mockAuthStore in auth_test.go).
type stubAuthStore struct {
getProjectOwnedFn func(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error)
}
func (s stubAuthStore) RegisterUser(context.Context, string, string) (store.User, store.Project, error) {
return store.User{}, store.Project{}, nil
}
func (s stubAuthStore) GetUserByEmail(context.Context, string) (store.User, error) {
return store.User{}, nil
}
func (s stubAuthStore) GetUserByID(context.Context, uuid.UUID) (store.User, error) {
return store.User{}, nil
}
func (s stubAuthStore) GetUserProject(context.Context, uuid.UUID) (store.Project, error) {
return store.Project{}, nil
}
func (s stubAuthStore) GetProjectOwned(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error) {
return s.getProjectOwnedFn(ctx, projectID, userID)
}
// alwaysOwnedAuthStore builds a stubAuthStore whose GetProjectOwned always
// succeeds, treating the caller as the owner of whatever project id is
// requested — for tests that only care about behavior past the
// RequireProjectAccess boundary (CRUD/check/apply tests).
func alwaysOwnedAuthStore() stubAuthStore {
return stubAuthStore{getProjectOwnedFn: func(_ context.Context, pid, uid uuid.UUID) (store.Project, error) {
return store.Project{ID: pid, UserID: uid}, nil
}}
}
// requestWithSessionCookie builds an httptest request carrying a session
// cookie so it clears RequireAuth in tests using stubSessions/alwaysValidSessions
// (which ignore the token value and validate unconditionally).
func requestWithSessionCookie(method, url string, body io.Reader) *http.Request {
req := httptest.NewRequest(method, url, body)
req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: "test-session-token"})
return req
}
// recordingCheckApplier is a CheckApplier test double that records whether
// Check/Apply were invoked — used by the IDOR regression test to assert the
// service layer is never reached for a project the caller doesn't own.
type recordingCheckApplier struct {
checkCalled bool
applyCalled bool
}
func (r *recordingCheckApplier) Check(context.Context, uuid.UUID, uuid.UUID) (diff.Changeset, error) {
r.checkCalled = true
return diff.Changeset{}, nil
}
func (r *recordingCheckApplier) Apply(context.Context, uuid.UUID, uuid.UUID, service.ApplyRequest) (diff.Changeset, error) {
r.applyCalled = true
return diff.Changeset{}, nil
}
// --- RequireAuth ---
func TestRequireAuth_NoCookie_Returns401(t *testing.T) {
a := &API{Sessions: stubSessions{validateFn: func(context.Context, string) (uuid.UUID, error) {
t.Fatal("Validate must not be called when there is no cookie")
return uuid.Nil, nil
}}}
nextCalled := false
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { nextCalled = true })
req := httptest.NewRequest(http.MethodGet, "/whatever", nil)
w := httptest.NewRecorder()
a.RequireAuth(next).ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", w.Code)
}
if nextCalled {
t.Fatal("next must not be called without a session cookie")
}
}
func TestRequireAuth_InvalidToken_Returns401(t *testing.T) {
a := &API{Sessions: stubSessions{validateFn: func(context.Context, string) (uuid.UUID, error) {
return uuid.Nil, errors.New("invalid or expired session")
}}}
nextCalled := false
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { nextCalled = true })
req := requestWithSessionCookie(http.MethodGet, "/whatever", nil)
w := httptest.NewRecorder()
a.RequireAuth(next).ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", w.Code)
}
if nextCalled {
t.Fatal("next must not be called when Sessions.Validate fails")
}
}
func TestRequireAuth_Success_CallsNextWithUserID(t *testing.T) {
uid := uuid.New()
a := &API{Sessions: alwaysValidSessions(uid)}
var gotUID uuid.UUID
var gotOK bool
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotUID, gotOK = userIDFrom(r.Context())
w.WriteHeader(http.StatusOK)
})
req := requestWithSessionCookie(http.MethodGet, "/whatever", nil)
w := httptest.NewRecorder()
a.RequireAuth(next).ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200 (next called), got %d", w.Code)
}
if !gotOK || gotUID != uid {
t.Fatalf("expected userIDFrom to yield %s, got %s (ok=%v)", uid, gotUID, gotOK)
}
}
// --- RequireProjectAccess ---
func TestRequireProjectAccess_NotOwned_Returns404(t *testing.T) {
a := &API{Auth: stubAuthStore{getProjectOwnedFn: func(context.Context, uuid.UUID, uuid.UUID) (store.Project, error) {
return store.Project{}, errors.New("no rows")
}}}
nextCalled := false
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { nextCalled = true })
router := chi.NewRouter()
router.With(a.RequireProjectAccess).Get("/projects/{pid}", next.ServeHTTP)
req := httptest.NewRequest(http.MethodGet, "/projects/"+uuid.New().String(), nil)
req = req.WithContext(context.WithValue(req.Context(), ctxUserID, uuid.New()))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404 for an unowned/foreign project, got %d body %s", w.Code, w.Body.String())
}
if nextCalled {
t.Fatal("next must not be called when the project isn't owned by the caller")
}
}
func TestRequireProjectAccess_Success_CallsNextWithProjectID(t *testing.T) {
a := &API{Auth: alwaysOwnedAuthStore()}
pid := uuid.New()
uid := uuid.New()
var gotPID uuid.UUID
var gotOK bool
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPID, gotOK = projectIDFrom(r.Context())
w.WriteHeader(http.StatusOK)
})
router := chi.NewRouter()
router.With(a.RequireProjectAccess).Get("/projects/{pid}", next.ServeHTTP)
req := httptest.NewRequest(http.MethodGet, "/projects/"+pid.String(), nil)
req = req.WithContext(context.WithValue(req.Context(), ctxUserID, uid))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200 (next called), got %d body %s", w.Code, w.Body.String())
}
if !gotOK || gotPID != pid {
t.Fatalf("expected projectIDFrom to yield %s, got %s (ok=%v)", pid, gotPID, gotOK)
}
}
// --- IDOR regression ---
// TestIDOR_CheckForeignProject_Returns404AndServiceNotCalled is the IDOR
// regression required by Task 4: a request for project B's domain, made
// with user A's session (who owns project A, not B), must be rejected by
// RequireProjectAccess with 404 before service.Check is ever invoked — a
// caller must never be able to check/apply another tenant's domain by
// guessing/leaking its ids.
func TestIDOR_CheckForeignProject_Returns404AndServiceNotCalled(t *testing.T) {
userA := uuid.New()
pidA := uuid.New()
pidB := uuid.New() // owned by a different user; not A
svc := &recordingCheckApplier{}
auth := stubAuthStore{getProjectOwnedFn: func(_ context.Context, pid, uid uuid.UUID) (store.Project, error) {
if pid == pidA && uid == userA {
return store.Project{ID: pidA, UserID: userA}, nil
}
return store.Project{}, errors.New("not found")
}}
a := &API{Svc: svc, Auth: auth, Sessions: alwaysValidSessions(userA)}
router := NewRouter(a)
did := uuid.New().String()
req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+pidB.String()+"/domains/"+did+"/check", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404 for user A requesting project B's domain, got %d body %s", w.Code, w.Body.String())
}
if svc.checkCalled {
t.Fatal("service.Check must not be called when the caller doesn't own the project")
}
// Sanity check: the same user against their own project succeeds and
// does reach the service — proving the 404 above is really about
// project ownership, not e.g. a broken route.
req2 := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+pidA.String()+"/domains/"+did+"/check", nil)
w2 := httptest.NewRecorder()
router.ServeHTTP(w2, req2)
if w2.Code != http.StatusOK {
t.Fatalf("expected 200 for user A requesting their own project, got %d body %s", w2.Code, w2.Body.String())
}
if !svc.checkCalled {
t.Fatal("expected service.Check to be called for the caller's own project")
}
}
+36 -60
View File
@@ -24,11 +24,9 @@ func decodeBody(w http.ResponseWriter, r *http.Request, v any) bool {
// --- accounts ---
func (a *API) handleCreateAccount(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid project id")
return
}
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
var req accountRequest
if !decodeBody(w, r, &req) {
return
@@ -53,11 +51,9 @@ func (a *API) handleCreateAccount(w http.ResponseWriter, r *http.Request) {
}
func (a *API) handleListAccounts(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid project id")
return
}
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
accs, err := a.Store.ListAccounts(r.Context(), pid)
if err != nil {
log.Printf("api: list accounts failed: %v", err)
@@ -72,11 +68,9 @@ func (a *API) handleListAccounts(w http.ResponseWriter, r *http.Request) {
}
func (a *API) handleDeleteAccount(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid project id")
return
}
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
aid, err := uuid.Parse(chi.URLParam(r, "aid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid account id")
@@ -93,11 +87,9 @@ func (a *API) handleDeleteAccount(w http.ResponseWriter, r *http.Request) {
// handleImportZones lists zones from the provider for the given account and
// creates one domain per zone (template_id left unset).
func (a *API) handleImportZones(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid project id")
return
}
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
aid, err := uuid.Parse(chi.URLParam(r, "aid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid account id")
@@ -145,11 +137,9 @@ func (a *API) handleImportZones(w http.ResponseWriter, r *http.Request) {
// --- templates ---
func (a *API) handleCreateTemplate(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid project id")
return
}
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
var req templateRequest
if !decodeBody(w, r, &req) {
return
@@ -169,11 +159,9 @@ func (a *API) handleCreateTemplate(w http.ResponseWriter, r *http.Request) {
}
func (a *API) handleListTemplates(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid project id")
return
}
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
tpls, err := a.Store.ListTemplates(r.Context(), pid)
if err != nil {
log.Printf("api: list templates failed: %v", err)
@@ -188,11 +176,9 @@ func (a *API) handleListTemplates(w http.ResponseWriter, r *http.Request) {
}
func (a *API) handleUpdateTemplate(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid project id")
return
}
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
tid, err := uuid.Parse(chi.URLParam(r, "tid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid template id")
@@ -217,11 +203,9 @@ func (a *API) handleUpdateTemplate(w http.ResponseWriter, r *http.Request) {
}
func (a *API) handleDeleteTemplate(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid project id")
return
}
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
tid, err := uuid.Parse(chi.URLParam(r, "tid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid template id")
@@ -238,11 +222,9 @@ func (a *API) handleDeleteTemplate(w http.ResponseWriter, r *http.Request) {
// --- domains ---
func (a *API) handleCreateDomain(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid project id")
return
}
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
var req domainRequest
if !decodeBody(w, r, &req) {
return
@@ -286,11 +268,9 @@ func (a *API) handleCreateDomain(w http.ResponseWriter, r *http.Request) {
}
func (a *API) handleListDomains(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid project id")
return
}
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
doms, err := a.Store.ListDomains(r.Context(), pid)
if err != nil {
log.Printf("api: list domains failed: %v", err)
@@ -308,11 +288,9 @@ func (a *API) handleListDomains(w http.ResponseWriter, r *http.Request) {
// check/apply a domain — this is what makes an imported domain (which
// starts with template_id=NULL) checkable, closing the import→check loop.
func (a *API) handleSetDomainTemplate(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid project id")
return
}
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
did, err := uuid.Parse(chi.URLParam(r, "did"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid domain id")
@@ -339,11 +317,9 @@ func (a *API) handleSetDomainTemplate(w http.ResponseWriter, r *http.Request) {
}
func (a *API) handleDeleteDomain(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid project id")
return
}
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
did, err := uuid.Parse(chi.URLParam(r, "did"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid domain id")
+29 -19
View File
@@ -132,7 +132,9 @@ func (m *mockTenantStore) ImportDomains(_ context.Context, projectID, accountID
type mockCipher struct{}
func (mockCipher) Encrypt(plaintext []byte) (string, error) { return "ENC(" + string(plaintext) + ")", nil }
func (mockCipher) Encrypt(plaintext []byte) (string, error) {
return "ENC(" + string(plaintext) + ")", nil
}
func (mockCipher) Decrypt(enc string) ([]byte, error) {
return []byte(strings.TrimSuffix(strings.TrimPrefix(enc, "ENC("), ")")), nil
}
@@ -160,9 +162,17 @@ func (mockProvider) ApplyChanges(context.Context, provider.Credentials, string,
return nil
}
// newTenantTestAPI wires a fixed authenticated user who owns whatever
// project id is requested (alwaysOwnedAuthStore/alwaysValidSessions, see
// middleware_test.go) — these tests exercise CRUD behavior past the
// RequireAuth/RequireProjectAccess boundary, which has its own dedicated
// coverage in middleware_test.go.
func newTenantTestAPI() (*API, *mockTenantStore) {
ts := &mockTenantStore{}
a := &API{Store: ts, Cipher: mockCipher{}, Reg: &mockRegistry{}}
a := &API{
Store: ts, Cipher: mockCipher{}, Reg: &mockRegistry{},
Auth: alwaysOwnedAuthStore(), Sessions: alwaysValidSessions(uuid.New()),
}
return a, ts
}
@@ -173,7 +183,7 @@ func TestCreateAccount_SecretEncryptedAndNotInResponse(t *testing.T) {
router := NewRouter(a)
body := `{"provider":"selectel","secret":"super-secret-token","comment":"prod"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts", strings.NewReader(body))
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts", strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -211,7 +221,7 @@ func TestListAccounts_NoSecretsInResponse(t *testing.T) {
}
router := NewRouter(a)
req := httptest.NewRequest(http.MethodGet, "/api/v1/projects/"+testPID+"/accounts", nil)
req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+testPID+"/accounts", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -234,7 +244,7 @@ func TestDeleteAccount_BadUUID(t *testing.T) {
a, _ := newTenantTestAPI()
router := NewRouter(a)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/projects/"+testPID+"/accounts/not-a-uuid", nil)
req := requestWithSessionCookie(http.MethodDelete, "/api/v1/projects/"+testPID+"/accounts/not-a-uuid", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -250,7 +260,7 @@ func TestCreateTemplate_SavesRecords(t *testing.T) {
router := NewRouter(a)
body := `{"name":"base","records":[{"type":"A","name":"@","ttl":300,"values":["1.2.3.4"]}]}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/templates", strings.NewReader(body))
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/templates", strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -278,7 +288,7 @@ func TestUpdateTemplate_BadUUID(t *testing.T) {
router := NewRouter(a)
body := `{"name":"x","records":[]}`
req := httptest.NewRequest(http.MethodPut, "/api/v1/projects/"+testPID+"/templates/not-a-uuid", strings.NewReader(body))
req := requestWithSessionCookie(http.MethodPut, "/api/v1/projects/"+testPID+"/templates/not-a-uuid", strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -299,7 +309,7 @@ func TestImportZones_CreatesDomainPerZone(t *testing.T) {
}}
router := NewRouter(a)
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/"+accID.String()+"/import", nil)
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/"+accID.String()+"/import", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -337,7 +347,7 @@ func TestImportZones_AtomicRollbackOnError(t *testing.T) {
}}
router := NewRouter(a)
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/"+accID.String()+"/import", nil)
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/"+accID.String()+"/import", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -356,7 +366,7 @@ func TestImportZones_BadAccountUUID(t *testing.T) {
a, _ := newTenantTestAPI()
router := NewRouter(a)
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/not-a-uuid/import", nil)
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/not-a-uuid/import", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -370,7 +380,7 @@ func TestCreateDomain_BadProjectUUID(t *testing.T) {
router := NewRouter(a)
body := `{"providerAccountId":"` + uuid.New().String() + `","zoneName":"example.com","zoneId":"z1"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/not-a-uuid/domains", strings.NewReader(body))
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/not-a-uuid/domains", strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -390,7 +400,7 @@ func TestCreateDomain_AccountNotFoundInProject(t *testing.T) {
// ts.accounts is empty, so GetAccount will not find this id.
foreignAccID := uuid.New()
body := `{"providerAccountId":"` + foreignAccID.String() + `","zoneName":"example.com","zoneId":"z1"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -413,7 +423,7 @@ func TestCreateDomain_TemplateNotFoundInProject(t *testing.T) {
foreignTplID := uuid.New()
body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1","templateId":"` + foreignTplID.String() + `"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -434,7 +444,7 @@ func TestCreateDomain_HappyPath(t *testing.T) {
router := NewRouter(a)
body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -464,7 +474,7 @@ func TestCreateDomain_ValidTemplateInProject(t *testing.T) {
router := NewRouter(a)
body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1","templateId":"` + tplID.String() + `"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -490,7 +500,7 @@ func TestSetDomainTemplate_ValidTemplateId(t *testing.T) {
router := NewRouter(a)
body := `{"templateId":"` + tplID.String() + `"}`
req := httptest.NewRequest(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body))
req := requestWithSessionCookie(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -513,7 +523,7 @@ func TestSetDomainTemplate_BadTemplateUUID(t *testing.T) {
router := NewRouter(a)
body := `{"templateId":"not-a-uuid"}`
req := httptest.NewRequest(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body))
req := requestWithSessionCookie(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -530,7 +540,7 @@ func TestSetDomainTemplate_TemplateNotFound(t *testing.T) {
router := NewRouter(a)
body := `{"templateId":"` + uuid.New().String() + `"}`
req := httptest.NewRequest(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body))
req := requestWithSessionCookie(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -543,7 +553,7 @@ func TestDeleteDomain_BadUUID(t *testing.T) {
a, _ := newTenantTestAPI()
router := NewRouter(a)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/projects/"+testPID+"/domains/not-a-uuid", nil)
req := requestWithSessionCookie(http.MethodDelete, "/api/v1/projects/"+testPID+"/domains/not-a-uuid", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)