Add method DBC.TransactionContext

This commit is contained in:
Владислав Весельский 2024-03-07 15:05:36 +03:00
parent 807329120b
commit 08743191fe
1 changed files with 27 additions and 3 deletions

30
db.go
View File

@ -1,6 +1,7 @@
package db package db
import ( import (
"context"
"database/sql" "database/sql"
"reflect" "reflect"
"time" "time"
@ -75,7 +76,7 @@ func OpenPool(s Settings) *Pool {
func GetDBC(p *Pool, logger logr.Logger) *DBC { func GetDBC(p *Pool, logger logr.Logger) *DBC {
if len(p.S.Prefix) > 0 { if len(p.S.Prefix) > 0 {
logger = logger.WithValues("db", p.S.Prefix) logger = logger.WithValues("db", p.S.Prefix)
} }
con := dbr.NewConnection(p.DB, nil) con := dbr.NewConnection(p.DB, nil)
sess := con.NewSession(&EventReceiver{logger: logger, s: p.S}) sess := con.NewSession(&EventReceiver{logger: logger, s: p.S})
@ -108,6 +109,29 @@ func (dbc *DBC) Transaction(txFunc func(dbc *DBC) error) (err error) {
return err return err
} }
func (dbc *DBC) TransactionContext(
ctx context.Context,
txFunc func(ctx context.Context, dbc *DBC) error,
) (err error) {
err = dbc.Begin()
if err != nil {
return
}
defer func() {
if p := recover(); p != nil {
dbc.Rollback()
panic(p) // re-throw panic after Rollback
} else if err != nil {
dbc.Rollback() // err is non-nil; don't change it
} else {
err = dbc.Commit() // err is nil; if Commit returns error update err
}
}()
err = txFunc(ctx, dbc)
return err
}
func (dbc *DBC) Begin() error { func (dbc *DBC) Begin() error {
//check if we are already in a transaction //check if we are already in a transaction
if dbc.trx != nil { if dbc.trx != nil {
@ -247,8 +271,8 @@ func (dbc *DBC) SelectBySQL(sql string, args ...interface{}) *dbr.SelectBuilder
} }
} }
//Note: Creates a new slice of dbr.SelectBuilder for the given SQL string // Note: Creates a new slice of dbr.SelectBuilder for the given SQL string
//Supported chunking only for first IN-list // Supported chunking only for first IN-list
func (dbc *DBC) SelectBySQLWithChunkedIN(sql string, chunkSize int, args ...interface{}) []*dbr.SelectBuilder { func (dbc *DBC) SelectBySQLWithChunkedIN(sql string, chunkSize int, args ...interface{}) []*dbr.SelectBuilder {
var builders []*dbr.SelectBuilder var builders []*dbr.SelectBuilder
listsForIN := make(map[int]reflect.Value) listsForIN := make(map[int]reflect.Value)