summaryrefslogtreecommitdiff
path: root/vendor/github.com/go-pg/pg/v10/hook.go
blob: a95dc20bca085dd7685f7331de4be8ebdbcb5050 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
package pg

import (
	"context"
	"fmt"
	"time"

	"github.com/go-pg/pg/v10/orm"
)

type (
	BeforeScanHook   = orm.BeforeScanHook
	AfterScanHook    = orm.AfterScanHook
	AfterSelectHook  = orm.AfterSelectHook
	BeforeInsertHook = orm.BeforeInsertHook
	AfterInsertHook  = orm.AfterInsertHook
	BeforeUpdateHook = orm.BeforeUpdateHook
	AfterUpdateHook  = orm.AfterUpdateHook
	BeforeDeleteHook = orm.BeforeDeleteHook
	AfterDeleteHook  = orm.AfterDeleteHook
)

//------------------------------------------------------------------------------

type dummyFormatter struct{}

func (dummyFormatter) FormatQuery(b []byte, query string, params ...interface{}) []byte {
	return append(b, query...)
}

// QueryEvent ...
type QueryEvent struct {
	StartTime  time.Time
	DB         orm.DB
	Model      interface{}
	Query      interface{}
	Params     []interface{}
	fmtedQuery []byte
	Result     Result
	Err        error

	Stash map[interface{}]interface{}
}

// QueryHook ...
type QueryHook interface {
	BeforeQuery(context.Context, *QueryEvent) (context.Context, error)
	AfterQuery(context.Context, *QueryEvent) error
}

// UnformattedQuery returns the unformatted query of a query event.
// The query is only valid until the query Result is returned to the user.
func (e *QueryEvent) UnformattedQuery() ([]byte, error) {
	return queryString(e.Query)
}

func queryString(query interface{}) ([]byte, error) {
	switch query := query.(type) {
	case orm.TemplateAppender:
		return query.AppendTemplate(nil)
	case string:
		return dummyFormatter{}.FormatQuery(nil, query), nil
	default:
		return nil, fmt.Errorf("pg: can't append %T", query)
	}
}

// FormattedQuery returns the formatted query of a query event.
// The query is only valid until the query Result is returned to the user.
func (e *QueryEvent) FormattedQuery() ([]byte, error) {
	return e.fmtedQuery, nil
}

// AddQueryHook adds a hook into query processing.
func (db *baseDB) AddQueryHook(hook QueryHook) {
	db.queryHooks = append(db.queryHooks, hook)
}

func (db *baseDB) beforeQuery(
	ctx context.Context,
	ormDB orm.DB,
	model, query interface{},
	params []interface{},
	fmtedQuery []byte,
) (context.Context, *QueryEvent, error) {
	if len(db.queryHooks) == 0 {
		return ctx, nil, nil
	}

	event := &QueryEvent{
		StartTime:  time.Now(),
		DB:         ormDB,
		Model:      model,
		Query:      query,
		Params:     params,
		fmtedQuery: fmtedQuery,
	}

	for i, hook := range db.queryHooks {
		var err error
		ctx, err = hook.BeforeQuery(ctx, event)
		if err != nil {
			if err := db.afterQueryFromIndex(ctx, event, i); err != nil {
				return ctx, nil, err
			}
			return ctx, nil, err
		}
	}

	return ctx, event, nil
}

func (db *baseDB) afterQuery(
	ctx context.Context,
	event *QueryEvent,
	res Result,
	err error,
) error {
	if event == nil {
		return nil
	}

	event.Err = err
	event.Result = res
	return db.afterQueryFromIndex(ctx, event, len(db.queryHooks)-1)
}

func (db *baseDB) afterQueryFromIndex(ctx context.Context, event *QueryEvent, hookIndex int) error {
	for ; hookIndex >= 0; hookIndex-- {
		if err := db.queryHooks[hookIndex].AfterQuery(ctx, event); err != nil {
			return err
		}
	}
	return nil
}

func copyQueryHooks(s []QueryHook) []QueryHook {
	return s[:len(s):len(s)]
}