diff options
Diffstat (limited to 'vendor/github.com/go-pg/pg/v10/hook.go')
-rw-r--r-- | vendor/github.com/go-pg/pg/v10/hook.go | 139 |
1 files changed, 139 insertions, 0 deletions
diff --git a/vendor/github.com/go-pg/pg/v10/hook.go b/vendor/github.com/go-pg/pg/v10/hook.go new file mode 100644 index 000000000..a95dc20bc --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/hook.go @@ -0,0 +1,139 @@ +package pg + +import ( + "context" + "fmt" + "time" + + "github.com/go-pg/pg/v10/orm" +) + +type ( + BeforeScanHook = orm.BeforeScanHook + AfterScanHook = orm.AfterScanHook + AfterSelectHook = orm.AfterSelectHook + BeforeInsertHook = orm.BeforeInsertHook + AfterInsertHook = orm.AfterInsertHook + BeforeUpdateHook = orm.BeforeUpdateHook + AfterUpdateHook = orm.AfterUpdateHook + BeforeDeleteHook = orm.BeforeDeleteHook + AfterDeleteHook = orm.AfterDeleteHook +) + +//------------------------------------------------------------------------------ + +type dummyFormatter struct{} + +func (dummyFormatter) FormatQuery(b []byte, query string, params ...interface{}) []byte { + return append(b, query...) +} + +// QueryEvent ... +type QueryEvent struct { + StartTime time.Time + DB orm.DB + Model interface{} + Query interface{} + Params []interface{} + fmtedQuery []byte + Result Result + Err error + + Stash map[interface{}]interface{} +} + +// QueryHook ... +type QueryHook interface { + BeforeQuery(context.Context, *QueryEvent) (context.Context, error) + AfterQuery(context.Context, *QueryEvent) error +} + +// UnformattedQuery returns the unformatted query of a query event. +// The query is only valid until the query Result is returned to the user. +func (e *QueryEvent) UnformattedQuery() ([]byte, error) { + return queryString(e.Query) +} + +func queryString(query interface{}) ([]byte, error) { + switch query := query.(type) { + case orm.TemplateAppender: + return query.AppendTemplate(nil) + case string: + return dummyFormatter{}.FormatQuery(nil, query), nil + default: + return nil, fmt.Errorf("pg: can't append %T", query) + } +} + +// FormattedQuery returns the formatted query of a query event. +// The query is only valid until the query Result is returned to the user. +func (e *QueryEvent) FormattedQuery() ([]byte, error) { + return e.fmtedQuery, nil +} + +// AddQueryHook adds a hook into query processing. +func (db *baseDB) AddQueryHook(hook QueryHook) { + db.queryHooks = append(db.queryHooks, hook) +} + +func (db *baseDB) beforeQuery( + ctx context.Context, + ormDB orm.DB, + model, query interface{}, + params []interface{}, + fmtedQuery []byte, +) (context.Context, *QueryEvent, error) { + if len(db.queryHooks) == 0 { + return ctx, nil, nil + } + + event := &QueryEvent{ + StartTime: time.Now(), + DB: ormDB, + Model: model, + Query: query, + Params: params, + fmtedQuery: fmtedQuery, + } + + for i, hook := range db.queryHooks { + var err error + ctx, err = hook.BeforeQuery(ctx, event) + if err != nil { + if err := db.afterQueryFromIndex(ctx, event, i); err != nil { + return ctx, nil, err + } + return ctx, nil, err + } + } + + return ctx, event, nil +} + +func (db *baseDB) afterQuery( + ctx context.Context, + event *QueryEvent, + res Result, + err error, +) error { + if event == nil { + return nil + } + + event.Err = err + event.Result = res + return db.afterQueryFromIndex(ctx, event, len(db.queryHooks)-1) +} + +func (db *baseDB) afterQueryFromIndex(ctx context.Context, event *QueryEvent, hookIndex int) error { + for ; hookIndex >= 0; hookIndex-- { + if err := db.queryHooks[hookIndex].AfterQuery(ctx, event); err != nil { + return err + } + } + return nil +} + +func copyQueryHooks(s []QueryHook) []QueryHook { + return s[:len(s):len(s)] +} |