367 lines
13 KiB
Go
367 lines
13 KiB
Go
package notify
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/vasyakrg/dns-autoresolver/internal/store"
|
|
)
|
|
|
|
func TestTelegramSendSuccess(t *testing.T) {
|
|
var gotPath string
|
|
var gotBody map[string]string
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
gotPath = r.URL.Path
|
|
_ = json.NewDecoder(r.Body).Decode(&gotBody)
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
tg := &Telegram{BaseURL: srv.URL, HTTP: srv.Client()}
|
|
ev := Event{Project: "proj", Domain: "example.com", Status: "drift", Summary: "A record changed", At: time.Now()}
|
|
|
|
err := tg.Send(context.Background(), json.RawMessage(`{"chat_id":"12345"}`), "sekret-token", ev)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if gotPath != "/botsekret-token/sendMessage" {
|
|
t.Fatalf("unexpected path: %s", gotPath)
|
|
}
|
|
if gotBody["chat_id"] != "12345" {
|
|
t.Fatalf("unexpected chat_id: %+v", gotBody)
|
|
}
|
|
if !strings.Contains(gotBody["text"], "example.com") || !strings.Contains(gotBody["text"], "drift") {
|
|
t.Fatalf("unexpected text: %+v", gotBody)
|
|
}
|
|
}
|
|
|
|
func TestTelegramSendServerError(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
tg := &Telegram{BaseURL: srv.URL, HTTP: srv.Client()}
|
|
ev := Event{Project: "proj", Domain: "example.com", Status: "drift", Summary: "x", At: time.Now()}
|
|
|
|
if err := tg.Send(context.Background(), json.RawMessage(`{"chat_id":"1"}`), "tok", ev); err == nil {
|
|
t.Fatal("expected error on 500 response")
|
|
}
|
|
}
|
|
|
|
func TestTelegramSendTransportErrorDoesNotLeakSecret(t *testing.T) {
|
|
// Bind and immediately close a server: its address is now unreachable
|
|
// (connection refused), which makes http.Client.Do return a *url.Error
|
|
// whose Error() embeds the full request URL — including /bot<secret>/.
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
|
deadURL := srv.URL
|
|
srv.Close()
|
|
|
|
tg := &Telegram{BaseURL: deadURL, HTTP: srv.Client()}
|
|
ev := Event{Project: "proj", Domain: "example.com", Status: "drift", Summary: "x", At: time.Now()}
|
|
|
|
const secret = "super-secret-bot-token"
|
|
err := tg.Send(context.Background(), json.RawMessage(`{"chat_id":"1"}`), secret, ev)
|
|
if err == nil {
|
|
t.Fatal("expected error for unreachable host")
|
|
}
|
|
if strings.Contains(err.Error(), secret) {
|
|
t.Fatalf("error leaks secret: %v", err)
|
|
}
|
|
if strings.Contains(err.Error(), deadURL) {
|
|
t.Fatalf("error leaks request URL: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestWebhookSendSuccess(t *testing.T) {
|
|
var gotEvent Event
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
t.Errorf("unexpected method: %s", r.Method)
|
|
}
|
|
_ = json.NewDecoder(r.Body).Decode(&gotEvent)
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
// allowPrivate: true — httptest.Server listens on 127.0.0.1, which the
|
|
// SSRF guard would otherwise reject; production dispatchers never set
|
|
// this (see TestIsAllowedURL / TestNewDispatcherDoesNotAllowPrivate).
|
|
wh := &Webhook{HTTP: srv.Client(), allowPrivate: true}
|
|
ev := Event{Project: "proj", Domain: "example.com", Status: "in_sync", Summary: "resolved", At: time.Now()}
|
|
|
|
cfg, _ := json.Marshal(map[string]string{"url": srv.URL})
|
|
if err := wh.Send(context.Background(), cfg, "", ev); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if gotEvent.Domain != "example.com" || gotEvent.Status != "in_sync" {
|
|
t.Fatalf("unexpected event delivered: %+v", gotEvent)
|
|
}
|
|
}
|
|
|
|
func TestWebhookSendNonSuccessStatus(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
wh := &Webhook{HTTP: srv.Client(), allowPrivate: true}
|
|
ev := Event{Project: "proj", Domain: "example.com", Status: "drift", Summary: "x", At: time.Now()}
|
|
cfg, _ := json.Marshal(map[string]string{"url": srv.URL})
|
|
|
|
if err := wh.Send(context.Background(), cfg, "", ev); err == nil {
|
|
t.Fatal("expected error on 400 response")
|
|
}
|
|
}
|
|
|
|
func TestWebhookSendRejectsPrivateDestinationByDefault(t *testing.T) {
|
|
wh := &Webhook{HTTP: http.DefaultClient} // allowPrivate not set: SSRF guard active
|
|
ev := Event{Project: "proj", Domain: "example.com", Status: "drift", Summary: "x", At: time.Now()}
|
|
cfg, _ := json.Marshal(map[string]string{"url": "http://127.0.0.1:1/hook"})
|
|
|
|
err := wh.Send(context.Background(), cfg, "", ev)
|
|
if err == nil {
|
|
t.Fatal("expected error for loopback destination")
|
|
}
|
|
if !strings.Contains(err.Error(), "destination not allowed") {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestIsAllowedURL(t *testing.T) {
|
|
cases := []struct {
|
|
name string
|
|
rawurl string
|
|
allowed bool
|
|
}{
|
|
{"localhost hostname", "http://localhost/hook", false},
|
|
{"loopback ip", "http://127.0.0.1/hook", false},
|
|
{"loopback ipv6", "http://[::1]/hook", false},
|
|
{"link-local metadata", "http://169.254.169.254/latest/meta-data", false},
|
|
{"private class a", "http://10.0.0.1/hook", false},
|
|
{"private class c", "http://192.168.1.1/hook", false},
|
|
{"private class b", "http://172.16.0.1/hook", false},
|
|
{"unspecified", "http://0.0.0.0/hook", false},
|
|
{"multicast", "http://224.0.0.1/hook", false},
|
|
{"non-http scheme", "ftp://example.com/hook", false},
|
|
{"public ip", "http://93.184.216.34/hook", true},
|
|
}
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
err := isAllowedURL(tc.rawurl)
|
|
if tc.allowed && err != nil {
|
|
t.Fatalf("expected %q to be allowed, got error: %v", tc.rawurl, err)
|
|
}
|
|
if !tc.allowed && err == nil {
|
|
t.Fatalf("expected %q to be rejected, got nil error", tc.rawurl)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDialControlBlocksActualConnectingAddress(t *testing.T) {
|
|
cases := []struct {
|
|
name string
|
|
address string
|
|
blocked bool
|
|
}{
|
|
{"loopback v4", "127.0.0.1:80", true},
|
|
{"loopback v6", "[::1]:80", true},
|
|
{"metadata link-local", "169.254.169.254:80", true},
|
|
{"private class a", "10.0.0.1:80", true},
|
|
{"private class b", "172.16.0.1:80", true},
|
|
{"private class c", "192.168.1.1:80", true},
|
|
{"unspecified", "0.0.0.0:80", true},
|
|
{"multicast", "224.0.0.1:80", true},
|
|
{"public ip", "93.184.216.34:443", false},
|
|
}
|
|
control := dialControl(false)
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
err := control("tcp", tc.address, nil)
|
|
if tc.blocked && err == nil {
|
|
t.Fatalf("expected %q to be blocked", tc.address)
|
|
}
|
|
if !tc.blocked && err != nil {
|
|
t.Fatalf("expected %q to be allowed, got: %v", tc.address, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDialControlAllowsEverythingWhenAllowPrivate(t *testing.T) {
|
|
control := dialControl(true)
|
|
if err := control("tcp", "127.0.0.1:80", nil); err != nil {
|
|
t.Fatalf("expected allowPrivate to skip the dial guard, got: %v", err)
|
|
}
|
|
}
|
|
|
|
// TestWebhookControlBlocksConnectionEvenWhenPreCheckPasses simulates the
|
|
// DNS-rebinding TOCTOU: allowPrivate=true skips the pre-request isAllowedURL
|
|
// check (standing in for a rebinding attacker answering a public IP to that
|
|
// lookup), but the Transport's Control func — wired independently of
|
|
// Webhook.allowPrivate — still inspects the literal address the dialer
|
|
// connects to and must still reject it. If Control did not exist, this
|
|
// request would reach the httptest handler; it must not.
|
|
func TestWebhookControlBlocksConnectionEvenWhenPreCheckPasses(t *testing.T) {
|
|
var handlerCalled bool
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
handlerCalled = true
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
wh := &Webhook{
|
|
HTTP: &http.Client{Transport: newWebhookTransport(false)},
|
|
allowPrivate: true, // pre-check bypassed on purpose; Control is not
|
|
}
|
|
ev := Event{Project: "proj", Domain: "example.com", Status: "drift", Summary: "x", At: time.Now()}
|
|
cfg, _ := json.Marshal(map[string]string{"url": srv.URL})
|
|
|
|
err := wh.Send(context.Background(), cfg, "", ev)
|
|
if err == nil {
|
|
t.Fatal("expected error: Control should have blocked the loopback connection")
|
|
}
|
|
if !strings.Contains(err.Error(), "destination not allowed") {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if handlerCalled {
|
|
t.Fatal("Control should have rejected the dial before the handler ran")
|
|
}
|
|
}
|
|
|
|
// --- Dispatcher ---
|
|
|
|
type mockChannelStore struct {
|
|
channels []store.Channel
|
|
err error
|
|
}
|
|
|
|
func (m *mockChannelStore) ListEnabledChannels(ctx context.Context, projectID uuid.UUID) ([]store.Channel, error) {
|
|
return m.channels, m.err
|
|
}
|
|
|
|
type mockDecryptor struct {
|
|
fail bool
|
|
}
|
|
|
|
func (m *mockDecryptor) Decrypt(enc string) ([]byte, error) {
|
|
if m.fail {
|
|
return nil, errBoom
|
|
}
|
|
return []byte("decrypted-" + enc), nil
|
|
}
|
|
|
|
var errBoom = &boomErr{}
|
|
|
|
type boomErr struct{}
|
|
|
|
func (*boomErr) Error() string { return "decrypt boom" }
|
|
|
|
func TestDispatcherSendsToAllChannelsAndAggregatesErrors(t *testing.T) {
|
|
var tgCalled, whCalled bool
|
|
var tgSecret string
|
|
|
|
tgSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
tgCalled = true
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer tgSrv.Close()
|
|
|
|
whSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
whCalled = true
|
|
w.WriteHeader(http.StatusInternalServerError) // webhook fails
|
|
}))
|
|
defer whSrv.Close()
|
|
|
|
projectID := uuid.New()
|
|
channels := []store.Channel{
|
|
{
|
|
ID: uuid.New(), ProjectID: projectID, Type: "telegram",
|
|
Config: json.RawMessage(`{"chat_id":"1"}`), SecretEnc: "enc-token", Enabled: true,
|
|
},
|
|
{
|
|
ID: uuid.New(), ProjectID: projectID, Type: "webhook",
|
|
Config: json.RawMessage(`{"url":"` + whSrv.URL + `"}`), SecretEnc: "", Enabled: true,
|
|
},
|
|
}
|
|
|
|
d := NewDispatcher(&mockChannelStore{channels: channels}, &mockDecryptor{})
|
|
// Redirect telegram to the httptest server and capture the decrypted secret.
|
|
d.byType["telegram"] = notifierFunc(func(ctx context.Context, cfg json.RawMessage, secret string, ev Event) error {
|
|
tgSecret = secret
|
|
tg := &Telegram{BaseURL: tgSrv.URL, HTTP: tgSrv.Client()}
|
|
return tg.Send(ctx, cfg, secret, ev)
|
|
})
|
|
// httptest servers listen on loopback, which the SSRF guard rejects by
|
|
// default; swap in an allowPrivate webhook so this test can still hit it.
|
|
d.byType["webhook"] = &Webhook{HTTP: whSrv.Client(), allowPrivate: true}
|
|
|
|
ev := Event{Project: "proj", Domain: "example.com", Status: "drift", Summary: "changed", At: time.Now()}
|
|
err := d.Send(context.Background(), projectID, ev)
|
|
|
|
if !tgCalled {
|
|
t.Error("expected telegram notifier to be called")
|
|
}
|
|
if !whCalled {
|
|
t.Error("expected webhook notifier to be called")
|
|
}
|
|
if err == nil {
|
|
t.Fatal("expected aggregated error because webhook failed")
|
|
}
|
|
if tgSecret != "decrypted-enc-token" {
|
|
t.Fatalf("expected decrypted secret to be passed to telegram, got %q", tgSecret)
|
|
}
|
|
}
|
|
|
|
func TestDispatcherSkipsUnknownChannelType(t *testing.T) {
|
|
projectID := uuid.New()
|
|
channels := []store.Channel{
|
|
{ID: uuid.New(), ProjectID: projectID, Type: "carrier-pigeon", Config: json.RawMessage(`{}`), Enabled: true},
|
|
}
|
|
d := NewDispatcher(&mockChannelStore{channels: channels}, &mockDecryptor{})
|
|
if err := d.Send(context.Background(), projectID, Event{Project: "p", Domain: "d", Status: "drift"}); err != nil {
|
|
t.Fatalf("unexpected error for unknown channel type: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestDispatcherDecryptFailureIsAggregatedNotFatal(t *testing.T) {
|
|
var whCalled bool
|
|
whSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
whCalled = true
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer whSrv.Close()
|
|
|
|
projectID := uuid.New()
|
|
channels := []store.Channel{
|
|
{ID: uuid.New(), ProjectID: projectID, Type: "telegram", Config: json.RawMessage(`{"chat_id":"1"}`), SecretEnc: "enc", Enabled: true},
|
|
{ID: uuid.New(), ProjectID: projectID, Type: "webhook", Config: json.RawMessage(`{"url":"` + whSrv.URL + `"}`), Enabled: true},
|
|
}
|
|
d := NewDispatcher(&mockChannelStore{channels: channels}, &mockDecryptor{fail: true})
|
|
// httptest servers listen on loopback, which the SSRF guard rejects by
|
|
// default; swap in an allowPrivate webhook so this test can still hit it.
|
|
d.byType["webhook"] = &Webhook{HTTP: whSrv.Client(), allowPrivate: true}
|
|
|
|
err := d.Send(context.Background(), projectID, Event{Project: "p", Domain: "d", Status: "drift"})
|
|
if err == nil {
|
|
t.Fatal("expected error due to decrypt failure")
|
|
}
|
|
if !whCalled {
|
|
t.Error("expected webhook channel to still be attempted after telegram decrypt failure")
|
|
}
|
|
}
|
|
|
|
// notifierFunc adapts a function to the Notifier interface for tests.
|
|
type notifierFunc func(ctx context.Context, cfg json.RawMessage, secret string, ev Event) error
|
|
|
|
func (f notifierFunc) Send(ctx context.Context, cfg json.RawMessage, secret string, ev Event) error {
|
|
return f(ctx, cfg, secret, ev)
|
|
}
|