diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/model_table_struct.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/model_table_struct.go | 63 |
1 files changed, 45 insertions, 18 deletions
diff --git a/vendor/github.com/uptrace/bun/model_table_struct.go b/vendor/github.com/uptrace/bun/model_table_struct.go index fba17f42a..fadc9284c 100644 --- a/vendor/github.com/uptrace/bun/model_table_struct.go +++ b/vendor/github.com/uptrace/bun/model_table_struct.go @@ -100,38 +100,65 @@ func (m *structTableModel) mountJoins() { } } -var _ schema.BeforeScanHook = (*structTableModel)(nil) +var _ schema.BeforeAppendModelHook = (*structTableModel)(nil) -func (m *structTableModel) BeforeScan(ctx context.Context) error { - if !m.table.HasBeforeScanHook() { +func (m *structTableModel) BeforeAppendModel(ctx context.Context, query Query) error { + if !m.table.HasBeforeAppendModelHook() || !m.strct.IsValid() { return nil } - return callBeforeScanHook(ctx, m.strct.Addr()) + return m.strct.Addr().Interface().(schema.BeforeAppendModelHook).BeforeAppendModel(ctx, query) } -var _ schema.AfterScanHook = (*structTableModel)(nil) +var _ schema.BeforeScanRowHook = (*structTableModel)(nil) -func (m *structTableModel) AfterScan(ctx context.Context) error { - if !m.table.HasAfterScanHook() || !m.structInited { +func (m *structTableModel) BeforeScanRow(ctx context.Context) error { + if m.table.HasBeforeScanRowHook() { + return m.strct.Addr().Interface().(schema.BeforeScanRowHook).BeforeScanRow(ctx) + } + if m.table.HasBeforeScanHook() { + return m.strct.Addr().Interface().(schema.BeforeScanHook).BeforeScan(ctx) + } + return nil +} + +var _ schema.AfterScanRowHook = (*structTableModel)(nil) + +func (m *structTableModel) AfterScanRow(ctx context.Context) error { + if !m.structInited { return nil } - var firstErr error + if m.table.HasAfterScanRowHook() { + firstErr := m.strct.Addr().Interface().(schema.AfterScanRowHook).AfterScanRow(ctx) + + for _, j := range m.joins { + switch j.Relation.Type { + case schema.HasOneRelation, schema.BelongsToRelation: + if err := j.JoinModel.AfterScanRow(ctx); err != nil && firstErr == nil { + firstErr = err + } + } + } - if err := callAfterScanHook(ctx, m.strct.Addr()); err != nil && firstErr == nil { - firstErr = err + return firstErr } - for _, j := range m.joins { - switch j.Relation.Type { - case schema.HasOneRelation, schema.BelongsToRelation: - if err := j.JoinModel.AfterScan(ctx); err != nil && firstErr == nil { - firstErr = err + if m.table.HasAfterScanHook() { + firstErr := m.strct.Addr().Interface().(schema.AfterScanHook).AfterScan(ctx) + + for _, j := range m.joins { + switch j.Relation.Type { + case schema.HasOneRelation, schema.BelongsToRelation: + if err := j.JoinModel.AfterScanRow(ctx); err != nil && firstErr == nil { + firstErr = err + } } } + + return firstErr } - return firstErr + return nil } func (m *structTableModel) getJoin(name string) *relationJoin { @@ -257,7 +284,7 @@ func (m *structTableModel) ScanRow(ctx context.Context, rows *sql.Rows) error { } func (m *structTableModel) scanRow(ctx context.Context, rows *sql.Rows, dest []interface{}) error { - if err := m.BeforeScan(ctx); err != nil { + if err := m.BeforeScanRow(ctx); err != nil { return err } @@ -266,7 +293,7 @@ func (m *structTableModel) scanRow(ctx context.Context, rows *sql.Rows, dest []i return err } - if err := m.AfterScan(ctx); err != nil { + if err := m.AfterScanRow(ctx); err != nil { return err } |