package db_test import ( "log" "os" "testing" "git.bit5.ru/backend/db" "git.bit5.ru/backend/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/go-logr/stdr" ) var logger = stdr.New(log.New(os.Stdout, "", log.Lshortfile)) func getSettings() db.Settings { //TODO: use ENV settings as well //stdr.SetVerbosity(2) return db.Settings{Host: "127.0.0.1", Port: "3306", User: "root", Pass: "test", Name: "tests", Prefix: "tests"} } func getPool() *db.Pool { return db.OpenPool(getSettings()) } func getDBC(p *db.Pool) *db.DBC { dbc := db.GetDBC(p, logger) return dbc } func TestDefaultClientCharsetAndCollation(t *testing.T) { p := getPool() dbc := getDBC(p) var result = make(map[string]string) characterSets, err := dbc.DB().Query("SHOW VARIABLES WHERE Variable_name in ('character_set_client', 'character_set_connection', 'character_set_results');") require.Nil(t, err) for characterSets.Next() { var variableName string var value string characterSets.Scan(&variableName, &value) result[variableName] = value } collations, err := dbc.DB().Query("show variables where Variable_name = 'collation_connection';") require.Nil(t, err) for collations.Next() { var variableName string var value string collations.Scan(&variableName, &value) result[variableName] = value } assert.Equal(t, "utf8", result["character_set_client"]) assert.Equal(t, "utf8", result["character_set_connection"]) assert.Equal(t, "utf8", result["character_set_results"]) assert.Equal(t, "utf8_general_ci", result["collation_connection"]) } func TestClientCharsetAndCollation(t *testing.T) { DSNWithLatinCollation := getSettings() DSNWithLatinCollation.Params = "?collation=latin1_swedish_ci" pool := db.OpenPool(DSNWithLatinCollation) defer pool.Close() var resultsLatin1 = make(map[string]string) dbLatin1 := db.GetDBC(pool, logger) characterSets, err := dbLatin1.DB().Query("show variables where Variable_name in ('character_set_client', 'character_set_connection', 'character_set_results');") require.Nil(t, err) for characterSets.Next() { var variableName string var value string characterSets.Scan(&variableName, &value) resultsLatin1[variableName] = value } collations, err := dbLatin1.DB().Query("show variables where Variable_name = 'collation_connection';") require.Nil(t, err) for collations.Next() { var variableName string var value string collations.Scan(&variableName, &value) resultsLatin1[variableName] = value } assert.Equal(t, "latin1", resultsLatin1["character_set_client"]) assert.Equal(t, "latin1", resultsLatin1["character_set_connection"]) assert.Equal(t, "latin1", resultsLatin1["character_set_results"]) assert.Equal(t, "latin1_swedish_ci", resultsLatin1["collation_connection"]) } func createFooTable(t *testing.T) { p := getPool() dbc := getDBC(p) defer p.Close() _, err := dbc.DB().Exec("DROP TABLE IF EXISTS foo") require.Nil(t, err) _, err = dbc.DB().Exec("CREATE TABLE IF NOT EXISTS foo(id int not null)") require.Nil(t, err) } func TestTransactionCommit(t *testing.T) { createFooTable(t) p := getPool() dbc := getDBC(p) defer p.Close() require.Nil(t, dbc.Begin()) dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec() //let's try a fresh connection dbc1 := getDBC(p) assert.EqualValues(t, 0, countFoos(t, dbc1)) assert.EqualValues(t, 2, countFoos(t, dbc)) dbc.Commit() assert.EqualValues(t, 2, countFoos(t, dbc)) //let's try a fresh connection dbc2 := getDBC(p) assert.EqualValues(t, 2, countFoos(t, dbc2)) } func TestTransactionCommitNestedAllOk(t *testing.T) { createFooTable(t) p := getPool() dbc := getDBC(p) defer p.Close() //begin 1 require.Nil(t, dbc.Begin()) dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec() //let's try a fresh connection dbc1 := getDBC(p) assert.EqualValues(t, 0, countFoos(t, dbc1)) //begin 2 require.Nil(t, dbc.Begin()) dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(3)").Exec() //commit 2 dbc.Commit() assert.EqualValues(t, 3, countFoos(t, dbc)) //let's try a fresh connection dbc2 := getDBC(p) assert.EqualValues(t, 0, countFoos(t, dbc2)) //commit 1 dbc.Commit() assert.EqualValues(t, 3, countFoos(t, dbc)) //let's try a fresh connection db3 := getDBC(p) assert.EqualValues(t, 3, countFoos(t, db3)) } func TestTransactionCommitNestedRollback(t *testing.T) { createFooTable(t) p := getPool() dbc := getDBC(p) defer p.Close() //begin 1 require.Nil(t, dbc.Begin()) dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec() //let's try a fresh connection dbc1 := getDBC(p) assert.EqualValues(t, 0, countFoos(t, dbc1)) //begin 2 require.Nil(t, dbc.Begin()) dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(3)").Exec() //rollback 2 dbc.Rollback() //rollback above doesn't have an effect since we are in the //nested transaction assert.EqualValues(t, 3, countFoos(t, dbc)) //let's try a fresh connection dbc2 := getDBC(p) assert.EqualValues(t, 0, countFoos(t, dbc2)) //rollback 1 dbc.Rollback() assert.EqualValues(t, 0, countFoos(t, dbc)) //let's try a fresh connection db3 := getDBC(p) assert.EqualValues(t, 0, countFoos(t, db3)) } func TestTransactionRollbackOnDeferAllOK(t *testing.T) { createFooTable(t) p := getPool() dbc := getDBC(p) defer p.Close() fn := func() { dbc.Begin() defer dbc.RollbackOnDefer() dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec() assert.EqualValues(t, 2, countFoos(t, dbc)) fnNested := func() { dbc.Begin() defer dbc.RollbackOnDefer() dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(3)").Exec() assert.EqualValues(t, 3, countFoos(t, dbc)) dbc.Commit() } fnNested() dbc.Commit() } fn() dbc1 := getDBC(p) assert.EqualValues(t, 3, countFoos(t, dbc1)) } func TestTransactionRollbackOnDefer(t *testing.T) { createFooTable(t) p := getPool() dbc := getDBC(p) defer p.Close() fn := func() { dbc.Begin() defer dbc.RollbackOnDefer() dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec() assert.EqualValues(t, 2, countFoos(t, dbc)) fnNested := func() { dbc.Begin() defer dbc.RollbackOnDefer() dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(3)").Exec() assert.EqualValues(t, 3, countFoos(t, dbc)) //commit is missing for some reason, emulating error //dbc.Commit() } fnNested() //commit is missing for some reason, emulating error //dbc.Commit() } fn() dbc1 := getDBC(p) assert.EqualValues(t, 0, countFoos(t, dbc1)) } func TestTransaction(t *testing.T) { createFooTable(t) p := getPool() dbc := getDBC(p) defer p.Close() require.Nil(t, dbc.Transaction(func(dbs *db.DBC) error { dbs.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec() //let's try a fresh connection dbc1 := getDBC(p) assert.EqualValues(t, 0, countFoos(t, dbc1)) assert.EqualValues(t, 2, countFoos(t, dbc)) return nil })) //let's try a fresh connection dbc2 := getDBC(p) assert.EqualValues(t, 2, countFoos(t, dbc2)) } func TestTransactionRollbackOnError(t *testing.T) { createFooTable(t) p := getPool() dbc := getDBC(p) defer p.Close() err := dbc.Transaction(func(dbs *db.DBC) error { dbs.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec() assert.EqualValues(t, 2, countFoos(t, dbc)) return errors.New("Opps") }) assert.EqualValues(t, "Opps", err.Error()) //let's try a fresh connection dbc2 := getDBC(p) assert.EqualValues(t, 0, countFoos(t, dbc2)) } func TestTransactionRollbackOnPanic(t *testing.T) { createFooTable(t) p := getPool() dbc := getDBC(p) defer p.Close() defer func() { if r := recover(); r != nil { str := r.(string) assert.EqualValues(t, str, "Ooops") //let's try a fresh connection dbc2 := getDBC(p) assert.EqualValues(t, 0, countFoos(t, dbc2)) } }() dbc.Transaction(func(dbs *db.DBC) error { dbs.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec() assert.EqualValues(t, 2, countFoos(t, dbc)) panic("Ooops") return nil }) } func countFoos(t *testing.T, dbc *db.DBC) int { var res int err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res) require.Nil(t, err) return res } func TestTransactionRollback(t *testing.T) { createFooTable(t) p := getPool() dbc := getDBC(p) defer p.Close() require.Nil(t, dbc.Begin()) { _, err := dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec() require.Nil(t, err) } var res int { err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res) require.Nil(t, err) assert.EqualValues(t, 2, res) } dbc.Rollback() { err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res) require.Nil(t, err) assert.EqualValues(t, 0, res) } //let's try a fresh connection dbc2 := getDBC(p) { err := dbc2.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res) require.Nil(t, err) assert.EqualValues(t, 0, res) } } func TestTransactionNestedRollback(t *testing.T) { createFooTable(t) p := getPool() dbc := getDBC(p) defer p.Close() //begin 1 require.Nil(t, dbc.Begin()) { _, err := dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1)").Exec() require.Nil(t, err) } var res int { err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res) require.Nil(t, err) assert.EqualValues(t, 1, res) } //begin 2 require.Nil(t, dbc.Begin()) { _, err := dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1)").Exec() require.Nil(t, err) } //no real rollback happens here since we are in a 'bigger' transaction //rollback 2 dbc.Rollback() { err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res) require.Nil(t, err) assert.EqualValues(t, 2, res) } //rollback 1 dbc.Rollback() { err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res) require.Nil(t, err) assert.EqualValues(t, 0, res) } }