From 521ab84ce6e4c35ca3c62e8b396636adc50a2469 Mon Sep 17 00:00:00 2001 From: Justin Ruggles Date: Thu, 26 Oct 2023 09:33:37 -0400 Subject: [PATCH] Add context methods to Clock interface This allows for using timeout/deadline functionality built in to context.Context with a custom clock implementation. --- clock.go | 29 +++++++++++++++++++++ fake/fake_clock.go | 57 ++++++++++++++++++++++++++++++++++++++++++ offset/offset_clock.go | 16 ++++++++++++ 3 files changed, 102 insertions(+) diff --git a/clock.go b/clock.go index 2ff2c2a..4c2f6d7 100644 --- a/clock.go +++ b/clock.go @@ -54,6 +54,22 @@ func (c defaultClock) AfterFunc(d time.Duration, f func()) StopTimer { return time.AfterFunc(d, f) } +func (c defaultClock) ContextWithDeadline(ctx context.Context, t time.Time) (context.Context, context.CancelFunc) { + return context.WithDeadline(ctx, t) +} + +func (c defaultClock) ContextWithDeadlineCause(ctx context.Context, t time.Time, cause error) (context.Context, context.CancelFunc) { + return context.WithDeadlineCause(ctx, t, cause) +} + +func (c defaultClock) ContextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) { + return context.WithTimeout(ctx, d) +} + +func (c defaultClock) ContextWithTimeoutCause(ctx context.Context, d time.Duration, cause error) (context.Context, context.CancelFunc) { + return context.WithTimeoutCause(ctx, d, cause) +} + // DefaultClock returns a clock that minimally wraps the `time` package func DefaultClock() Clock { return defaultClock{} @@ -80,4 +96,17 @@ type Clock interface { // The callback function f will be executed after the interval d has // elapsed, unless the returned timer's Stop() method is called first. AfterFunc(d time.Duration, f func()) StopTimer + + // ContextWithDeadline behaves like context.WithDeadline, but it uses the + // clock to determine the when the deadline has expired. + ContextWithDeadline(ctx context.Context, t time.Time) (context.Context, context.CancelFunc) + // ContextWithDeadlineCause behaves like context.WithDeadlineCause, but it + // uses the clock to determine the when the deadline has expired. + ContextWithDeadlineCause(ctx context.Context, t time.Time, cause error) (context.Context, context.CancelFunc) + // ContextWithTimeout behaves like context.WithTimeout, but it uses the + // clock to determine the when the timeout has elapsed. + ContextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) + // ContextWithTimeoutCause behaves like context.WithTimeoutCause, but it + // uses the clock to determine the when the timeout has elapsed. + ContextWithTimeoutCause(ctx context.Context, d time.Duration, cause error) (context.Context, context.CancelFunc) } diff --git a/fake/fake_clock.go b/fake/fake_clock.go index 2b25c6d..181b6bd 100644 --- a/fake/fake_clock.go +++ b/fake/fake_clock.go @@ -3,6 +3,7 @@ package fake import ( "context" "sync" + "sync/atomic" "time" clocks "github.com/vimeo/go-clocks" @@ -410,3 +411,59 @@ func (f *Clock) AwaitTimerAborts(n int) { func (f *Clock) WaitAfterFuncs() { f.cbsWG.Wait() } + +type deadlineContext struct { + context.Context + timedOut atomic.Bool + deadline time.Time +} + +func (d *deadlineContext) Deadline() (time.Time, bool) { + return d.deadline, true +} + +func (d *deadlineContext) Err() error { + if d.timedOut.Load() { + return context.DeadlineExceeded + } + return d.Context.Err() +} + +func (f *Clock) ContextWithDeadlineCause(ctx context.Context, t time.Time, cause error) (context.Context, context.CancelFunc) { + cctx, cancelCause := context.WithCancelCause(ctx) + dctx := &deadlineContext{ + Context: cctx, + deadline: t, + } + dur := f.Until(t) + if dur <= 0 { + dctx.timedOut.CompareAndSwap(false, true) + cancelCause(cause) + return dctx, func() { + cancelCause(context.Canceled) + } + } + stop := f.AfterFunc(dur, func() { + if cctx.Err() == nil { + dctx.timedOut.CompareAndSwap(false, true) + } + cancelCause(cause) + }) + cancel := func() { + cancelCause(context.Canceled) + stop.Stop() + } + return dctx, cancel +} + +func (c *Clock) ContextWithDeadline(ctx context.Context, t time.Time) (context.Context, context.CancelFunc) { + return c.ContextWithDeadlineCause(ctx, t, nil) +} + +func (c *Clock) ContextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) { + return c.ContextWithDeadlineCause(ctx, c.Now().Add(d), nil) +} + +func (c *Clock) ContextWithTimeoutCause(ctx context.Context, d time.Duration, cause error) (context.Context, context.CancelFunc) { + return c.ContextWithDeadlineCause(ctx, c.Now().Add(d), cause) +} diff --git a/offset/offset_clock.go b/offset/offset_clock.go index a83375a..7b69ebf 100644 --- a/offset/offset_clock.go +++ b/offset/offset_clock.go @@ -49,6 +49,22 @@ func (o *Clock) AfterFunc(d time.Duration, f func()) clocks.StopTimer { return o.inner.AfterFunc(d, f) } +func (o *Clock) ContextWithDeadline(ctx context.Context, t time.Time) (context.Context, context.CancelFunc) { + return o.inner.ContextWithDeadline(ctx, t.Add(o.offset)) +} + +func (o *Clock) ContextWithDeadlineCause(ctx context.Context, t time.Time, cause error) (context.Context, context.CancelFunc) { + return o.inner.ContextWithDeadlineCause(ctx, t.Add(o.offset), cause) +} + +func (o *Clock) ContextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) { + return o.inner.ContextWithTimeout(ctx, d+o.offset) +} + +func (o *Clock) ContextWithTimeoutCause(ctx context.Context, d time.Duration, cause error) (context.Context, context.CancelFunc) { + return o.inner.ContextWithTimeoutCause(ctx, d+o.offset, cause) +} + // NewOffsetClock creates an OffsetClock. offset is added to all absolute times. func NewOffsetClock(inner clocks.Clock, offset time.Duration) *Clock { return &Clock{