fix(store): scope SetDomainStatus by project (IDOR); scheduler reuses DeriveStatus

handleCheck's error branch wrote last_check_status via an id-only UPDATE, so
an authenticated caller's own valid project id paired with a foreign domain
id in the URL could flip a stranger's domain to "error" even though Check
itself is project-scoped and would 404/error out first. Add project_id to
the WHERE clause (queries/domains.sql + generated db/domains.sql.go), thread
projectID through Store/TenantStore/SchedStore SetDomainStatus, and pass pid
from context at both call sites in handleCheck plus the scheduler.

Also collapse checkDomain's inline status derivation in scheduler.go into a
call to service.DeriveStatus, the same helper handleCheck already uses, so
there's a single source of truth for "drift vs in_sync" instead of two
copies that could drift apart.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01BwxdSt4reTm7Dj1oxRvpP3
This commit is contained in:
2026-07-05 14:40:13 +07:00
parent 784e7bd822
commit 27d70a987e
10 changed files with 149 additions and 35 deletions
+3 -2
View File
@@ -52,8 +52,9 @@ type TenantStore interface {
SetDomainTemplate(ctx context.Context, domainID, projectID uuid.UUID, templateID *uuid.UUID) (store.Domain, error) SetDomainTemplate(ctx context.Context, domainID, projectID uuid.UUID, templateID *uuid.UUID) (store.Domain, error)
// SetDomainStatus persists the outcome of a manual check (handleCheck) so // SetDomainStatus persists the outcome of a manual check (handleCheck) so
// the domain's badge reflects reality immediately, instead of staying // the domain's badge reflects reality immediately, instead of staying
// "unknown" until the scheduler's next tick. // "unknown" until the scheduler's next tick. Scoped by projectID so a
SetDomainStatus(ctx context.Context, domainID uuid.UUID, status string) error // foreign domain ID can never have its status overwritten (IDOR-on-write).
SetDomainStatus(ctx context.Context, domainID, projectID uuid.UUID, status string) error
} }
// Cipher encrypts/decrypts provider account secrets. *crypto.Cipher satisfies it. // Cipher encrypts/decrypts provider account secrets. *crypto.Cipher satisfies it.
+51 -6
View File
@@ -172,8 +172,9 @@ func TestCheckEndpoint_PersistsDriftStatus(t *testing.T) {
if w.Code != http.StatusOK { if w.Code != http.StatusOK {
t.Fatalf("status %d body %s", w.Code, w.Body.String()) t.Fatalf("status %d body %s", w.Code, w.Body.String())
} }
if len(ts.statusCalls) != 1 || ts.statusCalls[0].domainID != did || ts.statusCalls[0].status != service.StatusDrift { wantPID := uuid.MustParse("00000000-0000-0000-0000-000000000002")
t.Fatalf("expected SetDomainStatus(%s, drift), got %+v", did, ts.statusCalls) if len(ts.statusCalls) != 1 || ts.statusCalls[0].domainID != did || ts.statusCalls[0].projectID != wantPID || ts.statusCalls[0].status != service.StatusDrift {
t.Fatalf("expected SetDomainStatus(%s, %s, drift), got %+v", did, wantPID, ts.statusCalls)
} }
} }
@@ -194,8 +195,9 @@ func TestCheckEndpoint_PersistsInSyncStatus(t *testing.T) {
if w.Code != http.StatusOK { if w.Code != http.StatusOK {
t.Fatalf("status %d body %s", w.Code, w.Body.String()) t.Fatalf("status %d body %s", w.Code, w.Body.String())
} }
if len(ts.statusCalls) != 1 || ts.statusCalls[0].status != service.StatusInSync { wantPID := uuid.MustParse("00000000-0000-0000-0000-000000000002")
t.Fatalf("expected SetDomainStatus(_, in_sync), got %+v", ts.statusCalls) if len(ts.statusCalls) != 1 || ts.statusCalls[0].projectID != wantPID || ts.statusCalls[0].status != service.StatusInSync {
t.Fatalf("expected SetDomainStatus(_, %s, in_sync), got %+v", wantPID, ts.statusCalls)
} }
} }
@@ -218,8 +220,51 @@ func TestCheckEndpoint_ErrorPersistsErrorStatus(t *testing.T) {
if w.Code != http.StatusInternalServerError { if w.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d body %s", w.Code, w.Body.String()) t.Fatalf("expected 500, got %d body %s", w.Code, w.Body.String())
} }
if len(ts.statusCalls) != 1 || ts.statusCalls[0].domainID != did || ts.statusCalls[0].status != service.StatusError { wantPID := uuid.MustParse("00000000-0000-0000-0000-000000000002")
t.Fatalf("expected SetDomainStatus(%s, error), got %+v", did, ts.statusCalls) if len(ts.statusCalls) != 1 || ts.statusCalls[0].domainID != did || ts.statusCalls[0].projectID != wantPID || ts.statusCalls[0].status != service.StatusError {
t.Fatalf("expected SetDomainStatus(%s, %s, error), got %+v", did, wantPID, ts.statusCalls)
}
}
// TestCheckEndpoint_ErrorScopesStatusToCallerProject covers the HIGH
// IDOR-on-write fix: handleCheck's error branch must persist the failure
// status scoped to the caller's OWN project (pid from the URL/context), even
// when the domain ID in the URL belongs to (or doesn't exist in) a different
// tenant. The handler itself has no way to know whether did is foreign — the
// scoping guarantee comes from always passing pid through to
// Store.SetDomainStatus, which is enforced to be a no-op for a mismatched
// project_id at the store/SQL layer (see internal/store/schedule_test.go's
// TestSetDomainStatus_ScopedByProject_ForeignProjectIsNoOp). This test proves
// the handler holds up its side of that contract: pid, never a zero value or
// some other project, is what gets passed down.
func TestCheckEndpoint_ErrorScopesStatusToCallerProject(t *testing.T) {
a, m := newTestAPI()
ts := a.Store.(*mockTenantStore)
m.checkErr = errors.New("boom: provider unreachable")
router := NewRouter(a)
callerPID := uuid.New()
// foreignDID stands in for a domain ID the caller does not own — from the
// handler's perspective it's just whatever {did} was in the URL; only the
// store layer can (and does) enforce that it isn't actually foreignPID's.
foreignDID := uuid.New()
req := requestWithSessionCookie(http.MethodGet,
"/api/v1/projects/"+callerPID.String()+"/domains/"+foreignDID.String()+"/check", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d body %s", w.Code, w.Body.String())
}
if len(ts.statusCalls) != 1 {
t.Fatalf("expected exactly 1 SetDomainStatus call, got %+v", ts.statusCalls)
}
call := ts.statusCalls[0]
if call.projectID != callerPID {
t.Fatalf("expected SetDomainStatus scoped to caller's own pid %s, got projectID %s (never empty/foreign)", callerPID, call.projectID)
}
if call.domainID != foreignDID || call.status != service.StatusError {
t.Fatalf("expected SetDomainStatus(%s, %s, error), got %+v", foreignDID, callerPID, call)
} }
} }
+2 -2
View File
@@ -39,7 +39,7 @@ func (a *API) handleCheck(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
// Persist the failure so the domain badge reflects it instead of stale // Persist the failure so the domain badge reflects it instead of stale
// "unknown"; the write error (if any) is logged, never masks the 500. // "unknown"; the write error (if any) is logged, never masks the 500.
if serr := a.Store.SetDomainStatus(r.Context(), did, service.StatusError); serr != nil { if serr := a.Store.SetDomainStatus(r.Context(), did, pid, service.StatusError); serr != nil {
log.Printf("api: set domain status (error) failed: %v", serr) log.Printf("api: set domain status (error) failed: %v", serr)
} }
log.Printf("api: check failed: %v", err) log.Printf("api: check failed: %v", err)
@@ -48,7 +48,7 @@ func (a *API) handleCheck(w http.ResponseWriter, r *http.Request) {
} }
// Manual check persists status/history only — no notification. Notify // Manual check persists status/history only — no notification. Notify
// remains the scheduler's responsibility (see internal/scheduler). // remains the scheduler's responsibility (see internal/scheduler).
if serr := a.Store.SetDomainStatus(r.Context(), did, service.DeriveStatus(cs)); serr != nil { if serr := a.Store.SetDomainStatus(r.Context(), did, pid, service.DeriveStatus(cs)); serr != nil {
log.Printf("api: set domain status failed: %v", serr) log.Printf("api: set domain status failed: %v", serr)
} }
writeJSON(w, http.StatusOK, toChangesetResponse(cs)) writeJSON(w, http.StatusOK, toChangesetResponse(cs))
+9 -5
View File
@@ -40,10 +40,12 @@ type mockTenantStore struct {
setDomainTemplateErr error setDomainTemplateErr error
// statusCalls records every SetDomainStatus(domainID, status) call, in // statusCalls records every SetDomainStatus(domainID, projectID, status)
// order, so tests can assert what the handler persisted. // call, in order, so tests can assert what the handler persisted — and,
// crucially, which projectID it scoped the write to (IDOR regression).
statusCalls []struct { statusCalls []struct {
domainID uuid.UUID domainID uuid.UUID
projectID uuid.UUID
status string status string
} }
setDomainStatusErr error setDomainStatusErr error
@@ -135,12 +137,14 @@ func (m *mockTenantStore) SetDomainTemplate(_ context.Context, domainID, project
} }
// SetDomainStatus records the call for assertion instead of actually mutating // SetDomainStatus records the call for assertion instead of actually mutating
// m.domains — handleCheck tests only need to verify what was written. // m.domains — handleCheck tests only need to verify what was written, and
func (m *mockTenantStore) SetDomainStatus(_ context.Context, domainID uuid.UUID, status string) error { // which projectID it was scoped to (IDOR regression coverage).
func (m *mockTenantStore) SetDomainStatus(_ context.Context, domainID, projectID uuid.UUID, status string) error {
m.statusCalls = append(m.statusCalls, struct { m.statusCalls = append(m.statusCalls, struct {
domainID uuid.UUID domainID uuid.UUID
projectID uuid.UUID
status string status string
}{domainID, status}) }{domainID, projectID, status})
return m.setDomainStatusErr return m.setDomainStatusErr
} }
+10 -8
View File
@@ -40,7 +40,9 @@ type SchedStore interface {
TouchScheduleRun(ctx context.Context, projectID uuid.UUID, at time.Time) error TouchScheduleRun(ctx context.Context, projectID uuid.UUID, at time.Time) error
ListDomains(ctx context.Context, projectID uuid.UUID) ([]store.Domain, error) ListDomains(ctx context.Context, projectID uuid.UUID) ([]store.Domain, error)
GetDomainStatus(ctx context.Context, domainID uuid.UUID) (string, error) GetDomainStatus(ctx context.Context, domainID uuid.UUID) (string, error)
SetDomainStatus(ctx context.Context, domainID uuid.UUID, status string) error // SetDomainStatus is scoped by projectID so a foreign domain ID can never
// have its status overwritten (IDOR-on-write) — see internal/store/tenant.go.
SetDomainStatus(ctx context.Context, domainID, projectID uuid.UUID, status string) error
CountDriftDomains(ctx context.Context) (int, error) CountDriftDomains(ctx context.Context) (int, error)
} }
@@ -144,12 +146,12 @@ func (s *Scheduler) checkDomain(ctx context.Context, projectID uuid.UUID, d stor
cs, checkErr := s.checker.Check(ctx, projectID, d.ID) cs, checkErr := s.checker.Check(ctx, projectID, d.ID)
dur := time.Since(start) dur := time.Since(start)
newStatus := StatusInSync // Derive the status via the same helper the manual check handler uses
switch { // (internal/api/handlers.go) so both paths agree on what counts as
case checkErr != nil: // "drift" vs. "in sync" — a failed check is always "error" regardless.
newStatus = StatusError newStatus := StatusError
case len(cs.Actionable()) > 0: if checkErr == nil {
newStatus = StatusDrift newStatus = service.DeriveStatus(cs)
} }
s.metrics.ObserveCheck(newStatus, dur) s.metrics.ObserveCheck(newStatus, dur)
@@ -164,7 +166,7 @@ func (s *Scheduler) checkDomain(ctx context.Context, projectID uuid.UUID, d stor
// check (drift or in_sync). Calling it again here would double-write // check (drift or in_sync). Calling it again here would double-write
// check_runs history for the same check. // check_runs history for the same check.
if err := s.store.SetDomainStatus(ctx, d.ID, newStatus); err != nil { if err := s.store.SetDomainStatus(ctx, d.ID, projectID, newStatus); err != nil {
log.Printf("scheduler: set domain status for %s failed: %v", d.ID, err) log.Printf("scheduler: set domain status for %s failed: %v", d.ID, err)
} }
+5 -1
View File
@@ -62,7 +62,11 @@ func (m *mockStore) GetDomainStatus(ctx context.Context, domainID uuid.UUID) (st
return StatusUnknown, nil return StatusUnknown, nil
} }
func (m *mockStore) SetDomainStatus(ctx context.Context, domainID uuid.UUID, status string) error { // SetDomainStatus ignores projectID here — this in-memory fake is keyed by
// domainID alone and isn't exercising the IDOR scoping itself (that's
// covered at the store layer / API handler level); it exists only to match
// the SchedStore interface signature.
func (m *mockStore) SetDomainStatus(ctx context.Context, domainID, projectID uuid.UUID, status string) error {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
m.status[domainID] = status m.status[domainID] = status
+3 -2
View File
@@ -218,16 +218,17 @@ func (q *Queries) LoadDomainFull(ctx context.Context, arg LoadDomainFullParams)
} }
const setDomainStatus = `-- name: SetDomainStatus :exec const setDomainStatus = `-- name: SetDomainStatus :exec
UPDATE domains SET last_check_status = $2 WHERE id = $1 UPDATE domains SET last_check_status = $2 WHERE id = $1 AND project_id = $3
` `
type SetDomainStatusParams struct { type SetDomainStatusParams struct {
ID uuid.UUID `json:"id"` ID uuid.UUID `json:"id"`
LastCheckStatus string `json:"last_check_status"` LastCheckStatus string `json:"last_check_status"`
ProjectID uuid.UUID `json:"project_id"`
} }
func (q *Queries) SetDomainStatus(ctx context.Context, arg SetDomainStatusParams) error { func (q *Queries) SetDomainStatus(ctx context.Context, arg SetDomainStatusParams) error {
_, err := q.db.Exec(ctx, setDomainStatus, arg.ID, arg.LastCheckStatus) _, err := q.db.Exec(ctx, setDomainStatus, arg.ID, arg.LastCheckStatus, arg.ProjectID)
return err return err
} }
+1 -1
View File
@@ -33,7 +33,7 @@ WHERE d.id = $1 AND d.project_id = $2;
SELECT last_check_status FROM domains WHERE id = $1; SELECT last_check_status FROM domains WHERE id = $1;
-- name: SetDomainStatus :exec -- name: SetDomainStatus :exec
UPDATE domains SET last_check_status = $2 WHERE id = $1; UPDATE domains SET last_check_status = $2 WHERE id = $1 AND project_id = $3;
-- name: CountDriftDomains :one -- name: CountDriftDomains :one
SELECT count(*) FROM domains WHERE last_check_status = 'drift'; SELECT count(*) FROM domains WHERE last_check_status = 'drift';
+54 -1
View File
@@ -246,7 +246,7 @@ func TestDomainStatus_RoundTrip(t *testing.T) {
t.Fatalf("expected default status 'unknown', got %q", status) t.Fatalf("expected default status 'unknown', got %q", status)
} }
if err := s.SetDomainStatus(ctx, d.ID, "ok"); err != nil { if err := s.SetDomainStatus(ctx, d.ID, p.ID, "ok"); err != nil {
t.Fatal(err) t.Fatal(err)
} }
status, err = s.GetDomainStatus(ctx, d.ID) status, err = s.GetDomainStatus(ctx, d.ID)
@@ -265,3 +265,56 @@ func TestDomainStatus_RoundTrip(t *testing.T) {
t.Fatalf("expected ListDomains to reflect updated status: %+v", domains) t.Fatalf("expected ListDomains to reflect updated status: %+v", domains)
} }
} }
// TestSetDomainStatus_ScopedByProject_ForeignProjectIsNoOp covers the IDOR
// fix: SetDomainStatus is called with a valid domain ID but a projectID that
// does NOT own it (e.g. an authenticated caller's own pid, paired with
// another tenant's did in the URL). The WHERE id = $1 AND project_id = $3
// clause must match zero rows — no error, but the foreign domain's status
// must remain untouched, never "error"/"drift"/whatever was passed in.
func TestSetDomainStatus_ScopedByProject_ForeignProjectIsNoOp(t *testing.T) {
s, ctx := newStore(t)
_, owner, err := s.RegisterUser(ctx, "domain-status-owner@example.com", "argon2-hash")
if err != nil {
t.Fatal(err)
}
_, attacker, err := s.RegisterUser(ctx, "domain-status-attacker@example.com", "argon2-hash")
if err != nil {
t.Fatal(err)
}
acc, err := s.CreateAccount(ctx, owner.ID, "selectel", "enc-blob", "test")
if err != nil {
t.Fatal(err)
}
d, err := s.CreateDomain(ctx, owner.ID, acc.ID, "example.com", "zone-1", nil)
if err != nil {
t.Fatal(err)
}
// attacker's own (valid) project ID, paired with owner's domain ID —
// mirrors the exact request shape an authenticated attacker could send.
if err := s.SetDomainStatus(ctx, d.ID, attacker.ID, "error"); err != nil {
t.Fatal(err)
}
status, err := s.GetDomainStatus(ctx, d.ID)
if err != nil {
t.Fatal(err)
}
if status != "unknown" {
t.Fatalf("expected foreign-project SetDomainStatus to be a no-op, but status changed to %q", status)
}
// The legitimate owner can still update it — proves the no-op above was
// due to project scoping, not some unrelated write failure.
if err := s.SetDomainStatus(ctx, d.ID, owner.ID, "error"); err != nil {
t.Fatal(err)
}
status, err = s.GetDomainStatus(ctx, d.ID)
if err != nil {
t.Fatal(err)
}
if status != "error" {
t.Fatalf("expected owner's SetDomainStatus to apply, got %q", status)
}
}
+7 -3
View File
@@ -253,9 +253,13 @@ func (s *Store) GetDomainStatus(ctx context.Context, domainID uuid.UUID) (string
} }
// SetDomainStatus records the outcome of the most recent check/apply run for // SetDomainStatus records the outcome of the most recent check/apply run for
// a domain (e.g. "ok", "drift", "error"). // a domain (e.g. "ok", "drift", "error"). Scoped by projectID — a domain ID
func (s *Store) SetDomainStatus(ctx context.Context, domainID uuid.UUID, status string) error { // belonging to another tenant's project is left untouched (matches zero
return s.q.SetDomainStatus(ctx, db.SetDomainStatusParams{ID: domainID, LastCheckStatus: status}) // rows) rather than being overwritten, closing an IDOR-on-write where a
// caller's own valid pid + a foreign did could otherwise flip a stranger's
// domain status.
func (s *Store) SetDomainStatus(ctx context.Context, domainID, projectID uuid.UUID, status string) error {
return s.q.SetDomainStatus(ctx, db.SetDomainStatusParams{ID: domainID, LastCheckStatus: status, ProjectID: projectID})
} }
// CountDriftDomains returns the current number of domains system-wide whose // CountDriftDomains returns the current number of domains system-wide whose