feat(api): RequireAuth+RequireProjectAccess middleware, IDOR-scope check/apply по projectID

This commit is contained in:
2026-07-03 20:47:40 +07:00
parent 35ffe73ae3
commit 4533b0ca25
16 changed files with 498 additions and 143 deletions
+14 -4
View File
@@ -18,8 +18,8 @@ import (
// CheckApplier is the service surface the API depends on.
type CheckApplier interface {
Check(ctx context.Context, domainID uuid.UUID) (diff.Changeset, error)
Apply(ctx context.Context, domainID uuid.UUID, req service.ApplyRequest) (diff.Changeset, error)
Check(ctx context.Context, projectID, domainID uuid.UUID) (diff.Changeset, error)
Apply(ctx context.Context, projectID, domainID uuid.UUID, req service.ApplyRequest) (diff.Changeset, error)
}
// TenantStore is the narrow persistence surface the CRUD handlers depend on.
@@ -63,6 +63,10 @@ type AuthStore interface {
GetUserByEmail(ctx context.Context, email string) (store.User, error)
GetUserByID(ctx context.Context, userID uuid.UUID) (store.User, error)
GetUserProject(ctx context.Context, userID uuid.UUID) (store.Project, error)
// GetProjectOwned looks up projectID and returns it only if it's owned by
// userID — RequireProjectAccess uses this to reject foreign/nonexistent
// projects with 404 before any handler runs.
GetProjectOwned(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error)
}
// SessionManager creates/validates/destroys login sessions. *auth.Sessions
@@ -91,11 +95,17 @@ func NewRouter(a *API) http.Handler {
r.Route("/api/v1/auth", func(r chi.Router) {
r.Post("/register", a.handleRegister)
r.Post("/login", a.handleLogin)
r.Post("/logout", a.handleLogout) // защитится RequireAuth в Task 4
r.Get("/me", a.handleMe) // защитится RequireAuth в Task 4
r.Group(func(r chi.Router) {
r.Use(a.RequireAuth)
r.Post("/logout", a.handleLogout)
r.Get("/me", a.handleMe)
})
})
r.Route("/api/v1/projects/{pid}", func(r chi.Router) {
r.Use(a.RequireAuth)
r.Use(a.RequireProjectAccess)
r.Route("/domains", func(r chi.Router) {
r.Post("/", a.handleCreateDomain)
r.Get("/", a.handleListDomains)
+13 -8
View File
@@ -20,18 +20,23 @@ type mockCheckApplier struct {
lastReq service.ApplyRequest
}
func (m *mockCheckApplier) Check(context.Context, uuid.UUID) (diff.Changeset, error) {
func (m *mockCheckApplier) Check(context.Context, uuid.UUID, uuid.UUID) (diff.Changeset, error) {
d := model.Record{Type: model.A, Name: "a.example.com.", TTL: 300, Values: []string{"1.1.1.1"}}
return diff.Changeset{Diffs: []diff.RecordDiff{{Kind: diff.Add, Type: d.Type, Name: d.Name, Desired: &d}}}, nil
}
func (m *mockCheckApplier) Apply(_ context.Context, _ uuid.UUID, req service.ApplyRequest) (diff.Changeset, error) {
func (m *mockCheckApplier) Apply(_ context.Context, _, _ uuid.UUID, req service.ApplyRequest) (diff.Changeset, error) {
m.lastReq = req
return diff.Changeset{}, nil
}
// newTestAPI wires a fixed authenticated user who owns whatever project id
// is requested (via alwaysOwnedAuthStore/alwaysValidSessions in
// middleware_test.go) — these tests exercise check/apply behavior past the
// RequireAuth/RequireProjectAccess boundary, which is covered separately by
// middleware_test.go's own tests and the IDOR regression.
func newTestAPI() (*API, *mockCheckApplier) {
m := &mockCheckApplier{}
return &API{Svc: m}, m // остальные зависимости (store/cipher) nil — CRUD-тесты добавит реализатор
return &API{Svc: m, Auth: alwaysOwnedAuthStore(), Sessions: alwaysValidSessions(uuid.New())}, m
}
func TestCheckEndpoint(t *testing.T) {
@@ -39,7 +44,7 @@ func TestCheckEndpoint(t *testing.T) {
router := NewRouter(a)
did := uuid.New().String()
req := httptest.NewRequest(http.MethodGet,
req := requestWithSessionCookie(http.MethodGet,
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/check", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -62,7 +67,7 @@ func TestApplyDefaultsPruneFalse(t *testing.T) {
did := uuid.New().String()
body := `{"applyUpdates":true}` // applyPrunes отсутствует → false
req := httptest.NewRequest(http.MethodPost,
req := requestWithSessionCookie(http.MethodPost,
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply",
strings.NewReader(body))
w := httptest.NewRecorder()
@@ -81,7 +86,7 @@ func TestApplyEmptyBodyOK(t *testing.T) {
router := NewRouter(a)
did := uuid.New().String()
req := httptest.NewRequest(http.MethodPost,
req := requestWithSessionCookie(http.MethodPost,
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -100,7 +105,7 @@ func TestApplyMalformedBody(t *testing.T) {
did := uuid.New().String()
body := `{"applyUpdates":`
req := httptest.NewRequest(http.MethodPost,
req := requestWithSessionCookie(http.MethodPost,
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply",
strings.NewReader(body))
w := httptest.NewRecorder()
@@ -114,7 +119,7 @@ func TestApplyMalformedBody(t *testing.T) {
func TestApplyBadUUID(t *testing.T) {
a, _ := newTestAPI()
router := NewRouter(a)
req := httptest.NewRequest(http.MethodPost,
req := requestWithSessionCookie(http.MethodPost,
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/not-a-uuid/apply",
bytes.NewReader([]byte(`{}`)))
w := httptest.NewRecorder()
+3 -20
View File
@@ -1,15 +1,12 @@
package api
import (
"context"
"errors"
"log"
"net/http"
"strings"
"time"
"github.com/google/uuid"
"github.com/vasyakrg/dns-autoresolver/internal/auth"
"github.com/vasyakrg/dns-autoresolver/internal/store"
)
@@ -39,19 +36,6 @@ func normalizeEmail(email string) string {
return strings.ToLower(strings.TrimSpace(email))
}
// ctxKeyUserID is a private context key carrying the authenticated user's ID.
// Task 4's RequireAuth middleware sets it after validating the session
// cookie; handleMe reads it back.
type ctxKeyUserID struct{}
// userIDFromContext extracts the authenticated user ID set by RequireAuth
// (Task 4). Until that middleware is wired in, tests set it directly via
// context.WithValue.
func userIDFromContext(ctx context.Context) (uuid.UUID, bool) {
id, ok := ctx.Value(ctxKeyUserID{}).(uuid.UUID)
return id, ok
}
func setSessionCookie(w http.ResponseWriter, token string, exp time.Time) {
http.SetCookie(w, &http.Cookie{
Name: sessionCookieName, Value: token, Path: "/",
@@ -165,11 +149,10 @@ func (a *API) handleLogout(w http.ResponseWriter, r *http.Request) {
}
// handleMe returns the authenticated caller's identity + default project.
// The user ID comes from the request context, set by Task 4's RequireAuth
// middleware after validating the session cookie (tests set it directly via
// context.WithValue in the interim).
// The user ID comes from the request context, set by RequireAuth after
// validating the session cookie.
func (a *API) handleMe(w http.ResponseWriter, r *http.Request) {
userID, ok := userIDFromContext(r.Context())
userID, ok := userIDFrom(r.Context())
if !ok {
writeErr(w, http.StatusUnauthorized, "authentication required")
return
+17 -3
View File
@@ -22,6 +22,7 @@ type mockAuthStore struct {
getUserByEmailFn func(ctx context.Context, email string) (store.User, error)
getUserByIDFn func(ctx context.Context, userID uuid.UUID) (store.User, error)
getUserProjectFn func(ctx context.Context, userID uuid.UUID) (store.Project, error)
getProjectOwnedFn func(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error)
}
func (m *mockAuthStore) RegisterUser(ctx context.Context, email, passwordHash string) (store.User, store.Project, error) {
@@ -40,8 +41,13 @@ func (m *mockAuthStore) GetUserProject(ctx context.Context, userID uuid.UUID) (s
return m.getUserProjectFn(ctx, userID)
}
func (m *mockAuthStore) GetProjectOwned(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error) {
return m.getProjectOwnedFn(ctx, projectID, userID)
}
type mockSessionManager struct {
createFn func(ctx context.Context, userID uuid.UUID) (string, time.Time, error)
validateFn func(ctx context.Context, token string) (uuid.UUID, error)
destroyCalled bool
destroyToken string
@@ -52,7 +58,10 @@ func (m *mockSessionManager) Create(ctx context.Context, userID uuid.UUID) (stri
return m.createFn(ctx, userID)
}
func (m *mockSessionManager) Validate(context.Context, string) (uuid.UUID, error) {
func (m *mockSessionManager) Validate(ctx context.Context, token string) (uuid.UUID, error) {
if m.validateFn != nil {
return m.validateFn(ctx, token)
}
return uuid.Nil, nil
}
@@ -338,7 +347,7 @@ func TestAuthLogout_ClearsSessionAndDestroys(t *testing.T) {
// now resolves the authenticated user via GetUserByID and returns their real
// email, instead of leaving it blank.
func TestAuthMe_ReturnsRealEmail(t *testing.T) {
a, authStore, _ := newTestAuthAPI()
a, authStore, sessions := newTestAuthAPI()
userID := uuid.New()
projectID := uuid.New()
authStore.getUserByIDFn = func(_ context.Context, id uuid.UUID) (store.User, error) {
@@ -350,10 +359,15 @@ func TestAuthMe_ReturnsRealEmail(t *testing.T) {
authStore.getUserProjectFn = func(_ context.Context, uid uuid.UUID) (store.Project, error) {
return store.Project{ID: projectID, UserID: uid, Name: "default"}, nil
}
// /me is now behind RequireAuth: the session cookie must resolve to
// userID via Sessions.Validate rather than being injected directly.
sessions.validateFn = func(context.Context, string) (uuid.UUID, error) {
return userID, nil
}
router := NewRouter(a)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/me", nil)
req = req.WithContext(context.WithValue(req.Context(), ctxKeyUserID{}, userID))
req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: "some-valid-token"})
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
+8 -2
View File
@@ -24,12 +24,15 @@ func writeErr(w http.ResponseWriter, status int, msg string) {
}
func (a *API) handleCheck(w http.ResponseWriter, r *http.Request) {
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
did, err := uuid.Parse(chi.URLParam(r, "did"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid domain id")
return
}
cs, err := a.Svc.Check(r.Context(), did)
cs, err := a.Svc.Check(r.Context(), pid, did)
if err != nil {
log.Printf("api: check failed: %v", err)
writeErr(w, http.StatusInternalServerError, "internal error")
@@ -39,6 +42,9 @@ func (a *API) handleCheck(w http.ResponseWriter, r *http.Request) {
}
func (a *API) handleApply(w http.ResponseWriter, r *http.Request) {
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
did, err := uuid.Parse(chi.URLParam(r, "did"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid domain id")
@@ -54,7 +60,7 @@ func (a *API) handleApply(w http.ResponseWriter, r *http.Request) {
return
}
}
cs, err := a.Svc.Apply(r.Context(), did, service.ApplyRequest{
cs, err := a.Svc.Apply(r.Context(), pid, did, service.ApplyRequest{
ApplyUpdates: req.ApplyUpdates, ApplyPrunes: req.ApplyPrunes,
})
if err != nil {
+75
View File
@@ -0,0 +1,75 @@
package api
import (
"context"
"net/http"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
)
type ctxKey string
const (
ctxUserID ctxKey = "userID"
ctxProjectID ctxKey = "projectID"
)
// RequireAuth validates the session cookie and, on success, stores the
// authenticated user's ID in the request context (see userIDFrom). Any
// failure — missing cookie or an invalid/expired token — is rejected with
// 401 before reaching the wrapped handler.
func (a *API) RequireAuth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := r.Cookie(sessionCookieName)
if err != nil {
writeErr(w, http.StatusUnauthorized, "unauthorized")
return
}
uid, err := a.Sessions.Validate(r.Context(), c.Value)
if err != nil {
writeErr(w, http.StatusUnauthorized, "unauthorized")
return
}
ctx := context.WithValue(r.Context(), ctxUserID, uid)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// RequireProjectAccess verifies that the {pid} URL segment names a project
// owned by the authenticated user (set by RequireAuth, which must run
// first) and, on success, stores the project ID in the request context (see
// projectIDFrom). A project that doesn't exist or isn't owned by the caller
// is rejected with 404 — not 403 — so a caller can't distinguish "not yours"
// from "doesn't exist" (closes IDOR by never confirming another tenant's
// project exists).
func (a *API) RequireProjectAccess(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
uid, ok := userIDFrom(r.Context())
if !ok {
writeErr(w, http.StatusUnauthorized, "unauthorized")
return
}
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid project id")
return
}
if _, err := a.Auth.GetProjectOwned(r.Context(), pid, uid); err != nil {
writeErr(w, http.StatusNotFound, "not found")
return
}
ctx := context.WithValue(r.Context(), ctxProjectID, pid)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func userIDFrom(ctx context.Context) (uuid.UUID, bool) {
v, ok := ctx.Value(ctxUserID).(uuid.UUID)
return v, ok
}
func projectIDFrom(ctx context.Context) (uuid.UUID, bool) {
v, ok := ctx.Value(ctxProjectID).(uuid.UUID)
return v, ok
}
+265
View File
@@ -0,0 +1,265 @@
package api
import (
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/vasyakrg/dns-autoresolver/internal/diff"
"github.com/vasyakrg/dns-autoresolver/internal/service"
"github.com/vasyakrg/dns-autoresolver/internal/store"
)
// --- shared test doubles (also used by api_test.go / tenant_test.go) ---
// stubSessions is a configurable SessionManager test double.
type stubSessions struct {
validateFn func(ctx context.Context, token string) (uuid.UUID, error)
}
func (s stubSessions) Create(context.Context, uuid.UUID) (string, time.Time, error) {
return "stub-token", time.Now().Add(time.Hour), nil
}
func (s stubSessions) Validate(ctx context.Context, token string) (uuid.UUID, error) {
return s.validateFn(ctx, token)
}
func (s stubSessions) Destroy(context.Context, string) error { return nil }
// alwaysValidSessions builds a stubSessions whose Validate always succeeds
// with the given fixed user ID — for tests that only care about behavior
// past the RequireAuth boundary.
func alwaysValidSessions(uid uuid.UUID) stubSessions {
return stubSessions{validateFn: func(context.Context, string) (uuid.UUID, error) { return uid, nil }}
}
// stubAuthStore is a configurable AuthStore test double. Methods besides
// GetProjectOwned return zero values — tests exercising register/login/me
// use their own dedicated mock (mockAuthStore in auth_test.go).
type stubAuthStore struct {
getProjectOwnedFn func(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error)
}
func (s stubAuthStore) RegisterUser(context.Context, string, string) (store.User, store.Project, error) {
return store.User{}, store.Project{}, nil
}
func (s stubAuthStore) GetUserByEmail(context.Context, string) (store.User, error) {
return store.User{}, nil
}
func (s stubAuthStore) GetUserByID(context.Context, uuid.UUID) (store.User, error) {
return store.User{}, nil
}
func (s stubAuthStore) GetUserProject(context.Context, uuid.UUID) (store.Project, error) {
return store.Project{}, nil
}
func (s stubAuthStore) GetProjectOwned(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error) {
return s.getProjectOwnedFn(ctx, projectID, userID)
}
// alwaysOwnedAuthStore builds a stubAuthStore whose GetProjectOwned always
// succeeds, treating the caller as the owner of whatever project id is
// requested — for tests that only care about behavior past the
// RequireProjectAccess boundary (CRUD/check/apply tests).
func alwaysOwnedAuthStore() stubAuthStore {
return stubAuthStore{getProjectOwnedFn: func(_ context.Context, pid, uid uuid.UUID) (store.Project, error) {
return store.Project{ID: pid, UserID: uid}, nil
}}
}
// requestWithSessionCookie builds an httptest request carrying a session
// cookie so it clears RequireAuth in tests using stubSessions/alwaysValidSessions
// (which ignore the token value and validate unconditionally).
func requestWithSessionCookie(method, url string, body io.Reader) *http.Request {
req := httptest.NewRequest(method, url, body)
req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: "test-session-token"})
return req
}
// recordingCheckApplier is a CheckApplier test double that records whether
// Check/Apply were invoked — used by the IDOR regression test to assert the
// service layer is never reached for a project the caller doesn't own.
type recordingCheckApplier struct {
checkCalled bool
applyCalled bool
}
func (r *recordingCheckApplier) Check(context.Context, uuid.UUID, uuid.UUID) (diff.Changeset, error) {
r.checkCalled = true
return diff.Changeset{}, nil
}
func (r *recordingCheckApplier) Apply(context.Context, uuid.UUID, uuid.UUID, service.ApplyRequest) (diff.Changeset, error) {
r.applyCalled = true
return diff.Changeset{}, nil
}
// --- RequireAuth ---
func TestRequireAuth_NoCookie_Returns401(t *testing.T) {
a := &API{Sessions: stubSessions{validateFn: func(context.Context, string) (uuid.UUID, error) {
t.Fatal("Validate must not be called when there is no cookie")
return uuid.Nil, nil
}}}
nextCalled := false
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { nextCalled = true })
req := httptest.NewRequest(http.MethodGet, "/whatever", nil)
w := httptest.NewRecorder()
a.RequireAuth(next).ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", w.Code)
}
if nextCalled {
t.Fatal("next must not be called without a session cookie")
}
}
func TestRequireAuth_InvalidToken_Returns401(t *testing.T) {
a := &API{Sessions: stubSessions{validateFn: func(context.Context, string) (uuid.UUID, error) {
return uuid.Nil, errors.New("invalid or expired session")
}}}
nextCalled := false
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { nextCalled = true })
req := requestWithSessionCookie(http.MethodGet, "/whatever", nil)
w := httptest.NewRecorder()
a.RequireAuth(next).ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", w.Code)
}
if nextCalled {
t.Fatal("next must not be called when Sessions.Validate fails")
}
}
func TestRequireAuth_Success_CallsNextWithUserID(t *testing.T) {
uid := uuid.New()
a := &API{Sessions: alwaysValidSessions(uid)}
var gotUID uuid.UUID
var gotOK bool
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotUID, gotOK = userIDFrom(r.Context())
w.WriteHeader(http.StatusOK)
})
req := requestWithSessionCookie(http.MethodGet, "/whatever", nil)
w := httptest.NewRecorder()
a.RequireAuth(next).ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200 (next called), got %d", w.Code)
}
if !gotOK || gotUID != uid {
t.Fatalf("expected userIDFrom to yield %s, got %s (ok=%v)", uid, gotUID, gotOK)
}
}
// --- RequireProjectAccess ---
func TestRequireProjectAccess_NotOwned_Returns404(t *testing.T) {
a := &API{Auth: stubAuthStore{getProjectOwnedFn: func(context.Context, uuid.UUID, uuid.UUID) (store.Project, error) {
return store.Project{}, errors.New("no rows")
}}}
nextCalled := false
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { nextCalled = true })
router := chi.NewRouter()
router.With(a.RequireProjectAccess).Get("/projects/{pid}", next.ServeHTTP)
req := httptest.NewRequest(http.MethodGet, "/projects/"+uuid.New().String(), nil)
req = req.WithContext(context.WithValue(req.Context(), ctxUserID, uuid.New()))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404 for an unowned/foreign project, got %d body %s", w.Code, w.Body.String())
}
if nextCalled {
t.Fatal("next must not be called when the project isn't owned by the caller")
}
}
func TestRequireProjectAccess_Success_CallsNextWithProjectID(t *testing.T) {
a := &API{Auth: alwaysOwnedAuthStore()}
pid := uuid.New()
uid := uuid.New()
var gotPID uuid.UUID
var gotOK bool
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPID, gotOK = projectIDFrom(r.Context())
w.WriteHeader(http.StatusOK)
})
router := chi.NewRouter()
router.With(a.RequireProjectAccess).Get("/projects/{pid}", next.ServeHTTP)
req := httptest.NewRequest(http.MethodGet, "/projects/"+pid.String(), nil)
req = req.WithContext(context.WithValue(req.Context(), ctxUserID, uid))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200 (next called), got %d body %s", w.Code, w.Body.String())
}
if !gotOK || gotPID != pid {
t.Fatalf("expected projectIDFrom to yield %s, got %s (ok=%v)", pid, gotPID, gotOK)
}
}
// --- IDOR regression ---
// TestIDOR_CheckForeignProject_Returns404AndServiceNotCalled is the IDOR
// regression required by Task 4: a request for project B's domain, made
// with user A's session (who owns project A, not B), must be rejected by
// RequireProjectAccess with 404 before service.Check is ever invoked — a
// caller must never be able to check/apply another tenant's domain by
// guessing/leaking its ids.
func TestIDOR_CheckForeignProject_Returns404AndServiceNotCalled(t *testing.T) {
userA := uuid.New()
pidA := uuid.New()
pidB := uuid.New() // owned by a different user; not A
svc := &recordingCheckApplier{}
auth := stubAuthStore{getProjectOwnedFn: func(_ context.Context, pid, uid uuid.UUID) (store.Project, error) {
if pid == pidA && uid == userA {
return store.Project{ID: pidA, UserID: userA}, nil
}
return store.Project{}, errors.New("not found")
}}
a := &API{Svc: svc, Auth: auth, Sessions: alwaysValidSessions(userA)}
router := NewRouter(a)
did := uuid.New().String()
req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+pidB.String()+"/domains/"+did+"/check", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404 for user A requesting project B's domain, got %d body %s", w.Code, w.Body.String())
}
if svc.checkCalled {
t.Fatal("service.Check must not be called when the caller doesn't own the project")
}
// Sanity check: the same user against their own project succeeds and
// does reach the service — proving the 404 above is really about
// project ownership, not e.g. a broken route.
req2 := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+pidA.String()+"/domains/"+did+"/check", nil)
w2 := httptest.NewRecorder()
router.ServeHTTP(w2, req2)
if w2.Code != http.StatusOK {
t.Fatalf("expected 200 for user A requesting their own project, got %d body %s", w2.Code, w2.Body.String())
}
if !svc.checkCalled {
t.Fatal("expected service.Check to be called for the caller's own project")
}
}
+36 -60
View File
@@ -24,11 +24,9 @@ func decodeBody(w http.ResponseWriter, r *http.Request, v any) bool {
// --- accounts ---
func (a *API) handleCreateAccount(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid project id")
return
}
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
var req accountRequest
if !decodeBody(w, r, &req) {
return
@@ -53,11 +51,9 @@ func (a *API) handleCreateAccount(w http.ResponseWriter, r *http.Request) {
}
func (a *API) handleListAccounts(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid project id")
return
}
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
accs, err := a.Store.ListAccounts(r.Context(), pid)
if err != nil {
log.Printf("api: list accounts failed: %v", err)
@@ -72,11 +68,9 @@ func (a *API) handleListAccounts(w http.ResponseWriter, r *http.Request) {
}
func (a *API) handleDeleteAccount(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid project id")
return
}
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
aid, err := uuid.Parse(chi.URLParam(r, "aid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid account id")
@@ -93,11 +87,9 @@ func (a *API) handleDeleteAccount(w http.ResponseWriter, r *http.Request) {
// handleImportZones lists zones from the provider for the given account and
// creates one domain per zone (template_id left unset).
func (a *API) handleImportZones(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid project id")
return
}
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
aid, err := uuid.Parse(chi.URLParam(r, "aid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid account id")
@@ -145,11 +137,9 @@ func (a *API) handleImportZones(w http.ResponseWriter, r *http.Request) {
// --- templates ---
func (a *API) handleCreateTemplate(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid project id")
return
}
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
var req templateRequest
if !decodeBody(w, r, &req) {
return
@@ -169,11 +159,9 @@ func (a *API) handleCreateTemplate(w http.ResponseWriter, r *http.Request) {
}
func (a *API) handleListTemplates(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid project id")
return
}
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
tpls, err := a.Store.ListTemplates(r.Context(), pid)
if err != nil {
log.Printf("api: list templates failed: %v", err)
@@ -188,11 +176,9 @@ func (a *API) handleListTemplates(w http.ResponseWriter, r *http.Request) {
}
func (a *API) handleUpdateTemplate(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid project id")
return
}
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
tid, err := uuid.Parse(chi.URLParam(r, "tid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid template id")
@@ -217,11 +203,9 @@ func (a *API) handleUpdateTemplate(w http.ResponseWriter, r *http.Request) {
}
func (a *API) handleDeleteTemplate(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid project id")
return
}
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
tid, err := uuid.Parse(chi.URLParam(r, "tid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid template id")
@@ -238,11 +222,9 @@ func (a *API) handleDeleteTemplate(w http.ResponseWriter, r *http.Request) {
// --- domains ---
func (a *API) handleCreateDomain(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid project id")
return
}
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
var req domainRequest
if !decodeBody(w, r, &req) {
return
@@ -286,11 +268,9 @@ func (a *API) handleCreateDomain(w http.ResponseWriter, r *http.Request) {
}
func (a *API) handleListDomains(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid project id")
return
}
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
doms, err := a.Store.ListDomains(r.Context(), pid)
if err != nil {
log.Printf("api: list domains failed: %v", err)
@@ -308,11 +288,9 @@ func (a *API) handleListDomains(w http.ResponseWriter, r *http.Request) {
// check/apply a domain — this is what makes an imported domain (which
// starts with template_id=NULL) checkable, closing the import→check loop.
func (a *API) handleSetDomainTemplate(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid project id")
return
}
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
did, err := uuid.Parse(chi.URLParam(r, "did"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid domain id")
@@ -339,11 +317,9 @@ func (a *API) handleSetDomainTemplate(w http.ResponseWriter, r *http.Request) {
}
func (a *API) handleDeleteDomain(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid project id")
return
}
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
did, err := uuid.Parse(chi.URLParam(r, "did"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid domain id")
+29 -19
View File
@@ -132,7 +132,9 @@ func (m *mockTenantStore) ImportDomains(_ context.Context, projectID, accountID
type mockCipher struct{}
func (mockCipher) Encrypt(plaintext []byte) (string, error) { return "ENC(" + string(plaintext) + ")", nil }
func (mockCipher) Encrypt(plaintext []byte) (string, error) {
return "ENC(" + string(plaintext) + ")", nil
}
func (mockCipher) Decrypt(enc string) ([]byte, error) {
return []byte(strings.TrimSuffix(strings.TrimPrefix(enc, "ENC("), ")")), nil
}
@@ -160,9 +162,17 @@ func (mockProvider) ApplyChanges(context.Context, provider.Credentials, string,
return nil
}
// newTenantTestAPI wires a fixed authenticated user who owns whatever
// project id is requested (alwaysOwnedAuthStore/alwaysValidSessions, see
// middleware_test.go) — these tests exercise CRUD behavior past the
// RequireAuth/RequireProjectAccess boundary, which has its own dedicated
// coverage in middleware_test.go.
func newTenantTestAPI() (*API, *mockTenantStore) {
ts := &mockTenantStore{}
a := &API{Store: ts, Cipher: mockCipher{}, Reg: &mockRegistry{}}
a := &API{
Store: ts, Cipher: mockCipher{}, Reg: &mockRegistry{},
Auth: alwaysOwnedAuthStore(), Sessions: alwaysValidSessions(uuid.New()),
}
return a, ts
}
@@ -173,7 +183,7 @@ func TestCreateAccount_SecretEncryptedAndNotInResponse(t *testing.T) {
router := NewRouter(a)
body := `{"provider":"selectel","secret":"super-secret-token","comment":"prod"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts", strings.NewReader(body))
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts", strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -211,7 +221,7 @@ func TestListAccounts_NoSecretsInResponse(t *testing.T) {
}
router := NewRouter(a)
req := httptest.NewRequest(http.MethodGet, "/api/v1/projects/"+testPID+"/accounts", nil)
req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+testPID+"/accounts", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -234,7 +244,7 @@ func TestDeleteAccount_BadUUID(t *testing.T) {
a, _ := newTenantTestAPI()
router := NewRouter(a)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/projects/"+testPID+"/accounts/not-a-uuid", nil)
req := requestWithSessionCookie(http.MethodDelete, "/api/v1/projects/"+testPID+"/accounts/not-a-uuid", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -250,7 +260,7 @@ func TestCreateTemplate_SavesRecords(t *testing.T) {
router := NewRouter(a)
body := `{"name":"base","records":[{"type":"A","name":"@","ttl":300,"values":["1.2.3.4"]}]}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/templates", strings.NewReader(body))
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/templates", strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -278,7 +288,7 @@ func TestUpdateTemplate_BadUUID(t *testing.T) {
router := NewRouter(a)
body := `{"name":"x","records":[]}`
req := httptest.NewRequest(http.MethodPut, "/api/v1/projects/"+testPID+"/templates/not-a-uuid", strings.NewReader(body))
req := requestWithSessionCookie(http.MethodPut, "/api/v1/projects/"+testPID+"/templates/not-a-uuid", strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -299,7 +309,7 @@ func TestImportZones_CreatesDomainPerZone(t *testing.T) {
}}
router := NewRouter(a)
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/"+accID.String()+"/import", nil)
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/"+accID.String()+"/import", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -337,7 +347,7 @@ func TestImportZones_AtomicRollbackOnError(t *testing.T) {
}}
router := NewRouter(a)
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/"+accID.String()+"/import", nil)
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/"+accID.String()+"/import", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -356,7 +366,7 @@ func TestImportZones_BadAccountUUID(t *testing.T) {
a, _ := newTenantTestAPI()
router := NewRouter(a)
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/not-a-uuid/import", nil)
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/not-a-uuid/import", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -370,7 +380,7 @@ func TestCreateDomain_BadProjectUUID(t *testing.T) {
router := NewRouter(a)
body := `{"providerAccountId":"` + uuid.New().String() + `","zoneName":"example.com","zoneId":"z1"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/not-a-uuid/domains", strings.NewReader(body))
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/not-a-uuid/domains", strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -390,7 +400,7 @@ func TestCreateDomain_AccountNotFoundInProject(t *testing.T) {
// ts.accounts is empty, so GetAccount will not find this id.
foreignAccID := uuid.New()
body := `{"providerAccountId":"` + foreignAccID.String() + `","zoneName":"example.com","zoneId":"z1"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -413,7 +423,7 @@ func TestCreateDomain_TemplateNotFoundInProject(t *testing.T) {
foreignTplID := uuid.New()
body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1","templateId":"` + foreignTplID.String() + `"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -434,7 +444,7 @@ func TestCreateDomain_HappyPath(t *testing.T) {
router := NewRouter(a)
body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -464,7 +474,7 @@ func TestCreateDomain_ValidTemplateInProject(t *testing.T) {
router := NewRouter(a)
body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1","templateId":"` + tplID.String() + `"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -490,7 +500,7 @@ func TestSetDomainTemplate_ValidTemplateId(t *testing.T) {
router := NewRouter(a)
body := `{"templateId":"` + tplID.String() + `"}`
req := httptest.NewRequest(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body))
req := requestWithSessionCookie(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -513,7 +523,7 @@ func TestSetDomainTemplate_BadTemplateUUID(t *testing.T) {
router := NewRouter(a)
body := `{"templateId":"not-a-uuid"}`
req := httptest.NewRequest(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body))
req := requestWithSessionCookie(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -530,7 +540,7 @@ func TestSetDomainTemplate_TemplateNotFound(t *testing.T) {
router := NewRouter(a)
body := `{"templateId":"` + uuid.New().String() + `"}`
req := httptest.NewRequest(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body))
req := requestWithSessionCookie(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -543,7 +553,7 @@ func TestDeleteDomain_BadUUID(t *testing.T) {
a, _ := newTenantTestAPI()
router := NewRouter(a)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/projects/"+testPID+"/domains/not-a-uuid", nil)
req := requestWithSessionCookie(http.MethodDelete, "/api/v1/projects/"+testPID+"/domains/not-a-uuid", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
+9 -7
View File
@@ -21,7 +21,7 @@ type DomainRef struct {
}
type Loader interface {
LoadDomain(ctx context.Context, domainID uuid.UUID) (DomainRef, error)
LoadDomain(ctx context.Context, projectID, domainID uuid.UUID) (DomainRef, error)
}
type Recorder interface {
@@ -45,8 +45,10 @@ func New(loader Loader, rec Recorder, reg *registry.Registry, cipher *crypto.Cip
}
// resolve loads the domain, its provider and decrypted credentials, and computes the diff.
func (s *DomainService) resolve(ctx context.Context, domainID uuid.UUID) (provider.Provider, provider.Credentials, DomainRef, diff.Changeset, error) {
ref, err := s.loader.LoadDomain(ctx, domainID)
// projectID scopes the lookup so a domainID belonging to another tenant's
// project can never be resolved here (closes IDOR).
func (s *DomainService) resolve(ctx context.Context, projectID, domainID uuid.UUID) (provider.Provider, provider.Credentials, DomainRef, diff.Changeset, error) {
ref, err := s.loader.LoadDomain(ctx, projectID, domainID)
if err != nil {
return nil, provider.Credentials{}, ref, diff.Changeset{}, err
}
@@ -68,8 +70,8 @@ func (s *DomainService) resolve(ctx context.Context, domainID uuid.UUID) (provid
}
// Check computes and records the diff between template and zone.
func (s *DomainService) Check(ctx context.Context, domainID uuid.UUID) (diff.Changeset, error) {
_, _, _, cs, err := s.resolve(ctx, domainID)
func (s *DomainService) Check(ctx context.Context, projectID, domainID uuid.UUID) (diff.Changeset, error) {
_, _, _, cs, err := s.resolve(ctx, projectID, domainID)
if err != nil {
return diff.Changeset{}, err
}
@@ -80,8 +82,8 @@ func (s *DomainService) Check(ctx context.Context, domainID uuid.UUID) (diff.Cha
}
// Apply applies updates always (when ApplyUpdates) and prunes only when ApplyPrunes.
func (s *DomainService) Apply(ctx context.Context, domainID uuid.UUID, req ApplyRequest) (diff.Changeset, error) {
p, creds, ref, cs, err := s.resolve(ctx, domainID)
func (s *DomainService) Apply(ctx context.Context, projectID, domainID uuid.UUID, req ApplyRequest) (diff.Changeset, error) {
p, creds, ref, cs, err := s.resolve(ctx, projectID, domainID)
if err != nil {
return diff.Changeset{}, err
}
+6 -4
View File
@@ -44,7 +44,9 @@ func (f *fakeProvider) ApplyChanges(_ context.Context, _ provider.Credentials, _
type fakeLoader struct{ ref DomainRef }
func (l fakeLoader) LoadDomain(context.Context, uuid.UUID) (DomainRef, error) { return l.ref, nil }
func (l fakeLoader) LoadDomain(context.Context, uuid.UUID, uuid.UUID) (DomainRef, error) {
return l.ref, nil
}
type nopRecorder struct{}
@@ -66,7 +68,7 @@ func TestCheckProducesDiff(t *testing.T) {
{Type: "A", Name: "a.example.com.", TTL: 300, Values: []string{"1.1.1.1"}}, // update
}}
svc, _ := setup(t, actual, tmpl)
cs, err := svc.Check(context.Background(), uuid.New())
cs, err := svc.Check(context.Background(), uuid.New(), uuid.New())
if err != nil {
t.Fatal(err)
}
@@ -87,7 +89,7 @@ func TestApplyRespectsPruneGuard(t *testing.T) {
// applyPrunes=false → удаление b НЕ применяется
svc, fp := setup(t, actual, tmpl)
if _, err := svc.Apply(context.Background(), uuid.New(), ApplyRequest{ApplyUpdates: true, ApplyPrunes: false}); err != nil {
if _, err := svc.Apply(context.Background(), uuid.New(), uuid.New(), ApplyRequest{ApplyUpdates: true, ApplyPrunes: false}); err != nil {
t.Fatal(err)
}
for _, d := range fp.applied.Diffs {
@@ -98,7 +100,7 @@ func TestApplyRespectsPruneGuard(t *testing.T) {
// applyPrunes=true → удаление b применяется
svc2, fp2 := setup(t, actual, tmpl)
if _, err := svc2.Apply(context.Background(), uuid.New(), ApplyRequest{ApplyUpdates: true, ApplyPrunes: true}); err != nil {
if _, err := svc2.Apply(context.Background(), uuid.New(), uuid.New(), ApplyRequest{ApplyUpdates: true, ApplyPrunes: true}); err != nil {
t.Fatal(err)
}
var sawDelete bool
+8 -3
View File
@@ -162,9 +162,14 @@ SELECT d.zone_id, a.provider, a.secret_enc, t.doc
FROM domains d
JOIN provider_accounts a ON a.id = d.provider_account_id
LEFT JOIN templates t ON t.id = d.template_id
WHERE d.id = $1
WHERE d.id = $1 AND d.project_id = $2
`
type LoadDomainFullParams struct {
ID uuid.UUID `json:"id"`
ProjectID uuid.UUID `json:"project_id"`
}
type LoadDomainFullRow struct {
ZoneID string `json:"zone_id"`
Provider string `json:"provider"`
@@ -172,8 +177,8 @@ type LoadDomainFullRow struct {
Doc *dto.TemplateDoc `json:"doc"`
}
func (q *Queries) LoadDomainFull(ctx context.Context, id uuid.UUID) (LoadDomainFullRow, error) {
row := q.db.QueryRow(ctx, loadDomainFull, id)
func (q *Queries) LoadDomainFull(ctx context.Context, arg LoadDomainFullParams) (LoadDomainFullRow, error) {
row := q.db.QueryRow(ctx, loadDomainFull, arg.ID, arg.ProjectID)
var i LoadDomainFullRow
err := row.Scan(
&i.ZoneID,
+5 -3
View File
@@ -13,9 +13,11 @@ import (
)
// LoadDomain joins domains+provider_accounts+templates to build the
// service.DomainRef needed to check/apply a domain's DNS records.
func (s *Store) LoadDomain(ctx context.Context, domainID uuid.UUID) (service.DomainRef, error) {
row, err := s.q.LoadDomainFull(ctx, domainID)
// service.DomainRef needed to check/apply a domain's DNS records. Scoped by
// projectID so a domain belonging to another tenant's project can never be
// loaded, even if its domainID is guessed/leaked (closes IDOR).
func (s *Store) LoadDomain(ctx context.Context, projectID, domainID uuid.UUID) (service.DomainRef, error) {
row, err := s.q.LoadDomainFull(ctx, db.LoadDomainFullParams{ID: domainID, ProjectID: projectID})
if err != nil {
return service.DomainRef{}, err
}
+2 -2
View File
@@ -40,7 +40,7 @@ func TestLoadDomainAndSaveCheckRun(t *testing.T) {
t.Fatal(err)
}
ref, err := s.LoadDomain(ctx, domain.ID)
ref, err := s.LoadDomain(ctx, defaultProject, domain.ID)
if err != nil {
t.Fatal(err)
}
@@ -87,7 +87,7 @@ func TestLoadDomainNoTemplate(t *testing.T) {
t.Fatal(err)
}
if _, err := s.LoadDomain(ctx, domain.ID); err == nil {
if _, err := s.LoadDomain(ctx, defaultProject, domain.ID); err == nil {
t.Fatal("expected error for domain without template, got nil")
}
}
+1 -1
View File
@@ -27,4 +27,4 @@ SELECT d.zone_id, a.provider, a.secret_enc, t.doc
FROM domains d
JOIN provider_accounts a ON a.id = d.provider_account_id
LEFT JOIN templates t ON t.id = d.template_id
WHERE d.id = $1;
WHERE d.id = $1 AND d.project_id = $2;
+2 -2
View File
@@ -214,7 +214,7 @@ func TestSetDomainTemplate_ClosesImportCheckLoop(t *testing.T) {
dom := doms[0]
// Before binding, the domain is not checkable.
if _, err := s.LoadDomain(ctx, dom.ID); err == nil {
if _, err := s.LoadDomain(ctx, defaultProject, dom.ID); err == nil {
t.Fatal("expected LoadDomain to fail before a template is bound")
}
@@ -234,7 +234,7 @@ func TestSetDomainTemplate_ClosesImportCheckLoop(t *testing.T) {
t.Fatalf("expected domain.TemplateID=%s, got %+v", tpl.ID, updated.TemplateID)
}
ref, err := s.LoadDomain(ctx, dom.ID)
ref, err := s.LoadDomain(ctx, defaultProject, dom.ID)
if err != nil {
t.Fatalf("expected LoadDomain to succeed after binding template, got error: %v", err)
}