package sqlite import ( "database/sql" "strconv" "strings" _ "github.com/mattn/go-sqlite3" "gorm.io/gorm" "gorm.io/gorm/callbacks" "gorm.io/gorm/clause" "gorm.io/gorm/logger" "gorm.io/gorm/migrator" "gorm.io/gorm/schema" ) // DriverName is the default driver name for SQLite. const DriverName = "sqlite3" type Dialector struct { DriverName string DSN string Conn gorm.ConnPool } func Open(dsn string) gorm.Dialector { return &Dialector{DSN: dsn} } func (dialector Dialector) Name() string { return "sqlite" } func (dialector Dialector) Initialize(db *gorm.DB) (err error) { if dialector.DriverName == "" { dialector.DriverName = DriverName } // register callbacks callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ LastInsertIDReversed: true, }) if dialector.Conn != nil { db.ConnPool = dialector.Conn } else { db.ConnPool, err = sql.Open(dialector.DriverName, dialector.DSN) if err != nil { return err } } for k, v := range dialector.ClauseBuilders() { db.ClauseBuilders[k] = v } return } func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder { return map[string]clause.ClauseBuilder{ "INSERT": func(c clause.Clause, builder clause.Builder) { if insert, ok := c.Expression.(clause.Insert); ok { if stmt, ok := builder.(*gorm.Statement); ok { stmt.WriteString("INSERT ") if insert.Modifier != "" { stmt.WriteString(insert.Modifier) stmt.WriteByte(' ') } stmt.WriteString("INTO ") if insert.Table.Name == "" { stmt.WriteQuoted(stmt.Table) } else { stmt.WriteQuoted(insert.Table) } return } } c.Build(builder) }, "LIMIT": func(c clause.Clause, builder clause.Builder) { if limit, ok := c.Expression.(clause.Limit); ok { if limit.Limit > 0 { builder.WriteString("LIMIT ") builder.WriteString(strconv.Itoa(limit.Limit)) } if limit.Offset > 0 { if limit.Limit > 0 { builder.WriteString(" ") } builder.WriteString("OFFSET ") builder.WriteString(strconv.Itoa(limit.Offset)) } } }, "FOR": func(c clause.Clause, builder clause.Builder) { if _, ok := c.Expression.(clause.Locking); ok { // SQLite3 does not support row-level locking. return } c.Build(builder) }, } } func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression { if field.AutoIncrement { return clause.Expr{SQL: "NULL"} } // doesn't work, will raise error return clause.Expr{SQL: "DEFAULT"} } func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { return Migrator{migrator.Migrator{Config: migrator.Config{ DB: db, Dialector: dialector, CreateIndexAfterCreateTable: true, }}} } func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { writer.WriteByte('?') } func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { writer.WriteByte('`') if strings.Contains(str, ".") { for idx, str := range strings.Split(str, ".") { if idx > 0 { writer.WriteString(".`") } writer.WriteString(str) writer.WriteByte('`') } } else { writer.WriteString(str) writer.WriteByte('`') } } func (dialector Dialector) Explain(sql string, vars ...interface{}) string { return logger.ExplainSQL(sql, nil, `"`, vars...) } func (dialector Dialector) DataTypeOf(field *schema.Field) string { switch field.DataType { case schema.Bool: return "numeric" case schema.Int, schema.Uint: if field.AutoIncrement && !field.PrimaryKey { // https://www.sqlite.org/autoinc.html return "integer PRIMARY KEY AUTOINCREMENT" } else { return "integer" } case schema.Float: return "real" case schema.String: return "text" case schema.Time: return "datetime" case schema.Bytes: return "blob" } return string(field.DataType) } func (dialectopr Dialector) SavePoint(tx *gorm.DB, name string) error { tx.Exec("SAVEPOINT " + name) return nil } func (dialectopr Dialector) RollbackTo(tx *gorm.DB, name string) error { tx.Exec("ROLLBACK TO SAVEPOINT " + name) return nil }