merge: Фаза 2 — авторизация
- internal/store: миграция sessions/password + методы users/sessions/projects
- internal/auth: argon2id пароли + session store (sha256 токена)
- internal/api: auth-хендлеры (register/login/logout/me) + cookie, RequireAuth+RequireProjectAccess middleware
- IDOR закрыт: все /projects/{pid}/* под middleware, LoadDomainFull scoped, projectID из контекста
- web: AuthContext + клиент под cookie, Login/Register, protected routes, logout, 401→/login
Финальный ревью: READY TO MERGE, IDOR закрыт end-to-end. Go 105+/15 пакетов, web 58 тестов.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
+8
-1
@@ -5,10 +5,12 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/vasyakrg/dns-autoresolver/internal/api"
|
||||
"github.com/vasyakrg/dns-autoresolver/internal/auth"
|
||||
"github.com/vasyakrg/dns-autoresolver/internal/config"
|
||||
"github.com/vasyakrg/dns-autoresolver/internal/crypto"
|
||||
"github.com/vasyakrg/dns-autoresolver/internal/provider/registry"
|
||||
@@ -18,6 +20,10 @@ import (
|
||||
"github.com/vasyakrg/dns-autoresolver/internal/web"
|
||||
)
|
||||
|
||||
// sessionTTL is how long a login session cookie remains valid before the
|
||||
// user must re-authenticate.
|
||||
const sessionTTL = 720 * time.Hour
|
||||
|
||||
// isAPIPath reports whether path must be routed to the API router rather
|
||||
// than the SPA. "/api" (no trailing slash) counts as an API path too —
|
||||
// only strings.HasPrefix(path, "/api/") would otherwise miss it and fall
|
||||
@@ -46,12 +52,13 @@ func main() {
|
||||
log.Fatalf("cipher: %v", err)
|
||||
}
|
||||
st := store.New(pool)
|
||||
sessions := auth.NewSessions(st, sessionTTL)
|
||||
|
||||
reg := registry.New()
|
||||
reg.Register(selectel.New())
|
||||
|
||||
svc := service.New(st, st, reg, cipher)
|
||||
a := &api.API{Svc: svc, Store: st, Cipher: cipher, Reg: reg}
|
||||
a := &api.API{Svc: svc, Store: st, Cipher: cipher, Reg: reg, Auth: st, Sessions: sessions}
|
||||
apiRouter := api.NewRouter(a)
|
||||
|
||||
webHandler, err := web.Handler()
|
||||
|
||||
@@ -9,6 +9,7 @@ require (
|
||||
github.com/pressly/goose/v3 v3.27.2
|
||||
github.com/testcontainers/testcontainers-go v0.43.0
|
||||
github.com/testcontainers/testcontainers-go/modules/postgres v0.43.0
|
||||
golang.org/x/crypto v0.53.0
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -64,9 +65,8 @@ require (
|
||||
go.opentelemetry.io/otel/metric v1.43.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.43.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/crypto v0.52.0 // indirect
|
||||
golang.org/x/sync v0.21.0 // indirect
|
||||
golang.org/x/sys v0.45.0 // indirect
|
||||
golang.org/x/text v0.37.0 // indirect
|
||||
golang.org/x/sys v0.46.0 // indirect
|
||||
golang.org/x/text v0.38.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@@ -150,19 +150,19 @@ go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09
|
||||
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
|
||||
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
golang.org/x/crypto v0.52.0 h1:RMs7fP2rXdep0CftQlK8Uf+kibLm7qkCcradZWYz988=
|
||||
golang.org/x/crypto v0.52.0/go.mod h1:1QgfPxDqh0T2M/elOJtp9RvuR95kVjir0e6/BvEmGbc=
|
||||
golang.org/x/crypto v0.53.0 h1:QZ4Muo8THX6CizN2vPPd5fBGHyogrdK9fG4wLPFUsto=
|
||||
golang.org/x/crypto v0.53.0/go.mod h1:DNLU434OwVakk9PzuwV8w62mAJpRJL3vsgcfp4Qnsio=
|
||||
golang.org/x/sync v0.21.0 h1:HLII4xRRTtCRkxYp4HNFF0Js/Og6q2i++KXbg0gHCwM=
|
||||
golang.org/x/sync v0.21.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY=
|
||||
golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4=
|
||||
golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk=
|
||||
golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
|
||||
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
|
||||
golang.org/x/sys v0.46.0 h1:noSf2Fq6F8DBgS+LysIkx7rIExoNHJsxOAtPp4rthXw=
|
||||
golang.org/x/sys v0.46.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.44.0 h1:0rLvDRCtNj0gZkyIXhCyOb2OAzEhLVqc4B+hrsBhrmc=
|
||||
golang.org/x/term v0.44.0/go.mod h1:7ze4MdzUzLXpSAoFP1H0bOI9aXDqveSvatT5vKcFh2Y=
|
||||
golang.org/x/text v0.38.0 h1:sXmwo9DwP3OK9EZ7PqAdaooSGozfl/3a6/xJcbzPRhE=
|
||||
golang.org/x/text v0.38.0/go.mod h1:YXZt3QhHUKYT53r2lLKFIVi6Ao1jdzrTR/KQ09qyxF4=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
|
||||
+40
-2
@@ -3,6 +3,7 @@ package api
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
@@ -17,8 +18,8 @@ import (
|
||||
|
||||
// CheckApplier is the service surface the API depends on.
|
||||
type CheckApplier interface {
|
||||
Check(ctx context.Context, domainID uuid.UUID) (diff.Changeset, error)
|
||||
Apply(ctx context.Context, domainID uuid.UUID, req service.ApplyRequest) (diff.Changeset, error)
|
||||
Check(ctx context.Context, projectID, domainID uuid.UUID) (diff.Changeset, error)
|
||||
Apply(ctx context.Context, projectID, domainID uuid.UUID, req service.ApplyRequest) (diff.Changeset, error)
|
||||
}
|
||||
|
||||
// TenantStore is the narrow persistence surface the CRUD handlers depend on.
|
||||
@@ -54,12 +55,36 @@ type ProviderRegistry interface {
|
||||
ByName(name string) (provider.Provider, error)
|
||||
}
|
||||
|
||||
// AuthStore is the persistence surface the auth handlers depend on.
|
||||
// *store.Store satisfies it directly (see internal/store/store.go); tests
|
||||
// can supply their own mock.
|
||||
type AuthStore interface {
|
||||
RegisterUser(ctx context.Context, email, passwordHash string) (store.User, store.Project, error)
|
||||
GetUserByEmail(ctx context.Context, email string) (store.User, error)
|
||||
GetUserByID(ctx context.Context, userID uuid.UUID) (store.User, error)
|
||||
GetUserProject(ctx context.Context, userID uuid.UUID) (store.Project, error)
|
||||
// GetProjectOwned looks up projectID and returns it only if it's owned by
|
||||
// userID — RequireProjectAccess uses this to reject foreign/nonexistent
|
||||
// projects with 404 before any handler runs.
|
||||
GetProjectOwned(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error)
|
||||
}
|
||||
|
||||
// SessionManager creates/validates/destroys login sessions. *auth.Sessions
|
||||
// satisfies it directly (see internal/auth/session.go).
|
||||
type SessionManager interface {
|
||||
Create(ctx context.Context, userID uuid.UUID) (string, time.Time, error)
|
||||
Validate(ctx context.Context, token string) (uuid.UUID, error)
|
||||
Destroy(ctx context.Context, token string) error
|
||||
}
|
||||
|
||||
// API holds handler dependencies.
|
||||
type API struct {
|
||||
Svc CheckApplier
|
||||
Store TenantStore
|
||||
Cipher Cipher
|
||||
Reg ProviderRegistry
|
||||
Auth AuthStore
|
||||
Sessions SessionManager
|
||||
}
|
||||
|
||||
func NewRouter(a *API) http.Handler {
|
||||
@@ -67,7 +92,20 @@ func NewRouter(a *API) http.Handler {
|
||||
r.Use(middleware.RequestID)
|
||||
r.Use(middleware.Recoverer)
|
||||
|
||||
r.Route("/api/v1/auth", func(r chi.Router) {
|
||||
r.Post("/register", a.handleRegister)
|
||||
r.Post("/login", a.handleLogin)
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(a.RequireAuth)
|
||||
r.Post("/logout", a.handleLogout)
|
||||
r.Get("/me", a.handleMe)
|
||||
})
|
||||
})
|
||||
|
||||
r.Route("/api/v1/projects/{pid}", func(r chi.Router) {
|
||||
r.Use(a.RequireAuth)
|
||||
r.Use(a.RequireProjectAccess)
|
||||
|
||||
r.Route("/domains", func(r chi.Router) {
|
||||
r.Post("/", a.handleCreateDomain)
|
||||
r.Get("/", a.handleListDomains)
|
||||
|
||||
@@ -20,18 +20,23 @@ type mockCheckApplier struct {
|
||||
lastReq service.ApplyRequest
|
||||
}
|
||||
|
||||
func (m *mockCheckApplier) Check(context.Context, uuid.UUID) (diff.Changeset, error) {
|
||||
func (m *mockCheckApplier) Check(context.Context, uuid.UUID, uuid.UUID) (diff.Changeset, error) {
|
||||
d := model.Record{Type: model.A, Name: "a.example.com.", TTL: 300, Values: []string{"1.1.1.1"}}
|
||||
return diff.Changeset{Diffs: []diff.RecordDiff{{Kind: diff.Add, Type: d.Type, Name: d.Name, Desired: &d}}}, nil
|
||||
}
|
||||
func (m *mockCheckApplier) Apply(_ context.Context, _ uuid.UUID, req service.ApplyRequest) (diff.Changeset, error) {
|
||||
func (m *mockCheckApplier) Apply(_ context.Context, _, _ uuid.UUID, req service.ApplyRequest) (diff.Changeset, error) {
|
||||
m.lastReq = req
|
||||
return diff.Changeset{}, nil
|
||||
}
|
||||
|
||||
// newTestAPI wires a fixed authenticated user who owns whatever project id
|
||||
// is requested (via alwaysOwnedAuthStore/alwaysValidSessions in
|
||||
// middleware_test.go) — these tests exercise check/apply behavior past the
|
||||
// RequireAuth/RequireProjectAccess boundary, which is covered separately by
|
||||
// middleware_test.go's own tests and the IDOR regression.
|
||||
func newTestAPI() (*API, *mockCheckApplier) {
|
||||
m := &mockCheckApplier{}
|
||||
return &API{Svc: m}, m // остальные зависимости (store/cipher) nil — CRUD-тесты добавит реализатор
|
||||
return &API{Svc: m, Auth: alwaysOwnedAuthStore(), Sessions: alwaysValidSessions(uuid.New())}, m
|
||||
}
|
||||
|
||||
func TestCheckEndpoint(t *testing.T) {
|
||||
@@ -39,7 +44,7 @@ func TestCheckEndpoint(t *testing.T) {
|
||||
router := NewRouter(a)
|
||||
|
||||
did := uuid.New().String()
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
req := requestWithSessionCookie(http.MethodGet,
|
||||
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/check", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
@@ -62,7 +67,7 @@ func TestApplyDefaultsPruneFalse(t *testing.T) {
|
||||
|
||||
did := uuid.New().String()
|
||||
body := `{"applyUpdates":true}` // applyPrunes отсутствует → false
|
||||
req := httptest.NewRequest(http.MethodPost,
|
||||
req := requestWithSessionCookie(http.MethodPost,
|
||||
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply",
|
||||
strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
@@ -81,7 +86,7 @@ func TestApplyEmptyBodyOK(t *testing.T) {
|
||||
router := NewRouter(a)
|
||||
|
||||
did := uuid.New().String()
|
||||
req := httptest.NewRequest(http.MethodPost,
|
||||
req := requestWithSessionCookie(http.MethodPost,
|
||||
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
@@ -100,7 +105,7 @@ func TestApplyMalformedBody(t *testing.T) {
|
||||
|
||||
did := uuid.New().String()
|
||||
body := `{"applyUpdates":`
|
||||
req := httptest.NewRequest(http.MethodPost,
|
||||
req := requestWithSessionCookie(http.MethodPost,
|
||||
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply",
|
||||
strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
@@ -114,7 +119,7 @@ func TestApplyMalformedBody(t *testing.T) {
|
||||
func TestApplyBadUUID(t *testing.T) {
|
||||
a, _ := newTestAPI()
|
||||
router := NewRouter(a)
|
||||
req := httptest.NewRequest(http.MethodPost,
|
||||
req := requestWithSessionCookie(http.MethodPost,
|
||||
"/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/not-a-uuid/apply",
|
||||
bytes.NewReader([]byte(`{}`)))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
@@ -0,0 +1,182 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/vasyakrg/dns-autoresolver/internal/auth"
|
||||
"github.com/vasyakrg/dns-autoresolver/internal/store"
|
||||
)
|
||||
|
||||
const sessionCookieName = "session"
|
||||
|
||||
// dummyPasswordHash is a valid-format argon2 hash with no real matching
|
||||
// password. handleLogin runs VerifyPassword against it whenever the email
|
||||
// lookup fails, so a login attempt for an unregistered email takes the same
|
||||
// wall-clock time as one for a registered email with a wrong password —
|
||||
// otherwise the timing difference would let an attacker enumerate which
|
||||
// emails are registered.
|
||||
var dummyPasswordHash string
|
||||
|
||||
func init() {
|
||||
h, err := auth.HashPassword("dns-autoresolver-timing-guard-dummy")
|
||||
if err != nil {
|
||||
panic("api: failed to initialize dummy password hash: " + err.Error())
|
||||
}
|
||||
dummyPasswordHash = h
|
||||
}
|
||||
|
||||
// normalizeEmail trims surrounding whitespace and lowercases the email so
|
||||
// storage and lookup are always consistent regardless of how the client
|
||||
// cased or padded the input.
|
||||
func normalizeEmail(email string) string {
|
||||
return strings.ToLower(strings.TrimSpace(email))
|
||||
}
|
||||
|
||||
func setSessionCookie(w http.ResponseWriter, token string, exp time.Time) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: sessionCookieName, Value: token, Path: "/",
|
||||
HttpOnly: true, Secure: true, SameSite: http.SameSiteLaxMode, Expires: exp,
|
||||
})
|
||||
}
|
||||
|
||||
func clearSessionCookie(w http.ResponseWriter) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: sessionCookieName, Value: "", Path: "/",
|
||||
HttpOnly: true, Secure: true, SameSite: http.SameSiteLaxMode, MaxAge: -1,
|
||||
})
|
||||
}
|
||||
|
||||
func (a *API) handleRegister(w http.ResponseWriter, r *http.Request) {
|
||||
var req registerRequest
|
||||
if !decodeBody(w, r, &req) {
|
||||
return
|
||||
}
|
||||
email := normalizeEmail(req.Email)
|
||||
if email == "" || req.Password == "" {
|
||||
writeErr(w, http.StatusBadRequest, "email and password are required")
|
||||
return
|
||||
}
|
||||
// Server-side minimum length is the source of truth: the client-side
|
||||
// zod min(8) check is UX only and can be bypassed with a direct POST.
|
||||
if len(req.Password) < 8 {
|
||||
writeErr(w, http.StatusBadRequest, "password must be at least 8 characters")
|
||||
return
|
||||
}
|
||||
|
||||
hash, err := auth.HashPassword(req.Password)
|
||||
if err != nil {
|
||||
log.Printf("api: hash password failed: %v", err)
|
||||
writeErr(w, http.StatusInternalServerError, "internal error")
|
||||
return
|
||||
}
|
||||
|
||||
u, p, err := a.Auth.RegisterUser(r.Context(), email, hash)
|
||||
if err != nil {
|
||||
if errors.Is(err, store.ErrEmailTaken) {
|
||||
writeErr(w, http.StatusConflict, "email already registered")
|
||||
return
|
||||
}
|
||||
log.Printf("api: register user failed: %v", err)
|
||||
writeErr(w, http.StatusInternalServerError, "internal error")
|
||||
return
|
||||
}
|
||||
|
||||
token, exp, err := a.Sessions.Create(r.Context(), u.ID)
|
||||
if err != nil {
|
||||
log.Printf("api: create session failed: %v", err)
|
||||
writeErr(w, http.StatusInternalServerError, "internal error")
|
||||
return
|
||||
}
|
||||
|
||||
setSessionCookie(w, token, exp)
|
||||
writeJSON(w, http.StatusOK, toAuthResponse(u, p))
|
||||
}
|
||||
|
||||
// invalidCredentials is deliberately identical for "no such user" and "wrong
|
||||
// password" — disclosing which one occurred would let an attacker enumerate
|
||||
// registered emails.
|
||||
func invalidCredentials(w http.ResponseWriter) {
|
||||
writeErr(w, http.StatusUnauthorized, "invalid credentials")
|
||||
}
|
||||
|
||||
func (a *API) handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
var req loginRequest
|
||||
if !decodeBody(w, r, &req) {
|
||||
return
|
||||
}
|
||||
email := normalizeEmail(req.Email)
|
||||
|
||||
u, err := a.Auth.GetUserByEmail(r.Context(), email)
|
||||
if err != nil {
|
||||
// No such user: still spend the argon2 verification cost against a
|
||||
// fixed dummy hash (see dummyPasswordHash) so this path isn't
|
||||
// distinguishable by timing from a wrong-password rejection below.
|
||||
_, _ = auth.VerifyPassword(dummyPasswordHash, req.Password)
|
||||
invalidCredentials(w)
|
||||
return
|
||||
}
|
||||
|
||||
ok, err := auth.VerifyPassword(u.PasswordHash, req.Password)
|
||||
if err != nil || !ok {
|
||||
invalidCredentials(w)
|
||||
return
|
||||
}
|
||||
|
||||
p, err := a.Auth.GetUserProject(r.Context(), u.ID)
|
||||
if err != nil {
|
||||
log.Printf("api: get user project failed: %v", err)
|
||||
writeErr(w, http.StatusInternalServerError, "internal error")
|
||||
return
|
||||
}
|
||||
|
||||
token, exp, err := a.Sessions.Create(r.Context(), u.ID)
|
||||
if err != nil {
|
||||
log.Printf("api: create session failed: %v", err)
|
||||
writeErr(w, http.StatusInternalServerError, "internal error")
|
||||
return
|
||||
}
|
||||
|
||||
setSessionCookie(w, token, exp)
|
||||
writeJSON(w, http.StatusOK, toAuthResponse(u, p))
|
||||
}
|
||||
|
||||
func (a *API) handleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
if c, err := r.Cookie(sessionCookieName); err == nil && c.Value != "" {
|
||||
if err := a.Sessions.Destroy(r.Context(), c.Value); err != nil {
|
||||
log.Printf("api: destroy session failed: %v", err)
|
||||
}
|
||||
}
|
||||
clearSessionCookie(w)
|
||||
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
|
||||
}
|
||||
|
||||
// handleMe returns the authenticated caller's identity + default project.
|
||||
// The user ID comes from the request context, set by RequireAuth after
|
||||
// validating the session cookie.
|
||||
func (a *API) handleMe(w http.ResponseWriter, r *http.Request) {
|
||||
userID, ok := userIDFrom(r.Context())
|
||||
if !ok {
|
||||
writeErr(w, http.StatusUnauthorized, "authentication required")
|
||||
return
|
||||
}
|
||||
|
||||
u, err := a.Auth.GetUserByID(r.Context(), userID)
|
||||
if err != nil {
|
||||
log.Printf("api: get user by id failed: %v", err)
|
||||
writeErr(w, http.StatusInternalServerError, "internal error")
|
||||
return
|
||||
}
|
||||
|
||||
p, err := a.Auth.GetUserProject(r.Context(), userID)
|
||||
if err != nil {
|
||||
log.Printf("api: get user project failed: %v", err)
|
||||
writeErr(w, http.StatusInternalServerError, "internal error")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, toAuthResponse(u, p))
|
||||
}
|
||||
@@ -0,0 +1,432 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/vasyakrg/dns-autoresolver/internal/auth"
|
||||
"github.com/vasyakrg/dns-autoresolver/internal/store"
|
||||
)
|
||||
|
||||
// --- mocks ---
|
||||
|
||||
type mockAuthStore struct {
|
||||
registerUserFn func(ctx context.Context, email, passwordHash string) (store.User, store.Project, error)
|
||||
getUserByEmailFn func(ctx context.Context, email string) (store.User, error)
|
||||
getUserByIDFn func(ctx context.Context, userID uuid.UUID) (store.User, error)
|
||||
getUserProjectFn func(ctx context.Context, userID uuid.UUID) (store.Project, error)
|
||||
getProjectOwnedFn func(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error)
|
||||
}
|
||||
|
||||
func (m *mockAuthStore) RegisterUser(ctx context.Context, email, passwordHash string) (store.User, store.Project, error) {
|
||||
return m.registerUserFn(ctx, email, passwordHash)
|
||||
}
|
||||
|
||||
func (m *mockAuthStore) GetUserByEmail(ctx context.Context, email string) (store.User, error) {
|
||||
return m.getUserByEmailFn(ctx, email)
|
||||
}
|
||||
|
||||
func (m *mockAuthStore) GetUserByID(ctx context.Context, userID uuid.UUID) (store.User, error) {
|
||||
return m.getUserByIDFn(ctx, userID)
|
||||
}
|
||||
|
||||
func (m *mockAuthStore) GetUserProject(ctx context.Context, userID uuid.UUID) (store.Project, error) {
|
||||
return m.getUserProjectFn(ctx, userID)
|
||||
}
|
||||
|
||||
func (m *mockAuthStore) GetProjectOwned(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error) {
|
||||
return m.getProjectOwnedFn(ctx, projectID, userID)
|
||||
}
|
||||
|
||||
type mockSessionManager struct {
|
||||
createFn func(ctx context.Context, userID uuid.UUID) (string, time.Time, error)
|
||||
validateFn func(ctx context.Context, token string) (uuid.UUID, error)
|
||||
|
||||
destroyCalled bool
|
||||
destroyToken string
|
||||
destroyErr error
|
||||
}
|
||||
|
||||
func (m *mockSessionManager) Create(ctx context.Context, userID uuid.UUID) (string, time.Time, error) {
|
||||
return m.createFn(ctx, userID)
|
||||
}
|
||||
|
||||
func (m *mockSessionManager) Validate(ctx context.Context, token string) (uuid.UUID, error) {
|
||||
if m.validateFn != nil {
|
||||
return m.validateFn(ctx, token)
|
||||
}
|
||||
return uuid.Nil, nil
|
||||
}
|
||||
|
||||
func (m *mockSessionManager) Destroy(ctx context.Context, token string) error {
|
||||
m.destroyCalled = true
|
||||
m.destroyToken = token
|
||||
return m.destroyErr
|
||||
}
|
||||
|
||||
func newTestAuthAPI() (*API, *mockAuthStore, *mockSessionManager) {
|
||||
authStore := &mockAuthStore{}
|
||||
sessions := &mockSessionManager{
|
||||
createFn: func(_ context.Context, userID uuid.UUID) (string, time.Time, error) {
|
||||
return "test-token", time.Now().Add(time.Hour), nil
|
||||
},
|
||||
}
|
||||
return &API{Auth: authStore, Sessions: sessions}, authStore, sessions
|
||||
}
|
||||
|
||||
func findCookie(resp *http.Response, name string) *http.Cookie {
|
||||
for _, c := range resp.Cookies() {
|
||||
if c.Name == name {
|
||||
return c
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- register ---
|
||||
|
||||
func TestAuthRegister_Success(t *testing.T) {
|
||||
a, authStore, _ := newTestAuthAPI()
|
||||
userID := uuid.New()
|
||||
projectID := uuid.New()
|
||||
authStore.registerUserFn = func(_ context.Context, email, passwordHash string) (store.User, store.Project, error) {
|
||||
if passwordHash == "" {
|
||||
t.Fatal("expected non-empty password hash passed to RegisterUser")
|
||||
}
|
||||
return store.User{ID: userID, Email: email, PasswordHash: passwordHash},
|
||||
store.Project{ID: projectID, UserID: userID, Name: "default"}, nil
|
||||
}
|
||||
|
||||
router := NewRouter(a)
|
||||
body := `{"email":"alice@example.com","password":"correct-horse"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status %d, body %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
resp := w.Result()
|
||||
cookie := findCookie(resp, sessionCookieName)
|
||||
if cookie == nil {
|
||||
t.Fatal("expected session cookie to be set")
|
||||
}
|
||||
if cookie.Value != "test-token" {
|
||||
t.Fatalf("unexpected cookie value: %q", cookie.Value)
|
||||
}
|
||||
|
||||
if strings.Contains(w.Body.String(), "password") {
|
||||
t.Fatalf("response body must not contain password/password_hash: %s", w.Body.String())
|
||||
}
|
||||
|
||||
var got authResponse
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &got); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got.User.ID != userID.String() || got.User.Email != "alice@example.com" {
|
||||
t.Fatalf("unexpected user in response: %+v", got.User)
|
||||
}
|
||||
if got.Project.ID != projectID.String() || got.Project.Name != "default" {
|
||||
t.Fatalf("unexpected project in response: %+v", got.Project)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthRegister_NormalizesEmail verifies the fix for the email-consistency
|
||||
// gap: a padded/mixed-case email is trimmed+lowercased before it reaches the
|
||||
// store, so storage and later lookups are always consistent.
|
||||
func TestAuthRegister_NormalizesEmail(t *testing.T) {
|
||||
a, authStore, _ := newTestAuthAPI()
|
||||
userID := uuid.New()
|
||||
var gotEmail string
|
||||
authStore.registerUserFn = func(_ context.Context, email, passwordHash string) (store.User, store.Project, error) {
|
||||
gotEmail = email
|
||||
return store.User{ID: userID, Email: email}, store.Project{ID: uuid.New(), UserID: userID, Name: "default"}, nil
|
||||
}
|
||||
|
||||
router := NewRouter(a)
|
||||
body := `{"email":" Alice@X.com ","password":"correct-horse"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status %d, body %s", w.Code, w.Body.String())
|
||||
}
|
||||
if gotEmail != "alice@x.com" {
|
||||
t.Fatalf("expected normalized email passed to RegisterUser, got %q", gotEmail)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthRegister_ShortPasswordReturns400 verifies the server-side password
|
||||
// length floor: the client's zod min(8) is UX only and can be bypassed with a
|
||||
// direct POST, so the handler itself must reject a password under 8 chars
|
||||
// before ever calling RegisterUser.
|
||||
func TestAuthRegister_ShortPasswordReturns400(t *testing.T) {
|
||||
a, authStore, _ := newTestAuthAPI()
|
||||
registerCalled := false
|
||||
authStore.registerUserFn = func(context.Context, string, string) (store.User, store.Project, error) {
|
||||
registerCalled = true
|
||||
return store.User{}, store.Project{}, nil
|
||||
}
|
||||
|
||||
router := NewRouter(a)
|
||||
body := `{"email":"alice@example.com","password":"short"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status %d, body %s", w.Code, w.Body.String())
|
||||
}
|
||||
var got map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &got); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got["error"] != "password must be at least 8 characters" {
|
||||
t.Fatalf(`expected error "password must be at least 8 characters", got %q`, got["error"])
|
||||
}
|
||||
if registerCalled {
|
||||
t.Fatal("expected RegisterUser not to be called for a too-short password")
|
||||
}
|
||||
if findCookie(w.Result(), sessionCookieName) != nil {
|
||||
t.Fatal("expected no session cookie on rejected register")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthRegister_DuplicateEmailReturns409 verifies the fix for the
|
||||
// duplicate-registration gap: RegisterUser reporting store.ErrEmailTaken
|
||||
// must surface as 409, not a generic 500.
|
||||
func TestAuthRegister_DuplicateEmailReturns409(t *testing.T) {
|
||||
a, authStore, _ := newTestAuthAPI()
|
||||
authStore.registerUserFn = func(context.Context, string, string) (store.User, store.Project, error) {
|
||||
return store.User{}, store.Project{}, store.ErrEmailTaken
|
||||
}
|
||||
|
||||
router := NewRouter(a)
|
||||
body := `{"email":"dup@example.com","password":"correct-horse"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusConflict {
|
||||
t.Fatalf("status %d, body %s", w.Code, w.Body.String())
|
||||
}
|
||||
var got map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &got); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got["error"] != "email already registered" {
|
||||
t.Fatalf(`expected error "email already registered", got %q`, got["error"])
|
||||
}
|
||||
}
|
||||
|
||||
// --- login ---
|
||||
|
||||
func TestAuthLogin_CorrectPassword(t *testing.T) {
|
||||
a, authStore, _ := newTestAuthAPI()
|
||||
hash, err := auth.HashPassword("correct-horse")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
userID := uuid.New()
|
||||
projectID := uuid.New()
|
||||
authStore.getUserByEmailFn = func(_ context.Context, email string) (store.User, error) {
|
||||
return store.User{ID: userID, Email: email, PasswordHash: hash}, nil
|
||||
}
|
||||
authStore.getUserProjectFn = func(_ context.Context, uid uuid.UUID) (store.Project, error) {
|
||||
return store.Project{ID: projectID, UserID: uid, Name: "default"}, nil
|
||||
}
|
||||
|
||||
router := NewRouter(a)
|
||||
body := `{"email":"bob@example.com","password":"correct-horse"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status %d, body %s", w.Code, w.Body.String())
|
||||
}
|
||||
if findCookie(w.Result(), sessionCookieName) == nil {
|
||||
t.Fatal("expected session cookie to be set")
|
||||
}
|
||||
if strings.Contains(w.Body.String(), "password") {
|
||||
t.Fatalf("response body must not contain password/password_hash: %s", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthLogin_NormalizesEmail verifies that a login for a padded/mixed-case
|
||||
// email reaches GetUserByEmail already trimmed+lowercased — the same
|
||||
// normalization applied on register, so "Alice@X.com" at registration and
|
||||
// "alice@x.com" at login resolve to the same account.
|
||||
func TestAuthLogin_NormalizesEmail(t *testing.T) {
|
||||
a, authStore, _ := newTestAuthAPI()
|
||||
hash, err := auth.HashPassword("correct-horse")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
userID := uuid.New()
|
||||
var gotEmail string
|
||||
authStore.getUserByEmailFn = func(_ context.Context, email string) (store.User, error) {
|
||||
gotEmail = email
|
||||
return store.User{ID: userID, Email: email, PasswordHash: hash}, nil
|
||||
}
|
||||
authStore.getUserProjectFn = func(_ context.Context, uid uuid.UUID) (store.Project, error) {
|
||||
return store.Project{ID: uuid.New(), UserID: uid, Name: "default"}, nil
|
||||
}
|
||||
|
||||
router := NewRouter(a)
|
||||
body := `{"email":" Alice@X.com ","password":"correct-horse"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status %d, body %s", w.Code, w.Body.String())
|
||||
}
|
||||
if gotEmail != "alice@x.com" {
|
||||
t.Fatalf("expected normalized email passed to GetUserByEmail, got %q", gotEmail)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthLogin_WrongPassword(t *testing.T) {
|
||||
a, authStore, _ := newTestAuthAPI()
|
||||
hash, err := auth.HashPassword("correct-horse")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
authStore.getUserByEmailFn = func(_ context.Context, email string) (store.User, error) {
|
||||
return store.User{ID: uuid.New(), Email: email, PasswordHash: hash}, nil
|
||||
}
|
||||
|
||||
router := NewRouter(a)
|
||||
body := `{"email":"bob@example.com","password":"wrong-password"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assertInvalidCredentials(t, w)
|
||||
}
|
||||
|
||||
func TestAuthLogin_UnknownEmail(t *testing.T) {
|
||||
a, authStore, _ := newTestAuthAPI()
|
||||
authStore.getUserByEmailFn = func(_ context.Context, email string) (store.User, error) {
|
||||
return store.User{}, errNoRowsForTest
|
||||
}
|
||||
|
||||
router := NewRouter(a)
|
||||
body := `{"email":"nobody@example.com","password":"whatever"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assertInvalidCredentials(t, w)
|
||||
}
|
||||
|
||||
func assertInvalidCredentials(t *testing.T, w *httptest.ResponseRecorder) {
|
||||
t.Helper()
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("status %d, body %s", w.Code, w.Body.String())
|
||||
}
|
||||
var got map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &got); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got["error"] != "invalid credentials" {
|
||||
t.Fatalf(`expected error "invalid credentials", got %q`, got["error"])
|
||||
}
|
||||
if findCookie(w.Result(), sessionCookieName) != nil {
|
||||
t.Fatal("expected no session cookie on failed login")
|
||||
}
|
||||
}
|
||||
|
||||
// --- logout ---
|
||||
|
||||
func TestAuthLogout_ClearsSessionAndDestroys(t *testing.T) {
|
||||
a, _, sessions := newTestAuthAPI()
|
||||
|
||||
router := NewRouter(a)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/logout", nil)
|
||||
req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: "some-token"})
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status %d, body %s", w.Code, w.Body.String())
|
||||
}
|
||||
if !sessions.destroyCalled {
|
||||
t.Fatal("expected Sessions.Destroy to be called")
|
||||
}
|
||||
if sessions.destroyToken != "some-token" {
|
||||
t.Fatalf("expected Destroy called with cookie token, got %q", sessions.destroyToken)
|
||||
}
|
||||
|
||||
cookie := findCookie(w.Result(), sessionCookieName)
|
||||
if cookie == nil {
|
||||
t.Fatal("expected a session cookie in the response (clearing cookie)")
|
||||
}
|
||||
if cookie.MaxAge > 0 {
|
||||
t.Fatalf("expected cookie MaxAge <= 0 (cleared), got %d", cookie.MaxAge)
|
||||
}
|
||||
}
|
||||
|
||||
// --- me ---
|
||||
|
||||
// TestAuthMe_ReturnsRealEmail verifies the fix for the /me gap: the handler
|
||||
// now resolves the authenticated user via GetUserByID and returns their real
|
||||
// email, instead of leaving it blank.
|
||||
func TestAuthMe_ReturnsRealEmail(t *testing.T) {
|
||||
a, authStore, sessions := newTestAuthAPI()
|
||||
userID := uuid.New()
|
||||
projectID := uuid.New()
|
||||
authStore.getUserByIDFn = func(_ context.Context, id uuid.UUID) (store.User, error) {
|
||||
if id != userID {
|
||||
t.Fatalf("unexpected user id: %s", id)
|
||||
}
|
||||
return store.User{ID: userID, Email: "me@example.com"}, nil
|
||||
}
|
||||
authStore.getUserProjectFn = func(_ context.Context, uid uuid.UUID) (store.Project, error) {
|
||||
return store.Project{ID: projectID, UserID: uid, Name: "default"}, nil
|
||||
}
|
||||
// /me is now behind RequireAuth: the session cookie must resolve to
|
||||
// userID via Sessions.Validate rather than being injected directly.
|
||||
sessions.validateFn = func(context.Context, string) (uuid.UUID, error) {
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
router := NewRouter(a)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/me", nil)
|
||||
req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: "some-valid-token"})
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status %d, body %s", w.Code, w.Body.String())
|
||||
}
|
||||
var got authResponse
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &got); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got.User.ID != userID.String() || got.User.Email != "me@example.com" {
|
||||
t.Fatalf("unexpected user in /me response: %+v", got.User)
|
||||
}
|
||||
if got.Project.ID != projectID.String() {
|
||||
t.Fatalf("unexpected project in /me response: %+v", got.Project)
|
||||
}
|
||||
}
|
||||
|
||||
// errNoRowsForTest stands in for a "not found" error a real store would
|
||||
// return (e.g. pgx.ErrNoRows) — handlers must not distinguish it from any
|
||||
// other GetUserByEmail failure in the response they send.
|
||||
var errNoRowsForTest = ¬FoundErr{}
|
||||
|
||||
type notFoundErr struct{}
|
||||
|
||||
func (*notFoundErr) Error() string { return "not found" }
|
||||
+38
-1
@@ -1,6 +1,43 @@
|
||||
package api
|
||||
|
||||
import "github.com/vasyakrg/dns-autoresolver/internal/diff"
|
||||
import (
|
||||
"github.com/vasyakrg/dns-autoresolver/internal/diff"
|
||||
"github.com/vasyakrg/dns-autoresolver/internal/store"
|
||||
)
|
||||
|
||||
type registerRequest struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type loginRequest struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
// userResponse and projectResponse deliberately expose only id/email and
|
||||
// id/name — password_hash must never reach a client response.
|
||||
type userResponse struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
type projectResponse struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type authResponse struct {
|
||||
User userResponse `json:"user"`
|
||||
Project projectResponse `json:"project"`
|
||||
}
|
||||
|
||||
func toAuthResponse(u store.User, p store.Project) authResponse {
|
||||
return authResponse{
|
||||
User: userResponse{ID: u.ID.String(), Email: u.Email},
|
||||
Project: projectResponse{ID: p.ID.String(), Name: p.Name},
|
||||
}
|
||||
}
|
||||
|
||||
type applyRequest struct {
|
||||
ApplyUpdates bool `json:"applyUpdates"`
|
||||
|
||||
@@ -24,12 +24,15 @@ func writeErr(w http.ResponseWriter, status int, msg string) {
|
||||
}
|
||||
|
||||
func (a *API) handleCheck(w http.ResponseWriter, r *http.Request) {
|
||||
// pid is guaranteed present and owned by the caller — RequireProjectAccess
|
||||
// validated it before this handler ever runs.
|
||||
pid, _ := projectIDFrom(r.Context())
|
||||
did, err := uuid.Parse(chi.URLParam(r, "did"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid domain id")
|
||||
return
|
||||
}
|
||||
cs, err := a.Svc.Check(r.Context(), did)
|
||||
cs, err := a.Svc.Check(r.Context(), pid, did)
|
||||
if err != nil {
|
||||
log.Printf("api: check failed: %v", err)
|
||||
writeErr(w, http.StatusInternalServerError, "internal error")
|
||||
@@ -39,6 +42,9 @@ func (a *API) handleCheck(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (a *API) handleApply(w http.ResponseWriter, r *http.Request) {
|
||||
// pid is guaranteed present and owned by the caller — RequireProjectAccess
|
||||
// validated it before this handler ever runs.
|
||||
pid, _ := projectIDFrom(r.Context())
|
||||
did, err := uuid.Parse(chi.URLParam(r, "did"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid domain id")
|
||||
@@ -54,7 +60,7 @@ func (a *API) handleApply(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
}
|
||||
cs, err := a.Svc.Apply(r.Context(), did, service.ApplyRequest{
|
||||
cs, err := a.Svc.Apply(r.Context(), pid, did, service.ApplyRequest{
|
||||
ApplyUpdates: req.ApplyUpdates, ApplyPrunes: req.ApplyPrunes,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type ctxKey string
|
||||
|
||||
const (
|
||||
ctxUserID ctxKey = "userID"
|
||||
ctxProjectID ctxKey = "projectID"
|
||||
)
|
||||
|
||||
// RequireAuth validates the session cookie and, on success, stores the
|
||||
// authenticated user's ID in the request context (see userIDFrom). Any
|
||||
// failure — missing cookie or an invalid/expired token — is rejected with
|
||||
// 401 before reaching the wrapped handler.
|
||||
func (a *API) RequireAuth(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := r.Cookie(sessionCookieName)
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusUnauthorized, "unauthorized")
|
||||
return
|
||||
}
|
||||
uid, err := a.Sessions.Validate(r.Context(), c.Value)
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusUnauthorized, "unauthorized")
|
||||
return
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), ctxUserID, uid)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// RequireProjectAccess verifies that the {pid} URL segment names a project
|
||||
// owned by the authenticated user (set by RequireAuth, which must run
|
||||
// first) and, on success, stores the project ID in the request context (see
|
||||
// projectIDFrom). A project that doesn't exist or isn't owned by the caller
|
||||
// is rejected with 404 — not 403 — so a caller can't distinguish "not yours"
|
||||
// from "doesn't exist" (closes IDOR by never confirming another tenant's
|
||||
// project exists).
|
||||
func (a *API) RequireProjectAccess(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
uid, ok := userIDFrom(r.Context())
|
||||
if !ok {
|
||||
writeErr(w, http.StatusUnauthorized, "unauthorized")
|
||||
return
|
||||
}
|
||||
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid project id")
|
||||
return
|
||||
}
|
||||
if _, err := a.Auth.GetProjectOwned(r.Context(), pid, uid); err != nil {
|
||||
writeErr(w, http.StatusNotFound, "not found")
|
||||
return
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), ctxProjectID, pid)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func userIDFrom(ctx context.Context) (uuid.UUID, bool) {
|
||||
v, ok := ctx.Value(ctxUserID).(uuid.UUID)
|
||||
return v, ok
|
||||
}
|
||||
|
||||
func projectIDFrom(ctx context.Context) (uuid.UUID, bool) {
|
||||
v, ok := ctx.Value(ctxProjectID).(uuid.UUID)
|
||||
return v, ok
|
||||
}
|
||||
@@ -0,0 +1,265 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/vasyakrg/dns-autoresolver/internal/diff"
|
||||
"github.com/vasyakrg/dns-autoresolver/internal/service"
|
||||
"github.com/vasyakrg/dns-autoresolver/internal/store"
|
||||
)
|
||||
|
||||
// --- shared test doubles (also used by api_test.go / tenant_test.go) ---
|
||||
|
||||
// stubSessions is a configurable SessionManager test double.
|
||||
type stubSessions struct {
|
||||
validateFn func(ctx context.Context, token string) (uuid.UUID, error)
|
||||
}
|
||||
|
||||
func (s stubSessions) Create(context.Context, uuid.UUID) (string, time.Time, error) {
|
||||
return "stub-token", time.Now().Add(time.Hour), nil
|
||||
}
|
||||
func (s stubSessions) Validate(ctx context.Context, token string) (uuid.UUID, error) {
|
||||
return s.validateFn(ctx, token)
|
||||
}
|
||||
func (s stubSessions) Destroy(context.Context, string) error { return nil }
|
||||
|
||||
// alwaysValidSessions builds a stubSessions whose Validate always succeeds
|
||||
// with the given fixed user ID — for tests that only care about behavior
|
||||
// past the RequireAuth boundary.
|
||||
func alwaysValidSessions(uid uuid.UUID) stubSessions {
|
||||
return stubSessions{validateFn: func(context.Context, string) (uuid.UUID, error) { return uid, nil }}
|
||||
}
|
||||
|
||||
// stubAuthStore is a configurable AuthStore test double. Methods besides
|
||||
// GetProjectOwned return zero values — tests exercising register/login/me
|
||||
// use their own dedicated mock (mockAuthStore in auth_test.go).
|
||||
type stubAuthStore struct {
|
||||
getProjectOwnedFn func(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error)
|
||||
}
|
||||
|
||||
func (s stubAuthStore) RegisterUser(context.Context, string, string) (store.User, store.Project, error) {
|
||||
return store.User{}, store.Project{}, nil
|
||||
}
|
||||
func (s stubAuthStore) GetUserByEmail(context.Context, string) (store.User, error) {
|
||||
return store.User{}, nil
|
||||
}
|
||||
func (s stubAuthStore) GetUserByID(context.Context, uuid.UUID) (store.User, error) {
|
||||
return store.User{}, nil
|
||||
}
|
||||
func (s stubAuthStore) GetUserProject(context.Context, uuid.UUID) (store.Project, error) {
|
||||
return store.Project{}, nil
|
||||
}
|
||||
func (s stubAuthStore) GetProjectOwned(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error) {
|
||||
return s.getProjectOwnedFn(ctx, projectID, userID)
|
||||
}
|
||||
|
||||
// alwaysOwnedAuthStore builds a stubAuthStore whose GetProjectOwned always
|
||||
// succeeds, treating the caller as the owner of whatever project id is
|
||||
// requested — for tests that only care about behavior past the
|
||||
// RequireProjectAccess boundary (CRUD/check/apply tests).
|
||||
func alwaysOwnedAuthStore() stubAuthStore {
|
||||
return stubAuthStore{getProjectOwnedFn: func(_ context.Context, pid, uid uuid.UUID) (store.Project, error) {
|
||||
return store.Project{ID: pid, UserID: uid}, nil
|
||||
}}
|
||||
}
|
||||
|
||||
// requestWithSessionCookie builds an httptest request carrying a session
|
||||
// cookie so it clears RequireAuth in tests using stubSessions/alwaysValidSessions
|
||||
// (which ignore the token value and validate unconditionally).
|
||||
func requestWithSessionCookie(method, url string, body io.Reader) *http.Request {
|
||||
req := httptest.NewRequest(method, url, body)
|
||||
req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: "test-session-token"})
|
||||
return req
|
||||
}
|
||||
|
||||
// recordingCheckApplier is a CheckApplier test double that records whether
|
||||
// Check/Apply were invoked — used by the IDOR regression test to assert the
|
||||
// service layer is never reached for a project the caller doesn't own.
|
||||
type recordingCheckApplier struct {
|
||||
checkCalled bool
|
||||
applyCalled bool
|
||||
}
|
||||
|
||||
func (r *recordingCheckApplier) Check(context.Context, uuid.UUID, uuid.UUID) (diff.Changeset, error) {
|
||||
r.checkCalled = true
|
||||
return diff.Changeset{}, nil
|
||||
}
|
||||
func (r *recordingCheckApplier) Apply(context.Context, uuid.UUID, uuid.UUID, service.ApplyRequest) (diff.Changeset, error) {
|
||||
r.applyCalled = true
|
||||
return diff.Changeset{}, nil
|
||||
}
|
||||
|
||||
// --- RequireAuth ---
|
||||
|
||||
func TestRequireAuth_NoCookie_Returns401(t *testing.T) {
|
||||
a := &API{Sessions: stubSessions{validateFn: func(context.Context, string) (uuid.UUID, error) {
|
||||
t.Fatal("Validate must not be called when there is no cookie")
|
||||
return uuid.Nil, nil
|
||||
}}}
|
||||
nextCalled := false
|
||||
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { nextCalled = true })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/whatever", nil)
|
||||
w := httptest.NewRecorder()
|
||||
a.RequireAuth(next).ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401, got %d", w.Code)
|
||||
}
|
||||
if nextCalled {
|
||||
t.Fatal("next must not be called without a session cookie")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireAuth_InvalidToken_Returns401(t *testing.T) {
|
||||
a := &API{Sessions: stubSessions{validateFn: func(context.Context, string) (uuid.UUID, error) {
|
||||
return uuid.Nil, errors.New("invalid or expired session")
|
||||
}}}
|
||||
nextCalled := false
|
||||
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { nextCalled = true })
|
||||
|
||||
req := requestWithSessionCookie(http.MethodGet, "/whatever", nil)
|
||||
w := httptest.NewRecorder()
|
||||
a.RequireAuth(next).ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401, got %d", w.Code)
|
||||
}
|
||||
if nextCalled {
|
||||
t.Fatal("next must not be called when Sessions.Validate fails")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireAuth_Success_CallsNextWithUserID(t *testing.T) {
|
||||
uid := uuid.New()
|
||||
a := &API{Sessions: alwaysValidSessions(uid)}
|
||||
|
||||
var gotUID uuid.UUID
|
||||
var gotOK bool
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotUID, gotOK = userIDFrom(r.Context())
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
req := requestWithSessionCookie(http.MethodGet, "/whatever", nil)
|
||||
w := httptest.NewRecorder()
|
||||
a.RequireAuth(next).ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 (next called), got %d", w.Code)
|
||||
}
|
||||
if !gotOK || gotUID != uid {
|
||||
t.Fatalf("expected userIDFrom to yield %s, got %s (ok=%v)", uid, gotUID, gotOK)
|
||||
}
|
||||
}
|
||||
|
||||
// --- RequireProjectAccess ---
|
||||
|
||||
func TestRequireProjectAccess_NotOwned_Returns404(t *testing.T) {
|
||||
a := &API{Auth: stubAuthStore{getProjectOwnedFn: func(context.Context, uuid.UUID, uuid.UUID) (store.Project, error) {
|
||||
return store.Project{}, errors.New("no rows")
|
||||
}}}
|
||||
nextCalled := false
|
||||
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { nextCalled = true })
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.With(a.RequireProjectAccess).Get("/projects/{pid}", next.ServeHTTP)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/projects/"+uuid.New().String(), nil)
|
||||
req = req.WithContext(context.WithValue(req.Context(), ctxUserID, uuid.New()))
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Fatalf("expected 404 for an unowned/foreign project, got %d body %s", w.Code, w.Body.String())
|
||||
}
|
||||
if nextCalled {
|
||||
t.Fatal("next must not be called when the project isn't owned by the caller")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireProjectAccess_Success_CallsNextWithProjectID(t *testing.T) {
|
||||
a := &API{Auth: alwaysOwnedAuthStore()}
|
||||
pid := uuid.New()
|
||||
uid := uuid.New()
|
||||
|
||||
var gotPID uuid.UUID
|
||||
var gotOK bool
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPID, gotOK = projectIDFrom(r.Context())
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.With(a.RequireProjectAccess).Get("/projects/{pid}", next.ServeHTTP)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/projects/"+pid.String(), nil)
|
||||
req = req.WithContext(context.WithValue(req.Context(), ctxUserID, uid))
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 (next called), got %d body %s", w.Code, w.Body.String())
|
||||
}
|
||||
if !gotOK || gotPID != pid {
|
||||
t.Fatalf("expected projectIDFrom to yield %s, got %s (ok=%v)", pid, gotPID, gotOK)
|
||||
}
|
||||
}
|
||||
|
||||
// --- IDOR regression ---
|
||||
|
||||
// TestIDOR_CheckForeignProject_Returns404AndServiceNotCalled is the IDOR
|
||||
// regression required by Task 4: a request for project B's domain, made
|
||||
// with user A's session (who owns project A, not B), must be rejected by
|
||||
// RequireProjectAccess with 404 before service.Check is ever invoked — a
|
||||
// caller must never be able to check/apply another tenant's domain by
|
||||
// guessing/leaking its ids.
|
||||
func TestIDOR_CheckForeignProject_Returns404AndServiceNotCalled(t *testing.T) {
|
||||
userA := uuid.New()
|
||||
pidA := uuid.New()
|
||||
pidB := uuid.New() // owned by a different user; not A
|
||||
|
||||
svc := &recordingCheckApplier{}
|
||||
auth := stubAuthStore{getProjectOwnedFn: func(_ context.Context, pid, uid uuid.UUID) (store.Project, error) {
|
||||
if pid == pidA && uid == userA {
|
||||
return store.Project{ID: pidA, UserID: userA}, nil
|
||||
}
|
||||
return store.Project{}, errors.New("not found")
|
||||
}}
|
||||
a := &API{Svc: svc, Auth: auth, Sessions: alwaysValidSessions(userA)}
|
||||
router := NewRouter(a)
|
||||
|
||||
did := uuid.New().String()
|
||||
req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+pidB.String()+"/domains/"+did+"/check", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Fatalf("expected 404 for user A requesting project B's domain, got %d body %s", w.Code, w.Body.String())
|
||||
}
|
||||
if svc.checkCalled {
|
||||
t.Fatal("service.Check must not be called when the caller doesn't own the project")
|
||||
}
|
||||
|
||||
// Sanity check: the same user against their own project succeeds and
|
||||
// does reach the service — proving the 404 above is really about
|
||||
// project ownership, not e.g. a broken route.
|
||||
req2 := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+pidA.String()+"/domains/"+did+"/check", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(w2, req2)
|
||||
if w2.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 for user A requesting their own project, got %d body %s", w2.Code, w2.Body.String())
|
||||
}
|
||||
if !svc.checkCalled {
|
||||
t.Fatal("expected service.Check to be called for the caller's own project")
|
||||
}
|
||||
}
|
||||
@@ -24,11 +24,9 @@ func decodeBody(w http.ResponseWriter, r *http.Request, v any) bool {
|
||||
// --- accounts ---
|
||||
|
||||
func (a *API) handleCreateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid project id")
|
||||
return
|
||||
}
|
||||
// pid is guaranteed present and owned by the caller — RequireProjectAccess
|
||||
// validated it before this handler ever runs.
|
||||
pid, _ := projectIDFrom(r.Context())
|
||||
var req accountRequest
|
||||
if !decodeBody(w, r, &req) {
|
||||
return
|
||||
@@ -53,11 +51,9 @@ func (a *API) handleCreateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (a *API) handleListAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid project id")
|
||||
return
|
||||
}
|
||||
// pid is guaranteed present and owned by the caller — RequireProjectAccess
|
||||
// validated it before this handler ever runs.
|
||||
pid, _ := projectIDFrom(r.Context())
|
||||
accs, err := a.Store.ListAccounts(r.Context(), pid)
|
||||
if err != nil {
|
||||
log.Printf("api: list accounts failed: %v", err)
|
||||
@@ -72,11 +68,9 @@ func (a *API) handleListAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (a *API) handleDeleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid project id")
|
||||
return
|
||||
}
|
||||
// pid is guaranteed present and owned by the caller — RequireProjectAccess
|
||||
// validated it before this handler ever runs.
|
||||
pid, _ := projectIDFrom(r.Context())
|
||||
aid, err := uuid.Parse(chi.URLParam(r, "aid"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid account id")
|
||||
@@ -93,11 +87,9 @@ func (a *API) handleDeleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||
// handleImportZones lists zones from the provider for the given account and
|
||||
// creates one domain per zone (template_id left unset).
|
||||
func (a *API) handleImportZones(w http.ResponseWriter, r *http.Request) {
|
||||
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid project id")
|
||||
return
|
||||
}
|
||||
// pid is guaranteed present and owned by the caller — RequireProjectAccess
|
||||
// validated it before this handler ever runs.
|
||||
pid, _ := projectIDFrom(r.Context())
|
||||
aid, err := uuid.Parse(chi.URLParam(r, "aid"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid account id")
|
||||
@@ -145,11 +137,9 @@ func (a *API) handleImportZones(w http.ResponseWriter, r *http.Request) {
|
||||
// --- templates ---
|
||||
|
||||
func (a *API) handleCreateTemplate(w http.ResponseWriter, r *http.Request) {
|
||||
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid project id")
|
||||
return
|
||||
}
|
||||
// pid is guaranteed present and owned by the caller — RequireProjectAccess
|
||||
// validated it before this handler ever runs.
|
||||
pid, _ := projectIDFrom(r.Context())
|
||||
var req templateRequest
|
||||
if !decodeBody(w, r, &req) {
|
||||
return
|
||||
@@ -169,11 +159,9 @@ func (a *API) handleCreateTemplate(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (a *API) handleListTemplates(w http.ResponseWriter, r *http.Request) {
|
||||
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid project id")
|
||||
return
|
||||
}
|
||||
// pid is guaranteed present and owned by the caller — RequireProjectAccess
|
||||
// validated it before this handler ever runs.
|
||||
pid, _ := projectIDFrom(r.Context())
|
||||
tpls, err := a.Store.ListTemplates(r.Context(), pid)
|
||||
if err != nil {
|
||||
log.Printf("api: list templates failed: %v", err)
|
||||
@@ -188,11 +176,9 @@ func (a *API) handleListTemplates(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (a *API) handleUpdateTemplate(w http.ResponseWriter, r *http.Request) {
|
||||
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid project id")
|
||||
return
|
||||
}
|
||||
// pid is guaranteed present and owned by the caller — RequireProjectAccess
|
||||
// validated it before this handler ever runs.
|
||||
pid, _ := projectIDFrom(r.Context())
|
||||
tid, err := uuid.Parse(chi.URLParam(r, "tid"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid template id")
|
||||
@@ -217,11 +203,9 @@ func (a *API) handleUpdateTemplate(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (a *API) handleDeleteTemplate(w http.ResponseWriter, r *http.Request) {
|
||||
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid project id")
|
||||
return
|
||||
}
|
||||
// pid is guaranteed present and owned by the caller — RequireProjectAccess
|
||||
// validated it before this handler ever runs.
|
||||
pid, _ := projectIDFrom(r.Context())
|
||||
tid, err := uuid.Parse(chi.URLParam(r, "tid"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid template id")
|
||||
@@ -238,11 +222,9 @@ func (a *API) handleDeleteTemplate(w http.ResponseWriter, r *http.Request) {
|
||||
// --- domains ---
|
||||
|
||||
func (a *API) handleCreateDomain(w http.ResponseWriter, r *http.Request) {
|
||||
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid project id")
|
||||
return
|
||||
}
|
||||
// pid is guaranteed present and owned by the caller — RequireProjectAccess
|
||||
// validated it before this handler ever runs.
|
||||
pid, _ := projectIDFrom(r.Context())
|
||||
var req domainRequest
|
||||
if !decodeBody(w, r, &req) {
|
||||
return
|
||||
@@ -286,11 +268,9 @@ func (a *API) handleCreateDomain(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (a *API) handleListDomains(w http.ResponseWriter, r *http.Request) {
|
||||
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid project id")
|
||||
return
|
||||
}
|
||||
// pid is guaranteed present and owned by the caller — RequireProjectAccess
|
||||
// validated it before this handler ever runs.
|
||||
pid, _ := projectIDFrom(r.Context())
|
||||
doms, err := a.Store.ListDomains(r.Context(), pid)
|
||||
if err != nil {
|
||||
log.Printf("api: list domains failed: %v", err)
|
||||
@@ -308,11 +288,9 @@ func (a *API) handleListDomains(w http.ResponseWriter, r *http.Request) {
|
||||
// check/apply a domain — this is what makes an imported domain (which
|
||||
// starts with template_id=NULL) checkable, closing the import→check loop.
|
||||
func (a *API) handleSetDomainTemplate(w http.ResponseWriter, r *http.Request) {
|
||||
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid project id")
|
||||
return
|
||||
}
|
||||
// pid is guaranteed present and owned by the caller — RequireProjectAccess
|
||||
// validated it before this handler ever runs.
|
||||
pid, _ := projectIDFrom(r.Context())
|
||||
did, err := uuid.Parse(chi.URLParam(r, "did"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid domain id")
|
||||
@@ -339,11 +317,9 @@ func (a *API) handleSetDomainTemplate(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (a *API) handleDeleteDomain(w http.ResponseWriter, r *http.Request) {
|
||||
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid project id")
|
||||
return
|
||||
}
|
||||
// pid is guaranteed present and owned by the caller — RequireProjectAccess
|
||||
// validated it before this handler ever runs.
|
||||
pid, _ := projectIDFrom(r.Context())
|
||||
did, err := uuid.Parse(chi.URLParam(r, "did"))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid domain id")
|
||||
|
||||
+29
-19
@@ -132,7 +132,9 @@ func (m *mockTenantStore) ImportDomains(_ context.Context, projectID, accountID
|
||||
|
||||
type mockCipher struct{}
|
||||
|
||||
func (mockCipher) Encrypt(plaintext []byte) (string, error) { return "ENC(" + string(plaintext) + ")", nil }
|
||||
func (mockCipher) Encrypt(plaintext []byte) (string, error) {
|
||||
return "ENC(" + string(plaintext) + ")", nil
|
||||
}
|
||||
func (mockCipher) Decrypt(enc string) ([]byte, error) {
|
||||
return []byte(strings.TrimSuffix(strings.TrimPrefix(enc, "ENC("), ")")), nil
|
||||
}
|
||||
@@ -160,9 +162,17 @@ func (mockProvider) ApplyChanges(context.Context, provider.Credentials, string,
|
||||
return nil
|
||||
}
|
||||
|
||||
// newTenantTestAPI wires a fixed authenticated user who owns whatever
|
||||
// project id is requested (alwaysOwnedAuthStore/alwaysValidSessions, see
|
||||
// middleware_test.go) — these tests exercise CRUD behavior past the
|
||||
// RequireAuth/RequireProjectAccess boundary, which has its own dedicated
|
||||
// coverage in middleware_test.go.
|
||||
func newTenantTestAPI() (*API, *mockTenantStore) {
|
||||
ts := &mockTenantStore{}
|
||||
a := &API{Store: ts, Cipher: mockCipher{}, Reg: &mockRegistry{}}
|
||||
a := &API{
|
||||
Store: ts, Cipher: mockCipher{}, Reg: &mockRegistry{},
|
||||
Auth: alwaysOwnedAuthStore(), Sessions: alwaysValidSessions(uuid.New()),
|
||||
}
|
||||
return a, ts
|
||||
}
|
||||
|
||||
@@ -173,7 +183,7 @@ func TestCreateAccount_SecretEncryptedAndNotInResponse(t *testing.T) {
|
||||
router := NewRouter(a)
|
||||
|
||||
body := `{"provider":"selectel","secret":"super-secret-token","comment":"prod"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts", strings.NewReader(body))
|
||||
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
@@ -211,7 +221,7 @@ func TestListAccounts_NoSecretsInResponse(t *testing.T) {
|
||||
}
|
||||
router := NewRouter(a)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/projects/"+testPID+"/accounts", nil)
|
||||
req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+testPID+"/accounts", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
@@ -234,7 +244,7 @@ func TestDeleteAccount_BadUUID(t *testing.T) {
|
||||
a, _ := newTenantTestAPI()
|
||||
router := NewRouter(a)
|
||||
|
||||
req := httptest.NewRequest(http.MethodDelete, "/api/v1/projects/"+testPID+"/accounts/not-a-uuid", nil)
|
||||
req := requestWithSessionCookie(http.MethodDelete, "/api/v1/projects/"+testPID+"/accounts/not-a-uuid", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
@@ -250,7 +260,7 @@ func TestCreateTemplate_SavesRecords(t *testing.T) {
|
||||
router := NewRouter(a)
|
||||
|
||||
body := `{"name":"base","records":[{"type":"A","name":"@","ttl":300,"values":["1.2.3.4"]}]}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/templates", strings.NewReader(body))
|
||||
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/templates", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
@@ -278,7 +288,7 @@ func TestUpdateTemplate_BadUUID(t *testing.T) {
|
||||
router := NewRouter(a)
|
||||
|
||||
body := `{"name":"x","records":[]}`
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/v1/projects/"+testPID+"/templates/not-a-uuid", strings.NewReader(body))
|
||||
req := requestWithSessionCookie(http.MethodPut, "/api/v1/projects/"+testPID+"/templates/not-a-uuid", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
@@ -299,7 +309,7 @@ func TestImportZones_CreatesDomainPerZone(t *testing.T) {
|
||||
}}
|
||||
router := NewRouter(a)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/"+accID.String()+"/import", nil)
|
||||
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/"+accID.String()+"/import", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
@@ -337,7 +347,7 @@ func TestImportZones_AtomicRollbackOnError(t *testing.T) {
|
||||
}}
|
||||
router := NewRouter(a)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/"+accID.String()+"/import", nil)
|
||||
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/"+accID.String()+"/import", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
@@ -356,7 +366,7 @@ func TestImportZones_BadAccountUUID(t *testing.T) {
|
||||
a, _ := newTenantTestAPI()
|
||||
router := NewRouter(a)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/not-a-uuid/import", nil)
|
||||
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/not-a-uuid/import", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
@@ -370,7 +380,7 @@ func TestCreateDomain_BadProjectUUID(t *testing.T) {
|
||||
router := NewRouter(a)
|
||||
|
||||
body := `{"providerAccountId":"` + uuid.New().String() + `","zoneName":"example.com","zoneId":"z1"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/not-a-uuid/domains", strings.NewReader(body))
|
||||
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/not-a-uuid/domains", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
@@ -390,7 +400,7 @@ func TestCreateDomain_AccountNotFoundInProject(t *testing.T) {
|
||||
// ts.accounts is empty, so GetAccount will not find this id.
|
||||
foreignAccID := uuid.New()
|
||||
body := `{"providerAccountId":"` + foreignAccID.String() + `","zoneName":"example.com","zoneId":"z1"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
|
||||
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
@@ -413,7 +423,7 @@ func TestCreateDomain_TemplateNotFoundInProject(t *testing.T) {
|
||||
|
||||
foreignTplID := uuid.New()
|
||||
body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1","templateId":"` + foreignTplID.String() + `"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
|
||||
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
@@ -434,7 +444,7 @@ func TestCreateDomain_HappyPath(t *testing.T) {
|
||||
router := NewRouter(a)
|
||||
|
||||
body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
|
||||
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
@@ -464,7 +474,7 @@ func TestCreateDomain_ValidTemplateInProject(t *testing.T) {
|
||||
router := NewRouter(a)
|
||||
|
||||
body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1","templateId":"` + tplID.String() + `"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
|
||||
req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
@@ -490,7 +500,7 @@ func TestSetDomainTemplate_ValidTemplateId(t *testing.T) {
|
||||
router := NewRouter(a)
|
||||
|
||||
body := `{"templateId":"` + tplID.String() + `"}`
|
||||
req := httptest.NewRequest(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body))
|
||||
req := requestWithSessionCookie(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
@@ -513,7 +523,7 @@ func TestSetDomainTemplate_BadTemplateUUID(t *testing.T) {
|
||||
router := NewRouter(a)
|
||||
|
||||
body := `{"templateId":"not-a-uuid"}`
|
||||
req := httptest.NewRequest(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body))
|
||||
req := requestWithSessionCookie(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
@@ -530,7 +540,7 @@ func TestSetDomainTemplate_TemplateNotFound(t *testing.T) {
|
||||
router := NewRouter(a)
|
||||
|
||||
body := `{"templateId":"` + uuid.New().String() + `"}`
|
||||
req := httptest.NewRequest(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body))
|
||||
req := requestWithSessionCookie(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
@@ -543,7 +553,7 @@ func TestDeleteDomain_BadUUID(t *testing.T) {
|
||||
a, _ := newTenantTestAPI()
|
||||
router := NewRouter(a)
|
||||
|
||||
req := httptest.NewRequest(http.MethodDelete, "/api/v1/projects/"+testPID+"/domains/not-a-uuid", nil)
|
||||
req := requestWithSessionCookie(http.MethodDelete, "/api/v1/projects/"+testPID+"/domains/not-a-uuid", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/argon2"
|
||||
)
|
||||
|
||||
const (
|
||||
argonTime = 1
|
||||
argonMemory = 64 * 1024
|
||||
argonThreads = 4
|
||||
argonKeyLen = 32
|
||||
argonSaltLen = 16
|
||||
|
||||
// Upper bounds guard against a corrupted/attacker-controlled
|
||||
// password_hash forcing an oversized argon2 computation (DoS).
|
||||
argonMaxMemoryKiB = 1 << 21 // 2 GiB in KiB
|
||||
argonMaxTime = 10
|
||||
argonMaxThreads = 16
|
||||
)
|
||||
|
||||
// paramsRe strictly matches the "m=<n>,t=<n>,p=<n>" parameter segment,
|
||||
// requiring the whole segment to be consumed (no trailing garbage).
|
||||
var paramsRe = regexp.MustCompile(`^m=([0-9]+),t=([0-9]+),p=([0-9]+)$`)
|
||||
|
||||
func HashPassword(password string) (string, error) {
|
||||
salt := make([]byte, argonSaltLen)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
return "", err
|
||||
}
|
||||
key := argon2.IDKey([]byte(password), salt, argonTime, argonMemory, argonThreads, argonKeyLen)
|
||||
b64 := base64.RawStdEncoding.EncodeToString
|
||||
return fmt.Sprintf("$argon2id$v=19$m=%d,t=%d,p=%d$%s$%s",
|
||||
argonMemory, argonTime, argonThreads, b64(salt), b64(key)), nil
|
||||
}
|
||||
|
||||
func VerifyPassword(encoded, password string) (bool, error) {
|
||||
parts := strings.Split(encoded, "$")
|
||||
if len(parts) != 6 || parts[1] != "argon2id" {
|
||||
return false, fmt.Errorf("auth: bad hash format")
|
||||
}
|
||||
if parts[2] != "v=19" {
|
||||
return false, fmt.Errorf("auth: unsupported version")
|
||||
}
|
||||
|
||||
matches := paramsRe.FindStringSubmatch(parts[3])
|
||||
if matches == nil {
|
||||
return false, fmt.Errorf("auth: bad hash format")
|
||||
}
|
||||
m64, err := strconv.ParseUint(matches[1], 10, 32)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("auth: bad hash format")
|
||||
}
|
||||
t64, err := strconv.ParseUint(matches[2], 10, 32)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("auth: bad hash format")
|
||||
}
|
||||
p64, err := strconv.ParseUint(matches[3], 10, 8)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("auth: bad hash format")
|
||||
}
|
||||
m, t, p := uint32(m64), uint32(t64), uint8(p64)
|
||||
|
||||
// argon2.IDKey panics if time<1 ("number of rounds too small") or
|
||||
// threads<1 ("parallelism degree too low"). Reject before calling it.
|
||||
// Minimum memory per argon2 spec is 8*parallelism (in KiB).
|
||||
if t < 1 || p < 1 || m < 8*uint32(p) {
|
||||
return false, fmt.Errorf("auth: bad hash params")
|
||||
}
|
||||
// Upper bounds guard against DoS via inflated parameters in a
|
||||
// corrupted or attacker-controlled stored hash.
|
||||
if m > argonMaxMemoryKiB || t > argonMaxTime || p > argonMaxThreads {
|
||||
return false, fmt.Errorf("auth: bad hash params")
|
||||
}
|
||||
|
||||
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
want, err := base64.RawStdEncoding.DecodeString(parts[5])
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
got := argon2.IDKey([]byte(password), salt, t, m, p, uint32(len(want)))
|
||||
return subtle.ConstantTimeCompare(got, want) == 1, nil
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
package auth
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestHashVerifyRoundTrip(t *testing.T) {
|
||||
h, err := HashPassword("s3cret-pw")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if h == "s3cret-pw" || len(h) < 20 {
|
||||
t.Fatalf("bad hash %q", h)
|
||||
}
|
||||
ok, err := VerifyPassword(h, "s3cret-pw")
|
||||
if err != nil || !ok {
|
||||
t.Fatalf("verify failed: %v %v", ok, err)
|
||||
}
|
||||
bad, _ := VerifyPassword(h, "wrong")
|
||||
if bad {
|
||||
t.Fatal("wrong password must not verify")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashNonDeterministic(t *testing.T) {
|
||||
a, _ := HashPassword("same")
|
||||
b, _ := HashPassword("same")
|
||||
if a == b {
|
||||
t.Fatal("salt must randomize hash")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyPasswordBadTimeDoesNotPanic(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("VerifyPassword panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
encoded := "$argon2id$v=19$m=65536,t=0,p=4$c29tZXNhbHRzb21lc2FsdA$c29tZWhhc2hzb21laGFzaA"
|
||||
ok, err := VerifyPassword(encoded, "anything")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for t=0, got nil")
|
||||
}
|
||||
if ok {
|
||||
t.Fatal("expected ok=false for t=0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyPasswordBadThreadsDoesNotPanic(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("VerifyPassword panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
encoded := "$argon2id$v=19$m=65536,t=1,p=0$c29tZXNhbHRzb21lc2FsdA$c29tZWhhc2hzb21laGFzaA"
|
||||
ok, err := VerifyPassword(encoded, "anything")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for p=0, got nil")
|
||||
}
|
||||
if ok {
|
||||
t.Fatal("expected ok=false for p=0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyPasswordUnsupportedVersion(t *testing.T) {
|
||||
encoded := "$argon2id$v=18$m=65536,t=1,p=4$c29tZXNhbHRzb21lc2FsdA$c29tZWhhc2hzb21laGFzaA"
|
||||
ok, err := VerifyPassword(encoded, "anything")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unsupported version, got nil")
|
||||
}
|
||||
if ok {
|
||||
t.Fatal("expected ok=false for unsupported version")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyPasswordGarbageFormatDoesNotPanic(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("VerifyPassword panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
ok, err := VerifyPassword("notahash", "anything")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for garbage format, got nil")
|
||||
}
|
||||
if ok {
|
||||
t.Fatal("expected ok=false for garbage format")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var ErrNoSession = errors.New("auth: no such session")
|
||||
|
||||
type SessionStore interface {
|
||||
CreateSession(ctx context.Context, userID uuid.UUID, tokenHash string, expiresAt time.Time) error
|
||||
GetSessionUser(ctx context.Context, tokenHash string) (uuid.UUID, error)
|
||||
DeleteSession(ctx context.Context, tokenHash string) error
|
||||
}
|
||||
|
||||
type Sessions struct {
|
||||
store SessionStore
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
func NewSessions(store SessionStore, ttl time.Duration) *Sessions {
|
||||
return &Sessions{store: store, ttl: ttl}
|
||||
}
|
||||
|
||||
func TokenHash(token string) string {
|
||||
sum := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func (s *Sessions) Create(ctx context.Context, userID uuid.UUID) (string, time.Time, error) {
|
||||
raw := make([]byte, 32)
|
||||
if _, err := rand.Read(raw); err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
token := base64.RawURLEncoding.EncodeToString(raw)
|
||||
exp := time.Now().Add(s.ttl)
|
||||
if err := s.store.CreateSession(ctx, userID, TokenHash(token), exp); err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
return token, exp, nil
|
||||
}
|
||||
|
||||
func (s *Sessions) Validate(ctx context.Context, token string) (uuid.UUID, error) {
|
||||
return s.store.GetSessionUser(ctx, TokenHash(token))
|
||||
}
|
||||
|
||||
func (s *Sessions) Destroy(ctx context.Context, token string) error {
|
||||
return s.store.DeleteSession(ctx, TokenHash(token))
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type memStore struct {
|
||||
byHash map[string]uuid.UUID
|
||||
exp map[string]time.Time
|
||||
}
|
||||
|
||||
func newMem() *memStore { return &memStore{byHash: map[string]uuid.UUID{}, exp: map[string]time.Time{}} }
|
||||
func (m *memStore) CreateSession(_ context.Context, uid uuid.UUID, h string, e time.Time) error {
|
||||
m.byHash[h] = uid
|
||||
m.exp[h] = e
|
||||
return nil
|
||||
}
|
||||
func (m *memStore) GetSessionUser(_ context.Context, h string) (uuid.UUID, error) {
|
||||
uid, ok := m.byHash[h]
|
||||
if !ok || time.Now().After(m.exp[h]) {
|
||||
return uuid.Nil, ErrNoSession
|
||||
}
|
||||
return uid, nil
|
||||
}
|
||||
func (m *memStore) DeleteSession(_ context.Context, h string) error { delete(m.byHash, h); return nil }
|
||||
|
||||
func TestSessionCreateValidateDestroy(t *testing.T) {
|
||||
s := NewSessions(newMem(), time.Hour)
|
||||
uid := uuid.New()
|
||||
token, exp, err := s.Create(context.Background(), uid)
|
||||
if err != nil || token == "" || exp.Before(time.Now()) {
|
||||
t.Fatalf("create: %v %q", err, token)
|
||||
}
|
||||
got, err := s.Validate(context.Background(), token)
|
||||
if err != nil || got != uid {
|
||||
t.Fatalf("validate: %v %v", got, err)
|
||||
}
|
||||
if err := s.Destroy(context.Background(), token); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := s.Validate(context.Background(), token); err == nil {
|
||||
t.Fatal("destroyed session must not validate")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateUnknownToken(t *testing.T) {
|
||||
s := NewSessions(newMem(), time.Hour)
|
||||
if _, err := s.Validate(context.Background(), "nope"); err == nil {
|
||||
t.Fatal("unknown token must error")
|
||||
}
|
||||
}
|
||||
@@ -21,7 +21,7 @@ type DomainRef struct {
|
||||
}
|
||||
|
||||
type Loader interface {
|
||||
LoadDomain(ctx context.Context, domainID uuid.UUID) (DomainRef, error)
|
||||
LoadDomain(ctx context.Context, projectID, domainID uuid.UUID) (DomainRef, error)
|
||||
}
|
||||
|
||||
type Recorder interface {
|
||||
@@ -45,8 +45,10 @@ func New(loader Loader, rec Recorder, reg *registry.Registry, cipher *crypto.Cip
|
||||
}
|
||||
|
||||
// resolve loads the domain, its provider and decrypted credentials, and computes the diff.
|
||||
func (s *DomainService) resolve(ctx context.Context, domainID uuid.UUID) (provider.Provider, provider.Credentials, DomainRef, diff.Changeset, error) {
|
||||
ref, err := s.loader.LoadDomain(ctx, domainID)
|
||||
// projectID scopes the lookup so a domainID belonging to another tenant's
|
||||
// project can never be resolved here (closes IDOR).
|
||||
func (s *DomainService) resolve(ctx context.Context, projectID, domainID uuid.UUID) (provider.Provider, provider.Credentials, DomainRef, diff.Changeset, error) {
|
||||
ref, err := s.loader.LoadDomain(ctx, projectID, domainID)
|
||||
if err != nil {
|
||||
return nil, provider.Credentials{}, ref, diff.Changeset{}, err
|
||||
}
|
||||
@@ -68,8 +70,8 @@ func (s *DomainService) resolve(ctx context.Context, domainID uuid.UUID) (provid
|
||||
}
|
||||
|
||||
// Check computes and records the diff between template and zone.
|
||||
func (s *DomainService) Check(ctx context.Context, domainID uuid.UUID) (diff.Changeset, error) {
|
||||
_, _, _, cs, err := s.resolve(ctx, domainID)
|
||||
func (s *DomainService) Check(ctx context.Context, projectID, domainID uuid.UUID) (diff.Changeset, error) {
|
||||
_, _, _, cs, err := s.resolve(ctx, projectID, domainID)
|
||||
if err != nil {
|
||||
return diff.Changeset{}, err
|
||||
}
|
||||
@@ -80,8 +82,8 @@ func (s *DomainService) Check(ctx context.Context, domainID uuid.UUID) (diff.Cha
|
||||
}
|
||||
|
||||
// Apply applies updates always (when ApplyUpdates) and prunes only when ApplyPrunes.
|
||||
func (s *DomainService) Apply(ctx context.Context, domainID uuid.UUID, req ApplyRequest) (diff.Changeset, error) {
|
||||
p, creds, ref, cs, err := s.resolve(ctx, domainID)
|
||||
func (s *DomainService) Apply(ctx context.Context, projectID, domainID uuid.UUID, req ApplyRequest) (diff.Changeset, error) {
|
||||
p, creds, ref, cs, err := s.resolve(ctx, projectID, domainID)
|
||||
if err != nil {
|
||||
return diff.Changeset{}, err
|
||||
}
|
||||
|
||||
@@ -44,7 +44,9 @@ func (f *fakeProvider) ApplyChanges(_ context.Context, _ provider.Credentials, _
|
||||
|
||||
type fakeLoader struct{ ref DomainRef }
|
||||
|
||||
func (l fakeLoader) LoadDomain(context.Context, uuid.UUID) (DomainRef, error) { return l.ref, nil }
|
||||
func (l fakeLoader) LoadDomain(context.Context, uuid.UUID, uuid.UUID) (DomainRef, error) {
|
||||
return l.ref, nil
|
||||
}
|
||||
|
||||
type nopRecorder struct{}
|
||||
|
||||
@@ -66,7 +68,7 @@ func TestCheckProducesDiff(t *testing.T) {
|
||||
{Type: "A", Name: "a.example.com.", TTL: 300, Values: []string{"1.1.1.1"}}, // update
|
||||
}}
|
||||
svc, _ := setup(t, actual, tmpl)
|
||||
cs, err := svc.Check(context.Background(), uuid.New())
|
||||
cs, err := svc.Check(context.Background(), uuid.New(), uuid.New())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -87,7 +89,7 @@ func TestApplyRespectsPruneGuard(t *testing.T) {
|
||||
|
||||
// applyPrunes=false → удаление b НЕ применяется
|
||||
svc, fp := setup(t, actual, tmpl)
|
||||
if _, err := svc.Apply(context.Background(), uuid.New(), ApplyRequest{ApplyUpdates: true, ApplyPrunes: false}); err != nil {
|
||||
if _, err := svc.Apply(context.Background(), uuid.New(), uuid.New(), ApplyRequest{ApplyUpdates: true, ApplyPrunes: false}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for _, d := range fp.applied.Diffs {
|
||||
@@ -98,7 +100,7 @@ func TestApplyRespectsPruneGuard(t *testing.T) {
|
||||
|
||||
// applyPrunes=true → удаление b применяется
|
||||
svc2, fp2 := setup(t, actual, tmpl)
|
||||
if _, err := svc2.Apply(context.Background(), uuid.New(), ApplyRequest{ApplyUpdates: true, ApplyPrunes: true}); err != nil {
|
||||
if _, err := svc2.Apply(context.Background(), uuid.New(), uuid.New(), ApplyRequest{ApplyUpdates: true, ApplyPrunes: true}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var sawDelete bool
|
||||
|
||||
@@ -0,0 +1,185 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
)
|
||||
|
||||
// TestRegisterUser_CreatesUserAndOwnedProject verifies the RegisterUser
|
||||
// transaction: a user and a default project are created together, and the
|
||||
// project belongs to that user.
|
||||
func TestRegisterUser_CreatesUserAndOwnedProject(t *testing.T) {
|
||||
s, ctx := newStore(t)
|
||||
|
||||
u, p, err := s.RegisterUser(ctx, "alice@example.com", "argon2-hash")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if u.Email != "alice@example.com" || u.PasswordHash != "argon2-hash" {
|
||||
t.Fatalf("unexpected user: %+v", u)
|
||||
}
|
||||
if p.UserID != u.ID {
|
||||
t.Fatalf("expected project to belong to user %s, got %+v", u.ID, p)
|
||||
}
|
||||
|
||||
owned, err := s.GetProjectOwned(ctx, p.ID, u.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if owned.ID != p.ID {
|
||||
t.Fatalf("expected owned project %s, got %+v", p.ID, owned)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetUserByEmail_FindsRegisteredUser verifies email lookup returns the
|
||||
// same user created by RegisterUser.
|
||||
func TestGetUserByEmail_FindsRegisteredUser(t *testing.T) {
|
||||
s, ctx := newStore(t)
|
||||
|
||||
u, _, err := s.RegisterUser(ctx, "bob@example.com", "argon2-hash")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got, err := s.GetUserByEmail(ctx, "bob@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got.ID != u.ID || got.PasswordHash != "argon2-hash" {
|
||||
t.Fatalf("unexpected user: %+v", got)
|
||||
}
|
||||
|
||||
if _, err := s.GetUserByEmail(ctx, "nobody@example.com"); err == nil {
|
||||
t.Fatal("expected error for unknown email, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRegisterUser_DuplicateEmailReturnsErrEmailTaken verifies the fix for
|
||||
// the duplicate-registration gap: a second RegisterUser call for an
|
||||
// already-taken email must fail with the ErrEmailTaken sentinel (mapped from
|
||||
// the UNIQUE constraint violation on users.email), not a generic pgx error.
|
||||
func TestRegisterUser_DuplicateEmailReturnsErrEmailTaken(t *testing.T) {
|
||||
s, ctx := newStore(t)
|
||||
|
||||
if _, _, err := s.RegisterUser(ctx, "dup@example.com", "argon2-hash"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, _, err := s.RegisterUser(ctx, "dup@example.com", "argon2-hash"); !errors.Is(err, ErrEmailTaken) {
|
||||
t.Fatalf("expected ErrEmailTaken, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetUserByID_ReturnsUser verifies the fix for the /me gap: GetUserByID
|
||||
// returns the same user created by RegisterUser, including their real email.
|
||||
func TestGetUserByID_ReturnsUser(t *testing.T) {
|
||||
s, ctx := newStore(t)
|
||||
|
||||
u, _, err := s.RegisterUser(ctx, "gina@example.com", "argon2-hash")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got, err := s.GetUserByID(ctx, u.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got.ID != u.ID || got.Email != "gina@example.com" {
|
||||
t.Fatalf("unexpected user: %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionLifecycle_CreateGetDelete verifies CreateSession + GetSessionUser
|
||||
// round-trips to the owning user ID, an expired session is excluded from
|
||||
// GetSessionUser, and DeleteSession removes the session.
|
||||
func TestSessionLifecycle_CreateGetDelete(t *testing.T) {
|
||||
s, ctx := newStore(t)
|
||||
|
||||
u, _, err := s.RegisterUser(ctx, "carol@example.com", "argon2-hash")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tokenHash := "sha256-token-hash"
|
||||
if err := s.CreateSession(ctx, u.ID, tokenHash, time.Now().Add(time.Hour)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
gotUserID, err := s.GetSessionUser(ctx, tokenHash)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if gotUserID != u.ID {
|
||||
t.Fatalf("expected user %s, got %s", u.ID, gotUserID)
|
||||
}
|
||||
|
||||
if err := s.DeleteSession(ctx, tokenHash); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := s.GetSessionUser(ctx, tokenHash); err == nil {
|
||||
t.Fatal("expected error after DeleteSession, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetSessionUser_ExpiredSessionNotReturned verifies the query's
|
||||
// expires_at > now() condition: a session created with an expiry in the
|
||||
// past must not be returned by GetSessionUser.
|
||||
func TestGetSessionUser_ExpiredSessionNotReturned(t *testing.T) {
|
||||
s, ctx := newStore(t)
|
||||
|
||||
u, _, err := s.RegisterUser(ctx, "dave@example.com", "argon2-hash")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tokenHash := "expired-token-hash"
|
||||
if err := s.CreateSession(ctx, u.ID, tokenHash, time.Now().Add(-time.Hour)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := s.GetSessionUser(ctx, tokenHash); err == nil {
|
||||
t.Fatal("expected expired session to not be returned, got nil error")
|
||||
} else if err != pgx.ErrNoRows {
|
||||
t.Fatalf("expected pgx.ErrNoRows, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetProjectOwned_ForeignUserRejected verifies that looking up a project
|
||||
// with the wrong user ID fails, so one tenant cannot address another
|
||||
// tenant's project by guessing its ID.
|
||||
func TestGetProjectOwned_ForeignUserRejected(t *testing.T) {
|
||||
s, ctx := newStore(t)
|
||||
|
||||
_, p, err := s.RegisterUser(ctx, "erin@example.com", "argon2-hash")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
foreignUserID := uuid.New()
|
||||
if _, err := s.GetProjectOwned(ctx, p.ID, foreignUserID); err == nil {
|
||||
t.Fatal("expected error for foreign user ID, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetUserProject_ReturnsTheUsersProject verifies GetUserProject returns
|
||||
// the project created for that user by RegisterUser.
|
||||
func TestGetUserProject_ReturnsTheUsersProject(t *testing.T) {
|
||||
s, ctx := newStore(t)
|
||||
|
||||
u, p, err := s.RegisterUser(ctx, "frank@example.com", "argon2-hash")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got, err := s.GetUserProject(ctx, u.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got.ID != p.ID {
|
||||
t.Fatalf("expected project %s, got %+v", p.ID, got)
|
||||
}
|
||||
}
|
||||
@@ -162,9 +162,14 @@ SELECT d.zone_id, a.provider, a.secret_enc, t.doc
|
||||
FROM domains d
|
||||
JOIN provider_accounts a ON a.id = d.provider_account_id
|
||||
LEFT JOIN templates t ON t.id = d.template_id
|
||||
WHERE d.id = $1
|
||||
WHERE d.id = $1 AND d.project_id = $2
|
||||
`
|
||||
|
||||
type LoadDomainFullParams struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
ProjectID uuid.UUID `json:"project_id"`
|
||||
}
|
||||
|
||||
type LoadDomainFullRow struct {
|
||||
ZoneID string `json:"zone_id"`
|
||||
Provider string `json:"provider"`
|
||||
@@ -172,8 +177,8 @@ type LoadDomainFullRow struct {
|
||||
Doc *dto.TemplateDoc `json:"doc"`
|
||||
}
|
||||
|
||||
func (q *Queries) LoadDomainFull(ctx context.Context, id uuid.UUID) (LoadDomainFullRow, error) {
|
||||
row := q.db.QueryRow(ctx, loadDomainFull, id)
|
||||
func (q *Queries) LoadDomainFull(ctx context.Context, arg LoadDomainFullParams) (LoadDomainFullRow, error) {
|
||||
row := q.db.QueryRow(ctx, loadDomainFull, arg.ID, arg.ProjectID)
|
||||
var i LoadDomainFullRow
|
||||
err := row.Scan(
|
||||
&i.ZoneID,
|
||||
|
||||
@@ -43,6 +43,14 @@ type ProviderAccount struct {
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
TokenHash string `json:"token_hash"`
|
||||
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
}
|
||||
|
||||
type Template struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
ProjectID uuid.UUID `json:"project_id"`
|
||||
@@ -57,4 +65,5 @@ type User struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Email string `json:"email"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
PasswordHash *string `json:"password_hash"`
|
||||
}
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.31.1
|
||||
// source: projects.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const createProject = `-- name: CreateProject :one
|
||||
INSERT INTO projects (id, user_id, name) VALUES ($1, $2, $3) RETURNING id, user_id, name, created_at
|
||||
`
|
||||
|
||||
type CreateProjectParams struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func (q *Queries) CreateProject(ctx context.Context, arg CreateProjectParams) (Project, error) {
|
||||
row := q.db.QueryRow(ctx, createProject, arg.ID, arg.UserID, arg.Name)
|
||||
var i Project
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.UserID,
|
||||
&i.Name,
|
||||
&i.CreatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getProjectOwned = `-- name: GetProjectOwned :one
|
||||
SELECT id, user_id, name, created_at FROM projects WHERE id = $1 AND user_id = $2
|
||||
`
|
||||
|
||||
type GetProjectOwnedParams struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetProjectOwned(ctx context.Context, arg GetProjectOwnedParams) (Project, error) {
|
||||
row := q.db.QueryRow(ctx, getProjectOwned, arg.ID, arg.UserID)
|
||||
var i Project
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.UserID,
|
||||
&i.Name,
|
||||
&i.CreatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getUserProject = `-- name: GetUserProject :one
|
||||
SELECT id, user_id, name, created_at FROM projects WHERE user_id = $1 ORDER BY created_at LIMIT 1
|
||||
`
|
||||
|
||||
func (q *Queries) GetUserProject(ctx context.Context, userID uuid.UUID) (Project, error) {
|
||||
row := q.db.QueryRow(ctx, getUserProject, userID)
|
||||
var i Project
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.UserID,
|
||||
&i.Name,
|
||||
&i.CreatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.31.1
|
||||
// source: sessions.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const createSession = `-- name: CreateSession :exec
|
||||
INSERT INTO sessions (id, user_id, token_hash, expires_at) VALUES ($1, $2, $3, $4)
|
||||
`
|
||||
|
||||
type CreateSessionParams struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
TokenHash string `json:"token_hash"`
|
||||
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
|
||||
}
|
||||
|
||||
func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) error {
|
||||
_, err := q.db.Exec(ctx, createSession,
|
||||
arg.ID,
|
||||
arg.UserID,
|
||||
arg.TokenHash,
|
||||
arg.ExpiresAt,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteSession = `-- name: DeleteSession :exec
|
||||
DELETE FROM sessions WHERE token_hash = $1
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteSession(ctx context.Context, tokenHash string) error {
|
||||
_, err := q.db.Exec(ctx, deleteSession, tokenHash)
|
||||
return err
|
||||
}
|
||||
|
||||
const getSessionUser = `-- name: GetSessionUser :one
|
||||
SELECT user_id FROM sessions WHERE token_hash = $1 AND expires_at > now()
|
||||
`
|
||||
|
||||
func (q *Queries) GetSessionUser(ctx context.Context, tokenHash string) (uuid.UUID, error) {
|
||||
row := q.db.QueryRow(ctx, getSessionUser, tokenHash)
|
||||
var user_id uuid.UUID
|
||||
err := row.Scan(&user_id)
|
||||
return user_id, err
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.31.1
|
||||
// source: users.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const createUser = `-- name: CreateUser :one
|
||||
INSERT INTO users (id, email, password_hash) VALUES ($1, $2, $3) RETURNING id, email, created_at, password_hash
|
||||
`
|
||||
|
||||
type CreateUserParams struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Email string `json:"email"`
|
||||
PasswordHash *string `json:"password_hash"`
|
||||
}
|
||||
|
||||
func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) {
|
||||
row := q.db.QueryRow(ctx, createUser, arg.ID, arg.Email, arg.PasswordHash)
|
||||
var i User
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Email,
|
||||
&i.CreatedAt,
|
||||
&i.PasswordHash,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getUserByEmail = `-- name: GetUserByEmail :one
|
||||
SELECT id, email, created_at, password_hash FROM users WHERE email = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetUserByEmail(ctx context.Context, email string) (User, error) {
|
||||
row := q.db.QueryRow(ctx, getUserByEmail, email)
|
||||
var i User
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Email,
|
||||
&i.CreatedAt,
|
||||
&i.PasswordHash,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getUserByID = `-- name: GetUserByID :one
|
||||
SELECT id, email, created_at, password_hash FROM users WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetUserByID(ctx context.Context, id uuid.UUID) (User, error) {
|
||||
row := q.db.QueryRow(ctx, getUserByID, id)
|
||||
var i User
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Email,
|
||||
&i.CreatedAt,
|
||||
&i.PasswordHash,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -13,9 +13,11 @@ import (
|
||||
)
|
||||
|
||||
// LoadDomain joins domains+provider_accounts+templates to build the
|
||||
// service.DomainRef needed to check/apply a domain's DNS records.
|
||||
func (s *Store) LoadDomain(ctx context.Context, domainID uuid.UUID) (service.DomainRef, error) {
|
||||
row, err := s.q.LoadDomainFull(ctx, domainID)
|
||||
// service.DomainRef needed to check/apply a domain's DNS records. Scoped by
|
||||
// projectID so a domain belonging to another tenant's project can never be
|
||||
// loaded, even if its domainID is guessed/leaked (closes IDOR).
|
||||
func (s *Store) LoadDomain(ctx context.Context, projectID, domainID uuid.UUID) (service.DomainRef, error) {
|
||||
row, err := s.q.LoadDomainFull(ctx, db.LoadDomainFullParams{ID: domainID, ProjectID: projectID})
|
||||
if err != nil {
|
||||
return service.DomainRef{}, err
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ func TestLoadDomainAndSaveCheckRun(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ref, err := s.LoadDomain(ctx, domain.ID)
|
||||
ref, err := s.LoadDomain(ctx, defaultProject, domain.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -87,7 +87,7 @@ func TestLoadDomainNoTemplate(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := s.LoadDomain(ctx, domain.ID); err == nil {
|
||||
if _, err := s.LoadDomain(ctx, defaultProject, domain.ID); err == nil {
|
||||
t.Fatal("expected error for domain without template, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
-- +goose Up
|
||||
ALTER TABLE users ADD COLUMN password_hash text;
|
||||
|
||||
CREATE TABLE sessions (
|
||||
id uuid PRIMARY KEY,
|
||||
user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
token_hash text NOT NULL UNIQUE,
|
||||
expires_at timestamptz NOT NULL,
|
||||
created_at timestamptz NOT NULL DEFAULT now()
|
||||
);
|
||||
CREATE INDEX sessions_token_hash_idx ON sessions (token_hash);
|
||||
|
||||
-- +goose Down
|
||||
DROP TABLE sessions;
|
||||
ALTER TABLE users DROP COLUMN password_hash;
|
||||
@@ -27,4 +27,4 @@ SELECT d.zone_id, a.provider, a.secret_enc, t.doc
|
||||
FROM domains d
|
||||
JOIN provider_accounts a ON a.id = d.provider_account_id
|
||||
LEFT JOIN templates t ON t.id = d.template_id
|
||||
WHERE d.id = $1;
|
||||
WHERE d.id = $1 AND d.project_id = $2;
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
-- name: CreateProject :one
|
||||
INSERT INTO projects (id, user_id, name) VALUES ($1, $2, $3) RETURNING *;
|
||||
|
||||
-- name: GetProjectOwned :one
|
||||
SELECT * FROM projects WHERE id = $1 AND user_id = $2;
|
||||
|
||||
-- name: GetUserProject :one
|
||||
SELECT * FROM projects WHERE user_id = $1 ORDER BY created_at LIMIT 1;
|
||||
@@ -0,0 +1,8 @@
|
||||
-- name: CreateSession :exec
|
||||
INSERT INTO sessions (id, user_id, token_hash, expires_at) VALUES ($1, $2, $3, $4);
|
||||
|
||||
-- name: GetSessionUser :one
|
||||
SELECT user_id FROM sessions WHERE token_hash = $1 AND expires_at > now();
|
||||
|
||||
-- name: DeleteSession :exec
|
||||
DELETE FROM sessions WHERE token_hash = $1;
|
||||
@@ -0,0 +1,8 @@
|
||||
-- name: CreateUser :one
|
||||
INSERT INTO users (id, email, password_hash) VALUES ($1, $2, $3) RETURNING *;
|
||||
|
||||
-- name: GetUserByEmail :one
|
||||
SELECT * FROM users WHERE email = $1;
|
||||
|
||||
-- name: GetUserByID :one
|
||||
SELECT * FROM users WHERE id = $1;
|
||||
@@ -214,7 +214,7 @@ func TestSetDomainTemplate_ClosesImportCheckLoop(t *testing.T) {
|
||||
dom := doms[0]
|
||||
|
||||
// Before binding, the domain is not checkable.
|
||||
if _, err := s.LoadDomain(ctx, dom.ID); err == nil {
|
||||
if _, err := s.LoadDomain(ctx, defaultProject, dom.ID); err == nil {
|
||||
t.Fatal("expected LoadDomain to fail before a template is bound")
|
||||
}
|
||||
|
||||
@@ -234,7 +234,7 @@ func TestSetDomainTemplate_ClosesImportCheckLoop(t *testing.T) {
|
||||
t.Fatalf("expected domain.TemplateID=%s, got %+v", tpl.ID, updated.TemplateID)
|
||||
}
|
||||
|
||||
ref, err := s.LoadDomain(ctx, dom.ID)
|
||||
ref, err := s.LoadDomain(ctx, defaultProject, dom.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("expected LoadDomain to succeed after binding template, got error: %v", err)
|
||||
}
|
||||
|
||||
@@ -3,15 +3,23 @@ package store
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"github.com/vasyakrg/dns-autoresolver/internal/provider"
|
||||
"github.com/vasyakrg/dns-autoresolver/internal/store/db"
|
||||
"github.com/vasyakrg/dns-autoresolver/internal/store/dto"
|
||||
)
|
||||
|
||||
// ErrEmailTaken is returned by RegisterUser when the email is already
|
||||
// registered — a UNIQUE constraint violation (pgerrcode 23505) on
|
||||
// users.email.
|
||||
var ErrEmailTaken = errors.New("store: email already registered")
|
||||
|
||||
// Account/Template/Domain are provider-neutral domain structs returned by the
|
||||
// thin wrappers below, so callers (internal/api) never need to import
|
||||
// internal/store/db directly.
|
||||
@@ -222,3 +230,142 @@ func (s *Store) SetDomainTemplate(ctx context.Context, domainID, projectID uuid.
|
||||
}
|
||||
return domainFromDB(d), nil
|
||||
}
|
||||
|
||||
// User and Project are provider-neutral domain structs for the auth/tenant
|
||||
// layer (Фаза 2), mirroring the Account/Template/Domain wrappers above so
|
||||
// callers never need to import internal/store/db directly.
|
||||
|
||||
type User struct {
|
||||
ID uuid.UUID
|
||||
Email string
|
||||
PasswordHash string
|
||||
}
|
||||
|
||||
type Project struct {
|
||||
ID uuid.UUID
|
||||
UserID uuid.UUID
|
||||
Name string
|
||||
}
|
||||
|
||||
// ptr is a small helper for passing a Go string into a nullable text column
|
||||
// (password_hash) via sqlc's generated *string param type.
|
||||
func ptr(s string) *string { return &s }
|
||||
|
||||
// strFromPtr converts a nullable text column back into a plain string; a
|
||||
// nil password_hash never happens on the real registration flow (an argon2
|
||||
// hash is always supplied), but is handled defensively here.
|
||||
func strFromPtr(p *string) string {
|
||||
if p == nil {
|
||||
return ""
|
||||
}
|
||||
return *p
|
||||
}
|
||||
|
||||
func toUser(u db.User) User {
|
||||
return User{ID: u.ID, Email: u.Email, PasswordHash: strFromPtr(u.PasswordHash)}
|
||||
}
|
||||
|
||||
func toProject(p db.Project) Project {
|
||||
return Project{ID: p.ID, UserID: p.UserID, Name: p.Name}
|
||||
}
|
||||
|
||||
func (s *Store) CreateUser(ctx context.Context, email, passwordHash string) (User, error) {
|
||||
u, err := s.q.CreateUser(ctx, db.CreateUserParams{ID: uuid.New(), Email: email, PasswordHash: ptr(passwordHash)})
|
||||
if err != nil {
|
||||
return User{}, err
|
||||
}
|
||||
return toUser(u), nil
|
||||
}
|
||||
|
||||
func (s *Store) GetUserByEmail(ctx context.Context, email string) (User, error) {
|
||||
u, err := s.q.GetUserByEmail(ctx, email)
|
||||
if err != nil {
|
||||
return User{}, err
|
||||
}
|
||||
return toUser(u), nil
|
||||
}
|
||||
|
||||
// GetUserByID looks up a user by primary key — used by handleMe (Task 3
|
||||
// hardening) to return the authenticated caller's real email instead of
|
||||
// leaving it blank.
|
||||
func (s *Store) GetUserByID(ctx context.Context, id uuid.UUID) (User, error) {
|
||||
u, err := s.q.GetUserByID(ctx, id)
|
||||
if err != nil {
|
||||
return User{}, err
|
||||
}
|
||||
return toUser(u), nil
|
||||
}
|
||||
|
||||
func (s *Store) CreateProjectForUser(ctx context.Context, userID uuid.UUID, name string) (Project, error) {
|
||||
p, err := s.q.CreateProject(ctx, db.CreateProjectParams{ID: uuid.New(), UserID: userID, Name: name})
|
||||
if err != nil {
|
||||
return Project{}, err
|
||||
}
|
||||
return toProject(p), nil
|
||||
}
|
||||
|
||||
func (s *Store) GetProjectOwned(ctx context.Context, projectID, userID uuid.UUID) (Project, error) {
|
||||
p, err := s.q.GetProjectOwned(ctx, db.GetProjectOwnedParams{ID: projectID, UserID: userID})
|
||||
if err != nil {
|
||||
return Project{}, err
|
||||
}
|
||||
return toProject(p), nil
|
||||
}
|
||||
|
||||
func (s *Store) GetUserProject(ctx context.Context, userID uuid.UUID) (Project, error) {
|
||||
p, err := s.q.GetUserProject(ctx, userID)
|
||||
if err != nil {
|
||||
return Project{}, err
|
||||
}
|
||||
return toProject(p), nil
|
||||
}
|
||||
|
||||
func (s *Store) CreateSession(ctx context.Context, userID uuid.UUID, tokenHash string, expiresAt time.Time) error {
|
||||
return s.q.CreateSession(ctx, db.CreateSessionParams{
|
||||
ID: uuid.New(),
|
||||
UserID: userID,
|
||||
TokenHash: tokenHash,
|
||||
ExpiresAt: pgtype.Timestamptz{Time: expiresAt, Valid: true},
|
||||
})
|
||||
}
|
||||
|
||||
// GetSessionUser returns the owning user ID for a non-expired session token;
|
||||
// expired sessions are excluded by the query itself (expires_at > now()).
|
||||
func (s *Store) GetSessionUser(ctx context.Context, tokenHash string) (uuid.UUID, error) {
|
||||
return s.q.GetSessionUser(ctx, tokenHash)
|
||||
}
|
||||
|
||||
func (s *Store) DeleteSession(ctx context.Context, tokenHash string) error {
|
||||
return s.q.DeleteSession(ctx, tokenHash)
|
||||
}
|
||||
|
||||
// RegisterUser creates a user and their default project in one transaction,
|
||||
// mirroring the ImportDomains pattern above: if project creation fails, the
|
||||
// user insert is rolled back too, so a caller never observes a user without
|
||||
// a default project.
|
||||
func (s *Store) RegisterUser(ctx context.Context, email, passwordHash string) (User, Project, error) {
|
||||
tx, err := s.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return User{}, Project{}, err
|
||||
}
|
||||
defer tx.Rollback(ctx) // no-op once Commit has succeeded
|
||||
|
||||
q := s.q.WithTx(tx)
|
||||
uid := uuid.New()
|
||||
dbu, err := q.CreateUser(ctx, db.CreateUserParams{ID: uid, Email: email, PasswordHash: ptr(passwordHash)})
|
||||
if err != nil {
|
||||
var pgErr *pgconn.PgError
|
||||
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
|
||||
return User{}, Project{}, ErrEmailTaken
|
||||
}
|
||||
return User{}, Project{}, err
|
||||
}
|
||||
dbp, err := q.CreateProject(ctx, db.CreateProjectParams{ID: uuid.New(), UserID: uid, Name: "default"})
|
||||
if err != nil {
|
||||
return User{}, Project{}, err
|
||||
}
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return User{}, Project{}, err
|
||||
}
|
||||
return toUser(dbu), toProject(dbp), nil
|
||||
}
|
||||
|
||||
+13
-2
@@ -1,19 +1,30 @@
|
||||
import { render, screen, within } from "@testing-library/react"
|
||||
import { MemoryRouter } from "react-router-dom"
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query"
|
||||
import { vi, test, expect } from "vitest"
|
||||
import { App } from "./App"
|
||||
import { AuthProvider } from "@/auth/AuthContext"
|
||||
import { api } from "@/api/client"
|
||||
|
||||
test("renders navigation and redirects to domains", async () => {
|
||||
vi.spyOn(api.auth, "me").mockResolvedValue({
|
||||
user: { id: "u1", email: "a@b.com" },
|
||||
project: { id: "p1", name: "Default" },
|
||||
})
|
||||
vi.spyOn(api, "listDomains").mockResolvedValue([])
|
||||
|
||||
test("renders navigation and redirects to domains", () => {
|
||||
render(
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
<AuthProvider>
|
||||
<MemoryRouter initialEntries={["/"]}>
|
||||
<App />
|
||||
</MemoryRouter>
|
||||
</AuthProvider>
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
// Sidebar nav also renders a "Domains" link label, so scope the assertion
|
||||
// to the routed page content to unambiguously confirm the redirect + page.
|
||||
const main = screen.getByRole("main")
|
||||
const main = await screen.findByRole("main")
|
||||
expect(within(main).getByText("Domains")).toBeInTheDocument()
|
||||
expect(screen.getByRole("link", { name: /domains/i })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
+20
-6
@@ -1,20 +1,34 @@
|
||||
import type { ReactNode } from "react"
|
||||
import { Routes, Route, Navigate } from "react-router-dom"
|
||||
import { ProtectedRoute } from "@/auth/ProtectedRoute"
|
||||
import { Layout } from "@/components/Layout"
|
||||
import { AccountsPage } from "@/pages/AccountsPage"
|
||||
import { DomainDiffPage } from "@/pages/DomainDiffPage"
|
||||
import { DomainsPage } from "@/pages/DomainsPage"
|
||||
import { LoginPage } from "@/pages/LoginPage"
|
||||
import { RegisterPage } from "@/pages/RegisterPage"
|
||||
import { TemplatesPage } from "@/pages/TemplatesPage"
|
||||
|
||||
// Every non-auth route shares the same guard + chrome; wrapping here keeps
|
||||
// each <Route> below a one-liner instead of repeating both on every page.
|
||||
function Protected({ children }: { children: ReactNode }) {
|
||||
return (
|
||||
<ProtectedRoute>
|
||||
<Layout>{children}</Layout>
|
||||
</ProtectedRoute>
|
||||
)
|
||||
}
|
||||
|
||||
export function App() {
|
||||
return (
|
||||
<Layout>
|
||||
<Routes>
|
||||
<Route path="/login" element={<LoginPage />} />
|
||||
<Route path="/register" element={<RegisterPage />} />
|
||||
<Route path="/" element={<Navigate to="/domains" replace />} />
|
||||
<Route path="/domains" element={<DomainsPage />} />
|
||||
<Route path="/domains/:id" element={<DomainDiffPage />} />
|
||||
<Route path="/accounts" element={<AccountsPage />} />
|
||||
<Route path="/templates" element={<TemplatesPage />} />
|
||||
<Route path="/domains" element={<Protected><DomainsPage /></Protected>} />
|
||||
<Route path="/domains/:id" element={<Protected><DomainDiffPage /></Protected>} />
|
||||
<Route path="/accounts" element={<Protected><AccountsPage /></Protected>} />
|
||||
<Route path="/templates" element={<Protected><TemplatesPage /></Protected>} />
|
||||
</Routes>
|
||||
</Layout>
|
||||
)
|
||||
}
|
||||
|
||||
+95
-10
@@ -1,6 +1,7 @@
|
||||
import { describe, it, expect, vi, beforeEach } from "vitest"
|
||||
import { api } from "./client"
|
||||
import { DEFAULT_PROJECT_ID } from "@/lib/config"
|
||||
import { api, UnauthorizedError } from "./client"
|
||||
|
||||
const PROJECT_ID = "11111111-1111-1111-1111-111111111111"
|
||||
|
||||
beforeEach(() => { vi.restoreAllMocks() })
|
||||
|
||||
@@ -13,19 +14,72 @@ function mockFetch(body: unknown, ok = true, status = 200) {
|
||||
}
|
||||
|
||||
describe("api client", () => {
|
||||
it("lists accounts at project-scoped path", async () => {
|
||||
it("sends credentials:include on every request", async () => {
|
||||
const spy = mockFetch([])
|
||||
await api.listAccounts(PROJECT_ID)
|
||||
const [, opts] = spy.mock.calls[0]
|
||||
expect((opts as RequestInit).credentials).toBe("include")
|
||||
})
|
||||
|
||||
describe("api.auth", () => {
|
||||
it("login POSTs to /api/v1/auth/login with credentials:include", async () => {
|
||||
const spy = mockFetch({ user: { id: "u1", email: "a@b.com" }, project: { id: "p1", name: "Default" } })
|
||||
await api.auth.login("a@b.com", "secret")
|
||||
const [url, opts] = spy.mock.calls[0]
|
||||
expect(url).toBe("/api/v1/auth/login")
|
||||
expect((opts as RequestInit).method).toBe("POST")
|
||||
expect((opts as RequestInit).credentials).toBe("include")
|
||||
expect(String((opts as RequestInit).body)).toContain("secret")
|
||||
})
|
||||
|
||||
it("register POSTs to /api/v1/auth/register", async () => {
|
||||
const spy = mockFetch({ user: { id: "u1", email: "a@b.com" }, project: { id: "p1", name: "Default" } })
|
||||
await api.auth.register("a@b.com", "secret")
|
||||
const [url, opts] = spy.mock.calls[0]
|
||||
expect(url).toBe("/api/v1/auth/register")
|
||||
expect((opts as RequestInit).method).toBe("POST")
|
||||
})
|
||||
|
||||
it("logout POSTs to /api/v1/auth/logout", async () => {
|
||||
const spy = mockFetch(undefined, true, 204)
|
||||
await api.auth.logout()
|
||||
const [url, opts] = spy.mock.calls[0]
|
||||
expect(url).toBe("/api/v1/auth/logout")
|
||||
expect((opts as RequestInit).method).toBe("POST")
|
||||
})
|
||||
|
||||
it("me GETs /api/v1/auth/me and returns AuthState", async () => {
|
||||
const state = { user: { id: "u1", email: "a@b.com" }, project: { id: "p1", name: "Default" } }
|
||||
const spy = mockFetch(state)
|
||||
const result = await api.auth.me()
|
||||
const [url] = spy.mock.calls[0]
|
||||
expect(url).toBe("/api/v1/auth/me")
|
||||
expect(result).toEqual(state)
|
||||
})
|
||||
})
|
||||
|
||||
it("resource methods hit project-scoped path with projectId first", async () => {
|
||||
const spy = mockFetch([{ id: "a1", provider: "selectel", comment: "" }])
|
||||
const accounts = await api.listAccounts()
|
||||
const accounts = await api.listAccounts(PROJECT_ID)
|
||||
expect(accounts).toHaveLength(1)
|
||||
expect(spy).toHaveBeenCalledWith(
|
||||
`/api/v1/projects/${DEFAULT_PROJECT_ID}/accounts`,
|
||||
`/api/v1/projects/${PROJECT_ID}/accounts`,
|
||||
expect.objectContaining({ method: "GET" }),
|
||||
)
|
||||
})
|
||||
|
||||
it("listDomains(projectId) hits /api/v1/projects/{projectId}/domains", async () => {
|
||||
const spy = mockFetch([])
|
||||
await api.listDomains(PROJECT_ID)
|
||||
expect(spy).toHaveBeenCalledWith(
|
||||
`/api/v1/projects/${PROJECT_ID}/domains`,
|
||||
expect.objectContaining({ method: "GET" }),
|
||||
)
|
||||
})
|
||||
|
||||
it("sends secret on account creation but path has no secret leakage in response typing", async () => {
|
||||
const spy = mockFetch({ id: "a2", provider: "selectel", comment: "prod" })
|
||||
await api.createAccount({ provider: "selectel", secret: "TOKEN", comment: "prod" })
|
||||
await api.createAccount(PROJECT_ID, { provider: "selectel", secret: "TOKEN", comment: "prod" })
|
||||
const [, opts] = spy.mock.calls[0]
|
||||
expect((opts as RequestInit).method).toBe("POST")
|
||||
expect(String((opts as RequestInit).body)).toContain("TOKEN")
|
||||
@@ -33,14 +87,45 @@ describe("api client", () => {
|
||||
|
||||
it("throws on non-ok response", async () => {
|
||||
mockFetch({ error: "boom" }, false, 500)
|
||||
await expect(api.listDomains()).rejects.toThrow()
|
||||
await expect(api.listDomains(PROJECT_ID)).rejects.toThrow()
|
||||
})
|
||||
|
||||
it("applies with prune flag", async () => {
|
||||
it("throws UnauthorizedError on 401", async () => {
|
||||
mockFetch({ error: "unauthorized" }, false, 401)
|
||||
await expect(api.listDomains(PROJECT_ID)).rejects.toThrow(UnauthorizedError)
|
||||
})
|
||||
|
||||
it("applies with prune flag using projectId, id, body order", async () => {
|
||||
const spy = mockFetch({ updates: [], prunes: [], readOnly: [], inSyncCount: 0 })
|
||||
await api.applyDomain("d1", { applyUpdates: true, applyPrunes: true })
|
||||
await api.applyDomain(PROJECT_ID, "d1", { applyUpdates: true, applyPrunes: true })
|
||||
const [url, opts] = spy.mock.calls[0]
|
||||
expect(url).toContain("/domains/d1/apply")
|
||||
expect(url).toBe(`/api/v1/projects/${PROJECT_ID}/domains/d1/apply`)
|
||||
expect(String((opts as RequestInit).body)).toContain("applyPrunes")
|
||||
})
|
||||
|
||||
it("checkDomain(projectId, id) hits project-scoped check path", async () => {
|
||||
const spy = mockFetch({ updates: [], prunes: [], readOnly: [], inSyncCount: 0 })
|
||||
await api.checkDomain(PROJECT_ID, "d1")
|
||||
expect(spy).toHaveBeenCalledWith(
|
||||
`/api/v1/projects/${PROJECT_ID}/domains/d1/check`,
|
||||
expect.objectContaining({ method: "GET" }),
|
||||
)
|
||||
})
|
||||
|
||||
it("importZones(projectId, accountId) hits project-scoped import path", async () => {
|
||||
const spy = mockFetch([])
|
||||
await api.importZones(PROJECT_ID, "acc1")
|
||||
expect(spy).toHaveBeenCalledWith(
|
||||
`/api/v1/projects/${PROJECT_ID}/accounts/acc1/import`,
|
||||
expect.objectContaining({ method: "POST" }),
|
||||
)
|
||||
})
|
||||
|
||||
it("setDomainTemplate(projectId, id, templateId) hits project-scoped domain path", async () => {
|
||||
const spy = mockFetch({ id: "d1", providerAccountId: "acc1", zoneName: "x.", zoneId: "z1" })
|
||||
await api.setDomainTemplate(PROJECT_ID, "d1", "t1")
|
||||
const [url, opts] = spy.mock.calls[0]
|
||||
expect(url).toBe(`/api/v1/projects/${PROJECT_ID}/domains/d1`)
|
||||
expect((opts as RequestInit).method).toBe("PATCH")
|
||||
})
|
||||
})
|
||||
|
||||
+60
-27
@@ -1,15 +1,25 @@
|
||||
import { API_BASE } from "@/lib/config"
|
||||
import { API_ROOT } from "@/lib/config"
|
||||
import type {
|
||||
AuthState,
|
||||
Account, CreateAccountInput, Template, CreateTemplateInput,
|
||||
Domain, CreateDomainInput, ChangesetResponse, ApplyRequest,
|
||||
} from "./types"
|
||||
|
||||
export class UnauthorizedError extends Error {
|
||||
constructor() {
|
||||
super("Unauthorized")
|
||||
this.name = "UnauthorizedError"
|
||||
}
|
||||
}
|
||||
|
||||
async function req<T>(path: string, init?: RequestInit): Promise<T> {
|
||||
const res = await fetch(`${API_BASE}${path}`, {
|
||||
const res = await fetch(path, {
|
||||
headers: { "Content-Type": "application/json" },
|
||||
method: "GET",
|
||||
credentials: "include",
|
||||
...init,
|
||||
})
|
||||
if (res.status === 401) throw new UnauthorizedError()
|
||||
if (!res.ok) {
|
||||
let msg = `HTTP ${res.status}`
|
||||
try { const b = await res.json(); if (b?.error) msg = String(b.error) } catch { /* ignore */ }
|
||||
@@ -19,29 +29,52 @@ async function req<T>(path: string, init?: RequestInit): Promise<T> {
|
||||
return (await res.json()) as T
|
||||
}
|
||||
|
||||
export const api = {
|
||||
listAccounts: () => req<Account[]>("/accounts"),
|
||||
createAccount: (input: CreateAccountInput) =>
|
||||
req<Account>("/accounts", { method: "POST", body: JSON.stringify(input) }),
|
||||
deleteAccount: (id: string) => req<void>(`/accounts/${id}`, { method: "DELETE" }),
|
||||
|
||||
listTemplates: () => req<Template[]>("/templates"),
|
||||
createTemplate: (input: CreateTemplateInput) =>
|
||||
req<Template>("/templates", { method: "POST", body: JSON.stringify(input) }),
|
||||
updateTemplate: (id: string, input: CreateTemplateInput) =>
|
||||
req<Template>(`/templates/${id}`, { method: "PUT", body: JSON.stringify(input) }),
|
||||
deleteTemplate: (id: string) => req<void>(`/templates/${id}`, { method: "DELETE" }),
|
||||
|
||||
listDomains: () => req<Domain[]>("/domains"),
|
||||
createDomain: (input: CreateDomainInput) =>
|
||||
req<Domain>("/domains", { method: "POST", body: JSON.stringify(input) }),
|
||||
deleteDomain: (id: string) => req<void>(`/domains/${id}`, { method: "DELETE" }),
|
||||
importZones: (accountId: string) =>
|
||||
req<Domain[]>(`/accounts/${accountId}/import`, { method: "POST" }),
|
||||
setDomainTemplate: (id: string, templateId: string | null) =>
|
||||
req<Domain>(`/domains/${id}`, { method: "PATCH", body: JSON.stringify({ templateId }) }),
|
||||
|
||||
checkDomain: (id: string) => req<ChangesetResponse>(`/domains/${id}/check`),
|
||||
applyDomain: (id: string, body: ApplyRequest) =>
|
||||
req<ChangesetResponse>(`/domains/${id}/apply`, { method: "POST", body: JSON.stringify(body) }),
|
||||
function projectPath(projectId: string, path: string): string {
|
||||
return `${API_ROOT}/projects/${projectId}${path}`
|
||||
}
|
||||
|
||||
export const api = {
|
||||
auth: {
|
||||
register: (email: string, password: string) =>
|
||||
req<AuthState>(`${API_ROOT}/auth/register`, {
|
||||
method: "POST",
|
||||
body: JSON.stringify({ email, password }),
|
||||
}),
|
||||
login: (email: string, password: string) =>
|
||||
req<AuthState>(`${API_ROOT}/auth/login`, {
|
||||
method: "POST",
|
||||
body: JSON.stringify({ email, password }),
|
||||
}),
|
||||
logout: () => req<void>(`${API_ROOT}/auth/logout`, { method: "POST" }),
|
||||
me: () => req<AuthState>(`${API_ROOT}/auth/me`),
|
||||
},
|
||||
|
||||
listAccounts: (projectId: string) => req<Account[]>(projectPath(projectId, "/accounts")),
|
||||
createAccount: (projectId: string, input: CreateAccountInput) =>
|
||||
req<Account>(projectPath(projectId, "/accounts"), { method: "POST", body: JSON.stringify(input) }),
|
||||
deleteAccount: (projectId: string, id: string) =>
|
||||
req<void>(projectPath(projectId, `/accounts/${id}`), { method: "DELETE" }),
|
||||
|
||||
listTemplates: (projectId: string) => req<Template[]>(projectPath(projectId, "/templates")),
|
||||
createTemplate: (projectId: string, input: CreateTemplateInput) =>
|
||||
req<Template>(projectPath(projectId, "/templates"), { method: "POST", body: JSON.stringify(input) }),
|
||||
updateTemplate: (projectId: string, id: string, input: CreateTemplateInput) =>
|
||||
req<Template>(projectPath(projectId, `/templates/${id}`), { method: "PUT", body: JSON.stringify(input) }),
|
||||
deleteTemplate: (projectId: string, id: string) =>
|
||||
req<void>(projectPath(projectId, `/templates/${id}`), { method: "DELETE" }),
|
||||
|
||||
listDomains: (projectId: string) => req<Domain[]>(projectPath(projectId, "/domains")),
|
||||
createDomain: (projectId: string, input: CreateDomainInput) =>
|
||||
req<Domain>(projectPath(projectId, "/domains"), { method: "POST", body: JSON.stringify(input) }),
|
||||
deleteDomain: (projectId: string, id: string) =>
|
||||
req<void>(projectPath(projectId, `/domains/${id}`), { method: "DELETE" }),
|
||||
importZones: (projectId: string, accountId: string) =>
|
||||
req<Domain[]>(projectPath(projectId, `/accounts/${accountId}/import`), { method: "POST" }),
|
||||
setDomainTemplate: (projectId: string, id: string, templateId: string | null) =>
|
||||
req<Domain>(projectPath(projectId, `/domains/${id}`), { method: "PATCH", body: JSON.stringify({ templateId }) }),
|
||||
|
||||
checkDomain: (projectId: string, id: string) =>
|
||||
req<ChangesetResponse>(projectPath(projectId, `/domains/${id}/check`)),
|
||||
applyDomain: (projectId: string, id: string, body: ApplyRequest) =>
|
||||
req<ChangesetResponse>(projectPath(projectId, `/domains/${id}/apply`), { method: "POST", body: JSON.stringify(body) }),
|
||||
}
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
export interface User { id: string; email: string }
|
||||
export interface Project { id: string; name: string }
|
||||
export interface AuthState { user: User; project: Project }
|
||||
|
||||
export interface Account { id: string; provider: string; comment: string }
|
||||
export interface CreateAccountInput { provider: string; secret: string; comment: string }
|
||||
|
||||
|
||||
@@ -0,0 +1,126 @@
|
||||
import { render, screen, waitFor } from "@testing-library/react"
|
||||
import userEvent from "@testing-library/user-event"
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query"
|
||||
import { describe, it, expect, vi, beforeEach } from "vitest"
|
||||
import { AuthProvider, notifyUnauthorized, useAuth } from "./AuthContext"
|
||||
import { api, UnauthorizedError } from "@/api/client"
|
||||
|
||||
const USER = { id: "u1", email: "a@b.com" }
|
||||
const PROJECT = { id: "p1", name: "Default" }
|
||||
|
||||
function Probe() {
|
||||
const { user, project, loading, login, register, logout } = useAuth()
|
||||
return (
|
||||
<div>
|
||||
<span data-testid="loading">{String(loading)}</span>
|
||||
<span data-testid="user">{user ? user.email : "none"}</span>
|
||||
<span data-testid="project">{project ? project.name : "none"}</span>
|
||||
<button onClick={() => login("a@b.com", "secret")}>login</button>
|
||||
<button onClick={() => register("a@b.com", "secret")}>register</button>
|
||||
<button onClick={() => logout()}>logout</button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
function renderProbe(qc: QueryClient = new QueryClient()) {
|
||||
return render(
|
||||
<QueryClientProvider client={qc}>
|
||||
<AuthProvider>
|
||||
<Probe />
|
||||
</AuthProvider>
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
describe("AuthContext", () => {
|
||||
it("populates user/project from api.auth.me() on mount", async () => {
|
||||
vi.spyOn(api.auth, "me").mockResolvedValue({ user: USER, project: PROJECT })
|
||||
renderProbe()
|
||||
|
||||
expect(screen.getByTestId("loading").textContent).toBe("true")
|
||||
|
||||
await waitFor(() => expect(screen.getByTestId("loading").textContent).toBe("false"))
|
||||
expect(screen.getByTestId("user").textContent).toBe(USER.email)
|
||||
expect(screen.getByTestId("project").textContent).toBe(PROJECT.name)
|
||||
})
|
||||
|
||||
it("treats 401 from api.auth.me() as unauthenticated, not an error", async () => {
|
||||
vi.spyOn(api.auth, "me").mockRejectedValue(new UnauthorizedError())
|
||||
renderProbe()
|
||||
|
||||
await waitFor(() => expect(screen.getByTestId("loading").textContent).toBe("false"))
|
||||
expect(screen.getByTestId("user").textContent).toBe("none")
|
||||
expect(screen.getByTestId("project").textContent).toBe("none")
|
||||
})
|
||||
|
||||
it("login sets user/project in context", async () => {
|
||||
vi.spyOn(api.auth, "me").mockRejectedValue(new UnauthorizedError())
|
||||
vi.spyOn(api.auth, "login").mockResolvedValue({ user: USER, project: PROJECT })
|
||||
const user = userEvent.setup()
|
||||
renderProbe()
|
||||
|
||||
await waitFor(() => expect(screen.getByTestId("loading").textContent).toBe("false"))
|
||||
await user.click(screen.getByRole("button", { name: "login" }))
|
||||
|
||||
await waitFor(() => expect(screen.getByTestId("user").textContent).toBe(USER.email))
|
||||
expect(screen.getByTestId("project").textContent).toBe(PROJECT.name)
|
||||
})
|
||||
|
||||
it("treats a non-401 error from api.auth.me() as logged-out but logs it for diagnostics", async () => {
|
||||
const consoleErrorSpy = vi.spyOn(console, "error").mockImplementation(() => {})
|
||||
const err = new Error("network down")
|
||||
vi.spyOn(api.auth, "me").mockRejectedValue(err)
|
||||
renderProbe()
|
||||
|
||||
await waitFor(() => expect(screen.getByTestId("loading").textContent).toBe("false"))
|
||||
expect(screen.getByTestId("user").textContent).toBe("none")
|
||||
expect(screen.getByTestId("project").textContent).toBe("none")
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(err)
|
||||
})
|
||||
|
||||
it("logout clears user/project from context", async () => {
|
||||
vi.spyOn(api.auth, "me").mockResolvedValue({ user: USER, project: PROJECT })
|
||||
vi.spyOn(api.auth, "logout").mockResolvedValue(undefined)
|
||||
const user = userEvent.setup()
|
||||
renderProbe()
|
||||
|
||||
await waitFor(() => expect(screen.getByTestId("user").textContent).toBe(USER.email))
|
||||
await user.click(screen.getByRole("button", { name: "logout" }))
|
||||
|
||||
await waitFor(() => expect(screen.getByTestId("user").textContent).toBe("none"))
|
||||
expect(screen.getByTestId("project").textContent).toBe("none")
|
||||
})
|
||||
|
||||
it("logout clears the react-query cache", async () => {
|
||||
vi.spyOn(api.auth, "me").mockResolvedValue({ user: USER, project: PROJECT })
|
||||
vi.spyOn(api.auth, "logout").mockResolvedValue(undefined)
|
||||
const qc = new QueryClient()
|
||||
const clearSpy = vi.spyOn(qc, "clear")
|
||||
const user = userEvent.setup()
|
||||
renderProbe(qc)
|
||||
|
||||
await waitFor(() => expect(screen.getByTestId("user").textContent).toBe(USER.email))
|
||||
await user.click(screen.getByRole("button", { name: "logout" }))
|
||||
|
||||
await waitFor(() => expect(clearSpy).toHaveBeenCalled())
|
||||
})
|
||||
|
||||
it("notifyUnauthorized (triggered by any 401 elsewhere in the app) drops the session and clears the cache", async () => {
|
||||
vi.spyOn(api.auth, "me").mockResolvedValue({ user: USER, project: PROJECT })
|
||||
const qc = new QueryClient()
|
||||
const clearSpy = vi.spyOn(qc, "clear")
|
||||
renderProbe(qc)
|
||||
|
||||
await waitFor(() => expect(screen.getByTestId("user").textContent).toBe(USER.email))
|
||||
|
||||
notifyUnauthorized()
|
||||
|
||||
await waitFor(() => expect(screen.getByTestId("user").textContent).toBe("none"))
|
||||
expect(screen.getByTestId("project").textContent).toBe("none")
|
||||
expect(clearSpy).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,113 @@
|
||||
import { createContext, useCallback, useContext, useEffect, useState, type ReactNode } from "react"
|
||||
import { useQueryClient } from "@tanstack/react-query"
|
||||
import { api, UnauthorizedError } from "@/api/client"
|
||||
import type { User, Project } from "@/api/types"
|
||||
|
||||
export interface AuthContextValue {
|
||||
user: User | null
|
||||
project: Project | null
|
||||
loading: boolean
|
||||
login: (email: string, password: string) => Promise<void>
|
||||
register: (email: string, password: string) => Promise<void>
|
||||
logout: () => Promise<void>
|
||||
}
|
||||
|
||||
const AuthContext = createContext<AuthContextValue | undefined>(undefined)
|
||||
|
||||
// AuthProvider registers a handler here so code outside the React tree (the
|
||||
// QueryClient's QueryCache/MutationCache onError, wired up in main.tsx) can
|
||||
// report a 401 from *any* query/mutation and have AuthContext drop the
|
||||
// session — the same "unauthenticated" state ProtectedRoute already reacts
|
||||
// to. There is exactly one AuthProvider in the app, so a single module-level
|
||||
// slot is enough; it's registered/unregistered via useEffect below.
|
||||
type UnauthorizedHandler = () => void
|
||||
let unauthorizedHandler: UnauthorizedHandler | null = null
|
||||
|
||||
export function registerUnauthorizedHandler(handler: UnauthorizedHandler | null) {
|
||||
unauthorizedHandler = handler
|
||||
}
|
||||
|
||||
export function notifyUnauthorized() {
|
||||
unauthorizedHandler?.()
|
||||
}
|
||||
|
||||
export function AuthProvider({ children }: { children: ReactNode }) {
|
||||
const [user, setUser] = useState<User | null>(null)
|
||||
const [project, setProject] = useState<Project | null>(null)
|
||||
const [loading, setLoading] = useState(true)
|
||||
const qc = useQueryClient()
|
||||
|
||||
useEffect(() => {
|
||||
let cancelled = false
|
||||
api.auth
|
||||
.me()
|
||||
.then((state) => {
|
||||
if (cancelled) return
|
||||
setUser(state.user)
|
||||
setProject(state.project)
|
||||
})
|
||||
.catch((err) => {
|
||||
// Unauthenticated (401) — normal logged-out state, no need to log.
|
||||
// Any other failure (network/500/etc) — still treat as logged-out so
|
||||
// we don't get stuck in loading, but surface it for diagnostics
|
||||
// instead of swallowing it silently. Redirect handling is out of
|
||||
// scope here (Task 6).
|
||||
if (!(err instanceof UnauthorizedError)) {
|
||||
console.error(err)
|
||||
}
|
||||
if (cancelled) return
|
||||
setUser(null)
|
||||
setProject(null)
|
||||
})
|
||||
.finally(() => {
|
||||
if (!cancelled) setLoading(false)
|
||||
})
|
||||
return () => {
|
||||
cancelled = true
|
||||
}
|
||||
}, [])
|
||||
|
||||
const login = useCallback(async (email: string, password: string) => {
|
||||
const state = await api.auth.login(email, password)
|
||||
setUser(state.user)
|
||||
setProject(state.project)
|
||||
}, [])
|
||||
|
||||
const register = useCallback(async (email: string, password: string) => {
|
||||
const state = await api.auth.register(email, password)
|
||||
setUser(state.user)
|
||||
setProject(state.project)
|
||||
}, [])
|
||||
|
||||
const logout = useCallback(async () => {
|
||||
await api.auth.logout()
|
||||
setUser(null)
|
||||
setProject(null)
|
||||
qc.clear()
|
||||
}, [qc])
|
||||
|
||||
// Any query/mutation elsewhere in the app that hits a 401 reports it here
|
||||
// (see notifyUnauthorized/registerUnauthorizedHandler above) — drop the
|
||||
// session the same way logout() would, so ProtectedRoute redirects to
|
||||
// /login instead of the UI silently sitting on stale, now-invalid data.
|
||||
useEffect(() => {
|
||||
registerUnauthorizedHandler(() => {
|
||||
setUser(null)
|
||||
setProject(null)
|
||||
qc.clear()
|
||||
})
|
||||
return () => registerUnauthorizedHandler(null)
|
||||
}, [qc])
|
||||
|
||||
return (
|
||||
<AuthContext.Provider value={{ user, project, loading, login, register, logout }}>
|
||||
{children}
|
||||
</AuthContext.Provider>
|
||||
)
|
||||
}
|
||||
|
||||
export function useAuth(): AuthContextValue {
|
||||
const ctx = useContext(AuthContext)
|
||||
if (!ctx) throw new Error("useAuth must be used within an AuthProvider")
|
||||
return ctx
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
import { render, screen } from "@testing-library/react"
|
||||
import { MemoryRouter, Routes, Route } from "react-router-dom"
|
||||
import { describe, it, expect, vi } from "vitest"
|
||||
import { ProtectedRoute } from "./ProtectedRoute"
|
||||
import * as AuthContextModule from "./AuthContext"
|
||||
|
||||
function renderWithAuth(authValue: Partial<AuthContextModule.AuthContextValue>) {
|
||||
vi.spyOn(AuthContextModule, "useAuth").mockReturnValue({
|
||||
user: null,
|
||||
project: null,
|
||||
loading: false,
|
||||
login: vi.fn(),
|
||||
register: vi.fn(),
|
||||
logout: vi.fn(),
|
||||
...authValue,
|
||||
})
|
||||
|
||||
return render(
|
||||
<MemoryRouter initialEntries={["/domains"]}>
|
||||
<Routes>
|
||||
<Route path="/login" element={<div>login page</div>} />
|
||||
<Route
|
||||
path="/domains"
|
||||
element={
|
||||
<ProtectedRoute>
|
||||
<div>protected content</div>
|
||||
</ProtectedRoute>
|
||||
}
|
||||
/>
|
||||
</Routes>
|
||||
</MemoryRouter>,
|
||||
)
|
||||
}
|
||||
|
||||
describe("ProtectedRoute", () => {
|
||||
it("показывает спиннер, пока идёт проверка сессии", () => {
|
||||
renderWithAuth({ user: null, loading: true })
|
||||
|
||||
expect(screen.queryByText("protected content")).not.toBeInTheDocument()
|
||||
expect(screen.queryByText("login page")).not.toBeInTheDocument()
|
||||
expect(screen.getByRole("status")).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it("редиректит на /login, когда пользователь не авторизован", () => {
|
||||
renderWithAuth({ user: null, loading: false })
|
||||
|
||||
expect(screen.getByText("login page")).toBeInTheDocument()
|
||||
expect(screen.queryByText("protected content")).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it("рендерит children, когда пользователь авторизован", () => {
|
||||
renderWithAuth({ user: { id: "u1", email: "a@b.com" }, loading: false })
|
||||
|
||||
expect(screen.getByText("protected content")).toBeInTheDocument()
|
||||
expect(screen.queryByText("login page")).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,24 @@
|
||||
import type { ReactNode } from "react"
|
||||
import { Navigate } from "react-router-dom"
|
||||
import { Loader2 } from "lucide-react"
|
||||
import { useAuth } from "@/auth/AuthContext"
|
||||
|
||||
export function ProtectedRoute({ children }: { children: ReactNode }) {
|
||||
const { user, loading } = useAuth()
|
||||
|
||||
if (loading) {
|
||||
return (
|
||||
<div
|
||||
role="status"
|
||||
aria-label="Проверка сессии"
|
||||
className="flex h-screen w-full items-center justify-center bg-background"
|
||||
>
|
||||
<Loader2 className="size-6 animate-spin text-muted-foreground" strokeWidth={1.75} />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
if (!user) return <Navigate to="/login" replace />
|
||||
|
||||
return <>{children}</>
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { ReactNode } from "react"
|
||||
import { NavLink, useLocation } from "react-router-dom"
|
||||
import { Globe, Users, LayoutTemplate, SquareTerminal } from "lucide-react"
|
||||
import { NavLink, useLocation, useNavigate } from "react-router-dom"
|
||||
import { Globe, LogOut, Users, LayoutTemplate, SquareTerminal } from "lucide-react"
|
||||
import { useAuth } from "@/auth/AuthContext"
|
||||
import { Button } from "@/components/ui/button"
|
||||
import { cn } from "@/lib/utils"
|
||||
|
||||
const NAV = [
|
||||
@@ -11,6 +13,13 @@ const NAV = [
|
||||
|
||||
export function Layout({ children }: { children: ReactNode }) {
|
||||
const location = useLocation()
|
||||
const navigate = useNavigate()
|
||||
const { user, logout } = useAuth()
|
||||
|
||||
async function onLogout() {
|
||||
await logout()
|
||||
navigate("/login", { replace: true })
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex h-screen w-full overflow-hidden bg-background text-foreground">
|
||||
@@ -64,10 +73,19 @@ export function Layout({ children }: { children: ReactNode }) {
|
||||
</aside>
|
||||
|
||||
<div className="flex flex-1 flex-col overflow-hidden">
|
||||
<header className="flex h-11 shrink-0 items-center border-b border-border px-6">
|
||||
<header className="flex h-11 shrink-0 items-center justify-between border-b border-border px-6">
|
||||
<span className="font-dns text-xs text-muted-foreground">
|
||||
{location.pathname}
|
||||
</span>
|
||||
{user && (
|
||||
<div className="flex items-center gap-3">
|
||||
<span className="font-dns text-xs text-muted-foreground">{user.email}</span>
|
||||
<Button variant="ghost" size="sm" onClick={onLogout}>
|
||||
<LogOut className="size-3.5" strokeWidth={1.75} />
|
||||
Выйти
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</header>
|
||||
<main className="flex-1 overflow-auto">{children}</main>
|
||||
</div>
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
import { renderHook, waitFor } from "@testing-library/react"
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query"
|
||||
import { describe, it, expect, vi, beforeEach } from "vitest"
|
||||
import type { ReactNode } from "react"
|
||||
import { AuthProvider } from "@/auth/AuthContext"
|
||||
import { api, UnauthorizedError } from "@/api/client"
|
||||
import { useDeleteAccount } from "./useApi"
|
||||
|
||||
beforeEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
function wrapper({ children }: { children: ReactNode }) {
|
||||
const qc = new QueryClient({ defaultOptions: { mutations: { retry: false } } })
|
||||
return (
|
||||
<QueryClientProvider client={qc}>
|
||||
<AuthProvider>{children}</AuthProvider>
|
||||
</QueryClientProvider>
|
||||
)
|
||||
}
|
||||
|
||||
describe("useApi mutations — null project guard", () => {
|
||||
it("mutate() without an active project fails with a clear error, not a TypeError", async () => {
|
||||
// No session yet => AuthContext resolves project to null.
|
||||
vi.spyOn(api.auth, "me").mockRejectedValue(new UnauthorizedError())
|
||||
vi.spyOn(api, "deleteAccount")
|
||||
|
||||
const { result } = renderHook(() => useDeleteAccount(), { wrapper })
|
||||
|
||||
result.current.mutate("acc-1")
|
||||
|
||||
await waitFor(() => expect(result.current.isError).toBe(true))
|
||||
|
||||
expect(result.current.error).toBeInstanceOf(Error)
|
||||
expect(result.current.error).not.toBeInstanceOf(TypeError)
|
||||
expect((result.current.error as Error).message).toBe("no active project")
|
||||
expect(api.deleteAccount).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
+89
-27
@@ -1,81 +1,143 @@
|
||||
import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query"
|
||||
import { api } from "@/api/client"
|
||||
import type { CreateAccountInput, CreateTemplateInput, ApplyRequest } from "@/api/types"
|
||||
import { useAuth } from "@/auth/AuthContext"
|
||||
import type { CreateAccountInput, CreateTemplateInput, ApplyRequest, Project } from "@/api/types"
|
||||
|
||||
function requireProjectId(project: Project | null): string {
|
||||
if (!project) throw new Error("no active project")
|
||||
return project.id
|
||||
}
|
||||
|
||||
export function useAccounts() {
|
||||
return useQuery({ queryKey: ["accounts"], queryFn: api.listAccounts })
|
||||
const { project } = useAuth()
|
||||
return useQuery({
|
||||
queryKey: ["accounts", project?.id],
|
||||
queryFn: () => api.listAccounts(project!.id),
|
||||
enabled: !!project,
|
||||
})
|
||||
}
|
||||
export function useCreateAccount() {
|
||||
const { project } = useAuth()
|
||||
const qc = useQueryClient()
|
||||
return useMutation({
|
||||
mutationFn: (input: CreateAccountInput) => api.createAccount(input),
|
||||
onSuccess: () => qc.invalidateQueries({ queryKey: ["accounts"] }),
|
||||
mutationFn: (input: CreateAccountInput) => {
|
||||
const pid = requireProjectId(project)
|
||||
return api.createAccount(pid, input)
|
||||
},
|
||||
onSuccess: () => qc.invalidateQueries({ queryKey: ["accounts", project?.id] }),
|
||||
})
|
||||
}
|
||||
export function useDeleteAccount() {
|
||||
const { project } = useAuth()
|
||||
const qc = useQueryClient()
|
||||
return useMutation({
|
||||
mutationFn: (id: string) => api.deleteAccount(id),
|
||||
onSuccess: () => qc.invalidateQueries({ queryKey: ["accounts"] }),
|
||||
mutationFn: (id: string) => {
|
||||
const pid = requireProjectId(project)
|
||||
return api.deleteAccount(pid, id)
|
||||
},
|
||||
onSuccess: () => qc.invalidateQueries({ queryKey: ["accounts", project?.id] }),
|
||||
})
|
||||
}
|
||||
|
||||
export function useTemplates() {
|
||||
return useQuery({ queryKey: ["templates"], queryFn: api.listTemplates })
|
||||
const { project } = useAuth()
|
||||
return useQuery({
|
||||
queryKey: ["templates", project?.id],
|
||||
queryFn: () => api.listTemplates(project!.id),
|
||||
enabled: !!project,
|
||||
})
|
||||
}
|
||||
export function useCreateTemplate() {
|
||||
const { project } = useAuth()
|
||||
const qc = useQueryClient()
|
||||
return useMutation({
|
||||
mutationFn: (input: CreateTemplateInput) => api.createTemplate(input),
|
||||
onSuccess: () => qc.invalidateQueries({ queryKey: ["templates"] }),
|
||||
mutationFn: (input: CreateTemplateInput) => {
|
||||
const pid = requireProjectId(project)
|
||||
return api.createTemplate(pid, input)
|
||||
},
|
||||
onSuccess: () => qc.invalidateQueries({ queryKey: ["templates", project?.id] }),
|
||||
})
|
||||
}
|
||||
export function useUpdateTemplate() {
|
||||
const { project } = useAuth()
|
||||
const qc = useQueryClient()
|
||||
return useMutation({
|
||||
mutationFn: ({ id, input }: { id: string; input: CreateTemplateInput }) => api.updateTemplate(id, input),
|
||||
onSuccess: () => qc.invalidateQueries({ queryKey: ["templates"] }),
|
||||
mutationFn: ({ id, input }: { id: string; input: CreateTemplateInput }) => {
|
||||
const pid = requireProjectId(project)
|
||||
return api.updateTemplate(pid, id, input)
|
||||
},
|
||||
onSuccess: () => qc.invalidateQueries({ queryKey: ["templates", project?.id] }),
|
||||
})
|
||||
}
|
||||
export function useDeleteTemplate() {
|
||||
const { project } = useAuth()
|
||||
const qc = useQueryClient()
|
||||
return useMutation({
|
||||
mutationFn: (id: string) => api.deleteTemplate(id),
|
||||
onSuccess: () => qc.invalidateQueries({ queryKey: ["templates"] }),
|
||||
mutationFn: (id: string) => {
|
||||
const pid = requireProjectId(project)
|
||||
return api.deleteTemplate(pid, id)
|
||||
},
|
||||
onSuccess: () => qc.invalidateQueries({ queryKey: ["templates", project?.id] }),
|
||||
})
|
||||
}
|
||||
|
||||
export function useDomains() {
|
||||
return useQuery({ queryKey: ["domains"], queryFn: api.listDomains })
|
||||
const { project } = useAuth()
|
||||
return useQuery({
|
||||
queryKey: ["domains", project?.id],
|
||||
queryFn: () => api.listDomains(project!.id),
|
||||
enabled: !!project,
|
||||
})
|
||||
}
|
||||
export function useImportZones() {
|
||||
const { project } = useAuth()
|
||||
const qc = useQueryClient()
|
||||
return useMutation({
|
||||
mutationFn: (accountId: string) => api.importZones(accountId),
|
||||
onSuccess: () => qc.invalidateQueries({ queryKey: ["domains"] }),
|
||||
mutationFn: (accountId: string) => {
|
||||
const pid = requireProjectId(project)
|
||||
return api.importZones(pid, accountId)
|
||||
},
|
||||
onSuccess: () => qc.invalidateQueries({ queryKey: ["domains", project?.id] }),
|
||||
})
|
||||
}
|
||||
export function useSetDomainTemplate() {
|
||||
const { project } = useAuth()
|
||||
const qc = useQueryClient()
|
||||
return useMutation({
|
||||
mutationFn: ({ id, templateId }: { id: string; templateId: string | null }) => api.setDomainTemplate(id, templateId),
|
||||
onSuccess: () => qc.invalidateQueries({ queryKey: ["domains"] }),
|
||||
mutationFn: ({ id, templateId }: { id: string; templateId: string | null }) => {
|
||||
const pid = requireProjectId(project)
|
||||
return api.setDomainTemplate(pid, id, templateId)
|
||||
},
|
||||
onSuccess: () => qc.invalidateQueries({ queryKey: ["domains", project?.id] }),
|
||||
})
|
||||
}
|
||||
export function useDeleteDomain() {
|
||||
const { project } = useAuth()
|
||||
const qc = useQueryClient()
|
||||
return useMutation({
|
||||
mutationFn: (id: string) => api.deleteDomain(id),
|
||||
onSuccess: () => qc.invalidateQueries({ queryKey: ["domains"] }),
|
||||
mutationFn: (id: string) => {
|
||||
const pid = requireProjectId(project)
|
||||
return api.deleteDomain(pid, id)
|
||||
},
|
||||
onSuccess: () => qc.invalidateQueries({ queryKey: ["domains", project?.id] }),
|
||||
})
|
||||
}
|
||||
export function useCheckDomain(id: string) {
|
||||
return useQuery({ queryKey: ["check", id], queryFn: () => api.checkDomain(id), enabled: !!id })
|
||||
}
|
||||
export function useApplyDomain(id: string) {
|
||||
const qc = useQueryClient()
|
||||
return useMutation({
|
||||
mutationFn: (body: ApplyRequest) => api.applyDomain(id, body),
|
||||
onSuccess: () => qc.invalidateQueries({ queryKey: ["check", id] }),
|
||||
const { project } = useAuth()
|
||||
return useQuery({
|
||||
queryKey: ["check", project?.id, id],
|
||||
queryFn: () => api.checkDomain(project!.id, id),
|
||||
enabled: !!project && !!id,
|
||||
})
|
||||
}
|
||||
export function useApplyDomain(id: string) {
|
||||
const { project } = useAuth()
|
||||
const qc = useQueryClient()
|
||||
return useMutation({
|
||||
mutationFn: (body: ApplyRequest) => {
|
||||
const pid = requireProjectId(project)
|
||||
return api.applyDomain(pid, id, body)
|
||||
},
|
||||
onSuccess: () => qc.invalidateQueries({ queryKey: ["check", project?.id, id] }),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,2 +1 @@
|
||||
export const DEFAULT_PROJECT_ID = "00000000-0000-0000-0000-000000000002"
|
||||
export const API_BASE = `/api/v1/projects/${DEFAULT_PROJECT_ID}`
|
||||
export const API_ROOT = "/api/v1"
|
||||
|
||||
+19
-2
@@ -4,18 +4,35 @@ import "@fontsource/ibm-plex-mono/500.css"
|
||||
import "./index.css"
|
||||
import React from "react"
|
||||
import ReactDOM from "react-dom/client"
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query"
|
||||
import { QueryCache, QueryClient, QueryClientProvider, MutationCache } from "@tanstack/react-query"
|
||||
import { BrowserRouter } from "react-router-dom"
|
||||
import { UnauthorizedError } from "@/api/client"
|
||||
import { AuthProvider, notifyUnauthorized } from "@/auth/AuthContext"
|
||||
import { App } from "./App"
|
||||
|
||||
const queryClient = new QueryClient()
|
||||
// A 401 from *any* query or mutation means the session died server-side
|
||||
// (expired/destroyed cookie) — drop it from here rather than requiring every
|
||||
// hook in useApi.ts to remember to handle it individually. AuthContext reacts
|
||||
// via notifyUnauthorized (registered by AuthProvider), which resets
|
||||
// user/project and clears the cache; ProtectedRoute then redirects to
|
||||
// /login on the next render.
|
||||
function onQueryError(error: unknown) {
|
||||
if (error instanceof UnauthorizedError) notifyUnauthorized()
|
||||
}
|
||||
|
||||
const queryClient = new QueryClient({
|
||||
queryCache: new QueryCache({ onError: onQueryError }),
|
||||
mutationCache: new MutationCache({ onError: onQueryError }),
|
||||
})
|
||||
|
||||
ReactDOM.createRoot(document.getElementById("root")!).render(
|
||||
<React.StrictMode>
|
||||
<QueryClientProvider client={queryClient}>
|
||||
<AuthProvider>
|
||||
<BrowserRouter>
|
||||
<App />
|
||||
</BrowserRouter>
|
||||
</AuthProvider>
|
||||
</QueryClientProvider>
|
||||
</React.StrictMode>,
|
||||
)
|
||||
|
||||
@@ -3,10 +3,12 @@ import userEvent from "@testing-library/user-event"
|
||||
import { MemoryRouter } from "react-router-dom"
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query"
|
||||
import { AccountsPage } from "./AccountsPage"
|
||||
import { AuthProvider } from "@/auth/AuthContext"
|
||||
import { api } from "@/api/client"
|
||||
import { vi, beforeEach, test, expect } from "vitest"
|
||||
import type { Account } from "@/api/types"
|
||||
|
||||
const PROJECT_ID = "p1"
|
||||
const accounts: Account[] = [
|
||||
{ id: "acc1", provider: "selectel", comment: "Main" },
|
||||
{ id: "acc2", provider: "selectel", comment: "Backup" },
|
||||
@@ -16,14 +18,21 @@ function renderPage() {
|
||||
const qc = new QueryClient()
|
||||
return render(
|
||||
<QueryClientProvider client={qc}>
|
||||
<AuthProvider>
|
||||
<MemoryRouter initialEntries={["/accounts"]}>
|
||||
<AccountsPage />
|
||||
</MemoryRouter>
|
||||
</AuthProvider>
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
vi.spyOn(api.auth, "me").mockResolvedValue({
|
||||
user: { id: "u1", email: "a@b.com" },
|
||||
project: { id: PROJECT_ID, name: "Default" },
|
||||
})
|
||||
vi.spyOn(api, "listAccounts").mockResolvedValue(accounts)
|
||||
})
|
||||
|
||||
@@ -55,7 +64,7 @@ test("форма создания вызывает api.createAccount с введ
|
||||
await user.click(screen.getByRole("button", { name: /добавить учётку/i }))
|
||||
|
||||
await waitFor(() =>
|
||||
expect(createSpy).toHaveBeenCalledWith({
|
||||
expect(createSpy).toHaveBeenCalledWith(PROJECT_ID, {
|
||||
provider: "selectel",
|
||||
secret: "super-secret-token-123",
|
||||
comment: "New account",
|
||||
@@ -99,5 +108,5 @@ test("удаление учётки вызывает api.deleteAccount", async (
|
||||
|
||||
await user.click(screen.getByRole("button", { name: /удалить.*main/i }))
|
||||
|
||||
await waitFor(() => expect(deleteSpy).toHaveBeenCalledWith("acc1"))
|
||||
await waitFor(() => expect(deleteSpy).toHaveBeenCalledWith(PROJECT_ID, "acc1"))
|
||||
})
|
||||
|
||||
@@ -3,20 +3,33 @@ import userEvent from "@testing-library/user-event"
|
||||
import { MemoryRouter, Routes, Route } from "react-router-dom"
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query"
|
||||
import { DomainDiffPage } from "./DomainDiffPage"
|
||||
import { AuthProvider } from "@/auth/AuthContext"
|
||||
import { api } from "@/api/client"
|
||||
import { vi } from "vitest"
|
||||
import { vi, beforeEach } from "vitest"
|
||||
|
||||
const PROJECT_ID = "p1"
|
||||
|
||||
function renderPage() {
|
||||
const qc = new QueryClient()
|
||||
return render(
|
||||
<QueryClientProvider client={qc}>
|
||||
<AuthProvider>
|
||||
<MemoryRouter initialEntries={["/domains/d1"]}>
|
||||
<Routes><Route path="/domains/:id" element={<DomainDiffPage />} /></Routes>
|
||||
</MemoryRouter>
|
||||
</AuthProvider>
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
vi.spyOn(api.auth, "me").mockResolvedValue({
|
||||
user: { id: "u1", email: "a@b.com" },
|
||||
project: { id: PROJECT_ID, name: "Default" },
|
||||
})
|
||||
})
|
||||
|
||||
test("apply sends applyPrunes=false by default, true only after opting in", async () => {
|
||||
vi.spyOn(api, "checkDomain").mockResolvedValue({
|
||||
updates: [{ kind: "update", type: "A", name: "a.", desired: ["1"], actual: ["2"], readOnly: false }],
|
||||
@@ -30,12 +43,12 @@ test("apply sends applyPrunes=false by default, true only after opting in", asyn
|
||||
const applyBtn = await screen.findByRole("button", { name: /apply/i })
|
||||
await user.click(applyBtn)
|
||||
await waitFor(() => expect(applySpy).toHaveBeenCalled())
|
||||
expect(applySpy.mock.calls[0][1]).toEqual({ applyUpdates: true, applyPrunes: false })
|
||||
expect(applySpy.mock.calls[0]).toEqual([PROJECT_ID, "d1", { applyUpdates: true, applyPrunes: false }])
|
||||
|
||||
// включить prune и применить снова
|
||||
const pruneToggle = screen.getByRole("checkbox", { name: /prune|удал/i })
|
||||
await user.click(pruneToggle)
|
||||
await user.click(screen.getByRole("button", { name: /apply/i }))
|
||||
await waitFor(() => expect(applySpy).toHaveBeenCalledTimes(2))
|
||||
expect(applySpy.mock.calls[1][1]).toEqual({ applyUpdates: true, applyPrunes: true })
|
||||
expect(applySpy.mock.calls[1]).toEqual([PROJECT_ID, "d1", { applyUpdates: true, applyPrunes: true }])
|
||||
})
|
||||
|
||||
@@ -3,10 +3,12 @@ import userEvent from "@testing-library/user-event"
|
||||
import { MemoryRouter, Routes, Route } from "react-router-dom"
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query"
|
||||
import { DomainsPage } from "./DomainsPage"
|
||||
import { AuthProvider } from "@/auth/AuthContext"
|
||||
import { api } from "@/api/client"
|
||||
import { vi, beforeEach, test, expect } from "vitest"
|
||||
import type { Account, Domain, Template } from "@/api/types"
|
||||
|
||||
const PROJECT_ID = "p1"
|
||||
const accounts: Account[] = [
|
||||
{ id: "acc1", provider: "selectel", comment: "Main" },
|
||||
{ id: "acc2", provider: "cloudflare", comment: "Backup" },
|
||||
@@ -24,17 +26,24 @@ function renderPage() {
|
||||
const qc = new QueryClient()
|
||||
return render(
|
||||
<QueryClientProvider client={qc}>
|
||||
<AuthProvider>
|
||||
<MemoryRouter initialEntries={["/domains"]}>
|
||||
<Routes>
|
||||
<Route path="/domains" element={<DomainsPage />} />
|
||||
<Route path="/domains/:id" element={<div>diff page</div>} />
|
||||
</Routes>
|
||||
</MemoryRouter>
|
||||
</AuthProvider>
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
vi.spyOn(api.auth, "me").mockResolvedValue({
|
||||
user: { id: "u1", email: "a@b.com" },
|
||||
project: { id: PROJECT_ID, name: "Default" },
|
||||
})
|
||||
vi.spyOn(api, "listDomains").mockResolvedValue(domains)
|
||||
vi.spyOn(api, "listAccounts").mockResolvedValue(accounts)
|
||||
vi.spyOn(api, "listTemplates").mockResolvedValue(templates)
|
||||
@@ -64,7 +73,7 @@ test("кнопка импорта вызывает api.importZones с выбра
|
||||
|
||||
await user.click(screen.getByRole("button", { name: /импортировать зоны/i }))
|
||||
|
||||
await waitFor(() => expect(importSpy).toHaveBeenCalledWith("acc2"))
|
||||
await waitFor(() => expect(importSpy).toHaveBeenCalledWith(PROJECT_ID, "acc2"))
|
||||
})
|
||||
|
||||
test("привязка шаблона в строке домена вызывает api.setDomainTemplate", async () => {
|
||||
@@ -77,7 +86,7 @@ test("привязка шаблона в строке домена вызыва
|
||||
await user.click(screen.getByRole("combobox", { name: /example\.com\./i }))
|
||||
await user.click(await screen.findByRole("option", { name: /^standard$/i }))
|
||||
|
||||
await waitFor(() => expect(setTemplateSpy).toHaveBeenCalledWith("d1", "t1"))
|
||||
await waitFor(() => expect(setTemplateSpy).toHaveBeenCalledWith(PROJECT_ID, "d1", "t1"))
|
||||
})
|
||||
|
||||
test("ошибка привязки шаблона отображается пользователю", async () => {
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
import { act, render, screen, waitFor } from "@testing-library/react"
|
||||
import userEvent from "@testing-library/user-event"
|
||||
import { MemoryRouter, Routes, Route } from "react-router-dom"
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query"
|
||||
import { describe, it, expect, vi, beforeEach } from "vitest"
|
||||
import { LoginPage } from "./LoginPage"
|
||||
import { AuthProvider } from "@/auth/AuthContext"
|
||||
import { api, UnauthorizedError } from "@/api/client"
|
||||
|
||||
function renderPage() {
|
||||
const qc = new QueryClient()
|
||||
return render(
|
||||
<QueryClientProvider client={qc}>
|
||||
<AuthProvider>
|
||||
<MemoryRouter initialEntries={["/login"]}>
|
||||
<Routes>
|
||||
<Route path="/login" element={<LoginPage />} />
|
||||
<Route path="/register" element={<div>register page</div>} />
|
||||
<Route path="/domains" element={<div>domains page</div>} />
|
||||
</Routes>
|
||||
</MemoryRouter>
|
||||
</AuthProvider>
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
vi.spyOn(api.auth, "me").mockRejectedValue(new UnauthorizedError())
|
||||
})
|
||||
|
||||
describe("LoginPage", () => {
|
||||
it("ввод email+пароль и сабмит вызывает useAuth().login с введёнными данными", async () => {
|
||||
const loginSpy = vi.spyOn(api.auth, "login").mockResolvedValue({
|
||||
user: { id: "u1", email: "a@b.com" },
|
||||
project: { id: "p1", name: "Default" },
|
||||
})
|
||||
const user = userEvent.setup()
|
||||
renderPage()
|
||||
|
||||
// AuthProvider resolves the session check (api.auth.me) asynchronously;
|
||||
// the form only renders once loading flips to false.
|
||||
await user.type(await screen.findByLabelText(/email/i), "a@b.com")
|
||||
await user.type(screen.getByLabelText(/пароль/i), "secret123")
|
||||
await user.click(screen.getByRole("button", { name: /войти/i }))
|
||||
|
||||
await waitFor(() => expect(loginSpy).toHaveBeenCalledWith("a@b.com", "secret123"))
|
||||
|
||||
expect(await screen.findByText("domains page")).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it("ошибка входа отображается пользователю через role=alert", async () => {
|
||||
vi.spyOn(api.auth, "login").mockRejectedValue(new Error("Неверный email или пароль"))
|
||||
const user = userEvent.setup()
|
||||
renderPage()
|
||||
|
||||
await user.type(await screen.findByLabelText(/email/i), "a@b.com")
|
||||
await user.type(screen.getByLabelText(/пароль/i), "wrong-password")
|
||||
await user.click(screen.getByRole("button", { name: /войти/i }))
|
||||
|
||||
expect(await screen.findByRole("alert")).toHaveTextContent("Неверный email или пароль")
|
||||
})
|
||||
|
||||
it("не рендерит форму логина, пока сессия (api.auth.me) не резолвнута", async () => {
|
||||
let rejectMe!: (err: unknown) => void
|
||||
vi.spyOn(api.auth, "me").mockImplementation(
|
||||
() =>
|
||||
new Promise((_resolve, reject) => {
|
||||
rejectMe = reject
|
||||
}),
|
||||
)
|
||||
renderPage()
|
||||
|
||||
expect(screen.queryByRole("button", { name: /войти/i })).not.toBeInTheDocument()
|
||||
expect(screen.queryByLabelText(/email/i)).not.toBeInTheDocument()
|
||||
|
||||
// Resolve the pending me() so the test doesn't leak an unhandled rejection.
|
||||
await act(async () => {
|
||||
rejectMe(new UnauthorizedError())
|
||||
})
|
||||
})
|
||||
|
||||
it("сетевая ошибка при логине показывает «Сервис недоступен»", async () => {
|
||||
vi.spyOn(api.auth, "login").mockRejectedValue(new TypeError("Failed to fetch"))
|
||||
const user = userEvent.setup()
|
||||
renderPage()
|
||||
|
||||
await user.type(await screen.findByLabelText(/email/i), "a@b.com")
|
||||
await user.type(screen.getByLabelText(/пароль/i), "wrong-password")
|
||||
await user.click(screen.getByRole("button", { name: /войти/i }))
|
||||
|
||||
expect(await screen.findByRole("alert")).toHaveTextContent("Сервис недоступен, попробуйте позже")
|
||||
})
|
||||
|
||||
it("содержит ссылку на регистрацию", async () => {
|
||||
renderPage()
|
||||
|
||||
const link = await screen.findByRole("link", { name: /зарегистрир/i })
|
||||
expect(link).toHaveAttribute("href", "/register")
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,169 @@
|
||||
import { useId, useState } from "react"
|
||||
import { Controller, useForm } from "react-hook-form"
|
||||
import { zodResolver } from "@hookform/resolvers/zod"
|
||||
import { z } from "zod"
|
||||
import { Link, Navigate } from "react-router-dom"
|
||||
import { KeyRound, Loader2, LogIn, SquareTerminal } from "lucide-react"
|
||||
import { useAuth } from "@/auth/AuthContext"
|
||||
import { UnauthorizedError } from "@/api/client"
|
||||
import { Button } from "@/components/ui/button"
|
||||
import { Input } from "@/components/ui/input"
|
||||
import {
|
||||
Field,
|
||||
FieldContent,
|
||||
FieldDescription,
|
||||
FieldError,
|
||||
FieldGroup,
|
||||
FieldLabel,
|
||||
FieldSet,
|
||||
} from "@/components/ui/field"
|
||||
|
||||
const loginSchema = z.object({
|
||||
email: z.string().trim().min(1, "Укажите email").email("Некорректный email"),
|
||||
password: z.string().min(1, "Укажите пароль"),
|
||||
})
|
||||
|
||||
type LoginForm = z.infer<typeof loginSchema>
|
||||
|
||||
// describeLoginError turns a login() rejection into user-facing Russian
|
||||
// copy. A network failure (TypeError from fetch itself) or a 5xx response
|
||||
// means the service is unreachable/broken — that's a different situation
|
||||
// from wrong credentials and should say so. Everything else (401
|
||||
// UnauthorizedError, or a backend "invalid credentials" message) reads as a
|
||||
// bad email/password.
|
||||
function describeLoginError(err: unknown): string {
|
||||
const isNetworkOrServerError =
|
||||
err instanceof TypeError || (err instanceof Error && err.message.startsWith("HTTP 5"))
|
||||
if (isNetworkOrServerError) return "Сервис недоступен, попробуйте позже"
|
||||
|
||||
const isInvalidCredentials =
|
||||
err instanceof UnauthorizedError ||
|
||||
(err instanceof Error && /invalid credentials/i.test(err.message))
|
||||
if (isInvalidCredentials) return "Неверный email или пароль"
|
||||
|
||||
return "Неверный email или пароль"
|
||||
}
|
||||
|
||||
export function LoginPage() {
|
||||
const { user, loading, login } = useAuth()
|
||||
const [authError, setAuthError] = useState<string | null>(null)
|
||||
const emailFieldId = useId()
|
||||
const passwordFieldId = useId()
|
||||
|
||||
const {
|
||||
control,
|
||||
handleSubmit,
|
||||
formState: { errors, isSubmitting },
|
||||
} = useForm<LoginForm>({
|
||||
resolver: zodResolver(loginSchema),
|
||||
defaultValues: { email: "", password: "" },
|
||||
})
|
||||
|
||||
// Session check (api.auth.me()) hasn't resolved yet — don't flash the
|
||||
// login form for a visitor who turns out to already have a valid session.
|
||||
if (loading) return null
|
||||
|
||||
// Already authenticated (fresh session on mount, or just logged in below) —
|
||||
// don't show the login form, go straight to the app.
|
||||
if (user) return <Navigate to="/domains" replace />
|
||||
|
||||
async function onSubmit(values: LoginForm) {
|
||||
setAuthError(null)
|
||||
try {
|
||||
await login(values.email, values.password)
|
||||
} catch (err) {
|
||||
setAuthError(describeLoginError(err))
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex h-screen w-full items-center justify-center bg-background px-4">
|
||||
<div className="flex w-full max-w-sm flex-col gap-6">
|
||||
<div className="flex flex-col items-center gap-2 text-center">
|
||||
<SquareTerminal className="size-6 text-primary" strokeWidth={1.75} />
|
||||
<div className="flex flex-col leading-none">
|
||||
<span className="text-sm font-semibold tracking-tight">DNS Autoresolver</span>
|
||||
<span className="font-dns text-[10px] tracking-wider text-muted-foreground uppercase">
|
||||
console
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<form
|
||||
onSubmit={handleSubmit(onSubmit)}
|
||||
noValidate
|
||||
className="flex flex-col gap-4 rounded-xl border border-border bg-card/60 p-5"
|
||||
>
|
||||
<FieldSet className="gap-3">
|
||||
<FieldGroup className="gap-3">
|
||||
<Field>
|
||||
<FieldLabel htmlFor={emailFieldId}>Email</FieldLabel>
|
||||
<FieldContent>
|
||||
<Controller
|
||||
control={control}
|
||||
name="email"
|
||||
render={({ field }) => (
|
||||
<Input
|
||||
{...field}
|
||||
id={emailFieldId}
|
||||
type="email"
|
||||
autoComplete="email"
|
||||
placeholder="you@example.com"
|
||||
aria-invalid={!!errors.email}
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
<FieldError errors={[errors.email]} />
|
||||
</FieldContent>
|
||||
</Field>
|
||||
|
||||
<Field>
|
||||
<FieldLabel htmlFor={passwordFieldId}>Пароль</FieldLabel>
|
||||
<FieldContent>
|
||||
<Controller
|
||||
control={control}
|
||||
name="password"
|
||||
render={({ field }) => (
|
||||
<Input
|
||||
{...field}
|
||||
id={passwordFieldId}
|
||||
type="password"
|
||||
autoComplete="current-password"
|
||||
placeholder="••••••••••••"
|
||||
aria-invalid={!!errors.password}
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
<FieldError errors={[errors.password]} />
|
||||
</FieldContent>
|
||||
</Field>
|
||||
</FieldGroup>
|
||||
|
||||
{authError && (
|
||||
<FieldDescription role="alert" className="flex items-center gap-2 text-destructive">
|
||||
<KeyRound className="size-3.5 shrink-0" strokeWidth={1.75} />
|
||||
{authError}
|
||||
</FieldDescription>
|
||||
)}
|
||||
</FieldSet>
|
||||
|
||||
<Button type="submit" disabled={isSubmitting} className="w-full">
|
||||
{isSubmitting ? (
|
||||
<Loader2 className="size-4 animate-spin" strokeWidth={1.75} />
|
||||
) : (
|
||||
<LogIn className="size-4" strokeWidth={1.75} />
|
||||
)}
|
||||
Войти
|
||||
</Button>
|
||||
</form>
|
||||
|
||||
<p className="text-center text-sm text-muted-foreground">
|
||||
Нет учётной записи?{" "}
|
||||
<Link to="/register" className="text-primary underline-offset-4 hover:underline">
|
||||
Зарегистрироваться
|
||||
</Link>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
import { useId, useState } from "react"
|
||||
import { Controller, useForm } from "react-hook-form"
|
||||
import { zodResolver } from "@hookform/resolvers/zod"
|
||||
import { z } from "zod"
|
||||
import { Link, Navigate } from "react-router-dom"
|
||||
import { KeyRound, Loader2, SquareTerminal, UserPlus } from "lucide-react"
|
||||
import { useAuth } from "@/auth/AuthContext"
|
||||
import { Button } from "@/components/ui/button"
|
||||
import { Input } from "@/components/ui/input"
|
||||
import {
|
||||
Field,
|
||||
FieldContent,
|
||||
FieldDescription,
|
||||
FieldError,
|
||||
FieldGroup,
|
||||
FieldLabel,
|
||||
FieldSet,
|
||||
} from "@/components/ui/field"
|
||||
|
||||
const registerSchema = z.object({
|
||||
email: z.string().trim().min(1, "Укажите email").email("Некорректный email"),
|
||||
password: z.string().min(8, "Минимум 8 символов"),
|
||||
})
|
||||
|
||||
type RegisterForm = z.infer<typeof registerSchema>
|
||||
|
||||
// describeRegisterError turns a register() rejection into user-facing
|
||||
// Russian copy. A network failure (TypeError from fetch itself) or a 5xx
|
||||
// response means the service is unreachable/broken, not a validation
|
||||
// problem — surface that distinctly instead of an opaque "HTTP 500". Any
|
||||
// other error (409 email taken, 400 password too short, etc.) already
|
||||
// carries a specific backend message worth showing as-is.
|
||||
function describeRegisterError(err: unknown): string {
|
||||
const isNetworkOrServerError =
|
||||
err instanceof TypeError || (err instanceof Error && err.message.startsWith("HTTP 5"))
|
||||
if (isNetworkOrServerError) return "Сервис недоступен, попробуйте позже"
|
||||
|
||||
return err instanceof Error ? err.message : "Не удалось зарегистрироваться"
|
||||
}
|
||||
|
||||
export function RegisterPage() {
|
||||
const { user, loading, register: registerUser } = useAuth()
|
||||
const [authError, setAuthError] = useState<string | null>(null)
|
||||
const emailFieldId = useId()
|
||||
const passwordFieldId = useId()
|
||||
|
||||
const {
|
||||
control,
|
||||
handleSubmit,
|
||||
formState: { errors, isSubmitting },
|
||||
} = useForm<RegisterForm>({
|
||||
resolver: zodResolver(registerSchema),
|
||||
defaultValues: { email: "", password: "" },
|
||||
})
|
||||
|
||||
// Session check (api.auth.me()) hasn't resolved yet — don't flash the
|
||||
// registration form for a visitor who turns out to already have a valid
|
||||
// session.
|
||||
if (loading) return null
|
||||
|
||||
// Already authenticated — skip straight to the app instead of showing the
|
||||
// registration form again.
|
||||
if (user) return <Navigate to="/domains" replace />
|
||||
|
||||
async function onSubmit(values: RegisterForm) {
|
||||
setAuthError(null)
|
||||
try {
|
||||
await registerUser(values.email, values.password)
|
||||
} catch (err) {
|
||||
setAuthError(describeRegisterError(err))
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex h-screen w-full items-center justify-center bg-background px-4">
|
||||
<div className="flex w-full max-w-sm flex-col gap-6">
|
||||
<div className="flex flex-col items-center gap-2 text-center">
|
||||
<SquareTerminal className="size-6 text-primary" strokeWidth={1.75} />
|
||||
<div className="flex flex-col leading-none">
|
||||
<span className="text-sm font-semibold tracking-tight">DNS Autoresolver</span>
|
||||
<span className="font-dns text-[10px] tracking-wider text-muted-foreground uppercase">
|
||||
console
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<form
|
||||
onSubmit={handleSubmit(onSubmit)}
|
||||
noValidate
|
||||
className="flex flex-col gap-4 rounded-xl border border-border bg-card/60 p-5"
|
||||
>
|
||||
<FieldSet className="gap-3">
|
||||
<FieldGroup className="gap-3">
|
||||
<Field>
|
||||
<FieldLabel htmlFor={emailFieldId}>Email</FieldLabel>
|
||||
<FieldContent>
|
||||
<Controller
|
||||
control={control}
|
||||
name="email"
|
||||
render={({ field }) => (
|
||||
<Input
|
||||
{...field}
|
||||
id={emailFieldId}
|
||||
type="email"
|
||||
autoComplete="email"
|
||||
placeholder="you@example.com"
|
||||
aria-invalid={!!errors.email}
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
<FieldError errors={[errors.email]} />
|
||||
</FieldContent>
|
||||
</Field>
|
||||
|
||||
<Field>
|
||||
<FieldLabel htmlFor={passwordFieldId}>Пароль</FieldLabel>
|
||||
<FieldContent>
|
||||
<Controller
|
||||
control={control}
|
||||
name="password"
|
||||
render={({ field }) => (
|
||||
<Input
|
||||
{...field}
|
||||
id={passwordFieldId}
|
||||
type="password"
|
||||
autoComplete="new-password"
|
||||
placeholder="••••••••••••"
|
||||
aria-invalid={!!errors.password}
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
<FieldError errors={[errors.password]} />
|
||||
</FieldContent>
|
||||
</Field>
|
||||
</FieldGroup>
|
||||
|
||||
{authError && (
|
||||
<FieldDescription role="alert" className="flex items-center gap-2 text-destructive">
|
||||
<KeyRound className="size-3.5 shrink-0" strokeWidth={1.75} />
|
||||
{authError}
|
||||
</FieldDescription>
|
||||
)}
|
||||
</FieldSet>
|
||||
|
||||
<Button type="submit" disabled={isSubmitting} className="w-full">
|
||||
{isSubmitting ? (
|
||||
<Loader2 className="size-4 animate-spin" strokeWidth={1.75} />
|
||||
) : (
|
||||
<UserPlus className="size-4" strokeWidth={1.75} />
|
||||
)}
|
||||
Зарегистрироваться
|
||||
</Button>
|
||||
</form>
|
||||
|
||||
<p className="text-center text-sm text-muted-foreground">
|
||||
Уже есть аккаунт?{" "}
|
||||
<Link to="/login" className="text-primary underline-offset-4 hover:underline">
|
||||
Войти
|
||||
</Link>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -3,10 +3,12 @@ import userEvent from "@testing-library/user-event"
|
||||
import { MemoryRouter } from "react-router-dom"
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query"
|
||||
import { TemplatesPage } from "./TemplatesPage"
|
||||
import { AuthProvider } from "@/auth/AuthContext"
|
||||
import { api } from "@/api/client"
|
||||
import { vi, beforeEach, test, expect } from "vitest"
|
||||
import type { Template } from "@/api/types"
|
||||
|
||||
const PROJECT_ID = "p1"
|
||||
const templates: Template[] = [
|
||||
{
|
||||
id: "t1",
|
||||
@@ -21,14 +23,21 @@ function renderPage() {
|
||||
const qc = new QueryClient()
|
||||
return render(
|
||||
<QueryClientProvider client={qc}>
|
||||
<AuthProvider>
|
||||
<MemoryRouter initialEntries={["/templates"]}>
|
||||
<TemplatesPage />
|
||||
</MemoryRouter>
|
||||
</AuthProvider>
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
vi.spyOn(api.auth, "me").mockResolvedValue({
|
||||
user: { id: "u1", email: "a@b.com" },
|
||||
project: { id: PROJECT_ID, name: "Default" },
|
||||
})
|
||||
vi.spyOn(api, "listTemplates").mockResolvedValue(templates)
|
||||
})
|
||||
|
||||
@@ -65,7 +74,7 @@ test("создание шаблона с записью вызывает api.cre
|
||||
await user.click(screen.getByRole("button", { name: /сохранить шаблон/i }))
|
||||
|
||||
await waitFor(() =>
|
||||
expect(createSpy).toHaveBeenCalledWith({
|
||||
expect(createSpy).toHaveBeenCalledWith(PROJECT_ID, {
|
||||
name: "New",
|
||||
records: [{ type: "A", name: "www", ttl: 3600, values: ["1.1.1.1"] }],
|
||||
}),
|
||||
@@ -88,7 +97,7 @@ test("редактирование шаблона вызывает api.updateTem
|
||||
await user.click(screen.getByRole("button", { name: /сохранить шаблон/i }))
|
||||
|
||||
await waitFor(() =>
|
||||
expect(updateSpy).toHaveBeenCalledWith("t1", {
|
||||
expect(updateSpy).toHaveBeenCalledWith(PROJECT_ID, "t1", {
|
||||
name: "Standard v2",
|
||||
records: [{ type: "A", name: "@", ttl: 3600, values: ["1.2.3.4"] }],
|
||||
}),
|
||||
@@ -105,7 +114,7 @@ test("удаление шаблона вызывает api.deleteTemplate", asyn
|
||||
|
||||
await user.click(screen.getByRole("button", { name: /удалить шаблон standard/i }))
|
||||
|
||||
await waitFor(() => expect(deleteSpy).toHaveBeenCalledWith("t1"))
|
||||
await waitFor(() => expect(deleteSpy).toHaveBeenCalledWith(PROJECT_ID, "t1"))
|
||||
})
|
||||
|
||||
test("ошибка создания шаблона отображается пользователю", async () => {
|
||||
|
||||
Reference in New Issue
Block a user