375 lines
8.2 KiB
Go
375 lines
8.2 KiB
Go
package db
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"reflect"
|
|
"time"
|
|
|
|
"git.bit5.ru/backend/dbr"
|
|
"git.bit5.ru/backend/errors"
|
|
"git.bit5.ru/backend/mysql"
|
|
|
|
"github.com/go-logr/logr"
|
|
)
|
|
|
|
type Settings struct {
|
|
Host, Port, User, Pass, Name, Prefix, Params string
|
|
Driver string
|
|
LogLevel int
|
|
Weight uint32
|
|
MaxIdleConns, MaxOpenConns int
|
|
ConnMaxLifetimeSec, ConnMaxIdleTimeSec int
|
|
}
|
|
|
|
func (s *Settings) ConnStr() string {
|
|
return s.User + ":" + s.Pass + "@tcp(" + s.Host + ":" + s.Port + ")/" + s.Name + s.Params
|
|
}
|
|
|
|
type Pool struct {
|
|
DB *sql.DB
|
|
S Settings
|
|
}
|
|
|
|
func (p *Pool) Close() {
|
|
p.DB.Close()
|
|
}
|
|
|
|
type DBC struct {
|
|
Logger logr.Logger
|
|
P *Pool
|
|
//NOTE: it's not a 'connection', it embeds sql.Pool (dbr uses strange names for entities)
|
|
con *dbr.Connection
|
|
sess *dbr.Session
|
|
trx *dbr.Tx
|
|
trxRefs int
|
|
commitTry int
|
|
}
|
|
|
|
func OpenPool(s Settings) *Pool {
|
|
driver := s.Driver
|
|
if len(driver) == 0 {
|
|
driver = "mysql"
|
|
}
|
|
//NOTE: sql.Open(..) doesn't happen to return an error
|
|
sqlDb, _ := sql.Open(driver, s.ConnStr())
|
|
|
|
if s.MaxIdleConns == 0 {
|
|
//NOTE: using default sql.DB settings
|
|
sqlDb.SetMaxIdleConns(2)
|
|
} else {
|
|
sqlDb.SetMaxIdleConns(s.MaxIdleConns)
|
|
}
|
|
if s.MaxOpenConns != 0 {
|
|
sqlDb.SetMaxOpenConns(s.MaxOpenConns)
|
|
}
|
|
if s.ConnMaxLifetimeSec != 0 {
|
|
sqlDb.SetConnMaxLifetime(time.Second * time.Duration(s.ConnMaxLifetimeSec))
|
|
}
|
|
if s.ConnMaxIdleTimeSec != 0 {
|
|
sqlDb.SetConnMaxIdleTime(time.Second * time.Duration(s.ConnMaxIdleTimeSec))
|
|
}
|
|
|
|
return &Pool{DB: sqlDb, S: s}
|
|
}
|
|
|
|
func GetDBC(p *Pool, logger logr.Logger) *DBC {
|
|
if len(p.S.Prefix) > 0 {
|
|
logger = logger.WithValues("db", p.S.Prefix)
|
|
}
|
|
|
|
con := dbr.NewConnection(p.DB, nil)
|
|
sess := con.NewSession(&EventReceiver{logger: logger, s: p.S})
|
|
|
|
dbc := &DBC{Logger: logger, P: p, con: con, sess: sess}
|
|
|
|
return dbc
|
|
}
|
|
|
|
func (dbc *DBC) DB() *sql.DB {
|
|
return dbc.P.DB
|
|
}
|
|
|
|
func (dbc *DBC) Transaction(txFunc func(dbc *DBC) error) (err error) {
|
|
err = dbc.Begin()
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer func() {
|
|
if p := recover(); p != nil {
|
|
dbc.Rollback()
|
|
panic(p) // re-throw panic after Rollback
|
|
} else if err != nil {
|
|
dbc.Rollback() // err is non-nil; don't change it
|
|
} else {
|
|
err = dbc.Commit() // err is nil; if Commit returns error update err
|
|
}
|
|
}()
|
|
err = txFunc(dbc)
|
|
return err
|
|
}
|
|
|
|
func (dbc *DBC) TransactionContext(
|
|
ctx context.Context,
|
|
txFunc func(ctx context.Context, dbc *DBC) error,
|
|
) (err error) {
|
|
|
|
err = dbc.Begin()
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer func() {
|
|
if p := recover(); p != nil {
|
|
dbc.Rollback()
|
|
panic(p) // re-throw panic after Rollback
|
|
} else if err != nil {
|
|
dbc.Rollback() // err is non-nil; don't change it
|
|
} else {
|
|
err = dbc.Commit() // err is nil; if Commit returns error update err
|
|
}
|
|
}()
|
|
err = txFunc(ctx, dbc)
|
|
return err
|
|
}
|
|
|
|
func (dbc *DBC) Begin() error {
|
|
//check if we are already in a transaction
|
|
if dbc.trx != nil {
|
|
dbc.trxRefs++
|
|
return nil
|
|
}
|
|
trx, err := dbc.sess.Begin()
|
|
if err == nil {
|
|
dbc.trx = trx
|
|
dbc.trxRefs = 1
|
|
dbc.commitTry = 0
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (dbc *DBC) Rollback() error {
|
|
if dbc.trxRefs > 0 {
|
|
dbc.trxRefs--
|
|
if dbc.trxRefs == 0 {
|
|
if err := dbc.trx.Rollback(); err != nil {
|
|
return err
|
|
}
|
|
dbc.trx = nil
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (dbc *DBC) RollbackOnDefer() {
|
|
if dbc.commitTry == 0 {
|
|
dbc.Rollback()
|
|
} else {
|
|
dbc.commitTry--
|
|
}
|
|
}
|
|
|
|
func (dbc *DBC) Commit() error {
|
|
dbc.commitTry++
|
|
if dbc.trxRefs > 0 {
|
|
dbc.trxRefs--
|
|
if dbc.trxRefs == 0 {
|
|
if err := dbc.trx.Commit(); err != nil {
|
|
return err
|
|
}
|
|
dbc.trx = nil
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Update creates a new UpdateBuilder for the given table
|
|
func (dbc *DBC) Update(table string) *dbr.UpdateBuilder {
|
|
if dbc.trx == nil {
|
|
return &dbr.UpdateBuilder{
|
|
Session: dbc.sess,
|
|
Runner: dbc.P.DB,
|
|
Table: table,
|
|
}
|
|
}
|
|
|
|
return &dbr.UpdateBuilder{
|
|
Session: dbc.trx.Session,
|
|
Runner: dbc.trx.Tx,
|
|
Table: table,
|
|
}
|
|
}
|
|
|
|
// UpdateBySQL creates a new UpdateBuilder for the given SQL string and arguments
|
|
func (dbc *DBC) UpdateBySQL(sql string, args ...interface{}) *dbr.UpdateBuilder {
|
|
if dbc.trx == nil {
|
|
return &dbr.UpdateBuilder{
|
|
Session: dbc.sess,
|
|
Runner: dbc.P.DB,
|
|
RawFullSql: sql,
|
|
RawArguments: args,
|
|
}
|
|
}
|
|
|
|
return &dbr.UpdateBuilder{
|
|
Session: dbc.trx.Session,
|
|
Runner: dbc.trx.Tx,
|
|
RawFullSql: sql,
|
|
RawArguments: args,
|
|
}
|
|
}
|
|
|
|
// DeleteFrom creates a new DeleteBuilder for the given table
|
|
func (dbc *DBC) DeleteFrom(from string) *dbr.DeleteBuilder {
|
|
if dbc.trx == nil {
|
|
return &dbr.DeleteBuilder{
|
|
Session: dbc.sess,
|
|
Runner: dbc.P.DB,
|
|
From: from,
|
|
}
|
|
}
|
|
|
|
return &dbr.DeleteBuilder{
|
|
Session: dbc.trx.Session,
|
|
Runner: dbc.trx.Tx,
|
|
From: from,
|
|
}
|
|
}
|
|
|
|
// Select creates a new SelectBuilder that select that given columns
|
|
func (dbc *DBC) Select(cols ...string) *dbr.SelectBuilder {
|
|
if dbc.trx == nil {
|
|
return &dbr.SelectBuilder{
|
|
Session: dbc.sess,
|
|
Runner: dbc.P.DB,
|
|
Columns: cols,
|
|
}
|
|
}
|
|
|
|
return &dbr.SelectBuilder{
|
|
Session: dbc.trx.Session,
|
|
Runner: dbc.trx.Tx,
|
|
Columns: cols,
|
|
}
|
|
}
|
|
|
|
// SelectBySQL creates a new SelectBuilder for the given SQL string and arguments
|
|
func (dbc *DBC) SelectBySQL(sql string, args ...interface{}) *dbr.SelectBuilder {
|
|
if dbc.trx == nil {
|
|
return &dbr.SelectBuilder{
|
|
Session: dbc.sess,
|
|
Runner: dbc.P.DB,
|
|
RawFullSql: sql,
|
|
RawArguments: args,
|
|
}
|
|
}
|
|
|
|
return &dbr.SelectBuilder{
|
|
Session: dbc.trx.Session,
|
|
Runner: dbc.trx.Tx,
|
|
RawFullSql: sql,
|
|
RawArguments: args,
|
|
}
|
|
}
|
|
|
|
// Note: Creates a new slice of dbr.SelectBuilder for the given SQL string
|
|
// Supported chunking only for first IN-list
|
|
func (dbc *DBC) SelectBySQLWithChunkedIN(sql string, chunkSize int, args ...interface{}) []*dbr.SelectBuilder {
|
|
var builders []*dbr.SelectBuilder
|
|
listsForIN := make(map[int]reflect.Value)
|
|
chunkedListsForIN := make(map[int][]reflect.Value)
|
|
|
|
for i, arg := range args {
|
|
valueOfDest := reflect.ValueOf(arg)
|
|
kindOfDest := valueOfDest.Kind()
|
|
if kindOfDest == reflect.Slice || kindOfDest == reflect.Array {
|
|
//Note: i is index of arg
|
|
listsForIN[i] = valueOfDest
|
|
}
|
|
}
|
|
|
|
if len(listsForIN) == 0 {
|
|
builder := dbc.SelectBySQL(sql, args)
|
|
builders = append(builders, builder)
|
|
return builders
|
|
}
|
|
|
|
for index, listForIN := range listsForIN {
|
|
var chunks []reflect.Value
|
|
valuesAmount := listForIN.Len()
|
|
fullChunksAmount := valuesAmount / chunkSize
|
|
modulo := valuesAmount % chunkSize
|
|
|
|
for i := 0; i < fullChunksAmount; i++ {
|
|
chunkStartIndex := i * chunkSize
|
|
chunkEndIndex := chunkStartIndex + chunkSize
|
|
chunkValues := listForIN.Slice(chunkStartIndex, chunkEndIndex)
|
|
chunks = append(chunks, chunkValues)
|
|
}
|
|
|
|
if modulo > 0 {
|
|
chunkStartIndex := fullChunksAmount * chunkSize
|
|
chunkEndIndex := chunkStartIndex + modulo
|
|
chunkValues := listForIN.Slice(chunkStartIndex, chunkEndIndex)
|
|
chunks = append(chunks, chunkValues)
|
|
}
|
|
|
|
chunkedListsForIN[index] = chunks
|
|
}
|
|
|
|
//TODO: Supported only first IN-list, because the several IN-lists can generate too many sql queries(chunks amount of 1-st IN-list * chunks amount of 2-d IN-list * etc.)
|
|
for argIndex, argChunks := range chunkedListsForIN {
|
|
for c := 0; c < len(argChunks); c++ {
|
|
argChunk := argChunks[c]
|
|
var sqlArgs = make([]interface{}, len(args))
|
|
|
|
//NOTE: Fill args for a separate sql query
|
|
for i := 0; i < len(args); i++ {
|
|
if i != argIndex {
|
|
sqlArgs[i] = args[i]
|
|
} else {
|
|
sqlArgs[i] = argChunk.Interface()
|
|
}
|
|
}
|
|
|
|
builder := dbc.SelectBySQL(sql, sqlArgs...)
|
|
builders = append(builders, builder)
|
|
}
|
|
}
|
|
|
|
return builders
|
|
}
|
|
|
|
// InsertInto instantiates a InsertBuilder for the given table
|
|
func (dbc *DBC) InsertInto(into string) *dbr.InsertBuilder {
|
|
if dbc.trx == nil {
|
|
return &dbr.InsertBuilder{
|
|
Session: dbc.sess,
|
|
Runner: dbc.P.DB,
|
|
Into: into,
|
|
}
|
|
}
|
|
|
|
return &dbr.InsertBuilder{
|
|
Session: dbc.trx.Session,
|
|
Runner: dbc.trx.Tx,
|
|
Into: into,
|
|
}
|
|
}
|
|
|
|
func IsDuplicateRecordError(err error) bool {
|
|
var myerr *mysql.MySQLError
|
|
if errors.As(err, &myerr) && myerr.Number == 1062 {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func IsNotFoundError(err error) bool {
|
|
return errors.Is(err, dbr.ErrNotFound)
|
|
}
|
|
|
|
func LastInsertIdU32(res sql.Result) uint32 {
|
|
lastId, _ := res.LastInsertId()
|
|
return uint32(lastId)
|
|
}
|