diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/model_table_struct.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/model_table_struct.go | 62 |
1 files changed, 32 insertions, 30 deletions
diff --git a/vendor/github.com/uptrace/bun/model_table_struct.go b/vendor/github.com/uptrace/bun/model_table_struct.go index 3bb0c14dd..ee207ea08 100644 --- a/vendor/github.com/uptrace/bun/model_table_struct.go +++ b/vendor/github.com/uptrace/bun/model_table_struct.go @@ -6,8 +6,8 @@ import ( "fmt" "reflect" "strings" + "time" - "github.com/uptrace/bun/dialect" "github.com/uptrace/bun/schema" ) @@ -16,7 +16,7 @@ type structTableModel struct { table *schema.Table rel *schema.Relation - joins []join + joins []relationJoin dest interface{} root reflect.Value @@ -151,7 +151,7 @@ func (m *structTableModel) AfterScan(ctx context.Context) error { return firstErr } -func (m *structTableModel) GetJoin(name string) *join { +func (m *structTableModel) GetJoin(name string) *relationJoin { for i := range m.joins { j := &m.joins[i] if j.Relation.Field.Name == name || j.Relation.Field.GoName == name { @@ -161,30 +161,28 @@ func (m *structTableModel) GetJoin(name string) *join { return nil } -func (m *structTableModel) GetJoins() []join { +func (m *structTableModel) GetJoins() []relationJoin { return m.joins } -func (m *structTableModel) AddJoin(j join) *join { +func (m *structTableModel) AddJoin(j relationJoin) *relationJoin { m.joins = append(m.joins, j) return &m.joins[len(m.joins)-1] } -func (m *structTableModel) Join(name string, apply func(*SelectQuery) *SelectQuery) *join { - return m.join(m.strct, name, apply) +func (m *structTableModel) Join(name string) *relationJoin { + return m.join(m.strct, name) } -func (m *structTableModel) join( - bind reflect.Value, name string, apply func(*SelectQuery) *SelectQuery, -) *join { +func (m *structTableModel) join(bind reflect.Value, name string) *relationJoin { path := strings.Split(name, ".") index := make([]int, 0, len(path)) - currJoin := join{ + currJoin := relationJoin{ BaseModel: m, JoinModel: m, } - var lastJoin *join + var lastJoin *relationJoin for _, name := range path { relation, ok := currJoin.JoinModel.Table().Relations[name] @@ -214,20 +212,12 @@ func (m *structTableModel) join( } } - // No joins with such name. - if lastJoin == nil { - return nil - } - if apply != nil { - lastJoin.ApplyQueryFunc = apply - } - return lastJoin } -func (m *structTableModel) updateSoftDeleteField() error { +func (m *structTableModel) updateSoftDeleteField(tm time.Time) error { fv := m.table.SoftDeleteField.Value(m.strct) - return m.table.UpdateSoftDeleteField(fv) + return m.table.UpdateSoftDeleteField(fv, tm) } func (m *structTableModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) { @@ -235,20 +225,24 @@ func (m *structTableModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, e return 0, rows.Err() } + var n int + if err := m.ScanRow(ctx, rows); err != nil { return 0, err } + n++ - // For inserts, SQLite3 can return a row like it was inserted sucessfully and then - // an actual error for the next row. See issues/100. - if m.db.dialect.Name() == dialect.SQLite { - _ = rows.Next() - if err := rows.Err(); err != nil { - return 0, err - } + // And discard the rest. This is especially important for SQLite3, which can return + // a row like it was inserted sucessfully and then return an actual error for the next row. + // See issues/100. + for rows.Next() { + n++ + } + if err := rows.Err(); err != nil { + return 0, err } - return 1, nil + return n, nil } func (m *structTableModel) ScanRow(ctx context.Context, rows *sql.Rows) error { @@ -305,6 +299,9 @@ func (m *structTableModel) scanColumn(column string, src interface{}) (bool, err } if field, ok := m.table.FieldMap[column]; ok { + if src == nil && m.isNil() { + return true, nil + } return true, field.ScanValue(m.strct, src) } @@ -312,6 +309,7 @@ func (m *structTableModel) scanColumn(column string, src interface{}) (bool, err if join := m.GetJoin(joinName); join != nil { return true, join.JoinModel.ScanColumn(column, src) } + if m.table.ModelName == joinName { return true, m.ScanColumn(column, src) } @@ -320,6 +318,10 @@ func (m *structTableModel) scanColumn(column string, src interface{}) (bool, err return false, nil } +func (m *structTableModel) isNil() bool { + return m.strct.Kind() == reflect.Ptr && m.strct.IsNil() +} + func (m *structTableModel) AppendNamedArg( fmter schema.Formatter, b []byte, name string, ) ([]byte, bool) { |