From 56aeea98c95b8e78dc74a8bd0b2a8e27290fd72d Mon Sep 17 00:00:00 2001 From: Pavel Shevaev Date: Wed, 26 Oct 2022 13:53:06 +0300 Subject: [PATCH] First commit --- LICENSE | 20 +++ README.md | 336 +++++++++++++++++++++++++++++++++++++++++++ circle.yml | 4 + dbr.go | 52 +++++++ dbr_test.go | 80 +++++++++++ delete.go | 143 +++++++++++++++++++ delete_test.go | 67 +++++++++ errors.go | 14 ++ event.go | 50 +++++++ expr.go | 11 ++ go.mod | 9 ++ go.sum | 28 ++++ insert.go | 176 +++++++++++++++++++++++ insert_test.go | 116 +++++++++++++++ interpolate.go | 212 ++++++++++++++++++++++++++++ interpolate_test.go | 104 ++++++++++++++ now.go | 18 +++ quote.go | 22 +++ select.go | 210 +++++++++++++++++++++++++++ select_load.go | 310 ++++++++++++++++++++++++++++++++++++++++ select_return.go | 66 +++++++++ select_test.go | 337 ++++++++++++++++++++++++++++++++++++++++++++ struct_mapping.go | 109 ++++++++++++++ tags | 270 +++++++++++++++++++++++++++++++++++ thots.txt | 11 ++ transaction.go | 62 ++++++++ transaction_test.go | 54 +++++++ types.go | 70 +++++++++ update.go | 208 +++++++++++++++++++++++++++ update_test.go | 131 +++++++++++++++++ util.go | 23 +++ where.go | 104 ++++++++++++++ 32 files changed, 3427 insertions(+) create mode 100644 LICENSE create mode 100644 README.md create mode 100644 circle.yml create mode 100644 dbr.go create mode 100644 dbr_test.go create mode 100644 delete.go create mode 100644 delete_test.go create mode 100644 errors.go create mode 100644 event.go create mode 100644 expr.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 insert.go create mode 100644 insert_test.go create mode 100644 interpolate.go create mode 100644 interpolate_test.go create mode 100644 now.go create mode 100644 quote.go create mode 100644 select.go create mode 100644 select_load.go create mode 100644 select_return.go create mode 100644 select_test.go create mode 100644 struct_mapping.go create mode 100644 tags create mode 100644 thots.txt create mode 100644 transaction.go create mode 100644 transaction_test.go create mode 100644 types.go create mode 100644 update.go create mode 100644 update_test.go create mode 100644 util.go create mode 100644 where.go diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..a13da86 --- /dev/null +++ b/LICENSE @@ -0,0 +1,20 @@ +The MIT License (MIT) + +Copyright (c) 2014 Jonathan Novak, Tyler Smith + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..f18f002 --- /dev/null +++ b/README.md @@ -0,0 +1,336 @@ +# gocraft/dbr (database records) [![GoDoc](https://godoc.org/github.com/gocraft/web?status.png)](https://godoc.org/github.com/gocraft/dbr) + +gocraft/dbr provides additions to Go's database/sql for super fast performance and convenience. + +## Getting Started + +```go +package main + +import ( + "database/sql" + "fmt" + _ "github.com/go-sql-driver/mysql" + "github.com/gocraft/dbr" +) + +// Simple data model +type Suggestion struct { + Id int64 + Title string + CreatedAt dbr.NullTime +} + +// Hold a single global connection (pooling provided by sql driver) +var connection *dbr.Connection + +func main() { + // Create the connection during application initialization + db, _ := sql.Open("mysql", "root@unix(/tmp/mysqld.sock)/your_database") + connection = dbr.NewConnection(db, nil) + + // Create a session for each business unit of execution (e.g. a web request or goworkers job) + dbrSess := connection.NewSession(nil) + + // Get a record + var suggestion Suggestion + err := dbrSess.Select("id, title").From("suggestions").Where("id = ?", 13).LoadStruct(&suggestion) + + if err != nil { + fmt.Println(err.Error()) + } else { + fmt.Println("Title:", suggestion.Title) + } + + // JSON-ready, with dbr.Null* types serialized like you want + recordJson, _ := json.Marshal(&suggestion) + fmt.Println(string(recordJson)) +} +``` + +## Feature highlights + +### Automatically map results to structs +Querying is the heart of gocraft/dbr. Automatically map results to structs: +```go +var posts []*struct { + Id int64 + Title string + Body dbr.NullString +} +err := sess.Select("id, title, body"). + From("posts").Where("id = ?", id).LoadStruct(&post) +``` + +Additionally, easily query a single value or a slice of values: +```go +id, err := sess.SelectBySql("SELECT id FROM posts WHERE title=?", title).ReturnInt64() +ids, err := sess.SelectBySql("SELECT id FROM posts", title).ReturnInt64s() +``` + +See below for many more examples. + +### Use a Sweet Query Builder or use Plain SQL +gocraft/dbr supports both. + +Sweet Query Builder: +```go +builder := sess.Select("title", "body"). + From("posts"). + Where("created_at > ?", someTime). + OrderBy("id ASC"). + Limit(10) + +var posts []*Post +n, err := builder.LoadStructs(&posts) +``` + +Plain SQL: +```go +n, err := sess.SelectBySql(`SELECT title, body FROM posts WHERE created_at > ? + ORDER BY id ASC LIMIT 10`, someTime).LoadStructs(&post) +``` + +### IN queries that aren't horrible +Traditionally, database/sql uses prepared statements, which means each argument in an IN clause needs its own question mark. gocraft/dbr, on the other hand, handles interpolation itself so that you can easily use a single question mark paired with a dynamically sized slice. + +```go +// Traditional database/sql way: +ids := []int64{1,2,3,4,5} +questionMarks := []string +for _, _ := range ids { + questionMarks = append(questionMarks, "?") +} +query := fmt.Sprintf("SELECT * FROM posts WHERE id IN (%s)", + strings.Join(questionMarks, ",") // lolwut +rows, err := db.Query(query, ids) + +// gocraft/dbr way: +ids := []int64{1,2,3,4,5} +n, err := sess.SelectBySql("SELECT * FROM posts WHERE id IN ?", ids) // yay +``` + +### Amazing instrumentation +Writing instrumented code is a first-class concern for gocraft/dbr. We instrument each query to emit to a gocraft/health-compatible EventReceiver interface. NOTE: we have not released gocraft/health yet. This allows you to instrument your app to easily connect gocraft/dbr to your metrics systems, such statsd. + +### Faster performance than using using database/sql directly +Every time you call database/sql's db.Query("SELECT ...") method, under the hood, the mysql driver will create a prepared statement, execute it, and then throw it away. This has a big performance cost. + +gocraft/dbr doesn't use prepared statements. We ported mysql's query escape functionality directly into our package, which means we interpolate all of those question marks with their arguments before they get to MySQL. The result of this is that it's way faster, and just as secure. + +Check out these [benchmarks](https://github.com/tyler-smith/golang-sql-benchmark). + +### JSON Friendly +Every try to JSON-encode a sql.NullString? You get: +```json +{ + "str1": { + "Valid": true, + "String": "Hi!" + }, + "str2": { + "Valid": false, + "String": "" + } +} +``` + +Not quite what you want. gocraft/dbr has dbr.NullString (and the rest of the Null* types) that encode correctly, giving you: + +```json +{ + "str1": "Hi!", + "str2": null +} +``` + +## Driver support +Currently only MySQL has been tested because that is what we use. Feel free to make an issue for Postgres if you're interested in adding support and we can discuss what it would take. + +## Usage Examples + +### Making a session +All queries in gocraft/dbr are made in the context of a session. This is because when instrumenting your app, it's important to understand which business action the query took place in. See gocraft/health for more detail. + +Here's an example web endpoint that makes a session: +```go +// At app startup. If you have a gocraft/health stream, pass it here instead of nil. +dbrCxn = dbr.NewConnection(db, nil) + +func SuggestionsIndex(rw http.ResponseWriter, r *http.Request) { + // Make a session. If you have a gocraft/health job, pass it here instead of nil. + dbrSess := connection.NewSession(nil) + + // Do queries with the session: + var sugg Suggestion + err := dbrSess.Select("id, title").From("suggestions"). + Where("id = ?", suggestion.Id).LoadStruct(&sugg) + + // Render stuff, etc. Nothing else needs to be done with dbr. +} +``` + +### Simple Record CRUD +```go +// Create a new suggestion record +suggestion := &Suggestion{Title: "My Cool Suggestion", State: "open"} + +// Insert; inserting a record automatically sets an int64 Id field if present +response, err := dbrSess.InsertInto("suggestions"). + Columns("title", "state").Record(suggestion).Exec() + +// Update +response, err = dbrSess.Update("suggestions"). + Set("title", "My New Title").Where("id = ?", suggestion.Id).Exec() + +// Select +var otherSuggestion Suggestion +err = dbrSess.Select("id, title").From("suggestions"). + Where("id = ?", suggestion.Id).LoadStruct(&otherSuggestion) + +// Delete +response, err = dbrSess.DeleteFrom("suggestions"). + Where("id = ?", otherSuggestion.Id).Limit(1).Exec() +``` + +### Primitive Values +```go +// Load primitives into existing variables +var ids []int64 +idCount, err := sess.Select("id").From("suggestions").LoadValues(&ids) + +var titles []string +titleCount, err := sess.Select("title").From("suggestions").LoadValues(&titles) + +// Or return them directly +ids, err = sess.Select("id").From("suggestions").ReturnInt64s() +titles, err = sess.Select("title").From("suggestions").ReturnStrings() +``` + +### Overriding Column Names With Struct Tags +```go +// By default dbr converts CamelCase property names to snake_case column_names +// You can override this with struct tags, just like with JSON tags +// This is especially helpful while migrating from legacy systems +type Suggestion struct { + Id int64 + Title dbr.NullString `db:"subject"` // subjects are called titles now + CreatedAt dbr.NullTime +} +``` + +### Embedded structs +```go +// Columns are mapped to fields breadth-first +type Suggestion struct { + Id int64 + Title string + User *struct { + Id int64 `db:"user_id"` + } +} + +var suggestion Suggestion +err := dbrSess.Select("id, title, user_id").From("suggestions"). + Limit(1).LoadStruct(&suggestion) +``` + +### JSON encoding of Null* types +```go +// dbr.Null* types serialize to JSON like you want +suggestion := &Suggestion{Id: 1, Title: "Test Title"} +jsonBytes, err := json.Marshal(&suggestion) +fmt.Println(string(jsonBytes)) // {"id":1,"title":"Test Title","created_at":null} +``` + +### Inserting Multiple Records +```go +// Start bulding an INSERT statement +createDevsBuilder := sess.InsertInto("developers"). + Columns("name", "language", "employee_number") + +// Add some new developers +for i := 0; i < 3; i++ { + createDevsBuilder.Record(&Dev{Name: "Gopher", Language: "Go", EmployeeNumber: i}) +} + +// Execute statment +_, err := createDevsBuilder.Exec() +if err != nil { + log.Fatalln("Error creating developers", err) +} +``` + +### Updating Records +```go +// Update any rubyists to gophers +response, err := sess.Update("developers"). + Set("name", "Gopher"). + Set("language", "Go"). + Where("language = ?", "Ruby").Exec() + + +// Alternatively use a map of attributes to update +attrsMap := map[string]interface{}{"name": "Gopher", "language": "Go"} +response, err := sess.Update("developers"). + SetMap(attrsMap).Where("language = ?", "Ruby").Exec() +``` + +### Transactions +```go +// Start txn +tx, err := c.Dbr.Begin() +if err != nil { + return err +} + +// Rollback unless we're successful. You can also manually call tx.Rollback() if you'd like. +defer tx.RollbackUnlessCommitted() + +// Issue statements that might cause errors +res, err := tx.Update("suggestions").Set("state", "deleted").Where("deleted_at IS NOT NULL").Exec() +if err != nil { + return err +} + +// Commit the transaction +if err := tx.Commit(); err != nil { + return err +} +``` + +### Generate SQL without executing +```go +// Create builder +builder := dbrSess.Select("*").From("suggestions").Where("subdomain_id = ?", 1) + +// Get builder's SQL and arguments +sql, args := builder.ToSql() +fmt.Println(sql) // SELECT * FROM suggestions WHERE (subdomain_id = ?) +fmt.Println(args) // [1] + +// Use raw database/sql for actual query +rows, err := db.Query(sql, args...) +if err != nil { + log.Fatalln(err) +} + +// Alternatively you can build the full query +query, err := dbr.Interpolate(builder.ToSql()) +if err != nil { + log.Fatalln(err) +} +fmt.Println(query) // SELECT * FROM suggestions WHERE (subdomain_id = 1) +``` + +## Contributing +We gladly accept contributions. We want to keep dbr pretty light but I certainly don't mind discussing any changes or additions. Feel free to open an issue if you'd like to discus a potential change. + +## Thanks & Authors +Inspiration from these excellent libraries: +* [sqlx](https://github.com/jmoiron/sqlx) - various useful tools and utils for interacting with database/sql. +* [Squirrel](https://github.com/lann/squirrel) - simple fluent query builder. + +Authors: +* Jonathan Novak -- [https://github.com/cypriss](https://github.com/cypriss) +* Tyler Smith -- [https://github.com/tyler-smith](https://github.com/tyler-smith) diff --git a/circle.yml b/circle.yml new file mode 100644 index 0000000..ffc9a49 --- /dev/null +++ b/circle.yml @@ -0,0 +1,4 @@ +## Customize the test machine +machine: + environment: + DBR_TEST_DSN: "ubuntu:@unix(/var/run/mysqld/mysqld.sock)/circle_test?charset=utf8&parseTime=true" diff --git a/dbr.go b/dbr.go new file mode 100644 index 0000000..824f2d3 --- /dev/null +++ b/dbr.go @@ -0,0 +1,52 @@ +package dbr + +import ( + "database/sql" +) + +// Connection is a connection to the database with an EventReceiver +// to send events, errors, and timings to +type Connection struct { + Db *sql.DB + EventReceiver +} + +// Session represents a business unit of execution for some connection +type Session struct { + cxn *Connection + EventReceiver +} + +// NewConnection instantiates a Connection for a given database/sql connection +// and event receiver +func NewConnection(db *sql.DB, log EventReceiver) *Connection { + if log == nil { + log = nullReceiver + } + + return &Connection{Db: db, EventReceiver: log} +} + +// NewSession instantiates a Session for the Connection +func (cxn *Connection) NewSession(log EventReceiver) *Session { + if log == nil { + log = cxn.EventReceiver // Use parent instrumentation + } + return &Session{cxn: cxn, EventReceiver: log} +} + +// SessionRunner can do anything that a Session can except start a transaction. +type SessionRunner interface { + Select(cols ...string) *SelectBuilder + SelectBySql(sql string, args ...interface{}) *SelectBuilder + + InsertInto(into string) *InsertBuilder + Update(table string) *UpdateBuilder + UpdateBySql(sql string, args ...interface{}) *UpdateBuilder + DeleteFrom(from string) *DeleteBuilder +} + +type Runner interface { + Exec(query string, args ...interface{}) (sql.Result, error) + Query(query string, args ...interface{}) (*sql.Rows, error) +} diff --git a/dbr_test.go b/dbr_test.go new file mode 100644 index 0000000..59f420e --- /dev/null +++ b/dbr_test.go @@ -0,0 +1,80 @@ +package dbr + +import ( + "database/sql" + "fmt" + "log" + "os" +) + +// +// Test helpers +// + +// Returns a session that's not backed by a database +func createFakeSession() *Session { + cxn := NewConnection(nil, nil) + return cxn.NewSession(nil) +} + +func createRealSession() *Session { + cxn := NewConnection(realDb(), nil) + return cxn.NewSession(nil) +} + +func createRealSessionWithFixtures() *Session { + sess := createRealSession() + installFixtures(sess.cxn.Db) + return sess +} + +func realDb() *sql.DB { + driver := os.Getenv("DBR_TEST_DRIVER") + if driver == "" { + driver = "mysql" + } + + dsn := os.Getenv("DBR_TEST_DSN") + if dsn == "" { + dsn = "root:unprotected@unix(/tmp/mysql.sock)/uservoice_development?charset=utf8&parseTime=true" + } + + db, err := sql.Open(driver, dsn) + if err != nil { + log.Fatalln("Mysql error ", err) + } + + return db +} + +type dbrPerson struct { + Id int64 + Name string + Email NullString + Key NullString +} + +func installFixtures(db *sql.DB) { + createTablePeople := fmt.Sprintf(` + CREATE TABLE dbr_people ( + id int(11) DEFAULT NULL auto_increment PRIMARY KEY, + name varchar(255) NOT NULL, + email varchar(255), + %s varchar(255) + ) + `, "`key`") + + sqlToRun := []string{ + "DROP TABLE IF EXISTS dbr_people", + createTablePeople, + "INSERT INTO dbr_people (name,email) VALUES ('Jonathan', 'jonathan@uservoice.com')", + "INSERT INTO dbr_people (name,email) VALUES ('Dmitri', 'zavorotni@jadius.com')", + } + + for _, v := range sqlToRun { + _, err := db.Exec(v) + if err != nil { + log.Fatalln("Failed to execute statement: ", v, " Got error: ", err) + } + } +} diff --git a/delete.go b/delete.go new file mode 100644 index 0000000..f3d037b --- /dev/null +++ b/delete.go @@ -0,0 +1,143 @@ +package dbr + +import ( + "bytes" + "database/sql" + "fmt" + "time" +) + +// DeleteBuilder contains the clauses for a DELETE statement +type DeleteBuilder struct { + *Session + Runner + + From string + WhereFragments []*whereFragment + OrderBys []string + LimitCount uint64 + LimitValid bool + OffsetCount uint64 + OffsetValid bool +} + +// DeleteFrom creates a new DeleteBuilder for the given table +func (sess *Session) DeleteFrom(from string) *DeleteBuilder { + return &DeleteBuilder{ + Session: sess, + Runner: sess.cxn.Db, + From: from, + } +} + +// DeleteFrom creates a new DeleteBuilder for the given table +// in the context for a transaction +func (tx *Tx) DeleteFrom(from string) *DeleteBuilder { + return &DeleteBuilder{ + Session: tx.Session, + Runner: tx.Tx, + From: from, + } +} + +// Where appends a WHERE clause to the statement whereSqlOrMap can be a +// string or map. If it's a string, args wil replaces any places holders +func (b *DeleteBuilder) Where(whereSqlOrMap interface{}, args ...interface{}) *DeleteBuilder { + b.WhereFragments = append(b.WhereFragments, newWhereFragment(whereSqlOrMap, args)) + return b +} + +// OrderBy appends an ORDER BY clause to the statement +func (b *DeleteBuilder) OrderBy(ord string) *DeleteBuilder { + b.OrderBys = append(b.OrderBys, ord) + return b +} + +// OrderDir appends an ORDER BY clause with a direction to the statement +func (b *DeleteBuilder) OrderDir(ord string, isAsc bool) *DeleteBuilder { + if isAsc { + b.OrderBys = append(b.OrderBys, ord+" ASC") + } else { + b.OrderBys = append(b.OrderBys, ord+" DESC") + } + return b +} + +// Limit sets a LIMIT clause for the statement; overrides any existing LIMIT +func (b *DeleteBuilder) Limit(limit uint64) *DeleteBuilder { + b.LimitCount = limit + b.LimitValid = true + return b +} + +// Offset sets an OFFSET clause for the statement; overrides any existing OFFSET +func (b *DeleteBuilder) Offset(offset uint64) *DeleteBuilder { + b.OffsetCount = offset + b.OffsetValid = true + return b +} + +// ToSql serialized the DeleteBuilder to a SQL string +// It returns the string with placeholders and a slice of query arguments +func (b *DeleteBuilder) ToSql() (string, []interface{}) { + if len(b.From) == 0 { + panic("no table specified") + } + + var sql bytes.Buffer + var args []interface{} + + sql.WriteString("DELETE FROM ") + sql.WriteString(b.From) + + // Write WHERE clause if we have any fragments + if len(b.WhereFragments) > 0 { + sql.WriteString(" WHERE ") + writeWhereFragmentsToSql(b.WhereFragments, &sql, &args) + } + + // Ordering and limiting + if len(b.OrderBys) > 0 { + sql.WriteString(" ORDER BY ") + for i, s := range b.OrderBys { + if i > 0 { + sql.WriteString(", ") + } + sql.WriteString(s) + } + } + + if b.LimitValid { + sql.WriteString(" LIMIT ") + fmt.Fprint(&sql, b.LimitCount) + } + + if b.OffsetValid { + sql.WriteString(" OFFSET ") + fmt.Fprint(&sql, b.OffsetCount) + } + + return sql.String(), args +} + +// Exec executes the statement represented by the DeleteBuilder +// It returns the raw database/sql Result and an error if there was one +func (b *DeleteBuilder) Exec() (sql.Result, error) { + sql, args := b.ToSql() + + fullSql, err := Interpolate(sql, args) + if err != nil { + return nil, b.EventErrKv("dbr.delete.exec.interpolate", err, kvs{"sql": fullSql}) + } + + // Start the timer: + startTime := time.Now() + defer func() { b.TimingKv("dbr.delete", time.Since(startTime).Nanoseconds(), kvs{"sql": fullSql}) }() + + result, err := b.Runner.Exec(fullSql) + if err != nil { + return result, b.EventErrKv("dbr.delete.exec.exec", err, kvs{"sql": fullSql}) + } + + return result, nil +} diff --git a/delete_test.go b/delete_test.go new file mode 100644 index 0000000..0adcb37 --- /dev/null +++ b/delete_test.go @@ -0,0 +1,67 @@ +package dbr + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func BenchmarkDeleteSql(b *testing.B) { + s := createFakeSession() + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + s.DeleteFrom("alpha").Where("a", "b").Limit(1).OrderDir("id", true).ToSql() + } +} + +func TestDeleteAllToSql(t *testing.T) { + s := createFakeSession() + + sql, _ := s.DeleteFrom("a").ToSql() + + assert.Equal(t, sql, "DELETE FROM a") +} + +func TestDeleteSingleToSql(t *testing.T) { + s := createFakeSession() + + sql, args := s.DeleteFrom("a").Where("id = ?", 1).ToSql() + + assert.Equal(t, sql, "DELETE FROM a WHERE (id = ?)") + assert.Equal(t, args, []interface{}{1}) +} + +func TestDeleteTenStaringFromTwentyToSql(t *testing.T) { + s := createFakeSession() + + sql, _ := s.DeleteFrom("a").Limit(10).Offset(20).OrderBy("id").ToSql() + + assert.Equal(t, sql, "DELETE FROM a ORDER BY id LIMIT 10 OFFSET 20") +} + +func TestDeleteReal(t *testing.T) { + s := createRealSessionWithFixtures() + + // Insert a Barack + res, err := s.InsertInto("dbr_people").Columns("name", "email").Values("Barack", "barack@whitehouse.gov").Exec() + assert.NoError(t, err) + + // Get Barack's ID + id, err := res.LastInsertId() + assert.NoError(t, err) + + // Delete Barack + res, err = s.DeleteFrom("dbr_people").Where("id = ?", id).Exec() + assert.NoError(t, err) + + // Ensure we only reflected one row and that the id no longer exists + rowsAff, err := res.RowsAffected() + assert.NoError(t, err) + assert.Equal(t, rowsAff, 1) + + var count int64 + err = s.Select("count(*)").From("dbr_people").Where("id = ?", id).LoadValue(&count) + assert.NoError(t, err) + assert.Equal(t, count, 0) +} diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..b125ead --- /dev/null +++ b/errors.go @@ -0,0 +1,14 @@ +package dbr + +import ( + "errors" +) + +var ( + ErrNotFound = errors.New("not found") + ErrNotUTF8 = errors.New("invalid UTF-8") + ErrInvalidSliceLength = errors.New("length of slice is 0. length must be >= 1") + ErrInvalidSliceValue = errors.New("trying to interpolate invalid slice value into query") + ErrInvalidValue = errors.New("trying to interpolate invalid value into query") + ErrArgumentMismatch = errors.New("mismatch between ? (placeholders) and arguments") +) diff --git a/event.go b/event.go new file mode 100644 index 0000000..c4eb474 --- /dev/null +++ b/event.go @@ -0,0 +1,50 @@ +package dbr + +// EventReceiver gets events from dbr methods for logging purposes +type EventReceiver interface { + Event(eventName string) + EventKv(eventName string, kvs map[string]string) + EventErr(eventName string, err error) error + EventErrKv(eventName string, err error, kvs map[string]string) error + Timing(eventName string, nanoseconds int64) + TimingKv(eventName string, nanoseconds int64, kvs map[string]string) +} + +type kvs map[string]string + +// NullEventReceiver is a sentinel EventReceiver; use it if the caller doesn't supply one +type NullEventReceiver struct{} + +var nullReceiver = &NullEventReceiver{} + +// Event receives a simple notification when various events occur +func (n *NullEventReceiver) Event(eventName string) { + // noop +} + +// EventKv receives a notification when various events occur along with +// optional key/value data +func (n *NullEventReceiver) EventKv(eventName string, kvs map[string]string) { + // noop +} + +// EventErr receives a notification of an error if one occurs +func (n *NullEventReceiver) EventErr(eventName string, err error) error { + return err +} + +// EventErrKv receives a notification of an error if one occurs along with +// optional key/value data +func (n *NullEventReceiver) EventErrKv(eventName string, err error, kvs map[string]string) error { + return err +} + +// Timing receives the time an event took to happen +func (n *NullEventReceiver) Timing(eventName string, nanoseconds int64) { + // noop +} + +// TimingKv receives the time an event took to happen along with optional key/value data +func (n *NullEventReceiver) TimingKv(eventName string, nanoseconds int64, kvs map[string]string) { + // noop +} diff --git a/expr.go b/expr.go new file mode 100644 index 0000000..6b52b9f --- /dev/null +++ b/expr.go @@ -0,0 +1,11 @@ +package dbr + +type expr struct { + Sql string + Values []interface{} +} + +// Expr is a SQL fragment with placeholders, and a slice of args to replace them with +func Expr(sql string, values ...interface{}) *expr { + return &expr{Sql: sql, Values: values} +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..0a28e79 --- /dev/null +++ b/go.mod @@ -0,0 +1,9 @@ +module git.bit5.ru/backend/dbr + +go 1.13 + +require ( + github.com/go-sql-driver/mysql v1.4.1 + github.com/stretchr/testify v1.8.1 + google.golang.org/appengine v1.6.7 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..8b3ac53 --- /dev/null +++ b/go.sum @@ -0,0 +1,28 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA= +github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= +google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/insert.go b/insert.go new file mode 100644 index 0000000..0ed5d05 --- /dev/null +++ b/insert.go @@ -0,0 +1,176 @@ +package dbr + +import ( + "bytes" + "database/sql" + "fmt" + "reflect" + "time" +) + +// InsertBuilder contains the clauses for an INSERT statement +type InsertBuilder struct { + *Session + Runner + + Into string + Cols []string + Vals [][]interface{} + Recs []interface{} +} + +// InsertInto instantiates a InsertBuilder for the given table +func (sess *Session) InsertInto(into string) *InsertBuilder { + return &InsertBuilder{ + Session: sess, + Runner: sess.cxn.Db, + Into: into, + } +} + +// InsertInto instantiates a InsertBuilder for the given table bound to a transaction +func (tx *Tx) InsertInto(into string) *InsertBuilder { + return &InsertBuilder{ + Session: tx.Session, + Runner: tx.Tx, + Into: into, + } +} + +// Columns appends columns to insert in the statement +func (b *InsertBuilder) Columns(columns ...string) *InsertBuilder { + b.Cols = columns + return b +} + +// Values appends a set of values to the statement +func (b *InsertBuilder) Values(vals ...interface{}) *InsertBuilder { + b.Vals = append(b.Vals, vals) + return b +} + +// Record pulls in values to match Columns from the record +func (b *InsertBuilder) Record(record interface{}) *InsertBuilder { + b.Recs = append(b.Recs, record) + return b +} + +// Pair adds a key/value pair to the statement +func (b *InsertBuilder) Pair(column string, value interface{}) *InsertBuilder { + b.Cols = append(b.Cols, column) + lenVals := len(b.Vals) + if lenVals == 0 { + args := []interface{}{value} + b.Vals = [][]interface{}{args} + } else if lenVals == 1 { + b.Vals[0] = append(b.Vals[0], value) + } else { + panic("pair only allows you to specify 1 record to insret") + } + return b +} + +// ToSql serialized the InsertBuilder to a SQL string +// It returns the string with placeholders and a slice of query arguments +func (b *InsertBuilder) ToSql() (string, []interface{}) { + if len(b.Into) == 0 { + panic("no table specified") + } + if len(b.Cols) == 0 { + panic("no columns specified") + } + if len(b.Vals) == 0 && len(b.Recs) == 0 { + panic("no values or records specified") + } + + var sql bytes.Buffer + var placeholder bytes.Buffer // Build the placeholder like "(?,?,?)" + var args []interface{} + + sql.WriteString("INSERT INTO ") + sql.WriteString(b.Into) + sql.WriteString(" (") + + // Simulataneously write the cols to the sql buffer, and build a placeholder + placeholder.WriteRune('(') + for i, c := range b.Cols { + if i > 0 { + sql.WriteRune(',') + placeholder.WriteRune(',') + } + Quoter.writeQuotedColumn(c, &sql) + placeholder.WriteRune('?') + } + sql.WriteString(") VALUES ") + placeholder.WriteRune(')') + placeholderStr := placeholder.String() + + // Go thru each value we want to insert. Write the placeholders, and collect args + for i, row := range b.Vals { + if i > 0 { + sql.WriteRune(',') + } + sql.WriteString(placeholderStr) + + for _, v := range row { + args = append(args, v) + } + } + anyVals := len(b.Vals) > 0 + + // Go thru the records. Write the placeholders, and do reflection on the records to extract args + for i, rec := range b.Recs { + if i > 0 || anyVals { + sql.WriteRune(',') + } + sql.WriteString(placeholderStr) + + ind := reflect.Indirect(reflect.ValueOf(rec)) + vals, err := b.valuesFor(ind.Type(), ind, b.Cols) + if err != nil { + panic(err.Error()) + } + for _, v := range vals { + args = append(args, v) + } + } + + return sql.String(), args +} + +// Exec executes the statement represented by the InsertBuilder +// It returns the raw database/sql Result and an error if there was one +func (b *InsertBuilder) Exec() (sql.Result, error) { + sql, args := b.ToSql() + + fullSql, err := Interpolate(sql, args) + if err != nil { + return nil, b.EventErrKv("dbr.insert.exec.interpolate", err, kvs{"sql": sql, "args": fmt.Sprint(args)}) + } + + // Start the timer: + startTime := time.Now() + defer func() { b.TimingKv("dbr.insert", time.Since(startTime).Nanoseconds(), kvs{"sql": fullSql}) }() + + result, err := b.Runner.Exec(fullSql) + if err != nil { + return result, b.EventErrKv("dbr.insert.exec.exec", err, kvs{"sql": fullSql}) + } + + // If the structure has an "Id" field which is an int64, set it from the LastInsertId(). Otherwise, don't bother. + if len(b.Recs) == 1 { + rec := b.Recs[0] + val := reflect.Indirect(reflect.ValueOf(rec)) + if val.Kind() == reflect.Struct && val.CanSet() { + if idField := val.FieldByName("Id"); idField.IsValid() && idField.Kind() == reflect.Int64 { + if lastID, err := result.LastInsertId(); err == nil { + idField.Set(reflect.ValueOf(lastID)) + } else { + b.EventErrKv("dbr.insert.exec.last_inserted_id", err, kvs{"sql": fullSql}) + } + } + } + } + + return result, nil +} diff --git a/insert_test.go b/insert_test.go new file mode 100644 index 0000000..7f47222 --- /dev/null +++ b/insert_test.go @@ -0,0 +1,116 @@ +package dbr + +import ( + "database/sql" + "testing" + + "github.com/stretchr/testify/assert" +) + +type someRecord struct { + SomethingId int + UserId int64 + Other bool +} + +func BenchmarkInsertValuesSql(b *testing.B) { + s := createFakeSession() + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + s.InsertInto("alpha").Columns("something_id", "user_id", "other").Values(1, 2, true).ToSql() + } +} + +func BenchmarkInsertRecordsSql(b *testing.B) { + s := createFakeSession() + obj := someRecord{1, 99, false} + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + s.InsertInto("alpha").Columns("something_id", "user_id", "other").Record(obj).ToSql() + } +} + +func TestInsertSingleToSql(t *testing.T) { + s := createFakeSession() + + sql, args := s.InsertInto("a").Columns("b", "c").Values(1, 2).ToSql() + + assert.Equal(t, sql, "INSERT INTO a (`b`,`c`) VALUES (?,?)") + assert.Equal(t, args, []interface{}{1, 2}) +} + +func TestInsertMultipleToSql(t *testing.T) { + s := createFakeSession() + + sql, args := s.InsertInto("a").Columns("b", "c").Values(1, 2).Values(3, 4).ToSql() + + assert.Equal(t, sql, "INSERT INTO a (`b`,`c`) VALUES (?,?),(?,?)") + assert.Equal(t, args, []interface{}{1, 2, 3, 4}) +} + +func TestInsertRecordsToSql(t *testing.T) { + s := createFakeSession() + + objs := []someRecord{{1, 88, false}, {2, 99, true}} + sql, args := s.InsertInto("a").Columns("something_id", "user_id", "other").Record(objs[0]).Record(objs[1]).ToSql() + + assert.Equal(t, sql, "INSERT INTO a (`something_id`,`user_id`,`other`) VALUES (?,?,?),(?,?,?)") + assert.Equal(t, args, []interface{}{1, 88, false, 2, 99, true}) +} + +func TestInsertKeywordColumnName(t *testing.T) { + // Insert a column whose name is reserved + s := createRealSessionWithFixtures() + res, err := s.InsertInto("dbr_people").Columns("name", "key").Values("Barack", "44").Exec() + assert.NoError(t, err) + + rowsAff, err := res.RowsAffected() + assert.NoError(t, err) + assert.Equal(t, rowsAff, 1) +} + +func TestInsertReal(t *testing.T) { + // Insert by specifying values + s := createRealSessionWithFixtures() + res, err := s.InsertInto("dbr_people").Columns("name", "email").Values("Barack", "obama@whitehouse.gov").Exec() + validateInsertingBarack(t, s, res, err) + + // Insert by specifying a record (ptr to struct) + s = createRealSessionWithFixtures() + person := dbrPerson{Name: "Barack"} + person.Email.Valid = true + person.Email.String = "obama@whitehouse.gov" + res, err = s.InsertInto("dbr_people").Columns("name", "email").Record(&person).Exec() + validateInsertingBarack(t, s, res, err) + + // Insert by specifying a record (struct) + s = createRealSessionWithFixtures() + res, err = s.InsertInto("dbr_people").Columns("name", "email").Record(person).Exec() + validateInsertingBarack(t, s, res, err) +} + +func validateInsertingBarack(t *testing.T, s *Session, res sql.Result, err error) { + assert.NoError(t, err) + id, err := res.LastInsertId() + assert.NoError(t, err) + rowsAff, err := res.RowsAffected() + assert.NoError(t, err) + + assert.True(t, id > 0) + assert.Equal(t, rowsAff, 1) + + var person dbrPerson + err = s.Select("*").From("dbr_people").Where("id = ?", id).LoadStruct(&person) + assert.NoError(t, err) + + assert.Equal(t, person.Id, id) + assert.Equal(t, person.Name, "Barack") + assert.Equal(t, person.Email.Valid, true) + assert.Equal(t, person.Email.String, "obama@whitehouse.gov") +} + +// TODO: do a real test inserting multiple records diff --git a/interpolate.go b/interpolate.go new file mode 100644 index 0000000..2e97c06 --- /dev/null +++ b/interpolate.go @@ -0,0 +1,212 @@ +package dbr + +import ( + // "fmt" + "bytes" + "database/sql/driver" + "reflect" + "strconv" + "strings" + "time" + "unicode/utf8" +) + +// Need to turn \x00, \n, \r, \, ', " and \x1a +// Returns an escaped, quoted string. eg, "hello 'world'" -> "'hello \'world\''" +func escapeAndQuoteString(val string) string { + buf := bytes.Buffer{} + + buf.WriteRune('\'') + + for _, char := range val { + if char == '\'' { // single quote: ' -> \' + buf.WriteString("\\'") + } else if char == '"' { // double quote: " -> \" + buf.WriteString("\\\"") + } else if char == '\\' { // slash: \ -> "\\" + buf.WriteString("\\\\") + } else if char == '\n' { // control: newline: \n -> "\n" + buf.WriteString("\\n") + } else if char == '\r' { // control: return: \r -> "\r" + buf.WriteString("\\r") + } else if char == 0 { // control: NUL: 0 -> "\x00" + buf.WriteString("\\x00") + } else if char == 0x1a { // control: \x1a -> "\x1a" + buf.WriteString("\\x1a") + } else { + buf.WriteRune(char) + } + } + + buf.WriteRune('\'') + + return buf.String() +} + +func isUint(k reflect.Kind) bool { + return (k == reflect.Uint) || + (k == reflect.Uint8) || + (k == reflect.Uint16) || + (k == reflect.Uint32) || + (k == reflect.Uint64) +} + +func isInt(k reflect.Kind) bool { + return (k == reflect.Int) || + (k == reflect.Int8) || + (k == reflect.Int16) || + (k == reflect.Int32) || + (k == reflect.Int64) +} + +func isFloat(k reflect.Kind) bool { + return (k == reflect.Float32) || + (k == reflect.Float64) +} + +// sql is like "id = ? OR username = ?" +// vals is like []interface{}{4, "bob"} +// NOTE that vals can only have values of certain types: +// - Integers (signed and unsigned) +// - floats +// - strings (that are valid utf-8) +// - booleans +// - times +var typeOfTime = reflect.TypeOf(time.Time{}) + +// Interpolate takes a SQL string with placeholders and a list of arguments to +// replace them with. Returns a blank string and error if the number of placeholders +// does not match the number of arguments. +func Interpolate(sql string, vals []interface{}) (string, error) { + // Get the number of arguments to add to this query + maxVals := len(vals) + + // If our query is blank and has no args return early + // Args with a blank query is an error + if sql == "" { + if maxVals != 0 { + return "", ErrArgumentMismatch + } + return "", nil + } + + // If we have no args and the query has no place holders return early + // No args for a query with place holders is an error + if len(vals) == 0 { + for _, c := range sql { + if c == '?' { + return "", ErrArgumentMismatch + } + } + return sql, nil + } + + // Iterate over each rune in the sql string and replace with the next arg if it's a place holder + curVal := 0 + buf := bytes.Buffer{} + + for _, r := range sql { + if r != '?' { + buf.WriteRune(r) + } else if r == '?' && curVal < maxVals { + v := vals[curVal] + + valuer, ok := v.(driver.Valuer) + if ok { + val, err := valuer.Value() + if err != nil { + return "", err + } + v = val + } + + valueOfV := reflect.ValueOf(v) + kindOfV := valueOfV.Kind() + + if v == nil { + buf.WriteString("NULL") + } else if isInt(kindOfV) { + var ival = valueOfV.Int() + + buf.WriteString(strconv.FormatInt(ival, 10)) + } else if isUint(kindOfV) { + var uival = valueOfV.Uint() + + buf.WriteString(strconv.FormatUint(uival, 10)) + } else if kindOfV == reflect.String { + var str = valueOfV.String() + + if !utf8.ValidString(str) { + return "", ErrNotUTF8 + } + + buf.WriteString(escapeAndQuoteString(str)) + } else if isFloat(kindOfV) { + var fval = valueOfV.Float() + + buf.WriteString(strconv.FormatFloat(fval, 'f', -1, 64)) + } else if kindOfV == reflect.Bool { + var bval = valueOfV.Bool() + + if bval { + buf.WriteRune('1') + } else { + buf.WriteRune('0') + } + } else if kindOfV == reflect.Struct { + if typeOfV := valueOfV.Type(); typeOfV == typeOfTime { + t := valueOfV.Interface().(time.Time) + buf.WriteString(escapeAndQuoteString(t.UTC().Format(timeFormat))) + } else { + return "", ErrInvalidValue + } + } else if kindOfV == reflect.Slice { + typeOfV := reflect.TypeOf(v) + subtype := typeOfV.Elem() + kindOfSubtype := subtype.Kind() + + sliceLen := valueOfV.Len() + stringSlice := make([]string, 0, sliceLen) + + if sliceLen == 0 { + return "", ErrInvalidSliceLength + } else if isInt(kindOfSubtype) { + for i := 0; i < sliceLen; i++ { + var ival = valueOfV.Index(i).Int() + stringSlice = append(stringSlice, strconv.FormatInt(ival, 10)) + } + } else if isUint(kindOfSubtype) { + for i := 0; i < sliceLen; i++ { + var uival = valueOfV.Index(i).Uint() + stringSlice = append(stringSlice, strconv.FormatUint(uival, 10)) + } + } else if kindOfSubtype == reflect.String { + for i := 0; i < sliceLen; i++ { + var str = valueOfV.Index(i).String() + if !utf8.ValidString(str) { + return "", ErrNotUTF8 + } + stringSlice = append(stringSlice, escapeAndQuoteString(str)) + } + } else { + return "", ErrInvalidSliceValue + } + buf.WriteRune('(') + buf.WriteString(strings.Join(stringSlice, ",")) + buf.WriteRune(')') + } else { + return "", ErrInvalidValue + } + + curVal++ + } else { + return "", ErrArgumentMismatch + } + } + + if curVal != maxVals { + return "", ErrArgumentMismatch + } + + return buf.String(), nil +} diff --git a/interpolate_test.go b/interpolate_test.go new file mode 100644 index 0000000..0364d7c --- /dev/null +++ b/interpolate_test.go @@ -0,0 +1,104 @@ +package dbr + +import ( + "database/sql/driver" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestInterpolateNil(t *testing.T) { + args := []interface{}{nil} + + str, err := Interpolate("SELECT * FROM x WHERE a = ?", args) + assert.NoError(t, err) + assert.Equal(t, str, "SELECT * FROM x WHERE a = NULL") +} + +func TestInterpolateInts(t *testing.T) { + args := []interface{}{ + int(1), + int8(-2), + int16(3), + int32(4), + int64(5), + uint(6), + uint8(7), + uint16(8), + uint32(9), + uint64(10), + } + + str, err := Interpolate("SELECT * FROM x WHERE a = ? AND b = ? AND c = ? AND d = ? AND e = ? AND f = ? AND g = ? AND h = ? AND i = ? AND j = ?", args) + assert.NoError(t, err) + assert.Equal(t, str, "SELECT * FROM x WHERE a = 1 AND b = -2 AND c = 3 AND d = 4 AND e = 5 AND f = 6 AND g = 7 AND h = 8 AND i = 9 AND j = 10") +} + +func TestInterpolateBools(t *testing.T) { + args := []interface{}{true, false} + + str, err := Interpolate("SELECT * FROM x WHERE a = ? AND b = ?", args) + assert.NoError(t, err) + assert.Equal(t, str, "SELECT * FROM x WHERE a = 1 AND b = 0") +} + +func TestInterpolateFloats(t *testing.T) { + args := []interface{}{float32(0.15625), float64(3.14159)} + + str, err := Interpolate("SELECT * FROM x WHERE a = ? AND b = ?", args) + assert.NoError(t, err) + assert.Equal(t, str, "SELECT * FROM x WHERE a = 0.15625 AND b = 3.14159") +} + +func TestInterpolateStrings(t *testing.T) { + args := []interface{}{"hello", "\"hello's \\ world\" \n\r\x00\x1a"} + + str, err := Interpolate("SELECT * FROM x WHERE a = ? AND b = ?", args) + assert.NoError(t, err) + assert.Equal(t, str, "SELECT * FROM x WHERE a = 'hello' AND b = '\\\"hello\\'s \\\\ world\\\" \\n\\r\\x00\\x1a'") +} + +func TestInterpolateSlices(t *testing.T) { + args := []interface{}{[]int{1}, []int{1, 2, 3}, []uint32{5, 6, 7}, []string{"wat", "ok"}} + + str, err := Interpolate("SELECT * FROM x WHERE a = ? AND b = ? AND c = ? AND d = ?", args) + assert.NoError(t, err) + assert.Equal(t, str, "SELECT * FROM x WHERE a = (1) AND b = (1,2,3) AND c = (5,6,7) AND d = ('wat','ok')") +} + +type myString struct { + Present bool + Val string +} + +func (m myString) Value() (driver.Value, error) { + if m.Present { + return m.Val, nil + } else { + return nil, nil + } +} + +func TestIntepolatingValuers(t *testing.T) { + args := []interface{}{myString{true, "wat"}, myString{false, "fry"}} + + str, err := Interpolate("SELECT * FROM x WHERE a = ? AND b = ?", args) + assert.NoError(t, err) + assert.Equal(t, str, "SELECT * FROM x WHERE a = 'wat' AND b = NULL") +} + +func TestInterpolateErrors(t *testing.T) { + _, err := Interpolate("SELECT * FROM x WHERE a = ? AND b = ?", []interface{}{1}) + assert.Equal(t, err, ErrArgumentMismatch) + + _, err = Interpolate("SELECT * FROM x WHERE", []interface{}{1}) + assert.Equal(t, err, ErrArgumentMismatch) + + _, err = Interpolate("SELECT * FROM x WHERE a = ?", []interface{}{string([]byte{0x34, 0xFF, 0xFE})}) + assert.Equal(t, err, ErrNotUTF8) + + _, err = Interpolate("SELECT * FROM x WHERE a = ?", []interface{}{struct{}{}}) + assert.Equal(t, err, ErrInvalidValue) + + _, err = Interpolate("SELECT * FROM x WHERE a = ?", []interface{}{[]struct{}{struct{}{}, struct{}{}}}) + assert.Equal(t, err, ErrInvalidSliceValue) +} diff --git a/now.go b/now.go new file mode 100644 index 0000000..e2aca6a --- /dev/null +++ b/now.go @@ -0,0 +1,18 @@ +package dbr + +import ( + "database/sql/driver" + "time" +) + +type nowSentinel struct{} + +// Now is a value that serializes to the curren time +var Now = nowSentinel{} +var timeFormat = "2006-01-02 15:04:05" + +// Value implements a valuer for compatibility +func (n nowSentinel) Value() (driver.Value, error) { + now := time.Now().UTC().Format(timeFormat) + return now, nil +} diff --git a/quote.go b/quote.go new file mode 100644 index 0000000..45664f7 --- /dev/null +++ b/quote.go @@ -0,0 +1,22 @@ +package dbr + +import ( + "bytes" +) + +// Quoter is the quoter to use for quoting text; use Mysql quoting by default +var Quoter = MysqlQuoter{} + +// Interface for driver-swappable quoting +type quoter interface { + writeQuotedColumn() +} + +// MysqlQuoter implements Mysql-specific quoting +type MysqlQuoter struct{} + +func (q MysqlQuoter) writeQuotedColumn(column string, sql *bytes.Buffer) { + sql.WriteRune('`') + sql.WriteString(column) + sql.WriteRune('`') +} diff --git a/select.go b/select.go new file mode 100644 index 0000000..ea978c5 --- /dev/null +++ b/select.go @@ -0,0 +1,210 @@ +package dbr + +import ( + "bytes" + "fmt" +) + +// SelectBuilder contains the clauses for a SELECT statement +type SelectBuilder struct { + *Session + Runner + + RawFullSql string + RawArguments []interface{} + + IsDistinct bool + Columns []string + FromTable string + WhereFragments []*whereFragment + GroupBys []string + HavingFragments []*whereFragment + OrderBys []string + LimitCount uint64 + LimitValid bool + OffsetCount uint64 + OffsetValid bool +} + +// Select creates a new SelectBuilder that select that given columns +func (sess *Session) Select(cols ...string) *SelectBuilder { + return &SelectBuilder{ + Session: sess, + Runner: sess.cxn.Db, + Columns: cols, + } +} + +// SelectBySql creates a new SelectBuilder for the given SQL string and arguments +func (sess *Session) SelectBySql(sql string, args ...interface{}) *SelectBuilder { + return &SelectBuilder{ + Session: sess, + Runner: sess.cxn.Db, + RawFullSql: sql, + RawArguments: args, + } +} + +// Select creates a new SelectBuilder that select that given columns bound to the transaction +func (tx *Tx) Select(cols ...string) *SelectBuilder { + return &SelectBuilder{ + Session: tx.Session, + Runner: tx.Tx, + Columns: cols, + } +} + +// SelectBySql creates a new SelectBuilder for the given SQL string and arguments bound to the transaction +func (tx *Tx) SelectBySql(sql string, args ...interface{}) *SelectBuilder { + return &SelectBuilder{ + Session: tx.Session, + Runner: tx.Tx, + RawFullSql: sql, + RawArguments: args, + } +} + +// Distinct marks the statement as a DISTINCT SELECT +func (b *SelectBuilder) Distinct() *SelectBuilder { + b.IsDistinct = true + return b +} + +// From sets the table to SELECT FROM +func (b *SelectBuilder) From(from string) *SelectBuilder { + b.FromTable = from + return b +} + +// Where appends a WHERE clause to the statement for the given string and args +// or map of column/value pairs +func (b *SelectBuilder) Where(whereSqlOrMap interface{}, args ...interface{}) *SelectBuilder { + b.WhereFragments = append(b.WhereFragments, newWhereFragment(whereSqlOrMap, args)) + return b +} + +// GroupBy appends a column to group the statement +func (b *SelectBuilder) GroupBy(group string) *SelectBuilder { + b.GroupBys = append(b.GroupBys, group) + return b +} + +// Having appends a HAVING clause to the statement +func (b *SelectBuilder) Having(whereSqlOrMap interface{}, args ...interface{}) *SelectBuilder { + b.HavingFragments = append(b.HavingFragments, newWhereFragment(whereSqlOrMap, args)) + return b +} + +// OrderBy appends a column to ORDER the statement by +func (b *SelectBuilder) OrderBy(ord string) *SelectBuilder { + b.OrderBys = append(b.OrderBys, ord) + return b +} + +// OrderDir appends a column to ORDER the statement by with a given direction +func (b *SelectBuilder) OrderDir(ord string, isAsc bool) *SelectBuilder { + if isAsc { + b.OrderBys = append(b.OrderBys, ord+" ASC") + } else { + b.OrderBys = append(b.OrderBys, ord+" DESC") + } + return b +} + +// Limit sets a limit for the statement; overrides any existing LIMIT +func (b *SelectBuilder) Limit(limit uint64) *SelectBuilder { + b.LimitCount = limit + b.LimitValid = true + return b +} + +// Offset sets an offset for the statement; overrides any existing OFFSET +func (b *SelectBuilder) Offset(offset uint64) *SelectBuilder { + b.OffsetCount = offset + b.OffsetValid = true + return b +} + +// Paginate sets LIMIT/OFFSET for the statement based on the given page/perPage +// Assumes page/perPage are valid. Page and perPage must be >= 1 +func (b *SelectBuilder) Paginate(page, perPage uint64) *SelectBuilder { + b.Limit(perPage) + b.Offset((page - 1) * perPage) + return b +} + +// ToSql serialized the SelectBuilder to a SQL string +// It returns the string with placeholders and a slice of query arguments +func (b *SelectBuilder) ToSql() (string, []interface{}) { + if b.RawFullSql != "" { + return b.RawFullSql, b.RawArguments + } + + if len(b.Columns) == 0 { + panic("no columns specified") + } + if len(b.FromTable) == 0 { + panic("no table specified") + } + + var sql bytes.Buffer + var args []interface{} + + sql.WriteString("SELECT ") + + if b.IsDistinct { + sql.WriteString("DISTINCT ") + } + + for i, s := range b.Columns { + if i > 0 { + sql.WriteString(", ") + } + sql.WriteString(s) + } + + sql.WriteString(" FROM ") + sql.WriteString(b.FromTable) + + if len(b.WhereFragments) > 0 { + sql.WriteString(" WHERE ") + writeWhereFragmentsToSql(b.WhereFragments, &sql, &args) + } + + if len(b.GroupBys) > 0 { + sql.WriteString(" GROUP BY ") + for i, s := range b.GroupBys { + if i > 0 { + sql.WriteString(", ") + } + sql.WriteString(s) + } + } + + if len(b.HavingFragments) > 0 { + sql.WriteString(" HAVING ") + writeWhereFragmentsToSql(b.HavingFragments, &sql, &args) + } + + if len(b.OrderBys) > 0 { + sql.WriteString(" ORDER BY ") + for i, s := range b.OrderBys { + if i > 0 { + sql.WriteString(", ") + } + sql.WriteString(s) + } + } + + if b.LimitValid { + sql.WriteString(" LIMIT ") + fmt.Fprint(&sql, b.LimitCount) + } + + if b.OffsetValid { + sql.WriteString(" OFFSET ") + fmt.Fprint(&sql, b.OffsetCount) + } + + return sql.String(), args +} diff --git a/select_load.go b/select_load.go new file mode 100644 index 0000000..eed4baf --- /dev/null +++ b/select_load.go @@ -0,0 +1,310 @@ +package dbr + +import ( + "reflect" + "time" +) + +// Unvetted thots: +// Given a query and given a structure (field list), there's 2 sets of fields. +// Take the intersection. We can fill those in. great. +// For fields in the structure that aren't in the query, we'll let that slide if db:"-" +// For fields in the structure that aren't in the query but without db:"-", return error +// For fields in the query that aren't in the structure, we'll ignore them. + +// LoadStructs executes the SelectBuilder and loads the resulting data into a slice of structs +// dest must be a pointer to a slice of pointers to structs +// Returns the number of items found (which is not necessarily the # of items set) +func (b *SelectBuilder) LoadStructs(dest interface{}) (int, error) { + // + // Validate the dest, and extract the reflection values we need. + // + + // This must be a pointer to a slice + valueOfDest := reflect.ValueOf(dest) + kindOfDest := valueOfDest.Kind() + + if kindOfDest != reflect.Ptr { + panic("invalid type passed to LoadStructs. Need a pointer to a slice") + } + + // This must a slice + valueOfDest = reflect.Indirect(valueOfDest) + kindOfDest = valueOfDest.Kind() + + if kindOfDest != reflect.Slice { + panic("invalid type passed to LoadStructs. Need a pointer to a slice") + } + + // The slice elements must be pointers to structures + recordType := valueOfDest.Type().Elem() + if recordType.Kind() != reflect.Ptr { + panic("Elements need to be pointers to structures") + } + + recordType = recordType.Elem() + if recordType.Kind() != reflect.Struct { + panic("Elements need to be pointers to structures") + } + + // + // Get full SQL + // + fullSql, err := Interpolate(b.ToSql()) + if err != nil { + return 0, b.EventErr("dbr.select.load_all.interpolate", err) + } + + numberOfRowsReturned := 0 + + // Start the timer: + startTime := time.Now() + defer func() { b.TimingKv("dbr.select", time.Since(startTime).Nanoseconds(), kvs{"sql": fullSql}) }() + + // Run the query: + rows, err := b.Runner.Query(fullSql) + if err != nil { + return 0, b.EventErrKv("dbr.select.load_all.query", err, kvs{"sql": fullSql}) + } + defer rows.Close() + + // Get the columns returned + columns, err := rows.Columns() + if err != nil { + return numberOfRowsReturned, b.EventErrKv("dbr.select.load_one.rows.Columns", err, kvs{"sql": fullSql}) + } + + // Create a map of this result set to the struct fields + fieldMap, err := b.calculateFieldMap(recordType, columns, false) + if err != nil { + return numberOfRowsReturned, b.EventErrKv("dbr.select.load_all.calculateFieldMap", err, kvs{"sql": fullSql}) + } + + // Build a 'holder', which is an []interface{}. Each value will be the set to address of the field corresponding to our newly made records: + holder := make([]interface{}, len(fieldMap)) + + // Iterate over rows and scan their data into the structs + sliceValue := valueOfDest + for rows.Next() { + // Create a new record to store our row: + pointerToNewRecord := reflect.New(recordType) + newRecord := reflect.Indirect(pointerToNewRecord) + + // Prepare the holder for this record + scannable, err := b.prepareHolderFor(newRecord, fieldMap, holder) + if err != nil { + return numberOfRowsReturned, b.EventErrKv("dbr.select.load_all.holderFor", err, kvs{"sql": fullSql}) + } + + // Load up our new structure with the row's values + err = rows.Scan(scannable...) + if err != nil { + return numberOfRowsReturned, b.EventErrKv("dbr.select.load_all.scan", err, kvs{"sql": fullSql}) + } + + // Append our new record to the slice: + sliceValue = reflect.Append(sliceValue, pointerToNewRecord) + + numberOfRowsReturned++ + } + valueOfDest.Set(sliceValue) + + // Check for errors at the end. Supposedly these are error that can happen during iteration. + if err = rows.Err(); err != nil { + return numberOfRowsReturned, b.EventErrKv("dbr.select.load_all.rows_err", err, kvs{"sql": fullSql}) + } + + return numberOfRowsReturned, nil +} + +// LoadStruct executes the SelectBuilder and loads the resulting data into a struct +// dest must be a pointer to a struct +// Returns ErrNotFound if nothing was found +func (b *SelectBuilder) LoadStruct(dest interface{}) error { + // + // Validate the dest, and extract the reflection values we need. + // + valueOfDest := reflect.ValueOf(dest) + indirectOfDest := reflect.Indirect(valueOfDest) + kindOfDest := valueOfDest.Kind() + + if kindOfDest != reflect.Ptr || indirectOfDest.Kind() != reflect.Struct { + panic("you need to pass in the address of a struct") + } + + recordType := indirectOfDest.Type() + + // + // Get full SQL + // + fullSql, err := Interpolate(b.ToSql()) + if err != nil { + return err + } + + // Start the timer: + startTime := time.Now() + defer func() { b.TimingKv("dbr.select", time.Since(startTime).Nanoseconds(), kvs{"sql": fullSql}) }() + + // Run the query: + rows, err := b.Runner.Query(fullSql) + if err != nil { + return b.EventErrKv("dbr.select.load_one.query", err, kvs{"sql": fullSql}) + } + defer rows.Close() + + // Get the columns of this result set + columns, err := rows.Columns() + if err != nil { + return b.EventErrKv("dbr.select.load_one.rows.Columns", err, kvs{"sql": fullSql}) + } + + // Create a map of this result set to the struct columns + fieldMap, err := b.calculateFieldMap(recordType, columns, false) + if err != nil { + return b.EventErrKv("dbr.select.load_one.calculateFieldMap", err, kvs{"sql": fullSql}) + } + + // Build a 'holder', which is an []interface{}. Each value will be the set to address of the field corresponding to our newly made records: + holder := make([]interface{}, len(fieldMap)) + + if rows.Next() { + // Build a 'holder', which is an []interface{}. Each value will be the address of the field corresponding to our newly made record: + scannable, err := b.prepareHolderFor(indirectOfDest, fieldMap, holder) + if err != nil { + return b.EventErrKv("dbr.select.load_one.holderFor", err, kvs{"sql": fullSql}) + } + + // Load up our new structure with the row's values + err = rows.Scan(scannable...) + if err != nil { + return b.EventErrKv("dbr.select.load_one.scan", err, kvs{"sql": fullSql}) + } + return nil + } + + if err := rows.Err(); err != nil { + return b.EventErrKv("dbr.select.load_one.rows_err", err, kvs{"sql": fullSql}) + } + + return ErrNotFound +} + +// LoadValues executes the SelectBuilder and loads the resulting data into a slice of primitive values +// Returns ErrNotFound if no value was found, and it was therefore not set. +func (b *SelectBuilder) LoadValues(dest interface{}) (int, error) { + // Validate the dest and reflection values we need + + // This must be a pointer to a slice + valueOfDest := reflect.ValueOf(dest) + kindOfDest := valueOfDest.Kind() + + if kindOfDest != reflect.Ptr { + panic("invalid type passed to LoadValues. Need a pointer to a slice") + } + + // This must a slice + valueOfDest = reflect.Indirect(valueOfDest) + kindOfDest = valueOfDest.Kind() + + if kindOfDest != reflect.Slice { + panic("invalid type passed to LoadValues. Need a pointer to a slice") + } + + recordType := valueOfDest.Type().Elem() + + recordTypeIsPtr := recordType.Kind() == reflect.Ptr + if recordTypeIsPtr { + reflect.ValueOf(dest) + } + + // + // Get full SQL + // + fullSql, err := Interpolate(b.ToSql()) + if err != nil { + return 0, err + } + + numberOfRowsReturned := 0 + + // Start the timer: + startTime := time.Now() + defer func() { b.TimingKv("dbr.select", time.Since(startTime).Nanoseconds(), kvs{"sql": fullSql}) }() + + // Run the query: + rows, err := b.Runner.Query(fullSql) + if err != nil { + return numberOfRowsReturned, b.EventErrKv("dbr.select.load_all_values.query", err, kvs{"sql": fullSql}) + } + defer rows.Close() + + sliceValue := valueOfDest + for rows.Next() { + // Create a new value to store our row: + pointerToNewValue := reflect.New(recordType) + newValue := reflect.Indirect(pointerToNewValue) + + err = rows.Scan(pointerToNewValue.Interface()) + if err != nil { + return numberOfRowsReturned, b.EventErrKv("dbr.select.load_all_values.scan", err, kvs{"sql": fullSql}) + } + + // Append our new value to the slice: + sliceValue = reflect.Append(sliceValue, newValue) + + numberOfRowsReturned++ + } + valueOfDest.Set(sliceValue) + + if err := rows.Err(); err != nil { + return numberOfRowsReturned, b.EventErrKv("dbr.select.load_all_values.rows_err", err, kvs{"sql": fullSql}) + } + + return numberOfRowsReturned, nil +} + +// LoadValue executes the SelectBuilder and loads the resulting data into a primitive value +// Returns ErrNotFound if no value was found, and it was therefore not set. +func (b *SelectBuilder) LoadValue(dest interface{}) error { + // Validate the dest + valueOfDest := reflect.ValueOf(dest) + kindOfDest := valueOfDest.Kind() + + if kindOfDest != reflect.Ptr { + panic("Destination must be a pointer") + } + + // + // Get full SQL + // + fullSql, err := Interpolate(b.ToSql()) + if err != nil { + return err + } + + // Start the timer: + startTime := time.Now() + defer func() { b.TimingKv("dbr.select", time.Since(startTime).Nanoseconds(), kvs{"sql": fullSql}) }() + + // Run the query: + rows, err := b.Runner.Query(fullSql) + if err != nil { + return b.EventErrKv("dbr.select.load_value.query", err, kvs{"sql": fullSql}) + } + defer rows.Close() + + if rows.Next() { + err = rows.Scan(dest) + if err != nil { + return b.EventErrKv("dbr.select.load_value.scan", err, kvs{"sql": fullSql}) + } + return nil + } + + if err := rows.Err(); err != nil { + return b.EventErrKv("dbr.select.load_value.rows_err", err, kvs{"sql": fullSql}) + } + + return ErrNotFound +} diff --git a/select_return.go b/select_return.go new file mode 100644 index 0000000..27e1594 --- /dev/null +++ b/select_return.go @@ -0,0 +1,66 @@ +package dbr + +// +// These are a set of helpers that just call LoadValue and return the value. +// They return (_, ErrNotFound) if nothing was found. +// + +// The inclusion of these helpers in the package is not an obvious choice: +// Benefits: +// - slight increase in code clarity/conciseness b/c you can use ":=" to define the variable +// +// count, err := d.Select("COUNT(*)").From("users").Where("x = ?", x).ReturnInt64() +// +// vs +// +// var count int64 +// err := d.Select("COUNT(*)").From("users").Where("x = ?", x).LoadValue(&count) +// +// Downsides: +// - very small increase in code cost, although it's not complex code +// - increase in conceptual model / API footprint when presenting the package to new users +// - no functionality that you can't achieve calling .LoadValue directly. +// - There's a lot of possible types. Do we want to include ALL of them? u?int{8,16,32,64}?, strings, null varieties, etc. +// - Let's just do the common, non-null varieties. + +// ReturnInt64 executes the SelectBuilder and returns the value as an int64 +func (b *SelectBuilder) ReturnInt64() (int64, error) { + var v int64 + err := b.LoadValue(&v) + return v, err +} + +// ReturnInt64s executes the SelectBuilder and returns the value as a slice of int64s +func (b *SelectBuilder) ReturnInt64s() ([]int64, error) { + var v []int64 + _, err := b.LoadValues(&v) + return v, err +} + +// ReturnUint64 executes the SelectBuilder and returns the value as an uint64 +func (b *SelectBuilder) ReturnUint64() (uint64, error) { + var v uint64 + err := b.LoadValue(&v) + return v, err +} + +// ReturnUint64s executes the SelectBuilder and returns the value as a slice of uint64s +func (b *SelectBuilder) ReturnUint64s() ([]uint64, error) { + var v []uint64 + _, err := b.LoadValues(&v) + return v, err +} + +// ReturnString executes the SelectBuilder and returns the value as a string +func (b *SelectBuilder) ReturnString() (string, error) { + var v string + err := b.LoadValue(&v) + return v, err +} + +// ReturnStrings executes the SelectBuilder and returns the value as a slice of strings +func (b *SelectBuilder) ReturnStrings() ([]string, error) { + var v []string + _, err := b.LoadValues(&v) + return v, err +} diff --git a/select_test.go b/select_test.go new file mode 100644 index 0000000..84c21b0 --- /dev/null +++ b/select_test.go @@ -0,0 +1,337 @@ +package dbr + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func BenchmarkSelectBasicSql(b *testing.B) { + s := createFakeSession() + + // Do some allocations outside the loop so they don't affect the results + argEq := Eq{"a": []int{1, 2, 3}} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + s.Select("something_id", "user_id", "other"). + From("some_table"). + Where("d = ? OR e = ?", 1, "wat"). + Where(argEq). + OrderDir("id", false). + Paginate(1, 20). + ToSql() + } +} + +func BenchmarkSelectFullSql(b *testing.B) { + s := createFakeSession() + + // Do some allocations outside the loop so they don't affect the results + argEq1 := Eq{"f": 2, "x": "hi"} + argEq2 := map[string]interface{}{"g": 3} + argEq3 := Eq{"h": []int{1, 2, 3}} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + s.Select("a", "b", "z", "y", "x"). + Distinct(). + From("c"). + Where("d = ? OR e = ?", 1, "wat"). + Where(argEq1). + Where(argEq2). + Where(argEq3). + GroupBy("i"). + GroupBy("ii"). + GroupBy("iii"). + Having("j = k"). + Having("jj = ?", 1). + Having("jjj = ?", 2). + OrderBy("l"). + OrderBy("l"). + OrderBy("l"). + Limit(7). + Offset(8). + ToSql() + } +} + +func TestSelectBasicToSql(t *testing.T) { + s := createFakeSession() + + sql, args := s.Select("a", "b").From("c").Where("id = ?", 1).ToSql() + + assert.Equal(t, sql, "SELECT a, b FROM c WHERE (id = ?)") + assert.Equal(t, args, []interface{}{1}) +} + +func TestSelectFullToSql(t *testing.T) { + s := createFakeSession() + + sql, args := s.Select("a", "b"). + Distinct(). + From("c"). + Where("d = ? OR e = ?", 1, "wat"). + Where(Eq{"f": 2}). + Where(map[string]interface{}{"g": 3}). + Where(Eq{"h": []int{4, 5, 6}}). + GroupBy("i"). + Having("j = k"). + OrderBy("l"). + Limit(7). + Offset(8). + ToSql() + + assert.Equal(t, sql, "SELECT DISTINCT a, b FROM c WHERE (d = ? OR e = ?) AND (`f` = ?) AND (`g` = ?) AND (`h` IN ?) GROUP BY i HAVING (j = k) ORDER BY l LIMIT 7 OFFSET 8") + assert.Equal(t, args, []interface{}{1, "wat", 2, 3, []int{4, 5, 6}}) +} + +func TestSelectPaginateOrderDirToSql(t *testing.T) { + s := createFakeSession() + + sql, args := s.Select("a", "b"). + From("c"). + Where("d = ?", 1). + Paginate(1, 20). + OrderDir("id", false). + ToSql() + + assert.Equal(t, sql, "SELECT a, b FROM c WHERE (d = ?) ORDER BY id DESC LIMIT 20 OFFSET 0") + assert.Equal(t, args, []interface{}{1}) + + sql, args = s.Select("a", "b"). + From("c"). + Where("d = ?", 1). + Paginate(3, 30). + OrderDir("id", true). + ToSql() + + assert.Equal(t, sql, "SELECT a, b FROM c WHERE (d = ?) ORDER BY id ASC LIMIT 30 OFFSET 60") + assert.Equal(t, args, []interface{}{1}) +} + +func TestSelectNoWhereSql(t *testing.T) { + s := createFakeSession() + + sql, args := s.Select("a", "b").From("c").ToSql() + + assert.Equal(t, sql, "SELECT a, b FROM c") + assert.Equal(t, args, []interface{}(nil)) +} + +func TestSelectMultiHavingSql(t *testing.T) { + s := createFakeSession() + + sql, args := s.Select("a", "b").From("c").Where("p = ?", 1).GroupBy("z").Having("z = ?", 2).Having("y = ?", 3).ToSql() + + assert.Equal(t, sql, "SELECT a, b FROM c WHERE (p = ?) GROUP BY z HAVING (z = ?) AND (y = ?)") + assert.Equal(t, args, []interface{}{1, 2, 3}) +} + +func TestSelectMultiOrderSql(t *testing.T) { + s := createFakeSession() + + sql, args := s.Select("a", "b").From("c").OrderBy("name ASC").OrderBy("id DESC").ToSql() + + assert.Equal(t, sql, "SELECT a, b FROM c ORDER BY name ASC, id DESC") + assert.Equal(t, args, []interface{}(nil)) +} + +func TestSelectWhereMapSql(t *testing.T) { + s := createFakeSession() + + sql, args := s.Select("a").From("b").Where(map[string]interface{}{"a": 1}).ToSql() + assert.Equal(t, sql, "SELECT a FROM b WHERE (`a` = ?)") + assert.Equal(t, args, []interface{}{1}) + + sql, args = s.Select("a").From("b").Where(map[string]interface{}{"a": 1, "b": true}).ToSql() + if sql == "SELECT a FROM b WHERE (`a` = ?) AND (`b` = ?)" { + assert.Equal(t, args, []interface{}{1, true}) + } else { + assert.Equal(t, sql, "SELECT a FROM b WHERE (`b` = ?) AND (`a` = ?)") + assert.Equal(t, args, []interface{}{true, 1}) + } + + sql, args = s.Select("a").From("b").Where(map[string]interface{}{"a": nil}).ToSql() + assert.Equal(t, sql, "SELECT a FROM b WHERE (`a` IS NULL)") + assert.Equal(t, args, []interface{}(nil)) + + sql, args = s.Select("a").From("b").Where(map[string]interface{}{"a": []int{1, 2, 3}}).ToSql() + assert.Equal(t, sql, "SELECT a FROM b WHERE (`a` IN ?)") + assert.Equal(t, args, []interface{}{[]int{1, 2, 3}}) + + sql, args = s.Select("a").From("b").Where(map[string]interface{}{"a": []int{1}}).ToSql() + assert.Equal(t, sql, "SELECT a FROM b WHERE (`a` = ?)") + assert.Equal(t, args, []interface{}{1}) + + // NOTE: a has no valid values, we want a query that returns nothing + sql, args = s.Select("a").From("b").Where(map[string]interface{}{"a": []int{}}).ToSql() + assert.Equal(t, sql, "SELECT a FROM b WHERE (1=0)") + assert.Equal(t, args, []interface{}(nil)) + + var aval []int + sql, args = s.Select("a").From("b").Where(map[string]interface{}{"a": aval}).ToSql() + assert.Equal(t, sql, "SELECT a FROM b WHERE (`a` IS NULL)") + assert.Equal(t, args, []interface{}(nil)) + + sql, args = s.Select("a").From("b"). + Where(map[string]interface{}{"a": []int(nil)}). + Where(map[string]interface{}{"b": false}). + ToSql() + assert.Equal(t, sql, "SELECT a FROM b WHERE (`a` IS NULL) AND (`b` = ?)") + assert.Equal(t, args, []interface{}{false}) +} + +func TestSelectWhereEqSql(t *testing.T) { + s := createFakeSession() + + sql, args := s.Select("a").From("b").Where(Eq{"a": 1, "b": []int64{1, 2, 3}}).ToSql() + if sql == "SELECT a FROM b WHERE (`a` = ?) AND (`b` IN ?)" { + assert.Equal(t, args, []interface{}{1, []int64{1, 2, 3}}) + } else { + assert.Equal(t, sql, "SELECT a FROM b WHERE (`b` IN ?) AND (`a` = ?)") + assert.Equal(t, args, []interface{}{[]int64{1, 2, 3}, 1}) + } +} + +func TestSelectBySql(t *testing.T) { + s := createFakeSession() + + sql, args := s.SelectBySql("SELECT * FROM users WHERE x = 1").ToSql() + assert.Equal(t, sql, "SELECT * FROM users WHERE x = 1") + assert.Equal(t, args, []interface{}(nil)) + + sql, args = s.SelectBySql("SELECT * FROM users WHERE x = ? AND y IN ?", 9, []int{5, 6, 7}).ToSql() + assert.Equal(t, sql, "SELECT * FROM users WHERE x = ? AND y IN ?") + assert.Equal(t, args, []interface{}{9, []int{5, 6, 7}}) + + // Doesn't fix shit if it's broken: + sql, args = s.SelectBySql("wat", 9, []int{5, 6, 7}).ToSql() + assert.Equal(t, sql, "wat") + assert.Equal(t, args, []interface{}{9, []int{5, 6, 7}}) +} + +func TestSelectVarieties(t *testing.T) { + s := createFakeSession() + + sql, _ := s.Select("id, name, email").From("users").ToSql() + sql2, _ := s.Select("id", "name", "email").From("users").ToSql() + assert.Equal(t, sql, sql2) +} + +func TestSelectLoadStructs(t *testing.T) { + s := createRealSessionWithFixtures() + + var people []*dbrPerson + count, err := s.Select("id", "name", "email").From("dbr_people").OrderBy("id ASC").LoadStructs(&people) + + assert.NoError(t, err) + assert.Equal(t, count, 2) + + assert.Equal(t, len(people), 2) + if len(people) == 2 { + // Make sure that the Ids are set. It's possible (maybe?) that different DBs set ids differently so + // don't assume they're 1 and 2. + assert.True(t, people[0].Id > 0) + assert.True(t, people[1].Id > people[0].Id) + + assert.Equal(t, people[0].Name, "Jonathan") + assert.True(t, people[0].Email.Valid) + assert.Equal(t, people[0].Email.String, "jonathan@uservoice.com") + assert.Equal(t, people[1].Name, "Dmitri") + assert.True(t, people[1].Email.Valid) + assert.Equal(t, people[1].Email.String, "zavorotni@jadius.com") + } + + // TODO: test map +} + +func TestSelectLoadStruct(t *testing.T) { + s := createRealSessionWithFixtures() + + // Found: + var person dbrPerson + err := s.Select("id", "name", "email").From("dbr_people").Where("email = ?", "jonathan@uservoice.com").LoadStruct(&person) + assert.NoError(t, err) + assert.True(t, person.Id > 0) + assert.Equal(t, person.Name, "Jonathan") + assert.True(t, person.Email.Valid) + assert.Equal(t, person.Email.String, "jonathan@uservoice.com") + + // Not found: + var person2 dbrPerson + err = s.Select("id", "name", "email").From("dbr_people").Where("email = ?", "dontexist@uservoice.com").LoadStruct(&person2) + assert.Equal(t, err, ErrNotFound) +} + +func TestSelectBySqlLoadStructs(t *testing.T) { + s := createRealSessionWithFixtures() + + var people []*dbrPerson + count, err := s.SelectBySql("SELECT name FROM dbr_people WHERE email IN ?", []string{"jonathan@uservoice.com"}).LoadStructs(&people) + + assert.NoError(t, err) + assert.Equal(t, count, 1) + if len(people) == 1 { + assert.Equal(t, people[0].Name, "Jonathan") + assert.Equal(t, people[0].Id, 0) // not set + assert.Equal(t, people[0].Email.Valid, false) // not set + assert.Equal(t, people[0].Email.String, "") // not set + } +} + +func TestSelectLoadValue(t *testing.T) { + s := createRealSessionWithFixtures() + + var name string + err := s.Select("name").From("dbr_people").Where("email = 'jonathan@uservoice.com'").LoadValue(&name) + + assert.NoError(t, err) + assert.Equal(t, name, "Jonathan") + + var id int64 + err = s.Select("id").From("dbr_people").Limit(1).LoadValue(&id) + + assert.NoError(t, err) + assert.True(t, id > 0) +} + +func TestSelectLoadValues(t *testing.T) { + s := createRealSessionWithFixtures() + + var names []string + count, err := s.Select("name").From("dbr_people").LoadValues(&names) + + assert.NoError(t, err) + assert.Equal(t, count, 2) + assert.Equal(t, names, []string{"Jonathan", "Dmitri"}) + + var ids []int64 + count, err = s.Select("id").From("dbr_people").Limit(1).LoadValues(&ids) + + assert.NoError(t, err) + assert.Equal(t, count, 1) + assert.Equal(t, ids, []int64{1}) +} + +func TestSelectReturn(t *testing.T) { + s := createRealSessionWithFixtures() + + name, err := s.Select("name").From("dbr_people").Where("email = 'jonathan@uservoice.com'").ReturnString() + assert.NoError(t, err) + assert.Equal(t, name, "Jonathan") + + count, err := s.Select("COUNT(*)").From("dbr_people").ReturnInt64() + assert.NoError(t, err) + assert.Equal(t, count, 2) + + names, err := s.Select("name").From("dbr_people").Where("email = 'jonathan@uservoice.com'").ReturnStrings() + assert.NoError(t, err) + assert.Equal(t, names, []string{"Jonathan"}) + + counts, err := s.Select("COUNT(*)").From("dbr_people").ReturnInt64s() + assert.NoError(t, err) + assert.Equal(t, counts, []int64{2}) +} + +// Series of tests that test mapping struct fields to columns diff --git a/struct_mapping.go b/struct_mapping.go new file mode 100644 index 0000000..fa4cb7c --- /dev/null +++ b/struct_mapping.go @@ -0,0 +1,109 @@ +package dbr + +import ( + "errors" + "fmt" + "reflect" +) + +var destDummy interface{} + +type fieldMapQueueElement struct { + Type reflect.Type + Idxs []int +} + +// recordType is the type of a structure +func (sess *Session) calculateFieldMap(recordType reflect.Type, columns []string, requireAllColumns bool) ([][]int, error) { + // each value is either the slice to get to the field via FieldByIndex(index []int) in the record, or nil if we don't want to map it to the structure. + lenColumns := len(columns) + fieldMap := make([][]int, lenColumns) + + for i, col := range columns { + fieldMap[i] = nil + + queue := []fieldMapQueueElement{fieldMapQueueElement{Type: recordType, Idxs: nil}} + + QueueLoop: + for len(queue) > 0 { + curEntry := queue[0] + queue = queue[1:] + + curType := curEntry.Type + curIdxs := curEntry.Idxs + lenFields := curType.NumField() + + for j := 0; j < lenFields; j++ { + fieldStruct := curType.Field(j) + + // Skip unexported field + if len(fieldStruct.PkgPath) != 0 { + continue + } + + name := fieldStruct.Tag.Get("db") + if name != "-" { + if name == "" { + name = NameMapping(fieldStruct.Name) + } + if name == col { + fieldMap[i] = append(curIdxs, j) + break QueueLoop + } + } + + if fieldStruct.Type.Kind() == reflect.Struct { + var idxs2 []int + copy(idxs2, curIdxs) + idxs2 = append(idxs2, j) + queue = append(queue, fieldMapQueueElement{Type: fieldStruct.Type, Idxs: idxs2}) + } + } + } + + if requireAllColumns && fieldMap[i] == nil { + return nil, errors.New(fmt.Sprint("couldn't find match for column ", col)) + } + } + + return fieldMap, nil +} + +func (sess *Session) prepareHolderFor(record reflect.Value, fieldMap [][]int, holder []interface{}) ([]interface{}, error) { + // Given a query and given a structure (field list), there's 2 sets of fields. + // Take the intersection. We can fill those in. great. + // For fields in the structure that aren't in the query, we'll let that slide if db:"-" + // For fields in the structure that aren't in the query but without db:"-", return error + // For fields in the query that aren't in the structure, we'll ignore them. + + for i, fieldIndex := range fieldMap { + if fieldIndex == nil { + holder[i] = &destDummy + } else { + field := record.FieldByIndex(fieldIndex) + holder[i] = field.Addr().Interface() + } + } + + return holder, nil +} + +func (sess *Session) valuesFor(recordType reflect.Type, record reflect.Value, columns []string) ([]interface{}, error) { + fieldMap, err := sess.calculateFieldMap(recordType, columns, true) + if err != nil { + fmt.Println("err: calc field map") + return nil, err + } + + values := make([]interface{}, len(columns)) + for i, fieldIndex := range fieldMap { + if fieldIndex == nil { + panic("wtf bro") + } else { + field := record.FieldByIndex(fieldIndex) + values[i] = field.Interface() + } + } + + return values, nil +} diff --git a/tags b/tags new file mode 100644 index 0000000..f815ea0 --- /dev/null +++ b/tags @@ -0,0 +1,270 @@ +!_TAG_FILE_FORMAT 2 /extended format; --format=1 will not append ;" to lines/ +!_TAG_FILE_SORTED 1 /0=unsorted, 1=sorted, 2=foldcase/ +!_TAG_OUTPUT_MODE u-ctags /u-ctags or e-ctags/ +!_TAG_PROGRAM_AUTHOR Universal Ctags Team // +!_TAG_PROGRAM_NAME Universal Ctags /Derived from Exuberant Ctags/ +!_TAG_PROGRAM_URL https://ctags.io/ /official site/ +!_TAG_PROGRAM_VERSION 0.0.0 /7fcdea3/ +Begin transaction.go /^func (sess *Session) Begin() (*Tx, error) {$/;" f +BenchmarkDeleteSql delete_test.go /^func BenchmarkDeleteSql(b *testing.B) {$/;" f +BenchmarkInsertRecordsSql insert_test.go /^func BenchmarkInsertRecordsSql(b *testing.B) {$/;" f +BenchmarkInsertValuesSql insert_test.go /^func BenchmarkInsertValuesSql(b *testing.B) {$/;" f +BenchmarkSelectBasicSql select_test.go /^func BenchmarkSelectBasicSql(b *testing.B) {$/;" f +BenchmarkSelectFullSql select_test.go /^func BenchmarkSelectFullSql(b *testing.B) {$/;" f +BenchmarkUpdateValueMapSql update_test.go /^func BenchmarkUpdateValueMapSql(b *testing.B) {$/;" f +BenchmarkUpdateValuesSql update_test.go /^func BenchmarkUpdateValuesSql(b *testing.B) {$/;" f +Cols insert.go /^ Cols []string$/;" m struct:InsertBuilder +Columns insert.go /^func (b *InsertBuilder) Columns(columns ...string) *InsertBuilder {$/;" f +Columns select.go /^ Columns []string$/;" m struct:SelectBuilder +Commit transaction.go /^func (tx *Tx) Commit() error {$/;" f +Condition where.go /^ Condition string$/;" m struct:whereFragment +Connection dbr.go /^type Connection struct {$/;" s +Db dbr.go /^ Db *sql.DB$/;" m struct:Connection +DeleteBuilder delete.go /^type DeleteBuilder struct {$/;" s +DeleteFrom delete.go /^func (sess *Session) DeleteFrom(from string) *DeleteBuilder {$/;" f +DeleteFrom delete.go /^func (tx *Tx) DeleteFrom(from string) *DeleteBuilder {$/;" f +Distinct select.go /^func (b *SelectBuilder) Distinct() *SelectBuilder {$/;" f +Email dbr_test.go /^ Email NullString$/;" m struct:dbrPerson +Eq where.go /^type Eq map[string]interface{}$/;" t +EqualityMap where.go /^ EqualityMap map[string]interface{}$/;" m struct:whereFragment +ErrArgumentMismatch errors.go /^ ErrArgumentMismatch = errors.New("mismatch between ? (placeholders) and arguments")$/;" v +ErrInvalidSliceLength errors.go /^ ErrInvalidSliceLength = errors.New("length of slice is 0. length must be >= 1")$/;" v +ErrInvalidSliceValue errors.go /^ ErrInvalidSliceValue = errors.New("trying to interpolate invalid slice value into query")$/;" v +ErrInvalidValue errors.go /^ ErrInvalidValue = errors.New("trying to interpolate invalid value into query")$/;" v +ErrNotFound errors.go /^ ErrNotFound = errors.New("not found")$/;" v +ErrNotUTF8 errors.go /^ ErrNotUTF8 = errors.New("invalid UTF-8")$/;" v +Event event.go /^func (n *NullEventReceiver) Event(eventName string) {$/;" f +EventErr event.go /^func (n *NullEventReceiver) EventErr(eventName string, err error) error {$/;" f +EventErrKv event.go /^func (n *NullEventReceiver) EventErrKv(eventName string, err error, kvs map[string]string) error/;" f +EventKv event.go /^func (n *NullEventReceiver) EventKv(eventName string, kvs map[string]string) {$/;" f +EventReceiver event.go /^type EventReceiver interface {$/;" i +Exec delete.go /^func (b *DeleteBuilder) Exec() (sql.Result, error) {$/;" f +Exec insert.go /^func (b *InsertBuilder) Exec() (sql.Result, error) {$/;" f +Exec update.go /^func (b *UpdateBuilder) Exec() (sql.Result, error) {$/;" f +Expr expr.go /^func Expr(sql string, values ...interface{}) *expr {$/;" f +From delete.go /^ From string$/;" m struct:DeleteBuilder +From select.go /^func (b *SelectBuilder) From(from string) *SelectBuilder {$/;" f +FromTable select.go /^ FromTable string$/;" m struct:SelectBuilder +GroupBy select.go /^func (b *SelectBuilder) GroupBy(group string) *SelectBuilder {$/;" f +GroupBys select.go /^ GroupBys []string$/;" m struct:SelectBuilder +Having select.go /^func (b *SelectBuilder) Having(whereSqlOrMap interface{}, args ...interface{}) *SelectBuilder {$/;" f +HavingFragments select.go /^ HavingFragments []*whereFragment$/;" m struct:SelectBuilder +Id dbr_test.go /^ Id int64$/;" m struct:dbrPerson +Idxs struct_mapping.go /^ Idxs []int$/;" m struct:fieldMapQueueElement +InsertBuilder insert.go /^type InsertBuilder struct {$/;" s +InsertInto insert.go /^func (sess *Session) InsertInto(into string) *InsertBuilder {$/;" f +InsertInto insert.go /^func (tx *Tx) InsertInto(into string) *InsertBuilder {$/;" f +Interpolate interpolate.go /^func Interpolate(sql string, vals []interface{}) (string, error) {$/;" f +Into insert.go /^ Into string$/;" m struct:InsertBuilder +IsDistinct select.go /^ IsDistinct bool$/;" m struct:SelectBuilder +Key dbr_test.go /^ Key NullString$/;" m struct:dbrPerson +Limit delete.go /^func (b *DeleteBuilder) Limit(limit uint64) *DeleteBuilder {$/;" f +Limit select.go /^func (b *SelectBuilder) Limit(limit uint64) *SelectBuilder {$/;" f +Limit update.go /^func (b *UpdateBuilder) Limit(limit uint64) *UpdateBuilder {$/;" f +LimitCount delete.go /^ LimitCount uint64$/;" m struct:DeleteBuilder +LimitCount select.go /^ LimitCount uint64$/;" m struct:SelectBuilder +LimitCount update.go /^ LimitCount uint64$/;" m struct:UpdateBuilder +LimitValid delete.go /^ LimitValid bool$/;" m struct:DeleteBuilder +LimitValid select.go /^ LimitValid bool$/;" m struct:SelectBuilder +LimitValid update.go /^ LimitValid bool$/;" m struct:UpdateBuilder +LoadStruct select_load.go /^func (b *SelectBuilder) LoadStruct(dest interface{}) error {$/;" f +LoadStructs select_load.go /^func (b *SelectBuilder) LoadStructs(dest interface{}) (int, error) {$/;" f +LoadValue select_load.go /^func (b *SelectBuilder) LoadValue(dest interface{}) error {$/;" f +LoadValues select_load.go /^func (b *SelectBuilder) LoadValues(dest interface{}) (int, error) {$/;" f +MarshalJSON types.go /^func (n *NullBool) MarshalJSON() ([]byte, error) {$/;" f +MarshalJSON types.go /^func (n *NullInt64) MarshalJSON() ([]byte, error) {$/;" f +MarshalJSON types.go /^func (n *NullString) MarshalJSON() ([]byte, error) {$/;" f +MarshalJSON types.go /^func (n *NullTime) MarshalJSON() ([]byte, error) {$/;" f +MysqlQuoter quote.go /^type MysqlQuoter struct{}$/;" s +Name dbr_test.go /^ Name string$/;" m struct:dbrPerson +NameMapping util.go /^var NameMapping = camelCaseToSnakeCase$/;" v +NewConnection dbr.go /^func NewConnection(db *sql.DB, log EventReceiver) *Connection {$/;" f +NewSession dbr.go /^func (cxn *Connection) NewSession(log EventReceiver) *Session {$/;" f +Now now.go /^var Now = nowSentinel{}$/;" v +NullBool types.go /^type NullBool struct {$/;" s +NullEventReceiver event.go /^type NullEventReceiver struct{}$/;" s +NullInt64 types.go /^type NullInt64 struct {$/;" s +NullString types.go /^type NullString struct {$/;" s +NullTime types.go /^type NullTime struct {$/;" s +Offset delete.go /^func (b *DeleteBuilder) Offset(offset uint64) *DeleteBuilder {$/;" f +Offset select.go /^func (b *SelectBuilder) Offset(offset uint64) *SelectBuilder {$/;" f +Offset update.go /^func (b *UpdateBuilder) Offset(offset uint64) *UpdateBuilder {$/;" f +OffsetCount delete.go /^ OffsetCount uint64$/;" m struct:DeleteBuilder +OffsetCount select.go /^ OffsetCount uint64$/;" m struct:SelectBuilder +OffsetCount update.go /^ OffsetCount uint64$/;" m struct:UpdateBuilder +OffsetValid delete.go /^ OffsetValid bool$/;" m struct:DeleteBuilder +OffsetValid select.go /^ OffsetValid bool$/;" m struct:SelectBuilder +OffsetValid update.go /^ OffsetValid bool$/;" m struct:UpdateBuilder +OrderBy delete.go /^func (b *DeleteBuilder) OrderBy(ord string) *DeleteBuilder {$/;" f +OrderBy select.go /^func (b *SelectBuilder) OrderBy(ord string) *SelectBuilder {$/;" f +OrderBy update.go /^func (b *UpdateBuilder) OrderBy(ord string) *UpdateBuilder {$/;" f +OrderBys delete.go /^ OrderBys []string$/;" m struct:DeleteBuilder +OrderBys select.go /^ OrderBys []string$/;" m struct:SelectBuilder +OrderBys update.go /^ OrderBys []string$/;" m struct:UpdateBuilder +OrderDir delete.go /^func (b *DeleteBuilder) OrderDir(ord string, isAsc bool) *DeleteBuilder {$/;" f +OrderDir select.go /^func (b *SelectBuilder) OrderDir(ord string, isAsc bool) *SelectBuilder {$/;" f +OrderDir update.go /^func (b *UpdateBuilder) OrderDir(ord string, isAsc bool) *UpdateBuilder {$/;" f +Other insert_test.go /^ Other bool$/;" m struct:someRecord +Paginate select.go /^func (b *SelectBuilder) Paginate(page, perPage uint64) *SelectBuilder {$/;" f +Pair insert.go /^func (b *InsertBuilder) Pair(column string, value interface{}) *InsertBuilder {$/;" f +Present interpolate_test.go /^ Present bool$/;" m struct:myString +Quoter quote.go /^var Quoter = MysqlQuoter{}$/;" v +RawArguments select.go /^ RawArguments []interface{}$/;" m struct:SelectBuilder +RawArguments update.go /^ RawArguments []interface{}$/;" m struct:UpdateBuilder +RawFullSql select.go /^ RawFullSql string$/;" m struct:SelectBuilder +RawFullSql update.go /^ RawFullSql string$/;" m struct:UpdateBuilder +Record insert.go /^func (b *InsertBuilder) Record(record interface{}) *InsertBuilder {$/;" f +Recs insert.go /^ Recs []interface{}$/;" m struct:InsertBuilder +ReturnInt64 select_return.go /^func (b *SelectBuilder) ReturnInt64() (int64, error) {$/;" f +ReturnInt64s select_return.go /^func (b *SelectBuilder) ReturnInt64s() ([]int64, error) {$/;" f +ReturnString select_return.go /^func (b *SelectBuilder) ReturnString() (string, error) {$/;" f +ReturnStrings select_return.go /^func (b *SelectBuilder) ReturnStrings() ([]string, error) {$/;" f +ReturnUint64 select_return.go /^func (b *SelectBuilder) ReturnUint64() (uint64, error) {$/;" f +ReturnUint64s select_return.go /^func (b *SelectBuilder) ReturnUint64s() ([]uint64, error) {$/;" f +Rollback transaction.go /^func (tx *Tx) Rollback() error {$/;" f +RollbackUnlessCommitted transaction.go /^func (tx *Tx) RollbackUnlessCommitted() {$/;" f +Runner dbr.go /^type Runner interface {$/;" i +Select select.go /^func (sess *Session) Select(cols ...string) *SelectBuilder {$/;" f +Select select.go /^func (tx *Tx) Select(cols ...string) *SelectBuilder {$/;" f +SelectBuilder select.go /^type SelectBuilder struct {$/;" s +SelectBySql select.go /^func (sess *Session) SelectBySql(sql string, args ...interface{}) *SelectBuilder {$/;" f +SelectBySql select.go /^func (tx *Tx) SelectBySql(sql string, args ...interface{}) *SelectBuilder {$/;" f +Session dbr.go /^type Session struct {$/;" s +SessionRunner dbr.go /^type SessionRunner interface {$/;" i +Set update.go /^func (b *UpdateBuilder) Set(column string, value interface{}) *UpdateBuilder {$/;" f +SetClauses update.go /^ SetClauses []*setClause$/;" m struct:UpdateBuilder +SetMap update.go /^func (b *UpdateBuilder) SetMap(clauses map[string]interface{}) *UpdateBuilder {$/;" f +SomethingId insert_test.go /^ SomethingId int$/;" m struct:someRecord +Sql expr.go /^ Sql string$/;" m struct:expr +Table update.go /^ Table string$/;" m struct:UpdateBuilder +TestDeleteAllToSql delete_test.go /^func TestDeleteAllToSql(t *testing.T) {$/;" f +TestDeleteReal delete_test.go /^func TestDeleteReal(t *testing.T) {$/;" f +TestDeleteSingleToSql delete_test.go /^func TestDeleteSingleToSql(t *testing.T) {$/;" f +TestDeleteTenStaringFromTwentyToSql delete_test.go /^func TestDeleteTenStaringFromTwentyToSql(t *testing.T) {$/;" f +TestInsertKeywordColumnName insert_test.go /^func TestInsertKeywordColumnName(t *testing.T) {$/;" f +TestInsertMultipleToSql insert_test.go /^func TestInsertMultipleToSql(t *testing.T) {$/;" f +TestInsertReal insert_test.go /^func TestInsertReal(t *testing.T) {$/;" f +TestInsertRecordsToSql insert_test.go /^func TestInsertRecordsToSql(t *testing.T) {$/;" f +TestInsertSingleToSql insert_test.go /^func TestInsertSingleToSql(t *testing.T) {$/;" f +TestIntepolatingValuers interpolate_test.go /^func TestIntepolatingValuers(t *testing.T) {$/;" f +TestInterpolateBools interpolate_test.go /^func TestInterpolateBools(t *testing.T) {$/;" f +TestInterpolateErrors interpolate_test.go /^func TestInterpolateErrors(t *testing.T) {$/;" f +TestInterpolateFloats interpolate_test.go /^func TestInterpolateFloats(t *testing.T) {$/;" f +TestInterpolateInts interpolate_test.go /^func TestInterpolateInts(t *testing.T) {$/;" f +TestInterpolateNil interpolate_test.go /^func TestInterpolateNil(t *testing.T) {$/;" f +TestInterpolateSlices interpolate_test.go /^func TestInterpolateSlices(t *testing.T) {$/;" f +TestInterpolateStrings interpolate_test.go /^func TestInterpolateStrings(t *testing.T) {$/;" f +TestSelectBasicToSql select_test.go /^func TestSelectBasicToSql(t *testing.T) {$/;" f +TestSelectBySql select_test.go /^func TestSelectBySql(t *testing.T) {$/;" f +TestSelectBySqlLoadStructs select_test.go /^func TestSelectBySqlLoadStructs(t *testing.T) {$/;" f +TestSelectFullToSql select_test.go /^func TestSelectFullToSql(t *testing.T) {$/;" f +TestSelectLoadStruct select_test.go /^func TestSelectLoadStruct(t *testing.T) {$/;" f +TestSelectLoadStructs select_test.go /^func TestSelectLoadStructs(t *testing.T) {$/;" f +TestSelectLoadValue select_test.go /^func TestSelectLoadValue(t *testing.T) {$/;" f +TestSelectLoadValues select_test.go /^func TestSelectLoadValues(t *testing.T) {$/;" f +TestSelectMultiHavingSql select_test.go /^func TestSelectMultiHavingSql(t *testing.T) {$/;" f +TestSelectMultiOrderSql select_test.go /^func TestSelectMultiOrderSql(t *testing.T) {$/;" f +TestSelectNoWhereSql select_test.go /^func TestSelectNoWhereSql(t *testing.T) {$/;" f +TestSelectPaginateOrderDirToSql select_test.go /^func TestSelectPaginateOrderDirToSql(t *testing.T) {$/;" f +TestSelectReturn select_test.go /^func TestSelectReturn(t *testing.T) {$/;" f +TestSelectVarieties select_test.go /^func TestSelectVarieties(t *testing.T) {$/;" f +TestSelectWhereEqSql select_test.go /^func TestSelectWhereEqSql(t *testing.T) {$/;" f +TestSelectWhereMapSql select_test.go /^func TestSelectWhereMapSql(t *testing.T) {$/;" f +TestTransactionReal transaction_test.go /^func TestTransactionReal(t *testing.T) {$/;" f +TestTransactionRollbackReal transaction_test.go /^func TestTransactionRollbackReal(t *testing.T) {$/;" f +TestUpdateAllToSql update_test.go /^func TestUpdateAllToSql(t *testing.T) {$/;" f +TestUpdateKeywordColumnName update_test.go /^func TestUpdateKeywordColumnName(t *testing.T) {$/;" f +TestUpdateReal update_test.go /^func TestUpdateReal(t *testing.T) {$/;" f +TestUpdateSetExprToSql update_test.go /^func TestUpdateSetExprToSql(t *testing.T) {$/;" f +TestUpdateSetMapToSql update_test.go /^func TestUpdateSetMapToSql(t *testing.T) {$/;" f +TestUpdateSingleToSql update_test.go /^func TestUpdateSingleToSql(t *testing.T) {$/;" f +TestUpdateTenStaringFromTwentyToSql update_test.go /^func TestUpdateTenStaringFromTwentyToSql(t *testing.T) {$/;" f +Timing event.go /^func (n *NullEventReceiver) Timing(eventName string, nanoseconds int64) {$/;" f +TimingKv event.go /^func (n *NullEventReceiver) TimingKv(eventName string, nanoseconds int64, kvs map[string]string)/;" f +ToSql delete.go /^func (b *DeleteBuilder) ToSql() (string, []interface{}) {$/;" f +ToSql insert.go /^func (b *InsertBuilder) ToSql() (string, []interface{}) {$/;" f +ToSql select.go /^func (b *SelectBuilder) ToSql() (string, []interface{}) {$/;" f +ToSql update.go /^func (b *UpdateBuilder) ToSql() (string, []interface{}) {$/;" f +Tx transaction.go /^type Tx struct {$/;" s +Type struct_mapping.go /^ Type reflect.Type$/;" m struct:fieldMapQueueElement +Update update.go /^func (sess *Session) Update(table string) *UpdateBuilder {$/;" f +Update update.go /^func (tx *Tx) Update(table string) *UpdateBuilder {$/;" f +UpdateBuilder update.go /^type UpdateBuilder struct {$/;" s +UpdateBySql update.go /^func (sess *Session) UpdateBySql(sql string, args ...interface{}) *UpdateBuilder {$/;" f +UpdateBySql update.go /^func (tx *Tx) UpdateBySql(sql string, args ...interface{}) *UpdateBuilder {$/;" f +UserId insert_test.go /^ UserId int64$/;" m struct:someRecord +Val interpolate_test.go /^ Val string$/;" m struct:myString +Vals insert.go /^ Vals [][]interface{}$/;" m struct:InsertBuilder +Value interpolate_test.go /^func (m myString) Value() (driver.Value, error) {$/;" f +Value now.go /^func (n nowSentinel) Value() (driver.Value, error) {$/;" f +Values expr.go /^ Values []interface{}$/;" m struct:expr +Values insert.go /^func (b *InsertBuilder) Values(vals ...interface{}) *InsertBuilder {$/;" f +Values where.go /^ Values []interface{}$/;" m struct:whereFragment +Where delete.go /^func (b *DeleteBuilder) Where(whereSqlOrMap interface{}, args ...interface{}) *DeleteBuilder {$/;" f +Where select.go /^func (b *SelectBuilder) Where(whereSqlOrMap interface{}, args ...interface{}) *SelectBuilder {$/;" f +Where update.go /^func (b *UpdateBuilder) Where(whereSqlOrMap interface{}, args ...interface{}) *UpdateBuilder {$/;" f +WhereFragments delete.go /^ WhereFragments []*whereFragment$/;" m struct:DeleteBuilder +WhereFragments select.go /^ WhereFragments []*whereFragment$/;" m struct:SelectBuilder +WhereFragments update.go /^ WhereFragments []*whereFragment$/;" m struct:UpdateBuilder +calculateFieldMap struct_mapping.go /^func (sess *Session) calculateFieldMap(recordType reflect.Type, columns []string, requireAllColu/;" f +camelCaseToSnakeCase util.go /^func camelCaseToSnakeCase(name string) string {$/;" f +column update.go /^ column string$/;" m struct:setClause +createFakeSession dbr_test.go /^func createFakeSession() *Session {$/;" f +createRealSession dbr_test.go /^func createRealSession() *Session {$/;" f +createRealSessionWithFixtures dbr_test.go /^func createRealSessionWithFixtures() *Session {$/;" f +cxn dbr.go /^ cxn *Connection$/;" m struct:Session +dbr dbr.go /^package dbr$/;" p +dbr dbr_test.go /^package dbr$/;" p +dbr delete.go /^package dbr$/;" p +dbr delete_test.go /^package dbr$/;" p +dbr errors.go /^package dbr$/;" p +dbr event.go /^package dbr$/;" p +dbr expr.go /^package dbr$/;" p +dbr insert.go /^package dbr$/;" p +dbr insert_test.go /^package dbr$/;" p +dbr interpolate.go /^package dbr$/;" p +dbr interpolate_test.go /^package dbr$/;" p +dbr now.go /^package dbr$/;" p +dbr quote.go /^package dbr$/;" p +dbr select.go /^package dbr$/;" p +dbr select_load.go /^package dbr$/;" p +dbr select_return.go /^package dbr$/;" p +dbr select_test.go /^package dbr$/;" p +dbr struct_mapping.go /^package dbr$/;" p +dbr transaction.go /^package dbr$/;" p +dbr transaction_test.go /^package dbr$/;" p +dbr types.go /^package dbr$/;" p +dbr update.go /^package dbr$/;" p +dbr update_test.go /^package dbr$/;" p +dbr util.go /^package dbr$/;" p +dbr where.go /^package dbr$/;" p +dbrPerson dbr_test.go /^type dbrPerson struct {$/;" s +destDummy struct_mapping.go /^var destDummy interface{}$/;" v +escapeAndQuoteString interpolate.go /^func escapeAndQuoteString(val string) string {$/;" f +expr expr.go /^type expr struct {$/;" s +fieldMapQueueElement struct_mapping.go /^type fieldMapQueueElement struct {$/;" s +installFixtures dbr_test.go /^func installFixtures(db *sql.DB) {$/;" f +isFloat interpolate.go /^func isFloat(k reflect.Kind) bool {$/;" f +isInt interpolate.go /^func isInt(k reflect.Kind) bool {$/;" f +isUint interpolate.go /^func isUint(k reflect.Kind) bool {$/;" f +kvs event.go /^type kvs map[string]string$/;" t +myString interpolate_test.go /^type myString struct {$/;" s +newWhereFragment where.go /^func newWhereFragment(whereSqlOrMap interface{}, args []interface{}) *whereFragment {$/;" f +nowSentinel now.go /^type nowSentinel struct{}$/;" s +nullReceiver event.go /^var nullReceiver = &NullEventReceiver{}$/;" v +nullString types.go /^var nullString = []byte("null")$/;" v +prepareHolderFor struct_mapping.go /^func (sess *Session) prepareHolderFor(record reflect.Value, fieldMap [][]int, holder []interface/;" f +quoter quote.go /^type quoter interface {$/;" i +realDb dbr_test.go /^func realDb() *sql.DB {$/;" f +setClause update.go /^type setClause struct {$/;" s +someRecord insert_test.go /^type someRecord struct {$/;" s +timeFormat now.go /^var timeFormat = "2006-01-02 15:04:05"$/;" v +typeOfTime interpolate.go /^var typeOfTime = reflect.TypeOf(time.Time{})$/;" v +validateInsertingBarack insert_test.go /^func validateInsertingBarack(t *testing.T, s *Session, res sql.Result, err error) {$/;" f +value update.go /^ value interface{}$/;" m struct:setClause +valuesFor struct_mapping.go /^func (sess *Session) valuesFor(recordType reflect.Type, record reflect.Value, columns []string) /;" f +whereFragment where.go /^type whereFragment struct {$/;" s +writeEqualityMapToSql where.go /^func writeEqualityMapToSql(eq map[string]interface{}, sql *bytes.Buffer, args *[]interface{}, an/;" f +writeQuotedColumn quote.go /^func (q MysqlQuoter) writeQuotedColumn(column string, sql *bytes.Buffer) {$/;" f +writeWhereCondition where.go /^func writeWhereCondition(sql *bytes.Buffer, k string, pred string, anyConditions bool) bool {$/;" f +writeWhereFragmentsToSql where.go /^func writeWhereFragmentsToSql(fragments []*whereFragment, sql *bytes.Buffer, args *[]interface{}/;" f diff --git a/thots.txt b/thots.txt new file mode 100644 index 0000000..4dae4cd --- /dev/null +++ b/thots.txt @@ -0,0 +1,11 @@ + +TODO: + - wire up insert to instrument, make a test for that + - any time we get an error do an EventErr + - add a perf test for query sql gen + - add a perf test for query sql with record mapping + + - Ideas: + - selectBuilder.Query() *sql.Rows, err // We might want to provide native sql.Rows support? Q: how does that impact metrics + - selectBuilder.Count() -- ignores Columns, Limit, Offset, Order and calculates COUNT(*) -- but: having could reference cols? + - i know rails/AR does some non-trivial stuff when you have joins. Things sometimes end up with DISTINCT in them. \ No newline at end of file diff --git a/transaction.go b/transaction.go new file mode 100644 index 0000000..aa9e284 --- /dev/null +++ b/transaction.go @@ -0,0 +1,62 @@ +package dbr + +import ( + "database/sql" +) + +// Tx is a transaction for the given Session +type Tx struct { + *Session + *sql.Tx +} + +// Begin creates a transaction for the given session +func (sess *Session) Begin() (*Tx, error) { + tx, err := sess.cxn.Db.Begin() + if err != nil { + return nil, sess.EventErr("dbr.begin.error", err) + } else { + sess.Event("dbr.begin") + } + + return &Tx{ + Session: sess, + Tx: tx, + }, nil +} + +// Commit finishes the transaction +func (tx *Tx) Commit() error { + err := tx.Tx.Commit() + if err != nil { + return tx.EventErr("dbr.commit.error", err) + } else { + tx.Event("dbr.commit") + } + return nil +} + +// Rollback cancels the transaction +func (tx *Tx) Rollback() error { + err := tx.Tx.Rollback() + if err != nil { + return tx.EventErr("dbr.rollback", err) + } else { + tx.Event("dbr.rollback") + } + return nil +} + +// RollbackUnlessCommitted rollsback the transaction unless it has already been committed or rolled back. +// Useful to defer tx.RollbackUnlessCommitted() -- so you don't have to handle N failure cases +// Keep in mind the only way to detect an error on the rollback is via the event log. +func (tx *Tx) RollbackUnlessCommitted() { + err := tx.Tx.Rollback() + if err == sql.ErrTxDone { + // ok + } else if err != nil { + tx.EventErr("dbr.rollback_unless_committed", err) + } else { + tx.Event("dbr.rollback") + } +} diff --git a/transaction_test.go b/transaction_test.go new file mode 100644 index 0000000..f2c8158 --- /dev/null +++ b/transaction_test.go @@ -0,0 +1,54 @@ +package dbr + +import ( + // "database/sql" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTransactionReal(t *testing.T) { + s := createRealSessionWithFixtures() + + tx, err := s.Begin() + assert.NoError(t, err) + + res, err := tx.InsertInto("dbr_people").Columns("name", "email").Values("Barack", "obama@whitehouse.gov").Exec() + + assert.NoError(t, err) + id, err := res.LastInsertId() + assert.NoError(t, err) + rowsAff, err := res.RowsAffected() + assert.NoError(t, err) + + assert.True(t, id > 0) + assert.Equal(t, rowsAff, 1) + + var person dbrPerson + err = tx.Select("*").From("dbr_people").Where("id = ?", id).LoadStruct(&person) + assert.NoError(t, err) + + assert.Equal(t, person.Id, id) + assert.Equal(t, person.Name, "Barack") + assert.Equal(t, person.Email.Valid, true) + assert.Equal(t, person.Email.String, "obama@whitehouse.gov") + + err = tx.Commit() + assert.NoError(t, err) +} + +func TestTransactionRollbackReal(t *testing.T) { + // Insert by specifying values + s := createRealSessionWithFixtures() + + tx, err := s.Begin() + assert.NoError(t, err) + + var person dbrPerson + err = tx.Select("*").From("dbr_people").Where("email = ?", "jonathan@uservoice.com").LoadStruct(&person) + assert.NoError(t, err) + assert.Equal(t, person.Name, "Jonathan") + + err = tx.Rollback() + assert.NoError(t, err) +} diff --git a/types.go b/types.go new file mode 100644 index 0000000..c942fd5 --- /dev/null +++ b/types.go @@ -0,0 +1,70 @@ +package dbr + +import ( + "database/sql" + "encoding/json" + + "github.com/go-sql-driver/mysql" +) + +// +// Your app can use these Null types instead of the defaults. The sole benefit you get is a MarshalJSON method that is not retarded. +// + +// NullString is a type that can be null or a string +type NullString struct { + sql.NullString +} + +// NullInt64 is a type that can be null or an int +type NullInt64 struct { + sql.NullInt64 +} + +// NullTime is a type that can be null or a time +type NullTime struct { + mysql.NullTime +} + +// NullBool is a type that can be null or a bool +type NullBool struct { + sql.NullBool +} + +var nullString = []byte("null") + +// MarshalJSON correctly serializes a NullString to JSON +func (n *NullString) MarshalJSON() ([]byte, error) { + if n.Valid { + j, e := json.Marshal(n.String) + return j, e + } + return nullString, nil +} + +// MarshalJSON correctly serializes a NullInt64 to JSON +func (n *NullInt64) MarshalJSON() ([]byte, error) { + if n.Valid { + j, e := json.Marshal(n.Int64) + return j, e + } + return nullString, nil +} + +// MarshalJSON correctly serializes a NullTime to JSON +func (n *NullTime) MarshalJSON() ([]byte, error) { + if n.Valid { + j, e := json.Marshal(n.Time) + return j, e + } + return nullString, nil +} + +// MarshalJSON correctly serializes a NullBool to JSON +func (n *NullBool) MarshalJSON() ([]byte, error) { + if n.Valid { + j, e := json.Marshal(n.Bool) + return j, e + } + return nullString, nil +} diff --git a/update.go b/update.go new file mode 100644 index 0000000..33864e9 --- /dev/null +++ b/update.go @@ -0,0 +1,208 @@ +package dbr + +import ( + "bytes" + "database/sql" + "fmt" + "time" +) + +// UpdateBuilder contains the clauses for an UPDATE statement +type UpdateBuilder struct { + *Session + Runner + + RawFullSql string + RawArguments []interface{} + + Table string + SetClauses []*setClause + WhereFragments []*whereFragment + OrderBys []string + LimitCount uint64 + LimitValid bool + OffsetCount uint64 + OffsetValid bool +} + +type setClause struct { + column string + value interface{} +} + +// Update creates a new UpdateBuilder for the given table +func (sess *Session) Update(table string) *UpdateBuilder { + return &UpdateBuilder{ + Session: sess, + Runner: sess.cxn.Db, + Table: table, + } +} + +// UpdateBySql creates a new UpdateBuilder for the given SQL string and arguments +func (sess *Session) UpdateBySql(sql string, args ...interface{}) *UpdateBuilder { + return &UpdateBuilder{ + Session: sess, + Runner: sess.cxn.Db, + RawFullSql: sql, + RawArguments: args, + } +} + +// Update creates a new UpdateBuilder for the given table bound to a transaction +func (tx *Tx) Update(table string) *UpdateBuilder { + return &UpdateBuilder{ + Session: tx.Session, + Runner: tx.Tx, + Table: table, + } +} + +// UpdateBySql creates a new UpdateBuilder for the given SQL string and arguments bound to a transaction +func (tx *Tx) UpdateBySql(sql string, args ...interface{}) *UpdateBuilder { + return &UpdateBuilder{ + Session: tx.Session, + Runner: tx.Tx, + RawFullSql: sql, + RawArguments: args, + } +} + +// Set appends a column/value pair for the statement +func (b *UpdateBuilder) Set(column string, value interface{}) *UpdateBuilder { + b.SetClauses = append(b.SetClauses, &setClause{column: column, value: value}) + return b +} + +// SetMap appends the elements of the map as column/value pairs for the statement +func (b *UpdateBuilder) SetMap(clauses map[string]interface{}) *UpdateBuilder { + for col, val := range clauses { + b = b.Set(col, val) + } + return b +} + +// Where appends a WHERE clause to the statement +func (b *UpdateBuilder) Where(whereSqlOrMap interface{}, args ...interface{}) *UpdateBuilder { + b.WhereFragments = append(b.WhereFragments, newWhereFragment(whereSqlOrMap, args)) + return b +} + +// OrderBy appends a column to ORDER the statement by +func (b *UpdateBuilder) OrderBy(ord string) *UpdateBuilder { + b.OrderBys = append(b.OrderBys, ord) + return b +} + +// OrderDir appends a column to ORDER the statement by with a given direction +func (b *UpdateBuilder) OrderDir(ord string, isAsc bool) *UpdateBuilder { + if isAsc { + b.OrderBys = append(b.OrderBys, ord+" ASC") + } else { + b.OrderBys = append(b.OrderBys, ord+" DESC") + } + return b +} + +// Limit sets a limit for the statement; overrides any existing LIMIT +func (b *UpdateBuilder) Limit(limit uint64) *UpdateBuilder { + b.LimitCount = limit + b.LimitValid = true + return b +} + +// Offset sets an offset for the statement; overrides any existing OFFSET +func (b *UpdateBuilder) Offset(offset uint64) *UpdateBuilder { + b.OffsetCount = offset + b.OffsetValid = true + return b +} + +// ToSql serialized the UpdateBuilder to a SQL string +// It returns the string with placeholders and a slice of query arguments +func (b *UpdateBuilder) ToSql() (string, []interface{}) { + if b.RawFullSql != "" { + return b.RawFullSql, b.RawArguments + } + + if len(b.Table) == 0 { + panic("no table specified") + } + if len(b.SetClauses) == 0 { + panic("no set clauses specified") + } + + var sql bytes.Buffer + var args []interface{} + + sql.WriteString("UPDATE ") + sql.WriteString(b.Table) + sql.WriteString(" SET ") + + // Build SET clause SQL with placeholders and add values to args + for i, c := range b.SetClauses { + if i > 0 { + sql.WriteString(", ") + } + Quoter.writeQuotedColumn(c.column, &sql) + if e, ok := c.value.(*expr); ok { + sql.WriteString(" = ") + sql.WriteString(e.Sql) + args = append(args, e.Values...) + } else { + sql.WriteString(" = ?") + args = append(args, c.value) + } + } + + // Write WHERE clause if we have any fragments + if len(b.WhereFragments) > 0 { + sql.WriteString(" WHERE ") + writeWhereFragmentsToSql(b.WhereFragments, &sql, &args) + } + + // Ordering and limiting + if len(b.OrderBys) > 0 { + sql.WriteString(" ORDER BY ") + for i, s := range b.OrderBys { + if i > 0 { + sql.WriteString(", ") + } + sql.WriteString(s) + } + } + + if b.LimitValid { + sql.WriteString(" LIMIT ") + fmt.Fprint(&sql, b.LimitCount) + } + + if b.OffsetValid { + sql.WriteString(" OFFSET ") + fmt.Fprint(&sql, b.OffsetCount) + } + + return sql.String(), args +} + +// Exec executes the statement represented by the UpdateBuilder +// It returns the raw database/sql Result and an error if there was one +func (b *UpdateBuilder) Exec() (sql.Result, error) { + sql, args := b.ToSql() + + fullSql, err := Interpolate(sql, args) + if err != nil { + return nil, b.EventErrKv("dbr.update.exec.interpolate", err, kvs{"sql": fullSql}) + } + + // Start the timer: + startTime := time.Now() + defer func() { b.TimingKv("dbr.update", time.Since(startTime).Nanoseconds(), kvs{"sql": fullSql}) }() + + result, err := b.Runner.Exec(fullSql) + if err != nil { + return result, b.EventErrKv("dbr.update.exec.exec", err, kvs{"sql": fullSql}) + } + + return result, nil +} diff --git a/update_test.go b/update_test.go new file mode 100644 index 0000000..dfd593a --- /dev/null +++ b/update_test.go @@ -0,0 +1,131 @@ +package dbr + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func BenchmarkUpdateValuesSql(b *testing.B) { + s := createFakeSession() + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + s.Update("alpha").Set("something_id", 1).Where("id", 1).ToSql() + } +} + +func BenchmarkUpdateValueMapSql(b *testing.B) { + s := createFakeSession() + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + s.Update("alpha").Set("something_id", 1).SetMap(map[string]interface{}{"b": 1, "c": 2}).Where("id", 1).ToSql() + } +} + +func TestUpdateAllToSql(t *testing.T) { + s := createFakeSession() + + sql, args := s.Update("a").Set("b", 1).Set("c", 2).ToSql() + + assert.Equal(t, sql, "UPDATE a SET `b` = ?, `c` = ?") + assert.Equal(t, args, []interface{}{1, 2}) +} + +func TestUpdateSingleToSql(t *testing.T) { + s := createFakeSession() + + sql, args := s.Update("a").Set("b", 1).Set("c", 2).Where("id = ?", 1).ToSql() + + assert.Equal(t, sql, "UPDATE a SET `b` = ?, `c` = ? WHERE (id = ?)") + assert.Equal(t, args, []interface{}{1, 2, 1}) +} + +func TestUpdateSetMapToSql(t *testing.T) { + s := createFakeSession() + + sql, args := s.Update("a").SetMap(map[string]interface{}{"b": 1, "c": 2}).Where("id = ?", 1).ToSql() + + if sql == "UPDATE a SET `b` = ?, `c` = ? WHERE (id = ?)" { + assert.Equal(t, args, []interface{}{1, 2, 1}) + } else { + assert.Equal(t, sql, "UPDATE a SET `c` = ?, `b` = ? WHERE (id = ?)") + assert.Equal(t, args, []interface{}{2, 1, 1}) + } +} + +func TestUpdateSetExprToSql(t *testing.T) { + s := createFakeSession() + + sql, args := s.Update("a").Set("foo", 1).Set("bar", Expr("COALESCE(bar, 0) + 1")).Where("id = ?", 9).ToSql() + + assert.Equal(t, sql, "UPDATE a SET `foo` = ?, `bar` = COALESCE(bar, 0) + 1 WHERE (id = ?)") + assert.Equal(t, args, []interface{}{1, 9}) + + sql, args = s.Update("a").Set("foo", 1).Set("bar", Expr("COALESCE(bar, 0) + ?", 2)).Where("id = ?", 9).ToSql() + + assert.Equal(t, sql, "UPDATE a SET `foo` = ?, `bar` = COALESCE(bar, 0) + ? WHERE (id = ?)") + assert.Equal(t, args, []interface{}{1, 2, 9}) +} + +func TestUpdateTenStaringFromTwentyToSql(t *testing.T) { + s := createFakeSession() + + sql, args := s.Update("a").Set("b", 1).Limit(10).Offset(20).ToSql() + + assert.Equal(t, sql, "UPDATE a SET `b` = ? LIMIT 10 OFFSET 20") + assert.Equal(t, args, []interface{}{1}) +} + +func TestUpdateKeywordColumnName(t *testing.T) { + s := createRealSessionWithFixtures() + + // Insert a user with a key + res, err := s.InsertInto("dbr_people").Columns("name", "email", "key").Values("Benjamin", "ben@whitehouse.gov", "6").Exec() + assert.NoError(t, err) + + // Update the key + res, err = s.Update("dbr_people").Set("key", "6-revoked").Where(Eq{"key": "6"}).Exec() + assert.NoError(t, err) + + // Assert our record was updated (and only our record) + rowsAff, err := res.RowsAffected() + assert.NoError(t, err) + assert.Equal(t, rowsAff, 1) + + var person dbrPerson + err = s.Select("*").From("dbr_people").Where(Eq{"email": "ben@whitehouse.gov"}).LoadStruct(&person) + assert.NoError(t, err) + + assert.Equal(t, person.Name, "Benjamin") + assert.Equal(t, person.Key.String, "6-revoked") +} + +func TestUpdateReal(t *testing.T) { + s := createRealSessionWithFixtures() + + // Insert a George + res, err := s.InsertInto("dbr_people").Columns("name", "email").Values("George", "george@whitehouse.gov").Exec() + assert.NoError(t, err) + + // Get George's ID + id, err := res.LastInsertId() + assert.NoError(t, err) + + // Rename our George to Barack + res, err = s.Update("dbr_people").SetMap(map[string]interface{}{"name": "Barack", "email": "barack@whitehouse.gov"}).Where("id = ?", id).Exec() + + assert.NoError(t, err) + + var person dbrPerson + err = s.Select("*").From("dbr_people").Where("id = ?", id).LoadStruct(&person) + assert.NoError(t, err) + + assert.Equal(t, person.Id, id) + assert.Equal(t, person.Name, "Barack") + assert.Equal(t, person.Email.Valid, true) + assert.Equal(t, person.Email.String, "barack@whitehouse.gov") +} diff --git a/util.go b/util.go new file mode 100644 index 0000000..59de044 --- /dev/null +++ b/util.go @@ -0,0 +1,23 @@ +package dbr + +// NameMapping is the routine to use when mapping column names to struct properties +var NameMapping = camelCaseToSnakeCase + +func camelCaseToSnakeCase(name string) string { + var newstr []rune + firstTime := true + + for _, chr := range name { + if isUpper := 'A' <= chr && chr <= 'Z'; isUpper { + if firstTime == true { + firstTime = false + } else { + newstr = append(newstr, '_') + } + chr -= ('A' - 'a') + } + newstr = append(newstr, chr) + } + + return string(newstr) +} diff --git a/where.go b/where.go new file mode 100644 index 0000000..a7f012f --- /dev/null +++ b/where.go @@ -0,0 +1,104 @@ +package dbr + +import ( + "bytes" + "reflect" +) + +// Eq is a map column -> value pairs which must be matched in a query +type Eq map[string]interface{} + +type whereFragment struct { + Condition string + Values []interface{} + EqualityMap map[string]interface{} +} + +func newWhereFragment(whereSqlOrMap interface{}, args []interface{}) *whereFragment { + switch pred := whereSqlOrMap.(type) { + case string: + return &whereFragment{Condition: pred, Values: args} + case map[string]interface{}: + return &whereFragment{EqualityMap: pred} + case Eq: + return &whereFragment{EqualityMap: map[string]interface{}(pred)} + default: + panic("Invalid argument passed to Where. Pass a string or an Eq map.") + } + + return nil +} + +// Invariant: only called when len(fragments) > 0 +func writeWhereFragmentsToSql(fragments []*whereFragment, sql *bytes.Buffer, args *[]interface{}) { + anyConditions := false + for _, f := range fragments { + if f.Condition != "" { + if anyConditions { + sql.WriteString(" AND (") + } else { + sql.WriteRune('(') + anyConditions = true + } + sql.WriteString(f.Condition) + sql.WriteRune(')') + if len(f.Values) > 0 { + *args = append(*args, f.Values...) + } + } else if f.EqualityMap != nil { + anyConditions = writeEqualityMapToSql(f.EqualityMap, sql, args, anyConditions) + } else { + panic("invalid equality map") + } + } +} + +func writeEqualityMapToSql(eq map[string]interface{}, sql *bytes.Buffer, args *[]interface{}, anyConditions bool) bool { + for k, v := range eq { + if v == nil { + anyConditions = writeWhereCondition(sql, k, " IS NULL", anyConditions) + } else { + vVal := reflect.ValueOf(v) + + if vVal.Kind() == reflect.Array || vVal.Kind() == reflect.Slice { + vValLen := vVal.Len() + if vValLen == 0 { + if vVal.IsNil() { + anyConditions = writeWhereCondition(sql, k, " IS NULL", anyConditions) + } else { + if anyConditions { + sql.WriteString(" AND (1=0)") + } else { + sql.WriteString("(1=0)") + } + } + } else if vValLen == 1 { + anyConditions = writeWhereCondition(sql, k, " = ?", anyConditions) + *args = append(*args, vVal.Index(0).Interface()) + } else { + anyConditions = writeWhereCondition(sql, k, " IN ?", anyConditions) + *args = append(*args, v) + } + } else { + anyConditions = writeWhereCondition(sql, k, " = ?", anyConditions) + *args = append(*args, v) + } + } + } + + return anyConditions +} + +func writeWhereCondition(sql *bytes.Buffer, k string, pred string, anyConditions bool) bool { + if anyConditions { + sql.WriteString(" AND (") + } else { + sql.WriteRune('(') + anyConditions = true + } + Quoter.writeQuotedColumn(k, sql) + sql.WriteString(pred) + sql.WriteRune(')') + + return anyConditions +}