1
0
Fork 0
This commit is contained in:
Gregory Eremin 2018-06-25 00:22:27 +02:00
parent a4b74992c4
commit 32b5607235
No known key found for this signature in database
GPG Key ID: 8CB79D42167BEB7F
7 changed files with 300 additions and 107 deletions

View File

@ -11,21 +11,34 @@ func Require(key string, dest interface{}) {
configs[key] = dest configs[key] = dest
} }
// Load reads the config file and distributes provided configuration to // Load ...
func Load(src string) error {
return load(src, toml.Decode)
}
// LoadFile reads the config file and distributes provided configuration to
// requested destinations. // requested destinations.
func Load(path string) error { func LoadFile(path string) error {
var conf map[string]toml.Primitive return load(path, toml.DecodeFile)
meta, err := toml.DecodeFile(path, &conf) }
func load(from string, withFn func(string, interface{}) (toml.MetaData, error)) error {
var sections map[string]toml.Primitive
meta, err := withFn(from, &sections)
if err != nil { if err != nil {
return err return err
} }
return decode(meta, sections)
}
for key, prim := range conf { func decode(meta toml.MetaData, sections map[string]toml.Primitive) error {
dest, ok := configs[key] for section, conf := range sections {
dest, ok := configs[section]
if !ok { if !ok {
continue continue
} }
err := meta.PrimitiveDecode(prim, dest)
err := meta.PrimitiveDecode(conf, dest)
if err != nil { if err != nil {
return err return err
} }

40
sqldb/callback.go Normal file
View File

@ -0,0 +1,40 @@
package sqldb
import (
"context"
"time"
)
type (
// BeforeCallback is a kind of function that can be called before a query is
// executed.
BeforeCallback func(ctx context.Context, query string)
// AfterCallback is a kind of function that can be called after a query was
// executed.
AfterCallback func(ctx context.Context, query string, took time.Duration, err error)
)
type callbacks struct {
before []BeforeCallback
after []AfterCallback
}
func (c *callbacks) addBefore(cb BeforeCallback) {
c.before = append(c.before, cb)
}
func (c *callbacks) addAfter(cb AfterCallback) {
c.after = append(c.after, cb)
}
func (c *callbacks) callBefore(ctx context.Context, query string) {
for _, cb := range c.before {
cb(ctx, query)
}
}
func (c *callbacks) callAfter(ctx context.Context, query string, took time.Duration, err error) {
for _, cb := range c.after {
cb(ctx, query, took, err)
}
}

67
sqldb/caller.go Normal file
View File

@ -0,0 +1,67 @@
package sqldb
import (
"context"
"database/sql"
"time"
)
type connOrTx interface {
executer
queryPerformer
}
type executer interface {
// Exec executes a query and does not expect any result.
Exec(ctx context.Context, query string, args ...interface{}) ExecResult
ExecNamed(ctx context.Context, query string, arg interface{}) ExecResult
}
type queryPerformer interface {
// Query executes a query and returns a result object that can later be used
// to retrieve values.
Query(ctx context.Context, query string, args ...interface{}) QueryResult
QueryNamed(ctx context.Context, query string, arg interface{}) QueryResult
}
type stdConnOrTx interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
}
type caller struct {
db stdConnOrTx
cb *callbacks
}
func (c *caller) Exec(ctx context.Context, query string, args ...interface{}) ExecResult {
c.cb.callBefore(ctx, query)
startedAt := time.Now()
res, err := c.db.ExecContext(ctx, query, args...)
c.cb.callAfter(ctx, query, time.Since(startedAt), err)
return &execResult{
db: c,
err: err,
res: res,
}
}
func (c *caller) ExecNamed(ctx context.Context, query string, arg interface{}) ExecResult {
return nil
}
func (c *caller) Query(ctx context.Context, query string, args ...interface{}) QueryResult {
c.cb.callBefore(ctx, query)
startedAt := time.Now()
rows, err := c.db.QueryContext(ctx, query, args...)
c.cb.callAfter(ctx, query, time.Since(startedAt), err)
return &queryResult{
err: err,
rows: rows,
}
}
func (c *caller) QueryNamed(ctx context.Context, query string, arg interface{}) QueryResult {
return nil
}

15
sqldb/chain_test.go Normal file
View File

@ -0,0 +1,15 @@
package sqldb
import (
"context"
"testing"
)
func TestCallChain(t *testing.T) {
ctx := context.Background()
mustT(t, conn.
Exec(ctx, "INSERT INTO sqldb_test (id, name) VALUES (3, 'Fred')").Then().
Exec(ctx, "UPDATE sqldb_test SET name = 'Wilson' WHERE id = 3").Then().
Exec(ctx, "DELETE FROM sqldb_test WHERE id = 3"),
)
}

View File

@ -3,27 +3,31 @@ package sqldb
import ( import (
"context" "context"
"database/sql" "database/sql"
"time"
"github.com/juju/errors" "github.com/juju/errors"
) )
// Conn represents database connection. // Conn represents database connection.
type Conn struct { type Conn interface {
db *sql.DB connOrTx
Begin(context.Context, func(Tx) error) error
beforeCallbacks []BeforeCallback Close() error
afterCallbacks []AfterCallback DB() *sql.DB
Before(BeforeCallback)
After(AfterCallback)
} }
type ( // Tx represents database transacation.
// BeforeCallback is a kind of function that can be called before a query is type Tx interface {
// executed. connOrTx
BeforeCallback func(ctx context.Context, query string) Commit() error
// AfterCallback is a kind of function that can be called after a query was Rollback() error
// executed. }
AfterCallback func(ctx context.Context, query string, took time.Duration, err error)
) type dbWrapper struct {
conn *sql.DB
*caller
}
// Flavor defines a kind of SQL database. // Flavor defines a kind of SQL database.
type Flavor string type Flavor string
@ -36,7 +40,7 @@ const (
) )
// Connect establishes a new database connection. // Connect establishes a new database connection.
func Connect(ctx context.Context, f Flavor, dsn string) (*Conn, error) { func Connect(ctx context.Context, f Flavor, dsn string) (Conn, error) {
conn, err := sql.Open(string(f), dsn) conn, err := sql.Open(string(f), dsn)
if err != nil { if err != nil {
return nil, errors.Annotate(err, "Failed to establish connection") return nil, errors.Annotate(err, "Failed to establish connection")
@ -45,62 +49,69 @@ func Connect(ctx context.Context, f Flavor, dsn string) (*Conn, error) {
if err != nil { if err != nil {
return nil, errors.Annotate(err, "Connection is not responding") return nil, errors.Annotate(err, "Connection is not responding")
} }
return &Conn{db: conn}, nil return &dbWrapper{
conn: conn,
caller: &caller{
db: conn,
cb: &callbacks{},
},
}, nil
} }
// Exec executes a query and does not expect any result. // Begin executes a transaction.
func (c *Conn) Exec(ctx context.Context, query string, args ...interface{}) Result { func (c *dbWrapper) Begin(ctx context.Context, fn func(tx Tx) error) error {
c.callBefore(ctx, query) tx, err := c.conn.BeginTx(ctx, &sql.TxOptions{})
startedAt := time.Now()
res, err := c.db.ExecContext(ctx, query, args...)
c.callAfter(ctx, query, time.Since(startedAt), err)
if err != nil { if err != nil {
return result{err: err} return err
} }
return result{res: res} err = fn(c.wrapTx(tx))
} if err != nil {
tx.Rollback()
// Query executes a query and returns a result object that can later be used to }
// retrieve values. return err
func (c *Conn) Query(ctx context.Context, query string, args ...interface{}) Result {
var r result
c.callBefore(ctx, query)
startedAt := time.Now()
r.rows, r.err = c.db.QueryContext(ctx, query, args...)
c.callAfter(ctx, query, time.Since(startedAt), r.err)
return r
} }
// DB returns the underlying DB object. // DB returns the underlying DB object.
func (c *Conn) DB() *sql.DB { func (c *dbWrapper) DB() *sql.DB {
return c.db return c.conn
} }
// Close closes the connection. // Close closes the connection.
func (c *Conn) Close() error { func (c *dbWrapper) Close() error {
return c.db.Close() return c.conn.Close()
} }
// Before adds a callback function that would be called before a query is // Before adds a callback function that would be called before a query is
// executed. // executed.
func (c *Conn) Before(cb BeforeCallback) { func (c *dbWrapper) Before(cb BeforeCallback) {
c.beforeCallbacks = append(c.beforeCallbacks, cb) c.cb.addBefore(cb)
} }
// After adds a callback function that would be called after a query was // After adds a callback function that would be called after a query was
// executed. // executed.
func (c *Conn) After(cb AfterCallback) { func (c *dbWrapper) After(cb AfterCallback) {
c.afterCallbacks = append(c.afterCallbacks, cb) c.cb.addAfter(cb)
} }
func (c *Conn) callBefore(ctx context.Context, query string) { func (c *dbWrapper) wrapTx(tx *sql.Tx) Tx {
for _, cb := range c.beforeCallbacks { return &txWrapper{
cb(ctx, query) tx: tx,
connOrTx: &caller{
db: tx,
cb: c.cb,
},
} }
} }
func (c *Conn) callAfter(ctx context.Context, query string, took time.Duration, err error) { type txWrapper struct {
for _, cb := range c.afterCallbacks { tx *sql.Tx
cb(ctx, query, took, err) connOrTx
} }
func (w *txWrapper) Commit() error {
return w.tx.Commit()
}
func (w *txWrapper) Rollback() error {
return w.tx.Rollback()
} }

79
sqldb/exec_result.go Normal file
View File

@ -0,0 +1,79 @@
package sqldb
import (
"context"
"database/sql"
)
// ExecResult ...
type ExecResult interface {
Error() error
LastInsertID() int64
RowsAffected() int64
Result() sql.Result
Then() ExecChain
}
type execResult struct {
db executer
err error
res sql.Result
}
func (r *execResult) Result() sql.Result {
return r.res
}
func (r *execResult) Error() error {
return r.err
}
func (r *execResult) LastInsertID() int64 {
if r.res == nil {
return 0
}
id, err := r.res.LastInsertId()
if err != nil {
return 0
}
return id
}
func (r *execResult) RowsAffected() int64 {
if r.res == nil {
return 0
}
ra, err := r.res.RowsAffected()
if err != nil {
return 0
}
return ra
}
func (r *execResult) Then() ExecChain {
if r.err != nil {
return &execChain{r.db}
}
return r.db
}
//
// Chain
//
// ExecChain ...
type ExecChain interface {
executer
}
type execChain struct {
executer
}
type brokenChain struct {
err error
}
func (c *brokenChain) Exec(_ context.Context, _ string, _ ...interface{}) ExecResult {
return &execResult{err: c.err}
}

View File

@ -6,61 +6,29 @@ import (
"reflect" "reflect"
) )
// Result represents query result. // QueryResult ...
type Result interface { type QueryResult interface {
// Load decodes rows into provided variable.
Load(dest interface{}) Result
// Error returns an error if one happened during query execution.
Error() error Error() error
// Rows returns original database rows object. Load(dest interface{}) error
Rows() *sql.Rows Rows() *sql.Rows
// LastInsertID returns the last inserted record ID for results obtained
// from Exec calls.
LastInsertID() int64
// RowsAffected returns the number of rows affected for results obtained
// from Exec calls.
RowsAffected() int64
} }
type result struct { type queryResult struct {
rows *sql.Rows
res sql.Result
err error err error
rows *sql.Rows
} }
func (r result) Rows() *sql.Rows { func (r *queryResult) Rows() *sql.Rows {
return r.rows return r.rows
} }
func (r result) Error() error { func (r *queryResult) Error() error {
return r.err return r.err
} }
func (r result) LastInsertID() int64 { func (r *queryResult) Load(dest interface{}) error {
if r.res == nil {
return 0
}
id, err := r.res.LastInsertId()
if err != nil {
return 0
}
return id
}
func (r result) RowsAffected() int64 {
if r.res == nil {
return 0
}
ra, err := r.res.RowsAffected()
if err != nil {
return 0
}
return ra
}
func (r result) Load(dest interface{}) Result {
if r.err != nil { if r.err != nil {
return r return r.err
} }
defer r.rows.Close() defer r.rows.Close()
@ -89,19 +57,19 @@ func (r result) Load(dest interface{}) Result {
} }
if r.err == nil && r.rows.Err() != nil { if r.err == nil && r.rows.Err() != nil {
return r.withError(r.rows.Err()) return r.rows.Err()
} }
return r return nil
} }
func (r *result) loadValue(dest interface{}) { func (r *queryResult) loadValue(dest interface{}) {
if r.rows.Next() { if r.rows.Next() {
r.err = r.rows.Scan(dest) r.err = r.rows.Scan(dest)
} }
} }
func (r *result) loadSlice(typ reflect.Type, dest interface{}) { func (r *queryResult) loadSlice(typ reflect.Type, dest interface{}) {
vSlice := reflect.MakeSlice(typ, 0, 0) vSlice := reflect.MakeSlice(typ, 0, 0)
for r.rows.Next() { for r.rows.Next() {
val := reflect.New(typ.Elem()) val := reflect.New(typ.Elem())
@ -114,7 +82,7 @@ func (r *result) loadSlice(typ reflect.Type, dest interface{}) {
reflect.ValueOf(dest).Elem().Set(vSlice) reflect.ValueOf(dest).Elem().Set(vSlice)
} }
func (r *result) loadMap(dest *map[string]interface{}) { func (r *queryResult) loadMap(dest *map[string]interface{}) {
if !r.rows.Next() { if !r.rows.Next() {
return return
} }
@ -155,7 +123,7 @@ func (r *result) loadMap(dest *map[string]interface{}) {
} }
} }
func (r *result) loadSliceOfMaps(dest *[]map[string]interface{}) { func (r *queryResult) loadSliceOfMaps(dest *[]map[string]interface{}) {
cols, err := r.rows.Columns() cols, err := r.rows.Columns()
if err != nil { if err != nil {
r.err = err r.err = err
@ -196,7 +164,7 @@ func (r *result) loadSliceOfMaps(dest *[]map[string]interface{}) {
} }
} }
func (r *result) loadStruct(typ reflect.Type, dest interface{}) { func (r *queryResult) loadStruct(typ reflect.Type, dest interface{}) {
if !r.rows.Next() { if !r.rows.Next() {
return return
} }
@ -232,7 +200,7 @@ func (r *result) loadStruct(typ reflect.Type, dest interface{}) {
} }
} }
func (r *result) loadSliceOfStructs(typ reflect.Type, dest interface{}) { func (r *queryResult) loadSliceOfStructs(typ reflect.Type, dest interface{}) {
cols, err := r.rows.Columns() cols, err := r.rows.Columns()
if err != nil { if err != nil {
r.err = err r.err = err
@ -271,7 +239,7 @@ func (r *result) loadSliceOfStructs(typ reflect.Type, dest interface{}) {
} }
} }
func (r result) withError(err error) Result { func (r *queryResult) withError(err error) *queryResult {
r.err = err r.err = err
return r return r
} }