summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/model_table_m2m.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/uptrace/bun/model_table_m2m.go')
-rw-r--r--vendor/github.com/uptrace/bun/model_table_m2m.go138
1 files changed, 138 insertions, 0 deletions
diff --git a/vendor/github.com/uptrace/bun/model_table_m2m.go b/vendor/github.com/uptrace/bun/model_table_m2m.go
new file mode 100644
index 000000000..4357e3a8e
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/model_table_m2m.go
@@ -0,0 +1,138 @@
+package bun
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "reflect"
+
+ "github.com/uptrace/bun/internal"
+ "github.com/uptrace/bun/schema"
+)
+
+type m2mModel struct {
+ *sliceTableModel
+ baseTable *schema.Table
+ rel *schema.Relation
+
+ baseValues map[internal.MapKey][]reflect.Value
+ structKey []interface{}
+}
+
+var _ tableModel = (*m2mModel)(nil)
+
+func newM2MModel(j *join) *m2mModel {
+ baseTable := j.BaseModel.Table()
+ joinModel := j.JoinModel.(*sliceTableModel)
+ baseValues := baseValues(joinModel, baseTable.PKs)
+ if len(baseValues) == 0 {
+ return nil
+ }
+ m := &m2mModel{
+ sliceTableModel: joinModel,
+ baseTable: baseTable,
+ rel: j.Relation,
+
+ baseValues: baseValues,
+ }
+ if !m.sliceOfPtr {
+ m.strct = reflect.New(m.table.Type).Elem()
+ }
+ return m
+}
+
+func (m *m2mModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) {
+ columns, err := rows.Columns()
+ if err != nil {
+ return 0, err
+ }
+
+ m.columns = columns
+ dest := makeDest(m, len(columns))
+
+ var n int
+
+ for rows.Next() {
+ if m.sliceOfPtr {
+ m.strct = reflect.New(m.table.Type).Elem()
+ } else {
+ m.strct.Set(m.table.ZeroValue)
+ }
+ m.structInited = false
+
+ m.scanIndex = 0
+ m.structKey = m.structKey[:0]
+ if err := rows.Scan(dest...); err != nil {
+ return 0, err
+ }
+
+ if err := m.parkStruct(); err != nil {
+ return 0, err
+ }
+
+ n++
+ }
+ if err := rows.Err(); err != nil {
+ return 0, err
+ }
+
+ return n, nil
+}
+
+func (m *m2mModel) Scan(src interface{}) error {
+ column := m.columns[m.scanIndex]
+ m.scanIndex++
+
+ field, ok := m.table.FieldMap[column]
+ if !ok {
+ return m.scanM2MColumn(column, src)
+ }
+
+ if err := field.ScanValue(m.strct, src); err != nil {
+ return err
+ }
+
+ for _, fk := range m.rel.M2MBaseFields {
+ if fk.Name == field.Name {
+ m.structKey = append(m.structKey, field.Value(m.strct).Interface())
+ break
+ }
+ }
+
+ return nil
+}
+
+func (m *m2mModel) scanM2MColumn(column string, src interface{}) error {
+ for _, field := range m.rel.M2MBaseFields {
+ if field.Name == column {
+ dest := reflect.New(field.IndirectType).Elem()
+ if err := field.Scan(dest, src); err != nil {
+ return err
+ }
+ m.structKey = append(m.structKey, dest.Interface())
+ break
+ }
+ }
+
+ _, err := m.scanColumn(column, src)
+ return err
+}
+
+func (m *m2mModel) parkStruct() error {
+ baseValues, ok := m.baseValues[internal.NewMapKey(m.structKey)]
+ if !ok {
+ return fmt.Errorf(
+ "bun: m2m relation=%s does not have base %s with key=%q (check join conditions)",
+ m.rel.Field.GoName, m.baseTable, m.structKey)
+ }
+
+ for _, v := range baseValues {
+ if m.sliceOfPtr {
+ v.Set(reflect.Append(v, m.strct.Addr()))
+ } else {
+ v.Set(reflect.Append(v, m.strct))
+ }
+ }
+
+ return nil
+}