Move stuff around
This commit is contained in:
parent
9b845584ed
commit
6c6ccaf0b3
|
@ -0,0 +1,11 @@
|
||||||
|
package bench
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchLoadSliceOfStructs(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package sqldb
|
package dbc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
|
@ -1,4 +1,4 @@
|
||||||
package sqldb
|
package dbc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
@ -7,7 +7,7 @@ import (
|
||||||
|
|
||||||
func TestCallChain(t *testing.T) {
|
func TestCallChain(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
mustT(t, conn.
|
mustExec(t, conn.
|
||||||
Exec(ctx, "INSERT INTO sqldb_test (id, name) VALUES (3, 'Fred')").Then().
|
Exec(ctx, "INSERT INTO sqldb_test (id, name) VALUES (3, 'Fred')").Then().
|
||||||
Exec(ctx, "UPDATE sqldb_test SET name = 'Wilson' WHERE id = 3").Then().
|
Exec(ctx, "UPDATE sqldb_test SET name = 'Wilson' WHERE id = 3").Then().
|
||||||
Exec(ctx, "DELETE FROM sqldb_test WHERE id = 3"),
|
Exec(ctx, "DELETE FROM sqldb_test WHERE id = 3"),
|
|
@ -1,4 +1,4 @@
|
||||||
package sqldb
|
package dbc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
@ -10,10 +10,19 @@ import (
|
||||||
// Conn represents database connection.
|
// Conn represents database connection.
|
||||||
type Conn interface {
|
type Conn interface {
|
||||||
connOrTx
|
connOrTx
|
||||||
|
// Begin executes a transaction.
|
||||||
Begin(context.Context, func(Tx) error) error
|
Begin(context.Context, func(Tx) error) error
|
||||||
|
// BeginCustom executes a transaction with provided options.
|
||||||
|
BeginCustom(context.Context, func(Tx) error, *sql.TxOptions) error
|
||||||
|
// Close closes the connection.
|
||||||
Close() error
|
Close() error
|
||||||
|
// DB returns the underlying DB object.
|
||||||
DB() *sql.DB
|
DB() *sql.DB
|
||||||
|
// Before adds a callback function that would be called before a query is
|
||||||
|
// executed.
|
||||||
Before(BeforeCallback)
|
Before(BeforeCallback)
|
||||||
|
// After adds a callback function that would be called after a query was
|
||||||
|
// executed.
|
||||||
After(AfterCallback)
|
After(AfterCallback)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -58,9 +67,15 @@ func Connect(ctx context.Context, f Flavor, dsn string) (Conn, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Begin executes a transaction.
|
|
||||||
func (c *dbWrapper) Begin(ctx context.Context, fn func(tx Tx) error) error {
|
func (c *dbWrapper) Begin(ctx context.Context, fn func(tx Tx) error) error {
|
||||||
tx, err := c.conn.BeginTx(ctx, &sql.TxOptions{})
|
return c.BeginCustom(ctx, fn, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *dbWrapper) BeginCustom(ctx context.Context, fn func(tx Tx) error, opts *sql.TxOptions) error {
|
||||||
|
if opts == nil {
|
||||||
|
opts = &sql.TxOptions{}
|
||||||
|
}
|
||||||
|
tx, err := c.conn.BeginTx(ctx, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -71,24 +86,18 @@ func (c *dbWrapper) Begin(ctx context.Context, fn func(tx Tx) error) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// DB returns the underlying DB object.
|
|
||||||
func (c *dbWrapper) DB() *sql.DB {
|
func (c *dbWrapper) DB() *sql.DB {
|
||||||
return c.conn
|
return c.conn
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the connection.
|
|
||||||
func (c *dbWrapper) Close() error {
|
func (c *dbWrapper) Close() error {
|
||||||
return c.conn.Close()
|
return c.conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Before adds a callback function that would be called before a query is
|
|
||||||
// executed.
|
|
||||||
func (c *dbWrapper) Before(cb BeforeCallback) {
|
func (c *dbWrapper) Before(cb BeforeCallback) {
|
||||||
c.cb.addBefore(cb)
|
c.cb.addBefore(cb)
|
||||||
}
|
}
|
||||||
|
|
||||||
// After adds a callback function that would be called after a query was
|
|
||||||
// executed.
|
|
||||||
func (c *dbWrapper) After(cb AfterCallback) {
|
func (c *dbWrapper) After(cb AfterCallback) {
|
||||||
c.cb.addAfter(cb)
|
c.cb.addAfter(cb)
|
||||||
}
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package sqldb
|
package dbc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
|
@ -1,4 +1,4 @@
|
||||||
package sqldb
|
package dbc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
@ -20,8 +20,8 @@ type executer interface {
|
||||||
type queryPerformer interface {
|
type queryPerformer interface {
|
||||||
// Query executes a query and returns a result object that can later be used
|
// Query executes a query and returns a result object that can later be used
|
||||||
// to retrieve values.
|
// to retrieve values.
|
||||||
Query(ctx context.Context, query string, args ...interface{}) QueryResult
|
Query(ctx context.Context, query string, args ...interface{}) Rows
|
||||||
QueryNamed(ctx context.Context, query string, arg interface{}) QueryResult
|
QueryNamed(ctx context.Context, query string, arg interface{}) Rows
|
||||||
}
|
}
|
||||||
|
|
||||||
type stdConnOrTx interface {
|
type stdConnOrTx interface {
|
||||||
|
@ -51,17 +51,17 @@ func (c *caller) ExecNamed(ctx context.Context, query string, arg interface{}) E
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *caller) Query(ctx context.Context, query string, args ...interface{}) QueryResult {
|
func (c *caller) Query(ctx context.Context, query string, args ...interface{}) Rows {
|
||||||
c.cb.callBefore(ctx, query)
|
c.cb.callBefore(ctx, query)
|
||||||
startedAt := time.Now()
|
startedAt := time.Now()
|
||||||
rows, err := c.db.QueryContext(ctx, query, args...)
|
r, err := c.db.QueryContext(ctx, query, args...)
|
||||||
c.cb.callAfter(ctx, query, time.Since(startedAt), err)
|
c.cb.callAfter(ctx, query, time.Since(startedAt), err)
|
||||||
return &queryResult{
|
return &rows{
|
||||||
err: err,
|
err: err,
|
||||||
rows: rows,
|
rows: r,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *caller) QueryNamed(ctx context.Context, query string, arg interface{}) QueryResult {
|
func (c *caller) QueryNamed(ctx context.Context, query string, arg interface{}) Rows {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package sqldb
|
package dbc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
@ -11,7 +11,7 @@ func TestLoadSingleValue(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
exp := int(1)
|
exp := int(1)
|
||||||
var out int
|
var out int
|
||||||
mustT(t, conn.Query(ctx, "SELECT 1").Load(&out))
|
mustQuery(t, conn.Query(ctx, "SELECT 1").Load(&out))
|
||||||
if exp != out {
|
if exp != out {
|
||||||
t.Errorf("Value doesn't match: expected %d, got %d", exp, out)
|
t.Errorf("Value doesn't match: expected %d, got %d", exp, out)
|
||||||
}
|
}
|
||||||
|
@ -21,7 +21,7 @@ func TestLoadSlice(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
exp := []int{1, 2}
|
exp := []int{1, 2}
|
||||||
var out []int
|
var out []int
|
||||||
mustT(t, conn.Query(ctx, "SELECT id FROM sqldb_test").Load(&out))
|
mustQuery(t, conn.Query(ctx, "SELECT id FROM sqldb_test").Load(&out))
|
||||||
if !cmp.Equal(exp, out) {
|
if !cmp.Equal(exp, out) {
|
||||||
t.Errorf("Values dont't match: %s", cmp.Diff(exp, out))
|
t.Errorf("Values dont't match: %s", cmp.Diff(exp, out))
|
||||||
}
|
}
|
||||||
|
@ -31,7 +31,7 @@ func TestLoadMap(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
exp := map[string]interface{}{"id": int64(1), "name": "Alice"}
|
exp := map[string]interface{}{"id": int64(1), "name": "Alice"}
|
||||||
var out map[string]interface{}
|
var out map[string]interface{}
|
||||||
mustT(t, conn.Query(ctx, "SELECT * FROM sqldb_test WHERE id = 1").Load(&out))
|
mustQuery(t, conn.Query(ctx, "SELECT * FROM sqldb_test WHERE id = 1").Load(&out))
|
||||||
if !cmp.Equal(exp, out) {
|
if !cmp.Equal(exp, out) {
|
||||||
t.Errorf("Record doesn't match: %s", cmp.Diff(exp, out))
|
t.Errorf("Record doesn't match: %s", cmp.Diff(exp, out))
|
||||||
}
|
}
|
||||||
|
@ -44,7 +44,7 @@ func TestLoadSliceOfMaps(t *testing.T) {
|
||||||
{"id": int64(2), "name": "Bob"},
|
{"id": int64(2), "name": "Bob"},
|
||||||
}
|
}
|
||||||
var out []map[string]interface{}
|
var out []map[string]interface{}
|
||||||
mustT(t, conn.Query(ctx, "SELECT * FROM sqldb_test ORDER BY id ASC").Load(&out))
|
mustQuery(t, conn.Query(ctx, "SELECT * FROM sqldb_test ORDER BY id ASC").Load(&out))
|
||||||
if !cmp.Equal(exp, out) {
|
if !cmp.Equal(exp, out) {
|
||||||
t.Errorf("Records don't match: %s", cmp.Diff(exp, out))
|
t.Errorf("Records don't match: %s", cmp.Diff(exp, out))
|
||||||
}
|
}
|
||||||
|
@ -54,7 +54,7 @@ func TestLoadStruct(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
exp := record{ID: 1, Name: "Alice"}
|
exp := record{ID: 1, Name: "Alice"}
|
||||||
var out record
|
var out record
|
||||||
mustT(t, conn.Query(ctx, "SELECT * FROM sqldb_test WHERE id = 1").Load(&out))
|
mustQuery(t, conn.Query(ctx, "SELECT * FROM sqldb_test WHERE id = 1").Load(&out))
|
||||||
if !cmp.Equal(exp, out) {
|
if !cmp.Equal(exp, out) {
|
||||||
t.Errorf("Record doesn't match: %s", cmp.Diff(exp, out))
|
t.Errorf("Record doesn't match: %s", cmp.Diff(exp, out))
|
||||||
}
|
}
|
||||||
|
@ -67,7 +67,7 @@ func TestLoadSliceOfStructs(t *testing.T) {
|
||||||
{ID: 2, Name: "Bob"},
|
{ID: 2, Name: "Bob"},
|
||||||
}
|
}
|
||||||
var out []record
|
var out []record
|
||||||
mustT(t, conn.Query(ctx, "SELECT * FROM sqldb_test").Load(&out))
|
mustQuery(t, conn.Query(ctx, "SELECT * FROM sqldb_test").Load(&out))
|
||||||
if !cmp.Equal(exp, out) {
|
if !cmp.Equal(exp, out) {
|
||||||
t.Errorf("Records don't match: %s", cmp.Diff(exp, out))
|
t.Errorf("Records don't match: %s", cmp.Diff(exp, out))
|
||||||
}
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package sqldb
|
package dbc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
@ -8,8 +8,8 @@ import (
|
||||||
"github.com/localhots/gobelt/reflect2"
|
"github.com/localhots/gobelt/reflect2"
|
||||||
)
|
)
|
||||||
|
|
||||||
// QueryResult ...
|
// Rows ...
|
||||||
type QueryResult interface {
|
type Rows interface {
|
||||||
Error() error
|
Error() error
|
||||||
Load(dest interface{}) error
|
Load(dest interface{}) error
|
||||||
Rows() *sql.Rows
|
Rows() *sql.Rows
|
||||||
|
@ -17,20 +17,20 @@ type QueryResult interface {
|
||||||
|
|
||||||
const tagName = "db"
|
const tagName = "db"
|
||||||
|
|
||||||
type queryResult struct {
|
type rows struct {
|
||||||
err error
|
err error
|
||||||
rows *sql.Rows
|
rows *sql.Rows
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *queryResult) Rows() *sql.Rows {
|
func (r *rows) Rows() *sql.Rows {
|
||||||
return r.rows
|
return r.rows
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *queryResult) Error() error {
|
func (r *rows) Error() error {
|
||||||
return r.err
|
return r.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *queryResult) Load(dest interface{}) error {
|
func (r *rows) Load(dest interface{}) error {
|
||||||
if r.err != nil {
|
if r.err != nil {
|
||||||
return r.err
|
return r.err
|
||||||
}
|
}
|
||||||
|
@ -67,13 +67,13 @@ func (r *queryResult) Load(dest interface{}) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *queryResult) loadValue(dest interface{}) {
|
func (r *rows) loadValue(dest interface{}) {
|
||||||
if r.rows.Next() {
|
if r.rows.Next() {
|
||||||
r.err = r.rows.Scan(dest)
|
r.err = r.rows.Scan(dest)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *queryResult) loadSlice(typ reflect.Type, dest interface{}) {
|
func (r *rows) loadSlice(typ reflect.Type, dest interface{}) {
|
||||||
vSlice := reflect.MakeSlice(typ, 0, 0)
|
vSlice := reflect.MakeSlice(typ, 0, 0)
|
||||||
for r.rows.Next() {
|
for r.rows.Next() {
|
||||||
val := reflect.New(typ.Elem())
|
val := reflect.New(typ.Elem())
|
||||||
|
@ -86,7 +86,7 @@ func (r *queryResult) loadSlice(typ reflect.Type, dest interface{}) {
|
||||||
reflect.ValueOf(dest).Elem().Set(vSlice)
|
reflect.ValueOf(dest).Elem().Set(vSlice)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *queryResult) loadMap(dest *map[string]interface{}) {
|
func (r *rows) loadMap(dest *map[string]interface{}) {
|
||||||
if !r.rows.Next() {
|
if !r.rows.Next() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -127,7 +127,7 @@ func (r *queryResult) loadMap(dest *map[string]interface{}) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *queryResult) loadSliceOfMaps(dest *[]map[string]interface{}) {
|
func (r *rows) loadSliceOfMaps(dest *[]map[string]interface{}) {
|
||||||
cols, err := r.rows.Columns()
|
cols, err := r.rows.Columns()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.err = err
|
r.err = err
|
||||||
|
@ -168,7 +168,7 @@ func (r *queryResult) loadSliceOfMaps(dest *[]map[string]interface{}) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *queryResult) loadStruct(typ reflect.Type, dest interface{}) {
|
func (r *rows) loadStruct(typ reflect.Type, dest interface{}) {
|
||||||
if !r.rows.Next() {
|
if !r.rows.Next() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -204,7 +204,7 @@ func (r *queryResult) loadStruct(typ reflect.Type, dest interface{}) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *queryResult) loadSliceOfStructs(typ reflect.Type, dest interface{}) {
|
func (r *rows) loadSliceOfStructs(typ reflect.Type, dest interface{}) {
|
||||||
cols, err := r.rows.Columns()
|
cols, err := r.rows.Columns()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.err = err
|
r.err = err
|
||||||
|
@ -243,7 +243,7 @@ func (r *queryResult) loadSliceOfStructs(typ reflect.Type, dest interface{}) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *queryResult) withError(err error) *queryResult {
|
func (r *rows) withError(err error) Rows {
|
||||||
r.err = err
|
r.err = err
|
||||||
return r
|
return r
|
||||||
}
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package sqldb
|
package dbc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
@ -16,7 +16,7 @@ type record struct {
|
||||||
Name string `db:"name"`
|
Name string `db:"name"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var conn *Conn
|
var conn Conn
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
dsn := flag.String("dsn", "", "Database source name")
|
dsn := flag.String("dsn", "", "Database source name")
|
||||||
|
@ -35,17 +35,17 @@ func TestMain(m *testing.M) {
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Println("Seeding database")
|
log.Println("Seeding database")
|
||||||
must(conn.Exec(ctx, `
|
mustExecMain(conn.Exec(ctx, `
|
||||||
DROP TABLE IF EXISTS sqldb_test
|
DROP TABLE IF EXISTS sqldb_test
|
||||||
`))
|
`))
|
||||||
must(conn.Exec(ctx, `
|
mustExecMain(conn.Exec(ctx, `
|
||||||
CREATE TABLE sqldb_test (
|
CREATE TABLE sqldb_test (
|
||||||
id int(11) UNSIGNED NOT NULL,
|
id int(11) UNSIGNED NOT NULL,
|
||||||
name VARCHAR(10) DEFAULT '',
|
name VARCHAR(10) DEFAULT '',
|
||||||
PRIMARY KEY (id)
|
PRIMARY KEY (id)
|
||||||
) ENGINE=InnoDB DEFAULT CHARSET=ascii
|
) ENGINE=InnoDB DEFAULT CHARSET=ascii
|
||||||
`))
|
`))
|
||||||
must(conn.Exec(ctx, `
|
mustExecMain(conn.Exec(ctx, `
|
||||||
INSERT INTO sqldb_test (id, name)
|
INSERT INTO sqldb_test (id, name)
|
||||||
VALUES
|
VALUES
|
||||||
(1, "Alice"),
|
(1, "Alice"),
|
||||||
|
@ -61,17 +61,24 @@ func TestMain(m *testing.M) {
|
||||||
os.Exit(exitCode)
|
os.Exit(exitCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
func mustT(t *testing.T, r Result) Result {
|
func mustExec(t *testing.T, r ExecResult) ExecResult {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
if r.Error() != nil {
|
if r.Error() != nil {
|
||||||
t.Fatalf("Query failed: %v", r.Error())
|
t.Fatalf("Exec failed: %v", r.Error())
|
||||||
}
|
}
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
func must(r Result) Result {
|
func mustExecMain(r ExecResult) ExecResult {
|
||||||
if r.Error() != nil {
|
if r.Error() != nil {
|
||||||
log.Fatalf("Query failed: %v\n", r.Error())
|
log.Fatalf("Query failed: %v\n", r.Error())
|
||||||
}
|
}
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mustQuery(t *testing.T, err error) {
|
||||||
|
t.Helper()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Query failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,25 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"flag"
|
||||||
|
|
||||||
|
"github.com/localhots/gobelt/dbc"
|
||||||
|
"github.com/localhots/gobelt/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
dsn := flag.String("dsn", "", "Database source name")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
conn, err := dbc.Connect(ctx, dbc.MySQL, *dsn)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(ctx, "Failed to establish database conneciton", log.F{
|
||||||
|
"dsn": dsn,
|
||||||
|
"error": err,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = conn
|
||||||
|
}
|
Loading…
Reference in New Issue