package api import ( "context" "net/http" "github.com/go-chi/chi/v5" "github.com/google/uuid" ) type ctxKey string const ( ctxUserID ctxKey = "userID" ctxProjectID ctxKey = "projectID" ) // RequireAuth validates the session cookie and, on success, stores the // authenticated user's ID in the request context (see userIDFrom). Any // failure — missing cookie or an invalid/expired token — is rejected with // 401 before reaching the wrapped handler. func (a *API) RequireAuth(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := r.Cookie(sessionCookieName) if err != nil { writeErr(w, http.StatusUnauthorized, "unauthorized") return } uid, err := a.Sessions.Validate(r.Context(), c.Value) if err != nil { writeErr(w, http.StatusUnauthorized, "unauthorized") return } ctx := context.WithValue(r.Context(), ctxUserID, uid) next.ServeHTTP(w, r.WithContext(ctx)) }) } // RequireProjectAccess verifies that the {pid} URL segment names a project // owned by the authenticated user (set by RequireAuth, which must run // first) and, on success, stores the project ID in the request context (see // projectIDFrom). A project that doesn't exist or isn't owned by the caller // is rejected with 404 — not 403 — so a caller can't distinguish "not yours" // from "doesn't exist" (closes IDOR by never confirming another tenant's // project exists). func (a *API) RequireProjectAccess(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { uid, ok := userIDFrom(r.Context()) if !ok { writeErr(w, http.StatusUnauthorized, "unauthorized") return } pid, err := uuid.Parse(chi.URLParam(r, "pid")) if err != nil { writeErr(w, http.StatusBadRequest, "invalid project id") return } if _, err := a.Auth.GetProjectOwned(r.Context(), pid, uid); err != nil { writeErr(w, http.StatusNotFound, "not found") return } ctx := context.WithValue(r.Context(), ctxProjectID, pid) next.ServeHTTP(w, r.WithContext(ctx)) }) } func userIDFrom(ctx context.Context) (uuid.UUID, bool) { v, ok := ctx.Value(ctxUserID).(uuid.UUID) return v, ok } func projectIDFrom(ctx context.Context) (uuid.UUID, bool) { v, ok := ctx.Value(ctxProjectID).(uuid.UUID) return v, ok }