456 lines
9.9 KiB
Go
456 lines
9.9 KiB
Go
// Copyright 2012 Gary Burd
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
|
// not use this file except in compliance with the License. You may obtain
|
|
// a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
// License for the specific language governing permissions and limitations
|
|
// under the License.
|
|
|
|
package redis
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// conn is the low-level implementation of Conn
|
|
type conn struct {
|
|
|
|
// Shared
|
|
mu sync.Mutex
|
|
pending int
|
|
err error
|
|
conn net.Conn
|
|
|
|
// Read
|
|
readTimeout time.Duration
|
|
br *bufio.Reader
|
|
|
|
// Write
|
|
writeTimeout time.Duration
|
|
bw *bufio.Writer
|
|
|
|
// Scratch space for formatting argument length.
|
|
// '*' or '$', length, "\r\n"
|
|
lenScratch [32]byte
|
|
|
|
// Scratch space for formatting integers and floats.
|
|
numScratch [40]byte
|
|
}
|
|
|
|
// Dial connects to the Redis server at the given network and address.
|
|
func Dial(network, address string) (Conn, error) {
|
|
dialer := xDialer{}
|
|
return dialer.Dial(network, address)
|
|
}
|
|
|
|
// DialTimeout acts like Dial but takes timeouts for establishing the
|
|
// connection to the server, writing a command and reading a reply.
|
|
func DialTimeout(network, address string, connectTimeout, readTimeout, writeTimeout time.Duration) (Conn, error) {
|
|
netDialer := net.Dialer{Timeout: connectTimeout}
|
|
dialer := xDialer{
|
|
NetDial: netDialer.Dial,
|
|
ReadTimeout: readTimeout,
|
|
WriteTimeout: writeTimeout,
|
|
}
|
|
return dialer.Dial(network, address)
|
|
}
|
|
|
|
// A Dialer specifies options for connecting to a Redis server.
|
|
type xDialer struct {
|
|
// NetDial specifies the dial function for creating TCP connections. If
|
|
// NetDial is nil, then net.Dial is used.
|
|
NetDial func(network, addr string) (net.Conn, error)
|
|
|
|
// ReadTimeout specifies the timeout for reading a single command
|
|
// reply. If ReadTimeout is zero, then no timeout is used.
|
|
ReadTimeout time.Duration
|
|
|
|
// WriteTimeout specifies the timeout for writing a single command. If
|
|
// WriteTimeout is zero, then no timeout is used.
|
|
WriteTimeout time.Duration
|
|
}
|
|
|
|
// Dial connects to the Redis server at address on the named network.
|
|
func (d *xDialer) Dial(network, address string) (Conn, error) {
|
|
dial := d.NetDial
|
|
if dial == nil {
|
|
dial = net.Dial
|
|
}
|
|
netConn, err := dial(network, address)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &conn{
|
|
conn: netConn,
|
|
bw: bufio.NewWriter(netConn),
|
|
br: bufio.NewReader(netConn),
|
|
readTimeout: d.ReadTimeout,
|
|
writeTimeout: d.WriteTimeout,
|
|
}, nil
|
|
}
|
|
|
|
// NewConn returns a new Redigo connection for the given net connection.
|
|
func NewConn(netConn net.Conn, readTimeout, writeTimeout time.Duration) Conn {
|
|
return &conn{
|
|
conn: netConn,
|
|
bw: bufio.NewWriter(netConn),
|
|
br: bufio.NewReader(netConn),
|
|
readTimeout: readTimeout,
|
|
writeTimeout: writeTimeout,
|
|
}
|
|
}
|
|
|
|
func (c *conn) Close() error {
|
|
c.mu.Lock()
|
|
err := c.err
|
|
if c.err == nil {
|
|
c.err = errors.New("redigo: closed")
|
|
err = c.conn.Close()
|
|
}
|
|
c.mu.Unlock()
|
|
return err
|
|
}
|
|
|
|
func (c *conn) fatal(err error) error {
|
|
c.mu.Lock()
|
|
if c.err == nil {
|
|
c.err = err
|
|
// Close connection to force errors on subsequent calls and to unblock
|
|
// other reader or writer.
|
|
c.conn.Close()
|
|
}
|
|
c.mu.Unlock()
|
|
return err
|
|
}
|
|
|
|
func (c *conn) Err() error {
|
|
c.mu.Lock()
|
|
err := c.err
|
|
c.mu.Unlock()
|
|
return err
|
|
}
|
|
|
|
func (c *conn) writeLen(prefix byte, n int) error {
|
|
c.lenScratch[len(c.lenScratch)-1] = '\n'
|
|
c.lenScratch[len(c.lenScratch)-2] = '\r'
|
|
i := len(c.lenScratch) - 3
|
|
for {
|
|
c.lenScratch[i] = byte('0' + n%10)
|
|
i -= 1
|
|
n = n / 10
|
|
if n == 0 {
|
|
break
|
|
}
|
|
}
|
|
c.lenScratch[i] = prefix
|
|
_, err := c.bw.Write(c.lenScratch[i:])
|
|
return err
|
|
}
|
|
|
|
func (c *conn) writeString(s string) error {
|
|
c.writeLen('$', len(s))
|
|
c.bw.WriteString(s)
|
|
_, err := c.bw.WriteString("\r\n")
|
|
return err
|
|
}
|
|
|
|
func (c *conn) writeBytes(p []byte) error {
|
|
c.writeLen('$', len(p))
|
|
c.bw.Write(p)
|
|
_, err := c.bw.WriteString("\r\n")
|
|
return err
|
|
}
|
|
|
|
func (c *conn) writeInt64(n int64) error {
|
|
return c.writeBytes(strconv.AppendInt(c.numScratch[:0], n, 10))
|
|
}
|
|
|
|
func (c *conn) writeFloat64(n float64) error {
|
|
return c.writeBytes(strconv.AppendFloat(c.numScratch[:0], n, 'g', -1, 64))
|
|
}
|
|
|
|
func (c *conn) writeCommand(cmd string, args []interface{}) (err error) {
|
|
c.writeLen('*', 1+len(args))
|
|
err = c.writeString(cmd)
|
|
for _, arg := range args {
|
|
if err != nil {
|
|
break
|
|
}
|
|
switch arg := arg.(type) {
|
|
case string:
|
|
err = c.writeString(arg)
|
|
case []byte:
|
|
err = c.writeBytes(arg)
|
|
case int:
|
|
err = c.writeInt64(int64(arg))
|
|
case int64:
|
|
err = c.writeInt64(arg)
|
|
case float64:
|
|
err = c.writeFloat64(arg)
|
|
case bool:
|
|
if arg {
|
|
err = c.writeString("1")
|
|
} else {
|
|
err = c.writeString("0")
|
|
}
|
|
case nil:
|
|
err = c.writeString("")
|
|
default:
|
|
var buf bytes.Buffer
|
|
fmt.Fprint(&buf, arg)
|
|
err = c.writeBytes(buf.Bytes())
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
|
|
type protocolError string
|
|
|
|
func (pe protocolError) Error() string {
|
|
return fmt.Sprintf("redigo: %s (possible server error or unsupported concurrent read by application)", string(pe))
|
|
}
|
|
|
|
func (c *conn) readLine() ([]byte, error) {
|
|
p, err := c.br.ReadSlice('\n')
|
|
if err == bufio.ErrBufferFull {
|
|
return nil, protocolError("long response line")
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
i := len(p) - 2
|
|
if i < 0 || p[i] != '\r' {
|
|
return nil, protocolError("bad response line terminator")
|
|
}
|
|
return p[:i], nil
|
|
}
|
|
|
|
// parseLen parses bulk string and array lengths.
|
|
func parseLen(p []byte) (int, error) {
|
|
if len(p) == 0 {
|
|
return -1, protocolError("malformed length")
|
|
}
|
|
|
|
if p[0] == '-' && len(p) == 2 && p[1] == '1' {
|
|
// handle $-1 and $-1 null replies.
|
|
return -1, nil
|
|
}
|
|
|
|
var n int
|
|
for _, b := range p {
|
|
n *= 10
|
|
if b < '0' || b > '9' {
|
|
return -1, protocolError("illegal bytes in length")
|
|
}
|
|
n += int(b - '0')
|
|
}
|
|
|
|
return n, nil
|
|
}
|
|
|
|
// parseInt parses an integer reply.
|
|
func parseInt(p []byte) (interface{}, error) {
|
|
if len(p) == 0 {
|
|
return 0, protocolError("malformed integer")
|
|
}
|
|
|
|
var negate bool
|
|
if p[0] == '-' {
|
|
negate = true
|
|
p = p[1:]
|
|
if len(p) == 0 {
|
|
return 0, protocolError("malformed integer")
|
|
}
|
|
}
|
|
|
|
var n int64
|
|
for _, b := range p {
|
|
n *= 10
|
|
if b < '0' || b > '9' {
|
|
return 0, protocolError("illegal bytes in length")
|
|
}
|
|
n += int64(b - '0')
|
|
}
|
|
|
|
if negate {
|
|
n = -n
|
|
}
|
|
return n, nil
|
|
}
|
|
|
|
var (
|
|
okReply interface{} = "OK"
|
|
pongReply interface{} = "PONG"
|
|
)
|
|
|
|
func (c *conn) readReply() (interface{}, error) {
|
|
line, err := c.readLine()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(line) == 0 {
|
|
return nil, protocolError("short response line")
|
|
}
|
|
switch line[0] {
|
|
case '+':
|
|
switch {
|
|
case len(line) == 3 && line[1] == 'O' && line[2] == 'K':
|
|
// Avoid allocation for frequent "+OK" response.
|
|
return okReply, nil
|
|
case len(line) == 5 && line[1] == 'P' && line[2] == 'O' && line[3] == 'N' && line[4] == 'G':
|
|
// Avoid allocation in PING command benchmarks :)
|
|
return pongReply, nil
|
|
default:
|
|
return string(line[1:]), nil
|
|
}
|
|
case '-':
|
|
return Error(string(line[1:])), nil
|
|
case ':':
|
|
return parseInt(line[1:])
|
|
case '$':
|
|
n, err := parseLen(line[1:])
|
|
if n < 0 || err != nil {
|
|
return nil, err
|
|
}
|
|
p := make([]byte, n)
|
|
_, err = io.ReadFull(c.br, p)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if line, err := c.readLine(); err != nil {
|
|
return nil, err
|
|
} else if len(line) != 0 {
|
|
return nil, protocolError("bad bulk string format")
|
|
}
|
|
return p, nil
|
|
case '*':
|
|
n, err := parseLen(line[1:])
|
|
if n < 0 || err != nil {
|
|
return nil, err
|
|
}
|
|
r := make([]interface{}, n)
|
|
for i := range r {
|
|
r[i], err = c.readReply()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return r, nil
|
|
}
|
|
return nil, protocolError("unexpected response line")
|
|
}
|
|
|
|
func (c *conn) Send(cmd string, args ...interface{}) error {
|
|
c.mu.Lock()
|
|
c.pending += 1
|
|
c.mu.Unlock()
|
|
if c.writeTimeout != 0 {
|
|
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
|
|
}
|
|
if err := c.writeCommand(cmd, args); err != nil {
|
|
return c.fatal(err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *conn) Flush() error {
|
|
if c.writeTimeout != 0 {
|
|
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
|
|
}
|
|
if err := c.bw.Flush(); err != nil {
|
|
return c.fatal(err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *conn) Receive() (reply interface{}, err error) {
|
|
if c.readTimeout != 0 {
|
|
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
|
|
}
|
|
if reply, err = c.readReply(); err != nil {
|
|
return nil, c.fatal(err)
|
|
}
|
|
// When using pub/sub, the number of receives can be greater than the
|
|
// number of sends. To enable normal use of the connection after
|
|
// unsubscribing from all channels, we do not decrement pending to a
|
|
// negative value.
|
|
//
|
|
// The pending field is decremented after the reply is read to handle the
|
|
// case where Receive is called before Send.
|
|
c.mu.Lock()
|
|
if c.pending > 0 {
|
|
c.pending -= 1
|
|
}
|
|
c.mu.Unlock()
|
|
if err, ok := reply.(Error); ok {
|
|
return nil, err
|
|
}
|
|
return
|
|
}
|
|
|
|
func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) {
|
|
c.mu.Lock()
|
|
pending := c.pending
|
|
c.pending = 0
|
|
c.mu.Unlock()
|
|
|
|
if cmd == "" && pending == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
if c.writeTimeout != 0 {
|
|
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
|
|
}
|
|
|
|
if cmd != "" {
|
|
c.writeCommand(cmd, args)
|
|
}
|
|
|
|
if err := c.bw.Flush(); err != nil {
|
|
return nil, c.fatal(err)
|
|
}
|
|
|
|
if c.readTimeout != 0 {
|
|
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
|
|
}
|
|
|
|
if cmd == "" {
|
|
reply := make([]interface{}, pending)
|
|
for i := range reply {
|
|
r, e := c.readReply()
|
|
if e != nil {
|
|
return nil, c.fatal(e)
|
|
}
|
|
reply[i] = r
|
|
}
|
|
return reply, nil
|
|
}
|
|
|
|
var err error
|
|
var reply interface{}
|
|
for i := 0; i <= pending; i++ {
|
|
var e error
|
|
if reply, e = c.readReply(); e != nil {
|
|
return nil, c.fatal(e)
|
|
}
|
|
if e, ok := reply.(Error); ok && err == nil {
|
|
err = e
|
|
}
|
|
}
|
|
return reply, err
|
|
}
|