package db import ( "database/sql" "reflect" "time" "git.bit5.ru/backend/dbr" "git.bit5.ru/backend/errors" "git.bit5.ru/backend/mysql" "github.com/go-logr/logr" ) type Settings struct { Host, Port, User, Pass, Name, Prefix, Params string Driver string LogLevel int Weight uint32 MaxIdleConns, MaxOpenConns int ConnMaxLifetimeSec, ConnMaxIdleTimeSec int } func (s *Settings) ConnStr() string { return s.User + ":" + s.Pass + "@tcp(" + s.Host + ":" + s.Port + ")/" + s.Name + s.Params } type Pool struct { DB *sql.DB S Settings } func (p *Pool) Close() { p.DB.Close() } type DBC struct { Logger logr.Logger P *Pool //NOTE: it's not a 'connection', it embeds sql.Pool (dbr uses strange names for entities) con *dbr.Connection sess *dbr.Session trx *dbr.Tx trxRefs int commitTry int } func OpenPool(s Settings) *Pool { driver := s.Driver if len(driver) == 0 { driver = "mysql" } //NOTE: sql.Open(..) doesn't happen to return an error sqlDb, _ := sql.Open(driver, s.ConnStr()) if s.MaxIdleConns == 0 { //NOTE: using default sql.DB settings sqlDb.SetMaxIdleConns(2) } else { sqlDb.SetMaxIdleConns(s.MaxIdleConns) } if s.MaxOpenConns != 0 { sqlDb.SetMaxOpenConns(s.MaxOpenConns) } if s.ConnMaxLifetimeSec != 0 { sqlDb.SetConnMaxLifetime(time.Second * time.Duration(s.ConnMaxLifetimeSec)) } if s.ConnMaxIdleTimeSec != 0 { sqlDb.SetConnMaxIdleTime(time.Second * time.Duration(s.ConnMaxIdleTimeSec)) } return &Pool{DB: sqlDb, S: s} } func GetDBC(p *Pool, logger logr.Logger) *DBC { logger = logger.WithValues("db", p.S.Prefix) con := dbr.NewConnection(p.DB, nil) sess := con.NewSession(&EventReceiver{logger: logger, s: p.S}) dbc := &DBC{Logger: logger, P: p, con: con, sess: sess} return dbc } func (dbc *DBC) DB() *sql.DB { return dbc.P.DB } func (dbc *DBC) Transaction(txFunc func(dbc *DBC) error) (err error) { err = dbc.Begin() if err != nil { return } defer func() { if p := recover(); p != nil { dbc.Rollback() panic(p) // re-throw panic after Rollback } else if err != nil { dbc.Rollback() // err is non-nil; don't change it } else { err = dbc.Commit() // err is nil; if Commit returns error update err } }() err = txFunc(dbc) return err } func (dbc *DBC) Begin() error { //check if we are already in a transaction if dbc.trx != nil { dbc.trxRefs++ return nil } trx, err := dbc.sess.Begin() if err == nil { dbc.trx = trx dbc.trxRefs = 1 dbc.commitTry = 0 } return err } func (dbc *DBC) Rollback() error { if dbc.trxRefs > 0 { dbc.trxRefs-- if dbc.trxRefs == 0 { if err := dbc.trx.Rollback(); err != nil { return err } dbc.trx = nil } } return nil } func (dbc *DBC) RollbackOnDefer() { if dbc.commitTry == 0 { dbc.Rollback() } else { dbc.commitTry-- } } func (dbc *DBC) Commit() error { dbc.commitTry++ if dbc.trxRefs > 0 { dbc.trxRefs-- if dbc.trxRefs == 0 { if err := dbc.trx.Commit(); err != nil { return err } dbc.trx = nil } } return nil } // Update creates a new UpdateBuilder for the given table func (dbc *DBC) Update(table string) *dbr.UpdateBuilder { if dbc.trx == nil { return &dbr.UpdateBuilder{ Session: dbc.sess, Runner: dbc.P.DB, Table: table, } } return &dbr.UpdateBuilder{ Session: dbc.trx.Session, Runner: dbc.trx.Tx, Table: table, } } // UpdateBySQL creates a new UpdateBuilder for the given SQL string and arguments func (dbc *DBC) UpdateBySQL(sql string, args ...interface{}) *dbr.UpdateBuilder { if dbc.trx == nil { return &dbr.UpdateBuilder{ Session: dbc.sess, Runner: dbc.P.DB, RawFullSql: sql, RawArguments: args, } } return &dbr.UpdateBuilder{ Session: dbc.trx.Session, Runner: dbc.trx.Tx, RawFullSql: sql, RawArguments: args, } } // DeleteFrom creates a new DeleteBuilder for the given table func (dbc *DBC) DeleteFrom(from string) *dbr.DeleteBuilder { if dbc.trx == nil { return &dbr.DeleteBuilder{ Session: dbc.sess, Runner: dbc.P.DB, From: from, } } return &dbr.DeleteBuilder{ Session: dbc.trx.Session, Runner: dbc.trx.Tx, From: from, } } // Select creates a new SelectBuilder that select that given columns func (dbc *DBC) Select(cols ...string) *dbr.SelectBuilder { if dbc.trx == nil { return &dbr.SelectBuilder{ Session: dbc.sess, Runner: dbc.P.DB, Columns: cols, } } return &dbr.SelectBuilder{ Session: dbc.trx.Session, Runner: dbc.trx.Tx, Columns: cols, } } // SelectBySQL creates a new SelectBuilder for the given SQL string and arguments func (dbc *DBC) SelectBySQL(sql string, args ...interface{}) *dbr.SelectBuilder { if dbc.trx == nil { return &dbr.SelectBuilder{ Session: dbc.sess, Runner: dbc.P.DB, RawFullSql: sql, RawArguments: args, } } return &dbr.SelectBuilder{ Session: dbc.trx.Session, Runner: dbc.trx.Tx, RawFullSql: sql, RawArguments: args, } } //Note: Creates a new slice of dbr.SelectBuilder for the given SQL string //Supported chunking only for first IN-list func (dbc *DBC) SelectBySQLWithChunkedIN(sql string, chunkSize int, args ...interface{}) []*dbr.SelectBuilder { var builders []*dbr.SelectBuilder listsForIN := make(map[int]reflect.Value) chunkedListsForIN := make(map[int][]reflect.Value) for i, arg := range args { valueOfDest := reflect.ValueOf(arg) kindOfDest := valueOfDest.Kind() if kindOfDest == reflect.Slice || kindOfDest == reflect.Array { //Note: i is index of arg listsForIN[i] = valueOfDest } } if len(listsForIN) == 0 { builder := dbc.SelectBySQL(sql, args) builders = append(builders, builder) return builders } for index, listForIN := range listsForIN { var chunks []reflect.Value valuesAmount := listForIN.Len() fullChunksAmount := valuesAmount / chunkSize modulo := valuesAmount % chunkSize for i := 0; i < fullChunksAmount; i++ { chunkStartIndex := i * chunkSize chunkEndIndex := chunkStartIndex + chunkSize chunkValues := listForIN.Slice(chunkStartIndex, chunkEndIndex) chunks = append(chunks, chunkValues) } if modulo > 0 { chunkStartIndex := fullChunksAmount * chunkSize chunkEndIndex := chunkStartIndex + modulo chunkValues := listForIN.Slice(chunkStartIndex, chunkEndIndex) chunks = append(chunks, chunkValues) } chunkedListsForIN[index] = chunks } //TODO: Supported only first IN-list, because the several IN-lists can generate too many sql queries(chunks amount of 1-st IN-list * chunks amount of 2-d IN-list * etc.) for argIndex, argChunks := range chunkedListsForIN { for c := 0; c < len(argChunks); c++ { argChunk := argChunks[c] var sqlArgs = make([]interface{}, len(args)) //NOTE: Fill args for a separate sql query for i := 0; i < len(args); i++ { if i != argIndex { sqlArgs[i] = args[i] } else { sqlArgs[i] = argChunk.Interface() } } builder := dbc.SelectBySQL(sql, sqlArgs...) builders = append(builders, builder) } } return builders } // InsertInto instantiates a InsertBuilder for the given table func (dbc *DBC) InsertInto(into string) *dbr.InsertBuilder { if dbc.trx == nil { return &dbr.InsertBuilder{ Session: dbc.sess, Runner: dbc.P.DB, Into: into, } } return &dbr.InsertBuilder{ Session: dbc.trx.Session, Runner: dbc.trx.Tx, Into: into, } } func IsDuplicateRecordError(err error) bool { var myerr *mysql.MySQLError if errors.As(err, &myerr) && myerr.Number == 1062 { return true } return false } func IsNotFoundError(err error) bool { return errors.Is(err, dbr.ErrNotFound) } func LastInsertIdU32(res sql.Result) uint32 { lastId, _ := res.LastInsertId() return uint32(lastId) }