Migrating to more consistent usage of Pool

This commit is contained in:
Pavel Shevaev 2022-10-28 10:52:25 +03:00
parent 046f7d0d7b
commit d6d473e848
2 changed files with 116 additions and 179 deletions

114
db.go
View File

@ -8,7 +8,6 @@ import (
"git.bit5.ru/backend/dbr"
"git.bit5.ru/backend/errors"
"git.bit5.ru/backend/mysql"
"git.bit5.ru/backend/res_tracker"
)
type Settings struct {
@ -22,78 +21,57 @@ func (s *Settings) ConnStr() string {
return s.User + ":" + s.Pass + "@tcp(" + s.Host + ":" + s.Port + ")/" + s.Name + s.Params
}
type Pool struct {
DB *sql.DB
S Settings
}
func (p *Pool) Close() {
p.DB.Close()
}
type DBC struct {
Logger *colog.CoLog
s Settings
_con *dbr.Connection //lazy one, should be accessed via con() method
_sess *dbr.Session //lazy one, should be accessed via sess() method
P *Pool
//NOTE: it's not a 'connection', it embeds sql.Pool (dbr uses strange names for entities)
con *dbr.Connection
sess *dbr.Session
trx *dbr.Tx
trxRefs int
commitTry int
}
func GetDBC(logger *colog.CoLog, s Settings) *DBC {
logger = logger.Clone().AddPrefix("[" + s.Prefix + "] ")
dbc := &DBC{Logger: logger, s: s, _con: nil, _sess: nil}
return dbc
}
//NOTE: In its current implementation this method creates a new Pool
// on each connection request. This is subotimal and should be
// addressed in a new version of the package
func (dbc *DBC) con() *dbr.Connection {
if dbc._con == nil {
driver := dbc.s.Driver
func OpenPool(s Settings) *Pool {
driver := s.Driver
if len(driver) == 0 {
driver = "mysql"
}
//NOTE: sql.Open(..) doesn't happen to return an error
pool, _ := sql.Open(driver, dbc.s.ConnStr())
sqlDb, _ := sql.Open(driver, s.ConnStr())
dbc._con = dbr.NewConnection(pool, nil)
//TODO: take values from Settings
sqlDb.SetMaxIdleConns(100)
sqlDb.SetMaxOpenConns(0)
sqlDb.SetConnMaxLifetime(0)
sqlDb.SetConnMaxIdleTime(0)
res_tracker.Track(dbc)
}
return dbc._con
return &Pool{DB: sqlDb, S: s}
}
func (dbc *DBC) sess() *dbr.Session {
if dbc._sess == nil {
dbc._sess = dbc.con().NewSession(&EventReceiver{logger: dbc.Logger, s: dbc.s})
}
return dbc._sess
func GetDBC(p *Pool, logger *colog.CoLog) *DBC {
logger = logger.Clone().AddPrefix("[" + p.S.Prefix + "] ")
con := dbr.NewConnection(p.DB, nil)
sess := con.NewSession(&EventReceiver{logger: logger, s: p.S})
dbc := &DBC{Logger: logger, P: p, con: con, sess: sess}
return dbc
}
func (dbc *DBC) Open() {
dbc.con()
}
func (dbc *DBC) IsOpen() bool {
return dbc._con != nil
}
func (dbc *DBC) Close() error {
if dbc._con != nil {
res_tracker.Untrack(dbc)
dbc.Rollback()
err := dbc._con.Db.Close()
dbc._con = nil
dbc._sess = nil
return err
} else {
return nil
}
}
//NOTE: for low level stuff
func (dbc *DBC) DB() *sql.DB {
return dbc.con().Db
return dbc.P.DB
}
func (dbc *DBC) Transaction(txFunc func(dbc *DBC) error) (err error) {
@ -121,7 +99,7 @@ func (dbc *DBC) Begin() error {
dbc.trxRefs++
return nil
}
trx, err := dbc.sess().Begin()
trx, err := dbc.sess.Begin()
if err == nil {
dbc.trx = trx
dbc.trxRefs = 1
@ -169,8 +147,8 @@ func (dbc *DBC) Commit() error {
func (dbc *DBC) Update(table string) *dbr.UpdateBuilder {
if dbc.trx == nil {
return &dbr.UpdateBuilder{
Session: dbc.sess(),
Runner: dbc.con().Db,
Session: dbc.sess,
Runner: dbc.P.DB,
Table: table,
}
}
@ -186,8 +164,8 @@ func (dbc *DBC) Update(table string) *dbr.UpdateBuilder {
func (dbc *DBC) UpdateBySQL(sql string, args ...interface{}) *dbr.UpdateBuilder {
if dbc.trx == nil {
return &dbr.UpdateBuilder{
Session: dbc.sess(),
Runner: dbc.con().Db,
Session: dbc.sess,
Runner: dbc.P.DB,
RawFullSql: sql,
RawArguments: args,
}
@ -205,8 +183,8 @@ func (dbc *DBC) UpdateBySQL(sql string, args ...interface{}) *dbr.UpdateBuilder
func (dbc *DBC) DeleteFrom(from string) *dbr.DeleteBuilder {
if dbc.trx == nil {
return &dbr.DeleteBuilder{
Session: dbc.sess(),
Runner: dbc.con().Db,
Session: dbc.sess,
Runner: dbc.P.DB,
From: from,
}
}
@ -222,8 +200,8 @@ func (dbc *DBC) DeleteFrom(from string) *dbr.DeleteBuilder {
func (dbc *DBC) Select(cols ...string) *dbr.SelectBuilder {
if dbc.trx == nil {
return &dbr.SelectBuilder{
Session: dbc.sess(),
Runner: dbc.con().Db,
Session: dbc.sess,
Runner: dbc.P.DB,
Columns: cols,
}
}
@ -239,8 +217,8 @@ func (dbc *DBC) Select(cols ...string) *dbr.SelectBuilder {
func (dbc *DBC) SelectBySQL(sql string, args ...interface{}) *dbr.SelectBuilder {
if dbc.trx == nil {
return &dbr.SelectBuilder{
Session: dbc.sess(),
Runner: dbc.con().Db,
Session: dbc.sess,
Runner: dbc.P.DB,
RawFullSql: sql,
RawArguments: args,
}
@ -326,8 +304,8 @@ func (dbc *DBC) SelectBySQLWithChunkedIN(sql string, chunkSize int, args ...inte
func (dbc *DBC) InsertInto(into string) *dbr.InsertBuilder {
if dbc.trx == nil {
return &dbr.InsertBuilder{
Session: dbc.sess(),
Runner: dbc.con().Db,
Session: dbc.sess,
Runner: dbc.P.DB,
Into: into,
}
}

View File

@ -11,17 +11,25 @@ import (
"github.com/stretchr/testify/assert"
)
var settings = db.Settings{Host: "127.0.0.1", Port: "3306", User: "root", Pass: "test", Name: "tests"}
var logger = colog.NewCoLog(os.Stderr, "", 0)
func getDBC() *db.DBC {
dbc := db.GetDBC(logger, settings)
func getSettings() db.Settings {
//TODO: use ENV settings as well
return db.Settings{Host: "127.0.0.1", Port: "3306", User: "root", Pass: "test", Name: "tests"}
}
func getPool() *db.Pool {
return db.OpenPool(getSettings())
}
func getDBC(p *db.Pool) *db.DBC {
dbc := db.GetDBC(p, logger)
return dbc
}
func TestDefaultClientCharsetAndCollation(t *testing.T) {
dbc := getDBC()
defer dbc.Close()
p := getPool()
dbc := getDBC(p)
var result = make(map[string]string)
@ -58,13 +66,14 @@ func TestDefaultClientCharsetAndCollation(t *testing.T) {
}
func TestClientCharsetAndCollation(t *testing.T) {
DSNWithLatinCollation := settings
DSNWithLatinCollation := getSettings()
DSNWithLatinCollation.Params = "?collation=latin1_swedish_ci"
pool := db.OpenPool(DSNWithLatinCollation)
defer pool.Close()
var resultsLatin1 = make(map[string]string)
dbLatin1 := db.GetDBC(logger, DSNWithLatinCollation)
defer dbLatin1.Close()
dbLatin1 := db.GetDBC(pool, logger)
characterSets, err := dbLatin1.DB().Query("show variables where Variable_name in ('character_set_client', 'character_set_connection', 'character_set_results');")
assert.Nil(t, err)
@ -98,55 +107,10 @@ func TestClientCharsetAndCollation(t *testing.T) {
}
func TestCloseConn(t *testing.T) {
dbc := getDBC()
defer dbc.Close()
var res int
err := dbc.SelectBySQL("SELECT 1").LoadValue(&res)
assert.Nil(t, err)
assert.EqualValues(t, 1, res)
assert.True(t, dbc.IsOpen())
dbc.Close()
assert.False(t, dbc.IsOpen())
//connection is automatically restored
err = dbc.SelectBySQL("SELECT 1").LoadValue(&res)
assert.True(t, dbc.IsOpen())
assert.Nil(t, err)
assert.EqualValues(t, 1, res)
}
func TestDoubleCloseConnIsOk(t *testing.T) {
dbc := getDBC()
defer dbc.Close()
var res int
err := dbc.SelectBySQL("SELECT 1").LoadValue(&res)
assert.Nil(t, err)
assert.EqualValues(t, 1, res)
assert.True(t, dbc.IsOpen())
dbc.Close()
assert.False(t, dbc.IsOpen())
//connection is automatically restored
err = dbc.SelectBySQL("SELECT 1").LoadValue(&res)
assert.True(t, dbc.IsOpen())
assert.Nil(t, err)
assert.EqualValues(t, 1, res)
dbc.Close()
assert.False(t, dbc.IsOpen())
//connection is automatically restored
err = dbc.SelectBySQL("SELECT 1").LoadValue(&res)
assert.True(t, dbc.IsOpen())
assert.Nil(t, err)
assert.EqualValues(t, 1, res)
}
func createFooTable(t *testing.T) {
dbc := getDBC()
defer dbc.Close()
p := getPool()
dbc := getDBC(p)
defer p.Close()
_, err := dbc.DB().Exec("DROP TABLE IF EXISTS foo")
assert.Nil(t, err)
@ -157,15 +121,15 @@ func createFooTable(t *testing.T) {
func TestTransactionCommit(t *testing.T) {
createFooTable(t)
dbc := getDBC()
defer dbc.Close()
p := getPool()
dbc := getDBC(p)
defer p.Close()
assert.Nil(t, dbc.Begin())
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
//let's try a fresh connection
dbc1 := getDBC()
defer dbc1.Close()
dbc1 := getDBC(p)
assert.EqualValues(t, 0, countFoos(t, dbc1))
assert.EqualValues(t, 2, countFoos(t, dbc))
@ -173,23 +137,22 @@ func TestTransactionCommit(t *testing.T) {
assert.EqualValues(t, 2, countFoos(t, dbc))
//let's try a fresh connection
dbc2 := getDBC()
defer dbc2.Close()
dbc2 := getDBC(p)
assert.EqualValues(t, 2, countFoos(t, dbc2))
}
func TestTransactionCommitNestedAllOk(t *testing.T) {
createFooTable(t)
dbc := getDBC()
defer dbc.Close()
p := getPool()
dbc := getDBC(p)
defer p.Close()
//begin 1
assert.Nil(t, dbc.Begin())
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
//let's try a fresh connection
dbc1 := getDBC()
defer dbc1.Close()
dbc1 := getDBC(p)
assert.EqualValues(t, 0, countFoos(t, dbc1))
//begin 2
@ -200,8 +163,7 @@ func TestTransactionCommitNestedAllOk(t *testing.T) {
assert.EqualValues(t, 3, countFoos(t, dbc))
//let's try a fresh connection
dbc2 := getDBC()
defer dbc2.Close()
dbc2 := getDBC(p)
assert.EqualValues(t, 0, countFoos(t, dbc2))
//commit 1
@ -209,23 +171,22 @@ func TestTransactionCommitNestedAllOk(t *testing.T) {
assert.EqualValues(t, 3, countFoos(t, dbc))
//let's try a fresh connection
db3 := getDBC()
defer db3.Close()
db3 := getDBC(p)
assert.EqualValues(t, 3, countFoos(t, db3))
}
func TestTransactionCommitNestedRollback(t *testing.T) {
createFooTable(t)
dbc := getDBC()
defer dbc.Close()
p := getPool()
dbc := getDBC(p)
defer p.Close()
//begin 1
assert.Nil(t, dbc.Begin())
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
//let's try a fresh connection
dbc1 := getDBC()
defer dbc1.Close()
dbc1 := getDBC(p)
assert.EqualValues(t, 0, countFoos(t, dbc1))
//begin 2
@ -238,8 +199,7 @@ func TestTransactionCommitNestedRollback(t *testing.T) {
assert.EqualValues(t, 3, countFoos(t, dbc))
//let's try a fresh connection
dbc2 := getDBC()
defer dbc2.Close()
dbc2 := getDBC(p)
assert.EqualValues(t, 0, countFoos(t, dbc2))
//rollback 1
@ -247,16 +207,16 @@ func TestTransactionCommitNestedRollback(t *testing.T) {
assert.EqualValues(t, 0, countFoos(t, dbc))
//let's try a fresh connection
db3 := getDBC()
defer db3.Close()
db3 := getDBC(p)
assert.EqualValues(t, 0, countFoos(t, db3))
}
func TestTransactionRollbackOnDeferAllOK(t *testing.T) {
createFooTable(t)
dbc := getDBC()
defer dbc.Close()
p := getPool()
dbc := getDBC(p)
defer p.Close()
fn := func() {
dbc.Begin()
@ -278,16 +238,16 @@ func TestTransactionRollbackOnDeferAllOK(t *testing.T) {
}
fn()
dbc1 := getDBC()
defer dbc1.Close()
dbc1 := getDBC(p)
assert.EqualValues(t, 3, countFoos(t, dbc1))
}
func TestTransactionRollbackOnDefer(t *testing.T) {
createFooTable(t)
dbc := getDBC()
defer dbc.Close()
p := getPool()
dbc := getDBC(p)
defer p.Close()
fn := func() {
dbc.Begin()
@ -312,23 +272,22 @@ func TestTransactionRollbackOnDefer(t *testing.T) {
fn()
dbc1 := getDBC()
defer dbc1.Close()
dbc1 := getDBC(p)
assert.EqualValues(t, 0, countFoos(t, dbc1))
}
func TestTransaction(t *testing.T) {
createFooTable(t)
dbc := getDBC()
defer dbc.Close()
p := getPool()
dbc := getDBC(p)
defer p.Close()
assert.Nil(t, dbc.Transaction(func(dbs *db.DBC) error {
dbs.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
//let's try a fresh connection
dbc1 := getDBC()
defer dbc1.Close()
dbc1 := getDBC(p)
assert.EqualValues(t, 0, countFoos(t, dbc1))
assert.EqualValues(t, 2, countFoos(t, dbc))
@ -336,16 +295,16 @@ func TestTransaction(t *testing.T) {
}))
//let's try a fresh connection
dbc2 := getDBC()
defer dbc2.Close()
dbc2 := getDBC(p)
assert.EqualValues(t, 2, countFoos(t, dbc2))
}
func TestTransactionRollbackOnError(t *testing.T) {
createFooTable(t)
dbc := getDBC()
defer dbc.Close()
p := getPool()
dbc := getDBC(p)
defer p.Close()
err := dbc.Transaction(func(dbs *db.DBC) error {
dbs.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
@ -356,16 +315,16 @@ func TestTransactionRollbackOnError(t *testing.T) {
assert.EqualValues(t, "Opps", err.Error())
//let's try a fresh connection
dbc2 := getDBC()
defer dbc2.Close()
dbc2 := getDBC(p)
assert.EqualValues(t, 0, countFoos(t, dbc2))
}
func TestTransactionRollbackOnPanic(t *testing.T) {
createFooTable(t)
dbc := getDBC()
defer dbc.Close()
p := getPool()
dbc := getDBC(p)
defer p.Close()
defer func() {
if r := recover(); r != nil {
@ -373,8 +332,7 @@ func TestTransactionRollbackOnPanic(t *testing.T) {
assert.EqualValues(t, str, "Ooops")
//let's try a fresh connection
dbc2 := getDBC()
defer dbc2.Close()
dbc2 := getDBC(p)
assert.EqualValues(t, 0, countFoos(t, dbc2))
}
@ -399,8 +357,9 @@ func countFoos(t *testing.T, dbc *db.DBC) int {
func TestTransactionRollback(t *testing.T) {
createFooTable(t)
dbc := getDBC()
defer dbc.Close()
p := getPool()
dbc := getDBC(p)
defer p.Close()
assert.Nil(t, dbc.Begin())
{
_, err := dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
@ -422,8 +381,7 @@ func TestTransactionRollback(t *testing.T) {
}
//let's try a fresh connection
dbc2 := getDBC()
defer dbc2.Close()
dbc2 := getDBC(p)
{
err := dbc2.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res)
assert.Nil(t, err)
@ -434,8 +392,9 @@ func TestTransactionRollback(t *testing.T) {
func TestTransactionNestedRollback(t *testing.T) {
createFooTable(t)
dbc := getDBC()
defer dbc.Close()
p := getPool()
dbc := getDBC(p)
defer p.Close()
//begin 1
assert.Nil(t, dbc.Begin())
{