package api import ( "context" "encoding/json" "errors" "net/http" "net/http/httptest" "strings" "testing" "github.com/google/uuid" "github.com/jackc/pgx/v5" "github.com/vasyakrg/dns-autoresolver/internal/store" ) // --- mock ScheduleStore --- type mockScheduleStore struct { schedule store.Schedule scheduleErr error upsertCalled bool upsertInterval int32 upsertEnabled bool upsertResult store.Schedule upsertErr error channels map[uuid.UUID]store.Channel createChannelIn []struct { ctype string config json.RawMessage secretEnc string } createChannelErr error deleteChannelCalled bool deletedChannelID uuid.UUID deleteChannelErr error domains map[uuid.UUID]store.Domain checkRuns map[uuid.UUID][]store.CheckRun listRunsErr error } func newMockScheduleStore() *mockScheduleStore { return &mockScheduleStore{ channels: map[uuid.UUID]store.Channel{}, domains: map[uuid.UUID]store.Domain{}, checkRuns: map[uuid.UUID][]store.CheckRun{}, } } func (m *mockScheduleStore) GetSchedule(context.Context, uuid.UUID) (store.Schedule, error) { if m.scheduleErr != nil { return store.Schedule{}, m.scheduleErr } return m.schedule, nil } func (m *mockScheduleStore) UpsertSchedule(_ context.Context, projectID uuid.UUID, interval int32, enabled bool) (store.Schedule, error) { m.upsertCalled = true m.upsertInterval = interval m.upsertEnabled = enabled if m.upsertErr != nil { return store.Schedule{}, m.upsertErr } if m.upsertResult.ID != uuid.Nil { return m.upsertResult, nil } return store.Schedule{ID: uuid.New(), ProjectID: projectID, IntervalSeconds: interval, Enabled: enabled}, nil } func (m *mockScheduleStore) CreateChannel(_ context.Context, projectID uuid.UUID, ctype string, config json.RawMessage, secretEnc string) (store.Channel, error) { m.createChannelIn = append(m.createChannelIn, struct { ctype string config json.RawMessage secretEnc string }{ctype, config, secretEnc}) if m.createChannelErr != nil { return store.Channel{}, m.createChannelErr } ch := store.Channel{ID: uuid.New(), ProjectID: projectID, Type: ctype, Config: config, SecretEnc: secretEnc, Enabled: true} m.channels[ch.ID] = ch return ch, nil } func (m *mockScheduleStore) ListChannels(context.Context, uuid.UUID) ([]store.Channel, error) { out := make([]store.Channel, 0, len(m.channels)) for _, c := range m.channels { out = append(out, c) } return out, nil } func (m *mockScheduleStore) GetChannel(_ context.Context, id, _ uuid.UUID) (store.Channel, error) { c, ok := m.channels[id] if !ok { return store.Channel{}, errors.New("channel not found") } return c, nil } func (m *mockScheduleStore) DeleteChannel(_ context.Context, id, _ uuid.UUID) error { m.deleteChannelCalled = true m.deletedChannelID = id if m.deleteChannelErr != nil { return m.deleteChannelErr } delete(m.channels, id) return nil } func (m *mockScheduleStore) GetDomain(_ context.Context, id, _ uuid.UUID) (store.Domain, error) { d, ok := m.domains[id] if !ok { return store.Domain{}, errors.New("domain not found") } return d, nil } func (m *mockScheduleStore) ListCheckRuns(_ context.Context, domainID uuid.UUID) ([]store.CheckRun, error) { if m.listRunsErr != nil { return nil, m.listRunsErr } return m.checkRuns[domainID], nil } // --- mock TestSender --- type mockTestSender struct { err error calledType string calledConfig json.RawMessage calledSecret string called bool } func (m *mockTestSender) SendTest(_ context.Context, channelType string, config json.RawMessage, secret string) error { m.called = true m.calledType = channelType m.calledConfig = config m.calledSecret = secret return m.err } // newScheduleTestAPI wires a fixed authenticated user who owns whatever // project id is requested (alwaysOwnedAuthStore/alwaysValidSessions, see // middleware_test.go) — these tests exercise schedule/channels/history // behavior past the RequireAuth/RequireProjectAccess boundary. func newScheduleTestAPI() (*API, *mockScheduleStore, *mockTestSender) { ms := newMockScheduleStore() mts := &mockTestSender{} a := &API{ Schedule: ms, Dispatch: mts, Cipher: mockCipher{}, Auth: alwaysOwnedAuthStore(), Sessions: alwaysValidSessions(uuid.New()), } return a, ms, mts } // --- schedule --- func TestGetSchedule_DefaultWhenNoRow(t *testing.T) { a, ms, _ := newScheduleTestAPI() ms.scheduleErr = pgx.ErrNoRows router := NewRouter(a) req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+testPID+"/schedule", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("status %d body %s", w.Code, w.Body.String()) } var resp scheduleResponse if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { t.Fatal(err) } if resp.IntervalSeconds != 3600 || resp.Enabled != false { t.Fatalf("expected default {3600,false}, got %+v", resp) } } func TestGetSchedule_Existing(t *testing.T) { a, ms, _ := newScheduleTestAPI() ms.schedule = store.Schedule{ID: uuid.New(), IntervalSeconds: 120, Enabled: true} router := NewRouter(a) req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+testPID+"/schedule", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("status %d body %s", w.Code, w.Body.String()) } var resp scheduleResponse if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { t.Fatal(err) } if resp.IntervalSeconds != 120 || !resp.Enabled { t.Fatalf("expected {120,true}, got %+v", resp) } } func TestPutSchedule_RejectsIntervalBelow60(t *testing.T) { a, ms, _ := newScheduleTestAPI() router := NewRouter(a) body := `{"intervalSeconds":59,"enabled":true}` req := requestWithSessionCookie(http.MethodPut, "/api/v1/projects/"+testPID+"/schedule", strings.NewReader(body)) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusBadRequest { t.Fatalf("expected 400 for interval<60, got %d body %s", w.Code, w.Body.String()) } if ms.upsertCalled { t.Fatal("UpsertSchedule must not be called when validation fails") } } func TestPutSchedule_Success(t *testing.T) { a, ms, _ := newScheduleTestAPI() router := NewRouter(a) body := `{"intervalSeconds":300,"enabled":true}` req := requestWithSessionCookie(http.MethodPut, "/api/v1/projects/"+testPID+"/schedule", strings.NewReader(body)) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("status %d body %s", w.Code, w.Body.String()) } if !ms.upsertCalled || ms.upsertInterval != 300 || !ms.upsertEnabled { t.Fatalf("expected UpsertSchedule(300,true), got called=%v interval=%d enabled=%v", ms.upsertCalled, ms.upsertInterval, ms.upsertEnabled) } } // --- channels --- func TestCreateChannel_EncryptsSecretAndOmitsFromResponse(t *testing.T) { a, ms, _ := newScheduleTestAPI() router := NewRouter(a) body := `{"type":"telegram","config":{"chat_id":"123"},"secret":"super-bot-token"}` req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/channels", strings.NewReader(body)) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusCreated { t.Fatalf("status %d body %s", w.Code, w.Body.String()) } if strings.Contains(w.Body.String(), "super-bot-token") { t.Fatalf("response leaks plaintext secret: %s", w.Body.String()) } if len(ms.createChannelIn) != 1 { t.Fatalf("expected 1 CreateChannel call, got %d", len(ms.createChannelIn)) } got := ms.createChannelIn[0] if got.secretEnc == "" || got.secretEnc == "super-bot-token" || !strings.Contains(got.secretEnc, "super-bot-token") { // mockCipher.Encrypt wraps as ENC(...) — assert it's the *encrypted* form, not the raw plaintext passed unchanged. t.Fatalf("expected secret to be passed through cipher.Encrypt, got secretEnc=%q", got.secretEnc) } if got.secretEnc == "super-bot-token" { t.Fatalf("secret was stored unencrypted") } var resp channelResponse if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { t.Fatal(err) } if resp.Type != "telegram" || resp.ID == "" { t.Fatalf("unexpected response: %+v", resp) } } func TestListChannels_NoSecrets(t *testing.T) { a, ms, _ := newScheduleTestAPI() ms.channels[uuid.New()] = store.Channel{ID: uuid.New(), Type: "webhook", Config: json.RawMessage(`{"url":"https://example.com"}`), SecretEnc: "ENC(top-secret)", Enabled: true} router := NewRouter(a) req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+testPID+"/channels", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("status %d body %s", w.Code, w.Body.String()) } if strings.Contains(w.Body.String(), "top-secret") || strings.Contains(w.Body.String(), "secretEnc") { t.Fatalf("channel list leaks secret: %s", w.Body.String()) } } func TestDeleteChannel(t *testing.T) { a, ms, _ := newScheduleTestAPI() cid := uuid.New() ms.channels[cid] = store.Channel{ID: cid, Type: "webhook"} router := NewRouter(a) req := requestWithSessionCookie(http.MethodDelete, "/api/v1/projects/"+testPID+"/channels/"+cid.String(), nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusNoContent { t.Fatalf("status %d body %s", w.Code, w.Body.String()) } if !ms.deleteChannelCalled || ms.deletedChannelID != cid { t.Fatalf("expected DeleteChannel(%s), called=%v got=%s", cid, ms.deleteChannelCalled, ms.deletedChannelID) } } func TestDeleteChannel_InvalidUUID(t *testing.T) { a, _, _ := newScheduleTestAPI() router := NewRouter(a) req := requestWithSessionCookie(http.MethodDelete, "/api/v1/projects/"+testPID+"/channels/not-a-uuid", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusBadRequest { t.Fatalf("expected 400 for bad channel uuid, got %d", w.Code) } } func TestTestChannel_Success(t *testing.T) { a, ms, mts := newScheduleTestAPI() cid := uuid.New() ms.channels[cid] = store.Channel{ID: cid, Type: "telegram", Config: json.RawMessage(`{"chat_id":"1"}`), SecretEnc: "ENC(bot-token)"} router := NewRouter(a) req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/channels/"+cid.String()+"/test", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("status %d body %s", w.Code, w.Body.String()) } if !mts.called || mts.calledType != "telegram" || mts.calledSecret != "bot-token" { t.Fatalf("expected SendTest(telegram,...,bot-token), got called=%v type=%s secret=%s", mts.called, mts.calledType, mts.calledSecret) } } func TestTestChannel_SenderError_Returns502WithoutSecret(t *testing.T) { a, ms, mts := newScheduleTestAPI() cid := uuid.New() ms.channels[cid] = store.Channel{ID: cid, Type: "telegram", Config: json.RawMessage(`{"chat_id":"1"}`), SecretEnc: "ENC(bot-token)"} mts.err = errors.New("telegram: status 401 Unauthorized (token=bot-token)") router := NewRouter(a) req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/channels/"+cid.String()+"/test", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusBadGateway { t.Fatalf("expected 502 on channel test failure, got %d body %s", w.Code, w.Body.String()) } if strings.Contains(w.Body.String(), "bot-token") { t.Fatalf("error response leaks secret: %s", w.Body.String()) } } func TestTestChannel_UnknownChannel_Returns404(t *testing.T) { a, _, _ := newScheduleTestAPI() router := NewRouter(a) req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/channels/"+uuid.New().String()+"/test", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusNotFound { t.Fatalf("expected 404 for unknown channel, got %d", w.Code) } } // --- history --- func TestDomainHistory_List(t *testing.T) { a, ms, _ := newScheduleTestAPI() did := uuid.New() ms.domains[did] = store.Domain{ID: did} ms.checkRuns[did] = []store.CheckRun{ {ID: uuid.New(), DomainID: did, Result: json.RawMessage(`{"updates":1,"prunes":0}`)}, {ID: uuid.New(), DomainID: did, Result: json.RawMessage(`{"updates":0,"prunes":0}`)}, } router := NewRouter(a) req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+testPID+"/domains/"+did.String()+"/history", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("status %d body %s", w.Code, w.Body.String()) } var resp []checkRunResponse if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { t.Fatal(err) } if len(resp) != 2 { t.Fatalf("expected 2 history entries, got %d", len(resp)) } } func TestDomainHistory_InvalidUUID(t *testing.T) { a, _, _ := newScheduleTestAPI() router := NewRouter(a) req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+testPID+"/domains/not-a-uuid/history", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusBadRequest { t.Fatalf("expected 400 for bad domain uuid, got %d", w.Code) } } // TestDomainHistory_ForeignDomain_Returns404 is the IDOR guard for history: // check_runs.domain_id has no project scoping of its own, so the handler // must verify domain ownership via GetDomain before calling ListCheckRuns — // a domain id the mock doesn't know about (i.e. not in this project) must // 404 rather than fall through to an unscoped history lookup. func TestDomainHistory_ForeignDomain_Returns404(t *testing.T) { a, _, _ := newScheduleTestAPI() router := NewRouter(a) did := uuid.New() // never registered in ms.domains req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+testPID+"/domains/"+did.String()+"/history", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusNotFound { t.Fatalf("expected 404 for a domain not owned by this project, got %d body %s", w.Code, w.Body.String()) } }