package notify import ( "context" "encoding/json" "errors" "net" "net/http" "net/http/httptest" "strings" "testing" "time" "github.com/google/uuid" "github.com/vasyakrg/dns-autoresolver/internal/store" ) func TestTelegramSendSuccess(t *testing.T) { var gotPath string var gotBody map[string]string srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotPath = r.URL.Path _ = json.NewDecoder(r.Body).Decode(&gotBody) w.WriteHeader(http.StatusOK) })) defer srv.Close() tg := &Telegram{BaseURL: srv.URL, HTTP: srv.Client()} ev := Event{Project: "proj", Domain: "example.com", Status: "drift", Summary: "A record changed", At: time.Now()} err := tg.Send(context.Background(), json.RawMessage(`{"chat_id":"12345"}`), "sekret-token", ev) if err != nil { t.Fatalf("unexpected error: %v", err) } if gotPath != "/botsekret-token/sendMessage" { t.Fatalf("unexpected path: %s", gotPath) } if gotBody["chat_id"] != "12345" { t.Fatalf("unexpected chat_id: %+v", gotBody) } if !strings.Contains(gotBody["text"], "example.com") || !strings.Contains(gotBody["text"], "drift") { t.Fatalf("unexpected text: %+v", gotBody) } } func TestTelegramSendServerError(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) })) defer srv.Close() tg := &Telegram{BaseURL: srv.URL, HTTP: srv.Client()} ev := Event{Project: "proj", Domain: "example.com", Status: "drift", Summary: "x", At: time.Now()} if err := tg.Send(context.Background(), json.RawMessage(`{"chat_id":"1"}`), "tok", ev); err == nil { t.Fatal("expected error on 500 response") } } func TestTelegramSendTransportErrorDoesNotLeakSecret(t *testing.T) { // Bind and immediately close a server: its address is now unreachable // (connection refused), which makes http.Client.Do return a *url.Error // whose Error() embeds the full request URL — including /bot/. srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) deadURL := srv.URL srv.Close() tg := &Telegram{BaseURL: deadURL, HTTP: srv.Client()} ev := Event{Project: "proj", Domain: "example.com", Status: "drift", Summary: "x", At: time.Now()} const secret = "super-secret-bot-token" err := tg.Send(context.Background(), json.RawMessage(`{"chat_id":"1"}`), secret, ev) if err == nil { t.Fatal("expected error for unreachable host") } if strings.Contains(err.Error(), secret) { t.Fatalf("error leaks secret: %v", err) } if strings.Contains(err.Error(), deadURL) { t.Fatalf("error leaks request URL: %v", err) } } func TestWebhookSendSuccess(t *testing.T) { var gotEvent Event srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { t.Errorf("unexpected method: %s", r.Method) } _ = json.NewDecoder(r.Body).Decode(&gotEvent) w.WriteHeader(http.StatusOK) })) defer srv.Close() // allowPrivate: true — httptest.Server listens on 127.0.0.1, which the // SSRF guard would otherwise reject; production dispatchers never set // this (see TestIsAllowedURL / TestNewDispatcherDoesNotAllowPrivate). wh := &Webhook{HTTP: srv.Client(), allowPrivate: true} ev := Event{Project: "proj", Domain: "example.com", Status: "in_sync", Summary: "resolved", At: time.Now()} cfg, _ := json.Marshal(map[string]string{"url": srv.URL}) if err := wh.Send(context.Background(), cfg, "", ev); err != nil { t.Fatalf("unexpected error: %v", err) } if gotEvent.Domain != "example.com" || gotEvent.Status != "in_sync" { t.Fatalf("unexpected event delivered: %+v", gotEvent) } } func TestWebhookSendNonSuccessStatus(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadRequest) })) defer srv.Close() wh := &Webhook{HTTP: srv.Client(), allowPrivate: true} ev := Event{Project: "proj", Domain: "example.com", Status: "drift", Summary: "x", At: time.Now()} cfg, _ := json.Marshal(map[string]string{"url": srv.URL}) if err := wh.Send(context.Background(), cfg, "", ev); err == nil { t.Fatal("expected error on 400 response") } } func TestWebhookSendRejectsPrivateDestinationByDefault(t *testing.T) { wh := &Webhook{HTTP: http.DefaultClient} // allowPrivate not set: SSRF guard active ev := Event{Project: "proj", Domain: "example.com", Status: "drift", Summary: "x", At: time.Now()} cfg, _ := json.Marshal(map[string]string{"url": "http://127.0.0.1:1/hook"}) err := wh.Send(context.Background(), cfg, "", ev) if err == nil { t.Fatal("expected error for loopback destination") } if !strings.Contains(err.Error(), "destination not allowed") { t.Fatalf("unexpected error: %v", err) } } func TestIsAllowedURL(t *testing.T) { cases := []struct { name string rawurl string allowed bool }{ {"localhost hostname", "http://localhost/hook", false}, {"loopback ip", "http://127.0.0.1/hook", false}, {"loopback ipv6", "http://[::1]/hook", false}, {"link-local metadata", "http://169.254.169.254/latest/meta-data", false}, {"private class a", "http://10.0.0.1/hook", false}, {"private class c", "http://192.168.1.1/hook", false}, {"private class b", "http://172.16.0.1/hook", false}, {"unspecified", "http://0.0.0.0/hook", false}, {"multicast", "http://224.0.0.1/hook", false}, {"non-http scheme", "ftp://example.com/hook", false}, {"public ip", "http://93.184.216.34/hook", true}, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { err := isAllowedURL(tc.rawurl) if tc.allowed && err != nil { t.Fatalf("expected %q to be allowed, got error: %v", tc.rawurl, err) } if !tc.allowed && err == nil { t.Fatalf("expected %q to be rejected, got nil error", tc.rawurl) } }) } } func TestDialControlBlocksActualConnectingAddress(t *testing.T) { cases := []struct { name string address string blocked bool }{ {"loopback v4", "127.0.0.1:80", true}, {"loopback v6", "[::1]:80", true}, {"metadata link-local", "169.254.169.254:80", true}, {"private class a", "10.0.0.1:80", true}, {"private class b", "172.16.0.1:80", true}, {"private class c", "192.168.1.1:80", true}, {"unspecified", "0.0.0.0:80", true}, {"multicast", "224.0.0.1:80", true}, {"public ip", "93.184.216.34:443", false}, } control := dialControl(false) for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { err := control("tcp", tc.address, nil) if tc.blocked && err == nil { t.Fatalf("expected %q to be blocked", tc.address) } if !tc.blocked && err != nil { t.Fatalf("expected %q to be allowed, got: %v", tc.address, err) } }) } } func TestIsBlockedIPCGNATRange(t *testing.T) { cases := []struct { name string ip string blocked bool }{ {"cgnat start", "100.64.0.1", true}, {"cgnat end", "100.127.255.255", true}, {"just below cgnat", "100.63.255.255", false}, {"just above cgnat", "100.128.0.0", false}, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { ip := net.ParseIP(tc.ip) if ip == nil { t.Fatalf("failed to parse %q", tc.ip) } got := isBlockedIP(ip) if got != tc.blocked { t.Fatalf("isBlockedIP(%q) = %v, want %v", tc.ip, got, tc.blocked) } }) } } func TestDialControlAllowsEverythingWhenAllowPrivate(t *testing.T) { control := dialControl(true) if err := control("tcp", "127.0.0.1:80", nil); err != nil { t.Fatalf("expected allowPrivate to skip the dial guard, got: %v", err) } } // TestWebhookControlBlocksConnectionEvenWhenPreCheckPasses simulates the // DNS-rebinding TOCTOU: allowPrivate=true skips the pre-request isAllowedURL // check (standing in for a rebinding attacker answering a public IP to that // lookup), but the Transport's Control func — wired independently of // Webhook.allowPrivate — still inspects the literal address the dialer // connects to and must still reject it. If Control did not exist, this // request would reach the httptest handler; it must not. func TestWebhookControlBlocksConnectionEvenWhenPreCheckPasses(t *testing.T) { var handlerCalled bool srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handlerCalled = true w.WriteHeader(http.StatusOK) })) defer srv.Close() wh := &Webhook{ HTTP: &http.Client{Transport: newWebhookTransport(false)}, allowPrivate: true, // pre-check bypassed on purpose; Control is not } ev := Event{Project: "proj", Domain: "example.com", Status: "drift", Summary: "x", At: time.Now()} cfg, _ := json.Marshal(map[string]string{"url": srv.URL}) err := wh.Send(context.Background(), cfg, "", ev) if err == nil { t.Fatal("expected error: Control should have blocked the loopback connection") } if !strings.Contains(err.Error(), "destination not allowed") { t.Fatalf("unexpected error: %v", err) } if handlerCalled { t.Fatal("Control should have rejected the dial before the handler ran") } } // --- Dispatcher --- type mockChannelStore struct { channels []store.Channel err error } func (m *mockChannelStore) ListEnabledChannels(ctx context.Context, projectID uuid.UUID) ([]store.Channel, error) { return m.channels, m.err } type mockDecryptor struct { fail bool } func (m *mockDecryptor) Decrypt(enc string) ([]byte, error) { if m.fail { return nil, errBoom } return []byte("decrypted-" + enc), nil } var errBoom = &boomErr{} type boomErr struct{} func (*boomErr) Error() string { return "decrypt boom" } func TestDispatcherSendsToAllChannelsAndAggregatesErrors(t *testing.T) { var tgCalled, whCalled bool var tgSecret string tgSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tgCalled = true w.WriteHeader(http.StatusOK) })) defer tgSrv.Close() whSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { whCalled = true w.WriteHeader(http.StatusInternalServerError) // webhook fails })) defer whSrv.Close() projectID := uuid.New() channels := []store.Channel{ { ID: uuid.New(), ProjectID: projectID, Type: "telegram", Config: json.RawMessage(`{"chat_id":"1"}`), SecretEnc: "enc-token", Enabled: true, }, { ID: uuid.New(), ProjectID: projectID, Type: "webhook", Config: json.RawMessage(`{"url":"` + whSrv.URL + `"}`), SecretEnc: "", Enabled: true, }, } d := NewDispatcher(&mockChannelStore{channels: channels}, &mockDecryptor{}) // Redirect telegram to the httptest server and capture the decrypted secret. d.byType["telegram"] = notifierFunc(func(ctx context.Context, cfg json.RawMessage, secret string, ev Event) error { tgSecret = secret tg := &Telegram{BaseURL: tgSrv.URL, HTTP: tgSrv.Client()} return tg.Send(ctx, cfg, secret, ev) }) // httptest servers listen on loopback, which the SSRF guard rejects by // default; swap in an allowPrivate webhook so this test can still hit it. d.byType["webhook"] = &Webhook{HTTP: whSrv.Client(), allowPrivate: true} ev := Event{Project: "proj", Domain: "example.com", Status: "drift", Summary: "changed", At: time.Now()} results, err := d.Send(context.Background(), projectID, ev) if !tgCalled { t.Error("expected telegram notifier to be called") } if !whCalled { t.Error("expected webhook notifier to be called") } if err == nil { t.Fatal("expected aggregated error because webhook failed") } if tgSecret != "decrypted-enc-token" { t.Fatalf("expected decrypted secret to be passed to telegram, got %q", tgSecret) } if len(results) != 2 { t.Fatalf("results = %d, want 2", len(results)) } byType := make(map[string]ChannelResult, len(results)) for _, r := range results { byType[r.Type] = r } if tg, ok := byType["telegram"]; !ok || tg.Err != nil { t.Fatalf("telegram result = %+v, want ok result", tg) } if wh, ok := byType["webhook"]; !ok || wh.Err == nil { t.Fatalf("webhook result = %+v, want error result", wh) } } // TestDispatcherSendReturnsPerChannelResults exercises the exact scenario // from the plan: one telegram channel succeeding, one webhook channel // failing at the Notifier — the metric consumer (scheduler) needs a result // per channel, not one aggregate blob, to record accurate per-channel/status // metrics. func TestDispatcherSendReturnsPerChannelResults(t *testing.T) { projectID := uuid.New() channels := []store.Channel{ {ID: uuid.New(), ProjectID: projectID, Type: "telegram", Config: json.RawMessage(`{"chat_id":"1"}`), Enabled: true}, {ID: uuid.New(), ProjectID: projectID, Type: "webhook", Config: json.RawMessage(`{"url":"http://x"}`), Enabled: true}, } d := NewDispatcher(&mockChannelStore{channels: channels}, &mockDecryptor{}) d.byType["telegram"] = notifierFunc(func(ctx context.Context, cfg json.RawMessage, secret string, ev Event) error { return nil }) d.byType["webhook"] = notifierFunc(func(ctx context.Context, cfg json.RawMessage, secret string, ev Event) error { return errBoom }) results, err := d.Send(context.Background(), projectID, Event{Project: "p", Domain: "d", Status: "drift"}) if err == nil { t.Fatal("expected aggregated error because webhook failed") } if len(results) != 2 { t.Fatalf("results = %d, want 2", len(results)) } if results[0].Type != "telegram" || results[0].Err != nil { t.Fatalf("results[0] = %+v, want telegram/nil", results[0]) } if results[1].Type != "webhook" || results[1].Err == nil { t.Fatalf("results[1] = %+v, want webhook/error", results[1]) } } func TestDispatcherSkipsUnknownChannelType(t *testing.T) { projectID := uuid.New() channels := []store.Channel{ {ID: uuid.New(), ProjectID: projectID, Type: "carrier-pigeon", Config: json.RawMessage(`{}`), Enabled: true}, } d := NewDispatcher(&mockChannelStore{channels: channels}, &mockDecryptor{}) results, err := d.Send(context.Background(), projectID, Event{Project: "p", Domain: "d", Status: "drift"}) if err != nil { t.Fatalf("unexpected error for unknown channel type: %v", err) } if len(results) != 0 { t.Fatalf("results = %d, want 0: unknown channel type must not produce a result", len(results)) } } func TestDispatcherDecryptFailureIsAggregatedNotFatal(t *testing.T) { var whCalled bool whSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { whCalled = true w.WriteHeader(http.StatusOK) })) defer whSrv.Close() projectID := uuid.New() channels := []store.Channel{ {ID: uuid.New(), ProjectID: projectID, Type: "telegram", Config: json.RawMessage(`{"chat_id":"1"}`), SecretEnc: "enc", Enabled: true}, {ID: uuid.New(), ProjectID: projectID, Type: "webhook", Config: json.RawMessage(`{"url":"` + whSrv.URL + `"}`), Enabled: true}, } d := NewDispatcher(&mockChannelStore{channels: channels}, &mockDecryptor{fail: true}) // httptest servers listen on loopback, which the SSRF guard rejects by // default; swap in an allowPrivate webhook so this test can still hit it. d.byType["webhook"] = &Webhook{HTTP: whSrv.Client(), allowPrivate: true} results, err := d.Send(context.Background(), projectID, Event{Project: "p", Domain: "d", Status: "drift"}) if err == nil { t.Fatal("expected error due to decrypt failure") } if !whCalled { t.Error("expected webhook channel to still be attempted after telegram decrypt failure") } if len(results) != 2 { t.Fatalf("results = %d, want 2", len(results)) } byType := make(map[string]ChannelResult, len(results)) for _, r := range results { byType[r.Type] = r } tg, ok := byType["telegram"] if !ok { t.Fatal("expected a telegram result") } if tg.Err == nil { t.Fatalf("telegram result = %+v, want decrypt error", tg) } if !errors.Is(tg.Err, errBoom) { t.Fatalf("telegram result err = %v, want errBoom", tg.Err) } wh, ok := byType["webhook"] if !ok { t.Fatal("expected a webhook result") } if wh.Err != nil { t.Fatalf("webhook result = %+v, want ok result (decrypt failure on telegram must not fail webhook)", wh) } } // notifierFunc adapts a function to the Notifier interface for tests. type notifierFunc func(ctx context.Context, cfg json.RawMessage, secret string, ev Event) error func (f notifierFunc) Send(ctx context.Context, cfg json.RawMessage, secret string, ev Event) error { return f(ctx, cfg, secret, ev) }