db/db_test.go

552 lines
12 KiB
Go

package db_test
import (
"testing"
"git.bit5.ru/backend/db"
"git.bit5.ru/backend/errors"
"game/autogen"
"game/dbmeta"
"game/dbshrd"
"game/env"
"game/tests"
"game/util"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func getDBC() *db.DBC {
g := tests.Globs()
dbc := db.GetDBC(g.Logger, g.Settings.DB_MAIN)
return dbc
}
func TestDefaultClientCharsetAndCollation(t *testing.T) {
dbc := getDBC()
defer dbc.Close()
var result = make(map[string]string)
characterSets, err := dbc.DB().Query("show variables where Variable_name in ('character_set_client', 'character_set_connection', 'character_set_results');")
assert.Nil(t, err)
for characterSets.Next() {
var variableName string
var value string
characterSets.Scan(&variableName, &value)
result[variableName] = value
}
collations, err := dbc.DB().Query("show variables where Variable_name = 'collation_connection';")
assert.Nil(t, err)
for collations.Next() {
var variableName string
var value string
collations.Scan(&variableName, &value)
result[variableName] = value
}
assert.Equal(t, "utf8", result["character_set_client"])
assert.Equal(t, "utf8", result["character_set_connection"])
assert.Equal(t, "utf8", result["character_set_results"])
assert.Equal(t, "utf8_general_ci", result["collation_connection"])
}
func TestClientCharsetAndCollation(t *testing.T) {
g := tests.Globs()
DSNWithLatinCollation := g.Settings.DB_MAIN
DSNWithLatinCollation.Params = "?collation=latin1_swedish_ci"
var resultsLatin1 = make(map[string]string)
dbLatin1 := db.GetDBC(g.Logger, DSNWithLatinCollation)
defer dbLatin1.Close()
characterSets, err := dbLatin1.DB().Query("show variables where Variable_name in ('character_set_client', 'character_set_connection', 'character_set_results');")
assert.Nil(t, err)
for characterSets.Next() {
var variableName string
var value string
characterSets.Scan(&variableName, &value)
resultsLatin1[variableName] = value
}
collations, err := dbLatin1.DB().Query("show variables where Variable_name = 'collation_connection';")
assert.Nil(t, err)
for collations.Next() {
var variableName string
var value string
collations.Scan(&variableName, &value)
resultsLatin1[variableName] = value
}
assert.Equal(t, "latin1", resultsLatin1["character_set_client"])
assert.Equal(t, "latin1", resultsLatin1["character_set_connection"])
assert.Equal(t, "latin1", resultsLatin1["character_set_results"])
assert.Equal(t, "latin1_swedish_ci", resultsLatin1["collation_connection"])
}
func TestCloseConn(t *testing.T) {
dbc := getDBC()
defer dbc.Close()
var res int
err := dbc.SelectBySQL("SELECT 1").LoadValue(&res)
assert.Nil(t, err)
assert.EqualValues(t, 1, res)
assert.True(t, dbc.IsOpen())
dbc.Close()
assert.False(t, dbc.IsOpen())
//connection is automatically restored
err = dbc.SelectBySQL("SELECT 1").LoadValue(&res)
assert.True(t, dbc.IsOpen())
assert.Nil(t, err)
assert.EqualValues(t, 1, res)
}
func TestDoubleCloseConnIsOk(t *testing.T) {
dbc := getDBC()
defer dbc.Close()
var res int
err := dbc.SelectBySQL("SELECT 1").LoadValue(&res)
assert.Nil(t, err)
assert.EqualValues(t, 1, res)
assert.True(t, dbc.IsOpen())
dbc.Close()
assert.False(t, dbc.IsOpen())
//connection is automatically restored
err = dbc.SelectBySQL("SELECT 1").LoadValue(&res)
assert.True(t, dbc.IsOpen())
assert.Nil(t, err)
assert.EqualValues(t, 1, res)
dbc.Close()
assert.False(t, dbc.IsOpen())
//connection is automatically restored
err = dbc.SelectBySQL("SELECT 1").LoadValue(&res)
assert.True(t, dbc.IsOpen())
assert.Nil(t, err)
assert.EqualValues(t, 1, res)
}
func createFooTable(t *testing.T) {
dbc := getDBC()
defer dbc.Close()
_, err := dbc.DB().Exec("DROP TABLE IF EXISTS foo")
assert.Nil(t, err)
_, err = dbc.DB().Exec("CREATE TABLE IF NOT EXISTS foo(id int not null)")
assert.Nil(t, err)
}
func TestTransactionCommit(t *testing.T) {
createFooTable(t)
dbc := getDBC()
defer dbc.Close()
assert.Nil(t, dbc.Begin())
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
//let's try a fresh connection
dbc1 := getDBC()
defer dbc1.Close()
assert.EqualValues(t, 0, countFoos(t, dbc1))
assert.EqualValues(t, 2, countFoos(t, dbc))
dbc.Commit()
assert.EqualValues(t, 2, countFoos(t, dbc))
//let's try a fresh connection
dbc2 := getDBC()
defer dbc2.Close()
assert.EqualValues(t, 2, countFoos(t, dbc2))
}
func TestTransactionCommitNestedAllOk(t *testing.T) {
createFooTable(t)
dbc := getDBC()
defer dbc.Close()
//begin 1
assert.Nil(t, dbc.Begin())
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
//let's try a fresh connection
dbc1 := getDBC()
defer dbc1.Close()
assert.EqualValues(t, 0, countFoos(t, dbc1))
//begin 2
assert.Nil(t, dbc.Begin())
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(3)").Exec()
//commit 2
dbc.Commit()
assert.EqualValues(t, 3, countFoos(t, dbc))
//let's try a fresh connection
dbc2 := getDBC()
defer dbc2.Close()
assert.EqualValues(t, 0, countFoos(t, dbc2))
//commit 1
dbc.Commit()
assert.EqualValues(t, 3, countFoos(t, dbc))
//let's try a fresh connection
db3 := getDBC()
defer db3.Close()
assert.EqualValues(t, 3, countFoos(t, db3))
}
func TestTransactionCommitNestedRollback(t *testing.T) {
createFooTable(t)
dbc := getDBC()
defer dbc.Close()
//begin 1
assert.Nil(t, dbc.Begin())
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
//let's try a fresh connection
dbc1 := getDBC()
defer dbc1.Close()
assert.EqualValues(t, 0, countFoos(t, dbc1))
//begin 2
assert.Nil(t, dbc.Begin())
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(3)").Exec()
//rollback 2
dbc.Rollback()
//rollback above doesn't have an effect since we are in the
//nested transaction
assert.EqualValues(t, 3, countFoos(t, dbc))
//let's try a fresh connection
dbc2 := getDBC()
defer dbc2.Close()
assert.EqualValues(t, 0, countFoos(t, dbc2))
//rollback 1
dbc.Rollback()
assert.EqualValues(t, 0, countFoos(t, dbc))
//let's try a fresh connection
db3 := getDBC()
defer db3.Close()
assert.EqualValues(t, 0, countFoos(t, db3))
}
func TestTransactionRollbackOnDeferAllOK(t *testing.T) {
createFooTable(t)
dbc := getDBC()
defer dbc.Close()
fn := func() {
dbc.Begin()
defer dbc.RollbackOnDefer()
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
assert.EqualValues(t, 2, countFoos(t, dbc))
fnNested := func() {
dbc.Begin()
defer dbc.RollbackOnDefer()
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(3)").Exec()
assert.EqualValues(t, 3, countFoos(t, dbc))
dbc.Commit()
}
fnNested()
dbc.Commit()
}
fn()
dbc1 := getDBC()
defer dbc1.Close()
assert.EqualValues(t, 3, countFoos(t, dbc1))
}
func TestTransactionRollbackOnDefer(t *testing.T) {
createFooTable(t)
dbc := getDBC()
defer dbc.Close()
fn := func() {
dbc.Begin()
defer dbc.RollbackOnDefer()
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
assert.EqualValues(t, 2, countFoos(t, dbc))
fnNested := func() {
dbc.Begin()
defer dbc.RollbackOnDefer()
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(3)").Exec()
assert.EqualValues(t, 3, countFoos(t, dbc))
//commit is missing for some reason, emulating error
//dbc.Commit()
}
fnNested()
//commit is missing for some reason, emulating error
//dbc.Commit()
}
fn()
dbc1 := getDBC()
defer dbc1.Close()
assert.EqualValues(t, 0, countFoos(t, dbc1))
}
func TestTransaction(t *testing.T) {
createFooTable(t)
dbc := getDBC()
defer dbc.Close()
assert.Nil(t, dbc.Transaction(func(dbs *db.DBC) error {
dbs.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
//let's try a fresh connection
dbc1 := getDBC()
defer dbc1.Close()
assert.EqualValues(t, 0, countFoos(t, dbc1))
assert.EqualValues(t, 2, countFoos(t, dbc))
return nil
}))
//let's try a fresh connection
dbc2 := getDBC()
defer dbc2.Close()
assert.EqualValues(t, 2, countFoos(t, dbc2))
}
func TestTransactionRollbackOnError(t *testing.T) {
createFooTable(t)
dbc := getDBC()
defer dbc.Close()
err := dbc.Transaction(func(dbs *db.DBC) error {
dbs.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
assert.EqualValues(t, 2, countFoos(t, dbc))
return errors.New("Opps")
})
assert.EqualValues(t, "Opps", err.Error())
//let's try a fresh connection
dbc2 := getDBC()
defer dbc2.Close()
assert.EqualValues(t, 0, countFoos(t, dbc2))
}
func TestTransactionRollbackOnPanic(t *testing.T) {
createFooTable(t)
dbc := getDBC()
defer dbc.Close()
defer func() {
if r := recover(); r != nil {
str := r.(string)
assert.EqualValues(t, str, "Ooops")
//let's try a fresh connection
dbc2 := getDBC()
defer dbc2.Close()
assert.EqualValues(t, 0, countFoos(t, dbc2))
}
}()
dbc.Transaction(func(dbs *db.DBC) error {
dbs.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
assert.EqualValues(t, 2, countFoos(t, dbc))
panic("Ooops")
return nil
})
}
func countFoos(t *testing.T, dbc *db.DBC) int {
var res int
err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res)
assert.Nil(t, err)
return res
}
func TestTransactionRollback(t *testing.T) {
createFooTable(t)
dbc := getDBC()
defer dbc.Close()
assert.Nil(t, dbc.Begin())
{
_, err := dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
assert.Nil(t, err)
}
var res int
{
err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res)
assert.Nil(t, err)
assert.EqualValues(t, 2, res)
}
dbc.Rollback()
{
err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res)
assert.Nil(t, err)
assert.EqualValues(t, 0, res)
}
//let's try a fresh connection
dbc2 := getDBC()
defer dbc2.Close()
{
err := dbc2.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res)
assert.Nil(t, err)
assert.EqualValues(t, 0, res)
}
}
func TestTransactionNestedRollback(t *testing.T) {
createFooTable(t)
dbc := getDBC()
defer dbc.Close()
//begin 1
assert.Nil(t, dbc.Begin())
{
_, err := dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1)").Exec()
assert.Nil(t, err)
}
var res int
{
err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res)
assert.Nil(t, err)
assert.EqualValues(t, 1, res)
}
//begin 2
assert.Nil(t, dbc.Begin())
{
_, err := dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1)").Exec()
assert.Nil(t, err)
}
//no real rollback happens here since we are in a 'bigger' transaction
//rollback 2
dbc.Rollback()
{
err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res)
assert.Nil(t, err)
assert.EqualValues(t, 2, res)
}
//rollback 1
dbc.Rollback()
{
err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res)
assert.Nil(t, err)
assert.EqualValues(t, 0, res)
}
}
func makePlayers(t *testing.T, env *env.Env, amount int) []uint32 {
var players []uint32
mainDb := env.MainDb()
for i := 0; i < amount; i++ {
shardPlayer, err := dbshrd.CreateShardPlayer(env.Settings.DB_SHARDS, mainDb, 1)
require.NoError(t, err)
players = append(players, shardPlayer.Id)
}
//Note: There is only one test shard db, so the shard id = 1 for all players
shardDb, err := dbshrd.GetShardDb(env.Logger, env.Settings.DB_SHARDS, 1)
require.NoError(t, err)
require.NotNil(t, shardDb)
defer shardDb.Close()
for i := 0; i < len(players); i++ {
player := autogen.NewDataPlayer()
player.Id = players[i]
err = dbmeta.SaveRow(shardDb, player)
require.NoError(t, err)
}
return players
}
func TestSelectBySQLWithChunkedIN(t *testing.T) {
env := tests.NewEnvCleanStorage()
defer env.Close()
//Note: There is only one test shard db, so the shard id = 1 for all players
shardDb, err := dbshrd.GetShardDb(env.Logger, env.Settings.DB_SHARDS, 1)
require.NoError(t, err)
require.NotNil(t, shardDb)
defer shardDb.Close()
var playersAmount = 100
var playerIds []uint32
playerIds = makePlayers(t, env, playersAmount)
for chunkSizeForIN := 1; chunkSizeForIN <= 101; chunkSizeForIN++ {
var totalIds []uint32
fullChunksAmount := playersAmount / chunkSizeForIN
modulo := playersAmount % chunkSizeForIN
sql := "SELECT id FROM player WHERE 1 = ? AND id IN ? AND 2 = ?"
builders := shardDb.SelectBySQLWithChunkedIN(sql, chunkSizeForIN, 1, playerIds, 2)
require.Len(t, builders, fullChunksAmount+util.BoolToInt(modulo > 0))
for queryIndex, builder := range builders {
var queryIds []uint32
_, err := builder.LoadValues(&queryIds)
require.NoError(t, err)
require.NotNil(t, queryIds)
totalIds = append(totalIds, queryIds...)
if queryIndex < fullChunksAmount {
require.Len(t, queryIds, chunkSizeForIN)
} else {
require.Len(t, queryIds, modulo)
}
}
require.Len(t, totalIds, playersAmount)
}
}