fix(store): scope SetDomainStatus by project (IDOR); scheduler reuses DeriveStatus
handleCheck's error branch wrote last_check_status via an id-only UPDATE, so an authenticated caller's own valid project id paired with a foreign domain id in the URL could flip a stranger's domain to "error" even though Check itself is project-scoped and would 404/error out first. Add project_id to the WHERE clause (queries/domains.sql + generated db/domains.sql.go), thread projectID through Store/TenantStore/SchedStore SetDomainStatus, and pass pid from context at both call sites in handleCheck plus the scheduler. Also collapse checkDomain's inline status derivation in scheduler.go into a call to service.DeriveStatus, the same helper handleCheck already uses, so there's a single source of truth for "drift vs in_sync" instead of two copies that could drift apart. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01BwxdSt4reTm7Dj1oxRvpP3
This commit is contained in:
+3
-2
@@ -52,8 +52,9 @@ type TenantStore interface {
|
||||
SetDomainTemplate(ctx context.Context, domainID, projectID uuid.UUID, templateID *uuid.UUID) (store.Domain, error)
|
||||
// SetDomainStatus persists the outcome of a manual check (handleCheck) so
|
||||
// the domain's badge reflects reality immediately, instead of staying
|
||||
// "unknown" until the scheduler's next tick.
|
||||
SetDomainStatus(ctx context.Context, domainID uuid.UUID, status string) error
|
||||
// "unknown" until the scheduler's next tick. Scoped by projectID so a
|
||||
// foreign domain ID can never have its status overwritten (IDOR-on-write).
|
||||
SetDomainStatus(ctx context.Context, domainID, projectID uuid.UUID, status string) error
|
||||
}
|
||||
|
||||
// Cipher encrypts/decrypts provider account secrets. *crypto.Cipher satisfies it.
|
||||
|
||||
@@ -172,8 +172,9 @@ func TestCheckEndpoint_PersistsDriftStatus(t *testing.T) {
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status %d body %s", w.Code, w.Body.String())
|
||||
}
|
||||
if len(ts.statusCalls) != 1 || ts.statusCalls[0].domainID != did || ts.statusCalls[0].status != service.StatusDrift {
|
||||
t.Fatalf("expected SetDomainStatus(%s, drift), got %+v", did, ts.statusCalls)
|
||||
wantPID := uuid.MustParse("00000000-0000-0000-0000-000000000002")
|
||||
if len(ts.statusCalls) != 1 || ts.statusCalls[0].domainID != did || ts.statusCalls[0].projectID != wantPID || ts.statusCalls[0].status != service.StatusDrift {
|
||||
t.Fatalf("expected SetDomainStatus(%s, %s, drift), got %+v", did, wantPID, ts.statusCalls)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -194,8 +195,9 @@ func TestCheckEndpoint_PersistsInSyncStatus(t *testing.T) {
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status %d body %s", w.Code, w.Body.String())
|
||||
}
|
||||
if len(ts.statusCalls) != 1 || ts.statusCalls[0].status != service.StatusInSync {
|
||||
t.Fatalf("expected SetDomainStatus(_, in_sync), got %+v", ts.statusCalls)
|
||||
wantPID := uuid.MustParse("00000000-0000-0000-0000-000000000002")
|
||||
if len(ts.statusCalls) != 1 || ts.statusCalls[0].projectID != wantPID || ts.statusCalls[0].status != service.StatusInSync {
|
||||
t.Fatalf("expected SetDomainStatus(_, %s, in_sync), got %+v", wantPID, ts.statusCalls)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -218,8 +220,51 @@ func TestCheckEndpoint_ErrorPersistsErrorStatus(t *testing.T) {
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("expected 500, got %d body %s", w.Code, w.Body.String())
|
||||
}
|
||||
if len(ts.statusCalls) != 1 || ts.statusCalls[0].domainID != did || ts.statusCalls[0].status != service.StatusError {
|
||||
t.Fatalf("expected SetDomainStatus(%s, error), got %+v", did, ts.statusCalls)
|
||||
wantPID := uuid.MustParse("00000000-0000-0000-0000-000000000002")
|
||||
if len(ts.statusCalls) != 1 || ts.statusCalls[0].domainID != did || ts.statusCalls[0].projectID != wantPID || ts.statusCalls[0].status != service.StatusError {
|
||||
t.Fatalf("expected SetDomainStatus(%s, %s, error), got %+v", did, wantPID, ts.statusCalls)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckEndpoint_ErrorScopesStatusToCallerProject covers the HIGH
|
||||
// IDOR-on-write fix: handleCheck's error branch must persist the failure
|
||||
// status scoped to the caller's OWN project (pid from the URL/context), even
|
||||
// when the domain ID in the URL belongs to (or doesn't exist in) a different
|
||||
// tenant. The handler itself has no way to know whether did is foreign — the
|
||||
// scoping guarantee comes from always passing pid through to
|
||||
// Store.SetDomainStatus, which is enforced to be a no-op for a mismatched
|
||||
// project_id at the store/SQL layer (see internal/store/schedule_test.go's
|
||||
// TestSetDomainStatus_ScopedByProject_ForeignProjectIsNoOp). This test proves
|
||||
// the handler holds up its side of that contract: pid, never a zero value or
|
||||
// some other project, is what gets passed down.
|
||||
func TestCheckEndpoint_ErrorScopesStatusToCallerProject(t *testing.T) {
|
||||
a, m := newTestAPI()
|
||||
ts := a.Store.(*mockTenantStore)
|
||||
m.checkErr = errors.New("boom: provider unreachable")
|
||||
router := NewRouter(a)
|
||||
|
||||
callerPID := uuid.New()
|
||||
// foreignDID stands in for a domain ID the caller does not own — from the
|
||||
// handler's perspective it's just whatever {did} was in the URL; only the
|
||||
// store layer can (and does) enforce that it isn't actually foreignPID's.
|
||||
foreignDID := uuid.New()
|
||||
req := requestWithSessionCookie(http.MethodGet,
|
||||
"/api/v1/projects/"+callerPID.String()+"/domains/"+foreignDID.String()+"/check", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("expected 500, got %d body %s", w.Code, w.Body.String())
|
||||
}
|
||||
if len(ts.statusCalls) != 1 {
|
||||
t.Fatalf("expected exactly 1 SetDomainStatus call, got %+v", ts.statusCalls)
|
||||
}
|
||||
call := ts.statusCalls[0]
|
||||
if call.projectID != callerPID {
|
||||
t.Fatalf("expected SetDomainStatus scoped to caller's own pid %s, got projectID %s (never empty/foreign)", callerPID, call.projectID)
|
||||
}
|
||||
if call.domainID != foreignDID || call.status != service.StatusError {
|
||||
t.Fatalf("expected SetDomainStatus(%s, %s, error), got %+v", foreignDID, callerPID, call)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ func (a *API) handleCheck(w http.ResponseWriter, r *http.Request) {
|
||||
if err != nil {
|
||||
// Persist the failure so the domain badge reflects it instead of stale
|
||||
// "unknown"; the write error (if any) is logged, never masks the 500.
|
||||
if serr := a.Store.SetDomainStatus(r.Context(), did, service.StatusError); serr != nil {
|
||||
if serr := a.Store.SetDomainStatus(r.Context(), did, pid, service.StatusError); serr != nil {
|
||||
log.Printf("api: set domain status (error) failed: %v", serr)
|
||||
}
|
||||
log.Printf("api: check failed: %v", err)
|
||||
@@ -48,7 +48,7 @@ func (a *API) handleCheck(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
// Manual check persists status/history only — no notification. Notify
|
||||
// remains the scheduler's responsibility (see internal/scheduler).
|
||||
if serr := a.Store.SetDomainStatus(r.Context(), did, service.DeriveStatus(cs)); serr != nil {
|
||||
if serr := a.Store.SetDomainStatus(r.Context(), did, pid, service.DeriveStatus(cs)); serr != nil {
|
||||
log.Printf("api: set domain status failed: %v", serr)
|
||||
}
|
||||
writeJSON(w, http.StatusOK, toChangesetResponse(cs))
|
||||
|
||||
@@ -40,11 +40,13 @@ type mockTenantStore struct {
|
||||
|
||||
setDomainTemplateErr error
|
||||
|
||||
// statusCalls records every SetDomainStatus(domainID, status) call, in
|
||||
// order, so tests can assert what the handler persisted.
|
||||
// statusCalls records every SetDomainStatus(domainID, projectID, status)
|
||||
// call, in order, so tests can assert what the handler persisted — and,
|
||||
// crucially, which projectID it scoped the write to (IDOR regression).
|
||||
statusCalls []struct {
|
||||
domainID uuid.UUID
|
||||
status string
|
||||
domainID uuid.UUID
|
||||
projectID uuid.UUID
|
||||
status string
|
||||
}
|
||||
setDomainStatusErr error
|
||||
}
|
||||
@@ -135,12 +137,14 @@ func (m *mockTenantStore) SetDomainTemplate(_ context.Context, domainID, project
|
||||
}
|
||||
|
||||
// SetDomainStatus records the call for assertion instead of actually mutating
|
||||
// m.domains — handleCheck tests only need to verify what was written.
|
||||
func (m *mockTenantStore) SetDomainStatus(_ context.Context, domainID uuid.UUID, status string) error {
|
||||
// m.domains — handleCheck tests only need to verify what was written, and
|
||||
// which projectID it was scoped to (IDOR regression coverage).
|
||||
func (m *mockTenantStore) SetDomainStatus(_ context.Context, domainID, projectID uuid.UUID, status string) error {
|
||||
m.statusCalls = append(m.statusCalls, struct {
|
||||
domainID uuid.UUID
|
||||
status string
|
||||
}{domainID, status})
|
||||
domainID uuid.UUID
|
||||
projectID uuid.UUID
|
||||
status string
|
||||
}{domainID, projectID, status})
|
||||
return m.setDomainStatusErr
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user