diff --git a/internal/api/api.go b/internal/api/api.go index 249d72b..c3cff8f 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -40,6 +40,7 @@ type TenantStore interface { ListDomains(ctx context.Context, projectID uuid.UUID) ([]store.Domain, error) DeleteDomain(ctx context.Context, id, projectID uuid.UUID) error ImportDomains(ctx context.Context, projectID, accountID uuid.UUID, zones []provider.Zone) ([]store.Domain, error) + SetDomainTemplate(ctx context.Context, domainID, projectID uuid.UUID, templateID *uuid.UUID) (store.Domain, error) } // Cipher encrypts/decrypts provider account secrets. *crypto.Cipher satisfies it. @@ -73,6 +74,7 @@ func NewRouter(a *API) http.Handler { r.Route("/{did}", func(r chi.Router) { r.Get("/check", a.handleCheck) r.Post("/apply", a.handleApply) + r.Patch("/", a.handleSetDomainTemplate) r.Delete("/", a.handleDeleteDomain) }) }) diff --git a/internal/api/tenant_dto.go b/internal/api/tenant_dto.go index 58baddb..b3dd1a2 100644 --- a/internal/api/tenant_dto.go +++ b/internal/api/tenant_dto.go @@ -47,6 +47,12 @@ type domainRequest struct { TemplateID *string `json:"templateId,omitempty"` } +// updateDomainTemplateRequest is the PATCH .../domains/{did} body used to +// bind (or clear, when templateId is null/omitted) a domain's DNS template. +type updateDomainTemplateRequest struct { + TemplateID *string `json:"templateId"` +} + type domainResponse struct { ID string `json:"id"` ProviderAccountID string `json:"providerAccountId"` diff --git a/internal/api/tenant_handlers.go b/internal/api/tenant_handlers.go index e952d79..02f3dcc 100644 --- a/internal/api/tenant_handlers.go +++ b/internal/api/tenant_handlers.go @@ -304,6 +304,40 @@ func (a *API) handleListDomains(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusOK, resp) } +// handleSetDomainTemplate binds (or clears) the DNS template used to +// check/apply a domain — this is what makes an imported domain (which +// starts with template_id=NULL) checkable, closing the import→check loop. +func (a *API) handleSetDomainTemplate(w http.ResponseWriter, r *http.Request) { + pid, err := uuid.Parse(chi.URLParam(r, "pid")) + if err != nil { + writeErr(w, http.StatusBadRequest, "invalid project id") + return + } + did, err := uuid.Parse(chi.URLParam(r, "did")) + if err != nil { + writeErr(w, http.StatusBadRequest, "invalid domain id") + return + } + var req updateDomainTemplateRequest + if !decodeBody(w, r, &req) { + return + } + templateID, ok := parseOptionalUUID(req.TemplateID) + if !ok { + writeErr(w, http.StatusBadRequest, "invalid templateId") + return + } + + dom, err := a.Store.SetDomainTemplate(r.Context(), did, pid, templateID) + if err != nil { + // Either the domain itself or the (scoped) template wasn't found in + // this project — treat both as 404 rather than leak which one. + writeErr(w, http.StatusNotFound, "domain or template not found") + return + } + writeJSON(w, http.StatusOK, toDomainResponse(dom)) +} + func (a *API) handleDeleteDomain(w http.ResponseWriter, r *http.Request) { pid, err := uuid.Parse(chi.URLParam(r, "pid")) if err != nil { diff --git a/internal/api/tenant_test.go b/internal/api/tenant_test.go index afc91f3..2b807a8 100644 --- a/internal/api/tenant_test.go +++ b/internal/api/tenant_test.go @@ -35,6 +35,8 @@ type mockTenantStore struct { importDomains []store.Domain importDomainsErr error importCalled bool + + setDomainTemplateErr error } func (m *mockTenantStore) CreateAccount(_ context.Context, projectID uuid.UUID, prov, secretEnc, comment string) (store.Account, error) { @@ -98,6 +100,21 @@ func (m *mockTenantStore) ListDomains(context.Context, uuid.UUID) ([]store.Domai func (m *mockTenantStore) DeleteDomain(context.Context, uuid.UUID, uuid.UUID) error { return nil } +func (m *mockTenantStore) SetDomainTemplate(_ context.Context, domainID, projectID uuid.UUID, templateID *uuid.UUID) (store.Domain, error) { + if m.setDomainTemplateErr != nil { + return store.Domain{}, m.setDomainTemplateErr + } + for i, d := range m.domains { + if d.ID == domainID { + m.domains[i].TemplateID = templateID + return m.domains[i], nil + } + } + d := store.Domain{ID: domainID, ProjectID: projectID, TemplateID: templateID} + m.domains = append(m.domains, d) + return d, nil +} + func (m *mockTenantStore) ImportDomains(_ context.Context, projectID, accountID uuid.UUID, zones []provider.Zone) ([]store.Domain, error) { m.importCalled = true if m.importDomainsErr != nil { @@ -463,6 +480,65 @@ func TestCreateDomain_ValidTemplateInProject(t *testing.T) { } } +// --- domain template binding (import -> check loop) --- + +func TestSetDomainTemplate_ValidTemplateId(t *testing.T) { + a, ts := newTenantTestAPI() + domID := uuid.New() + tplID := uuid.New() + ts.domains = []store.Domain{{ID: domID, ZoneName: "example.com", ZoneID: "z1"}} + router := NewRouter(a) + + body := `{"templateId":"` + tplID.String() + `"}` + req := httptest.NewRequest(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body)) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body %s", w.Code, w.Body.String()) + } + var resp domainResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatal(err) + } + if resp.TemplateID == nil || *resp.TemplateID != tplID.String() { + t.Fatalf("unexpected response: %+v", resp) + } +} + +func TestSetDomainTemplate_BadTemplateUUID(t *testing.T) { + a, ts := newTenantTestAPI() + domID := uuid.New() + ts.domains = []store.Domain{{ID: domID}} + router := NewRouter(a) + + body := `{"templateId":"not-a-uuid"}` + req := httptest.NewRequest(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body)) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d body %s", w.Code, w.Body.String()) + } +} + +func TestSetDomainTemplate_TemplateNotFound(t *testing.T) { + a, ts := newTenantTestAPI() + domID := uuid.New() + ts.domains = []store.Domain{{ID: domID}} + ts.setDomainTemplateErr = errors.New("template not found in project") + router := NewRouter(a) + + body := `{"templateId":"` + uuid.New().String() + `"}` + req := httptest.NewRequest(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body)) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Fatalf("expected 404, got %d body %s", w.Code, w.Body.String()) + } +} + func TestDeleteDomain_BadUUID(t *testing.T) { a, _ := newTenantTestAPI() router := NewRouter(a) diff --git a/internal/store/db/domains.sql.go b/internal/store/db/domains.sql.go index 24baed9..3b5aa32 100644 --- a/internal/store/db/domains.sql.go +++ b/internal/store/db/domains.sql.go @@ -87,6 +87,44 @@ func (q *Queries) GetDomain(ctx context.Context, arg GetDomainParams) (Domain, e return i, err } +const importDomain = `-- name: ImportDomain :one +INSERT INTO domains (id, project_id, provider_account_id, zone_name, zone_id, template_id) +VALUES ($1, $2, $3, $4, $5, $6) +ON CONFLICT (project_id, zone_id) DO NOTHING +RETURNING id, project_id, provider_account_id, zone_name, zone_id, template_id, created_at +` + +type ImportDomainParams struct { + ID uuid.UUID `json:"id"` + ProjectID uuid.UUID `json:"project_id"` + ProviderAccountID uuid.UUID `json:"provider_account_id"` + ZoneName string `json:"zone_name"` + ZoneID string `json:"zone_id"` + TemplateID *uuid.UUID `json:"template_id"` +} + +func (q *Queries) ImportDomain(ctx context.Context, arg ImportDomainParams) (Domain, error) { + row := q.db.QueryRow(ctx, importDomain, + arg.ID, + arg.ProjectID, + arg.ProviderAccountID, + arg.ZoneName, + arg.ZoneID, + arg.TemplateID, + ) + var i Domain + err := row.Scan( + &i.ID, + &i.ProjectID, + &i.ProviderAccountID, + &i.ZoneName, + &i.ZoneID, + &i.TemplateID, + &i.CreatedAt, + ) + return i, err +} + const listDomains = `-- name: ListDomains :many SELECT id, project_id, provider_account_id, zone_name, zone_id, template_id, created_at FROM domains WHERE project_id = $1 ORDER BY created_at ` @@ -145,3 +183,29 @@ func (q *Queries) LoadDomainFull(ctx context.Context, id uuid.UUID) (LoadDomainF ) return i, err } + +const updateDomainTemplate = `-- name: UpdateDomainTemplate :one +UPDATE domains SET template_id = $3 WHERE id = $1 AND project_id = $2 +RETURNING id, project_id, provider_account_id, zone_name, zone_id, template_id, created_at +` + +type UpdateDomainTemplateParams struct { + ID uuid.UUID `json:"id"` + ProjectID uuid.UUID `json:"project_id"` + TemplateID *uuid.UUID `json:"template_id"` +} + +func (q *Queries) UpdateDomainTemplate(ctx context.Context, arg UpdateDomainTemplateParams) (Domain, error) { + row := q.db.QueryRow(ctx, updateDomainTemplate, arg.ID, arg.ProjectID, arg.TemplateID) + var i Domain + err := row.Scan( + &i.ID, + &i.ProjectID, + &i.ProviderAccountID, + &i.ZoneName, + &i.ZoneID, + &i.TemplateID, + &i.CreatedAt, + ) + return i, err +} diff --git a/internal/store/migrations/0002_domains_unique.sql b/internal/store/migrations/0002_domains_unique.sql new file mode 100644 index 0000000..78c82d5 --- /dev/null +++ b/internal/store/migrations/0002_domains_unique.sql @@ -0,0 +1,5 @@ +-- +goose Up +ALTER TABLE domains ADD CONSTRAINT domains_project_zone_uniq UNIQUE (project_id, zone_id); + +-- +goose Down +ALTER TABLE domains DROP CONSTRAINT domains_project_zone_uniq; diff --git a/internal/store/queries/domains.sql b/internal/store/queries/domains.sql index 4a85d59..1f05fda 100644 --- a/internal/store/queries/domains.sql +++ b/internal/store/queries/domains.sql @@ -3,6 +3,16 @@ INSERT INTO domains (id, project_id, provider_account_id, zone_name, zone_id, te VALUES ($1, $2, $3, $4, $5, $6) RETURNING *; +-- name: ImportDomain :one +INSERT INTO domains (id, project_id, provider_account_id, zone_name, zone_id, template_id) +VALUES ($1, $2, $3, $4, $5, $6) +ON CONFLICT (project_id, zone_id) DO NOTHING +RETURNING *; + +-- name: UpdateDomainTemplate :one +UPDATE domains SET template_id = $3 WHERE id = $1 AND project_id = $2 +RETURNING *; + -- name: GetDomain :one SELECT * FROM domains WHERE id = $1 AND project_id = $2; diff --git a/internal/store/store_test.go b/internal/store/store_test.go index 0f08f2c..bcbdc12 100644 --- a/internal/store/store_test.go +++ b/internal/store/store_test.go @@ -143,3 +143,137 @@ func TestImportDomains_RollsBackAllOnError(t *testing.T) { t.Fatalf("expected 0 domains after rollback, got %d", len(list)) } } + +// TestImportDomains_IdempotentOnRepeat verifies the fix for the import +// idempotency gap: re-importing the same zones must not create duplicate +// domains (enforced by the domains_project_zone_uniq constraint + ON +// CONFLICT DO NOTHING in the ImportDomain query) and must not error. +func TestImportDomains_IdempotentOnRepeat(t *testing.T) { + s, ctx := newStore(t) + acc, err := s.Queries().CreateAccount(ctx, db.CreateAccountParams{ + ID: uuid.New(), ProjectID: defaultProject, Provider: "selectel", SecretEnc: "enc-blob", + }) + if err != nil { + t.Fatal(err) + } + + zones := []provider.Zone{ + {ID: "z1", Name: "a.example.com"}, + {ID: "z2", Name: "b.example.com"}, + } + first, err := s.ImportDomains(ctx, defaultProject, acc.ID, zones) + if err != nil { + t.Fatal(err) + } + if len(first) != 2 { + t.Fatalf("expected 2 domains on first import, got %d", len(first)) + } + + second, err := s.ImportDomains(ctx, defaultProject, acc.ID, zones) + if err != nil { + t.Fatalf("expected repeat import to succeed idempotently, got error: %v", err) + } + if len(second) != 0 { + t.Fatalf("expected 0 newly-created domains on repeat import, got %d", len(second)) + } + + list, err := s.ListDomains(ctx, defaultProject) + if err != nil { + t.Fatal(err) + } + if len(list) != 2 { + t.Fatalf("expected still exactly 2 domains (no duplicates), got %d", len(list)) + } + + var count int + row := s.pool.QueryRow(ctx, `SELECT COUNT(*) FROM domains WHERE project_id = $1 AND zone_id = $2`, defaultProject, "z1") + if err := row.Scan(&count); err != nil { + t.Fatal(err) + } + if count != 1 { + t.Fatalf("expected COUNT=1 for zone z1 (UNIQUE constraint), got %d", count) + } +} + +// TestSetDomainTemplate_ClosesImportCheckLoop verifies the fix for the +// second review gap: an imported domain (template_id=NULL) can be bound to +// a template via SetDomainTemplate, after which LoadDomain succeeds and +// returns that template — closing the import -> bind -> check cycle. +func TestSetDomainTemplate_ClosesImportCheckLoop(t *testing.T) { + s, ctx := newStore(t) + acc, err := s.Queries().CreateAccount(ctx, db.CreateAccountParams{ + ID: uuid.New(), ProjectID: defaultProject, Provider: "selectel", SecretEnc: "enc-blob", + }) + if err != nil { + t.Fatal(err) + } + doms, err := s.ImportDomains(ctx, defaultProject, acc.ID, []provider.Zone{{ID: "z1", Name: "a.example.com"}}) + if err != nil { + t.Fatal(err) + } + dom := doms[0] + + // Before binding, the domain is not checkable. + if _, err := s.LoadDomain(ctx, dom.ID); err == nil { + t.Fatal("expected LoadDomain to fail before a template is bound") + } + + doc := dto.TemplateDoc{Records: []dto.RecordDTO{ + {Type: "A", Name: "www.a.example.com.", TTL: 300, Values: []string{"1.2.3.4"}}, + }} + tpl, err := s.CreateTemplate(ctx, defaultProject, "base", doc) + if err != nil { + t.Fatal(err) + } + + updated, err := s.SetDomainTemplate(ctx, dom.ID, defaultProject, &tpl.ID) + if err != nil { + t.Fatal(err) + } + if updated.TemplateID == nil || *updated.TemplateID != tpl.ID { + t.Fatalf("expected domain.TemplateID=%s, got %+v", tpl.ID, updated.TemplateID) + } + + ref, err := s.LoadDomain(ctx, dom.ID) + if err != nil { + t.Fatalf("expected LoadDomain to succeed after binding template, got error: %v", err) + } + if len(ref.Template.Records) != 1 || ref.Template.Records[0].Type != "A" { + t.Fatalf("unexpected template loaded: %+v", ref.Template) + } +} + +// TestSetDomainTemplate_RejectsForeignProjectTemplate verifies that binding +// a template belonging to a different project is rejected rather than +// silently succeeding (which would let one tenant's domain use another +// tenant's DNS template). +func TestSetDomainTemplate_RejectsForeignProjectTemplate(t *testing.T) { + s, ctx := newStore(t) + acc, err := s.Queries().CreateAccount(ctx, db.CreateAccountParams{ + ID: uuid.New(), ProjectID: defaultProject, Provider: "selectel", SecretEnc: "enc-blob", + }) + if err != nil { + t.Fatal(err) + } + doms, err := s.ImportDomains(ctx, defaultProject, acc.ID, []provider.Zone{{ID: "z1", Name: "a.example.com"}}) + if err != nil { + t.Fatal(err) + } + dom := doms[0] + + // A template that belongs to a different (foreign) project. The default + // user is the seed tenant from migrations/0001_init.sql. + defaultUser := uuid.MustParse("00000000-0000-0000-0000-000000000001") + foreignProject := uuid.New() + if _, err := s.pool.Exec(ctx, `INSERT INTO projects (id, user_id, name) VALUES ($1, $2, 'foreign')`, foreignProject, defaultUser); err != nil { + t.Fatal(err) + } + foreignTpl, err := s.CreateTemplate(ctx, foreignProject, "foreign", dto.TemplateDoc{}) + if err != nil { + t.Fatal(err) + } + + if _, err := s.SetDomainTemplate(ctx, dom.ID, defaultProject, &foreignTpl.ID); err == nil { + t.Fatal("expected error binding a template from a different project, got nil") + } +} diff --git a/internal/store/tenant.go b/internal/store/tenant.go index dc8908e..724f811 100644 --- a/internal/store/tenant.go +++ b/internal/store/tenant.go @@ -2,8 +2,10 @@ package store import ( "context" + "errors" "github.com/google/uuid" + "github.com/jackc/pgx/v5" "github.com/vasyakrg/dns-autoresolver/internal/provider" "github.com/vasyakrg/dns-autoresolver/internal/store/db" @@ -166,6 +168,12 @@ func (s *Store) DeleteDomain(ctx context.Context, id, projectID uuid.UUID) error // ImportDomains creates one domain per zone inside a single transaction: if // any zone fails to be created, the whole batch is rolled back so callers // never observe a partially-imported set of domains. +// +// Import is idempotent: zones that already have a domain for this project +// (enforced by the domains_project_zone_uniq constraint) are silently +// skipped via ON CONFLICT DO NOTHING rather than erroring or duplicating — +// so a repeated POST .../import never creates duplicate domains. Only the +// zones that were actually newly created are returned. func (s *Store) ImportDomains(ctx context.Context, projectID, accountID uuid.UUID, zones []provider.Zone) ([]Domain, error) { tx, err := s.pool.Begin(ctx) if err != nil { @@ -176,11 +184,16 @@ func (s *Store) ImportDomains(ctx context.Context, projectID, accountID uuid.UUI q := s.q.WithTx(tx) out := make([]Domain, 0, len(zones)) for _, z := range zones { - d, err := q.CreateDomain(ctx, db.CreateDomainParams{ + d, err := q.ImportDomain(ctx, db.ImportDomainParams{ ID: uuid.New(), ProjectID: projectID, ProviderAccountID: accountID, ZoneName: z.Name, ZoneID: z.ID, TemplateID: nil, }) if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + // ON CONFLICT DO NOTHING: this zone was already imported + // for this project — skip it rather than fail the batch. + continue + } return nil, err } out = append(out, domainFromDB(d)) @@ -190,3 +203,22 @@ func (s *Store) ImportDomains(ctx context.Context, projectID, accountID uuid.UUI } return out, nil } + +// SetDomainTemplate attaches (or clears, when templateID is nil) the DNS +// template used to check/apply a domain. When templateID is non-nil it must +// belong to the same project — verified via scoped GetTemplate — otherwise +// a caller could bind a domain to another tenant's template. +func (s *Store) SetDomainTemplate(ctx context.Context, domainID, projectID uuid.UUID, templateID *uuid.UUID) (Domain, error) { + if templateID != nil { + if _, err := s.GetTemplate(ctx, *templateID, projectID); err != nil { + return Domain{}, err + } + } + d, err := s.q.UpdateDomainTemplate(ctx, db.UpdateDomainTemplateParams{ + ID: domainID, ProjectID: projectID, TemplateID: templateID, + }) + if err != nil { + return Domain{}, err + } + return domainFromDB(d), nil +}