upgrade from go-pg to bun
This commit is contained in:
parent
f59f129354
commit
1886be62bc
23 changed files with 415 additions and 385 deletions
|
|
@ -1,19 +1,22 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/dhax/go-base/auth/jwt"
|
||||
"github.com/dhax/go-base/auth/pwdless"
|
||||
"github.com/dhax/go-base/models"
|
||||
"github.com/go-pg/pg"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
// AccountStore implements database operations for account management by user.
|
||||
type AccountStore struct {
|
||||
db *pg.DB
|
||||
db *bun.DB
|
||||
}
|
||||
|
||||
// NewAccountStore returns an AccountStore.
|
||||
func NewAccountStore(db *pg.DB) *AccountStore {
|
||||
func NewAccountStore(db *bun.DB) *AccountStore {
|
||||
return &AccountStore{
|
||||
db: db,
|
||||
}
|
||||
|
|
@ -21,52 +24,71 @@ func NewAccountStore(db *pg.DB) *AccountStore {
|
|||
|
||||
// Get an account by ID.
|
||||
func (s *AccountStore) Get(id int) (*pwdless.Account, error) {
|
||||
a := pwdless.Account{ID: id}
|
||||
err := s.db.Model(&a).
|
||||
Where("account.id = ?id").
|
||||
Column("account.*", "Token").
|
||||
First()
|
||||
return &a, err
|
||||
a := &pwdless.Account{ID: id}
|
||||
err := s.db.NewSelect().
|
||||
Model(a).
|
||||
Where("id = ?", id).
|
||||
Relation("Token").
|
||||
Scan(context.Background())
|
||||
return a, err
|
||||
}
|
||||
|
||||
// Update an account.
|
||||
func (s *AccountStore) Update(a *pwdless.Account) error {
|
||||
_, err := s.db.Model(a).
|
||||
_, err := s.db.NewUpdate().
|
||||
Model(a).
|
||||
Column("email", "name").
|
||||
WherePK().
|
||||
Update()
|
||||
Exec(context.Background())
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete an account.
|
||||
func (s *AccountStore) Delete(a *pwdless.Account) error {
|
||||
err := s.db.RunInTransaction(func(tx *pg.Tx) error {
|
||||
if _, err := tx.Model(&jwt.Token{}).
|
||||
Where("account_id = ?", a.ID).
|
||||
Delete(); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Model(&models.Profile{}).
|
||||
Where("account_id = ?", a.ID).
|
||||
Delete(); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Delete(a)
|
||||
})
|
||||
return err
|
||||
ctx := context.Background()
|
||||
tx, err := s.db.BeginTx(ctx, &sql.TxOptions{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.NewDelete().
|
||||
Model((*jwt.Token)(nil)).
|
||||
Where("account_id = ?", a.ID).
|
||||
Exec(ctx); err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
if _, err := tx.NewDelete().
|
||||
Model((*models.Profile)(nil)).
|
||||
Where("account_id = ?", a.ID).
|
||||
Exec(ctx); err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
if _, err := tx.NewDelete().
|
||||
Model(a).
|
||||
WherePK().
|
||||
Exec(ctx); err != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
tx.Commit()
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateToken updates a jwt refresh token.
|
||||
func (s *AccountStore) UpdateToken(t *jwt.Token) error {
|
||||
_, err := s.db.Model(t).
|
||||
_, err := s.db.NewUpdate().
|
||||
Model(t).
|
||||
Column("identifier").
|
||||
WherePK().
|
||||
Update()
|
||||
Exec(context.Background())
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteToken deletes a jwt refresh token.
|
||||
func (s *AccountStore) DeleteToken(t *jwt.Token) error {
|
||||
err := s.db.Delete(t)
|
||||
_, err := s.db.NewDelete().
|
||||
Model(t).
|
||||
WherePK().
|
||||
Exec(context.Background())
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,15 +1,15 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"net/url"
|
||||
|
||||
"github.com/dhax/go-base/auth/jwt"
|
||||
"github.com/dhax/go-base/auth/pwdless"
|
||||
"github.com/dhax/go-base/models"
|
||||
"github.com/go-pg/pg"
|
||||
"github.com/go-pg/pg/orm"
|
||||
"github.com/go-pg/pg/urlvalues"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
@ -21,11 +21,11 @@ var (
|
|||
|
||||
// AdmAccountStore implements database operations for account management by admin.
|
||||
type AdmAccountStore struct {
|
||||
db *pg.DB
|
||||
db *bun.DB
|
||||
}
|
||||
|
||||
// NewAdmAccountStore returns an AccountStore.
|
||||
func NewAdmAccountStore(db *pg.DB) *AdmAccountStore {
|
||||
func NewAdmAccountStore(db *bun.DB) *AdmAccountStore {
|
||||
return &AdmAccountStore{
|
||||
db: db,
|
||||
}
|
||||
|
|
@ -33,8 +33,9 @@ func NewAdmAccountStore(db *pg.DB) *AdmAccountStore {
|
|||
|
||||
// AccountFilter provides pagination and filtering options on accounts.
|
||||
type AccountFilter struct {
|
||||
Pager *urlvalues.Pager
|
||||
Filter *urlvalues.Filter
|
||||
Limit int
|
||||
Offset int
|
||||
Filter map[string]interface{}
|
||||
Order []string
|
||||
}
|
||||
|
||||
|
|
@ -44,29 +45,47 @@ func NewAccountFilter(params interface{}) (*AccountFilter, error) {
|
|||
if !ok {
|
||||
return nil, ErrBadParams
|
||||
}
|
||||
p := urlvalues.Values(v)
|
||||
f := &AccountFilter{
|
||||
Pager: urlvalues.NewPager(p),
|
||||
Filter: urlvalues.NewFilter(p),
|
||||
Order: p["order"],
|
||||
Limit: 10, // Default limit
|
||||
Offset: 0, // Default offset
|
||||
Filter: make(map[string]interface{}),
|
||||
Order: v["order"],
|
||||
}
|
||||
// Parse limit and offset
|
||||
if limit := v.Get("limit"); limit != "" {
|
||||
f.Limit = int(limit[0] - '0')
|
||||
}
|
||||
if offset := v.Get("offset"); offset != "" {
|
||||
f.Offset = int(offset[0] - '0')
|
||||
}
|
||||
// Parse filters
|
||||
for key, values := range v {
|
||||
if key != "limit" && key != "offset" && key != "order" {
|
||||
f.Filter[key] = values[0]
|
||||
}
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// Apply applies an AccountFilter on an orm.Query.
|
||||
func (f *AccountFilter) Apply(q *orm.Query) (*orm.Query, error) {
|
||||
q = q.Apply(f.Pager.Pagination)
|
||||
q = q.Apply(f.Filter.Filters)
|
||||
q = q.Order(f.Order...)
|
||||
return q, nil
|
||||
// Apply applies an AccountFilter on a bun.SelectQuery.
|
||||
func (f *AccountFilter) Apply(q *bun.SelectQuery) *bun.SelectQuery {
|
||||
q = q.Limit(f.Limit).Offset(f.Offset)
|
||||
for key, value := range f.Filter {
|
||||
q = q.Where("? = ?", bun.Ident(key), value)
|
||||
}
|
||||
for _, order := range f.Order {
|
||||
q = q.Order(order)
|
||||
}
|
||||
return q
|
||||
}
|
||||
|
||||
// List applies a filter and returns paginated array of matching results and total count.
|
||||
func (s *AdmAccountStore) List(f *AccountFilter) ([]pwdless.Account, int, error) {
|
||||
a := []pwdless.Account{}
|
||||
count, err := s.db.Model(&a).
|
||||
var a []pwdless.Account
|
||||
count, err := s.db.NewSelect().
|
||||
Model(&a).
|
||||
Apply(f.Apply).
|
||||
SelectAndCount()
|
||||
ScanAndCount(context.Background())
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
|
@ -75,55 +94,90 @@ func (s *AdmAccountStore) List(f *AccountFilter) ([]pwdless.Account, int, error)
|
|||
|
||||
// Create creates a new account.
|
||||
func (s *AdmAccountStore) Create(a *pwdless.Account) error {
|
||||
count, _ := s.db.Model(a).
|
||||
Where("email = ?email").
|
||||
Count()
|
||||
|
||||
if count != 0 {
|
||||
return ErrUniqueEmailConstraint
|
||||
exists, err := s.db.NewSelect().
|
||||
Model((*pwdless.Account)(nil)).
|
||||
Where("email = ?", a.Email).
|
||||
Exists(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err := s.db.RunInTransaction(func(tx *pg.Tx) error {
|
||||
err := tx.Insert(a)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p := &models.Profile{
|
||||
AccountID: a.ID,
|
||||
}
|
||||
return tx.Insert(p)
|
||||
})
|
||||
if exists {
|
||||
return ErrUniqueEmailConstraint
|
||||
}
|
||||
ctx := context.Background()
|
||||
tx, err := s.db.BeginTx(ctx, &sql.TxOptions{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.NewInsert().
|
||||
Model(a).
|
||||
Exec(ctx); err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
p := &models.Profile{
|
||||
AccountID: a.ID,
|
||||
}
|
||||
if _, err := tx.NewInsert().
|
||||
Model(p).
|
||||
Exec(ctx); err != nil {
|
||||
tx.Rollback()
|
||||
|
||||
return err
|
||||
return err
|
||||
}
|
||||
tx.Commit()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get account by ID.
|
||||
func (s *AdmAccountStore) Get(id int) (*pwdless.Account, error) {
|
||||
a := pwdless.Account{ID: id}
|
||||
err := s.db.Select(&a)
|
||||
return &a, err
|
||||
a := &pwdless.Account{ID: id}
|
||||
err := s.db.NewSelect().
|
||||
Model(a).
|
||||
WherePK().
|
||||
Scan(context.Background())
|
||||
return a, err
|
||||
}
|
||||
|
||||
// Update account.
|
||||
func (s *AdmAccountStore) Update(a *pwdless.Account) error {
|
||||
err := s.db.Update(a)
|
||||
_, err := s.db.NewUpdate().
|
||||
Model(a).
|
||||
WherePK().
|
||||
Exec(context.Background())
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete account.
|
||||
func (s *AdmAccountStore) Delete(a *pwdless.Account) error {
|
||||
err := s.db.RunInTransaction(func(tx *pg.Tx) error {
|
||||
if _, err := tx.Model(&jwt.Token{}).
|
||||
Where("account_id = ?", a.ID).
|
||||
Delete(); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Model(&models.Profile{}).
|
||||
Where("account_id = ?", a.ID).
|
||||
Delete(); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Delete(a)
|
||||
})
|
||||
return err
|
||||
ctx := context.Background()
|
||||
tx, err := s.db.BeginTx(ctx, &sql.TxOptions{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.NewDelete().
|
||||
Model((*jwt.Token)(nil)).
|
||||
Where("account_id = ?", a.ID).
|
||||
Exec(ctx); err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
if _, err := tx.NewDelete().
|
||||
Model((*models.Profile)(nil)).
|
||||
Where("account_id = ?", a.ID).
|
||||
Exec(ctx); err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
if _, err := tx.NewDelete().
|
||||
Model(a).
|
||||
WherePK().
|
||||
Exec(ctx); err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
tx.Commit()
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,20 +1,21 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/dhax/go-base/auth/jwt"
|
||||
"github.com/dhax/go-base/auth/pwdless"
|
||||
"github.com/go-pg/pg"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
// AuthStore implements database operations for account pwdlessentication.
|
||||
type AuthStore struct {
|
||||
db *pg.DB
|
||||
db *bun.DB
|
||||
}
|
||||
|
||||
// NewAuthStore return an AuthStore.
|
||||
func NewAuthStore(db *pg.DB) *AuthStore {
|
||||
func NewAuthStore(db *bun.DB) *AuthStore {
|
||||
return &AuthStore{
|
||||
db: db,
|
||||
}
|
||||
|
|
@ -22,65 +23,74 @@ func NewAuthStore(db *pg.DB) *AuthStore {
|
|||
|
||||
// GetAccount returns an account by ID.
|
||||
func (s *AuthStore) GetAccount(id int) (*pwdless.Account, error) {
|
||||
a := pwdless.Account{ID: id}
|
||||
err := s.db.Model(&a).
|
||||
Column("account.*").
|
||||
Where("id = ?id").
|
||||
First()
|
||||
return &a, err
|
||||
a := &pwdless.Account{ID: id}
|
||||
err := s.db.NewSelect().
|
||||
Model(a).
|
||||
Where("id = ?", id).
|
||||
Scan(context.Background())
|
||||
return a, err
|
||||
}
|
||||
|
||||
// GetAccountByEmail returns an account by email.
|
||||
func (s *AuthStore) GetAccountByEmail(e string) (*pwdless.Account, error) {
|
||||
a := pwdless.Account{Email: e}
|
||||
err := s.db.Model(&a).
|
||||
a := &pwdless.Account{Email: e}
|
||||
err := s.db.NewSelect().
|
||||
Model(a).
|
||||
Column("id", "active", "email", "name").
|
||||
Where("email = ?email").
|
||||
First()
|
||||
return &a, err
|
||||
Where("email = ?", e).
|
||||
Scan(context.Background())
|
||||
return a, err
|
||||
}
|
||||
|
||||
// UpdateAccount upates account data related to pwdlessentication.
|
||||
func (s *AuthStore) UpdateAccount(a *pwdless.Account) error {
|
||||
_, err := s.db.Model(a).
|
||||
_, err := s.db.NewUpdate().
|
||||
Model(a).
|
||||
Column("last_login").
|
||||
WherePK().
|
||||
Update()
|
||||
Exec(context.Background())
|
||||
return err
|
||||
}
|
||||
|
||||
// GetToken returns refresh token by token identifier.
|
||||
func (s *AuthStore) GetToken(t string) (*jwt.Token, error) {
|
||||
token := jwt.Token{Token: t}
|
||||
err := s.db.Model(&token).
|
||||
Where("token = ?token").
|
||||
First()
|
||||
|
||||
return &token, err
|
||||
token := &jwt.Token{Token: t}
|
||||
err := s.db.NewSelect().
|
||||
Model(token).
|
||||
Where("token = ?", t).
|
||||
Scan(context.Background())
|
||||
return token, err
|
||||
}
|
||||
|
||||
// CreateOrUpdateToken creates or updates an existing refresh token.
|
||||
func (s *AuthStore) CreateOrUpdateToken(t *jwt.Token) error {
|
||||
var err error
|
||||
if t.ID == 0 {
|
||||
err = s.db.Insert(t)
|
||||
} else {
|
||||
err = s.db.Update(t)
|
||||
_, err := s.db.NewInsert().
|
||||
Model(t).
|
||||
Exec(context.Background())
|
||||
return err
|
||||
}
|
||||
_, err := s.db.NewUpdate().
|
||||
Model(t).
|
||||
WherePK().
|
||||
Exec(context.Background())
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteToken deletes a refresh token.
|
||||
func (s *AuthStore) DeleteToken(t *jwt.Token) error {
|
||||
err := s.db.Delete(t)
|
||||
_, err := s.db.NewDelete().
|
||||
Model(t).
|
||||
WherePK().
|
||||
Exec(context.Background())
|
||||
return err
|
||||
}
|
||||
|
||||
// PurgeExpiredToken deletes expired refresh token.
|
||||
func (s *AuthStore) PurgeExpiredToken() error {
|
||||
_, err := s.db.Model(&jwt.Token{}).
|
||||
_, err := s.db.NewDelete().
|
||||
Model((*jwt.Token)(nil)).
|
||||
Where("expiry < ?", time.Now()).
|
||||
Delete()
|
||||
|
||||
Exec(context.Background())
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,48 +0,0 @@
|
|||
package migrate
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-pg/migrations"
|
||||
)
|
||||
|
||||
const bootstrapAdminAccount = `
|
||||
INSERT INTO accounts (id, email, name, active, roles)
|
||||
VALUES (DEFAULT, 'admin@boot.io', 'Admin Boot', true, '{admin}')
|
||||
`
|
||||
|
||||
const bootstrapUserAccount = `
|
||||
INSERT INTO accounts (id, email, name, active)
|
||||
VALUES (DEFAULT, 'user@boot.io', 'User Boot', true)
|
||||
`
|
||||
|
||||
func init() {
|
||||
up := []string{
|
||||
bootstrapAdminAccount,
|
||||
bootstrapUserAccount,
|
||||
}
|
||||
|
||||
down := []string{
|
||||
`TRUNCATE accounts CASCADE`,
|
||||
}
|
||||
|
||||
migrations.Register(func(db migrations.DB) error {
|
||||
fmt.Println("add bootstrap accounts")
|
||||
for _, q := range up {
|
||||
_, err := db.Exec(q)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}, func(db migrations.DB) error {
|
||||
fmt.Println("truncate accounts cascading")
|
||||
for _, q := range down {
|
||||
_, err := db.Exec(q)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
|
@ -1,63 +0,0 @@
|
|||
// Package migrate implements postgres migrations.
|
||||
package migrate
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/dhax/go-base/database"
|
||||
"github.com/go-pg/migrations"
|
||||
"github.com/go-pg/pg"
|
||||
)
|
||||
|
||||
// Migrate runs go-pg migrations
|
||||
func Migrate(args []string) {
|
||||
db, err := database.DBConn()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.RunInTransaction(func(tx *pg.Tx) error {
|
||||
oldVersion, newVersion, err := migrations.Run(tx, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if newVersion != oldVersion {
|
||||
log.Printf("migrated from version %d to %d\n", oldVersion, newVersion)
|
||||
} else {
|
||||
log.Printf("version is %d\n", oldVersion)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Reset runs reverts all migrations to version 0 and then applies all migrations to latest
|
||||
func Reset() {
|
||||
db, err := database.DBConn()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
version, err := migrations.Version(db)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.RunInTransaction(func(tx *pg.Tx) error {
|
||||
for version != 0 {
|
||||
oldVersion, newVersion, err := migrations.Run(tx, "down")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("migrated from version %d to %d\n", oldVersion, newVersion)
|
||||
version = newVersion
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,9 +1,10 @@
|
|||
package migrate
|
||||
package migrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/go-pg/migrations"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
const accountTable = `
|
||||
|
|
@ -43,7 +44,7 @@ func init() {
|
|||
`DROP TABLE accounts`,
|
||||
}
|
||||
|
||||
migrations.Register(func(db migrations.DB) error {
|
||||
Migrations.MustRegister(func(ctx context.Context, db *bun.DB) error {
|
||||
fmt.Println("creating initial tables")
|
||||
for _, q := range up {
|
||||
_, err := db.Exec(q)
|
||||
|
|
@ -52,7 +53,7 @@ func init() {
|
|||
}
|
||||
}
|
||||
return nil
|
||||
}, func(db migrations.DB) error {
|
||||
}, func(ctx context.Context, db *bun.DB) error {
|
||||
fmt.Println("dropping initial tables")
|
||||
for _, q := range down {
|
||||
_, err := db.Exec(q)
|
||||
1
database/migrations/2_bootstrap_users.tx.down.sql
Normal file
1
database/migrations/2_bootstrap_users.tx.down.sql
Normal file
|
|
@ -0,0 +1 @@
|
|||
TRUNCATE accounts CASCADE
|
||||
8
database/migrations/2_bootstrap_users.up.sql
Normal file
8
database/migrations/2_bootstrap_users.up.sql
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
INSERT INTO accounts (id, email, name, active, roles)
|
||||
VALUES (DEFAULT, 'admin@example.com', 'Admin Example', true, '{admin}');
|
||||
|
||||
--bun:split
|
||||
|
||||
INSERT INTO accounts (id, email, name, active)
|
||||
VALUES (DEFAULT, 'user@example.com', 'User Example', true);
|
||||
|
||||
|
|
@ -1,9 +1,10 @@
|
|||
package migrate
|
||||
package migrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/go-pg/migrations"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
const profileTable = `
|
||||
|
|
@ -30,7 +31,7 @@ func init() {
|
|||
`DROP TABLE profiles`,
|
||||
}
|
||||
|
||||
migrations.Register(func(db migrations.DB) error {
|
||||
Migrations.MustRegister(func(ctx context.Context, db *bun.DB) error {
|
||||
fmt.Println("create profile table")
|
||||
for _, q := range up {
|
||||
_, err := db.Exec(q)
|
||||
|
|
@ -39,7 +40,7 @@ func init() {
|
|||
}
|
||||
}
|
||||
return nil
|
||||
}, func(db migrations.DB) error {
|
||||
}, func(ctx context.Context, db *bun.DB) error {
|
||||
fmt.Println("drop profile table")
|
||||
for _, q := range down {
|
||||
_, err := db.Exec(q)
|
||||
49
database/migrations/main.go
Normal file
49
database/migrations/main.go
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
package migrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"embed"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/dhax/go-base/database"
|
||||
"github.com/uptrace/bun/migrate"
|
||||
)
|
||||
|
||||
//go:embed *.sql
|
||||
var sqlMigrations embed.FS
|
||||
|
||||
var Migrations = migrate.NewMigrations()
|
||||
|
||||
func init() {
|
||||
if err := Migrations.Discover(sqlMigrations); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Migrate runs all migrations
|
||||
func Migrate() {
|
||||
db, err := database.DBConn()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
migrator := migrate.NewMigrator(db, Migrations)
|
||||
|
||||
err = migrator.Init(context.Background())
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
group, err := migrator.Migrate(context.Background())
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
if group.ID == 0 {
|
||||
fmt.Printf("there are no new migrations to run\n")
|
||||
} else {
|
||||
fmt.Printf("migrated to %s\n", group)
|
||||
}
|
||||
}
|
||||
|
|
@ -2,54 +2,42 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"log"
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
|
||||
"github.com/go-pg/pg"
|
||||
"github.com/uptrace/bun"
|
||||
"github.com/uptrace/bun/dialect/pgdialect"
|
||||
"github.com/uptrace/bun/driver/pgdriver"
|
||||
"github.com/uptrace/bun/extra/bundebug"
|
||||
)
|
||||
|
||||
// DBConn returns a postgres connection pool.
|
||||
func DBConn() (*pg.DB, error) {
|
||||
func DBConn() (*bun.DB, error) {
|
||||
viper.SetDefault("db_network", "tcp")
|
||||
viper.SetDefault("db_addr", "localhost:5432")
|
||||
viper.SetDefault("db_user", "postgres")
|
||||
viper.SetDefault("db_password", "postgres")
|
||||
viper.SetDefault("db_database", "postgres")
|
||||
|
||||
db := pg.Connect(&pg.Options{
|
||||
Network: viper.GetString("db_network"),
|
||||
Addr: viper.GetString("db_addr"),
|
||||
User: viper.GetString("db_user"),
|
||||
Password: viper.GetString("db_password"),
|
||||
Database: viper.GetString("db_database"),
|
||||
})
|
||||
dsn := "postgres://" + viper.GetString("db_user") + ":" + viper.GetString("db_password") + "@" + viper.GetString("db_addr") + "/" + viper.GetString("db_database") + "?sslmode=disable"
|
||||
|
||||
sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(dsn)))
|
||||
|
||||
db := bun.NewDB(sqldb, pgdialect.New())
|
||||
|
||||
if err := checkConn(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if viper.GetBool("db_debug") {
|
||||
db.AddQueryHook(&logSQL{})
|
||||
db.AddQueryHook(bundebug.NewQueryHook(bundebug.WithVerbose(true)))
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
type logSQL struct{}
|
||||
|
||||
func (l *logSQL) BeforeQuery(e *pg.QueryEvent) {}
|
||||
|
||||
func (l *logSQL) AfterQuery(e *pg.QueryEvent) {
|
||||
query, err := e.FormattedQuery()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
log.Println(query)
|
||||
}
|
||||
|
||||
func checkConn(db *pg.DB) error {
|
||||
func checkConn(db *bun.DB) error {
|
||||
var n int
|
||||
_, err := db.QueryOne(pg.Scan(&n), "SELECT 1")
|
||||
return err
|
||||
return db.NewSelect().ColumnExpr("1").Scan(context.Background(), &n)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,17 +1,20 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/dhax/go-base/models"
|
||||
"github.com/go-pg/pg"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
// ProfileStore implements database operations for profile management.
|
||||
type ProfileStore struct {
|
||||
db *pg.DB
|
||||
db *bun.DB
|
||||
}
|
||||
|
||||
// NewProfileStore returns a ProfileStore implementation.
|
||||
func NewProfileStore(db *pg.DB) *ProfileStore {
|
||||
func NewProfileStore(db *bun.DB) *ProfileStore {
|
||||
return &ProfileStore{
|
||||
db: db,
|
||||
}
|
||||
|
|
@ -19,16 +22,26 @@ func NewProfileStore(db *pg.DB) *ProfileStore {
|
|||
|
||||
// Get gets an profile by account ID.
|
||||
func (s *ProfileStore) Get(accountID int) (*models.Profile, error) {
|
||||
p := models.Profile{AccountID: accountID}
|
||||
_, err := s.db.Model(&p).
|
||||
p := &models.Profile{AccountID: accountID}
|
||||
err := s.db.NewSelect().
|
||||
Model(p).
|
||||
Where("account_id = ?", accountID).
|
||||
SelectOrInsert()
|
||||
Scan(context.Background())
|
||||
|
||||
return &p, err
|
||||
if err == sql.ErrNoRows {
|
||||
_, err = s.db.NewInsert().
|
||||
Model(p).
|
||||
Exec(context.Background())
|
||||
}
|
||||
|
||||
return p, err
|
||||
}
|
||||
|
||||
// Update updates profile.
|
||||
func (s *ProfileStore) Update(p *models.Profile) error {
|
||||
err := s.db.Update(p)
|
||||
_, err := s.db.NewUpdate().
|
||||
Model(p).
|
||||
WherePK().
|
||||
Exec(context.Background())
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue