merge: Фаза 2 — авторизация

- internal/store: миграция sessions/password + методы users/sessions/projects
- internal/auth: argon2id пароли + session store (sha256 токена)
- internal/api: auth-хендлеры (register/login/logout/me) + cookie, RequireAuth+RequireProjectAccess middleware
- IDOR закрыт: все /projects/{pid}/* под middleware, LoadDomainFull scoped, projectID из контекста
- web: AuthContext + клиент под cookie, Login/Register, protected routes, logout, 401→/login
Финальный ревью: READY TO MERGE, IDOR закрыт end-to-end. Go 105+/15 пакетов, web 58 тестов.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-07-03 21:40:45 +07:00
55 changed files with 3228 additions and 246 deletions
+8 -1
View File
@@ -5,10 +5,12 @@ import (
"log" "log"
"net/http" "net/http"
"strings" "strings"
"time"
"github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/pgxpool"
"github.com/vasyakrg/dns-autoresolver/internal/api" "github.com/vasyakrg/dns-autoresolver/internal/api"
"github.com/vasyakrg/dns-autoresolver/internal/auth"
"github.com/vasyakrg/dns-autoresolver/internal/config" "github.com/vasyakrg/dns-autoresolver/internal/config"
"github.com/vasyakrg/dns-autoresolver/internal/crypto" "github.com/vasyakrg/dns-autoresolver/internal/crypto"
"github.com/vasyakrg/dns-autoresolver/internal/provider/registry" "github.com/vasyakrg/dns-autoresolver/internal/provider/registry"
@@ -18,6 +20,10 @@ import (
"github.com/vasyakrg/dns-autoresolver/internal/web" "github.com/vasyakrg/dns-autoresolver/internal/web"
) )
// sessionTTL is how long a login session cookie remains valid before the
// user must re-authenticate.
const sessionTTL = 720 * time.Hour
// isAPIPath reports whether path must be routed to the API router rather // isAPIPath reports whether path must be routed to the API router rather
// than the SPA. "/api" (no trailing slash) counts as an API path too — // than the SPA. "/api" (no trailing slash) counts as an API path too —
// only strings.HasPrefix(path, "/api/") would otherwise miss it and fall // only strings.HasPrefix(path, "/api/") would otherwise miss it and fall
@@ -46,12 +52,13 @@ func main() {
log.Fatalf("cipher: %v", err) log.Fatalf("cipher: %v", err)
} }
st := store.New(pool) st := store.New(pool)
sessions := auth.NewSessions(st, sessionTTL)
reg := registry.New() reg := registry.New()
reg.Register(selectel.New()) reg.Register(selectel.New())
svc := service.New(st, st, reg, cipher) svc := service.New(st, st, reg, cipher)
a := &api.API{Svc: svc, Store: st, Cipher: cipher, Reg: reg} a := &api.API{Svc: svc, Store: st, Cipher: cipher, Reg: reg, Auth: st, Sessions: sessions}
apiRouter := api.NewRouter(a) apiRouter := api.NewRouter(a)
webHandler, err := web.Handler() webHandler, err := web.Handler()
+3 -3
View File
@@ -9,6 +9,7 @@ require (
github.com/pressly/goose/v3 v3.27.2 github.com/pressly/goose/v3 v3.27.2
github.com/testcontainers/testcontainers-go v0.43.0 github.com/testcontainers/testcontainers-go v0.43.0
github.com/testcontainers/testcontainers-go/modules/postgres v0.43.0 github.com/testcontainers/testcontainers-go/modules/postgres v0.43.0
golang.org/x/crypto v0.53.0
) )
require ( require (
@@ -64,9 +65,8 @@ require (
go.opentelemetry.io/otel/metric v1.43.0 // indirect go.opentelemetry.io/otel/metric v1.43.0 // indirect
go.opentelemetry.io/otel/trace v1.43.0 // indirect go.opentelemetry.io/otel/trace v1.43.0 // indirect
go.uber.org/multierr v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect
golang.org/x/crypto v0.52.0 // indirect
golang.org/x/sync v0.21.0 // indirect golang.org/x/sync v0.21.0 // indirect
golang.org/x/sys v0.45.0 // indirect golang.org/x/sys v0.46.0 // indirect
golang.org/x/text v0.37.0 // indirect golang.org/x/text v0.38.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )
+8 -8
View File
@@ -150,19 +150,19 @@ go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0= go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
golang.org/x/crypto v0.52.0 h1:RMs7fP2rXdep0CftQlK8Uf+kibLm7qkCcradZWYz988= golang.org/x/crypto v0.53.0 h1:QZ4Muo8THX6CizN2vPPd5fBGHyogrdK9fG4wLPFUsto=
golang.org/x/crypto v0.52.0/go.mod h1:1QgfPxDqh0T2M/elOJtp9RvuR95kVjir0e6/BvEmGbc= golang.org/x/crypto v0.53.0/go.mod h1:DNLU434OwVakk9PzuwV8w62mAJpRJL3vsgcfp4Qnsio=
golang.org/x/sync v0.21.0 h1:HLII4xRRTtCRkxYp4HNFF0Js/Og6q2i++KXbg0gHCwM= golang.org/x/sync v0.21.0 h1:HLII4xRRTtCRkxYp4HNFF0Js/Og6q2i++KXbg0gHCwM=
golang.org/x/sync v0.21.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sync v0.21.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY= golang.org/x/sys v0.46.0 h1:noSf2Fq6F8DBgS+LysIkx7rIExoNHJsxOAtPp4rthXw=
golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/sys v0.46.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4= golang.org/x/term v0.44.0 h1:0rLvDRCtNj0gZkyIXhCyOb2OAzEhLVqc4B+hrsBhrmc=
golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk= golang.org/x/term v0.44.0/go.mod h1:7ze4MdzUzLXpSAoFP1H0bOI9aXDqveSvatT5vKcFh2Y=
golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc= golang.org/x/text v0.38.0 h1:sXmwo9DwP3OK9EZ7PqAdaooSGozfl/3a6/xJcbzPRhE=
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38= golang.org/x/text v0.38.0/go.mod h1:YXZt3QhHUKYT53r2lLKFIVi6Ao1jdzrTR/KQ09qyxF4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
+40 -2
View File
@@ -3,6 +3,7 @@ package api
import ( import (
"context" "context"
"net/http" "net/http"
"time"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware" "github.com/go-chi/chi/v5/middleware"
@@ -17,8 +18,8 @@ import (
// CheckApplier is the service surface the API depends on. // CheckApplier is the service surface the API depends on.
type CheckApplier interface { type CheckApplier interface {
Check(ctx context.Context, domainID uuid.UUID) (diff.Changeset, error) Check(ctx context.Context, projectID, domainID uuid.UUID) (diff.Changeset, error)
Apply(ctx context.Context, domainID uuid.UUID, req service.ApplyRequest) (diff.Changeset, error) Apply(ctx context.Context, projectID, domainID uuid.UUID, req service.ApplyRequest) (diff.Changeset, error)
} }
// TenantStore is the narrow persistence surface the CRUD handlers depend on. // TenantStore is the narrow persistence surface the CRUD handlers depend on.
@@ -54,12 +55,36 @@ type ProviderRegistry interface {
ByName(name string) (provider.Provider, error) ByName(name string) (provider.Provider, error)
} }
// AuthStore is the persistence surface the auth handlers depend on.
// *store.Store satisfies it directly (see internal/store/store.go); tests
// can supply their own mock.
type AuthStore interface {
RegisterUser(ctx context.Context, email, passwordHash string) (store.User, store.Project, error)
GetUserByEmail(ctx context.Context, email string) (store.User, error)
GetUserByID(ctx context.Context, userID uuid.UUID) (store.User, error)
GetUserProject(ctx context.Context, userID uuid.UUID) (store.Project, error)
// GetProjectOwned looks up projectID and returns it only if it's owned by
// userID — RequireProjectAccess uses this to reject foreign/nonexistent
// projects with 404 before any handler runs.
GetProjectOwned(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error)
}
// SessionManager creates/validates/destroys login sessions. *auth.Sessions
// satisfies it directly (see internal/auth/session.go).
type SessionManager interface {
Create(ctx context.Context, userID uuid.UUID) (string, time.Time, error)
Validate(ctx context.Context, token string) (uuid.UUID, error)
Destroy(ctx context.Context, token string) error
}
// API holds handler dependencies. // API holds handler dependencies.
type API struct { type API struct {
Svc CheckApplier Svc CheckApplier
Store TenantStore Store TenantStore
Cipher Cipher Cipher Cipher
Reg ProviderRegistry Reg ProviderRegistry
Auth AuthStore
Sessions SessionManager
} }
func NewRouter(a *API) http.Handler { func NewRouter(a *API) http.Handler {
@@ -67,7 +92,20 @@ func NewRouter(a *API) http.Handler {
r.Use(middleware.RequestID) r.Use(middleware.RequestID)
r.Use(middleware.Recoverer) r.Use(middleware.Recoverer)
r.Route("/api/v1/auth", func(r chi.Router) {
r.Post("/register", a.handleRegister)
r.Post("/login", a.handleLogin)
r.Group(func(r chi.Router) {
r.Use(a.RequireAuth)
r.Post("/logout", a.handleLogout)
r.Get("/me", a.handleMe)
})
})
r.Route("/api/v1/projects/{pid}", func(r chi.Router) { r.Route("/api/v1/projects/{pid}", func(r chi.Router) {
r.Use(a.RequireAuth)
r.Use(a.RequireProjectAccess)
r.Route("/domains", func(r chi.Router) { r.Route("/domains", func(r chi.Router) {
r.Post("/", a.handleCreateDomain) r.Post("/", a.handleCreateDomain)
r.Get("/", a.handleListDomains) r.Get("/", a.handleListDomains)
+13 -8
View File
@@ -20,18 +20,23 @@ type mockCheckApplier struct {
lastReq service.ApplyRequest lastReq service.ApplyRequest
} }
func (m *mockCheckApplier) Check(context.Context, uuid.UUID) (diff.Changeset, error) { func (m *mockCheckApplier) Check(context.Context, uuid.UUID, uuid.UUID) (diff.Changeset, error) {
d := model.Record{Type: model.A, Name: "a.example.com.", TTL: 300, Values: []string{"1.1.1.1"}} d := model.Record{Type: model.A, Name: "a.example.com.", TTL: 300, Values: []string{"1.1.1.1"}}
return diff.Changeset{Diffs: []diff.RecordDiff{{Kind: diff.Add, Type: d.Type, Name: d.Name, Desired: &d}}}, nil return diff.Changeset{Diffs: []diff.RecordDiff{{Kind: diff.Add, Type: d.Type, Name: d.Name, Desired: &d}}}, nil
} }
func (m *mockCheckApplier) Apply(_ context.Context, _ uuid.UUID, req service.ApplyRequest) (diff.Changeset, error) { func (m *mockCheckApplier) Apply(_ context.Context, _, _ uuid.UUID, req service.ApplyRequest) (diff.Changeset, error) {
m.lastReq = req m.lastReq = req
return diff.Changeset{}, nil return diff.Changeset{}, nil
} }
// newTestAPI wires a fixed authenticated user who owns whatever project id
// is requested (via alwaysOwnedAuthStore/alwaysValidSessions in
// middleware_test.go) — these tests exercise check/apply behavior past the
// RequireAuth/RequireProjectAccess boundary, which is covered separately by
// middleware_test.go's own tests and the IDOR regression.
func newTestAPI() (*API, *mockCheckApplier) { func newTestAPI() (*API, *mockCheckApplier) {
m := &mockCheckApplier{} m := &mockCheckApplier{}
return &API{Svc: m}, m // остальные зависимости (store/cipher) nil — CRUD-тесты добавит реализатор return &API{Svc: m, Auth: alwaysOwnedAuthStore(), Sessions: alwaysValidSessions(uuid.New())}, m
} }
func TestCheckEndpoint(t *testing.T) { func TestCheckEndpoint(t *testing.T) {
@@ -39,7 +44,7 @@ func TestCheckEndpoint(t *testing.T) {
router := NewRouter(a) router := NewRouter(a)
did := uuid.New().String() did := uuid.New().String()
req := httptest.NewRequest(http.MethodGet, req := requestWithSessionCookie(http.MethodGet,
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/check", nil) "/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/check", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -62,7 +67,7 @@ func TestApplyDefaultsPruneFalse(t *testing.T) {
did := uuid.New().String() did := uuid.New().String()
body := `{"applyUpdates":true}` // applyPrunes отсутствует → false body := `{"applyUpdates":true}` // applyPrunes отсутствует → false
req := httptest.NewRequest(http.MethodPost, req := requestWithSessionCookie(http.MethodPost,
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply", "/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply",
strings.NewReader(body)) strings.NewReader(body))
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -81,7 +86,7 @@ func TestApplyEmptyBodyOK(t *testing.T) {
router := NewRouter(a) router := NewRouter(a)
did := uuid.New().String() did := uuid.New().String()
req := httptest.NewRequest(http.MethodPost, req := requestWithSessionCookie(http.MethodPost,
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply", nil) "/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -100,7 +105,7 @@ func TestApplyMalformedBody(t *testing.T) {
did := uuid.New().String() did := uuid.New().String()
body := `{"applyUpdates":` body := `{"applyUpdates":`
req := httptest.NewRequest(http.MethodPost, req := requestWithSessionCookie(http.MethodPost,
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply", "/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply",
strings.NewReader(body)) strings.NewReader(body))
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -114,7 +119,7 @@ func TestApplyMalformedBody(t *testing.T) {
func TestApplyBadUUID(t *testing.T) { func TestApplyBadUUID(t *testing.T) {
a, _ := newTestAPI() a, _ := newTestAPI()
router := NewRouter(a) router := NewRouter(a)
req := httptest.NewRequest(http.MethodPost, req := requestWithSessionCookie(http.MethodPost,
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/not-a-uuid/apply", "/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/not-a-uuid/apply",
bytes.NewReader([]byte(`{}`))) bytes.NewReader([]byte(`{}`)))
w := httptest.NewRecorder() w := httptest.NewRecorder()
+182
View File
@@ -0,0 +1,182 @@
package api
import (
"errors"
"log"
"net/http"
"strings"
"time"
"github.com/vasyakrg/dns-autoresolver/internal/auth"
"github.com/vasyakrg/dns-autoresolver/internal/store"
)
const sessionCookieName = "session"
// dummyPasswordHash is a valid-format argon2 hash with no real matching
// password. handleLogin runs VerifyPassword against it whenever the email
// lookup fails, so a login attempt for an unregistered email takes the same
// wall-clock time as one for a registered email with a wrong password —
// otherwise the timing difference would let an attacker enumerate which
// emails are registered.
var dummyPasswordHash string
func init() {
h, err := auth.HashPassword("dns-autoresolver-timing-guard-dummy")
if err != nil {
panic("api: failed to initialize dummy password hash: " + err.Error())
}
dummyPasswordHash = h
}
// normalizeEmail trims surrounding whitespace and lowercases the email so
// storage and lookup are always consistent regardless of how the client
// cased or padded the input.
func normalizeEmail(email string) string {
return strings.ToLower(strings.TrimSpace(email))
}
func setSessionCookie(w http.ResponseWriter, token string, exp time.Time) {
http.SetCookie(w, &http.Cookie{
Name: sessionCookieName, Value: token, Path: "/",
HttpOnly: true, Secure: true, SameSite: http.SameSiteLaxMode, Expires: exp,
})
}
func clearSessionCookie(w http.ResponseWriter) {
http.SetCookie(w, &http.Cookie{
Name: sessionCookieName, Value: "", Path: "/",
HttpOnly: true, Secure: true, SameSite: http.SameSiteLaxMode, MaxAge: -1,
})
}
func (a *API) handleRegister(w http.ResponseWriter, r *http.Request) {
var req registerRequest
if !decodeBody(w, r, &req) {
return
}
email := normalizeEmail(req.Email)
if email == "" || req.Password == "" {
writeErr(w, http.StatusBadRequest, "email and password are required")
return
}
// Server-side minimum length is the source of truth: the client-side
// zod min(8) check is UX only and can be bypassed with a direct POST.
if len(req.Password) < 8 {
writeErr(w, http.StatusBadRequest, "password must be at least 8 characters")
return
}
hash, err := auth.HashPassword(req.Password)
if err != nil {
log.Printf("api: hash password failed: %v", err)
writeErr(w, http.StatusInternalServerError, "internal error")
return
}
u, p, err := a.Auth.RegisterUser(r.Context(), email, hash)
if err != nil {
if errors.Is(err, store.ErrEmailTaken) {
writeErr(w, http.StatusConflict, "email already registered")
return
}
log.Printf("api: register user failed: %v", err)
writeErr(w, http.StatusInternalServerError, "internal error")
return
}
token, exp, err := a.Sessions.Create(r.Context(), u.ID)
if err != nil {
log.Printf("api: create session failed: %v", err)
writeErr(w, http.StatusInternalServerError, "internal error")
return
}
setSessionCookie(w, token, exp)
writeJSON(w, http.StatusOK, toAuthResponse(u, p))
}
// invalidCredentials is deliberately identical for "no such user" and "wrong
// password" — disclosing which one occurred would let an attacker enumerate
// registered emails.
func invalidCredentials(w http.ResponseWriter) {
writeErr(w, http.StatusUnauthorized, "invalid credentials")
}
func (a *API) handleLogin(w http.ResponseWriter, r *http.Request) {
var req loginRequest
if !decodeBody(w, r, &req) {
return
}
email := normalizeEmail(req.Email)
u, err := a.Auth.GetUserByEmail(r.Context(), email)
if err != nil {
// No such user: still spend the argon2 verification cost against a
// fixed dummy hash (see dummyPasswordHash) so this path isn't
// distinguishable by timing from a wrong-password rejection below.
_, _ = auth.VerifyPassword(dummyPasswordHash, req.Password)
invalidCredentials(w)
return
}
ok, err := auth.VerifyPassword(u.PasswordHash, req.Password)
if err != nil || !ok {
invalidCredentials(w)
return
}
p, err := a.Auth.GetUserProject(r.Context(), u.ID)
if err != nil {
log.Printf("api: get user project failed: %v", err)
writeErr(w, http.StatusInternalServerError, "internal error")
return
}
token, exp, err := a.Sessions.Create(r.Context(), u.ID)
if err != nil {
log.Printf("api: create session failed: %v", err)
writeErr(w, http.StatusInternalServerError, "internal error")
return
}
setSessionCookie(w, token, exp)
writeJSON(w, http.StatusOK, toAuthResponse(u, p))
}
func (a *API) handleLogout(w http.ResponseWriter, r *http.Request) {
if c, err := r.Cookie(sessionCookieName); err == nil && c.Value != "" {
if err := a.Sessions.Destroy(r.Context(), c.Value); err != nil {
log.Printf("api: destroy session failed: %v", err)
}
}
clearSessionCookie(w)
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
}
// handleMe returns the authenticated caller's identity + default project.
// The user ID comes from the request context, set by RequireAuth after
// validating the session cookie.
func (a *API) handleMe(w http.ResponseWriter, r *http.Request) {
userID, ok := userIDFrom(r.Context())
if !ok {
writeErr(w, http.StatusUnauthorized, "authentication required")
return
}
u, err := a.Auth.GetUserByID(r.Context(), userID)
if err != nil {
log.Printf("api: get user by id failed: %v", err)
writeErr(w, http.StatusInternalServerError, "internal error")
return
}
p, err := a.Auth.GetUserProject(r.Context(), userID)
if err != nil {
log.Printf("api: get user project failed: %v", err)
writeErr(w, http.StatusInternalServerError, "internal error")
return
}
writeJSON(w, http.StatusOK, toAuthResponse(u, p))
}
+432
View File
@@ -0,0 +1,432 @@
package api
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/google/uuid"
"github.com/vasyakrg/dns-autoresolver/internal/auth"
"github.com/vasyakrg/dns-autoresolver/internal/store"
)
// --- mocks ---
type mockAuthStore struct {
registerUserFn func(ctx context.Context, email, passwordHash string) (store.User, store.Project, error)
getUserByEmailFn func(ctx context.Context, email string) (store.User, error)
getUserByIDFn func(ctx context.Context, userID uuid.UUID) (store.User, error)
getUserProjectFn func(ctx context.Context, userID uuid.UUID) (store.Project, error)
getProjectOwnedFn func(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error)
}
func (m *mockAuthStore) RegisterUser(ctx context.Context, email, passwordHash string) (store.User, store.Project, error) {
return m.registerUserFn(ctx, email, passwordHash)
}
func (m *mockAuthStore) GetUserByEmail(ctx context.Context, email string) (store.User, error) {
return m.getUserByEmailFn(ctx, email)
}
func (m *mockAuthStore) GetUserByID(ctx context.Context, userID uuid.UUID) (store.User, error) {
return m.getUserByIDFn(ctx, userID)
}
func (m *mockAuthStore) GetUserProject(ctx context.Context, userID uuid.UUID) (store.Project, error) {
return m.getUserProjectFn(ctx, userID)
}
func (m *mockAuthStore) GetProjectOwned(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error) {
return m.getProjectOwnedFn(ctx, projectID, userID)
}
type mockSessionManager struct {
createFn func(ctx context.Context, userID uuid.UUID) (string, time.Time, error)
validateFn func(ctx context.Context, token string) (uuid.UUID, error)
destroyCalled bool
destroyToken string
destroyErr error
}
func (m *mockSessionManager) Create(ctx context.Context, userID uuid.UUID) (string, time.Time, error) {
return m.createFn(ctx, userID)
}
func (m *mockSessionManager) Validate(ctx context.Context, token string) (uuid.UUID, error) {
if m.validateFn != nil {
return m.validateFn(ctx, token)
}
return uuid.Nil, nil
}
func (m *mockSessionManager) Destroy(ctx context.Context, token string) error {
m.destroyCalled = true
m.destroyToken = token
return m.destroyErr
}
func newTestAuthAPI() (*API, *mockAuthStore, *mockSessionManager) {
authStore := &mockAuthStore{}
sessions := &mockSessionManager{
createFn: func(_ context.Context, userID uuid.UUID) (string, time.Time, error) {
return "test-token", time.Now().Add(time.Hour), nil
},
}
return &API{Auth: authStore, Sessions: sessions}, authStore, sessions
}
func findCookie(resp *http.Response, name string) *http.Cookie {
for _, c := range resp.Cookies() {
if c.Name == name {
return c
}
}
return nil
}
// --- register ---
func TestAuthRegister_Success(t *testing.T) {
a, authStore, _ := newTestAuthAPI()
userID := uuid.New()
projectID := uuid.New()
authStore.registerUserFn = func(_ context.Context, email, passwordHash string) (store.User, store.Project, error) {
if passwordHash == "" {
t.Fatal("expected non-empty password hash passed to RegisterUser")
}
return store.User{ID: userID, Email: email, PasswordHash: passwordHash},
store.Project{ID: projectID, UserID: userID, Name: "default"}, nil
}
router := NewRouter(a)
body := `{"email":"alice@example.com","password":"correct-horse"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status %d, body %s", w.Code, w.Body.String())
}
resp := w.Result()
cookie := findCookie(resp, sessionCookieName)
if cookie == nil {
t.Fatal("expected session cookie to be set")
}
if cookie.Value != "test-token" {
t.Fatalf("unexpected cookie value: %q", cookie.Value)
}
if strings.Contains(w.Body.String(), "password") {
t.Fatalf("response body must not contain password/password_hash: %s", w.Body.String())
}
var got authResponse
if err := json.Unmarshal(w.Body.Bytes(), &got); err != nil {
t.Fatal(err)
}
if got.User.ID != userID.String() || got.User.Email != "alice@example.com" {
t.Fatalf("unexpected user in response: %+v", got.User)
}
if got.Project.ID != projectID.String() || got.Project.Name != "default" {
t.Fatalf("unexpected project in response: %+v", got.Project)
}
}
// TestAuthRegister_NormalizesEmail verifies the fix for the email-consistency
// gap: a padded/mixed-case email is trimmed+lowercased before it reaches the
// store, so storage and later lookups are always consistent.
func TestAuthRegister_NormalizesEmail(t *testing.T) {
a, authStore, _ := newTestAuthAPI()
userID := uuid.New()
var gotEmail string
authStore.registerUserFn = func(_ context.Context, email, passwordHash string) (store.User, store.Project, error) {
gotEmail = email
return store.User{ID: userID, Email: email}, store.Project{ID: uuid.New(), UserID: userID, Name: "default"}, nil
}
router := NewRouter(a)
body := `{"email":" Alice@X.com ","password":"correct-horse"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status %d, body %s", w.Code, w.Body.String())
}
if gotEmail != "alice@x.com" {
t.Fatalf("expected normalized email passed to RegisterUser, got %q", gotEmail)
}
}
// TestAuthRegister_ShortPasswordReturns400 verifies the server-side password
// length floor: the client's zod min(8) is UX only and can be bypassed with a
// direct POST, so the handler itself must reject a password under 8 chars
// before ever calling RegisterUser.
func TestAuthRegister_ShortPasswordReturns400(t *testing.T) {
a, authStore, _ := newTestAuthAPI()
registerCalled := false
authStore.registerUserFn = func(context.Context, string, string) (store.User, store.Project, error) {
registerCalled = true
return store.User{}, store.Project{}, nil
}
router := NewRouter(a)
body := `{"email":"alice@example.com","password":"short"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status %d, body %s", w.Code, w.Body.String())
}
var got map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &got); err != nil {
t.Fatal(err)
}
if got["error"] != "password must be at least 8 characters" {
t.Fatalf(`expected error "password must be at least 8 characters", got %q`, got["error"])
}
if registerCalled {
t.Fatal("expected RegisterUser not to be called for a too-short password")
}
if findCookie(w.Result(), sessionCookieName) != nil {
t.Fatal("expected no session cookie on rejected register")
}
}
// TestAuthRegister_DuplicateEmailReturns409 verifies the fix for the
// duplicate-registration gap: RegisterUser reporting store.ErrEmailTaken
// must surface as 409, not a generic 500.
func TestAuthRegister_DuplicateEmailReturns409(t *testing.T) {
a, authStore, _ := newTestAuthAPI()
authStore.registerUserFn = func(context.Context, string, string) (store.User, store.Project, error) {
return store.User{}, store.Project{}, store.ErrEmailTaken
}
router := NewRouter(a)
body := `{"email":"dup@example.com","password":"correct-horse"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusConflict {
t.Fatalf("status %d, body %s", w.Code, w.Body.String())
}
var got map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &got); err != nil {
t.Fatal(err)
}
if got["error"] != "email already registered" {
t.Fatalf(`expected error "email already registered", got %q`, got["error"])
}
}
// --- login ---
func TestAuthLogin_CorrectPassword(t *testing.T) {
a, authStore, _ := newTestAuthAPI()
hash, err := auth.HashPassword("correct-horse")
if err != nil {
t.Fatal(err)
}
userID := uuid.New()
projectID := uuid.New()
authStore.getUserByEmailFn = func(_ context.Context, email string) (store.User, error) {
return store.User{ID: userID, Email: email, PasswordHash: hash}, nil
}
authStore.getUserProjectFn = func(_ context.Context, uid uuid.UUID) (store.Project, error) {
return store.Project{ID: projectID, UserID: uid, Name: "default"}, nil
}
router := NewRouter(a)
body := `{"email":"bob@example.com","password":"correct-horse"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status %d, body %s", w.Code, w.Body.String())
}
if findCookie(w.Result(), sessionCookieName) == nil {
t.Fatal("expected session cookie to be set")
}
if strings.Contains(w.Body.String(), "password") {
t.Fatalf("response body must not contain password/password_hash: %s", w.Body.String())
}
}
// TestAuthLogin_NormalizesEmail verifies that a login for a padded/mixed-case
// email reaches GetUserByEmail already trimmed+lowercased — the same
// normalization applied on register, so "Alice@X.com" at registration and
// "alice@x.com" at login resolve to the same account.
func TestAuthLogin_NormalizesEmail(t *testing.T) {
a, authStore, _ := newTestAuthAPI()
hash, err := auth.HashPassword("correct-horse")
if err != nil {
t.Fatal(err)
}
userID := uuid.New()
var gotEmail string
authStore.getUserByEmailFn = func(_ context.Context, email string) (store.User, error) {
gotEmail = email
return store.User{ID: userID, Email: email, PasswordHash: hash}, nil
}
authStore.getUserProjectFn = func(_ context.Context, uid uuid.UUID) (store.Project, error) {
return store.Project{ID: uuid.New(), UserID: uid, Name: "default"}, nil
}
router := NewRouter(a)
body := `{"email":" Alice@X.com ","password":"correct-horse"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status %d, body %s", w.Code, w.Body.String())
}
if gotEmail != "alice@x.com" {
t.Fatalf("expected normalized email passed to GetUserByEmail, got %q", gotEmail)
}
}
func TestAuthLogin_WrongPassword(t *testing.T) {
a, authStore, _ := newTestAuthAPI()
hash, err := auth.HashPassword("correct-horse")
if err != nil {
t.Fatal(err)
}
authStore.getUserByEmailFn = func(_ context.Context, email string) (store.User, error) {
return store.User{ID: uuid.New(), Email: email, PasswordHash: hash}, nil
}
router := NewRouter(a)
body := `{"email":"bob@example.com","password":"wrong-password"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
assertInvalidCredentials(t, w)
}
func TestAuthLogin_UnknownEmail(t *testing.T) {
a, authStore, _ := newTestAuthAPI()
authStore.getUserByEmailFn = func(_ context.Context, email string) (store.User, error) {
return store.User{}, errNoRowsForTest
}
router := NewRouter(a)
body := `{"email":"nobody@example.com","password":"whatever"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", strings.NewReader(body))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
assertInvalidCredentials(t, w)
}
func assertInvalidCredentials(t *testing.T, w *httptest.ResponseRecorder) {
t.Helper()
if w.Code != http.StatusUnauthorized {
t.Fatalf("status %d, body %s", w.Code, w.Body.String())
}
var got map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &got); err != nil {
t.Fatal(err)
}
if got["error"] != "invalid credentials" {
t.Fatalf(`expected error "invalid credentials", got %q`, got["error"])
}
if findCookie(w.Result(), sessionCookieName) != nil {
t.Fatal("expected no session cookie on failed login")
}
}
// --- logout ---
func TestAuthLogout_ClearsSessionAndDestroys(t *testing.T) {
a, _, sessions := newTestAuthAPI()
router := NewRouter(a)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/logout", nil)
req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: "some-token"})
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status %d, body %s", w.Code, w.Body.String())
}
if !sessions.destroyCalled {
t.Fatal("expected Sessions.Destroy to be called")
}
if sessions.destroyToken != "some-token" {
t.Fatalf("expected Destroy called with cookie token, got %q", sessions.destroyToken)
}
cookie := findCookie(w.Result(), sessionCookieName)
if cookie == nil {
t.Fatal("expected a session cookie in the response (clearing cookie)")
}
if cookie.MaxAge > 0 {
t.Fatalf("expected cookie MaxAge <= 0 (cleared), got %d", cookie.MaxAge)
}
}
// --- me ---
// TestAuthMe_ReturnsRealEmail verifies the fix for the /me gap: the handler
// now resolves the authenticated user via GetUserByID and returns their real
// email, instead of leaving it blank.
func TestAuthMe_ReturnsRealEmail(t *testing.T) {
a, authStore, sessions := newTestAuthAPI()
userID := uuid.New()
projectID := uuid.New()
authStore.getUserByIDFn = func(_ context.Context, id uuid.UUID) (store.User, error) {
if id != userID {
t.Fatalf("unexpected user id: %s", id)
}
return store.User{ID: userID, Email: "me@example.com"}, nil
}
authStore.getUserProjectFn = func(_ context.Context, uid uuid.UUID) (store.Project, error) {
return store.Project{ID: projectID, UserID: uid, Name: "default"}, nil
}
// /me is now behind RequireAuth: the session cookie must resolve to
// userID via Sessions.Validate rather than being injected directly.
sessions.validateFn = func(context.Context, string) (uuid.UUID, error) {
return userID, nil
}
router := NewRouter(a)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/me", nil)
req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: "some-valid-token"})
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status %d, body %s", w.Code, w.Body.String())
}
var got authResponse
if err := json.Unmarshal(w.Body.Bytes(), &got); err != nil {
t.Fatal(err)
}
if got.User.ID != userID.String() || got.User.Email != "me@example.com" {
t.Fatalf("unexpected user in /me response: %+v", got.User)
}
if got.Project.ID != projectID.String() {
t.Fatalf("unexpected project in /me response: %+v", got.Project)
}
}
// errNoRowsForTest stands in for a "not found" error a real store would
// return (e.g. pgx.ErrNoRows) — handlers must not distinguish it from any
// other GetUserByEmail failure in the response they send.
var errNoRowsForTest = &notFoundErr{}
type notFoundErr struct{}
func (*notFoundErr) Error() string { return "not found" }
+38 -1
View File
@@ -1,6 +1,43 @@
package api package api
import "github.com/vasyakrg/dns-autoresolver/internal/diff" import (
"github.com/vasyakrg/dns-autoresolver/internal/diff"
"github.com/vasyakrg/dns-autoresolver/internal/store"
)
type registerRequest struct {
Email string `json:"email"`
Password string `json:"password"`
}
type loginRequest struct {
Email string `json:"email"`
Password string `json:"password"`
}
// userResponse and projectResponse deliberately expose only id/email and
// id/name — password_hash must never reach a client response.
type userResponse struct {
ID string `json:"id"`
Email string `json:"email"`
}
type projectResponse struct {
ID string `json:"id"`
Name string `json:"name"`
}
type authResponse struct {
User userResponse `json:"user"`
Project projectResponse `json:"project"`
}
func toAuthResponse(u store.User, p store.Project) authResponse {
return authResponse{
User: userResponse{ID: u.ID.String(), Email: u.Email},
Project: projectResponse{ID: p.ID.String(), Name: p.Name},
}
}
type applyRequest struct { type applyRequest struct {
ApplyUpdates bool `json:"applyUpdates"` ApplyUpdates bool `json:"applyUpdates"`
+8 -2
View File
@@ -24,12 +24,15 @@ func writeErr(w http.ResponseWriter, status int, msg string) {
} }
func (a *API) handleCheck(w http.ResponseWriter, r *http.Request) { func (a *API) handleCheck(w http.ResponseWriter, r *http.Request) {
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
did, err := uuid.Parse(chi.URLParam(r, "did")) did, err := uuid.Parse(chi.URLParam(r, "did"))
if err != nil { if err != nil {
writeErr(w, http.StatusBadRequest, "invalid domain id") writeErr(w, http.StatusBadRequest, "invalid domain id")
return return
} }
cs, err := a.Svc.Check(r.Context(), did) cs, err := a.Svc.Check(r.Context(), pid, did)
if err != nil { if err != nil {
log.Printf("api: check failed: %v", err) log.Printf("api: check failed: %v", err)
writeErr(w, http.StatusInternalServerError, "internal error") writeErr(w, http.StatusInternalServerError, "internal error")
@@ -39,6 +42,9 @@ func (a *API) handleCheck(w http.ResponseWriter, r *http.Request) {
} }
func (a *API) handleApply(w http.ResponseWriter, r *http.Request) { func (a *API) handleApply(w http.ResponseWriter, r *http.Request) {
// pid is guaranteed present and owned by the caller — RequireProjectAccess
// validated it before this handler ever runs.
pid, _ := projectIDFrom(r.Context())
did, err := uuid.Parse(chi.URLParam(r, "did")) did, err := uuid.Parse(chi.URLParam(r, "did"))
if err != nil { if err != nil {
writeErr(w, http.StatusBadRequest, "invalid domain id") writeErr(w, http.StatusBadRequest, "invalid domain id")
@@ -54,7 +60,7 @@ func (a *API) handleApply(w http.ResponseWriter, r *http.Request) {
return return
} }
} }
cs, err := a.Svc.Apply(r.Context(), did, service.ApplyRequest{ cs, err := a.Svc.Apply(r.Context(), pid, did, service.ApplyRequest{
ApplyUpdates: req.ApplyUpdates, ApplyPrunes: req.ApplyPrunes, ApplyUpdates: req.ApplyUpdates, ApplyPrunes: req.ApplyPrunes,
}) })
if err != nil { if err != nil {
+75
View File
@@ -0,0 +1,75 @@
package api
import (
"context"
"net/http"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
)
type ctxKey string
const (
ctxUserID ctxKey = "userID"
ctxProjectID ctxKey = "projectID"
)
// RequireAuth validates the session cookie and, on success, stores the
// authenticated user's ID in the request context (see userIDFrom). Any
// failure — missing cookie or an invalid/expired token — is rejected with
// 401 before reaching the wrapped handler.
func (a *API) RequireAuth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := r.Cookie(sessionCookieName)
if err != nil {
writeErr(w, http.StatusUnauthorized, "unauthorized")
return
}
uid, err := a.Sessions.Validate(r.Context(), c.Value)
if err != nil {
writeErr(w, http.StatusUnauthorized, "unauthorized")
return
}
ctx := context.WithValue(r.Context(), ctxUserID, uid)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// RequireProjectAccess verifies that the {pid} URL segment names a project
// owned by the authenticated user (set by RequireAuth, which must run
// first) and, on success, stores the project ID in the request context (see
// projectIDFrom). A project that doesn't exist or isn't owned by the caller
// is rejected with 404 — not 403 — so a caller can't distinguish "not yours"
// from "doesn't exist" (closes IDOR by never confirming another tenant's
// project exists).
func (a *API) RequireProjectAccess(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
uid, ok := userIDFrom(r.Context())
if !ok {
writeErr(w, http.StatusUnauthorized, "unauthorized")
return
}
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid project id")
return
}
if _, err := a.Auth.GetProjectOwned(r.Context(), pid, uid); err != nil {
writeErr(w, http.StatusNotFound, "not found")
return
}
ctx := context.WithValue(r.Context(), ctxProjectID, pid)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func userIDFrom(ctx context.Context) (uuid.UUID, bool) {
v, ok := ctx.Value(ctxUserID).(uuid.UUID)
return v, ok
}
func projectIDFrom(ctx context.Context) (uuid.UUID, bool) {
v, ok := ctx.Value(ctxProjectID).(uuid.UUID)
return v, ok
}
+265
View File
@@ -0,0 +1,265 @@
package api
import (
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/vasyakrg/dns-autoresolver/internal/diff"
"github.com/vasyakrg/dns-autoresolver/internal/service"
"github.com/vasyakrg/dns-autoresolver/internal/store"
)
// --- shared test doubles (also used by api_test.go / tenant_test.go) ---
// stubSessions is a configurable SessionManager test double.
type stubSessions struct {
validateFn func(ctx context.Context, token string) (uuid.UUID, error)
}
func (s stubSessions) Create(context.Context, uuid.UUID) (string, time.Time, error) {
return "stub-token", time.Now().Add(time.Hour), nil
}
func (s stubSessions) Validate(ctx context.Context, token string) (uuid.UUID, error) {
return s.validateFn(ctx, token)
}
func (s stubSessions) Destroy(context.Context, string) error { return nil }
// alwaysValidSessions builds a stubSessions whose Validate always succeeds
// with the given fixed user ID — for tests that only care about behavior
// past the RequireAuth boundary.
func alwaysValidSessions(uid uuid.UUID) stubSessions {
return stubSessions{validateFn: func(context.Context, string) (uuid.UUID, error) { return uid, nil }}
}
// stubAuthStore is a configurable AuthStore test double. Methods besides
// GetProjectOwned return zero values — tests exercising register/login/me
// use their own dedicated mock (mockAuthStore in auth_test.go).
type stubAuthStore struct {
getProjectOwnedFn func(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error)
}
func (s stubAuthStore) RegisterUser(context.Context, string, string) (store.User, store.Project, error) {
return store.User{}, store.Project{}, nil
}
func (s stubAuthStore) GetUserByEmail(context.Context, string) (store.User, error) {
return store.User{}, nil
}
func (s stubAuthStore) GetUserByID(context.Context, uuid.UUID) (store.User, error) {
return store.User{}, nil
}
func (s stubAuthStore) GetUserProject(context.Context, uuid.UUID) (store.Project, error) {
return store.Project{}, nil
}
func (s stubAuthStore) GetProjectOwned(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error) {
return s.getProjectOwnedFn(ctx, projectID, userID)
}
// alwaysOwnedAuthStore builds a stubAuthStore whose GetProjectOwned always
// succeeds, treating the caller as the owner of whatever project id is
// requested — for tests that only care about behavior past the
// RequireProjectAccess boundary (CRUD/check/apply tests).
func alwaysOwnedAuthStore() stubAuthStore {
return stubAuthStore{getProjectOwnedFn: func(_ context.Context, pid, uid uuid.UUID) (store.Project, error) {
return store.Project{ID: pid, UserID: uid}, nil
}}
}
// requestWithSessionCookie builds an httptest request carrying a session
// cookie so it clears RequireAuth in tests using stubSessions/alwaysValidSessions
// (which ignore the token value and validate unconditionally).
func requestWithSessionCookie(method, url string, body io.Reader) *http.Request {
req := httptest.NewRequest(method, url, body)
req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: "test-session-token"})
return req
}
// recordingCheckApplier is a CheckApplier test double that records whether
// Check/Apply were invoked — used by the IDOR regression test to assert the
// service layer is never reached for a project the caller doesn't own.
type recordingCheckApplier struct {
checkCalled bool
applyCalled bool
}
func (r *recordingCheckApplier) Check(context.Context, uuid.UUID, uuid.UUID) (diff.Changeset, error) {
r.checkCalled = true
return diff.Changeset{}, nil
}
func (r *recordingCheckApplier) Apply(context.Context, uuid.UUID, uuid.UUID, service.ApplyRequest) (diff.Changeset, error) {
r.applyCalled = true
return diff.Changeset{}, nil
}
// --- RequireAuth ---
func TestRequireAuth_NoCookie_Returns401(t *testing.T) {
a := &API{Sessions: stubSessions{validateFn: func(context.Context, string) (uuid.UUID, error) {
t.Fatal("Validate must not be called when there is no cookie")
return uuid.Nil, nil
}}}
nextCalled := false
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { nextCalled = true })
req := httptest.NewRequest(http.MethodGet, "/whatever", nil)
w := httptest.NewRecorder()
a.RequireAuth(next).ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", w.Code)
}
if nextCalled {
t.Fatal("next must not be called without a session cookie")
}
}
func TestRequireAuth_InvalidToken_Returns401(t *testing.T) {
a := &API{Sessions: stubSessions{validateFn: func(context.Context, string) (uuid.UUID, error) {
return uuid.Nil, errors.New("invalid or expired session")
}}}
nextCalled := false
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { nextCalled = true })
req := requestWithSessionCookie(http.MethodGet, "/whatever", nil)
w := httptest.NewRecorder()
a.RequireAuth(next).ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", w.Code)
}
if nextCalled {
t.Fatal("next must not be called when Sessions.Validate fails")
}
}
func TestRequireAuth_Success_CallsNextWithUserID(t *testing.T) {
uid := uuid.New()
a := &API{Sessions: alwaysValidSessions(uid)}
var gotUID uuid.UUID
var gotOK bool
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotUID, gotOK = userIDFrom(r.Context())
w.WriteHeader(http.StatusOK)
})
req := requestWithSessionCookie(http.MethodGet, "/whatever", nil)
w := httptest.NewRecorder()
a.RequireAuth(next).ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200 (next called), got %d", w.Code)
}
if !gotOK || gotUID != uid {
t.Fatalf("expected userIDFrom to yield %s, got %s (ok=%v)", uid, gotUID, gotOK)
}
}
// --- RequireProjectAccess ---
func TestRequireProjectAccess_NotOwned_Returns404(t *testing.T) {
a := &API{Auth: stubAuthStore{getProjectOwnedFn: func(context.Context, uuid.UUID, uuid.UUID) (store.Project, error) {
return store.Project{}, errors.New("no rows")
}}}
nextCalled := false
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { nextCalled = true })
router := chi.NewRouter()
router.With(a.RequireProjectAccess).Get("/projects/{pid}", next.ServeHTTP)
req := httptest.NewRequest(http.MethodGet, "/projects/"+uuid.New().String(), nil)
req = req.WithContext(context.WithValue(req.Context(), ctxUserID, uuid.New()))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404 for an unowned/foreign project, got %d body %s", w.Code, w.Body.String())
}
if nextCalled {
t.Fatal("next must not be called when the project isn't owned by the caller")
}
}
func TestRequireProjectAccess_Success_CallsNextWithProjectID(t *testing.T) {
a := &API{Auth: alwaysOwnedAuthStore()}
pid := uuid.New()
uid := uuid.New()
var gotPID uuid.UUID
var gotOK bool
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPID, gotOK = projectIDFrom(r.Context())
w.WriteHeader(http.StatusOK)
})
router := chi.NewRouter()
router.With(a.RequireProjectAccess).Get("/projects/{pid}", next.ServeHTTP)
req := httptest.NewRequest(http.MethodGet, "/projects/"+pid.String(), nil)
req = req.WithContext(context.WithValue(req.Context(), ctxUserID, uid))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200 (next called), got %d body %s", w.Code, w.Body.String())
}
if !gotOK || gotPID != pid {
t.Fatalf("expected projectIDFrom to yield %s, got %s (ok=%v)", pid, gotPID, gotOK)
}
}
// --- IDOR regression ---
// TestIDOR_CheckForeignProject_Returns404AndServiceNotCalled is the IDOR
// regression required by Task 4: a request for project B's domain, made
// with user A's session (who owns project A, not B), must be rejected by
// RequireProjectAccess with 404 before service.Check is ever invoked — a
// caller must never be able to check/apply another tenant's domain by
// guessing/leaking its ids.
func TestIDOR_CheckForeignProject_Returns404AndServiceNotCalled(t *testing.T) {
userA := uuid.New()
pidA := uuid.New()
pidB := uuid.New() // owned by a different user; not A
svc := &recordingCheckApplier{}
auth := stubAuthStore{getProjectOwnedFn: func(_ context.Context, pid, uid uuid.UUID) (store.Project, error) {
if pid == pidA && uid == userA {
return store.Project{ID: pidA, UserID: userA}, nil
}
return store.Project{}, errors.New("not found")
}}
a := &API{Svc: svc, Auth: auth, Sessions: alwaysValidSessions(userA)}
router := NewRouter(a)
did := uuid.New().String()
req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+pidB.String()+"/domains/"+did+"/check", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404 for user A requesting project B's domain, got %d body %s", w.Code, w.Body.String())
}
if svc.checkCalled {
t.Fatal("service.Check must not be called when the caller doesn't own the project")
}
// Sanity check: the same user against their own project succeeds and
// does reach the service — proving the 404 above is really about
// project ownership, not e.g. a broken route.
req2 := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+pidA.String()+"/domains/"+did+"/check", nil)
w2 := httptest.NewRecorder()
router.ServeHTTP(w2, req2)
if w2.Code != http.StatusOK {
t.Fatalf("expected 200 for user A requesting their own project, got %d body %s", w2.Code, w2.Body.String())
}
if !svc.checkCalled {
t.Fatal("expected service.Check to be called for the caller's own project")
}
}
+36 -60
View File
@@ -24,11 +24,9 @@ func decodeBody(w http.ResponseWriter, r *http.Request, v any) bool {
// --- accounts --- // --- accounts ---
func (a *API) handleCreateAccount(w http.ResponseWriter, r *http.Request) { func (a *API) handleCreateAccount(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid")) // pid is guaranteed present and owned by the caller — RequireProjectAccess
if err != nil { // validated it before this handler ever runs.
writeErr(w, http.StatusBadRequest, "invalid project id") pid, _ := projectIDFrom(r.Context())
return
}
var req accountRequest var req accountRequest
if !decodeBody(w, r, &req) { if !decodeBody(w, r, &req) {
return return
@@ -53,11 +51,9 @@ func (a *API) handleCreateAccount(w http.ResponseWriter, r *http.Request) {
} }
func (a *API) handleListAccounts(w http.ResponseWriter, r *http.Request) { func (a *API) handleListAccounts(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid")) // pid is guaranteed present and owned by the caller — RequireProjectAccess
if err != nil { // validated it before this handler ever runs.
writeErr(w, http.StatusBadRequest, "invalid project id") pid, _ := projectIDFrom(r.Context())
return
}
accs, err := a.Store.ListAccounts(r.Context(), pid) accs, err := a.Store.ListAccounts(r.Context(), pid)
if err != nil { if err != nil {
log.Printf("api: list accounts failed: %v", err) log.Printf("api: list accounts failed: %v", err)
@@ -72,11 +68,9 @@ func (a *API) handleListAccounts(w http.ResponseWriter, r *http.Request) {
} }
func (a *API) handleDeleteAccount(w http.ResponseWriter, r *http.Request) { func (a *API) handleDeleteAccount(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid")) // pid is guaranteed present and owned by the caller — RequireProjectAccess
if err != nil { // validated it before this handler ever runs.
writeErr(w, http.StatusBadRequest, "invalid project id") pid, _ := projectIDFrom(r.Context())
return
}
aid, err := uuid.Parse(chi.URLParam(r, "aid")) aid, err := uuid.Parse(chi.URLParam(r, "aid"))
if err != nil { if err != nil {
writeErr(w, http.StatusBadRequest, "invalid account id") writeErr(w, http.StatusBadRequest, "invalid account id")
@@ -93,11 +87,9 @@ func (a *API) handleDeleteAccount(w http.ResponseWriter, r *http.Request) {
// handleImportZones lists zones from the provider for the given account and // handleImportZones lists zones from the provider for the given account and
// creates one domain per zone (template_id left unset). // creates one domain per zone (template_id left unset).
func (a *API) handleImportZones(w http.ResponseWriter, r *http.Request) { func (a *API) handleImportZones(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid")) // pid is guaranteed present and owned by the caller — RequireProjectAccess
if err != nil { // validated it before this handler ever runs.
writeErr(w, http.StatusBadRequest, "invalid project id") pid, _ := projectIDFrom(r.Context())
return
}
aid, err := uuid.Parse(chi.URLParam(r, "aid")) aid, err := uuid.Parse(chi.URLParam(r, "aid"))
if err != nil { if err != nil {
writeErr(w, http.StatusBadRequest, "invalid account id") writeErr(w, http.StatusBadRequest, "invalid account id")
@@ -145,11 +137,9 @@ func (a *API) handleImportZones(w http.ResponseWriter, r *http.Request) {
// --- templates --- // --- templates ---
func (a *API) handleCreateTemplate(w http.ResponseWriter, r *http.Request) { func (a *API) handleCreateTemplate(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid")) // pid is guaranteed present and owned by the caller — RequireProjectAccess
if err != nil { // validated it before this handler ever runs.
writeErr(w, http.StatusBadRequest, "invalid project id") pid, _ := projectIDFrom(r.Context())
return
}
var req templateRequest var req templateRequest
if !decodeBody(w, r, &req) { if !decodeBody(w, r, &req) {
return return
@@ -169,11 +159,9 @@ func (a *API) handleCreateTemplate(w http.ResponseWriter, r *http.Request) {
} }
func (a *API) handleListTemplates(w http.ResponseWriter, r *http.Request) { func (a *API) handleListTemplates(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid")) // pid is guaranteed present and owned by the caller — RequireProjectAccess
if err != nil { // validated it before this handler ever runs.
writeErr(w, http.StatusBadRequest, "invalid project id") pid, _ := projectIDFrom(r.Context())
return
}
tpls, err := a.Store.ListTemplates(r.Context(), pid) tpls, err := a.Store.ListTemplates(r.Context(), pid)
if err != nil { if err != nil {
log.Printf("api: list templates failed: %v", err) log.Printf("api: list templates failed: %v", err)
@@ -188,11 +176,9 @@ func (a *API) handleListTemplates(w http.ResponseWriter, r *http.Request) {
} }
func (a *API) handleUpdateTemplate(w http.ResponseWriter, r *http.Request) { func (a *API) handleUpdateTemplate(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid")) // pid is guaranteed present and owned by the caller — RequireProjectAccess
if err != nil { // validated it before this handler ever runs.
writeErr(w, http.StatusBadRequest, "invalid project id") pid, _ := projectIDFrom(r.Context())
return
}
tid, err := uuid.Parse(chi.URLParam(r, "tid")) tid, err := uuid.Parse(chi.URLParam(r, "tid"))
if err != nil { if err != nil {
writeErr(w, http.StatusBadRequest, "invalid template id") writeErr(w, http.StatusBadRequest, "invalid template id")
@@ -217,11 +203,9 @@ func (a *API) handleUpdateTemplate(w http.ResponseWriter, r *http.Request) {
} }
func (a *API) handleDeleteTemplate(w http.ResponseWriter, r *http.Request) { func (a *API) handleDeleteTemplate(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid")) // pid is guaranteed present and owned by the caller — RequireProjectAccess
if err != nil { // validated it before this handler ever runs.
writeErr(w, http.StatusBadRequest, "invalid project id") pid, _ := projectIDFrom(r.Context())
return
}
tid, err := uuid.Parse(chi.URLParam(r, "tid")) tid, err := uuid.Parse(chi.URLParam(r, "tid"))
if err != nil { if err != nil {
writeErr(w, http.StatusBadRequest, "invalid template id") writeErr(w, http.StatusBadRequest, "invalid template id")
@@ -238,11 +222,9 @@ func (a *API) handleDeleteTemplate(w http.ResponseWriter, r *http.Request) {
// --- domains --- // --- domains ---
func (a *API) handleCreateDomain(w http.ResponseWriter, r *http.Request) { func (a *API) handleCreateDomain(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid")) // pid is guaranteed present and owned by the caller — RequireProjectAccess
if err != nil { // validated it before this handler ever runs.
writeErr(w, http.StatusBadRequest, "invalid project id") pid, _ := projectIDFrom(r.Context())
return
}
var req domainRequest var req domainRequest
if !decodeBody(w, r, &req) { if !decodeBody(w, r, &req) {
return return
@@ -286,11 +268,9 @@ func (a *API) handleCreateDomain(w http.ResponseWriter, r *http.Request) {
} }
func (a *API) handleListDomains(w http.ResponseWriter, r *http.Request) { func (a *API) handleListDomains(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid")) // pid is guaranteed present and owned by the caller — RequireProjectAccess
if err != nil { // validated it before this handler ever runs.
writeErr(w, http.StatusBadRequest, "invalid project id") pid, _ := projectIDFrom(r.Context())
return
}
doms, err := a.Store.ListDomains(r.Context(), pid) doms, err := a.Store.ListDomains(r.Context(), pid)
if err != nil { if err != nil {
log.Printf("api: list domains failed: %v", err) log.Printf("api: list domains failed: %v", err)
@@ -308,11 +288,9 @@ func (a *API) handleListDomains(w http.ResponseWriter, r *http.Request) {
// check/apply a domain — this is what makes an imported domain (which // check/apply a domain — this is what makes an imported domain (which
// starts with template_id=NULL) checkable, closing the import→check loop. // starts with template_id=NULL) checkable, closing the import→check loop.
func (a *API) handleSetDomainTemplate(w http.ResponseWriter, r *http.Request) { func (a *API) handleSetDomainTemplate(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid")) // pid is guaranteed present and owned by the caller — RequireProjectAccess
if err != nil { // validated it before this handler ever runs.
writeErr(w, http.StatusBadRequest, "invalid project id") pid, _ := projectIDFrom(r.Context())
return
}
did, err := uuid.Parse(chi.URLParam(r, "did")) did, err := uuid.Parse(chi.URLParam(r, "did"))
if err != nil { if err != nil {
writeErr(w, http.StatusBadRequest, "invalid domain id") writeErr(w, http.StatusBadRequest, "invalid domain id")
@@ -339,11 +317,9 @@ func (a *API) handleSetDomainTemplate(w http.ResponseWriter, r *http.Request) {
} }
func (a *API) handleDeleteDomain(w http.ResponseWriter, r *http.Request) { func (a *API) handleDeleteDomain(w http.ResponseWriter, r *http.Request) {
pid, err := uuid.Parse(chi.URLParam(r, "pid")) // pid is guaranteed present and owned by the caller — RequireProjectAccess
if err != nil { // validated it before this handler ever runs.
writeErr(w, http.StatusBadRequest, "invalid project id") pid, _ := projectIDFrom(r.Context())
return
}
did, err := uuid.Parse(chi.URLParam(r, "did")) did, err := uuid.Parse(chi.URLParam(r, "did"))
if err != nil { if err != nil {
writeErr(w, http.StatusBadRequest, "invalid domain id") writeErr(w, http.StatusBadRequest, "invalid domain id")
+29 -19
View File
@@ -132,7 +132,9 @@ func (m *mockTenantStore) ImportDomains(_ context.Context, projectID, accountID
type mockCipher struct{} type mockCipher struct{}
func (mockCipher) Encrypt(plaintext []byte) (string, error) { return "ENC(" + string(plaintext) + ")", nil } func (mockCipher) Encrypt(plaintext []byte) (string, error) {
return "ENC(" + string(plaintext) + ")", nil
}
func (mockCipher) Decrypt(enc string) ([]byte, error) { func (mockCipher) Decrypt(enc string) ([]byte, error) {
return []byte(strings.TrimSuffix(strings.TrimPrefix(enc, "ENC("), ")")), nil return []byte(strings.TrimSuffix(strings.TrimPrefix(enc, "ENC("), ")")), nil
} }
@@ -160,9 +162,17 @@ func (mockProvider) ApplyChanges(context.Context, provider.Credentials, string,
return nil return nil
} }
// newTenantTestAPI wires a fixed authenticated user who owns whatever
// project id is requested (alwaysOwnedAuthStore/alwaysValidSessions, see
// middleware_test.go) — these tests exercise CRUD behavior past the
// RequireAuth/RequireProjectAccess boundary, which has its own dedicated
// coverage in middleware_test.go.
func newTenantTestAPI() (*API, *mockTenantStore) { func newTenantTestAPI() (*API, *mockTenantStore) {
ts := &mockTenantStore{} ts := &mockTenantStore{}
a := &API{Store: ts, Cipher: mockCipher{}, Reg: &mockRegistry{}} a := &API{
Store: ts, Cipher: mockCipher{}, Reg: &mockRegistry{},
Auth: alwaysOwnedAuthStore(), Sessions: alwaysValidSessions(uuid.New()),
}
return a, ts return a, ts
} }
@@ -173,7 +183,7 @@ func TestCreateAccount_SecretEncryptedAndNotInResponse(t *testing.T) {
router := NewRouter(a) router := NewRouter(a)
body := `{"provider":"selectel","secret":"super-secret-token","comment":"prod"}` body := `{"provider":"selectel","secret":"super-secret-token","comment":"prod"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts", strings.NewReader(body)) req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts", strings.NewReader(body))
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -211,7 +221,7 @@ func TestListAccounts_NoSecretsInResponse(t *testing.T) {
} }
router := NewRouter(a) router := NewRouter(a)
req := httptest.NewRequest(http.MethodGet, "/api/v1/projects/"+testPID+"/accounts", nil) req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+testPID+"/accounts", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -234,7 +244,7 @@ func TestDeleteAccount_BadUUID(t *testing.T) {
a, _ := newTenantTestAPI() a, _ := newTenantTestAPI()
router := NewRouter(a) router := NewRouter(a)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/projects/"+testPID+"/accounts/not-a-uuid", nil) req := requestWithSessionCookie(http.MethodDelete, "/api/v1/projects/"+testPID+"/accounts/not-a-uuid", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -250,7 +260,7 @@ func TestCreateTemplate_SavesRecords(t *testing.T) {
router := NewRouter(a) router := NewRouter(a)
body := `{"name":"base","records":[{"type":"A","name":"@","ttl":300,"values":["1.2.3.4"]}]}` body := `{"name":"base","records":[{"type":"A","name":"@","ttl":300,"values":["1.2.3.4"]}]}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/templates", strings.NewReader(body)) req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/templates", strings.NewReader(body))
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -278,7 +288,7 @@ func TestUpdateTemplate_BadUUID(t *testing.T) {
router := NewRouter(a) router := NewRouter(a)
body := `{"name":"x","records":[]}` body := `{"name":"x","records":[]}`
req := httptest.NewRequest(http.MethodPut, "/api/v1/projects/"+testPID+"/templates/not-a-uuid", strings.NewReader(body)) req := requestWithSessionCookie(http.MethodPut, "/api/v1/projects/"+testPID+"/templates/not-a-uuid", strings.NewReader(body))
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -299,7 +309,7 @@ func TestImportZones_CreatesDomainPerZone(t *testing.T) {
}} }}
router := NewRouter(a) router := NewRouter(a)
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/"+accID.String()+"/import", nil) req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/"+accID.String()+"/import", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -337,7 +347,7 @@ func TestImportZones_AtomicRollbackOnError(t *testing.T) {
}} }}
router := NewRouter(a) router := NewRouter(a)
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/"+accID.String()+"/import", nil) req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/"+accID.String()+"/import", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -356,7 +366,7 @@ func TestImportZones_BadAccountUUID(t *testing.T) {
a, _ := newTenantTestAPI() a, _ := newTenantTestAPI()
router := NewRouter(a) router := NewRouter(a)
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/not-a-uuid/import", nil) req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/not-a-uuid/import", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -370,7 +380,7 @@ func TestCreateDomain_BadProjectUUID(t *testing.T) {
router := NewRouter(a) router := NewRouter(a)
body := `{"providerAccountId":"` + uuid.New().String() + `","zoneName":"example.com","zoneId":"z1"}` body := `{"providerAccountId":"` + uuid.New().String() + `","zoneName":"example.com","zoneId":"z1"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/not-a-uuid/domains", strings.NewReader(body)) req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/not-a-uuid/domains", strings.NewReader(body))
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -390,7 +400,7 @@ func TestCreateDomain_AccountNotFoundInProject(t *testing.T) {
// ts.accounts is empty, so GetAccount will not find this id. // ts.accounts is empty, so GetAccount will not find this id.
foreignAccID := uuid.New() foreignAccID := uuid.New()
body := `{"providerAccountId":"` + foreignAccID.String() + `","zoneName":"example.com","zoneId":"z1"}` body := `{"providerAccountId":"` + foreignAccID.String() + `","zoneName":"example.com","zoneId":"z1"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body)) req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -413,7 +423,7 @@ func TestCreateDomain_TemplateNotFoundInProject(t *testing.T) {
foreignTplID := uuid.New() foreignTplID := uuid.New()
body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1","templateId":"` + foreignTplID.String() + `"}` body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1","templateId":"` + foreignTplID.String() + `"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body)) req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -434,7 +444,7 @@ func TestCreateDomain_HappyPath(t *testing.T) {
router := NewRouter(a) router := NewRouter(a)
body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1"}` body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body)) req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -464,7 +474,7 @@ func TestCreateDomain_ValidTemplateInProject(t *testing.T) {
router := NewRouter(a) router := NewRouter(a)
body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1","templateId":"` + tplID.String() + `"}` body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1","templateId":"` + tplID.String() + `"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body)) req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -490,7 +500,7 @@ func TestSetDomainTemplate_ValidTemplateId(t *testing.T) {
router := NewRouter(a) router := NewRouter(a)
body := `{"templateId":"` + tplID.String() + `"}` body := `{"templateId":"` + tplID.String() + `"}`
req := httptest.NewRequest(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body)) req := requestWithSessionCookie(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body))
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -513,7 +523,7 @@ func TestSetDomainTemplate_BadTemplateUUID(t *testing.T) {
router := NewRouter(a) router := NewRouter(a)
body := `{"templateId":"not-a-uuid"}` body := `{"templateId":"not-a-uuid"}`
req := httptest.NewRequest(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body)) req := requestWithSessionCookie(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body))
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -530,7 +540,7 @@ func TestSetDomainTemplate_TemplateNotFound(t *testing.T) {
router := NewRouter(a) router := NewRouter(a)
body := `{"templateId":"` + uuid.New().String() + `"}` body := `{"templateId":"` + uuid.New().String() + `"}`
req := httptest.NewRequest(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body)) req := requestWithSessionCookie(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body))
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -543,7 +553,7 @@ func TestDeleteDomain_BadUUID(t *testing.T) {
a, _ := newTenantTestAPI() a, _ := newTenantTestAPI()
router := NewRouter(a) router := NewRouter(a)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/projects/"+testPID+"/domains/not-a-uuid", nil) req := requestWithSessionCookie(http.MethodDelete, "/api/v1/projects/"+testPID+"/domains/not-a-uuid", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
+93
View File
@@ -0,0 +1,93 @@
package auth
import (
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"fmt"
"regexp"
"strconv"
"strings"
"golang.org/x/crypto/argon2"
)
const (
argonTime = 1
argonMemory = 64 * 1024
argonThreads = 4
argonKeyLen = 32
argonSaltLen = 16
// Upper bounds guard against a corrupted/attacker-controlled
// password_hash forcing an oversized argon2 computation (DoS).
argonMaxMemoryKiB = 1 << 21 // 2 GiB in KiB
argonMaxTime = 10
argonMaxThreads = 16
)
// paramsRe strictly matches the "m=<n>,t=<n>,p=<n>" parameter segment,
// requiring the whole segment to be consumed (no trailing garbage).
var paramsRe = regexp.MustCompile(`^m=([0-9]+),t=([0-9]+),p=([0-9]+)$`)
func HashPassword(password string) (string, error) {
salt := make([]byte, argonSaltLen)
if _, err := rand.Read(salt); err != nil {
return "", err
}
key := argon2.IDKey([]byte(password), salt, argonTime, argonMemory, argonThreads, argonKeyLen)
b64 := base64.RawStdEncoding.EncodeToString
return fmt.Sprintf("$argon2id$v=19$m=%d,t=%d,p=%d$%s$%s",
argonMemory, argonTime, argonThreads, b64(salt), b64(key)), nil
}
func VerifyPassword(encoded, password string) (bool, error) {
parts := strings.Split(encoded, "$")
if len(parts) != 6 || parts[1] != "argon2id" {
return false, fmt.Errorf("auth: bad hash format")
}
if parts[2] != "v=19" {
return false, fmt.Errorf("auth: unsupported version")
}
matches := paramsRe.FindStringSubmatch(parts[3])
if matches == nil {
return false, fmt.Errorf("auth: bad hash format")
}
m64, err := strconv.ParseUint(matches[1], 10, 32)
if err != nil {
return false, fmt.Errorf("auth: bad hash format")
}
t64, err := strconv.ParseUint(matches[2], 10, 32)
if err != nil {
return false, fmt.Errorf("auth: bad hash format")
}
p64, err := strconv.ParseUint(matches[3], 10, 8)
if err != nil {
return false, fmt.Errorf("auth: bad hash format")
}
m, t, p := uint32(m64), uint32(t64), uint8(p64)
// argon2.IDKey panics if time<1 ("number of rounds too small") or
// threads<1 ("parallelism degree too low"). Reject before calling it.
// Minimum memory per argon2 spec is 8*parallelism (in KiB).
if t < 1 || p < 1 || m < 8*uint32(p) {
return false, fmt.Errorf("auth: bad hash params")
}
// Upper bounds guard against DoS via inflated parameters in a
// corrupted or attacker-controlled stored hash.
if m > argonMaxMemoryKiB || t > argonMaxTime || p > argonMaxThreads {
return false, fmt.Errorf("auth: bad hash params")
}
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
if err != nil {
return false, err
}
want, err := base64.RawStdEncoding.DecodeString(parts[5])
if err != nil {
return false, err
}
got := argon2.IDKey([]byte(password), salt, t, m, p, uint32(len(want)))
return subtle.ConstantTimeCompare(got, want) == 1, nil
}
+87
View File
@@ -0,0 +1,87 @@
package auth
import "testing"
func TestHashVerifyRoundTrip(t *testing.T) {
h, err := HashPassword("s3cret-pw")
if err != nil {
t.Fatal(err)
}
if h == "s3cret-pw" || len(h) < 20 {
t.Fatalf("bad hash %q", h)
}
ok, err := VerifyPassword(h, "s3cret-pw")
if err != nil || !ok {
t.Fatalf("verify failed: %v %v", ok, err)
}
bad, _ := VerifyPassword(h, "wrong")
if bad {
t.Fatal("wrong password must not verify")
}
}
func TestHashNonDeterministic(t *testing.T) {
a, _ := HashPassword("same")
b, _ := HashPassword("same")
if a == b {
t.Fatal("salt must randomize hash")
}
}
func TestVerifyPasswordBadTimeDoesNotPanic(t *testing.T) {
defer func() {
if r := recover(); r != nil {
t.Fatalf("VerifyPassword panicked: %v", r)
}
}()
encoded := "$argon2id$v=19$m=65536,t=0,p=4$c29tZXNhbHRzb21lc2FsdA$c29tZWhhc2hzb21laGFzaA"
ok, err := VerifyPassword(encoded, "anything")
if err == nil {
t.Fatal("expected error for t=0, got nil")
}
if ok {
t.Fatal("expected ok=false for t=0")
}
}
func TestVerifyPasswordBadThreadsDoesNotPanic(t *testing.T) {
defer func() {
if r := recover(); r != nil {
t.Fatalf("VerifyPassword panicked: %v", r)
}
}()
encoded := "$argon2id$v=19$m=65536,t=1,p=0$c29tZXNhbHRzb21lc2FsdA$c29tZWhhc2hzb21laGFzaA"
ok, err := VerifyPassword(encoded, "anything")
if err == nil {
t.Fatal("expected error for p=0, got nil")
}
if ok {
t.Fatal("expected ok=false for p=0")
}
}
func TestVerifyPasswordUnsupportedVersion(t *testing.T) {
encoded := "$argon2id$v=18$m=65536,t=1,p=4$c29tZXNhbHRzb21lc2FsdA$c29tZWhhc2hzb21laGFzaA"
ok, err := VerifyPassword(encoded, "anything")
if err == nil {
t.Fatal("expected error for unsupported version, got nil")
}
if ok {
t.Fatal("expected ok=false for unsupported version")
}
}
func TestVerifyPasswordGarbageFormatDoesNotPanic(t *testing.T) {
defer func() {
if r := recover(); r != nil {
t.Fatalf("VerifyPassword panicked: %v", r)
}
}()
ok, err := VerifyPassword("notahash", "anything")
if err == nil {
t.Fatal("expected error for garbage format, got nil")
}
if ok {
t.Fatal("expected ok=false for garbage format")
}
}
+56
View File
@@ -0,0 +1,56 @@
package auth
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"errors"
"time"
"github.com/google/uuid"
)
var ErrNoSession = errors.New("auth: no such session")
type SessionStore interface {
CreateSession(ctx context.Context, userID uuid.UUID, tokenHash string, expiresAt time.Time) error
GetSessionUser(ctx context.Context, tokenHash string) (uuid.UUID, error)
DeleteSession(ctx context.Context, tokenHash string) error
}
type Sessions struct {
store SessionStore
ttl time.Duration
}
func NewSessions(store SessionStore, ttl time.Duration) *Sessions {
return &Sessions{store: store, ttl: ttl}
}
func TokenHash(token string) string {
sum := sha256.Sum256([]byte(token))
return hex.EncodeToString(sum[:])
}
func (s *Sessions) Create(ctx context.Context, userID uuid.UUID) (string, time.Time, error) {
raw := make([]byte, 32)
if _, err := rand.Read(raw); err != nil {
return "", time.Time{}, err
}
token := base64.RawURLEncoding.EncodeToString(raw)
exp := time.Now().Add(s.ttl)
if err := s.store.CreateSession(ctx, userID, TokenHash(token), exp); err != nil {
return "", time.Time{}, err
}
return token, exp, nil
}
func (s *Sessions) Validate(ctx context.Context, token string) (uuid.UUID, error) {
return s.store.GetSessionUser(ctx, TokenHash(token))
}
func (s *Sessions) Destroy(ctx context.Context, token string) error {
return s.store.DeleteSession(ctx, TokenHash(token))
}
+55
View File
@@ -0,0 +1,55 @@
package auth
import (
"context"
"testing"
"time"
"github.com/google/uuid"
)
type memStore struct {
byHash map[string]uuid.UUID
exp map[string]time.Time
}
func newMem() *memStore { return &memStore{byHash: map[string]uuid.UUID{}, exp: map[string]time.Time{}} }
func (m *memStore) CreateSession(_ context.Context, uid uuid.UUID, h string, e time.Time) error {
m.byHash[h] = uid
m.exp[h] = e
return nil
}
func (m *memStore) GetSessionUser(_ context.Context, h string) (uuid.UUID, error) {
uid, ok := m.byHash[h]
if !ok || time.Now().After(m.exp[h]) {
return uuid.Nil, ErrNoSession
}
return uid, nil
}
func (m *memStore) DeleteSession(_ context.Context, h string) error { delete(m.byHash, h); return nil }
func TestSessionCreateValidateDestroy(t *testing.T) {
s := NewSessions(newMem(), time.Hour)
uid := uuid.New()
token, exp, err := s.Create(context.Background(), uid)
if err != nil || token == "" || exp.Before(time.Now()) {
t.Fatalf("create: %v %q", err, token)
}
got, err := s.Validate(context.Background(), token)
if err != nil || got != uid {
t.Fatalf("validate: %v %v", got, err)
}
if err := s.Destroy(context.Background(), token); err != nil {
t.Fatal(err)
}
if _, err := s.Validate(context.Background(), token); err == nil {
t.Fatal("destroyed session must not validate")
}
}
func TestValidateUnknownToken(t *testing.T) {
s := NewSessions(newMem(), time.Hour)
if _, err := s.Validate(context.Background(), "nope"); err == nil {
t.Fatal("unknown token must error")
}
}
+9 -7
View File
@@ -21,7 +21,7 @@ type DomainRef struct {
} }
type Loader interface { type Loader interface {
LoadDomain(ctx context.Context, domainID uuid.UUID) (DomainRef, error) LoadDomain(ctx context.Context, projectID, domainID uuid.UUID) (DomainRef, error)
} }
type Recorder interface { type Recorder interface {
@@ -45,8 +45,10 @@ func New(loader Loader, rec Recorder, reg *registry.Registry, cipher *crypto.Cip
} }
// resolve loads the domain, its provider and decrypted credentials, and computes the diff. // resolve loads the domain, its provider and decrypted credentials, and computes the diff.
func (s *DomainService) resolve(ctx context.Context, domainID uuid.UUID) (provider.Provider, provider.Credentials, DomainRef, diff.Changeset, error) { // projectID scopes the lookup so a domainID belonging to another tenant's
ref, err := s.loader.LoadDomain(ctx, domainID) // project can never be resolved here (closes IDOR).
func (s *DomainService) resolve(ctx context.Context, projectID, domainID uuid.UUID) (provider.Provider, provider.Credentials, DomainRef, diff.Changeset, error) {
ref, err := s.loader.LoadDomain(ctx, projectID, domainID)
if err != nil { if err != nil {
return nil, provider.Credentials{}, ref, diff.Changeset{}, err return nil, provider.Credentials{}, ref, diff.Changeset{}, err
} }
@@ -68,8 +70,8 @@ func (s *DomainService) resolve(ctx context.Context, domainID uuid.UUID) (provid
} }
// Check computes and records the diff between template and zone. // Check computes and records the diff between template and zone.
func (s *DomainService) Check(ctx context.Context, domainID uuid.UUID) (diff.Changeset, error) { func (s *DomainService) Check(ctx context.Context, projectID, domainID uuid.UUID) (diff.Changeset, error) {
_, _, _, cs, err := s.resolve(ctx, domainID) _, _, _, cs, err := s.resolve(ctx, projectID, domainID)
if err != nil { if err != nil {
return diff.Changeset{}, err return diff.Changeset{}, err
} }
@@ -80,8 +82,8 @@ func (s *DomainService) Check(ctx context.Context, domainID uuid.UUID) (diff.Cha
} }
// Apply applies updates always (when ApplyUpdates) and prunes only when ApplyPrunes. // Apply applies updates always (when ApplyUpdates) and prunes only when ApplyPrunes.
func (s *DomainService) Apply(ctx context.Context, domainID uuid.UUID, req ApplyRequest) (diff.Changeset, error) { func (s *DomainService) Apply(ctx context.Context, projectID, domainID uuid.UUID, req ApplyRequest) (diff.Changeset, error) {
p, creds, ref, cs, err := s.resolve(ctx, domainID) p, creds, ref, cs, err := s.resolve(ctx, projectID, domainID)
if err != nil { if err != nil {
return diff.Changeset{}, err return diff.Changeset{}, err
} }
+6 -4
View File
@@ -44,7 +44,9 @@ func (f *fakeProvider) ApplyChanges(_ context.Context, _ provider.Credentials, _
type fakeLoader struct{ ref DomainRef } type fakeLoader struct{ ref DomainRef }
func (l fakeLoader) LoadDomain(context.Context, uuid.UUID) (DomainRef, error) { return l.ref, nil } func (l fakeLoader) LoadDomain(context.Context, uuid.UUID, uuid.UUID) (DomainRef, error) {
return l.ref, nil
}
type nopRecorder struct{} type nopRecorder struct{}
@@ -66,7 +68,7 @@ func TestCheckProducesDiff(t *testing.T) {
{Type: "A", Name: "a.example.com.", TTL: 300, Values: []string{"1.1.1.1"}}, // update {Type: "A", Name: "a.example.com.", TTL: 300, Values: []string{"1.1.1.1"}}, // update
}} }}
svc, _ := setup(t, actual, tmpl) svc, _ := setup(t, actual, tmpl)
cs, err := svc.Check(context.Background(), uuid.New()) cs, err := svc.Check(context.Background(), uuid.New(), uuid.New())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -87,7 +89,7 @@ func TestApplyRespectsPruneGuard(t *testing.T) {
// applyPrunes=false → удаление b НЕ применяется // applyPrunes=false → удаление b НЕ применяется
svc, fp := setup(t, actual, tmpl) svc, fp := setup(t, actual, tmpl)
if _, err := svc.Apply(context.Background(), uuid.New(), ApplyRequest{ApplyUpdates: true, ApplyPrunes: false}); err != nil { if _, err := svc.Apply(context.Background(), uuid.New(), uuid.New(), ApplyRequest{ApplyUpdates: true, ApplyPrunes: false}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
for _, d := range fp.applied.Diffs { for _, d := range fp.applied.Diffs {
@@ -98,7 +100,7 @@ func TestApplyRespectsPruneGuard(t *testing.T) {
// applyPrunes=true → удаление b применяется // applyPrunes=true → удаление b применяется
svc2, fp2 := setup(t, actual, tmpl) svc2, fp2 := setup(t, actual, tmpl)
if _, err := svc2.Apply(context.Background(), uuid.New(), ApplyRequest{ApplyUpdates: true, ApplyPrunes: true}); err != nil { if _, err := svc2.Apply(context.Background(), uuid.New(), uuid.New(), ApplyRequest{ApplyUpdates: true, ApplyPrunes: true}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
var sawDelete bool var sawDelete bool
+185
View File
@@ -0,0 +1,185 @@
package store
import (
"errors"
"testing"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
)
// TestRegisterUser_CreatesUserAndOwnedProject verifies the RegisterUser
// transaction: a user and a default project are created together, and the
// project belongs to that user.
func TestRegisterUser_CreatesUserAndOwnedProject(t *testing.T) {
s, ctx := newStore(t)
u, p, err := s.RegisterUser(ctx, "alice@example.com", "argon2-hash")
if err != nil {
t.Fatal(err)
}
if u.Email != "alice@example.com" || u.PasswordHash != "argon2-hash" {
t.Fatalf("unexpected user: %+v", u)
}
if p.UserID != u.ID {
t.Fatalf("expected project to belong to user %s, got %+v", u.ID, p)
}
owned, err := s.GetProjectOwned(ctx, p.ID, u.ID)
if err != nil {
t.Fatal(err)
}
if owned.ID != p.ID {
t.Fatalf("expected owned project %s, got %+v", p.ID, owned)
}
}
// TestGetUserByEmail_FindsRegisteredUser verifies email lookup returns the
// same user created by RegisterUser.
func TestGetUserByEmail_FindsRegisteredUser(t *testing.T) {
s, ctx := newStore(t)
u, _, err := s.RegisterUser(ctx, "bob@example.com", "argon2-hash")
if err != nil {
t.Fatal(err)
}
got, err := s.GetUserByEmail(ctx, "bob@example.com")
if err != nil {
t.Fatal(err)
}
if got.ID != u.ID || got.PasswordHash != "argon2-hash" {
t.Fatalf("unexpected user: %+v", got)
}
if _, err := s.GetUserByEmail(ctx, "nobody@example.com"); err == nil {
t.Fatal("expected error for unknown email, got nil")
}
}
// TestRegisterUser_DuplicateEmailReturnsErrEmailTaken verifies the fix for
// the duplicate-registration gap: a second RegisterUser call for an
// already-taken email must fail with the ErrEmailTaken sentinel (mapped from
// the UNIQUE constraint violation on users.email), not a generic pgx error.
func TestRegisterUser_DuplicateEmailReturnsErrEmailTaken(t *testing.T) {
s, ctx := newStore(t)
if _, _, err := s.RegisterUser(ctx, "dup@example.com", "argon2-hash"); err != nil {
t.Fatal(err)
}
if _, _, err := s.RegisterUser(ctx, "dup@example.com", "argon2-hash"); !errors.Is(err, ErrEmailTaken) {
t.Fatalf("expected ErrEmailTaken, got %v", err)
}
}
// TestGetUserByID_ReturnsUser verifies the fix for the /me gap: GetUserByID
// returns the same user created by RegisterUser, including their real email.
func TestGetUserByID_ReturnsUser(t *testing.T) {
s, ctx := newStore(t)
u, _, err := s.RegisterUser(ctx, "gina@example.com", "argon2-hash")
if err != nil {
t.Fatal(err)
}
got, err := s.GetUserByID(ctx, u.ID)
if err != nil {
t.Fatal(err)
}
if got.ID != u.ID || got.Email != "gina@example.com" {
t.Fatalf("unexpected user: %+v", got)
}
}
// TestSessionLifecycle_CreateGetDelete verifies CreateSession + GetSessionUser
// round-trips to the owning user ID, an expired session is excluded from
// GetSessionUser, and DeleteSession removes the session.
func TestSessionLifecycle_CreateGetDelete(t *testing.T) {
s, ctx := newStore(t)
u, _, err := s.RegisterUser(ctx, "carol@example.com", "argon2-hash")
if err != nil {
t.Fatal(err)
}
tokenHash := "sha256-token-hash"
if err := s.CreateSession(ctx, u.ID, tokenHash, time.Now().Add(time.Hour)); err != nil {
t.Fatal(err)
}
gotUserID, err := s.GetSessionUser(ctx, tokenHash)
if err != nil {
t.Fatal(err)
}
if gotUserID != u.ID {
t.Fatalf("expected user %s, got %s", u.ID, gotUserID)
}
if err := s.DeleteSession(ctx, tokenHash); err != nil {
t.Fatal(err)
}
if _, err := s.GetSessionUser(ctx, tokenHash); err == nil {
t.Fatal("expected error after DeleteSession, got nil")
}
}
// TestGetSessionUser_ExpiredSessionNotReturned verifies the query's
// expires_at > now() condition: a session created with an expiry in the
// past must not be returned by GetSessionUser.
func TestGetSessionUser_ExpiredSessionNotReturned(t *testing.T) {
s, ctx := newStore(t)
u, _, err := s.RegisterUser(ctx, "dave@example.com", "argon2-hash")
if err != nil {
t.Fatal(err)
}
tokenHash := "expired-token-hash"
if err := s.CreateSession(ctx, u.ID, tokenHash, time.Now().Add(-time.Hour)); err != nil {
t.Fatal(err)
}
if _, err := s.GetSessionUser(ctx, tokenHash); err == nil {
t.Fatal("expected expired session to not be returned, got nil error")
} else if err != pgx.ErrNoRows {
t.Fatalf("expected pgx.ErrNoRows, got %v", err)
}
}
// TestGetProjectOwned_ForeignUserRejected verifies that looking up a project
// with the wrong user ID fails, so one tenant cannot address another
// tenant's project by guessing its ID.
func TestGetProjectOwned_ForeignUserRejected(t *testing.T) {
s, ctx := newStore(t)
_, p, err := s.RegisterUser(ctx, "erin@example.com", "argon2-hash")
if err != nil {
t.Fatal(err)
}
foreignUserID := uuid.New()
if _, err := s.GetProjectOwned(ctx, p.ID, foreignUserID); err == nil {
t.Fatal("expected error for foreign user ID, got nil")
}
}
// TestGetUserProject_ReturnsTheUsersProject verifies GetUserProject returns
// the project created for that user by RegisterUser.
func TestGetUserProject_ReturnsTheUsersProject(t *testing.T) {
s, ctx := newStore(t)
u, p, err := s.RegisterUser(ctx, "frank@example.com", "argon2-hash")
if err != nil {
t.Fatal(err)
}
got, err := s.GetUserProject(ctx, u.ID)
if err != nil {
t.Fatal(err)
}
if got.ID != p.ID {
t.Fatalf("expected project %s, got %+v", p.ID, got)
}
}
+8 -3
View File
@@ -162,9 +162,14 @@ SELECT d.zone_id, a.provider, a.secret_enc, t.doc
FROM domains d FROM domains d
JOIN provider_accounts a ON a.id = d.provider_account_id JOIN provider_accounts a ON a.id = d.provider_account_id
LEFT JOIN templates t ON t.id = d.template_id LEFT JOIN templates t ON t.id = d.template_id
WHERE d.id = $1 WHERE d.id = $1 AND d.project_id = $2
` `
type LoadDomainFullParams struct {
ID uuid.UUID `json:"id"`
ProjectID uuid.UUID `json:"project_id"`
}
type LoadDomainFullRow struct { type LoadDomainFullRow struct {
ZoneID string `json:"zone_id"` ZoneID string `json:"zone_id"`
Provider string `json:"provider"` Provider string `json:"provider"`
@@ -172,8 +177,8 @@ type LoadDomainFullRow struct {
Doc *dto.TemplateDoc `json:"doc"` Doc *dto.TemplateDoc `json:"doc"`
} }
func (q *Queries) LoadDomainFull(ctx context.Context, id uuid.UUID) (LoadDomainFullRow, error) { func (q *Queries) LoadDomainFull(ctx context.Context, arg LoadDomainFullParams) (LoadDomainFullRow, error) {
row := q.db.QueryRow(ctx, loadDomainFull, id) row := q.db.QueryRow(ctx, loadDomainFull, arg.ID, arg.ProjectID)
var i LoadDomainFullRow var i LoadDomainFullRow
err := row.Scan( err := row.Scan(
&i.ZoneID, &i.ZoneID,
+9
View File
@@ -43,6 +43,14 @@ type ProviderAccount struct {
CreatedAt pgtype.Timestamptz `json:"created_at"` CreatedAt pgtype.Timestamptz `json:"created_at"`
} }
type Session struct {
ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"`
TokenHash string `json:"token_hash"`
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
}
type Template struct { type Template struct {
ID uuid.UUID `json:"id"` ID uuid.UUID `json:"id"`
ProjectID uuid.UUID `json:"project_id"` ProjectID uuid.UUID `json:"project_id"`
@@ -57,4 +65,5 @@ type User struct {
ID uuid.UUID `json:"id"` ID uuid.UUID `json:"id"`
Email string `json:"email"` Email string `json:"email"`
CreatedAt pgtype.Timestamptz `json:"created_at"` CreatedAt pgtype.Timestamptz `json:"created_at"`
PasswordHash *string `json:"password_hash"`
} }
+71
View File
@@ -0,0 +1,71 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.31.1
// source: projects.sql
package db
import (
"context"
"github.com/google/uuid"
)
const createProject = `-- name: CreateProject :one
INSERT INTO projects (id, user_id, name) VALUES ($1, $2, $3) RETURNING id, user_id, name, created_at
`
type CreateProjectParams struct {
ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"`
Name string `json:"name"`
}
func (q *Queries) CreateProject(ctx context.Context, arg CreateProjectParams) (Project, error) {
row := q.db.QueryRow(ctx, createProject, arg.ID, arg.UserID, arg.Name)
var i Project
err := row.Scan(
&i.ID,
&i.UserID,
&i.Name,
&i.CreatedAt,
)
return i, err
}
const getProjectOwned = `-- name: GetProjectOwned :one
SELECT id, user_id, name, created_at FROM projects WHERE id = $1 AND user_id = $2
`
type GetProjectOwnedParams struct {
ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"`
}
func (q *Queries) GetProjectOwned(ctx context.Context, arg GetProjectOwnedParams) (Project, error) {
row := q.db.QueryRow(ctx, getProjectOwned, arg.ID, arg.UserID)
var i Project
err := row.Scan(
&i.ID,
&i.UserID,
&i.Name,
&i.CreatedAt,
)
return i, err
}
const getUserProject = `-- name: GetUserProject :one
SELECT id, user_id, name, created_at FROM projects WHERE user_id = $1 ORDER BY created_at LIMIT 1
`
func (q *Queries) GetUserProject(ctx context.Context, userID uuid.UUID) (Project, error) {
row := q.db.QueryRow(ctx, getUserProject, userID)
var i Project
err := row.Scan(
&i.ID,
&i.UserID,
&i.Name,
&i.CreatedAt,
)
return i, err
}
+54
View File
@@ -0,0 +1,54 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.31.1
// source: sessions.sql
package db
import (
"context"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgtype"
)
const createSession = `-- name: CreateSession :exec
INSERT INTO sessions (id, user_id, token_hash, expires_at) VALUES ($1, $2, $3, $4)
`
type CreateSessionParams struct {
ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"`
TokenHash string `json:"token_hash"`
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
}
func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) error {
_, err := q.db.Exec(ctx, createSession,
arg.ID,
arg.UserID,
arg.TokenHash,
arg.ExpiresAt,
)
return err
}
const deleteSession = `-- name: DeleteSession :exec
DELETE FROM sessions WHERE token_hash = $1
`
func (q *Queries) DeleteSession(ctx context.Context, tokenHash string) error {
_, err := q.db.Exec(ctx, deleteSession, tokenHash)
return err
}
const getSessionUser = `-- name: GetSessionUser :one
SELECT user_id FROM sessions WHERE token_hash = $1 AND expires_at > now()
`
func (q *Queries) GetSessionUser(ctx context.Context, tokenHash string) (uuid.UUID, error) {
row := q.db.QueryRow(ctx, getSessionUser, tokenHash)
var user_id uuid.UUID
err := row.Scan(&user_id)
return user_id, err
}
+66
View File
@@ -0,0 +1,66 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.31.1
// source: users.sql
package db
import (
"context"
"github.com/google/uuid"
)
const createUser = `-- name: CreateUser :one
INSERT INTO users (id, email, password_hash) VALUES ($1, $2, $3) RETURNING id, email, created_at, password_hash
`
type CreateUserParams struct {
ID uuid.UUID `json:"id"`
Email string `json:"email"`
PasswordHash *string `json:"password_hash"`
}
func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) {
row := q.db.QueryRow(ctx, createUser, arg.ID, arg.Email, arg.PasswordHash)
var i User
err := row.Scan(
&i.ID,
&i.Email,
&i.CreatedAt,
&i.PasswordHash,
)
return i, err
}
const getUserByEmail = `-- name: GetUserByEmail :one
SELECT id, email, created_at, password_hash FROM users WHERE email = $1
`
func (q *Queries) GetUserByEmail(ctx context.Context, email string) (User, error) {
row := q.db.QueryRow(ctx, getUserByEmail, email)
var i User
err := row.Scan(
&i.ID,
&i.Email,
&i.CreatedAt,
&i.PasswordHash,
)
return i, err
}
const getUserByID = `-- name: GetUserByID :one
SELECT id, email, created_at, password_hash FROM users WHERE id = $1
`
func (q *Queries) GetUserByID(ctx context.Context, id uuid.UUID) (User, error) {
row := q.db.QueryRow(ctx, getUserByID, id)
var i User
err := row.Scan(
&i.ID,
&i.Email,
&i.CreatedAt,
&i.PasswordHash,
)
return i, err
}
+5 -3
View File
@@ -13,9 +13,11 @@ import (
) )
// LoadDomain joins domains+provider_accounts+templates to build the // LoadDomain joins domains+provider_accounts+templates to build the
// service.DomainRef needed to check/apply a domain's DNS records. // service.DomainRef needed to check/apply a domain's DNS records. Scoped by
func (s *Store) LoadDomain(ctx context.Context, domainID uuid.UUID) (service.DomainRef, error) { // projectID so a domain belonging to another tenant's project can never be
row, err := s.q.LoadDomainFull(ctx, domainID) // loaded, even if its domainID is guessed/leaked (closes IDOR).
func (s *Store) LoadDomain(ctx context.Context, projectID, domainID uuid.UUID) (service.DomainRef, error) {
row, err := s.q.LoadDomainFull(ctx, db.LoadDomainFullParams{ID: domainID, ProjectID: projectID})
if err != nil { if err != nil {
return service.DomainRef{}, err return service.DomainRef{}, err
} }
+2 -2
View File
@@ -40,7 +40,7 @@ func TestLoadDomainAndSaveCheckRun(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
ref, err := s.LoadDomain(ctx, domain.ID) ref, err := s.LoadDomain(ctx, defaultProject, domain.ID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -87,7 +87,7 @@ func TestLoadDomainNoTemplate(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if _, err := s.LoadDomain(ctx, domain.ID); err == nil { if _, err := s.LoadDomain(ctx, defaultProject, domain.ID); err == nil {
t.Fatal("expected error for domain without template, got nil") t.Fatal("expected error for domain without template, got nil")
} }
} }
+15
View File
@@ -0,0 +1,15 @@
-- +goose Up
ALTER TABLE users ADD COLUMN password_hash text;
CREATE TABLE sessions (
id uuid PRIMARY KEY,
user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE,
token_hash text NOT NULL UNIQUE,
expires_at timestamptz NOT NULL,
created_at timestamptz NOT NULL DEFAULT now()
);
CREATE INDEX sessions_token_hash_idx ON sessions (token_hash);
-- +goose Down
DROP TABLE sessions;
ALTER TABLE users DROP COLUMN password_hash;
+1 -1
View File
@@ -27,4 +27,4 @@ SELECT d.zone_id, a.provider, a.secret_enc, t.doc
FROM domains d FROM domains d
JOIN provider_accounts a ON a.id = d.provider_account_id JOIN provider_accounts a ON a.id = d.provider_account_id
LEFT JOIN templates t ON t.id = d.template_id LEFT JOIN templates t ON t.id = d.template_id
WHERE d.id = $1; WHERE d.id = $1 AND d.project_id = $2;
+8
View File
@@ -0,0 +1,8 @@
-- name: CreateProject :one
INSERT INTO projects (id, user_id, name) VALUES ($1, $2, $3) RETURNING *;
-- name: GetProjectOwned :one
SELECT * FROM projects WHERE id = $1 AND user_id = $2;
-- name: GetUserProject :one
SELECT * FROM projects WHERE user_id = $1 ORDER BY created_at LIMIT 1;
+8
View File
@@ -0,0 +1,8 @@
-- name: CreateSession :exec
INSERT INTO sessions (id, user_id, token_hash, expires_at) VALUES ($1, $2, $3, $4);
-- name: GetSessionUser :one
SELECT user_id FROM sessions WHERE token_hash = $1 AND expires_at > now();
-- name: DeleteSession :exec
DELETE FROM sessions WHERE token_hash = $1;
+8
View File
@@ -0,0 +1,8 @@
-- name: CreateUser :one
INSERT INTO users (id, email, password_hash) VALUES ($1, $2, $3) RETURNING *;
-- name: GetUserByEmail :one
SELECT * FROM users WHERE email = $1;
-- name: GetUserByID :one
SELECT * FROM users WHERE id = $1;
+2 -2
View File
@@ -214,7 +214,7 @@ func TestSetDomainTemplate_ClosesImportCheckLoop(t *testing.T) {
dom := doms[0] dom := doms[0]
// Before binding, the domain is not checkable. // Before binding, the domain is not checkable.
if _, err := s.LoadDomain(ctx, dom.ID); err == nil { if _, err := s.LoadDomain(ctx, defaultProject, dom.ID); err == nil {
t.Fatal("expected LoadDomain to fail before a template is bound") t.Fatal("expected LoadDomain to fail before a template is bound")
} }
@@ -234,7 +234,7 @@ func TestSetDomainTemplate_ClosesImportCheckLoop(t *testing.T) {
t.Fatalf("expected domain.TemplateID=%s, got %+v", tpl.ID, updated.TemplateID) t.Fatalf("expected domain.TemplateID=%s, got %+v", tpl.ID, updated.TemplateID)
} }
ref, err := s.LoadDomain(ctx, dom.ID) ref, err := s.LoadDomain(ctx, defaultProject, dom.ID)
if err != nil { if err != nil {
t.Fatalf("expected LoadDomain to succeed after binding template, got error: %v", err) t.Fatalf("expected LoadDomain to succeed after binding template, got error: %v", err)
} }
+147
View File
@@ -3,15 +3,23 @@ package store
import ( import (
"context" "context"
"errors" "errors"
"time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
"github.com/vasyakrg/dns-autoresolver/internal/provider" "github.com/vasyakrg/dns-autoresolver/internal/provider"
"github.com/vasyakrg/dns-autoresolver/internal/store/db" "github.com/vasyakrg/dns-autoresolver/internal/store/db"
"github.com/vasyakrg/dns-autoresolver/internal/store/dto" "github.com/vasyakrg/dns-autoresolver/internal/store/dto"
) )
// ErrEmailTaken is returned by RegisterUser when the email is already
// registered — a UNIQUE constraint violation (pgerrcode 23505) on
// users.email.
var ErrEmailTaken = errors.New("store: email already registered")
// Account/Template/Domain are provider-neutral domain structs returned by the // Account/Template/Domain are provider-neutral domain structs returned by the
// thin wrappers below, so callers (internal/api) never need to import // thin wrappers below, so callers (internal/api) never need to import
// internal/store/db directly. // internal/store/db directly.
@@ -222,3 +230,142 @@ func (s *Store) SetDomainTemplate(ctx context.Context, domainID, projectID uuid.
} }
return domainFromDB(d), nil return domainFromDB(d), nil
} }
// User and Project are provider-neutral domain structs for the auth/tenant
// layer (Фаза 2), mirroring the Account/Template/Domain wrappers above so
// callers never need to import internal/store/db directly.
type User struct {
ID uuid.UUID
Email string
PasswordHash string
}
type Project struct {
ID uuid.UUID
UserID uuid.UUID
Name string
}
// ptr is a small helper for passing a Go string into a nullable text column
// (password_hash) via sqlc's generated *string param type.
func ptr(s string) *string { return &s }
// strFromPtr converts a nullable text column back into a plain string; a
// nil password_hash never happens on the real registration flow (an argon2
// hash is always supplied), but is handled defensively here.
func strFromPtr(p *string) string {
if p == nil {
return ""
}
return *p
}
func toUser(u db.User) User {
return User{ID: u.ID, Email: u.Email, PasswordHash: strFromPtr(u.PasswordHash)}
}
func toProject(p db.Project) Project {
return Project{ID: p.ID, UserID: p.UserID, Name: p.Name}
}
func (s *Store) CreateUser(ctx context.Context, email, passwordHash string) (User, error) {
u, err := s.q.CreateUser(ctx, db.CreateUserParams{ID: uuid.New(), Email: email, PasswordHash: ptr(passwordHash)})
if err != nil {
return User{}, err
}
return toUser(u), nil
}
func (s *Store) GetUserByEmail(ctx context.Context, email string) (User, error) {
u, err := s.q.GetUserByEmail(ctx, email)
if err != nil {
return User{}, err
}
return toUser(u), nil
}
// GetUserByID looks up a user by primary key — used by handleMe (Task 3
// hardening) to return the authenticated caller's real email instead of
// leaving it blank.
func (s *Store) GetUserByID(ctx context.Context, id uuid.UUID) (User, error) {
u, err := s.q.GetUserByID(ctx, id)
if err != nil {
return User{}, err
}
return toUser(u), nil
}
func (s *Store) CreateProjectForUser(ctx context.Context, userID uuid.UUID, name string) (Project, error) {
p, err := s.q.CreateProject(ctx, db.CreateProjectParams{ID: uuid.New(), UserID: userID, Name: name})
if err != nil {
return Project{}, err
}
return toProject(p), nil
}
func (s *Store) GetProjectOwned(ctx context.Context, projectID, userID uuid.UUID) (Project, error) {
p, err := s.q.GetProjectOwned(ctx, db.GetProjectOwnedParams{ID: projectID, UserID: userID})
if err != nil {
return Project{}, err
}
return toProject(p), nil
}
func (s *Store) GetUserProject(ctx context.Context, userID uuid.UUID) (Project, error) {
p, err := s.q.GetUserProject(ctx, userID)
if err != nil {
return Project{}, err
}
return toProject(p), nil
}
func (s *Store) CreateSession(ctx context.Context, userID uuid.UUID, tokenHash string, expiresAt time.Time) error {
return s.q.CreateSession(ctx, db.CreateSessionParams{
ID: uuid.New(),
UserID: userID,
TokenHash: tokenHash,
ExpiresAt: pgtype.Timestamptz{Time: expiresAt, Valid: true},
})
}
// GetSessionUser returns the owning user ID for a non-expired session token;
// expired sessions are excluded by the query itself (expires_at > now()).
func (s *Store) GetSessionUser(ctx context.Context, tokenHash string) (uuid.UUID, error) {
return s.q.GetSessionUser(ctx, tokenHash)
}
func (s *Store) DeleteSession(ctx context.Context, tokenHash string) error {
return s.q.DeleteSession(ctx, tokenHash)
}
// RegisterUser creates a user and their default project in one transaction,
// mirroring the ImportDomains pattern above: if project creation fails, the
// user insert is rolled back too, so a caller never observes a user without
// a default project.
func (s *Store) RegisterUser(ctx context.Context, email, passwordHash string) (User, Project, error) {
tx, err := s.pool.Begin(ctx)
if err != nil {
return User{}, Project{}, err
}
defer tx.Rollback(ctx) // no-op once Commit has succeeded
q := s.q.WithTx(tx)
uid := uuid.New()
dbu, err := q.CreateUser(ctx, db.CreateUserParams{ID: uid, Email: email, PasswordHash: ptr(passwordHash)})
if err != nil {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
return User{}, Project{}, ErrEmailTaken
}
return User{}, Project{}, err
}
dbp, err := q.CreateProject(ctx, db.CreateProjectParams{ID: uuid.New(), UserID: uid, Name: "default"})
if err != nil {
return User{}, Project{}, err
}
if err := tx.Commit(ctx); err != nil {
return User{}, Project{}, err
}
return toUser(dbu), toProject(dbp), nil
}
+13 -2
View File
@@ -1,19 +1,30 @@
import { render, screen, within } from "@testing-library/react" import { render, screen, within } from "@testing-library/react"
import { MemoryRouter } from "react-router-dom" import { MemoryRouter } from "react-router-dom"
import { QueryClient, QueryClientProvider } from "@tanstack/react-query" import { QueryClient, QueryClientProvider } from "@tanstack/react-query"
import { vi, test, expect } from "vitest"
import { App } from "./App" import { App } from "./App"
import { AuthProvider } from "@/auth/AuthContext"
import { api } from "@/api/client"
test("renders navigation and redirects to domains", async () => {
vi.spyOn(api.auth, "me").mockResolvedValue({
user: { id: "u1", email: "a@b.com" },
project: { id: "p1", name: "Default" },
})
vi.spyOn(api, "listDomains").mockResolvedValue([])
test("renders navigation and redirects to domains", () => {
render( render(
<QueryClientProvider client={new QueryClient()}> <QueryClientProvider client={new QueryClient()}>
<AuthProvider>
<MemoryRouter initialEntries={["/"]}> <MemoryRouter initialEntries={["/"]}>
<App /> <App />
</MemoryRouter> </MemoryRouter>
</AuthProvider>
</QueryClientProvider>, </QueryClientProvider>,
) )
// Sidebar nav also renders a "Domains" link label, so scope the assertion // Sidebar nav also renders a "Domains" link label, so scope the assertion
// to the routed page content to unambiguously confirm the redirect + page. // to the routed page content to unambiguously confirm the redirect + page.
const main = screen.getByRole("main") const main = await screen.findByRole("main")
expect(within(main).getByText("Domains")).toBeInTheDocument() expect(within(main).getByText("Domains")).toBeInTheDocument()
expect(screen.getByRole("link", { name: /domains/i })).toBeInTheDocument() expect(screen.getByRole("link", { name: /domains/i })).toBeInTheDocument()
}) })
+20 -6
View File
@@ -1,20 +1,34 @@
import type { ReactNode } from "react"
import { Routes, Route, Navigate } from "react-router-dom" import { Routes, Route, Navigate } from "react-router-dom"
import { ProtectedRoute } from "@/auth/ProtectedRoute"
import { Layout } from "@/components/Layout" import { Layout } from "@/components/Layout"
import { AccountsPage } from "@/pages/AccountsPage" import { AccountsPage } from "@/pages/AccountsPage"
import { DomainDiffPage } from "@/pages/DomainDiffPage" import { DomainDiffPage } from "@/pages/DomainDiffPage"
import { DomainsPage } from "@/pages/DomainsPage" import { DomainsPage } from "@/pages/DomainsPage"
import { LoginPage } from "@/pages/LoginPage"
import { RegisterPage } from "@/pages/RegisterPage"
import { TemplatesPage } from "@/pages/TemplatesPage" import { TemplatesPage } from "@/pages/TemplatesPage"
// Every non-auth route shares the same guard + chrome; wrapping here keeps
// each <Route> below a one-liner instead of repeating both on every page.
function Protected({ children }: { children: ReactNode }) {
return (
<ProtectedRoute>
<Layout>{children}</Layout>
</ProtectedRoute>
)
}
export function App() { export function App() {
return ( return (
<Layout>
<Routes> <Routes>
<Route path="/login" element={<LoginPage />} />
<Route path="/register" element={<RegisterPage />} />
<Route path="/" element={<Navigate to="/domains" replace />} /> <Route path="/" element={<Navigate to="/domains" replace />} />
<Route path="/domains" element={<DomainsPage />} /> <Route path="/domains" element={<Protected><DomainsPage /></Protected>} />
<Route path="/domains/:id" element={<DomainDiffPage />} /> <Route path="/domains/:id" element={<Protected><DomainDiffPage /></Protected>} />
<Route path="/accounts" element={<AccountsPage />} /> <Route path="/accounts" element={<Protected><AccountsPage /></Protected>} />
<Route path="/templates" element={<TemplatesPage />} /> <Route path="/templates" element={<Protected><TemplatesPage /></Protected>} />
</Routes> </Routes>
</Layout>
) )
} }
+95 -10
View File
@@ -1,6 +1,7 @@
import { describe, it, expect, vi, beforeEach } from "vitest" import { describe, it, expect, vi, beforeEach } from "vitest"
import { api } from "./client" import { api, UnauthorizedError } from "./client"
import { DEFAULT_PROJECT_ID } from "@/lib/config"
const PROJECT_ID = "11111111-1111-1111-1111-111111111111"
beforeEach(() => { vi.restoreAllMocks() }) beforeEach(() => { vi.restoreAllMocks() })
@@ -13,19 +14,72 @@ function mockFetch(body: unknown, ok = true, status = 200) {
} }
describe("api client", () => { describe("api client", () => {
it("lists accounts at project-scoped path", async () => { it("sends credentials:include on every request", async () => {
const spy = mockFetch([])
await api.listAccounts(PROJECT_ID)
const [, opts] = spy.mock.calls[0]
expect((opts as RequestInit).credentials).toBe("include")
})
describe("api.auth", () => {
it("login POSTs to /api/v1/auth/login with credentials:include", async () => {
const spy = mockFetch({ user: { id: "u1", email: "a@b.com" }, project: { id: "p1", name: "Default" } })
await api.auth.login("a@b.com", "secret")
const [url, opts] = spy.mock.calls[0]
expect(url).toBe("/api/v1/auth/login")
expect((opts as RequestInit).method).toBe("POST")
expect((opts as RequestInit).credentials).toBe("include")
expect(String((opts as RequestInit).body)).toContain("secret")
})
it("register POSTs to /api/v1/auth/register", async () => {
const spy = mockFetch({ user: { id: "u1", email: "a@b.com" }, project: { id: "p1", name: "Default" } })
await api.auth.register("a@b.com", "secret")
const [url, opts] = spy.mock.calls[0]
expect(url).toBe("/api/v1/auth/register")
expect((opts as RequestInit).method).toBe("POST")
})
it("logout POSTs to /api/v1/auth/logout", async () => {
const spy = mockFetch(undefined, true, 204)
await api.auth.logout()
const [url, opts] = spy.mock.calls[0]
expect(url).toBe("/api/v1/auth/logout")
expect((opts as RequestInit).method).toBe("POST")
})
it("me GETs /api/v1/auth/me and returns AuthState", async () => {
const state = { user: { id: "u1", email: "a@b.com" }, project: { id: "p1", name: "Default" } }
const spy = mockFetch(state)
const result = await api.auth.me()
const [url] = spy.mock.calls[0]
expect(url).toBe("/api/v1/auth/me")
expect(result).toEqual(state)
})
})
it("resource methods hit project-scoped path with projectId first", async () => {
const spy = mockFetch([{ id: "a1", provider: "selectel", comment: "" }]) const spy = mockFetch([{ id: "a1", provider: "selectel", comment: "" }])
const accounts = await api.listAccounts() const accounts = await api.listAccounts(PROJECT_ID)
expect(accounts).toHaveLength(1) expect(accounts).toHaveLength(1)
expect(spy).toHaveBeenCalledWith( expect(spy).toHaveBeenCalledWith(
`/api/v1/projects/${DEFAULT_PROJECT_ID}/accounts`, `/api/v1/projects/${PROJECT_ID}/accounts`,
expect.objectContaining({ method: "GET" }),
)
})
it("listDomains(projectId) hits /api/v1/projects/{projectId}/domains", async () => {
const spy = mockFetch([])
await api.listDomains(PROJECT_ID)
expect(spy).toHaveBeenCalledWith(
`/api/v1/projects/${PROJECT_ID}/domains`,
expect.objectContaining({ method: "GET" }), expect.objectContaining({ method: "GET" }),
) )
}) })
it("sends secret on account creation but path has no secret leakage in response typing", async () => { it("sends secret on account creation but path has no secret leakage in response typing", async () => {
const spy = mockFetch({ id: "a2", provider: "selectel", comment: "prod" }) const spy = mockFetch({ id: "a2", provider: "selectel", comment: "prod" })
await api.createAccount({ provider: "selectel", secret: "TOKEN", comment: "prod" }) await api.createAccount(PROJECT_ID, { provider: "selectel", secret: "TOKEN", comment: "prod" })
const [, opts] = spy.mock.calls[0] const [, opts] = spy.mock.calls[0]
expect((opts as RequestInit).method).toBe("POST") expect((opts as RequestInit).method).toBe("POST")
expect(String((opts as RequestInit).body)).toContain("TOKEN") expect(String((opts as RequestInit).body)).toContain("TOKEN")
@@ -33,14 +87,45 @@ describe("api client", () => {
it("throws on non-ok response", async () => { it("throws on non-ok response", async () => {
mockFetch({ error: "boom" }, false, 500) mockFetch({ error: "boom" }, false, 500)
await expect(api.listDomains()).rejects.toThrow() await expect(api.listDomains(PROJECT_ID)).rejects.toThrow()
}) })
it("applies with prune flag", async () => { it("throws UnauthorizedError on 401", async () => {
mockFetch({ error: "unauthorized" }, false, 401)
await expect(api.listDomains(PROJECT_ID)).rejects.toThrow(UnauthorizedError)
})
it("applies with prune flag using projectId, id, body order", async () => {
const spy = mockFetch({ updates: [], prunes: [], readOnly: [], inSyncCount: 0 }) const spy = mockFetch({ updates: [], prunes: [], readOnly: [], inSyncCount: 0 })
await api.applyDomain("d1", { applyUpdates: true, applyPrunes: true }) await api.applyDomain(PROJECT_ID, "d1", { applyUpdates: true, applyPrunes: true })
const [url, opts] = spy.mock.calls[0] const [url, opts] = spy.mock.calls[0]
expect(url).toContain("/domains/d1/apply") expect(url).toBe(`/api/v1/projects/${PROJECT_ID}/domains/d1/apply`)
expect(String((opts as RequestInit).body)).toContain("applyPrunes") expect(String((opts as RequestInit).body)).toContain("applyPrunes")
}) })
it("checkDomain(projectId, id) hits project-scoped check path", async () => {
const spy = mockFetch({ updates: [], prunes: [], readOnly: [], inSyncCount: 0 })
await api.checkDomain(PROJECT_ID, "d1")
expect(spy).toHaveBeenCalledWith(
`/api/v1/projects/${PROJECT_ID}/domains/d1/check`,
expect.objectContaining({ method: "GET" }),
)
})
it("importZones(projectId, accountId) hits project-scoped import path", async () => {
const spy = mockFetch([])
await api.importZones(PROJECT_ID, "acc1")
expect(spy).toHaveBeenCalledWith(
`/api/v1/projects/${PROJECT_ID}/accounts/acc1/import`,
expect.objectContaining({ method: "POST" }),
)
})
it("setDomainTemplate(projectId, id, templateId) hits project-scoped domain path", async () => {
const spy = mockFetch({ id: "d1", providerAccountId: "acc1", zoneName: "x.", zoneId: "z1" })
await api.setDomainTemplate(PROJECT_ID, "d1", "t1")
const [url, opts] = spy.mock.calls[0]
expect(url).toBe(`/api/v1/projects/${PROJECT_ID}/domains/d1`)
expect((opts as RequestInit).method).toBe("PATCH")
})
}) })
+60 -27
View File
@@ -1,15 +1,25 @@
import { API_BASE } from "@/lib/config" import { API_ROOT } from "@/lib/config"
import type { import type {
AuthState,
Account, CreateAccountInput, Template, CreateTemplateInput, Account, CreateAccountInput, Template, CreateTemplateInput,
Domain, CreateDomainInput, ChangesetResponse, ApplyRequest, Domain, CreateDomainInput, ChangesetResponse, ApplyRequest,
} from "./types" } from "./types"
export class UnauthorizedError extends Error {
constructor() {
super("Unauthorized")
this.name = "UnauthorizedError"
}
}
async function req<T>(path: string, init?: RequestInit): Promise<T> { async function req<T>(path: string, init?: RequestInit): Promise<T> {
const res = await fetch(`${API_BASE}${path}`, { const res = await fetch(path, {
headers: { "Content-Type": "application/json" }, headers: { "Content-Type": "application/json" },
method: "GET", method: "GET",
credentials: "include",
...init, ...init,
}) })
if (res.status === 401) throw new UnauthorizedError()
if (!res.ok) { if (!res.ok) {
let msg = `HTTP ${res.status}` let msg = `HTTP ${res.status}`
try { const b = await res.json(); if (b?.error) msg = String(b.error) } catch { /* ignore */ } try { const b = await res.json(); if (b?.error) msg = String(b.error) } catch { /* ignore */ }
@@ -19,29 +29,52 @@ async function req<T>(path: string, init?: RequestInit): Promise<T> {
return (await res.json()) as T return (await res.json()) as T
} }
export const api = { function projectPath(projectId: string, path: string): string {
listAccounts: () => req<Account[]>("/accounts"), return `${API_ROOT}/projects/${projectId}${path}`
createAccount: (input: CreateAccountInput) => }
req<Account>("/accounts", { method: "POST", body: JSON.stringify(input) }),
deleteAccount: (id: string) => req<void>(`/accounts/${id}`, { method: "DELETE" }), export const api = {
auth: {
listTemplates: () => req<Template[]>("/templates"), register: (email: string, password: string) =>
createTemplate: (input: CreateTemplateInput) => req<AuthState>(`${API_ROOT}/auth/register`, {
req<Template>("/templates", { method: "POST", body: JSON.stringify(input) }), method: "POST",
updateTemplate: (id: string, input: CreateTemplateInput) => body: JSON.stringify({ email, password }),
req<Template>(`/templates/${id}`, { method: "PUT", body: JSON.stringify(input) }), }),
deleteTemplate: (id: string) => req<void>(`/templates/${id}`, { method: "DELETE" }), login: (email: string, password: string) =>
req<AuthState>(`${API_ROOT}/auth/login`, {
listDomains: () => req<Domain[]>("/domains"), method: "POST",
createDomain: (input: CreateDomainInput) => body: JSON.stringify({ email, password }),
req<Domain>("/domains", { method: "POST", body: JSON.stringify(input) }), }),
deleteDomain: (id: string) => req<void>(`/domains/${id}`, { method: "DELETE" }), logout: () => req<void>(`${API_ROOT}/auth/logout`, { method: "POST" }),
importZones: (accountId: string) => me: () => req<AuthState>(`${API_ROOT}/auth/me`),
req<Domain[]>(`/accounts/${accountId}/import`, { method: "POST" }), },
setDomainTemplate: (id: string, templateId: string | null) =>
req<Domain>(`/domains/${id}`, { method: "PATCH", body: JSON.stringify({ templateId }) }), listAccounts: (projectId: string) => req<Account[]>(projectPath(projectId, "/accounts")),
createAccount: (projectId: string, input: CreateAccountInput) =>
checkDomain: (id: string) => req<ChangesetResponse>(`/domains/${id}/check`), req<Account>(projectPath(projectId, "/accounts"), { method: "POST", body: JSON.stringify(input) }),
applyDomain: (id: string, body: ApplyRequest) => deleteAccount: (projectId: string, id: string) =>
req<ChangesetResponse>(`/domains/${id}/apply`, { method: "POST", body: JSON.stringify(body) }), req<void>(projectPath(projectId, `/accounts/${id}`), { method: "DELETE" }),
listTemplates: (projectId: string) => req<Template[]>(projectPath(projectId, "/templates")),
createTemplate: (projectId: string, input: CreateTemplateInput) =>
req<Template>(projectPath(projectId, "/templates"), { method: "POST", body: JSON.stringify(input) }),
updateTemplate: (projectId: string, id: string, input: CreateTemplateInput) =>
req<Template>(projectPath(projectId, `/templates/${id}`), { method: "PUT", body: JSON.stringify(input) }),
deleteTemplate: (projectId: string, id: string) =>
req<void>(projectPath(projectId, `/templates/${id}`), { method: "DELETE" }),
listDomains: (projectId: string) => req<Domain[]>(projectPath(projectId, "/domains")),
createDomain: (projectId: string, input: CreateDomainInput) =>
req<Domain>(projectPath(projectId, "/domains"), { method: "POST", body: JSON.stringify(input) }),
deleteDomain: (projectId: string, id: string) =>
req<void>(projectPath(projectId, `/domains/${id}`), { method: "DELETE" }),
importZones: (projectId: string, accountId: string) =>
req<Domain[]>(projectPath(projectId, `/accounts/${accountId}/import`), { method: "POST" }),
setDomainTemplate: (projectId: string, id: string, templateId: string | null) =>
req<Domain>(projectPath(projectId, `/domains/${id}`), { method: "PATCH", body: JSON.stringify({ templateId }) }),
checkDomain: (projectId: string, id: string) =>
req<ChangesetResponse>(projectPath(projectId, `/domains/${id}/check`)),
applyDomain: (projectId: string, id: string, body: ApplyRequest) =>
req<ChangesetResponse>(projectPath(projectId, `/domains/${id}/apply`), { method: "POST", body: JSON.stringify(body) }),
} }
+4
View File
@@ -1,3 +1,7 @@
export interface User { id: string; email: string }
export interface Project { id: string; name: string }
export interface AuthState { user: User; project: Project }
export interface Account { id: string; provider: string; comment: string } export interface Account { id: string; provider: string; comment: string }
export interface CreateAccountInput { provider: string; secret: string; comment: string } export interface CreateAccountInput { provider: string; secret: string; comment: string }
+126
View File
@@ -0,0 +1,126 @@
import { render, screen, waitFor } from "@testing-library/react"
import userEvent from "@testing-library/user-event"
import { QueryClient, QueryClientProvider } from "@tanstack/react-query"
import { describe, it, expect, vi, beforeEach } from "vitest"
import { AuthProvider, notifyUnauthorized, useAuth } from "./AuthContext"
import { api, UnauthorizedError } from "@/api/client"
const USER = { id: "u1", email: "a@b.com" }
const PROJECT = { id: "p1", name: "Default" }
function Probe() {
const { user, project, loading, login, register, logout } = useAuth()
return (
<div>
<span data-testid="loading">{String(loading)}</span>
<span data-testid="user">{user ? user.email : "none"}</span>
<span data-testid="project">{project ? project.name : "none"}</span>
<button onClick={() => login("a@b.com", "secret")}>login</button>
<button onClick={() => register("a@b.com", "secret")}>register</button>
<button onClick={() => logout()}>logout</button>
</div>
)
}
function renderProbe(qc: QueryClient = new QueryClient()) {
return render(
<QueryClientProvider client={qc}>
<AuthProvider>
<Probe />
</AuthProvider>
</QueryClientProvider>,
)
}
beforeEach(() => {
vi.restoreAllMocks()
})
describe("AuthContext", () => {
it("populates user/project from api.auth.me() on mount", async () => {
vi.spyOn(api.auth, "me").mockResolvedValue({ user: USER, project: PROJECT })
renderProbe()
expect(screen.getByTestId("loading").textContent).toBe("true")
await waitFor(() => expect(screen.getByTestId("loading").textContent).toBe("false"))
expect(screen.getByTestId("user").textContent).toBe(USER.email)
expect(screen.getByTestId("project").textContent).toBe(PROJECT.name)
})
it("treats 401 from api.auth.me() as unauthenticated, not an error", async () => {
vi.spyOn(api.auth, "me").mockRejectedValue(new UnauthorizedError())
renderProbe()
await waitFor(() => expect(screen.getByTestId("loading").textContent).toBe("false"))
expect(screen.getByTestId("user").textContent).toBe("none")
expect(screen.getByTestId("project").textContent).toBe("none")
})
it("login sets user/project in context", async () => {
vi.spyOn(api.auth, "me").mockRejectedValue(new UnauthorizedError())
vi.spyOn(api.auth, "login").mockResolvedValue({ user: USER, project: PROJECT })
const user = userEvent.setup()
renderProbe()
await waitFor(() => expect(screen.getByTestId("loading").textContent).toBe("false"))
await user.click(screen.getByRole("button", { name: "login" }))
await waitFor(() => expect(screen.getByTestId("user").textContent).toBe(USER.email))
expect(screen.getByTestId("project").textContent).toBe(PROJECT.name)
})
it("treats a non-401 error from api.auth.me() as logged-out but logs it for diagnostics", async () => {
const consoleErrorSpy = vi.spyOn(console, "error").mockImplementation(() => {})
const err = new Error("network down")
vi.spyOn(api.auth, "me").mockRejectedValue(err)
renderProbe()
await waitFor(() => expect(screen.getByTestId("loading").textContent).toBe("false"))
expect(screen.getByTestId("user").textContent).toBe("none")
expect(screen.getByTestId("project").textContent).toBe("none")
expect(consoleErrorSpy).toHaveBeenCalledWith(err)
})
it("logout clears user/project from context", async () => {
vi.spyOn(api.auth, "me").mockResolvedValue({ user: USER, project: PROJECT })
vi.spyOn(api.auth, "logout").mockResolvedValue(undefined)
const user = userEvent.setup()
renderProbe()
await waitFor(() => expect(screen.getByTestId("user").textContent).toBe(USER.email))
await user.click(screen.getByRole("button", { name: "logout" }))
await waitFor(() => expect(screen.getByTestId("user").textContent).toBe("none"))
expect(screen.getByTestId("project").textContent).toBe("none")
})
it("logout clears the react-query cache", async () => {
vi.spyOn(api.auth, "me").mockResolvedValue({ user: USER, project: PROJECT })
vi.spyOn(api.auth, "logout").mockResolvedValue(undefined)
const qc = new QueryClient()
const clearSpy = vi.spyOn(qc, "clear")
const user = userEvent.setup()
renderProbe(qc)
await waitFor(() => expect(screen.getByTestId("user").textContent).toBe(USER.email))
await user.click(screen.getByRole("button", { name: "logout" }))
await waitFor(() => expect(clearSpy).toHaveBeenCalled())
})
it("notifyUnauthorized (triggered by any 401 elsewhere in the app) drops the session and clears the cache", async () => {
vi.spyOn(api.auth, "me").mockResolvedValue({ user: USER, project: PROJECT })
const qc = new QueryClient()
const clearSpy = vi.spyOn(qc, "clear")
renderProbe(qc)
await waitFor(() => expect(screen.getByTestId("user").textContent).toBe(USER.email))
notifyUnauthorized()
await waitFor(() => expect(screen.getByTestId("user").textContent).toBe("none"))
expect(screen.getByTestId("project").textContent).toBe("none")
expect(clearSpy).toHaveBeenCalled()
})
})
+113
View File
@@ -0,0 +1,113 @@
import { createContext, useCallback, useContext, useEffect, useState, type ReactNode } from "react"
import { useQueryClient } from "@tanstack/react-query"
import { api, UnauthorizedError } from "@/api/client"
import type { User, Project } from "@/api/types"
export interface AuthContextValue {
user: User | null
project: Project | null
loading: boolean
login: (email: string, password: string) => Promise<void>
register: (email: string, password: string) => Promise<void>
logout: () => Promise<void>
}
const AuthContext = createContext<AuthContextValue | undefined>(undefined)
// AuthProvider registers a handler here so code outside the React tree (the
// QueryClient's QueryCache/MutationCache onError, wired up in main.tsx) can
// report a 401 from *any* query/mutation and have AuthContext drop the
// session — the same "unauthenticated" state ProtectedRoute already reacts
// to. There is exactly one AuthProvider in the app, so a single module-level
// slot is enough; it's registered/unregistered via useEffect below.
type UnauthorizedHandler = () => void
let unauthorizedHandler: UnauthorizedHandler | null = null
export function registerUnauthorizedHandler(handler: UnauthorizedHandler | null) {
unauthorizedHandler = handler
}
export function notifyUnauthorized() {
unauthorizedHandler?.()
}
export function AuthProvider({ children }: { children: ReactNode }) {
const [user, setUser] = useState<User | null>(null)
const [project, setProject] = useState<Project | null>(null)
const [loading, setLoading] = useState(true)
const qc = useQueryClient()
useEffect(() => {
let cancelled = false
api.auth
.me()
.then((state) => {
if (cancelled) return
setUser(state.user)
setProject(state.project)
})
.catch((err) => {
// Unauthenticated (401) — normal logged-out state, no need to log.
// Any other failure (network/500/etc) — still treat as logged-out so
// we don't get stuck in loading, but surface it for diagnostics
// instead of swallowing it silently. Redirect handling is out of
// scope here (Task 6).
if (!(err instanceof UnauthorizedError)) {
console.error(err)
}
if (cancelled) return
setUser(null)
setProject(null)
})
.finally(() => {
if (!cancelled) setLoading(false)
})
return () => {
cancelled = true
}
}, [])
const login = useCallback(async (email: string, password: string) => {
const state = await api.auth.login(email, password)
setUser(state.user)
setProject(state.project)
}, [])
const register = useCallback(async (email: string, password: string) => {
const state = await api.auth.register(email, password)
setUser(state.user)
setProject(state.project)
}, [])
const logout = useCallback(async () => {
await api.auth.logout()
setUser(null)
setProject(null)
qc.clear()
}, [qc])
// Any query/mutation elsewhere in the app that hits a 401 reports it here
// (see notifyUnauthorized/registerUnauthorizedHandler above) — drop the
// session the same way logout() would, so ProtectedRoute redirects to
// /login instead of the UI silently sitting on stale, now-invalid data.
useEffect(() => {
registerUnauthorizedHandler(() => {
setUser(null)
setProject(null)
qc.clear()
})
return () => registerUnauthorizedHandler(null)
}, [qc])
return (
<AuthContext.Provider value={{ user, project, loading, login, register, logout }}>
{children}
</AuthContext.Provider>
)
}
export function useAuth(): AuthContextValue {
const ctx = useContext(AuthContext)
if (!ctx) throw new Error("useAuth must be used within an AuthProvider")
return ctx
}
+57
View File
@@ -0,0 +1,57 @@
import { render, screen } from "@testing-library/react"
import { MemoryRouter, Routes, Route } from "react-router-dom"
import { describe, it, expect, vi } from "vitest"
import { ProtectedRoute } from "./ProtectedRoute"
import * as AuthContextModule from "./AuthContext"
function renderWithAuth(authValue: Partial<AuthContextModule.AuthContextValue>) {
vi.spyOn(AuthContextModule, "useAuth").mockReturnValue({
user: null,
project: null,
loading: false,
login: vi.fn(),
register: vi.fn(),
logout: vi.fn(),
...authValue,
})
return render(
<MemoryRouter initialEntries={["/domains"]}>
<Routes>
<Route path="/login" element={<div>login page</div>} />
<Route
path="/domains"
element={
<ProtectedRoute>
<div>protected content</div>
</ProtectedRoute>
}
/>
</Routes>
</MemoryRouter>,
)
}
describe("ProtectedRoute", () => {
it("показывает спиннер, пока идёт проверка сессии", () => {
renderWithAuth({ user: null, loading: true })
expect(screen.queryByText("protected content")).not.toBeInTheDocument()
expect(screen.queryByText("login page")).not.toBeInTheDocument()
expect(screen.getByRole("status")).toBeInTheDocument()
})
it("редиректит на /login, когда пользователь не авторизован", () => {
renderWithAuth({ user: null, loading: false })
expect(screen.getByText("login page")).toBeInTheDocument()
expect(screen.queryByText("protected content")).not.toBeInTheDocument()
})
it("рендерит children, когда пользователь авторизован", () => {
renderWithAuth({ user: { id: "u1", email: "a@b.com" }, loading: false })
expect(screen.getByText("protected content")).toBeInTheDocument()
expect(screen.queryByText("login page")).not.toBeInTheDocument()
})
})
+24
View File
@@ -0,0 +1,24 @@
import type { ReactNode } from "react"
import { Navigate } from "react-router-dom"
import { Loader2 } from "lucide-react"
import { useAuth } from "@/auth/AuthContext"
export function ProtectedRoute({ children }: { children: ReactNode }) {
const { user, loading } = useAuth()
if (loading) {
return (
<div
role="status"
aria-label="Проверка сессии"
className="flex h-screen w-full items-center justify-center bg-background"
>
<Loader2 className="size-6 animate-spin text-muted-foreground" strokeWidth={1.75} />
</div>
)
}
if (!user) return <Navigate to="/login" replace />
return <>{children}</>
}
+21 -3
View File
@@ -1,6 +1,8 @@
import type { ReactNode } from "react" import type { ReactNode } from "react"
import { NavLink, useLocation } from "react-router-dom" import { NavLink, useLocation, useNavigate } from "react-router-dom"
import { Globe, Users, LayoutTemplate, SquareTerminal } from "lucide-react" import { Globe, LogOut, Users, LayoutTemplate, SquareTerminal } from "lucide-react"
import { useAuth } from "@/auth/AuthContext"
import { Button } from "@/components/ui/button"
import { cn } from "@/lib/utils" import { cn } from "@/lib/utils"
const NAV = [ const NAV = [
@@ -11,6 +13,13 @@ const NAV = [
export function Layout({ children }: { children: ReactNode }) { export function Layout({ children }: { children: ReactNode }) {
const location = useLocation() const location = useLocation()
const navigate = useNavigate()
const { user, logout } = useAuth()
async function onLogout() {
await logout()
navigate("/login", { replace: true })
}
return ( return (
<div className="flex h-screen w-full overflow-hidden bg-background text-foreground"> <div className="flex h-screen w-full overflow-hidden bg-background text-foreground">
@@ -64,10 +73,19 @@ export function Layout({ children }: { children: ReactNode }) {
</aside> </aside>
<div className="flex flex-1 flex-col overflow-hidden"> <div className="flex flex-1 flex-col overflow-hidden">
<header className="flex h-11 shrink-0 items-center border-b border-border px-6"> <header className="flex h-11 shrink-0 items-center justify-between border-b border-border px-6">
<span className="font-dns text-xs text-muted-foreground"> <span className="font-dns text-xs text-muted-foreground">
{location.pathname} {location.pathname}
</span> </span>
{user && (
<div className="flex items-center gap-3">
<span className="font-dns text-xs text-muted-foreground">{user.email}</span>
<Button variant="ghost" size="sm" onClick={onLogout}>
<LogOut className="size-3.5" strokeWidth={1.75} />
Выйти
</Button>
</div>
)}
</header> </header>
<main className="flex-1 overflow-auto">{children}</main> <main className="flex-1 overflow-auto">{children}</main>
</div> </div>
+39
View File
@@ -0,0 +1,39 @@
import { renderHook, waitFor } from "@testing-library/react"
import { QueryClient, QueryClientProvider } from "@tanstack/react-query"
import { describe, it, expect, vi, beforeEach } from "vitest"
import type { ReactNode } from "react"
import { AuthProvider } from "@/auth/AuthContext"
import { api, UnauthorizedError } from "@/api/client"
import { useDeleteAccount } from "./useApi"
beforeEach(() => {
vi.restoreAllMocks()
})
function wrapper({ children }: { children: ReactNode }) {
const qc = new QueryClient({ defaultOptions: { mutations: { retry: false } } })
return (
<QueryClientProvider client={qc}>
<AuthProvider>{children}</AuthProvider>
</QueryClientProvider>
)
}
describe("useApi mutations — null project guard", () => {
it("mutate() without an active project fails with a clear error, not a TypeError", async () => {
// No session yet => AuthContext resolves project to null.
vi.spyOn(api.auth, "me").mockRejectedValue(new UnauthorizedError())
vi.spyOn(api, "deleteAccount")
const { result } = renderHook(() => useDeleteAccount(), { wrapper })
result.current.mutate("acc-1")
await waitFor(() => expect(result.current.isError).toBe(true))
expect(result.current.error).toBeInstanceOf(Error)
expect(result.current.error).not.toBeInstanceOf(TypeError)
expect((result.current.error as Error).message).toBe("no active project")
expect(api.deleteAccount).not.toHaveBeenCalled()
})
})
+89 -27
View File
@@ -1,81 +1,143 @@
import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query" import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query"
import { api } from "@/api/client" import { api } from "@/api/client"
import type { CreateAccountInput, CreateTemplateInput, ApplyRequest } from "@/api/types" import { useAuth } from "@/auth/AuthContext"
import type { CreateAccountInput, CreateTemplateInput, ApplyRequest, Project } from "@/api/types"
function requireProjectId(project: Project | null): string {
if (!project) throw new Error("no active project")
return project.id
}
export function useAccounts() { export function useAccounts() {
return useQuery({ queryKey: ["accounts"], queryFn: api.listAccounts }) const { project } = useAuth()
return useQuery({
queryKey: ["accounts", project?.id],
queryFn: () => api.listAccounts(project!.id),
enabled: !!project,
})
} }
export function useCreateAccount() { export function useCreateAccount() {
const { project } = useAuth()
const qc = useQueryClient() const qc = useQueryClient()
return useMutation({ return useMutation({
mutationFn: (input: CreateAccountInput) => api.createAccount(input), mutationFn: (input: CreateAccountInput) => {
onSuccess: () => qc.invalidateQueries({ queryKey: ["accounts"] }), const pid = requireProjectId(project)
return api.createAccount(pid, input)
},
onSuccess: () => qc.invalidateQueries({ queryKey: ["accounts", project?.id] }),
}) })
} }
export function useDeleteAccount() { export function useDeleteAccount() {
const { project } = useAuth()
const qc = useQueryClient() const qc = useQueryClient()
return useMutation({ return useMutation({
mutationFn: (id: string) => api.deleteAccount(id), mutationFn: (id: string) => {
onSuccess: () => qc.invalidateQueries({ queryKey: ["accounts"] }), const pid = requireProjectId(project)
return api.deleteAccount(pid, id)
},
onSuccess: () => qc.invalidateQueries({ queryKey: ["accounts", project?.id] }),
}) })
} }
export function useTemplates() { export function useTemplates() {
return useQuery({ queryKey: ["templates"], queryFn: api.listTemplates }) const { project } = useAuth()
return useQuery({
queryKey: ["templates", project?.id],
queryFn: () => api.listTemplates(project!.id),
enabled: !!project,
})
} }
export function useCreateTemplate() { export function useCreateTemplate() {
const { project } = useAuth()
const qc = useQueryClient() const qc = useQueryClient()
return useMutation({ return useMutation({
mutationFn: (input: CreateTemplateInput) => api.createTemplate(input), mutationFn: (input: CreateTemplateInput) => {
onSuccess: () => qc.invalidateQueries({ queryKey: ["templates"] }), const pid = requireProjectId(project)
return api.createTemplate(pid, input)
},
onSuccess: () => qc.invalidateQueries({ queryKey: ["templates", project?.id] }),
}) })
} }
export function useUpdateTemplate() { export function useUpdateTemplate() {
const { project } = useAuth()
const qc = useQueryClient() const qc = useQueryClient()
return useMutation({ return useMutation({
mutationFn: ({ id, input }: { id: string; input: CreateTemplateInput }) => api.updateTemplate(id, input), mutationFn: ({ id, input }: { id: string; input: CreateTemplateInput }) => {
onSuccess: () => qc.invalidateQueries({ queryKey: ["templates"] }), const pid = requireProjectId(project)
return api.updateTemplate(pid, id, input)
},
onSuccess: () => qc.invalidateQueries({ queryKey: ["templates", project?.id] }),
}) })
} }
export function useDeleteTemplate() { export function useDeleteTemplate() {
const { project } = useAuth()
const qc = useQueryClient() const qc = useQueryClient()
return useMutation({ return useMutation({
mutationFn: (id: string) => api.deleteTemplate(id), mutationFn: (id: string) => {
onSuccess: () => qc.invalidateQueries({ queryKey: ["templates"] }), const pid = requireProjectId(project)
return api.deleteTemplate(pid, id)
},
onSuccess: () => qc.invalidateQueries({ queryKey: ["templates", project?.id] }),
}) })
} }
export function useDomains() { export function useDomains() {
return useQuery({ queryKey: ["domains"], queryFn: api.listDomains }) const { project } = useAuth()
return useQuery({
queryKey: ["domains", project?.id],
queryFn: () => api.listDomains(project!.id),
enabled: !!project,
})
} }
export function useImportZones() { export function useImportZones() {
const { project } = useAuth()
const qc = useQueryClient() const qc = useQueryClient()
return useMutation({ return useMutation({
mutationFn: (accountId: string) => api.importZones(accountId), mutationFn: (accountId: string) => {
onSuccess: () => qc.invalidateQueries({ queryKey: ["domains"] }), const pid = requireProjectId(project)
return api.importZones(pid, accountId)
},
onSuccess: () => qc.invalidateQueries({ queryKey: ["domains", project?.id] }),
}) })
} }
export function useSetDomainTemplate() { export function useSetDomainTemplate() {
const { project } = useAuth()
const qc = useQueryClient() const qc = useQueryClient()
return useMutation({ return useMutation({
mutationFn: ({ id, templateId }: { id: string; templateId: string | null }) => api.setDomainTemplate(id, templateId), mutationFn: ({ id, templateId }: { id: string; templateId: string | null }) => {
onSuccess: () => qc.invalidateQueries({ queryKey: ["domains"] }), const pid = requireProjectId(project)
return api.setDomainTemplate(pid, id, templateId)
},
onSuccess: () => qc.invalidateQueries({ queryKey: ["domains", project?.id] }),
}) })
} }
export function useDeleteDomain() { export function useDeleteDomain() {
const { project } = useAuth()
const qc = useQueryClient() const qc = useQueryClient()
return useMutation({ return useMutation({
mutationFn: (id: string) => api.deleteDomain(id), mutationFn: (id: string) => {
onSuccess: () => qc.invalidateQueries({ queryKey: ["domains"] }), const pid = requireProjectId(project)
return api.deleteDomain(pid, id)
},
onSuccess: () => qc.invalidateQueries({ queryKey: ["domains", project?.id] }),
}) })
} }
export function useCheckDomain(id: string) { export function useCheckDomain(id: string) {
return useQuery({ queryKey: ["check", id], queryFn: () => api.checkDomain(id), enabled: !!id }) const { project } = useAuth()
} return useQuery({
export function useApplyDomain(id: string) { queryKey: ["check", project?.id, id],
const qc = useQueryClient() queryFn: () => api.checkDomain(project!.id, id),
return useMutation({ enabled: !!project && !!id,
mutationFn: (body: ApplyRequest) => api.applyDomain(id, body), })
onSuccess: () => qc.invalidateQueries({ queryKey: ["check", id] }), }
export function useApplyDomain(id: string) {
const { project } = useAuth()
const qc = useQueryClient()
return useMutation({
mutationFn: (body: ApplyRequest) => {
const pid = requireProjectId(project)
return api.applyDomain(pid, id, body)
},
onSuccess: () => qc.invalidateQueries({ queryKey: ["check", project?.id, id] }),
}) })
} }
+1 -2
View File
@@ -1,2 +1 @@
export const DEFAULT_PROJECT_ID = "00000000-0000-0000-0000-000000000002" export const API_ROOT = "/api/v1"
export const API_BASE = `/api/v1/projects/${DEFAULT_PROJECT_ID}`
+19 -2
View File
@@ -4,18 +4,35 @@ import "@fontsource/ibm-plex-mono/500.css"
import "./index.css" import "./index.css"
import React from "react" import React from "react"
import ReactDOM from "react-dom/client" import ReactDOM from "react-dom/client"
import { QueryClient, QueryClientProvider } from "@tanstack/react-query" import { QueryCache, QueryClient, QueryClientProvider, MutationCache } from "@tanstack/react-query"
import { BrowserRouter } from "react-router-dom" import { BrowserRouter } from "react-router-dom"
import { UnauthorizedError } from "@/api/client"
import { AuthProvider, notifyUnauthorized } from "@/auth/AuthContext"
import { App } from "./App" import { App } from "./App"
const queryClient = new QueryClient() // A 401 from *any* query or mutation means the session died server-side
// (expired/destroyed cookie) — drop it from here rather than requiring every
// hook in useApi.ts to remember to handle it individually. AuthContext reacts
// via notifyUnauthorized (registered by AuthProvider), which resets
// user/project and clears the cache; ProtectedRoute then redirects to
// /login on the next render.
function onQueryError(error: unknown) {
if (error instanceof UnauthorizedError) notifyUnauthorized()
}
const queryClient = new QueryClient({
queryCache: new QueryCache({ onError: onQueryError }),
mutationCache: new MutationCache({ onError: onQueryError }),
})
ReactDOM.createRoot(document.getElementById("root")!).render( ReactDOM.createRoot(document.getElementById("root")!).render(
<React.StrictMode> <React.StrictMode>
<QueryClientProvider client={queryClient}> <QueryClientProvider client={queryClient}>
<AuthProvider>
<BrowserRouter> <BrowserRouter>
<App /> <App />
</BrowserRouter> </BrowserRouter>
</AuthProvider>
</QueryClientProvider> </QueryClientProvider>
</React.StrictMode>, </React.StrictMode>,
) )
+11 -2
View File
@@ -3,10 +3,12 @@ import userEvent from "@testing-library/user-event"
import { MemoryRouter } from "react-router-dom" import { MemoryRouter } from "react-router-dom"
import { QueryClient, QueryClientProvider } from "@tanstack/react-query" import { QueryClient, QueryClientProvider } from "@tanstack/react-query"
import { AccountsPage } from "./AccountsPage" import { AccountsPage } from "./AccountsPage"
import { AuthProvider } from "@/auth/AuthContext"
import { api } from "@/api/client" import { api } from "@/api/client"
import { vi, beforeEach, test, expect } from "vitest" import { vi, beforeEach, test, expect } from "vitest"
import type { Account } from "@/api/types" import type { Account } from "@/api/types"
const PROJECT_ID = "p1"
const accounts: Account[] = [ const accounts: Account[] = [
{ id: "acc1", provider: "selectel", comment: "Main" }, { id: "acc1", provider: "selectel", comment: "Main" },
{ id: "acc2", provider: "selectel", comment: "Backup" }, { id: "acc2", provider: "selectel", comment: "Backup" },
@@ -16,14 +18,21 @@ function renderPage() {
const qc = new QueryClient() const qc = new QueryClient()
return render( return render(
<QueryClientProvider client={qc}> <QueryClientProvider client={qc}>
<AuthProvider>
<MemoryRouter initialEntries={["/accounts"]}> <MemoryRouter initialEntries={["/accounts"]}>
<AccountsPage /> <AccountsPage />
</MemoryRouter> </MemoryRouter>
</AuthProvider>
</QueryClientProvider>, </QueryClientProvider>,
) )
} }
beforeEach(() => { beforeEach(() => {
vi.restoreAllMocks()
vi.spyOn(api.auth, "me").mockResolvedValue({
user: { id: "u1", email: "a@b.com" },
project: { id: PROJECT_ID, name: "Default" },
})
vi.spyOn(api, "listAccounts").mockResolvedValue(accounts) vi.spyOn(api, "listAccounts").mockResolvedValue(accounts)
}) })
@@ -55,7 +64,7 @@ test("форма создания вызывает api.createAccount с введ
await user.click(screen.getByRole("button", { name: /добавить учётку/i })) await user.click(screen.getByRole("button", { name: /добавить учётку/i }))
await waitFor(() => await waitFor(() =>
expect(createSpy).toHaveBeenCalledWith({ expect(createSpy).toHaveBeenCalledWith(PROJECT_ID, {
provider: "selectel", provider: "selectel",
secret: "super-secret-token-123", secret: "super-secret-token-123",
comment: "New account", comment: "New account",
@@ -99,5 +108,5 @@ test("удаление учётки вызывает api.deleteAccount", async (
await user.click(screen.getByRole("button", { name: /удалить.*main/i })) await user.click(screen.getByRole("button", { name: /удалить.*main/i }))
await waitFor(() => expect(deleteSpy).toHaveBeenCalledWith("acc1")) await waitFor(() => expect(deleteSpy).toHaveBeenCalledWith(PROJECT_ID, "acc1"))
}) })
+16 -3
View File
@@ -3,20 +3,33 @@ import userEvent from "@testing-library/user-event"
import { MemoryRouter, Routes, Route } from "react-router-dom" import { MemoryRouter, Routes, Route } from "react-router-dom"
import { QueryClient, QueryClientProvider } from "@tanstack/react-query" import { QueryClient, QueryClientProvider } from "@tanstack/react-query"
import { DomainDiffPage } from "./DomainDiffPage" import { DomainDiffPage } from "./DomainDiffPage"
import { AuthProvider } from "@/auth/AuthContext"
import { api } from "@/api/client" import { api } from "@/api/client"
import { vi } from "vitest" import { vi, beforeEach } from "vitest"
const PROJECT_ID = "p1"
function renderPage() { function renderPage() {
const qc = new QueryClient() const qc = new QueryClient()
return render( return render(
<QueryClientProvider client={qc}> <QueryClientProvider client={qc}>
<AuthProvider>
<MemoryRouter initialEntries={["/domains/d1"]}> <MemoryRouter initialEntries={["/domains/d1"]}>
<Routes><Route path="/domains/:id" element={<DomainDiffPage />} /></Routes> <Routes><Route path="/domains/:id" element={<DomainDiffPage />} /></Routes>
</MemoryRouter> </MemoryRouter>
</AuthProvider>
</QueryClientProvider>, </QueryClientProvider>,
) )
} }
beforeEach(() => {
vi.restoreAllMocks()
vi.spyOn(api.auth, "me").mockResolvedValue({
user: { id: "u1", email: "a@b.com" },
project: { id: PROJECT_ID, name: "Default" },
})
})
test("apply sends applyPrunes=false by default, true only after opting in", async () => { test("apply sends applyPrunes=false by default, true only after opting in", async () => {
vi.spyOn(api, "checkDomain").mockResolvedValue({ vi.spyOn(api, "checkDomain").mockResolvedValue({
updates: [{ kind: "update", type: "A", name: "a.", desired: ["1"], actual: ["2"], readOnly: false }], updates: [{ kind: "update", type: "A", name: "a.", desired: ["1"], actual: ["2"], readOnly: false }],
@@ -30,12 +43,12 @@ test("apply sends applyPrunes=false by default, true only after opting in", asyn
const applyBtn = await screen.findByRole("button", { name: /apply/i }) const applyBtn = await screen.findByRole("button", { name: /apply/i })
await user.click(applyBtn) await user.click(applyBtn)
await waitFor(() => expect(applySpy).toHaveBeenCalled()) await waitFor(() => expect(applySpy).toHaveBeenCalled())
expect(applySpy.mock.calls[0][1]).toEqual({ applyUpdates: true, applyPrunes: false }) expect(applySpy.mock.calls[0]).toEqual([PROJECT_ID, "d1", { applyUpdates: true, applyPrunes: false }])
// включить prune и применить снова // включить prune и применить снова
const pruneToggle = screen.getByRole("checkbox", { name: /prune|удал/i }) const pruneToggle = screen.getByRole("checkbox", { name: /prune|удал/i })
await user.click(pruneToggle) await user.click(pruneToggle)
await user.click(screen.getByRole("button", { name: /apply/i })) await user.click(screen.getByRole("button", { name: /apply/i }))
await waitFor(() => expect(applySpy).toHaveBeenCalledTimes(2)) await waitFor(() => expect(applySpy).toHaveBeenCalledTimes(2))
expect(applySpy.mock.calls[1][1]).toEqual({ applyUpdates: true, applyPrunes: true }) expect(applySpy.mock.calls[1]).toEqual([PROJECT_ID, "d1", { applyUpdates: true, applyPrunes: true }])
}) })
+11 -2
View File
@@ -3,10 +3,12 @@ import userEvent from "@testing-library/user-event"
import { MemoryRouter, Routes, Route } from "react-router-dom" import { MemoryRouter, Routes, Route } from "react-router-dom"
import { QueryClient, QueryClientProvider } from "@tanstack/react-query" import { QueryClient, QueryClientProvider } from "@tanstack/react-query"
import { DomainsPage } from "./DomainsPage" import { DomainsPage } from "./DomainsPage"
import { AuthProvider } from "@/auth/AuthContext"
import { api } from "@/api/client" import { api } from "@/api/client"
import { vi, beforeEach, test, expect } from "vitest" import { vi, beforeEach, test, expect } from "vitest"
import type { Account, Domain, Template } from "@/api/types" import type { Account, Domain, Template } from "@/api/types"
const PROJECT_ID = "p1"
const accounts: Account[] = [ const accounts: Account[] = [
{ id: "acc1", provider: "selectel", comment: "Main" }, { id: "acc1", provider: "selectel", comment: "Main" },
{ id: "acc2", provider: "cloudflare", comment: "Backup" }, { id: "acc2", provider: "cloudflare", comment: "Backup" },
@@ -24,17 +26,24 @@ function renderPage() {
const qc = new QueryClient() const qc = new QueryClient()
return render( return render(
<QueryClientProvider client={qc}> <QueryClientProvider client={qc}>
<AuthProvider>
<MemoryRouter initialEntries={["/domains"]}> <MemoryRouter initialEntries={["/domains"]}>
<Routes> <Routes>
<Route path="/domains" element={<DomainsPage />} /> <Route path="/domains" element={<DomainsPage />} />
<Route path="/domains/:id" element={<div>diff page</div>} /> <Route path="/domains/:id" element={<div>diff page</div>} />
</Routes> </Routes>
</MemoryRouter> </MemoryRouter>
</AuthProvider>
</QueryClientProvider>, </QueryClientProvider>,
) )
} }
beforeEach(() => { beforeEach(() => {
vi.restoreAllMocks()
vi.spyOn(api.auth, "me").mockResolvedValue({
user: { id: "u1", email: "a@b.com" },
project: { id: PROJECT_ID, name: "Default" },
})
vi.spyOn(api, "listDomains").mockResolvedValue(domains) vi.spyOn(api, "listDomains").mockResolvedValue(domains)
vi.spyOn(api, "listAccounts").mockResolvedValue(accounts) vi.spyOn(api, "listAccounts").mockResolvedValue(accounts)
vi.spyOn(api, "listTemplates").mockResolvedValue(templates) vi.spyOn(api, "listTemplates").mockResolvedValue(templates)
@@ -64,7 +73,7 @@ test("кнопка импорта вызывает api.importZones с выбра
await user.click(screen.getByRole("button", { name: /импортировать зоны/i })) await user.click(screen.getByRole("button", { name: /импортировать зоны/i }))
await waitFor(() => expect(importSpy).toHaveBeenCalledWith("acc2")) await waitFor(() => expect(importSpy).toHaveBeenCalledWith(PROJECT_ID, "acc2"))
}) })
test("привязка шаблона в строке домена вызывает api.setDomainTemplate", async () => { test("привязка шаблона в строке домена вызывает api.setDomainTemplate", async () => {
@@ -77,7 +86,7 @@ test("привязка шаблона в строке домена вызыва
await user.click(screen.getByRole("combobox", { name: /example\.com\./i })) await user.click(screen.getByRole("combobox", { name: /example\.com\./i }))
await user.click(await screen.findByRole("option", { name: /^standard$/i })) await user.click(await screen.findByRole("option", { name: /^standard$/i }))
await waitFor(() => expect(setTemplateSpy).toHaveBeenCalledWith("d1", "t1")) await waitFor(() => expect(setTemplateSpy).toHaveBeenCalledWith(PROJECT_ID, "d1", "t1"))
}) })
test("ошибка привязки шаблона отображается пользователю", async () => { test("ошибка привязки шаблона отображается пользователю", async () => {
+101
View File
@@ -0,0 +1,101 @@
import { act, render, screen, waitFor } from "@testing-library/react"
import userEvent from "@testing-library/user-event"
import { MemoryRouter, Routes, Route } from "react-router-dom"
import { QueryClient, QueryClientProvider } from "@tanstack/react-query"
import { describe, it, expect, vi, beforeEach } from "vitest"
import { LoginPage } from "./LoginPage"
import { AuthProvider } from "@/auth/AuthContext"
import { api, UnauthorizedError } from "@/api/client"
function renderPage() {
const qc = new QueryClient()
return render(
<QueryClientProvider client={qc}>
<AuthProvider>
<MemoryRouter initialEntries={["/login"]}>
<Routes>
<Route path="/login" element={<LoginPage />} />
<Route path="/register" element={<div>register page</div>} />
<Route path="/domains" element={<div>domains page</div>} />
</Routes>
</MemoryRouter>
</AuthProvider>
</QueryClientProvider>,
)
}
beforeEach(() => {
vi.restoreAllMocks()
vi.spyOn(api.auth, "me").mockRejectedValue(new UnauthorizedError())
})
describe("LoginPage", () => {
it("ввод email+пароль и сабмит вызывает useAuth().login с введёнными данными", async () => {
const loginSpy = vi.spyOn(api.auth, "login").mockResolvedValue({
user: { id: "u1", email: "a@b.com" },
project: { id: "p1", name: "Default" },
})
const user = userEvent.setup()
renderPage()
// AuthProvider resolves the session check (api.auth.me) asynchronously;
// the form only renders once loading flips to false.
await user.type(await screen.findByLabelText(/email/i), "a@b.com")
await user.type(screen.getByLabelText(/пароль/i), "secret123")
await user.click(screen.getByRole("button", { name: /войти/i }))
await waitFor(() => expect(loginSpy).toHaveBeenCalledWith("a@b.com", "secret123"))
expect(await screen.findByText("domains page")).toBeInTheDocument()
})
it("ошибка входа отображается пользователю через role=alert", async () => {
vi.spyOn(api.auth, "login").mockRejectedValue(new Error("Неверный email или пароль"))
const user = userEvent.setup()
renderPage()
await user.type(await screen.findByLabelText(/email/i), "a@b.com")
await user.type(screen.getByLabelText(/пароль/i), "wrong-password")
await user.click(screen.getByRole("button", { name: /войти/i }))
expect(await screen.findByRole("alert")).toHaveTextContent("Неверный email или пароль")
})
it("не рендерит форму логина, пока сессия (api.auth.me) не резолвнута", async () => {
let rejectMe!: (err: unknown) => void
vi.spyOn(api.auth, "me").mockImplementation(
() =>
new Promise((_resolve, reject) => {
rejectMe = reject
}),
)
renderPage()
expect(screen.queryByRole("button", { name: /войти/i })).not.toBeInTheDocument()
expect(screen.queryByLabelText(/email/i)).not.toBeInTheDocument()
// Resolve the pending me() so the test doesn't leak an unhandled rejection.
await act(async () => {
rejectMe(new UnauthorizedError())
})
})
it("сетевая ошибка при логине показывает «Сервис недоступен»", async () => {
vi.spyOn(api.auth, "login").mockRejectedValue(new TypeError("Failed to fetch"))
const user = userEvent.setup()
renderPage()
await user.type(await screen.findByLabelText(/email/i), "a@b.com")
await user.type(screen.getByLabelText(/пароль/i), "wrong-password")
await user.click(screen.getByRole("button", { name: /войти/i }))
expect(await screen.findByRole("alert")).toHaveTextContent("Сервис недоступен, попробуйте позже")
})
it("содержит ссылку на регистрацию", async () => {
renderPage()
const link = await screen.findByRole("link", { name: /зарегистрир/i })
expect(link).toHaveAttribute("href", "/register")
})
})
+169
View File
@@ -0,0 +1,169 @@
import { useId, useState } from "react"
import { Controller, useForm } from "react-hook-form"
import { zodResolver } from "@hookform/resolvers/zod"
import { z } from "zod"
import { Link, Navigate } from "react-router-dom"
import { KeyRound, Loader2, LogIn, SquareTerminal } from "lucide-react"
import { useAuth } from "@/auth/AuthContext"
import { UnauthorizedError } from "@/api/client"
import { Button } from "@/components/ui/button"
import { Input } from "@/components/ui/input"
import {
Field,
FieldContent,
FieldDescription,
FieldError,
FieldGroup,
FieldLabel,
FieldSet,
} from "@/components/ui/field"
const loginSchema = z.object({
email: z.string().trim().min(1, "Укажите email").email("Некорректный email"),
password: z.string().min(1, "Укажите пароль"),
})
type LoginForm = z.infer<typeof loginSchema>
// describeLoginError turns a login() rejection into user-facing Russian
// copy. A network failure (TypeError from fetch itself) or a 5xx response
// means the service is unreachable/broken — that's a different situation
// from wrong credentials and should say so. Everything else (401
// UnauthorizedError, or a backend "invalid credentials" message) reads as a
// bad email/password.
function describeLoginError(err: unknown): string {
const isNetworkOrServerError =
err instanceof TypeError || (err instanceof Error && err.message.startsWith("HTTP 5"))
if (isNetworkOrServerError) return "Сервис недоступен, попробуйте позже"
const isInvalidCredentials =
err instanceof UnauthorizedError ||
(err instanceof Error && /invalid credentials/i.test(err.message))
if (isInvalidCredentials) return "Неверный email или пароль"
return "Неверный email или пароль"
}
export function LoginPage() {
const { user, loading, login } = useAuth()
const [authError, setAuthError] = useState<string | null>(null)
const emailFieldId = useId()
const passwordFieldId = useId()
const {
control,
handleSubmit,
formState: { errors, isSubmitting },
} = useForm<LoginForm>({
resolver: zodResolver(loginSchema),
defaultValues: { email: "", password: "" },
})
// Session check (api.auth.me()) hasn't resolved yet — don't flash the
// login form for a visitor who turns out to already have a valid session.
if (loading) return null
// Already authenticated (fresh session on mount, or just logged in below) —
// don't show the login form, go straight to the app.
if (user) return <Navigate to="/domains" replace />
async function onSubmit(values: LoginForm) {
setAuthError(null)
try {
await login(values.email, values.password)
} catch (err) {
setAuthError(describeLoginError(err))
}
}
return (
<div className="flex h-screen w-full items-center justify-center bg-background px-4">
<div className="flex w-full max-w-sm flex-col gap-6">
<div className="flex flex-col items-center gap-2 text-center">
<SquareTerminal className="size-6 text-primary" strokeWidth={1.75} />
<div className="flex flex-col leading-none">
<span className="text-sm font-semibold tracking-tight">DNS Autoresolver</span>
<span className="font-dns text-[10px] tracking-wider text-muted-foreground uppercase">
console
</span>
</div>
</div>
<form
onSubmit={handleSubmit(onSubmit)}
noValidate
className="flex flex-col gap-4 rounded-xl border border-border bg-card/60 p-5"
>
<FieldSet className="gap-3">
<FieldGroup className="gap-3">
<Field>
<FieldLabel htmlFor={emailFieldId}>Email</FieldLabel>
<FieldContent>
<Controller
control={control}
name="email"
render={({ field }) => (
<Input
{...field}
id={emailFieldId}
type="email"
autoComplete="email"
placeholder="you@example.com"
aria-invalid={!!errors.email}
/>
)}
/>
<FieldError errors={[errors.email]} />
</FieldContent>
</Field>
<Field>
<FieldLabel htmlFor={passwordFieldId}>Пароль</FieldLabel>
<FieldContent>
<Controller
control={control}
name="password"
render={({ field }) => (
<Input
{...field}
id={passwordFieldId}
type="password"
autoComplete="current-password"
placeholder="••••••••••••"
aria-invalid={!!errors.password}
/>
)}
/>
<FieldError errors={[errors.password]} />
</FieldContent>
</Field>
</FieldGroup>
{authError && (
<FieldDescription role="alert" className="flex items-center gap-2 text-destructive">
<KeyRound className="size-3.5 shrink-0" strokeWidth={1.75} />
{authError}
</FieldDescription>
)}
</FieldSet>
<Button type="submit" disabled={isSubmitting} className="w-full">
{isSubmitting ? (
<Loader2 className="size-4 animate-spin" strokeWidth={1.75} />
) : (
<LogIn className="size-4" strokeWidth={1.75} />
)}
Войти
</Button>
</form>
<p className="text-center text-sm text-muted-foreground">
Нет учётной записи?{" "}
<Link to="/register" className="text-primary underline-offset-4 hover:underline">
Зарегистрироваться
</Link>
</p>
</div>
</div>
)
}
+164
View File
@@ -0,0 +1,164 @@
import { useId, useState } from "react"
import { Controller, useForm } from "react-hook-form"
import { zodResolver } from "@hookform/resolvers/zod"
import { z } from "zod"
import { Link, Navigate } from "react-router-dom"
import { KeyRound, Loader2, SquareTerminal, UserPlus } from "lucide-react"
import { useAuth } from "@/auth/AuthContext"
import { Button } from "@/components/ui/button"
import { Input } from "@/components/ui/input"
import {
Field,
FieldContent,
FieldDescription,
FieldError,
FieldGroup,
FieldLabel,
FieldSet,
} from "@/components/ui/field"
const registerSchema = z.object({
email: z.string().trim().min(1, "Укажите email").email("Некорректный email"),
password: z.string().min(8, "Минимум 8 символов"),
})
type RegisterForm = z.infer<typeof registerSchema>
// describeRegisterError turns a register() rejection into user-facing
// Russian copy. A network failure (TypeError from fetch itself) or a 5xx
// response means the service is unreachable/broken, not a validation
// problem — surface that distinctly instead of an opaque "HTTP 500". Any
// other error (409 email taken, 400 password too short, etc.) already
// carries a specific backend message worth showing as-is.
function describeRegisterError(err: unknown): string {
const isNetworkOrServerError =
err instanceof TypeError || (err instanceof Error && err.message.startsWith("HTTP 5"))
if (isNetworkOrServerError) return "Сервис недоступен, попробуйте позже"
return err instanceof Error ? err.message : "Не удалось зарегистрироваться"
}
export function RegisterPage() {
const { user, loading, register: registerUser } = useAuth()
const [authError, setAuthError] = useState<string | null>(null)
const emailFieldId = useId()
const passwordFieldId = useId()
const {
control,
handleSubmit,
formState: { errors, isSubmitting },
} = useForm<RegisterForm>({
resolver: zodResolver(registerSchema),
defaultValues: { email: "", password: "" },
})
// Session check (api.auth.me()) hasn't resolved yet — don't flash the
// registration form for a visitor who turns out to already have a valid
// session.
if (loading) return null
// Already authenticated — skip straight to the app instead of showing the
// registration form again.
if (user) return <Navigate to="/domains" replace />
async function onSubmit(values: RegisterForm) {
setAuthError(null)
try {
await registerUser(values.email, values.password)
} catch (err) {
setAuthError(describeRegisterError(err))
}
}
return (
<div className="flex h-screen w-full items-center justify-center bg-background px-4">
<div className="flex w-full max-w-sm flex-col gap-6">
<div className="flex flex-col items-center gap-2 text-center">
<SquareTerminal className="size-6 text-primary" strokeWidth={1.75} />
<div className="flex flex-col leading-none">
<span className="text-sm font-semibold tracking-tight">DNS Autoresolver</span>
<span className="font-dns text-[10px] tracking-wider text-muted-foreground uppercase">
console
</span>
</div>
</div>
<form
onSubmit={handleSubmit(onSubmit)}
noValidate
className="flex flex-col gap-4 rounded-xl border border-border bg-card/60 p-5"
>
<FieldSet className="gap-3">
<FieldGroup className="gap-3">
<Field>
<FieldLabel htmlFor={emailFieldId}>Email</FieldLabel>
<FieldContent>
<Controller
control={control}
name="email"
render={({ field }) => (
<Input
{...field}
id={emailFieldId}
type="email"
autoComplete="email"
placeholder="you@example.com"
aria-invalid={!!errors.email}
/>
)}
/>
<FieldError errors={[errors.email]} />
</FieldContent>
</Field>
<Field>
<FieldLabel htmlFor={passwordFieldId}>Пароль</FieldLabel>
<FieldContent>
<Controller
control={control}
name="password"
render={({ field }) => (
<Input
{...field}
id={passwordFieldId}
type="password"
autoComplete="new-password"
placeholder="••••••••••••"
aria-invalid={!!errors.password}
/>
)}
/>
<FieldError errors={[errors.password]} />
</FieldContent>
</Field>
</FieldGroup>
{authError && (
<FieldDescription role="alert" className="flex items-center gap-2 text-destructive">
<KeyRound className="size-3.5 shrink-0" strokeWidth={1.75} />
{authError}
</FieldDescription>
)}
</FieldSet>
<Button type="submit" disabled={isSubmitting} className="w-full">
{isSubmitting ? (
<Loader2 className="size-4 animate-spin" strokeWidth={1.75} />
) : (
<UserPlus className="size-4" strokeWidth={1.75} />
)}
Зарегистрироваться
</Button>
</form>
<p className="text-center text-sm text-muted-foreground">
Уже есть аккаунт?{" "}
<Link to="/login" className="text-primary underline-offset-4 hover:underline">
Войти
</Link>
</p>
</div>
</div>
)
}
+12 -3
View File
@@ -3,10 +3,12 @@ import userEvent from "@testing-library/user-event"
import { MemoryRouter } from "react-router-dom" import { MemoryRouter } from "react-router-dom"
import { QueryClient, QueryClientProvider } from "@tanstack/react-query" import { QueryClient, QueryClientProvider } from "@tanstack/react-query"
import { TemplatesPage } from "./TemplatesPage" import { TemplatesPage } from "./TemplatesPage"
import { AuthProvider } from "@/auth/AuthContext"
import { api } from "@/api/client" import { api } from "@/api/client"
import { vi, beforeEach, test, expect } from "vitest" import { vi, beforeEach, test, expect } from "vitest"
import type { Template } from "@/api/types" import type { Template } from "@/api/types"
const PROJECT_ID = "p1"
const templates: Template[] = [ const templates: Template[] = [
{ {
id: "t1", id: "t1",
@@ -21,14 +23,21 @@ function renderPage() {
const qc = new QueryClient() const qc = new QueryClient()
return render( return render(
<QueryClientProvider client={qc}> <QueryClientProvider client={qc}>
<AuthProvider>
<MemoryRouter initialEntries={["/templates"]}> <MemoryRouter initialEntries={["/templates"]}>
<TemplatesPage /> <TemplatesPage />
</MemoryRouter> </MemoryRouter>
</AuthProvider>
</QueryClientProvider>, </QueryClientProvider>,
) )
} }
beforeEach(() => { beforeEach(() => {
vi.restoreAllMocks()
vi.spyOn(api.auth, "me").mockResolvedValue({
user: { id: "u1", email: "a@b.com" },
project: { id: PROJECT_ID, name: "Default" },
})
vi.spyOn(api, "listTemplates").mockResolvedValue(templates) vi.spyOn(api, "listTemplates").mockResolvedValue(templates)
}) })
@@ -65,7 +74,7 @@ test("создание шаблона с записью вызывает api.cre
await user.click(screen.getByRole("button", { name: /сохранить шаблон/i })) await user.click(screen.getByRole("button", { name: /сохранить шаблон/i }))
await waitFor(() => await waitFor(() =>
expect(createSpy).toHaveBeenCalledWith({ expect(createSpy).toHaveBeenCalledWith(PROJECT_ID, {
name: "New", name: "New",
records: [{ type: "A", name: "www", ttl: 3600, values: ["1.1.1.1"] }], records: [{ type: "A", name: "www", ttl: 3600, values: ["1.1.1.1"] }],
}), }),
@@ -88,7 +97,7 @@ test("редактирование шаблона вызывает api.updateTem
await user.click(screen.getByRole("button", { name: /сохранить шаблон/i })) await user.click(screen.getByRole("button", { name: /сохранить шаблон/i }))
await waitFor(() => await waitFor(() =>
expect(updateSpy).toHaveBeenCalledWith("t1", { expect(updateSpy).toHaveBeenCalledWith(PROJECT_ID, "t1", {
name: "Standard v2", name: "Standard v2",
records: [{ type: "A", name: "@", ttl: 3600, values: ["1.2.3.4"] }], records: [{ type: "A", name: "@", ttl: 3600, values: ["1.2.3.4"] }],
}), }),
@@ -105,7 +114,7 @@ test("удаление шаблона вызывает api.deleteTemplate", asyn
await user.click(screen.getByRole("button", { name: /удалить шаблон standard/i })) await user.click(screen.getByRole("button", { name: /удалить шаблон standard/i }))
await waitFor(() => expect(deleteSpy).toHaveBeenCalledWith("t1")) await waitFor(() => expect(deleteSpy).toHaveBeenCalledWith(PROJECT_ID, "t1"))
}) })
test("ошибка создания шаблона отображается пользователю", async () => { test("ошибка создания шаблона отображается пользователю", async () => {