27d70a987e
handleCheck's error branch wrote last_check_status via an id-only UPDATE, so an authenticated caller's own valid project id paired with a foreign domain id in the URL could flip a stranger's domain to "error" even though Check itself is project-scoped and would 404/error out first. Add project_id to the WHERE clause (queries/domains.sql + generated db/domains.sql.go), thread projectID through Store/TenantStore/SchedStore SetDomainStatus, and pass pid from context at both call sites in handleCheck plus the scheduler. Also collapse checkDomain's inline status derivation in scheduler.go into a call to service.DeriveStatus, the same helper handleCheck already uses, so there's a single source of truth for "drift vs in_sync" instead of two copies that could drift apart. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01BwxdSt4reTm7Dj1oxRvpP3
321 lines
9.7 KiB
Go
321 lines
9.7 KiB
Go
package store
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/jackc/pgx/v5"
|
|
)
|
|
|
|
// TestUpsertSchedule_InsertThenUpdate verifies UpsertSchedule inserts a new
|
|
// row for a project on the first call and updates that same row (rather
|
|
// than inserting a second one) on a subsequent call, per the
|
|
// ON CONFLICT (project_id) DO UPDATE clause.
|
|
func TestUpsertSchedule_InsertThenUpdate(t *testing.T) {
|
|
s, ctx := newStore(t)
|
|
_, p, err := s.RegisterUser(ctx, "sched-upsert@example.com", "argon2-hash")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
created, err := s.UpsertSchedule(ctx, p.ID, 1800, true)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if created.IntervalSeconds != 1800 || !created.Enabled {
|
|
t.Fatalf("unexpected created schedule: %+v", created)
|
|
}
|
|
|
|
updated, err := s.UpsertSchedule(ctx, p.ID, 7200, false)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if updated.ID != created.ID {
|
|
t.Fatalf("expected same row id, got created=%s updated=%s", created.ID, updated.ID)
|
|
}
|
|
if updated.IntervalSeconds != 7200 || updated.Enabled {
|
|
t.Fatalf("unexpected updated schedule: %+v", updated)
|
|
}
|
|
|
|
got, err := s.GetSchedule(ctx, p.ID)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if got.IntervalSeconds != 7200 || got.Enabled {
|
|
t.Fatalf("GetSchedule mismatch after update: %+v", got)
|
|
}
|
|
}
|
|
|
|
// TestGetSchedule_NoRowReturnsErrNoRows verifies the contract used by the API
|
|
// layer (Task 5): a project with no schedule row yet returns pgx.ErrNoRows,
|
|
// which the API translates into the default {interval:3600, enabled:false}.
|
|
func TestGetSchedule_NoRowReturnsErrNoRows(t *testing.T) {
|
|
s, ctx := newStore(t)
|
|
_, p, err := s.RegisterUser(ctx, "sched-norow@example.com", "argon2-hash")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if _, err := s.GetSchedule(ctx, p.ID); !errors.Is(err, pgx.ErrNoRows) {
|
|
t.Fatalf("expected pgx.ErrNoRows, got %v", err)
|
|
}
|
|
}
|
|
|
|
// TestListDueSchedules verifies the due-selection logic: an enabled schedule
|
|
// that never ran (last_run_at IS NULL) is due; a disabled schedule is never
|
|
// due; and an enabled schedule that ran recently with a long interval is not
|
|
// yet due.
|
|
func TestListDueSchedules(t *testing.T) {
|
|
s, ctx := newStore(t)
|
|
now := time.Now().UTC()
|
|
|
|
_, neverRunProject, err := s.RegisterUser(ctx, "sched-neverrun@example.com", "argon2-hash")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if _, err := s.UpsertSchedule(ctx, neverRunProject.ID, 3600, true); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
_, disabledProject, err := s.RegisterUser(ctx, "sched-disabled@example.com", "argon2-hash")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if _, err := s.UpsertSchedule(ctx, disabledProject.ID, 60, false); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
_, recentProject, err := s.RegisterUser(ctx, "sched-recent@example.com", "argon2-hash")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if _, err := s.UpsertSchedule(ctx, recentProject.ID, 3600, true); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := s.TouchScheduleRun(ctx, recentProject.ID, now); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
due, err := s.ListDueSchedules(ctx, now)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
byProject := make(map[uuid.UUID]bool, len(due))
|
|
for _, d := range due {
|
|
byProject[d.ProjectID] = true
|
|
}
|
|
|
|
if !byProject[neverRunProject.ID] {
|
|
t.Errorf("expected enabled/never-run schedule for project %s to be due", neverRunProject.ID)
|
|
}
|
|
if byProject[disabledProject.ID] {
|
|
t.Errorf("did not expect disabled schedule for project %s to be due", disabledProject.ID)
|
|
}
|
|
if byProject[recentProject.ID] {
|
|
t.Errorf("did not expect recently-run schedule (long interval) for project %s to be due", recentProject.ID)
|
|
}
|
|
}
|
|
|
|
// TestTouchScheduleRun_SetsLastRunAt verifies TouchScheduleRun persists
|
|
// last_run_at, which GetSchedule then returns as a non-nil *time.Time close
|
|
// to the value passed in.
|
|
func TestTouchScheduleRun_SetsLastRunAt(t *testing.T) {
|
|
s, ctx := newStore(t)
|
|
_, p, err := s.RegisterUser(ctx, "sched-touch@example.com", "argon2-hash")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if _, err := s.UpsertSchedule(ctx, p.ID, 3600, true); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
at := time.Now().UTC().Truncate(time.Second)
|
|
if err := s.TouchScheduleRun(ctx, p.ID, at); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
got, err := s.GetSchedule(ctx, p.ID)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if got.LastRunAt == nil {
|
|
t.Fatal("expected non-nil LastRunAt after TouchScheduleRun")
|
|
}
|
|
if diff := got.LastRunAt.Sub(at); diff < -time.Second || diff > time.Second {
|
|
t.Fatalf("expected LastRunAt ~%v, got %v", at, *got.LastRunAt)
|
|
}
|
|
}
|
|
|
|
// TestChannelCRUD_ScopedByProject verifies CreateChannel/ListChannels/
|
|
// GetChannel/DeleteChannel round-trip correctly and that GetChannel scopes
|
|
// by project_id: looking up a channel with the wrong project ID must fail
|
|
// with pgx.ErrNoRows rather than returning another tenant's channel.
|
|
func TestChannelCRUD_ScopedByProject(t *testing.T) {
|
|
s, ctx := newStore(t)
|
|
_, p1, err := s.RegisterUser(ctx, "chan-owner@example.com", "argon2-hash")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
_, p2, err := s.RegisterUser(ctx, "chan-other@example.com", "argon2-hash")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
cfg := json.RawMessage(`{"webhook_url":"https://example.com/hook"}`)
|
|
ch, err := s.CreateChannel(ctx, p1.ID, "telegram", cfg, "enc-secret")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
// jsonb round-trips through Postgres with its own canonical formatting
|
|
// (e.g. a space after ':'), so compare decoded values rather than raw
|
|
// bytes.
|
|
var gotCfg, wantCfg map[string]string
|
|
if err := json.Unmarshal(ch.Config, &gotCfg); err != nil {
|
|
t.Fatalf("unmarshal returned config: %v", err)
|
|
}
|
|
if err := json.Unmarshal(cfg, &wantCfg); err != nil {
|
|
t.Fatalf("unmarshal expected config: %v", err)
|
|
}
|
|
if ch.Type != "telegram" || !ch.Enabled || gotCfg["webhook_url"] != wantCfg["webhook_url"] || ch.SecretEnc != "enc-secret" {
|
|
t.Fatalf("unexpected created channel: %+v", ch)
|
|
}
|
|
|
|
list, err := s.ListChannels(ctx, p1.ID)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(list) != 1 || list[0].ID != ch.ID {
|
|
t.Fatalf("unexpected ListChannels result: %+v", list)
|
|
}
|
|
|
|
enabledList, err := s.ListEnabledChannels(ctx, p1.ID)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(enabledList) != 1 || enabledList[0].ID != ch.ID {
|
|
t.Fatalf("unexpected ListEnabledChannels result: %+v", enabledList)
|
|
}
|
|
|
|
got, err := s.GetChannel(ctx, ch.ID, p1.ID)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if got.ID != ch.ID {
|
|
t.Fatalf("GetChannel mismatch: %+v", got)
|
|
}
|
|
|
|
if _, err := s.GetChannel(ctx, ch.ID, p2.ID); !errors.Is(err, pgx.ErrNoRows) {
|
|
t.Fatalf("expected pgx.ErrNoRows for foreign project, got %v", err)
|
|
}
|
|
|
|
if err := s.DeleteChannel(ctx, ch.ID, p1.ID); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if _, err := s.GetChannel(ctx, ch.ID, p1.ID); !errors.Is(err, pgx.ErrNoRows) {
|
|
t.Fatalf("expected pgx.ErrNoRows after delete, got %v", err)
|
|
}
|
|
}
|
|
|
|
// TestDomainStatus_RoundTrip verifies SetDomainStatus/GetDomainStatus
|
|
// round-trip, and that a freshly-imported domain defaults to "unknown" per
|
|
// the migration's DEFAULT 'unknown'.
|
|
func TestDomainStatus_RoundTrip(t *testing.T) {
|
|
s, ctx := newStore(t)
|
|
_, p, err := s.RegisterUser(ctx, "domain-status@example.com", "argon2-hash")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
acc, err := s.CreateAccount(ctx, p.ID, "selectel", "enc-blob", "test")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
d, err := s.CreateDomain(ctx, p.ID, acc.ID, "example.com", "zone-1", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
status, err := s.GetDomainStatus(ctx, d.ID)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if status != "unknown" {
|
|
t.Fatalf("expected default status 'unknown', got %q", status)
|
|
}
|
|
|
|
if err := s.SetDomainStatus(ctx, d.ID, p.ID, "ok"); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
status, err = s.GetDomainStatus(ctx, d.ID)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if status != "ok" {
|
|
t.Fatalf("expected status 'ok' after SetDomainStatus, got %q", status)
|
|
}
|
|
|
|
domains, err := s.ListDomains(ctx, p.ID)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(domains) != 1 || domains[0].LastCheckStatus != "ok" {
|
|
t.Fatalf("expected ListDomains to reflect updated status: %+v", domains)
|
|
}
|
|
}
|
|
|
|
// TestSetDomainStatus_ScopedByProject_ForeignProjectIsNoOp covers the IDOR
|
|
// fix: SetDomainStatus is called with a valid domain ID but a projectID that
|
|
// does NOT own it (e.g. an authenticated caller's own pid, paired with
|
|
// another tenant's did in the URL). The WHERE id = $1 AND project_id = $3
|
|
// clause must match zero rows — no error, but the foreign domain's status
|
|
// must remain untouched, never "error"/"drift"/whatever was passed in.
|
|
func TestSetDomainStatus_ScopedByProject_ForeignProjectIsNoOp(t *testing.T) {
|
|
s, ctx := newStore(t)
|
|
_, owner, err := s.RegisterUser(ctx, "domain-status-owner@example.com", "argon2-hash")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
_, attacker, err := s.RegisterUser(ctx, "domain-status-attacker@example.com", "argon2-hash")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
acc, err := s.CreateAccount(ctx, owner.ID, "selectel", "enc-blob", "test")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
d, err := s.CreateDomain(ctx, owner.ID, acc.ID, "example.com", "zone-1", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// attacker's own (valid) project ID, paired with owner's domain ID —
|
|
// mirrors the exact request shape an authenticated attacker could send.
|
|
if err := s.SetDomainStatus(ctx, d.ID, attacker.ID, "error"); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
status, err := s.GetDomainStatus(ctx, d.ID)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if status != "unknown" {
|
|
t.Fatalf("expected foreign-project SetDomainStatus to be a no-op, but status changed to %q", status)
|
|
}
|
|
|
|
// The legitimate owner can still update it — proves the no-op above was
|
|
// due to project scoping, not some unrelated write failure.
|
|
if err := s.SetDomainStatus(ctx, d.ID, owner.ID, "error"); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
status, err = s.GetDomainStatus(ctx, d.ID)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if status != "error" {
|
|
t.Fatalf("expected owner's SetDomainStatus to apply, got %q", status)
|
|
}
|
|
}
|