summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/db.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/uptrace/bun/db.go')
-rw-r--r--vendor/github.com/uptrace/bun/db.go170
1 files changed, 162 insertions, 8 deletions
diff --git a/vendor/github.com/uptrace/bun/db.go b/vendor/github.com/uptrace/bun/db.go
index c283f56bd..067996d1c 100644
--- a/vendor/github.com/uptrace/bun/db.go
+++ b/vendor/github.com/uptrace/bun/db.go
@@ -9,6 +9,7 @@ import (
"reflect"
"strings"
"sync/atomic"
+ "time"
"github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/internal"
@@ -26,32 +27,56 @@ type DBStats struct {
type DBOption func(db *DB)
+func WithOptions(opts ...DBOption) DBOption {
+ return func(db *DB) {
+ for _, opt := range opts {
+ opt(db)
+ }
+ }
+}
+
func WithDiscardUnknownColumns() DBOption {
return func(db *DB) {
db.flags = db.flags.Set(discardUnknownColumns)
}
}
-type DB struct {
- *sql.DB
+func WithConnResolver(resolver ConnResolver) DBOption {
+ return func(db *DB) {
+ db.resolver = resolver
+ }
+}
- dialect schema.Dialect
+type DB struct {
+ // Must be a pointer so we copy the whole state, not individual fields.
+ *noCopyState
queryHooks []QueryHook
fmter schema.Formatter
- flags internal.Flag
-
stats DBStats
}
+// noCopyState contains DB fields that must not be copied on clone(),
+// for example, it is forbidden to copy atomic.Pointer.
+type noCopyState struct {
+ *sql.DB
+ dialect schema.Dialect
+ resolver ConnResolver
+
+ flags internal.Flag
+ closed atomic.Bool
+}
+
func NewDB(sqldb *sql.DB, dialect schema.Dialect, opts ...DBOption) *DB {
dialect.Init(sqldb)
db := &DB{
- DB: sqldb,
- dialect: dialect,
- fmter: schema.NewFormatter(dialect),
+ noCopyState: &noCopyState{
+ DB: sqldb,
+ dialect: dialect,
+ },
+ fmter: schema.NewFormatter(dialect),
}
for _, opt := range opts {
@@ -69,6 +94,22 @@ func (db *DB) String() string {
return b.String()
}
+func (db *DB) Close() error {
+ if db.closed.Swap(true) {
+ return nil
+ }
+
+ firstErr := db.DB.Close()
+
+ if db.resolver != nil {
+ if err := db.resolver.Close(); err != nil && firstErr == nil {
+ firstErr = err
+ }
+ }
+
+ return firstErr
+}
+
func (db *DB) DBStats() DBStats {
return DBStats{
Queries: atomic.LoadUint32(&db.stats.Queries),
@@ -703,3 +744,116 @@ func (tx Tx) NewDropColumn() *DropColumnQuery {
func (db *DB) makeQueryBytes() []byte {
return internal.MakeQueryBytes()
}
+
+//------------------------------------------------------------------------------
+
+// ConnResolver enables routing queries to multiple databases.
+type ConnResolver interface {
+ ResolveConn(query Query) IConn
+ Close() error
+}
+
+// TODO:
+// - make monitoring interval configurable
+// - make ping timeout configutable
+// - allow adding read/write replicas for multi-master replication
+type ReadWriteConnResolver struct {
+ replicas []*sql.DB // read-only replicas
+ healthyReplicas atomic.Pointer[[]*sql.DB]
+ nextReplica atomic.Int64
+ closed atomic.Bool
+}
+
+func NewReadWriteConnResolver(opts ...ReadWriteConnResolverOption) *ReadWriteConnResolver {
+ r := new(ReadWriteConnResolver)
+
+ for _, opt := range opts {
+ opt(r)
+ }
+
+ if len(r.replicas) > 0 {
+ r.healthyReplicas.Store(&r.replicas)
+ go r.monitor()
+ }
+
+ return r
+}
+
+type ReadWriteConnResolverOption func(r *ReadWriteConnResolver)
+
+func WithReadOnlyReplica(dbs ...*sql.DB) ReadWriteConnResolverOption {
+ return func(r *ReadWriteConnResolver) {
+ r.replicas = append(r.replicas, dbs...)
+ }
+}
+
+func (r *ReadWriteConnResolver) Close() error {
+ if r.closed.Swap(true) {
+ return nil
+ }
+
+ var firstErr error
+ for _, db := range r.replicas {
+ if err := db.Close(); err != nil && firstErr == nil {
+ firstErr = err
+ }
+ }
+ return firstErr
+}
+
+// healthyReplica returns a random healthy replica.
+func (r *ReadWriteConnResolver) ResolveConn(query Query) IConn {
+ if len(r.replicas) == 0 || !isReadOnlyQuery(query) {
+ return nil
+ }
+
+ replicas := r.loadHealthyReplicas()
+ if len(replicas) == 0 {
+ return nil
+ }
+ if len(replicas) == 1 {
+ return replicas[0]
+ }
+ i := r.nextReplica.Add(1)
+ return replicas[int(i)%len(replicas)]
+}
+
+func isReadOnlyQuery(query Query) bool {
+ sel, ok := query.(*SelectQuery)
+ if !ok {
+ return false
+ }
+ for _, el := range sel.with {
+ if !isReadOnlyQuery(el.query) {
+ return false
+ }
+ }
+ return true
+}
+
+func (r *ReadWriteConnResolver) loadHealthyReplicas() []*sql.DB {
+ if ptr := r.healthyReplicas.Load(); ptr != nil {
+ return *ptr
+ }
+ return nil
+}
+
+func (r *ReadWriteConnResolver) monitor() {
+ const interval = 5 * time.Second
+ for !r.closed.Load() {
+ healthy := make([]*sql.DB, 0, len(r.replicas))
+
+ for _, replica := range r.replicas {
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ err := replica.PingContext(ctx)
+ cancel()
+
+ if err == nil {
+ healthy = append(healthy, replica)
+ }
+ }
+
+ r.healthyReplicas.Store(&healthy)
+ time.Sleep(interval)
+ }
+}