package pg_test import ( "bytes" "database/sql/driver" "net" "strings" "sync" "sync/atomic" "testing" "time" "github.com/go-pg/pg" . "github.com/onsi/ginkgo" . "gopkg.in/check.v1" ) func TestUnixSocket(t *testing.T) { opt := pgOptions() opt.Network = "unix" opt.Addr = "/var/run/postgresql/.s.PGSQL.5432" opt.TLSConfig = nil db := pg.Connect(opt) defer db.Close() _, err := db.Exec("SELECT 'test_unix_socket'") if err != nil { t.Fatal(err) } } func TestGocheck(t *testing.T) { TestingT(t) } var _ = Suite(&DBTest{}) type DBTest struct { db *pg.DB } func (t *DBTest) SetUpTest(c *C) { t.db = pg.Connect(pgOptions()) } func (t *DBTest) TearDownTest(c *C) { c.Assert(t.db.Close(), IsNil) } func (t *DBTest) TestQueryZeroRows(c *C) { res, err := t.db.Query(pg.Discard, "SELECT 1 WHERE 1 != 1") c.Assert(err, IsNil) c.Assert(res.RowsAffected(), Equals, 0) } func (t *DBTest) TestQueryOneErrNoRows(c *C) { _, err := t.db.QueryOne(pg.Discard, "SELECT 1 WHERE 1 != 1") c.Assert(err, Equals, pg.ErrNoRows) } func (t *DBTest) TestQueryOneErrMultiRows(c *C) { _, err := t.db.QueryOne(pg.Discard, "SELECT generate_series(0, 1)") c.Assert(err, Equals, pg.ErrMultiRows) } func (t *DBTest) TestExecOne(c *C) { res, err := t.db.ExecOne("SELECT 'test_exec_one'") c.Assert(err, IsNil) c.Assert(res.RowsAffected(), Equals, 1) } func (t *DBTest) TestExecOneErrNoRows(c *C) { _, err := t.db.ExecOne("SELECT 1 WHERE 1 != 1") c.Assert(err, Equals, pg.ErrNoRows) } func (t *DBTest) TestExecOneErrMultiRows(c *C) { _, err := t.db.ExecOne("SELECT generate_series(0, 1)") c.Assert(err, Equals, pg.ErrMultiRows) } func (t *DBTest) TestScan(c *C) { var dst int _, err := t.db.QueryOne(pg.Scan(&dst), "SELECT 1") c.Assert(err, IsNil) c.Assert(dst, Equals, 1) } func (t *DBTest) TestExec(c *C) { res, err := t.db.Exec("CREATE TEMP TABLE test(id serial PRIMARY KEY)") c.Assert(err, IsNil) c.Assert(res.RowsAffected(), Equals, -1) res, err = t.db.Exec("INSERT INTO test VALUES (1)") c.Assert(err, IsNil) c.Assert(res.RowsAffected(), Equals, 1) } func (t *DBTest) TestStatementExec(c *C) { res, err := t.db.Exec("CREATE TEMP TABLE test(id serial PRIMARY KEY)") c.Assert(err, IsNil) c.Assert(res.RowsAffected(), Equals, -1) stmt, err := t.db.Prepare("INSERT INTO test VALUES($1)") c.Assert(err, IsNil) defer stmt.Close() res, err = stmt.Exec(1) c.Assert(err, IsNil) c.Assert(res.RowsAffected(), Equals, 1) } func (t *DBTest) TestLargeWriteRead(c *C) { src := bytes.Repeat([]byte{0x1}, 1e6) var dst []byte _, err := t.db.QueryOne(pg.Scan(&dst), "SELECT ?", src) c.Assert(err, IsNil) c.Assert(dst, DeepEquals, src) } func (t *DBTest) TestIntegrityError(c *C) { _, err := t.db.Exec("DO $$BEGIN RAISE unique_violation USING MESSAGE='foo'; END$$;") c.Assert(err.(pg.Error).IntegrityViolation(), Equals, true) } type customStrSlice []string func (s customStrSlice) Value() (driver.Value, error) { return strings.Join(s, "\n"), nil } func (s *customStrSlice) Scan(v interface{}) error { if v == nil { *s = nil return nil } b := v.([]byte) if len(b) == 0 { *s = []string{} return nil } *s = strings.Split(string(b), "\n") return nil } func (t *DBTest) TestScannerValueOnStruct(c *C) { src := customStrSlice{"foo", "bar"} dst := struct{ Dst customStrSlice }{} _, err := t.db.QueryOne(&dst, "SELECT ? AS dst", src) c.Assert(err, IsNil) c.Assert(dst.Dst, DeepEquals, src) } //------------------------------------------------------------------------------ type badConnError string func (e badConnError) Error() string { return string(e) } func (e badConnError) Timeout() bool { return false } func (e badConnError) Temporary() bool { return false } type badConn struct { net.TCPConn readDelay, writeDelay time.Duration readErr, writeErr error } var _ net.Conn = &badConn{} func (cn *badConn) Read([]byte) (int, error) { if cn.readDelay != 0 { time.Sleep(cn.readDelay) } if cn.readErr != nil { return 0, cn.readErr } return 0, badConnError("bad connection") } func (cn *badConn) Write([]byte) (int, error) { if cn.writeDelay != 0 { time.Sleep(cn.writeDelay) } if cn.writeErr != nil { return 0, cn.writeErr } return 0, badConnError("bad connection") } func perform(n int, cbs ...func(int)) { var wg sync.WaitGroup for _, cb := range cbs { for i := 0; i < n; i++ { wg.Add(1) go func(cb func(int), i int) { defer GinkgoRecover() defer wg.Done() cb(i) }(cb, i) } } wg.Wait() } func eventually(fn func() error, timeout time.Duration) (err error) { done := make(chan struct{}) var exit int32 go func() { for atomic.LoadInt32(&exit) == 0 { err = fn() if err == nil { close(done) return } time.Sleep(timeout / 100) } }() select { case <-done: return nil case <-time.After(timeout): atomic.StoreInt32(&exit, 1) return err } }