diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..0d1fe41 --- /dev/null +++ b/Makefile @@ -0,0 +1,7 @@ +.PHONY: test +test: + go test ./... + +.PHONY: build +build: + go build ./... diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..d340757 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/vasyakrg/dns-autoresolver + +go 1.26.4 diff --git a/internal/diff/diff.go b/internal/diff/diff.go new file mode 100644 index 0000000..2d1edef --- /dev/null +++ b/internal/diff/diff.go @@ -0,0 +1,79 @@ +package diff + +import "github.com/vasyakrg/dns-autoresolver/internal/model" + +type ChangeKind string + +const ( + InSync ChangeKind = "in_sync" + Add ChangeKind = "add" + Update ChangeKind = "update" + Delete ChangeKind = "delete" +) + +// RecordDiff describes one RRset's deviation between template and zone. +type RecordDiff struct { + Kind ChangeKind + Type model.RecordType + Name string + Desired *model.Record // nil for Delete + Actual *model.Record // nil for Add + ReadOnly bool // NS/SOA — shown but never applied +} + +type Changeset struct { + Diffs []RecordDiff +} + +// Actionable returns managed diffs that are not in sync. +func (c Changeset) Actionable() []RecordDiff { + var out []RecordDiff + for _, d := range c.Diffs { + if d.ReadOnly || d.Kind == InSync { + continue + } + out = append(out, d) + } + return out +} + +// Diff compares a template against the actual zone records. +// Records present in the zone but absent from the template yield Delete. +func Diff(template, actual []model.Record) Changeset { + current := index(actual) + seen := make(map[string]bool, len(template)) + var diffs []RecordDiff + + for _, t := range template { + tt := t + key := tt.Key() + seen[key] = true + ro := !tt.Type.Managed() + if a, ok := current[key]; ok { + ac := a + kind := Update + if tt.Equal(ac) { + kind = InSync + } + diffs = append(diffs, RecordDiff{Kind: kind, Type: tt.Type, Name: tt.Name, Desired: &tt, Actual: &ac, ReadOnly: ro}) + } else { + diffs = append(diffs, RecordDiff{Kind: Add, Type: tt.Type, Name: tt.Name, Desired: &tt, ReadOnly: ro}) + } + } + for _, a := range actual { + ac := a + if seen[ac.Key()] { + continue + } + diffs = append(diffs, RecordDiff{Kind: Delete, Type: ac.Type, Name: ac.Name, Actual: &ac, ReadOnly: !ac.Type.Managed()}) + } + return Changeset{Diffs: diffs} +} + +func index(recs []model.Record) map[string]model.Record { + m := make(map[string]model.Record, len(recs)) + for _, r := range recs { + m[r.Key()] = r + } + return m +} diff --git a/internal/diff/diff_test.go b/internal/diff/diff_test.go new file mode 100644 index 0000000..eac6b43 --- /dev/null +++ b/internal/diff/diff_test.go @@ -0,0 +1,115 @@ +package diff + +import ( + "testing" + + "github.com/vasyakrg/dns-autoresolver/internal/model" +) + +func find(cs Changeset, key string) *RecordDiff { + for i := range cs.Diffs { + d := cs.Diffs[i] + var r *model.Record + if d.Desired != nil { + r = d.Desired + } else { + r = d.Actual + } + if r.Key() == key { + return &cs.Diffs[i] + } + } + return nil +} + +func TestDiffAddUpdateDeleteInSync(t *testing.T) { + tmpl := []model.Record{ + {Type: model.A, Name: "a.example.com.", TTL: 300, Values: []string{"1.1.1.1"}}, // in sync + {Type: model.A, Name: "b.example.com.", TTL: 300, Values: []string{"2.2.2.2"}}, // update + {Type: model.A, Name: "c.example.com.", TTL: 300, Values: []string{"3.3.3.3"}}, // add + } + actual := []model.Record{ + {Type: model.A, Name: "a.example.com.", TTL: 300, Values: []string{"1.1.1.1"}}, + {Type: model.A, Name: "b.example.com.", TTL: 300, Values: []string{"9.9.9.9"}}, + {Type: model.A, Name: "d.example.com.", TTL: 300, Values: []string{"4.4.4.4"}}, // delete (extra) + } + cs := Diff(tmpl, actual) + + if d := find(cs, "A a.example.com."); d == nil || d.Kind != InSync { + t.Fatalf("a should be InSync, got %+v", d) + } + if d := find(cs, "A b.example.com."); d == nil || d.Kind != Update { + t.Fatalf("b should be Update, got %+v", d) + } + if d := find(cs, "A c.example.com."); d == nil || d.Kind != Add { + t.Fatalf("c should be Add, got %+v", d) + } + if d := find(cs, "A d.example.com."); d == nil || d.Kind != Delete { + t.Fatalf("d should be Delete, got %+v", d) + } +} + +func TestDiffMarksReadOnlyForNSSOA(t *testing.T) { + tmpl := []model.Record{{Type: model.NS, Name: "example.com.", TTL: 3600, Values: []string{"ns1.example.com."}}} + actual := []model.Record{{Type: model.NS, Name: "example.com.", TTL: 3600, Values: []string{"ns9.other.com."}}} + cs := Diff(tmpl, actual) + d := find(cs, "NS example.com.") + if d == nil || d.Kind != Update || !d.ReadOnly { + t.Fatalf("NS diff must be Update and ReadOnly, got %+v", d) + } +} + +// Global Constraint: an empty/nil template must not silently no-op — every managed +// record in the zone must surface as a Delete, while read-only records (NS/SOA) +// stay ReadOnly and excluded from Actionable(). This guards against mass deletion +// bugs where a missing template accidentally wipes the zone unattended. +func TestDiffEmptyTemplateDeletesAllManagedKeepsNSReadOnly(t *testing.T) { + actual := []model.Record{ + {Type: model.A, Name: "a.example.com.", TTL: 300, Values: []string{"1.1.1.1"}}, + {Type: model.A, Name: "b.example.com.", TTL: 300, Values: []string{"2.2.2.2"}}, + {Type: model.NS, Name: "example.com.", TTL: 3600, Values: []string{"ns1.example.com."}}, + } + cs := Diff(nil, actual) + + if len(cs.Diffs) != 3 { + t.Fatalf("expected 3 diffs (2 A deletes + 1 NS), got %d: %+v", len(cs.Diffs), cs.Diffs) + } + + da := find(cs, "A a.example.com.") + if da == nil || da.Kind != Delete || da.ReadOnly { + t.Fatalf("A a.example.com. must be a non-read-only Delete, got %+v", da) + } + db := find(cs, "A b.example.com.") + if db == nil || db.Kind != Delete || db.ReadOnly { + t.Fatalf("A b.example.com. must be a non-read-only Delete, got %+v", db) + } + dns := find(cs, "NS example.com.") + if dns == nil || dns.Kind != Delete || !dns.ReadOnly { + t.Fatalf("NS example.com. must be a ReadOnly Delete, got %+v", dns) + } + + act := cs.Actionable() + if len(act) != 2 { + t.Fatalf("expected 2 actionable deletes (A records only), got %d: %+v", len(act), act) + } + for _, d := range act { + if d.Type == model.NS { + t.Fatalf("NS must be excluded from Actionable(), got %+v", d) + } + } +} + +func TestActionableExcludesInSyncAndReadOnly(t *testing.T) { + tmpl := []model.Record{ + {Type: model.A, Name: "a.example.com.", TTL: 300, Values: []string{"1.1.1.1"}}, // in sync + {Type: model.A, Name: "b.example.com.", TTL: 300, Values: []string{"2.2.2.2"}}, // add + {Type: model.NS, Name: "example.com.", TTL: 3600, Values: []string{"ns1.example.com."}}, // read-only add + } + actual := []model.Record{ + {Type: model.A, Name: "a.example.com.", TTL: 300, Values: []string{"1.1.1.1"}}, + } + act := Diff(tmpl, actual).Actionable() + if len(act) != 1 || act[0].Name != "b.example.com." { + t.Fatalf("only b.example.com. is actionable, got %+v", act) + } +} diff --git a/internal/model/record.go b/internal/model/record.go new file mode 100644 index 0000000..8aa2d0b --- /dev/null +++ b/internal/model/record.go @@ -0,0 +1,106 @@ +package model + +import ( + "sort" + "strings" +) + +type RecordType string + +const ( + A RecordType = "A" + AAAA RecordType = "AAAA" + CNAME RecordType = "CNAME" + MX RecordType = "MX" + TXT RecordType = "TXT" + SRV RecordType = "SRV" + NS RecordType = "NS" + SOA RecordType = "SOA" +) + +// Managed reports whether the type participates in diff+apply. +// NS and SOA are read-only. +func (t RecordType) Managed() bool { + switch t { + case A, AAAA, CNAME, MX, TXT, SRV: + return true + default: + return false + } +} + +// Record is the provider-neutral representation of a DNS RRset. +// For MX the value is " "; for SRV it is +// " ". Values is an unordered set. +type Record struct { + Type RecordType + Name string + TTL int + Values []string +} + +// Key uniquely identifies an RRset within a zone. +func (r Record) Key() string { + return string(r.Type) + " " + normalizeName(r.Name) +} + +func normalizeName(name string) string { + n := strings.ToLower(strings.TrimSpace(name)) + if n != "" && !strings.HasSuffix(n, ".") { + n += "." + } + return n +} + +// normalizeValue canonicalizes a single RR value for comparison. +func normalizeValue(t RecordType, content string) string { + if t == TXT { + return content // byte-exact — case and whitespace are significant (DKIM/SPF/DMARC) + } + c := strings.Join(strings.Fields(content), " ") // collapse whitespace + switch t { + case MX: + parts := strings.SplitN(c, " ", 2) + if len(parts) == 2 { + return parts[0] + " " + normalizeName(parts[1]) + } + return c + case SRV: + f := strings.Fields(c) + if len(f) == 4 { + return f[0] + " " + f[1] + " " + f[2] + " " + normalizeName(f[3]) + } + return c + case CNAME, NS: + return normalizeName(c) + default: // A, AAAA, SOA + return strings.ToLower(c) + } +} + +// NormalizedValues returns sorted, normalized values. +func (r Record) NormalizedValues() []string { + out := make([]string, len(r.Values)) + for i, v := range r.Values { + out[i] = normalizeValue(r.Type, v) + } + sort.Strings(out) + return out +} + +// Equal reports whether two records have the same TTL and value set. +func (r Record) Equal(o Record) bool { + if r.TTL != o.TTL { + return false + } + a, b := r.NormalizedValues(), o.NormalizedValues() + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/internal/model/record_test.go b/internal/model/record_test.go new file mode 100644 index 0000000..4df2f43 --- /dev/null +++ b/internal/model/record_test.go @@ -0,0 +1,116 @@ +package model + +import "testing" + +func TestManaged(t *testing.T) { + managed := []RecordType{A, AAAA, CNAME, MX, TXT, SRV} + for _, rt := range managed { + if !rt.Managed() { + t.Errorf("%s should be managed", rt) + } + } + for _, rt := range []RecordType{NS, SOA} { + if rt.Managed() { + t.Errorf("%s should be read-only", rt) + } + } +} + +func TestKeyNormalizesName(t *testing.T) { + r1 := Record{Type: A, Name: "www.Example.com"} + r2 := Record{Type: A, Name: "www.example.com."} + if r1.Key() != r2.Key() { + t.Fatalf("keys differ: %q vs %q", r1.Key(), r2.Key()) + } + if r1.Key() != "A www.example.com." { + t.Fatalf("unexpected key %q", r1.Key()) + } +} + +func TestEqualMXPriorityAndOrder(t *testing.T) { + a := Record{Type: MX, Name: "example.com.", TTL: 3600, Values: []string{"10 mx1.example.com.", "20 mx2.Example.com."}} + b := Record{Type: MX, Name: "example.com.", TTL: 3600, Values: []string{"20 mx2.example.com.", "10 mx1.example.com."}} + if !a.Equal(b) { + t.Fatal("MX records equal regardless of order and target case") + } + c := Record{Type: MX, Name: "example.com.", TTL: 3600, Values: []string{"30 mx1.example.com."}} + if a.Equal(c) { + t.Fatal("different priority must not be equal") + } + + // Isolated case: same value count and same target, only priority differs — + // must fail on priority comparison, not on a length mismatch shortcut. + d := Record{Type: MX, Name: "example.com.", TTL: 3600, Values: []string{"10 mx1.example.com."}} + e := Record{Type: MX, Name: "example.com.", TTL: 3600, Values: []string{"20 mx1.example.com."}} + if d.Equal(e) { + t.Fatal("different MX priority with same target and value count must not be equal") + } +} + +func TestEqualSRVBasic(t *testing.T) { + a := Record{Type: SRV, Name: "_sip._tcp.example.com.", TTL: 3600, Values: []string{ + "10 20 5060 sipserver.example.com.", + "5 10 5061 backup.example.com.", + }} + b := Record{Type: SRV, Name: "_sip._tcp.example.com.", TTL: 3600, Values: []string{ + "5 10 5061 BACKUP.Example.COM.", + "10 20 5060 SIPServer.Example.com.", + }} + if !a.Equal(b) { + t.Fatal("SRV records equal regardless of order and target case") + } + + // Isolated: identical single value except priority differs. + c1 := Record{Type: SRV, Name: "_sip._tcp.example.com.", TTL: 3600, Values: []string{"10 20 5060 sipserver.example.com."}} + c2 := Record{Type: SRV, Name: "_sip._tcp.example.com.", TTL: 3600, Values: []string{"20 20 5060 sipserver.example.com."}} + if c1.Equal(c2) { + t.Fatal("different SRV priority must not be equal") + } + + // Isolated: identical single value except port differs. + d1 := Record{Type: SRV, Name: "_sip._tcp.example.com.", TTL: 3600, Values: []string{"10 20 5060 sipserver.example.com."}} + d2 := Record{Type: SRV, Name: "_sip._tcp.example.com.", TTL: 3600, Values: []string{"10 20 5061 sipserver.example.com."}} + if d1.Equal(d2) { + t.Fatal("different SRV port must not be equal") + } +} + +func TestNormalizeValueIncompleteNoPanic(t *testing.T) { + // MX value missing the target field. + a := Record{Type: MX, Name: "example.com.", TTL: 300, Values: []string{"10"}} + b := Record{Type: MX, Name: "example.com.", TTL: 300, Values: []string{"10"}} + if !a.Equal(b) { + t.Fatal("incomplete MX values with identical content should be equal, not panic") + } + + // SRV value missing port and target fields. + c := Record{Type: SRV, Name: "_sip._tcp.example.com.", TTL: 300, Values: []string{"10 20"}} + d := Record{Type: SRV, Name: "_sip._tcp.example.com.", TTL: 300, Values: []string{"10 20"}} + if !c.Equal(d) { + t.Fatal("incomplete SRV values with identical content should be equal, not panic") + } +} + +func TestEqualTXTCaseSensitive(t *testing.T) { + a := Record{Type: TXT, Name: "example.com.", TTL: 60, Values: []string{"v=DKIM1; p=AbC"}} + b := Record{Type: TXT, Name: "example.com.", TTL: 60, Values: []string{"v=DKIM1; p=abc"}} + if a.Equal(b) { + t.Fatal("TXT is case-sensitive") + } +} + +func TestEqualTXTWhitespaceSignificant(t *testing.T) { + a := Record{Type: TXT, Name: "example.com.", TTL: 60, Values: []string{"v=spf1 a"}} + b := Record{Type: TXT, Name: "example.com.", TTL: 60, Values: []string{"v=spf1 a"}} + if a.Equal(b) { + t.Fatal("TXT records differing only in whitespace count must not be equal (byte-exact comparison)") + } +} + +func TestEqualTTLMatters(t *testing.T) { + a := Record{Type: A, Name: "example.com.", TTL: 300, Values: []string{"1.2.3.4"}} + b := Record{Type: A, Name: "example.com.", TTL: 600, Values: []string{"1.2.3.4"}} + if a.Equal(b) { + t.Fatal("different TTL must not be equal") + } +} diff --git a/internal/provider/provider.go b/internal/provider/provider.go new file mode 100644 index 0000000..f2a8f23 --- /dev/null +++ b/internal/provider/provider.go @@ -0,0 +1,28 @@ +package provider + +import ( + "context" + + "github.com/vasyakrg/dns-autoresolver/internal/diff" + "github.com/vasyakrg/dns-autoresolver/internal/model" +) + +// Credentials holds the secret used to authenticate against a provider. +// For Selectel this is the project-scoped token sent as X-Auth-Token. +type Credentials struct { + Secret string +} + +// Zone is a provider-neutral DNS zone reference. +type Zone struct { + ID string + Name string +} + +// Provider is implemented per DNS provider (Selectel first). +type Provider interface { + Name() string + ListZones(ctx context.Context, creds Credentials) ([]Zone, error) + GetRecords(ctx context.Context, creds Credentials, zoneID string) ([]model.Record, error) + ApplyChanges(ctx context.Context, creds Credentials, zoneID string, cs diff.Changeset) error +} diff --git a/internal/provider/provider_test.go b/internal/provider/provider_test.go new file mode 100644 index 0000000..800aad0 --- /dev/null +++ b/internal/provider/provider_test.go @@ -0,0 +1,31 @@ +package provider + +import ( + "context" + "testing" + + "github.com/vasyakrg/dns-autoresolver/internal/diff" + "github.com/vasyakrg/dns-autoresolver/internal/model" +) + +// stubProvider проверяет, что интерфейс реализуем. +type stubProvider struct{} + +func (stubProvider) Name() string { return "stub" } +func (stubProvider) ListZones(context.Context, Credentials) ([]Zone, error) { + return []Zone{{ID: "1", Name: "example.com."}}, nil +} +func (stubProvider) GetRecords(context.Context, Credentials, string) ([]model.Record, error) { + return nil, nil +} +func (stubProvider) ApplyChanges(context.Context, Credentials, string, diff.Changeset) error { + return nil +} + +func TestProviderInterfaceSatisfied(t *testing.T) { + var p Provider = stubProvider{} + zs, err := p.ListZones(context.Background(), Credentials{Secret: "x"}) + if err != nil || len(zs) != 1 || zs[0].Name != "example.com." { + t.Fatalf("unexpected: %v %v", zs, err) + } +} diff --git a/internal/provider/selectel/selectel.go b/internal/provider/selectel/selectel.go new file mode 100644 index 0000000..91c4949 --- /dev/null +++ b/internal/provider/selectel/selectel.go @@ -0,0 +1,215 @@ +package selectel + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "time" + + "github.com/vasyakrg/dns-autoresolver/internal/diff" + "github.com/vasyakrg/dns-autoresolver/internal/model" + "github.com/vasyakrg/dns-autoresolver/internal/provider" +) + +const DefaultBaseURL = "https://api.selectel.ru/domains/v2" + +// Client implements provider.Provider for Selectel DNS API v2. +type Client struct { + BaseURL string + HTTP *http.Client +} + +func New() *Client { + return &Client{BaseURL: DefaultBaseURL, HTTP: &http.Client{Timeout: 30 * time.Second}} +} + +func (c *Client) Name() string { return "selectel" } + +// --- wire types --- + +type apiZone struct { + ID string `json:"id"` + Name string `json:"name"` +} +type apiZoneList struct { + Result []apiZone `json:"result"` + NextOffset int `json:"next_offset"` +} +type apiRec struct { + Content string `json:"content"` + Disabled bool `json:"disabled,omitempty"` +} +type apiRRSet struct { + ID string `json:"id,omitempty"` + Name string `json:"name"` + Type string `json:"type"` + TTL int `json:"ttl"` + Records []apiRec `json:"records"` +} +type apiRRSetList struct { + Result []apiRRSet `json:"result"` + NextOffset int `json:"next_offset"` +} + +// --- HTTP helper --- + +func (c *Client) do(ctx context.Context, method, path, token string, body any, out any) error { + var reader io.Reader + if body != nil { + b, err := json.Marshal(body) + if err != nil { + return err + } + reader = bytes.NewReader(b) + } + req, err := http.NewRequestWithContext(ctx, method, c.BaseURL+path, reader) + if err != nil { + return err + } + req.Header.Set("X-Auth-Token", token) + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + resp, err := c.HTTP.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode >= 300 { + msg, _ := io.ReadAll(resp.Body) + return fmt.Errorf("selectel %s %s: %d: %s", method, path, resp.StatusCode, string(msg)) + } + if out != nil { + return json.NewDecoder(resp.Body).Decode(out) + } + return nil +} + +// --- Provider implementation --- + +func (c *Client) ListZones(ctx context.Context, creds provider.Credentials) ([]provider.Zone, error) { + var zones []provider.Zone + offset := 0 + for { + var page apiZoneList + path := fmt.Sprintf("/zones?limit=1000&offset=%d", offset) + if err := c.do(ctx, http.MethodGet, path, creds.Secret, nil, &page); err != nil { + return nil, err + } + for _, z := range page.Result { + zones = append(zones, provider.Zone{ID: z.ID, Name: z.Name}) + } + if page.NextOffset == 0 || len(page.Result) == 0 { + break + } + offset = page.NextOffset + } + return zones, nil +} + +func (c *Client) GetRecords(ctx context.Context, creds provider.Credentials, zoneID string) ([]model.Record, error) { + rrsets, err := c.listRRSets(ctx, creds.Secret, zoneID) + if err != nil { + return nil, err + } + recs := make([]model.Record, 0, len(rrsets)) + for _, rr := range rrsets { + recs = append(recs, toRecord(rr)) + } + return recs, nil +} + +func (c *Client) listRRSets(ctx context.Context, token, zoneID string) ([]apiRRSet, error) { + var all []apiRRSet + offset := 0 + for { + var page apiRRSetList + path := fmt.Sprintf("/zones/%s/rrset?limit=1000&offset=%d", url.PathEscape(zoneID), offset) + if err := c.do(ctx, http.MethodGet, path, token, nil, &page); err != nil { + return nil, err + } + all = append(all, page.Result...) + if page.NextOffset == 0 || len(page.Result) == 0 { + break + } + offset = page.NextOffset + } + return all, nil +} + +func (c *Client) ApplyChanges(ctx context.Context, creds provider.Credentials, zoneID string, cs diff.Changeset) error { + // resolve rrset ids for update/delete + existing, err := c.listRRSets(ctx, creds.Secret, zoneID) + if err != nil { + return err + } + idByKey := make(map[string]string, len(existing)) + for _, rr := range existing { + idByKey[toRecord(rr).Key()] = rr.ID + } + + base := "/zones/" + url.PathEscape(zoneID) + "/rrset" + for _, d := range cs.Diffs { + if d.ReadOnly || d.Kind == diff.InSync { + continue + } + switch d.Kind { + case diff.Add: + if d.Desired == nil { + return fmt.Errorf("selectel: add/update diff without Desired record") + } + if err := c.do(ctx, http.MethodPost, base, creds.Secret, toRRSet(*d.Desired), nil); err != nil { + return err + } + case diff.Update: + if d.Desired == nil { + return fmt.Errorf("selectel: add/update diff without Desired record") + } + id, ok := idByKey[d.Desired.Key()] + if !ok { + return fmt.Errorf("cannot update: rrset %s not found in zone", d.Desired.Key()) + } + if err := c.do(ctx, http.MethodPatch, base+"/"+url.PathEscape(id), creds.Secret, toRRSet(*d.Desired), nil); err != nil { + return err + } + case diff.Delete: + if d.Actual == nil { + return fmt.Errorf("selectel: delete diff without Actual record") + } + id, ok := idByKey[d.Actual.Key()] + if !ok { + return fmt.Errorf("cannot delete: rrset %s not found in zone", d.Actual.Key()) + } + if err := c.do(ctx, http.MethodDelete, base+"/"+url.PathEscape(id), creds.Secret, nil, nil); err != nil { + return err + } + } + } + return nil +} + +func toRecord(rr apiRRSet) model.Record { + vals := make([]string, 0, len(rr.Records)) + for _, r := range rr.Records { + if r.Disabled { + continue + } + vals = append(vals, r.Content) + } + return model.Record{Type: model.RecordType(rr.Type), Name: rr.Name, TTL: rr.TTL, Values: vals} +} + +func toRRSet(rec model.Record) apiRRSet { + rs := apiRRSet{Name: rec.Name, Type: string(rec.Type), TTL: rec.TTL} + for _, v := range rec.Values { + rs.Records = append(rs.Records, apiRec{Content: v}) + } + return rs +} + +// compile-time check +var _ provider.Provider = (*Client)(nil) diff --git a/internal/provider/selectel/selectel_test.go b/internal/provider/selectel/selectel_test.go new file mode 100644 index 0000000..4b08519 --- /dev/null +++ b/internal/provider/selectel/selectel_test.go @@ -0,0 +1,373 @@ +package selectel + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/vasyakrg/dns-autoresolver/internal/diff" + "github.com/vasyakrg/dns-autoresolver/internal/model" + "github.com/vasyakrg/dns-autoresolver/internal/provider" +) + +func creds() provider.Credentials { return provider.Credentials{Secret: "secret-token"} } + +func newTestClient(h http.Handler) (*Client, *httptest.Server) { + srv := httptest.NewServer(h) + return &Client{BaseURL: srv.URL, HTTP: srv.Client()}, srv +} + +func TestListZonesSendsTokenAndParses(t *testing.T) { + var gotToken string + c, srv := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotToken = r.Header.Get("X-Auth-Token") + json.NewEncoder(w).Encode(map[string]any{ + "result": []map[string]any{ + {"id": "z1", "name": "example.com."}, + {"id": "z2", "name": "test.org."}, + }, + "next_offset": 0, + }) + })) + defer srv.Close() + + zs, err := c.ListZones(context.Background(), creds()) + if err != nil { + t.Fatal(err) + } + if gotToken != "secret-token" { + t.Fatalf("token not sent, got %q", gotToken) + } + if len(zs) != 2 || zs[0].ID != "z1" || zs[1].Name != "test.org." { + t.Fatalf("unexpected zones: %+v", zs) + } +} + +func TestGetRecordsMapsRRSet(t *testing.T) { + c, srv := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]any{ + "result": []map[string]any{ + {"id": "r1", "name": "example.com.", "type": "MX", "ttl": 3600, + "records": []map[string]any{{"content": "10 mx1.example.com.", "disabled": false}}}, + {"id": "r2", "name": "www.example.com.", "type": "A", "ttl": 300, + "records": []map[string]any{{"content": "1.2.3.4"}, {"content": "5.6.7.8", "disabled": true}}}, + }, + "next_offset": 0, + }) + })) + defer srv.Close() + + recs, err := c.GetRecords(context.Background(), creds(), "z1") + if err != nil { + t.Fatal(err) + } + if len(recs) != 2 { + t.Fatalf("want 2 records, got %d", len(recs)) + } + var a model.Record + for _, r := range recs { + if r.Type == model.A { + a = r + } + } + // disabled record dropped -> only one value + if len(a.Values) != 1 || a.Values[0] != "1.2.3.4" { + t.Fatalf("disabled record must be skipped, got %+v", a.Values) + } +} + +func TestApplyChangesRoutesVerbs(t *testing.T) { + type call struct{ method, path string } + var calls []call + c, srv := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // GET rrset -> return existing set with ids for update/delete resolution + if r.Method == http.MethodGet { + json.NewEncoder(w).Encode(map[string]any{ + "result": []map[string]any{ + {"id": "up1", "name": "b.example.com.", "type": "A", "ttl": 300, + "records": []map[string]any{{"content": "9.9.9.9"}}}, + {"id": "del1", "name": "d.example.com.", "type": "A", "ttl": 300, + "records": []map[string]any{{"content": "4.4.4.4"}}}, + }, + "next_offset": 0, + }) + return + } + calls = append(calls, call{r.Method, r.URL.Path}) + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + add := model.Record{Type: model.A, Name: "c.example.com.", TTL: 300, Values: []string{"3.3.3.3"}} + updDesired := model.Record{Type: model.A, Name: "b.example.com.", TTL: 300, Values: []string{"2.2.2.2"}} + delActual := model.Record{Type: model.A, Name: "d.example.com.", TTL: 300, Values: []string{"4.4.4.4"}} + ns := model.Record{Type: model.NS, Name: "example.com.", TTL: 3600, Values: []string{"ns1.example.com."}} + + cs := diff.Changeset{Diffs: []diff.RecordDiff{ + {Kind: diff.Add, Type: add.Type, Name: add.Name, Desired: &add}, + {Kind: diff.Update, Type: updDesired.Type, Name: updDesired.Name, Desired: &updDesired}, + {Kind: diff.Delete, Type: delActual.Type, Name: delActual.Name, Actual: &delActual}, + {Kind: diff.Update, Type: ns.Type, Name: ns.Name, Desired: &ns, ReadOnly: true}, // must be skipped + }} + + if err := c.ApplyChanges(context.Background(), creds(), "z1", cs); err != nil { + t.Fatal(err) + } + + want := map[string]bool{ + "POST /zones/z1/rrset": true, + "PATCH /zones/z1/rrset/up1": true, + "DELETE /zones/z1/rrset/del1": true, + } + if len(calls) != len(want) { + t.Fatalf("want %d calls, got %v", len(want), calls) + } + for _, cl := range calls { + if !want[cl.method+" "+cl.path] { + t.Fatalf("unexpected call %s %s", cl.method, cl.path) + } + } +} + +// Global Constraint: id not found for Update -> error, and mutation must not proceed. +func TestApplyChangesUpdateIDNotFoundReturnsErrorAndSkipsMutation(t *testing.T) { + var calls []string + c, srv := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + // empty existing rrset set -> nothing resolves to an id + json.NewEncoder(w).Encode(map[string]any{"result": []map[string]any{}, "next_offset": 0}) + return + } + calls = append(calls, r.Method+" "+r.URL.Path) + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + missing := model.Record{Type: model.A, Name: "missing.example.com.", TTL: 300, Values: []string{"1.1.1.1"}} + add := model.Record{Type: model.A, Name: "new.example.com.", TTL: 300, Values: []string{"2.2.2.2"}} + + cs := diff.Changeset{Diffs: []diff.RecordDiff{ + {Kind: diff.Update, Type: missing.Type, Name: missing.Name, Desired: &missing}, + {Kind: diff.Add, Type: add.Type, Name: add.Name, Desired: &add}, + }} + + err := c.ApplyChanges(context.Background(), creds(), "z1", cs) + if err == nil { + t.Fatal("expected non-nil error when update rrset id is not found") + } + if len(calls) != 0 { + t.Fatalf("expected no mutating requests to be sent, got %v", calls) + } +} + +// Global Constraint: id not found for Delete -> error. +func TestApplyChangesDeleteIDNotFoundReturnsError(t *testing.T) { + c, srv := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + json.NewEncoder(w).Encode(map[string]any{"result": []map[string]any{}, "next_offset": 0}) + return + } + t.Fatalf("unexpected mutating call %s %s, delete should have errored before reaching HTTP", r.Method, r.URL.Path) + })) + defer srv.Close() + + missing := model.Record{Type: model.A, Name: "missing.example.com.", TTL: 300, Values: []string{"1.1.1.1"}} + cs := diff.Changeset{Diffs: []diff.RecordDiff{ + {Kind: diff.Delete, Type: missing.Type, Name: missing.Name, Actual: &missing}, + }} + + if err := c.ApplyChanges(context.Background(), creds(), "z1", cs); err == nil { + t.Fatal("expected non-nil error when delete rrset id is not found") + } +} + +// Global Constraint: X-Auth-Token must be sent on mutating requests (POST/PATCH/DELETE), not only on GET. +func TestApplyChangesSendsTokenOnMutations(t *testing.T) { + var tokens []string + c, srv := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + json.NewEncoder(w).Encode(map[string]any{ + "result": []map[string]any{ + {"id": "up1", "name": "b.example.com.", "type": "A", "ttl": 300, + "records": []map[string]any{{"content": "9.9.9.9"}}}, + {"id": "del1", "name": "d.example.com.", "type": "A", "ttl": 300, + "records": []map[string]any{{"content": "4.4.4.4"}}}, + }, + "next_offset": 0, + }) + return + } + tokens = append(tokens, r.Header.Get("X-Auth-Token")) + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + add := model.Record{Type: model.A, Name: "c.example.com.", TTL: 300, Values: []string{"3.3.3.3"}} + updDesired := model.Record{Type: model.A, Name: "b.example.com.", TTL: 300, Values: []string{"2.2.2.2"}} + delActual := model.Record{Type: model.A, Name: "d.example.com.", TTL: 300, Values: []string{"4.4.4.4"}} + + cs := diff.Changeset{Diffs: []diff.RecordDiff{ + {Kind: diff.Add, Type: add.Type, Name: add.Name, Desired: &add}, + {Kind: diff.Update, Type: updDesired.Type, Name: updDesired.Name, Desired: &updDesired}, + {Kind: diff.Delete, Type: delActual.Type, Name: delActual.Name, Actual: &delActual}, + }} + + if err := c.ApplyChanges(context.Background(), creds(), "z1", cs); err != nil { + t.Fatal(err) + } + if len(tokens) != 3 { + t.Fatalf("expected 3 mutating requests (POST/PATCH/DELETE), got %d", len(tokens)) + } + for _, tok := range tokens { + if tok != "secret-token" { + t.Fatalf("expected X-Auth-Token %q on mutation, got %q", "secret-token", tok) + } + } +} + +// Global Constraint: multi-page pagination must accumulate records across pages without an infinite loop. +func TestListZonesPaginatesAcrossMultiplePages(t *testing.T) { + var offsets []string + c, srv := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + offset := r.URL.Query().Get("offset") + offsets = append(offsets, offset) + if len(offsets) > 2 { + t.Fatalf("too many requests, possible infinite pagination loop: %v", offsets) + } + switch offset { + case "0": + json.NewEncoder(w).Encode(map[string]any{ + "result": []map[string]any{{"id": "z1", "name": "first.example.com."}}, + "next_offset": 1000, + }) + case "1000": + json.NewEncoder(w).Encode(map[string]any{ + "result": []map[string]any{{"id": "z2", "name": "second.example.com."}}, + "next_offset": 0, + }) + default: + t.Fatalf("unexpected offset %q", offset) + } + })) + defer srv.Close() + + zs, err := c.ListZones(context.Background(), creds()) + if err != nil { + t.Fatal(err) + } + if len(offsets) != 2 { + t.Fatalf("expected exactly 2 page requests, got %d: %v", len(offsets), offsets) + } + if len(zs) != 2 || zs[0].ID != "z1" || zs[1].ID != "z2" { + t.Fatalf("expected accumulated zones from both pages, got %+v", zs) + } +} + +// Global Constraint: listRRSets (via GetRecords) must paginate across multiple pages, +// accumulating records from every page, and must stop as soon as next_offset is 0 — +// no third request should ever be issued. +func TestGetRecordsPaginatesAcrossMultiplePages(t *testing.T) { + var offsets []string + c, srv := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + offset := r.URL.Query().Get("offset") + offsets = append(offsets, offset) + if len(offsets) > 2 { + t.Fatalf("too many requests, possible infinite pagination loop: %v", offsets) + } + switch offset { + case "0": + json.NewEncoder(w).Encode(map[string]any{ + "result": []map[string]any{ + {"id": "r1", "name": "a.example.com.", "type": "A", "ttl": 300, + "records": []map[string]any{{"content": "1.1.1.1"}}}, + }, + "next_offset": 1000, + }) + case "1000": + json.NewEncoder(w).Encode(map[string]any{ + "result": []map[string]any{ + {"id": "r2", "name": "b.example.com.", "type": "A", "ttl": 300, + "records": []map[string]any{{"content": "2.2.2.2"}}}, + }, + "next_offset": 0, + }) + default: + t.Fatalf("unexpected offset %q", offset) + } + })) + defer srv.Close() + + recs, err := c.GetRecords(context.Background(), creds(), "z1") + if err != nil { + t.Fatal(err) + } + if len(offsets) != 2 { + t.Fatalf("expected exactly 2 page requests, got %d: %v", len(offsets), offsets) + } + if len(recs) != 2 { + t.Fatalf("expected accumulated records from both pages, got %+v", recs) + } + names := map[string]bool{recs[0].Name: true, recs[1].Name: true} + if !names["a.example.com."] || !names["b.example.com."] { + t.Fatalf("expected records from both pages, got %+v", recs) + } +} + +// Global Constraint: ApplyChanges must not panic on a Changeset with a nil Desired +// record for Add/Update, and must instead return a clear error. +func TestApplyChangesAddWithNilDesiredReturnsErrorNoPanic(t *testing.T) { + c, srv := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + json.NewEncoder(w).Encode(map[string]any{"result": []map[string]any{}, "next_offset": 0}) + return + } + t.Fatalf("unexpected mutating call %s %s, nil Desired should have errored before reaching HTTP", r.Method, r.URL.Path) + })) + defer srv.Close() + + cs := diff.Changeset{Diffs: []diff.RecordDiff{ + {Kind: diff.Add, Type: model.A, Name: "nil-desired.example.com.", Desired: nil}, + }} + + defer func() { + if r := recover(); r != nil { + t.Fatalf("ApplyChanges panicked on nil Desired: %v", r) + } + }() + + err := c.ApplyChanges(context.Background(), creds(), "z1", cs) + if err == nil { + t.Fatal("expected non-nil error for Add diff with nil Desired") + } +} + +// Global Constraint: HTTP errors (status >= 300) must surface a non-nil error whose text +// includes the method/path/status (or response body) for diagnosability. +func TestListZonesHTTPErrorIncludesMethodPathStatus(t *testing.T) { + c, srv := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte("zone not found")) + })) + defer srv.Close() + + _, err := c.ListZones(context.Background(), creds()) + if err == nil { + t.Fatal("expected non-nil error on non-2xx response") + } + msg := err.Error() + if !strings.Contains(msg, http.MethodGet) { + t.Fatalf("error should mention HTTP method %q, got %q", http.MethodGet, msg) + } + if !strings.Contains(msg, "404") { + t.Fatalf("error should mention status code 404, got %q", msg) + } + if !strings.Contains(msg, "/zones") { + t.Fatalf("error should mention request path, got %q", msg) + } + if !strings.Contains(msg, "zone not found") { + t.Fatalf("error should include response body, got %q", msg) + } +}