summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/model_table_struct.go
diff options
context:
space:
mode:
authorLibravatar tobi <31960611+tsmethurst@users.noreply.github.com>2021-08-25 15:34:33 +0200
committerLibravatar GitHub <noreply@github.com>2021-08-25 15:34:33 +0200
commit2dc9fc1626507bb54417fc4a1920b847cafb27a2 (patch)
tree4ddeac479b923db38090aac8bd9209f3646851c1 /vendor/github.com/uptrace/bun/model_table_struct.go
parentManually approves followers (#146) (diff)
downloadgotosocial-2dc9fc1626507bb54417fc4a1920b847cafb27a2.tar.xz
Pg to bun (#148)
* start moving to bun * changing more stuff * more * and yet more * tests passing * seems stable now * more big changes * small fix * little fixes
Diffstat (limited to 'vendor/github.com/uptrace/bun/model_table_struct.go')
-rw-r--r--vendor/github.com/uptrace/bun/model_table_struct.go345
1 files changed, 345 insertions, 0 deletions
diff --git a/vendor/github.com/uptrace/bun/model_table_struct.go b/vendor/github.com/uptrace/bun/model_table_struct.go
new file mode 100644
index 000000000..3bb0c14dd
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/model_table_struct.go
@@ -0,0 +1,345 @@
+package bun
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "reflect"
+ "strings"
+
+ "github.com/uptrace/bun/dialect"
+ "github.com/uptrace/bun/schema"
+)
+
+type structTableModel struct {
+ db *DB
+ table *schema.Table
+
+ rel *schema.Relation
+ joins []join
+
+ dest interface{}
+ root reflect.Value
+ index []int
+
+ strct reflect.Value
+ structInited bool
+ structInitErr error
+
+ columns []string
+ scanIndex int
+}
+
+var _ tableModel = (*structTableModel)(nil)
+
+func newStructTableModel(db *DB, dest interface{}, table *schema.Table) *structTableModel {
+ return &structTableModel{
+ db: db,
+ table: table,
+ dest: dest,
+ }
+}
+
+func newStructTableModelValue(db *DB, dest interface{}, v reflect.Value) *structTableModel {
+ return &structTableModel{
+ db: db,
+ table: db.Table(v.Type()),
+ dest: dest,
+ root: v,
+ strct: v,
+ }
+}
+
+func (m *structTableModel) Value() interface{} {
+ return m.dest
+}
+
+func (m *structTableModel) Table() *schema.Table {
+ return m.table
+}
+
+func (m *structTableModel) Relation() *schema.Relation {
+ return m.rel
+}
+
+func (m *structTableModel) Root() reflect.Value {
+ return m.root
+}
+
+func (m *structTableModel) Index() []int {
+ return m.index
+}
+
+func (m *structTableModel) ParentIndex() []int {
+ return m.index[:len(m.index)-len(m.rel.Field.Index)]
+}
+
+func (m *structTableModel) Mount(host reflect.Value) {
+ m.strct = host.FieldByIndex(m.rel.Field.Index)
+ m.structInited = false
+}
+
+func (m *structTableModel) initStruct() error {
+ if m.structInited {
+ return m.structInitErr
+ }
+ m.structInited = true
+
+ switch m.strct.Kind() {
+ case reflect.Invalid:
+ m.structInitErr = errNilModel
+ return m.structInitErr
+ case reflect.Interface:
+ m.strct = m.strct.Elem()
+ }
+
+ if m.strct.Kind() == reflect.Ptr {
+ if m.strct.IsNil() {
+ m.strct.Set(reflect.New(m.strct.Type().Elem()))
+ m.strct = m.strct.Elem()
+ } else {
+ m.strct = m.strct.Elem()
+ }
+ }
+
+ m.mountJoins()
+
+ return nil
+}
+
+func (m *structTableModel) mountJoins() {
+ for i := range m.joins {
+ j := &m.joins[i]
+ switch j.Relation.Type {
+ case schema.HasOneRelation, schema.BelongsToRelation:
+ j.JoinModel.Mount(m.strct)
+ }
+ }
+}
+
+var _ schema.BeforeScanHook = (*structTableModel)(nil)
+
+func (m *structTableModel) BeforeScan(ctx context.Context) error {
+ if !m.table.HasBeforeScanHook() {
+ return nil
+ }
+ return callBeforeScanHook(ctx, m.strct.Addr())
+}
+
+var _ schema.AfterScanHook = (*structTableModel)(nil)
+
+func (m *structTableModel) AfterScan(ctx context.Context) error {
+ if !m.table.HasAfterScanHook() || !m.structInited {
+ return nil
+ }
+
+ var firstErr error
+
+ if err := callAfterScanHook(ctx, m.strct.Addr()); err != nil && firstErr == nil {
+ firstErr = err
+ }
+
+ for _, j := range m.joins {
+ switch j.Relation.Type {
+ case schema.HasOneRelation, schema.BelongsToRelation:
+ if err := j.JoinModel.AfterScan(ctx); err != nil && firstErr == nil {
+ firstErr = err
+ }
+ }
+ }
+
+ return firstErr
+}
+
+func (m *structTableModel) GetJoin(name string) *join {
+ for i := range m.joins {
+ j := &m.joins[i]
+ if j.Relation.Field.Name == name || j.Relation.Field.GoName == name {
+ return j
+ }
+ }
+ return nil
+}
+
+func (m *structTableModel) GetJoins() []join {
+ return m.joins
+}
+
+func (m *structTableModel) AddJoin(j join) *join {
+ m.joins = append(m.joins, j)
+ return &m.joins[len(m.joins)-1]
+}
+
+func (m *structTableModel) Join(name string, apply func(*SelectQuery) *SelectQuery) *join {
+ return m.join(m.strct, name, apply)
+}
+
+func (m *structTableModel) join(
+ bind reflect.Value, name string, apply func(*SelectQuery) *SelectQuery,
+) *join {
+ path := strings.Split(name, ".")
+ index := make([]int, 0, len(path))
+
+ currJoin := join{
+ BaseModel: m,
+ JoinModel: m,
+ }
+ var lastJoin *join
+
+ for _, name := range path {
+ relation, ok := currJoin.JoinModel.Table().Relations[name]
+ if !ok {
+ return nil
+ }
+
+ currJoin.Relation = relation
+ index = append(index, relation.Field.Index...)
+
+ if j := currJoin.JoinModel.GetJoin(name); j != nil {
+ currJoin.BaseModel = j.BaseModel
+ currJoin.JoinModel = j.JoinModel
+
+ lastJoin = j
+ } else {
+ model, err := newTableModelIndex(m.db, m.table, bind, index, relation)
+ if err != nil {
+ return nil
+ }
+
+ currJoin.Parent = lastJoin
+ currJoin.BaseModel = currJoin.JoinModel
+ currJoin.JoinModel = model
+
+ lastJoin = currJoin.BaseModel.AddJoin(currJoin)
+ }
+ }
+
+ // No joins with such name.
+ if lastJoin == nil {
+ return nil
+ }
+ if apply != nil {
+ lastJoin.ApplyQueryFunc = apply
+ }
+
+ return lastJoin
+}
+
+func (m *structTableModel) updateSoftDeleteField() error {
+ fv := m.table.SoftDeleteField.Value(m.strct)
+ return m.table.UpdateSoftDeleteField(fv)
+}
+
+func (m *structTableModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) {
+ if !rows.Next() {
+ return 0, rows.Err()
+ }
+
+ if err := m.ScanRow(ctx, rows); err != nil {
+ return 0, err
+ }
+
+ // For inserts, SQLite3 can return a row like it was inserted sucessfully and then
+ // an actual error for the next row. See issues/100.
+ if m.db.dialect.Name() == dialect.SQLite {
+ _ = rows.Next()
+ if err := rows.Err(); err != nil {
+ return 0, err
+ }
+ }
+
+ return 1, nil
+}
+
+func (m *structTableModel) ScanRow(ctx context.Context, rows *sql.Rows) error {
+ columns, err := rows.Columns()
+ if err != nil {
+ return err
+ }
+
+ m.columns = columns
+ dest := makeDest(m, len(columns))
+
+ return m.scanRow(ctx, rows, dest)
+}
+
+func (m *structTableModel) scanRow(ctx context.Context, rows *sql.Rows, dest []interface{}) error {
+ if err := m.BeforeScan(ctx); err != nil {
+ return err
+ }
+
+ m.scanIndex = 0
+ if err := rows.Scan(dest...); err != nil {
+ return err
+ }
+
+ if err := m.AfterScan(ctx); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (m *structTableModel) Scan(src interface{}) error {
+ column := m.columns[m.scanIndex]
+ m.scanIndex++
+
+ return m.ScanColumn(unquote(column), src)
+}
+
+func (m *structTableModel) ScanColumn(column string, src interface{}) error {
+ if ok, err := m.scanColumn(column, src); ok {
+ return err
+ }
+ if column == "" || column[0] == '_' || m.db.flags.Has(discardUnknownColumns) {
+ return nil
+ }
+ return fmt.Errorf("bun: %s does not have column %q", m.table.TypeName, column)
+}
+
+func (m *structTableModel) scanColumn(column string, src interface{}) (bool, error) {
+ if src != nil {
+ if err := m.initStruct(); err != nil {
+ return true, err
+ }
+ }
+
+ if field, ok := m.table.FieldMap[column]; ok {
+ return true, field.ScanValue(m.strct, src)
+ }
+
+ if joinName, column := splitColumn(column); joinName != "" {
+ if join := m.GetJoin(joinName); join != nil {
+ return true, join.JoinModel.ScanColumn(column, src)
+ }
+ if m.table.ModelName == joinName {
+ return true, m.ScanColumn(column, src)
+ }
+ }
+
+ return false, nil
+}
+
+func (m *structTableModel) AppendNamedArg(
+ fmter schema.Formatter, b []byte, name string,
+) ([]byte, bool) {
+ return m.table.AppendNamedArg(fmter, b, name, m.strct)
+}
+
+// sqlite3 sometimes does not unquote columns.
+func unquote(s string) string {
+ if s == "" {
+ return s
+ }
+ if s[0] == '"' && s[len(s)-1] == '"' {
+ return s[1 : len(s)-1]
+ }
+ return s
+}
+
+func splitColumn(s string) (string, string) {
+ if i := strings.Index(s, "__"); i >= 0 {
+ return s[:i], s[i+2:]
+ }
+ return "", s
+}