diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/db.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/db.go | 170 |
1 files changed, 162 insertions, 8 deletions
diff --git a/vendor/github.com/uptrace/bun/db.go b/vendor/github.com/uptrace/bun/db.go index c283f56bd..067996d1c 100644 --- a/vendor/github.com/uptrace/bun/db.go +++ b/vendor/github.com/uptrace/bun/db.go @@ -9,6 +9,7 @@ import ( "reflect" "strings" "sync/atomic" + "time" "github.com/uptrace/bun/dialect/feature" "github.com/uptrace/bun/internal" @@ -26,32 +27,56 @@ type DBStats struct { type DBOption func(db *DB) +func WithOptions(opts ...DBOption) DBOption { + return func(db *DB) { + for _, opt := range opts { + opt(db) + } + } +} + func WithDiscardUnknownColumns() DBOption { return func(db *DB) { db.flags = db.flags.Set(discardUnknownColumns) } } -type DB struct { - *sql.DB +func WithConnResolver(resolver ConnResolver) DBOption { + return func(db *DB) { + db.resolver = resolver + } +} - dialect schema.Dialect +type DB struct { + // Must be a pointer so we copy the whole state, not individual fields. + *noCopyState queryHooks []QueryHook fmter schema.Formatter - flags internal.Flag - stats DBStats } +// noCopyState contains DB fields that must not be copied on clone(), +// for example, it is forbidden to copy atomic.Pointer. +type noCopyState struct { + *sql.DB + dialect schema.Dialect + resolver ConnResolver + + flags internal.Flag + closed atomic.Bool +} + func NewDB(sqldb *sql.DB, dialect schema.Dialect, opts ...DBOption) *DB { dialect.Init(sqldb) db := &DB{ - DB: sqldb, - dialect: dialect, - fmter: schema.NewFormatter(dialect), + noCopyState: &noCopyState{ + DB: sqldb, + dialect: dialect, + }, + fmter: schema.NewFormatter(dialect), } for _, opt := range opts { @@ -69,6 +94,22 @@ func (db *DB) String() string { return b.String() } +func (db *DB) Close() error { + if db.closed.Swap(true) { + return nil + } + + firstErr := db.DB.Close() + + if db.resolver != nil { + if err := db.resolver.Close(); err != nil && firstErr == nil { + firstErr = err + } + } + + return firstErr +} + func (db *DB) DBStats() DBStats { return DBStats{ Queries: atomic.LoadUint32(&db.stats.Queries), @@ -703,3 +744,116 @@ func (tx Tx) NewDropColumn() *DropColumnQuery { func (db *DB) makeQueryBytes() []byte { return internal.MakeQueryBytes() } + +//------------------------------------------------------------------------------ + +// ConnResolver enables routing queries to multiple databases. +type ConnResolver interface { + ResolveConn(query Query) IConn + Close() error +} + +// TODO: +// - make monitoring interval configurable +// - make ping timeout configutable +// - allow adding read/write replicas for multi-master replication +type ReadWriteConnResolver struct { + replicas []*sql.DB // read-only replicas + healthyReplicas atomic.Pointer[[]*sql.DB] + nextReplica atomic.Int64 + closed atomic.Bool +} + +func NewReadWriteConnResolver(opts ...ReadWriteConnResolverOption) *ReadWriteConnResolver { + r := new(ReadWriteConnResolver) + + for _, opt := range opts { + opt(r) + } + + if len(r.replicas) > 0 { + r.healthyReplicas.Store(&r.replicas) + go r.monitor() + } + + return r +} + +type ReadWriteConnResolverOption func(r *ReadWriteConnResolver) + +func WithReadOnlyReplica(dbs ...*sql.DB) ReadWriteConnResolverOption { + return func(r *ReadWriteConnResolver) { + r.replicas = append(r.replicas, dbs...) + } +} + +func (r *ReadWriteConnResolver) Close() error { + if r.closed.Swap(true) { + return nil + } + + var firstErr error + for _, db := range r.replicas { + if err := db.Close(); err != nil && firstErr == nil { + firstErr = err + } + } + return firstErr +} + +// healthyReplica returns a random healthy replica. +func (r *ReadWriteConnResolver) ResolveConn(query Query) IConn { + if len(r.replicas) == 0 || !isReadOnlyQuery(query) { + return nil + } + + replicas := r.loadHealthyReplicas() + if len(replicas) == 0 { + return nil + } + if len(replicas) == 1 { + return replicas[0] + } + i := r.nextReplica.Add(1) + return replicas[int(i)%len(replicas)] +} + +func isReadOnlyQuery(query Query) bool { + sel, ok := query.(*SelectQuery) + if !ok { + return false + } + for _, el := range sel.with { + if !isReadOnlyQuery(el.query) { + return false + } + } + return true +} + +func (r *ReadWriteConnResolver) loadHealthyReplicas() []*sql.DB { + if ptr := r.healthyReplicas.Load(); ptr != nil { + return *ptr + } + return nil +} + +func (r *ReadWriteConnResolver) monitor() { + const interval = 5 * time.Second + for !r.closed.Load() { + healthy := make([]*sql.DB, 0, len(r.replicas)) + + for _, replica := range r.replicas { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + err := replica.PingContext(ctx) + cancel() + + if err == nil { + healthy = append(healthy, replica) + } + } + + r.healthyReplicas.Store(&healthy) + time.Sleep(interval) + } +} |