From 2aca92d070e23cc6e797d55258dfe1a0e54797d2 Mon Sep 17 00:00:00 2001 From: Vassiliy Yegorov Date: Fri, 3 Jul 2026 15:08:16 +0700 Subject: [PATCH] =?UTF-8?q?fix(api):=20tenant-=D0=BF=D1=80=D0=BE=D0=B2?= =?UTF-8?q?=D0=B5=D1=80=D0=BA=D0=B0=20account/template=20=D0=B2=20CreateDo?= =?UTF-8?q?main=20(HIGH),=20=D0=B0=D1=82=D0=BE=D0=BC=D0=B0=D1=80=D0=BD?= =?UTF-8?q?=D1=8B=D0=B9=20import=20=D1=87=D0=B5=D1=80=D0=B5=D0=B7=20=D1=82?= =?UTF-8?q?=D1=80=D0=B0=D0=BD=D0=B7=D0=B0=D0=BA=D1=86=D0=B8=D1=8E=20(MEDIU?= =?UTF-8?q?M)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/api/api.go | 2 + internal/api/tenant_handlers.go | 33 +++++-- internal/api/tenant_test.go | 170 +++++++++++++++++++++++++++++++- internal/store/store_test.go | 55 +++++++++++ internal/store/tenant.go | 39 ++++++++ 5 files changed, 288 insertions(+), 11 deletions(-) diff --git a/internal/api/api.go b/internal/api/api.go index a4c53eb..249d72b 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -34,10 +34,12 @@ type TenantStore interface { ListTemplates(ctx context.Context, projectID uuid.UUID) ([]store.Template, error) UpdateTemplate(ctx context.Context, id, projectID uuid.UUID, name string, doc dto.TemplateDoc) (store.Template, error) DeleteTemplate(ctx context.Context, id, projectID uuid.UUID) error + GetTemplate(ctx context.Context, id, projectID uuid.UUID) (store.Template, error) CreateDomain(ctx context.Context, projectID, accountID uuid.UUID, zoneName, zoneID string, templateID *uuid.UUID) (store.Domain, error) ListDomains(ctx context.Context, projectID uuid.UUID) ([]store.Domain, error) DeleteDomain(ctx context.Context, id, projectID uuid.UUID) error + ImportDomains(ctx context.Context, projectID, accountID uuid.UUID, zones []provider.Zone) ([]store.Domain, error) } // Cipher encrypts/decrypts provider account secrets. *crypto.Cipher satisfies it. diff --git a/internal/api/tenant_handlers.go b/internal/api/tenant_handlers.go index 0e79188..e952d79 100644 --- a/internal/api/tenant_handlers.go +++ b/internal/api/tenant_handlers.go @@ -127,14 +127,16 @@ func (a *API) handleImportZones(w http.ResponseWriter, r *http.Request) { writeErr(w, http.StatusInternalServerError, "internal error") return } - created := make([]domainResponse, 0, len(zones)) - for _, z := range zones { - d, err := a.Store.CreateDomain(r.Context(), pid, aid, z.Name, z.ID, nil) - if err != nil { - log.Printf("api: import: create domain failed: %v", err) - writeErr(w, http.StatusInternalServerError, "internal error") - return - } + // Imported atomically: either every zone becomes a domain or none does, + // so a mid-batch provider/DB error never leaves a partial import behind. + doms, err := a.Store.ImportDomains(r.Context(), pid, aid, zones) + if err != nil { + log.Printf("api: import: create domains failed: %v", err) + writeErr(w, http.StatusInternalServerError, "internal error") + return + } + created := make([]domainResponse, 0, len(doms)) + for _, d := range doms { created = append(created, toDomainResponse(d)) } writeJSON(w, http.StatusCreated, created) @@ -259,6 +261,21 @@ func (a *API) handleCreateDomain(w http.ResponseWriter, r *http.Request) { writeErr(w, http.StatusBadRequest, "invalid templateId") return } + + // Tenant isolation: the account (and template, if given) must belong to + // this project — otherwise a caller could attach a domain to another + // tenant's provider account or DNS template. + if _, err := a.Store.GetAccount(r.Context(), accID, pid); err != nil { + writeErr(w, http.StatusNotFound, "provider account not found") + return + } + if templateID != nil { + if _, err := a.Store.GetTemplate(r.Context(), *templateID, pid); err != nil { + writeErr(w, http.StatusNotFound, "template not found") + return + } + } + dom, err := a.Store.CreateDomain(r.Context(), pid, accID, req.ZoneName, req.ZoneID, templateID) if err != nil { log.Printf("api: create domain failed: %v", err) diff --git a/internal/api/tenant_test.go b/internal/api/tenant_test.go index 4e24b01..afc91f3 100644 --- a/internal/api/tenant_test.go +++ b/internal/api/tenant_test.go @@ -3,6 +3,7 @@ package api import ( "context" "encoding/json" + "errors" "net/http" "net/http/httptest" "strings" @@ -30,6 +31,10 @@ type mockTenantStore struct { domains []store.Domain createDomains int + + importDomains []store.Domain + importDomainsErr error + importCalled bool } func (m *mockTenantStore) CreateAccount(_ context.Context, projectID uuid.UUID, prov, secretEnc, comment string) (store.Account, error) { @@ -49,7 +54,7 @@ func (m *mockTenantStore) GetAccount(_ context.Context, id, _ uuid.UUID) (store. return a, nil } } - return m.accounts[0], nil + return store.Account{}, errors.New("account not found") } func (m *mockTenantStore) DeleteAccount(context.Context, uuid.UUID, uuid.UUID) error { return nil } @@ -71,6 +76,15 @@ func (m *mockTenantStore) UpdateTemplate(_ context.Context, id, projectID uuid.U func (m *mockTenantStore) DeleteTemplate(context.Context, uuid.UUID, uuid.UUID) error { return nil } +func (m *mockTenantStore) GetTemplate(_ context.Context, id, _ uuid.UUID) (store.Template, error) { + for _, t := range m.templates { + if t.ID == id { + return t, nil + } + } + return store.Template{}, errors.New("template not found") +} + func (m *mockTenantStore) CreateDomain(_ context.Context, projectID, accountID uuid.UUID, zoneName, zoneID string, templateID *uuid.UUID) (store.Domain, error) { m.createDomains++ d := store.Domain{ID: uuid.New(), ProjectID: projectID, ProviderAccountID: accountID, ZoneName: zoneName, ZoneID: zoneID, TemplateID: templateID} @@ -84,6 +98,21 @@ func (m *mockTenantStore) ListDomains(context.Context, uuid.UUID) ([]store.Domai func (m *mockTenantStore) DeleteDomain(context.Context, uuid.UUID, uuid.UUID) error { return nil } +func (m *mockTenantStore) ImportDomains(_ context.Context, projectID, accountID uuid.UUID, zones []provider.Zone) ([]store.Domain, error) { + m.importCalled = true + if m.importDomainsErr != nil { + return nil, m.importDomainsErr + } + out := make([]store.Domain, 0, len(zones)) + for _, z := range zones { + d := store.Domain{ID: uuid.New(), ProjectID: projectID, ProviderAccountID: accountID, ZoneName: z.Name, ZoneID: z.ID} + out = append(out, d) + } + m.domains = append(m.domains, out...) + m.importDomains = out + return out, nil +} + type mockCipher struct{} func (mockCipher) Encrypt(plaintext []byte) (string, error) { return "ENC(" + string(plaintext) + ")", nil } @@ -260,8 +289,11 @@ func TestImportZones_CreatesDomainPerZone(t *testing.T) { if w.Code != http.StatusCreated { t.Fatalf("status %d body %s", w.Code, w.Body.String()) } - if ts.createDomains != 2 { - t.Fatalf("expected 2 CreateDomain calls, got %d", ts.createDomains) + if !ts.importCalled { + t.Fatal("expected ImportDomains to be called") + } + if len(ts.importDomains) != 2 { + t.Fatalf("expected 2 domains created via ImportDomains, got %d", len(ts.importDomains)) } var resp []domainResponse if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { @@ -272,6 +304,37 @@ func TestImportZones_CreatesDomainPerZone(t *testing.T) { } } +// TestImportZones_AtomicRollbackOnError verifies that when the store fails +// to import the batch (e.g. a mid-batch DB error), the handler surfaces a +// 500 and — per store.ImportDomains' transactional contract — no partial +// set of domains is left behind (modeled here by ImportDomains returning no +// domains alongside the error). +func TestImportZones_AtomicRollbackOnError(t *testing.T) { + a, ts := newTenantTestAPI() + accID := uuid.New() + ts.accounts = []store.Account{{ID: accID, Provider: "selectel", SecretEnc: "ENC(token)"}} + ts.importDomainsErr = errors.New("boom: mid-batch failure") + a.Reg = &mockRegistry{zones: []provider.Zone{ + {ID: "z1", Name: "example.com"}, + {ID: "z2", Name: "example.net"}, + }} + router := NewRouter(a) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/"+accID.String()+"/import", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Fatalf("expected 500, got %d body %s", w.Code, w.Body.String()) + } + if strings.Contains(w.Body.String(), "boom") { + t.Fatalf("internal error details leaked to response: %s", w.Body.String()) + } + if len(ts.domains) != 0 { + t.Fatalf("expected no domains to be created on rollback, got %d", len(ts.domains)) + } +} + func TestImportZones_BadAccountUUID(t *testing.T) { a, _ := newTenantTestAPI() router := NewRouter(a) @@ -299,6 +362,107 @@ func TestCreateDomain_BadProjectUUID(t *testing.T) { } } +// TestCreateDomain_AccountNotFoundInProject covers the HIGH tenant-isolation +// fix: a providerAccountId that scoped GetAccount can't find within this +// project must be rejected before any domain is created — otherwise a +// caller could attach a domain to another tenant's provider account. +func TestCreateDomain_AccountNotFoundInProject(t *testing.T) { + a, ts := newTenantTestAPI() + router := NewRouter(a) + + // ts.accounts is empty, so GetAccount will not find this id. + foreignAccID := uuid.New() + body := `{"providerAccountId":"` + foreignAccID.String() + `","zoneName":"example.com","zoneId":"z1"}` + req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body)) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Fatalf("expected 404, got %d body %s", w.Code, w.Body.String()) + } + if ts.createDomains != 0 { + t.Fatalf("expected CreateDomain not to be called, got %d calls", ts.createDomains) + } +} + +// TestCreateDomain_TemplateNotFoundInProject covers the same isolation fix +// for the optional templateId: a template belonging to another project (or +// nonexistent) must reject the request before the domain is created. +func TestCreateDomain_TemplateNotFoundInProject(t *testing.T) { + a, ts := newTenantTestAPI() + accID := uuid.New() + ts.accounts = []store.Account{{ID: accID, Provider: "selectel", SecretEnc: "ENC(token)"}} + router := NewRouter(a) + + foreignTplID := uuid.New() + body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1","templateId":"` + foreignTplID.String() + `"}` + req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body)) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Fatalf("expected 404, got %d body %s", w.Code, w.Body.String()) + } + if ts.createDomains != 0 { + t.Fatalf("expected CreateDomain not to be called, got %d calls", ts.createDomains) + } +} + +// TestCreateDomain_HappyPath ensures the tenant-isolation checks don't break +// the existing success path: a valid account in-project and no template. +func TestCreateDomain_HappyPath(t *testing.T) { + a, ts := newTenantTestAPI() + accID := uuid.New() + ts.accounts = []store.Account{{ID: accID, Provider: "selectel", SecretEnc: "ENC(token)"}} + router := NewRouter(a) + + body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1"}` + req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body)) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusCreated { + t.Fatalf("expected 201, got %d body %s", w.Code, w.Body.String()) + } + if ts.createDomains != 1 { + t.Fatalf("expected 1 CreateDomain call, got %d", ts.createDomains) + } + var resp domainResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatal(err) + } + if resp.ZoneName != "example.com" || resp.TemplateID != nil { + t.Fatalf("unexpected response: %+v", resp) + } +} + +// TestCreateDomain_ValidTemplateInProject ensures a template that scoped +// GetTemplate does find (i.e. belongs to this project) is accepted. +func TestCreateDomain_ValidTemplateInProject(t *testing.T) { + a, ts := newTenantTestAPI() + accID := uuid.New() + tplID := uuid.New() + ts.accounts = []store.Account{{ID: accID, Provider: "selectel", SecretEnc: "ENC(token)"}} + ts.templates = []store.Template{{ID: tplID, Name: "base"}} + router := NewRouter(a) + + body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1","templateId":"` + tplID.String() + `"}` + req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body)) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusCreated { + t.Fatalf("expected 201, got %d body %s", w.Code, w.Body.String()) + } + var resp domainResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatal(err) + } + if resp.TemplateID == nil || *resp.TemplateID != tplID.String() { + t.Fatalf("unexpected response: %+v", resp) + } +} + func TestDeleteDomain_BadUUID(t *testing.T) { a, _ := newTenantTestAPI() router := NewRouter(a) diff --git a/internal/store/store_test.go b/internal/store/store_test.go index fb89309..0f08f2c 100644 --- a/internal/store/store_test.go +++ b/internal/store/store_test.go @@ -7,6 +7,7 @@ import ( "github.com/google/uuid" "github.com/jackc/pgx/v5/pgxpool" + "github.com/vasyakrg/dns-autoresolver/internal/provider" "github.com/vasyakrg/dns-autoresolver/internal/store/db" "github.com/vasyakrg/dns-autoresolver/internal/store/dto" ) @@ -88,3 +89,57 @@ func TestTemplateJSONBRoundTrip(t *testing.T) { t.Fatal(err) } } + +func TestImportDomains_CommitsAllOnSuccess(t *testing.T) { + s, ctx := newStore(t) + acc, err := s.Queries().CreateAccount(ctx, db.CreateAccountParams{ + ID: uuid.New(), ProjectID: defaultProject, Provider: "selectel", SecretEnc: "enc-blob", + }) + if err != nil { + t.Fatal(err) + } + + zones := []provider.Zone{ + {ID: "z1", Name: "a.example.com"}, + {ID: "z2", Name: "b.example.com"}, + } + doms, err := s.ImportDomains(ctx, defaultProject, acc.ID, zones) + if err != nil { + t.Fatal(err) + } + if len(doms) != 2 { + t.Fatalf("expected 2 domains returned, got %d", len(doms)) + } + + list, err := s.ListDomains(ctx, defaultProject) + if err != nil { + t.Fatal(err) + } + if len(list) != 2 { + t.Fatalf("expected 2 persisted domains, got %d", len(list)) + } +} + +// TestImportDomains_RollsBackAllOnError verifies the transactional contract: +// if any zone in the batch fails to insert (here, an FK violation because +// the account doesn't exist), none of the batch is left committed. +func TestImportDomains_RollsBackAllOnError(t *testing.T) { + s, ctx := newStore(t) + bogusAccountID := uuid.New() // no matching provider_accounts row + + zones := []provider.Zone{ + {ID: "z1", Name: "a.example.com"}, + {ID: "z2", Name: "b.example.com"}, + } + if _, err := s.ImportDomains(ctx, defaultProject, bogusAccountID, zones); err == nil { + t.Fatal("expected FK violation error, got nil") + } + + list, err := s.ListDomains(ctx, defaultProject) + if err != nil { + t.Fatal(err) + } + if len(list) != 0 { + t.Fatalf("expected 0 domains after rollback, got %d", len(list)) + } +} diff --git a/internal/store/tenant.go b/internal/store/tenant.go index 5bfc460..dc8908e 100644 --- a/internal/store/tenant.go +++ b/internal/store/tenant.go @@ -5,6 +5,7 @@ import ( "github.com/google/uuid" + "github.com/vasyakrg/dns-autoresolver/internal/provider" "github.com/vasyakrg/dns-autoresolver/internal/store/db" "github.com/vasyakrg/dns-autoresolver/internal/store/dto" ) @@ -109,6 +110,16 @@ func (s *Store) DeleteTemplate(ctx context.Context, id, projectID uuid.UUID) err return s.q.DeleteTemplate(ctx, db.DeleteTemplateParams{ID: id, ProjectID: projectID}) } +// GetTemplate is a scoped lookup used to verify a template belongs to +// projectID before it is referenced elsewhere (e.g. CreateDomain). +func (s *Store) GetTemplate(ctx context.Context, id, projectID uuid.UUID) (Template, error) { + t, err := s.q.GetTemplate(ctx, db.GetTemplateParams{ID: id, ProjectID: projectID}) + if err != nil { + return Template{}, err + } + return templateFromDB(t), nil +} + type Domain struct { ID uuid.UUID ProjectID uuid.UUID @@ -151,3 +162,31 @@ func (s *Store) ListDomains(ctx context.Context, projectID uuid.UUID) ([]Domain, func (s *Store) DeleteDomain(ctx context.Context, id, projectID uuid.UUID) error { return s.q.DeleteDomain(ctx, db.DeleteDomainParams{ID: id, ProjectID: projectID}) } + +// ImportDomains creates one domain per zone inside a single transaction: if +// any zone fails to be created, the whole batch is rolled back so callers +// never observe a partially-imported set of domains. +func (s *Store) ImportDomains(ctx context.Context, projectID, accountID uuid.UUID, zones []provider.Zone) ([]Domain, error) { + tx, err := s.pool.Begin(ctx) + if err != nil { + return nil, err + } + defer tx.Rollback(ctx) // no-op once Commit has succeeded + + q := s.q.WithTx(tx) + out := make([]Domain, 0, len(zones)) + for _, z := range zones { + d, err := q.CreateDomain(ctx, db.CreateDomainParams{ + ID: uuid.New(), ProjectID: projectID, ProviderAccountID: accountID, + ZoneName: z.Name, ZoneID: z.ID, TemplateID: nil, + }) + if err != nil { + return nil, err + } + out = append(out, domainFromDB(d)) + } + if err := tx.Commit(ctx); err != nil { + return nil, err + } + return out, nil +}