Move stuff around
This commit is contained in:
@@ -0,0 +1,11 @@
|
||||
package bench
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func BenchLoadSliceOfStructs(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
package dbc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
type (
|
||||
// BeforeCallback is a kind of function that can be called before a query is
|
||||
// executed.
|
||||
BeforeCallback func(ctx context.Context, query string)
|
||||
// AfterCallback is a kind of function that can be called after a query was
|
||||
// executed.
|
||||
AfterCallback func(ctx context.Context, query string, took time.Duration, err error)
|
||||
)
|
||||
|
||||
type callbacks struct {
|
||||
before []BeforeCallback
|
||||
after []AfterCallback
|
||||
}
|
||||
|
||||
func (c *callbacks) addBefore(cb BeforeCallback) {
|
||||
c.before = append(c.before, cb)
|
||||
}
|
||||
|
||||
func (c *callbacks) addAfter(cb AfterCallback) {
|
||||
c.after = append(c.after, cb)
|
||||
}
|
||||
|
||||
func (c *callbacks) callBefore(ctx context.Context, query string) {
|
||||
for _, cb := range c.before {
|
||||
cb(ctx, query)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *callbacks) callAfter(ctx context.Context, query string, took time.Duration, err error) {
|
||||
for _, cb := range c.after {
|
||||
cb(ctx, query, took, err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package dbc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCallChain(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mustExec(t, conn.
|
||||
Exec(ctx, "INSERT INTO sqldb_test (id, name) VALUES (3, 'Fred')").Then().
|
||||
Exec(ctx, "UPDATE sqldb_test SET name = 'Wilson' WHERE id = 3").Then().
|
||||
Exec(ctx, "DELETE FROM sqldb_test WHERE id = 3"),
|
||||
)
|
||||
}
|
||||
+126
@@ -0,0 +1,126 @@
|
||||
package dbc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/juju/errors"
|
||||
)
|
||||
|
||||
// Conn represents database connection.
|
||||
type Conn interface {
|
||||
connOrTx
|
||||
// Begin executes a transaction.
|
||||
Begin(context.Context, func(Tx) error) error
|
||||
// BeginCustom executes a transaction with provided options.
|
||||
BeginCustom(context.Context, func(Tx) error, *sql.TxOptions) error
|
||||
// Close closes the connection.
|
||||
Close() error
|
||||
// DB returns the underlying DB object.
|
||||
DB() *sql.DB
|
||||
// Before adds a callback function that would be called before a query is
|
||||
// executed.
|
||||
Before(BeforeCallback)
|
||||
// After adds a callback function that would be called after a query was
|
||||
// executed.
|
||||
After(AfterCallback)
|
||||
}
|
||||
|
||||
// Tx represents database transacation.
|
||||
type Tx interface {
|
||||
connOrTx
|
||||
Commit() error
|
||||
Rollback() error
|
||||
}
|
||||
|
||||
type dbWrapper struct {
|
||||
conn *sql.DB
|
||||
*caller
|
||||
}
|
||||
|
||||
// Flavor defines a kind of SQL database.
|
||||
type Flavor string
|
||||
|
||||
const (
|
||||
// MySQL is the MySQL SQL flavor.
|
||||
MySQL Flavor = "mysql"
|
||||
// PostgreSQL is the PostgreSQL SQL flavor.
|
||||
PostgreSQL Flavor = "postgresql"
|
||||
)
|
||||
|
||||
// Connect establishes a new database connection.
|
||||
func Connect(ctx context.Context, f Flavor, dsn string) (Conn, error) {
|
||||
conn, err := sql.Open(string(f), dsn)
|
||||
if err != nil {
|
||||
return nil, errors.Annotate(err, "Failed to establish connection")
|
||||
}
|
||||
err = conn.PingContext(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Annotate(err, "Connection is not responding")
|
||||
}
|
||||
return &dbWrapper{
|
||||
conn: conn,
|
||||
caller: &caller{
|
||||
db: conn,
|
||||
cb: &callbacks{},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *dbWrapper) Begin(ctx context.Context, fn func(tx Tx) error) error {
|
||||
return c.BeginCustom(ctx, fn, nil)
|
||||
}
|
||||
|
||||
func (c *dbWrapper) BeginCustom(ctx context.Context, fn func(tx Tx) error, opts *sql.TxOptions) error {
|
||||
if opts == nil {
|
||||
opts = &sql.TxOptions{}
|
||||
}
|
||||
tx, err := c.conn.BeginTx(ctx, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = fn(c.wrapTx(tx))
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *dbWrapper) DB() *sql.DB {
|
||||
return c.conn
|
||||
}
|
||||
|
||||
func (c *dbWrapper) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
func (c *dbWrapper) Before(cb BeforeCallback) {
|
||||
c.cb.addBefore(cb)
|
||||
}
|
||||
|
||||
func (c *dbWrapper) After(cb AfterCallback) {
|
||||
c.cb.addAfter(cb)
|
||||
}
|
||||
|
||||
func (c *dbWrapper) wrapTx(tx *sql.Tx) Tx {
|
||||
return &txWrapper{
|
||||
tx: tx,
|
||||
connOrTx: &caller{
|
||||
db: tx,
|
||||
cb: c.cb,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type txWrapper struct {
|
||||
tx *sql.Tx
|
||||
connOrTx
|
||||
}
|
||||
|
||||
func (w *txWrapper) Commit() error {
|
||||
return w.tx.Commit()
|
||||
}
|
||||
|
||||
func (w *txWrapper) Rollback() error {
|
||||
return w.tx.Rollback()
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
package dbc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
// ExecResult ...
|
||||
type ExecResult interface {
|
||||
Error() error
|
||||
LastInsertID() int64
|
||||
RowsAffected() int64
|
||||
Result() sql.Result
|
||||
Then() ExecChain
|
||||
}
|
||||
|
||||
type execResult struct {
|
||||
db executer
|
||||
err error
|
||||
res sql.Result
|
||||
}
|
||||
|
||||
func (r *execResult) Result() sql.Result {
|
||||
return r.res
|
||||
}
|
||||
|
||||
func (r *execResult) Error() error {
|
||||
return r.err
|
||||
}
|
||||
|
||||
func (r *execResult) LastInsertID() int64 {
|
||||
if r.res == nil {
|
||||
return 0
|
||||
}
|
||||
id, err := r.res.LastInsertId()
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
func (r *execResult) RowsAffected() int64 {
|
||||
if r.res == nil {
|
||||
return 0
|
||||
}
|
||||
ra, err := r.res.RowsAffected()
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return ra
|
||||
}
|
||||
|
||||
func (r *execResult) Then() ExecChain {
|
||||
if r.err != nil {
|
||||
return &execChain{r.db}
|
||||
}
|
||||
return r.db
|
||||
}
|
||||
|
||||
//
|
||||
// Chain
|
||||
//
|
||||
|
||||
// ExecChain ...
|
||||
type ExecChain interface {
|
||||
executer
|
||||
}
|
||||
|
||||
type execChain struct {
|
||||
executer
|
||||
}
|
||||
|
||||
type brokenChain struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (c *brokenChain) Exec(_ context.Context, _ string, _ ...interface{}) ExecResult {
|
||||
return &execResult{err: c.err}
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
package dbc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
)
|
||||
|
||||
type connOrTx interface {
|
||||
executer
|
||||
queryPerformer
|
||||
}
|
||||
|
||||
type executer interface {
|
||||
// Exec executes a query and does not expect any result.
|
||||
Exec(ctx context.Context, query string, args ...interface{}) ExecResult
|
||||
ExecNamed(ctx context.Context, query string, arg interface{}) ExecResult
|
||||
}
|
||||
|
||||
type queryPerformer interface {
|
||||
// Query executes a query and returns a result object that can later be used
|
||||
// to retrieve values.
|
||||
Query(ctx context.Context, query string, args ...interface{}) Rows
|
||||
QueryNamed(ctx context.Context, query string, arg interface{}) Rows
|
||||
}
|
||||
|
||||
type stdConnOrTx interface {
|
||||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
type caller struct {
|
||||
db stdConnOrTx
|
||||
cb *callbacks
|
||||
}
|
||||
|
||||
func (c *caller) Exec(ctx context.Context, query string, args ...interface{}) ExecResult {
|
||||
c.cb.callBefore(ctx, query)
|
||||
startedAt := time.Now()
|
||||
res, err := c.db.ExecContext(ctx, query, args...)
|
||||
c.cb.callAfter(ctx, query, time.Since(startedAt), err)
|
||||
|
||||
return &execResult{
|
||||
db: c,
|
||||
err: err,
|
||||
res: res,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *caller) ExecNamed(ctx context.Context, query string, arg interface{}) ExecResult {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *caller) Query(ctx context.Context, query string, args ...interface{}) Rows {
|
||||
c.cb.callBefore(ctx, query)
|
||||
startedAt := time.Now()
|
||||
r, err := c.db.QueryContext(ctx, query, args...)
|
||||
c.cb.callAfter(ctx, query, time.Since(startedAt), err)
|
||||
return &rows{
|
||||
err: err,
|
||||
rows: r,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *caller) QueryNamed(ctx context.Context, query string, arg interface{}) Rows {
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
package dbc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestLoadSingleValue(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
exp := int(1)
|
||||
var out int
|
||||
mustQuery(t, conn.Query(ctx, "SELECT 1").Load(&out))
|
||||
if exp != out {
|
||||
t.Errorf("Value doesn't match: expected %d, got %d", exp, out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadSlice(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
exp := []int{1, 2}
|
||||
var out []int
|
||||
mustQuery(t, conn.Query(ctx, "SELECT id FROM sqldb_test").Load(&out))
|
||||
if !cmp.Equal(exp, out) {
|
||||
t.Errorf("Values dont't match: %s", cmp.Diff(exp, out))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadMap(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
exp := map[string]interface{}{"id": int64(1), "name": "Alice"}
|
||||
var out map[string]interface{}
|
||||
mustQuery(t, conn.Query(ctx, "SELECT * FROM sqldb_test WHERE id = 1").Load(&out))
|
||||
if !cmp.Equal(exp, out) {
|
||||
t.Errorf("Record doesn't match: %s", cmp.Diff(exp, out))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadSliceOfMaps(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
exp := []map[string]interface{}{
|
||||
{"id": int64(1), "name": "Alice"},
|
||||
{"id": int64(2), "name": "Bob"},
|
||||
}
|
||||
var out []map[string]interface{}
|
||||
mustQuery(t, conn.Query(ctx, "SELECT * FROM sqldb_test ORDER BY id ASC").Load(&out))
|
||||
if !cmp.Equal(exp, out) {
|
||||
t.Errorf("Records don't match: %s", cmp.Diff(exp, out))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadStruct(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
exp := record{ID: 1, Name: "Alice"}
|
||||
var out record
|
||||
mustQuery(t, conn.Query(ctx, "SELECT * FROM sqldb_test WHERE id = 1").Load(&out))
|
||||
if !cmp.Equal(exp, out) {
|
||||
t.Errorf("Record doesn't match: %s", cmp.Diff(exp, out))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadSliceOfStructs(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
exp := []record{
|
||||
{ID: 1, Name: "Alice"},
|
||||
{ID: 2, Name: "Bob"},
|
||||
}
|
||||
var out []record
|
||||
mustQuery(t, conn.Query(ctx, "SELECT * FROM sqldb_test").Load(&out))
|
||||
if !cmp.Equal(exp, out) {
|
||||
t.Errorf("Records don't match: %s", cmp.Diff(exp, out))
|
||||
}
|
||||
}
|
||||
+265
@@ -0,0 +1,265 @@
|
||||
package dbc
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/localhots/gobelt/reflect2"
|
||||
)
|
||||
|
||||
// Rows ...
|
||||
type Rows interface {
|
||||
Error() error
|
||||
Load(dest interface{}) error
|
||||
Rows() *sql.Rows
|
||||
}
|
||||
|
||||
const tagName = "db"
|
||||
|
||||
type rows struct {
|
||||
err error
|
||||
rows *sql.Rows
|
||||
}
|
||||
|
||||
func (r *rows) Rows() *sql.Rows {
|
||||
return r.rows
|
||||
}
|
||||
|
||||
func (r *rows) Error() error {
|
||||
return r.err
|
||||
}
|
||||
|
||||
func (r *rows) Load(dest interface{}) error {
|
||||
if r.err != nil {
|
||||
return r.err
|
||||
}
|
||||
defer r.rows.Close()
|
||||
|
||||
dtyp := reflect.TypeOf(dest)
|
||||
if dtyp.Kind() != reflect.Ptr {
|
||||
panic("Value must be a pointer")
|
||||
}
|
||||
dtyp = dtyp.Elem()
|
||||
|
||||
switch dtyp.Kind() {
|
||||
case reflect.Struct:
|
||||
r.loadStruct(dtyp, dest)
|
||||
case reflect.Map:
|
||||
r.loadMap(dest.(*map[string]interface{}))
|
||||
case reflect.Slice:
|
||||
switch dtyp.Elem().Kind() {
|
||||
case reflect.Struct:
|
||||
r.loadSliceOfStructs(dtyp, dest)
|
||||
case reflect.Map:
|
||||
r.loadSliceOfMaps(dest.(*[]map[string]interface{}))
|
||||
default:
|
||||
r.loadSlice(dtyp, dest)
|
||||
}
|
||||
default:
|
||||
r.loadValue(dest)
|
||||
}
|
||||
|
||||
if r.err == nil && r.rows.Err() != nil {
|
||||
return r.rows.Err()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *rows) loadValue(dest interface{}) {
|
||||
if r.rows.Next() {
|
||||
r.err = r.rows.Scan(dest)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *rows) loadSlice(typ reflect.Type, dest interface{}) {
|
||||
vSlice := reflect.MakeSlice(typ, 0, 0)
|
||||
for r.rows.Next() {
|
||||
val := reflect.New(typ.Elem())
|
||||
r.err = r.rows.Scan(val.Interface())
|
||||
if r.err != nil {
|
||||
return
|
||||
}
|
||||
vSlice = reflect.Append(vSlice, val.Elem())
|
||||
}
|
||||
reflect.ValueOf(dest).Elem().Set(vSlice)
|
||||
}
|
||||
|
||||
func (r *rows) loadMap(dest *map[string]interface{}) {
|
||||
if !r.rows.Next() {
|
||||
return
|
||||
}
|
||||
|
||||
cols, err := r.rows.Columns()
|
||||
if err != nil {
|
||||
r.err = err
|
||||
return
|
||||
}
|
||||
colTypes, err := r.rows.ColumnTypes()
|
||||
if err != nil {
|
||||
r.err = err
|
||||
return
|
||||
}
|
||||
|
||||
vals := make([]interface{}, len(cols))
|
||||
for i := range cols {
|
||||
vals[i] = newValue(colTypes[i])
|
||||
}
|
||||
err = r.rows.Scan(vals...)
|
||||
if err != nil {
|
||||
r.err = err
|
||||
return
|
||||
}
|
||||
|
||||
if *dest == nil {
|
||||
*dest = make(map[string]interface{}, len(cols))
|
||||
}
|
||||
for i, col := range cols {
|
||||
switch tval := vals[i].(type) {
|
||||
case *int64:
|
||||
(*dest)[col] = *tval
|
||||
case *string:
|
||||
(*dest)[col] = *tval
|
||||
case *bool:
|
||||
(*dest)[col] = *tval
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *rows) loadSliceOfMaps(dest *[]map[string]interface{}) {
|
||||
cols, err := r.rows.Columns()
|
||||
if err != nil {
|
||||
r.err = err
|
||||
return
|
||||
}
|
||||
colTypes, err := r.rows.ColumnTypes()
|
||||
if err != nil {
|
||||
r.err = err
|
||||
return
|
||||
}
|
||||
|
||||
if *dest == nil {
|
||||
*dest = make([]map[string]interface{}, 0)
|
||||
}
|
||||
for r.rows.Next() {
|
||||
vals := make([]interface{}, len(cols))
|
||||
for i := range cols {
|
||||
vals[i] = newValue(colTypes[i])
|
||||
}
|
||||
err = r.rows.Scan(vals...)
|
||||
if err != nil {
|
||||
r.err = err
|
||||
return
|
||||
}
|
||||
|
||||
row := make(map[string]interface{}, len(cols))
|
||||
for i, col := range cols {
|
||||
switch tval := vals[i].(type) {
|
||||
case *int64:
|
||||
row[col] = *tval
|
||||
case *string:
|
||||
row[col] = *tval
|
||||
case *bool:
|
||||
row[col] = *tval
|
||||
}
|
||||
}
|
||||
*dest = append(*dest, row)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *rows) loadStruct(typ reflect.Type, dest interface{}) {
|
||||
if !r.rows.Next() {
|
||||
return
|
||||
}
|
||||
|
||||
cols, err := r.rows.Columns()
|
||||
if err != nil {
|
||||
r.err = err
|
||||
return
|
||||
}
|
||||
|
||||
val := reflect.ValueOf(dest).Elem()
|
||||
vals := make([]interface{}, len(cols))
|
||||
tm := reflect2.AssociateColumns(val.Type(), tagName, cols)
|
||||
for i := range cols {
|
||||
if fi, ok := tm[i]; ok {
|
||||
fval := val.Field(fi)
|
||||
vals[i] = reflect.New(fval.Type()).Interface()
|
||||
} else {
|
||||
var dummy interface{}
|
||||
vals[i] = &dummy
|
||||
}
|
||||
}
|
||||
|
||||
err = r.rows.Scan(vals...)
|
||||
if err != nil {
|
||||
r.err = err
|
||||
return
|
||||
}
|
||||
|
||||
for i, fi := range tm {
|
||||
fval := val.Field(fi)
|
||||
fval.Set(reflect.ValueOf(vals[i]).Elem())
|
||||
}
|
||||
}
|
||||
|
||||
func (r *rows) loadSliceOfStructs(typ reflect.Type, dest interface{}) {
|
||||
cols, err := r.rows.Columns()
|
||||
if err != nil {
|
||||
r.err = err
|
||||
return
|
||||
}
|
||||
|
||||
vSlice := reflect.ValueOf(dest).Elem()
|
||||
tSlice := vSlice.Type()
|
||||
tElem := tSlice.Elem()
|
||||
tm := reflect2.AssociateColumns(tElem, tagName, cols)
|
||||
|
||||
for r.rows.Next() {
|
||||
vals := make([]interface{}, len(cols))
|
||||
val := reflect.New(tElem).Elem()
|
||||
for i := range cols {
|
||||
if fi, ok := tm[i]; ok {
|
||||
fval := val.Field(fi)
|
||||
vals[i] = reflect.New(fval.Type()).Interface()
|
||||
} else {
|
||||
var dummy interface{}
|
||||
vals[i] = &dummy
|
||||
}
|
||||
}
|
||||
|
||||
err = r.rows.Scan(vals...)
|
||||
if err != nil {
|
||||
r.err = err
|
||||
return
|
||||
}
|
||||
|
||||
for i, fi := range tm {
|
||||
fval := val.Field(fi)
|
||||
fval.Set(reflect.ValueOf(vals[i]).Elem())
|
||||
}
|
||||
vSlice.Set(reflect.Append(vSlice, val))
|
||||
}
|
||||
}
|
||||
|
||||
func (r *rows) withError(err error) Rows {
|
||||
r.err = err
|
||||
return r
|
||||
}
|
||||
|
||||
func newValue(typ *sql.ColumnType) interface{} {
|
||||
switch typ.DatabaseTypeName() {
|
||||
case "VARCHAR", "NVARCHAR", "TEXT":
|
||||
var s string
|
||||
return &s
|
||||
case "INT", "BIGINT":
|
||||
var i int64
|
||||
return &i
|
||||
case "BOOL":
|
||||
var b bool
|
||||
return &b
|
||||
default:
|
||||
panic(fmt.Errorf("Unsupported MySQL type: %s", typ.DatabaseTypeName()))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
package dbc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql" // MySQL driver
|
||||
)
|
||||
|
||||
type record struct {
|
||||
ID uint `db:"id"`
|
||||
Name string `db:"name"`
|
||||
}
|
||||
|
||||
var conn Conn
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
dsn := flag.String("dsn", "", "Database source name")
|
||||
flag.Parse()
|
||||
if *dsn == "" {
|
||||
log.Println("Database source name is not provided, skipping package tests")
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
log.Println("Establishing connection to the test database")
|
||||
ctx := context.Background()
|
||||
var err error
|
||||
conn, err = Connect(ctx, MySQL, *dsn)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to connect: %v\n", err)
|
||||
}
|
||||
|
||||
log.Println("Seeding database")
|
||||
mustExecMain(conn.Exec(ctx, `
|
||||
DROP TABLE IF EXISTS sqldb_test
|
||||
`))
|
||||
mustExecMain(conn.Exec(ctx, `
|
||||
CREATE TABLE sqldb_test (
|
||||
id int(11) UNSIGNED NOT NULL,
|
||||
name VARCHAR(10) DEFAULT '',
|
||||
PRIMARY KEY (id)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=ascii
|
||||
`))
|
||||
mustExecMain(conn.Exec(ctx, `
|
||||
INSERT INTO sqldb_test (id, name)
|
||||
VALUES
|
||||
(1, "Alice"),
|
||||
(2, "Bob")
|
||||
`))
|
||||
|
||||
fmt.Println("Starting test suite")
|
||||
exitCode := m.Run()
|
||||
log.Println("Test suite finished")
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Printf("Failed to close connection: %v\n", err)
|
||||
}
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
|
||||
func mustExec(t *testing.T, r ExecResult) ExecResult {
|
||||
t.Helper()
|
||||
if r.Error() != nil {
|
||||
t.Fatalf("Exec failed: %v", r.Error())
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func mustExecMain(r ExecResult) ExecResult {
|
||||
if r.Error() != nil {
|
||||
log.Fatalf("Query failed: %v\n", r.Error())
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func mustQuery(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
if err != nil {
|
||||
t.Fatalf("Query failed: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
|
||||
"github.com/localhots/gobelt/dbc"
|
||||
"github.com/localhots/gobelt/log"
|
||||
)
|
||||
|
||||
func main() {
|
||||
dsn := flag.String("dsn", "", "Database source name")
|
||||
flag.Parse()
|
||||
|
||||
ctx := context.Background()
|
||||
conn, err := dbc.Connect(ctx, dbc.MySQL, *dsn)
|
||||
if err != nil {
|
||||
log.Fatal(ctx, "Failed to establish database conneciton", log.F{
|
||||
"dsn": dsn,
|
||||
"error": err,
|
||||
})
|
||||
}
|
||||
|
||||
_ = conn
|
||||
}
|
||||
Reference in New Issue
Block a user