diff --git a/internal/httpapi/ws.go b/internal/httpapi/ws.go index 177ced9..f75a9a1 100644 --- a/internal/httpapi/ws.go +++ b/internal/httpapi/ws.go @@ -25,7 +25,10 @@ func (s *Server) handleWS(w http.ResponseWriter, r *http.Request) { subID, ch := s.hub.Subscribe(taskID) defer s.hub.Unsubscribe(taskID, subID) - ctx := r.Context() + // websocket.Accept хайджекает соединение, поэтому r.Context() не отменяется + // при обрыве связи клиентом. CloseRead запускает фоновое чтение control-фреймов + // и отменяет возвращаемый контекст, когда соединение действительно умирает. + ctx := c.CloseRead(r.Context()) for { select { case ev, ok := <-ch: diff --git a/internal/httpapi/ws_test.go b/internal/httpapi/ws_test.go new file mode 100644 index 0000000..34b6587 --- /dev/null +++ b/internal/httpapi/ws_test.go @@ -0,0 +1,63 @@ +package httpapi + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/coder/websocket" + "github.com/vasyansk/imap-copier/internal/config" + "github.com/vasyansk/imap-copier/internal/crypto" + "github.com/vasyansk/imap-copier/internal/wshub" +) + +func TestWSRequiresAuth(t *testing.T) { + s := &Server{cfg: config.Config{SessionSecret: []byte("x")}, hub: wshub.New()} + srv := httptest.NewServer(s.Router()) + defer srv.Close() + // no cookie -> upgrade rejected (401) + _, resp, err := websocket.Dial(context.Background(), "ws"+srv.URL[4:]+"/ws?task_id=1", nil) + if err == nil { + t.Fatal("expected auth rejection") + } + if resp != nil && resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("want 401, got %d", resp.StatusCode) + } +} + +func TestWSUnsubscribesOnClientDisconnect(t *testing.T) { + hub := wshub.New() + secret := []byte("sekret") + s := &Server{cfg: config.Config{AuthUser: "admin", SessionSecret: secret}, hub: hub} + srv := httptest.NewServer(s.Router()) + defer srv.Close() + + tok := crypto.SignSession(secret, "admin", time.Now().Add(time.Hour)) + hdr := http.Header{} + hdr.Set("Cookie", cookieName+"="+tok) + + ctx := context.Background() + c, _, err := websocket.Dial(ctx, "ws"+srv.URL[4:]+"/ws?task_id=7", &websocket.DialOptions{HTTPHeader: hdr}) + if err != nil { + t.Fatalf("dial: %v", err) + } + // wait until subscribed + deadline := time.Now().Add(2 * time.Second) + for hub.SubscriberCount(7) == 0 { + if time.Now().After(deadline) { + t.Fatal("never subscribed") + } + time.Sleep(10 * time.Millisecond) + } + // abrupt client close -> server must detect and unsubscribe + c.CloseNow() + deadline = time.Now().Add(3 * time.Second) + for hub.SubscriberCount(7) != 0 { + if time.Now().After(deadline) { + t.Fatal("subscription leaked after client disconnect") + } + time.Sleep(20 * time.Millisecond) + } +} diff --git a/internal/wshub/wshub.go b/internal/wshub/wshub.go index ba67cb7..130b12a 100644 --- a/internal/wshub/wshub.go +++ b/internal/wshub/wshub.go @@ -45,6 +45,13 @@ func (h *Hub) Unsubscribe(taskID, id int64) { } } +// SubscriberCount returns the number of active subscribers for a task (for tests/metrics). +func (h *Hub) SubscriberCount(taskID int64) int { + h.mu.Lock() + defer h.mu.Unlock() + return len(h.subs[taskID]) +} + func (h *Hub) Publish(ev Event) { h.mu.Lock() defer h.mu.Unlock()