231 lines
5.5 KiB
Go
231 lines
5.5 KiB
Go
package dbr
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"time"
|
|
|
|
"go.opentelemetry.io/otel"
|
|
"go.opentelemetry.io/otel/codes"
|
|
)
|
|
|
|
const tracerName = "git.bit5.ru/backend/dbr"
|
|
|
|
var tracer = otel.Tracer(tracerName)
|
|
|
|
// UpdateBuilder contains the clauses for an UPDATE statement
|
|
type UpdateBuilder struct {
|
|
*Session
|
|
Runner
|
|
|
|
RawFullSql string
|
|
RawArguments []interface{}
|
|
|
|
Table string
|
|
SetClauses []*setClause
|
|
WhereFragments []*whereFragment
|
|
OrderBys []string
|
|
LimitCount uint64
|
|
LimitValid bool
|
|
OffsetCount uint64
|
|
OffsetValid bool
|
|
}
|
|
|
|
type setClause struct {
|
|
column string
|
|
value interface{}
|
|
}
|
|
|
|
// Update creates a new UpdateBuilder for the given table
|
|
func (sess *Session) Update(table string) *UpdateBuilder {
|
|
return &UpdateBuilder{
|
|
Session: sess,
|
|
Runner: sess.cxn.Db,
|
|
Table: table,
|
|
}
|
|
}
|
|
|
|
// UpdateBySql creates a new UpdateBuilder for the given SQL string and arguments
|
|
func (sess *Session) UpdateBySql(sql string, args ...interface{}) *UpdateBuilder {
|
|
return &UpdateBuilder{
|
|
Session: sess,
|
|
Runner: sess.cxn.Db,
|
|
RawFullSql: sql,
|
|
RawArguments: args,
|
|
}
|
|
}
|
|
|
|
// Update creates a new UpdateBuilder for the given table bound to a transaction
|
|
func (tx *Tx) Update(table string) *UpdateBuilder {
|
|
return &UpdateBuilder{
|
|
Session: tx.Session,
|
|
Runner: tx.Tx,
|
|
Table: table,
|
|
}
|
|
}
|
|
|
|
// UpdateBySql creates a new UpdateBuilder for the given SQL string and arguments bound to a transaction
|
|
func (tx *Tx) UpdateBySql(sql string, args ...interface{}) *UpdateBuilder {
|
|
return &UpdateBuilder{
|
|
Session: tx.Session,
|
|
Runner: tx.Tx,
|
|
RawFullSql: sql,
|
|
RawArguments: args,
|
|
}
|
|
}
|
|
|
|
// Set appends a column/value pair for the statement
|
|
func (b *UpdateBuilder) Set(column string, value interface{}) *UpdateBuilder {
|
|
b.SetClauses = append(b.SetClauses, &setClause{column: column, value: value})
|
|
return b
|
|
}
|
|
|
|
// SetMap appends the elements of the map as column/value pairs for the statement
|
|
func (b *UpdateBuilder) SetMap(clauses map[string]interface{}) *UpdateBuilder {
|
|
for col, val := range clauses {
|
|
b = b.Set(col, val)
|
|
}
|
|
return b
|
|
}
|
|
|
|
// Where appends a WHERE clause to the statement
|
|
func (b *UpdateBuilder) Where(whereSqlOrMap interface{}, args ...interface{}) *UpdateBuilder {
|
|
b.WhereFragments = append(b.WhereFragments, newWhereFragment(whereSqlOrMap, args))
|
|
return b
|
|
}
|
|
|
|
// OrderBy appends a column to ORDER the statement by
|
|
func (b *UpdateBuilder) OrderBy(ord string) *UpdateBuilder {
|
|
b.OrderBys = append(b.OrderBys, ord)
|
|
return b
|
|
}
|
|
|
|
// OrderDir appends a column to ORDER the statement by with a given direction
|
|
func (b *UpdateBuilder) OrderDir(ord string, isAsc bool) *UpdateBuilder {
|
|
if isAsc {
|
|
b.OrderBys = append(b.OrderBys, ord+" ASC")
|
|
} else {
|
|
b.OrderBys = append(b.OrderBys, ord+" DESC")
|
|
}
|
|
return b
|
|
}
|
|
|
|
// Limit sets a limit for the statement; overrides any existing LIMIT
|
|
func (b *UpdateBuilder) Limit(limit uint64) *UpdateBuilder {
|
|
b.LimitCount = limit
|
|
b.LimitValid = true
|
|
return b
|
|
}
|
|
|
|
// Offset sets an offset for the statement; overrides any existing OFFSET
|
|
func (b *UpdateBuilder) Offset(offset uint64) *UpdateBuilder {
|
|
b.OffsetCount = offset
|
|
b.OffsetValid = true
|
|
return b
|
|
}
|
|
|
|
// ToSql serialized the UpdateBuilder to a SQL string
|
|
// It returns the string with placeholders and a slice of query arguments
|
|
func (b *UpdateBuilder) ToSql() (string, []interface{}) {
|
|
if b.RawFullSql != "" {
|
|
return b.RawFullSql, b.RawArguments
|
|
}
|
|
|
|
if len(b.Table) == 0 {
|
|
panic("no table specified")
|
|
}
|
|
if len(b.SetClauses) == 0 {
|
|
panic("no set clauses specified")
|
|
}
|
|
|
|
var sql bytes.Buffer
|
|
var args []interface{}
|
|
|
|
sql.WriteString("UPDATE ")
|
|
sql.WriteString(b.Table)
|
|
sql.WriteString(" SET ")
|
|
|
|
// Build SET clause SQL with placeholders and add values to args
|
|
for i, c := range b.SetClauses {
|
|
if i > 0 {
|
|
sql.WriteString(", ")
|
|
}
|
|
Quoter.writeQuotedColumn(c.column, &sql)
|
|
if e, ok := c.value.(*expr); ok {
|
|
sql.WriteString(" = ")
|
|
sql.WriteString(e.Sql)
|
|
args = append(args, e.Values...)
|
|
} else {
|
|
sql.WriteString(" = ?")
|
|
args = append(args, c.value)
|
|
}
|
|
}
|
|
|
|
// Write WHERE clause if we have any fragments
|
|
if len(b.WhereFragments) > 0 {
|
|
sql.WriteString(" WHERE ")
|
|
writeWhereFragmentsToSql(b.WhereFragments, &sql, &args)
|
|
}
|
|
|
|
// Ordering and limiting
|
|
if len(b.OrderBys) > 0 {
|
|
sql.WriteString(" ORDER BY ")
|
|
for i, s := range b.OrderBys {
|
|
if i > 0 {
|
|
sql.WriteString(", ")
|
|
}
|
|
sql.WriteString(s)
|
|
}
|
|
}
|
|
|
|
if b.LimitValid {
|
|
sql.WriteString(" LIMIT ")
|
|
fmt.Fprint(&sql, b.LimitCount)
|
|
}
|
|
|
|
if b.OffsetValid {
|
|
sql.WriteString(" OFFSET ")
|
|
fmt.Fprint(&sql, b.OffsetCount)
|
|
}
|
|
|
|
return sql.String(), args
|
|
}
|
|
|
|
func (b *UpdateBuilder) Exec() (sql.Result, error) {
|
|
sql, args := b.ToSql()
|
|
|
|
fullSql, err := Interpolate(sql, args)
|
|
if err != nil {
|
|
return nil, b.EventErrKv("dbr.update.exec.interpolate", err, kvs{"sql": fullSql})
|
|
}
|
|
|
|
// Start the timer:
|
|
startTime := time.Now()
|
|
defer func() { b.TimingKv("dbr.update", time.Since(startTime).Nanoseconds(), kvs{"sql": fullSql}) }()
|
|
|
|
result, err := b.Runner.Exec(fullSql)
|
|
if err != nil {
|
|
return result, b.EventErrKv("dbr.update.exec.exec", err, kvs{"sql": fullSql})
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// ExecContext executes the statement represented by the UpdateBuilder
|
|
// It returns the raw database/sql Result and an error if there was one
|
|
func (b *UpdateBuilder) ExecContext(ctx context.Context) (sql.Result, error) {
|
|
_, span := tracer.Start(ctx, "UpdateBuilder.ExecContext")
|
|
defer span.End()
|
|
|
|
res, err := b.Exec()
|
|
if err != nil {
|
|
span.RecordError(err)
|
|
span.SetStatus(codes.Error, err.Error())
|
|
return nil, err
|
|
}
|
|
|
|
return res, err
|
|
}
|