db/db.go

375 lines
8.2 KiB
Go

package db
import (
"context"
"database/sql"
"reflect"
"time"
"git.bit5.ru/backend/dbr"
"git.bit5.ru/backend/errors"
"git.bit5.ru/backend/mysql"
"github.com/go-logr/logr"
)
type Settings struct {
Host, Port, User, Pass, Name, Prefix, Params string
Driver string
LogLevel int
Weight uint32
MaxIdleConns, MaxOpenConns int
ConnMaxLifetimeSec, ConnMaxIdleTimeSec int
}
func (s *Settings) ConnStr() string {
return s.User + ":" + s.Pass + "@tcp(" + s.Host + ":" + s.Port + ")/" + s.Name + s.Params
}
type Pool struct {
DB *sql.DB
S Settings
}
func (p *Pool) Close() {
p.DB.Close()
}
type DBC struct {
Logger logr.Logger
P *Pool
//NOTE: it's not a 'connection', it embeds sql.Pool (dbr uses strange names for entities)
con *dbr.Connection
sess *dbr.Session
trx *dbr.Tx
trxRefs int
commitTry int
}
func OpenPool(s Settings) *Pool {
driver := s.Driver
if len(driver) == 0 {
driver = "mysql"
}
//NOTE: sql.Open(..) doesn't happen to return an error
sqlDb, _ := sql.Open(driver, s.ConnStr())
if s.MaxIdleConns == 0 {
//NOTE: using default sql.DB settings
sqlDb.SetMaxIdleConns(2)
} else {
sqlDb.SetMaxIdleConns(s.MaxIdleConns)
}
if s.MaxOpenConns != 0 {
sqlDb.SetMaxOpenConns(s.MaxOpenConns)
}
if s.ConnMaxLifetimeSec != 0 {
sqlDb.SetConnMaxLifetime(time.Second * time.Duration(s.ConnMaxLifetimeSec))
}
if s.ConnMaxIdleTimeSec != 0 {
sqlDb.SetConnMaxIdleTime(time.Second * time.Duration(s.ConnMaxIdleTimeSec))
}
return &Pool{DB: sqlDb, S: s}
}
func GetDBC(p *Pool, logger logr.Logger) *DBC {
if len(p.S.Prefix) > 0 {
logger = logger.WithValues("db", p.S.Prefix)
}
con := dbr.NewConnection(p.DB, nil)
sess := con.NewSession(&EventReceiver{logger: logger, s: p.S})
dbc := &DBC{Logger: logger, P: p, con: con, sess: sess}
return dbc
}
func (dbc *DBC) DB() *sql.DB {
return dbc.P.DB
}
func (dbc *DBC) Transaction(txFunc func(dbc *DBC) error) (err error) {
err = dbc.Begin()
if err != nil {
return
}
defer func() {
if p := recover(); p != nil {
dbc.Rollback()
panic(p) // re-throw panic after Rollback
} else if err != nil {
dbc.Rollback() // err is non-nil; don't change it
} else {
err = dbc.Commit() // err is nil; if Commit returns error update err
}
}()
err = txFunc(dbc)
return err
}
func (dbc *DBC) TransactionContext(
ctx context.Context,
txFunc func(ctx context.Context, dbc *DBC) error,
) (err error) {
err = dbc.Begin()
if err != nil {
return
}
defer func() {
if p := recover(); p != nil {
dbc.Rollback()
panic(p) // re-throw panic after Rollback
} else if err != nil {
dbc.Rollback() // err is non-nil; don't change it
} else {
err = dbc.Commit() // err is nil; if Commit returns error update err
}
}()
err = txFunc(ctx, dbc)
return err
}
func (dbc *DBC) Begin() error {
//check if we are already in a transaction
if dbc.trx != nil {
dbc.trxRefs++
return nil
}
trx, err := dbc.sess.Begin()
if err == nil {
dbc.trx = trx
dbc.trxRefs = 1
dbc.commitTry = 0
}
return err
}
func (dbc *DBC) Rollback() error {
if dbc.trxRefs > 0 {
dbc.trxRefs--
if dbc.trxRefs == 0 {
if err := dbc.trx.Rollback(); err != nil {
return err
}
dbc.trx = nil
}
}
return nil
}
func (dbc *DBC) RollbackOnDefer() {
if dbc.commitTry == 0 {
dbc.Rollback()
} else {
dbc.commitTry--
}
}
func (dbc *DBC) Commit() error {
dbc.commitTry++
if dbc.trxRefs > 0 {
dbc.trxRefs--
if dbc.trxRefs == 0 {
if err := dbc.trx.Commit(); err != nil {
return err
}
dbc.trx = nil
}
}
return nil
}
// Update creates a new UpdateBuilder for the given table
func (dbc *DBC) Update(table string) *dbr.UpdateBuilder {
if dbc.trx == nil {
return &dbr.UpdateBuilder{
Session: dbc.sess,
Runner: dbc.P.DB,
Table: table,
}
}
return &dbr.UpdateBuilder{
Session: dbc.trx.Session,
Runner: dbc.trx.Tx,
Table: table,
}
}
// UpdateBySQL creates a new UpdateBuilder for the given SQL string and arguments
func (dbc *DBC) UpdateBySQL(sql string, args ...interface{}) *dbr.UpdateBuilder {
if dbc.trx == nil {
return &dbr.UpdateBuilder{
Session: dbc.sess,
Runner: dbc.P.DB,
RawFullSql: sql,
RawArguments: args,
}
}
return &dbr.UpdateBuilder{
Session: dbc.trx.Session,
Runner: dbc.trx.Tx,
RawFullSql: sql,
RawArguments: args,
}
}
// DeleteFrom creates a new DeleteBuilder for the given table
func (dbc *DBC) DeleteFrom(from string) *dbr.DeleteBuilder {
if dbc.trx == nil {
return &dbr.DeleteBuilder{
Session: dbc.sess,
Runner: dbc.P.DB,
From: from,
}
}
return &dbr.DeleteBuilder{
Session: dbc.trx.Session,
Runner: dbc.trx.Tx,
From: from,
}
}
// Select creates a new SelectBuilder that select that given columns
func (dbc *DBC) Select(cols ...string) *dbr.SelectBuilder {
if dbc.trx == nil {
return &dbr.SelectBuilder{
Session: dbc.sess,
Runner: dbc.P.DB,
Columns: cols,
}
}
return &dbr.SelectBuilder{
Session: dbc.trx.Session,
Runner: dbc.trx.Tx,
Columns: cols,
}
}
// SelectBySQL creates a new SelectBuilder for the given SQL string and arguments
func (dbc *DBC) SelectBySQL(sql string, args ...interface{}) *dbr.SelectBuilder {
if dbc.trx == nil {
return &dbr.SelectBuilder{
Session: dbc.sess,
Runner: dbc.P.DB,
RawFullSql: sql,
RawArguments: args,
}
}
return &dbr.SelectBuilder{
Session: dbc.trx.Session,
Runner: dbc.trx.Tx,
RawFullSql: sql,
RawArguments: args,
}
}
// Note: Creates a new slice of dbr.SelectBuilder for the given SQL string
// Supported chunking only for first IN-list
func (dbc *DBC) SelectBySQLWithChunkedIN(sql string, chunkSize int, args ...interface{}) []*dbr.SelectBuilder {
var builders []*dbr.SelectBuilder
listsForIN := make(map[int]reflect.Value)
chunkedListsForIN := make(map[int][]reflect.Value)
for i, arg := range args {
valueOfDest := reflect.ValueOf(arg)
kindOfDest := valueOfDest.Kind()
if kindOfDest == reflect.Slice || kindOfDest == reflect.Array {
//Note: i is index of arg
listsForIN[i] = valueOfDest
}
}
if len(listsForIN) == 0 {
builder := dbc.SelectBySQL(sql, args)
builders = append(builders, builder)
return builders
}
for index, listForIN := range listsForIN {
var chunks []reflect.Value
valuesAmount := listForIN.Len()
fullChunksAmount := valuesAmount / chunkSize
modulo := valuesAmount % chunkSize
for i := 0; i < fullChunksAmount; i++ {
chunkStartIndex := i * chunkSize
chunkEndIndex := chunkStartIndex + chunkSize
chunkValues := listForIN.Slice(chunkStartIndex, chunkEndIndex)
chunks = append(chunks, chunkValues)
}
if modulo > 0 {
chunkStartIndex := fullChunksAmount * chunkSize
chunkEndIndex := chunkStartIndex + modulo
chunkValues := listForIN.Slice(chunkStartIndex, chunkEndIndex)
chunks = append(chunks, chunkValues)
}
chunkedListsForIN[index] = chunks
}
//TODO: Supported only first IN-list, because the several IN-lists can generate too many sql queries(chunks amount of 1-st IN-list * chunks amount of 2-d IN-list * etc.)
for argIndex, argChunks := range chunkedListsForIN {
for c := 0; c < len(argChunks); c++ {
argChunk := argChunks[c]
var sqlArgs = make([]interface{}, len(args))
//NOTE: Fill args for a separate sql query
for i := 0; i < len(args); i++ {
if i != argIndex {
sqlArgs[i] = args[i]
} else {
sqlArgs[i] = argChunk.Interface()
}
}
builder := dbc.SelectBySQL(sql, sqlArgs...)
builders = append(builders, builder)
}
}
return builders
}
// InsertInto instantiates a InsertBuilder for the given table
func (dbc *DBC) InsertInto(into string) *dbr.InsertBuilder {
if dbc.trx == nil {
return &dbr.InsertBuilder{
Session: dbc.sess,
Runner: dbc.P.DB,
Into: into,
}
}
return &dbr.InsertBuilder{
Session: dbc.trx.Session,
Runner: dbc.trx.Tx,
Into: into,
}
}
func IsDuplicateRecordError(err error) bool {
var myerr *mysql.MySQLError
if errors.As(err, &myerr) && myerr.Number == 1062 {
return true
}
return false
}
func IsNotFoundError(err error) bool {
return errors.Is(err, dbr.ErrNotFound)
}
func LastInsertIdU32(res sql.Result) uint32 {
lastId, _ := res.LastInsertId()
return uint32(lastId)
}