package api import ( "bytes" "context" "encoding/json" "errors" "net/http" "net/http/httptest" "strings" "testing" "github.com/google/uuid" "github.com/vasyakrg/dns-autoresolver/internal/diff" "github.com/vasyakrg/dns-autoresolver/internal/model" "github.com/vasyakrg/dns-autoresolver/internal/service" ) type mockCheckApplier struct { lastReq service.ApplyRequest zoneRecords []model.Record zoneErr error // checkCS/checkErr, when set, override Check's default fixed changeset — // used by handleCheck status-persistence tests (drift/in_sync/error). checkCS *diff.Changeset checkErr error } func (m *mockCheckApplier) Check(context.Context, uuid.UUID, uuid.UUID) (diff.Changeset, error) { if m.checkErr != nil { return diff.Changeset{}, m.checkErr } if m.checkCS != nil { return *m.checkCS, nil } d := model.Record{Type: model.A, Name: "a.example.com.", TTL: 300, Values: []string{"1.1.1.1"}} return diff.Changeset{Diffs: []diff.RecordDiff{{Kind: diff.Add, Type: d.Type, Name: d.Name, Desired: &d}}}, nil } func (m *mockCheckApplier) Apply(_ context.Context, _, _ uuid.UUID, req service.ApplyRequest) (diff.Changeset, error) { m.lastReq = req return diff.Changeset{}, nil } func (m *mockCheckApplier) ZoneRecords(context.Context, uuid.UUID, uuid.UUID) ([]model.Record, error) { if m.zoneErr != nil { return nil, m.zoneErr } return m.zoneRecords, nil } // newTestAPI wires a fixed authenticated user who owns whatever project id // is requested (via alwaysOwnedAuthStore/alwaysValidSessions in // middleware_test.go) — these tests exercise check/apply behavior past the // RequireAuth/RequireProjectAccess boundary, which is covered separately by // middleware_test.go's own tests and the IDOR regression. func newTestAPI() (*API, *mockCheckApplier) { m := &mockCheckApplier{} return &API{ Svc: m, Store: &mockTenantStore{}, Auth: alwaysOwnedAuthStore(), Sessions: alwaysValidSessions(uuid.New()), }, m } func TestCheckEndpoint(t *testing.T) { a, _ := newTestAPI() router := NewRouter(a) did := uuid.New().String() req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/check", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("status %d, body %s", w.Code, w.Body.String()) } var resp changesetResponse if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { t.Fatal(err) } if len(resp.Updates) != 1 { t.Fatalf("expected 1 update in response, got %+v", resp) } } // TestApplySendsSelectedKeys covers the per-record selection request shape: // POST /apply with only "prunes" set must reach the service with that key in // Prunes and an empty Updates. func TestApplySendsSelectedKeys(t *testing.T) { a, m := newTestAPI() router := NewRouter(a) did := uuid.New().String() body := `{"prunes":["A gitlocator.com."]}` req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply", strings.NewReader(body)) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("status %d body %s", w.Code, w.Body.String()) } if len(m.lastReq.Prunes) != 1 || m.lastReq.Prunes[0] != "A gitlocator.com." || len(m.lastReq.Updates) != 0 { t.Fatalf("apply request mismatch: %+v", m.lastReq) } } func TestApplyEmptyBodyOK(t *testing.T) { a, m := newTestAPI() router := NewRouter(a) did := uuid.New().String() req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("status %d body %s", w.Code, w.Body.String()) } if len(m.lastReq.Updates) != 0 || len(m.lastReq.Prunes) != 0 { t.Fatalf("expected empty Updates/Prunes for empty body, got %+v", m.lastReq) } } func TestApplyMalformedBody(t *testing.T) { a, _ := newTestAPI() router := NewRouter(a) did := uuid.New().String() body := `{"updates":` req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply", strings.NewReader(body)) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusBadRequest { t.Fatalf("expected 400 for malformed body, got %d body %s", w.Code, w.Body.String()) } } func TestApplyBadUUID(t *testing.T) { a, _ := newTestAPI() router := NewRouter(a) req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/not-a-uuid/apply", bytes.NewReader([]byte(`{}`))) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusBadRequest { t.Fatalf("expected 400 for bad uuid, got %d", w.Code) } } // TestCheckEndpoint_PersistsDriftStatus covers the core bug fix: a manual // check (GET .../check) whose changeset has an actionable prune must persist // "drift" via Store.SetDomainStatus — previously only the scheduler wrote // last_check_status, leaving a manually-checked domain stuck at "unknown". func TestCheckEndpoint_PersistsDriftStatus(t *testing.T) { a, m := newTestAPI() ts := a.Store.(*mockTenantStore) rec := model.Record{Type: model.A, Name: "x.example.com.", Values: []string{"1.1.1.1"}} m.checkCS = &diff.Changeset{Diffs: []diff.RecordDiff{{Kind: diff.Delete, Type: rec.Type, Name: rec.Name, Actual: &rec}}} router := NewRouter(a) did := uuid.New() req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did.String()+"/check", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("status %d body %s", w.Code, w.Body.String()) } wantPID := uuid.MustParse("00000000-0000-0000-0000-000000000002") if len(ts.statusCalls) != 1 || ts.statusCalls[0].domainID != did || ts.statusCalls[0].projectID != wantPID || ts.statusCalls[0].status != service.StatusDrift { t.Fatalf("expected SetDomainStatus(%s, %s, drift), got %+v", did, wantPID, ts.statusCalls) } } // TestCheckEndpoint_PersistsInSyncStatus covers the no-drift case: a // changeset with no actionable diffs persists "in_sync". func TestCheckEndpoint_PersistsInSyncStatus(t *testing.T) { a, m := newTestAPI() ts := a.Store.(*mockTenantStore) m.checkCS = &diff.Changeset{} router := NewRouter(a) did := uuid.New() req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did.String()+"/check", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("status %d body %s", w.Code, w.Body.String()) } wantPID := uuid.MustParse("00000000-0000-0000-0000-000000000002") if len(ts.statusCalls) != 1 || ts.statusCalls[0].projectID != wantPID || ts.statusCalls[0].status != service.StatusInSync { t.Fatalf("expected SetDomainStatus(_, %s, in_sync), got %+v", wantPID, ts.statusCalls) } } // TestCheckEndpoint_ErrorPersistsErrorStatus covers the failure path: when // Svc.Check itself fails (provider/loader error), the handler must persist // "error" before returning 500 — a write failure of the status itself must // not mask the original 500. func TestCheckEndpoint_ErrorPersistsErrorStatus(t *testing.T) { a, m := newTestAPI() ts := a.Store.(*mockTenantStore) m.checkErr = errors.New("boom: provider unreachable") router := NewRouter(a) did := uuid.New() req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did.String()+"/check", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusInternalServerError { t.Fatalf("expected 500, got %d body %s", w.Code, w.Body.String()) } wantPID := uuid.MustParse("00000000-0000-0000-0000-000000000002") if len(ts.statusCalls) != 1 || ts.statusCalls[0].domainID != did || ts.statusCalls[0].projectID != wantPID || ts.statusCalls[0].status != service.StatusError { t.Fatalf("expected SetDomainStatus(%s, %s, error), got %+v", did, wantPID, ts.statusCalls) } } // TestCheckEndpoint_ErrorScopesStatusToCallerProject covers the HIGH // IDOR-on-write fix: handleCheck's error branch must persist the failure // status scoped to the caller's OWN project (pid from the URL/context), even // when the domain ID in the URL belongs to (or doesn't exist in) a different // tenant. The handler itself has no way to know whether did is foreign — the // scoping guarantee comes from always passing pid through to // Store.SetDomainStatus, which is enforced to be a no-op for a mismatched // project_id at the store/SQL layer (see internal/store/schedule_test.go's // TestSetDomainStatus_ScopedByProject_ForeignProjectIsNoOp). This test proves // the handler holds up its side of that contract: pid, never a zero value or // some other project, is what gets passed down. func TestCheckEndpoint_ErrorScopesStatusToCallerProject(t *testing.T) { a, m := newTestAPI() ts := a.Store.(*mockTenantStore) m.checkErr = errors.New("boom: provider unreachable") router := NewRouter(a) callerPID := uuid.New() // foreignDID stands in for a domain ID the caller does not own — from the // handler's perspective it's just whatever {did} was in the URL; only the // store layer can (and does) enforce that it isn't actually foreignPID's. foreignDID := uuid.New() req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+callerPID.String()+"/domains/"+foreignDID.String()+"/check", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusInternalServerError { t.Fatalf("expected 500, got %d body %s", w.Code, w.Body.String()) } if len(ts.statusCalls) != 1 { t.Fatalf("expected exactly 1 SetDomainStatus call, got %+v", ts.statusCalls) } call := ts.statusCalls[0] if call.projectID != callerPID { t.Fatalf("expected SetDomainStatus scoped to caller's own pid %s, got projectID %s (never empty/foreign)", callerPID, call.projectID) } if call.domainID != foreignDID || call.status != service.StatusError { t.Fatalf("expected SetDomainStatus(%s, %s, error), got %+v", foreignDID, callerPID, call) } } // TestChangesetResponseEmptyMarshalsToArrays guards the белый-экран bug: an // empty changeset (zone matches its template exactly, e.g. right after a // snapshot) must marshal updates/prunes/readOnly as [] not null — a nil slice // becomes JSON null and crashes the client's .length/.map calls. func TestChangesetResponseEmptyMarshalsToArrays(t *testing.T) { b, err := json.Marshal(toChangesetResponse(diff.Changeset{})) if err != nil { t.Fatal(err) } s := string(b) for _, want := range []string{`"updates":[]`, `"prunes":[]`, `"readOnly":[]`} { if !strings.Contains(s, want) { t.Fatalf("expected %s in %s", want, s) } } }