76 lines
2.3 KiB
Go
76 lines
2.3 KiB
Go
package api
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
type ctxKey string
|
|
|
|
const (
|
|
ctxUserID ctxKey = "userID"
|
|
ctxProjectID ctxKey = "projectID"
|
|
)
|
|
|
|
// RequireAuth validates the session cookie and, on success, stores the
|
|
// authenticated user's ID in the request context (see userIDFrom). Any
|
|
// failure — missing cookie or an invalid/expired token — is rejected with
|
|
// 401 before reaching the wrapped handler.
|
|
func (a *API) RequireAuth(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
c, err := r.Cookie(sessionCookieName)
|
|
if err != nil {
|
|
writeErr(w, http.StatusUnauthorized, "unauthorized")
|
|
return
|
|
}
|
|
uid, err := a.Sessions.Validate(r.Context(), c.Value)
|
|
if err != nil {
|
|
writeErr(w, http.StatusUnauthorized, "unauthorized")
|
|
return
|
|
}
|
|
ctx := context.WithValue(r.Context(), ctxUserID, uid)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
|
|
// RequireProjectAccess verifies that the {pid} URL segment names a project
|
|
// owned by the authenticated user (set by RequireAuth, which must run
|
|
// first) and, on success, stores the project ID in the request context (see
|
|
// projectIDFrom). A project that doesn't exist or isn't owned by the caller
|
|
// is rejected with 404 — not 403 — so a caller can't distinguish "not yours"
|
|
// from "doesn't exist" (closes IDOR by never confirming another tenant's
|
|
// project exists).
|
|
func (a *API) RequireProjectAccess(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
uid, ok := userIDFrom(r.Context())
|
|
if !ok {
|
|
writeErr(w, http.StatusUnauthorized, "unauthorized")
|
|
return
|
|
}
|
|
pid, err := uuid.Parse(chi.URLParam(r, "pid"))
|
|
if err != nil {
|
|
writeErr(w, http.StatusBadRequest, "invalid project id")
|
|
return
|
|
}
|
|
if _, err := a.Auth.GetProjectOwned(r.Context(), pid, uid); err != nil {
|
|
writeErr(w, http.StatusNotFound, "not found")
|
|
return
|
|
}
|
|
ctx := context.WithValue(r.Context(), ctxProjectID, pid)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
|
|
func userIDFrom(ctx context.Context) (uuid.UUID, bool) {
|
|
v, ok := ctx.Value(ctxUserID).(uuid.UUID)
|
|
return v, ok
|
|
}
|
|
|
|
func projectIDFrom(ctx context.Context) (uuid.UUID, bool) {
|
|
v, ok := ctx.Value(ctxProjectID).(uuid.UUID)
|
|
return v, ok
|
|
}
|