summaryrefslogtreecommitdiff
path: root/vendor/codeberg.org/gruf/go-mutexes/map.go
blob: 8a73a96d94be1fb1ec7756c5cad8f3c237975bbd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
package mutexes

import (
	"sync"
	"sync/atomic"
	"unsafe"
)

const (
	// possible lock types.
	lockTypeRead  = uint8(1) << 0
	lockTypeWrite = uint8(1) << 1

	// frequency of GC cycles
	// per no. unlocks. i.e.
	// every 'gcfreq' unlocks.
	gcfreq = 1024
)

// MutexMap is a structure that allows read / write locking
// per key, performing as you'd expect a map[string]*RWMutex
// to perform, without you needing to worry about deadlocks
// between competing read / write locks and the map's own mutex.
// It uses memory pooling for the internal "mutex" (ish) types
// and performs self-eviction of keys.
//
// Under the hood this is achieved using a single mutex for the
// map, state tracking for individual keys, and some sync.Cond{}
// like structures for sleeping / awaking awaiting goroutines.
type MutexMap struct {
	mapmu  sync.Mutex
	mumap  map[string]*rwmutex
	mupool rwmutexPool
	count  uint32
}

// checkInit ensures MutexMap is initialized (UNSAFE).
func (mm *MutexMap) checkInit() {
	if mm.mumap == nil {
		mm.mumap = make(map[string]*rwmutex)
	}
}

// Lock acquires a write lock on key in map, returning unlock function.
func (mm *MutexMap) Lock(key string) func() {
	return mm.lock(key, lockTypeWrite)
}

// RLock acquires a read lock on key in map, returning runlock function.
func (mm *MutexMap) RLock(key string) func() {
	return mm.lock(key, lockTypeRead)
}

func (mm *MutexMap) lock(key string, lt uint8) func() {
	// Perform first map lock
	// and check initialization
	// OUTSIDE the main loop.
	mm.mapmu.Lock()
	mm.checkInit()

	for {
		// Check map for mu.
		mu := mm.mumap[key]

		if mu == nil {
			// Allocate new mutex.
			mu = mm.mupool.Acquire()
			mm.mumap[key] = mu
		}

		if !mu.Lock(lt) {
			// Wait on mutex unlock, after
			// immediately relocking map mu.
			mu.WaitRelock(&mm.mapmu)
			continue
		}

		// Done with map.
		mm.mapmu.Unlock()

		// Return mutex unlock function.
		return func() { mm.unlock(key, mu) }
	}
}

func (mm *MutexMap) unlock(key string, mu *rwmutex) {
	// Get map lock.
	mm.mapmu.Lock()

	// Unlock mutex.
	if mu.Unlock() {

		// Mutex fully unlocked
		// with zero waiters. Self
		// evict and release it.
		delete(mm.mumap, key)
		mm.mupool.Release(mu)
	}

	if mm.count++; mm.count%gcfreq == 0 {
		// Every 'gcfreq' unlocks perform
		// a garbage collection to keep
		// us squeaky clean :]
		mm.mupool.GC()
	}

	// Done with map.
	mm.mapmu.Unlock()
}

// rwmutexPool is a very simply memory rwmutexPool.
type rwmutexPool struct {
	current []*rwmutex
	victim  []*rwmutex
}

// Acquire will returns a rwmutexState from rwmutexPool (or alloc new).
func (p *rwmutexPool) Acquire() *rwmutex {
	// First try the current queue
	if l := len(p.current) - 1; l >= 0 {
		mu := p.current[l]
		p.current = p.current[:l]
		return mu
	}

	// Next try the victim queue.
	if l := len(p.victim) - 1; l >= 0 {
		mu := p.victim[l]
		p.victim = p.victim[:l]
		return mu
	}

	// Lastly, alloc new.
	mu := new(rwmutex)
	return mu
}

// Release places a sync.rwmutexState back in the rwmutexPool.
func (p *rwmutexPool) Release(mu *rwmutex) {
	p.current = append(p.current, mu)
}

// GC will clear out unused entries from the rwmutexPool.
func (p *rwmutexPool) GC() {
	current := p.current
	p.current = nil
	p.victim = current
}

