From 3bd237d5627e450062af5ffd5bdd9a9bc4b8c6cf Mon Sep 17 00:00:00 2001 From: Vassiliy Yegorov Date: Fri, 3 Jul 2026 19:44:36 +0700 Subject: [PATCH 01/10] =?UTF-8?q?feat(store):=20=D0=BC=D0=B8=D0=B3=D1=80?= =?UTF-8?q?=D0=B0=D1=86=D0=B8=D1=8F=20sessions/password=20+=20=D0=BC=D0=B5?= =?UTF-8?q?=D1=82=D0=BE=D0=B4=D1=8B=20users/sessions/projects?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Фаза 2, Task 1: добавлена таблица sessions и nullable password_hash у users, sqlc-запросы и *Store-обёртки (CreateUser, GetUserByEmail, CreateProjectForUser, GetProjectOwned, GetUserProject, CreateSession, GetSessionUser, DeleteSession, RegisterUser в транзакции), интеграционные тесты на testcontainers. Co-Authored-By: Claude Opus 4.8 (1M context) Claude-Session: https://claude.ai/code/session_01BwxdSt4reTm7Dj1oxRvpP3 --- internal/store/auth_test.go | 149 ++++++++++++++++++++++++ internal/store/db/models.go | 15 ++- internal/store/db/projects.sql.go | 71 +++++++++++ internal/store/db/sessions.sql.go | 54 +++++++++ internal/store/db/users.sql.go | 50 ++++++++ internal/store/migrations/0003_auth.sql | 15 +++ internal/store/queries/projects.sql | 8 ++ internal/store/queries/sessions.sql | 8 ++ internal/store/queries/users.sql | 5 + internal/store/tenant.go | 126 ++++++++++++++++++++ 10 files changed, 498 insertions(+), 3 deletions(-) create mode 100644 internal/store/auth_test.go create mode 100644 internal/store/db/projects.sql.go create mode 100644 internal/store/db/sessions.sql.go create mode 100644 internal/store/db/users.sql.go create mode 100644 internal/store/migrations/0003_auth.sql create mode 100644 internal/store/queries/projects.sql create mode 100644 internal/store/queries/sessions.sql create mode 100644 internal/store/queries/users.sql diff --git a/internal/store/auth_test.go b/internal/store/auth_test.go new file mode 100644 index 0000000..2bedd35 --- /dev/null +++ b/internal/store/auth_test.go @@ -0,0 +1,149 @@ +package store + +import ( + "testing" + "time" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5" +) + +// TestRegisterUser_CreatesUserAndOwnedProject verifies the RegisterUser +// transaction: a user and a default project are created together, and the +// project belongs to that user. +func TestRegisterUser_CreatesUserAndOwnedProject(t *testing.T) { + s, ctx := newStore(t) + + u, p, err := s.RegisterUser(ctx, "alice@example.com", "argon2-hash") + if err != nil { + t.Fatal(err) + } + if u.Email != "alice@example.com" || u.PasswordHash != "argon2-hash" { + t.Fatalf("unexpected user: %+v", u) + } + if p.UserID != u.ID { + t.Fatalf("expected project to belong to user %s, got %+v", u.ID, p) + } + + owned, err := s.GetProjectOwned(ctx, p.ID, u.ID) + if err != nil { + t.Fatal(err) + } + if owned.ID != p.ID { + t.Fatalf("expected owned project %s, got %+v", p.ID, owned) + } +} + +// TestGetUserByEmail_FindsRegisteredUser verifies email lookup returns the +// same user created by RegisterUser. +func TestGetUserByEmail_FindsRegisteredUser(t *testing.T) { + s, ctx := newStore(t) + + u, _, err := s.RegisterUser(ctx, "bob@example.com", "argon2-hash") + if err != nil { + t.Fatal(err) + } + + got, err := s.GetUserByEmail(ctx, "bob@example.com") + if err != nil { + t.Fatal(err) + } + if got.ID != u.ID || got.PasswordHash != "argon2-hash" { + t.Fatalf("unexpected user: %+v", got) + } + + if _, err := s.GetUserByEmail(ctx, "nobody@example.com"); err == nil { + t.Fatal("expected error for unknown email, got nil") + } +} + +// TestSessionLifecycle_CreateGetDelete verifies CreateSession + GetSessionUser +// round-trips to the owning user ID, an expired session is excluded from +// GetSessionUser, and DeleteSession removes the session. +func TestSessionLifecycle_CreateGetDelete(t *testing.T) { + s, ctx := newStore(t) + + u, _, err := s.RegisterUser(ctx, "carol@example.com", "argon2-hash") + if err != nil { + t.Fatal(err) + } + + tokenHash := "sha256-token-hash" + if err := s.CreateSession(ctx, u.ID, tokenHash, time.Now().Add(time.Hour)); err != nil { + t.Fatal(err) + } + + gotUserID, err := s.GetSessionUser(ctx, tokenHash) + if err != nil { + t.Fatal(err) + } + if gotUserID != u.ID { + t.Fatalf("expected user %s, got %s", u.ID, gotUserID) + } + + if err := s.DeleteSession(ctx, tokenHash); err != nil { + t.Fatal(err) + } + if _, err := s.GetSessionUser(ctx, tokenHash); err == nil { + t.Fatal("expected error after DeleteSession, got nil") + } +} + +// TestGetSessionUser_ExpiredSessionNotReturned verifies the query's +// expires_at > now() condition: a session created with an expiry in the +// past must not be returned by GetSessionUser. +func TestGetSessionUser_ExpiredSessionNotReturned(t *testing.T) { + s, ctx := newStore(t) + + u, _, err := s.RegisterUser(ctx, "dave@example.com", "argon2-hash") + if err != nil { + t.Fatal(err) + } + + tokenHash := "expired-token-hash" + if err := s.CreateSession(ctx, u.ID, tokenHash, time.Now().Add(-time.Hour)); err != nil { + t.Fatal(err) + } + + if _, err := s.GetSessionUser(ctx, tokenHash); err == nil { + t.Fatal("expected expired session to not be returned, got nil error") + } else if err != pgx.ErrNoRows { + t.Fatalf("expected pgx.ErrNoRows, got %v", err) + } +} + +// TestGetProjectOwned_ForeignUserRejected verifies that looking up a project +// with the wrong user ID fails, so one tenant cannot address another +// tenant's project by guessing its ID. +func TestGetProjectOwned_ForeignUserRejected(t *testing.T) { + s, ctx := newStore(t) + + _, p, err := s.RegisterUser(ctx, "erin@example.com", "argon2-hash") + if err != nil { + t.Fatal(err) + } + + foreignUserID := uuid.New() + if _, err := s.GetProjectOwned(ctx, p.ID, foreignUserID); err == nil { + t.Fatal("expected error for foreign user ID, got nil") + } +} + +// TestGetUserProject_ReturnsTheUsersProject verifies GetUserProject returns +// the project created for that user by RegisterUser. +func TestGetUserProject_ReturnsTheUsersProject(t *testing.T) { + s, ctx := newStore(t) + + u, p, err := s.RegisterUser(ctx, "frank@example.com", "argon2-hash") + if err != nil { + t.Fatal(err) + } + + got, err := s.GetUserProject(ctx, u.ID) + if err != nil { + t.Fatal(err) + } + if got.ID != p.ID { + t.Fatalf("expected project %s, got %+v", p.ID, got) + } +} diff --git a/internal/store/db/models.go b/internal/store/db/models.go index d3b98ab..d541484 100644 --- a/internal/store/db/models.go +++ b/internal/store/db/models.go @@ -43,6 +43,14 @@ type ProviderAccount struct { CreatedAt pgtype.Timestamptz `json:"created_at"` } +type Session struct { + ID uuid.UUID `json:"id"` + UserID uuid.UUID `json:"user_id"` + TokenHash string `json:"token_hash"` + ExpiresAt pgtype.Timestamptz `json:"expires_at"` + CreatedAt pgtype.Timestamptz `json:"created_at"` +} + type Template struct { ID uuid.UUID `json:"id"` ProjectID uuid.UUID `json:"project_id"` @@ -54,7 +62,8 @@ type Template struct { } type User struct { - ID uuid.UUID `json:"id"` - Email string `json:"email"` - CreatedAt pgtype.Timestamptz `json:"created_at"` + ID uuid.UUID `json:"id"` + Email string `json:"email"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + PasswordHash *string `json:"password_hash"` } diff --git a/internal/store/db/projects.sql.go b/internal/store/db/projects.sql.go new file mode 100644 index 0000000..4e0d435 --- /dev/null +++ b/internal/store/db/projects.sql.go @@ -0,0 +1,71 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 +// source: projects.sql + +package db + +import ( + "context" + + "github.com/google/uuid" +) + +const createProject = `-- name: CreateProject :one +INSERT INTO projects (id, user_id, name) VALUES ($1, $2, $3) RETURNING id, user_id, name, created_at +` + +type CreateProjectParams struct { + ID uuid.UUID `json:"id"` + UserID uuid.UUID `json:"user_id"` + Name string `json:"name"` +} + +func (q *Queries) CreateProject(ctx context.Context, arg CreateProjectParams) (Project, error) { + row := q.db.QueryRow(ctx, createProject, arg.ID, arg.UserID, arg.Name) + var i Project + err := row.Scan( + &i.ID, + &i.UserID, + &i.Name, + &i.CreatedAt, + ) + return i, err +} + +const getProjectOwned = `-- name: GetProjectOwned :one +SELECT id, user_id, name, created_at FROM projects WHERE id = $1 AND user_id = $2 +` + +type GetProjectOwnedParams struct { + ID uuid.UUID `json:"id"` + UserID uuid.UUID `json:"user_id"` +} + +func (q *Queries) GetProjectOwned(ctx context.Context, arg GetProjectOwnedParams) (Project, error) { + row := q.db.QueryRow(ctx, getProjectOwned, arg.ID, arg.UserID) + var i Project + err := row.Scan( + &i.ID, + &i.UserID, + &i.Name, + &i.CreatedAt, + ) + return i, err +} + +const getUserProject = `-- name: GetUserProject :one +SELECT id, user_id, name, created_at FROM projects WHERE user_id = $1 ORDER BY created_at LIMIT 1 +` + +func (q *Queries) GetUserProject(ctx context.Context, userID uuid.UUID) (Project, error) { + row := q.db.QueryRow(ctx, getUserProject, userID) + var i Project + err := row.Scan( + &i.ID, + &i.UserID, + &i.Name, + &i.CreatedAt, + ) + return i, err +} diff --git a/internal/store/db/sessions.sql.go b/internal/store/db/sessions.sql.go new file mode 100644 index 0000000..bfd6286 --- /dev/null +++ b/internal/store/db/sessions.sql.go @@ -0,0 +1,54 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 +// source: sessions.sql + +package db + +import ( + "context" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgtype" +) + +const createSession = `-- name: CreateSession :exec +INSERT INTO sessions (id, user_id, token_hash, expires_at) VALUES ($1, $2, $3, $4) +` + +type CreateSessionParams struct { + ID uuid.UUID `json:"id"` + UserID uuid.UUID `json:"user_id"` + TokenHash string `json:"token_hash"` + ExpiresAt pgtype.Timestamptz `json:"expires_at"` +} + +func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) error { + _, err := q.db.Exec(ctx, createSession, + arg.ID, + arg.UserID, + arg.TokenHash, + arg.ExpiresAt, + ) + return err +} + +const deleteSession = `-- name: DeleteSession :exec +DELETE FROM sessions WHERE token_hash = $1 +` + +func (q *Queries) DeleteSession(ctx context.Context, tokenHash string) error { + _, err := q.db.Exec(ctx, deleteSession, tokenHash) + return err +} + +const getSessionUser = `-- name: GetSessionUser :one +SELECT user_id FROM sessions WHERE token_hash = $1 AND expires_at > now() +` + +func (q *Queries) GetSessionUser(ctx context.Context, tokenHash string) (uuid.UUID, error) { + row := q.db.QueryRow(ctx, getSessionUser, tokenHash) + var user_id uuid.UUID + err := row.Scan(&user_id) + return user_id, err +} diff --git a/internal/store/db/users.sql.go b/internal/store/db/users.sql.go new file mode 100644 index 0000000..c76c410 --- /dev/null +++ b/internal/store/db/users.sql.go @@ -0,0 +1,50 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 +// source: users.sql + +package db + +import ( + "context" + + "github.com/google/uuid" +) + +const createUser = `-- name: CreateUser :one +INSERT INTO users (id, email, password_hash) VALUES ($1, $2, $3) RETURNING id, email, created_at, password_hash +` + +type CreateUserParams struct { + ID uuid.UUID `json:"id"` + Email string `json:"email"` + PasswordHash *string `json:"password_hash"` +} + +func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) { + row := q.db.QueryRow(ctx, createUser, arg.ID, arg.Email, arg.PasswordHash) + var i User + err := row.Scan( + &i.ID, + &i.Email, + &i.CreatedAt, + &i.PasswordHash, + ) + return i, err +} + +const getUserByEmail = `-- name: GetUserByEmail :one +SELECT id, email, created_at, password_hash FROM users WHERE email = $1 +` + +func (q *Queries) GetUserByEmail(ctx context.Context, email string) (User, error) { + row := q.db.QueryRow(ctx, getUserByEmail, email) + var i User + err := row.Scan( + &i.ID, + &i.Email, + &i.CreatedAt, + &i.PasswordHash, + ) + return i, err +} diff --git a/internal/store/migrations/0003_auth.sql b/internal/store/migrations/0003_auth.sql new file mode 100644 index 0000000..2aa0ed4 --- /dev/null +++ b/internal/store/migrations/0003_auth.sql @@ -0,0 +1,15 @@ +-- +goose Up +ALTER TABLE users ADD COLUMN password_hash text; + +CREATE TABLE sessions ( + id uuid PRIMARY KEY, + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + token_hash text NOT NULL UNIQUE, + expires_at timestamptz NOT NULL, + created_at timestamptz NOT NULL DEFAULT now() +); +CREATE INDEX sessions_token_hash_idx ON sessions (token_hash); + +-- +goose Down +DROP TABLE sessions; +ALTER TABLE users DROP COLUMN password_hash; diff --git a/internal/store/queries/projects.sql b/internal/store/queries/projects.sql new file mode 100644 index 0000000..cc77dc4 --- /dev/null +++ b/internal/store/queries/projects.sql @@ -0,0 +1,8 @@ +-- name: CreateProject :one +INSERT INTO projects (id, user_id, name) VALUES ($1, $2, $3) RETURNING *; + +-- name: GetProjectOwned :one +SELECT * FROM projects WHERE id = $1 AND user_id = $2; + +-- name: GetUserProject :one +SELECT * FROM projects WHERE user_id = $1 ORDER BY created_at LIMIT 1; diff --git a/internal/store/queries/sessions.sql b/internal/store/queries/sessions.sql new file mode 100644 index 0000000..337e3f7 --- /dev/null +++ b/internal/store/queries/sessions.sql @@ -0,0 +1,8 @@ +-- name: CreateSession :exec +INSERT INTO sessions (id, user_id, token_hash, expires_at) VALUES ($1, $2, $3, $4); + +-- name: GetSessionUser :one +SELECT user_id FROM sessions WHERE token_hash = $1 AND expires_at > now(); + +-- name: DeleteSession :exec +DELETE FROM sessions WHERE token_hash = $1; diff --git a/internal/store/queries/users.sql b/internal/store/queries/users.sql new file mode 100644 index 0000000..d68dc45 --- /dev/null +++ b/internal/store/queries/users.sql @@ -0,0 +1,5 @@ +-- name: CreateUser :one +INSERT INTO users (id, email, password_hash) VALUES ($1, $2, $3) RETURNING *; + +-- name: GetUserByEmail :one +SELECT * FROM users WHERE email = $1; diff --git a/internal/store/tenant.go b/internal/store/tenant.go index 724f811..37961a0 100644 --- a/internal/store/tenant.go +++ b/internal/store/tenant.go @@ -3,9 +3,11 @@ package store import ( "context" "errors" + "time" "github.com/google/uuid" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" "github.com/vasyakrg/dns-autoresolver/internal/provider" "github.com/vasyakrg/dns-autoresolver/internal/store/db" @@ -222,3 +224,127 @@ func (s *Store) SetDomainTemplate(ctx context.Context, domainID, projectID uuid. } return domainFromDB(d), nil } + +// User and Project are provider-neutral domain structs for the auth/tenant +// layer (Фаза 2), mirroring the Account/Template/Domain wrappers above so +// callers never need to import internal/store/db directly. + +type User struct { + ID uuid.UUID + Email string + PasswordHash string +} + +type Project struct { + ID uuid.UUID + UserID uuid.UUID + Name string +} + +// ptr is a small helper for passing a Go string into a nullable text column +// (password_hash) via sqlc's generated *string param type. +func ptr(s string) *string { return &s } + +// strFromPtr converts a nullable text column back into a plain string; a +// nil password_hash never happens on the real registration flow (an argon2 +// hash is always supplied), but is handled defensively here. +func strFromPtr(p *string) string { + if p == nil { + return "" + } + return *p +} + +func toUser(u db.User) User { + return User{ID: u.ID, Email: u.Email, PasswordHash: strFromPtr(u.PasswordHash)} +} + +func toProject(p db.Project) Project { + return Project{ID: p.ID, UserID: p.UserID, Name: p.Name} +} + +func (s *Store) CreateUser(ctx context.Context, email, passwordHash string) (User, error) { + u, err := s.q.CreateUser(ctx, db.CreateUserParams{ID: uuid.New(), Email: email, PasswordHash: ptr(passwordHash)}) + if err != nil { + return User{}, err + } + return toUser(u), nil +} + +func (s *Store) GetUserByEmail(ctx context.Context, email string) (User, error) { + u, err := s.q.GetUserByEmail(ctx, email) + if err != nil { + return User{}, err + } + return toUser(u), nil +} + +func (s *Store) CreateProjectForUser(ctx context.Context, userID uuid.UUID, name string) (Project, error) { + p, err := s.q.CreateProject(ctx, db.CreateProjectParams{ID: uuid.New(), UserID: userID, Name: name}) + if err != nil { + return Project{}, err + } + return toProject(p), nil +} + +func (s *Store) GetProjectOwned(ctx context.Context, projectID, userID uuid.UUID) (Project, error) { + p, err := s.q.GetProjectOwned(ctx, db.GetProjectOwnedParams{ID: projectID, UserID: userID}) + if err != nil { + return Project{}, err + } + return toProject(p), nil +} + +func (s *Store) GetUserProject(ctx context.Context, userID uuid.UUID) (Project, error) { + p, err := s.q.GetUserProject(ctx, userID) + if err != nil { + return Project{}, err + } + return toProject(p), nil +} + +func (s *Store) CreateSession(ctx context.Context, userID uuid.UUID, tokenHash string, expiresAt time.Time) error { + return s.q.CreateSession(ctx, db.CreateSessionParams{ + ID: uuid.New(), + UserID: userID, + TokenHash: tokenHash, + ExpiresAt: pgtype.Timestamptz{Time: expiresAt, Valid: true}, + }) +} + +// GetSessionUser returns the owning user ID for a non-expired session token; +// expired sessions are excluded by the query itself (expires_at > now()). +func (s *Store) GetSessionUser(ctx context.Context, tokenHash string) (uuid.UUID, error) { + return s.q.GetSessionUser(ctx, tokenHash) +} + +func (s *Store) DeleteSession(ctx context.Context, tokenHash string) error { + return s.q.DeleteSession(ctx, tokenHash) +} + +// RegisterUser creates a user and their default project in one transaction, +// mirroring the ImportDomains pattern above: if project creation fails, the +// user insert is rolled back too, so a caller never observes a user without +// a default project. +func (s *Store) RegisterUser(ctx context.Context, email, passwordHash string) (User, Project, error) { + tx, err := s.pool.Begin(ctx) + if err != nil { + return User{}, Project{}, err + } + defer tx.Rollback(ctx) // no-op once Commit has succeeded + + q := s.q.WithTx(tx) + uid := uuid.New() + dbu, err := q.CreateUser(ctx, db.CreateUserParams{ID: uid, Email: email, PasswordHash: ptr(passwordHash)}) + if err != nil { + return User{}, Project{}, err + } + dbp, err := q.CreateProject(ctx, db.CreateProjectParams{ID: uuid.New(), UserID: uid, Name: "default"}) + if err != nil { + return User{}, Project{}, err + } + if err := tx.Commit(ctx); err != nil { + return User{}, Project{}, err + } + return toUser(dbu), toProject(dbp), nil +} From 12b7945efcb61b42ea3944c776a681a7b078f345 Mon Sep 17 00:00:00 2001 From: Vassiliy Yegorov Date: Fri, 3 Jul 2026 19:50:11 +0700 Subject: [PATCH 02/10] =?UTF-8?q?feat(auth):=20argon2id=20=D0=BF=D0=B0?= =?UTF-8?q?=D1=80=D0=BE=D0=BB=D0=B8=20+=20session=20store=20(sha256=20?= =?UTF-8?q?=D1=82=D0=BE=D0=BA=D0=B5=D0=BD=D0=B0)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- go.mod | 6 ++-- go.sum | 16 +++++----- internal/auth/password.go | 52 +++++++++++++++++++++++++++++++ internal/auth/password_test.go | 29 ++++++++++++++++++ internal/auth/session.go | 56 ++++++++++++++++++++++++++++++++++ internal/auth/session_test.go | 55 +++++++++++++++++++++++++++++++++ 6 files changed, 203 insertions(+), 11 deletions(-) create mode 100644 internal/auth/password.go create mode 100644 internal/auth/password_test.go create mode 100644 internal/auth/session.go create mode 100644 internal/auth/session_test.go diff --git a/go.mod b/go.mod index 986f736..59e8617 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/pressly/goose/v3 v3.27.2 github.com/testcontainers/testcontainers-go v0.43.0 github.com/testcontainers/testcontainers-go/modules/postgres v0.43.0 + golang.org/x/crypto v0.53.0 ) require ( @@ -64,9 +65,8 @@ require ( go.opentelemetry.io/otel/metric v1.43.0 // indirect go.opentelemetry.io/otel/trace v1.43.0 // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/crypto v0.52.0 // indirect golang.org/x/sync v0.21.0 // indirect - golang.org/x/sys v0.45.0 // indirect - golang.org/x/text v0.37.0 // indirect + golang.org/x/sys v0.46.0 // indirect + golang.org/x/text v0.38.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 23b3e13..92810b7 100644 --- a/go.sum +++ b/go.sum @@ -150,19 +150,19 @@ go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09 go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= -golang.org/x/crypto v0.52.0 h1:RMs7fP2rXdep0CftQlK8Uf+kibLm7qkCcradZWYz988= -golang.org/x/crypto v0.52.0/go.mod h1:1QgfPxDqh0T2M/elOJtp9RvuR95kVjir0e6/BvEmGbc= +golang.org/x/crypto v0.53.0 h1:QZ4Muo8THX6CizN2vPPd5fBGHyogrdK9fG4wLPFUsto= +golang.org/x/crypto v0.53.0/go.mod h1:DNLU434OwVakk9PzuwV8w62mAJpRJL3vsgcfp4Qnsio= golang.org/x/sync v0.21.0 h1:HLII4xRRTtCRkxYp4HNFF0Js/Og6q2i++KXbg0gHCwM= golang.org/x/sync v0.21.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY= -golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= -golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4= -golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk= -golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc= -golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38= +golang.org/x/sys v0.46.0 h1:noSf2Fq6F8DBgS+LysIkx7rIExoNHJsxOAtPp4rthXw= +golang.org/x/sys v0.46.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/term v0.44.0 h1:0rLvDRCtNj0gZkyIXhCyOb2OAzEhLVqc4B+hrsBhrmc= +golang.org/x/term v0.44.0/go.mod h1:7ze4MdzUzLXpSAoFP1H0bOI9aXDqveSvatT5vKcFh2Y= +golang.org/x/text v0.38.0 h1:sXmwo9DwP3OK9EZ7PqAdaooSGozfl/3a6/xJcbzPRhE= +golang.org/x/text v0.38.0/go.mod h1:YXZt3QhHUKYT53r2lLKFIVi6Ao1jdzrTR/KQ09qyxF4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= diff --git a/internal/auth/password.go b/internal/auth/password.go new file mode 100644 index 0000000..e88b04b --- /dev/null +++ b/internal/auth/password.go @@ -0,0 +1,52 @@ +package auth + +import ( + "crypto/rand" + "crypto/subtle" + "encoding/base64" + "fmt" + "strings" + + "golang.org/x/crypto/argon2" +) + +const ( + argonTime = 1 + argonMemory = 64 * 1024 + argonThreads = 4 + argonKeyLen = 32 + argonSaltLen = 16 +) + +func HashPassword(password string) (string, error) { + salt := make([]byte, argonSaltLen) + if _, err := rand.Read(salt); err != nil { + return "", err + } + key := argon2.IDKey([]byte(password), salt, argonTime, argonMemory, argonThreads, argonKeyLen) + b64 := base64.RawStdEncoding.EncodeToString + return fmt.Sprintf("$argon2id$v=19$m=%d,t=%d,p=%d$%s$%s", + argonMemory, argonTime, argonThreads, b64(salt), b64(key)), nil +} + +func VerifyPassword(encoded, password string) (bool, error) { + parts := strings.Split(encoded, "$") + if len(parts) != 6 || parts[1] != "argon2id" { + return false, fmt.Errorf("auth: bad hash format") + } + var m, t uint32 + var p uint8 + if _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &m, &t, &p); err != nil { + return false, err + } + salt, err := base64.RawStdEncoding.DecodeString(parts[4]) + if err != nil { + return false, err + } + want, err := base64.RawStdEncoding.DecodeString(parts[5]) + if err != nil { + return false, err + } + got := argon2.IDKey([]byte(password), salt, t, m, p, uint32(len(want))) + return subtle.ConstantTimeCompare(got, want) == 1, nil +} diff --git a/internal/auth/password_test.go b/internal/auth/password_test.go new file mode 100644 index 0000000..61c151e --- /dev/null +++ b/internal/auth/password_test.go @@ -0,0 +1,29 @@ +package auth + +import "testing" + +func TestHashVerifyRoundTrip(t *testing.T) { + h, err := HashPassword("s3cret-pw") + if err != nil { + t.Fatal(err) + } + if h == "s3cret-pw" || len(h) < 20 { + t.Fatalf("bad hash %q", h) + } + ok, err := VerifyPassword(h, "s3cret-pw") + if err != nil || !ok { + t.Fatalf("verify failed: %v %v", ok, err) + } + bad, _ := VerifyPassword(h, "wrong") + if bad { + t.Fatal("wrong password must not verify") + } +} + +func TestHashNonDeterministic(t *testing.T) { + a, _ := HashPassword("same") + b, _ := HashPassword("same") + if a == b { + t.Fatal("salt must randomize hash") + } +} diff --git a/internal/auth/session.go b/internal/auth/session.go new file mode 100644 index 0000000..cf34883 --- /dev/null +++ b/internal/auth/session.go @@ -0,0 +1,56 @@ +package auth + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "errors" + "time" + + "github.com/google/uuid" +) + +var ErrNoSession = errors.New("auth: no such session") + +type SessionStore interface { + CreateSession(ctx context.Context, userID uuid.UUID, tokenHash string, expiresAt time.Time) error + GetSessionUser(ctx context.Context, tokenHash string) (uuid.UUID, error) + DeleteSession(ctx context.Context, tokenHash string) error +} + +type Sessions struct { + store SessionStore + ttl time.Duration +} + +func NewSessions(store SessionStore, ttl time.Duration) *Sessions { + return &Sessions{store: store, ttl: ttl} +} + +func TokenHash(token string) string { + sum := sha256.Sum256([]byte(token)) + return hex.EncodeToString(sum[:]) +} + +func (s *Sessions) Create(ctx context.Context, userID uuid.UUID) (string, time.Time, error) { + raw := make([]byte, 32) + if _, err := rand.Read(raw); err != nil { + return "", time.Time{}, err + } + token := base64.RawURLEncoding.EncodeToString(raw) + exp := time.Now().Add(s.ttl) + if err := s.store.CreateSession(ctx, userID, TokenHash(token), exp); err != nil { + return "", time.Time{}, err + } + return token, exp, nil +} + +func (s *Sessions) Validate(ctx context.Context, token string) (uuid.UUID, error) { + return s.store.GetSessionUser(ctx, TokenHash(token)) +} + +func (s *Sessions) Destroy(ctx context.Context, token string) error { + return s.store.DeleteSession(ctx, TokenHash(token)) +} diff --git a/internal/auth/session_test.go b/internal/auth/session_test.go new file mode 100644 index 0000000..6dad7da --- /dev/null +++ b/internal/auth/session_test.go @@ -0,0 +1,55 @@ +package auth + +import ( + "context" + "testing" + "time" + + "github.com/google/uuid" +) + +type memStore struct { + byHash map[string]uuid.UUID + exp map[string]time.Time +} + +func newMem() *memStore { return &memStore{byHash: map[string]uuid.UUID{}, exp: map[string]time.Time{}} } +func (m *memStore) CreateSession(_ context.Context, uid uuid.UUID, h string, e time.Time) error { + m.byHash[h] = uid + m.exp[h] = e + return nil +} +func (m *memStore) GetSessionUser(_ context.Context, h string) (uuid.UUID, error) { + uid, ok := m.byHash[h] + if !ok || time.Now().After(m.exp[h]) { + return uuid.Nil, ErrNoSession + } + return uid, nil +} +func (m *memStore) DeleteSession(_ context.Context, h string) error { delete(m.byHash, h); return nil } + +func TestSessionCreateValidateDestroy(t *testing.T) { + s := NewSessions(newMem(), time.Hour) + uid := uuid.New() + token, exp, err := s.Create(context.Background(), uid) + if err != nil || token == "" || exp.Before(time.Now()) { + t.Fatalf("create: %v %q", err, token) + } + got, err := s.Validate(context.Background(), token) + if err != nil || got != uid { + t.Fatalf("validate: %v %v", got, err) + } + if err := s.Destroy(context.Background(), token); err != nil { + t.Fatal(err) + } + if _, err := s.Validate(context.Background(), token); err == nil { + t.Fatal("destroyed session must not validate") + } +} + +func TestValidateUnknownToken(t *testing.T) { + s := NewSessions(newMem(), time.Hour) + if _, err := s.Validate(context.Background(), "nope"); err == nil { + t.Fatal("unknown token must error") + } +} From a584cf5c37f3e17c4f6265b9bc2df7377b0e5bbe Mon Sep 17 00:00:00 2001 From: Vassiliy Yegorov Date: Fri, 3 Jul 2026 19:58:54 +0700 Subject: [PATCH 03/10] =?UTF-8?q?fix(auth):=20VerifyPassword=20=D0=B2?= =?UTF-8?q?=D0=B0=D0=BB=D0=B8=D0=B4=D0=B8=D1=80=D1=83=D0=B5=D1=82=20=D0=BF?= =?UTF-8?q?=D0=B0=D1=80=D0=B0=D0=BC=D0=B5=D1=82=D1=80=D1=8B/=D0=B2=D0=B5?= =?UTF-8?q?=D1=80=D1=81=D0=B8=D1=8E,=20=D0=BD=D0=B5=20=D0=BF=D0=B0=D0=BD?= =?UTF-8?q?=D0=B8=D0=BA=D1=83=D0=B5=D1=82=20=D0=BD=D0=B0=20=D0=B1=D0=B8?= =?UTF-8?q?=D1=82=D0=BE=D0=BC=20=D1=85=D1=8D=D1=88=D0=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/auth/password.go | 49 +++++++++++++++++++++++++--- internal/auth/password_test.go | 58 ++++++++++++++++++++++++++++++++++ 2 files changed, 103 insertions(+), 4 deletions(-) diff --git a/internal/auth/password.go b/internal/auth/password.go index e88b04b..6127e58 100644 --- a/internal/auth/password.go +++ b/internal/auth/password.go @@ -5,6 +5,8 @@ import ( "crypto/subtle" "encoding/base64" "fmt" + "regexp" + "strconv" "strings" "golang.org/x/crypto/argon2" @@ -16,8 +18,18 @@ const ( argonThreads = 4 argonKeyLen = 32 argonSaltLen = 16 + + // Upper bounds guard against a corrupted/attacker-controlled + // password_hash forcing an oversized argon2 computation (DoS). + argonMaxMemoryKiB = 1 << 21 // 2 GiB in KiB + argonMaxTime = 10 + argonMaxThreads = 16 ) +// paramsRe strictly matches the "m=,t=,p=" parameter segment, +// requiring the whole segment to be consumed (no trailing garbage). +var paramsRe = regexp.MustCompile(`^m=([0-9]+),t=([0-9]+),p=([0-9]+)$`) + func HashPassword(password string) (string, error) { salt := make([]byte, argonSaltLen) if _, err := rand.Read(salt); err != nil { @@ -34,11 +46,40 @@ func VerifyPassword(encoded, password string) (bool, error) { if len(parts) != 6 || parts[1] != "argon2id" { return false, fmt.Errorf("auth: bad hash format") } - var m, t uint32 - var p uint8 - if _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &m, &t, &p); err != nil { - return false, err + if parts[2] != "v=19" { + return false, fmt.Errorf("auth: unsupported version") } + + matches := paramsRe.FindStringSubmatch(parts[3]) + if matches == nil { + return false, fmt.Errorf("auth: bad hash format") + } + m64, err := strconv.ParseUint(matches[1], 10, 32) + if err != nil { + return false, fmt.Errorf("auth: bad hash format") + } + t64, err := strconv.ParseUint(matches[2], 10, 32) + if err != nil { + return false, fmt.Errorf("auth: bad hash format") + } + p64, err := strconv.ParseUint(matches[3], 10, 8) + if err != nil { + return false, fmt.Errorf("auth: bad hash format") + } + m, t, p := uint32(m64), uint32(t64), uint8(p64) + + // argon2.IDKey panics if time<1 ("number of rounds too small") or + // threads<1 ("parallelism degree too low"). Reject before calling it. + // Minimum memory per argon2 spec is 8*parallelism (in KiB). + if t < 1 || p < 1 || m < 8*uint32(p) { + return false, fmt.Errorf("auth: bad hash params") + } + // Upper bounds guard against DoS via inflated parameters in a + // corrupted or attacker-controlled stored hash. + if m > argonMaxMemoryKiB || t > argonMaxTime || p > argonMaxThreads { + return false, fmt.Errorf("auth: bad hash params") + } + salt, err := base64.RawStdEncoding.DecodeString(parts[4]) if err != nil { return false, err diff --git a/internal/auth/password_test.go b/internal/auth/password_test.go index 61c151e..035e2d8 100644 --- a/internal/auth/password_test.go +++ b/internal/auth/password_test.go @@ -27,3 +27,61 @@ func TestHashNonDeterministic(t *testing.T) { t.Fatal("salt must randomize hash") } } + +func TestVerifyPasswordBadTimeDoesNotPanic(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Fatalf("VerifyPassword panicked: %v", r) + } + }() + encoded := "$argon2id$v=19$m=65536,t=0,p=4$c29tZXNhbHRzb21lc2FsdA$c29tZWhhc2hzb21laGFzaA" + ok, err := VerifyPassword(encoded, "anything") + if err == nil { + t.Fatal("expected error for t=0, got nil") + } + if ok { + t.Fatal("expected ok=false for t=0") + } +} + +func TestVerifyPasswordBadThreadsDoesNotPanic(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Fatalf("VerifyPassword panicked: %v", r) + } + }() + encoded := "$argon2id$v=19$m=65536,t=1,p=0$c29tZXNhbHRzb21lc2FsdA$c29tZWhhc2hzb21laGFzaA" + ok, err := VerifyPassword(encoded, "anything") + if err == nil { + t.Fatal("expected error for p=0, got nil") + } + if ok { + t.Fatal("expected ok=false for p=0") + } +} + +func TestVerifyPasswordUnsupportedVersion(t *testing.T) { + encoded := "$argon2id$v=18$m=65536,t=1,p=4$c29tZXNhbHRzb21lc2FsdA$c29tZWhhc2hzb21laGFzaA" + ok, err := VerifyPassword(encoded, "anything") + if err == nil { + t.Fatal("expected error for unsupported version, got nil") + } + if ok { + t.Fatal("expected ok=false for unsupported version") + } +} + +func TestVerifyPasswordGarbageFormatDoesNotPanic(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Fatalf("VerifyPassword panicked: %v", r) + } + }() + ok, err := VerifyPassword("notahash", "anything") + if err == nil { + t.Fatal("expected error for garbage format, got nil") + } + if ok { + t.Fatal("expected ok=false for garbage format") + } +} From aa0ef1c6a9421f4775dd6d59ba30dcbbb1a05429 Mon Sep 17 00:00:00 2001 From: Vassiliy Yegorov Date: Fri, 3 Jul 2026 20:11:00 +0700 Subject: [PATCH 04/10] =?UTF-8?q?feat(api):=20auth-=D1=85=D0=B5=D0=BD?= =?UTF-8?q?=D0=B4=D0=BB=D0=B5=D1=80=D1=8B=20register/login/logout/me=20+?= =?UTF-8?q?=20session=20cookie?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/api/api.go | 35 ++++- internal/api/auth_handlers.go | 154 +++++++++++++++++++++ internal/api/auth_test.go | 250 ++++++++++++++++++++++++++++++++++ internal/api/dto.go | 39 +++++- 4 files changed, 473 insertions(+), 5 deletions(-) create mode 100644 internal/api/auth_handlers.go create mode 100644 internal/api/auth_test.go diff --git a/internal/api/api.go b/internal/api/api.go index c3cff8f..9de9798 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -3,6 +3,7 @@ package api import ( "context" "net/http" + "time" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" @@ -54,12 +55,31 @@ type ProviderRegistry interface { ByName(name string) (provider.Provider, error) } +// AuthStore is the persistence surface the auth handlers depend on. +// *store.Store satisfies it directly (see internal/store/store.go); tests +// can supply their own mock. +type AuthStore interface { + RegisterUser(ctx context.Context, email, passwordHash string) (store.User, store.Project, error) + GetUserByEmail(ctx context.Context, email string) (store.User, error) + GetUserProject(ctx context.Context, userID uuid.UUID) (store.Project, error) +} + +// SessionManager creates/validates/destroys login sessions. *auth.Sessions +// satisfies it directly (see internal/auth/session.go). +type SessionManager interface { + Create(ctx context.Context, userID uuid.UUID) (string, time.Time, error) + Validate(ctx context.Context, token string) (uuid.UUID, error) + Destroy(ctx context.Context, token string) error +} + // API holds handler dependencies. type API struct { - Svc CheckApplier - Store TenantStore - Cipher Cipher - Reg ProviderRegistry + Svc CheckApplier + Store TenantStore + Cipher Cipher + Reg ProviderRegistry + Auth AuthStore + Sessions SessionManager } func NewRouter(a *API) http.Handler { @@ -67,6 +87,13 @@ func NewRouter(a *API) http.Handler { r.Use(middleware.RequestID) r.Use(middleware.Recoverer) + r.Route("/api/v1/auth", func(r chi.Router) { + r.Post("/register", a.handleRegister) + r.Post("/login", a.handleLogin) + r.Post("/logout", a.handleLogout) // защитится RequireAuth в Task 4 + r.Get("/me", a.handleMe) // защитится RequireAuth в Task 4 + }) + r.Route("/api/v1/projects/{pid}", func(r chi.Router) { r.Route("/domains", func(r chi.Router) { r.Post("/", a.handleCreateDomain) diff --git a/internal/api/auth_handlers.go b/internal/api/auth_handlers.go new file mode 100644 index 0000000..5ec3c04 --- /dev/null +++ b/internal/api/auth_handlers.go @@ -0,0 +1,154 @@ +package api + +import ( + "context" + "log" + "net/http" + "time" + + "github.com/google/uuid" + + "github.com/vasyakrg/dns-autoresolver/internal/auth" +) + +const sessionCookieName = "session" + +// ctxKeyUserID is a private context key carrying the authenticated user's ID. +// Task 4's RequireAuth middleware sets it after validating the session +// cookie; handleMe reads it back. +type ctxKeyUserID struct{} + +// userIDFromContext extracts the authenticated user ID set by RequireAuth +// (Task 4). Until that middleware is wired in, tests set it directly via +// context.WithValue. +func userIDFromContext(ctx context.Context) (uuid.UUID, bool) { + id, ok := ctx.Value(ctxKeyUserID{}).(uuid.UUID) + return id, ok +} + +func setSessionCookie(w http.ResponseWriter, token string, exp time.Time) { + http.SetCookie(w, &http.Cookie{ + Name: sessionCookieName, Value: token, Path: "/", + HttpOnly: true, Secure: true, SameSite: http.SameSiteLaxMode, Expires: exp, + }) +} + +func clearSessionCookie(w http.ResponseWriter) { + http.SetCookie(w, &http.Cookie{ + Name: sessionCookieName, Value: "", Path: "/", + HttpOnly: true, Secure: true, SameSite: http.SameSiteLaxMode, MaxAge: -1, + }) +} + +func (a *API) handleRegister(w http.ResponseWriter, r *http.Request) { + var req registerRequest + if !decodeBody(w, r, &req) { + return + } + if req.Email == "" || req.Password == "" { + writeErr(w, http.StatusBadRequest, "email and password are required") + return + } + + hash, err := auth.HashPassword(req.Password) + if err != nil { + log.Printf("api: hash password failed: %v", err) + writeErr(w, http.StatusInternalServerError, "internal error") + return + } + + u, p, err := a.Auth.RegisterUser(r.Context(), req.Email, hash) + if err != nil { + log.Printf("api: register user failed: %v", err) + writeErr(w, http.StatusInternalServerError, "internal error") + return + } + + token, exp, err := a.Sessions.Create(r.Context(), u.ID) + if err != nil { + log.Printf("api: create session failed: %v", err) + writeErr(w, http.StatusInternalServerError, "internal error") + return + } + + setSessionCookie(w, token, exp) + writeJSON(w, http.StatusOK, toAuthResponse(u, p)) +} + +// invalidCredentials is deliberately identical for "no such user" and "wrong +// password" — disclosing which one occurred would let an attacker enumerate +// registered emails. +func invalidCredentials(w http.ResponseWriter) { + writeErr(w, http.StatusUnauthorized, "invalid credentials") +} + +func (a *API) handleLogin(w http.ResponseWriter, r *http.Request) { + var req loginRequest + if !decodeBody(w, r, &req) { + return + } + + u, err := a.Auth.GetUserByEmail(r.Context(), req.Email) + if err != nil { + invalidCredentials(w) + return + } + + ok, err := auth.VerifyPassword(u.PasswordHash, req.Password) + if err != nil || !ok { + invalidCredentials(w) + return + } + + p, err := a.Auth.GetUserProject(r.Context(), u.ID) + if err != nil { + log.Printf("api: get user project failed: %v", err) + writeErr(w, http.StatusInternalServerError, "internal error") + return + } + + token, exp, err := a.Sessions.Create(r.Context(), u.ID) + if err != nil { + log.Printf("api: create session failed: %v", err) + writeErr(w, http.StatusInternalServerError, "internal error") + return + } + + setSessionCookie(w, token, exp) + writeJSON(w, http.StatusOK, toAuthResponse(u, p)) +} + +func (a *API) handleLogout(w http.ResponseWriter, r *http.Request) { + if c, err := r.Cookie(sessionCookieName); err == nil && c.Value != "" { + if err := a.Sessions.Destroy(r.Context(), c.Value); err != nil { + log.Printf("api: destroy session failed: %v", err) + } + } + clearSessionCookie(w) + writeJSON(w, http.StatusOK, map[string]string{"status": "ok"}) +} + +// handleMe returns the authenticated caller's identity + default project. +// The user ID comes from the request context, set by Task 4's RequireAuth +// middleware after validating the session cookie (tests set it directly via +// context.WithValue in the interim). AuthStore has no GetUserByID — the +// email field is intentionally left empty here; see task-3-report.md. +func (a *API) handleMe(w http.ResponseWriter, r *http.Request) { + userID, ok := userIDFromContext(r.Context()) + if !ok { + writeErr(w, http.StatusUnauthorized, "authentication required") + return + } + + p, err := a.Auth.GetUserProject(r.Context(), userID) + if err != nil { + log.Printf("api: get user project failed: %v", err) + writeErr(w, http.StatusInternalServerError, "internal error") + return + } + + writeJSON(w, http.StatusOK, authResponse{ + User: userResponse{ID: userID.String()}, + Project: projectResponse{ID: p.ID.String(), Name: p.Name}, + }) +} diff --git a/internal/api/auth_test.go b/internal/api/auth_test.go new file mode 100644 index 0000000..13ecbb9 --- /dev/null +++ b/internal/api/auth_test.go @@ -0,0 +1,250 @@ +package api + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/google/uuid" + + "github.com/vasyakrg/dns-autoresolver/internal/auth" + "github.com/vasyakrg/dns-autoresolver/internal/store" +) + +// --- mocks --- + +type mockAuthStore struct { + registerUserFn func(ctx context.Context, email, passwordHash string) (store.User, store.Project, error) + getUserByEmailFn func(ctx context.Context, email string) (store.User, error) + getUserProjectFn func(ctx context.Context, userID uuid.UUID) (store.Project, error) +} + +func (m *mockAuthStore) RegisterUser(ctx context.Context, email, passwordHash string) (store.User, store.Project, error) { + return m.registerUserFn(ctx, email, passwordHash) +} + +func (m *mockAuthStore) GetUserByEmail(ctx context.Context, email string) (store.User, error) { + return m.getUserByEmailFn(ctx, email) +} + +func (m *mockAuthStore) GetUserProject(ctx context.Context, userID uuid.UUID) (store.Project, error) { + return m.getUserProjectFn(ctx, userID) +} + +type mockSessionManager struct { + createFn func(ctx context.Context, userID uuid.UUID) (string, time.Time, error) + + destroyCalled bool + destroyToken string + destroyErr error +} + +func (m *mockSessionManager) Create(ctx context.Context, userID uuid.UUID) (string, time.Time, error) { + return m.createFn(ctx, userID) +} + +func (m *mockSessionManager) Validate(context.Context, string) (uuid.UUID, error) { + return uuid.Nil, nil +} + +func (m *mockSessionManager) Destroy(ctx context.Context, token string) error { + m.destroyCalled = true + m.destroyToken = token + return m.destroyErr +} + +func newTestAuthAPI() (*API, *mockAuthStore, *mockSessionManager) { + authStore := &mockAuthStore{} + sessions := &mockSessionManager{ + createFn: func(_ context.Context, userID uuid.UUID) (string, time.Time, error) { + return "test-token", time.Now().Add(time.Hour), nil + }, + } + return &API{Auth: authStore, Sessions: sessions}, authStore, sessions +} + +func findCookie(resp *http.Response, name string) *http.Cookie { + for _, c := range resp.Cookies() { + if c.Name == name { + return c + } + } + return nil +} + +// --- register --- + +func TestAuthRegister_Success(t *testing.T) { + a, authStore, _ := newTestAuthAPI() + userID := uuid.New() + projectID := uuid.New() + authStore.registerUserFn = func(_ context.Context, email, passwordHash string) (store.User, store.Project, error) { + if passwordHash == "" { + t.Fatal("expected non-empty password hash passed to RegisterUser") + } + return store.User{ID: userID, Email: email, PasswordHash: passwordHash}, + store.Project{ID: projectID, UserID: userID, Name: "default"}, nil + } + + router := NewRouter(a) + body := `{"email":"alice@example.com","password":"correct-horse"}` + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", strings.NewReader(body)) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status %d, body %s", w.Code, w.Body.String()) + } + + resp := w.Result() + cookie := findCookie(resp, sessionCookieName) + if cookie == nil { + t.Fatal("expected session cookie to be set") + } + if cookie.Value != "test-token" { + t.Fatalf("unexpected cookie value: %q", cookie.Value) + } + + if strings.Contains(w.Body.String(), "password") { + t.Fatalf("response body must not contain password/password_hash: %s", w.Body.String()) + } + + var got authResponse + if err := json.Unmarshal(w.Body.Bytes(), &got); err != nil { + t.Fatal(err) + } + if got.User.ID != userID.String() || got.User.Email != "alice@example.com" { + t.Fatalf("unexpected user in response: %+v", got.User) + } + if got.Project.ID != projectID.String() || got.Project.Name != "default" { + t.Fatalf("unexpected project in response: %+v", got.Project) + } +} + +// --- login --- + +func TestAuthLogin_CorrectPassword(t *testing.T) { + a, authStore, _ := newTestAuthAPI() + hash, err := auth.HashPassword("correct-horse") + if err != nil { + t.Fatal(err) + } + userID := uuid.New() + projectID := uuid.New() + authStore.getUserByEmailFn = func(_ context.Context, email string) (store.User, error) { + return store.User{ID: userID, Email: email, PasswordHash: hash}, nil + } + authStore.getUserProjectFn = func(_ context.Context, uid uuid.UUID) (store.Project, error) { + return store.Project{ID: projectID, UserID: uid, Name: "default"}, nil + } + + router := NewRouter(a) + body := `{"email":"bob@example.com","password":"correct-horse"}` + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", strings.NewReader(body)) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status %d, body %s", w.Code, w.Body.String()) + } + if findCookie(w.Result(), sessionCookieName) == nil { + t.Fatal("expected session cookie to be set") + } + if strings.Contains(w.Body.String(), "password") { + t.Fatalf("response body must not contain password/password_hash: %s", w.Body.String()) + } +} + +func TestAuthLogin_WrongPassword(t *testing.T) { + a, authStore, _ := newTestAuthAPI() + hash, err := auth.HashPassword("correct-horse") + if err != nil { + t.Fatal(err) + } + authStore.getUserByEmailFn = func(_ context.Context, email string) (store.User, error) { + return store.User{ID: uuid.New(), Email: email, PasswordHash: hash}, nil + } + + router := NewRouter(a) + body := `{"email":"bob@example.com","password":"wrong-password"}` + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", strings.NewReader(body)) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + assertInvalidCredentials(t, w) +} + +func TestAuthLogin_UnknownEmail(t *testing.T) { + a, authStore, _ := newTestAuthAPI() + authStore.getUserByEmailFn = func(_ context.Context, email string) (store.User, error) { + return store.User{}, errNoRowsForTest + } + + router := NewRouter(a) + body := `{"email":"nobody@example.com","password":"whatever"}` + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", strings.NewReader(body)) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + assertInvalidCredentials(t, w) +} + +func assertInvalidCredentials(t *testing.T, w *httptest.ResponseRecorder) { + t.Helper() + if w.Code != http.StatusUnauthorized { + t.Fatalf("status %d, body %s", w.Code, w.Body.String()) + } + var got map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &got); err != nil { + t.Fatal(err) + } + if got["error"] != "invalid credentials" { + t.Fatalf(`expected error "invalid credentials", got %q`, got["error"]) + } + if findCookie(w.Result(), sessionCookieName) != nil { + t.Fatal("expected no session cookie on failed login") + } +} + +// --- logout --- + +func TestAuthLogout_ClearsSessionAndDestroys(t *testing.T) { + a, _, sessions := newTestAuthAPI() + + router := NewRouter(a) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/logout", nil) + req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: "some-token"}) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status %d, body %s", w.Code, w.Body.String()) + } + if !sessions.destroyCalled { + t.Fatal("expected Sessions.Destroy to be called") + } + if sessions.destroyToken != "some-token" { + t.Fatalf("expected Destroy called with cookie token, got %q", sessions.destroyToken) + } + + cookie := findCookie(w.Result(), sessionCookieName) + if cookie == nil { + t.Fatal("expected a session cookie in the response (clearing cookie)") + } + if cookie.MaxAge > 0 { + t.Fatalf("expected cookie MaxAge <= 0 (cleared), got %d", cookie.MaxAge) + } +} + +// errNoRowsForTest stands in for a "not found" error a real store would +// return (e.g. pgx.ErrNoRows) — handlers must not distinguish it from any +// other GetUserByEmail failure in the response they send. +var errNoRowsForTest = ¬FoundErr{} + +type notFoundErr struct{} + +func (*notFoundErr) Error() string { return "not found" } diff --git a/internal/api/dto.go b/internal/api/dto.go index fdc5a13..b127976 100644 --- a/internal/api/dto.go +++ b/internal/api/dto.go @@ -1,6 +1,43 @@ package api -import "github.com/vasyakrg/dns-autoresolver/internal/diff" +import ( + "github.com/vasyakrg/dns-autoresolver/internal/diff" + "github.com/vasyakrg/dns-autoresolver/internal/store" +) + +type registerRequest struct { + Email string `json:"email"` + Password string `json:"password"` +} + +type loginRequest struct { + Email string `json:"email"` + Password string `json:"password"` +} + +// userResponse and projectResponse deliberately expose only id/email and +// id/name — password_hash must never reach a client response. +type userResponse struct { + ID string `json:"id"` + Email string `json:"email"` +} + +type projectResponse struct { + ID string `json:"id"` + Name string `json:"name"` +} + +type authResponse struct { + User userResponse `json:"user"` + Project projectResponse `json:"project"` +} + +func toAuthResponse(u store.User, p store.Project) authResponse { + return authResponse{ + User: userResponse{ID: u.ID.String(), Email: u.Email}, + Project: projectResponse{ID: p.ID.String(), Name: p.Name}, + } +} type applyRequest struct { ApplyUpdates bool `json:"applyUpdates"` From 35ffe73ae33d970ceda989cd2999aadf8de85071 Mon Sep 17 00:00:00 2001 From: Vassiliy Yegorov Date: Fri, 3 Jul 2026 20:29:05 +0700 Subject: [PATCH 05/10] =?UTF-8?q?fix(auth):=20wiring=20Auth/Sessions,=20?= =?UTF-8?q?=D0=BD=D0=BE=D1=80=D0=BC=D0=B0=D0=BB=D0=B8=D0=B7=D0=B0=D1=86?= =?UTF-8?q?=D0=B8=D1=8F=20email,=20GetUserByID=20=D0=B4=D0=BB=D1=8F=20/me,?= =?UTF-8?q?=20409=20=D0=BD=D0=B0=20=D0=B4=D1=83=D0=B1=D0=BB=D1=8C,=20timin?= =?UTF-8?q?g-guard=20=D0=BB=D0=BE=D0=B3=D0=B8=D0=BD=D0=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cmd/server/main.go | 9 ++- internal/api/api.go | 1 + internal/api/auth_handlers.go | 57 ++++++++++--- internal/api/auth_test.go | 132 +++++++++++++++++++++++++++++++ internal/store/auth_test.go | 36 +++++++++ internal/store/db/users.sql.go | 16 ++++ internal/store/queries/users.sql | 3 + internal/store/tenant.go | 21 +++++ 8 files changed, 265 insertions(+), 10 deletions(-) diff --git a/cmd/server/main.go b/cmd/server/main.go index 308c423..7ab73fc 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -5,10 +5,12 @@ import ( "log" "net/http" "strings" + "time" "github.com/jackc/pgx/v5/pgxpool" "github.com/vasyakrg/dns-autoresolver/internal/api" + "github.com/vasyakrg/dns-autoresolver/internal/auth" "github.com/vasyakrg/dns-autoresolver/internal/config" "github.com/vasyakrg/dns-autoresolver/internal/crypto" "github.com/vasyakrg/dns-autoresolver/internal/provider/registry" @@ -18,6 +20,10 @@ import ( "github.com/vasyakrg/dns-autoresolver/internal/web" ) +// sessionTTL is how long a login session cookie remains valid before the +// user must re-authenticate. +const sessionTTL = 720 * time.Hour + // isAPIPath reports whether path must be routed to the API router rather // than the SPA. "/api" (no trailing slash) counts as an API path too — // only strings.HasPrefix(path, "/api/") would otherwise miss it and fall @@ -46,12 +52,13 @@ func main() { log.Fatalf("cipher: %v", err) } st := store.New(pool) + sessions := auth.NewSessions(st, sessionTTL) reg := registry.New() reg.Register(selectel.New()) svc := service.New(st, st, reg, cipher) - a := &api.API{Svc: svc, Store: st, Cipher: cipher, Reg: reg} + a := &api.API{Svc: svc, Store: st, Cipher: cipher, Reg: reg, Auth: st, Sessions: sessions} apiRouter := api.NewRouter(a) webHandler, err := web.Handler() diff --git a/internal/api/api.go b/internal/api/api.go index 9de9798..16710cd 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -61,6 +61,7 @@ type ProviderRegistry interface { type AuthStore interface { RegisterUser(ctx context.Context, email, passwordHash string) (store.User, store.Project, error) GetUserByEmail(ctx context.Context, email string) (store.User, error) + GetUserByID(ctx context.Context, userID uuid.UUID) (store.User, error) GetUserProject(ctx context.Context, userID uuid.UUID) (store.Project, error) } diff --git a/internal/api/auth_handlers.go b/internal/api/auth_handlers.go index 5ec3c04..76d3c38 100644 --- a/internal/api/auth_handlers.go +++ b/internal/api/auth_handlers.go @@ -2,17 +2,43 @@ package api import ( "context" + "errors" "log" "net/http" + "strings" "time" "github.com/google/uuid" "github.com/vasyakrg/dns-autoresolver/internal/auth" + "github.com/vasyakrg/dns-autoresolver/internal/store" ) const sessionCookieName = "session" +// dummyPasswordHash is a valid-format argon2 hash with no real matching +// password. handleLogin runs VerifyPassword against it whenever the email +// lookup fails, so a login attempt for an unregistered email takes the same +// wall-clock time as one for a registered email with a wrong password — +// otherwise the timing difference would let an attacker enumerate which +// emails are registered. +var dummyPasswordHash string + +func init() { + h, err := auth.HashPassword("dns-autoresolver-timing-guard-dummy") + if err != nil { + panic("api: failed to initialize dummy password hash: " + err.Error()) + } + dummyPasswordHash = h +} + +// normalizeEmail trims surrounding whitespace and lowercases the email so +// storage and lookup are always consistent regardless of how the client +// cased or padded the input. +func normalizeEmail(email string) string { + return strings.ToLower(strings.TrimSpace(email)) +} + // ctxKeyUserID is a private context key carrying the authenticated user's ID. // Task 4's RequireAuth middleware sets it after validating the session // cookie; handleMe reads it back. @@ -45,7 +71,8 @@ func (a *API) handleRegister(w http.ResponseWriter, r *http.Request) { if !decodeBody(w, r, &req) { return } - if req.Email == "" || req.Password == "" { + email := normalizeEmail(req.Email) + if email == "" || req.Password == "" { writeErr(w, http.StatusBadRequest, "email and password are required") return } @@ -57,8 +84,12 @@ func (a *API) handleRegister(w http.ResponseWriter, r *http.Request) { return } - u, p, err := a.Auth.RegisterUser(r.Context(), req.Email, hash) + u, p, err := a.Auth.RegisterUser(r.Context(), email, hash) if err != nil { + if errors.Is(err, store.ErrEmailTaken) { + writeErr(w, http.StatusConflict, "email already registered") + return + } log.Printf("api: register user failed: %v", err) writeErr(w, http.StatusInternalServerError, "internal error") return @@ -87,9 +118,14 @@ func (a *API) handleLogin(w http.ResponseWriter, r *http.Request) { if !decodeBody(w, r, &req) { return } + email := normalizeEmail(req.Email) - u, err := a.Auth.GetUserByEmail(r.Context(), req.Email) + u, err := a.Auth.GetUserByEmail(r.Context(), email) if err != nil { + // No such user: still spend the argon2 verification cost against a + // fixed dummy hash (see dummyPasswordHash) so this path isn't + // distinguishable by timing from a wrong-password rejection below. + _, _ = auth.VerifyPassword(dummyPasswordHash, req.Password) invalidCredentials(w) return } @@ -131,8 +167,7 @@ func (a *API) handleLogout(w http.ResponseWriter, r *http.Request) { // handleMe returns the authenticated caller's identity + default project. // The user ID comes from the request context, set by Task 4's RequireAuth // middleware after validating the session cookie (tests set it directly via -// context.WithValue in the interim). AuthStore has no GetUserByID — the -// email field is intentionally left empty here; see task-3-report.md. +// context.WithValue in the interim). func (a *API) handleMe(w http.ResponseWriter, r *http.Request) { userID, ok := userIDFromContext(r.Context()) if !ok { @@ -140,6 +175,13 @@ func (a *API) handleMe(w http.ResponseWriter, r *http.Request) { return } + u, err := a.Auth.GetUserByID(r.Context(), userID) + if err != nil { + log.Printf("api: get user by id failed: %v", err) + writeErr(w, http.StatusInternalServerError, "internal error") + return + } + p, err := a.Auth.GetUserProject(r.Context(), userID) if err != nil { log.Printf("api: get user project failed: %v", err) @@ -147,8 +189,5 @@ func (a *API) handleMe(w http.ResponseWriter, r *http.Request) { return } - writeJSON(w, http.StatusOK, authResponse{ - User: userResponse{ID: userID.String()}, - Project: projectResponse{ID: p.ID.String(), Name: p.Name}, - }) + writeJSON(w, http.StatusOK, toAuthResponse(u, p)) } diff --git a/internal/api/auth_test.go b/internal/api/auth_test.go index 13ecbb9..58f9402 100644 --- a/internal/api/auth_test.go +++ b/internal/api/auth_test.go @@ -20,6 +20,7 @@ import ( type mockAuthStore struct { registerUserFn func(ctx context.Context, email, passwordHash string) (store.User, store.Project, error) getUserByEmailFn func(ctx context.Context, email string) (store.User, error) + getUserByIDFn func(ctx context.Context, userID uuid.UUID) (store.User, error) getUserProjectFn func(ctx context.Context, userID uuid.UUID) (store.Project, error) } @@ -31,6 +32,10 @@ func (m *mockAuthStore) GetUserByEmail(ctx context.Context, email string) (store return m.getUserByEmailFn(ctx, email) } +func (m *mockAuthStore) GetUserByID(ctx context.Context, userID uuid.UUID) (store.User, error) { + return m.getUserByIDFn(ctx, userID) +} + func (m *mockAuthStore) GetUserProject(ctx context.Context, userID uuid.UUID) (store.Project, error) { return m.getUserProjectFn(ctx, userID) } @@ -125,6 +130,59 @@ func TestAuthRegister_Success(t *testing.T) { } } +// TestAuthRegister_NormalizesEmail verifies the fix for the email-consistency +// gap: a padded/mixed-case email is trimmed+lowercased before it reaches the +// store, so storage and later lookups are always consistent. +func TestAuthRegister_NormalizesEmail(t *testing.T) { + a, authStore, _ := newTestAuthAPI() + userID := uuid.New() + var gotEmail string + authStore.registerUserFn = func(_ context.Context, email, passwordHash string) (store.User, store.Project, error) { + gotEmail = email + return store.User{ID: userID, Email: email}, store.Project{ID: uuid.New(), UserID: userID, Name: "default"}, nil + } + + router := NewRouter(a) + body := `{"email":" Alice@X.com ","password":"correct-horse"}` + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", strings.NewReader(body)) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status %d, body %s", w.Code, w.Body.String()) + } + if gotEmail != "alice@x.com" { + t.Fatalf("expected normalized email passed to RegisterUser, got %q", gotEmail) + } +} + +// TestAuthRegister_DuplicateEmailReturns409 verifies the fix for the +// duplicate-registration gap: RegisterUser reporting store.ErrEmailTaken +// must surface as 409, not a generic 500. +func TestAuthRegister_DuplicateEmailReturns409(t *testing.T) { + a, authStore, _ := newTestAuthAPI() + authStore.registerUserFn = func(context.Context, string, string) (store.User, store.Project, error) { + return store.User{}, store.Project{}, store.ErrEmailTaken + } + + router := NewRouter(a) + body := `{"email":"dup@example.com","password":"correct-horse"}` + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", strings.NewReader(body)) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusConflict { + t.Fatalf("status %d, body %s", w.Code, w.Body.String()) + } + var got map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &got); err != nil { + t.Fatal(err) + } + if got["error"] != "email already registered" { + t.Fatalf(`expected error "email already registered", got %q`, got["error"]) + } +} + // --- login --- func TestAuthLogin_CorrectPassword(t *testing.T) { @@ -159,6 +217,40 @@ func TestAuthLogin_CorrectPassword(t *testing.T) { } } +// TestAuthLogin_NormalizesEmail verifies that a login for a padded/mixed-case +// email reaches GetUserByEmail already trimmed+lowercased — the same +// normalization applied on register, so "Alice@X.com" at registration and +// "alice@x.com" at login resolve to the same account. +func TestAuthLogin_NormalizesEmail(t *testing.T) { + a, authStore, _ := newTestAuthAPI() + hash, err := auth.HashPassword("correct-horse") + if err != nil { + t.Fatal(err) + } + userID := uuid.New() + var gotEmail string + authStore.getUserByEmailFn = func(_ context.Context, email string) (store.User, error) { + gotEmail = email + return store.User{ID: userID, Email: email, PasswordHash: hash}, nil + } + authStore.getUserProjectFn = func(_ context.Context, uid uuid.UUID) (store.Project, error) { + return store.Project{ID: uuid.New(), UserID: uid, Name: "default"}, nil + } + + router := NewRouter(a) + body := `{"email":" Alice@X.com ","password":"correct-horse"}` + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", strings.NewReader(body)) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status %d, body %s", w.Code, w.Body.String()) + } + if gotEmail != "alice@x.com" { + t.Fatalf("expected normalized email passed to GetUserByEmail, got %q", gotEmail) + } +} + func TestAuthLogin_WrongPassword(t *testing.T) { a, authStore, _ := newTestAuthAPI() hash, err := auth.HashPassword("correct-horse") @@ -240,6 +332,46 @@ func TestAuthLogout_ClearsSessionAndDestroys(t *testing.T) { } } +// --- me --- + +// TestAuthMe_ReturnsRealEmail verifies the fix for the /me gap: the handler +// now resolves the authenticated user via GetUserByID and returns their real +// email, instead of leaving it blank. +func TestAuthMe_ReturnsRealEmail(t *testing.T) { + a, authStore, _ := newTestAuthAPI() + userID := uuid.New() + projectID := uuid.New() + authStore.getUserByIDFn = func(_ context.Context, id uuid.UUID) (store.User, error) { + if id != userID { + t.Fatalf("unexpected user id: %s", id) + } + return store.User{ID: userID, Email: "me@example.com"}, nil + } + authStore.getUserProjectFn = func(_ context.Context, uid uuid.UUID) (store.Project, error) { + return store.Project{ID: projectID, UserID: uid, Name: "default"}, nil + } + + router := NewRouter(a) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/me", nil) + req = req.WithContext(context.WithValue(req.Context(), ctxKeyUserID{}, userID)) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status %d, body %s", w.Code, w.Body.String()) + } + var got authResponse + if err := json.Unmarshal(w.Body.Bytes(), &got); err != nil { + t.Fatal(err) + } + if got.User.ID != userID.String() || got.User.Email != "me@example.com" { + t.Fatalf("unexpected user in /me response: %+v", got.User) + } + if got.Project.ID != projectID.String() { + t.Fatalf("unexpected project in /me response: %+v", got.Project) + } +} + // errNoRowsForTest stands in for a "not found" error a real store would // return (e.g. pgx.ErrNoRows) — handlers must not distinguish it from any // other GetUserByEmail failure in the response they send. diff --git a/internal/store/auth_test.go b/internal/store/auth_test.go index 2bedd35..21092df 100644 --- a/internal/store/auth_test.go +++ b/internal/store/auth_test.go @@ -1,6 +1,7 @@ package store import ( + "errors" "testing" "time" @@ -57,6 +58,41 @@ func TestGetUserByEmail_FindsRegisteredUser(t *testing.T) { } } +// TestRegisterUser_DuplicateEmailReturnsErrEmailTaken verifies the fix for +// the duplicate-registration gap: a second RegisterUser call for an +// already-taken email must fail with the ErrEmailTaken sentinel (mapped from +// the UNIQUE constraint violation on users.email), not a generic pgx error. +func TestRegisterUser_DuplicateEmailReturnsErrEmailTaken(t *testing.T) { + s, ctx := newStore(t) + + if _, _, err := s.RegisterUser(ctx, "dup@example.com", "argon2-hash"); err != nil { + t.Fatal(err) + } + + if _, _, err := s.RegisterUser(ctx, "dup@example.com", "argon2-hash"); !errors.Is(err, ErrEmailTaken) { + t.Fatalf("expected ErrEmailTaken, got %v", err) + } +} + +// TestGetUserByID_ReturnsUser verifies the fix for the /me gap: GetUserByID +// returns the same user created by RegisterUser, including their real email. +func TestGetUserByID_ReturnsUser(t *testing.T) { + s, ctx := newStore(t) + + u, _, err := s.RegisterUser(ctx, "gina@example.com", "argon2-hash") + if err != nil { + t.Fatal(err) + } + + got, err := s.GetUserByID(ctx, u.ID) + if err != nil { + t.Fatal(err) + } + if got.ID != u.ID || got.Email != "gina@example.com" { + t.Fatalf("unexpected user: %+v", got) + } +} + // TestSessionLifecycle_CreateGetDelete verifies CreateSession + GetSessionUser // round-trips to the owning user ID, an expired session is excluded from // GetSessionUser, and DeleteSession removes the session. diff --git a/internal/store/db/users.sql.go b/internal/store/db/users.sql.go index c76c410..8fef920 100644 --- a/internal/store/db/users.sql.go +++ b/internal/store/db/users.sql.go @@ -48,3 +48,19 @@ func (q *Queries) GetUserByEmail(ctx context.Context, email string) (User, error ) return i, err } + +const getUserByID = `-- name: GetUserByID :one +SELECT id, email, created_at, password_hash FROM users WHERE id = $1 +` + +func (q *Queries) GetUserByID(ctx context.Context, id uuid.UUID) (User, error) { + row := q.db.QueryRow(ctx, getUserByID, id) + var i User + err := row.Scan( + &i.ID, + &i.Email, + &i.CreatedAt, + &i.PasswordHash, + ) + return i, err +} diff --git a/internal/store/queries/users.sql b/internal/store/queries/users.sql index d68dc45..31d7fde 100644 --- a/internal/store/queries/users.sql +++ b/internal/store/queries/users.sql @@ -3,3 +3,6 @@ INSERT INTO users (id, email, password_hash) VALUES ($1, $2, $3) RETURNING *; -- name: GetUserByEmail :one SELECT * FROM users WHERE email = $1; + +-- name: GetUserByID :one +SELECT * FROM users WHERE id = $1; diff --git a/internal/store/tenant.go b/internal/store/tenant.go index 37961a0..fccc469 100644 --- a/internal/store/tenant.go +++ b/internal/store/tenant.go @@ -7,6 +7,7 @@ import ( "github.com/google/uuid" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" "github.com/vasyakrg/dns-autoresolver/internal/provider" @@ -14,6 +15,11 @@ import ( "github.com/vasyakrg/dns-autoresolver/internal/store/dto" ) +// ErrEmailTaken is returned by RegisterUser when the email is already +// registered — a UNIQUE constraint violation (pgerrcode 23505) on +// users.email. +var ErrEmailTaken = errors.New("store: email already registered") + // Account/Template/Domain are provider-neutral domain structs returned by the // thin wrappers below, so callers (internal/api) never need to import // internal/store/db directly. @@ -279,6 +285,17 @@ func (s *Store) GetUserByEmail(ctx context.Context, email string) (User, error) return toUser(u), nil } +// GetUserByID looks up a user by primary key — used by handleMe (Task 3 +// hardening) to return the authenticated caller's real email instead of +// leaving it blank. +func (s *Store) GetUserByID(ctx context.Context, id uuid.UUID) (User, error) { + u, err := s.q.GetUserByID(ctx, id) + if err != nil { + return User{}, err + } + return toUser(u), nil +} + func (s *Store) CreateProjectForUser(ctx context.Context, userID uuid.UUID, name string) (Project, error) { p, err := s.q.CreateProject(ctx, db.CreateProjectParams{ID: uuid.New(), UserID: userID, Name: name}) if err != nil { @@ -337,6 +354,10 @@ func (s *Store) RegisterUser(ctx context.Context, email, passwordHash string) (U uid := uuid.New() dbu, err := q.CreateUser(ctx, db.CreateUserParams{ID: uid, Email: email, PasswordHash: ptr(passwordHash)}) if err != nil { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) && pgErr.Code == "23505" { + return User{}, Project{}, ErrEmailTaken + } return User{}, Project{}, err } dbp, err := q.CreateProject(ctx, db.CreateProjectParams{ID: uuid.New(), UserID: uid, Name: "default"}) From 4533b0ca25d3cdb9d4c847af27793b776c7cbcad Mon Sep 17 00:00:00 2001 From: Vassiliy Yegorov Date: Fri, 3 Jul 2026 20:47:40 +0700 Subject: [PATCH 06/10] =?UTF-8?q?feat(api):=20RequireAuth+RequireProjectAc?= =?UTF-8?q?cess=20middleware,=20IDOR-scope=20check/apply=20=D0=BF=D0=BE=20?= =?UTF-8?q?projectID?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/api/api.go | 18 +- internal/api/api_test.go | 21 ++- internal/api/auth_handlers.go | 23 +-- internal/api/auth_test.go | 30 +++- internal/api/handlers.go | 10 +- internal/api/middleware.go | 75 ++++++++ internal/api/middleware_test.go | 265 +++++++++++++++++++++++++++++ internal/api/tenant_handlers.go | 96 ++++------- internal/api/tenant_test.go | 48 +++--- internal/service/service.go | 16 +- internal/service/service_test.go | 10 +- internal/store/db/domains.sql.go | 11 +- internal/store/loader.go | 8 +- internal/store/loader_test.go | 4 +- internal/store/queries/domains.sql | 2 +- internal/store/store_test.go | 4 +- 16 files changed, 498 insertions(+), 143 deletions(-) create mode 100644 internal/api/middleware.go create mode 100644 internal/api/middleware_test.go diff --git a/internal/api/api.go b/internal/api/api.go index 16710cd..cd01cf4 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -18,8 +18,8 @@ import ( // CheckApplier is the service surface the API depends on. type CheckApplier interface { - Check(ctx context.Context, domainID uuid.UUID) (diff.Changeset, error) - Apply(ctx context.Context, domainID uuid.UUID, req service.ApplyRequest) (diff.Changeset, error) + Check(ctx context.Context, projectID, domainID uuid.UUID) (diff.Changeset, error) + Apply(ctx context.Context, projectID, domainID uuid.UUID, req service.ApplyRequest) (diff.Changeset, error) } // TenantStore is the narrow persistence surface the CRUD handlers depend on. @@ -63,6 +63,10 @@ type AuthStore interface { GetUserByEmail(ctx context.Context, email string) (store.User, error) GetUserByID(ctx context.Context, userID uuid.UUID) (store.User, error) GetUserProject(ctx context.Context, userID uuid.UUID) (store.Project, error) + // GetProjectOwned looks up projectID and returns it only if it's owned by + // userID — RequireProjectAccess uses this to reject foreign/nonexistent + // projects with 404 before any handler runs. + GetProjectOwned(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error) } // SessionManager creates/validates/destroys login sessions. *auth.Sessions @@ -91,11 +95,17 @@ func NewRouter(a *API) http.Handler { r.Route("/api/v1/auth", func(r chi.Router) { r.Post("/register", a.handleRegister) r.Post("/login", a.handleLogin) - r.Post("/logout", a.handleLogout) // защитится RequireAuth в Task 4 - r.Get("/me", a.handleMe) // защитится RequireAuth в Task 4 + r.Group(func(r chi.Router) { + r.Use(a.RequireAuth) + r.Post("/logout", a.handleLogout) + r.Get("/me", a.handleMe) + }) }) r.Route("/api/v1/projects/{pid}", func(r chi.Router) { + r.Use(a.RequireAuth) + r.Use(a.RequireProjectAccess) + r.Route("/domains", func(r chi.Router) { r.Post("/", a.handleCreateDomain) r.Get("/", a.handleListDomains) diff --git a/internal/api/api_test.go b/internal/api/api_test.go index 0a48ee0..876d06b 100644 --- a/internal/api/api_test.go +++ b/internal/api/api_test.go @@ -20,18 +20,23 @@ type mockCheckApplier struct { lastReq service.ApplyRequest } -func (m *mockCheckApplier) Check(context.Context, uuid.UUID) (diff.Changeset, error) { +func (m *mockCheckApplier) Check(context.Context, uuid.UUID, uuid.UUID) (diff.Changeset, error) { d := model.Record{Type: model.A, Name: "a.example.com.", TTL: 300, Values: []string{"1.1.1.1"}} return diff.Changeset{Diffs: []diff.RecordDiff{{Kind: diff.Add, Type: d.Type, Name: d.Name, Desired: &d}}}, nil } -func (m *mockCheckApplier) Apply(_ context.Context, _ uuid.UUID, req service.ApplyRequest) (diff.Changeset, error) { +func (m *mockCheckApplier) Apply(_ context.Context, _, _ uuid.UUID, req service.ApplyRequest) (diff.Changeset, error) { m.lastReq = req return diff.Changeset{}, nil } +// newTestAPI wires a fixed authenticated user who owns whatever project id +// is requested (via alwaysOwnedAuthStore/alwaysValidSessions in +// middleware_test.go) — these tests exercise check/apply behavior past the +// RequireAuth/RequireProjectAccess boundary, which is covered separately by +// middleware_test.go's own tests and the IDOR regression. func newTestAPI() (*API, *mockCheckApplier) { m := &mockCheckApplier{} - return &API{Svc: m}, m // остальные зависимости (store/cipher) nil — CRUD-тесты добавит реализатор + return &API{Svc: m, Auth: alwaysOwnedAuthStore(), Sessions: alwaysValidSessions(uuid.New())}, m } func TestCheckEndpoint(t *testing.T) { @@ -39,7 +44,7 @@ func TestCheckEndpoint(t *testing.T) { router := NewRouter(a) did := uuid.New().String() - req := httptest.NewRequest(http.MethodGet, + req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/check", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -62,7 +67,7 @@ func TestApplyDefaultsPruneFalse(t *testing.T) { did := uuid.New().String() body := `{"applyUpdates":true}` // applyPrunes отсутствует → false - req := httptest.NewRequest(http.MethodPost, + req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply", strings.NewReader(body)) w := httptest.NewRecorder() @@ -81,7 +86,7 @@ func TestApplyEmptyBodyOK(t *testing.T) { router := NewRouter(a) did := uuid.New().String() - req := httptest.NewRequest(http.MethodPost, + req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -100,7 +105,7 @@ func TestApplyMalformedBody(t *testing.T) { did := uuid.New().String() body := `{"applyUpdates":` - req := httptest.NewRequest(http.MethodPost, + req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/"+did+"/apply", strings.NewReader(body)) w := httptest.NewRecorder() @@ -114,7 +119,7 @@ func TestApplyMalformedBody(t *testing.T) { func TestApplyBadUUID(t *testing.T) { a, _ := newTestAPI() router := NewRouter(a) - req := httptest.NewRequest(http.MethodPost, + req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/00000000-0000-0000-0000-000000000002/domains/not-a-uuid/apply", bytes.NewReader([]byte(`{}`))) w := httptest.NewRecorder() diff --git a/internal/api/auth_handlers.go b/internal/api/auth_handlers.go index 76d3c38..60ff2d3 100644 --- a/internal/api/auth_handlers.go +++ b/internal/api/auth_handlers.go @@ -1,15 +1,12 @@ package api import ( - "context" "errors" "log" "net/http" "strings" "time" - "github.com/google/uuid" - "github.com/vasyakrg/dns-autoresolver/internal/auth" "github.com/vasyakrg/dns-autoresolver/internal/store" ) @@ -39,19 +36,6 @@ func normalizeEmail(email string) string { return strings.ToLower(strings.TrimSpace(email)) } -// ctxKeyUserID is a private context key carrying the authenticated user's ID. -// Task 4's RequireAuth middleware sets it after validating the session -// cookie; handleMe reads it back. -type ctxKeyUserID struct{} - -// userIDFromContext extracts the authenticated user ID set by RequireAuth -// (Task 4). Until that middleware is wired in, tests set it directly via -// context.WithValue. -func userIDFromContext(ctx context.Context) (uuid.UUID, bool) { - id, ok := ctx.Value(ctxKeyUserID{}).(uuid.UUID) - return id, ok -} - func setSessionCookie(w http.ResponseWriter, token string, exp time.Time) { http.SetCookie(w, &http.Cookie{ Name: sessionCookieName, Value: token, Path: "/", @@ -165,11 +149,10 @@ func (a *API) handleLogout(w http.ResponseWriter, r *http.Request) { } // handleMe returns the authenticated caller's identity + default project. -// The user ID comes from the request context, set by Task 4's RequireAuth -// middleware after validating the session cookie (tests set it directly via -// context.WithValue in the interim). +// The user ID comes from the request context, set by RequireAuth after +// validating the session cookie. func (a *API) handleMe(w http.ResponseWriter, r *http.Request) { - userID, ok := userIDFromContext(r.Context()) + userID, ok := userIDFrom(r.Context()) if !ok { writeErr(w, http.StatusUnauthorized, "authentication required") return diff --git a/internal/api/auth_test.go b/internal/api/auth_test.go index 58f9402..5525fc7 100644 --- a/internal/api/auth_test.go +++ b/internal/api/auth_test.go @@ -18,10 +18,11 @@ import ( // --- mocks --- type mockAuthStore struct { - registerUserFn func(ctx context.Context, email, passwordHash string) (store.User, store.Project, error) - getUserByEmailFn func(ctx context.Context, email string) (store.User, error) - getUserByIDFn func(ctx context.Context, userID uuid.UUID) (store.User, error) - getUserProjectFn func(ctx context.Context, userID uuid.UUID) (store.Project, error) + registerUserFn func(ctx context.Context, email, passwordHash string) (store.User, store.Project, error) + getUserByEmailFn func(ctx context.Context, email string) (store.User, error) + getUserByIDFn func(ctx context.Context, userID uuid.UUID) (store.User, error) + getUserProjectFn func(ctx context.Context, userID uuid.UUID) (store.Project, error) + getProjectOwnedFn func(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error) } func (m *mockAuthStore) RegisterUser(ctx context.Context, email, passwordHash string) (store.User, store.Project, error) { @@ -40,8 +41,13 @@ func (m *mockAuthStore) GetUserProject(ctx context.Context, userID uuid.UUID) (s return m.getUserProjectFn(ctx, userID) } +func (m *mockAuthStore) GetProjectOwned(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error) { + return m.getProjectOwnedFn(ctx, projectID, userID) +} + type mockSessionManager struct { - createFn func(ctx context.Context, userID uuid.UUID) (string, time.Time, error) + createFn func(ctx context.Context, userID uuid.UUID) (string, time.Time, error) + validateFn func(ctx context.Context, token string) (uuid.UUID, error) destroyCalled bool destroyToken string @@ -52,7 +58,10 @@ func (m *mockSessionManager) Create(ctx context.Context, userID uuid.UUID) (stri return m.createFn(ctx, userID) } -func (m *mockSessionManager) Validate(context.Context, string) (uuid.UUID, error) { +func (m *mockSessionManager) Validate(ctx context.Context, token string) (uuid.UUID, error) { + if m.validateFn != nil { + return m.validateFn(ctx, token) + } return uuid.Nil, nil } @@ -338,7 +347,7 @@ func TestAuthLogout_ClearsSessionAndDestroys(t *testing.T) { // now resolves the authenticated user via GetUserByID and returns their real // email, instead of leaving it blank. func TestAuthMe_ReturnsRealEmail(t *testing.T) { - a, authStore, _ := newTestAuthAPI() + a, authStore, sessions := newTestAuthAPI() userID := uuid.New() projectID := uuid.New() authStore.getUserByIDFn = func(_ context.Context, id uuid.UUID) (store.User, error) { @@ -350,10 +359,15 @@ func TestAuthMe_ReturnsRealEmail(t *testing.T) { authStore.getUserProjectFn = func(_ context.Context, uid uuid.UUID) (store.Project, error) { return store.Project{ID: projectID, UserID: uid, Name: "default"}, nil } + // /me is now behind RequireAuth: the session cookie must resolve to + // userID via Sessions.Validate rather than being injected directly. + sessions.validateFn = func(context.Context, string) (uuid.UUID, error) { + return userID, nil + } router := NewRouter(a) req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/me", nil) - req = req.WithContext(context.WithValue(req.Context(), ctxKeyUserID{}, userID)) + req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: "some-valid-token"}) w := httptest.NewRecorder() router.ServeHTTP(w, req) diff --git a/internal/api/handlers.go b/internal/api/handlers.go index 2815dec..3553bfa 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -24,12 +24,15 @@ func writeErr(w http.ResponseWriter, status int, msg string) { } func (a *API) handleCheck(w http.ResponseWriter, r *http.Request) { + // pid is guaranteed present and owned by the caller — RequireProjectAccess + // validated it before this handler ever runs. + pid, _ := projectIDFrom(r.Context()) did, err := uuid.Parse(chi.URLParam(r, "did")) if err != nil { writeErr(w, http.StatusBadRequest, "invalid domain id") return } - cs, err := a.Svc.Check(r.Context(), did) + cs, err := a.Svc.Check(r.Context(), pid, did) if err != nil { log.Printf("api: check failed: %v", err) writeErr(w, http.StatusInternalServerError, "internal error") @@ -39,6 +42,9 @@ func (a *API) handleCheck(w http.ResponseWriter, r *http.Request) { } func (a *API) handleApply(w http.ResponseWriter, r *http.Request) { + // pid is guaranteed present and owned by the caller — RequireProjectAccess + // validated it before this handler ever runs. + pid, _ := projectIDFrom(r.Context()) did, err := uuid.Parse(chi.URLParam(r, "did")) if err != nil { writeErr(w, http.StatusBadRequest, "invalid domain id") @@ -54,7 +60,7 @@ func (a *API) handleApply(w http.ResponseWriter, r *http.Request) { return } } - cs, err := a.Svc.Apply(r.Context(), did, service.ApplyRequest{ + cs, err := a.Svc.Apply(r.Context(), pid, did, service.ApplyRequest{ ApplyUpdates: req.ApplyUpdates, ApplyPrunes: req.ApplyPrunes, }) if err != nil { diff --git a/internal/api/middleware.go b/internal/api/middleware.go new file mode 100644 index 0000000..339e171 --- /dev/null +++ b/internal/api/middleware.go @@ -0,0 +1,75 @@ +package api + +import ( + "context" + "net/http" + + "github.com/go-chi/chi/v5" + "github.com/google/uuid" +) + +type ctxKey string + +const ( + ctxUserID ctxKey = "userID" + ctxProjectID ctxKey = "projectID" +) + +// RequireAuth validates the session cookie and, on success, stores the +// authenticated user's ID in the request context (see userIDFrom). Any +// failure — missing cookie or an invalid/expired token — is rejected with +// 401 before reaching the wrapped handler. +func (a *API) RequireAuth(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := r.Cookie(sessionCookieName) + if err != nil { + writeErr(w, http.StatusUnauthorized, "unauthorized") + return + } + uid, err := a.Sessions.Validate(r.Context(), c.Value) + if err != nil { + writeErr(w, http.StatusUnauthorized, "unauthorized") + return + } + ctx := context.WithValue(r.Context(), ctxUserID, uid) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// RequireProjectAccess verifies that the {pid} URL segment names a project +// owned by the authenticated user (set by RequireAuth, which must run +// first) and, on success, stores the project ID in the request context (see +// projectIDFrom). A project that doesn't exist or isn't owned by the caller +// is rejected with 404 — not 403 — so a caller can't distinguish "not yours" +// from "doesn't exist" (closes IDOR by never confirming another tenant's +// project exists). +func (a *API) RequireProjectAccess(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + uid, ok := userIDFrom(r.Context()) + if !ok { + writeErr(w, http.StatusUnauthorized, "unauthorized") + return + } + pid, err := uuid.Parse(chi.URLParam(r, "pid")) + if err != nil { + writeErr(w, http.StatusBadRequest, "invalid project id") + return + } + if _, err := a.Auth.GetProjectOwned(r.Context(), pid, uid); err != nil { + writeErr(w, http.StatusNotFound, "not found") + return + } + ctx := context.WithValue(r.Context(), ctxProjectID, pid) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +func userIDFrom(ctx context.Context) (uuid.UUID, bool) { + v, ok := ctx.Value(ctxUserID).(uuid.UUID) + return v, ok +} + +func projectIDFrom(ctx context.Context) (uuid.UUID, bool) { + v, ok := ctx.Value(ctxProjectID).(uuid.UUID) + return v, ok +} diff --git a/internal/api/middleware_test.go b/internal/api/middleware_test.go new file mode 100644 index 0000000..b24ecbd --- /dev/null +++ b/internal/api/middleware_test.go @@ -0,0 +1,265 @@ +package api + +import ( + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-chi/chi/v5" + "github.com/google/uuid" + + "github.com/vasyakrg/dns-autoresolver/internal/diff" + "github.com/vasyakrg/dns-autoresolver/internal/service" + "github.com/vasyakrg/dns-autoresolver/internal/store" +) + +// --- shared test doubles (also used by api_test.go / tenant_test.go) --- + +// stubSessions is a configurable SessionManager test double. +type stubSessions struct { + validateFn func(ctx context.Context, token string) (uuid.UUID, error) +} + +func (s stubSessions) Create(context.Context, uuid.UUID) (string, time.Time, error) { + return "stub-token", time.Now().Add(time.Hour), nil +} +func (s stubSessions) Validate(ctx context.Context, token string) (uuid.UUID, error) { + return s.validateFn(ctx, token) +} +func (s stubSessions) Destroy(context.Context, string) error { return nil } + +// alwaysValidSessions builds a stubSessions whose Validate always succeeds +// with the given fixed user ID — for tests that only care about behavior +// past the RequireAuth boundary. +func alwaysValidSessions(uid uuid.UUID) stubSessions { + return stubSessions{validateFn: func(context.Context, string) (uuid.UUID, error) { return uid, nil }} +} + +// stubAuthStore is a configurable AuthStore test double. Methods besides +// GetProjectOwned return zero values — tests exercising register/login/me +// use their own dedicated mock (mockAuthStore in auth_test.go). +type stubAuthStore struct { + getProjectOwnedFn func(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error) +} + +func (s stubAuthStore) RegisterUser(context.Context, string, string) (store.User, store.Project, error) { + return store.User{}, store.Project{}, nil +} +func (s stubAuthStore) GetUserByEmail(context.Context, string) (store.User, error) { + return store.User{}, nil +} +func (s stubAuthStore) GetUserByID(context.Context, uuid.UUID) (store.User, error) { + return store.User{}, nil +} +func (s stubAuthStore) GetUserProject(context.Context, uuid.UUID) (store.Project, error) { + return store.Project{}, nil +} +func (s stubAuthStore) GetProjectOwned(ctx context.Context, projectID, userID uuid.UUID) (store.Project, error) { + return s.getProjectOwnedFn(ctx, projectID, userID) +} + +// alwaysOwnedAuthStore builds a stubAuthStore whose GetProjectOwned always +// succeeds, treating the caller as the owner of whatever project id is +// requested — for tests that only care about behavior past the +// RequireProjectAccess boundary (CRUD/check/apply tests). +func alwaysOwnedAuthStore() stubAuthStore { + return stubAuthStore{getProjectOwnedFn: func(_ context.Context, pid, uid uuid.UUID) (store.Project, error) { + return store.Project{ID: pid, UserID: uid}, nil + }} +} + +// requestWithSessionCookie builds an httptest request carrying a session +// cookie so it clears RequireAuth in tests using stubSessions/alwaysValidSessions +// (which ignore the token value and validate unconditionally). +func requestWithSessionCookie(method, url string, body io.Reader) *http.Request { + req := httptest.NewRequest(method, url, body) + req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: "test-session-token"}) + return req +} + +// recordingCheckApplier is a CheckApplier test double that records whether +// Check/Apply were invoked — used by the IDOR regression test to assert the +// service layer is never reached for a project the caller doesn't own. +type recordingCheckApplier struct { + checkCalled bool + applyCalled bool +} + +func (r *recordingCheckApplier) Check(context.Context, uuid.UUID, uuid.UUID) (diff.Changeset, error) { + r.checkCalled = true + return diff.Changeset{}, nil +} +func (r *recordingCheckApplier) Apply(context.Context, uuid.UUID, uuid.UUID, service.ApplyRequest) (diff.Changeset, error) { + r.applyCalled = true + return diff.Changeset{}, nil +} + +// --- RequireAuth --- + +func TestRequireAuth_NoCookie_Returns401(t *testing.T) { + a := &API{Sessions: stubSessions{validateFn: func(context.Context, string) (uuid.UUID, error) { + t.Fatal("Validate must not be called when there is no cookie") + return uuid.Nil, nil + }}} + nextCalled := false + next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { nextCalled = true }) + + req := httptest.NewRequest(http.MethodGet, "/whatever", nil) + w := httptest.NewRecorder() + a.RequireAuth(next).ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", w.Code) + } + if nextCalled { + t.Fatal("next must not be called without a session cookie") + } +} + +func TestRequireAuth_InvalidToken_Returns401(t *testing.T) { + a := &API{Sessions: stubSessions{validateFn: func(context.Context, string) (uuid.UUID, error) { + return uuid.Nil, errors.New("invalid or expired session") + }}} + nextCalled := false + next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { nextCalled = true }) + + req := requestWithSessionCookie(http.MethodGet, "/whatever", nil) + w := httptest.NewRecorder() + a.RequireAuth(next).ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", w.Code) + } + if nextCalled { + t.Fatal("next must not be called when Sessions.Validate fails") + } +} + +func TestRequireAuth_Success_CallsNextWithUserID(t *testing.T) { + uid := uuid.New() + a := &API{Sessions: alwaysValidSessions(uid)} + + var gotUID uuid.UUID + var gotOK bool + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotUID, gotOK = userIDFrom(r.Context()) + w.WriteHeader(http.StatusOK) + }) + + req := requestWithSessionCookie(http.MethodGet, "/whatever", nil) + w := httptest.NewRecorder() + a.RequireAuth(next).ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200 (next called), got %d", w.Code) + } + if !gotOK || gotUID != uid { + t.Fatalf("expected userIDFrom to yield %s, got %s (ok=%v)", uid, gotUID, gotOK) + } +} + +// --- RequireProjectAccess --- + +func TestRequireProjectAccess_NotOwned_Returns404(t *testing.T) { + a := &API{Auth: stubAuthStore{getProjectOwnedFn: func(context.Context, uuid.UUID, uuid.UUID) (store.Project, error) { + return store.Project{}, errors.New("no rows") + }}} + nextCalled := false + next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { nextCalled = true }) + + router := chi.NewRouter() + router.With(a.RequireProjectAccess).Get("/projects/{pid}", next.ServeHTTP) + + req := httptest.NewRequest(http.MethodGet, "/projects/"+uuid.New().String(), nil) + req = req.WithContext(context.WithValue(req.Context(), ctxUserID, uuid.New())) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Fatalf("expected 404 for an unowned/foreign project, got %d body %s", w.Code, w.Body.String()) + } + if nextCalled { + t.Fatal("next must not be called when the project isn't owned by the caller") + } +} + +func TestRequireProjectAccess_Success_CallsNextWithProjectID(t *testing.T) { + a := &API{Auth: alwaysOwnedAuthStore()} + pid := uuid.New() + uid := uuid.New() + + var gotPID uuid.UUID + var gotOK bool + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPID, gotOK = projectIDFrom(r.Context()) + w.WriteHeader(http.StatusOK) + }) + + router := chi.NewRouter() + router.With(a.RequireProjectAccess).Get("/projects/{pid}", next.ServeHTTP) + + req := httptest.NewRequest(http.MethodGet, "/projects/"+pid.String(), nil) + req = req.WithContext(context.WithValue(req.Context(), ctxUserID, uid)) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200 (next called), got %d body %s", w.Code, w.Body.String()) + } + if !gotOK || gotPID != pid { + t.Fatalf("expected projectIDFrom to yield %s, got %s (ok=%v)", pid, gotPID, gotOK) + } +} + +// --- IDOR regression --- + +// TestIDOR_CheckForeignProject_Returns404AndServiceNotCalled is the IDOR +// regression required by Task 4: a request for project B's domain, made +// with user A's session (who owns project A, not B), must be rejected by +// RequireProjectAccess with 404 before service.Check is ever invoked — a +// caller must never be able to check/apply another tenant's domain by +// guessing/leaking its ids. +func TestIDOR_CheckForeignProject_Returns404AndServiceNotCalled(t *testing.T) { + userA := uuid.New() + pidA := uuid.New() + pidB := uuid.New() // owned by a different user; not A + + svc := &recordingCheckApplier{} + auth := stubAuthStore{getProjectOwnedFn: func(_ context.Context, pid, uid uuid.UUID) (store.Project, error) { + if pid == pidA && uid == userA { + return store.Project{ID: pidA, UserID: userA}, nil + } + return store.Project{}, errors.New("not found") + }} + a := &API{Svc: svc, Auth: auth, Sessions: alwaysValidSessions(userA)} + router := NewRouter(a) + + did := uuid.New().String() + req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+pidB.String()+"/domains/"+did+"/check", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Fatalf("expected 404 for user A requesting project B's domain, got %d body %s", w.Code, w.Body.String()) + } + if svc.checkCalled { + t.Fatal("service.Check must not be called when the caller doesn't own the project") + } + + // Sanity check: the same user against their own project succeeds and + // does reach the service — proving the 404 above is really about + // project ownership, not e.g. a broken route. + req2 := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+pidA.String()+"/domains/"+did+"/check", nil) + w2 := httptest.NewRecorder() + router.ServeHTTP(w2, req2) + if w2.Code != http.StatusOK { + t.Fatalf("expected 200 for user A requesting their own project, got %d body %s", w2.Code, w2.Body.String()) + } + if !svc.checkCalled { + t.Fatal("expected service.Check to be called for the caller's own project") + } +} diff --git a/internal/api/tenant_handlers.go b/internal/api/tenant_handlers.go index 02f3dcc..13711a4 100644 --- a/internal/api/tenant_handlers.go +++ b/internal/api/tenant_handlers.go @@ -24,11 +24,9 @@ func decodeBody(w http.ResponseWriter, r *http.Request, v any) bool { // --- accounts --- func (a *API) handleCreateAccount(w http.ResponseWriter, r *http.Request) { - pid, err := uuid.Parse(chi.URLParam(r, "pid")) - if err != nil { - writeErr(w, http.StatusBadRequest, "invalid project id") - return - } + // pid is guaranteed present and owned by the caller — RequireProjectAccess + // validated it before this handler ever runs. + pid, _ := projectIDFrom(r.Context()) var req accountRequest if !decodeBody(w, r, &req) { return @@ -53,11 +51,9 @@ func (a *API) handleCreateAccount(w http.ResponseWriter, r *http.Request) { } func (a *API) handleListAccounts(w http.ResponseWriter, r *http.Request) { - pid, err := uuid.Parse(chi.URLParam(r, "pid")) - if err != nil { - writeErr(w, http.StatusBadRequest, "invalid project id") - return - } + // pid is guaranteed present and owned by the caller — RequireProjectAccess + // validated it before this handler ever runs. + pid, _ := projectIDFrom(r.Context()) accs, err := a.Store.ListAccounts(r.Context(), pid) if err != nil { log.Printf("api: list accounts failed: %v", err) @@ -72,11 +68,9 @@ func (a *API) handleListAccounts(w http.ResponseWriter, r *http.Request) { } func (a *API) handleDeleteAccount(w http.ResponseWriter, r *http.Request) { - pid, err := uuid.Parse(chi.URLParam(r, "pid")) - if err != nil { - writeErr(w, http.StatusBadRequest, "invalid project id") - return - } + // pid is guaranteed present and owned by the caller — RequireProjectAccess + // validated it before this handler ever runs. + pid, _ := projectIDFrom(r.Context()) aid, err := uuid.Parse(chi.URLParam(r, "aid")) if err != nil { writeErr(w, http.StatusBadRequest, "invalid account id") @@ -93,11 +87,9 @@ func (a *API) handleDeleteAccount(w http.ResponseWriter, r *http.Request) { // handleImportZones lists zones from the provider for the given account and // creates one domain per zone (template_id left unset). func (a *API) handleImportZones(w http.ResponseWriter, r *http.Request) { - pid, err := uuid.Parse(chi.URLParam(r, "pid")) - if err != nil { - writeErr(w, http.StatusBadRequest, "invalid project id") - return - } + // pid is guaranteed present and owned by the caller — RequireProjectAccess + // validated it before this handler ever runs. + pid, _ := projectIDFrom(r.Context()) aid, err := uuid.Parse(chi.URLParam(r, "aid")) if err != nil { writeErr(w, http.StatusBadRequest, "invalid account id") @@ -145,11 +137,9 @@ func (a *API) handleImportZones(w http.ResponseWriter, r *http.Request) { // --- templates --- func (a *API) handleCreateTemplate(w http.ResponseWriter, r *http.Request) { - pid, err := uuid.Parse(chi.URLParam(r, "pid")) - if err != nil { - writeErr(w, http.StatusBadRequest, "invalid project id") - return - } + // pid is guaranteed present and owned by the caller — RequireProjectAccess + // validated it before this handler ever runs. + pid, _ := projectIDFrom(r.Context()) var req templateRequest if !decodeBody(w, r, &req) { return @@ -169,11 +159,9 @@ func (a *API) handleCreateTemplate(w http.ResponseWriter, r *http.Request) { } func (a *API) handleListTemplates(w http.ResponseWriter, r *http.Request) { - pid, err := uuid.Parse(chi.URLParam(r, "pid")) - if err != nil { - writeErr(w, http.StatusBadRequest, "invalid project id") - return - } + // pid is guaranteed present and owned by the caller — RequireProjectAccess + // validated it before this handler ever runs. + pid, _ := projectIDFrom(r.Context()) tpls, err := a.Store.ListTemplates(r.Context(), pid) if err != nil { log.Printf("api: list templates failed: %v", err) @@ -188,11 +176,9 @@ func (a *API) handleListTemplates(w http.ResponseWriter, r *http.Request) { } func (a *API) handleUpdateTemplate(w http.ResponseWriter, r *http.Request) { - pid, err := uuid.Parse(chi.URLParam(r, "pid")) - if err != nil { - writeErr(w, http.StatusBadRequest, "invalid project id") - return - } + // pid is guaranteed present and owned by the caller — RequireProjectAccess + // validated it before this handler ever runs. + pid, _ := projectIDFrom(r.Context()) tid, err := uuid.Parse(chi.URLParam(r, "tid")) if err != nil { writeErr(w, http.StatusBadRequest, "invalid template id") @@ -217,11 +203,9 @@ func (a *API) handleUpdateTemplate(w http.ResponseWriter, r *http.Request) { } func (a *API) handleDeleteTemplate(w http.ResponseWriter, r *http.Request) { - pid, err := uuid.Parse(chi.URLParam(r, "pid")) - if err != nil { - writeErr(w, http.StatusBadRequest, "invalid project id") - return - } + // pid is guaranteed present and owned by the caller — RequireProjectAccess + // validated it before this handler ever runs. + pid, _ := projectIDFrom(r.Context()) tid, err := uuid.Parse(chi.URLParam(r, "tid")) if err != nil { writeErr(w, http.StatusBadRequest, "invalid template id") @@ -238,11 +222,9 @@ func (a *API) handleDeleteTemplate(w http.ResponseWriter, r *http.Request) { // --- domains --- func (a *API) handleCreateDomain(w http.ResponseWriter, r *http.Request) { - pid, err := uuid.Parse(chi.URLParam(r, "pid")) - if err != nil { - writeErr(w, http.StatusBadRequest, "invalid project id") - return - } + // pid is guaranteed present and owned by the caller — RequireProjectAccess + // validated it before this handler ever runs. + pid, _ := projectIDFrom(r.Context()) var req domainRequest if !decodeBody(w, r, &req) { return @@ -286,11 +268,9 @@ func (a *API) handleCreateDomain(w http.ResponseWriter, r *http.Request) { } func (a *API) handleListDomains(w http.ResponseWriter, r *http.Request) { - pid, err := uuid.Parse(chi.URLParam(r, "pid")) - if err != nil { - writeErr(w, http.StatusBadRequest, "invalid project id") - return - } + // pid is guaranteed present and owned by the caller — RequireProjectAccess + // validated it before this handler ever runs. + pid, _ := projectIDFrom(r.Context()) doms, err := a.Store.ListDomains(r.Context(), pid) if err != nil { log.Printf("api: list domains failed: %v", err) @@ -308,11 +288,9 @@ func (a *API) handleListDomains(w http.ResponseWriter, r *http.Request) { // check/apply a domain — this is what makes an imported domain (which // starts with template_id=NULL) checkable, closing the import→check loop. func (a *API) handleSetDomainTemplate(w http.ResponseWriter, r *http.Request) { - pid, err := uuid.Parse(chi.URLParam(r, "pid")) - if err != nil { - writeErr(w, http.StatusBadRequest, "invalid project id") - return - } + // pid is guaranteed present and owned by the caller — RequireProjectAccess + // validated it before this handler ever runs. + pid, _ := projectIDFrom(r.Context()) did, err := uuid.Parse(chi.URLParam(r, "did")) if err != nil { writeErr(w, http.StatusBadRequest, "invalid domain id") @@ -339,11 +317,9 @@ func (a *API) handleSetDomainTemplate(w http.ResponseWriter, r *http.Request) { } func (a *API) handleDeleteDomain(w http.ResponseWriter, r *http.Request) { - pid, err := uuid.Parse(chi.URLParam(r, "pid")) - if err != nil { - writeErr(w, http.StatusBadRequest, "invalid project id") - return - } + // pid is guaranteed present and owned by the caller — RequireProjectAccess + // validated it before this handler ever runs. + pid, _ := projectIDFrom(r.Context()) did, err := uuid.Parse(chi.URLParam(r, "did")) if err != nil { writeErr(w, http.StatusBadRequest, "invalid domain id") diff --git a/internal/api/tenant_test.go b/internal/api/tenant_test.go index 2b807a8..73c9f89 100644 --- a/internal/api/tenant_test.go +++ b/internal/api/tenant_test.go @@ -132,7 +132,9 @@ func (m *mockTenantStore) ImportDomains(_ context.Context, projectID, accountID type mockCipher struct{} -func (mockCipher) Encrypt(plaintext []byte) (string, error) { return "ENC(" + string(plaintext) + ")", nil } +func (mockCipher) Encrypt(plaintext []byte) (string, error) { + return "ENC(" + string(plaintext) + ")", nil +} func (mockCipher) Decrypt(enc string) ([]byte, error) { return []byte(strings.TrimSuffix(strings.TrimPrefix(enc, "ENC("), ")")), nil } @@ -160,9 +162,17 @@ func (mockProvider) ApplyChanges(context.Context, provider.Credentials, string, return nil } +// newTenantTestAPI wires a fixed authenticated user who owns whatever +// project id is requested (alwaysOwnedAuthStore/alwaysValidSessions, see +// middleware_test.go) — these tests exercise CRUD behavior past the +// RequireAuth/RequireProjectAccess boundary, which has its own dedicated +// coverage in middleware_test.go. func newTenantTestAPI() (*API, *mockTenantStore) { ts := &mockTenantStore{} - a := &API{Store: ts, Cipher: mockCipher{}, Reg: &mockRegistry{}} + a := &API{ + Store: ts, Cipher: mockCipher{}, Reg: &mockRegistry{}, + Auth: alwaysOwnedAuthStore(), Sessions: alwaysValidSessions(uuid.New()), + } return a, ts } @@ -173,7 +183,7 @@ func TestCreateAccount_SecretEncryptedAndNotInResponse(t *testing.T) { router := NewRouter(a) body := `{"provider":"selectel","secret":"super-secret-token","comment":"prod"}` - req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts", strings.NewReader(body)) + req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts", strings.NewReader(body)) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -211,7 +221,7 @@ func TestListAccounts_NoSecretsInResponse(t *testing.T) { } router := NewRouter(a) - req := httptest.NewRequest(http.MethodGet, "/api/v1/projects/"+testPID+"/accounts", nil) + req := requestWithSessionCookie(http.MethodGet, "/api/v1/projects/"+testPID+"/accounts", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -234,7 +244,7 @@ func TestDeleteAccount_BadUUID(t *testing.T) { a, _ := newTenantTestAPI() router := NewRouter(a) - req := httptest.NewRequest(http.MethodDelete, "/api/v1/projects/"+testPID+"/accounts/not-a-uuid", nil) + req := requestWithSessionCookie(http.MethodDelete, "/api/v1/projects/"+testPID+"/accounts/not-a-uuid", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -250,7 +260,7 @@ func TestCreateTemplate_SavesRecords(t *testing.T) { router := NewRouter(a) body := `{"name":"base","records":[{"type":"A","name":"@","ttl":300,"values":["1.2.3.4"]}]}` - req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/templates", strings.NewReader(body)) + req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/templates", strings.NewReader(body)) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -278,7 +288,7 @@ func TestUpdateTemplate_BadUUID(t *testing.T) { router := NewRouter(a) body := `{"name":"x","records":[]}` - req := httptest.NewRequest(http.MethodPut, "/api/v1/projects/"+testPID+"/templates/not-a-uuid", strings.NewReader(body)) + req := requestWithSessionCookie(http.MethodPut, "/api/v1/projects/"+testPID+"/templates/not-a-uuid", strings.NewReader(body)) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -299,7 +309,7 @@ func TestImportZones_CreatesDomainPerZone(t *testing.T) { }} router := NewRouter(a) - req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/"+accID.String()+"/import", nil) + req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/"+accID.String()+"/import", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -337,7 +347,7 @@ func TestImportZones_AtomicRollbackOnError(t *testing.T) { }} router := NewRouter(a) - req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/"+accID.String()+"/import", nil) + req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/"+accID.String()+"/import", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -356,7 +366,7 @@ func TestImportZones_BadAccountUUID(t *testing.T) { a, _ := newTenantTestAPI() router := NewRouter(a) - req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/not-a-uuid/import", nil) + req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/accounts/not-a-uuid/import", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -370,7 +380,7 @@ func TestCreateDomain_BadProjectUUID(t *testing.T) { router := NewRouter(a) body := `{"providerAccountId":"` + uuid.New().String() + `","zoneName":"example.com","zoneId":"z1"}` - req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/not-a-uuid/domains", strings.NewReader(body)) + req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/not-a-uuid/domains", strings.NewReader(body)) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -390,7 +400,7 @@ func TestCreateDomain_AccountNotFoundInProject(t *testing.T) { // ts.accounts is empty, so GetAccount will not find this id. foreignAccID := uuid.New() body := `{"providerAccountId":"` + foreignAccID.String() + `","zoneName":"example.com","zoneId":"z1"}` - req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body)) + req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body)) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -413,7 +423,7 @@ func TestCreateDomain_TemplateNotFoundInProject(t *testing.T) { foreignTplID := uuid.New() body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1","templateId":"` + foreignTplID.String() + `"}` - req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body)) + req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body)) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -434,7 +444,7 @@ func TestCreateDomain_HappyPath(t *testing.T) { router := NewRouter(a) body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1"}` - req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body)) + req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body)) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -464,7 +474,7 @@ func TestCreateDomain_ValidTemplateInProject(t *testing.T) { router := NewRouter(a) body := `{"providerAccountId":"` + accID.String() + `","zoneName":"example.com","zoneId":"z1","templateId":"` + tplID.String() + `"}` - req := httptest.NewRequest(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body)) + req := requestWithSessionCookie(http.MethodPost, "/api/v1/projects/"+testPID+"/domains", strings.NewReader(body)) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -490,7 +500,7 @@ func TestSetDomainTemplate_ValidTemplateId(t *testing.T) { router := NewRouter(a) body := `{"templateId":"` + tplID.String() + `"}` - req := httptest.NewRequest(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body)) + req := requestWithSessionCookie(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body)) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -513,7 +523,7 @@ func TestSetDomainTemplate_BadTemplateUUID(t *testing.T) { router := NewRouter(a) body := `{"templateId":"not-a-uuid"}` - req := httptest.NewRequest(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body)) + req := requestWithSessionCookie(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body)) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -530,7 +540,7 @@ func TestSetDomainTemplate_TemplateNotFound(t *testing.T) { router := NewRouter(a) body := `{"templateId":"` + uuid.New().String() + `"}` - req := httptest.NewRequest(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body)) + req := requestWithSessionCookie(http.MethodPatch, "/api/v1/projects/"+testPID+"/domains/"+domID.String(), strings.NewReader(body)) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -543,7 +553,7 @@ func TestDeleteDomain_BadUUID(t *testing.T) { a, _ := newTenantTestAPI() router := NewRouter(a) - req := httptest.NewRequest(http.MethodDelete, "/api/v1/projects/"+testPID+"/domains/not-a-uuid", nil) + req := requestWithSessionCookie(http.MethodDelete, "/api/v1/projects/"+testPID+"/domains/not-a-uuid", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) diff --git a/internal/service/service.go b/internal/service/service.go index 2dec925..ea8c46d 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -21,7 +21,7 @@ type DomainRef struct { } type Loader interface { - LoadDomain(ctx context.Context, domainID uuid.UUID) (DomainRef, error) + LoadDomain(ctx context.Context, projectID, domainID uuid.UUID) (DomainRef, error) } type Recorder interface { @@ -45,8 +45,10 @@ func New(loader Loader, rec Recorder, reg *registry.Registry, cipher *crypto.Cip } // resolve loads the domain, its provider and decrypted credentials, and computes the diff. -func (s *DomainService) resolve(ctx context.Context, domainID uuid.UUID) (provider.Provider, provider.Credentials, DomainRef, diff.Changeset, error) { - ref, err := s.loader.LoadDomain(ctx, domainID) +// projectID scopes the lookup so a domainID belonging to another tenant's +// project can never be resolved here (closes IDOR). +func (s *DomainService) resolve(ctx context.Context, projectID, domainID uuid.UUID) (provider.Provider, provider.Credentials, DomainRef, diff.Changeset, error) { + ref, err := s.loader.LoadDomain(ctx, projectID, domainID) if err != nil { return nil, provider.Credentials{}, ref, diff.Changeset{}, err } @@ -68,8 +70,8 @@ func (s *DomainService) resolve(ctx context.Context, domainID uuid.UUID) (provid } // Check computes and records the diff between template and zone. -func (s *DomainService) Check(ctx context.Context, domainID uuid.UUID) (diff.Changeset, error) { - _, _, _, cs, err := s.resolve(ctx, domainID) +func (s *DomainService) Check(ctx context.Context, projectID, domainID uuid.UUID) (diff.Changeset, error) { + _, _, _, cs, err := s.resolve(ctx, projectID, domainID) if err != nil { return diff.Changeset{}, err } @@ -80,8 +82,8 @@ func (s *DomainService) Check(ctx context.Context, domainID uuid.UUID) (diff.Cha } // Apply applies updates always (when ApplyUpdates) and prunes only when ApplyPrunes. -func (s *DomainService) Apply(ctx context.Context, domainID uuid.UUID, req ApplyRequest) (diff.Changeset, error) { - p, creds, ref, cs, err := s.resolve(ctx, domainID) +func (s *DomainService) Apply(ctx context.Context, projectID, domainID uuid.UUID, req ApplyRequest) (diff.Changeset, error) { + p, creds, ref, cs, err := s.resolve(ctx, projectID, domainID) if err != nil { return diff.Changeset{}, err } diff --git a/internal/service/service_test.go b/internal/service/service_test.go index e799306..d3e7cf7 100644 --- a/internal/service/service_test.go +++ b/internal/service/service_test.go @@ -44,7 +44,9 @@ func (f *fakeProvider) ApplyChanges(_ context.Context, _ provider.Credentials, _ type fakeLoader struct{ ref DomainRef } -func (l fakeLoader) LoadDomain(context.Context, uuid.UUID) (DomainRef, error) { return l.ref, nil } +func (l fakeLoader) LoadDomain(context.Context, uuid.UUID, uuid.UUID) (DomainRef, error) { + return l.ref, nil +} type nopRecorder struct{} @@ -66,7 +68,7 @@ func TestCheckProducesDiff(t *testing.T) { {Type: "A", Name: "a.example.com.", TTL: 300, Values: []string{"1.1.1.1"}}, // update }} svc, _ := setup(t, actual, tmpl) - cs, err := svc.Check(context.Background(), uuid.New()) + cs, err := svc.Check(context.Background(), uuid.New(), uuid.New()) if err != nil { t.Fatal(err) } @@ -87,7 +89,7 @@ func TestApplyRespectsPruneGuard(t *testing.T) { // applyPrunes=false → удаление b НЕ применяется svc, fp := setup(t, actual, tmpl) - if _, err := svc.Apply(context.Background(), uuid.New(), ApplyRequest{ApplyUpdates: true, ApplyPrunes: false}); err != nil { + if _, err := svc.Apply(context.Background(), uuid.New(), uuid.New(), ApplyRequest{ApplyUpdates: true, ApplyPrunes: false}); err != nil { t.Fatal(err) } for _, d := range fp.applied.Diffs { @@ -98,7 +100,7 @@ func TestApplyRespectsPruneGuard(t *testing.T) { // applyPrunes=true → удаление b применяется svc2, fp2 := setup(t, actual, tmpl) - if _, err := svc2.Apply(context.Background(), uuid.New(), ApplyRequest{ApplyUpdates: true, ApplyPrunes: true}); err != nil { + if _, err := svc2.Apply(context.Background(), uuid.New(), uuid.New(), ApplyRequest{ApplyUpdates: true, ApplyPrunes: true}); err != nil { t.Fatal(err) } var sawDelete bool diff --git a/internal/store/db/domains.sql.go b/internal/store/db/domains.sql.go index 3b5aa32..527bd66 100644 --- a/internal/store/db/domains.sql.go +++ b/internal/store/db/domains.sql.go @@ -162,9 +162,14 @@ SELECT d.zone_id, a.provider, a.secret_enc, t.doc FROM domains d JOIN provider_accounts a ON a.id = d.provider_account_id LEFT JOIN templates t ON t.id = d.template_id -WHERE d.id = $1 +WHERE d.id = $1 AND d.project_id = $2 ` +type LoadDomainFullParams struct { + ID uuid.UUID `json:"id"` + ProjectID uuid.UUID `json:"project_id"` +} + type LoadDomainFullRow struct { ZoneID string `json:"zone_id"` Provider string `json:"provider"` @@ -172,8 +177,8 @@ type LoadDomainFullRow struct { Doc *dto.TemplateDoc `json:"doc"` } -func (q *Queries) LoadDomainFull(ctx context.Context, id uuid.UUID) (LoadDomainFullRow, error) { - row := q.db.QueryRow(ctx, loadDomainFull, id) +func (q *Queries) LoadDomainFull(ctx context.Context, arg LoadDomainFullParams) (LoadDomainFullRow, error) { + row := q.db.QueryRow(ctx, loadDomainFull, arg.ID, arg.ProjectID) var i LoadDomainFullRow err := row.Scan( &i.ZoneID, diff --git a/internal/store/loader.go b/internal/store/loader.go index 984c85c..620db85 100644 --- a/internal/store/loader.go +++ b/internal/store/loader.go @@ -13,9 +13,11 @@ import ( ) // LoadDomain joins domains+provider_accounts+templates to build the -// service.DomainRef needed to check/apply a domain's DNS records. -func (s *Store) LoadDomain(ctx context.Context, domainID uuid.UUID) (service.DomainRef, error) { - row, err := s.q.LoadDomainFull(ctx, domainID) +// service.DomainRef needed to check/apply a domain's DNS records. Scoped by +// projectID so a domain belonging to another tenant's project can never be +// loaded, even if its domainID is guessed/leaked (closes IDOR). +func (s *Store) LoadDomain(ctx context.Context, projectID, domainID uuid.UUID) (service.DomainRef, error) { + row, err := s.q.LoadDomainFull(ctx, db.LoadDomainFullParams{ID: domainID, ProjectID: projectID}) if err != nil { return service.DomainRef{}, err } diff --git a/internal/store/loader_test.go b/internal/store/loader_test.go index 33f210f..c53a1c7 100644 --- a/internal/store/loader_test.go +++ b/internal/store/loader_test.go @@ -40,7 +40,7 @@ func TestLoadDomainAndSaveCheckRun(t *testing.T) { t.Fatal(err) } - ref, err := s.LoadDomain(ctx, domain.ID) + ref, err := s.LoadDomain(ctx, defaultProject, domain.ID) if err != nil { t.Fatal(err) } @@ -87,7 +87,7 @@ func TestLoadDomainNoTemplate(t *testing.T) { t.Fatal(err) } - if _, err := s.LoadDomain(ctx, domain.ID); err == nil { + if _, err := s.LoadDomain(ctx, defaultProject, domain.ID); err == nil { t.Fatal("expected error for domain without template, got nil") } } diff --git a/internal/store/queries/domains.sql b/internal/store/queries/domains.sql index 1f05fda..9e16759 100644 --- a/internal/store/queries/domains.sql +++ b/internal/store/queries/domains.sql @@ -27,4 +27,4 @@ SELECT d.zone_id, a.provider, a.secret_enc, t.doc FROM domains d JOIN provider_accounts a ON a.id = d.provider_account_id LEFT JOIN templates t ON t.id = d.template_id -WHERE d.id = $1; +WHERE d.id = $1 AND d.project_id = $2; diff --git a/internal/store/store_test.go b/internal/store/store_test.go index bcbdc12..9572617 100644 --- a/internal/store/store_test.go +++ b/internal/store/store_test.go @@ -214,7 +214,7 @@ func TestSetDomainTemplate_ClosesImportCheckLoop(t *testing.T) { dom := doms[0] // Before binding, the domain is not checkable. - if _, err := s.LoadDomain(ctx, dom.ID); err == nil { + if _, err := s.LoadDomain(ctx, defaultProject, dom.ID); err == nil { t.Fatal("expected LoadDomain to fail before a template is bound") } @@ -234,7 +234,7 @@ func TestSetDomainTemplate_ClosesImportCheckLoop(t *testing.T) { t.Fatalf("expected domain.TemplateID=%s, got %+v", tpl.ID, updated.TemplateID) } - ref, err := s.LoadDomain(ctx, dom.ID) + ref, err := s.LoadDomain(ctx, defaultProject, dom.ID) if err != nil { t.Fatalf("expected LoadDomain to succeed after binding template, got error: %v", err) } From b5d9e8f7ab74a163a9c0613c32532db74ee3642a Mon Sep 17 00:00:00 2001 From: Vassiliy Yegorov Date: Fri, 3 Jul 2026 21:00:18 +0700 Subject: [PATCH 07/10] =?UTF-8?q?feat(web):=20AuthContext=20+=20=D0=BA?= =?UTF-8?q?=D0=BB=D0=B8=D0=B5=D0=BD=D1=82=20=D0=BF=D0=BE=D0=B4=20cookie-?= =?UTF-8?q?=D1=81=D0=B5=D1=81=D1=81=D0=B8=D0=B8,=20projectId=20=D0=B8?= =?UTF-8?q?=D0=B7=20=D0=BA=D0=BE=D0=BD=D1=82=D0=B5=D0=BA=D1=81=D1=82=D0=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- web/src/api/client.test.ts | 105 +++++++++++++++++++++++++++--- web/src/api/client.ts | 87 +++++++++++++++++-------- web/src/api/types.ts | 4 ++ web/src/auth/AuthContext.test.tsx | 82 +++++++++++++++++++++++ web/src/auth/AuthContext.tsx | 74 +++++++++++++++++++++ web/src/hooks/useApi.ts | 84 ++++++++++++++++-------- web/src/lib/config.ts | 3 +- 7 files changed, 374 insertions(+), 65 deletions(-) create mode 100644 web/src/auth/AuthContext.test.tsx create mode 100644 web/src/auth/AuthContext.tsx diff --git a/web/src/api/client.test.ts b/web/src/api/client.test.ts index c894b9c..43a966f 100644 --- a/web/src/api/client.test.ts +++ b/web/src/api/client.test.ts @@ -1,6 +1,7 @@ import { describe, it, expect, vi, beforeEach } from "vitest" -import { api } from "./client" -import { DEFAULT_PROJECT_ID } from "@/lib/config" +import { api, UnauthorizedError } from "./client" + +const PROJECT_ID = "11111111-1111-1111-1111-111111111111" beforeEach(() => { vi.restoreAllMocks() }) @@ -13,19 +14,72 @@ function mockFetch(body: unknown, ok = true, status = 200) { } describe("api client", () => { - it("lists accounts at project-scoped path", async () => { + it("sends credentials:include on every request", async () => { + const spy = mockFetch([]) + await api.listAccounts(PROJECT_ID) + const [, opts] = spy.mock.calls[0] + expect((opts as RequestInit).credentials).toBe("include") + }) + + describe("api.auth", () => { + it("login POSTs to /api/v1/auth/login with credentials:include", async () => { + const spy = mockFetch({ user: { id: "u1", email: "a@b.com" }, project: { id: "p1", name: "Default" } }) + await api.auth.login("a@b.com", "secret") + const [url, opts] = spy.mock.calls[0] + expect(url).toBe("/api/v1/auth/login") + expect((opts as RequestInit).method).toBe("POST") + expect((opts as RequestInit).credentials).toBe("include") + expect(String((opts as RequestInit).body)).toContain("secret") + }) + + it("register POSTs to /api/v1/auth/register", async () => { + const spy = mockFetch({ user: { id: "u1", email: "a@b.com" }, project: { id: "p1", name: "Default" } }) + await api.auth.register("a@b.com", "secret") + const [url, opts] = spy.mock.calls[0] + expect(url).toBe("/api/v1/auth/register") + expect((opts as RequestInit).method).toBe("POST") + }) + + it("logout POSTs to /api/v1/auth/logout", async () => { + const spy = mockFetch(undefined, true, 204) + await api.auth.logout() + const [url, opts] = spy.mock.calls[0] + expect(url).toBe("/api/v1/auth/logout") + expect((opts as RequestInit).method).toBe("POST") + }) + + it("me GETs /api/v1/auth/me and returns AuthState", async () => { + const state = { user: { id: "u1", email: "a@b.com" }, project: { id: "p1", name: "Default" } } + const spy = mockFetch(state) + const result = await api.auth.me() + const [url] = spy.mock.calls[0] + expect(url).toBe("/api/v1/auth/me") + expect(result).toEqual(state) + }) + }) + + it("resource methods hit project-scoped path with projectId first", async () => { const spy = mockFetch([{ id: "a1", provider: "selectel", comment: "" }]) - const accounts = await api.listAccounts() + const accounts = await api.listAccounts(PROJECT_ID) expect(accounts).toHaveLength(1) expect(spy).toHaveBeenCalledWith( - `/api/v1/projects/${DEFAULT_PROJECT_ID}/accounts`, + `/api/v1/projects/${PROJECT_ID}/accounts`, + expect.objectContaining({ method: "GET" }), + ) + }) + + it("listDomains(projectId) hits /api/v1/projects/{projectId}/domains", async () => { + const spy = mockFetch([]) + await api.listDomains(PROJECT_ID) + expect(spy).toHaveBeenCalledWith( + `/api/v1/projects/${PROJECT_ID}/domains`, expect.objectContaining({ method: "GET" }), ) }) it("sends secret on account creation but path has no secret leakage in response typing", async () => { const spy = mockFetch({ id: "a2", provider: "selectel", comment: "prod" }) - await api.createAccount({ provider: "selectel", secret: "TOKEN", comment: "prod" }) + await api.createAccount(PROJECT_ID, { provider: "selectel", secret: "TOKEN", comment: "prod" }) const [, opts] = spy.mock.calls[0] expect((opts as RequestInit).method).toBe("POST") expect(String((opts as RequestInit).body)).toContain("TOKEN") @@ -33,14 +87,45 @@ describe("api client", () => { it("throws on non-ok response", async () => { mockFetch({ error: "boom" }, false, 500) - await expect(api.listDomains()).rejects.toThrow() + await expect(api.listDomains(PROJECT_ID)).rejects.toThrow() }) - it("applies with prune flag", async () => { + it("throws UnauthorizedError on 401", async () => { + mockFetch({ error: "unauthorized" }, false, 401) + await expect(api.listDomains(PROJECT_ID)).rejects.toThrow(UnauthorizedError) + }) + + it("applies with prune flag using projectId, id, body order", async () => { const spy = mockFetch({ updates: [], prunes: [], readOnly: [], inSyncCount: 0 }) - await api.applyDomain("d1", { applyUpdates: true, applyPrunes: true }) + await api.applyDomain(PROJECT_ID, "d1", { applyUpdates: true, applyPrunes: true }) const [url, opts] = spy.mock.calls[0] - expect(url).toContain("/domains/d1/apply") + expect(url).toBe(`/api/v1/projects/${PROJECT_ID}/domains/d1/apply`) expect(String((opts as RequestInit).body)).toContain("applyPrunes") }) + + it("checkDomain(projectId, id) hits project-scoped check path", async () => { + const spy = mockFetch({ updates: [], prunes: [], readOnly: [], inSyncCount: 0 }) + await api.checkDomain(PROJECT_ID, "d1") + expect(spy).toHaveBeenCalledWith( + `/api/v1/projects/${PROJECT_ID}/domains/d1/check`, + expect.objectContaining({ method: "GET" }), + ) + }) + + it("importZones(projectId, accountId) hits project-scoped import path", async () => { + const spy = mockFetch([]) + await api.importZones(PROJECT_ID, "acc1") + expect(spy).toHaveBeenCalledWith( + `/api/v1/projects/${PROJECT_ID}/accounts/acc1/import`, + expect.objectContaining({ method: "POST" }), + ) + }) + + it("setDomainTemplate(projectId, id, templateId) hits project-scoped domain path", async () => { + const spy = mockFetch({ id: "d1", providerAccountId: "acc1", zoneName: "x.", zoneId: "z1" }) + await api.setDomainTemplate(PROJECT_ID, "d1", "t1") + const [url, opts] = spy.mock.calls[0] + expect(url).toBe(`/api/v1/projects/${PROJECT_ID}/domains/d1`) + expect((opts as RequestInit).method).toBe("PATCH") + }) }) diff --git a/web/src/api/client.ts b/web/src/api/client.ts index 771dbce..748581d 100644 --- a/web/src/api/client.ts +++ b/web/src/api/client.ts @@ -1,15 +1,25 @@ -import { API_BASE } from "@/lib/config" +import { API_ROOT } from "@/lib/config" import type { + AuthState, Account, CreateAccountInput, Template, CreateTemplateInput, Domain, CreateDomainInput, ChangesetResponse, ApplyRequest, } from "./types" +export class UnauthorizedError extends Error { + constructor() { + super("Unauthorized") + this.name = "UnauthorizedError" + } +} + async function req(path: string, init?: RequestInit): Promise { - const res = await fetch(`${API_BASE}${path}`, { + const res = await fetch(path, { headers: { "Content-Type": "application/json" }, method: "GET", + credentials: "include", ...init, }) + if (res.status === 401) throw new UnauthorizedError() if (!res.ok) { let msg = `HTTP ${res.status}` try { const b = await res.json(); if (b?.error) msg = String(b.error) } catch { /* ignore */ } @@ -19,29 +29,52 @@ async function req(path: string, init?: RequestInit): Promise { return (await res.json()) as T } -export const api = { - listAccounts: () => req("/accounts"), - createAccount: (input: CreateAccountInput) => - req("/accounts", { method: "POST", body: JSON.stringify(input) }), - deleteAccount: (id: string) => req(`/accounts/${id}`, { method: "DELETE" }), - - listTemplates: () => req("/templates"), - createTemplate: (input: CreateTemplateInput) => - req