db/db_test.go

442 lines
9.6 KiB
Go

package db_test
import (
"log"
"os"
"testing"
"git.bit5.ru/backend/db"
"git.bit5.ru/backend/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/go-logr/stdr"
)
const verbosity = 0//2
var logger = stdr.New(log.New(os.Stdout, "", log.Lshortfile))
func getSettings() db.Settings {
//TODO: use ENV settings as well
stdr.SetVerbosity(verbosity)
return db.Settings{Host: "127.0.0.1", Port: "3306", User: "root", Pass: "test", Name: "tests", Prefix: "tests"}
}
func getPool() *db.Pool {
return db.OpenPool(getSettings())
}
func getDBC(p *db.Pool) *db.DBC {
dbc := db.GetDBC(p, logger)
return dbc
}
func TestDefaultClientCharsetAndCollation(t *testing.T) {
p := getPool()
dbc := getDBC(p)
var result = make(map[string]string)
characterSets, err := dbc.DB().Query("SHOW VARIABLES WHERE Variable_name in ('character_set_client', 'character_set_connection', 'character_set_results');")
require.Nil(t, err)
for characterSets.Next() {
var variableName string
var value string
characterSets.Scan(&variableName, &value)
result[variableName] = value
}
collations, err := dbc.DB().Query("show variables where Variable_name = 'collation_connection';")
require.Nil(t, err)
for collations.Next() {
var variableName string
var value string
collations.Scan(&variableName, &value)
result[variableName] = value
}
assert.Equal(t, "utf8", result["character_set_client"])
assert.Equal(t, "utf8", result["character_set_connection"])
assert.Equal(t, "utf8", result["character_set_results"])
assert.Equal(t, "utf8_general_ci", result["collation_connection"])
}
func TestClientCharsetAndCollation(t *testing.T) {
DSNWithLatinCollation := getSettings()
DSNWithLatinCollation.Params = "?collation=latin1_swedish_ci"
pool := db.OpenPool(DSNWithLatinCollation)
defer pool.Close()
var resultsLatin1 = make(map[string]string)
dbLatin1 := db.GetDBC(pool, logger)
characterSets, err := dbLatin1.DB().Query("show variables where Variable_name in ('character_set_client', 'character_set_connection', 'character_set_results');")
require.Nil(t, err)
for characterSets.Next() {
var variableName string
var value string
characterSets.Scan(&variableName, &value)
resultsLatin1[variableName] = value
}
collations, err := dbLatin1.DB().Query("show variables where Variable_name = 'collation_connection';")
require.Nil(t, err)
for collations.Next() {
var variableName string
var value string
collations.Scan(&variableName, &value)
resultsLatin1[variableName] = value
}
assert.Equal(t, "latin1", resultsLatin1["character_set_client"])
assert.Equal(t, "latin1", resultsLatin1["character_set_connection"])
assert.Equal(t, "latin1", resultsLatin1["character_set_results"])
assert.Equal(t, "latin1_swedish_ci", resultsLatin1["collation_connection"])
}
func createFooTable(t *testing.T) {
p := getPool()
dbc := getDBC(p)
defer p.Close()
_, err := dbc.DB().Exec("DROP TABLE IF EXISTS foo")
require.Nil(t, err)
_, err = dbc.DB().Exec("CREATE TABLE IF NOT EXISTS foo(id int not null)")
require.Nil(t, err)
}
func TestTransactionCommit(t *testing.T) {
createFooTable(t)
p := getPool()
dbc := getDBC(p)
defer p.Close()
require.Nil(t, dbc.Begin())
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
//let's try a fresh connection
dbc1 := getDBC(p)
assert.EqualValues(t, 0, countFoos(t, dbc1))
assert.EqualValues(t, 2, countFoos(t, dbc))
dbc.Commit()
assert.EqualValues(t, 2, countFoos(t, dbc))
//let's try a fresh connection
dbc2 := getDBC(p)
assert.EqualValues(t, 2, countFoos(t, dbc2))
}
func TestTransactionCommitNestedAllOk(t *testing.T) {
createFooTable(t)
p := getPool()
dbc := getDBC(p)
defer p.Close()
//begin 1
require.Nil(t, dbc.Begin())
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
//let's try a fresh connection
dbc1 := getDBC(p)
assert.EqualValues(t, 0, countFoos(t, dbc1))
//begin 2
require.Nil(t, dbc.Begin())
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(3)").Exec()
//commit 2
dbc.Commit()
assert.EqualValues(t, 3, countFoos(t, dbc))
//let's try a fresh connection
dbc2 := getDBC(p)
assert.EqualValues(t, 0, countFoos(t, dbc2))
//commit 1
dbc.Commit()
assert.EqualValues(t, 3, countFoos(t, dbc))
//let's try a fresh connection
db3 := getDBC(p)
assert.EqualValues(t, 3, countFoos(t, db3))
}
func TestTransactionCommitNestedRollback(t *testing.T) {
createFooTable(t)
p := getPool()
dbc := getDBC(p)
defer p.Close()
//begin 1
require.Nil(t, dbc.Begin())
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
//let's try a fresh connection
dbc1 := getDBC(p)
assert.EqualValues(t, 0, countFoos(t, dbc1))
//begin 2
require.Nil(t, dbc.Begin())
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(3)").Exec()
//rollback 2
dbc.Rollback()
//rollback above doesn't have an effect since we are in the
//nested transaction
assert.EqualValues(t, 3, countFoos(t, dbc))
//let's try a fresh connection
dbc2 := getDBC(p)
assert.EqualValues(t, 0, countFoos(t, dbc2))
//rollback 1
dbc.Rollback()
assert.EqualValues(t, 0, countFoos(t, dbc))
//let's try a fresh connection
db3 := getDBC(p)
assert.EqualValues(t, 0, countFoos(t, db3))
}
func TestTransactionRollbackOnDeferAllOK(t *testing.T) {
createFooTable(t)
p := getPool()
dbc := getDBC(p)
defer p.Close()
fn := func() {
dbc.Begin()
defer dbc.RollbackOnDefer()
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
assert.EqualValues(t, 2, countFoos(t, dbc))
fnNested := func() {
dbc.Begin()
defer dbc.RollbackOnDefer()
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(3)").Exec()
assert.EqualValues(t, 3, countFoos(t, dbc))
dbc.Commit()
}
fnNested()
dbc.Commit()
}
fn()
dbc1 := getDBC(p)
assert.EqualValues(t, 3, countFoos(t, dbc1))
}
func TestTransactionRollbackOnDefer(t *testing.T) {
createFooTable(t)
p := getPool()
dbc := getDBC(p)
defer p.Close()
fn := func() {
dbc.Begin()
defer dbc.RollbackOnDefer()
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
assert.EqualValues(t, 2, countFoos(t, dbc))
fnNested := func() {
dbc.Begin()
defer dbc.RollbackOnDefer()
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(3)").Exec()
assert.EqualValues(t, 3, countFoos(t, dbc))
//commit is missing for some reason, emulating error
//dbc.Commit()
}
fnNested()
//commit is missing for some reason, emulating error
//dbc.Commit()
}
fn()
dbc1 := getDBC(p)
assert.EqualValues(t, 0, countFoos(t, dbc1))
}
func TestTransaction(t *testing.T) {
createFooTable(t)
p := getPool()
dbc := getDBC(p)
defer p.Close()
require.Nil(t, dbc.Transaction(func(dbs *db.DBC) error {
dbs.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
//let's try a fresh connection
dbc1 := getDBC(p)
assert.EqualValues(t, 0, countFoos(t, dbc1))
assert.EqualValues(t, 2, countFoos(t, dbc))
return nil
}))
//let's try a fresh connection
dbc2 := getDBC(p)
assert.EqualValues(t, 2, countFoos(t, dbc2))
}
func TestTransactionRollbackOnError(t *testing.T) {
createFooTable(t)
p := getPool()
dbc := getDBC(p)
defer p.Close()
err := dbc.Transaction(func(dbs *db.DBC) error {
dbs.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
assert.EqualValues(t, 2, countFoos(t, dbc))
return errors.New("Opps")
})
assert.EqualValues(t, "Opps", err.Error())
//let's try a fresh connection
dbc2 := getDBC(p)
assert.EqualValues(t, 0, countFoos(t, dbc2))
}
func TestTransactionRollbackOnPanic(t *testing.T) {
createFooTable(t)
p := getPool()
dbc := getDBC(p)
defer p.Close()
defer func() {
if r := recover(); r != nil {
str := r.(string)
assert.EqualValues(t, str, "Ooops")
//let's try a fresh connection
dbc2 := getDBC(p)
assert.EqualValues(t, 0, countFoos(t, dbc2))
}
}()
dbc.Transaction(func(dbs *db.DBC) error {
dbs.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
assert.EqualValues(t, 2, countFoos(t, dbc))
panic("Ooops")
})
}
func countFoos(t *testing.T, dbc *db.DBC) int {
var res int
err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res)
require.Nil(t, err)
return res
}
func TestTransactionRollback(t *testing.T) {
createFooTable(t)
p := getPool()
dbc := getDBC(p)
defer p.Close()
require.Nil(t, dbc.Begin())
{
_, err := dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
require.Nil(t, err)
}
var res int
{
err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res)
require.Nil(t, err)
assert.EqualValues(t, 2, res)
}
dbc.Rollback()
{
err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res)
require.Nil(t, err)
assert.EqualValues(t, 0, res)
}
//let's try a fresh connection
dbc2 := getDBC(p)
{
err := dbc2.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res)
require.Nil(t, err)
assert.EqualValues(t, 0, res)
}
}
func TestTransactionNestedRollback(t *testing.T) {
createFooTable(t)
p := getPool()
dbc := getDBC(p)
defer p.Close()
//begin 1
require.Nil(t, dbc.Begin())
{
_, err := dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1)").Exec()
require.Nil(t, err)
}
var res int
{
err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res)
require.Nil(t, err)
assert.EqualValues(t, 1, res)
}
//begin 2
require.Nil(t, dbc.Begin())
{
_, err := dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1)").Exec()
require.Nil(t, err)
}
//no real rollback happens here since we are in a 'bigger' transaction
//rollback 2
dbc.Rollback()
{
err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res)
require.Nil(t, err)
assert.EqualValues(t, 2, res)
}
//rollback 1
dbc.Rollback()
{
err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res)
require.Nil(t, err)
assert.EqualValues(t, 0, res)
}
}