package callbacks import ( "reflect" "strings" "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" ) func SaveBeforeAssociations(create bool) func(db *gorm.DB) { return func(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil { selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create) // Save Belongs To associations for _, rel := range db.Statement.Schema.Relationships.BelongsTo { if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { continue } setupReferences := func(obj reflect.Value, elem reflect.Value) { for _, ref := range rel.References { if !ref.OwnPrimaryKey { pv, _ := ref.PrimaryKey.ValueOf(elem) db.AddError(ref.ForeignKey.Set(obj, pv)) if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { dest[ref.ForeignKey.DBName] = pv if _, ok := dest[rel.Name]; ok { dest[rel.Name] = elem.Interface() } } } } } switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: var ( objs = make([]reflect.Value, 0, db.Statement.ReflectValue.Len()) fieldType = rel.Field.FieldType isPtr = fieldType.Kind() == reflect.Ptr ) if !isPtr { fieldType = reflect.PtrTo(fieldType) } elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) for i := 0; i < db.Statement.ReflectValue.Len(); i++ { obj := db.Statement.ReflectValue.Index(i) if reflect.Indirect(obj).Kind() == reflect.Struct { if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value rv := rel.Field.ReflectValueOf(obj) // relation reflect value objs = append(objs, obj) if isPtr { elems = reflect.Append(elems, rv) } else { elems = reflect.Append(elems, rv.Addr()) } } } else { break } } if elems.Len() > 0 { if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil { for i := 0; i < elems.Len(); i++ { setupReferences(objs[i], elems.Index(i)) } } } case reflect.Struct: if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value if rv.Kind() != reflect.Ptr { rv = rv.Addr() } if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil { setupReferences(db.Statement.ReflectValue, rv) } } } } } } } func SaveAfterAssociations(create bool) func(db *gorm.DB) { return func(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil { selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create) // Save Has One associations for _, rel := range db.Statement.Schema.Relationships.HasOne { if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { continue } switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: var ( fieldType = rel.Field.FieldType isPtr = fieldType.Kind() == reflect.Ptr ) if !isPtr { fieldType = reflect.PtrTo(fieldType) } elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) for i := 0; i < db.Statement.ReflectValue.Len(); i++ { obj := db.Statement.ReflectValue.Index(i) if reflect.Indirect(obj).Kind() == reflect.Struct { if _, zero := rel.Field.ValueOf(obj); !zero { rv := rel.Field.ReflectValueOf(obj) if rv.Kind() != reflect.Ptr { rv = rv.Addr() } for _, ref := range rel.References { if ref.OwnPrimaryKey { fv, _ := ref.PrimaryKey.ValueOf(obj) db.AddError(ref.ForeignKey.Set(rv, fv)) } else if ref.PrimaryValue != "" { db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue)) } } elems = reflect.Append(elems, rv) } } } if elems.Len() > 0 { assignmentColumns := make([]string, 0, len(rel.References)) for _, ref := range rel.References { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) } case reflect.Struct: if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) if f.Kind() != reflect.Ptr { f = f.Addr() } assignmentColumns := make([]string, 0, len(rel.References)) for _, ref := range rel.References { if ref.OwnPrimaryKey { fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue) ref.ForeignKey.Set(f, fv) } else if ref.PrimaryValue != "" { ref.ForeignKey.Set(f, ref.PrimaryValue) } assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns) } } } // Save Has Many associations for _, rel := range db.Statement.Schema.Relationships.HasMany { if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { continue } fieldType := rel.Field.IndirectFieldType.Elem() isPtr := fieldType.Kind() == reflect.Ptr if !isPtr { fieldType = reflect.PtrTo(fieldType) } elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) appendToElems := func(v reflect.Value) { if _, zero := rel.Field.ValueOf(v); !zero { f := reflect.Indirect(rel.Field.ReflectValueOf(v)) for i := 0; i < f.Len(); i++ { elem := f.Index(i) for _, ref := range rel.References { if ref.OwnPrimaryKey { pv, _ := ref.PrimaryKey.ValueOf(v) ref.ForeignKey.Set(elem, pv) } else if ref.PrimaryValue != "" { ref.ForeignKey.Set(elem, ref.PrimaryValue) } } if isPtr { elems = reflect.Append(elems, elem) } else { elems = reflect.Append(elems, elem.Addr()) } } } } switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < db.Statement.ReflectValue.Len(); i++ { obj := db.Statement.ReflectValue.Index(i) if reflect.Indirect(obj).Kind() == reflect.Struct { appendToElems(obj) } } case reflect.Struct: appendToElems(db.Statement.ReflectValue) } if elems.Len() > 0 { assignmentColumns := make([]string, 0, len(rel.References)) for _, ref := range rel.References { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) } } // Save Many2Many associations for _, rel := range db.Statement.Schema.Relationships.Many2Many { if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { continue } fieldType := rel.Field.IndirectFieldType.Elem() isPtr := fieldType.Kind() == reflect.Ptr if !isPtr { fieldType = reflect.PtrTo(fieldType) } elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10) objs := []reflect.Value{} appendToJoins := func(obj reflect.Value, elem reflect.Value) { joinValue := reflect.New(rel.JoinTable.ModelType) for _, ref := range rel.References { if ref.OwnPrimaryKey { fv, _ := ref.PrimaryKey.ValueOf(obj) ref.ForeignKey.Set(joinValue, fv) } else if ref.PrimaryValue != "" { ref.ForeignKey.Set(joinValue, ref.PrimaryValue) } else { fv, _ := ref.PrimaryKey.ValueOf(elem) ref.ForeignKey.Set(joinValue, fv) } } joins = reflect.Append(joins, joinValue) } appendToElems := func(v reflect.Value) { if _, zero := rel.Field.ValueOf(v); !zero { f := reflect.Indirect(rel.Field.ReflectValueOf(v)) for i := 0; i < f.Len(); i++ { elem := f.Index(i) objs = append(objs, v) if isPtr { elems = reflect.Append(elems, elem) } else { elems = reflect.Append(elems, elem.Addr()) } } } } switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < db.Statement.ReflectValue.Len(); i++ { obj := db.Statement.ReflectValue.Index(i) if reflect.Indirect(obj).Kind() == reflect.Struct { appendToElems(obj) } } case reflect.Struct: appendToElems(db.Statement.ReflectValue) } // optimize elems of reflect value length if elemLen := elems.Len(); elemLen > 0 { if v, ok := selectColumns[rel.Name+".*"]; !ok || v { saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) } for i := 0; i < elemLen; i++ { appendToJoins(objs[i], elems.Index(i)) } } if joins.Len() > 0 { db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(clause.OnConflict{DoNothing: true}).Session(&gorm.Session{ SkipHooks: db.Statement.SkipHooks, DisableNestedTransaction: true, }).Create(joins.Interface()).Error) } } } } } func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) clause.OnConflict { if stmt.DB.FullSaveAssociations { defaultUpdatingColumns = make([]string, 0, len(s.DBNames)) for _, dbName := range s.DBNames { if v, ok := selectColumns[dbName]; (ok && !v) || (!ok && restricted) { continue } if !s.LookUpField(dbName).PrimaryKey { defaultUpdatingColumns = append(defaultUpdatingColumns, dbName) } } } if len(defaultUpdatingColumns) > 0 { columns := make([]clause.Column, 0, len(s.PrimaryFieldDBNames)) for _, dbName := range s.PrimaryFieldDBNames { columns = append(columns, clause.Column{Name: dbName}) } return clause.OnConflict{ Columns: columns, DoUpdates: clause.AssignmentColumns(defaultUpdatingColumns), } } return clause.OnConflict{DoNothing: true} } func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error { var ( selects, omits []string onConflict = onConflictOption(db.Statement, rel.FieldSchema, selectColumns, restricted, defaultUpdatingColumns) refName = rel.Name + "." ) for name, ok := range selectColumns { columnName := "" if strings.HasPrefix(name, refName) { columnName = strings.TrimPrefix(name, refName) } if columnName != "" { if ok { selects = append(selects, columnName) } else { omits = append(omits, columnName) } } } tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict).Session(&gorm.Session{ FullSaveAssociations: db.FullSaveAssociations, SkipHooks: db.Statement.SkipHooks, DisableNestedTransaction: true, }) db.Statement.Settings.Range(func(k, v interface{}) bool { tx.Statement.Settings.Store(k, v) return true }) if tx.Statement.FullSaveAssociations { tx = tx.InstanceSet("gorm:update_track_time", true) } if len(selects) > 0 { tx = tx.Select(selects) } else if restricted && len(omits) == 0 { tx = tx.Omit(clause.Associations) } if len(omits) > 0 { tx = tx.Omit(omits...) } return db.AddError(tx.Create(values).Error) }