diff --git a/cmd/server/main.go b/cmd/server/main.go index c29fbb0..08ecbe9 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -8,6 +8,7 @@ import ( "os" "os/signal" "strings" + "sync" "syscall" "time" @@ -49,7 +50,57 @@ func isAPIPath(path string) bool { return path == "/api" || strings.HasPrefix(path, "/api/") } +// buildMux wires the public /healthz + /metrics endpoints, the API router, +// and the embedded SPA. /healthz and /metrics are intentionally auth-free — +// /healthz is a liveness probe (always 200 while the process serves), and +// metricsHandler only ever exposes aggregate counters/gauges. +func buildMux(metricsHandler http.Handler, apiRouter http.Handler, webHandler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.URL.Path == "/healthz": + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + case r.URL.Path == "/metrics": + metricsHandler.ServeHTTP(w, r) + case isAPIPath(r.URL.Path): + apiRouter.ServeHTTP(w, r) + case webHandler != nil: + webHandler.ServeHTTP(w, r) + default: + http.NotFound(w, r) + } + }) +} + +// healthcheck performs an in-process liveness probe used as the container +// HEALTHCHECK — distroless images have no curl/wget. It GETs /healthz on the +// configured listen address and maps 200 -> 0, anything else -> 1. +func healthcheck() int { + addr := os.Getenv("DNS_AR_LISTEN") + if addr == "" { + addr = ":8080" + } + // ":8080" -> "127.0.0.1:8080" + if strings.HasPrefix(addr, ":") { + addr = "127.0.0.1" + addr + } + c := &http.Client{Timeout: 3 * time.Second} + resp, err := c.Get("http://" + addr + "/healthz") + if err != nil { + return 1 + } + defer resp.Body.Close() + if resp.StatusCode == http.StatusOK { + return 0 + } + return 1 +} + func main() { + if len(os.Args) > 1 && os.Args[1] == "-healthcheck" { + os.Exit(healthcheck()) + } + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer stop() @@ -97,22 +148,14 @@ func main() { // internally and never stop the loop; ctx cancellation (signal) is the // only thing that ends Run. sched := scheduler.New(st, svc, dispatcher, m) - go sched.Run(ctx, schedulerTick) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + sched.Run(ctx, schedulerTick) + }() - mux := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch { - case r.URL.Path == "/metrics": - // Public by design (no auth) — Metrics.Handler only ever exposes - // aggregate counters/gauges, never per-domain or secret data. - m.Handler().ServeHTTP(w, r) - case isAPIPath(r.URL.Path): - apiRouter.ServeHTTP(w, r) - case webHandler != nil: - webHandler.ServeHTTP(w, r) - default: - http.NotFound(w, r) - } - }) + mux := buildMux(m.Handler(), apiRouter, webHandler) srv := &http.Server{Addr: cfg.ListenAddr, Handler: mux} @@ -135,6 +178,10 @@ func main() { log.Printf("server: graceful shutdown failed: %v", err) } <-serveErr + // Wait for the in-flight scheduler RunOnce (interrupted by the + // cancelled ctx passed into checker.Check) to finish before exiting, + // so we never kill the process mid-write of a check/notify status. + wg.Wait() log.Printf("server stopped") case err := <-serveErr: if err != nil { diff --git a/cmd/server/main_test.go b/cmd/server/main_test.go index e056b4e..3424c8b 100644 --- a/cmd/server/main_test.go +++ b/cmd/server/main_test.go @@ -1,6 +1,10 @@ package main -import "testing" +import ( + "net/http" + "net/http/httptest" + "testing" +) func TestIsAPIPath(t *testing.T) { cases := []struct { @@ -20,3 +24,68 @@ func TestIsAPIPath(t *testing.T) { } } } + +func TestBuildMux(t *testing.T) { + var metricsHit, apiHit, webHit bool + + metricsHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + metricsHit = true + w.WriteHeader(http.StatusOK) + }) + apiRouter := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + apiHit = true + w.WriteHeader(http.StatusOK) + }) + webHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + webHit = true + w.WriteHeader(http.StatusOK) + }) + + mux := buildMux(metricsHandler, apiRouter, webHandler) + + t.Run("healthz returns 200 ok", func(t *testing.T) { + metricsHit, apiHit, webHit = false, false, false + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/healthz", nil) + mux.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rr.Code, http.StatusOK) + } + if rr.Body.String() != "ok" { + t.Fatalf("body = %q, want %q", rr.Body.String(), "ok") + } + if metricsHit || apiHit || webHit { + t.Fatalf("healthz must not fall through to other handlers") + } + }) + + t.Run("metrics routed to metrics handler", func(t *testing.T) { + metricsHit, apiHit, webHit = false, false, false + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/metrics", nil) + mux.ServeHTTP(rr, req) + if !metricsHit { + t.Fatalf("expected metrics handler to be hit") + } + }) + + t.Run("api path routed to api router", func(t *testing.T) { + metricsHit, apiHit, webHit = false, false, false + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/domains", nil) + mux.ServeHTTP(rr, req) + if !apiHit { + t.Fatalf("expected api router to be hit") + } + }) + + t.Run("other path routed to web handler", func(t *testing.T) { + metricsHit, apiHit, webHit = false, false, false + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/domains/xyz", nil) + mux.ServeHTTP(rr, req) + if !webHit { + t.Fatalf("expected web handler to be hit") + } + }) +}