442 lines
9.6 KiB
Go
442 lines
9.6 KiB
Go
package db_test
|
|
|
|
import (
|
|
"log"
|
|
"os"
|
|
"testing"
|
|
|
|
"git.bit5.ru/backend/db"
|
|
"git.bit5.ru/backend/errors"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/go-logr/stdr"
|
|
)
|
|
|
|
const verbosity = 0//2
|
|
|
|
var logger = stdr.New(log.New(os.Stdout, "", log.Lshortfile))
|
|
|
|
func getSettings() db.Settings {
|
|
//TODO: use ENV settings as well
|
|
stdr.SetVerbosity(verbosity)
|
|
return db.Settings{Host: "127.0.0.1", Port: "3306", User: "root", Pass: "test", Name: "tests", Prefix: "tests"}
|
|
}
|
|
|
|
func getPool() *db.Pool {
|
|
return db.OpenPool(getSettings())
|
|
}
|
|
|
|
func getDBC(p *db.Pool) *db.DBC {
|
|
dbc := db.GetDBC(p, logger)
|
|
return dbc
|
|
}
|
|
|
|
func TestDefaultClientCharsetAndCollation(t *testing.T) {
|
|
p := getPool()
|
|
dbc := getDBC(p)
|
|
|
|
var result = make(map[string]string)
|
|
|
|
characterSets, err := dbc.DB().Query("SHOW VARIABLES WHERE Variable_name in ('character_set_client', 'character_set_connection', 'character_set_results');")
|
|
require.Nil(t, err)
|
|
|
|
for characterSets.Next() {
|
|
|
|
var variableName string
|
|
var value string
|
|
|
|
characterSets.Scan(&variableName, &value)
|
|
|
|
result[variableName] = value
|
|
}
|
|
|
|
collations, err := dbc.DB().Query("show variables where Variable_name = 'collation_connection';")
|
|
require.Nil(t, err)
|
|
|
|
for collations.Next() {
|
|
|
|
var variableName string
|
|
var value string
|
|
|
|
collations.Scan(&variableName, &value)
|
|
|
|
result[variableName] = value
|
|
}
|
|
|
|
assert.Equal(t, "utf8", result["character_set_client"])
|
|
assert.Equal(t, "utf8", result["character_set_connection"])
|
|
assert.Equal(t, "utf8", result["character_set_results"])
|
|
assert.Equal(t, "utf8_general_ci", result["collation_connection"])
|
|
}
|
|
|
|
func TestClientCharsetAndCollation(t *testing.T) {
|
|
DSNWithLatinCollation := getSettings()
|
|
DSNWithLatinCollation.Params = "?collation=latin1_swedish_ci"
|
|
pool := db.OpenPool(DSNWithLatinCollation)
|
|
defer pool.Close()
|
|
|
|
var resultsLatin1 = make(map[string]string)
|
|
|
|
dbLatin1 := db.GetDBC(pool, logger)
|
|
|
|
characterSets, err := dbLatin1.DB().Query("show variables where Variable_name in ('character_set_client', 'character_set_connection', 'character_set_results');")
|
|
require.Nil(t, err)
|
|
|
|
for characterSets.Next() {
|
|
var variableName string
|
|
var value string
|
|
|
|
characterSets.Scan(&variableName, &value)
|
|
|
|
resultsLatin1[variableName] = value
|
|
}
|
|
|
|
collations, err := dbLatin1.DB().Query("show variables where Variable_name = 'collation_connection';")
|
|
require.Nil(t, err)
|
|
|
|
for collations.Next() {
|
|
|
|
var variableName string
|
|
var value string
|
|
|
|
collations.Scan(&variableName, &value)
|
|
|
|
resultsLatin1[variableName] = value
|
|
}
|
|
|
|
assert.Equal(t, "latin1", resultsLatin1["character_set_client"])
|
|
assert.Equal(t, "latin1", resultsLatin1["character_set_connection"])
|
|
assert.Equal(t, "latin1", resultsLatin1["character_set_results"])
|
|
assert.Equal(t, "latin1_swedish_ci", resultsLatin1["collation_connection"])
|
|
|
|
}
|
|
|
|
func createFooTable(t *testing.T) {
|
|
p := getPool()
|
|
dbc := getDBC(p)
|
|
defer p.Close()
|
|
|
|
_, err := dbc.DB().Exec("DROP TABLE IF EXISTS foo")
|
|
require.Nil(t, err)
|
|
_, err = dbc.DB().Exec("CREATE TABLE IF NOT EXISTS foo(id int not null)")
|
|
require.Nil(t, err)
|
|
}
|
|
|
|
func TestTransactionCommit(t *testing.T) {
|
|
createFooTable(t)
|
|
|
|
p := getPool()
|
|
dbc := getDBC(p)
|
|
defer p.Close()
|
|
|
|
require.Nil(t, dbc.Begin())
|
|
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
|
|
|
|
//let's try a fresh connection
|
|
dbc1 := getDBC(p)
|
|
assert.EqualValues(t, 0, countFoos(t, dbc1))
|
|
|
|
assert.EqualValues(t, 2, countFoos(t, dbc))
|
|
dbc.Commit()
|
|
assert.EqualValues(t, 2, countFoos(t, dbc))
|
|
|
|
//let's try a fresh connection
|
|
dbc2 := getDBC(p)
|
|
assert.EqualValues(t, 2, countFoos(t, dbc2))
|
|
}
|
|
|
|
func TestTransactionCommitNestedAllOk(t *testing.T) {
|
|
createFooTable(t)
|
|
|
|
p := getPool()
|
|
dbc := getDBC(p)
|
|
defer p.Close()
|
|
//begin 1
|
|
require.Nil(t, dbc.Begin())
|
|
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
|
|
|
|
//let's try a fresh connection
|
|
dbc1 := getDBC(p)
|
|
assert.EqualValues(t, 0, countFoos(t, dbc1))
|
|
|
|
//begin 2
|
|
require.Nil(t, dbc.Begin())
|
|
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(3)").Exec()
|
|
//commit 2
|
|
dbc.Commit()
|
|
assert.EqualValues(t, 3, countFoos(t, dbc))
|
|
|
|
//let's try a fresh connection
|
|
dbc2 := getDBC(p)
|
|
assert.EqualValues(t, 0, countFoos(t, dbc2))
|
|
|
|
//commit 1
|
|
dbc.Commit()
|
|
assert.EqualValues(t, 3, countFoos(t, dbc))
|
|
|
|
//let's try a fresh connection
|
|
db3 := getDBC(p)
|
|
assert.EqualValues(t, 3, countFoos(t, db3))
|
|
}
|
|
|
|
func TestTransactionCommitNestedRollback(t *testing.T) {
|
|
createFooTable(t)
|
|
|
|
p := getPool()
|
|
dbc := getDBC(p)
|
|
defer p.Close()
|
|
//begin 1
|
|
require.Nil(t, dbc.Begin())
|
|
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
|
|
|
|
//let's try a fresh connection
|
|
dbc1 := getDBC(p)
|
|
assert.EqualValues(t, 0, countFoos(t, dbc1))
|
|
|
|
//begin 2
|
|
require.Nil(t, dbc.Begin())
|
|
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(3)").Exec()
|
|
//rollback 2
|
|
dbc.Rollback()
|
|
//rollback above doesn't have an effect since we are in the
|
|
//nested transaction
|
|
assert.EqualValues(t, 3, countFoos(t, dbc))
|
|
|
|
//let's try a fresh connection
|
|
dbc2 := getDBC(p)
|
|
assert.EqualValues(t, 0, countFoos(t, dbc2))
|
|
|
|
//rollback 1
|
|
dbc.Rollback()
|
|
assert.EqualValues(t, 0, countFoos(t, dbc))
|
|
|
|
//let's try a fresh connection
|
|
db3 := getDBC(p)
|
|
assert.EqualValues(t, 0, countFoos(t, db3))
|
|
}
|
|
|
|
func TestTransactionRollbackOnDeferAllOK(t *testing.T) {
|
|
createFooTable(t)
|
|
|
|
p := getPool()
|
|
dbc := getDBC(p)
|
|
defer p.Close()
|
|
|
|
fn := func() {
|
|
dbc.Begin()
|
|
defer dbc.RollbackOnDefer()
|
|
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
|
|
assert.EqualValues(t, 2, countFoos(t, dbc))
|
|
|
|
fnNested := func() {
|
|
dbc.Begin()
|
|
defer dbc.RollbackOnDefer()
|
|
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(3)").Exec()
|
|
assert.EqualValues(t, 3, countFoos(t, dbc))
|
|
dbc.Commit()
|
|
}
|
|
|
|
fnNested()
|
|
|
|
dbc.Commit()
|
|
}
|
|
|
|
fn()
|
|
dbc1 := getDBC(p)
|
|
assert.EqualValues(t, 3, countFoos(t, dbc1))
|
|
}
|
|
|
|
func TestTransactionRollbackOnDefer(t *testing.T) {
|
|
createFooTable(t)
|
|
|
|
p := getPool()
|
|
dbc := getDBC(p)
|
|
defer p.Close()
|
|
|
|
fn := func() {
|
|
dbc.Begin()
|
|
defer dbc.RollbackOnDefer()
|
|
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
|
|
assert.EqualValues(t, 2, countFoos(t, dbc))
|
|
|
|
fnNested := func() {
|
|
dbc.Begin()
|
|
defer dbc.RollbackOnDefer()
|
|
dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(3)").Exec()
|
|
assert.EqualValues(t, 3, countFoos(t, dbc))
|
|
//commit is missing for some reason, emulating error
|
|
//dbc.Commit()
|
|
}
|
|
|
|
fnNested()
|
|
|
|
//commit is missing for some reason, emulating error
|
|
//dbc.Commit()
|
|
}
|
|
|
|
fn()
|
|
|
|
dbc1 := getDBC(p)
|
|
assert.EqualValues(t, 0, countFoos(t, dbc1))
|
|
}
|
|
|
|
func TestTransaction(t *testing.T) {
|
|
createFooTable(t)
|
|
|
|
p := getPool()
|
|
dbc := getDBC(p)
|
|
defer p.Close()
|
|
|
|
require.Nil(t, dbc.Transaction(func(dbs *db.DBC) error {
|
|
dbs.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
|
|
|
|
//let's try a fresh connection
|
|
dbc1 := getDBC(p)
|
|
assert.EqualValues(t, 0, countFoos(t, dbc1))
|
|
|
|
assert.EqualValues(t, 2, countFoos(t, dbc))
|
|
return nil
|
|
}))
|
|
|
|
//let's try a fresh connection
|
|
dbc2 := getDBC(p)
|
|
assert.EqualValues(t, 2, countFoos(t, dbc2))
|
|
}
|
|
|
|
func TestTransactionRollbackOnError(t *testing.T) {
|
|
createFooTable(t)
|
|
|
|
p := getPool()
|
|
dbc := getDBC(p)
|
|
defer p.Close()
|
|
|
|
err := dbc.Transaction(func(dbs *db.DBC) error {
|
|
dbs.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
|
|
|
|
assert.EqualValues(t, 2, countFoos(t, dbc))
|
|
return errors.New("Opps")
|
|
})
|
|
assert.EqualValues(t, "Opps", err.Error())
|
|
|
|
//let's try a fresh connection
|
|
dbc2 := getDBC(p)
|
|
assert.EqualValues(t, 0, countFoos(t, dbc2))
|
|
}
|
|
|
|
func TestTransactionRollbackOnPanic(t *testing.T) {
|
|
createFooTable(t)
|
|
|
|
p := getPool()
|
|
dbc := getDBC(p)
|
|
defer p.Close()
|
|
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
str := r.(string)
|
|
assert.EqualValues(t, str, "Ooops")
|
|
|
|
//let's try a fresh connection
|
|
dbc2 := getDBC(p)
|
|
assert.EqualValues(t, 0, countFoos(t, dbc2))
|
|
|
|
}
|
|
}()
|
|
|
|
dbc.Transaction(func(dbs *db.DBC) error {
|
|
dbs.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
|
|
|
|
assert.EqualValues(t, 2, countFoos(t, dbc))
|
|
panic("Ooops")
|
|
})
|
|
}
|
|
|
|
func countFoos(t *testing.T, dbc *db.DBC) int {
|
|
var res int
|
|
err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res)
|
|
require.Nil(t, err)
|
|
return res
|
|
}
|
|
|
|
func TestTransactionRollback(t *testing.T) {
|
|
createFooTable(t)
|
|
|
|
p := getPool()
|
|
dbc := getDBC(p)
|
|
defer p.Close()
|
|
require.Nil(t, dbc.Begin())
|
|
{
|
|
_, err := dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1),(2)").Exec()
|
|
require.Nil(t, err)
|
|
}
|
|
|
|
var res int
|
|
{
|
|
err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res)
|
|
require.Nil(t, err)
|
|
assert.EqualValues(t, 2, res)
|
|
}
|
|
dbc.Rollback()
|
|
|
|
{
|
|
err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res)
|
|
require.Nil(t, err)
|
|
assert.EqualValues(t, 0, res)
|
|
}
|
|
|
|
//let's try a fresh connection
|
|
dbc2 := getDBC(p)
|
|
{
|
|
err := dbc2.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res)
|
|
require.Nil(t, err)
|
|
assert.EqualValues(t, 0, res)
|
|
}
|
|
}
|
|
|
|
func TestTransactionNestedRollback(t *testing.T) {
|
|
createFooTable(t)
|
|
|
|
p := getPool()
|
|
dbc := getDBC(p)
|
|
defer p.Close()
|
|
//begin 1
|
|
require.Nil(t, dbc.Begin())
|
|
{
|
|
_, err := dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1)").Exec()
|
|
require.Nil(t, err)
|
|
}
|
|
|
|
var res int
|
|
{
|
|
err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res)
|
|
require.Nil(t, err)
|
|
assert.EqualValues(t, 1, res)
|
|
}
|
|
|
|
//begin 2
|
|
require.Nil(t, dbc.Begin())
|
|
{
|
|
_, err := dbc.UpdateBySQL("INSERT INTO foo(id) VALUES(1)").Exec()
|
|
require.Nil(t, err)
|
|
}
|
|
|
|
//no real rollback happens here since we are in a 'bigger' transaction
|
|
//rollback 2
|
|
dbc.Rollback()
|
|
{
|
|
err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res)
|
|
require.Nil(t, err)
|
|
assert.EqualValues(t, 2, res)
|
|
}
|
|
|
|
//rollback 1
|
|
dbc.Rollback()
|
|
|
|
{
|
|
err := dbc.SelectBySQL("SELECT COUNT(id) FROM foo").LoadValue(&res)
|
|
require.Nil(t, err)
|
|
assert.EqualValues(t, 0, res)
|
|
}
|
|
|
|
}
|