summaryrefslogtreecommitdiff
path: root/vendor/github.com/jackc/pgconn/internal/ctxwatch/context_watcher.go
blob: b39cb3ee57a447a8fe6c07d1c93525e202a7dc6a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
package ctxwatch

import (
	"context"
	"sync"
)

// ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a
// time.
type ContextWatcher struct {
	onCancel             func()
	onUnwatchAfterCancel func()
	unwatchChan          chan struct{}

	lock              sync.Mutex
	watchInProgress   bool
	onCancelWasCalled bool
}

// NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled.
// OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and
// onCancel called.
func NewContextWatcher(onCancel func(), onUnwatchAfterCancel func()) *ContextWatcher {
	cw := &ContextWatcher{
		onCancel:             onCancel,
		onUnwatchAfterCancel: onUnwatchAfterCancel,
		unwatchChan:          make(chan struct{}),
	}

	return cw
}

// Watch starts watching ctx. If ctx is canceled then the onCancel function passed to NewContextWatcher will be called.
func (cw *ContextWatcher) Watch(ctx context.Context) {
	cw.lock.Lock()
	defer cw.lock.Unlock()

	if cw.watchInProgress {
		panic("Watch already in progress")
	}

	cw.onCancelWasCalled = false

	if ctx.Done() != nil {
		cw.watchInProgress = true
		go func() {
			select {
			case <-ctx.Done():
				cw.onCancel()
				cw.onCancelWasCalled = true
				<-cw.unwatchChan
			case <-cw.unwatchChan:
			}
		}()
	} else {
		cw.watchInProgress = false
	}
}

// Unwatch stops watching the previously watched context. If the onCancel function passed to NewContextWatcher was
// called then onUnwatchAfterCancel will also be called.
func (cw *ContextWatcher) Unwatch() {
	cw.lock.Lock()
	defer cw.lock.Unlock()

	if cw.watchInProgress {
		cw.unwatchChan <- struct{}{}
		if cw.onCancelWasCalled {
			cw.onUnwatchAfterCancel()
		}
		cw.watchInProgress = false
	}
}