diff --git a/internal/httpapi/auth.go b/internal/httpapi/auth.go new file mode 100644 index 0000000..6de0309 --- /dev/null +++ b/internal/httpapi/auth.go @@ -0,0 +1,67 @@ +package httpapi + +import ( + "crypto/subtle" + "encoding/json" + "net/http" + "time" + + "github.com/vasyansk/imap-copier/internal/config" + "github.com/vasyansk/imap-copier/internal/crypto" + "github.com/vasyansk/imap-copier/internal/orchestrator" + "github.com/vasyansk/imap-copier/internal/store" + "github.com/vasyansk/imap-copier/internal/wshub" +) + +const cookieName = "session" + +type Server struct { + cfg config.Config + store *store.Store + orch *orchestrator.Orchestrator + hub *wshub.Hub +} + +func NewServer(cfg config.Config, s *store.Store, orch *orchestrator.Orchestrator, hub *wshub.Hub) *Server { + return &Server{cfg: cfg, store: s, orch: orch, hub: hub} +} + +func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) { + var body struct{ User, Pass string } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, "bad json", http.StatusBadRequest) + return + } + uOK := subtle.ConstantTimeCompare([]byte(body.User), []byte(s.cfg.AuthUser)) == 1 + pOK := subtle.ConstantTimeCompare([]byte(body.Pass), []byte(s.cfg.AuthPass)) == 1 + if !uOK || !pOK { + http.Error(w, "invalid credentials", http.StatusUnauthorized) + return + } + tok := crypto.SignSession(s.cfg.SessionSecret, body.User, time.Now().Add(24*time.Hour)) + http.SetCookie(w, &http.Cookie{ + Name: cookieName, Value: tok, Path: "/", + HttpOnly: true, SameSite: http.SameSiteLaxMode, MaxAge: 86400, + }) + w.WriteHeader(http.StatusOK) +} + +func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) { + http.SetCookie(w, &http.Cookie{Name: cookieName, Value: "", Path: "/", MaxAge: -1}) + w.WriteHeader(http.StatusOK) +} + +func (s *Server) requireAuth(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := r.Cookie(cookieName) + if err != nil { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if _, ok := crypto.VerifySession(s.cfg.SessionSecret, c.Value, time.Now()); !ok { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + next.ServeHTTP(w, r) + }) +} diff --git a/internal/httpapi/auth_test.go b/internal/httpapi/auth_test.go new file mode 100644 index 0000000..71b9c82 --- /dev/null +++ b/internal/httpapi/auth_test.go @@ -0,0 +1,57 @@ +package httpapi + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/vasyansk/imap-copier/internal/config" +) + +func testServer() *Server { + return &Server{cfg: config.Config{ + AuthUser: "admin", AuthPass: "pw", SessionSecret: []byte("sekret"), + }} +} + +func TestLoginSetsCookie(t *testing.T) { + s := testServer() + req := httptest.NewRequest("POST", "/api/login", strings.NewReader(`{"user":"admin","pass":"pw"}`)) + rw := httptest.NewRecorder() + s.handleLogin(rw, req) + if rw.Code != http.StatusOK { + t.Fatalf("code=%d", rw.Code) + } + if len(rw.Result().Cookies()) == 0 { + t.Fatal("no session cookie set") + } +} + +func TestRequireAuthBlocksNoCookie(t *testing.T) { + s := testServer() + h := s.requireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) })) + rw := httptest.NewRecorder() + h.ServeHTTP(rw, httptest.NewRequest("GET", "/api/tasks", nil)) + if rw.Code != http.StatusUnauthorized { + t.Fatalf("want 401, got %d", rw.Code) + } +} + +func TestRequireAuthAllowsValidCookie(t *testing.T) { + s := testServer() + // логинимся, забираем cookie, повторяем запрос + lr := httptest.NewRequest("POST", "/api/login", strings.NewReader(`{"user":"admin","pass":"pw"}`)) + lrw := httptest.NewRecorder() + s.handleLogin(lrw, lr) + cookie := lrw.Result().Cookies()[0] + + h := s.requireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) })) + req := httptest.NewRequest("GET", "/api/tasks", nil) + req.AddCookie(cookie) + rw := httptest.NewRecorder() + h.ServeHTTP(rw, req) + if rw.Code != 200 { + t.Fatalf("want 200, got %d", rw.Code) + } +}