230 lines
4.8 KiB
Go
230 lines
4.8 KiB
Go
package pg_test
|
|
|
|
import (
|
|
"bytes"
|
|
"database/sql/driver"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/go-pg/pg"
|
|
|
|
. "github.com/onsi/ginkgo"
|
|
. "gopkg.in/check.v1"
|
|
)
|
|
|
|
func TestUnixSocket(t *testing.T) {
|
|
opt := pgOptions()
|
|
opt.Network = "unix"
|
|
opt.Addr = "/var/run/postgresql/.s.PGSQL.5432"
|
|
opt.TLSConfig = nil
|
|
db := pg.Connect(opt)
|
|
defer db.Close()
|
|
|
|
_, err := db.Exec("SELECT 'test_unix_socket'")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func TestGocheck(t *testing.T) { TestingT(t) }
|
|
|
|
var _ = Suite(&DBTest{})
|
|
|
|
type DBTest struct {
|
|
db *pg.DB
|
|
}
|
|
|
|
func (t *DBTest) SetUpTest(c *C) {
|
|
t.db = pg.Connect(pgOptions())
|
|
}
|
|
|
|
func (t *DBTest) TearDownTest(c *C) {
|
|
c.Assert(t.db.Close(), IsNil)
|
|
}
|
|
|
|
func (t *DBTest) TestQueryZeroRows(c *C) {
|
|
res, err := t.db.Query(pg.Discard, "SELECT 1 WHERE 1 != 1")
|
|
c.Assert(err, IsNil)
|
|
c.Assert(res.RowsAffected(), Equals, 0)
|
|
}
|
|
|
|
func (t *DBTest) TestQueryOneErrNoRows(c *C) {
|
|
_, err := t.db.QueryOne(pg.Discard, "SELECT 1 WHERE 1 != 1")
|
|
c.Assert(err, Equals, pg.ErrNoRows)
|
|
}
|
|
|
|
func (t *DBTest) TestQueryOneErrMultiRows(c *C) {
|
|
_, err := t.db.QueryOne(pg.Discard, "SELECT generate_series(0, 1)")
|
|
c.Assert(err, Equals, pg.ErrMultiRows)
|
|
}
|
|
|
|
func (t *DBTest) TestExecOne(c *C) {
|
|
res, err := t.db.ExecOne("SELECT 'test_exec_one'")
|
|
c.Assert(err, IsNil)
|
|
c.Assert(res.RowsAffected(), Equals, 1)
|
|
}
|
|
|
|
func (t *DBTest) TestExecOneErrNoRows(c *C) {
|
|
_, err := t.db.ExecOne("SELECT 1 WHERE 1 != 1")
|
|
c.Assert(err, Equals, pg.ErrNoRows)
|
|
}
|
|
|
|
func (t *DBTest) TestExecOneErrMultiRows(c *C) {
|
|
_, err := t.db.ExecOne("SELECT generate_series(0, 1)")
|
|
c.Assert(err, Equals, pg.ErrMultiRows)
|
|
}
|
|
|
|
func (t *DBTest) TestScan(c *C) {
|
|
var dst int
|
|
_, err := t.db.QueryOne(pg.Scan(&dst), "SELECT 1")
|
|
c.Assert(err, IsNil)
|
|
c.Assert(dst, Equals, 1)
|
|
}
|
|
|
|
func (t *DBTest) TestExec(c *C) {
|
|
res, err := t.db.Exec("CREATE TEMP TABLE test(id serial PRIMARY KEY)")
|
|
c.Assert(err, IsNil)
|
|
c.Assert(res.RowsAffected(), Equals, -1)
|
|
|
|
res, err = t.db.Exec("INSERT INTO test VALUES (1)")
|
|
c.Assert(err, IsNil)
|
|
c.Assert(res.RowsAffected(), Equals, 1)
|
|
}
|
|
|
|
func (t *DBTest) TestStatementExec(c *C) {
|
|
res, err := t.db.Exec("CREATE TEMP TABLE test(id serial PRIMARY KEY)")
|
|
c.Assert(err, IsNil)
|
|
c.Assert(res.RowsAffected(), Equals, -1)
|
|
|
|
stmt, err := t.db.Prepare("INSERT INTO test VALUES($1)")
|
|
c.Assert(err, IsNil)
|
|
defer stmt.Close()
|
|
|
|
res, err = stmt.Exec(1)
|
|
c.Assert(err, IsNil)
|
|
c.Assert(res.RowsAffected(), Equals, 1)
|
|
}
|
|
|
|
func (t *DBTest) TestLargeWriteRead(c *C) {
|
|
src := bytes.Repeat([]byte{0x1}, 1e6)
|
|
var dst []byte
|
|
_, err := t.db.QueryOne(pg.Scan(&dst), "SELECT ?", src)
|
|
c.Assert(err, IsNil)
|
|
c.Assert(dst, DeepEquals, src)
|
|
}
|
|
|
|
func (t *DBTest) TestIntegrityError(c *C) {
|
|
_, err := t.db.Exec("DO $$BEGIN RAISE unique_violation USING MESSAGE='foo'; END$$;")
|
|
c.Assert(err.(pg.Error).IntegrityViolation(), Equals, true)
|
|
}
|
|
|
|
type customStrSlice []string
|
|
|
|
func (s customStrSlice) Value() (driver.Value, error) {
|
|
return strings.Join(s, "\n"), nil
|
|
}
|
|
|
|
func (s *customStrSlice) Scan(v interface{}) error {
|
|
if v == nil {
|
|
*s = nil
|
|
return nil
|
|
}
|
|
|
|
b := v.([]byte)
|
|
|
|
if len(b) == 0 {
|
|
*s = []string{}
|
|
return nil
|
|
}
|
|
|
|
*s = strings.Split(string(b), "\n")
|
|
return nil
|
|
}
|
|
|
|
func (t *DBTest) TestScannerValueOnStruct(c *C) {
|
|
src := customStrSlice{"foo", "bar"}
|
|
dst := struct{ Dst customStrSlice }{}
|
|
_, err := t.db.QueryOne(&dst, "SELECT ? AS dst", src)
|
|
c.Assert(err, IsNil)
|
|
c.Assert(dst.Dst, DeepEquals, src)
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
type badConnError string
|
|
|
|
func (e badConnError) Error() string { return string(e) }
|
|
func (e badConnError) Timeout() bool { return false }
|
|
func (e badConnError) Temporary() bool { return false }
|
|
|
|
type badConn struct {
|
|
net.TCPConn
|
|
|
|
readDelay, writeDelay time.Duration
|
|
readErr, writeErr error
|
|
}
|
|
|
|
var _ net.Conn = &badConn{}
|
|
|
|
func (cn *badConn) Read([]byte) (int, error) {
|
|
if cn.readDelay != 0 {
|
|
time.Sleep(cn.readDelay)
|
|
}
|
|
if cn.readErr != nil {
|
|
return 0, cn.readErr
|
|
}
|
|
return 0, badConnError("bad connection")
|
|
}
|
|
|
|
func (cn *badConn) Write([]byte) (int, error) {
|
|
if cn.writeDelay != 0 {
|
|
time.Sleep(cn.writeDelay)
|
|
}
|
|
if cn.writeErr != nil {
|
|
return 0, cn.writeErr
|
|
}
|
|
return 0, badConnError("bad connection")
|
|
}
|
|
|
|
func perform(n int, cbs ...func(int)) {
|
|
var wg sync.WaitGroup
|
|
for _, cb := range cbs {
|
|
for i := 0; i < n; i++ {
|
|
wg.Add(1)
|
|
go func(cb func(int), i int) {
|
|
defer GinkgoRecover()
|
|
defer wg.Done()
|
|
|
|
cb(i)
|
|
}(cb, i)
|
|
}
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
func eventually(fn func() error, timeout time.Duration) (err error) {
|
|
done := make(chan struct{})
|
|
var exit int32
|
|
go func() {
|
|
for atomic.LoadInt32(&exit) == 0 {
|
|
err = fn()
|
|
if err == nil {
|
|
close(done)
|
|
return
|
|
}
|
|
time.Sleep(timeout / 100)
|
|
}
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
return nil
|
|
case <-time.After(timeout):
|
|
atomic.StoreInt32(&exit, 1)
|
|
return err
|
|
}
|
|
}
|