summaryrefslogtreecommitdiff
path: root/vendor/github.com/sourcegraph/conc/pool/result_pool.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/sourcegraph/conc/pool/result_pool.go')
-rw-r--r--vendor/github.com/sourcegraph/conc/pool/result_pool.go142
1 files changed, 142 insertions, 0 deletions
diff --git a/vendor/github.com/sourcegraph/conc/pool/result_pool.go b/vendor/github.com/sourcegraph/conc/pool/result_pool.go
new file mode 100644
index 000000000..f73a77261
--- /dev/null
+++ b/vendor/github.com/sourcegraph/conc/pool/result_pool.go
@@ -0,0 +1,142 @@
+package pool
+
+import (
+ "context"
+ "sort"
+ "sync"
+)
+
+// NewWithResults creates a new ResultPool for tasks with a result of type T.
+//
+// The configuration methods (With*) will panic if they are used after calling
+// Go() for the first time.
+func NewWithResults[T any]() *ResultPool[T] {
+ return &ResultPool[T]{
+ pool: *New(),
+ }
+}
+
+// ResultPool is a pool that executes tasks that return a generic result type.
+// Tasks are executed in the pool with Go(), then the results of the tasks are
+// returned by Wait().
+//
+// The order of the results is guaranteed to be the same as the order the
+// tasks were submitted.
+type ResultPool[T any] struct {
+ pool Pool
+ agg resultAggregator[T]
+}
+
+// Go submits a task to the pool. If all goroutines in the pool
+// are busy, a call to Go() will block until the task can be started.
+func (p *ResultPool[T]) Go(f func() T) {
+ idx := p.agg.nextIndex()
+ p.pool.Go(func() {
+ p.agg.save(idx, f(), false)
+ })
+}
+
+// Wait cleans up all spawned goroutines, propagating any panics, and returning
+// a slice of results from tasks that did not panic.
+func (p *ResultPool[T]) Wait() []T {
+ p.pool.Wait()
+ results := p.agg.collect(true)
+ p.agg = resultAggregator[T]{} // reset for reuse
+ return results
+}
+
+// MaxGoroutines returns the maximum size of the pool.
+func (p *ResultPool[T]) MaxGoroutines() int {
+ return p.pool.MaxGoroutines()
+}
+
+// WithErrors converts the pool to an ResultErrorPool so the submitted tasks
+// can return errors.
+func (p *ResultPool[T]) WithErrors() *ResultErrorPool[T] {
+ p.panicIfInitialized()
+ return &ResultErrorPool[T]{
+ errorPool: *p.pool.WithErrors(),
+ }
+}
+
+// WithContext converts the pool to a ResultContextPool for tasks that should
+// run under the same context, such that they each respect shared cancellation.
+// For example, WithCancelOnError can be configured on the returned pool to
+// signal that all goroutines should be cancelled upon the first error.
+func (p *ResultPool[T]) WithContext(ctx context.Context) *ResultContextPool[T] {
+ p.panicIfInitialized()
+ return &ResultContextPool[T]{
+ contextPool: *p.pool.WithContext(ctx),
+ }
+}
+
+// WithMaxGoroutines limits the number of goroutines in a pool.
+// Defaults to unlimited. Panics if n < 1.
+func (p *ResultPool[T]) WithMaxGoroutines(n int) *ResultPool[T] {
+ p.panicIfInitialized()
+ p.pool.WithMaxGoroutines(n)
+ return p
+}
+
+func (p *ResultPool[T]) panicIfInitialized() {
+ p.pool.panicIfInitialized()
+}
+
+// resultAggregator is a utility type that lets us safely append from multiple
+// goroutines. The zero value is valid and ready to use.
+type resultAggregator[T any] struct {
+ mu sync.Mutex
+ len int
+ results []T
+ errored []int
+}
+
+// nextIndex reserves a slot for a result. The returned value should be passed
+// to save() when adding a result to the aggregator.
+func (r *resultAggregator[T]) nextIndex() int {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ nextIdx := r.len
+ r.len += 1
+ return nextIdx
+}
+
+func (r *resultAggregator[T]) save(i int, res T, errored bool) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if i >= len(r.results) {
+ old := r.results
+ r.results = make([]T, r.len)
+ copy(r.results, old)
+ }
+
+ r.results[i] = res
+
+ if errored {
+ r.errored = append(r.errored, i)
+ }
+}
+
+// collect returns the set of aggregated results.
+func (r *resultAggregator[T]) collect(collectErrored bool) []T {
+ if !r.mu.TryLock() {
+ panic("collect should not be called until all goroutines have exited")
+ }
+
+ if collectErrored || len(r.errored) == 0 {
+ return r.results
+ }
+
+ filtered := r.results[:0]
+ sort.Ints(r.errored)
+ for i, e := range r.errored {
+ if i == 0 {
+ filtered = append(filtered, r.results[:e]...)
+ } else {
+ filtered = append(filtered, r.results[r.errored[i-1]+1:e]...)
+ }
+ }
+ return filtered
+}