diff --git a/internal/imapx/dial.go b/internal/imapx/dial.go index a23f114..8de1b20 100644 --- a/internal/imapx/dial.go +++ b/internal/imapx/dial.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "fmt" + "time" "github.com/emersion/go-imap/v2/imapclient" ) @@ -16,29 +17,45 @@ type Endpoint struct { func (e Endpoint) addr() string { return fmt.Sprintf("%s:%d", e.Host, e.Port) } -func Connect(ctx context.Context, ep Endpoint) (*imapclient.Client, error) { - var ( - c *imapclient.Client - err error - ) +func dialOnce(ep Endpoint) (*imapclient.Client, error) { switch ep.TLSMode { case "ssl": - c, err = imapclient.DialTLS(ep.addr(), &imapclient.Options{ + return imapclient.DialTLS(ep.addr(), &imapclient.Options{ TLSConfig: &tls.Config{ServerName: ep.Host}, }) case "starttls": - c, err = imapclient.DialStartTLS(ep.addr(), &imapclient.Options{ + return imapclient.DialStartTLS(ep.addr(), &imapclient.Options{ TLSConfig: &tls.Config{ServerName: ep.Host}, }) case "plain": - c, err = imapclient.DialInsecure(ep.addr(), nil) + return imapclient.DialInsecure(ep.addr(), nil) default: return nil, fmt.Errorf("unknown tls_mode %q", ep.TLSMode) } - if err != nil { - return nil, err +} + +func Connect(ctx context.Context, ep Endpoint) (*imapclient.Client, error) { + const attempts = 3 + var lastErr error + for i := 0; i < attempts; i++ { + if err := ctx.Err(); err != nil { + return nil, err + } + c, err := dialOnce(ep) + if err == nil { + return c, nil + } + lastErr = err + if i < attempts-1 { + backoff := time.Duration(200*(i+1)) * time.Millisecond + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(backoff): + } + } } - return c, nil + return nil, lastErr } func TestEndpoint(ctx context.Context, ep Endpoint) error { diff --git a/internal/imapx/dial_test.go b/internal/imapx/dial_test.go index 6f621aa..f51a006 100644 --- a/internal/imapx/dial_test.go +++ b/internal/imapx/dial_test.go @@ -40,3 +40,12 @@ func TestTestLoginListsFolders(t *testing.T) { t.Fatalf("INBOX not in folders: %v", folders) } } + +func TestConnectHonorsCancelledContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := Connect(ctx, Endpoint{Host: "10.255.255.1", Port: 993, TLSMode: "ssl"}) + if err == nil { + t.Fatal("expected error for cancelled context") + } +}