summaryrefslogtreecommitdiff
path: root/vendor/github.com/jackc/pgx/v5/rows.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/jackc/pgx/v5/rows.go')
-rw-r--r--vendor/github.com/jackc/pgx/v5/rows.go266
1 files changed, 192 insertions, 74 deletions
diff --git a/vendor/github.com/jackc/pgx/v5/rows.go b/vendor/github.com/jackc/pgx/v5/rows.go
index 78ef5326a..d4f7a9016 100644
--- a/vendor/github.com/jackc/pgx/v5/rows.go
+++ b/vendor/github.com/jackc/pgx/v5/rows.go
@@ -6,6 +6,7 @@ import (
"fmt"
"reflect"
"strings"
+ "sync"
"time"
"github.com/jackc/pgx/v5/pgconn"
@@ -418,6 +419,8 @@ type CollectableRow interface {
type RowToFunc[T any] func(row CollectableRow) (T, error)
// AppendRows iterates through rows, calling fn for each row, and appending the results into a slice of T.
+//
+// This function closes the rows automatically on return.
func AppendRows[T any, S ~[]T](slice S, rows Rows, fn RowToFunc[T]) (S, error) {
defer rows.Close()
@@ -437,12 +440,16 @@ func AppendRows[T any, S ~[]T](slice S, rows Rows, fn RowToFunc[T]) (S, error) {
}
// CollectRows iterates through rows, calling fn for each row, and collecting the results into a slice of T.
+//
+// This function closes the rows automatically on return.
func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) {
return AppendRows([]T{}, rows, fn)
}
// CollectOneRow calls fn for the first row in rows and returns the result. If no rows are found returns an error where errors.Is(ErrNoRows) is true.
// CollectOneRow is to CollectRows as QueryRow is to Query.
+//
+// This function closes the rows automatically on return.
func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) {
defer rows.Close()
@@ -468,6 +475,8 @@ func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) {
// CollectExactlyOneRow calls fn for the first row in rows and returns the result.
// - If no rows are found returns an error where errors.Is(ErrNoRows) is true.
// - If more than 1 row is found returns an error where errors.Is(ErrTooManyRows) is true.
+//
+// This function closes the rows automatically on return.
func CollectExactlyOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) {
defer rows.Close()
@@ -541,7 +550,7 @@ func (rs *mapRowScanner) ScanRow(rows Rows) error {
// ignored.
func RowToStructByPos[T any](row CollectableRow) (T, error) {
var value T
- err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value})
+ err := (&positionalStructRowScanner{ptrToStruct: &value}).ScanRow(row)
return value, err
}
@@ -550,7 +559,7 @@ func RowToStructByPos[T any](row CollectableRow) (T, error) {
// the field will be ignored.
func RowToAddrOfStructByPos[T any](row CollectableRow) (*T, error) {
var value T
- err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value})
+ err := (&positionalStructRowScanner{ptrToStruct: &value}).ScanRow(row)
return &value, err
}
@@ -558,46 +567,60 @@ type positionalStructRowScanner struct {
ptrToStruct any
}
-func (rs *positionalStructRowScanner) ScanRow(rows Rows) error {
- dst := rs.ptrToStruct
- dstValue := reflect.ValueOf(dst)
- if dstValue.Kind() != reflect.Ptr {
- return fmt.Errorf("dst not a pointer")
+func (rs *positionalStructRowScanner) ScanRow(rows CollectableRow) error {
+ typ := reflect.TypeOf(rs.ptrToStruct).Elem()
+ fields := lookupStructFields(typ)
+ if len(rows.RawValues()) > len(fields) {
+ return fmt.Errorf(
+ "got %d values, but dst struct has only %d fields",
+ len(rows.RawValues()),
+ len(fields),
+ )
}
-
- dstElemValue := dstValue.Elem()
- scanTargets := rs.appendScanTargets(dstElemValue, nil)
-
- if len(rows.RawValues()) > len(scanTargets) {
- return fmt.Errorf("got %d values, but dst struct has only %d fields", len(rows.RawValues()), len(scanTargets))
- }
-
+ scanTargets := setupStructScanTargets(rs.ptrToStruct, fields)
return rows.Scan(scanTargets...)
}
-func (rs *positionalStructRowScanner) appendScanTargets(dstElemValue reflect.Value, scanTargets []any) []any {
- dstElemType := dstElemValue.Type()
+// Map from reflect.Type -> []structRowField
+var positionalStructFieldMap sync.Map
- if scanTargets == nil {
- scanTargets = make([]any, 0, dstElemType.NumField())
+func lookupStructFields(t reflect.Type) []structRowField {
+ if cached, ok := positionalStructFieldMap.Load(t); ok {
+ return cached.([]structRowField)
}
- for i := 0; i < dstElemType.NumField(); i++ {
- sf := dstElemType.Field(i)
+ fieldStack := make([]int, 0, 1)
+ fields := computeStructFields(t, make([]structRowField, 0, t.NumField()), &fieldStack)
+ fieldsIface, _ := positionalStructFieldMap.LoadOrStore(t, fields)
+ return fieldsIface.([]structRowField)
+}
+
+func computeStructFields(
+ t reflect.Type,
+ fields []structRowField,
+ fieldStack *[]int,
+) []structRowField {
+ tail := len(*fieldStack)
+ *fieldStack = append(*fieldStack, 0)
+ for i := 0; i < t.NumField(); i++ {
+ sf := t.Field(i)
+ (*fieldStack)[tail] = i
// Handle anonymous struct embedding, but do not try to handle embedded pointers.
if sf.Anonymous && sf.Type.Kind() == reflect.Struct {
- scanTargets = rs.appendScanTargets(dstElemValue.Field(i), scanTargets)
+ fields = computeStructFields(sf.Type, fields, fieldStack)
} else if sf.PkgPath == "" {
dbTag, _ := sf.Tag.Lookup(structTagKey)
if dbTag == "-" {
// Field is ignored, skip it.
continue
}
- scanTargets = append(scanTargets, dstElemValue.Field(i).Addr().Interface())
+ fields = append(fields, structRowField{
+ path: append([]int(nil), *fieldStack...),
+ })
}
}
-
- return scanTargets
+ *fieldStack = (*fieldStack)[:tail]
+ return fields
}
// RowToStructByName returns a T scanned from row. T must be a struct. T must have the same number of named public
@@ -605,7 +628,7 @@ func (rs *positionalStructRowScanner) appendScanTargets(dstElemValue reflect.Val
// column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored.
func RowToStructByName[T any](row CollectableRow) (T, error) {
var value T
- err := row.Scan(&namedStructRowScanner{ptrToStruct: &value})
+ err := (&namedStructRowScanner{ptrToStruct: &value}).ScanRow(row)
return value, err
}
@@ -615,7 +638,7 @@ func RowToStructByName[T any](row CollectableRow) (T, error) {
// then the field will be ignored.
func RowToAddrOfStructByName[T any](row CollectableRow) (*T, error) {
var value T
- err := row.Scan(&namedStructRowScanner{ptrToStruct: &value})
+ err := (&namedStructRowScanner{ptrToStruct: &value}).ScanRow(row)
return &value, err
}
@@ -624,7 +647,7 @@ func RowToAddrOfStructByName[T any](row CollectableRow) (*T, error) {
// column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored.
func RowToStructByNameLax[T any](row CollectableRow) (T, error) {
var value T
- err := row.Scan(&namedStructRowScanner{ptrToStruct: &value, lax: true})
+ err := (&namedStructRowScanner{ptrToStruct: &value, lax: true}).ScanRow(row)
return value, err
}
@@ -634,7 +657,7 @@ func RowToStructByNameLax[T any](row CollectableRow) (T, error) {
// then the field will be ignored.
func RowToAddrOfStructByNameLax[T any](row CollectableRow) (*T, error) {
var value T
- err := row.Scan(&namedStructRowScanner{ptrToStruct: &value, lax: true})
+ err := (&namedStructRowScanner{ptrToStruct: &value, lax: true}).ScanRow(row)
return &value, err
}
@@ -643,64 +666,123 @@ type namedStructRowScanner struct {
lax bool
}
-func (rs *namedStructRowScanner) ScanRow(rows Rows) error {
- dst := rs.ptrToStruct
- dstValue := reflect.ValueOf(dst)
- if dstValue.Kind() != reflect.Ptr {
- return fmt.Errorf("dst not a pointer")
- }
-
- dstElemValue := dstValue.Elem()
- scanTargets, err := rs.appendScanTargets(dstElemValue, nil, rows.FieldDescriptions())
+func (rs *namedStructRowScanner) ScanRow(rows CollectableRow) error {
+ typ := reflect.TypeOf(rs.ptrToStruct).Elem()
+ fldDescs := rows.FieldDescriptions()
+ namedStructFields, err := lookupNamedStructFields(typ, fldDescs)
if err != nil {
return err
}
-
- for i, t := range scanTargets {
- if t == nil {
- return fmt.Errorf("struct doesn't have corresponding row field %s", rows.FieldDescriptions()[i].Name)
- }
+ if !rs.lax && namedStructFields.missingField != "" {
+ return fmt.Errorf("cannot find field %s in returned row", namedStructFields.missingField)
}
-
+ fields := namedStructFields.fields
+ scanTargets := setupStructScanTargets(rs.ptrToStruct, fields)
return rows.Scan(scanTargets...)
}
-const structTagKey = "db"
-
-func fieldPosByName(fldDescs []pgconn.FieldDescription, field string) (i int) {
- i = -1
- for i, desc := range fldDescs {
+// Map from namedStructFieldMap -> *namedStructFields
+var namedStructFieldMap sync.Map
- // Snake case support.
- field = strings.ReplaceAll(field, "_", "")
- descName := strings.ReplaceAll(desc.Name, "_", "")
+type namedStructFieldsKey struct {
+ t reflect.Type
+ colNames string
+}
- if strings.EqualFold(descName, field) {
- return i
- }
- }
- return
+type namedStructFields struct {
+ fields []structRowField
+ // missingField is the first field from the struct without a corresponding row field.
+ // This is used to construct the correct error message for non-lax queries.
+ missingField string
}
-func (rs *namedStructRowScanner) appendScanTargets(dstElemValue reflect.Value, scanTargets []any, fldDescs []pgconn.FieldDescription) ([]any, error) {
- var err error
- dstElemType := dstElemValue.Type()
+func lookupNamedStructFields(
+ t reflect.Type,
+ fldDescs []pgconn.FieldDescription,
+) (*namedStructFields, error) {
+ key := namedStructFieldsKey{
+ t: t,
+ colNames: joinFieldNames(fldDescs),
+ }
+ if cached, ok := namedStructFieldMap.Load(key); ok {
+ return cached.(*namedStructFields), nil
+ }
- if scanTargets == nil {
- scanTargets = make([]any, len(fldDescs))
+ // We could probably do two-levels of caching, where we compute the key -> fields mapping
+ // for a type only once, cache it by type, then use that to compute the column -> fields
+ // mapping for a given set of columns.
+ fieldStack := make([]int, 0, 1)
+ fields, missingField := computeNamedStructFields(
+ fldDescs,
+ t,
+ make([]structRowField, len(fldDescs)),
+ &fieldStack,
+ )
+ for i, f := range fields {
+ if f.path == nil {
+ return nil, fmt.Errorf(
+ "struct doesn't have corresponding row field %s",
+ fldDescs[i].Name,
+ )
+ }
}
- for i := 0; i < dstElemType.NumField(); i++ {
- sf := dstElemType.Field(i)
+ fieldsIface, _ := namedStructFieldMap.LoadOrStore(
+ key,
+ &namedStructFields{fields: fields, missingField: missingField},
+ )
+ return fieldsIface.(*namedStructFields), nil
+}
+
+func joinFieldNames(fldDescs []pgconn.FieldDescription) string {
+ switch len(fldDescs) {
+ case 0:
+ return ""
+ case 1:
+ return fldDescs[0].Name
+ }
+
+ totalSize := len(fldDescs) - 1 // Space for separator bytes.
+ for _, d := range fldDescs {
+ totalSize += len(d.Name)
+ }
+ var b strings.Builder
+ b.Grow(totalSize)
+ b.WriteString(fldDescs[0].Name)
+ for _, d := range fldDescs[1:] {
+ b.WriteByte(0) // Join with NUL byte as it's (presumably) not a valid column character.
+ b.WriteString(d.Name)
+ }
+ return b.String()
+}
+
+func computeNamedStructFields(
+ fldDescs []pgconn.FieldDescription,
+ t reflect.Type,
+ fields []structRowField,
+ fieldStack *[]int,
+) ([]structRowField, string) {
+ var missingField string
+ tail := len(*fieldStack)
+ *fieldStack = append(*fieldStack, 0)
+ for i := 0; i < t.NumField(); i++ {
+ sf := t.Field(i)
+ (*fieldStack)[tail] = i
if sf.PkgPath != "" && !sf.Anonymous {
// Field is unexported, skip it.
continue
}
// Handle anonymous struct embedding, but do not try to handle embedded pointers.
if sf.Anonymous && sf.Type.Kind() == reflect.Struct {
- scanTargets, err = rs.appendScanTargets(dstElemValue.Field(i), scanTargets, fldDescs)
- if err != nil {
- return nil, err
+ var missingSubField string
+ fields, missingSubField = computeNamedStructFields(
+ fldDescs,
+ sf.Type,
+ fields,
+ fieldStack,
+ )
+ if missingField == "" {
+ missingField = missingSubField
}
} else {
dbTag, dbTagPresent := sf.Tag.Lookup(structTagKey)
@@ -717,17 +799,53 @@ func (rs *namedStructRowScanner) appendScanTargets(dstElemValue reflect.Value, s
}
fpos := fieldPosByName(fldDescs, colName)
if fpos == -1 {
- if rs.lax {
- continue
+ if missingField == "" {
+ missingField = colName
}
- return nil, fmt.Errorf("cannot find field %s in returned row", colName)
+ continue
}
- if fpos >= len(scanTargets) && !rs.lax {
- return nil, fmt.Errorf("cannot find field %s in returned row", colName)
+ fields[fpos] = structRowField{
+ path: append([]int(nil), *fieldStack...),
}
- scanTargets[fpos] = dstElemValue.Field(i).Addr().Interface()
}
}
+ *fieldStack = (*fieldStack)[:tail]
+
+ return fields, missingField
+}
+
+const structTagKey = "db"
+
+func fieldPosByName(fldDescs []pgconn.FieldDescription, field string) (i int) {
+ i = -1
+ for i, desc := range fldDescs {
- return scanTargets, err
+ // Snake case support.
+ field = strings.ReplaceAll(field, "_", "")
+ descName := strings.ReplaceAll(desc.Name, "_", "")
+
+ if strings.EqualFold(descName, field) {
+ return i
+ }
+ }
+ return
+}
+
+// structRowField describes a field of a struct.
+//
+// TODO: It would be a bit more efficient to track the path using the pointer
+// offset within the (outermost) struct and use unsafe.Pointer arithmetic to
+// construct references when scanning rows. However, it's not clear it's worth
+// using unsafe for this.
+type structRowField struct {
+ path []int
+}
+
+func setupStructScanTargets(receiver any, fields []structRowField) []any {
+ scanTargets := make([]any, len(fields))
+ v := reflect.ValueOf(receiver).Elem()
+ for i, f := range fields {
+ scanTargets[i] = v.FieldByIndex(f.path).Addr().Interface()
+ }
+ return scanTargets
}