0b26923586
RecordDiff.Key() gives a stable normalized identifier ("TYPE name.") for
every diff kind, exposed as recordView.Key. ApplyRequest now takes
Updates/Prunes key lists instead of two booleans, so callers can apply a
subset of records. service.Apply builds the applied set with selected
prunes (Delete) added before selected updates (Add/Update) — an
invariant, not an option — since the provider rejects an Add/Update
whose name still conflicts with an existing record (e.g. a CNAME cannot
be created while an A on the same name still exists).
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01BwxdSt4reTm7Dj1oxRvpP3
290 lines
11 KiB
Go
290 lines
11 KiB
Go
package api
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/google/uuid"
|
|
|
|
"github.com/vasyakrg/dns-autoresolver/internal/diff"
|
|
"github.com/vasyakrg/dns-autoresolver/internal/model"
|
|
"github.com/vasyakrg/dns-autoresolver/internal/service"
|
|
)
|
|
|
|
type mockCheckApplier struct {
|
|
lastReq service.ApplyRequest
|
|
zoneRecords []model.Record
|
|
zoneErr error
|
|
|
|
// checkCS/checkErr, when set, override Check's default fixed changeset —
|
|
// used by handleCheck status-persistence tests (drift/in_sync/error).
|
|
checkCS *diff.Changeset
|
|
checkErr error
|
|
}
|
|
|
|
func (m *mockCheckApplier) Check(context.Context, uuid.UUID, uuid.UUID) (diff.Changeset, error) {
|
|
if m.checkErr != nil {
|
|
return diff.Changeset{}, m.checkErr
|
|
}
|
|
if m.checkCS != nil {
|
|
return *m.checkCS, nil
|
|
}
|
|
d := model.Record{Type: model.A, Name: "a.example.com.", TTL: 300, Values: []string{"1.1.1.1"}}
|
|
return diff.Changeset{Diffs: []diff.RecordDiff{{Kind: diff.Add, Type: d.Type, Name: d.Name, Desired: &d}}}, nil
|
|
}
|
|
func (m *mockCheckApplier) Apply(_ context.Context, _, _ uuid.UUID, req service.ApplyRequest) (diff.Changeset, error) {
|
|
m.lastReq = req
|
|
return diff.Changeset{}, nil
|
|
}
|
|
func (m *mockCheckApplier) ZoneRecords(context.Context, uuid.UUID, uuid.UUID) ([]model.Record, error) {
|
|
if m.zoneErr != nil {
|
|
return nil, m.zoneErr
|
|
}
|
|
return m.zoneRecords, nil
|
|
}
|
|
|
|
// newTestAPI wires a fixed authenticated user who owns whatever project id
|
|
// is requested (via alwaysOwnedAuthStore/alwaysValidSessions in
|
|
// middleware_test.go) — these tests exercise check/apply behavior past the
|
|
// RequireAuth/RequireProjectAccess boundary, which is covered separately by
|
|
// middleware_test.go's own tests and the IDOR regression.
|
|
func newTestAPI() (*API, *mockCheckApplier) {
|
|
m := &mockCheckApplier{}
|
|
return &API{
|
|
Svc: m, Store: &mockTenantStore{},
|
|
Auth: alwaysOwnedAuthStore(), Sessions: alwaysValidSessions(uuid.New()),
|
|
}, m
|
|
}
|
|
|
|
func TestCheckEndpoint(t *testing.T) {
|
|
a, _ := newTestAPI()
|
|
router := NewRouter(a)
|
|
|
|
did := uuid.New().String()
|
|
req := requestWithSessionCookie(http.MethodGet,
|
|
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/check", nil)
|
|
w := httptest.NewRecorder()
|
|
router.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("status %d, body %s", w.Code, w.Body.String())
|
|
}
|
|
var resp changesetResponse
|
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(resp.Updates) != 1 {
|
|
t.Fatalf("expected 1 update in response, got %+v", resp)
|
|
}
|
|
}
|
|
|
|
// TestApplySendsSelectedKeys covers the per-record selection request shape:
|
|
// POST /apply with only "prunes" set must reach the service with that key in
|
|
// Prunes and an empty Updates.
|
|
func TestApplySendsSelectedKeys(t *testing.T) {
|
|
a, m := newTestAPI()
|
|
router := NewRouter(a)
|
|
|
|
did := uuid.New().String()
|
|
body := `{"prunes":["A gitlocator.com."]}`
|
|
req := requestWithSessionCookie(http.MethodPost,
|
|
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply",
|
|
strings.NewReader(body))
|
|
w := httptest.NewRecorder()
|
|
router.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("status %d body %s", w.Code, w.Body.String())
|
|
}
|
|
if len(m.lastReq.Prunes) != 1 || m.lastReq.Prunes[0] != "A gitlocator.com." || len(m.lastReq.Updates) != 0 {
|
|
t.Fatalf("apply request mismatch: %+v", m.lastReq)
|
|
}
|
|
}
|
|
|
|
func TestApplyEmptyBodyOK(t *testing.T) {
|
|
a, m := newTestAPI()
|
|
router := NewRouter(a)
|
|
|
|
did := uuid.New().String()
|
|
req := requestWithSessionCookie(http.MethodPost,
|
|
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply", nil)
|
|
w := httptest.NewRecorder()
|
|
router.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("status %d body %s", w.Code, w.Body.String())
|
|
}
|
|
if len(m.lastReq.Updates) != 0 || len(m.lastReq.Prunes) != 0 {
|
|
t.Fatalf("expected empty Updates/Prunes for empty body, got %+v", m.lastReq)
|
|
}
|
|
}
|
|
|
|
func TestApplyMalformedBody(t *testing.T) {
|
|
a, _ := newTestAPI()
|
|
router := NewRouter(a)
|
|
|
|
did := uuid.New().String()
|
|
body := `{"updates":`
|
|
req := requestWithSessionCookie(http.MethodPost,
|
|
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply",
|
|
strings.NewReader(body))
|
|
w := httptest.NewRecorder()
|
|
router.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusBadRequest {
|
|
t.Fatalf("expected 400 for malformed body, got %d body %s", w.Code, w.Body.String())
|
|
}
|
|
}
|
|
|
|
func TestApplyBadUUID(t *testing.T) {
|
|
a, _ := newTestAPI()
|
|
router := NewRouter(a)
|
|
req := requestWithSessionCookie(http.MethodPost,
|
|
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/not-a-uuid/apply",
|
|
bytes.NewReader([]byte(`{}`)))
|
|
w := httptest.NewRecorder()
|
|
router.ServeHTTP(w, req)
|
|
if w.Code != http.StatusBadRequest {
|
|
t.Fatalf("expected 400 for bad uuid, got %d", w.Code)
|
|
}
|
|
}
|
|
|
|
// TestCheckEndpoint_PersistsDriftStatus covers the core bug fix: a manual
|
|
// check (GET .../check) whose changeset has an actionable prune must persist
|
|
// "drift" via Store.SetDomainStatus — previously only the scheduler wrote
|
|
// last_check_status, leaving a manually-checked domain stuck at "unknown".
|
|
func TestCheckEndpoint_PersistsDriftStatus(t *testing.T) {
|
|
a, m := newTestAPI()
|
|
ts := a.Store.(*mockTenantStore)
|
|
rec := model.Record{Type: model.A, Name: "x.example.com.", Values: []string{"1.1.1.1"}}
|
|
m.checkCS = &diff.Changeset{Diffs: []diff.RecordDiff{{Kind: diff.Delete, Type: rec.Type, Name: rec.Name, Actual: &rec}}}
|
|
router := NewRouter(a)
|
|
|
|
did := uuid.New()
|
|
req := requestWithSessionCookie(http.MethodGet,
|
|
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did.String()+"/check", nil)
|
|
w := httptest.NewRecorder()
|
|
router.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("status %d body %s", w.Code, w.Body.String())
|
|
}
|
|
wantPID := uuid.MustParse("00000000-0000-0000-0000-000000000002")
|
|
if len(ts.statusCalls) != 1 || ts.statusCalls[0].domainID != did || ts.statusCalls[0].projectID != wantPID || ts.statusCalls[0].status != service.StatusDrift {
|
|
t.Fatalf("expected SetDomainStatus(%s, %s, drift), got %+v", did, wantPID, ts.statusCalls)
|
|
}
|
|
}
|
|
|
|
// TestCheckEndpoint_PersistsInSyncStatus covers the no-drift case: a
|
|
// changeset with no actionable diffs persists "in_sync".
|
|
func TestCheckEndpoint_PersistsInSyncStatus(t *testing.T) {
|
|
a, m := newTestAPI()
|
|
ts := a.Store.(*mockTenantStore)
|
|
m.checkCS = &diff.Changeset{}
|
|
router := NewRouter(a)
|
|
|
|
did := uuid.New()
|
|
req := requestWithSessionCookie(http.MethodGet,
|
|
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did.String()+"/check", nil)
|
|
w := httptest.NewRecorder()
|
|
router.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("status %d body %s", w.Code, w.Body.String())
|
|
}
|
|
wantPID := uuid.MustParse("00000000-0000-0000-0000-000000000002")
|
|
if len(ts.statusCalls) != 1 || ts.statusCalls[0].projectID != wantPID || ts.statusCalls[0].status != service.StatusInSync {
|
|
t.Fatalf("expected SetDomainStatus(_, %s, in_sync), got %+v", wantPID, ts.statusCalls)
|
|
}
|
|
}
|
|
|
|
// TestCheckEndpoint_ErrorPersistsErrorStatus covers the failure path: when
|
|
// Svc.Check itself fails (provider/loader error), the handler must persist
|
|
// "error" before returning 500 — a write failure of the status itself must
|
|
// not mask the original 500.
|
|
func TestCheckEndpoint_ErrorPersistsErrorStatus(t *testing.T) {
|
|
a, m := newTestAPI()
|
|
ts := a.Store.(*mockTenantStore)
|
|
m.checkErr = errors.New("boom: provider unreachable")
|
|
router := NewRouter(a)
|
|
|
|
did := uuid.New()
|
|
req := requestWithSessionCookie(http.MethodGet,
|
|
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did.String()+"/check", nil)
|
|
w := httptest.NewRecorder()
|
|
router.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusInternalServerError {
|
|
t.Fatalf("expected 500, got %d body %s", w.Code, w.Body.String())
|
|
}
|
|
wantPID := uuid.MustParse("00000000-0000-0000-0000-000000000002")
|
|
if len(ts.statusCalls) != 1 || ts.statusCalls[0].domainID != did || ts.statusCalls[0].projectID != wantPID || ts.statusCalls[0].status != service.StatusError {
|
|
t.Fatalf("expected SetDomainStatus(%s, %s, error), got %+v", did, wantPID, ts.statusCalls)
|
|
}
|
|
}
|
|
|
|
// TestCheckEndpoint_ErrorScopesStatusToCallerProject covers the HIGH
|
|
// IDOR-on-write fix: handleCheck's error branch must persist the failure
|
|
// status scoped to the caller's OWN project (pid from the URL/context), even
|
|
// when the domain ID in the URL belongs to (or doesn't exist in) a different
|
|
// tenant. The handler itself has no way to know whether did is foreign — the
|
|
// scoping guarantee comes from always passing pid through to
|
|
// Store.SetDomainStatus, which is enforced to be a no-op for a mismatched
|
|
// project_id at the store/SQL layer (see internal/store/schedule_test.go's
|
|
// TestSetDomainStatus_ScopedByProject_ForeignProjectIsNoOp). This test proves
|
|
// the handler holds up its side of that contract: pid, never a zero value or
|
|
// some other project, is what gets passed down.
|
|
func TestCheckEndpoint_ErrorScopesStatusToCallerProject(t *testing.T) {
|
|
a, m := newTestAPI()
|
|
ts := a.Store.(*mockTenantStore)
|
|
m.checkErr = errors.New("boom: provider unreachable")
|
|
router := NewRouter(a)
|
|
|
|
callerPID := uuid.New()
|
|
// foreignDID stands in for a domain ID the caller does not own — from the
|
|
// handler's perspective it's just whatever {did} was in the URL; only the
|
|
// store layer can (and does) enforce that it isn't actually foreignPID's.
|
|
foreignDID := uuid.New()
|
|
req := requestWithSessionCookie(http.MethodGet,
|
|
"/api/v1/projects/"+callerPID.String()+"/domains/"+foreignDID.String()+"/check", nil)
|
|
w := httptest.NewRecorder()
|
|
router.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusInternalServerError {
|
|
t.Fatalf("expected 500, got %d body %s", w.Code, w.Body.String())
|
|
}
|
|
if len(ts.statusCalls) != 1 {
|
|
t.Fatalf("expected exactly 1 SetDomainStatus call, got %+v", ts.statusCalls)
|
|
}
|
|
call := ts.statusCalls[0]
|
|
if call.projectID != callerPID {
|
|
t.Fatalf("expected SetDomainStatus scoped to caller's own pid %s, got projectID %s (never empty/foreign)", callerPID, call.projectID)
|
|
}
|
|
if call.domainID != foreignDID || call.status != service.StatusError {
|
|
t.Fatalf("expected SetDomainStatus(%s, %s, error), got %+v", foreignDID, callerPID, call)
|
|
}
|
|
}
|
|
|
|
// TestChangesetResponseEmptyMarshalsToArrays guards the белый-экран bug: an
|
|
// empty changeset (zone matches its template exactly, e.g. right after a
|
|
// snapshot) must marshal updates/prunes/readOnly as [] not null — a nil slice
|
|
// becomes JSON null and crashes the client's .length/.map calls.
|
|
func TestChangesetResponseEmptyMarshalsToArrays(t *testing.T) {
|
|
b, err := json.Marshal(toChangesetResponse(diff.Changeset{}))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
s := string(b)
|
|
for _, want := range []string{`"updates":[]`, `"prunes":[]`, `"readOnly":[]`} {
|
|
if !strings.Contains(s, want) {
|
|
t.Fatalf("expected %s in %s", want, s)
|
|
}
|
|
}
|
|
}
|