1
0
Fork 0

Move stuff around

This commit is contained in:
Gregory Eremin 2018-07-07 14:11:23 +02:00
parent 9b845584ed
commit 6c6ccaf0b3
No known key found for this signature in database
GPG Key ID: 8CB79D42167BEB7F
10 changed files with 102 additions and 50 deletions

11
dbc/bench/dbc_test.go Normal file
View File

@ -0,0 +1,11 @@
package bench
import (
"testing"
)
func BenchLoadSliceOfStructs(b *testing.B) {
for i := 0; i < b.N; i++ {
}
}

View File

@ -1,4 +1,4 @@
package sqldb package dbc
import ( import (
"context" "context"

View File

@ -1,4 +1,4 @@
package sqldb package dbc
import ( import (
"context" "context"
@ -7,7 +7,7 @@ import (
func TestCallChain(t *testing.T) { func TestCallChain(t *testing.T) {
ctx := context.Background() ctx := context.Background()
mustT(t, conn. mustExec(t, conn.
Exec(ctx, "INSERT INTO sqldb_test (id, name) VALUES (3, 'Fred')").Then(). Exec(ctx, "INSERT INTO sqldb_test (id, name) VALUES (3, 'Fred')").Then().
Exec(ctx, "UPDATE sqldb_test SET name = 'Wilson' WHERE id = 3").Then(). Exec(ctx, "UPDATE sqldb_test SET name = 'Wilson' WHERE id = 3").Then().
Exec(ctx, "DELETE FROM sqldb_test WHERE id = 3"), Exec(ctx, "DELETE FROM sqldb_test WHERE id = 3"),

View File

@ -1,4 +1,4 @@
package sqldb package dbc
import ( import (
"context" "context"
@ -10,10 +10,19 @@ import (
// Conn represents database connection. // Conn represents database connection.
type Conn interface { type Conn interface {
connOrTx connOrTx
// Begin executes a transaction.
Begin(context.Context, func(Tx) error) error Begin(context.Context, func(Tx) error) error
// BeginCustom executes a transaction with provided options.
BeginCustom(context.Context, func(Tx) error, *sql.TxOptions) error
// Close closes the connection.
Close() error Close() error
// DB returns the underlying DB object.
DB() *sql.DB DB() *sql.DB
// Before adds a callback function that would be called before a query is
// executed.
Before(BeforeCallback) Before(BeforeCallback)
// After adds a callback function that would be called after a query was
// executed.
After(AfterCallback) After(AfterCallback)
} }
@ -58,9 +67,15 @@ func Connect(ctx context.Context, f Flavor, dsn string) (Conn, error) {
}, nil }, nil
} }
// Begin executes a transaction.
func (c *dbWrapper) Begin(ctx context.Context, fn func(tx Tx) error) error { func (c *dbWrapper) Begin(ctx context.Context, fn func(tx Tx) error) error {
tx, err := c.conn.BeginTx(ctx, &sql.TxOptions{}) return c.BeginCustom(ctx, fn, nil)
}
func (c *dbWrapper) BeginCustom(ctx context.Context, fn func(tx Tx) error, opts *sql.TxOptions) error {
if opts == nil {
opts = &sql.TxOptions{}
}
tx, err := c.conn.BeginTx(ctx, opts)
if err != nil { if err != nil {
return err return err
} }
@ -71,24 +86,18 @@ func (c *dbWrapper) Begin(ctx context.Context, fn func(tx Tx) error) error {
return err return err
} }
// DB returns the underlying DB object.
func (c *dbWrapper) DB() *sql.DB { func (c *dbWrapper) DB() *sql.DB {
return c.conn return c.conn
} }
// Close closes the connection.
func (c *dbWrapper) Close() error { func (c *dbWrapper) Close() error {
return c.conn.Close() return c.conn.Close()
} }
// Before adds a callback function that would be called before a query is
// executed.
func (c *dbWrapper) Before(cb BeforeCallback) { func (c *dbWrapper) Before(cb BeforeCallback) {
c.cb.addBefore(cb) c.cb.addBefore(cb)
} }
// After adds a callback function that would be called after a query was
// executed.
func (c *dbWrapper) After(cb AfterCallback) { func (c *dbWrapper) After(cb AfterCallback) {
c.cb.addAfter(cb) c.cb.addAfter(cb)
} }

View File

@ -1,4 +1,4 @@
package sqldb package dbc
import ( import (
"context" "context"

View File

@ -1,4 +1,4 @@
package sqldb package dbc
import ( import (
"context" "context"
@ -20,8 +20,8 @@ type executer interface {
type queryPerformer interface { type queryPerformer interface {
// Query executes a query and returns a result object that can later be used // Query executes a query and returns a result object that can later be used
// to retrieve values. // to retrieve values.
Query(ctx context.Context, query string, args ...interface{}) QueryResult Query(ctx context.Context, query string, args ...interface{}) Rows
QueryNamed(ctx context.Context, query string, arg interface{}) QueryResult QueryNamed(ctx context.Context, query string, arg interface{}) Rows
} }
type stdConnOrTx interface { type stdConnOrTx interface {
@ -51,17 +51,17 @@ func (c *caller) ExecNamed(ctx context.Context, query string, arg interface{}) E
return nil return nil
} }
func (c *caller) Query(ctx context.Context, query string, args ...interface{}) QueryResult { func (c *caller) Query(ctx context.Context, query string, args ...interface{}) Rows {
c.cb.callBefore(ctx, query) c.cb.callBefore(ctx, query)
startedAt := time.Now() startedAt := time.Now()
rows, err := c.db.QueryContext(ctx, query, args...) r, err := c.db.QueryContext(ctx, query, args...)
c.cb.callAfter(ctx, query, time.Since(startedAt), err) c.cb.callAfter(ctx, query, time.Since(startedAt), err)
return &queryResult{ return &rows{
err: err, err: err,
rows: rows, rows: r,
} }
} }
func (c *caller) QueryNamed(ctx context.Context, query string, arg interface{}) QueryResult { func (c *caller) QueryNamed(ctx context.Context, query string, arg interface{}) Rows {
return nil return nil
} }

View File

@ -1,4 +1,4 @@
package sqldb package dbc
import ( import (
"context" "context"
@ -11,7 +11,7 @@ func TestLoadSingleValue(t *testing.T) {
ctx := context.Background() ctx := context.Background()
exp := int(1) exp := int(1)
var out int var out int
mustT(t, conn.Query(ctx, "SELECT 1").Load(&out)) mustQuery(t, conn.Query(ctx, "SELECT 1").Load(&out))
if exp != out { if exp != out {
t.Errorf("Value doesn't match: expected %d, got %d", exp, out) t.Errorf("Value doesn't match: expected %d, got %d", exp, out)
} }
@ -21,7 +21,7 @@ func TestLoadSlice(t *testing.T) {
ctx := context.Background() ctx := context.Background()
exp := []int{1, 2} exp := []int{1, 2}
var out []int var out []int
mustT(t, conn.Query(ctx, "SELECT id FROM sqldb_test").Load(&out)) mustQuery(t, conn.Query(ctx, "SELECT id FROM sqldb_test").Load(&out))
if !cmp.Equal(exp, out) { if !cmp.Equal(exp, out) {
t.Errorf("Values dont't match: %s", cmp.Diff(exp, out)) t.Errorf("Values dont't match: %s", cmp.Diff(exp, out))
} }
@ -31,7 +31,7 @@ func TestLoadMap(t *testing.T) {
ctx := context.Background() ctx := context.Background()
exp := map[string]interface{}{"id": int64(1), "name": "Alice"} exp := map[string]interface{}{"id": int64(1), "name": "Alice"}
var out map[string]interface{} var out map[string]interface{}
mustT(t, conn.Query(ctx, "SELECT * FROM sqldb_test WHERE id = 1").Load(&out)) mustQuery(t, conn.Query(ctx, "SELECT * FROM sqldb_test WHERE id = 1").Load(&out))
if !cmp.Equal(exp, out) { if !cmp.Equal(exp, out) {
t.Errorf("Record doesn't match: %s", cmp.Diff(exp, out)) t.Errorf("Record doesn't match: %s", cmp.Diff(exp, out))
} }
@ -44,7 +44,7 @@ func TestLoadSliceOfMaps(t *testing.T) {
{"id": int64(2), "name": "Bob"}, {"id": int64(2), "name": "Bob"},
} }
var out []map[string]interface{} var out []map[string]interface{}
mustT(t, conn.Query(ctx, "SELECT * FROM sqldb_test ORDER BY id ASC").Load(&out)) mustQuery(t, conn.Query(ctx, "SELECT * FROM sqldb_test ORDER BY id ASC").Load(&out))
if !cmp.Equal(exp, out) { if !cmp.Equal(exp, out) {
t.Errorf("Records don't match: %s", cmp.Diff(exp, out)) t.Errorf("Records don't match: %s", cmp.Diff(exp, out))
} }
@ -54,7 +54,7 @@ func TestLoadStruct(t *testing.T) {
ctx := context.Background() ctx := context.Background()
exp := record{ID: 1, Name: "Alice"} exp := record{ID: 1, Name: "Alice"}
var out record var out record
mustT(t, conn.Query(ctx, "SELECT * FROM sqldb_test WHERE id = 1").Load(&out)) mustQuery(t, conn.Query(ctx, "SELECT * FROM sqldb_test WHERE id = 1").Load(&out))
if !cmp.Equal(exp, out) { if !cmp.Equal(exp, out) {
t.Errorf("Record doesn't match: %s", cmp.Diff(exp, out)) t.Errorf("Record doesn't match: %s", cmp.Diff(exp, out))
} }
@ -67,7 +67,7 @@ func TestLoadSliceOfStructs(t *testing.T) {
{ID: 2, Name: "Bob"}, {ID: 2, Name: "Bob"},
} }
var out []record var out []record
mustT(t, conn.Query(ctx, "SELECT * FROM sqldb_test").Load(&out)) mustQuery(t, conn.Query(ctx, "SELECT * FROM sqldb_test").Load(&out))
if !cmp.Equal(exp, out) { if !cmp.Equal(exp, out) {
t.Errorf("Records don't match: %s", cmp.Diff(exp, out)) t.Errorf("Records don't match: %s", cmp.Diff(exp, out))
} }

View File

@ -1,4 +1,4 @@
package sqldb package dbc
import ( import (
"database/sql" "database/sql"
@ -8,8 +8,8 @@ import (
"github.com/localhots/gobelt/reflect2" "github.com/localhots/gobelt/reflect2"
) )
// QueryResult ... // Rows ...
type QueryResult interface { type Rows interface {
Error() error Error() error
Load(dest interface{}) error Load(dest interface{}) error
Rows() *sql.Rows Rows() *sql.Rows
@ -17,20 +17,20 @@ type QueryResult interface {
const tagName = "db" const tagName = "db"
type queryResult struct { type rows struct {
err error err error
rows *sql.Rows rows *sql.Rows
} }
func (r *queryResult) Rows() *sql.Rows { func (r *rows) Rows() *sql.Rows {
return r.rows return r.rows
} }
func (r *queryResult) Error() error { func (r *rows) Error() error {
return r.err return r.err
} }
func (r *queryResult) Load(dest interface{}) error { func (r *rows) Load(dest interface{}) error {
if r.err != nil { if r.err != nil {
return r.err return r.err
} }
@ -67,13 +67,13 @@ func (r *queryResult) Load(dest interface{}) error {
return nil return nil
} }
func (r *queryResult) loadValue(dest interface{}) { func (r *rows) loadValue(dest interface{}) {
if r.rows.Next() { if r.rows.Next() {
r.err = r.rows.Scan(dest) r.err = r.rows.Scan(dest)
} }
} }
func (r *queryResult) loadSlice(typ reflect.Type, dest interface{}) { func (r *rows) loadSlice(typ reflect.Type, dest interface{}) {
vSlice := reflect.MakeSlice(typ, 0, 0) vSlice := reflect.MakeSlice(typ, 0, 0)
for r.rows.Next() { for r.rows.Next() {
val := reflect.New(typ.Elem()) val := reflect.New(typ.Elem())
@ -86,7 +86,7 @@ func (r *queryResult) loadSlice(typ reflect.Type, dest interface{}) {
reflect.ValueOf(dest).Elem().Set(vSlice) reflect.ValueOf(dest).Elem().Set(vSlice)
} }
func (r *queryResult) loadMap(dest *map[string]interface{}) { func (r *rows) loadMap(dest *map[string]interface{}) {
if !r.rows.Next() { if !r.rows.Next() {
return return
} }
@ -127,7 +127,7 @@ func (r *queryResult) loadMap(dest *map[string]interface{}) {
} }
} }
func (r *queryResult) loadSliceOfMaps(dest *[]map[string]interface{}) { func (r *rows) loadSliceOfMaps(dest *[]map[string]interface{}) {
cols, err := r.rows.Columns() cols, err := r.rows.Columns()
if err != nil { if err != nil {
r.err = err r.err = err
@ -168,7 +168,7 @@ func (r *queryResult) loadSliceOfMaps(dest *[]map[string]interface{}) {
} }
} }
func (r *queryResult) loadStruct(typ reflect.Type, dest interface{}) { func (r *rows) loadStruct(typ reflect.Type, dest interface{}) {
if !r.rows.Next() { if !r.rows.Next() {
return return
} }
@ -204,7 +204,7 @@ func (r *queryResult) loadStruct(typ reflect.Type, dest interface{}) {
} }
} }
func (r *queryResult) loadSliceOfStructs(typ reflect.Type, dest interface{}) { func (r *rows) loadSliceOfStructs(typ reflect.Type, dest interface{}) {
cols, err := r.rows.Columns() cols, err := r.rows.Columns()
if err != nil { if err != nil {
r.err = err r.err = err
@ -243,7 +243,7 @@ func (r *queryResult) loadSliceOfStructs(typ reflect.Type, dest interface{}) {
} }
} }
func (r *queryResult) withError(err error) *queryResult { func (r *rows) withError(err error) Rows {
r.err = err r.err = err
return r return r
} }

View File

@ -1,4 +1,4 @@
package sqldb package dbc
import ( import (
"context" "context"
@ -16,7 +16,7 @@ type record struct {
Name string `db:"name"` Name string `db:"name"`
} }
var conn *Conn var conn Conn
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
dsn := flag.String("dsn", "", "Database source name") dsn := flag.String("dsn", "", "Database source name")
@ -35,17 +35,17 @@ func TestMain(m *testing.M) {
} }
log.Println("Seeding database") log.Println("Seeding database")
must(conn.Exec(ctx, ` mustExecMain(conn.Exec(ctx, `
DROP TABLE IF EXISTS sqldb_test DROP TABLE IF EXISTS sqldb_test
`)) `))
must(conn.Exec(ctx, ` mustExecMain(conn.Exec(ctx, `
CREATE TABLE sqldb_test ( CREATE TABLE sqldb_test (
id int(11) UNSIGNED NOT NULL, id int(11) UNSIGNED NOT NULL,
name VARCHAR(10) DEFAULT '', name VARCHAR(10) DEFAULT '',
PRIMARY KEY (id) PRIMARY KEY (id)
) ENGINE=InnoDB DEFAULT CHARSET=ascii ) ENGINE=InnoDB DEFAULT CHARSET=ascii
`)) `))
must(conn.Exec(ctx, ` mustExecMain(conn.Exec(ctx, `
INSERT INTO sqldb_test (id, name) INSERT INTO sqldb_test (id, name)
VALUES VALUES
(1, "Alice"), (1, "Alice"),
@ -61,17 +61,24 @@ func TestMain(m *testing.M) {
os.Exit(exitCode) os.Exit(exitCode)
} }
func mustT(t *testing.T, r Result) Result { func mustExec(t *testing.T, r ExecResult) ExecResult {
t.Helper() t.Helper()
if r.Error() != nil { if r.Error() != nil {
t.Fatalf("Query failed: %v", r.Error()) t.Fatalf("Exec failed: %v", r.Error())
} }
return r return r
} }
func must(r Result) Result { func mustExecMain(r ExecResult) ExecResult {
if r.Error() != nil { if r.Error() != nil {
log.Fatalf("Query failed: %v\n", r.Error()) log.Fatalf("Query failed: %v\n", r.Error())
} }
return r return r
} }
func mustQuery(t *testing.T, err error) {
t.Helper()
if err != nil {
t.Fatalf("Query failed: %v", err)
}
}

View File

@ -0,0 +1,25 @@
package main
import (
"context"
"flag"
"github.com/localhots/gobelt/dbc"
"github.com/localhots/gobelt/log"
)
func main() {
dsn := flag.String("dsn", "", "Database source name")
flag.Parse()
ctx := context.Background()
conn, err := dbc.Connect(ctx, dbc.MySQL, *dsn)
if err != nil {
log.Fatal(ctx, "Failed to establish database conneciton", log.F{
"dsn": dsn,
"error": err,
})
}
_ = conn
}