diff --git a/rdb.go b/rdb.go index 048f676..3a36d55 100644 --- a/rdb.go +++ b/rdb.go @@ -18,7 +18,7 @@ const ( ExpireRDMutexSec = 30 ) -type Settings struct { +type RdSettings struct { Host, Prefix string //will be shown in CLIENT LIST ClientName string @@ -29,16 +29,16 @@ type Settings struct { IdleTimeoutSec int } -func (s *Settings) ConnStr() string { +func (s *RdSettings) ConnStr() string { return s.Host + ":" + strconv.Itoa(s.Port) } type Pool struct { - S Settings + S RdSettings RP *redis.Pool } -func OpenPool(s Settings, logger *colog.CoLog) *Pool { +func OpenPool(s RdSettings, logger *colog.CoLog) *Pool { p := &Pool{S: s, RP: newRedisPool(s, logger)} return p } @@ -59,12 +59,16 @@ func (p *Pool) Get() redis.Conn { return conn } +func (p *Pool) Auto() redis.Conn { + return &autoConn{p, nil} +} + // NOTE: redis connection logs all operations type rdb struct { orig redis.Conn name string logger *colog.CoLog - s Settings + s RdSettings } func (rd *rdb) Close() error { @@ -121,7 +125,7 @@ func (rd *rdb) Receive() (interface{}, error) { return rd.orig.Receive() } -func newRedisPool(s Settings, logger *colog.CoLog) *redis.Pool { +func newRedisPool(s RdSettings, logger *colog.CoLog) *redis.Pool { maxIdle := s.MaxIdle if maxIdle == 0 { @@ -396,14 +400,57 @@ type tracked struct { subj redis.Conn } -func (t tracked) Close() error { +func (t *tracked) Close() error { res_tracker.Untrack(t) return t.subj.Close() } -func (t tracked) Do(cmd string, args ...interface{}) (interface{}, error) { +func (t *tracked) Do(cmd string, args ...interface{}) (interface{}, error) { return t.subj.Do(cmd, args...) } -func (t tracked) Send(cmd string, args ...interface{}) error { return t.subj.Send(cmd, args...) } -func (t tracked) Err() error { return t.subj.Err() } -func (t tracked) Flush() error { return t.subj.Flush() } -func (t tracked) Receive() (interface{}, error) { return t.subj.Receive() } +func (t *tracked) Send(cmd string, args ...interface{}) error { return t.subj.Send(cmd, args...) } +func (t *tracked) Err() error { return t.subj.Err() } +func (t *tracked) Flush() error { return t.subj.Flush() } +func (t *tracked) Receive() (interface{}, error) { return t.subj.Receive() } + +type autoConn struct { + p *Pool + pipe redis.Conn +} + +func (c *autoConn) Close() error { + return nil +} +func (c *autoConn) Do(cmd string, args ...interface{}) (interface{}, error) { + if c.pipe != nil { + return nil, errors.New("There is an active pipeline") + } + + rc := c.p.Get() + defer rc.Close() + return rc.Do(cmd, args...) +} + +func (c *autoConn) Send(cmd string, args ...interface{}) error { + if c.pipe == nil { + c.pipe = c.p.Get() + } + return c.pipe.Send(cmd, args...) +} +func (c *autoConn) Err() error { + if c.pipe != nil { + return c.pipe.Err() + } + return nil +} +func (c *autoConn) Flush() error { + if c.pipe != nil { + return c.pipe.Flush() + } + return errors.New("There is no active pipeline") +} +func (c *autoConn) Receive() (interface{}, error) { + if c.pipe != nil { + return c.pipe.Receive() + } + return nil, errors.New("There is no active pipeline") +} diff --git a/rdb_test.go b/rdb_test.go index e19dba0..69275d2 100644 --- a/rdb_test.go +++ b/rdb_test.go @@ -10,7 +10,7 @@ import ( ) func getPool() *rdb.Pool { - pool := rdb.OpenPool(rdb.Settings{Host: "localhost", Port: 6379, Db: 10, ClientName: "test"}, getLogger()) + pool := rdb.OpenPool(rdb.RdSettings{Host: "localhost", Port: 6379, Db: 10, ClientName: "test"}, getLogger()) return pool } @@ -23,7 +23,7 @@ func getLogger() *colog.CoLog { } func TestBadConnection(t *testing.T) { - pool := rdb.OpenPool(rdb.Settings{Host: "dummy", Port: 80, ClientName: "test"}, getLogger()) + pool := rdb.OpenPool(rdb.RdSettings{Host: "dummy", Port: 80, ClientName: "test"}, getLogger()) conn := pool.Get() assert.NotNil(t, conn) _, err := conn.Do("PING")