diff --git a/db.go b/db.go index 49a9b6b..c7ef812 100644 --- a/db.go +++ b/db.go @@ -8,7 +8,6 @@ import ( "git.bit5.ru/backend/dbr" "git.bit5.ru/backend/errors" "git.bit5.ru/backend/mysql" - "git.bit5.ru/backend/res_tracker" ) type Settings struct { @@ -22,78 +21,57 @@ func (s *Settings) ConnStr() string { return s.User + ":" + s.Pass + "@tcp(" + s.Host + ":" + s.Port + ")/" + s.Name + s.Params } +type Pool struct { + DB *sql.DB + S Settings +} + +func (p *Pool) Close() { + p.DB.Close() +} + type DBC struct { - Logger *colog.CoLog - s Settings - _con *dbr.Connection //lazy one, should be accessed via con() method - _sess *dbr.Session //lazy one, should be accessed via sess() method + Logger *colog.CoLog + P *Pool + //NOTE: it's not a 'connection', it embeds sql.Pool (dbr uses strange names for entities) + con *dbr.Connection + sess *dbr.Session trx *dbr.Tx trxRefs int commitTry int } -func GetDBC(logger *colog.CoLog, s Settings) *DBC { +func OpenPool(s Settings) *Pool { + driver := s.Driver + if len(driver) == 0 { + driver = "mysql" + } + //NOTE: sql.Open(..) doesn't happen to return an error + sqlDb, _ := sql.Open(driver, s.ConnStr()) - logger = logger.Clone().AddPrefix("[" + s.Prefix + "] ") + //TODO: take values from Settings + sqlDb.SetMaxIdleConns(100) + sqlDb.SetMaxOpenConns(0) + sqlDb.SetConnMaxLifetime(0) + sqlDb.SetConnMaxIdleTime(0) - dbc := &DBC{Logger: logger, s: s, _con: nil, _sess: nil} + return &Pool{DB: sqlDb, S: s} +} + +func GetDBC(p *Pool, logger *colog.CoLog) *DBC { + + logger = logger.Clone().AddPrefix("[" + p.S.Prefix + "] ") + + con := dbr.NewConnection(p.DB, nil) + sess := con.NewSession(&EventReceiver{logger: logger, s: p.S}) + + dbc := &DBC{Logger: logger, P: p, con: con, sess: sess} return dbc } -//NOTE: In its current implementation this method creates a new Pool -// on each connection request. This is subotimal and should be -// addressed in a new version of the package -func (dbc *DBC) con() *dbr.Connection { - if dbc._con == nil { - driver := dbc.s.Driver - if len(driver) == 0 { - driver = "mysql" - } - //NOTE: sql.Open(..) doesn't happen to return an error - pool, _ := sql.Open(driver, dbc.s.ConnStr()) - - dbc._con = dbr.NewConnection(pool, nil) - - res_tracker.Track(dbc) - } - return dbc._con -} - -func (dbc *DBC) sess() *dbr.Session { - if dbc._sess == nil { - dbc._sess = dbc.con().NewSession(&EventReceiver{logger: dbc.Logger, s: dbc.s}) - } - return dbc._sess -} - -func (dbc *DBC) Open() { - dbc.con() -} - -func (dbc *DBC) IsOpen() bool { - return dbc._con != nil -} - -func (dbc *DBC) Close() error { - if dbc._con != nil { - res_tracker.Untrack(dbc) - - dbc.Rollback() - err := dbc._con.Db.Close() - - dbc._con = nil - dbc._sess = nil - - return err - } else { - return nil - } -} - -//NOTE: for low level stuff func (dbc *DBC) DB() *sql.DB { - return dbc.con().Db + return dbc.P.DB } func (dbc *DBC) Transaction(txFunc func(dbc *DBC) error) (err error) { @@ -121,7 +99,7 @@ func (dbc *DBC) Begin() error { dbc.trxRefs++ return nil } - trx, err := dbc.sess().Begin() + trx, err := dbc.sess.Begin() if err == nil { dbc.trx = trx dbc.trxRefs = 1 @@ -169,8 +147,8 @@ func (dbc *DBC) Commit() error { func (dbc *DBC) Update(table string) *dbr.UpdateBuilder { if dbc.trx == nil { return &dbr.UpdateBuilder{ - Session: dbc.sess(), - Runner: dbc.con().Db, + Session: dbc.sess, + Runner: dbc.P.DB, Table: table, } } @@ -186,8 +164,8 @@ func (dbc *DBC) Update(table string) *dbr.UpdateBuilder { func (dbc *DBC) UpdateBySQL(sql string, args ...interface{}) *dbr.UpdateBuilder { if dbc.trx == nil { return &dbr.UpdateBuilder{ - Session: dbc.sess(), - Runner: dbc.con().Db, + Session: dbc.sess, + Runner: dbc.P.DB, RawFullSql: sql, RawArguments: args, } @@ -205,8 +183,8 @@ func (dbc *DBC) UpdateBySQL(sql string, args ...interface{}) *dbr.UpdateBuilder func (dbc *DBC) DeleteFrom(from string) *dbr.DeleteBuilder { if dbc.trx == nil { return &dbr.DeleteBuilder{ - Session: dbc.sess(), - Runner: dbc.con().Db, + Session: dbc.sess, + Runner: dbc.P.DB, From: from, } } @@ -222,8 +200,8 @@ func (dbc *DBC) DeleteFrom(from string) *dbr.DeleteBuilder { func (dbc *DBC) Select(cols ...string) *dbr.SelectBuilder { if dbc.trx == nil { return &dbr.SelectBuilder{ - Session: dbc.sess(), - Runner: dbc.con().Db, + Session: dbc.sess, + Runner: dbc.P.DB, Columns: cols, } } @@ -239,8 +217,8 @@ func (dbc *DBC) Select(cols ...string) *dbr.SelectBuilder { func (dbc *DBC) SelectBySQL(sql string, args ...interface{}) *dbr.SelectBuilder { if dbc.trx == nil { return &dbr.SelectBuilder{ - Session: dbc.sess(), - Runner: dbc.con().Db, + Session: dbc.sess, + Runner: dbc.P.DB, RawFullSql: sql, RawArguments: args, } @@ -326,8 +304,8 @@ func (dbc *DBC) SelectBySQLWithChunkedIN(sql string, chunkSize int, args ...inte func (dbc *DBC) InsertInto(into string) *dbr.InsertBuilder { if dbc.trx == nil { return &dbr.InsertBuilder{ - Session: dbc.sess(), - Runner: dbc.con().Db, + Session: dbc.sess, + Runner: dbc.P.DB, Into: into, } } diff --git a/db_test.go b/db_test.go index f62c766..bd97cee 100644 --- a/db_test.go +++ b/db_test.go @@ -11,17 +11,25 @@ import ( "github.com/stretchr/testify/assert" ) -var settings = db.Settings{Host: "127.0.0.1", Port: "3306", User: "root", Pass: "test", Name: "tests"} var logger = colog.NewCoLog(os.Stderr, "", 0) -func getDBC() *db.DBC { - dbc := db.GetDBC(logger, settings) +func getSettings() db.Settings { + //TODO: use ENV settings as well + return db.Settings{Host: "127.0.0.1", Port: "3306", User: "root", Pass: "test", Name: "tests"} +} + +func getPool() *db.Pool { + return db.OpenPool(getSettings()) +} + +func getDBC(p *db.Pool) *db.DBC { + dbc := db.GetDBC(p, logger) return dbc } func TestDefaultClientCharsetAndCollation(t *testing.T) { - dbc := getDBC() - defer dbc.Close() + p := getPool() + dbc := getDBC(p) var result = make(map[string]string) @@ -58,13 +66,14 @@ func TestDefaultClientCharsetAndCollation(t *testing.T) { } func TestClientCharsetAndCollation(t *testing.T) { - DSNWithLatinCollation := settings + DSNWithLatinCollation := getSettings() DSNWithLatinCollation.Params = "?collation=latin1_swedish_ci" + pool := db.OpenPool(DSNWithLatinCollation) + defer pool.Close() var resultsLatin1 = make(map[string]string) - dbLatin1 := db.GetDBC(logger, DSNWithLatinCollation) - defer dbLatin1.Close() + dbLatin1 := db.GetDBC(pool, logger) characterSets, err := dbLatin1.DB().Query("show variables where Variable_name in ('character_set_client', 'character_set_connection', 'character_set_results');") assert.Nil(t, err) @@ -98,55 +107,10 @@ func TestClientCharsetAndCollation(t *testing.T) { } -func TestCloseConn(t *testing.T) { - dbc := getDBC() - defer dbc.Close() - - var res int - err := dbc.SelectBySQL("SELECT 1").LoadValue(&res) - assert.Nil(t, err) - assert.EqualValues(t, 1, res) - assert.True(t, dbc.IsOpen()) - - dbc.Close() - assert.False(t, dbc.IsOpen()) - //connection is automatically restored - err = dbc.SelectBySQL("SELECT 1").LoadValue(&res) - assert.True(t, dbc.IsOpen()) - assert.Nil(t, err) - assert.EqualValues(t, 1, res) -} - -func TestDoubleCloseConnIsOk(t *testing.T) { - dbc := getDBC() - defer dbc.Close() - - var res int - err := dbc.SelectBySQL("SELECT 1").LoadValue(&res) - assert.Nil(t, err) - assert.EqualValues(t, 1, res) - assert.True(t, dbc.IsOpen()) - - dbc.Close() - assert.False(t, dbc.IsOpen()) - //connection is automatically restored - err = dbc.SelectBySQL("SELECT 1").LoadValue(&res) - assert.True(t, dbc.IsOpen()) - assert.Nil(t, err) - assert.EqualValues(t, 1, res) - - dbc.Close() - assert.False(t, dbc.IsOpen()) - //connection is automatically restored - err = dbc.SelectBySQL("SELECT 1").LoadValue(&res) - assert.True(t, dbc.IsOpen()) - assert.Nil(t, err) - assert.EqualValues(t, 1, res) -} - func createFooTable(t *testing.T) { - dbc := getDBC() - defer dbc.Close() + p := getPool() + dbc := getDBC(p) + defer p.Close() _, err := dbc.DB().Exec("DROP TABLE IF EXISTS foo") assert.Nil(t, err) @@ -157,15 +121,15 @@ func createFooTable(t *testing.T) { func TestTransactionCommit(t *testing.T) { createFooTable(t) - dbc := getDBC() - defer dbc.Close() + p := getPool() + dbc := getDBC(p) + defer p.Close() assert.Nil(t, dbc.Begin()) dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec() //let's try a fresh connection - dbc1 := getDBC() - defer dbc1.Close() + dbc1 := getDBC(p) assert.EqualValues(t, 0, countFoos(t, dbc1)) assert.EqualValues(t, 2, countFoos(t, dbc)) @@ -173,23 +137,22 @@ func TestTransactionCommit(t *testing.T) { assert.EqualValues(t, 2, countFoos(t, dbc)) //let's try a fresh connection - dbc2 := getDBC() - defer dbc2.Close() + dbc2 := getDBC(p) assert.EqualValues(t, 2, countFoos(t, dbc2)) } func TestTransactionCommitNestedAllOk(t *testing.T) { createFooTable(t) - dbc := getDBC() - defer dbc.Close() + p := getPool() + dbc := getDBC(p) + defer p.Close() //begin 1 assert.Nil(t, dbc.Begin()) dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec() //let's try a fresh connection - dbc1 := getDBC() - defer dbc1.Close() + dbc1 := getDBC(p) assert.EqualValues(t, 0, countFoos(t, dbc1)) //begin 2 @@ -200,8 +163,7 @@ func TestTransactionCommitNestedAllOk(t *testing.T) { assert.EqualValues(t, 3, countFoos(t, dbc)) //let's try a fresh connection - dbc2 := getDBC() - defer dbc2.Close() + dbc2 := getDBC(p) assert.EqualValues(t, 0, countFoos(t, dbc2)) //commit 1 @@ -209,23 +171,22 @@ func TestTransactionCommitNestedAllOk(t *testing.T) { assert.EqualValues(t, 3, countFoos(t, dbc)) //let's try a fresh connection - db3 := getDBC() - defer db3.Close() + db3 := getDBC(p) assert.EqualValues(t, 3, countFoos(t, db3)) } func TestTransactionCommitNestedRollback(t *testing.T) { createFooTable(t) - dbc := getDBC() - defer dbc.Close() + p := getPool() + dbc := getDBC(p) + defer p.Close() //begin 1 assert.Nil(t, dbc.Begin()) dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec() //let's try a fresh connection - dbc1 := getDBC() - defer dbc1.Close() + dbc1 := getDBC(p) assert.EqualValues(t, 0, countFoos(t, dbc1)) //begin 2 @@ -238,8 +199,7 @@ func TestTransactionCommitNestedRollback(t *testing.T) { assert.EqualValues(t, 3, countFoos(t, dbc)) //let's try a fresh connection - dbc2 := getDBC() - defer dbc2.Close() + dbc2 := getDBC(p) assert.EqualValues(t, 0, countFoos(t, dbc2)) //rollback 1 @@ -247,16 +207,16 @@ func TestTransactionCommitNestedRollback(t *testing.T) { assert.EqualValues(t, 0, countFoos(t, dbc)) //let's try a fresh connection - db3 := getDBC() - defer db3.Close() + db3 := getDBC(p) assert.EqualValues(t, 0, countFoos(t, db3)) } func TestTransactionRollbackOnDeferAllOK(t *testing.T) { createFooTable(t) - dbc := getDBC() - defer dbc.Close() + p := getPool() + dbc := getDBC(p) + defer p.Close() fn := func() { dbc.Begin() @@ -278,16 +238,16 @@ func TestTransactionRollbackOnDeferAllOK(t *testing.T) { } fn() - dbc1 := getDBC() - defer dbc1.Close() + dbc1 := getDBC(p) assert.EqualValues(t, 3, countFoos(t, dbc1)) } func TestTransactionRollbackOnDefer(t *testing.T) { createFooTable(t) - dbc := getDBC() - defer dbc.Close() + p := getPool() + dbc := getDBC(p) + defer p.Close() fn := func() { dbc.Begin() @@ -312,23 +272,22 @@ func TestTransactionRollbackOnDefer(t *testing.T) { fn() - dbc1 := getDBC() - defer dbc1.Close() + dbc1 := getDBC(p) assert.EqualValues(t, 0, countFoos(t, dbc1)) } func TestTransaction(t *testing.T) { createFooTable(t) - dbc := getDBC() - defer dbc.Close() + p := getPool() + dbc := getDBC(p) + defer p.Close() assert.Nil(t, dbc.Transaction(func(dbs *db.DBC) error { dbs.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec() //let's try a fresh connection - dbc1 := getDBC() - defer dbc1.Close() + dbc1 := getDBC(p) assert.EqualValues(t, 0, countFoos(t, dbc1)) assert.EqualValues(t, 2, countFoos(t, dbc)) @@ -336,16 +295,16 @@ func TestTransaction(t *testing.T) { })) //let's try a fresh connection - dbc2 := getDBC() - defer dbc2.Close() + dbc2 := getDBC(p) assert.EqualValues(t, 2, countFoos(t, dbc2)) } func TestTransactionRollbackOnError(t *testing.T) { createFooTable(t) - dbc := getDBC() - defer dbc.Close() + p := getPool() + dbc := getDBC(p) + defer p.Close() err := dbc.Transaction(func(dbs *db.DBC) error { dbs.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec() @@ -356,16 +315,16 @@ func TestTransactionRollbackOnError(t *testing.T) { assert.EqualValues(t, "Opps", err.Error()) //let's try a fresh connection - dbc2 := getDBC() - defer dbc2.Close() + dbc2 := getDBC(p) assert.EqualValues(t, 0, countFoos(t, dbc2)) } func TestTransactionRollbackOnPanic(t *testing.T) { createFooTable(t) - dbc := getDBC() - defer dbc.Close() + p := getPool() + dbc := getDBC(p) + defer p.Close() defer func() { if r := recover(); r != nil { @@ -373,8 +332,7 @@ func TestTransactionRollbackOnPanic(t *testing.T) { assert.EqualValues(t, str, "Ooops") //let's try a fresh connection - dbc2 := getDBC() - defer dbc2.Close() + dbc2 := getDBC(p) assert.EqualValues(t, 0, countFoos(t, dbc2)) } @@ -399,8 +357,9 @@ func countFoos(t *testing.T, dbc *db.DBC) int { func TestTransactionRollback(t *testing.T) { createFooTable(t) - dbc := getDBC() - defer dbc.Close() + p := getPool() + dbc := getDBC(p) + defer p.Close() assert.Nil(t, dbc.Begin()) { _, err := dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec() @@ -422,8 +381,7 @@ func TestTransactionRollback(t *testing.T) { } //let's try a fresh connection - dbc2 := getDBC() - defer dbc2.Close() + dbc2 := getDBC(p) { err := dbc2.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res) assert.Nil(t, err) @@ -434,8 +392,9 @@ func TestTransactionRollback(t *testing.T) { func TestTransactionNestedRollback(t *testing.T) { createFooTable(t) - dbc := getDBC() - defer dbc.Close() + p := getPool() + dbc := getDBC(p) + defer p.Close() //begin 1 assert.Nil(t, dbc.Begin()) {