Go-Back-Skeleton/vendor/github.com/go-pg/pg/orm/table.go
2017-09-25 20:20:52 +02:00

568 lines
12 KiB
Go

package orm
import (
"bytes"
"database/sql"
"encoding/json"
"fmt"
"net"
"reflect"
"strings"
"time"
"github.com/go-pg/pg/internal"
"github.com/go-pg/pg/types"
)
var timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
var ipType = reflect.TypeOf((*net.IP)(nil)).Elem()
var ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem()
var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
var nullBoolType = reflect.TypeOf((*sql.NullBool)(nil)).Elem()
var nullFloatType = reflect.TypeOf((*sql.NullFloat64)(nil)).Elem()
var nullIntType = reflect.TypeOf((*sql.NullInt64)(nil)).Elem()
var nullStringType = reflect.TypeOf((*sql.NullString)(nil)).Elem()
type Table struct {
Type reflect.Type
zeroStruct reflect.Value
TypeName string
Name types.Q
Alias types.Q
ModelName string
Fields []*Field
PKs []*Field
Columns []*Field
FieldsMap map[string]*Field
Methods map[string]*Method
Relations map[string]*Relation
flags uint8
}
func (t *Table) SetFlag(flag uint8) {
t.flags |= flag
}
func (t *Table) HasFlag(flag uint8) bool {
if t == nil {
return false
}
return t.flags&flag != 0
}
func (t *Table) HasField(field string) bool {
_, err := t.GetField(field)
return err == nil
}
func (t *Table) checkPKs() error {
if len(t.PKs) == 0 {
return fmt.Errorf("model=%s does not have primary keys", t.Type.Name())
}
return nil
}
func (t *Table) AddField(field *Field) {
t.Fields = append(t.Fields, field)
if field.HasFlag(PrimaryKeyFlag) {
t.PKs = append(t.PKs, field)
} else {
t.Columns = append(t.Columns, field)
}
t.FieldsMap[field.SQLName] = field
}
func (t *Table) GetField(fieldName string) (*Field, error) {
field, ok := t.FieldsMap[fieldName]
if !ok {
return nil, fmt.Errorf("can't find column=%s in table=%s", fieldName, t.Name)
}
return field, nil
}
func (t *Table) AppendParam(b []byte, strct reflect.Value, name string) ([]byte, bool) {
if field, ok := t.FieldsMap[name]; ok {
b = field.AppendValue(b, strct, 1)
return b, true
}
if method, ok := t.Methods[name]; ok {
b = method.AppendValue(b, strct.Addr(), 1)
return b, true
}
return b, false
}
func (t *Table) addRelation(rel *Relation) {
if t.Relations == nil {
t.Relations = make(map[string]*Relation)
}
t.Relations[rel.Field.GoName] = rel
}
func newTable(typ reflect.Type) *Table {
table, ok := Tables.tables[typ]
if ok {
return table
}
modelName := internal.Underscore(typ.Name())
table = &Table{
Type: typ,
zeroStruct: reflect.Zero(typ),
TypeName: internal.ToExported(typ.Name()),
Name: types.Q(types.AppendField(nil, tableNameInflector(modelName), 1)),
Alias: types.Q(types.AppendField(nil, modelName, 1)),
ModelName: modelName,
Fields: make([]*Field, 0, typ.NumField()),
FieldsMap: make(map[string]*Field, typ.NumField()),
}
Tables.tables[typ] = table
table.addFields(typ, nil)
typ = reflect.PtrTo(typ)
if typ.Implements(afterQueryHookType) {
table.SetFlag(AfterQueryHookFlag)
}
if typ.Implements(afterSelectHookType) {
table.SetFlag(AfterSelectHookFlag)
}
if typ.Implements(beforeInsertHookType) {
table.SetFlag(BeforeInsertHookFlag)
}
if typ.Implements(afterInsertHookType) {
table.SetFlag(AfterInsertHookFlag)
}
if typ.Implements(beforeUpdateHookType) {
table.SetFlag(BeforeUpdateHookFlag)
}
if typ.Implements(afterUpdateHookType) {
table.SetFlag(AfterUpdateHookFlag)
}
if typ.Implements(beforeDeleteHookType) {
table.SetFlag(BeforeDeleteHookFlag)
}
if typ.Implements(afterDeleteHookType) {
table.SetFlag(AfterDeleteHookFlag)
}
if table.Methods == nil {
table.Methods = make(map[string]*Method)
}
for i := 0; i < typ.NumMethod(); i++ {
m := typ.Method(i)
if m.PkgPath != "" {
continue
}
if m.Type.NumIn() > 1 {
continue
}
if m.Type.NumOut() != 1 {
continue
}
retType := m.Type.Out(0)
method := Method{
Index: m.Index,
appender: types.Appender(retType),
}
table.Methods[m.Name] = &method
}
return table
}
func (t *Table) addFields(typ reflect.Type, baseIndex []int) {
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
// Make a copy so slice is not shared between fields.
var index []int
index = append(index, baseIndex...)
if f.Anonymous {
embeddedTable := newTable(indirectType(f.Type))
pgTag := parseTag(f.Tag.Get("pg"))
if _, ok := pgTag.Options["override"]; ok {
t.TypeName = embeddedTable.TypeName
t.Name = embeddedTable.Name
t.Alias = embeddedTable.Alias
t.ModelName = embeddedTable.ModelName
}
t.addFields(embeddedTable.Type, append(index, f.Index...))
continue
}
field := t.newField(f, index)
if field != nil {
t.AddField(field)
}
}
}
func (t *Table) getField(name string) *Field {
for _, f := range t.Fields {
if f.GoName == name {
return f
}
}
f, ok := t.Type.FieldByName(name)
if !ok {
return nil
}
return t.newField(f, nil)
}
func (t *Table) newField(f reflect.StructField, index []int) *Field {
sqlTag := parseTag(f.Tag.Get("sql"))
switch f.Name {
case "tableName", "TableName":
if index != nil {
return nil
}
if sqlTag.Name != "" {
if isPostgresKeyword(sqlTag.Name) {
sqlTag.Name = `"` + sqlTag.Name + `"`
}
t.Name = types.Q(sqlTag.Name)
}
if alias, ok := sqlTag.Options["alias"]; ok {
t.Alias = types.Q(alias)
}
return nil
}
if f.PkgPath != "" {
return nil
}
skip := sqlTag.Name == "-"
if skip || sqlTag.Name == "" {
sqlTag.Name = internal.Underscore(f.Name)
}
if field, ok := t.FieldsMap[sqlTag.Name]; ok {
if field.GoName == f.Name {
return field
}
}
field := Field{
Type: indirectType(f.Type),
GoName: f.Name,
SQLName: sqlTag.Name,
Column: types.Q(types.AppendField(nil, sqlTag.Name, 1)),
Index: append(index, f.Index...),
}
if _, ok := sqlTag.Options["notnull"]; ok {
field.SetFlag(NotNullFlag)
}
if _, ok := sqlTag.Options["unique"]; ok {
field.SetFlag(UniqueFlag)
}
if v, ok := sqlTag.Options["default"]; ok {
v, ok = unquote(v)
if ok {
field.Default = types.Q(types.AppendString(nil, v, 1))
} else {
field.Default = types.Q(v)
}
}
if len(t.PKs) == 0 && (field.SQLName == "id" || field.SQLName == "uuid") {
field.SetFlag(PrimaryKeyFlag)
} else if _, ok := sqlTag.Options["pk"]; ok {
field.SetFlag(PrimaryKeyFlag)
} else if strings.HasSuffix(field.SQLName, "_id") ||
strings.HasSuffix(field.SQLName, "_uuid") {
field.SetFlag(ForeignKeyFlag)
}
pgTag := parseTag(f.Tag.Get("pg"))
if _, ok := pgTag.Options["array"]; ok {
field.SetFlag(ArrayFlag)
}
field.SQLType = fieldSQLType(&field, sqlTag)
if strings.HasSuffix(field.SQLType, "[]") {
field.SetFlag(ArrayFlag)
}
if _, ok := pgTag.Options["json_use_number"]; ok {
field.append = types.Appender(f.Type)
field.scan = scanJSONValue
} else if field.HasFlag(ArrayFlag) {
field.append = types.ArrayAppender(f.Type)
field.scan = types.ArrayScanner(f.Type)
} else if _, ok := pgTag.Options["hstore"]; ok {
field.append = types.HstoreAppender(f.Type)
field.scan = types.HstoreScanner(f.Type)
} else {
field.append = types.Appender(f.Type)
field.scan = types.Scanner(f.Type)
}
field.isZero = isZeroFunc(f.Type)
if !skip && isColumn(f.Type) {
return &field
}
switch field.Type.Kind() {
case reflect.Slice:
elemType := indirectType(field.Type.Elem())
if elemType.Kind() != reflect.Struct {
break
}
joinTable := newTable(elemType)
fk, ok := pgTag.Options["fk"]
if !ok {
fk = t.TypeName
}
if m2mTable, _ := pgTag.Options["many2many"]; m2mTable != "" {
m2mTableAlias := m2mTable
if ind := strings.IndexByte(m2mTable, '.'); ind >= 0 {
m2mTableAlias = m2mTable[ind+1:]
}
joinFK, ok := pgTag.Options["joinFK"]
if !ok {
joinFK = joinTable.TypeName
}
t.addRelation(&Relation{
Type: Many2ManyRelation,
Field: &field,
JoinTable: joinTable,
M2MTableName: types.Q(m2mTable),
M2MTableAlias: types.Q(m2mTableAlias),
BasePrefix: internal.Underscore(fk + "_"),
JoinPrefix: internal.Underscore(joinFK + "_"),
})
return nil
}
s, polymorphic := pgTag.Options["polymorphic"]
if polymorphic {
fk = s
}
fks := foreignKeys(t, joinTable, fk, t.TypeName)
if len(fks) > 0 {
t.addRelation(&Relation{
Type: HasManyRelation,
Polymorphic: polymorphic,
Field: &field,
FKs: fks,
JoinTable: joinTable,
BasePrefix: internal.Underscore(fk + "_"),
})
return nil
}
case reflect.Struct:
joinTable := newTable(field.Type)
if len(joinTable.Fields) == 0 {
break
}
for _, ff := range joinTable.FieldsMap {
ff = ff.Copy()
ff.SQLName = field.SQLName + "__" + ff.SQLName
ff.Column = types.Q(types.AppendField(nil, ff.SQLName, 1))
ff.Index = append(field.Index, ff.Index...)
if _, ok := t.FieldsMap[ff.SQLName]; !ok {
t.FieldsMap[ff.SQLName] = ff
}
}
if t.tryHasOne(joinTable, &field, pgTag) ||
t.tryBelongsToOne(joinTable, &field, pgTag) {
t.FieldsMap[field.SQLName] = &field
return nil
}
}
if skip {
t.FieldsMap[field.SQLName] = &field
return nil
}
return &field
}
func isPostgresKeyword(s string) bool {
switch s {
case "user":
return true
}
return false
}
func isColumn(typ reflect.Type) bool {
return typ.Implements(scannerType) || reflect.PtrTo(typ).Implements(scannerType)
}
func fieldSQLType(field *Field, sqlTag *tag) string {
if v, ok := sqlTag.Options["type"]; ok {
field.SetFlag(customTypeFlag)
v, _ := unquote(v)
return v
}
if field.HasFlag(ArrayFlag) {
sqlType := sqlType(field.Type.Elem())
return sqlType + "[]"
}
sqlType := sqlType(field.Type)
if field.HasFlag(PrimaryKeyFlag) {
switch sqlType {
case "smallint":
return "smallserial"
case "integer":
return "serial"
case "bigint":
return "bigserial"
}
}
return sqlType
}
func sqlType(typ reflect.Type) string {
switch typ {
case timeType:
return "timestamptz"
case ipType:
return "inet"
case ipNetType:
return "cidr"
case nullBoolType:
return "boolean"
case nullFloatType:
return "double precision"
case nullIntType:
return "bigint"
case nullStringType:
return "text"
}
switch typ.Kind() {
case reflect.Int8, reflect.Uint8, reflect.Int16:
return "smallint"
case reflect.Uint16, reflect.Int32:
return "integer"
case reflect.Uint32, reflect.Int64, reflect.Int:
return "bigint"
case reflect.Uint, reflect.Uint64:
return "decimal"
case reflect.Float32:
return "real"
case reflect.Float64:
return "double precision"
case reflect.Bool:
return "boolean"
case reflect.String:
return "text"
case reflect.Map, reflect.Struct:
return "jsonb"
case reflect.Array, reflect.Slice:
if typ.Elem().Kind() == reflect.Uint8 {
return "bytea"
}
return "jsonb"
default:
return typ.Kind().String()
}
}
func foreignKeys(base, join *Table, fk, fieldName string) []*Field {
var fks []*Field
for _, pk := range base.PKs {
fkName := fk + pk.GoName
if f := join.getField(fkName); f != nil {
fks = append(fks, f)
}
}
if len(fks) > 0 {
return fks
}
if fk != "" && fk != fieldName {
if f := join.getField(fk); f != nil {
fks = append(fks, f)
}
}
return fks
}
func (t *Table) tryHasOne(joinTable *Table, field *Field, tag *tag) bool {
fk, ok := tag.Options["fk"]
if !ok {
fk = field.GoName
}
fks := foreignKeys(joinTable, t, fk, field.GoName)
if len(fks) > 0 {
t.addRelation(&Relation{
Type: HasOneRelation,
Field: field,
FKs: fks,
JoinTable: joinTable,
})
return true
}
return false
}
func (t *Table) tryBelongsToOne(joinTable *Table, field *Field, tag *tag) bool {
fk, ok := tag.Options["fk"]
if !ok {
fk = t.TypeName
}
fks := foreignKeys(t, joinTable, fk, t.TypeName)
if len(fks) > 0 {
t.addRelation(&Relation{
Type: BelongsToRelation,
Field: field,
FKs: fks,
JoinTable: joinTable,
})
return true
}
return false
}
func scanJSONValue(v reflect.Value, b []byte) error {
if !v.CanSet() {
return fmt.Errorf("pg: Scan(non-pointer %s)", v.Type())
}
if b == nil {
v.Set(reflect.New(v.Type()).Elem())
return nil
}
dec := json.NewDecoder(bytes.NewReader(b))
dec.UseNumber()
return dec.Decode(v.Addr().Interface())
}