package db import ( "database/sql" "reflect" "git.bit5.ru/backend/colog" "git.bit5.ru/backend/dbr" "git.bit5.ru/backend/errors" "git.bit5.ru/backend/mysql" "git.bit5.ru/backend/res_tracker" ) const ( ChunkSizeForIN = 50000 ) type Settings struct { Host, Port, User, Pass, Name, Prefix, Params string Driver string LogLevel int Weight uint32 } func (s *Settings) ConnStr() string { return s.User + ":" + s.Pass + "@tcp(" + s.Host + ":" + s.Port + ")/" + s.Name + s.Params } type DBC struct { Logger *colog.CoLog s Settings _con *dbr.Connection //lazy one, should be accessed via con() method _sess *dbr.Session //lazy one, should be accessed via sess() method trx *dbr.Tx trxRefs int commitTry int } func GetDBC(logger *colog.CoLog, s Settings) *DBC { logger = logger.Clone().AddPrefix("[" + s.Prefix + "] ") dbc := &DBC{Logger: logger, s: s, _con: nil, _sess: nil} return dbc } func (dbc *DBC) con() *dbr.Connection { if dbc._con == nil { driver := dbc.s.Driver if len(driver) == 0 { driver = "mysql" } //NOTE: sql.Open(..) doesn't happen to return an error sqlDb, _ := sql.Open(driver, dbc.s.ConnStr()) dbc._con = dbr.NewConnection(sqlDb, nil) res_tracker.Track(dbc) } return dbc._con } func (dbc *DBC) sess() *dbr.Session { if dbc._sess == nil { dbc._sess = dbc.con().NewSession(&EventReceiver{logger: dbc.Logger, s: dbc.s}) } return dbc._sess } func (dbc *DBC) Open() { dbc.con() } func (dbc *DBC) IsOpen() bool { return dbc._con != nil } func (dbc *DBC) Close() error { if dbc._con != nil { res_tracker.Untrack(dbc) dbc.Rollback() err := dbc._con.Db.Close() dbc._con = nil dbc._sess = nil return err } else { return nil } } //NOTE: for low level stuff func (dbc *DBC) DB() *sql.DB { return dbc.con().Db } func (dbc *DBC) Transaction(txFunc func(dbc *DBC) error) (err error) { err = dbc.Begin() if err != nil { return } defer func() { if p := recover(); p != nil { dbc.Rollback() panic(p) // re-throw panic after Rollback } else if err != nil { dbc.Rollback() // err is non-nil; don't change it } else { err = dbc.Commit() // err is nil; if Commit returns error update err } }() err = txFunc(dbc) return err } func (dbc *DBC) Begin() error { //check if we are already in a transaction if dbc.trx != nil { dbc.trxRefs++ return nil } trx, err := dbc.sess().Begin() if err == nil { dbc.trx = trx dbc.trxRefs = 1 dbc.commitTry = 0 } return err } func (dbc *DBC) Rollback() error { if dbc.trxRefs > 0 { dbc.trxRefs-- if dbc.trxRefs == 0 { if err := dbc.trx.Rollback(); err != nil { return err } dbc.trx = nil } } return nil } func (dbc *DBC) RollbackOnDefer() { if dbc.commitTry == 0 { dbc.Rollback() } else { dbc.commitTry-- } } func (dbc *DBC) Commit() error { dbc.commitTry++ if dbc.trxRefs > 0 { dbc.trxRefs-- if dbc.trxRefs == 0 { if err := dbc.trx.Commit(); err != nil { return err } dbc.trx = nil } } return nil } // Update creates a new UpdateBuilder for the given table func (dbc *DBC) Update(table string) *dbr.UpdateBuilder { if dbc.trx == nil { return &dbr.UpdateBuilder{ Session: dbc.sess(), Runner: dbc.con().Db, Table: table, } } return &dbr.UpdateBuilder{ Session: dbc.trx.Session, Runner: dbc.trx.Tx, Table: table, } } // UpdateBySQL creates a new UpdateBuilder for the given SQL string and arguments func (dbc *DBC) UpdateBySQL(sql string, args ...interface{}) *dbr.UpdateBuilder { if dbc.trx == nil { return &dbr.UpdateBuilder{ Session: dbc.sess(), Runner: dbc.con().Db, RawFullSql: sql, RawArguments: args, } } return &dbr.UpdateBuilder{ Session: dbc.trx.Session, Runner: dbc.trx.Tx, RawFullSql: sql, RawArguments: args, } } // DeleteFrom creates a new DeleteBuilder for the given table func (dbc *DBC) DeleteFrom(from string) *dbr.DeleteBuilder { if dbc.trx == nil { return &dbr.DeleteBuilder{ Session: dbc.sess(), Runner: dbc.con().Db, From: from, } } return &dbr.DeleteBuilder{ Session: dbc.trx.Session, Runner: dbc.trx.Tx, From: from, } } // Select creates a new SelectBuilder that select that given columns func (dbc *DBC) Select(cols ...string) *dbr.SelectBuilder { if dbc.trx == nil { return &dbr.SelectBuilder{ Session: dbc.sess(), Runner: dbc.con().Db, Columns: cols, } } return &dbr.SelectBuilder{ Session: dbc.trx.Session, Runner: dbc.trx.Tx, Columns: cols, } } // SelectBySQL creates a new SelectBuilder for the given SQL string and arguments func (dbc *DBC) SelectBySQL(sql string, args ...interface{}) *dbr.SelectBuilder { if dbc.trx == nil { return &dbr.SelectBuilder{ Session: dbc.sess(), Runner: dbc.con().Db, RawFullSql: sql, RawArguments: args, } } return &dbr.SelectBuilder{ Session: dbc.trx.Session, Runner: dbc.trx.Tx, RawFullSql: sql, RawArguments: args, } } //Note: Creates a new slice of dbr.SelectBuilder for the given SQL string //Supported chunking only for first IN-list func (dbc *DBC) SelectBySQLWithChunkedIN(sql string, chunkSize int, args ...interface{}) []*dbr.SelectBuilder { var builders []*dbr.SelectBuilder listsForIN := make(map[int]reflect.Value) chunkedListsForIN := make(map[int][]reflect.Value) for i, arg := range args { valueOfDest := reflect.ValueOf(arg) kindOfDest := valueOfDest.Kind() if kindOfDest == reflect.Slice || kindOfDest == reflect.Array { //Note: i is index of arg listsForIN[i] = valueOfDest } } if len(listsForIN) == 0 { builder := dbc.SelectBySQL(sql, args) builders = append(builders, builder) return builders } for index, listForIN := range listsForIN { var chunks []reflect.Value valuesAmount := listForIN.Len() fullChunksAmount := valuesAmount / chunkSize modulo := valuesAmount % chunkSize for i := 0; i < fullChunksAmount; i++ { chunkStartIndex := i * chunkSize chunkEndIndex := chunkStartIndex + chunkSize chunkValues := listForIN.Slice(chunkStartIndex, chunkEndIndex) chunks = append(chunks, chunkValues) } if modulo > 0 { chunkStartIndex := fullChunksAmount * chunkSize chunkEndIndex := chunkStartIndex + modulo chunkValues := listForIN.Slice(chunkStartIndex, chunkEndIndex) chunks = append(chunks, chunkValues) } chunkedListsForIN[index] = chunks } //TODO: Supported only first IN-list, because the several IN-lists can generate too many sql queries(chunks amount of 1-st IN-list * chunks amount of 2-d IN-list * etc.) for argIndex, argChunks := range chunkedListsForIN { for c := 0; c < len(argChunks); c++ { argChunk := argChunks[c] var sqlArgs = make([]interface{}, len(args)) //NOTE: Fill args for a separate sql query for i := 0; i < len(args); i++ { if i != argIndex { sqlArgs[i] = args[i] } else { sqlArgs[i] = argChunk.Interface() } } builder := dbc.SelectBySQL(sql, sqlArgs...) builders = append(builders, builder) } } return builders } // InsertInto instantiates a InsertBuilder for the given table func (dbc *DBC) InsertInto(into string) *dbr.InsertBuilder { if dbc.trx == nil { return &dbr.InsertBuilder{ Session: dbc.sess(), Runner: dbc.con().Db, Into: into, } } return &dbr.InsertBuilder{ Session: dbc.trx.Session, Runner: dbc.trx.Tx, Into: into, } } func IsDuplicateRecordError(err error) bool { var myerr *mysql.MySQLError if errors.As(err, &myerr) && myerr.Number == 1062 { return true } return false } func IsNotFoundError(err error) bool { return errors.Is(err, dbr.ErrNotFound) } func LastInsertIdU32(res sql.Result) uint32 { lastId, _ := res.LastInsertId() return uint32(lastId) }