Merge feature/phase-3: scheduler, notifications (Telegram/Webhook), Prometheus metrics

Планировщик периодических проверок (read-only), уведомления по смене статуса
(bot_token шифруется, SSRF-guard с пиннингом IP + CGNAT), Prometheus /metrics.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01BwxdSt4reTm7Dj1oxRvpP3
This commit is contained in:
2026-07-04 15:16:58 +07:00
47 changed files with 4278 additions and 21 deletions
+70 -11
View File
@@ -2,9 +2,13 @@ package main
import (
"context"
"errors"
"log"
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/jackc/pgx/v5/pgxpool"
@@ -13,8 +17,11 @@ import (
"github.com/vasyakrg/dns-autoresolver/internal/auth"
"github.com/vasyakrg/dns-autoresolver/internal/config"
"github.com/vasyakrg/dns-autoresolver/internal/crypto"
"github.com/vasyakrg/dns-autoresolver/internal/metrics"
"github.com/vasyakrg/dns-autoresolver/internal/notify"
"github.com/vasyakrg/dns-autoresolver/internal/provider/registry"
"github.com/vasyakrg/dns-autoresolver/internal/provider/selectel"
"github.com/vasyakrg/dns-autoresolver/internal/scheduler"
"github.com/vasyakrg/dns-autoresolver/internal/service"
"github.com/vasyakrg/dns-autoresolver/internal/store"
"github.com/vasyakrg/dns-autoresolver/internal/web"
@@ -24,6 +31,16 @@ import (
// user must re-authenticate.
const sessionTTL = 720 * time.Hour
// schedulerTick is how often the in-process scheduler checks for due project
// schedules. Individual projects only actually run when their own
// schedules.interval_seconds has elapsed (see internal/store ListDueSchedules) —
// this is just the polling granularity.
const schedulerTick = time.Minute
// shutdownTimeout bounds how long graceful shutdown waits for in-flight HTTP
// requests to finish before forcing the listener closed.
const shutdownTimeout = 10 * time.Second
// isAPIPath reports whether path must be routed to the API router rather
// than the SPA. "/api" (no trailing slash) counts as an API path too —
// only strings.HasPrefix(path, "/api/") would otherwise miss it and fall
@@ -33,7 +50,9 @@ func isAPIPath(path string) bool {
}
func main() {
ctx := context.Background()
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
cfg, err := config.Load()
if err != nil {
log.Fatalf("config: %v", err)
@@ -58,7 +77,14 @@ func main() {
reg.Register(selectel.New())
svc := service.New(st, st, reg, cipher)
a := &api.API{Svc: svc, Store: st, Cipher: cipher, Reg: reg, Auth: st, Sessions: sessions}
m := metrics.New()
dispatcher := notify.NewDispatcher(st, cipher)
a := &api.API{
Svc: svc, Store: st, Cipher: cipher, Reg: reg, Auth: st, Sessions: sessions,
Schedule: st, Dispatch: dispatcher,
}
apiRouter := api.NewRouter(a)
webHandler, err := web.Handler()
@@ -66,20 +92,53 @@ func main() {
log.Printf("web: static UI unavailable: %v", err)
}
// The scheduler only checks and notifies — it never applies zone changes
// (Apply stays a manual, explicit API call). Its own errors are logged
// internally and never stop the loop; ctx cancellation (signal) is the
// only thing that ends Run.
sched := scheduler.New(st, svc, dispatcher, m)
go sched.Run(ctx, schedulerTick)
mux := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if isAPIPath(r.URL.Path) {
switch {
case r.URL.Path == "/metrics":
// Public by design (no auth) — Metrics.Handler only ever exposes
// aggregate counters/gauges, never per-domain or secret data.
m.Handler().ServeHTTP(w, r)
case isAPIPath(r.URL.Path):
apiRouter.ServeHTTP(w, r)
return
}
if webHandler != nil {
case webHandler != nil:
webHandler.ServeHTTP(w, r)
return
default:
http.NotFound(w, r)
}
http.NotFound(w, r)
})
log.Printf("listening on %s", cfg.ListenAddr)
if err := http.ListenAndServe(cfg.ListenAddr, mux); err != nil {
log.Fatal(err)
srv := &http.Server{Addr: cfg.ListenAddr, Handler: mux}
serveErr := make(chan error, 1)
go func() {
log.Printf("listening on %s", cfg.ListenAddr)
if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
serveErr <- err
return
}
serveErr <- nil
}()
select {
case <-ctx.Done():
log.Printf("shutdown signal received, draining connections (timeout %s)", shutdownTimeout)
shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
defer cancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
log.Printf("server: graceful shutdown failed: %v", err)
}
<-serveErr
log.Printf("server stopped")
case err := <-serveErr:
if err != nil {
log.Fatalf("server: %v", err)
}
}
}
+9
View File
@@ -7,6 +7,7 @@ require (
github.com/google/uuid v1.6.0
github.com/jackc/pgx/v5 v5.10.0
github.com/pressly/goose/v3 v3.27.2
github.com/prometheus/client_golang v1.23.2
github.com/testcontainers/testcontainers-go v0.43.0
github.com/testcontainers/testcontainers-go/modules/postgres v0.43.0
golang.org/x/crypto v0.53.0
@@ -16,6 +17,7 @@ require (
dario.cat/mergo v1.0.2 // indirect
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/containerd/errdefs v1.0.0 // indirect
@@ -36,6 +38,7 @@ require (
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/klauspost/compress v1.18.5 // indirect
github.com/kylelemons/godebug v1.1.0 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/magiconair/properties v1.8.10 // indirect
github.com/mfridman/interpolate v0.0.2 // indirect
@@ -48,10 +51,14 @@ require (
github.com/moby/sys/user v0.4.0 // indirect
github.com/moby/sys/userns v0.1.0 // indirect
github.com/moby/term v0.5.2 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.66.1 // indirect
github.com/prometheus/procfs v0.20.1 // indirect
github.com/sethvargo/go-retry v0.3.0 // indirect
github.com/shirou/gopsutil/v4 v4.26.5 // indirect
github.com/sirupsen/logrus v1.9.4 // indirect
@@ -65,8 +72,10 @@ require (
go.opentelemetry.io/otel/metric v1.43.0 // indirect
go.opentelemetry.io/otel/trace v1.43.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
golang.org/x/sync v0.21.0 // indirect
golang.org/x/sys v0.46.0 // indirect
golang.org/x/text v0.38.0 // indirect
google.golang.org/protobuf v1.36.11 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
+20
View File
@@ -6,6 +6,8 @@ github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c h1:udKWzYgxTojEK
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
@@ -65,6 +67,8 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
@@ -95,6 +99,8 @@ github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g
github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28=
github.com/moby/term v0.5.2 h1:6qk3FJAFDs6i/q3W/pQ97SX192qKfZgGjCQqfCJkgzQ=
github.com/moby/term v0.5.2/go.mod h1:d3djjFCrjnB+fl8NJux+EJzu0msscUP+f8it8hPkFLc=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
@@ -107,6 +113,14 @@ github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/pressly/goose/v3 v3.27.2 h1:FjKNzcmMdGrQlSIu5alMSmakQtJFBgtw+A0bb1p/LC8=
github.com/pressly/goose/v3 v3.27.2/go.mod h1:qWW+/8dkVtJYjJrbIpwD5xxnEJTUKvxkQ9JKQp9LaIM=
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
github.com/prometheus/procfs v0.20.1 h1:XwbrGOIplXW/AU3YhIhLODXMJYyC1isLFfYCsTEycfc=
github.com/prometheus/procfs v0.20.1/go.mod h1:o9EMBZGRyvDrSPH1RqdxhojkuXstoe4UlK79eF5TGGo=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
@@ -148,8 +162,12 @@ go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfC
go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A=
go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A=
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
golang.org/x/crypto v0.53.0 h1:QZ4Muo8THX6CizN2vPPd5fBGHyogrdK9fG4wLPFUsto=
golang.org/x/crypto v0.53.0/go.mod h1:DNLU434OwVakk9PzuwV8w62mAJpRJL3vsgcfp4Qnsio=
golang.org/x/sync v0.21.0 h1:HLII4xRRTtCRkxYp4HNFF0Js/Og6q2i++KXbg0gHCwM=
@@ -164,6 +182,8 @@ golang.org/x/term v0.44.0/go.mod h1:7ze4MdzUzLXpSAoFP1H0bOI9aXDqveSvatT5vKcFh2Y=
golang.org/x/text v0.38.0 h1:sXmwo9DwP3OK9EZ7PqAdaooSGozfl/3a6/xJcbzPRhE=
golang.org/x/text v0.38.0/go.mod h1:YXZt3QhHUKYT53r2lLKFIVi6Ao1jdzrTR/KQ09qyxF4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
+14
View File
@@ -85,6 +85,8 @@ type API struct {
Reg ProviderRegistry
Auth AuthStore
Sessions SessionManager
Schedule ScheduleStore
Dispatch TestSender
}
func NewRouter(a *API) http.Handler {
@@ -114,6 +116,18 @@ func NewRouter(a *API) http.Handler {
r.Post("/apply", a.handleApply)
r.Patch("/", a.handleSetDomainTemplate)
r.Delete("/", a.handleDeleteDomain)
r.Get("/history", a.handleDomainHistory)
})
})
r.Get("/schedule", a.handleGetSchedule)
r.Put("/schedule", a.handlePutSchedule)
r.Route("/channels", func(r chi.Router) {
r.Post("/", a.handleCreateChannel)
r.Get("/", a.handleListChannels)
r.Route("/{cid}", func(r chi.Router) {
r.Delete("/", a.handleDeleteChannel)
r.Post("/test", a.handleTestChannel)
})
})
+280
View File
@@ -0,0 +1,280 @@
package api
import (
"context"
"encoding/json"
"errors"
"log"
"net/http"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/vasyakrg/dns-autoresolver/internal/store"
)
// ScheduleStore is the persistence surface the schedule/channels/history
// handlers depend on. *store.Store satisfies it directly via its thin
// wrapper methods (see internal/store/tenant.go, internal/store/loader.go);
// tests can supply their own mock.
type ScheduleStore interface {
GetSchedule(ctx context.Context, projectID uuid.UUID) (store.Schedule, error)
UpsertSchedule(ctx context.Context, projectID uuid.UUID, interval int32, enabled bool) (store.Schedule, error)
CreateChannel(ctx context.Context, projectID uuid.UUID, ctype string, config json.RawMessage, secretEnc string) (store.Channel, error)
ListChannels(ctx context.Context, projectID uuid.UUID) ([]store.Channel, error)
GetChannel(ctx context.Context, id, projectID uuid.UUID) (store.Channel, error)
DeleteChannel(ctx context.Context, id, projectID uuid.UUID) error
// GetDomain verifies domainID belongs to projectID — required before
// ListCheckRuns, which is not itself scoped by project.
GetDomain(ctx context.Context, id, projectID uuid.UUID) (store.Domain, error)
ListCheckRuns(ctx context.Context, domainID uuid.UUID) ([]store.CheckRun, error)
}
// TestSender sends a one-off test notification through a single
// notification channel (POST /channels/{cid}/test). It's a narrow surface
// deliberately decoupled from notify.Dispatcher (which fans a single event
// out to every enabled channel of a project by project ID, not a single
// channel by ID) — cmd/server wiring (Task 6) supplies a concrete adapter
// over internal/notify's Notifiers; tests supply a mock.
type TestSender interface {
SendTest(ctx context.Context, channelType string, config json.RawMessage, secret string) error
}
const minScheduleIntervalSeconds = 60
// defaultScheduleResponse is what GET /schedule returns when the project has
// never created a schedule row (store.GetSchedule returns pgx.ErrNoRows).
var defaultScheduleResponse = scheduleResponse{IntervalSeconds: 3600, Enabled: false}
type scheduleRequest struct {
IntervalSeconds int32 `json:"intervalSeconds"`
Enabled bool `json:"enabled"`
}
type scheduleResponse struct {
IntervalSeconds int32 `json:"intervalSeconds"`
Enabled bool `json:"enabled"`
LastRunAt *time.Time `json:"lastRunAt,omitempty"`
}
func toScheduleResponse(s store.Schedule) scheduleResponse {
return scheduleResponse{IntervalSeconds: s.IntervalSeconds, Enabled: s.Enabled, LastRunAt: s.LastRunAt}
}
// --- schedule ---
func (a *API) handleGetSchedule(w http.ResponseWriter, r *http.Request) {
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
sc, err := a.Schedule.GetSchedule(r.Context(), pid)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
writeJSON(w, http.StatusOK, defaultScheduleResponse)
return
}
log.Printf("api: get schedule failed: %v", err)
writeErr(w, http.StatusInternalServerError, "internal error")
return
}
writeJSON(w, http.StatusOK, toScheduleResponse(sc))
}
func (a *API) handlePutSchedule(w http.ResponseWriter, r *http.Request) {
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
var req scheduleRequest
if !decodeBody(w, r, &req) {
return
}
if req.IntervalSeconds < minScheduleIntervalSeconds {
writeErr(w, http.StatusBadRequest, "intervalSeconds must be >= 60")
return
}
sc, err := a.Schedule.UpsertSchedule(r.Context(), pid, req.IntervalSeconds, req.Enabled)
if err != nil {
log.Printf("api: upsert schedule failed: %v", err)
writeErr(w, http.StatusInternalServerError, "internal error")
return
}
writeJSON(w, http.StatusOK, toScheduleResponse(sc))
}
// --- channels ---
type channelRequest struct {
Type string `json:"type"`
Config json.RawMessage `json:"config"`
Secret string `json:"secret"`
}
// channelResponse deliberately excludes the secret (plaintext or encrypted) —
// bot tokens/webhook signing keys must never reach an API response.
type channelResponse struct {
ID string `json:"id"`
Type string `json:"type"`
Config json.RawMessage `json:"config"`
Enabled bool `json:"enabled"`
}
func toChannelResponse(c store.Channel) channelResponse {
return channelResponse{ID: c.ID.String(), Type: c.Type, Config: c.Config, Enabled: c.Enabled}
}
func (a *API) handleCreateChannel(w http.ResponseWriter, r *http.Request) {
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
var req channelRequest
if !decodeBody(w, r, &req) {
return
}
if req.Type == "" {
writeErr(w, http.StatusBadRequest, "type is required")
return
}
if len(req.Config) == 0 {
req.Config = json.RawMessage("{}")
}
secretEnc := ""
if req.Secret != "" {
enc, err := a.Cipher.Encrypt([]byte(req.Secret))
if err != nil {
log.Printf("api: encrypt channel secret failed: %v", err)
writeErr(w, http.StatusInternalServerError, "internal error")
return
}
secretEnc = enc
}
ch, err := a.Schedule.CreateChannel(r.Context(), pid, req.Type, req.Config, secretEnc)
if err != nil {
log.Printf("api: create channel failed: %v", err)
writeErr(w, http.StatusInternalServerError, "internal error")
return
}
writeJSON(w, http.StatusCreated, toChannelResponse(ch))
}
func (a *API) handleListChannels(w http.ResponseWriter, r *http.Request) {
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
chs, err := a.Schedule.ListChannels(r.Context(), pid)
if err != nil {
log.Printf("api: list channels failed: %v", err)
writeErr(w, http.StatusInternalServerError, "internal error")
return
}
resp := make([]channelResponse, 0, len(chs))
for _, c := range chs {
resp = append(resp, toChannelResponse(c))
}
writeJSON(w, http.StatusOK, resp)
}
func (a *API) handleDeleteChannel(w http.ResponseWriter, r *http.Request) {
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
cid, err := uuid.Parse(chi.URLParam(r, "cid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid channel id")
return
}
if err := a.Schedule.DeleteChannel(r.Context(), cid, pid); err != nil {
log.Printf("api: delete channel failed: %v", err)
writeErr(w, http.StatusInternalServerError, "internal error")
return
}
w.WriteHeader(http.StatusNoContent)
}
// handleTestChannel sends a one-off test notification through a single
// channel so a user can verify their bot_token/chat_id or webhook URL work
// before enabling the schedule. The channel's secret is decrypted only in
// memory to make the outbound call — it's never echoed back, and a failure
// from the remote channel (bad token, unreachable webhook) is reported as
// 502 without including any secret material in the error body.
func (a *API) handleTestChannel(w http.ResponseWriter, r *http.Request) {
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
cid, err := uuid.Parse(chi.URLParam(r, "cid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid channel id")
return
}
ch, err := a.Schedule.GetChannel(r.Context(), cid, pid)
if err != nil {
writeErr(w, http.StatusNotFound, "channel not found")
return
}
secret := ""
if ch.SecretEnc != "" {
dec, err := a.Cipher.Decrypt(ch.SecretEnc)
if err != nil {
log.Printf("api: decrypt channel secret failed: %v", err)
writeErr(w, http.StatusInternalServerError, "internal error")
return
}
secret = string(dec)
}
if err := a.Dispatch.SendTest(r.Context(), ch.Type, ch.Config, secret); err != nil {
// Defense-in-depth: notify implementations sanitize errors before
// returning them (no secret/URL material), but this log deliberately
// omits the raw error (%v) anyway so a lower-layer regression can
// never leak a bot token or webhook URL into logs.
log.Printf("api: test channel %s failed", cid)
writeErr(w, http.StatusBadGateway, "channel test failed")
return
}
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
}
// --- history ---
type checkRunResponse struct {
ID string `json:"id"`
CreatedAt time.Time `json:"createdAt"`
Result json.RawMessage `json:"result"`
}
func toCheckRunResponse(c store.CheckRun) checkRunResponse {
return checkRunResponse{ID: c.ID.String(), CreatedAt: c.CreatedAt, Result: c.Result}
}
// handleDomainHistory returns the most recent check_runs for a domain.
// check_runs.domain_id has no project scoping of its own, so this handler
// must first confirm the domain belongs to the caller's project (GetDomain)
// before listing its history — otherwise a caller could enumerate another
// tenant's domain IDs to read their check history (IDOR).
func (a *API) handleDomainHistory(w http.ResponseWriter, r *http.Request) {
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
did, err := uuid.Parse(chi.URLParam(r, "did"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid domain id")
return
}
if _, err := a.Schedule.GetDomain(r.Context(), did, pid); err != nil {
writeErr(w, http.StatusNotFound, "domain not found")
return
}
runs, err := a.Schedule.ListCheckRuns(r.Context(), did)
if err != nil {
log.Printf("api: list check runs failed: %v", err)
writeErr(w, http.StatusInternalServerError, "internal error")
return
}
resp := make([]checkRunResponse, 0, len(runs))
for _, cr := range runs {
resp = append(resp, toCheckRunResponse(cr))
}
writeJSON(w, http.StatusOK, resp)
}
+433
View File
@@ -0,0 +1,433 @@
package api
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/vasyakrg/dns-autoresolver/internal/store"
)
// --- mock ScheduleStore ---
type mockScheduleStore struct {
schedule store.Schedule
scheduleErr error
upsertCalled bool
upsertInterval int32
upsertEnabled bool
upsertResult store.Schedule
upsertErr error
channels map[uuid.UUID]store.Channel
createChannelIn []struct {
ctype string
config json.RawMessage
secretEnc string
}
createChannelErr error
deleteChannelCalled bool
deletedChannelID uuid.UUID
deleteChannelErr error
domains map[uuid.UUID]store.Domain
checkRuns map[uuid.UUID][]store.CheckRun
listRunsErr error
}
func newMockScheduleStore() *mockScheduleStore {
return &mockScheduleStore{
channels: map[uuid.UUID]store.Channel{},
domains: map[uuid.UUID]store.Domain{},
checkRuns: map[uuid.UUID][]store.CheckRun{},
}
}
func (m *mockScheduleStore) GetSchedule(context.Context, uuid.UUID) (store.Schedule, error) {
if m.scheduleErr != nil {
return store.Schedule{}, m.scheduleErr
}
return m.schedule, nil
}
func (m *mockScheduleStore) UpsertSchedule(_ context.Context, projectID uuid.UUID, interval int32, enabled bool) (store.Schedule, error) {
m.upsertCalled = true
m.upsertInterval = interval
m.upsertEnabled = enabled
if m.upsertErr != nil {
return store.Schedule{}, m.upsertErr
}
if m.upsertResult.ID != uuid.Nil {
return m.upsertResult, nil
}
return store.Schedule{ID: uuid.New(), ProjectID: projectID, IntervalSeconds: interval, Enabled: enabled}, nil
}
func (m *mockScheduleStore) CreateChannel(_ context.Context, projectID uuid.UUID, ctype string, config json.RawMessage, secretEnc string) (store.Channel, error) {
m.createChannelIn = append(m.createChannelIn, struct {
ctype string
config json.RawMessage
secretEnc string
}{ctype, config, secretEnc})
if m.createChannelErr != nil {
return store.Channel{}, m.createChannelErr
}
ch := store.Channel{ID: uuid.New(), ProjectID: projectID, Type: ctype, Config: config, SecretEnc: secretEnc, Enabled: true}
m.channels[ch.ID] = ch
return ch, nil
}
func (m *mockScheduleStore) ListChannels(context.Context, uuid.UUID) ([]store.Channel, error) {
out := make([]store.Channel, 0, len(m.channels))
for _, c := range m.channels {
out = append(out, c)
}
return out, nil
}
func (m *mockScheduleStore) GetChannel(_ context.Context, id, _ uuid.UUID) (store.Channel, error) {
c, ok := m.channels[id]
if !ok {
return store.Channel{}, errors.New("channel not found")
}
return c, nil
}
func (m *mockScheduleStore) DeleteChannel(_ context.Context, id, _ uuid.UUID) error {
m.deleteChannelCalled = true
m.deletedChannelID = id
if m.deleteChannelErr != nil {
return m.deleteChannelErr
}
delete(m.channels, id)
return nil
}
func (m *mockScheduleStore) GetDomain(_ context.Context, id, _ uuid.UUID) (store.Domain, error) {
d, ok := m.domains[id]
if !ok {
return store.Domain{}, errors.New("domain not found")
}
return d, nil
}
func (m *mockScheduleStore) ListCheckRuns(_ context.Context, domainID uuid.UUID) ([]store.CheckRun, error) {
if m.listRunsErr != nil {
return nil, m.listRunsErr
}
return m.checkRuns[domainID], nil
}
// --- mock TestSender ---
type mockTestSender struct {
err error
calledType string
calledConfig json.RawMessage
calledSecret string
called bool
}
func (m *mockTestSender) SendTest(_ context.Context, channelType string, config json.RawMessage, secret string) error {
m.called = true
m.calledType = channelType
m.calledConfig = config
m.calledSecret = secret
return m.err
}
// newScheduleTestAPI wires a fixed authenticated user who owns whatever
// project id is requested (alwaysOwnedAuthStore/alwaysValidSessions, see
// middleware_test.go) — these tests exercise schedule/channels/history
// behavior past the RequireAuth/RequireProjectAccess boundary.
func newScheduleTestAPI() (*API, *mockScheduleStore, *mockTestSender) {
ms := newMockScheduleStore()
mts := &mockTestSender{}
a := &API{
Schedule: ms, Dispatch: mts, Cipher: mockCipher{},
Auth: alwaysOwnedAuthStore(), Sessions: alwaysValidSessions(uuid.New()),
}
return a, ms, mts
}
// --- schedule ---
func TestGetSchedule_DefaultWhenNoRow(t *testing.T) {
a, ms, _ := newScheduleTestAPI()
ms.scheduleErr = pgx.ErrNoRows
router := NewRouter(a)
req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+testPID+"/schedule", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status %d body %s", w.Code, w.Body.String())
}
var resp scheduleResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
if resp.IntervalSeconds != 3600 || resp.Enabled != false {
t.Fatalf("expected default {3600,false}, got %+v", resp)
}
}
func TestGetSchedule_Existing(t *testing.T) {
a, ms, _ := newScheduleTestAPI()
ms.schedule = store.Schedule{ID: uuid.New(), IntervalSeconds: 120, Enabled: true}
router := NewRouter(a)
req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+testPID+"/schedule", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status %d body %s", w.Code, w.Body.String())
}
var resp scheduleResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
if resp.IntervalSeconds != 120 || !resp.Enabled {
t.Fatalf("expected {120,true}, got %+v", resp)
}
}
func TestPutSchedule_RejectsIntervalBelow60(t *testing.T) {
a, ms, _ := newScheduleTestAPI()
router := NewRouter(a)
body := `{"intervalSeconds":59,"enabled":true}`
req := requestWithSessionCookie(http.MethodPut, "/api/v1/projects/"+testPID+"/schedule", strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400 for interval<60, got %d body %s", w.Code, w.Body.String())
}
if ms.upsertCalled {
t.Fatal("UpsertSchedule must not be called when validation fails")
}
}
func TestPutSchedule_Success(t *testing.T) {
a, ms, _ := newScheduleTestAPI()
router := NewRouter(a)
body := `{"intervalSeconds":300,"enabled":true}`
req := requestWithSessionCookie(http.MethodPut, "/api/v1/projects/"+testPID+"/schedule", strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status %d body %s", w.Code, w.Body.String())
}
if !ms.upsertCalled || ms.upsertInterval != 300 || !ms.upsertEnabled {
t.Fatalf("expected UpsertSchedule(300,true), got called=%v interval=%d enabled=%v", ms.upsertCalled, ms.upsertInterval, ms.upsertEnabled)
}
}
// --- channels ---
func TestCreateChannel_EncryptsSecretAndOmitsFromResponse(t *testing.T) {
a, ms, _ := newScheduleTestAPI()
router := NewRouter(a)
body := `{"type":"telegram","config":{"chat_id":"123"},"secret":"super-bot-token"}`
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/channels", strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("status %d body %s", w.Code, w.Body.String())
}
if strings.Contains(w.Body.String(), "super-bot-token") {
t.Fatalf("response leaks plaintext secret: %s", w.Body.String())
}
if len(ms.createChannelIn) != 1 {
t.Fatalf("expected 1 CreateChannel call, got %d", len(ms.createChannelIn))
}
got := ms.createChannelIn[0]
if got.secretEnc == "" || got.secretEnc == "super-bot-token" || !strings.Contains(got.secretEnc, "super-bot-token") {
// mockCipher.Encrypt wraps as ENC(...) — assert it's the *encrypted* form, not the raw plaintext passed unchanged.
t.Fatalf("expected secret to be passed through cipher.Encrypt, got secretEnc=%q", got.secretEnc)
}
if got.secretEnc == "super-bot-token" {
t.Fatalf("secret was stored unencrypted")
}
var resp channelResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
if resp.Type != "telegram" || resp.ID == "" {
t.Fatalf("unexpected response: %+v", resp)
}
}
func TestListChannels_NoSecrets(t *testing.T) {
a, ms, _ := newScheduleTestAPI()
ms.channels[uuid.New()] = store.Channel{ID: uuid.New(), Type: "webhook", Config: json.RawMessage(`{"url":"https://example.com"}`), SecretEnc: "ENC(top-secret)", Enabled: true}
router := NewRouter(a)
req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+testPID+"/channels", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status %d body %s", w.Code, w.Body.String())
}
if strings.Contains(w.Body.String(), "top-secret") || strings.Contains(w.Body.String(), "secretEnc") {
t.Fatalf("channel list leaks secret: %s", w.Body.String())
}
}
func TestDeleteChannel(t *testing.T) {
a, ms, _ := newScheduleTestAPI()
cid := uuid.New()
ms.channels[cid] = store.Channel{ID: cid, Type: "webhook"}
router := NewRouter(a)
req := requestWithSessionCookie(http.MethodDelete, "/api/v1/projects/"+testPID+"/channels/"+cid.String(), nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusNoContent {
t.Fatalf("status %d body %s", w.Code, w.Body.String())
}
if !ms.deleteChannelCalled || ms.deletedChannelID != cid {
t.Fatalf("expected DeleteChannel(%s), called=%v got=%s", cid, ms.deleteChannelCalled, ms.deletedChannelID)
}
}
func TestDeleteChannel_InvalidUUID(t *testing.T) {
a, _, _ := newScheduleTestAPI()
router := NewRouter(a)
req := requestWithSessionCookie(http.MethodDelete, "/api/v1/projects/"+testPID+"/channels/not-a-uuid", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400 for bad channel uuid, got %d", w.Code)
}
}
func TestTestChannel_Success(t *testing.T) {
a, ms, mts := newScheduleTestAPI()
cid := uuid.New()
ms.channels[cid] = store.Channel{ID: cid, Type: "telegram", Config: json.RawMessage(`{"chat_id":"1"}`), SecretEnc: "ENC(bot-token)"}
router := NewRouter(a)
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/channels/"+cid.String()+"/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status %d body %s", w.Code, w.Body.String())
}
if !mts.called || mts.calledType != "telegram" || mts.calledSecret != "bot-token" {
t.Fatalf("expected SendTest(telegram,...,bot-token), got called=%v type=%s secret=%s", mts.called, mts.calledType, mts.calledSecret)
}
}
func TestTestChannel_SenderError_Returns502WithoutSecret(t *testing.T) {
a, ms, mts := newScheduleTestAPI()
cid := uuid.New()
ms.channels[cid] = store.Channel{ID: cid, Type: "telegram", Config: json.RawMessage(`{"chat_id":"1"}`), SecretEnc: "ENC(bot-token)"}
mts.err = errors.New("telegram: status 401 Unauthorized (token=bot-token)")
router := NewRouter(a)
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/channels/"+cid.String()+"/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadGateway {
t.Fatalf("expected 502 on channel test failure, got %d body %s", w.Code, w.Body.String())
}
if strings.Contains(w.Body.String(), "bot-token") {
t.Fatalf("error response leaks secret: %s", w.Body.String())
}
}
func TestTestChannel_UnknownChannel_Returns404(t *testing.T) {
a, _, _ := newScheduleTestAPI()
router := NewRouter(a)
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/channels/"+uuid.New().String()+"/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404 for unknown channel, got %d", w.Code)
}
}
// --- history ---
func TestDomainHistory_List(t *testing.T) {
a, ms, _ := newScheduleTestAPI()
did := uuid.New()
ms.domains[did] = store.Domain{ID: did}
ms.checkRuns[did] = []store.CheckRun{
{ID: uuid.New(), DomainID: did, Result: json.RawMessage(`{"updates":1,"prunes":0}`)},
{ID: uuid.New(), DomainID: did, Result: json.RawMessage(`{"updates":0,"prunes":0}`)},
}
router := NewRouter(a)
req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+testPID+"/domains/"+did.String()+"/history", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status %d body %s", w.Code, w.Body.String())
}
var resp []checkRunResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
if len(resp) != 2 {
t.Fatalf("expected 2 history entries, got %d", len(resp))
}
}
func TestDomainHistory_InvalidUUID(t *testing.T) {
a, _, _ := newScheduleTestAPI()
router := NewRouter(a)
req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+testPID+"/domains/not-a-uuid/history", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400 for bad domain uuid, got %d", w.Code)
}
}
// TestDomainHistory_ForeignDomain_Returns404 is the IDOR guard for history:
// check_runs.domain_id has no project scoping of its own, so the handler
// must verify domain ownership via GetDomain before calling ListCheckRuns —
// a domain id the mock doesn't know about (i.e. not in this project) must
// 404 rather than fall through to an unscoped history lookup.
func TestDomainHistory_ForeignDomain_Returns404(t *testing.T) {
a, _, _ := newScheduleTestAPI()
router := NewRouter(a)
did := uuid.New() // never registered in ms.domains
req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+testPID+"/domains/"+did.String()+"/history", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404 for a domain not owned by this project, got %d body %s", w.Code, w.Body.String())
}
}
+2 -1
View File
@@ -59,12 +59,13 @@ type domainResponse struct {
ZoneName string `json:"zoneName"`
ZoneID string `json:"zoneId"`
TemplateID *string `json:"templateId,omitempty"`
LastCheckStatus string `json:"lastCheckStatus"`
}
func toDomainResponse(d store.Domain) domainResponse {
resp := domainResponse{
ID: d.ID.String(), ProviderAccountID: d.ProviderAccountID.String(),
ZoneName: d.ZoneName, ZoneID: d.ZoneID,
ZoneName: d.ZoneName, ZoneID: d.ZoneID, LastCheckStatus: d.LastCheckStatus,
}
if d.TemplateID != nil {
s := d.TemplateID.String()
+80
View File
@@ -0,0 +1,80 @@
// Package metrics предоставляет Prometheus-метрики DNS Autoresolver.
package metrics
import (
"net/http"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/collectors"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
// Metrics агрегирует Prometheus-метрики приложения на собственном реестре.
type Metrics struct {
Registry *prometheus.Registry
ChecksTotal *prometheus.CounterVec
CheckDuration prometheus.Histogram
DriftDomains prometheus.Gauge
NotificationsTotal *prometheus.CounterVec
}
// New создаёт реестр метрик, регистрирует стандартные Go/Process-коллекторы
// и все метрики приложения.
func New() *Metrics {
reg := prometheus.NewRegistry()
reg.MustRegister(
collectors.NewGoCollector(),
collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}),
)
f := promauto.With(reg)
return &Metrics{
Registry: reg,
ChecksTotal: f.NewCounterVec(prometheus.CounterOpts{
Name: "dns_ar_checks_total",
Help: "Общее количество выполненных проверок доменов по статусу.",
}, []string{"status"}),
CheckDuration: f.NewHistogram(prometheus.HistogramOpts{
Name: "dns_ar_check_duration_seconds",
Help: "Длительность выполнения проверки домена в секундах.",
// Проверка домена — сетевой вызов DNS-провайдера, а не
// внутрипроцессная операция (для которой рассчитан
// prometheus.DefBuckets, начинающийся с 5мс). Бакеты подобраны
// под реалистичный диапазон задержек такого вызова, включая
// таймауты/ретраи медленных провайдеров.
Buckets: []float64{0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10, 30},
}),
DriftDomains: f.NewGauge(prometheus.GaugeOpts{
Name: "dns_ar_drift_domains",
Help: "Текущее количество доменов в состоянии drift.",
}),
NotificationsTotal: f.NewCounterVec(prometheus.CounterOpts{
Name: "dns_ar_notifications_total",
Help: "Общее количество отправленных уведомлений по каналу и статусу.",
}, []string{"channel", "status"}),
}
}
// Handler возвращает HTTP-обработчик для отдачи метрик реестра.
func (m *Metrics) Handler() http.Handler {
return promhttp.HandlerFor(m.Registry, promhttp.HandlerOpts{})
}
// ObserveCheck фиксирует результат проверки: статус и длительность.
func (m *Metrics) ObserveCheck(status string, dur time.Duration) {
m.ChecksTotal.WithLabelValues(status).Inc()
m.CheckDuration.Observe(dur.Seconds())
}
// SetDrift устанавливает текущее количество доменов в состоянии drift.
func (m *Metrics) SetDrift(n int) {
m.DriftDomains.Set(float64(n))
}
// IncNotification фиксирует отправку уведомления по каналу и статусу.
func (m *Metrics) IncNotification(channel, status string) {
m.NotificationsTotal.WithLabelValues(channel, status).Inc()
}
+58
View File
@@ -0,0 +1,58 @@
package metrics
import (
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/prometheus/client_golang/prometheus/testutil"
)
func TestMetricsRecord(t *testing.T) {
m := New()
m.ObserveCheck("drift", 100*time.Millisecond)
m.ObserveCheck("in_sync", 50*time.Millisecond)
m.IncNotification("telegram", "ok")
m.SetDrift(3)
if got := testutil.ToFloat64(m.ChecksTotal.WithLabelValues("drift")); got != 1 {
t.Fatalf("checks drift = %v", got)
}
if got := testutil.ToFloat64(m.DriftDomains); got != 3 {
t.Fatalf("drift gauge = %v", got)
}
if got := testutil.ToFloat64(m.NotificationsTotal.WithLabelValues("telegram", "ok")); got != 1 {
t.Fatalf("notif = %v", got)
}
}
func TestCheckDurationUsesNetworkCallBuckets(t *testing.T) {
m := New()
m.ObserveCheck("in_sync", 100*time.Millisecond)
rec := httptest.NewRecorder()
m.Handler().ServeHTTP(rec, httptest.NewRequest("GET", "/metrics", nil))
body := rec.Body.String()
// DefBuckets (le="0.005", ...) is tuned for sub-10ms in-process calls;
// dns_ar_check_duration_seconds is a network call to a DNS provider, so
// it must use the wider explicit buckets instead.
for _, want := range []string{`le="0.05"`, `le="1"`, `le="30"`} {
if !strings.Contains(body, `dns_ar_check_duration_seconds_bucket{`+want) {
t.Fatalf("expected bucket %s in exposed metrics:\n%s", want, body)
}
}
if strings.Contains(body, `dns_ar_check_duration_seconds_bucket{le="0.005"`) {
t.Fatalf("found default histogram bucket 0.005, expected custom buckets:\n%s", body)
}
}
func TestHandlerExposesMetrics(t *testing.T) {
m := New()
m.ObserveCheck("in_sync", time.Millisecond)
rec := httptest.NewRecorder()
m.Handler().ServeHTTP(rec, httptest.NewRequest("GET", "/metrics", nil))
if rec.Code != 200 || !strings.Contains(rec.Body.String(), "dns_ar_checks_total") {
t.Fatalf("metrics not exposed: %d", rec.Code)
}
}
+101
View File
@@ -0,0 +1,101 @@
package notify
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"time"
"github.com/google/uuid"
"github.com/vasyakrg/dns-autoresolver/internal/store"
)
// ChannelStore is the narrow store dependency Dispatcher needs: the set of
// enabled notification channels for a project.
type ChannelStore interface {
ListEnabledChannels(ctx context.Context, projectID uuid.UUID) ([]store.Channel, error)
}
// Decryptor decrypts a channel's stored secret (bot token, signing key, ...).
type Decryptor interface {
Decrypt(enc string) ([]byte, error)
}
// Dispatcher fans an Event out to every enabled channel of a project,
// picking the Notifier implementation by channel type. A failure on one
// channel does not stop delivery to the others; all errors are aggregated
// via errors.Join.
type Dispatcher struct {
store ChannelStore
cipher Decryptor
byType map[string]Notifier
}
// NewDispatcher builds a Dispatcher wired with the default Telegram and
// Webhook notifiers.
func NewDispatcher(store ChannelStore, cipher Decryptor) *Dispatcher {
return &Dispatcher{
store: store,
cipher: cipher,
byType: map[string]Notifier{
"telegram": &Telegram{BaseURL: "https://api.telegram.org", HTTP: &http.Client{Timeout: 15 * time.Second}},
"webhook": &Webhook{HTTP: &http.Client{
Timeout: 15 * time.Second,
Transport: newWebhookTransport(false),
}},
},
}
}
// Send delivers ev to every enabled channel of projectID. Errors from
// individual channels are aggregated (via errors.Join) rather than aborting
// delivery to the remaining channels.
func (d *Dispatcher) Send(ctx context.Context, projectID uuid.UUID, ev Event) error {
channels, err := d.store.ListEnabledChannels(ctx, projectID)
if err != nil {
return err
}
var errs []error
for _, ch := range channels {
n, ok := d.byType[ch.Type]
if !ok {
continue
}
secret := ""
if ch.SecretEnc != "" {
b, err := d.cipher.Decrypt(ch.SecretEnc)
if err != nil {
errs = append(errs, err)
continue
}
secret = string(b)
}
if err := n.Send(ctx, ch.Config, secret, ev); err != nil {
errs = append(errs, err)
}
}
return errors.Join(errs...)
}
// SendTest sends a single synthetic Event directly through the Notifier for
// channelType, bypassing project/channel lookup entirely. It satisfies
// api.TestSender and backs POST /channels/{cid}/test, letting a user verify
// a channel's bot_token/chat_id or webhook URL works before enabling the
// schedule — the api layer resolves the channel and decrypts its secret; this
// method only performs the actual delivery attempt.
func (d *Dispatcher) SendTest(ctx context.Context, channelType string, config json.RawMessage, secret string) error {
n, ok := d.byType[channelType]
if !ok {
return fmt.Errorf("notify: unknown channel type %q", channelType)
}
ev := Event{
Project: "test",
Domain: "test",
Status: "test",
Summary: "test notification",
At: time.Now(),
}
return n.Send(ctx, config, secret, ev)
}
+27
View File
@@ -0,0 +1,27 @@
// Package notify sends drift/error notifications to project-configured
// channels (Telegram, generic webhooks, ...). Notifier implementations must
// never log the secret they receive (bot tokens, HMAC keys, etc.).
package notify
import (
"context"
"encoding/json"
"time"
)
// Event describes a single notification-worthy occurrence for a domain
// belonging to a project (e.g. a status change detected by the scheduler).
type Event struct {
Project string
Domain string
Status string
Summary string
At time.Time
}
// Notifier delivers an Event to a channel described by cfg (channel-type
// specific JSON config) and secret (decrypted credential, e.g. a bot token).
// Implementations must not log secret.
type Notifier interface {
Send(ctx context.Context, cfg json.RawMessage, secret string, ev Event) error
}
+392
View File
@@ -0,0 +1,392 @@
package notify
import (
"context"
"encoding/json"
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/google/uuid"
"github.com/vasyakrg/dns-autoresolver/internal/store"
)
func TestTelegramSendSuccess(t *testing.T) {
var gotPath string
var gotBody map[string]string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
_ = json.NewDecoder(r.Body).Decode(&gotBody)
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
tg := &Telegram{BaseURL: srv.URL, HTTP: srv.Client()}
ev := Event{Project: "proj", Domain: "example.com", Status: "drift", Summary: "A record changed", At: time.Now()}
err := tg.Send(context.Background(), json.RawMessage(`{"chat_id":"12345"}`), "sekret-token", ev)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if gotPath != "/botsekret-token/sendMessage" {
t.Fatalf("unexpected path: %s", gotPath)
}
if gotBody["chat_id"] != "12345" {
t.Fatalf("unexpected chat_id: %+v", gotBody)
}
if !strings.Contains(gotBody["text"], "example.com") || !strings.Contains(gotBody["text"], "drift") {
t.Fatalf("unexpected text: %+v", gotBody)
}
}
func TestTelegramSendServerError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer srv.Close()
tg := &Telegram{BaseURL: srv.URL, HTTP: srv.Client()}
ev := Event{Project: "proj", Domain: "example.com", Status: "drift", Summary: "x", At: time.Now()}
if err := tg.Send(context.Background(), json.RawMessage(`{"chat_id":"1"}`), "tok", ev); err == nil {
t.Fatal("expected error on 500 response")
}
}
func TestTelegramSendTransportErrorDoesNotLeakSecret(t *testing.T) {
// Bind and immediately close a server: its address is now unreachable
// (connection refused), which makes http.Client.Do return a *url.Error
// whose Error() embeds the full request URL — including /bot<secret>/.
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
deadURL := srv.URL
srv.Close()
tg := &Telegram{BaseURL: deadURL, HTTP: srv.Client()}
ev := Event{Project: "proj", Domain: "example.com", Status: "drift", Summary: "x", At: time.Now()}
const secret = "super-secret-bot-token"
err := tg.Send(context.Background(), json.RawMessage(`{"chat_id":"1"}`), secret, ev)
if err == nil {
t.Fatal("expected error for unreachable host")
}
if strings.Contains(err.Error(), secret) {
t.Fatalf("error leaks secret: %v", err)
}
if strings.Contains(err.Error(), deadURL) {
t.Fatalf("error leaks request URL: %v", err)
}
}
func TestWebhookSendSuccess(t *testing.T) {
var gotEvent Event
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("unexpected method: %s", r.Method)
}
_ = json.NewDecoder(r.Body).Decode(&gotEvent)
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
// allowPrivate: true — httptest.Server listens on 127.0.0.1, which the
// SSRF guard would otherwise reject; production dispatchers never set
// this (see TestIsAllowedURL / TestNewDispatcherDoesNotAllowPrivate).
wh := &Webhook{HTTP: srv.Client(), allowPrivate: true}
ev := Event{Project: "proj", Domain: "example.com", Status: "in_sync", Summary: "resolved", At: time.Now()}
cfg, _ := json.Marshal(map[string]string{"url": srv.URL})
if err := wh.Send(context.Background(), cfg, "", ev); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if gotEvent.Domain != "example.com" || gotEvent.Status != "in_sync" {
t.Fatalf("unexpected event delivered: %+v", gotEvent)
}
}
func TestWebhookSendNonSuccessStatus(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
}))
defer srv.Close()
wh := &Webhook{HTTP: srv.Client(), allowPrivate: true}
ev := Event{Project: "proj", Domain: "example.com", Status: "drift", Summary: "x", At: time.Now()}
cfg, _ := json.Marshal(map[string]string{"url": srv.URL})
if err := wh.Send(context.Background(), cfg, "", ev); err == nil {
t.Fatal("expected error on 400 response")
}
}
func TestWebhookSendRejectsPrivateDestinationByDefault(t *testing.T) {
wh := &Webhook{HTTP: http.DefaultClient} // allowPrivate not set: SSRF guard active
ev := Event{Project: "proj", Domain: "example.com", Status: "drift", Summary: "x", At: time.Now()}
cfg, _ := json.Marshal(map[string]string{"url": "http://127.0.0.1:1/hook"})
err := wh.Send(context.Background(), cfg, "", ev)
if err == nil {
t.Fatal("expected error for loopback destination")
}
if !strings.Contains(err.Error(), "destination not allowed") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestIsAllowedURL(t *testing.T) {
cases := []struct {
name string
rawurl string
allowed bool
}{
{"localhost hostname", "http://localhost/hook", false},
{"loopback ip", "http://127.0.0.1/hook", false},
{"loopback ipv6", "http://[::1]/hook", false},
{"link-local metadata", "http://169.254.169.254/latest/meta-data", false},
{"private class a", "http://10.0.0.1/hook", false},
{"private class c", "http://192.168.1.1/hook", false},
{"private class b", "http://172.16.0.1/hook", false},
{"unspecified", "http://0.0.0.0/hook", false},
{"multicast", "http://224.0.0.1/hook", false},
{"non-http scheme", "ftp://example.com/hook", false},
{"public ip", "http://93.184.216.34/hook", true},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
err := isAllowedURL(tc.rawurl)
if tc.allowed && err != nil {
t.Fatalf("expected %q to be allowed, got error: %v", tc.rawurl, err)
}
if !tc.allowed && err == nil {
t.Fatalf("expected %q to be rejected, got nil error", tc.rawurl)
}
})
}
}
func TestDialControlBlocksActualConnectingAddress(t *testing.T) {
cases := []struct {
name string
address string
blocked bool
}{
{"loopback v4", "127.0.0.1:80", true},
{"loopback v6", "[::1]:80", true},
{"metadata link-local", "169.254.169.254:80", true},
{"private class a", "10.0.0.1:80", true},
{"private class b", "172.16.0.1:80", true},
{"private class c", "192.168.1.1:80", true},
{"unspecified", "0.0.0.0:80", true},
{"multicast", "224.0.0.1:80", true},
{"public ip", "93.184.216.34:443", false},
}
control := dialControl(false)
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
err := control("tcp", tc.address, nil)
if tc.blocked && err == nil {
t.Fatalf("expected %q to be blocked", tc.address)
}
if !tc.blocked && err != nil {
t.Fatalf("expected %q to be allowed, got: %v", tc.address, err)
}
})
}
}
func TestIsBlockedIPCGNATRange(t *testing.T) {
cases := []struct {
name string
ip string
blocked bool
}{
{"cgnat start", "100.64.0.1", true},
{"cgnat end", "100.127.255.255", true},
{"just below cgnat", "100.63.255.255", false},
{"just above cgnat", "100.128.0.0", false},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
ip := net.ParseIP(tc.ip)
if ip == nil {
t.Fatalf("failed to parse %q", tc.ip)
}
got := isBlockedIP(ip)
if got != tc.blocked {
t.Fatalf("isBlockedIP(%q) = %v, want %v", tc.ip, got, tc.blocked)
}
})
}
}
func TestDialControlAllowsEverythingWhenAllowPrivate(t *testing.T) {
control := dialControl(true)
if err := control("tcp", "127.0.0.1:80", nil); err != nil {
t.Fatalf("expected allowPrivate to skip the dial guard, got: %v", err)
}
}
// TestWebhookControlBlocksConnectionEvenWhenPreCheckPasses simulates the
// DNS-rebinding TOCTOU: allowPrivate=true skips the pre-request isAllowedURL
// check (standing in for a rebinding attacker answering a public IP to that
// lookup), but the Transport's Control func — wired independently of
// Webhook.allowPrivate — still inspects the literal address the dialer
// connects to and must still reject it. If Control did not exist, this
// request would reach the httptest handler; it must not.
func TestWebhookControlBlocksConnectionEvenWhenPreCheckPasses(t *testing.T) {
var handlerCalled bool
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlerCalled = true
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
wh := &Webhook{
HTTP: &http.Client{Transport: newWebhookTransport(false)},
allowPrivate: true, // pre-check bypassed on purpose; Control is not
}
ev := Event{Project: "proj", Domain: "example.com", Status: "drift", Summary: "x", At: time.Now()}
cfg, _ := json.Marshal(map[string]string{"url": srv.URL})
err := wh.Send(context.Background(), cfg, "", ev)
if err == nil {
t.Fatal("expected error: Control should have blocked the loopback connection")
}
if !strings.Contains(err.Error(), "destination not allowed") {
t.Fatalf("unexpected error: %v", err)
}
if handlerCalled {
t.Fatal("Control should have rejected the dial before the handler ran")
}
}
// --- Dispatcher ---
type mockChannelStore struct {
channels []store.Channel
err error
}
func (m *mockChannelStore) ListEnabledChannels(ctx context.Context, projectID uuid.UUID) ([]store.Channel, error) {
return m.channels, m.err
}
type mockDecryptor struct {
fail bool
}
func (m *mockDecryptor) Decrypt(enc string) ([]byte, error) {
if m.fail {
return nil, errBoom
}
return []byte("decrypted-" + enc), nil
}
var errBoom = &boomErr{}
type boomErr struct{}
func (*boomErr) Error() string { return "decrypt boom" }
func TestDispatcherSendsToAllChannelsAndAggregatesErrors(t *testing.T) {
var tgCalled, whCalled bool
var tgSecret string
tgSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tgCalled = true
w.WriteHeader(http.StatusOK)
}))
defer tgSrv.Close()
whSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
whCalled = true
w.WriteHeader(http.StatusInternalServerError) // webhook fails
}))
defer whSrv.Close()
projectID := uuid.New()
channels := []store.Channel{
{
ID: uuid.New(), ProjectID: projectID, Type: "telegram",
Config: json.RawMessage(`{"chat_id":"1"}`), SecretEnc: "enc-token", Enabled: true,
},
{
ID: uuid.New(), ProjectID: projectID, Type: "webhook",
Config: json.RawMessage(`{"url":"` + whSrv.URL + `"}`), SecretEnc: "", Enabled: true,
},
}
d := NewDispatcher(&mockChannelStore{channels: channels}, &mockDecryptor{})
// Redirect telegram to the httptest server and capture the decrypted secret.
d.byType["telegram"] = notifierFunc(func(ctx context.Context, cfg json.RawMessage, secret string, ev Event) error {
tgSecret = secret
tg := &Telegram{BaseURL: tgSrv.URL, HTTP: tgSrv.Client()}
return tg.Send(ctx, cfg, secret, ev)
})
// httptest servers listen on loopback, which the SSRF guard rejects by
// default; swap in an allowPrivate webhook so this test can still hit it.
d.byType["webhook"] = &Webhook{HTTP: whSrv.Client(), allowPrivate: true}
ev := Event{Project: "proj", Domain: "example.com", Status: "drift", Summary: "changed", At: time.Now()}
err := d.Send(context.Background(), projectID, ev)
if !tgCalled {
t.Error("expected telegram notifier to be called")
}
if !whCalled {
t.Error("expected webhook notifier to be called")
}
if err == nil {
t.Fatal("expected aggregated error because webhook failed")
}
if tgSecret != "decrypted-enc-token" {
t.Fatalf("expected decrypted secret to be passed to telegram, got %q", tgSecret)
}
}
func TestDispatcherSkipsUnknownChannelType(t *testing.T) {
projectID := uuid.New()
channels := []store.Channel{
{ID: uuid.New(), ProjectID: projectID, Type: "carrier-pigeon", Config: json.RawMessage(`{}`), Enabled: true},
}
d := NewDispatcher(&mockChannelStore{channels: channels}, &mockDecryptor{})
if err := d.Send(context.Background(), projectID, Event{Project: "p", Domain: "d", Status: "drift"}); err != nil {
t.Fatalf("unexpected error for unknown channel type: %v", err)
}
}
func TestDispatcherDecryptFailureIsAggregatedNotFatal(t *testing.T) {
var whCalled bool
whSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
whCalled = true
w.WriteHeader(http.StatusOK)
}))
defer whSrv.Close()
projectID := uuid.New()
channels := []store.Channel{
{ID: uuid.New(), ProjectID: projectID, Type: "telegram", Config: json.RawMessage(`{"chat_id":"1"}`), SecretEnc: "enc", Enabled: true},
{ID: uuid.New(), ProjectID: projectID, Type: "webhook", Config: json.RawMessage(`{"url":"` + whSrv.URL + `"}`), Enabled: true},
}
d := NewDispatcher(&mockChannelStore{channels: channels}, &mockDecryptor{fail: true})
// httptest servers listen on loopback, which the SSRF guard rejects by
// default; swap in an allowPrivate webhook so this test can still hit it.
d.byType["webhook"] = &Webhook{HTTP: whSrv.Client(), allowPrivate: true}
err := d.Send(context.Background(), projectID, Event{Project: "p", Domain: "d", Status: "drift"})
if err == nil {
t.Fatal("expected error due to decrypt failure")
}
if !whCalled {
t.Error("expected webhook channel to still be attempted after telegram decrypt failure")
}
}
// notifierFunc adapts a function to the Notifier interface for tests.
type notifierFunc func(ctx context.Context, cfg json.RawMessage, secret string, ev Event) error
func (f notifierFunc) Send(ctx context.Context, cfg json.RawMessage, secret string, ev Event) error {
return f(ctx, cfg, secret, ev)
}
+49
View File
@@ -0,0 +1,49 @@
package notify
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
)
// Telegram delivers notifications via the Telegram Bot API sendMessage
// endpoint. Config is {"chat_id": "..."}; secret is the bot token and is
// never logged.
type Telegram struct {
BaseURL string
HTTP *http.Client
}
func (t *Telegram) Send(ctx context.Context, cfg json.RawMessage, secret string, ev Event) error {
var c struct {
ChatID string `json:"chat_id"`
}
if err := json.Unmarshal(cfg, &c); err != nil {
return err
}
body, _ := json.Marshal(map[string]string{
"chat_id": c.ChatID,
"text": fmt.Sprintf("[%s] %s → %s\n%s", ev.Project, ev.Domain, ev.Status, ev.Summary),
})
url := t.BaseURL + "/bot" + secret + "/sendMessage"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
resp, err := t.HTTP.Do(req)
if err != nil {
// Do NOT wrap/return err as-is: *url.Error.Error() embeds the full
// request URL, which contains the bot token (/bot<secret>/...). A
// caller logging this error would leak the secret.
return errors.New("telegram: request failed")
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
return fmt.Errorf("telegram: status %d", resp.StatusCode)
}
return nil
}
+200
View File
@@ -0,0 +1,200 @@
package notify
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"syscall"
"time"
)
// Webhook delivers notifications as a JSON POST of the Event to a
// project-configured URL. Config is {"url": "..."}. secret is currently
// unused (reserved for future request signing) and is never logged.
//
// The destination URL is project-controlled (any project owner can set it),
// so it is treated as untrusted input. Two layers guard against SSRF:
//
// 1. isAllowedURL is a pre-request fast-fail check on the URL's scheme and
// (resolved) hostname.
// 2. HTTP's Transport, when built via newWebhookTransport, wires a
// net.Dialer.Control that re-checks the actual "ip:port" being dialed for
// every connection net/http opens — including the DNS resolution
// http.Client.Do performs internally, independent of (1).
//
// Layer (2) is the source of truth: DNS answers are attacker-influenceable
// (an attacker with authoritative DNS and a low TTL can answer a public IP to
// a pre-request lookup and a private/loopback IP to the actual connection —
// DNS rebinding). Relying on (1) alone leaves that TOCTOU window open; (2)
// closes it because it inspects the address the connection is actually made
// to, not a name. Redirects are not followed, since a redirect response
// could otherwise be used to bypass the destination checks.
type Webhook struct {
HTTP *http.Client
// allowPrivate disables the isAllowedURL pre-check. It exists only so
// tests can exercise Send happy-paths against httptest servers, which
// listen on loopback. Production Dispatchers (NewDispatcher) must never
// set this; they also wire a Transport whose Control func enforces the
// same guard at dial time regardless of this flag.
allowPrivate bool
}
// isAllowedURL rejects any URL that is not a plain http/https request to a
// public, resolvable address. It resolves hostnames and checks every
// returned address — a hostname that resolves to even one
// private/loopback/link-local/unspecified address is rejected, since DNS
// answers are attacker-influenceable (rebinding) and partial trust is not
// safe.
func isAllowedURL(rawurl string) error {
u, err := url.Parse(rawurl)
if err != nil {
return fmt.Errorf("webhook: invalid url: %w", err)
}
if u.Scheme != "http" && u.Scheme != "https" {
return errors.New("webhook: destination not allowed")
}
host := u.Hostname()
if host == "" {
return errors.New("webhook: destination not allowed")
}
var ips []net.IP
if ip := net.ParseIP(host); ip != nil {
ips = []net.IP{ip}
} else {
resolved, err := net.LookupIP(host)
if err != nil {
return errors.New("webhook: destination not allowed")
}
ips = resolved
}
for _, ip := range ips {
if isBlockedIP(ip) {
return errors.New("webhook: destination not allowed")
}
}
return nil
}
// cgnatBlock is the shared address space reserved for carrier-grade NAT
// (RFC 6598, 100.64.0.0/10). net.IP.IsPrivate() only covers RFC1918/RFC4193
// and does not treat this range as private, so it must be checked
// explicitly or CGNAT-addressed internal services would be reachable via
// webhook SSRF.
var cgnatBlock = func() *net.IPNet {
_, block, err := net.ParseCIDR("100.64.0.0/10")
if err != nil {
panic(err)
}
return block
}()
// isBlockedIP reports whether ip must never be connected to: loopback,
// private (RFC1918 etc.), link-local, unspecified, multicast, or
// carrier-grade NAT (RFC 6598). Used both by isAllowedURL's pre-request
// check and by dialControl's per-connection check.
func isBlockedIP(ip net.IP) bool {
if v4 := ip.To4(); v4 != nil {
ip = v4
}
return ip.IsLoopback() ||
ip.IsPrivate() ||
ip.IsLinkLocalUnicast() ||
ip.IsLinkLocalMulticast() ||
ip.IsUnspecified() ||
ip.IsMulticast() ||
cgnatBlock.Contains(ip)
}
// dialControl returns a net.Dialer.Control function enforcing the SSRF guard
// on the literal address ("ip:port") that net/http is about to connect to.
// It runs after any DNS resolution net/http performs internally — including
// resolution done independently of, and possibly later than, isAllowedURL's
// own lookup — so it sees the real connecting IP and closes the DNS-rebinding
// TOCTOU window described on Webhook.
//
// allowPrivate disables the check entirely; it exists so tests can dial
// httptest servers, which listen on loopback.
func dialControl(allowPrivate bool) func(network, address string, c syscall.RawConn) error {
return func(network, address string, c syscall.RawConn) error {
if allowPrivate {
return nil
}
host, _, err := net.SplitHostPort(address)
if err != nil {
return errors.New("webhook: destination not allowed")
}
ip := net.ParseIP(host)
if ip == nil {
return errors.New("webhook: destination not allowed")
}
if isBlockedIP(ip) {
return errors.New("webhook: destination not allowed")
}
return nil
}
}
// newWebhookTransport builds an http.Transport whose dialer enforces the
// SSRF guard on the actual address being connected to, for every connection
// it opens (see dialControl). This is the guard of record; isAllowedURL is
// only a fast pre-request rejection layered in front of it.
func newWebhookTransport(allowPrivate bool) *http.Transport {
dialer := &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
Control: dialControl(allowPrivate),
}
t := http.DefaultTransport.(*http.Transport).Clone()
t.DialContext = dialer.DialContext
return t
}
func (w *Webhook) Send(ctx context.Context, cfg json.RawMessage, secret string, ev Event) error {
var c struct {
URL string `json:"url"`
}
if err := json.Unmarshal(cfg, &c); err != nil {
return err
}
if !w.allowPrivate {
if err := isAllowedURL(c.URL); err != nil {
return err
}
}
body, err := json.Marshal(ev)
if err != nil {
return err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.URL, bytes.NewReader(body))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
client := w.HTTP
if client.CheckRedirect == nil {
clientCopy := *client
clientCopy.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
client = &clientCopy
}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("webhook: status %d", resp.StatusCode)
}
return nil
}
+211
View File
@@ -0,0 +1,211 @@
// Package scheduler runs an in-process loop that periodically checks every
// domain of every due project schedule, records the resulting status, and
// notifies configured channels on meaningful status transitions.
package scheduler
import (
"context"
"fmt"
"log"
"time"
"github.com/google/uuid"
"github.com/vasyakrg/dns-autoresolver/internal/diff"
"github.com/vasyakrg/dns-autoresolver/internal/metrics"
"github.com/vasyakrg/dns-autoresolver/internal/notify"
"github.com/vasyakrg/dns-autoresolver/internal/store"
)
// Domain check statuses persisted via SchedStore.SetDomainStatus /
// surfaced via GetDomainStatus. "unknown" is the DB default for a domain
// that has never been checked (see migrations/0004_schedule_notify.sql).
const (
StatusUnknown = "unknown"
StatusInSync = "in_sync"
StatusDrift = "drift"
StatusError = "error"
)
// SchedStore is the narrow store dependency the scheduler needs: due
// schedules, their domains, and per-domain status bookkeeping. Persisting
// the check result itself (check_runs) is the Checker's job — see Checker
// below — not the scheduler's.
type SchedStore interface {
ListDueSchedules(ctx context.Context, now time.Time) ([]store.Schedule, error)
TouchScheduleRun(ctx context.Context, projectID uuid.UUID, at time.Time) error
ListDomains(ctx context.Context, projectID uuid.UUID) ([]store.Domain, error)
GetDomainStatus(ctx context.Context, domainID uuid.UUID) (string, error)
SetDomainStatus(ctx context.Context, domainID uuid.UUID, status string) error
CountDriftDomains(ctx context.Context) (int, error)
}
// Checker computes the diff between a domain's desired template and its
// actual zone state. internal/service.DomainService satisfies this.
type Checker interface {
Check(ctx context.Context, projectID, domainID uuid.UUID) (diff.Changeset, error)
}
// NotifySender delivers a status-change event to a project's notification
// channels. internal/notify.Dispatcher satisfies this.
type NotifySender interface {
Send(ctx context.Context, projectID uuid.UUID, ev notify.Event) error
}
// Scheduler drives periodic domain checks for every due project schedule.
type Scheduler struct {
store SchedStore
checker Checker
notifier NotifySender
metrics *metrics.Metrics
}
// New builds a Scheduler wired with its store, checker, notifier and metrics
// dependencies.
func New(store SchedStore, checker Checker, notifier NotifySender, m *metrics.Metrics) *Scheduler {
return &Scheduler{store: store, checker: checker, notifier: notifier, metrics: m}
}
// Run ticks every `tick` and calls RunOnce until ctx is cancelled. A failed
// iteration is logged, never fatal — the loop keeps ticking so a transient
// store/provider outage does not permanently stop future checks.
func (s *Scheduler) Run(ctx context.Context, tick time.Duration) {
ticker := time.NewTicker(tick)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.RunOnce(ctx, time.Now()); err != nil {
log.Printf("scheduler: run once failed: %v", err)
}
}
}
}
// RunOnce performs a single scheduling pass: every due project schedule is
// checked, each of its domains is diffed against its template, its status
// is updated, and channels are notified on a meaningful status transition.
func (s *Scheduler) RunOnce(ctx context.Context, now time.Time) error {
due, err := s.store.ListDueSchedules(ctx, now)
if err != nil {
return fmt.Errorf("list due schedules: %w", err)
}
for _, sch := range due {
domains, err := s.store.ListDomains(ctx, sch.ProjectID)
if err != nil {
log.Printf("scheduler: list domains for project %s failed: %v", sch.ProjectID, err)
continue
}
for _, d := range domains {
// A domain with no template attached is not yet configured for
// checking (a valid, expected state right after import) — not a
// failure. Checking it would make LoadDomain return "domain has
// no template", turning into a StatusError that spams a
// notification and shows a red badge for a domain the user
// simply hasn't set up yet. Skip it silently: no check, no
// status change, no notification.
if d.TemplateID == nil {
continue
}
s.checkDomain(ctx, sch.ProjectID, d, now)
}
if err := s.store.TouchScheduleRun(ctx, sch.ProjectID, now); err != nil {
log.Printf("scheduler: touch schedule run for project %s failed: %v", sch.ProjectID, err)
}
}
// The real, system-wide count of drift domains — not a local
// accumulator scoped to this tick's due projects — so the gauge
// reflects reality even across ticks where different projects are due.
count, err := s.store.CountDriftDomains(ctx)
if err != nil {
log.Printf("scheduler: count drift domains failed: %v", err)
} else {
s.metrics.SetDrift(count)
}
return nil
}
// checkDomain runs a single domain's check, persists the outcome, and fires
// a notification if the status transition warrants one. It returns the new
// status.
func (s *Scheduler) checkDomain(ctx context.Context, projectID uuid.UUID, d store.Domain, now time.Time) string {
start := time.Now()
cs, checkErr := s.checker.Check(ctx, projectID, d.ID)
dur := time.Since(start)
newStatus := StatusInSync
switch {
case checkErr != nil:
newStatus = StatusError
case len(cs.Actionable()) > 0:
newStatus = StatusDrift
}
s.metrics.ObserveCheck(newStatus, dur)
prev, err := s.store.GetDomainStatus(ctx, d.ID)
if err != nil {
log.Printf("scheduler: get domain status for %s failed: %v", d.ID, err)
prev = StatusUnknown
}
// Persisting the check_runs row is the Checker's job: DomainService.Check
// already calls Recorder.SaveCheckRun internally on every successful
// check (drift or in_sync). Calling it again here would double-write
// check_runs history for the same check.
if err := s.store.SetDomainStatus(ctx, d.ID, newStatus); err != nil {
log.Printf("scheduler: set domain status for %s failed: %v", d.ID, err)
}
if shouldNotify(prev, newStatus) {
ev := notify.Event{
Project: projectID.String(),
Domain: d.ID.String(),
Status: newStatus,
Summary: summarize(newStatus, cs, checkErr),
At: now,
}
if err := s.notifier.Send(ctx, projectID, ev); err != nil {
log.Printf("scheduler: notify send for project %s domain %s failed: %v", projectID, d.ID, err)
}
s.metrics.IncNotification("dispatch", newStatus)
}
return newStatus
}
// shouldNotify decides whether a prev -> new status transition is worth
// alerting on:
// - entering drift or error from any other status is always notified;
// - recovering from drift OR error back to in_sync ("resolved") is
// notified — including recovery after a provider/check failure;
// - the initial unknown -> in_sync transition (first successful check of a
// domain that never drifted or errored) is NOT notified — it is not
// news, it is the expected steady state.
func shouldNotify(prev, newStatus string) bool {
if (newStatus == StatusDrift || newStatus == StatusError) && newStatus != prev {
return true
}
if (prev == StatusDrift || prev == StatusError) && newStatus == StatusInSync {
return true
}
return false
}
// summarize builds a short, secret-free human-readable message for an Event.
func summarize(status string, cs diff.Changeset, checkErr error) string {
if checkErr != nil {
return fmt.Sprintf("check failed: %v", checkErr)
}
if status == StatusDrift {
return fmt.Sprintf("%d actionable diff(s) detected", len(cs.Actionable()))
}
return "zone back in sync with template"
}
+321
View File
@@ -0,0 +1,321 @@
package scheduler
import (
"context"
"errors"
"sync"
"testing"
"time"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus/testutil"
"github.com/vasyakrg/dns-autoresolver/internal/diff"
"github.com/vasyakrg/dns-autoresolver/internal/metrics"
"github.com/vasyakrg/dns-autoresolver/internal/notify"
"github.com/vasyakrg/dns-autoresolver/internal/store"
)
// mockStore is an in-memory SchedStore double.
type mockStore struct {
mu sync.Mutex
schedules []store.Schedule
domains map[uuid.UUID][]store.Domain
status map[uuid.UUID]string
touchedProjects []uuid.UUID
// driftCount is what CountDriftDomains returns — a canned system-wide
// count, independent of what this RunOnce's due projects touched.
driftCount int
}
func newMockStore() *mockStore {
return &mockStore{
domains: make(map[uuid.UUID][]store.Domain),
status: make(map[uuid.UUID]string),
}
}
func (m *mockStore) ListDueSchedules(ctx context.Context, now time.Time) ([]store.Schedule, error) {
return m.schedules, nil
}
func (m *mockStore) TouchScheduleRun(ctx context.Context, projectID uuid.UUID, at time.Time) error {
m.mu.Lock()
defer m.mu.Unlock()
m.touchedProjects = append(m.touchedProjects, projectID)
return nil
}
func (m *mockStore) ListDomains(ctx context.Context, projectID uuid.UUID) ([]store.Domain, error) {
return m.domains[projectID], nil
}
func (m *mockStore) GetDomainStatus(ctx context.Context, domainID uuid.UUID) (string, error) {
m.mu.Lock()
defer m.mu.Unlock()
if st, ok := m.status[domainID]; ok {
return st, nil
}
return StatusUnknown, nil
}
func (m *mockStore) SetDomainStatus(ctx context.Context, domainID uuid.UUID, status string) error {
m.mu.Lock()
defer m.mu.Unlock()
m.status[domainID] = status
return nil
}
func (m *mockStore) CountDriftDomains(ctx context.Context) (int, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.driftCount, nil
}
// mockChecker returns a preset Changeset or error per domainID, and records
// which domain IDs it was called with.
type mockChecker struct {
mu sync.Mutex
results map[uuid.UUID]diff.Changeset
errs map[uuid.UUID]error
calls []uuid.UUID
}
func (c *mockChecker) Check(ctx context.Context, projectID, domainID uuid.UUID) (diff.Changeset, error) {
c.mu.Lock()
c.calls = append(c.calls, domainID)
c.mu.Unlock()
if err, ok := c.errs[domainID]; ok {
return diff.Changeset{}, err
}
return c.results[domainID], nil
}
// mockNotifier records every Event it is asked to Send.
type mockNotifier struct {
mu sync.Mutex
events []notify.Event
}
func (n *mockNotifier) Send(ctx context.Context, projectID uuid.UUID, ev notify.Event) error {
n.mu.Lock()
defer n.mu.Unlock()
n.events = append(n.events, ev)
return nil
}
func (n *mockNotifier) count() int {
n.mu.Lock()
defer n.mu.Unlock()
return len(n.events)
}
func driftChangeset() diff.Changeset {
return diff.Changeset{Diffs: []diff.RecordDiff{{Kind: diff.Update, Name: "www"}}}
}
func TestRunOnce_NotifiesOnDriftNotOnFirstInSync(t *testing.T) {
projectID := uuid.New()
templateID := uuid.New()
domainA := store.Domain{ID: uuid.New(), ProjectID: projectID, TemplateID: &templateID}
domainB := store.Domain{ID: uuid.New(), ProjectID: projectID, TemplateID: &templateID}
st := newMockStore()
st.schedules = []store.Schedule{{ID: uuid.New(), ProjectID: projectID, IntervalSeconds: 3600, Enabled: true}}
st.domains[projectID] = []store.Domain{domainA, domainB}
checker := &mockChecker{
results: map[uuid.UUID]diff.Changeset{
domainA.ID: driftChangeset(),
domainB.ID: {},
},
}
notifier := &mockNotifier{}
m := metrics.New()
// CountDriftDomains is the real system-wide count, independent of what
// this tick touched — set it to something that would NOT match a local
// per-tick accumulator (only 1 of 2 domains here drifted) to prove the
// gauge comes from the store call, not a local tally.
st.driftCount = 7
sched := New(st, checker, notifier, m)
if err := sched.RunOnce(context.Background(), time.Now()); err != nil {
t.Fatalf("RunOnce: %v", err)
}
if st.status[domainA.ID] != StatusDrift {
t.Fatalf("domain A status = %q, want drift", st.status[domainA.ID])
}
if st.status[domainB.ID] != StatusInSync {
t.Fatalf("domain B status = %q, want in_sync", st.status[domainB.ID])
}
if got := notifier.count(); got != 1 {
t.Fatalf("notifications sent = %d, want 1 (only domain A)", got)
}
if notifier.events[0].Domain != domainA.ID.String() {
t.Fatalf("notified domain = %q, want domain A (%s)", notifier.events[0].Domain, domainA.ID)
}
if notifier.events[0].Status != StatusDrift {
t.Fatalf("notified status = %q, want drift", notifier.events[0].Status)
}
if len(st.touchedProjects) != 1 || st.touchedProjects[0] != projectID {
t.Fatalf("TouchScheduleRun calls = %v, want [%s]", st.touchedProjects, projectID)
}
if got := testutil.ToFloat64(m.ChecksTotal.WithLabelValues(StatusDrift)); got != 1 {
t.Fatalf("ChecksTotal{drift} = %v, want 1", got)
}
if got := testutil.ToFloat64(m.ChecksTotal.WithLabelValues(StatusInSync)); got != 1 {
t.Fatalf("ChecksTotal{in_sync} = %v, want 1", got)
}
if got := testutil.ToFloat64(m.DriftDomains); got != float64(st.driftCount) {
t.Fatalf("DriftDomains gauge = %v, want %d (from CountDriftDomains)", got, st.driftCount)
}
}
func TestRunOnce_Idempotent_NoRepeatNotifyOnUnchangedDrift(t *testing.T) {
projectID := uuid.New()
templateID := uuid.New()
domainA := store.Domain{ID: uuid.New(), ProjectID: projectID, TemplateID: &templateID}
st := newMockStore()
st.schedules = []store.Schedule{{ID: uuid.New(), ProjectID: projectID, IntervalSeconds: 3600, Enabled: true}}
st.domains[projectID] = []store.Domain{domainA}
checker := &mockChecker{
results: map[uuid.UUID]diff.Changeset{domainA.ID: driftChangeset()},
}
notifier := &mockNotifier{}
m := metrics.New()
sched := New(st, checker, notifier, m)
if err := sched.RunOnce(context.Background(), time.Now()); err != nil {
t.Fatalf("first RunOnce: %v", err)
}
if got := notifier.count(); got != 1 {
t.Fatalf("after first run notifications = %d, want 1", got)
}
if err := sched.RunOnce(context.Background(), time.Now()); err != nil {
t.Fatalf("second RunOnce: %v", err)
}
if got := notifier.count(); got != 1 {
t.Fatalf("after second run (drift->drift) notifications = %d, want still 1 (no repeat)", got)
}
}
func TestRunOnce_CheckError_StatusErrorAndNotify(t *testing.T) {
projectID := uuid.New()
templateID := uuid.New()
domainA := store.Domain{ID: uuid.New(), ProjectID: projectID, TemplateID: &templateID}
st := newMockStore()
st.schedules = []store.Schedule{{ID: uuid.New(), ProjectID: projectID, IntervalSeconds: 3600, Enabled: true}}
st.domains[projectID] = []store.Domain{domainA}
checker := &mockChecker{
errs: map[uuid.UUID]error{domainA.ID: errors.New("provider timeout")},
}
notifier := &mockNotifier{}
m := metrics.New()
sched := New(st, checker, notifier, m)
if err := sched.RunOnce(context.Background(), time.Now()); err != nil {
t.Fatalf("RunOnce: %v", err)
}
if st.status[domainA.ID] != StatusError {
t.Fatalf("domain A status = %q, want error", st.status[domainA.ID])
}
if got := notifier.count(); got != 1 {
t.Fatalf("notifications = %d, want 1 (unknown->error)", got)
}
if notifier.events[0].Status != StatusError {
t.Fatalf("notified status = %q, want error", notifier.events[0].Status)
}
if got := testutil.ToFloat64(m.ChecksTotal.WithLabelValues(StatusError)); got != 1 {
t.Fatalf("ChecksTotal{error} = %v, want 1", got)
}
}
func TestRunOnce_SkipsDomainWithoutTemplate(t *testing.T) {
projectID := uuid.New()
templateID := uuid.New()
domainNoTemplate := store.Domain{ID: uuid.New(), ProjectID: projectID, TemplateID: nil}
domainWithTemplate := store.Domain{ID: uuid.New(), ProjectID: projectID, TemplateID: &templateID}
st := newMockStore()
st.schedules = []store.Schedule{{ID: uuid.New(), ProjectID: projectID, IntervalSeconds: 3600, Enabled: true}}
st.domains[projectID] = []store.Domain{domainNoTemplate, domainWithTemplate}
checker := &mockChecker{
results: map[uuid.UUID]diff.Changeset{domainWithTemplate.ID: {}},
}
notifier := &mockNotifier{}
m := metrics.New()
sched := New(st, checker, notifier, m)
if err := sched.RunOnce(context.Background(), time.Now()); err != nil {
t.Fatalf("RunOnce: %v", err)
}
for _, id := range checker.calls {
if id == domainNoTemplate.ID {
t.Fatalf("Checker.Check was called for templateless domain %s, want skipped", id)
}
}
if len(checker.calls) != 1 || checker.calls[0] != domainWithTemplate.ID {
t.Fatalf("Checker.Check calls = %v, want exactly [%s]", checker.calls, domainWithTemplate.ID)
}
if _, ok := st.status[domainNoTemplate.ID]; ok {
t.Fatalf("templateless domain status = %q, want no status set (never checked)", st.status[domainNoTemplate.ID])
}
if st.status[domainWithTemplate.ID] != StatusInSync {
t.Fatalf("domain with template status = %q, want in_sync", st.status[domainWithTemplate.ID])
}
if got := notifier.count(); got != 0 {
t.Fatalf("notifications sent = %d, want 0 (templateless skip is silent, and template domain unknown->in_sync is not news)", got)
}
if got := testutil.ToFloat64(m.ChecksTotal.WithLabelValues(StatusInSync)); got != 1 {
t.Fatalf("ChecksTotal{in_sync} = %v, want 1 (only the templated domain was checked)", got)
}
}
func TestShouldNotify(t *testing.T) {
cases := []struct {
name string
prev string
new string
want bool
}{
{"unknown->drift notifies", StatusUnknown, StatusDrift, true},
{"unknown->error notifies", StatusUnknown, StatusError, true},
{"unknown->in_sync is silent (first sync is not news)", StatusUnknown, StatusInSync, false},
{"drift->drift does not repeat", StatusDrift, StatusDrift, false},
{"error->error does not repeat", StatusError, StatusError, false},
{"drift->in_sync notifies (resolved)", StatusDrift, StatusInSync, true},
{"in_sync->drift notifies", StatusInSync, StatusDrift, true},
{"in_sync->error notifies", StatusInSync, StatusError, true},
{"in_sync->in_sync is silent", StatusInSync, StatusInSync, false},
{"error->drift notifies (still bad, different bad)", StatusError, StatusDrift, true},
{"error->in_sync notifies (resolved after failure)", StatusError, StatusInSync, true},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
if got := shouldNotify(tc.prev, tc.new); got != tc.want {
t.Fatalf("shouldNotify(%q, %q) = %v, want %v", tc.prev, tc.new, got, tc.want)
}
})
}
}
+148
View File
@@ -0,0 +1,148 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.31.1
// source: channels.sql
package db
import (
"context"
"github.com/google/uuid"
)
const createChannel = `-- name: CreateChannel :one
INSERT INTO notification_channels (id, project_id, type, config, secret_enc)
VALUES ($1, $2, $3, $4, $5) RETURNING id, project_id, type, config, secret_enc, enabled, created_at
`
type CreateChannelParams struct {
ID uuid.UUID `json:"id"`
ProjectID uuid.UUID `json:"project_id"`
Type string `json:"type"`
Config []byte `json:"config"`
SecretEnc string `json:"secret_enc"`
}
func (q *Queries) CreateChannel(ctx context.Context, arg CreateChannelParams) (NotificationChannel, error) {
row := q.db.QueryRow(ctx, createChannel,
arg.ID,
arg.ProjectID,
arg.Type,
arg.Config,
arg.SecretEnc,
)
var i NotificationChannel
err := row.Scan(
&i.ID,
&i.ProjectID,
&i.Type,
&i.Config,
&i.SecretEnc,
&i.Enabled,
&i.CreatedAt,
)
return i, err
}
const deleteChannel = `-- name: DeleteChannel :exec
DELETE FROM notification_channels WHERE id = $1 AND project_id = $2
`
type DeleteChannelParams struct {
ID uuid.UUID `json:"id"`
ProjectID uuid.UUID `json:"project_id"`
}
func (q *Queries) DeleteChannel(ctx context.Context, arg DeleteChannelParams) error {
_, err := q.db.Exec(ctx, deleteChannel, arg.ID, arg.ProjectID)
return err
}
const getChannel = `-- name: GetChannel :one
SELECT id, project_id, type, config, secret_enc, enabled, created_at FROM notification_channels WHERE id = $1 AND project_id = $2
`
type GetChannelParams struct {
ID uuid.UUID `json:"id"`
ProjectID uuid.UUID `json:"project_id"`
}
func (q *Queries) GetChannel(ctx context.Context, arg GetChannelParams) (NotificationChannel, error) {
row := q.db.QueryRow(ctx, getChannel, arg.ID, arg.ProjectID)
var i NotificationChannel
err := row.Scan(
&i.ID,
&i.ProjectID,
&i.Type,
&i.Config,
&i.SecretEnc,
&i.Enabled,
&i.CreatedAt,
)
return i, err
}
const listChannels = `-- name: ListChannels :many
SELECT id, project_id, type, config, secret_enc, enabled, created_at FROM notification_channels WHERE project_id = $1 ORDER BY created_at
`
func (q *Queries) ListChannels(ctx context.Context, projectID uuid.UUID) ([]NotificationChannel, error) {
rows, err := q.db.Query(ctx, listChannels, projectID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []NotificationChannel
for rows.Next() {
var i NotificationChannel
if err := rows.Scan(
&i.ID,
&i.ProjectID,
&i.Type,
&i.Config,
&i.SecretEnc,
&i.Enabled,
&i.CreatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listEnabledChannels = `-- name: ListEnabledChannels :many
SELECT id, project_id, type, config, secret_enc, enabled, created_at FROM notification_channels WHERE project_id = $1 AND enabled ORDER BY created_at
`
func (q *Queries) ListEnabledChannels(ctx context.Context, projectID uuid.UUID) ([]NotificationChannel, error) {
rows, err := q.db.Query(ctx, listEnabledChannels, projectID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []NotificationChannel
for rows.Next() {
var i NotificationChannel
if err := rows.Scan(
&i.ID,
&i.ProjectID,
&i.Type,
&i.Config,
&i.SecretEnc,
&i.Enabled,
&i.CreatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
+29
View File
@@ -34,3 +34,32 @@ func (q *Queries) CreateCheckRun(ctx context.Context, arg CreateCheckRunParams)
)
return i, err
}
const listCheckRuns = `-- name: ListCheckRuns :many
SELECT id, domain_id, result, created_at FROM check_runs WHERE domain_id = $1 ORDER BY created_at DESC LIMIT 50
`
func (q *Queries) ListCheckRuns(ctx context.Context, domainID uuid.UUID) ([]CheckRun, error) {
rows, err := q.db.Query(ctx, listCheckRuns, domainID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []CheckRun
for rows.Next() {
var i CheckRun
if err := rows.Scan(
&i.ID,
&i.DomainID,
&i.Result,
&i.CreatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
+46 -5
View File
@@ -12,10 +12,21 @@ import (
dto "github.com/vasyakrg/dns-autoresolver/internal/store/dto"
)
const countDriftDomains = `-- name: CountDriftDomains :one
SELECT count(*) FROM domains WHERE last_check_status = 'drift'
`
func (q *Queries) CountDriftDomains(ctx context.Context) (int64, error) {
row := q.db.QueryRow(ctx, countDriftDomains)
var count int64
err := row.Scan(&count)
return count, err
}
const createDomain = `-- name: CreateDomain :one
INSERT INTO domains (id, project_id, provider_account_id, zone_name, zone_id, template_id)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, project_id, provider_account_id, zone_name, zone_id, template_id, created_at
RETURNING id, project_id, provider_account_id, zone_name, zone_id, template_id, created_at, last_check_status
`
type CreateDomainParams struct {
@@ -45,6 +56,7 @@ func (q *Queries) CreateDomain(ctx context.Context, arg CreateDomainParams) (Dom
&i.ZoneID,
&i.TemplateID,
&i.CreatedAt,
&i.LastCheckStatus,
)
return i, err
}
@@ -64,7 +76,7 @@ func (q *Queries) DeleteDomain(ctx context.Context, arg DeleteDomainParams) erro
}
const getDomain = `-- name: GetDomain :one
SELECT id, project_id, provider_account_id, zone_name, zone_id, template_id, created_at FROM domains WHERE id = $1 AND project_id = $2
SELECT id, project_id, provider_account_id, zone_name, zone_id, template_id, created_at, last_check_status FROM domains WHERE id = $1 AND project_id = $2
`
type GetDomainParams struct {
@@ -83,15 +95,27 @@ func (q *Queries) GetDomain(ctx context.Context, arg GetDomainParams) (Domain, e
&i.ZoneID,
&i.TemplateID,
&i.CreatedAt,
&i.LastCheckStatus,
)
return i, err
}
const getDomainStatus = `-- name: GetDomainStatus :one
SELECT last_check_status FROM domains WHERE id = $1
`
func (q *Queries) GetDomainStatus(ctx context.Context, id uuid.UUID) (string, error) {
row := q.db.QueryRow(ctx, getDomainStatus, id)
var last_check_status string
err := row.Scan(&last_check_status)
return last_check_status, err
}
const importDomain = `-- name: ImportDomain :one
INSERT INTO domains (id, project_id, provider_account_id, zone_name, zone_id, template_id)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (project_id, zone_id) DO NOTHING
RETURNING id, project_id, provider_account_id, zone_name, zone_id, template_id, created_at
RETURNING id, project_id, provider_account_id, zone_name, zone_id, template_id, created_at, last_check_status
`
type ImportDomainParams struct {
@@ -121,12 +145,13 @@ func (q *Queries) ImportDomain(ctx context.Context, arg ImportDomainParams) (Dom
&i.ZoneID,
&i.TemplateID,
&i.CreatedAt,
&i.LastCheckStatus,
)
return i, err
}
const listDomains = `-- name: ListDomains :many
SELECT id, project_id, provider_account_id, zone_name, zone_id, template_id, created_at FROM domains WHERE project_id = $1 ORDER BY created_at
SELECT id, project_id, provider_account_id, zone_name, zone_id, template_id, created_at, last_check_status FROM domains WHERE project_id = $1 ORDER BY created_at
`
func (q *Queries) ListDomains(ctx context.Context, projectID uuid.UUID) ([]Domain, error) {
@@ -146,6 +171,7 @@ func (q *Queries) ListDomains(ctx context.Context, projectID uuid.UUID) ([]Domai
&i.ZoneID,
&i.TemplateID,
&i.CreatedAt,
&i.LastCheckStatus,
); err != nil {
return nil, err
}
@@ -189,9 +215,23 @@ func (q *Queries) LoadDomainFull(ctx context.Context, arg LoadDomainFullParams)
return i, err
}
const setDomainStatus = `-- name: SetDomainStatus :exec
UPDATE domains SET last_check_status = $2 WHERE id = $1
`
type SetDomainStatusParams struct {
ID uuid.UUID `json:"id"`
LastCheckStatus string `json:"last_check_status"`
}
func (q *Queries) SetDomainStatus(ctx context.Context, arg SetDomainStatusParams) error {
_, err := q.db.Exec(ctx, setDomainStatus, arg.ID, arg.LastCheckStatus)
return err
}
const updateDomainTemplate = `-- name: UpdateDomainTemplate :one
UPDATE domains SET template_id = $3 WHERE id = $1 AND project_id = $2
RETURNING id, project_id, provider_account_id, zone_name, zone_id, template_id, created_at
RETURNING id, project_id, provider_account_id, zone_name, zone_id, template_id, created_at, last_check_status
`
type UpdateDomainTemplateParams struct {
@@ -211,6 +251,7 @@ func (q *Queries) UpdateDomainTemplate(ctx context.Context, arg UpdateDomainTemp
&i.ZoneID,
&i.TemplateID,
&i.CreatedAt,
&i.LastCheckStatus,
)
return i, err
}
+20
View File
@@ -25,6 +25,17 @@ type Domain struct {
ZoneID string `json:"zone_id"`
TemplateID *uuid.UUID `json:"template_id"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
LastCheckStatus string `json:"last_check_status"`
}
type NotificationChannel struct {
ID uuid.UUID `json:"id"`
ProjectID uuid.UUID `json:"project_id"`
Type string `json:"type"`
Config []byte `json:"config"`
SecretEnc string `json:"secret_enc"`
Enabled bool `json:"enabled"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
}
type Project struct {
@@ -43,6 +54,15 @@ type ProviderAccount struct {
CreatedAt pgtype.Timestamptz `json:"created_at"`
}
type Schedule struct {
ID uuid.UUID `json:"id"`
ProjectID uuid.UUID `json:"project_id"`
IntervalSeconds int32 `json:"interval_seconds"`
Enabled bool `json:"enabled"`
LastRunAt pgtype.Timestamptz `json:"last_run_at"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
}
type Session struct {
ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"`
+110
View File
@@ -0,0 +1,110 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.31.1
// source: schedules.sql
package db
import (
"context"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgtype"
)
const getSchedule = `-- name: GetSchedule :one
SELECT id, project_id, interval_seconds, enabled, last_run_at, created_at FROM schedules WHERE project_id = $1
`
func (q *Queries) GetSchedule(ctx context.Context, projectID uuid.UUID) (Schedule, error) {
row := q.db.QueryRow(ctx, getSchedule, projectID)
var i Schedule
err := row.Scan(
&i.ID,
&i.ProjectID,
&i.IntervalSeconds,
&i.Enabled,
&i.LastRunAt,
&i.CreatedAt,
)
return i, err
}
const listDueSchedules = `-- name: ListDueSchedules :many
SELECT id, project_id, interval_seconds, enabled, last_run_at, created_at FROM schedules
WHERE enabled AND (last_run_at IS NULL OR last_run_at + (interval_seconds || ' seconds')::interval <= $1)
`
func (q *Queries) ListDueSchedules(ctx context.Context, lastRunAt pgtype.Timestamptz) ([]Schedule, error) {
rows, err := q.db.Query(ctx, listDueSchedules, lastRunAt)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Schedule
for rows.Next() {
var i Schedule
if err := rows.Scan(
&i.ID,
&i.ProjectID,
&i.IntervalSeconds,
&i.Enabled,
&i.LastRunAt,
&i.CreatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const touchScheduleRun = `-- name: TouchScheduleRun :exec
UPDATE schedules SET last_run_at = $2 WHERE project_id = $1
`
type TouchScheduleRunParams struct {
ProjectID uuid.UUID `json:"project_id"`
LastRunAt pgtype.Timestamptz `json:"last_run_at"`
}
func (q *Queries) TouchScheduleRun(ctx context.Context, arg TouchScheduleRunParams) error {
_, err := q.db.Exec(ctx, touchScheduleRun, arg.ProjectID, arg.LastRunAt)
return err
}
const upsertSchedule = `-- name: UpsertSchedule :one
INSERT INTO schedules (id, project_id, interval_seconds, enabled)
VALUES ($1, $2, $3, $4)
ON CONFLICT (project_id) DO UPDATE SET interval_seconds = $3, enabled = $4
RETURNING id, project_id, interval_seconds, enabled, last_run_at, created_at
`
type UpsertScheduleParams struct {
ID uuid.UUID `json:"id"`
ProjectID uuid.UUID `json:"project_id"`
IntervalSeconds int32 `json:"interval_seconds"`
Enabled bool `json:"enabled"`
}
func (q *Queries) UpsertSchedule(ctx context.Context, arg UpsertScheduleParams) (Schedule, error) {
row := q.db.QueryRow(ctx, upsertSchedule,
arg.ID,
arg.ProjectID,
arg.IntervalSeconds,
arg.Enabled,
)
var i Schedule
err := row.Scan(
&i.ID,
&i.ProjectID,
&i.IntervalSeconds,
&i.Enabled,
&i.LastRunAt,
&i.CreatedAt,
)
return i, err
}
+35
View File
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
@@ -51,6 +52,40 @@ func (s *Store) SaveCheckRun(ctx context.Context, domainID uuid.UUID, cs diff.Ch
return err
}
// CheckRun is a provider-neutral summary of a past check/apply run, returned
// by ListCheckRuns for the domain history endpoint (Фаза 3).
type CheckRun struct {
ID uuid.UUID
DomainID uuid.UUID
Result json.RawMessage
CreatedAt time.Time
}
func checkRunFromDB(c db.CheckRun) CheckRun {
return CheckRun{
ID: c.ID,
DomainID: c.DomainID,
Result: json.RawMessage(c.Result),
CreatedAt: c.CreatedAt.Time,
}
}
// ListCheckRuns returns the most recent check_runs rows for a domain (newest
// first, capped at 50). Not scoped by project itself — callers must verify
// the domain belongs to the caller's project first (e.g. via GetDomain)
// since check_runs only references domain_id.
func (s *Store) ListCheckRuns(ctx context.Context, domainID uuid.UUID) ([]CheckRun, error) {
rows, err := s.q.ListCheckRuns(ctx, domainID)
if err != nil {
return nil, err
}
out := make([]CheckRun, 0, len(rows))
for _, r := range rows {
out = append(out, checkRunFromDB(r))
}
return out, nil
}
// compile-time interface checks
var _ service.Loader = (*Store)(nil)
var _ service.Recorder = (*Store)(nil)
@@ -0,0 +1,24 @@
-- +goose Up
CREATE TABLE schedules (
id uuid PRIMARY KEY,
project_id uuid NOT NULL UNIQUE REFERENCES projects(id) ON DELETE CASCADE,
interval_seconds int NOT NULL DEFAULT 3600,
enabled boolean NOT NULL DEFAULT false,
last_run_at timestamptz,
created_at timestamptz NOT NULL DEFAULT now()
);
CREATE TABLE notification_channels (
id uuid PRIMARY KEY,
project_id uuid NOT NULL REFERENCES projects(id) ON DELETE CASCADE,
type text NOT NULL,
config jsonb NOT NULL,
secret_enc text NOT NULL DEFAULT '',
enabled boolean NOT NULL DEFAULT true,
created_at timestamptz NOT NULL DEFAULT now()
);
ALTER TABLE domains ADD COLUMN last_check_status text NOT NULL DEFAULT 'unknown';
-- +goose Down
ALTER TABLE domains DROP COLUMN last_check_status;
DROP TABLE notification_channels;
DROP TABLE schedules;
+15
View File
@@ -0,0 +1,15 @@
-- name: CreateChannel :one
INSERT INTO notification_channels (id, project_id, type, config, secret_enc)
VALUES ($1, $2, $3, $4, $5) RETURNING *;
-- name: ListChannels :many
SELECT * FROM notification_channels WHERE project_id = $1 ORDER BY created_at;
-- name: ListEnabledChannels :many
SELECT * FROM notification_channels WHERE project_id = $1 AND enabled ORDER BY created_at;
-- name: GetChannel :one
SELECT * FROM notification_channels WHERE id = $1 AND project_id = $2;
-- name: DeleteChannel :exec
DELETE FROM notification_channels WHERE id = $1 AND project_id = $2;
+3
View File
@@ -2,3 +2,6 @@
INSERT INTO check_runs (id, domain_id, result)
VALUES ($1, $2, $3)
RETURNING *;
-- name: ListCheckRuns :many
SELECT * FROM check_runs WHERE domain_id = $1 ORDER BY created_at DESC LIMIT 50;
+9
View File
@@ -28,3 +28,12 @@ FROM domains d
JOIN provider_accounts a ON a.id = d.provider_account_id
LEFT JOIN templates t ON t.id = d.template_id
WHERE d.id = $1 AND d.project_id = $2;
-- name: GetDomainStatus :one
SELECT last_check_status FROM domains WHERE id = $1;
-- name: SetDomainStatus :exec
UPDATE domains SET last_check_status = $2 WHERE id = $1;
-- name: CountDriftDomains :one
SELECT count(*) FROM domains WHERE last_check_status = 'drift';
+15
View File
@@ -0,0 +1,15 @@
-- name: GetSchedule :one
SELECT * FROM schedules WHERE project_id = $1;
-- name: UpsertSchedule :one
INSERT INTO schedules (id, project_id, interval_seconds, enabled)
VALUES ($1, $2, $3, $4)
ON CONFLICT (project_id) DO UPDATE SET interval_seconds = $3, enabled = $4
RETURNING *;
-- name: ListDueSchedules :many
SELECT * FROM schedules
WHERE enabled AND (last_run_at IS NULL OR last_run_at + (interval_seconds || ' seconds')::interval <= $1);
-- name: TouchScheduleRun :exec
UPDATE schedules SET last_run_at = $2 WHERE project_id = $1;
+267
View File
@@ -0,0 +1,267 @@
package store
import (
"encoding/json"
"errors"
"testing"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
)
// TestUpsertSchedule_InsertThenUpdate verifies UpsertSchedule inserts a new
// row for a project on the first call and updates that same row (rather
// than inserting a second one) on a subsequent call, per the
// ON CONFLICT (project_id) DO UPDATE clause.
func TestUpsertSchedule_InsertThenUpdate(t *testing.T) {
s, ctx := newStore(t)
_, p, err := s.RegisterUser(ctx, "sched-upsert@example.com", "argon2-hash")
if err != nil {
t.Fatal(err)
}
created, err := s.UpsertSchedule(ctx, p.ID, 1800, true)
if err != nil {
t.Fatal(err)
}
if created.IntervalSeconds != 1800 || !created.Enabled {
t.Fatalf("unexpected created schedule: %+v", created)
}
updated, err := s.UpsertSchedule(ctx, p.ID, 7200, false)
if err != nil {
t.Fatal(err)
}
if updated.ID != created.ID {
t.Fatalf("expected same row id, got created=%s updated=%s", created.ID, updated.ID)
}
if updated.IntervalSeconds != 7200 || updated.Enabled {
t.Fatalf("unexpected updated schedule: %+v", updated)
}
got, err := s.GetSchedule(ctx, p.ID)
if err != nil {
t.Fatal(err)
}
if got.IntervalSeconds != 7200 || got.Enabled {
t.Fatalf("GetSchedule mismatch after update: %+v", got)
}
}
// TestGetSchedule_NoRowReturnsErrNoRows verifies the contract used by the API
// layer (Task 5): a project with no schedule row yet returns pgx.ErrNoRows,
// which the API translates into the default {interval:3600, enabled:false}.
func TestGetSchedule_NoRowReturnsErrNoRows(t *testing.T) {
s, ctx := newStore(t)
_, p, err := s.RegisterUser(ctx, "sched-norow@example.com", "argon2-hash")
if err != nil {
t.Fatal(err)
}
if _, err := s.GetSchedule(ctx, p.ID); !errors.Is(err, pgx.ErrNoRows) {
t.Fatalf("expected pgx.ErrNoRows, got %v", err)
}
}
// TestListDueSchedules verifies the due-selection logic: an enabled schedule
// that never ran (last_run_at IS NULL) is due; a disabled schedule is never
// due; and an enabled schedule that ran recently with a long interval is not
// yet due.
func TestListDueSchedules(t *testing.T) {
s, ctx := newStore(t)
now := time.Now().UTC()
_, neverRunProject, err := s.RegisterUser(ctx, "sched-neverrun@example.com", "argon2-hash")
if err != nil {
t.Fatal(err)
}
if _, err := s.UpsertSchedule(ctx, neverRunProject.ID, 3600, true); err != nil {
t.Fatal(err)
}
_, disabledProject, err := s.RegisterUser(ctx, "sched-disabled@example.com", "argon2-hash")
if err != nil {
t.Fatal(err)
}
if _, err := s.UpsertSchedule(ctx, disabledProject.ID, 60, false); err != nil {
t.Fatal(err)
}
_, recentProject, err := s.RegisterUser(ctx, "sched-recent@example.com", "argon2-hash")
if err != nil {
t.Fatal(err)
}
if _, err := s.UpsertSchedule(ctx, recentProject.ID, 3600, true); err != nil {
t.Fatal(err)
}
if err := s.TouchScheduleRun(ctx, recentProject.ID, now); err != nil {
t.Fatal(err)
}
due, err := s.ListDueSchedules(ctx, now)
if err != nil {
t.Fatal(err)
}
byProject := make(map[uuid.UUID]bool, len(due))
for _, d := range due {
byProject[d.ProjectID] = true
}
if !byProject[neverRunProject.ID] {
t.Errorf("expected enabled/never-run schedule for project %s to be due", neverRunProject.ID)
}
if byProject[disabledProject.ID] {
t.Errorf("did not expect disabled schedule for project %s to be due", disabledProject.ID)
}
if byProject[recentProject.ID] {
t.Errorf("did not expect recently-run schedule (long interval) for project %s to be due", recentProject.ID)
}
}
// TestTouchScheduleRun_SetsLastRunAt verifies TouchScheduleRun persists
// last_run_at, which GetSchedule then returns as a non-nil *time.Time close
// to the value passed in.
func TestTouchScheduleRun_SetsLastRunAt(t *testing.T) {
s, ctx := newStore(t)
_, p, err := s.RegisterUser(ctx, "sched-touch@example.com", "argon2-hash")
if err != nil {
t.Fatal(err)
}
if _, err := s.UpsertSchedule(ctx, p.ID, 3600, true); err != nil {
t.Fatal(err)
}
at := time.Now().UTC().Truncate(time.Second)
if err := s.TouchScheduleRun(ctx, p.ID, at); err != nil {
t.Fatal(err)
}
got, err := s.GetSchedule(ctx, p.ID)
if err != nil {
t.Fatal(err)
}
if got.LastRunAt == nil {
t.Fatal("expected non-nil LastRunAt after TouchScheduleRun")
}
if diff := got.LastRunAt.Sub(at); diff < -time.Second || diff > time.Second {
t.Fatalf("expected LastRunAt ~%v, got %v", at, *got.LastRunAt)
}
}
// TestChannelCRUD_ScopedByProject verifies CreateChannel/ListChannels/
// GetChannel/DeleteChannel round-trip correctly and that GetChannel scopes
// by project_id: looking up a channel with the wrong project ID must fail
// with pgx.ErrNoRows rather than returning another tenant's channel.
func TestChannelCRUD_ScopedByProject(t *testing.T) {
s, ctx := newStore(t)
_, p1, err := s.RegisterUser(ctx, "chan-owner@example.com", "argon2-hash")
if err != nil {
t.Fatal(err)
}
_, p2, err := s.RegisterUser(ctx, "chan-other@example.com", "argon2-hash")
if err != nil {
t.Fatal(err)
}
cfg := json.RawMessage(`{"webhook_url":"https://example.com/hook"}`)
ch, err := s.CreateChannel(ctx, p1.ID, "telegram", cfg, "enc-secret")
if err != nil {
t.Fatal(err)
}
// jsonb round-trips through Postgres with its own canonical formatting
// (e.g. a space after ':'), so compare decoded values rather than raw
// bytes.
var gotCfg, wantCfg map[string]string
if err := json.Unmarshal(ch.Config, &gotCfg); err != nil {
t.Fatalf("unmarshal returned config: %v", err)
}
if err := json.Unmarshal(cfg, &wantCfg); err != nil {
t.Fatalf("unmarshal expected config: %v", err)
}
if ch.Type != "telegram" || !ch.Enabled || gotCfg["webhook_url"] != wantCfg["webhook_url"] || ch.SecretEnc != "enc-secret" {
t.Fatalf("unexpected created channel: %+v", ch)
}
list, err := s.ListChannels(ctx, p1.ID)
if err != nil {
t.Fatal(err)
}
if len(list) != 1 || list[0].ID != ch.ID {
t.Fatalf("unexpected ListChannels result: %+v", list)
}
enabledList, err := s.ListEnabledChannels(ctx, p1.ID)
if err != nil {
t.Fatal(err)
}
if len(enabledList) != 1 || enabledList[0].ID != ch.ID {
t.Fatalf("unexpected ListEnabledChannels result: %+v", enabledList)
}
got, err := s.GetChannel(ctx, ch.ID, p1.ID)
if err != nil {
t.Fatal(err)
}
if got.ID != ch.ID {
t.Fatalf("GetChannel mismatch: %+v", got)
}
if _, err := s.GetChannel(ctx, ch.ID, p2.ID); !errors.Is(err, pgx.ErrNoRows) {
t.Fatalf("expected pgx.ErrNoRows for foreign project, got %v", err)
}
if err := s.DeleteChannel(ctx, ch.ID, p1.ID); err != nil {
t.Fatal(err)
}
if _, err := s.GetChannel(ctx, ch.ID, p1.ID); !errors.Is(err, pgx.ErrNoRows) {
t.Fatalf("expected pgx.ErrNoRows after delete, got %v", err)
}
}
// TestDomainStatus_RoundTrip verifies SetDomainStatus/GetDomainStatus
// round-trip, and that a freshly-imported domain defaults to "unknown" per
// the migration's DEFAULT 'unknown'.
func TestDomainStatus_RoundTrip(t *testing.T) {
s, ctx := newStore(t)
_, p, err := s.RegisterUser(ctx, "domain-status@example.com", "argon2-hash")
if err != nil {
t.Fatal(err)
}
acc, err := s.CreateAccount(ctx, p.ID, "selectel", "enc-blob", "test")
if err != nil {
t.Fatal(err)
}
d, err := s.CreateDomain(ctx, p.ID, acc.ID, "example.com", "zone-1", nil)
if err != nil {
t.Fatal(err)
}
status, err := s.GetDomainStatus(ctx, d.ID)
if err != nil {
t.Fatal(err)
}
if status != "unknown" {
t.Fatalf("expected default status 'unknown', got %q", status)
}
if err := s.SetDomainStatus(ctx, d.ID, "ok"); err != nil {
t.Fatal(err)
}
status, err = s.GetDomainStatus(ctx, d.ID)
if err != nil {
t.Fatal(err)
}
if status != "ok" {
t.Fatalf("expected status 'ok' after SetDomainStatus, got %q", status)
}
domains, err := s.ListDomains(ctx, p.ID)
if err != nil {
t.Fatal(err)
}
if len(domains) != 1 || domains[0].LastCheckStatus != "ok" {
t.Fatalf("expected ListDomains to reflect updated status: %+v", domains)
}
}
+193
View File
@@ -2,6 +2,7 @@ package store
import (
"context"
"encoding/json"
"errors"
"time"
@@ -137,12 +138,14 @@ type Domain struct {
ZoneName string
ZoneID string
TemplateID *uuid.UUID
LastCheckStatus string
}
func domainFromDB(d db.Domain) Domain {
return Domain{
ID: d.ID, ProjectID: d.ProjectID, ProviderAccountID: d.ProviderAccountID,
ZoneName: d.ZoneName, ZoneID: d.ZoneID, TemplateID: d.TemplateID,
LastCheckStatus: d.LastCheckStatus,
}
}
@@ -173,6 +176,17 @@ func (s *Store) DeleteDomain(ctx context.Context, id, projectID uuid.UUID) error
return s.q.DeleteDomain(ctx, db.DeleteDomainParams{ID: id, ProjectID: projectID})
}
// GetDomain is a scoped lookup used to verify a domain belongs to projectID
// before it's referenced elsewhere (e.g. history — check_runs isn't itself
// scoped by project, so callers must confirm domain ownership first).
func (s *Store) GetDomain(ctx context.Context, id, projectID uuid.UUID) (Domain, error) {
d, err := s.q.GetDomain(ctx, db.GetDomainParams{ID: id, ProjectID: projectID})
if err != nil {
return Domain{}, err
}
return domainFromDB(d), nil
}
// ImportDomains creates one domain per zone inside a single transaction: if
// any zone fails to be created, the whole batch is rolled back so callers
// never observe a partially-imported set of domains.
@@ -231,6 +245,27 @@ func (s *Store) SetDomainTemplate(ctx context.Context, domainID, projectID uuid.
return domainFromDB(d), nil
}
// GetDomainStatus returns the last known check status for a domain (Фаза 3
// scheduler/checker). Callers scope access to the domain themselves (e.g.
// via a prior GetDomain) — this lookup is by primary key alone.
func (s *Store) GetDomainStatus(ctx context.Context, domainID uuid.UUID) (string, error) {
return s.q.GetDomainStatus(ctx, domainID)
}
// SetDomainStatus records the outcome of the most recent check/apply run for
// a domain (e.g. "ok", "drift", "error").
func (s *Store) SetDomainStatus(ctx context.Context, domainID uuid.UUID, status string) error {
return s.q.SetDomainStatus(ctx, db.SetDomainStatusParams{ID: domainID, LastCheckStatus: status})
}
// CountDriftDomains returns the current number of domains system-wide whose
// last check status is "drift". This is a global count (not per-project) —
// it backs the dns_ar_drift_domains gauge, which is a system-level metric.
func (s *Store) CountDriftDomains(ctx context.Context) (int, error) {
n, err := s.q.CountDriftDomains(ctx)
return int(n), err
}
// User and Project are provider-neutral domain structs for the auth/tenant
// layer (Фаза 2), mirroring the Account/Template/Domain wrappers above so
// callers never need to import internal/store/db directly.
@@ -369,3 +404,161 @@ func (s *Store) RegisterUser(ctx context.Context, email, passwordHash string) (U
}
return toUser(dbu), toProject(dbp), nil
}
// Schedule and Channel are provider-neutral domain structs for the
// scheduler/notifications layer (Фаза 3), mirroring the wrappers above so
// callers never need to import internal/store/db or pgtype directly.
type Schedule struct {
ID uuid.UUID
ProjectID uuid.UUID
IntervalSeconds int32
Enabled bool
LastRunAt *time.Time
}
// timeFromTimestamptz converts a nullable pgtype.Timestamptz (schedules.last_run_at)
// into a *time.Time, nil when the column is NULL (schedule never ran).
func timeFromTimestamptz(t pgtype.Timestamptz) *time.Time {
if !t.Valid {
return nil
}
tt := t.Time
return &tt
}
// timestamptzFromTime is the inverse of timeFromTimestamptz, used to pass a
// Go time.Time (or nil) into a nullable timestamptz query parameter.
func timestamptzFromTime(t *time.Time) pgtype.Timestamptz {
if t == nil {
return pgtype.Timestamptz{}
}
return pgtype.Timestamptz{Time: *t, Valid: true}
}
func scheduleFromDB(s db.Schedule) Schedule {
return Schedule{
ID: s.ID,
ProjectID: s.ProjectID,
IntervalSeconds: s.IntervalSeconds,
Enabled: s.Enabled,
LastRunAt: timeFromTimestamptz(s.LastRunAt),
}
}
// GetSchedule looks up the schedule row for projectID. When no schedule has
// ever been created for the project it returns pgx.ErrNoRows unwrapped —
// the API layer (Task 5) is expected to treat that as the default schedule
// {interval: 3600, enabled: false} rather than an error.
func (s *Store) GetSchedule(ctx context.Context, projectID uuid.UUID) (Schedule, error) {
sc, err := s.q.GetSchedule(ctx, projectID)
if err != nil {
return Schedule{}, err
}
return scheduleFromDB(sc), nil
}
// UpsertSchedule creates or updates the (single, UNIQUE) schedule row for a
// project: an existing row has its interval/enabled flag updated in place
// rather than a second row being inserted.
func (s *Store) UpsertSchedule(ctx context.Context, projectID uuid.UUID, interval int32, enabled bool) (Schedule, error) {
sc, err := s.q.UpsertSchedule(ctx, db.UpsertScheduleParams{
ID: uuid.New(), ProjectID: projectID, IntervalSeconds: interval, Enabled: enabled,
})
if err != nil {
return Schedule{}, err
}
return scheduleFromDB(sc), nil
}
// ListDueSchedules returns every enabled schedule that is due to run at
// `now`: either it has never run (last_run_at IS NULL) or its interval has
// elapsed since the last run.
func (s *Store) ListDueSchedules(ctx context.Context, now time.Time) ([]Schedule, error) {
rows, err := s.q.ListDueSchedules(ctx, pgtype.Timestamptz{Time: now, Valid: true})
if err != nil {
return nil, err
}
out := make([]Schedule, 0, len(rows))
for _, r := range rows {
out = append(out, scheduleFromDB(r))
}
return out, nil
}
// TouchScheduleRun records that a project's schedule ran at `at`, so the
// next ListDueSchedules call excludes it until the interval elapses again.
func (s *Store) TouchScheduleRun(ctx context.Context, projectID uuid.UUID, at time.Time) error {
return s.q.TouchScheduleRun(ctx, db.TouchScheduleRunParams{
ProjectID: projectID, LastRunAt: timestamptzFromTime(&at),
})
}
type Channel struct {
ID uuid.UUID
ProjectID uuid.UUID
Type string
Config json.RawMessage
SecretEnc string
Enabled bool
}
// channelFromDB never logs SecretEnc — callers must not either (secrets are
// encrypted at rest, but the plaintext blob should still stay out of logs).
func channelFromDB(c db.NotificationChannel) Channel {
return Channel{
ID: c.ID, ProjectID: c.ProjectID, Type: c.Type,
Config: json.RawMessage(c.Config), SecretEnc: c.SecretEnc, Enabled: c.Enabled,
}
}
func (s *Store) CreateChannel(ctx context.Context, projectID uuid.UUID, ctype string, config json.RawMessage, secretEnc string) (Channel, error) {
c, err := s.q.CreateChannel(ctx, db.CreateChannelParams{
ID: uuid.New(), ProjectID: projectID, Type: ctype, Config: []byte(config), SecretEnc: secretEnc,
})
if err != nil {
return Channel{}, err
}
return channelFromDB(c), nil
}
func (s *Store) ListChannels(ctx context.Context, projectID uuid.UUID) ([]Channel, error) {
rows, err := s.q.ListChannels(ctx, projectID)
if err != nil {
return nil, err
}
out := make([]Channel, 0, len(rows))
for _, r := range rows {
out = append(out, channelFromDB(r))
}
return out, nil
}
// ListEnabledChannels returns only the channels a project has enabled — used
// by the notification dispatcher (Фаза 3) so disabled channels are silently
// skipped rather than filtered by every caller.
func (s *Store) ListEnabledChannels(ctx context.Context, projectID uuid.UUID) ([]Channel, error) {
rows, err := s.q.ListEnabledChannels(ctx, projectID)
if err != nil {
return nil, err
}
out := make([]Channel, 0, len(rows))
for _, r := range rows {
out = append(out, channelFromDB(r))
}
return out, nil
}
// GetChannel is scoped by projectID: a channel ID belonging to another
// tenant's project returns pgx.ErrNoRows rather than the foreign channel.
func (s *Store) GetChannel(ctx context.Context, id, projectID uuid.UUID) (Channel, error) {
c, err := s.q.GetChannel(ctx, db.GetChannelParams{ID: id, ProjectID: projectID})
if err != nil {
return Channel{}, err
}
return channelFromDB(c), nil
}
func (s *Store) DeleteChannel(ctx context.Context, id, projectID uuid.UUID) error {
return s.q.DeleteChannel(ctx, db.DeleteChannelParams{ID: id, ProjectID: projectID})
}
+4
View File
@@ -3,10 +3,12 @@ import { Routes, Route, Navigate } from "react-router-dom"
import { ProtectedRoute } from "@/auth/ProtectedRoute"
import { Layout } from "@/components/Layout"
import { AccountsPage } from "@/pages/AccountsPage"
import { ChannelsPage } from "@/pages/ChannelsPage"
import { DomainDiffPage } from "@/pages/DomainDiffPage"
import { DomainsPage } from "@/pages/DomainsPage"
import { LoginPage } from "@/pages/LoginPage"
import { RegisterPage } from "@/pages/RegisterPage"
import { SchedulePage } from "@/pages/SchedulePage"
import { TemplatesPage } from "@/pages/TemplatesPage"
// Every non-auth route shares the same guard + chrome; wrapping here keeps
@@ -29,6 +31,8 @@ export function App() {
<Route path="/domains/:id" element={<Protected><DomainDiffPage /></Protected>} />
<Route path="/accounts" element={<Protected><AccountsPage /></Protected>} />
<Route path="/templates" element={<Protected><TemplatesPage /></Protected>} />
<Route path="/schedule" element={<Protected><SchedulePage /></Protected>} />
<Route path="/channels" element={<Protected><ChannelsPage /></Protected>} />
</Routes>
)
}
+68
View File
@@ -128,4 +128,72 @@ describe("api client", () => {
expect(url).toBe(`/api/v1/projects/${PROJECT_ID}/domains/d1`)
expect((opts as RequestInit).method).toBe("PATCH")
})
describe("schedule", () => {
it("getSchedule(projectId) GETs /schedule", async () => {
const spy = mockFetch({ intervalSeconds: 3600, enabled: false })
await api.getSchedule(PROJECT_ID)
expect(spy).toHaveBeenCalledWith(
`/api/v1/projects/${PROJECT_ID}/schedule`,
expect.objectContaining({ method: "GET", credentials: "include" }),
)
})
it("putSchedule(projectId, {intervalSeconds,enabled}) PUTs /schedule", async () => {
const spy = mockFetch({ intervalSeconds: 120, enabled: true })
await api.putSchedule(PROJECT_ID, { intervalSeconds: 120, enabled: true })
const [url, opts] = spy.mock.calls[0]
expect(url).toBe(`/api/v1/projects/${PROJECT_ID}/schedule`)
expect((opts as RequestInit).method).toBe("PUT")
expect((opts as RequestInit).credentials).toBe("include")
expect(String((opts as RequestInit).body)).toContain("intervalSeconds")
})
})
describe("channels", () => {
it("listChannels(projectId) GETs /channels", async () => {
const spy = mockFetch([])
await api.listChannels(PROJECT_ID)
expect(spy).toHaveBeenCalledWith(
`/api/v1/projects/${PROJECT_ID}/channels`,
expect.objectContaining({ method: "GET" }),
)
})
it("createChannel(projectId, {type,config,secret}) POSTs /channels with secret in body", async () => {
const spy = mockFetch({ id: "c1", type: "telegram", config: { chat_id: "1" }, enabled: true })
await api.createChannel(PROJECT_ID, { type: "telegram", config: { chat_id: "1" }, secret: "BOT_TOKEN" })
const [url, opts] = spy.mock.calls[0]
expect(url).toBe(`/api/v1/projects/${PROJECT_ID}/channels`)
expect((opts as RequestInit).method).toBe("POST")
expect(String((opts as RequestInit).body)).toContain("BOT_TOKEN")
})
it("deleteChannel(projectId, id) DELETEs /channels/{id}", async () => {
const spy = mockFetch(undefined, true, 204)
await api.deleteChannel(PROJECT_ID, "c1")
expect(spy).toHaveBeenCalledWith(
`/api/v1/projects/${PROJECT_ID}/channels/c1`,
expect.objectContaining({ method: "DELETE" }),
)
})
it("testChannel(projectId, id) POSTs /channels/{id}/test", async () => {
const spy = mockFetch({ status: "ok" })
await api.testChannel(PROJECT_ID, "c1")
expect(spy).toHaveBeenCalledWith(
`/api/v1/projects/${PROJECT_ID}/channels/c1/test`,
expect.objectContaining({ method: "POST" }),
)
})
})
it("domainHistory(projectId, domainId) GETs /domains/{did}/history", async () => {
const spy = mockFetch([])
await api.domainHistory(PROJECT_ID, "d1")
expect(spy).toHaveBeenCalledWith(
`/api/v1/projects/${PROJECT_ID}/domains/d1/history`,
expect.objectContaining({ method: "GET", credentials: "include" }),
)
})
})
+15
View File
@@ -3,6 +3,7 @@ import type {
AuthState,
Account, CreateAccountInput, Template, CreateTemplateInput,
Domain, CreateDomainInput, ChangesetResponse, ApplyRequest,
Schedule, Channel, CreateChannelInput, CheckRun,
} from "./types"
export class UnauthorizedError extends Error {
@@ -77,4 +78,18 @@ export const api = {
req<ChangesetResponse>(projectPath(projectId, `/domains/${id}/check`)),
applyDomain: (projectId: string, id: string, body: ApplyRequest) =>
req<ChangesetResponse>(projectPath(projectId, `/domains/${id}/apply`), { method: "POST", body: JSON.stringify(body) }),
domainHistory: (projectId: string, id: string) =>
req<CheckRun[]>(projectPath(projectId, `/domains/${id}/history`)),
getSchedule: (projectId: string) => req<Schedule>(projectPath(projectId, "/schedule")),
putSchedule: (projectId: string, input: Schedule) =>
req<Schedule>(projectPath(projectId, "/schedule"), { method: "PUT", body: JSON.stringify(input) }),
listChannels: (projectId: string) => req<Channel[]>(projectPath(projectId, "/channels")),
createChannel: (projectId: string, input: CreateChannelInput) =>
req<Channel>(projectPath(projectId, "/channels"), { method: "POST", body: JSON.stringify(input) }),
deleteChannel: (projectId: string, id: string) =>
req<void>(projectPath(projectId, `/channels/${id}`), { method: "DELETE" }),
testChannel: (projectId: string, id: string) =>
req<{ status: string }>(projectPath(projectId, `/channels/${id}/test`), { method: "POST" }),
}
+8
View File
@@ -15,6 +15,7 @@ export interface Domain {
zoneName: string
zoneId: string
templateId?: string | null // Go omitempty: поле может отсутствовать (undefined), а не null
lastCheckStatus?: string // "unknown" | "in_sync" | "drift" | "error"
}
export interface CreateDomainInput {
providerAccountId: string
@@ -23,6 +24,13 @@ export interface CreateDomainInput {
templateId?: string | null
}
export interface Schedule { intervalSeconds: number; enabled: boolean }
export interface Channel { id: string; type: string; config: object; enabled: boolean }
export interface CreateChannelInput { type: string; config: object; secret: string }
export interface CheckRun { id?: string; createdAt: string; result: object }
export interface RecordView {
kind: string // add | update | delete | in_sync
type: string
+47
View File
@@ -0,0 +1,47 @@
import { render, screen } from "@testing-library/react"
import { QueryClient, QueryClientProvider } from "@tanstack/react-query"
import { AuthProvider } from "@/auth/AuthContext"
import { api } from "@/api/client"
import { vi, beforeEach, test, expect } from "vitest"
import { DomainHistory } from "./DomainHistory"
const PROJECT_ID = "p1"
function renderComponent() {
const qc = new QueryClient()
return render(
<QueryClientProvider client={qc}>
<AuthProvider>
<DomainHistory domainId="d1" />
</AuthProvider>
</QueryClientProvider>,
)
}
beforeEach(() => {
vi.restoreAllMocks()
vi.spyOn(api.auth, "me").mockResolvedValue({
user: { id: "u1", email: "a@b.com" },
project: { id: PROJECT_ID, name: "Default" },
})
})
test("отрисовывает список проверок со сводкой updates/prunes", async () => {
vi.spyOn(api, "domainHistory").mockResolvedValue([
{ id: "r1", createdAt: "2026-07-01T10:00:00Z", result: { updates: 2, prunes: 1 } },
{ id: "r2", createdAt: "2026-06-30T10:00:00Z", result: { updates: 0, prunes: 0 } },
])
renderComponent()
expect(await screen.findByText(/updates:\s*2/i)).toBeInTheDocument()
expect(screen.getByText(/prunes:\s*1/i)).toBeInTheDocument()
})
test("пустое состояние при отсутствии истории", async () => {
vi.spyOn(api, "domainHistory").mockResolvedValue([])
renderComponent()
expect(await screen.findByText(/проверок пока нет/i)).toBeInTheDocument()
})
+66
View File
@@ -0,0 +1,66 @@
import { History, Loader2 } from "lucide-react"
import { useDomainHistory } from "@/hooks/useApi"
// check_runs.result is a provider-neutral JSON summary written by
// store.SaveCheckRun (Фаза 3, T1/T4): {"updates": N, "prunes": N}. We read it
// defensively since it's typed as `object` end-to-end and older/foreign rows
// could in principle omit a key.
function summarize(result: object): string {
const r = result as Record<string, unknown>
const updates = typeof r.updates === "number" ? r.updates : 0
const prunes = typeof r.prunes === "number" ? r.prunes : 0
return `updates: ${updates} · prunes: ${prunes}`
}
function formatTimestamp(iso: string): string {
const date = new Date(iso)
if (Number.isNaN(date.getTime())) return iso
return date.toLocaleString("ru-RU", {
day: "2-digit",
month: "2-digit",
year: "numeric",
hour: "2-digit",
minute: "2-digit",
})
}
export function DomainHistory({ domainId }: { domainId: string }) {
const history = useDomainHistory(domainId)
const runs = history.data ?? []
return (
<section className="flex flex-col gap-2">
<header className="flex items-center gap-2 px-0.5">
<History className="size-3.5 text-muted-foreground" strokeWidth={1.75} />
<h2 className="text-xs font-semibold tracking-wide text-foreground uppercase">
История проверок
</h2>
</header>
{history.isPending ? (
<div className="flex items-center gap-2 rounded-lg border border-border bg-card px-4 py-4 text-sm text-muted-foreground">
<Loader2 className="size-4 animate-spin" strokeWidth={1.75} />
Загружаю историю
</div>
) : runs.length === 0 ? (
<p className="rounded-lg border border-dashed border-border px-3 py-2.5 text-sm text-muted-foreground/70">
Проверок пока нет история появится после первого запуска планировщика.
</p>
) : (
<div className="flex flex-col divide-y divide-border rounded-lg bg-card ring-1 ring-border">
{runs.map((run, i) => (
<div
key={run.id ?? i}
className="flex items-center justify-between gap-3 px-3 py-2 text-sm"
>
<span className="font-dns text-xs text-muted-foreground">
{formatTimestamp(run.createdAt)}
</span>
<span className="font-dns text-xs text-foreground">{summarize(run.result)}</span>
</div>
))}
</div>
)}
</section>
)
}
+3 -1
View File
@@ -1,6 +1,6 @@
import type { ReactNode } from "react"
import { NavLink, useLocation, useNavigate } from "react-router-dom"
import { Globe, LogOut, Users, LayoutTemplate, SquareTerminal } from "lucide-react"
import { BellRing, CalendarClock, Globe, LogOut, Users, LayoutTemplate, SquareTerminal } from "lucide-react"
import { useAuth } from "@/auth/AuthContext"
import { Button } from "@/components/ui/button"
import { cn } from "@/lib/utils"
@@ -9,6 +9,8 @@ const NAV = [
{ to: "/domains", label: "Domains", icon: Globe },
{ to: "/accounts", label: "Accounts", icon: Users },
{ to: "/templates", label: "Templates", icon: LayoutTemplate },
{ to: "/schedule", label: "Schedule", icon: CalendarClock },
{ to: "/channels", label: "Channels", icon: BellRing },
] as const
export function Layout({ children }: { children: ReactNode }) {
+39
View File
@@ -0,0 +1,39 @@
import { render, screen } from "@testing-library/react"
import { test, expect } from "vitest"
import { StatusBadge } from "./StatusBadge"
test("in_sync — emerald, текст «in sync»", () => {
render(<StatusBadge status="in_sync" />)
expect(screen.getByText("in sync")).toBeInTheDocument()
const badge = screen.getByText("in sync").closest('[data-slot="status-badge"]')
expect(badge).toHaveAttribute("data-status", "in_sync")
expect(screen.getByTestId("status-dot")).toHaveStyle({ background: "var(--diff-add)" })
})
test("drift — amber, текст «drift»", () => {
render(<StatusBadge status="drift" />)
expect(screen.getByText("drift")).toBeInTheDocument()
expect(screen.getByTestId("status-dot")).toHaveStyle({ background: "var(--diff-update)" })
})
test("error — rose, текст «error»", () => {
render(<StatusBadge status="error" />)
expect(screen.getByText("error")).toBeInTheDocument()
expect(screen.getByTestId("status-dot")).toHaveStyle({ background: "var(--diff-delete)" })
})
test("unknown — muted, текст «unknown»", () => {
render(<StatusBadge status="unknown" />)
expect(screen.getByText("unknown")).toBeInTheDocument()
expect(screen.getByTestId("status-dot")).toHaveStyle({ background: "var(--diff-readonly)" })
})
test("отсутствие статуса трактуется как unknown", () => {
render(<StatusBadge />)
expect(screen.getByText("unknown")).toBeInTheDocument()
})
test("неизвестное значение статуса не падает и рендерит unknown", () => {
render(<StatusBadge status="bogus" />)
expect(screen.getByText("unknown")).toBeInTheDocument()
})
+42
View File
@@ -0,0 +1,42 @@
import { Badge } from "@/components/ui/badge"
import { cn } from "@/lib/utils"
// Mirrors backend check status (store.Domain.LastCheckStatus / T4-T5):
// unknown | in_sync | drift | error. Colors reuse the diff-* tokens already
// established for the domain-diff console so a drifted zone reads the same
// "amber" whether you're looking at the list or the diff view.
export type CheckStatus = "unknown" | "in_sync" | "drift" | "error"
const STATUS_META: Record<CheckStatus, { label: string; color: string }> = {
in_sync: { label: "in sync", color: "var(--diff-add)" },
drift: { label: "drift", color: "var(--diff-update)" },
error: { label: "error", color: "var(--diff-delete)" },
unknown: { label: "unknown", color: "var(--diff-readonly)" },
}
function resolveStatus(status?: string): CheckStatus {
if (status === "in_sync" || status === "drift" || status === "error") return status
return "unknown"
}
export function StatusBadge({ status, className }: { status?: string; className?: string }) {
const resolved = resolveStatus(status)
const meta = STATUS_META[resolved]
return (
<Badge
variant="outline"
data-slot="status-badge"
data-status={resolved}
className={cn("font-dns gap-1.5 border-border text-foreground", className)}
>
<span
data-testid="status-dot"
className="size-1.5 shrink-0 rounded-full"
style={{ background: meta.color }}
aria-hidden
/>
{meta.label}
</Badge>
)
}
+69 -1
View File
@@ -1,7 +1,7 @@
import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query"
import { api } from "@/api/client"
import { useAuth } from "@/auth/AuthContext"
import type { CreateAccountInput, CreateTemplateInput, ApplyRequest, Project } from "@/api/types"
import type { CreateAccountInput, CreateTemplateInput, ApplyRequest, Project, Schedule, CreateChannelInput } from "@/api/types"
function requireProjectId(project: Project | null): string {
if (!project) throw new Error("no active project")
@@ -141,3 +141,71 @@ export function useApplyDomain(id: string) {
onSuccess: () => qc.invalidateQueries({ queryKey: ["check", project?.id, id] }),
})
}
export function useDomainHistory(id: string) {
const { project } = useAuth()
return useQuery({
queryKey: ["domainHistory", project?.id, id],
queryFn: () => api.domainHistory(project!.id, id),
enabled: !!project && !!id,
})
}
export function useSchedule() {
const { project } = useAuth()
return useQuery({
queryKey: ["schedule", project?.id],
queryFn: () => api.getSchedule(project!.id),
enabled: !!project,
})
}
export function useUpdateSchedule() {
const { project } = useAuth()
const qc = useQueryClient()
return useMutation({
mutationFn: (input: Schedule) => {
const pid = requireProjectId(project)
return api.putSchedule(pid, input)
},
onSuccess: () => qc.invalidateQueries({ queryKey: ["schedule", project?.id] }),
})
}
export function useChannels() {
const { project } = useAuth()
return useQuery({
queryKey: ["channels", project?.id],
queryFn: () => api.listChannels(project!.id),
enabled: !!project,
})
}
export function useCreateChannel() {
const { project } = useAuth()
const qc = useQueryClient()
return useMutation({
mutationFn: (input: CreateChannelInput) => {
const pid = requireProjectId(project)
return api.createChannel(pid, input)
},
onSuccess: () => qc.invalidateQueries({ queryKey: ["channels", project?.id] }),
})
}
export function useDeleteChannel() {
const { project } = useAuth()
const qc = useQueryClient()
return useMutation({
mutationFn: (id: string) => {
const pid = requireProjectId(project)
return api.deleteChannel(pid, id)
},
onSuccess: () => qc.invalidateQueries({ queryKey: ["channels", project?.id] }),
})
}
export function useTestChannel() {
const { project } = useAuth()
return useMutation({
mutationFn: (id: string) => {
const pid = requireProjectId(project)
return api.testChannel(pid, id)
},
})
}
+146
View File
@@ -0,0 +1,146 @@
import { render, screen, waitFor } from "@testing-library/react"
import userEvent from "@testing-library/user-event"
import { MemoryRouter } from "react-router-dom"
import { QueryClient, QueryClientProvider } from "@tanstack/react-query"
import { ChannelsPage } from "./ChannelsPage"
import { AuthProvider } from "@/auth/AuthContext"
import { api } from "@/api/client"
import { vi, beforeEach, test, expect } from "vitest"
import type { Channel } from "@/api/types"
const PROJECT_ID = "p1"
const channels: Channel[] = [
{ id: "c1", type: "telegram", config: { chat_id: "123456" }, enabled: true },
{ id: "c2", type: "webhook", config: { url: "https://hooks.example.com/x" }, enabled: false },
]
function renderPage() {
const qc = new QueryClient()
return render(
<QueryClientProvider client={qc}>
<AuthProvider>
<MemoryRouter initialEntries={["/channels"]}>
<ChannelsPage />
</MemoryRouter>
</AuthProvider>
</QueryClientProvider>,
)
}
beforeEach(() => {
vi.restoreAllMocks()
vi.spyOn(api.auth, "me").mockResolvedValue({
user: { id: "u1", email: "a@b.com" },
project: { id: PROJECT_ID, name: "Default" },
})
vi.spyOn(api, "listChannels").mockResolvedValue(channels)
})
test("отрисовывает список каналов без секрета", async () => {
renderPage()
expect(await screen.findByText("telegram")).toBeInTheDocument()
expect(screen.getByText("webhook")).toBeInTheDocument()
expect(screen.getByText(/123456/)).toBeInTheDocument()
expect(screen.getByText(/hooks\.example\.com/)).toBeInTheDocument()
expect(document.body.textContent).not.toMatch(/bot_token/i)
expect(screen.queryByDisplayValue(/123456/)).not.toBeInTheDocument()
})
test("создание telegram-канала собирает config.chat_id + secret=bot_token", async () => {
const createSpy = vi.spyOn(api, "createChannel").mockResolvedValue({
id: "c3", type: "telegram", config: { chat_id: "999" }, enabled: true,
})
const user = userEvent.setup()
renderPage()
await screen.findByText("telegram")
await user.click(screen.getByRole("combobox", { name: /тип канала/i }))
await user.click(await screen.findByRole("option", { name: /telegram/i }))
await user.type(screen.getByLabelText(/chat id/i), "999")
await user.type(screen.getByLabelText(/bot token/i), "SECRET_TOKEN")
await user.click(screen.getByRole("button", { name: /добавить канал/i }))
await waitFor(() =>
expect(createSpy).toHaveBeenCalledWith(PROJECT_ID, {
type: "telegram",
config: { chat_id: "999" },
secret: "SECRET_TOKEN",
}),
)
expect(document.body.textContent).not.toMatch(/SECRET_TOKEN/)
})
test("создание webhook-канала собирает config.url без секрета", async () => {
const createSpy = vi.spyOn(api, "createChannel").mockResolvedValue({
id: "c4", type: "webhook", config: { url: "https://hooks.example.com/y" }, enabled: true,
})
const user = userEvent.setup()
renderPage()
await screen.findByText("telegram")
await user.click(screen.getByRole("combobox", { name: /тип канала/i }))
await user.click(await screen.findByRole("option", { name: /webhook/i }))
await user.type(screen.getByLabelText(/url/i), "https://hooks.example.com/y")
await user.click(screen.getByRole("button", { name: /добавить канал/i }))
await waitFor(() =>
expect(createSpy).toHaveBeenCalledWith(PROJECT_ID, {
type: "webhook",
config: { url: "https://hooks.example.com/y" },
secret: "",
}),
)
})
test("удаление канала вызывает api.deleteChannel", async () => {
const deleteSpy = vi.spyOn(api, "deleteChannel").mockResolvedValue(undefined)
vi.spyOn(window, "confirm").mockReturnValue(true)
const user = userEvent.setup()
renderPage()
await screen.findByText("telegram")
await user.click(screen.getByRole("button", { name: /удалить канал telegram/i }))
await waitFor(() => expect(deleteSpy).toHaveBeenCalledWith(PROJECT_ID, "c1"))
})
test("кнопка «Тест» вызывает api.testChannel", async () => {
const testSpy = vi.spyOn(api, "testChannel").mockResolvedValue({ status: "ok" })
const user = userEvent.setup()
renderPage()
await screen.findByText("telegram")
const testButtons = screen.getAllByRole("button", { name: /тест/i })
await user.click(testButtons[0])
await waitFor(() => expect(testSpy).toHaveBeenCalledWith(PROJECT_ID, "c1"))
})
test("ошибка тест-отправки отображается как alert", async () => {
vi.spyOn(api, "testChannel").mockRejectedValue(new Error("Канал не отвечает"))
const user = userEvent.setup()
renderPage()
await screen.findByText("telegram")
const testButtons = screen.getAllByRole("button", { name: /тест/i })
await user.click(testButtons[0])
expect(await screen.findByRole("alert")).toHaveTextContent("Канал не отвечает")
})
test("пустое состояние при отсутствии каналов", async () => {
vi.spyOn(api, "listChannels").mockResolvedValue([])
renderPage()
expect(await screen.findByText(/каналов пока нет/i)).toBeInTheDocument()
})
+342
View File
@@ -0,0 +1,342 @@
import { useId, useState } from "react"
import { Controller, useForm, useWatch } from "react-hook-form"
import { zodResolver } from "@hookform/resolvers/zod"
import { z } from "zod"
import { Inbox, Loader2, Plus, Send, Trash2 } from "lucide-react"
import { Button } from "@/components/ui/button"
import { Input } from "@/components/ui/input"
import {
Field,
FieldContent,
FieldError,
FieldGroup,
FieldLabel,
FieldSet,
} from "@/components/ui/field"
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select"
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/ui/table"
import { useChannels, useCreateChannel, useDeleteChannel, useTestChannel } from "@/hooks/useApi"
import { cn } from "@/lib/utils"
import type { Channel, CreateChannelInput } from "@/api/types"
const CHANNEL_TYPES = [
{ value: "telegram", label: "Telegram" },
{ value: "webhook", label: "Webhook" },
] as const
const channelFormSchema = z
.object({
type: z.enum(["telegram", "webhook"]),
chatId: z.string(),
botToken: z.string(),
url: z.string(),
})
.superRefine((values, ctx) => {
if (values.type === "telegram") {
if (!values.chatId.trim()) {
ctx.addIssue({ code: "custom", path: ["chatId"], message: "Укажите chat_id" })
}
if (!values.botToken.trim()) {
ctx.addIssue({ code: "custom", path: ["botToken"], message: "Укажите bot token" })
}
return
}
if (!values.url.trim()) {
ctx.addIssue({ code: "custom", path: ["url"], message: "Укажите URL" })
return
}
try {
new URL(values.url)
} catch {
ctx.addIssue({ code: "custom", path: ["url"], message: "Некорректный URL, включая протокол http(s)://" })
}
})
type ChannelForm = z.infer<typeof channelFormSchema>
const EMPTY_FORM: ChannelForm = { type: "telegram", chatId: "", botToken: "", url: "" }
// Channel.config is a generic `object` end-to-end (T5/T7) — it never carries
// the secret (bot_token/signing key), only the public half (chat_id/url), so
// rendering every key is always safe to show in the list.
function formatConfig(config: object): string {
const entries = Object.entries(config as Record<string, unknown>)
if (entries.length === 0) return "—"
return entries.map(([k, v]) => `${k}=${String(v)}`).join(" · ")
}
function ChannelForm({ onCreated }: { onCreated: () => void }) {
const createChannel = useCreateChannel()
const typeFieldId = useId()
const chatIdFieldId = useId()
const botTokenFieldId = useId()
const urlFieldId = useId()
const {
control,
handleSubmit,
reset,
formState: { errors },
} = useForm<ChannelForm>({
resolver: zodResolver(channelFormSchema),
defaultValues: EMPTY_FORM,
})
const type = useWatch({ control, name: "type" })
function onSubmit(values: ChannelForm) {
const input: CreateChannelInput =
values.type === "telegram"
? { type: "telegram", config: { chat_id: values.chatId.trim() }, secret: values.botToken.trim() }
: { type: "webhook", config: { url: values.url.trim() }, secret: "" }
createChannel.mutate(input, {
onSuccess: () => {
reset(EMPTY_FORM)
onCreated()
},
})
}
return (
<form
onSubmit={handleSubmit(onSubmit)}
noValidate
className="flex flex-col gap-4 rounded-xl border border-border bg-card/60 p-4"
>
<FieldSet className="gap-3">
<FieldGroup className="flex-row flex-wrap items-start gap-3">
<Field className="w-40">
<FieldLabel htmlFor={typeFieldId}>Тип канала</FieldLabel>
<FieldContent>
<Controller
control={control}
name="type"
render={({ field }) => (
<Select
items={CHANNEL_TYPES}
value={field.value}
onValueChange={(v) => field.onChange(v)}
>
<SelectTrigger id={typeFieldId} aria-label="Тип канала" className="w-full">
<SelectValue placeholder="Выберите тип" />
</SelectTrigger>
<SelectContent>
{CHANNEL_TYPES.map((item) => (
<SelectItem key={item.value} value={item.value}>
{item.label}
</SelectItem>
))}
</SelectContent>
</Select>
)}
/>
</FieldContent>
</Field>
{type === "telegram" ? (
<>
<Field className="w-56">
<FieldLabel htmlFor={chatIdFieldId}>Chat ID</FieldLabel>
<FieldContent>
<Controller
control={control}
name="chatId"
render={({ field }) => (
<Input
{...field}
id={chatIdFieldId}
placeholder="123456789"
className="font-dns"
aria-invalid={!!errors.chatId}
/>
)}
/>
<FieldError errors={[errors.chatId]} />
</FieldContent>
</Field>
<Field className="w-64">
<FieldLabel htmlFor={botTokenFieldId}>Bot token</FieldLabel>
<FieldContent>
<Controller
control={control}
name="botToken"
render={({ field }) => (
<Input
{...field}
id={botTokenFieldId}
type="password"
autoComplete="off"
placeholder="123456:ABC-DEF…"
className="font-dns"
aria-invalid={!!errors.botToken}
/>
)}
/>
<FieldError errors={[errors.botToken]} />
</FieldContent>
</Field>
</>
) : (
<Field className="w-80">
<FieldLabel htmlFor={urlFieldId}>URL</FieldLabel>
<FieldContent>
<Controller
control={control}
name="url"
render={({ field }) => (
<Input
{...field}
id={urlFieldId}
placeholder="https://hooks.example.com/…"
className="font-dns"
aria-invalid={!!errors.url}
/>
)}
/>
<FieldError errors={[errors.url]} />
</FieldContent>
</Field>
)}
</FieldGroup>
</FieldSet>
<div className="flex items-center justify-between gap-3 border-t border-border pt-3">
{createChannel.isError && (
<span role="alert" className="font-dns text-xs text-destructive">
{createChannel.error.message}
</span>
)}
<Button type="submit" disabled={createChannel.isPending} className="ml-auto">
{createChannel.isPending ? (
<Loader2 className="size-4 animate-spin" strokeWidth={1.75} />
) : (
<Plus className="size-4" strokeWidth={1.75} />
)}
Добавить канал
</Button>
</div>
</form>
)
}
export function ChannelsPage() {
const channels = useChannels()
const deleteChannel = useDeleteChannel()
const testChannel = useTestChannel()
const [testingId, setTestingId] = useState<string | null>(null)
const channelList = channels.data ?? []
function onDelete(channel: Channel) {
if (window.confirm(`Удалить канал «${channel.type}»? Действие необратимо.`)) {
deleteChannel.mutate(channel.id)
}
}
function onTest(channel: Channel) {
setTestingId(channel.id)
testChannel.mutate(channel.id)
}
return (
<div className="mx-auto flex max-w-4xl flex-col gap-6 px-6 py-8">
<header className="flex flex-col gap-1">
<span className="font-dns text-[11px] tracking-wider text-muted-foreground uppercase">
notifications
</span>
<h1 className="text-xl font-semibold tracking-tight text-foreground">Каналы уведомлений</h1>
</header>
<ChannelForm onCreated={() => setTestingId(null)} />
{deleteChannel.isError && (
<span role="alert" className="font-dns text-xs text-destructive">
{deleteChannel.error.message}
</span>
)}
{testChannel.isError && (
<span role="alert" className="font-dns text-xs text-destructive">
{testChannel.error.message}
</span>
)}
{channelList.length === 0 ? (
<div className="flex flex-col items-center gap-2 rounded-xl border border-dashed border-border px-4 py-12 text-center text-sm text-muted-foreground">
<Inbox className="size-6" strokeWidth={1.5} />
Каналов пока нет добавьте Telegram или Webhook выше.
</div>
) : (
<Table>
<TableHeader>
<TableRow>
<TableHead>Тип</TableHead>
<TableHead>Конфигурация</TableHead>
<TableHead>Статус</TableHead>
<TableHead className="text-right">Действия</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{channelList.map((c) => (
<TableRow key={c.id}>
<TableCell className="font-dns">{c.type}</TableCell>
<TableCell className="font-dns text-xs text-muted-foreground">
{formatConfig(c.config)}
</TableCell>
<TableCell>
<span
className={cn(
"font-dns text-xs",
c.enabled ? "text-foreground" : "text-muted-foreground",
)}
>
{c.enabled ? "включён" : "выключен"}
</span>
</TableCell>
<TableCell className="text-right">
<div className="flex justify-end gap-1.5">
<Button
variant="outline"
size="sm"
onClick={() => onTest(c)}
disabled={testChannel.isPending && testingId === c.id}
>
{testChannel.isPending && testingId === c.id ? (
<Loader2 className="size-3.5 animate-spin" strokeWidth={1.75} />
) : (
<Send className="size-3.5" strokeWidth={1.75} />
)}
Тест
</Button>
<Button
variant="destructive"
size="icon-sm"
aria-label={`Удалить канал ${c.type}`}
onClick={() => onDelete(c)}
disabled={deleteChannel.isPending}
>
<Trash2 className="size-3.5" strokeWidth={1.75} />
</Button>
</div>
</TableCell>
</TableRow>
))}
</TableBody>
</Table>
)}
</div>
)
}
+1
View File
@@ -28,6 +28,7 @@ beforeEach(() => {
user: { id: "u1", email: "a@b.com" },
project: { id: PROJECT_ID, name: "Default" },
})
vi.spyOn(api, "domainHistory").mockResolvedValue([])
})
test("apply sends applyPrunes=false by default, true only after opting in", async () => {
+3
View File
@@ -2,6 +2,7 @@ import { useId, useState } from "react"
import { useParams } from "react-router-dom"
import { AlertTriangle, Loader2, Play, RefreshCw, TriangleAlert } from "lucide-react"
import { DiffView } from "@/components/DiffView"
import { DomainHistory } from "@/components/DomainHistory"
import { Button } from "@/components/ui/button"
import { Checkbox } from "@/components/ui/checkbox"
import { Label } from "@/components/ui/label"
@@ -137,6 +138,8 @@ export function DomainDiffPage() {
</div>
</>
)}
<DomainHistory domainId={id} />
</div>
)
}
+11 -2
View File
@@ -18,8 +18,8 @@ const templates: Template[] = [
{ id: "t2", name: "Minimal", records: [], version: 1 },
]
const domains: Domain[] = [
{ id: "d1", providerAccountId: "acc1", zoneName: "example.com.", zoneId: "z1", templateId: null },
{ id: "d2", providerAccountId: "acc2", zoneName: "test.org.", zoneId: "z2", templateId: "t1" },
{ id: "d1", providerAccountId: "acc1", zoneName: "example.com.", zoneId: "z1", templateId: null, lastCheckStatus: "drift" },
{ id: "d2", providerAccountId: "acc2", zoneName: "test.org.", zoneId: "z2", templateId: "t1", lastCheckStatus: "in_sync" },
]
function renderPage() {
@@ -108,3 +108,12 @@ test("пустое состояние при отсутствии доменов
expect(await screen.findByText(/доменов пока нет/i)).toBeInTheDocument()
})
test("drift-badge отражает lastCheckStatus каждого домена", async () => {
renderPage()
await screen.findByText("example.com.")
expect(screen.getByText("drift")).toBeInTheDocument()
expect(screen.getByText("in sync")).toBeInTheDocument()
})
+5
View File
@@ -2,6 +2,7 @@ import { useState } from "react"
import { Link } from "react-router-dom"
import { Inbox, Loader2, Trash2, Upload } from "lucide-react"
import { Button } from "@/components/ui/button"
import { StatusBadge } from "@/components/StatusBadge"
import {
Select,
SelectContent,
@@ -134,6 +135,7 @@ export function DomainsPage() {
<TableHead>Zone</TableHead>
<TableHead>Учётка</TableHead>
<TableHead>Шаблон</TableHead>
<TableHead>Статус</TableHead>
<TableHead className="text-right">Действия</TableHead>
</TableRow>
</TableHeader>
@@ -167,6 +169,9 @@ export function DomainsPage() {
</SelectContent>
</Select>
</TableCell>
<TableCell>
<StatusBadge status={d.lastCheckStatus} />
</TableCell>
<TableCell className="text-right">
<div className="flex justify-end gap-1.5">
<Button variant="outline" size="sm" render={<Link to={`/domains/${d.id}`} />}>
+83
View File
@@ -0,0 +1,83 @@
import { render, screen, waitFor } from "@testing-library/react"
import userEvent from "@testing-library/user-event"
import { MemoryRouter } from "react-router-dom"
import { QueryClient, QueryClientProvider } from "@tanstack/react-query"
import { SchedulePage } from "./SchedulePage"
import { AuthProvider } from "@/auth/AuthContext"
import { api } from "@/api/client"
import { vi, beforeEach, test, expect } from "vitest"
const PROJECT_ID = "p1"
function renderPage() {
const qc = new QueryClient()
return render(
<QueryClientProvider client={qc}>
<AuthProvider>
<MemoryRouter initialEntries={["/schedule"]}>
<SchedulePage />
</MemoryRouter>
</AuthProvider>
</QueryClientProvider>,
)
}
beforeEach(() => {
vi.restoreAllMocks()
vi.spyOn(api.auth, "me").mockResolvedValue({
user: { id: "u1", email: "a@b.com" },
project: { id: PROJECT_ID, name: "Default" },
})
vi.spyOn(api, "getSchedule").mockResolvedValue({ intervalSeconds: 120, enabled: true })
})
test("показывает текущий интервал и enabled", async () => {
renderPage()
const intervalInput = await screen.findByLabelText(/интервал/i)
expect(intervalInput).toHaveValue(120)
expect(screen.getByRole("checkbox", { name: /включ/i })).toBeChecked()
})
test("сохранение вызывает updateSchedule с новыми значениями", async () => {
const updateSpy = vi.spyOn(api, "putSchedule").mockResolvedValue({ intervalSeconds: 300, enabled: false })
const user = userEvent.setup()
renderPage()
const intervalInput = await screen.findByLabelText(/интервал/i)
await user.clear(intervalInput)
await user.type(intervalInput, "300")
await user.click(screen.getByRole("checkbox", { name: /включ/i }))
await user.click(screen.getByRole("button", { name: /сохранить/i }))
await waitFor(() =>
expect(updateSpy).toHaveBeenCalledWith(PROJECT_ID, { intervalSeconds: 300, enabled: false }),
)
})
test("валидация: интервал меньше 60 блокирует сохранение", async () => {
const updateSpy = vi.spyOn(api, "putSchedule").mockResolvedValue({ intervalSeconds: 120, enabled: true })
const user = userEvent.setup()
renderPage()
const intervalInput = await screen.findByLabelText(/интервал/i)
await user.clear(intervalInput)
await user.type(intervalInput, "30")
await user.click(screen.getByRole("button", { name: /сохранить/i }))
expect(await screen.findByRole("alert")).toHaveTextContent(/60/)
expect(updateSpy).not.toHaveBeenCalled()
})
test("ошибка сохранения отображается пользователю", async () => {
vi.spyOn(api, "putSchedule").mockRejectedValue(new Error("Не удалось сохранить расписание"))
const user = userEvent.setup()
renderPage()
const intervalInput = await screen.findByLabelText(/интервал/i)
await user.clear(intervalInput)
await user.type(intervalInput, "180")
await user.click(screen.getByRole("button", { name: /сохранить/i }))
expect(await screen.findByRole("alert")).toHaveTextContent("Не удалось сохранить расписание")
})
+145
View File
@@ -0,0 +1,145 @@
import { useId } from "react"
import { Controller, useForm } from "react-hook-form"
import { zodResolver } from "@hookform/resolvers/zod"
import { z } from "zod"
import { Loader2, Save } from "lucide-react"
import { Button } from "@/components/ui/button"
import { Checkbox } from "@/components/ui/checkbox"
import { Input } from "@/components/ui/input"
import { Label } from "@/components/ui/label"
import {
Field,
FieldContent,
FieldError,
FieldGroup,
FieldLabel,
FieldSet,
} from "@/components/ui/field"
import { useSchedule, useUpdateSchedule } from "@/hooks/useApi"
import type { Schedule } from "@/api/types"
const scheduleFormSchema = z.object({
intervalSeconds: z
.number({ error: "Укажите интервал в секундах" })
.int("Интервал — целое число секунд")
.min(60, "Интервал не может быть меньше 60 секунд"),
enabled: z.boolean(),
})
type ScheduleForm = z.infer<typeof scheduleFormSchema>
const DEFAULT_SCHEDULE: Schedule = { intervalSeconds: 300, enabled: false }
// Rendered only once the current schedule is loaded, so useForm's
// defaultValues (a one-time snapshot in react-hook-form) are always seeded
// with the real values — no reset()-after-fetch race between the query and
// the form's first render.
function ScheduleFormCard({ initial }: { initial: Schedule }) {
const updateSchedule = useUpdateSchedule()
const intervalFieldId = useId()
const {
control,
handleSubmit,
formState: { errors },
} = useForm<ScheduleForm>({
resolver: zodResolver(scheduleFormSchema),
defaultValues: initial,
})
function onSubmit(values: ScheduleForm) {
updateSchedule.mutate(values)
}
return (
<form
onSubmit={handleSubmit(onSubmit)}
noValidate
className="flex flex-col gap-4 rounded-xl border border-border bg-card/60 p-4"
>
<FieldSet className="gap-3">
<FieldGroup className="gap-3">
<Field className="sm:max-w-56">
<FieldLabel htmlFor={intervalFieldId}>Интервал (секунд)</FieldLabel>
<FieldContent>
<Controller
control={control}
name="intervalSeconds"
render={({ field }) => (
<Input
{...field}
id={intervalFieldId}
type="number"
min={60}
step={1}
className="font-dns"
aria-invalid={!!errors.intervalSeconds}
onChange={(e) => field.onChange(e.target.valueAsNumber)}
/>
)}
/>
<FieldError errors={[errors.intervalSeconds]} />
</FieldContent>
</Field>
<Field>
<Controller
control={control}
name="enabled"
render={({ field }) => (
<Label className="flex items-center gap-2.5 text-sm font-normal">
<Checkbox
aria-label="Включено"
checked={field.value}
onCheckedChange={(v) => field.onChange(v === true)}
/>
Автоматические проверки включены
</Label>
)}
/>
</Field>
</FieldGroup>
</FieldSet>
<div className="flex items-center justify-between gap-3 border-t border-border pt-3">
{updateSchedule.isError && (
<span role="alert" className="font-dns text-xs text-destructive">
{updateSchedule.error.message}
</span>
)}
<Button type="submit" disabled={updateSchedule.isPending} className="ml-auto">
{updateSchedule.isPending ? (
<Loader2 className="size-4 animate-spin" strokeWidth={1.75} />
) : (
<Save className="size-4" strokeWidth={1.75} />
)}
Сохранить расписание
</Button>
</div>
</form>
)
}
export function SchedulePage() {
const schedule = useSchedule()
return (
<div className="mx-auto flex max-w-2xl flex-col gap-6 px-6 py-8">
<header className="flex flex-col gap-1">
<span className="font-dns text-[11px] tracking-wider text-muted-foreground uppercase">
scheduler
</span>
<h1 className="text-xl font-semibold tracking-tight text-foreground">Расписание проверок</h1>
</header>
{schedule.isPending ? (
<div className="flex items-center gap-2 rounded-lg border border-border bg-card px-4 py-8 text-sm text-muted-foreground">
<Loader2 className="size-4 animate-spin" strokeWidth={1.75} />
Загружаю расписание
</div>
) : (
<ScheduleFormCard initial={schedule.data ?? DEFAULT_SCHEDULE} />
)}
</div>
)
}