// rwmutex represents a RW mutex when used correctly within
// a MapMutex. It should ONLY be access when protected by
// the outer map lock, except for the 'notifyList' which is
// a runtime internal structure borrowed from the sync.Cond{}.
//
// this functions very similarly to a sync.Cond{}, but with
// lock state tracking, and returning on 'Broadcast()' whether
// any goroutines were actually awoken. it also has a less
// confusing API than sync.Cond{} with the outer locking
// mechanism we use, otherwise all Cond{}.L would reference
// the same outer map mutex.
type rwmutex struct {
	n notifyList // 'trigger' mechanism
	l int32      // no. locks
	t uint8      // lock type
}

// Lock will lock the mutex for given lock type, in the
// sense that it will update the internal state tracker
// accordingly. Return value is true on successful lock.
func (mu *rwmutex) Lock(lt uint8) bool {
	switch mu.t {
	case lockTypeRead:
		// already read locked,
		// only permit more reads.
		if lt != lockTypeRead {
			return false
		}

	case lockTypeWrite:
		// already write locked,
		// no other locks allowed.
		return false

	default:
		// Fully unlocked,
		// set incoming type.
		mu.t = lt
	}

	// Update
	// count.
	mu.l++

	return true
}

// Unlock will unlock the mutex, in the sense that it
// will update the internal state tracker accordingly.
// On totally unlocked state, it will awaken all
// sleeping goroutines waiting on this mutex.
func (mu *rwmutex) Unlock() bool {
	switch mu.l--; {
	case mu.l > 0 && mu.t == lockTypeWrite:
		panic("BUG: multiple writer locks")
	case mu.l < 0:
		panic("BUG: negative lock count")

	case mu.l == 0:
		// Fully unlocked.
		mu.t = 0

		// Awake all blocked goroutines and check
		// for change in the last notified ticket.
		before := atomic.LoadUint32(&mu.n.notify)
		runtime_notifyListNotifyAll(&mu.n)
		after := atomic.LoadUint32(&mu.n.notify)

		// If ticket changed, this indicates
		// AT LEAST one goroutine was awoken.
		//
		// (before != after) => (waiters > 0)
		// (before == after) => (waiters = 0)
		return (before == after)

	default:
		// i.e. mutex still
		// locked by others.
		return false
	}
}

// WaitRelock expects a mutex to be passed in, already in the
// locked state. It incr the notifyList waiter count before
// unlocking the outer mutex and blocking on notifyList wait.
// On awake it will decr wait count and relock outer mutex.
func (mu *rwmutex) WaitRelock(outer *sync.Mutex) {

	// add ourselves to list while still
	// under protection of outer map lock.
	t := runtime_notifyListAdd(&mu.n)

	// Finished with
	// outer map lock.
	outer.Unlock()

	// Block until awoken by another
	// goroutine within mu.Unlock().
	runtime_notifyListWait(&mu.n, t)

	// Relock!
	outer.Lock()
}

// unused fields left
// un-named for safety.
type notifyList struct {
	_      uint32         // wait   uint32
	notify uint32         // notify uint32
	_      uintptr        // lock   mutex
	_      unsafe.Pointer // head   *sudog
	_      unsafe.Pointer // tail   *sudog
}

// See runtime/sema.go for documentation.
//
//go:linkname runtime_notifyListAdd sync.runtime_notifyListAdd
func runtime_notifyListAdd(l *notifyList) uint32

// See runtime/sema.go for documentation.
//
//go:linkname runtime_notifyListWait sync.runtime_notifyListWait
func runtime_notifyListWait(l *notifyList, t uint32)

// See runtime/sema.go for documentation.
//
//go:linkname runtime_notifyListNotifyAll sync.runtime_notifyListNotifyAll
func runtime_notifyListNotifyAll(l *notifyList)

// Ensure that sync and runtime agree on size of notifyList.
//
//go:linkname runtime_notifyListCheck sync.runtime_notifyListCheck
func runtime_notifyListCheck(size uintptr)
func init() {
	var n notifyList
	runtime_notifyListCheck(unsafe.Sizeof(n))
}