// Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. package mysql import ( "crypto/sha1" "crypto/tls" "database/sql/driver" "encoding/binary" "errors" "fmt" "io" "net" "net/url" "strings" "time" ) var ( tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs errInvalidDSNUnescaped = errors.New("Invalid DSN: Did you forget to escape a param value?") errInvalidDSNAddr = errors.New("Invalid DSN: Network Address not terminated (missing closing brace)") errInvalidDSNNoSlash = errors.New("Invalid DSN: Missing the slash separating the database name") ) func init() { tlsConfigRegister = make(map[string]*tls.Config) } // RegisterTLSConfig registers a custom tls.Config to be used with sql.Open. // Use the key as a value in the DSN where tls=value. // // rootCertPool := x509.NewCertPool() // pem, err := ioutil.ReadFile("/path/ca-cert.pem") // if err != nil { // log.Fatal(err) // } // if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { // log.Fatal("Failed to append PEM.") // } // clientCert := make([]tls.Certificate, 0, 1) // certs, err := tls.LoadX509KeyPair("/path/client-cert.pem", "/path/client-key.pem") // if err != nil { // log.Fatal(err) // } // clientCert = append(clientCert, certs) // mysql.RegisterTLSConfig("custom", &tls.Config{ // RootCAs: rootCertPool, // Certificates: clientCert, // }) // db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom") // func RegisterTLSConfig(key string, config *tls.Config) error { if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" { return fmt.Errorf("Key '%s' is reserved", key) } tlsConfigRegister[key] = config return nil } // DeregisterTLSConfig removes the tls.Config associated with key. func DeregisterTLSConfig(key string) { delete(tlsConfigRegister, key) } // parseDSN parses the DSN string to a config func parseDSN(dsn string) (cfg *config, err error) { // New config with some default values cfg = &config{ loc: time.UTC, collation: defaultCollation, } // TODO: use strings.IndexByte when we can depend on Go 1.2 // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] // Find the last '/' (since the password or the net addr might contain a '/') foundSlash := false for i := len(dsn) - 1; i >= 0; i-- { if dsn[i] == '/' { foundSlash = true var j, k int // left part is empty if i <= 0 if i > 0 { // [username[:password]@][protocol[(address)]] // Find the last '@' in dsn[:i] for j = i; j >= 0; j-- { if dsn[j] == '@' { // username[:password] // Find the first ':' in dsn[:j] for k = 0; k < j; k++ { if dsn[k] == ':' { cfg.passwd = dsn[k+1 : j] break } } cfg.user = dsn[:k] break } } // [protocol[(address)]] // Find the first '(' in dsn[j+1:i] for k = j + 1; k < i; k++ { if dsn[k] == '(' { // dsn[i-1] must be == ')' if an address is specified if dsn[i-1] != ')' { if strings.ContainsRune(dsn[k+1:i], ')') { return nil, errInvalidDSNUnescaped } return nil, errInvalidDSNAddr } cfg.addr = dsn[k+1 : i-1] break } } cfg.net = dsn[j+1 : k] } // dbname[?param1=value1&...¶mN=valueN] // Find the first '?' in dsn[i+1:] for j = i + 1; j < len(dsn); j++ { if dsn[j] == '?' { if err = parseDSNParams(cfg, dsn[j+1:]); err != nil { return } break } } cfg.dbname = dsn[i+1 : j] break } } if !foundSlash && len(dsn) > 0 { return nil, errInvalidDSNNoSlash } // Set default network if empty if cfg.net == "" { cfg.net = "tcp" } // Set default address if empty if cfg.addr == "" { switch cfg.net { case "tcp": cfg.addr = "127.0.0.1:3306" case "unix": cfg.addr = "/tmp/mysql.sock" default: return nil, errors.New("Default addr for network '" + cfg.net + "' unknown") } } return } // parseDSNParams parses the DSN "query string" // Values must be url.QueryEscape'ed func parseDSNParams(cfg *config, params string) (err error) { for _, v := range strings.Split(params, "&") { param := strings.SplitN(v, "=", 2) if len(param) != 2 { continue } // cfg params switch value := param[1]; param[0] { // Disable INFILE whitelist / enable all files case "allowAllFiles": var isBool bool cfg.allowAllFiles, isBool = readBool(value) if !isBool { return fmt.Errorf("Invalid Bool value: %s", value) } // Use old authentication mode (pre MySQL 4.1) case "allowOldPasswords": var isBool bool cfg.allowOldPasswords, isBool = readBool(value) if !isBool { return fmt.Errorf("Invalid Bool value: %s", value) } // Switch "rowsAffected" mode case "clientFoundRows": var isBool bool cfg.clientFoundRows, isBool = readBool(value) if !isBool { return fmt.Errorf("Invalid Bool value: %s", value) } // Collation case "collation": collation, ok := collations[value] if !ok { // Note possibility for false negatives: // could be triggered although the collation is valid if the // collations map does not contain entries the server supports. err = errors.New("unknown collation") return } cfg.collation = collation break case "columnsWithAlias": var isBool bool cfg.columnsWithAlias, isBool = readBool(value) if !isBool { return fmt.Errorf("Invalid Bool value: %s", value) } // Time Location case "loc": if value, err = url.QueryUnescape(value); err != nil { return } cfg.loc, err = time.LoadLocation(value) if err != nil { return } // Dial Timeout case "timeout": cfg.timeout, err = time.ParseDuration(value) if err != nil { return } // TLS-Encryption case "tls": boolValue, isBool := readBool(value) if isBool { if boolValue { cfg.tls = &tls.Config{} } } else { if strings.ToLower(value) == "skip-verify" { cfg.tls = &tls.Config{InsecureSkipVerify: true} } else if tlsConfig, ok := tlsConfigRegister[value]; ok { if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify { host, _, err := net.SplitHostPort(cfg.addr) if err == nil { tlsConfig.ServerName = host } } cfg.tls = tlsConfig } else { return fmt.Errorf("Invalid value / unknown config name: %s", value) } } default: // lazy init if cfg.params == nil { cfg.params = make(map[string]string) } if cfg.params[param[0]], err = url.QueryUnescape(value); err != nil { return } } } return } // Returns the bool value of the input. // The 2nd return value indicates if the input was a valid bool value func readBool(input string) (value bool, valid bool) { switch input { case "1", "true", "TRUE", "True": return true, true case "0", "false", "FALSE", "False": return false, true } // Not a valid bool value return } /****************************************************************************** * Authentication * ******************************************************************************/ // Encrypt password using 4.1+ method func scramblePassword(scramble, password []byte) []byte { if len(password) == 0 { return nil } // stage1Hash = SHA1(password) crypt := sha1.New() crypt.Write(password) stage1 := crypt.Sum(nil) // scrambleHash = SHA1(scramble + SHA1(stage1Hash)) // inner Hash crypt.Reset() crypt.Write(stage1) hash := crypt.Sum(nil) // outer Hash crypt.Reset() crypt.Write(scramble) crypt.Write(hash) scramble = crypt.Sum(nil) // token = scrambleHash XOR stage1Hash for i := range scramble { scramble[i] ^= stage1[i] } return scramble } // Encrypt password using pre 4.1 (old password) method // https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c type myRnd struct { seed1, seed2 uint32 } const myRndMaxVal = 0x3FFFFFFF // Pseudo random number generator func newMyRnd(seed1, seed2 uint32) *myRnd { return &myRnd{ seed1: seed1 % myRndMaxVal, seed2: seed2 % myRndMaxVal, } } // Tested to be equivalent to MariaDB's floating point variant // http://play.golang.org/p/QHvhd4qved // http://play.golang.org/p/RG0q4ElWDx func (r *myRnd) NextByte() byte { r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal return byte(uint64(r.seed1) * 31 / myRndMaxVal) } // Generate binary hash from byte string using insecure pre 4.1 method func pwHash(password []byte) (result [2]uint32) { var add uint32 = 7 var tmp uint32 result[0] = 1345345333 result[1] = 0x12345671 for _, c := range password { // skip spaces and tabs in password if c == ' ' || c == '\t' { continue } tmp = uint32(c) result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8) result[1] += (result[1] << 8) ^ result[0] add += tmp } // Remove sign bit (1<<31)-1) result[0] &= 0x7FFFFFFF result[1] &= 0x7FFFFFFF return } // Encrypt password using insecure pre 4.1 method func scrambleOldPassword(scramble, password []byte) []byte { if len(password) == 0 { return nil } scramble = scramble[:8] hashPw := pwHash(password) hashSc := pwHash(scramble) r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1]) var out [8]byte for i := range out { out[i] = r.NextByte() + 64 } mask := r.NextByte() for i := range out { out[i] ^= mask } return out[:] } /****************************************************************************** * Time related utils * ******************************************************************************/ // NullTime represents a time.Time that may be NULL. // NullTime implements the Scanner interface so // it can be used as a scan destination: // // var nt NullTime // err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt) // ... // if nt.Valid { // // use nt.Time // } else { // // NULL value // } // // This NullTime implementation is not driver-specific type NullTime struct { Time time.Time Valid bool // Valid is true if Time is not NULL } // Scan implements the Scanner interface. // The value type must be time.Time or string / []byte (formatted time-string), // otherwise Scan fails. func (nt *NullTime) Scan(value interface{}) (err error) { if value == nil { nt.Time, nt.Valid = time.Time{}, false return } switch v := value.(type) { case time.Time: nt.Time, nt.Valid = v, true return case []byte: nt.Time, err = parseDateTime(string(v), time.UTC) nt.Valid = (err == nil) return case string: nt.Time, err = parseDateTime(v, time.UTC) nt.Valid = (err == nil) return } nt.Valid = false return fmt.Errorf("Can't convert %T to time.Time", value) } // Value implements the driver Valuer interface. func (nt NullTime) Value() (driver.Value, error) { if !nt.Valid { return nil, nil } return nt.Time, nil } func parseDateTime(str string, loc *time.Location) (t time.Time, err error) { base := "0000-00-00 00:00:00.0000000" switch len(str) { case 10, 19, 21, 22, 23, 24, 25, 26: // up to "YYYY-MM-DD HH:MM:SS.MMMMMM" if str == base[:len(str)] { return } t, err = time.Parse(timeFormat[:len(str)], str) default: err = fmt.Errorf("Invalid Time-String: %s", str) return } // Adjust location if err == nil && loc != time.UTC { y, mo, d := t.Date() h, mi, s := t.Clock() t, err = time.Date(y, mo, d, h, mi, s, t.Nanosecond(), loc), nil } return } func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Value, error) { switch num { case 0: return time.Time{}, nil case 4: return time.Date( int(binary.LittleEndian.Uint16(data[:2])), // year time.Month(data[2]), // month int(data[3]), // day 0, 0, 0, 0, loc, ), nil case 7: return time.Date( int(binary.LittleEndian.Uint16(data[:2])), // year time.Month(data[2]), // month int(data[3]), // day int(data[4]), // hour int(data[5]), // minutes int(data[6]), // seconds 0, loc, ), nil case 11: return time.Date( int(binary.LittleEndian.Uint16(data[:2])), // year time.Month(data[2]), // month int(data[3]), // day int(data[4]), // hour int(data[5]), // minutes int(data[6]), // seconds int(binary.LittleEndian.Uint32(data[7:11]))*1000, // nanoseconds loc, ), nil } return nil, fmt.Errorf("Invalid DATETIME-packet length %d", num) } // zeroDateTime is used in formatBinaryDateTime to avoid an allocation // if the DATE or DATETIME has the zero value. // It must never be changed. // The current behavior depends on database/sql copying the result. var zeroDateTime = []byte("0000-00-00 00:00:00.000000") func formatBinaryDateTime(src []byte, length uint8, justTime bool) (driver.Value, error) { // length expects the deterministic length of the zero value, // negative time and 100+ hours are automatically added if needed const digits01 = "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" const digits10 = "0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999" if len(src) == 0 { if justTime { return zeroDateTime[11 : 11+length], nil } return zeroDateTime[:length], nil } var dst []byte // return value var pt, p1, p2, p3 byte // current digit pair var zOffs byte // offset of value in zeroDateTime if justTime { switch length { case 8, // time (can be up to 10 when negative and 100+ hours) 10, 11, 12, 13, 14, 15: // time with fractional seconds default: return nil, fmt.Errorf("illegal TIME length %d", length) } switch len(src) { case 8, 12: default: return nil, fmt.Errorf("Invalid TIME-packet length %d", len(src)) } // +2 to enable negative time and 100+ hours dst = make([]byte, 0, length+2) if src[0] == 1 { dst = append(dst, '-') } if src[1] != 0 { hour := uint16(src[1])*24 + uint16(src[5]) pt = byte(hour / 100) p1 = byte(hour - 100*uint16(pt)) dst = append(dst, digits01[pt]) } else { p1 = src[5] } zOffs = 11 src = src[6:] } else { switch length { case 10, 19, 21, 22, 23, 24, 25, 26: default: t := "DATE" if length > 10 { t += "TIME" } return nil, fmt.Errorf("illegal %s length %d", t, length) } switch len(src) { case 4, 7, 11: default: t := "DATE" if length > 10 { t += "TIME" } return nil, fmt.Errorf("illegal %s-packet length %d", t, len(src)) } dst = make([]byte, 0, length) // start with the date year := binary.LittleEndian.Uint16(src[:2]) pt = byte(year / 100) p1 = byte(year - 100*uint16(pt)) p2, p3 = src[2], src[3] dst = append(dst, digits10[pt], digits01[pt], digits10[p1], digits01[p1], '-', digits10[p2], digits01[p2], '-', digits10[p3], digits01[p3], ) if length == 10 { return dst, nil } if len(src) == 4 { return append(dst, zeroDateTime[10:length]...), nil } dst = append(dst, ' ') p1 = src[4] // hour src = src[5:] } // p1 is 2-digit hour, src is after hour p2, p3 = src[0], src[1] dst = append(dst, digits10[p1], digits01[p1], ':', digits10[p2], digits01[p2], ':', digits10[p3], digits01[p3], ) if length <= byte(len(dst)) { return dst, nil } src = src[2:] if len(src) == 0 { return append(dst, zeroDateTime[19:zOffs+length]...), nil } microsecs := binary.LittleEndian.Uint32(src[:4]) p1 = byte(microsecs / 10000) microsecs -= 10000 * uint32(p1) p2 = byte(microsecs / 100) microsecs -= 100 * uint32(p2) p3 = byte(microsecs) switch decimals := zOffs + length - 20; decimals { default: return append(dst, '.', digits10[p1], digits01[p1], digits10[p2], digits01[p2], digits10[p3], digits01[p3], ), nil case 1: return append(dst, '.', digits10[p1], ), nil case 2: return append(dst, '.', digits10[p1], digits01[p1], ), nil case 3: return append(dst, '.', digits10[p1], digits01[p1], digits10[p2], ), nil case 4: return append(dst, '.', digits10[p1], digits01[p1], digits10[p2], digits01[p2], ), nil case 5: return append(dst, '.', digits10[p1], digits01[p1], digits10[p2], digits01[p2], digits10[p3], ), nil } } /****************************************************************************** * Convert from and to bytes * ******************************************************************************/ func uint64ToBytes(n uint64) []byte { return []byte{ byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24), byte(n >> 32), byte(n >> 40), byte(n >> 48), byte(n >> 56), } } func uint64ToString(n uint64) []byte { var a [20]byte i := 20 // U+0030 = 0 // ... // U+0039 = 9 var q uint64 for n >= 10 { i-- q = n / 10 a[i] = uint8(n-q*10) + 0x30 n = q } i-- a[i] = uint8(n) + 0x30 return a[i:] } // treats string value as unsigned integer representation func stringToInt(b []byte) int { val := 0 for i := range b { val *= 10 val += int(b[i] - 0x30) } return val } // returns the string read as a bytes slice, wheter the value is NULL, // the number of bytes read and an error, in case the string is longer than // the input slice func readLengthEncodedString(b []byte) ([]byte, bool, int, error) { // Get length num, isNull, n := readLengthEncodedInteger(b) if num < 1 { return b[n:n], isNull, n, nil } n += int(num) // Check data length if len(b) >= n { return b[n-int(num) : n], false, n, nil } return nil, false, n, io.EOF } // returns the number of bytes skipped and an error, in case the string is // longer than the input slice func skipLengthEncodedString(b []byte) (int, error) { // Get length num, _, n := readLengthEncodedInteger(b) if num < 1 { return n, nil } n += int(num) // Check data length if len(b) >= n { return n, nil } return n, io.EOF } // returns the number read, whether the value is NULL and the number of bytes read func readLengthEncodedInteger(b []byte) (uint64, bool, int) { switch b[0] { // 251: NULL case 0xfb: return 0, true, 1 // 252: value of following 2 case 0xfc: return uint64(b[1]) | uint64(b[2])<<8, false, 3 // 253: value of following 3 case 0xfd: return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16, false, 4 // 254: value of following 8 case 0xfe: return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 | uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 | uint64(b[7])<<48 | uint64(b[8])<<56, false, 9 } // 0-250: value of first byte return uint64(b[0]), false, 1 } // encodes a uint64 value and appends it to the given bytes slice func appendLengthEncodedInteger(b []byte, n uint64) []byte { switch { case n <= 250: return append(b, byte(n)) case n <= 0xffff: return append(b, 0xfc, byte(n), byte(n>>8)) case n <= 0xffffff: return append(b, 0xfd, byte(n), byte(n>>8), byte(n>>16)) } return append(b, 0xfe, byte(n), byte(n>>8), byte(n>>16), byte(n>>24), byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56)) }