package db_test import ( "os" "testing" "git.bit5.ru/backend/colog" "git.bit5.ru/backend/db" "git.bit5.ru/backend/errors" "github.com/stretchr/testify/assert" ) var logger = colog.NewCoLog(os.Stderr, "", 0) func getSettings() db.Settings { //TODO: use ENV settings as well return db.Settings{Host: "127.0.0.1", Port: "3306", User: "root", Pass: "test", Name: "tests"} } func getPool() *db.Pool { return db.OpenPool(getSettings()) } func getDBC(p *db.Pool) *db.DBC { dbc := db.GetDBC(p, logger) return dbc } func TestDefaultClientCharsetAndCollation(t *testing.T) { p := getPool() dbc := getDBC(p) var result = make(map[string]string) characterSets, err := dbc.DB().Query("show variables where Variable_name in ('character_set_client', 'character_set_connection', 'character_set_results');") assert.Nil(t, err) for characterSets.Next() { var variableName string var value string characterSets.Scan(&variableName, &value) result[variableName] = value } collations, err := dbc.DB().Query("show variables where Variable_name = 'collation_connection';") assert.Nil(t, err) for collations.Next() { var variableName string var value string collations.Scan(&variableName, &value) result[variableName] = value } assert.Equal(t, "utf8", result["character_set_client"]) assert.Equal(t, "utf8", result["character_set_connection"]) assert.Equal(t, "utf8", result["character_set_results"]) assert.Equal(t, "utf8_general_ci", result["collation_connection"]) } func TestClientCharsetAndCollation(t *testing.T) { DSNWithLatinCollation := getSettings() DSNWithLatinCollation.Params = "?collation=latin1_swedish_ci" pool := db.OpenPool(DSNWithLatinCollation) defer pool.Close() var resultsLatin1 = make(map[string]string) dbLatin1 := db.GetDBC(pool, logger) characterSets, err := dbLatin1.DB().Query("show variables where Variable_name in ('character_set_client', 'character_set_connection', 'character_set_results');") assert.Nil(t, err) for characterSets.Next() { var variableName string var value string characterSets.Scan(&variableName, &value) resultsLatin1[variableName] = value } collations, err := dbLatin1.DB().Query("show variables where Variable_name = 'collation_connection';") assert.Nil(t, err) for collations.Next() { var variableName string var value string collations.Scan(&variableName, &value) resultsLatin1[variableName] = value } assert.Equal(t, "latin1", resultsLatin1["character_set_client"]) assert.Equal(t, "latin1", resultsLatin1["character_set_connection"]) assert.Equal(t, "latin1", resultsLatin1["character_set_results"]) assert.Equal(t, "latin1_swedish_ci", resultsLatin1["collation_connection"]) } func createFooTable(t *testing.T) { p := getPool() dbc := getDBC(p) defer p.Close() _, err := dbc.DB().Exec("DROP TABLE IF EXISTS foo") assert.Nil(t, err) _, err = dbc.DB().Exec("CREATE TABLE IF NOT EXISTS foo(id int not null)") assert.Nil(t, err) } func TestTransactionCommit(t *testing.T) { createFooTable(t) p := getPool() dbc := getDBC(p) defer p.Close() assert.Nil(t, dbc.Begin()) dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec() //let's try a fresh connection dbc1 := getDBC(p) assert.EqualValues(t, 0, countFoos(t, dbc1)) assert.EqualValues(t, 2, countFoos(t, dbc)) dbc.Commit() assert.EqualValues(t, 2, countFoos(t, dbc)) //let's try a fresh connection dbc2 := getDBC(p) assert.EqualValues(t, 2, countFoos(t, dbc2)) } func TestTransactionCommitNestedAllOk(t *testing.T) { createFooTable(t) p := getPool() dbc := getDBC(p) defer p.Close() //begin 1 assert.Nil(t, dbc.Begin()) dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec() //let's try a fresh connection dbc1 := getDBC(p) assert.EqualValues(t, 0, countFoos(t, dbc1)) //begin 2 assert.Nil(t, dbc.Begin()) dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(3)").Exec() //commit 2 dbc.Commit() assert.EqualValues(t, 3, countFoos(t, dbc)) //let's try a fresh connection dbc2 := getDBC(p) assert.EqualValues(t, 0, countFoos(t, dbc2)) //commit 1 dbc.Commit() assert.EqualValues(t, 3, countFoos(t, dbc)) //let's try a fresh connection db3 := getDBC(p) assert.EqualValues(t, 3, countFoos(t, db3)) } func TestTransactionCommitNestedRollback(t *testing.T) { createFooTable(t) p := getPool() dbc := getDBC(p) defer p.Close() //begin 1 assert.Nil(t, dbc.Begin()) dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec() //let's try a fresh connection dbc1 := getDBC(p) assert.EqualValues(t, 0, countFoos(t, dbc1)) //begin 2 assert.Nil(t, dbc.Begin()) dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(3)").Exec() //rollback 2 dbc.Rollback() //rollback above doesn't have an effect since we are in the //nested transaction assert.EqualValues(t, 3, countFoos(t, dbc)) //let's try a fresh connection dbc2 := getDBC(p) assert.EqualValues(t, 0, countFoos(t, dbc2)) //rollback 1 dbc.Rollback() assert.EqualValues(t, 0, countFoos(t, dbc)) //let's try a fresh connection db3 := getDBC(p) assert.EqualValues(t, 0, countFoos(t, db3)) } func TestTransactionRollbackOnDeferAllOK(t *testing.T) { createFooTable(t) p := getPool() dbc := getDBC(p) defer p.Close() fn := func() { dbc.Begin() defer dbc.RollbackOnDefer() dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec() assert.EqualValues(t, 2, countFoos(t, dbc)) fnNested := func() { dbc.Begin() defer dbc.RollbackOnDefer() dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(3)").Exec() assert.EqualValues(t, 3, countFoos(t, dbc)) dbc.Commit() } fnNested() dbc.Commit() } fn() dbc1 := getDBC(p) assert.EqualValues(t, 3, countFoos(t, dbc1)) } func TestTransactionRollbackOnDefer(t *testing.T) { createFooTable(t) p := getPool() dbc := getDBC(p) defer p.Close() fn := func() { dbc.Begin() defer dbc.RollbackOnDefer() dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec() assert.EqualValues(t, 2, countFoos(t, dbc)) fnNested := func() { dbc.Begin() defer dbc.RollbackOnDefer() dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(3)").Exec() assert.EqualValues(t, 3, countFoos(t, dbc)) //commit is missing for some reason, emulating error //dbc.Commit() } fnNested() //commit is missing for some reason, emulating error //dbc.Commit() } fn() dbc1 := getDBC(p) assert.EqualValues(t, 0, countFoos(t, dbc1)) } func TestTransaction(t *testing.T) { createFooTable(t) p := getPool() dbc := getDBC(p) defer p.Close() assert.Nil(t, dbc.Transaction(func(dbs *db.DBC) error { dbs.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec() //let's try a fresh connection dbc1 := getDBC(p) assert.EqualValues(t, 0, countFoos(t, dbc1)) assert.EqualValues(t, 2, countFoos(t, dbc)) return nil })) //let's try a fresh connection dbc2 := getDBC(p) assert.EqualValues(t, 2, countFoos(t, dbc2)) } func TestTransactionRollbackOnError(t *testing.T) { createFooTable(t) p := getPool() dbc := getDBC(p) defer p.Close() err := dbc.Transaction(func(dbs *db.DBC) error { dbs.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec() assert.EqualValues(t, 2, countFoos(t, dbc)) return errors.New("Opps") }) assert.EqualValues(t, "Opps", err.Error()) //let's try a fresh connection dbc2 := getDBC(p) assert.EqualValues(t, 0, countFoos(t, dbc2)) } func TestTransactionRollbackOnPanic(t *testing.T) { createFooTable(t) p := getPool() dbc := getDBC(p) defer p.Close() defer func() { if r := recover(); r != nil { str := r.(string) assert.EqualValues(t, str, "Ooops") //let's try a fresh connection dbc2 := getDBC(p) assert.EqualValues(t, 0, countFoos(t, dbc2)) } }() dbc.Transaction(func(dbs *db.DBC) error { dbs.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec() assert.EqualValues(t, 2, countFoos(t, dbc)) panic("Ooops") return nil }) } func countFoos(t *testing.T, dbc *db.DBC) int { var res int err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res) assert.Nil(t, err) return res } func TestTransactionRollback(t *testing.T) { createFooTable(t) p := getPool() dbc := getDBC(p) defer p.Close() assert.Nil(t, dbc.Begin()) { _, err := dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec() assert.Nil(t, err) } var res int { err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res) assert.Nil(t, err) assert.EqualValues(t, 2, res) } dbc.Rollback() { err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res) assert.Nil(t, err) assert.EqualValues(t, 0, res) } //let's try a fresh connection dbc2 := getDBC(p) { err := dbc2.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res) assert.Nil(t, err) assert.EqualValues(t, 0, res) } } func TestTransactionNestedRollback(t *testing.T) { createFooTable(t) p := getPool() dbc := getDBC(p) defer p.Close() //begin 1 assert.Nil(t, dbc.Begin()) { _, err := dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1)").Exec() assert.Nil(t, err) } var res int { err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res) assert.Nil(t, err) assert.EqualValues(t, 1, res) } //begin 2 assert.Nil(t, dbc.Begin()) { _, err := dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1)").Exec() assert.Nil(t, err) } //no real rollback happens here since we are in a 'bigger' transaction //rollback 2 dbc.Rollback() { err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res) assert.Nil(t, err) assert.EqualValues(t, 2, res) } //rollback 1 dbc.Rollback() { err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res) assert.Nil(t, err) assert.EqualValues(t, 0, res) } } //func makePlayers(t *testing.T, env *env.Env, amount int) []uint32 { // var players []uint32 // mainDb := env.MainDb() // // for i := 0; i < amount; i++ { // shardPlayer, err := dbshrd.CreateShardPlayer(env.Settings.DB_SHARDS, mainDb, 1) // require.NoError(t, err) // players = append(players, shardPlayer.Id) // } // // //Note: There is only one test shard db, so the shard id = 1 for all players // shardDb, err := dbshrd.GetShardDb(env.Logger, env.Settings.DB_SHARDS, 1) // require.NoError(t, err) // require.NotNil(t, shardDb) // defer shardDb.Close() // // for i := 0; i < len(players); i++ { // player := autogen.NewDataPlayer() // player.Id = players[i] // err = dbmeta.SaveRow(shardDb, player) // require.NoError(t, err) // } // // return players //} //func TestSelectBySQLWithChunkedIN(t *testing.T) { // env := tests.NewEnvCleanStorage() // defer env.Close() // // //Note: There is only one test shard db, so the shard id = 1 for all players // shardDb, err := dbshrd.GetShardDb(env.Logger, env.Settings.DB_SHARDS, 1) // require.NoError(t, err) // require.NotNil(t, shardDb) // defer shardDb.Close() // // var playersAmount = 100 // var playerIds []uint32 // playerIds = makePlayers(t, env, playersAmount) // // for chunkSizeForIN := 1; chunkSizeForIN <= 101; chunkSizeForIN++ { // var totalIds []uint32 // fullChunksAmount := playersAmount / chunkSizeForIN // modulo := playersAmount % chunkSizeForIN // sql := "SELECT id FROM player WHERE 1 = ? AND id IN ? AND 2 = ?" // builders := shardDb.SelectBySQLWithChunkedIN(sql, chunkSizeForIN, 1, playerIds, 2) // // require.Len(t, builders, fullChunksAmount+util.BoolToInt(modulo > 0)) // // for queryIndex, builder := range builders { // var queryIds []uint32 // _, err := builder.LoadValues(&queryIds) // require.NoError(t, err) // require.NotNil(t, queryIds) // totalIds = append(totalIds, queryIds...) // // if queryIndex < fullChunksAmount { // require.Len(t, queryIds, chunkSizeForIN) // } else { // require.Len(t, queryIds, modulo) // } // } // // require.Len(t, totalIds, playersAmount) // } //}