fix(api): tenant-проверка account/template в CreateDomain (HIGH), атомарный import через транзакцию (MEDIUM)

This commit is contained in:
2026-07-03 15:08:16 +07:00
parent ae6a4d7f4c
commit 2aca92d070
5 changed files with 288 additions and 11 deletions
+2
View File
@@ -34,10 +34,12 @@ type TenantStore interface {
ListTemplates(ctx context.Context, projectID uuid.UUID) ([]store.Template, error)
UpdateTemplate(ctx context.Context, id, projectID uuid.UUID, name string, doc dto.TemplateDoc) (store.Template, error)
DeleteTemplate(ctx context.Context, id, projectID uuid.UUID) error
GetTemplate(ctx context.Context, id, projectID uuid.UUID) (store.Template, error)
CreateDomain(ctx context.Context, projectID, accountID uuid.UUID, zoneName, zoneID string, templateID *uuid.UUID) (store.Domain, error)
ListDomains(ctx context.Context, projectID uuid.UUID) ([]store.Domain, error)
DeleteDomain(ctx context.Context, id, projectID uuid.UUID) error
ImportDomains(ctx context.Context, projectID, accountID uuid.UUID, zones []provider.Zone) ([]store.Domain, error)
}
// Cipher encrypts/decrypts provider account secrets. *crypto.Cipher satisfies it.
+25 -8
View File
@@ -127,14 +127,16 @@ func (a *API) handleImportZones(w http.ResponseWriter, r *http.Request) {
writeErr(w, http.StatusInternalServerError, "internal error")
return
}
created := make([]domainResponse, 0, len(zones))
for _, z := range zones {
d, err := a.Store.CreateDomain(r.Context(), pid, aid, z.Name, z.ID, nil)
if err != nil {
log.Printf("api: import: create domain failed: %v", err)
writeErr(w, http.StatusInternalServerError, "internal error")
return
}
// Imported atomically: either every zone becomes a domain or none does,
// so a mid-batch provider/DB error never leaves a partial import behind.
doms, err := a.Store.ImportDomains(r.Context(), pid, aid, zones)
if err != nil {
log.Printf("api: import: create domains failed: %v", err)
writeErr(w, http.StatusInternalServerError, "internal error")
return
}
created := make([]domainResponse, 0, len(doms))
for _, d := range doms {
created = append(created, toDomainResponse(d))
}
writeJSON(w, http.StatusCreated, created)
@@ -259,6 +261,21 @@ func (a *API) handleCreateDomain(w http.ResponseWriter, r *http.Request) {
writeErr(w, http.StatusBadRequest, "invalid templateId")
return
}
// Tenant isolation: the account (and template, if given) must belong to
// this project — otherwise a caller could attach a domain to another
// tenant's provider account or DNS template.
if _, err := a.Store.GetAccount(r.Context(), accID, pid); err != nil {
writeErr(w, http.StatusNotFound, "provider account not found")
return
}
if templateID != nil {
if _, err := a.Store.GetTemplate(r.Context(), *templateID, pid); err != nil {
writeErr(w, http.StatusNotFound, "template not found")
return
}
}
dom, err := a.Store.CreateDomain(r.Context(), pid, accID, req.ZoneName, req.ZoneID, templateID)
if err != nil {
log.Printf("api: create domain failed: %v", err)
+167 -3
View File
@@ -3,6 +3,7 @@ package api
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
@@ -30,6 +31,10 @@ type mockTenantStore struct {
domains []store.Domain
createDomains int
importDomains []store.Domain
importDomainsErr error
importCalled bool
}
func (m *mockTenantStore) CreateAccount(_ context.Context, projectID uuid.UUID, prov, secretEnc, comment string) (store.Account, error) {
@@ -49,7 +54,7 @@ func (m *mockTenantStore) GetAccount(_ context.Context, id, _ uuid.UUID) (store.
return a, nil
}
}
return m.accounts[0], nil
return store.Account{}, errors.New("account not found")
}
func (m *mockTenantStore) DeleteAccount(context.Context, uuid.UUID, uuid.UUID) error { return nil }
@@ -71,6 +76,15 @@ func (m *mockTenantStore) UpdateTemplate(_ context.Context, id, projectID uuid.U
func (m *mockTenantStore) DeleteTemplate(context.Context, uuid.UUID, uuid.UUID) error { return nil }
func (m *mockTenantStore) GetTemplate(_ context.Context, id, _ uuid.UUID) (store.Template, error) {
for _, t := range m.templates {
if t.ID == id {
return t, nil
}
}
return store.Template{}, errors.New("template not found")
}
func (m *mockTenantStore) CreateDomain(_ context.Context, projectID, accountID uuid.UUID, zoneName, zoneID string, templateID *uuid.UUID) (store.Domain, error) {
m.createDomains++
d := store.Domain{ID: uuid.New(), ProjectID: projectID, ProviderAccountID: accountID, ZoneName: zoneName, ZoneID: zoneID, TemplateID: templateID}
@@ -84,6 +98,21 @@ func (m *mockTenantStore) ListDomains(context.Context, uuid.UUID) ([]store.Domai
func (m *mockTenantStore) DeleteDomain(context.Context, uuid.UUID, uuid.UUID) error { return nil }
func (m *mockTenantStore) ImportDomains(_ context.Context, projectID, accountID uuid.UUID, zones []provider.Zone) ([]store.Domain, error) {
m.importCalled = true
if m.importDomainsErr != nil {
return nil, m.importDomainsErr
}
out := make([]store.Domain, 0, len(zones))
for _, z := range zones {
d := store.Domain{ID: uuid.New(), ProjectID: projectID, ProviderAccountID: accountID, ZoneName: z.Name, ZoneID: z.ID}
out = append(out, d)
}
m.domains = append(m.domains, out...)
m.importDomains = out
return out, nil
}
type mockCipher struct{}
func (mockCipher) Encrypt(plaintext []byte) (string, error) { return "ENC(" + string(plaintext) + ")", nil }
@@ -260,8 +289,11 @@ func TestImportZones_CreatesDomainPerZone(t *testing.T) {
if w.Code != http.StatusCreated {
t.Fatalf("status %d body %s", w.Code, w.Body.String())
}
if ts.createDomains != 2 {
t.Fatalf("expected 2 CreateDomain calls, got %d", ts.createDomains)
if !ts.importCalled {
t.Fatal("expected ImportDomains to be called")
}
if len(ts.importDomains) != 2 {
t.Fatalf("expected 2 domains created via ImportDomains, got %d", len(ts.importDomains))
}
var resp []domainResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
@@ -272,6 +304,37 @@ func TestImportZones_CreatesDomainPerZone(t *testing.T) {
}
}
// TestImportZones_AtomicRollbackOnError verifies that when the store fails
// to import the batch (e.g. a mid-batch DB error), the handler surfaces a
// 500 and — per store.ImportDomains' transactional contract — no partial
// set of domains is left behind (modeled here by ImportDomains returning no
// domains alongside the error).
func TestImportZones_AtomicRollbackOnError(t *testing.T) {
a, ts := newTenantTestAPI()
accID := uuid.New()
ts.accounts = []store.Account{{ID: accID, Provider: "selectel", SecretEnc: "ENC(token)"}}
ts.importDomainsErr = errors.New("boom: mid-batch failure")
a.Reg = &mockRegistry{zones: []provider.Zone{
{ID: "z1", Name: "example.com"},
{ID: "z2", Name: "example.net"},
}}
router := NewRouter(a)
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/"+accID.String()+"/import", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d body %s", w.Code, w.Body.String())
}
if strings.Contains(w.Body.String(), "boom") {
t.Fatalf("internal error details leaked to response: %s", w.Body.String())
}
if len(ts.domains) != 0 {
t.Fatalf("expected no domains to be created on rollback, got %d", len(ts.domains))
}
}
func TestImportZones_BadAccountUUID(t *testing.T) {
a, _ := newTenantTestAPI()
router := NewRouter(a)
@@ -299,6 +362,107 @@ func TestCreateDomain_BadProjectUUID(t *testing.T) {
}
}
// TestCreateDomain_AccountNotFoundInProject covers the HIGH tenant-isolation
// fix: a providerAccountId that scoped GetAccount can't find within this
// project must be rejected before any domain is created — otherwise a
// caller could attach a domain to another tenant's provider account.
func TestCreateDomain_AccountNotFoundInProject(t *testing.T) {
a, ts := newTenantTestAPI()
router := NewRouter(a)
// ts.accounts is empty, so GetAccount will not find this id.
foreignAccID := uuid.New()
body := `{"providerAccountId":"` + foreignAccID.String() + `","zoneName":"example.com","zoneId":"z1"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d body %s", w.Code, w.Body.String())
}
if ts.createDomains != 0 {
t.Fatalf("expected CreateDomain not to be called, got %d calls", ts.createDomains)
}
}
// TestCreateDomain_TemplateNotFoundInProject covers the same isolation fix
// for the optional templateId: a template belonging to another project (or
// nonexistent) must reject the request before the domain is created.
func TestCreateDomain_TemplateNotFoundInProject(t *testing.T) {
a, ts := newTenantTestAPI()
accID := uuid.New()
ts.accounts = []store.Account{{ID: accID, Provider: "selectel", SecretEnc: "ENC(token)"}}
router := NewRouter(a)
foreignTplID := uuid.New()
body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1","templateId":"` + foreignTplID.String() + `"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d body %s", w.Code, w.Body.String())
}
if ts.createDomains != 0 {
t.Fatalf("expected CreateDomain not to be called, got %d calls", ts.createDomains)
}
}
// TestCreateDomain_HappyPath ensures the tenant-isolation checks don't break
// the existing success path: a valid account in-project and no template.
func TestCreateDomain_HappyPath(t *testing.T) {
a, ts := newTenantTestAPI()
accID := uuid.New()
ts.accounts = []store.Account{{ID: accID, Provider: "selectel", SecretEnc: "ENC(token)"}}
router := NewRouter(a)
body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("expected 201, got %d body %s", w.Code, w.Body.String())
}
if ts.createDomains != 1 {
t.Fatalf("expected 1 CreateDomain call, got %d", ts.createDomains)
}
var resp domainResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
if resp.ZoneName != "example.com" || resp.TemplateID != nil {
t.Fatalf("unexpected response: %+v", resp)
}
}
// TestCreateDomain_ValidTemplateInProject ensures a template that scoped
// GetTemplate does find (i.e. belongs to this project) is accepted.
func TestCreateDomain_ValidTemplateInProject(t *testing.T) {
a, ts := newTenantTestAPI()
accID := uuid.New()
tplID := uuid.New()
ts.accounts = []store.Account{{ID: accID, Provider: "selectel", SecretEnc: "ENC(token)"}}
ts.templates = []store.Template{{ID: tplID, Name: "base"}}
router := NewRouter(a)
body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1","templateId":"` + tplID.String() + `"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("expected 201, got %d body %s", w.Code, w.Body.String())
}
var resp domainResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
if resp.TemplateID == nil || *resp.TemplateID != tplID.String() {
t.Fatalf("unexpected response: %+v", resp)
}
}
func TestDeleteDomain_BadUUID(t *testing.T) {
a, _ := newTenantTestAPI()
router := NewRouter(a)
+55
View File
@@ -7,6 +7,7 @@ import (
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/vasyakrg/dns-autoresolver/internal/provider"
"github.com/vasyakrg/dns-autoresolver/internal/store/db"
"github.com/vasyakrg/dns-autoresolver/internal/store/dto"
)
@@ -88,3 +89,57 @@ func TestTemplateJSONBRoundTrip(t *testing.T) {
t.Fatal(err)
}
}
func TestImportDomains_CommitsAllOnSuccess(t *testing.T) {
s, ctx := newStore(t)
acc, err := s.Queries().CreateAccount(ctx, db.CreateAccountParams{
ID: uuid.New(), ProjectID: defaultProject, Provider: "selectel", SecretEnc: "enc-blob",
})
if err != nil {
t.Fatal(err)
}
zones := []provider.Zone{
{ID: "z1", Name: "a.example.com"},
{ID: "z2", Name: "b.example.com"},
}
doms, err := s.ImportDomains(ctx, defaultProject, acc.ID, zones)
if err != nil {
t.Fatal(err)
}
if len(doms) != 2 {
t.Fatalf("expected 2 domains returned, got %d", len(doms))
}
list, err := s.ListDomains(ctx, defaultProject)
if err != nil {
t.Fatal(err)
}
if len(list) != 2 {
t.Fatalf("expected 2 persisted domains, got %d", len(list))
}
}
// TestImportDomains_RollsBackAllOnError verifies the transactional contract:
// if any zone in the batch fails to insert (here, an FK violation because
// the account doesn't exist), none of the batch is left committed.
func TestImportDomains_RollsBackAllOnError(t *testing.T) {
s, ctx := newStore(t)
bogusAccountID := uuid.New() // no matching provider_accounts row
zones := []provider.Zone{
{ID: "z1", Name: "a.example.com"},
{ID: "z2", Name: "b.example.com"},
}
if _, err := s.ImportDomains(ctx, defaultProject, bogusAccountID, zones); err == nil {
t.Fatal("expected FK violation error, got nil")
}
list, err := s.ListDomains(ctx, defaultProject)
if err != nil {
t.Fatal(err)
}
if len(list) != 0 {
t.Fatalf("expected 0 domains after rollback, got %d", len(list))
}
}
+39
View File
@@ -5,6 +5,7 @@ import (
"github.com/google/uuid"
"github.com/vasyakrg/dns-autoresolver/internal/provider"
"github.com/vasyakrg/dns-autoresolver/internal/store/db"
"github.com/vasyakrg/dns-autoresolver/internal/store/dto"
)
@@ -109,6 +110,16 @@ func (s *Store) DeleteTemplate(ctx context.Context, id, projectID uuid.UUID) err
return s.q.DeleteTemplate(ctx, db.DeleteTemplateParams{ID: id, ProjectID: projectID})
}
// GetTemplate is a scoped lookup used to verify a template belongs to
// projectID before it is referenced elsewhere (e.g. CreateDomain).
func (s *Store) GetTemplate(ctx context.Context, id, projectID uuid.UUID) (Template, error) {
t, err := s.q.GetTemplate(ctx, db.GetTemplateParams{ID: id, ProjectID: projectID})
if err != nil {
return Template{}, err
}
return templateFromDB(t), nil
}
type Domain struct {
ID uuid.UUID
ProjectID uuid.UUID
@@ -151,3 +162,31 @@ func (s *Store) ListDomains(ctx context.Context, projectID uuid.UUID) ([]Domain,
func (s *Store) DeleteDomain(ctx context.Context, id, projectID uuid.UUID) error {
return s.q.DeleteDomain(ctx, db.DeleteDomainParams{ID: id, ProjectID: projectID})
}
// ImportDomains creates one domain per zone inside a single transaction: if
// any zone fails to be created, the whole batch is rolled back so callers
// never observe a partially-imported set of domains.
func (s *Store) ImportDomains(ctx context.Context, projectID, accountID uuid.UUID, zones []provider.Zone) ([]Domain, error) {
tx, err := s.pool.Begin(ctx)
if err != nil {
return nil, err
}
defer tx.Rollback(ctx) // no-op once Commit has succeeded
q := s.q.WithTx(tx)
out := make([]Domain, 0, len(zones))
for _, z := range zones {
d, err := q.CreateDomain(ctx, db.CreateDomainParams{
ID: uuid.New(), ProjectID: projectID, ProviderAccountID: accountID,
ZoneName: z.Name, ZoneID: z.ID, TemplateID: nil,
})
if err != nil {
return nil, err
}
out = append(out, domainFromDB(d))
}
if err := tx.Commit(ctx); err != nil {
return nil, err
}
return out, nil
}