From 08743191fe9a32558e6e4ab29cc2d5bdc0ddfe30 Mon Sep 17 00:00:00 2001 From: Vladislav Veselskiy Date: Thu, 7 Mar 2024 15:05:36 +0300 Subject: [PATCH] Add method DBC.TransactionContext --- db.go | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/db.go b/db.go index b5c32c8..82a7f61 100644 --- a/db.go +++ b/db.go @@ -1,6 +1,7 @@ package db import ( + "context" "database/sql" "reflect" "time" @@ -75,7 +76,7 @@ func OpenPool(s Settings) *Pool { func GetDBC(p *Pool, logger logr.Logger) *DBC { if len(p.S.Prefix) > 0 { logger = logger.WithValues("db", p.S.Prefix) - } + } con := dbr.NewConnection(p.DB, nil) sess := con.NewSession(&EventReceiver{logger: logger, s: p.S}) @@ -108,6 +109,29 @@ func (dbc *DBC) Transaction(txFunc func(dbc *DBC) error) (err error) { return err } +func (dbc *DBC) TransactionContext( + ctx context.Context, + txFunc func(ctx context.Context, dbc *DBC) error, +) (err error) { + + err = dbc.Begin() + if err != nil { + return + } + defer func() { + if p := recover(); p != nil { + dbc.Rollback() + panic(p) // re-throw panic after Rollback + } else if err != nil { + dbc.Rollback() // err is non-nil; don't change it + } else { + err = dbc.Commit() // err is nil; if Commit returns error update err + } + }() + err = txFunc(ctx, dbc) + return err +} + func (dbc *DBC) Begin() error { //check if we are already in a transaction if dbc.trx != nil { @@ -247,8 +271,8 @@ func (dbc *DBC) SelectBySQL(sql string, args ...interface{}) *dbr.SelectBuilder } } -//Note: Creates a new slice of dbr.SelectBuilder for the given SQL string -//Supported chunking only for first IN-list +// Note: Creates a new slice of dbr.SelectBuilder for the given SQL string +// Supported chunking only for first IN-list func (dbc *DBC) SelectBySQLWithChunkedIN(sql string, chunkSize int, args ...interface{}) []*dbr.SelectBuilder { var builders []*dbr.SelectBuilder listsForIN := make(map[int]reflect.Value)