db/db_test.go

546 lines
13 KiB
Go
Raw Normal View History

2022-10-26 17:28:42 +03:00
package db_test
import (
"os"
2022-10-26 17:28:42 +03:00
"testing"
"git.bit5.ru/backend/colog"
2022-10-26 17:28:42 +03:00
"git.bit5.ru/backend/db"
"git.bit5.ru/backend/errors"
"github.com/stretchr/testify/assert"
)
var settings = db.Settings{Host: "127.0.0.1", Port: "3306", User: "root", Pass: "test", Name: "tests"}
var logger = colog.NewCoLog(os.Stderr, "", 0)
2022-10-26 17:28:42 +03:00
func getDBC() *db.DBC {
dbc := db.GetDBC(logger, settings)
2022-10-26 17:28:42 +03:00
return dbc
}
func TestDefaultClientCharsetAndCollation(t *testing.T) {
dbc := getDBC()
defer dbc.Close()
var result = make(map[string]string)
characterSets, err := dbc.DB().Query("show variables where Variable_name in ('character_set_client', 'character_set_connection', 'character_set_results');")
assert.Nil(t, err)
for characterSets.Next() {
var variableName string
var value string
characterSets.Scan(&variableName, &value)
result[variableName] = value
}
collations, err := dbc.DB().Query("show variables where Variable_name = 'collation_connection';")
assert.Nil(t, err)
for collations.Next() {
var variableName string
var value string
collations.Scan(&variableName, &value)
result[variableName] = value
}
assert.Equal(t, "utf8", result["character_set_client"])
assert.Equal(t, "utf8", result["character_set_connection"])
assert.Equal(t, "utf8", result["character_set_results"])
assert.Equal(t, "utf8_general_ci", result["collation_connection"])
}
func TestClientCharsetAndCollation(t *testing.T) {
DSNWithLatinCollation := settings
2022-10-26 17:28:42 +03:00
DSNWithLatinCollation.Params = "?collation=latin1_swedish_ci"
var resultsLatin1 = make(map[string]string)
dbLatin1 := db.GetDBC(logger, DSNWithLatinCollation)
2022-10-26 17:28:42 +03:00
defer dbLatin1.Close()
characterSets, err := dbLatin1.DB().Query("show variables where Variable_name in ('character_set_client', 'character_set_connection', 'character_set_results');")
assert.Nil(t, err)
for characterSets.Next() {
var variableName string
var value string
characterSets.Scan(&variableName, &value)
resultsLatin1[variableName] = value
}
collations, err := dbLatin1.DB().Query("show variables where Variable_name = 'collation_connection';")
assert.Nil(t, err)
for collations.Next() {
var variableName string
var value string
collations.Scan(&variableName, &value)
resultsLatin1[variableName] = value
}
assert.Equal(t, "latin1", resultsLatin1["character_set_client"])
assert.Equal(t, "latin1", resultsLatin1["character_set_connection"])
assert.Equal(t, "latin1", resultsLatin1["character_set_results"])
assert.Equal(t, "latin1_swedish_ci", resultsLatin1["collation_connection"])
}
func TestCloseConn(t *testing.T) {
dbc := getDBC()
defer dbc.Close()
var res int
err := dbc.SelectBySQL("SELECT 1").LoadValue(&res)
assert.Nil(t, err)
assert.EqualValues(t, 1, res)
assert.True(t, dbc.IsOpen())
dbc.Close()
assert.False(t, dbc.IsOpen())
//connection is automatically restored
err = dbc.SelectBySQL("SELECT 1").LoadValue(&res)
assert.True(t, dbc.IsOpen())
assert.Nil(t, err)
assert.EqualValues(t, 1, res)
}
func TestDoubleCloseConnIsOk(t *testing.T) {
dbc := getDBC()
defer dbc.Close()
var res int
err := dbc.SelectBySQL("SELECT 1").LoadValue(&res)
assert.Nil(t, err)
assert.EqualValues(t, 1, res)
assert.True(t, dbc.IsOpen())
dbc.Close()
assert.False(t, dbc.IsOpen())
//connection is automatically restored
err = dbc.SelectBySQL("SELECT 1").LoadValue(&res)
assert.True(t, dbc.IsOpen())
assert.Nil(t, err)
assert.EqualValues(t, 1, res)
dbc.Close()
assert.False(t, dbc.IsOpen())
//connection is automatically restored
err = dbc.SelectBySQL("SELECT 1").LoadValue(&res)
assert.True(t, dbc.IsOpen())
assert.Nil(t, err)
assert.EqualValues(t, 1, res)
}
func createFooTable(t *testing.T) {
dbc := getDBC()
defer dbc.Close()
_, err := dbc.DB().Exec("DROP TABLE IF EXISTS foo")
assert.Nil(t, err)
_, err = dbc.DB().Exec("CREATE TABLE IF NOT EXISTS foo(id int not null)")
assert.Nil(t, err)
}
func TestTransactionCommit(t *testing.T) {
createFooTable(t)
dbc := getDBC()
defer dbc.Close()
assert.Nil(t, dbc.Begin())
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
//let's try a fresh connection
dbc1 := getDBC()
defer dbc1.Close()
assert.EqualValues(t, 0, countFoos(t, dbc1))
assert.EqualValues(t, 2, countFoos(t, dbc))
dbc.Commit()
assert.EqualValues(t, 2, countFoos(t, dbc))
//let's try a fresh connection
dbc2 := getDBC()
defer dbc2.Close()
assert.EqualValues(t, 2, countFoos(t, dbc2))
}
func TestTransactionCommitNestedAllOk(t *testing.T) {
createFooTable(t)
dbc := getDBC()
defer dbc.Close()
//begin 1
assert.Nil(t, dbc.Begin())
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
//let's try a fresh connection
dbc1 := getDBC()
defer dbc1.Close()
assert.EqualValues(t, 0, countFoos(t, dbc1))
//begin 2
assert.Nil(t, dbc.Begin())
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(3)").Exec()
//commit 2
dbc.Commit()
assert.EqualValues(t, 3, countFoos(t, dbc))
//let's try a fresh connection
dbc2 := getDBC()
defer dbc2.Close()
assert.EqualValues(t, 0, countFoos(t, dbc2))
//commit 1
dbc.Commit()
assert.EqualValues(t, 3, countFoos(t, dbc))
//let's try a fresh connection
db3 := getDBC()
defer db3.Close()
assert.EqualValues(t, 3, countFoos(t, db3))
}
func TestTransactionCommitNestedRollback(t *testing.T) {
createFooTable(t)
dbc := getDBC()
defer dbc.Close()
//begin 1
assert.Nil(t, dbc.Begin())
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
//let's try a fresh connection
dbc1 := getDBC()
defer dbc1.Close()
assert.EqualValues(t, 0, countFoos(t, dbc1))
//begin 2
assert.Nil(t, dbc.Begin())
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(3)").Exec()
//rollback 2
dbc.Rollback()
//rollback above doesn't have an effect since we are in the
//nested transaction
assert.EqualValues(t, 3, countFoos(t, dbc))
//let's try a fresh connection
dbc2 := getDBC()
defer dbc2.Close()
assert.EqualValues(t, 0, countFoos(t, dbc2))
//rollback 1
dbc.Rollback()
assert.EqualValues(t, 0, countFoos(t, dbc))
//let's try a fresh connection
db3 := getDBC()
defer db3.Close()
assert.EqualValues(t, 0, countFoos(t, db3))
}
func TestTransactionRollbackOnDeferAllOK(t *testing.T) {
createFooTable(t)
dbc := getDBC()
defer dbc.Close()
fn := func() {
dbc.Begin()
defer dbc.RollbackOnDefer()
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
assert.EqualValues(t, 2, countFoos(t, dbc))
fnNested := func() {
dbc.Begin()
defer dbc.RollbackOnDefer()
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(3)").Exec()
assert.EqualValues(t, 3, countFoos(t, dbc))
dbc.Commit()
}
fnNested()
dbc.Commit()
}
fn()
dbc1 := getDBC()
defer dbc1.Close()
assert.EqualValues(t, 3, countFoos(t, dbc1))
}
func TestTransactionRollbackOnDefer(t *testing.T) {
createFooTable(t)
dbc := getDBC()
defer dbc.Close()
fn := func() {
dbc.Begin()
defer dbc.RollbackOnDefer()
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
assert.EqualValues(t, 2, countFoos(t, dbc))
fnNested := func() {
dbc.Begin()
defer dbc.RollbackOnDefer()
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(3)").Exec()
assert.EqualValues(t, 3, countFoos(t, dbc))
//commit is missing for some reason, emulating error
//dbc.Commit()
}
fnNested()
//commit is missing for some reason, emulating error
//dbc.Commit()
}
fn()
dbc1 := getDBC()
defer dbc1.Close()
assert.EqualValues(t, 0, countFoos(t, dbc1))
}
func TestTransaction(t *testing.T) {
createFooTable(t)
dbc := getDBC()
defer dbc.Close()
assert.Nil(t, dbc.Transaction(func(dbs *db.DBC) error {
dbs.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
//let's try a fresh connection
dbc1 := getDBC()
defer dbc1.Close()
assert.EqualValues(t, 0, countFoos(t, dbc1))
assert.EqualValues(t, 2, countFoos(t, dbc))
return nil
}))
//let's try a fresh connection
dbc2 := getDBC()
defer dbc2.Close()
assert.EqualValues(t, 2, countFoos(t, dbc2))
}
func TestTransactionRollbackOnError(t *testing.T) {
createFooTable(t)
dbc := getDBC()
defer dbc.Close()
err := dbc.Transaction(func(dbs *db.DBC) error {
dbs.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
assert.EqualValues(t, 2, countFoos(t, dbc))
return errors.New("Opps")
})
assert.EqualValues(t, "Opps", err.Error())
//let's try a fresh connection
dbc2 := getDBC()
defer dbc2.Close()
assert.EqualValues(t, 0, countFoos(t, dbc2))
}
func TestTransactionRollbackOnPanic(t *testing.T) {
createFooTable(t)
dbc := getDBC()
defer dbc.Close()
defer func() {
if r := recover(); r != nil {
str := r.(string)
assert.EqualValues(t, str, "Ooops")
//let's try a fresh connection
dbc2 := getDBC()
defer dbc2.Close()
assert.EqualValues(t, 0, countFoos(t, dbc2))
}
}()
dbc.Transaction(func(dbs *db.DBC) error {
dbs.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
assert.EqualValues(t, 2, countFoos(t, dbc))
panic("Ooops")
return nil
})
}
func countFoos(t *testing.T, dbc *db.DBC) int {
var res int
err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res)
assert.Nil(t, err)
return res
}
func TestTransactionRollback(t *testing.T) {
createFooTable(t)
dbc := getDBC()
defer dbc.Close()
assert.Nil(t, dbc.Begin())
{
_, err := dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
assert.Nil(t, err)
}
var res int
{
err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res)
assert.Nil(t, err)
assert.EqualValues(t, 2, res)
}
dbc.Rollback()
{
err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res)
assert.Nil(t, err)
assert.EqualValues(t, 0, res)
}
//let's try a fresh connection
dbc2 := getDBC()
defer dbc2.Close()
{
err := dbc2.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res)
assert.Nil(t, err)
assert.EqualValues(t, 0, res)
}
}
func TestTransactionNestedRollback(t *testing.T) {
createFooTable(t)
dbc := getDBC()
defer dbc.Close()
//begin 1
assert.Nil(t, dbc.Begin())
{
_, err := dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1)").Exec()
assert.Nil(t, err)
}
var res int
{
err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res)
assert.Nil(t, err)
assert.EqualValues(t, 1, res)
}
//begin 2
assert.Nil(t, dbc.Begin())
{
_, err := dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1)").Exec()
assert.Nil(t, err)
}
//no real rollback happens here since we are in a 'bigger' transaction
//rollback 2
dbc.Rollback()
{
err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res)
assert.Nil(t, err)
assert.EqualValues(t, 2, res)
}
//rollback 1
dbc.Rollback()
{
err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res)
assert.Nil(t, err)
assert.EqualValues(t, 0, res)
}
}
//func makePlayers(t *testing.T, env *env.Env, amount int) []uint32 {
// var players []uint32
// mainDb := env.MainDb()
//
// for i := 0; i < amount; i++ {
// shardPlayer, err := dbshrd.CreateShardPlayer(env.Settings.DB_SHARDS, mainDb, 1)
// require.NoError(t, err)
// players = append(players, shardPlayer.Id)
// }
//
// //Note: There is only one test shard db, so the shard id = 1 for all players
// shardDb, err := dbshrd.GetShardDb(env.Logger, env.Settings.DB_SHARDS, 1)
// require.NoError(t, err)
// require.NotNil(t, shardDb)
// defer shardDb.Close()
//
// for i := 0; i < len(players); i++ {
// player := autogen.NewDataPlayer()
// player.Id = players[i]
// err = dbmeta.SaveRow(shardDb, player)
// require.NoError(t, err)
// }
//
// return players
//}
//func TestSelectBySQLWithChunkedIN(t *testing.T) {
// env := tests.NewEnvCleanStorage()
// defer env.Close()
//
// //Note: There is only one test shard db, so the shard id = 1 for all players
// shardDb, err := dbshrd.GetShardDb(env.Logger, env.Settings.DB_SHARDS, 1)
// require.NoError(t, err)
// require.NotNil(t, shardDb)
// defer shardDb.Close()
//
// var playersAmount = 100
// var playerIds []uint32
// playerIds = makePlayers(t, env, playersAmount)
//
// for chunkSizeForIN := 1; chunkSizeForIN <= 101; chunkSizeForIN++ {
// var totalIds []uint32
// fullChunksAmount := playersAmount / chunkSizeForIN
// modulo := playersAmount % chunkSizeForIN
// sql := "SELECT id FROM player WHERE 1 = ? AND id IN ? AND 2 = ?"
// builders := shardDb.SelectBySQLWithChunkedIN(sql, chunkSizeForIN, 1, playerIds, 2)
//
// require.Len(t, builders, fullChunksAmount+util.BoolToInt(modulo > 0))
//
// for queryIndex, builder := range builders {
// var queryIds []uint32
// _, err := builder.LoadValues(&queryIds)
// require.NoError(t, err)
// require.NotNil(t, queryIds)
// totalIds = append(totalIds, queryIds...)
//
// if queryIndex < fullChunksAmount {
// require.Len(t, queryIds, chunkSizeForIN)
// } else {
// require.Len(t, queryIds, modulo)
// }
// }
//
// require.Len(t, totalIds, playersAmount)
// }
//}