diff options
author | 2023-05-12 14:33:40 +0200 | |
---|---|---|
committer | 2023-05-12 14:33:40 +0200 | |
commit | ec325fee141c1e9757144a0a4094061b56839b78 (patch) | |
tree | 2948ab4ef5702cc8478ab2be841340b946bdb867 /vendor/github.com/jackc/pgx/v5/rows.go | |
parent | [frogend/bugfix] fix dynamicSpoiler elements (#1771) (diff) | |
download | gotosocial-ec325fee141c1e9757144a0a4094061b56839b78.tar.xz |
[chore] Update a bunch of database dependencies (#1772)
* [chore] Update a bunch of database dependencies
* fix lil thing
Diffstat (limited to 'vendor/github.com/jackc/pgx/v5/rows.go')
-rw-r--r-- | vendor/github.com/jackc/pgx/v5/rows.go | 649 |
1 files changed, 649 insertions, 0 deletions
diff --git a/vendor/github.com/jackc/pgx/v5/rows.go b/vendor/github.com/jackc/pgx/v5/rows.go new file mode 100644 index 000000000..ffe739b02 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/rows.go @@ -0,0 +1,649 @@ +package pgx + +import ( + "context" + "errors" + "fmt" + "reflect" + "strings" + "time" + + "github.com/jackc/pgx/v5/internal/stmtcache" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" +) + +// Rows is the result set returned from *Conn.Query. Rows must be closed before +// the *Conn can be used again. Rows are closed by explicitly calling Close(), +// calling Next() until it returns false, or when a fatal error occurs. +// +// Once a Rows is closed the only methods that may be called are Close(), Err(), and CommandTag(). +// +// Rows is an interface instead of a struct to allow tests to mock Query. However, +// adding a method to an interface is technically a breaking change. Because of this +// the Rows interface is partially excluded from semantic version requirements. +// Methods will not be removed or changed, but new methods may be added. +type Rows interface { + // Close closes the rows, making the connection ready for use again. It is safe + // to call Close after rows is already closed. + Close() + + // Err returns any error that occurred while reading. + Err() error + + // CommandTag returns the command tag from this query. It is only available after Rows is closed. + CommandTag() pgconn.CommandTag + + FieldDescriptions() []pgconn.FieldDescription + + // Next prepares the next row for reading. It returns true if there is another + // row and false if no more rows are available. It automatically closes rows + // when all rows are read. + Next() bool + + // Scan reads the values from the current row into dest values positionally. + // dest can include pointers to core types, values implementing the Scanner + // interface, and nil. nil will skip the value entirely. It is an error to + // call Scan without first calling Next() and checking that it returned true. + Scan(dest ...any) error + + // Values returns the decoded row values. As with Scan(), it is an error to + // call Values without first calling Next() and checking that it returned + // true. + Values() ([]any, error) + + // RawValues returns the unparsed bytes of the row values. The returned data is only valid until the next Next + // call or the Rows is closed. + RawValues() [][]byte + + // Conn returns the underlying *Conn on which the query was executed. This may return nil if Rows did not come from a + // *Conn (e.g. if it was created by RowsFromResultReader) + Conn() *Conn +} + +// Row is a convenience wrapper over Rows that is returned by QueryRow. +// +// Row is an interface instead of a struct to allow tests to mock QueryRow. However, +// adding a method to an interface is technically a breaking change. Because of this +// the Row interface is partially excluded from semantic version requirements. +// Methods will not be removed or changed, but new methods may be added. +type Row interface { + // Scan works the same as Rows. with the following exceptions. If no + // rows were found it returns ErrNoRows. If multiple rows are returned it + // ignores all but the first. + Scan(dest ...any) error +} + +// RowScanner scans an entire row at a time into the RowScanner. +type RowScanner interface { + // ScanRows scans the row. + ScanRow(rows Rows) error +} + +// connRow implements the Row interface for Conn.QueryRow. +type connRow baseRows + +func (r *connRow) Scan(dest ...any) (err error) { + rows := (*baseRows)(r) + + if rows.Err() != nil { + return rows.Err() + } + + for _, d := range dest { + if _, ok := d.(*pgtype.DriverBytes); ok { + rows.Close() + return fmt.Errorf("cannot scan into *pgtype.DriverBytes from QueryRow") + } + } + + if !rows.Next() { + if rows.Err() == nil { + return ErrNoRows + } + return rows.Err() + } + + rows.Scan(dest...) + rows.Close() + return rows.Err() +} + +// baseRows implements the Rows interface for Conn.Query. +type baseRows struct { + typeMap *pgtype.Map + resultReader *pgconn.ResultReader + + values [][]byte + + commandTag pgconn.CommandTag + err error + closed bool + + scanPlans []pgtype.ScanPlan + scanTypes []reflect.Type + + conn *Conn + multiResultReader *pgconn.MultiResultReader + + queryTracer QueryTracer + batchTracer BatchTracer + ctx context.Context + startTime time.Time + sql string + args []any + rowCount int +} + +func (rows *baseRows) FieldDescriptions() []pgconn.FieldDescription { + return rows.resultReader.FieldDescriptions() +} + +func (rows *baseRows) Close() { + if rows.closed { + return + } + + rows.closed = true + + if rows.resultReader != nil { + var closeErr error + rows.commandTag, closeErr = rows.resultReader.Close() + if rows.err == nil { + rows.err = closeErr + } + } + + if rows.multiResultReader != nil { + closeErr := rows.multiResultReader.Close() + if rows.err == nil { + rows.err = closeErr + } + } + + if rows.err != nil && rows.conn != nil && rows.sql != "" { + if stmtcache.IsStatementInvalid(rows.err) { + if sc := rows.conn.statementCache; sc != nil { + sc.Invalidate(rows.sql) + } + + if sc := rows.conn.descriptionCache; sc != nil { + sc.Invalidate(rows.sql) + } + } + } + + if rows.batchTracer != nil { + rows.batchTracer.TraceBatchQuery(rows.ctx, rows.conn, TraceBatchQueryData{SQL: rows.sql, Args: rows.args, CommandTag: rows.commandTag, Err: rows.err}) + } else if rows.queryTracer != nil { + rows.queryTracer.TraceQueryEnd(rows.ctx, rows.conn, TraceQueryEndData{rows.commandTag, rows.err}) + } +} + +func (rows *baseRows) CommandTag() pgconn.CommandTag { + return rows.commandTag +} + +func (rows *baseRows) Err() error { + return rows.err +} + +// fatal signals an error occurred after the query was sent to the server. It +// closes the rows automatically. +func (rows *baseRows) fatal(err error) { + if rows.err != nil { + return + } + + rows.err = err + rows.Close() +} + +func (rows *baseRows) Next() bool { + if rows.closed { + return false + } + + if rows.resultReader.NextRow() { + rows.rowCount++ + rows.values = rows.resultReader.Values() + return true + } else { + rows.Close() + return false + } +} + +func (rows *baseRows) Scan(dest ...any) error { + m := rows.typeMap + fieldDescriptions := rows.FieldDescriptions() + values := rows.values + + if len(fieldDescriptions) != len(values) { + err := fmt.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values)) + rows.fatal(err) + return err + } + + if len(dest) == 1 { + if rc, ok := dest[0].(RowScanner); ok { + return rc.ScanRow(rows) + } + } + + if len(fieldDescriptions) != len(dest) { + err := fmt.Errorf("number of field descriptions must equal number of destinations, got %d and %d", len(fieldDescriptions), len(dest)) + rows.fatal(err) + return err + } + + if rows.scanPlans == nil { + rows.scanPlans = make([]pgtype.ScanPlan, len(values)) + rows.scanTypes = make([]reflect.Type, len(values)) + for i := range dest { + rows.scanPlans[i] = m.PlanScan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, dest[i]) + rows.scanTypes[i] = reflect.TypeOf(dest[i]) + } + } + + for i, dst := range dest { + if dst == nil { + continue + } + + if rows.scanTypes[i] != reflect.TypeOf(dst) { + rows.scanPlans[i] = m.PlanScan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, dest[i]) + rows.scanTypes[i] = reflect.TypeOf(dest[i]) + } + + err := rows.scanPlans[i].Scan(values[i], dst) + if err != nil { + err = ScanArgError{ColumnIndex: i, Err: err} + rows.fatal(err) + return err + } + } + + return nil +} + +func (rows *baseRows) Values() ([]any, error) { + if rows.closed { + return nil, errors.New("rows is closed") + } + + values := make([]any, 0, len(rows.FieldDescriptions())) + + for i := range rows.FieldDescriptions() { + buf := rows.values[i] + fd := &rows.FieldDescriptions()[i] + + if buf == nil { + values = append(values, nil) + continue + } + + if dt, ok := rows.typeMap.TypeForOID(fd.DataTypeOID); ok { + value, err := dt.Codec.DecodeValue(rows.typeMap, fd.DataTypeOID, fd.Format, buf) + if err != nil { + rows.fatal(err) + } + values = append(values, value) + } else { + switch fd.Format { + case TextFormatCode: + values = append(values, string(buf)) + case BinaryFormatCode: + newBuf := make([]byte, len(buf)) + copy(newBuf, buf) + values = append(values, newBuf) + default: + rows.fatal(errors.New("Unknown format code")) + } + } + + if rows.Err() != nil { + return nil, rows.Err() + } + } + + return values, rows.Err() +} + +func (rows *baseRows) RawValues() [][]byte { + return rows.values +} + +func (rows *baseRows) Conn() *Conn { + return rows.conn +} + +type ScanArgError struct { + ColumnIndex int + Err error +} + +func (e ScanArgError) Error() string { + return fmt.Sprintf("can't scan into dest[%d]: %v", e.ColumnIndex, e.Err) +} + +func (e ScanArgError) Unwrap() error { + return e.Err +} + +// ScanRow decodes raw row data into dest. It can be used to scan rows read from the lower level pgconn interface. +// +// typeMap - OID to Go type mapping. +// fieldDescriptions - OID and format of values +// values - the raw data as returned from the PostgreSQL server +// dest - the destination that values will be decoded into +func ScanRow(typeMap *pgtype.Map, fieldDescriptions []pgconn.FieldDescription, values [][]byte, dest ...any) error { + if len(fieldDescriptions) != len(values) { + return fmt.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values)) + } + if len(fieldDescriptions) != len(dest) { + return fmt.Errorf("number of field descriptions must equal number of destinations, got %d and %d", len(fieldDescriptions), len(dest)) + } + + for i, d := range dest { + if d == nil { + continue + } + + err := typeMap.Scan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, values[i], d) + if err != nil { + return ScanArgError{ColumnIndex: i, Err: err} + } + } + + return nil +} + +// RowsFromResultReader returns a Rows that will read from values resultReader and decode with typeMap. It can be used +// to read from the lower level pgconn interface. +func RowsFromResultReader(typeMap *pgtype.Map, resultReader *pgconn.ResultReader) Rows { + return &baseRows{ + typeMap: typeMap, + resultReader: resultReader, + } +} + +// ForEachRow iterates through rows. For each row it scans into the elements of scans and calls fn. If any row +// fails to scan or fn returns an error the query will be aborted and the error will be returned. Rows will be closed +// when ForEachRow returns. +func ForEachRow(rows Rows, scans []any, fn func() error) (pgconn.CommandTag, error) { + defer rows.Close() + + for rows.Next() { + err := rows.Scan(scans...) + if err != nil { + return pgconn.CommandTag{}, err + } + + err = fn() + if err != nil { + return pgconn.CommandTag{}, err + } + } + + if err := rows.Err(); err != nil { + return pgconn.CommandTag{}, err + } + + return rows.CommandTag(), nil +} + +// CollectableRow is the subset of Rows methods that a RowToFunc is allowed to call. +type CollectableRow interface { + FieldDescriptions() []pgconn.FieldDescription + Scan(dest ...any) error + Values() ([]any, error) + RawValues() [][]byte +} + +// RowToFunc is a function that scans or otherwise converts row to a T. +type RowToFunc[T any] func(row CollectableRow) (T, error) + +// CollectRows iterates through rows, calling fn for each row, and collecting the results into a slice of T. +func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) { + defer rows.Close() + + slice := []T{} + + for rows.Next() { + value, err := fn(rows) + if err != nil { + return nil, err + } + slice = append(slice, value) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return slice, nil +} + +// CollectOneRow calls fn for the first row in rows and returns the result. If no rows are found returns an error where errors.Is(ErrNoRows) is true. +// CollectOneRow is to CollectRows as QueryRow is to Query. +func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) { + defer rows.Close() + + var value T + var err error + + if !rows.Next() { + if err = rows.Err(); err != nil { + return value, err + } + return value, ErrNoRows + } + + value, err = fn(rows) + if err != nil { + return value, err + } + + rows.Close() + return value, rows.Err() +} + +// RowTo returns a T scanned from row. +func RowTo[T any](row CollectableRow) (T, error) { + var value T + err := row.Scan(&value) + return value, err +} + +// RowTo returns a the address of a T scanned from row. +func RowToAddrOf[T any](row CollectableRow) (*T, error) { + var value T + err := row.Scan(&value) + return &value, err +} + +// RowToMap returns a map scanned from row. +func RowToMap(row CollectableRow) (map[string]any, error) { + var value map[string]any + err := row.Scan((*mapRowScanner)(&value)) + return value, err +} + +type mapRowScanner map[string]any + +func (rs *mapRowScanner) ScanRow(rows Rows) error { + values, err := rows.Values() + if err != nil { + return err + } + + *rs = make(mapRowScanner, len(values)) + + for i := range values { + (*rs)[string(rows.FieldDescriptions()[i].Name)] = values[i] + } + + return nil +} + +// RowToStructByPos returns a T scanned from row. T must be a struct. T must have the same number a public fields as row +// has fields. The row and T fields will by matched by position. +func RowToStructByPos[T any](row CollectableRow) (T, error) { + var value T + err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value}) + return value, err +} + +// RowToAddrOfStructByPos returns the address of a T scanned from row. T must be a struct. T must have the same number a +// public fields as row has fields. The row and T fields will by matched by position. +func RowToAddrOfStructByPos[T any](row CollectableRow) (*T, error) { + var value T + err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value}) + return &value, err +} + +type positionalStructRowScanner struct { + ptrToStruct any +} + +func (rs *positionalStructRowScanner) ScanRow(rows Rows) error { + dst := rs.ptrToStruct + dstValue := reflect.ValueOf(dst) + if dstValue.Kind() != reflect.Ptr { + return fmt.Errorf("dst not a pointer") + } + + dstElemValue := dstValue.Elem() + scanTargets := rs.appendScanTargets(dstElemValue, nil) + + if len(rows.RawValues()) > len(scanTargets) { + return fmt.Errorf("got %d values, but dst struct has only %d fields", len(rows.RawValues()), len(scanTargets)) + } + + return rows.Scan(scanTargets...) +} + +func (rs *positionalStructRowScanner) appendScanTargets(dstElemValue reflect.Value, scanTargets []any) []any { + dstElemType := dstElemValue.Type() + + if scanTargets == nil { + scanTargets = make([]any, 0, dstElemType.NumField()) + } + + for i := 0; i < dstElemType.NumField(); i++ { + sf := dstElemType.Field(i) + if sf.PkgPath == "" { + // Handle anonymous struct embedding, but do not try to handle embedded pointers. + if sf.Anonymous && sf.Type.Kind() == reflect.Struct { + scanTargets = rs.appendScanTargets(dstElemValue.Field(i), scanTargets) + } else { + scanTargets = append(scanTargets, dstElemValue.Field(i).Addr().Interface()) + } + } + } + + return scanTargets +} + +// RowToStructByName returns a T scanned from row. T must be a struct. T must have the same number of named public +// fields as row has fields. The row and T fields will by matched by name. The match is case-insensitive. The database +// column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored. +func RowToStructByName[T any](row CollectableRow) (T, error) { + var value T + err := row.Scan(&namedStructRowScanner{ptrToStruct: &value}) + return value, err +} + +// RowToAddrOfStructByName returns the address of a T scanned from row. T must be a struct. T must have the same number +// of named public fields as row has fields. The row and T fields will by matched by name. The match is +// case-insensitive. The database column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" +// then the field will be ignored. +func RowToAddrOfStructByName[T any](row CollectableRow) (*T, error) { + var value T + err := row.Scan(&namedStructRowScanner{ptrToStruct: &value}) + return &value, err +} + +type namedStructRowScanner struct { + ptrToStruct any +} + +func (rs *namedStructRowScanner) ScanRow(rows Rows) error { + dst := rs.ptrToStruct + dstValue := reflect.ValueOf(dst) + if dstValue.Kind() != reflect.Ptr { + return fmt.Errorf("dst not a pointer") + } + + dstElemValue := dstValue.Elem() + scanTargets, err := rs.appendScanTargets(dstElemValue, nil, rows.FieldDescriptions()) + + if err != nil { + return err + } + + for i, t := range scanTargets { + if t == nil { + return fmt.Errorf("struct doesn't have corresponding row field %s", rows.FieldDescriptions()[i].Name) + } + } + + return rows.Scan(scanTargets...) +} + +const structTagKey = "db" + +func fieldPosByName(fldDescs []pgconn.FieldDescription, field string) (i int) { + i = -1 + for i, desc := range fldDescs { + if strings.EqualFold(desc.Name, field) { + return i + } + } + return +} + +func (rs *namedStructRowScanner) appendScanTargets(dstElemValue reflect.Value, scanTargets []any, fldDescs []pgconn.FieldDescription) ([]any, error) { + var err error + dstElemType := dstElemValue.Type() + + if scanTargets == nil { + scanTargets = make([]any, len(fldDescs)) + } + + for i := 0; i < dstElemType.NumField(); i++ { + sf := dstElemType.Field(i) + if sf.PkgPath != "" && !sf.Anonymous { + // Field is unexported, skip it. + continue + } + // Handle anoymous struct embedding, but do not try to handle embedded pointers. + if sf.Anonymous && sf.Type.Kind() == reflect.Struct { + scanTargets, err = rs.appendScanTargets(dstElemValue.Field(i), scanTargets, fldDescs) + if err != nil { + return nil, err + } + } else { + dbTag, dbTagPresent := sf.Tag.Lookup(structTagKey) + if dbTagPresent { + dbTag = strings.Split(dbTag, ",")[0] + } + if dbTag == "-" { + // Field is ignored, skip it. + continue + } + colName := dbTag + if !dbTagPresent { + colName = sf.Name + } + fpos := fieldPosByName(fldDescs, colName) + if fpos == -1 || fpos >= len(scanTargets) { + return nil, fmt.Errorf("cannot find field %s in returned row", colName) + } + scanTargets[fpos] = dstElemValue.Field(i).Addr().Interface() + } + } + + return scanTargets, err +} |