diff options
Diffstat (limited to 'vendor/github.com/go-pg/pg/v10/orm/table.go')
-rw-r--r-- | vendor/github.com/go-pg/pg/v10/orm/table.go | 1560 |
1 files changed, 0 insertions, 1560 deletions
diff --git a/vendor/github.com/go-pg/pg/v10/orm/table.go b/vendor/github.com/go-pg/pg/v10/orm/table.go deleted file mode 100644 index 8b57bbfc0..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/table.go +++ /dev/null @@ -1,1560 +0,0 @@ -package orm - -import ( - "database/sql" - "encoding/json" - "fmt" - "net" - "reflect" - "strconv" - "strings" - "sync" - "time" - - "github.com/jinzhu/inflection" - "github.com/vmihailenco/tagparser" - - "github.com/go-pg/pg/v10/internal" - "github.com/go-pg/pg/v10/internal/pool" - "github.com/go-pg/pg/v10/pgjson" - "github.com/go-pg/pg/v10/types" - "github.com/go-pg/zerochecker" -) - -const ( - beforeScanHookFlag = uint16(1) << iota - afterScanHookFlag - afterSelectHookFlag - beforeInsertHookFlag - afterInsertHookFlag - beforeUpdateHookFlag - afterUpdateHookFlag - beforeDeleteHookFlag - afterDeleteHookFlag - discardUnknownColumnsFlag -) - -var ( - timeType = reflect.TypeOf((*time.Time)(nil)).Elem() - nullTimeType = reflect.TypeOf((*types.NullTime)(nil)).Elem() - sqlNullTimeType = reflect.TypeOf((*sql.NullTime)(nil)).Elem() - ipType = reflect.TypeOf((*net.IP)(nil)).Elem() - ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem() - scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem() - nullBoolType = reflect.TypeOf((*sql.NullBool)(nil)).Elem() - nullFloatType = reflect.TypeOf((*sql.NullFloat64)(nil)).Elem() - nullIntType = reflect.TypeOf((*sql.NullInt64)(nil)).Elem() - nullStringType = reflect.TypeOf((*sql.NullString)(nil)).Elem() - jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem() -) - -var tableNameInflector = inflection.Plural - -// SetTableNameInflector overrides the default func that pluralizes -// model name to get table name, e.g. my_article becomes my_articles. -func SetTableNameInflector(fn func(string) string) { - tableNameInflector = fn -} - -// Table represents a SQL table created from Go struct. -type Table struct { - Type reflect.Type - zeroStruct reflect.Value - - TypeName string - Alias types.Safe - ModelName string - - SQLName types.Safe - SQLNameForSelects types.Safe - - Tablespace types.Safe - - PartitionBy string - - allFields []*Field // read only - skippedFields []*Field - - Fields []*Field // PKs + DataFields - PKs []*Field - DataFields []*Field - fieldsMapMu sync.RWMutex - FieldsMap map[string]*Field - - Methods map[string]*Method - Relations map[string]*Relation - Unique map[string][]*Field - - SoftDeleteField *Field - SetSoftDeleteField func(fv reflect.Value) error - - flags uint16 -} - -func newTable(typ reflect.Type) *Table { - t := new(Table) - t.Type = typ - t.zeroStruct = reflect.New(t.Type).Elem() - t.TypeName = internal.ToExported(t.Type.Name()) - t.ModelName = internal.Underscore(t.Type.Name()) - tableName := tableNameInflector(t.ModelName) - t.setName(quoteIdent(tableName)) - t.Alias = quoteIdent(t.ModelName) - - typ = reflect.PtrTo(t.Type) - if typ.Implements(beforeScanHookType) { - t.setFlag(beforeScanHookFlag) - } - if typ.Implements(afterScanHookType) { - t.setFlag(afterScanHookFlag) - } - if typ.Implements(afterSelectHookType) { - t.setFlag(afterSelectHookFlag) - } - if typ.Implements(beforeInsertHookType) { - t.setFlag(beforeInsertHookFlag) - } - if typ.Implements(afterInsertHookType) { - t.setFlag(afterInsertHookFlag) - } - if typ.Implements(beforeUpdateHookType) { - t.setFlag(beforeUpdateHookFlag) - } - if typ.Implements(afterUpdateHookType) { - t.setFlag(afterUpdateHookFlag) - } - if typ.Implements(beforeDeleteHookType) { - t.setFlag(beforeDeleteHookFlag) - } - if typ.Implements(afterDeleteHookType) { - t.setFlag(afterDeleteHookFlag) - } - - return t -} - -func (t *Table) init1() { - t.initFields() - t.initMethods() -} - -func (t *Table) init2() { - t.initInlines() - t.initRelations() - t.skippedFields = nil -} - -func (t *Table) setName(name types.Safe) { - t.SQLName = name - t.SQLNameForSelects = name - if t.Alias == "" { - t.Alias = name - } -} - -func (t *Table) String() string { - return "model=" + t.TypeName -} - -func (t *Table) setFlag(flag uint16) { - t.flags |= flag -} - -func (t *Table) hasFlag(flag uint16) bool { - if t == nil { - return false - } - return t.flags&flag != 0 -} - -func (t *Table) checkPKs() error { - if len(t.PKs) == 0 { - return fmt.Errorf("pg: %s does not have primary keys", t) - } - return nil -} - -func (t *Table) mustSoftDelete() error { - if t.SoftDeleteField == nil { - return fmt.Errorf("pg: %s does not support soft deletes", t) - } - return nil -} - -func (t *Table) AddField(field *Field) { - t.Fields = append(t.Fields, field) - if field.hasFlag(PrimaryKeyFlag) { - t.PKs = append(t.PKs, field) - } else { - t.DataFields = append(t.DataFields, field) - } - t.FieldsMap[field.SQLName] = field -} - -func (t *Table) RemoveField(field *Field) { - t.Fields = removeField(t.Fields, field) - if field.hasFlag(PrimaryKeyFlag) { - t.PKs = removeField(t.PKs, field) - } else { - t.DataFields = removeField(t.DataFields, field) - } - delete(t.FieldsMap, field.SQLName) -} - -func removeField(fields []*Field, field *Field) []*Field { - for i, f := range fields { - if f == field { - fields = append(fields[:i], fields[i+1:]...) - } - } - return fields -} - -func (t *Table) getField(name string) *Field { - t.fieldsMapMu.RLock() - field := t.FieldsMap[name] - t.fieldsMapMu.RUnlock() - return field -} - -func (t *Table) HasField(name string) bool { - _, ok := t.FieldsMap[name] - return ok -} - -func (t *Table) GetField(name string) (*Field, error) { - field, ok := t.FieldsMap[name] - if !ok { - return nil, fmt.Errorf("pg: %s does not have column=%s", t, name) - } - return field, nil -} - -func (t *Table) AppendParam(b []byte, strct reflect.Value, name string) ([]byte, bool) { - field, ok := t.FieldsMap[name] - if ok { - b = field.AppendValue(b, strct, 1) - return b, true - } - - method, ok := t.Methods[name] - if ok { - b = method.AppendValue(b, strct.Addr(), 1) - return b, true - } - - return b, false -} - -func (t *Table) initFields() { - t.Fields = make([]*Field, 0, t.Type.NumField()) - t.FieldsMap = make(map[string]*Field, t.Type.NumField()) - t.addFields(t.Type, nil) -} - -func (t *Table) addFields(typ reflect.Type, baseIndex []int) { - for i := 0; i < typ.NumField(); i++ { - f := typ.Field(i) - - // Make a copy so slice is not shared between fields. - index := make([]int, len(baseIndex)) - copy(index, baseIndex) - - if f.Anonymous { - if f.Tag.Get("sql") == "-" || f.Tag.Get("pg") == "-" { - continue - } - - fieldType := indirectType(f.Type) - if fieldType.Kind() != reflect.Struct { - continue - } - t.addFields(fieldType, append(index, f.Index...)) - - pgTag := tagparser.Parse(f.Tag.Get("pg")) - if _, inherit := pgTag.Options["inherit"]; inherit { - embeddedTable := _tables.get(fieldType, true) - t.TypeName = embeddedTable.TypeName - t.SQLName = embeddedTable.SQLName - t.SQLNameForSelects = embeddedTable.SQLNameForSelects - t.Alias = embeddedTable.Alias - t.ModelName = embeddedTable.ModelName - } - - continue - } - - field := t.newField(f, index) - if field != nil { - t.AddField(field) - } - } -} - -//nolint -func (t *Table) newField(f reflect.StructField, index []int) *Field { - pgTag := tagparser.Parse(f.Tag.Get("pg")) - - switch f.Name { - case "tableName": - if len(index) > 0 { - return nil - } - - if isKnownTableOption(pgTag.Name) { - internal.Warn.Printf( - "%s.%s tag name %q is also an option name; is it a mistake?", - t.TypeName, f.Name, pgTag.Name, - ) - } - - for name := range pgTag.Options { - if !isKnownTableOption(name) { - internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, f.Name, name) - } - } - - if tableSpace, ok := pgTag.Options["tablespace"]; ok { - s, _ := tagparser.Unquote(tableSpace) - t.Tablespace = quoteIdent(s) - } - - partitionBy, ok := pgTag.Options["partition_by"] - if !ok { - partitionBy, ok = pgTag.Options["partitionBy"] - if ok { - internal.Deprecated.Printf("partitionBy is renamed to partition_by") - } - } - if ok { - s, _ := tagparser.Unquote(partitionBy) - t.PartitionBy = s - } - - if pgTag.Name == "_" { - t.setName("") - } else if pgTag.Name != "" { - s, _ := tagparser.Unquote(pgTag.Name) - t.setName(types.Safe(quoteTableName(s))) - } - - if s, ok := pgTag.Options["select"]; ok { - s, _ = tagparser.Unquote(s) - t.SQLNameForSelects = types.Safe(quoteTableName(s)) - } - - if v, ok := pgTag.Options["alias"]; ok { - v, _ = tagparser.Unquote(v) - t.Alias = quoteIdent(v) - } - - pgTag := tagparser.Parse(f.Tag.Get("pg")) - if _, ok := pgTag.Options["discard_unknown_columns"]; ok { - t.setFlag(discardUnknownColumnsFlag) - } - - return nil - } - - if f.PkgPath != "" { - return nil - } - - sqlName := internal.Underscore(f.Name) - - if pgTag.Name != sqlName && isKnownFieldOption(pgTag.Name) { - internal.Warn.Printf( - "%s.%s tag name %q is also an option name; is it a mistake?", - t.TypeName, f.Name, pgTag.Name, - ) - } - - for name := range pgTag.Options { - if !isKnownFieldOption(name) { - internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, f.Name, name) - } - } - - skip := pgTag.Name == "-" - if !skip && pgTag.Name != "" { - sqlName = pgTag.Name - } - - index = append(index, f.Index...) - if field := t.getField(sqlName); field != nil { - if indexEqual(field.Index, index) { - return field - } - t.RemoveField(field) - } - - field := &Field{ - Field: f, - Type: indirectType(f.Type), - - GoName: f.Name, - SQLName: sqlName, - Column: quoteIdent(sqlName), - - Index: index, - } - - if _, ok := pgTag.Options["notnull"]; ok { - field.setFlag(NotNullFlag) - } - if v, ok := pgTag.Options["unique"]; ok { - if v == "" { - field.setFlag(UniqueFlag) - } - // Split the value by comma, this will allow multiple names to be specified. - // We can use this to create multiple named unique constraints where a single column - // might be included in multiple constraints. - v, _ = tagparser.Unquote(v) - for _, uniqueName := range strings.Split(v, ",") { - if t.Unique == nil { - t.Unique = make(map[string][]*Field) - } - t.Unique[uniqueName] = append(t.Unique[uniqueName], field) - } - } - if v, ok := pgTag.Options["default"]; ok { - v, ok = tagparser.Unquote(v) - if ok { - field.Default = types.Safe(types.AppendString(nil, v, 1)) - } else { - field.Default = types.Safe(v) - } - } - - //nolint - if _, ok := pgTag.Options["pk"]; ok { - field.setFlag(PrimaryKeyFlag) - } else if strings.HasSuffix(field.SQLName, "_id") || - strings.HasSuffix(field.SQLName, "_uuid") { - field.setFlag(ForeignKeyFlag) - } else if strings.HasPrefix(field.SQLName, "fk_") { - field.setFlag(ForeignKeyFlag) - } else if len(t.PKs) == 0 && !pgTag.HasOption("nopk") { - switch field.SQLName { - case "id", "uuid", "pk_" + t.ModelName: - field.setFlag(PrimaryKeyFlag) - } - } - - if _, ok := pgTag.Options["use_zero"]; ok { - field.setFlag(UseZeroFlag) - } - if _, ok := pgTag.Options["array"]; ok { - field.setFlag(ArrayFlag) - } - - field.SQLType = fieldSQLType(field, pgTag) - if strings.HasSuffix(field.SQLType, "[]") { - field.setFlag(ArrayFlag) - } - - if v, ok := pgTag.Options["on_delete"]; ok { - field.OnDelete = v - } - - if v, ok := pgTag.Options["on_update"]; ok { - field.OnUpdate = v - } - - if _, ok := pgTag.Options["composite"]; ok { - field.append = compositeAppender(f.Type) - field.scan = compositeScanner(f.Type) - } else if _, ok := pgTag.Options["json_use_number"]; ok { - field.append = types.Appender(f.Type) - field.scan = scanJSONValue - } else if field.hasFlag(ArrayFlag) { - field.append = types.ArrayAppender(f.Type) - field.scan = types.ArrayScanner(f.Type) - } else if _, ok := pgTag.Options["hstore"]; ok { - field.append = types.HstoreAppender(f.Type) - field.scan = types.HstoreScanner(f.Type) - } else if field.SQLType == pgTypeBigint && field.Type.Kind() == reflect.Uint64 { - if f.Type.Kind() == reflect.Ptr { - field.append = appendUintPtrAsInt - } else { - field.append = appendUintAsInt - } - field.scan = types.Scanner(f.Type) - } else if _, ok := pgTag.Options["msgpack"]; ok { - field.append = msgpackAppender(f.Type) - field.scan = msgpackScanner(f.Type) - } else { - field.append = types.Appender(f.Type) - field.scan = types.Scanner(f.Type) - } - field.isZero = zerochecker.Checker(f.Type) - - if v, ok := pgTag.Options["alias"]; ok { - v, _ = tagparser.Unquote(v) - t.FieldsMap[v] = field - } - - t.allFields = append(t.allFields, field) - if skip { - t.skippedFields = append(t.skippedFields, field) - t.FieldsMap[field.SQLName] = field - return nil - } - - if _, ok := pgTag.Options["soft_delete"]; ok { - t.SetSoftDeleteField = setSoftDeleteFieldFunc(f.Type) - if t.SetSoftDeleteField == nil { - err := fmt.Errorf( - "pg: soft_delete is only supported for time.Time, pg.NullTime, sql.NullInt64, and int64 (or implement ValueScanner that scans time)") - panic(err) - } - t.SoftDeleteField = field - } - - return field -} - -func (t *Table) initMethods() { - t.Methods = make(map[string]*Method) - typ := reflect.PtrTo(t.Type) - for i := 0; i < typ.NumMethod(); i++ { - m := typ.Method(i) - if m.PkgPath != "" { - continue - } - if m.Type.NumIn() > 1 { - continue - } - if m.Type.NumOut() != 1 { - continue - } - - retType := m.Type.Out(0) - t.Methods[m.Name] = &Method{ - Index: m.Index, - - appender: types.Appender(retType), - } - } -} - -func (t *Table) initInlines() { - for _, f := range t.skippedFields { - if f.Type.Kind() == reflect.Struct { - t.inlineFields(f, nil) - } - } -} - -func (t *Table) initRelations() { - for i := 0; i < len(t.Fields); { - f := t.Fields[i] - if t.tryRelation(f) { - t.Fields = removeField(t.Fields, f) - t.DataFields = removeField(t.DataFields, f) - } else { - i++ - } - - if f.Type.Kind() == reflect.Struct { - t.inlineFields(f, nil) - } - } -} - -func (t *Table) tryRelation(field *Field) bool { - pgTag := tagparser.Parse(field.Field.Tag.Get("pg")) - - if rel, ok := pgTag.Options["rel"]; ok { - return t.tryRelationType(field, rel, pgTag) - } - if _, ok := pgTag.Options["many2many"]; ok { - return t.tryRelationType(field, "many2many", pgTag) - } - - if field.UserSQLType != "" || isScanner(field.Type) { - return false - } - - switch field.Type.Kind() { - case reflect.Slice: - return t.tryRelationSlice(field, pgTag) - case reflect.Struct: - return t.tryRelationStruct(field, pgTag) - } - return false -} - -func (t *Table) tryRelationType(field *Field, rel string, pgTag *tagparser.Tag) bool { - switch rel { - case "has-one": - return t.mustHasOneRelation(field, pgTag) - case "belongs-to": - return t.mustBelongsToRelation(field, pgTag) - case "has-many": - return t.mustHasManyRelation(field, pgTag) - case "many2many": - return t.mustM2MRelation(field, pgTag) - default: - panic(fmt.Errorf("pg: unknown relation=%s on field=%s", rel, field.GoName)) - } -} - -func (t *Table) mustHasOneRelation(field *Field, pgTag *tagparser.Tag) bool { - joinTable := _tables.get(field.Type, true) - if err := joinTable.checkPKs(); err != nil { - panic(err) - } - fkPrefix, fkOK := pgTag.Options["fk"] - - if fkOK && len(joinTable.PKs) == 1 { - fk := t.getField(fkPrefix) - if fk == nil { - panic(fmt.Errorf( - "pg: %s has-one %s: %s must have column %s "+ - "(use fk:custom_column tag on %s field to specify custom column)", - t.TypeName, field.GoName, t.TypeName, fkPrefix, field.GoName, - )) - } - - t.addRelation(&Relation{ - Type: HasOneRelation, - Field: field, - JoinTable: joinTable, - BaseFKs: []*Field{fk}, - JoinFKs: joinTable.PKs, - }) - return true - } - - if !fkOK { - fkPrefix = internal.Underscore(field.GoName) + "_" - } - fks := make([]*Field, 0, len(joinTable.PKs)) - - for _, joinPK := range joinTable.PKs { - fkName := fkPrefix + joinPK.SQLName - if fk := t.getField(fkName); fk != nil { - fks = append(fks, fk) - continue - } - - if fk := t.getField(joinPK.SQLName); fk != nil { - fks = append(fks, fk) - continue - } - - panic(fmt.Errorf( - "pg: %s has-one %s: %s must have column %s "+ - "(use fk:custom_column tag on %s field to specify custom column)", - t.TypeName, field.GoName, t.TypeName, fkName, field.GoName, - )) - } - - t.addRelation(&Relation{ - Type: HasOneRelation, - Field: field, - JoinTable: joinTable, - BaseFKs: fks, - JoinFKs: joinTable.PKs, - }) - return true -} - -func (t *Table) mustBelongsToRelation(field *Field, pgTag *tagparser.Tag) bool { - if err := t.checkPKs(); err != nil { - panic(err) - } - joinTable := _tables.get(field.Type, true) - fkPrefix, fkOK := pgTag.Options["join_fk"] - - if fkOK && len(t.PKs) == 1 { - fk := joinTable.getField(fkPrefix) - if fk == nil { - panic(fmt.Errorf( - "pg: %s belongs-to %s: %s must have column %s "+ - "(use join_fk:custom_column tag on %s field to specify custom column)", - field.GoName, t.TypeName, joinTable.TypeName, fkPrefix, field.GoName, - )) - } - - t.addRelation(&Relation{ - Type: BelongsToRelation, - Field: field, - JoinTable: joinTable, - BaseFKs: t.PKs, - JoinFKs: []*Field{fk}, - }) - return true - } - - if !fkOK { - fkPrefix = internal.Underscore(t.ModelName) + "_" - } - fks := make([]*Field, 0, len(t.PKs)) - - for _, pk := range t.PKs { - fkName := fkPrefix + pk.SQLName - if fk := joinTable.getField(fkName); fk != nil { - fks = append(fks, fk) - continue - } - - if fk := joinTable.getField(pk.SQLName); fk != nil { - fks = append(fks, fk) - continue - } - - panic(fmt.Errorf( - "pg: %s belongs-to %s: %s must have column %s "+ - "(use join_fk:custom_column tag on %s field to specify custom column)", - field.GoName, t.TypeName, joinTable.TypeName, fkName, field.GoName, - )) - } - - t.addRelation(&Relation{ - Type: BelongsToRelation, - Field: field, - JoinTable: joinTable, - BaseFKs: t.PKs, - JoinFKs: fks, - }) - return true -} - -func (t *Table) mustHasManyRelation(field *Field, pgTag *tagparser.Tag) bool { - if err := t.checkPKs(); err != nil { - panic(err) - } - if field.Type.Kind() != reflect.Slice { - panic(fmt.Errorf( - "pg: %s.%s has-many relation requires slice, got %q", - t.TypeName, field.GoName, field.Type.Kind(), - )) - } - - joinTable := _tables.get(indirectType(field.Type.Elem()), true) - fkPrefix, fkOK := pgTag.Options["join_fk"] - _, polymorphic := pgTag.Options["polymorphic"] - - if fkOK && !polymorphic && len(t.PKs) == 1 { - fk := joinTable.getField(fkPrefix) - if fk == nil { - panic(fmt.Errorf( - "pg: %s has-many %s: %s must have column %s "+ - "(use join_fk:custom_column tag on %s field to specify custom column)", - t.TypeName, field.GoName, joinTable.TypeName, fkPrefix, field.GoName, - )) - } - - t.addRelation(&Relation{ - Type: HasManyRelation, - Field: field, - JoinTable: joinTable, - BaseFKs: t.PKs, - JoinFKs: []*Field{fk}, - }) - return true - } - - if !fkOK { - fkPrefix = internal.Underscore(t.ModelName) + "_" - } - fks := make([]*Field, 0, len(t.PKs)) - - for _, pk := range t.PKs { - fkName := fkPrefix + pk.SQLName - if fk := joinTable.getField(fkName); fk != nil { - fks = append(fks, fk) - continue - } - - if fk := joinTable.getField(pk.SQLName); fk != nil { - fks = append(fks, fk) - continue - } - - panic(fmt.Errorf( - "pg: %s has-many %s: %s must have column %s "+ - "(use join_fk:custom_column tag on %s field to specify custom column)", - t.TypeName, field.GoName, joinTable.TypeName, fkName, field.GoName, - )) - } - - var typeField *Field - - if polymorphic { - typeFieldName := fkPrefix + "type" - typeField = joinTable.getField(typeFieldName) - if typeField == nil { - panic(fmt.Errorf( - "pg: %s has-many %s: %s must have polymorphic column %s", - t.TypeName, field.GoName, joinTable.TypeName, typeFieldName, - )) - } - } - - t.addRelation(&Relation{ - Type: HasManyRelation, - Field: field, - JoinTable: joinTable, - BaseFKs: t.PKs, - JoinFKs: fks, - Polymorphic: typeField, - }) - return true -} - -func (t *Table) mustM2MRelation(field *Field, pgTag *tagparser.Tag) bool { - if field.Type.Kind() != reflect.Slice { - panic(fmt.Errorf( - "pg: %s.%s many2many relation requires slice, got %q", - t.TypeName, field.GoName, field.Type.Kind(), - )) - } - joinTable := _tables.get(indirectType(field.Type.Elem()), true) - - if err := t.checkPKs(); err != nil { - panic(err) - } - if err := joinTable.checkPKs(); err != nil { - panic(err) - } - - m2mTableNameString, ok := pgTag.Options["many2many"] - if !ok { - panic(fmt.Errorf("pg: %s must have many2many tag option", field.GoName)) - } - m2mTableName := quoteTableName(m2mTableNameString) - - m2mTable := _tables.getByName(m2mTableName) - if m2mTable == nil { - panic(fmt.Errorf( - "pg: can't find %s table (use orm.RegisterTable to register the model)", - m2mTableName, - )) - } - - var baseFKs []string - var joinFKs []string - - { - fkPrefix, ok := pgTag.Options["fk"] - if !ok { - fkPrefix = internal.Underscore(t.ModelName) + "_" - } - - if ok && len(t.PKs) == 1 { - if m2mTable.getField(fkPrefix) == nil { - panic(fmt.Errorf( - "pg: %s many2many %s: %s must have column %s "+ - "(use fk:custom_column tag on %s field to specify custom column)", - t.TypeName, field.GoName, m2mTable.TypeName, fkPrefix, field.GoName, - )) - } - baseFKs = []string{fkPrefix} - } else { - for _, pk := range t.PKs { - fkName := fkPrefix + pk.SQLName - if m2mTable.getField(fkName) != nil { - baseFKs = append(baseFKs, fkName) - continue - } - - if m2mTable.getField(pk.SQLName) != nil { - baseFKs = append(baseFKs, pk.SQLName) - continue - } - - panic(fmt.Errorf( - "pg: %s many2many %s: %s must have column %s "+ - "(use fk:custom_column tag on %s field to specify custom column)", - t.TypeName, field.GoName, m2mTable.TypeName, fkName, field.GoName, - )) - } - } - } - - { - joinFKPrefix, ok := pgTag.Options["join_fk"] - if !ok { - joinFKPrefix = internal.Underscore(joinTable.ModelName) + "_" - } - - if ok && len(joinTable.PKs) == 1 { - if m2mTable.getField(joinFKPrefix) == nil { - panic(fmt.Errorf( - "pg: %s many2many %s: %s must have column %s "+ - "(use join_fk:custom_column tag on %s field to specify custom column)", - joinTable.TypeName, field.GoName, m2mTable.TypeName, joinFKPrefix, field.GoName, - )) - } - joinFKs = []string{joinFKPrefix} - } else { - for _, joinPK := range joinTable.PKs { - fkName := joinFKPrefix + joinPK.SQLName - if m2mTable.getField(fkName) != nil { - joinFKs = append(joinFKs, fkName) - continue - } - - if m2mTable.getField(joinPK.SQLName) != nil { - joinFKs = append(joinFKs, joinPK.SQLName) - continue - } - - panic(fmt.Errorf( - "pg: %s many2many %s: %s must have column %s "+ - "(use join_fk:custom_column tag on %s field to specify custom column)", - t.TypeName, field.GoName, m2mTable.TypeName, fkName, field.GoName, - )) - } - } - } - - t.addRelation(&Relation{ - Type: Many2ManyRelation, - Field: field, - JoinTable: joinTable, - M2MTableName: m2mTableName, - M2MTableAlias: m2mTable.Alias, - M2MBaseFKs: baseFKs, - M2MJoinFKs: joinFKs, - }) - return true -} - -//nolint -func (t *Table) tryRelationSlice(field *Field, pgTag *tagparser.Tag) bool { - if t.tryM2MRelation(field, pgTag) { - internal.Deprecated.Printf( - `add pg:"rel:many2many" to %s.%s field tag`, t.TypeName, field.GoName) - return true - } - if t.tryHasManyRelation(field, pgTag) { - internal.Deprecated.Printf( - `add pg:"rel:has-many" to %s.%s field tag`, t.TypeName, field.GoName) - return true - } - return false -} - -func (t *Table) tryM2MRelation(field *Field, pgTag *tagparser.Tag) bool { - elemType := indirectType(field.Type.Elem()) - if elemType.Kind() != reflect.Struct { - return false - } - - joinTable := _tables.get(elemType, true) - - fk, fkOK := pgTag.Options["fk"] - if fkOK { - if fk == "-" { - return false - } - fk = tryUnderscorePrefix(fk) - } - - m2mTableName := pgTag.Options["many2many"] - if m2mTableName == "" { - return false - } - - m2mTable := _tables.getByName(quoteIdent(m2mTableName)) - - var m2mTableAlias types.Safe - if m2mTable != nil { - m2mTableAlias = m2mTable.Alias - } else if ind := strings.IndexByte(m2mTableName, '.'); ind >= 0 { - m2mTableAlias = quoteIdent(m2mTableName[ind+1:]) - } else { - m2mTableAlias = quoteIdent(m2mTableName) - } - - var fks []string - if !fkOK { - fk = t.ModelName + "_" - } - if m2mTable != nil { - keys := foreignKeys(t, m2mTable, fk, fkOK) - if len(keys) == 0 { - return false - } - for _, fk := range keys { - fks = append(fks, fk.SQLName) - } - } else { - if fkOK && len(t.PKs) == 1 { - fks = append(fks, fk) - } else { - for _, pk := range t.PKs { - fks = append(fks, fk+pk.SQLName) - } - } - } - - joinFK, joinFKOk := pgTag.Options["join_fk"] - if !joinFKOk { - joinFK, joinFKOk = pgTag.Options["joinFK"] - if joinFKOk { - internal.Deprecated.Printf("joinFK is renamed to join_fk") - } - } - if joinFKOk { - joinFK = tryUnderscorePrefix(joinFK) - } else { - joinFK = joinTable.ModelName + "_" - } - - var joinFKs []string - if m2mTable != nil { - keys := foreignKeys(joinTable, m2mTable, joinFK, joinFKOk) - if len(keys) == 0 { - return false - } - for _, fk := range keys { - joinFKs = append(joinFKs, fk.SQLName) - } - } else { - if joinFKOk && len(joinTable.PKs) == 1 { - joinFKs = append(joinFKs, joinFK) - } else { - for _, pk := range joinTable.PKs { - joinFKs = append(joinFKs, joinFK+pk.SQLName) - } - } - } - - t.addRelation(&Relation{ - Type: Many2ManyRelation, - Field: field, - JoinTable: joinTable, - M2MTableName: quoteIdent(m2mTableName), - M2MTableAlias: m2mTableAlias, - M2MBaseFKs: fks, - M2MJoinFKs: joinFKs, - }) - return true -} - -func (t *Table) tryHasManyRelation(field *Field, pgTag *tagparser.Tag) bool { - elemType := indirectType(field.Type.Elem()) - if elemType.Kind() != reflect.Struct { - return false - } - - joinTable := _tables.get(elemType, true) - - fk, fkOK := pgTag.Options["fk"] - if fkOK { - if fk == "-" { - return false - } - fk = tryUnderscorePrefix(fk) - } - - s, polymorphic := pgTag.Options["polymorphic"] - var typeField *Field - if polymorphic { - fk = tryUnderscorePrefix(s) - - typeField = joinTable.getField(fk + "type") - if typeField == nil { - return false - } - } else if !fkOK { - fk = t.ModelName + "_" - } - - fks := foreignKeys(t, joinTable, fk, fkOK || polymorphic) - if len(fks) == 0 { - return false - } - - var fkValues []*Field - fkValue, ok := pgTag.Options["fk_value"] - if ok { - if len(fks) > 1 { - panic(fmt.Errorf("got fk_value, but there are %d fks", len(fks))) - } - - f := t.getField(fkValue) - if f == nil { - panic(fmt.Errorf("fk_value=%q not found in %s", fkValue, t)) - } - fkValues = append(fkValues, f) - } else { - fkValues = t.PKs - } - - if len(fks) != len(fkValues) { - panic("len(fks) != len(fkValues)") - } - - if len(fks) > 0 { - t.addRelation(&Relation{ - Type: HasManyRelation, - Field: field, - JoinTable: joinTable, - BaseFKs: fkValues, - JoinFKs: fks, - Polymorphic: typeField, - }) - return true - } - - return false -} - -func (t *Table) tryRelationStruct(field *Field, pgTag *tagparser.Tag) bool { - joinTable := _tables.get(field.Type, true) - - if len(joinTable.allFields) == 0 { - return false - } - - if t.tryHasOne(joinTable, field, pgTag) { - internal.Deprecated.Printf( - `add pg:"rel:has-one" to %s.%s field tag`, t.TypeName, field.GoName) - t.inlineFields(field, nil) - return true - } - - if t.tryBelongsToOne(joinTable, field, pgTag) { - internal.Deprecated.Printf( - `add pg:"rel:belongs-to" to %s.%s field tag`, t.TypeName, field.GoName) - t.inlineFields(field, nil) - return true - } - - t.inlineFields(field, nil) - return false -} - -func (t *Table) inlineFields(strct *Field, path map[reflect.Type]struct{}) { - if path == nil { - path = map[reflect.Type]struct{}{ - t.Type: {}, - } - } - - if _, ok := path[strct.Type]; ok { - return - } - path[strct.Type] = struct{}{} - - joinTable := _tables.get(strct.Type, true) - for _, f := range joinTable.allFields { - f = f.Clone() - f.GoName = strct.GoName + "_" + f.GoName - f.SQLName = strct.SQLName + "__" + f.SQLName - f.Column = quoteIdent(f.SQLName) - f.Index = appendNew(strct.Index, f.Index...) - - t.fieldsMapMu.Lock() - if _, ok := t.FieldsMap[f.SQLName]; !ok { - t.FieldsMap[f.SQLName] = f - } - t.fieldsMapMu.Unlock() - - if f.Type.Kind() != reflect.Struct { - continue - } - - if _, ok := path[f.Type]; !ok { - t.inlineFields(f, path) - } - } -} - -func appendNew(dst []int, src ...int) []int { - cp := make([]int, len(dst)+len(src)) - copy(cp, dst) - copy(cp[len(dst):], src) - return cp -} - -func isScanner(typ reflect.Type) bool { - return typ.Implements(scannerType) || reflect.PtrTo(typ).Implements(scannerType) -} - -func fieldSQLType(field *Field, pgTag *tagparser.Tag) string { - if typ, ok := pgTag.Options["type"]; ok { - typ, _ = tagparser.Unquote(typ) - field.UserSQLType = typ - typ = normalizeSQLType(typ) - return typ - } - - if typ, ok := pgTag.Options["composite"]; ok { - typ, _ = tagparser.Unquote(typ) - return typ - } - - if _, ok := pgTag.Options["hstore"]; ok { - return "hstore" - } else if _, ok := pgTag.Options["hstore"]; ok { - return "hstore" - } - - if field.hasFlag(ArrayFlag) { - switch field.Type.Kind() { - case reflect.Slice, reflect.Array: - sqlType := sqlType(field.Type.Elem()) - return sqlType + "[]" - } - } - - sqlType := sqlType(field.Type) - return sqlType -} - -func sqlType(typ reflect.Type) string { - switch typ { - case timeType, nullTimeType, sqlNullTimeType: - return pgTypeTimestampTz - case ipType: - return pgTypeInet - case ipNetType: - return pgTypeCidr - case nullBoolType: - return pgTypeBoolean - case nullFloatType: - return pgTypeDoublePrecision - case nullIntType: - return pgTypeBigint - case nullStringType: - return pgTypeText - case jsonRawMessageType: - return pgTypeJSONB - } - - switch typ.Kind() { - case reflect.Int8, reflect.Uint8, reflect.Int16: - return pgTypeSmallint - case reflect.Uint16, reflect.Int32: - return pgTypeInteger - case reflect.Uint32, reflect.Int64, reflect.Int: - return pgTypeBigint - case reflect.Uint, reflect.Uint64: - // Unsigned bigint is not supported - use bigint. - return pgTypeBigint - case reflect.Float32: - return pgTypeReal - case reflect.Float64: - return pgTypeDoublePrecision - case reflect.Bool: - return pgTypeBoolean - case reflect.String: - return pgTypeText - case reflect.Map, reflect.Struct: - return pgTypeJSONB - case reflect.Array, reflect.Slice: - if typ.Elem().Kind() == reflect.Uint8 { - return pgTypeBytea - } - return pgTypeJSONB - default: - return typ.Kind().String() - } -} - -func normalizeSQLType(s string) string { - switch s { - case "int2": - return pgTypeSmallint - case "int4", "int", "serial": - return pgTypeInteger - case "int8", pgTypeBigserial: - return pgTypeBigint - case "float4": - return pgTypeReal - case "float8": - return pgTypeDoublePrecision - } - return s -} - -func sqlTypeEqual(a, b string) bool { - return a == b -} - -func (t *Table) tryHasOne(joinTable *Table, field *Field, pgTag *tagparser.Tag) bool { - fk, fkOK := pgTag.Options["fk"] - if fkOK { - if fk == "-" { - return false - } - fk = tryUnderscorePrefix(fk) - } else { - fk = internal.Underscore(field.GoName) + "_" - } - - fks := foreignKeys(joinTable, t, fk, fkOK) - if len(fks) > 0 { - t.addRelation(&Relation{ - Type: HasOneRelation, - Field: field, - JoinTable: joinTable, - BaseFKs: fks, - JoinFKs: joinTable.PKs, - }) - return true - } - return false -} - -func (t *Table) tryBelongsToOne(joinTable *Table, field *Field, pgTag *tagparser.Tag) bool { - fk, fkOK := pgTag.Options["fk"] - if fkOK { - if fk == "-" { - return false - } - fk = tryUnderscorePrefix(fk) - } else { - fk = internal.Underscore(t.TypeName) + "_" - } - - fks := foreignKeys(t, joinTable, fk, fkOK) - if len(fks) > 0 { - t.addRelation(&Relation{ - Type: BelongsToRelation, - Field: field, - JoinTable: joinTable, - BaseFKs: t.PKs, - JoinFKs: fks, - }) - return true - } - return false -} - -func (t *Table) addRelation(rel *Relation) { - if t.Relations == nil { - t.Relations = make(map[string]*Relation) - } - _, ok := t.Relations[rel.Field.GoName] - if ok { - panic(fmt.Errorf("%s already has %s", t, rel)) - } - t.Relations[rel.Field.GoName] = rel -} - -func foreignKeys(base, join *Table, fk string, tryFK bool) []*Field { - var fks []*Field - - for _, pk := range base.PKs { - fkName := fk + pk.SQLName - f := join.getField(fkName) - if f != nil && sqlTypeEqual(pk.SQLType, f.SQLType) { - fks = append(fks, f) - continue - } - - if strings.IndexByte(pk.SQLName, '_') == -1 { - continue - } - - f = join.getField(pk.SQLName) - if f != nil && sqlTypeEqual(pk.SQLType, f.SQLType) { - fks = append(fks, f) - continue - } - } - if len(fks) > 0 && len(fks) == len(base.PKs) { - return fks - } - - fks = nil - for _, pk := range base.PKs { - if !strings.HasPrefix(pk.SQLName, "pk_") { - continue - } - fkName := "fk_" + pk.SQLName[3:] - f := join.getField(fkName) - if f != nil && sqlTypeEqual(pk.SQLType, f.SQLType) { - fks = append(fks, f) - } - } - if len(fks) > 0 && len(fks) == len(base.PKs) { - return fks - } - - if fk == "" || len(base.PKs) != 1 { - return nil - } - - if tryFK { - f := join.getField(fk) - if f != nil && sqlTypeEqual(base.PKs[0].SQLType, f.SQLType) { - return []*Field{f} - } - } - - for _, suffix := range []string{"id", "uuid"} { - f := join.getField(fk + suffix) - if f != nil && sqlTypeEqual(base.PKs[0].SQLType, f.SQLType) { - return []*Field{f} - } - } - - return nil -} - -func scanJSONValue(v reflect.Value, rd types.Reader, n int) error { - // Zero value so it works with SelectOrInsert. - // TODO: better handle slices - v.Set(reflect.New(v.Type()).Elem()) - - if n == -1 { - return nil - } - - dec := pgjson.NewDecoder(rd) - dec.UseNumber() - return dec.Decode(v.Addr().Interface()) -} - -func appendUintAsInt(b []byte, v reflect.Value, _ int) []byte { - return strconv.AppendInt(b, int64(v.Uint()), 10) -} - -func appendUintPtrAsInt(b []byte, v reflect.Value, _ int) []byte { - return strconv.AppendInt(b, int64(v.Elem().Uint()), 10) -} - -func tryUnderscorePrefix(s string) string { - if s == "" { - return s - } - if c := s[0]; internal.IsUpper(c) { - return internal.Underscore(s) + "_" - } - return s -} - -func quoteTableName(s string) types.Safe { - // Don't quote if table name contains placeholder (?) or parentheses. - if strings.IndexByte(s, '?') >= 0 || - strings.IndexByte(s, '(') >= 0 && strings.IndexByte(s, ')') >= 0 { - return types.Safe(s) - } - return quoteIdent(s) -} - -func quoteIdent(s string) types.Safe { - return types.Safe(types.AppendIdent(nil, s, 1)) -} - -func setSoftDeleteFieldFunc(typ reflect.Type) func(fv reflect.Value) error { - switch typ { - case timeType: - return func(fv reflect.Value) error { - ptr := fv.Addr().Interface().(*time.Time) - *ptr = time.Now() - return nil - } - case nullTimeType: - return func(fv reflect.Value) error { - ptr := fv.Addr().Interface().(*types.NullTime) - *ptr = types.NullTime{Time: time.Now()} - return nil - } - case nullIntType: - return func(fv reflect.Value) error { - ptr := fv.Addr().Interface().(*sql.NullInt64) - *ptr = sql.NullInt64{Int64: time.Now().UnixNano()} - return nil - } - } - - switch typ.Kind() { - case reflect.Int64: - return func(fv reflect.Value) error { - ptr := fv.Addr().Interface().(*int64) - *ptr = time.Now().UnixNano() - return nil - } - case reflect.Ptr: - break - default: - return setSoftDeleteFallbackFunc(typ) - } - - originalType := typ - typ = typ.Elem() - - switch typ { //nolint:gocritic - case timeType: - return func(fv reflect.Value) error { - now := time.Now() - fv.Set(reflect.ValueOf(&now)) - return nil - } - } - - switch typ.Kind() { //nolint:gocritic - case reflect.Int64: - return func(fv reflect.Value) error { - utime := time.Now().UnixNano() - fv.Set(reflect.ValueOf(&utime)) - return nil - } - } - - return setSoftDeleteFallbackFunc(originalType) -} - -func setSoftDeleteFallbackFunc(typ reflect.Type) func(fv reflect.Value) error { - scanner := types.Scanner(typ) - if scanner == nil { - return nil - } - - return func(fv reflect.Value) error { - var flags int - b := types.AppendTime(nil, time.Now(), flags) - return scanner(fv, pool.NewBytesReader(b), len(b)) - } -} - -func isKnownTableOption(name string) bool { - switch name { - case "alias", - "select", - "tablespace", - "partition_by", - "discard_unknown_columns": - return true - } - return false -} - -func isKnownFieldOption(name string) bool { - switch name { - case "alias", - "type", - "array", - "hstore", - "composite", - "json_use_number", - "msgpack", - "notnull", - "use_zero", - "default", - "unique", - "soft_delete", - "on_delete", - "on_update", - - "pk", - "nopk", - "rel", - "fk", - "join_fk", - "many2many", - "polymorphic": - return true - } - return false -} |