Migrating to more consistent usage of Pool

This commit is contained in:
Pavel Shevaev 2022-10-28 10:52:25 +03:00
parent 046f7d0d7b
commit d6d473e848
2 changed files with 116 additions and 179 deletions

124
db.go
View File

@ -8,7 +8,6 @@ import (
"git.bit5.ru/backend/dbr" "git.bit5.ru/backend/dbr"
"git.bit5.ru/backend/errors" "git.bit5.ru/backend/errors"
"git.bit5.ru/backend/mysql" "git.bit5.ru/backend/mysql"
"git.bit5.ru/backend/res_tracker"
) )
type Settings struct { type Settings struct {
@ -22,78 +21,57 @@ func (s *Settings) ConnStr() string {
return s.User + ":" + s.Pass + "@tcp(" + s.Host + ":" + s.Port + ")/" + s.Name + s.Params return s.User + ":" + s.Pass + "@tcp(" + s.Host + ":" + s.Port + ")/" + s.Name + s.Params
} }
type Pool struct {
DB *sql.DB
S Settings
}
func (p *Pool) Close() {
p.DB.Close()
}
type DBC struct { type DBC struct {
Logger *colog.CoLog Logger *colog.CoLog
s Settings P *Pool
_con *dbr.Connection //lazy one, should be accessed via con() method //NOTE: it's not a 'connection', it embeds sql.Pool (dbr uses strange names for entities)
_sess *dbr.Session //lazy one, should be accessed via sess() method con *dbr.Connection
sess *dbr.Session
trx *dbr.Tx trx *dbr.Tx
trxRefs int trxRefs int
commitTry int commitTry int
} }
func GetDBC(logger *colog.CoLog, s Settings) *DBC { func OpenPool(s Settings) *Pool {
driver := s.Driver
if len(driver) == 0 {
driver = "mysql"
}
//NOTE: sql.Open(..) doesn't happen to return an error
sqlDb, _ := sql.Open(driver, s.ConnStr())
logger = logger.Clone().AddPrefix("[" + s.Prefix + "] ") //TODO: take values from Settings
sqlDb.SetMaxIdleConns(100)
sqlDb.SetMaxOpenConns(0)
sqlDb.SetConnMaxLifetime(0)
sqlDb.SetConnMaxIdleTime(0)
dbc := &DBC{Logger: logger, s: s, _con: nil, _sess: nil} return &Pool{DB: sqlDb, S: s}
}
func GetDBC(p *Pool, logger *colog.CoLog) *DBC {
logger = logger.Clone().AddPrefix("[" + p.S.Prefix + "] ")
con := dbr.NewConnection(p.DB, nil)
sess := con.NewSession(&EventReceiver{logger: logger, s: p.S})
dbc := &DBC{Logger: logger, P: p, con: con, sess: sess}
return dbc return dbc
} }
//NOTE: In its current implementation this method creates a new Pool
// on each connection request. This is subotimal and should be
// addressed in a new version of the package
func (dbc *DBC) con() *dbr.Connection {
if dbc._con == nil {
driver := dbc.s.Driver
if len(driver) == 0 {
driver = "mysql"
}
//NOTE: sql.Open(..) doesn't happen to return an error
pool, _ := sql.Open(driver, dbc.s.ConnStr())
dbc._con = dbr.NewConnection(pool, nil)
res_tracker.Track(dbc)
}
return dbc._con
}
func (dbc *DBC) sess() *dbr.Session {
if dbc._sess == nil {
dbc._sess = dbc.con().NewSession(&EventReceiver{logger: dbc.Logger, s: dbc.s})
}
return dbc._sess
}
func (dbc *DBC) Open() {
dbc.con()
}
func (dbc *DBC) IsOpen() bool {
return dbc._con != nil
}
func (dbc *DBC) Close() error {
if dbc._con != nil {
res_tracker.Untrack(dbc)
dbc.Rollback()
err := dbc._con.Db.Close()
dbc._con = nil
dbc._sess = nil
return err
} else {
return nil
}
}
//NOTE: for low level stuff
func (dbc *DBC) DB() *sql.DB { func (dbc *DBC) DB() *sql.DB {
return dbc.con().Db return dbc.P.DB
} }
func (dbc *DBC) Transaction(txFunc func(dbc *DBC) error) (err error) { func (dbc *DBC) Transaction(txFunc func(dbc *DBC) error) (err error) {
@ -121,7 +99,7 @@ func (dbc *DBC) Begin() error {
dbc.trxRefs++ dbc.trxRefs++
return nil return nil
} }
trx, err := dbc.sess().Begin() trx, err := dbc.sess.Begin()
if err == nil { if err == nil {
dbc.trx = trx dbc.trx = trx
dbc.trxRefs = 1 dbc.trxRefs = 1
@ -169,8 +147,8 @@ func (dbc *DBC) Commit() error {
func (dbc *DBC) Update(table string) *dbr.UpdateBuilder { func (dbc *DBC) Update(table string) *dbr.UpdateBuilder {
if dbc.trx == nil { if dbc.trx == nil {
return &dbr.UpdateBuilder{ return &dbr.UpdateBuilder{
Session: dbc.sess(), Session: dbc.sess,
Runner: dbc.con().Db, Runner: dbc.P.DB,
Table: table, Table: table,
} }
} }
@ -186,8 +164,8 @@ func (dbc *DBC) Update(table string) *dbr.UpdateBuilder {
func (dbc *DBC) UpdateBySQL(sql string, args ...interface{}) *dbr.UpdateBuilder { func (dbc *DBC) UpdateBySQL(sql string, args ...interface{}) *dbr.UpdateBuilder {
if dbc.trx == nil { if dbc.trx == nil {
return &dbr.UpdateBuilder{ return &dbr.UpdateBuilder{
Session: dbc.sess(), Session: dbc.sess,
Runner: dbc.con().Db, Runner: dbc.P.DB,
RawFullSql: sql, RawFullSql: sql,
RawArguments: args, RawArguments: args,
} }
@ -205,8 +183,8 @@ func (dbc *DBC) UpdateBySQL(sql string, args ...interface{}) *dbr.UpdateBuilder
func (dbc *DBC) DeleteFrom(from string) *dbr.DeleteBuilder { func (dbc *DBC) DeleteFrom(from string) *dbr.DeleteBuilder {
if dbc.trx == nil { if dbc.trx == nil {
return &dbr.DeleteBuilder{ return &dbr.DeleteBuilder{
Session: dbc.sess(), Session: dbc.sess,
Runner: dbc.con().Db, Runner: dbc.P.DB,
From: from, From: from,
} }
} }
@ -222,8 +200,8 @@ func (dbc *DBC) DeleteFrom(from string) *dbr.DeleteBuilder {
func (dbc *DBC) Select(cols ...string) *dbr.SelectBuilder { func (dbc *DBC) Select(cols ...string) *dbr.SelectBuilder {
if dbc.trx == nil { if dbc.trx == nil {
return &dbr.SelectBuilder{ return &dbr.SelectBuilder{
Session: dbc.sess(), Session: dbc.sess,
Runner: dbc.con().Db, Runner: dbc.P.DB,
Columns: cols, Columns: cols,
} }
} }
@ -239,8 +217,8 @@ func (dbc *DBC) Select(cols ...string) *dbr.SelectBuilder {
func (dbc *DBC) SelectBySQL(sql string, args ...interface{}) *dbr.SelectBuilder { func (dbc *DBC) SelectBySQL(sql string, args ...interface{}) *dbr.SelectBuilder {
if dbc.trx == nil { if dbc.trx == nil {
return &dbr.SelectBuilder{ return &dbr.SelectBuilder{
Session: dbc.sess(), Session: dbc.sess,
Runner: dbc.con().Db, Runner: dbc.P.DB,
RawFullSql: sql, RawFullSql: sql,
RawArguments: args, RawArguments: args,
} }
@ -326,8 +304,8 @@ func (dbc *DBC) SelectBySQLWithChunkedIN(sql string, chunkSize int, args ...inte
func (dbc *DBC) InsertInto(into string) *dbr.InsertBuilder { func (dbc *DBC) InsertInto(into string) *dbr.InsertBuilder {
if dbc.trx == nil { if dbc.trx == nil {
return &dbr.InsertBuilder{ return &dbr.InsertBuilder{
Session: dbc.sess(), Session: dbc.sess,
Runner: dbc.con().Db, Runner: dbc.P.DB,
Into: into, Into: into,
} }
} }

View File

@ -11,17 +11,25 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
var settings = db.Settings{Host: "127.0.0.1", Port: "3306", User: "root", Pass: "test", Name: "tests"}
var logger = colog.NewCoLog(os.Stderr, "", 0) var logger = colog.NewCoLog(os.Stderr, "", 0)
func getDBC() *db.DBC { func getSettings() db.Settings {
dbc := db.GetDBC(logger, settings) //TODO: use ENV settings as well
return db.Settings{Host: "127.0.0.1", Port: "3306", User: "root", Pass: "test", Name: "tests"}
}
func getPool() *db.Pool {
return db.OpenPool(getSettings())
}
func getDBC(p *db.Pool) *db.DBC {
dbc := db.GetDBC(p, logger)
return dbc return dbc
} }
func TestDefaultClientCharsetAndCollation(t *testing.T) { func TestDefaultClientCharsetAndCollation(t *testing.T) {
dbc := getDBC() p := getPool()
defer dbc.Close() dbc := getDBC(p)
var result = make(map[string]string) var result = make(map[string]string)
@ -58,13 +66,14 @@ func TestDefaultClientCharsetAndCollation(t *testing.T) {
} }
func TestClientCharsetAndCollation(t *testing.T) { func TestClientCharsetAndCollation(t *testing.T) {
DSNWithLatinCollation := settings DSNWithLatinCollation := getSettings()
DSNWithLatinCollation.Params = "?collation=latin1_swedish_ci" DSNWithLatinCollation.Params = "?collation=latin1_swedish_ci"
pool := db.OpenPool(DSNWithLatinCollation)
defer pool.Close()
var resultsLatin1 = make(map[string]string) var resultsLatin1 = make(map[string]string)
dbLatin1 := db.GetDBC(logger, DSNWithLatinCollation) dbLatin1 := db.GetDBC(pool, logger)
defer dbLatin1.Close()
characterSets, err := dbLatin1.DB().Query("show variables where Variable_name in ('character_set_client', 'character_set_connection', 'character_set_results');") characterSets, err := dbLatin1.DB().Query("show variables where Variable_name in ('character_set_client', 'character_set_connection', 'character_set_results');")
assert.Nil(t, err) assert.Nil(t, err)
@ -98,55 +107,10 @@ func TestClientCharsetAndCollation(t *testing.T) {
} }
func TestCloseConn(t *testing.T) {
dbc := getDBC()
defer dbc.Close()
var res int
err := dbc.SelectBySQL("SELECT 1").LoadValue(&res)
assert.Nil(t, err)
assert.EqualValues(t, 1, res)
assert.True(t, dbc.IsOpen())
dbc.Close()
assert.False(t, dbc.IsOpen())
//connection is automatically restored
err = dbc.SelectBySQL("SELECT 1").LoadValue(&res)
assert.True(t, dbc.IsOpen())
assert.Nil(t, err)
assert.EqualValues(t, 1, res)
}
func TestDoubleCloseConnIsOk(t *testing.T) {
dbc := getDBC()
defer dbc.Close()
var res int
err := dbc.SelectBySQL("SELECT 1").LoadValue(&res)
assert.Nil(t, err)
assert.EqualValues(t, 1, res)
assert.True(t, dbc.IsOpen())
dbc.Close()
assert.False(t, dbc.IsOpen())
//connection is automatically restored
err = dbc.SelectBySQL("SELECT 1").LoadValue(&res)
assert.True(t, dbc.IsOpen())
assert.Nil(t, err)
assert.EqualValues(t, 1, res)
dbc.Close()
assert.False(t, dbc.IsOpen())
//connection is automatically restored
err = dbc.SelectBySQL("SELECT 1").LoadValue(&res)
assert.True(t, dbc.IsOpen())
assert.Nil(t, err)
assert.EqualValues(t, 1, res)
}
func createFooTable(t *testing.T) { func createFooTable(t *testing.T) {
dbc := getDBC() p := getPool()
defer dbc.Close() dbc := getDBC(p)
defer p.Close()
_, err := dbc.DB().Exec("DROP TABLE IF EXISTS foo") _, err := dbc.DB().Exec("DROP TABLE IF EXISTS foo")
assert.Nil(t, err) assert.Nil(t, err)
@ -157,15 +121,15 @@ func createFooTable(t *testing.T) {
func TestTransactionCommit(t *testing.T) { func TestTransactionCommit(t *testing.T) {
createFooTable(t) createFooTable(t)
dbc := getDBC() p := getPool()
defer dbc.Close() dbc := getDBC(p)
defer p.Close()
assert.Nil(t, dbc.Begin()) assert.Nil(t, dbc.Begin())
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec() dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
//let's try a fresh connection //let's try a fresh connection
dbc1 := getDBC() dbc1 := getDBC(p)
defer dbc1.Close()
assert.EqualValues(t, 0, countFoos(t, dbc1)) assert.EqualValues(t, 0, countFoos(t, dbc1))
assert.EqualValues(t, 2, countFoos(t, dbc)) assert.EqualValues(t, 2, countFoos(t, dbc))
@ -173,23 +137,22 @@ func TestTransactionCommit(t *testing.T) {
assert.EqualValues(t, 2, countFoos(t, dbc)) assert.EqualValues(t, 2, countFoos(t, dbc))
//let's try a fresh connection //let's try a fresh connection
dbc2 := getDBC() dbc2 := getDBC(p)
defer dbc2.Close()
assert.EqualValues(t, 2, countFoos(t, dbc2)) assert.EqualValues(t, 2, countFoos(t, dbc2))
} }
func TestTransactionCommitNestedAllOk(t *testing.T) { func TestTransactionCommitNestedAllOk(t *testing.T) {
createFooTable(t) createFooTable(t)
dbc := getDBC() p := getPool()
defer dbc.Close() dbc := getDBC(p)
defer p.Close()
//begin 1 //begin 1
assert.Nil(t, dbc.Begin()) assert.Nil(t, dbc.Begin())
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec() dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
//let's try a fresh connection //let's try a fresh connection
dbc1 := getDBC() dbc1 := getDBC(p)
defer dbc1.Close()
assert.EqualValues(t, 0, countFoos(t, dbc1)) assert.EqualValues(t, 0, countFoos(t, dbc1))
//begin 2 //begin 2
@ -200,8 +163,7 @@ func TestTransactionCommitNestedAllOk(t *testing.T) {
assert.EqualValues(t, 3, countFoos(t, dbc)) assert.EqualValues(t, 3, countFoos(t, dbc))
//let's try a fresh connection //let's try a fresh connection
dbc2 := getDBC() dbc2 := getDBC(p)
defer dbc2.Close()
assert.EqualValues(t, 0, countFoos(t, dbc2)) assert.EqualValues(t, 0, countFoos(t, dbc2))
//commit 1 //commit 1
@ -209,23 +171,22 @@ func TestTransactionCommitNestedAllOk(t *testing.T) {
assert.EqualValues(t, 3, countFoos(t, dbc)) assert.EqualValues(t, 3, countFoos(t, dbc))
//let's try a fresh connection //let's try a fresh connection
db3 := getDBC() db3 := getDBC(p)
defer db3.Close()
assert.EqualValues(t, 3, countFoos(t, db3)) assert.EqualValues(t, 3, countFoos(t, db3))
} }
func TestTransactionCommitNestedRollback(t *testing.T) { func TestTransactionCommitNestedRollback(t *testing.T) {
createFooTable(t) createFooTable(t)
dbc := getDBC() p := getPool()
defer dbc.Close() dbc := getDBC(p)
defer p.Close()
//begin 1 //begin 1
assert.Nil(t, dbc.Begin()) assert.Nil(t, dbc.Begin())
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec() dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
//let's try a fresh connection //let's try a fresh connection
dbc1 := getDBC() dbc1 := getDBC(p)
defer dbc1.Close()
assert.EqualValues(t, 0, countFoos(t, dbc1)) assert.EqualValues(t, 0, countFoos(t, dbc1))
//begin 2 //begin 2
@ -238,8 +199,7 @@ func TestTransactionCommitNestedRollback(t *testing.T) {
assert.EqualValues(t, 3, countFoos(t, dbc)) assert.EqualValues(t, 3, countFoos(t, dbc))
//let's try a fresh connection //let's try a fresh connection
dbc2 := getDBC() dbc2 := getDBC(p)
defer dbc2.Close()
assert.EqualValues(t, 0, countFoos(t, dbc2)) assert.EqualValues(t, 0, countFoos(t, dbc2))
//rollback 1 //rollback 1
@ -247,16 +207,16 @@ func TestTransactionCommitNestedRollback(t *testing.T) {
assert.EqualValues(t, 0, countFoos(t, dbc)) assert.EqualValues(t, 0, countFoos(t, dbc))
//let's try a fresh connection //let's try a fresh connection
db3 := getDBC() db3 := getDBC(p)
defer db3.Close()
assert.EqualValues(t, 0, countFoos(t, db3)) assert.EqualValues(t, 0, countFoos(t, db3))
} }
func TestTransactionRollbackOnDeferAllOK(t *testing.T) { func TestTransactionRollbackOnDeferAllOK(t *testing.T) {
createFooTable(t) createFooTable(t)
dbc := getDBC() p := getPool()
defer dbc.Close() dbc := getDBC(p)
defer p.Close()
fn := func() { fn := func() {
dbc.Begin() dbc.Begin()
@ -278,16 +238,16 @@ func TestTransactionRollbackOnDeferAllOK(t *testing.T) {
} }
fn() fn()
dbc1 := getDBC() dbc1 := getDBC(p)
defer dbc1.Close()
assert.EqualValues(t, 3, countFoos(t, dbc1)) assert.EqualValues(t, 3, countFoos(t, dbc1))
} }
func TestTransactionRollbackOnDefer(t *testing.T) { func TestTransactionRollbackOnDefer(t *testing.T) {
createFooTable(t) createFooTable(t)
dbc := getDBC() p := getPool()
defer dbc.Close() dbc := getDBC(p)
defer p.Close()
fn := func() { fn := func() {
dbc.Begin() dbc.Begin()
@ -312,23 +272,22 @@ func TestTransactionRollbackOnDefer(t *testing.T) {
fn() fn()
dbc1 := getDBC() dbc1 := getDBC(p)
defer dbc1.Close()
assert.EqualValues(t, 0, countFoos(t, dbc1)) assert.EqualValues(t, 0, countFoos(t, dbc1))
} }
func TestTransaction(t *testing.T) { func TestTransaction(t *testing.T) {
createFooTable(t) createFooTable(t)
dbc := getDBC() p := getPool()
defer dbc.Close() dbc := getDBC(p)
defer p.Close()
assert.Nil(t, dbc.Transaction(func(dbs *db.DBC) error { assert.Nil(t, dbc.Transaction(func(dbs *db.DBC) error {
dbs.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec() dbs.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
//let's try a fresh connection //let's try a fresh connection
dbc1 := getDBC() dbc1 := getDBC(p)
defer dbc1.Close()
assert.EqualValues(t, 0, countFoos(t, dbc1)) assert.EqualValues(t, 0, countFoos(t, dbc1))
assert.EqualValues(t, 2, countFoos(t, dbc)) assert.EqualValues(t, 2, countFoos(t, dbc))
@ -336,16 +295,16 @@ func TestTransaction(t *testing.T) {
})) }))
//let's try a fresh connection //let's try a fresh connection
dbc2 := getDBC() dbc2 := getDBC(p)
defer dbc2.Close()
assert.EqualValues(t, 2, countFoos(t, dbc2)) assert.EqualValues(t, 2, countFoos(t, dbc2))
} }
func TestTransactionRollbackOnError(t *testing.T) { func TestTransactionRollbackOnError(t *testing.T) {
createFooTable(t) createFooTable(t)
dbc := getDBC() p := getPool()
defer dbc.Close() dbc := getDBC(p)
defer p.Close()
err := dbc.Transaction(func(dbs *db.DBC) error { err := dbc.Transaction(func(dbs *db.DBC) error {
dbs.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec() dbs.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
@ -356,16 +315,16 @@ func TestTransactionRollbackOnError(t *testing.T) {
assert.EqualValues(t, "Opps", err.Error()) assert.EqualValues(t, "Opps", err.Error())
//let's try a fresh connection //let's try a fresh connection
dbc2 := getDBC() dbc2 := getDBC(p)
defer dbc2.Close()
assert.EqualValues(t, 0, countFoos(t, dbc2)) assert.EqualValues(t, 0, countFoos(t, dbc2))
} }
func TestTransactionRollbackOnPanic(t *testing.T) { func TestTransactionRollbackOnPanic(t *testing.T) {
createFooTable(t) createFooTable(t)
dbc := getDBC() p := getPool()
defer dbc.Close() dbc := getDBC(p)
defer p.Close()
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
@ -373,8 +332,7 @@ func TestTransactionRollbackOnPanic(t *testing.T) {
assert.EqualValues(t, str, "Ooops") assert.EqualValues(t, str, "Ooops")
//let's try a fresh connection //let's try a fresh connection
dbc2 := getDBC() dbc2 := getDBC(p)
defer dbc2.Close()
assert.EqualValues(t, 0, countFoos(t, dbc2)) assert.EqualValues(t, 0, countFoos(t, dbc2))
} }
@ -399,8 +357,9 @@ func countFoos(t *testing.T, dbc *db.DBC) int {
func TestTransactionRollback(t *testing.T) { func TestTransactionRollback(t *testing.T) {
createFooTable(t) createFooTable(t)
dbc := getDBC() p := getPool()
defer dbc.Close() dbc := getDBC(p)
defer p.Close()
assert.Nil(t, dbc.Begin()) assert.Nil(t, dbc.Begin())
{ {
_, err := dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec() _, err := dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
@ -422,8 +381,7 @@ func TestTransactionRollback(t *testing.T) {
} }
//let's try a fresh connection //let's try a fresh connection
dbc2 := getDBC() dbc2 := getDBC(p)
defer dbc2.Close()
{ {
err := dbc2.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res) err := dbc2.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res)
assert.Nil(t, err) assert.Nil(t, err)
@ -434,8 +392,9 @@ func TestTransactionRollback(t *testing.T) {
func TestTransactionNestedRollback(t *testing.T) { func TestTransactionNestedRollback(t *testing.T) {
createFooTable(t) createFooTable(t)
dbc := getDBC() p := getPool()
defer dbc.Close() dbc := getDBC(p)
defer p.Close()
//begin 1 //begin 1
assert.Nil(t, dbc.Begin()) assert.Nil(t, dbc.Begin())
{ {