fix(api): tenant-проверка account/template в CreateDomain (HIGH), атомарный import через транзакцию (MEDIUM)

This commit is contained in:
2026-07-03 15:08:16 +07:00
parent ae6a4d7f4c
commit 2aca92d070
5 changed files with 288 additions and 11 deletions
+167 -3
View File
@@ -3,6 +3,7 @@ package api
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
@@ -30,6 +31,10 @@ type mockTenantStore struct {
domains []store.Domain
createDomains int
importDomains []store.Domain
importDomainsErr error
importCalled bool
}
func (m *mockTenantStore) CreateAccount(_ context.Context, projectID uuid.UUID, prov, secretEnc, comment string) (store.Account, error) {
@@ -49,7 +54,7 @@ func (m *mockTenantStore) GetAccount(_ context.Context, id, _ uuid.UUID) (store.
return a, nil
}
}
return m.accounts[0], nil
return store.Account{}, errors.New("account not found")
}
func (m *mockTenantStore) DeleteAccount(context.Context, uuid.UUID, uuid.UUID) error { return nil }
@@ -71,6 +76,15 @@ func (m *mockTenantStore) UpdateTemplate(_ context.Context, id, projectID uuid.U
func (m *mockTenantStore) DeleteTemplate(context.Context, uuid.UUID, uuid.UUID) error { return nil }
func (m *mockTenantStore) GetTemplate(_ context.Context, id, _ uuid.UUID) (store.Template, error) {
for _, t := range m.templates {
if t.ID == id {
return t, nil
}
}
return store.Template{}, errors.New("template not found")
}
func (m *mockTenantStore) CreateDomain(_ context.Context, projectID, accountID uuid.UUID, zoneName, zoneID string, templateID *uuid.UUID) (store.Domain, error) {
m.createDomains++
d := store.Domain{ID: uuid.New(), ProjectID: projectID, ProviderAccountID: accountID, ZoneName: zoneName, ZoneID: zoneID, TemplateID: templateID}
@@ -84,6 +98,21 @@ func (m *mockTenantStore) ListDomains(context.Context, uuid.UUID) ([]store.Domai
func (m *mockTenantStore) DeleteDomain(context.Context, uuid.UUID, uuid.UUID) error { return nil }
func (m *mockTenantStore) ImportDomains(_ context.Context, projectID, accountID uuid.UUID, zones []provider.Zone) ([]store.Domain, error) {
m.importCalled = true
if m.importDomainsErr != nil {
return nil, m.importDomainsErr
}
out := make([]store.Domain, 0, len(zones))
for _, z := range zones {
d := store.Domain{ID: uuid.New(), ProjectID: projectID, ProviderAccountID: accountID, ZoneName: z.Name, ZoneID: z.ID}
out = append(out, d)
}
m.domains = append(m.domains, out...)
m.importDomains = out
return out, nil
}
type mockCipher struct{}
func (mockCipher) Encrypt(plaintext []byte) (string, error) { return "ENC(" + string(plaintext) + ")", nil }
@@ -260,8 +289,11 @@ func TestImportZones_CreatesDomainPerZone(t *testing.T) {
if w.Code != http.StatusCreated {
t.Fatalf("status %d body %s", w.Code, w.Body.String())
}
if ts.createDomains != 2 {
t.Fatalf("expected 2 CreateDomain calls, got %d", ts.createDomains)
if !ts.importCalled {
t.Fatal("expected ImportDomains to be called")
}
if len(ts.importDomains) != 2 {
t.Fatalf("expected 2 domains created via ImportDomains, got %d", len(ts.importDomains))
}
var resp []domainResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
@@ -272,6 +304,37 @@ func TestImportZones_CreatesDomainPerZone(t *testing.T) {
}
}
// TestImportZones_AtomicRollbackOnError verifies that when the store fails
// to import the batch (e.g. a mid-batch DB error), the handler surfaces a
// 500 and — per store.ImportDomains' transactional contract — no partial
// set of domains is left behind (modeled here by ImportDomains returning no
// domains alongside the error).
func TestImportZones_AtomicRollbackOnError(t *testing.T) {
a, ts := newTenantTestAPI()
accID := uuid.New()
ts.accounts = []store.Account{{ID: accID, Provider: "selectel", SecretEnc: "ENC(token)"}}
ts.importDomainsErr = errors.New("boom: mid-batch failure")
a.Reg = &mockRegistry{zones: []provider.Zone{
{ID: "z1", Name: "example.com"},
{ID: "z2", Name: "example.net"},
}}
router := NewRouter(a)
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/"+accID.String()+"/import", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d body %s", w.Code, w.Body.String())
}
if strings.Contains(w.Body.String(), "boom") {
t.Fatalf("internal error details leaked to response: %s", w.Body.String())
}
if len(ts.domains) != 0 {
t.Fatalf("expected no domains to be created on rollback, got %d", len(ts.domains))
}
}
func TestImportZones_BadAccountUUID(t *testing.T) {
a, _ := newTenantTestAPI()
router := NewRouter(a)
@@ -299,6 +362,107 @@ func TestCreateDomain_BadProjectUUID(t *testing.T) {
}
}
// TestCreateDomain_AccountNotFoundInProject covers the HIGH tenant-isolation
// fix: a providerAccountId that scoped GetAccount can't find within this
// project must be rejected before any domain is created — otherwise a
// caller could attach a domain to another tenant's provider account.
func TestCreateDomain_AccountNotFoundInProject(t *testing.T) {
a, ts := newTenantTestAPI()
router := NewRouter(a)
// ts.accounts is empty, so GetAccount will not find this id.
foreignAccID := uuid.New()
body := `{"providerAccountId":"` + foreignAccID.String() + `","zoneName":"example.com","zoneId":"z1"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d body %s", w.Code, w.Body.String())
}
if ts.createDomains != 0 {
t.Fatalf("expected CreateDomain not to be called, got %d calls", ts.createDomains)
}
}
// TestCreateDomain_TemplateNotFoundInProject covers the same isolation fix
// for the optional templateId: a template belonging to another project (or
// nonexistent) must reject the request before the domain is created.
func TestCreateDomain_TemplateNotFoundInProject(t *testing.T) {
a, ts := newTenantTestAPI()
accID := uuid.New()
ts.accounts = []store.Account{{ID: accID, Provider: "selectel", SecretEnc: "ENC(token)"}}
router := NewRouter(a)
foreignTplID := uuid.New()
body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1","templateId":"` + foreignTplID.String() + `"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d body %s", w.Code, w.Body.String())
}
if ts.createDomains != 0 {
t.Fatalf("expected CreateDomain not to be called, got %d calls", ts.createDomains)
}
}
// TestCreateDomain_HappyPath ensures the tenant-isolation checks don't break
// the existing success path: a valid account in-project and no template.
func TestCreateDomain_HappyPath(t *testing.T) {
a, ts := newTenantTestAPI()
accID := uuid.New()
ts.accounts = []store.Account{{ID: accID, Provider: "selectel", SecretEnc: "ENC(token)"}}
router := NewRouter(a)
body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("expected 201, got %d body %s", w.Code, w.Body.String())
}
if ts.createDomains != 1 {
t.Fatalf("expected 1 CreateDomain call, got %d", ts.createDomains)
}
var resp domainResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
if resp.ZoneName != "example.com" || resp.TemplateID != nil {
t.Fatalf("unexpected response: %+v", resp)
}
}
// TestCreateDomain_ValidTemplateInProject ensures a template that scoped
// GetTemplate does find (i.e. belongs to this project) is accepted.
func TestCreateDomain_ValidTemplateInProject(t *testing.T) {
a, ts := newTenantTestAPI()
accID := uuid.New()
tplID := uuid.New()
ts.accounts = []store.Account{{ID: accID, Provider: "selectel", SecretEnc: "ENC(token)"}}
ts.templates = []store.Template{{ID: tplID, Name: "base"}}
router := NewRouter(a)
body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1","templateId":"` + tplID.String() + `"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("expected 201, got %d body %s", w.Code, w.Body.String())
}
var resp domainResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
if resp.TemplateID == nil || *resp.TemplateID != tplID.String() {
t.Fatalf("unexpected response: %+v", resp)
}
}
func TestDeleteDomain_BadUUID(t *testing.T) {
a, _ := newTenantTestAPI()
router := NewRouter(a)