1
0
Fork 0
This commit is contained in:
Gregory Eremin 2018-06-25 00:22:27 +02:00
parent a4b74992c4
commit 32b5607235
No known key found for this signature in database
GPG Key ID: 8CB79D42167BEB7F
7 changed files with 300 additions and 107 deletions

View File

@ -11,21 +11,34 @@ func Require(key string, dest interface{}) {
configs[key] = dest
}
// Load reads the config file and distributes provided configuration to
// Load ...
func Load(src string) error {
return load(src, toml.Decode)
}
// LoadFile reads the config file and distributes provided configuration to
// requested destinations.
func Load(path string) error {
var conf map[string]toml.Primitive
meta, err := toml.DecodeFile(path, &conf)
func LoadFile(path string) error {
return load(path, toml.DecodeFile)
}
func load(from string, withFn func(string, interface{}) (toml.MetaData, error)) error {
var sections map[string]toml.Primitive
meta, err := withFn(from, &sections)
if err != nil {
return err
}
return decode(meta, sections)
}
for key, prim := range conf {
dest, ok := configs[key]
func decode(meta toml.MetaData, sections map[string]toml.Primitive) error {
for section, conf := range sections {
dest, ok := configs[section]
if !ok {
continue
}
err := meta.PrimitiveDecode(prim, dest)
err := meta.PrimitiveDecode(conf, dest)
if err != nil {
return err
}

40
sqldb/callback.go Normal file
View File

@ -0,0 +1,40 @@
package sqldb
import (
"context"
"time"
)
type (
// BeforeCallback is a kind of function that can be called before a query is
// executed.
BeforeCallback func(ctx context.Context, query string)
// AfterCallback is a kind of function that can be called after a query was
// executed.
AfterCallback func(ctx context.Context, query string, took time.Duration, err error)
)
type callbacks struct {
before []BeforeCallback
after []AfterCallback
}
func (c *callbacks) addBefore(cb BeforeCallback) {
c.before = append(c.before, cb)
}
func (c *callbacks) addAfter(cb AfterCallback) {
c.after = append(c.after, cb)
}
func (c *callbacks) callBefore(ctx context.Context, query string) {
for _, cb := range c.before {
cb(ctx, query)
}
}
func (c *callbacks) callAfter(ctx context.Context, query string, took time.Duration, err error) {
for _, cb := range c.after {
cb(ctx, query, took, err)
}
}

67
sqldb/caller.go Normal file
View File

@ -0,0 +1,67 @@
package sqldb
import (
"context"
"database/sql"
"time"
)
type connOrTx interface {
executer
queryPerformer
}
type executer interface {
// Exec executes a query and does not expect any result.
Exec(ctx context.Context, query string, args ...interface{}) ExecResult
ExecNamed(ctx context.Context, query string, arg interface{}) ExecResult
}
type queryPerformer interface {
// Query executes a query and returns a result object that can later be used
// to retrieve values.
Query(ctx context.Context, query string, args ...interface{}) QueryResult
QueryNamed(ctx context.Context, query string, arg interface{}) QueryResult
}
type stdConnOrTx interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
}
type caller struct {
db stdConnOrTx
cb *callbacks
}
func (c *caller) Exec(ctx context.Context, query string, args ...interface{}) ExecResult {
c.cb.callBefore(ctx, query)
startedAt := time.Now()
res, err := c.db.ExecContext(ctx, query, args...)
c.cb.callAfter(ctx, query, time.Since(startedAt), err)
return &execResult{
db: c,
err: err,
res: res,
}
}
func (c *caller) ExecNamed(ctx context.Context, query string, arg interface{}) ExecResult {
return nil
}
func (c *caller) Query(ctx context.Context, query string, args ...interface{}) QueryResult {
c.cb.callBefore(ctx, query)
startedAt := time.Now()
rows, err := c.db.QueryContext(ctx, query, args...)
c.cb.callAfter(ctx, query, time.Since(startedAt), err)
return &queryResult{
err: err,
rows: rows,
}
}
func (c *caller) QueryNamed(ctx context.Context, query string, arg interface{}) QueryResult {
return nil
}

15
sqldb/chain_test.go Normal file
View File

@ -0,0 +1,15 @@
package sqldb
import (
"context"
"testing"
)
func TestCallChain(t *testing.T) {
ctx := context.Background()
mustT(t, conn.
Exec(ctx, "INSERT INTO sqldb_test (id, name) VALUES (3, 'Fred')").Then().
Exec(ctx, "UPDATE sqldb_test SET name = 'Wilson' WHERE id = 3").Then().
Exec(ctx, "DELETE FROM sqldb_test WHERE id = 3"),
)
}

View File

@ -3,27 +3,31 @@ package sqldb
import (
"context"
"database/sql"
"time"
"github.com/juju/errors"
)
// Conn represents database connection.
type Conn struct {
db *sql.DB
beforeCallbacks []BeforeCallback
afterCallbacks []AfterCallback
type Conn interface {
connOrTx
Begin(context.Context, func(Tx) error) error
Close() error
DB() *sql.DB
Before(BeforeCallback)
After(AfterCallback)
}
type (
// BeforeCallback is a kind of function that can be called before a query is
// executed.
BeforeCallback func(ctx context.Context, query string)
// AfterCallback is a kind of function that can be called after a query was
// executed.
AfterCallback func(ctx context.Context, query string, took time.Duration, err error)
)
// Tx represents database transacation.
type Tx interface {
connOrTx
Commit() error
Rollback() error
}
type dbWrapper struct {
conn *sql.DB
*caller
}
// Flavor defines a kind of SQL database.
type Flavor string
@ -36,7 +40,7 @@ const (
)
// Connect establishes a new database connection.
func Connect(ctx context.Context, f Flavor, dsn string) (*Conn, error) {
func Connect(ctx context.Context, f Flavor, dsn string) (Conn, error) {
conn, err := sql.Open(string(f), dsn)
if err != nil {
return nil, errors.Annotate(err, "Failed to establish connection")
@ -45,62 +49,69 @@ func Connect(ctx context.Context, f Flavor, dsn string) (*Conn, error) {
if err != nil {
return nil, errors.Annotate(err, "Connection is not responding")
}
return &Conn{db: conn}, nil
return &dbWrapper{
conn: conn,
caller: &caller{
db: conn,
cb: &callbacks{},
},
}, nil
}
// Exec executes a query and does not expect any result.
func (c *Conn) Exec(ctx context.Context, query string, args ...interface{}) Result {
c.callBefore(ctx, query)
startedAt := time.Now()
res, err := c.db.ExecContext(ctx, query, args...)
c.callAfter(ctx, query, time.Since(startedAt), err)
// Begin executes a transaction.
func (c *dbWrapper) Begin(ctx context.Context, fn func(tx Tx) error) error {
tx, err := c.conn.BeginTx(ctx, &sql.TxOptions{})
if err != nil {
return result{err: err}
return err
}
return result{res: res}
err = fn(c.wrapTx(tx))
if err != nil {
tx.Rollback()
}
// Query executes a query and returns a result object that can later be used to
// retrieve values.
func (c *Conn) Query(ctx context.Context, query string, args ...interface{}) Result {
var r result
c.callBefore(ctx, query)
startedAt := time.Now()
r.rows, r.err = c.db.QueryContext(ctx, query, args...)
c.callAfter(ctx, query, time.Since(startedAt), r.err)
return r
return err
}
// DB returns the underlying DB object.
func (c *Conn) DB() *sql.DB {
return c.db
func (c *dbWrapper) DB() *sql.DB {
return c.conn
}
// Close closes the connection.
func (c *Conn) Close() error {
return c.db.Close()
func (c *dbWrapper) Close() error {
return c.conn.Close()
}
// Before adds a callback function that would be called before a query is
// executed.
func (c *Conn) Before(cb BeforeCallback) {
c.beforeCallbacks = append(c.beforeCallbacks, cb)
func (c *dbWrapper) Before(cb BeforeCallback) {
c.cb.addBefore(cb)
}
// After adds a callback function that would be called after a query was
// executed.
func (c *Conn) After(cb AfterCallback) {
c.afterCallbacks = append(c.afterCallbacks, cb)
func (c *dbWrapper) After(cb AfterCallback) {
c.cb.addAfter(cb)
}
func (c *Conn) callBefore(ctx context.Context, query string) {
for _, cb := range c.beforeCallbacks {
cb(ctx, query)
func (c *dbWrapper) wrapTx(tx *sql.Tx) Tx {
return &txWrapper{
tx: tx,
connOrTx: &caller{
db: tx,
cb: c.cb,
},
}
}
func (c *Conn) callAfter(ctx context.Context, query string, took time.Duration, err error) {
for _, cb := range c.afterCallbacks {
cb(ctx, query, took, err)
type txWrapper struct {
tx *sql.Tx
connOrTx
}
func (w *txWrapper) Commit() error {
return w.tx.Commit()
}
func (w *txWrapper) Rollback() error {
return w.tx.Rollback()
}

79
sqldb/exec_result.go Normal file
View File

@ -0,0 +1,79 @@
package sqldb
import (
"context"
"database/sql"
)
// ExecResult ...
type ExecResult interface {
Error() error
LastInsertID() int64
RowsAffected() int64
Result() sql.Result
Then() ExecChain
}
type execResult struct {
db executer
err error
res sql.Result
}
func (r *execResult) Result() sql.Result {
return r.res
}
func (r *execResult) Error() error {
return r.err
}
func (r *execResult) LastInsertID() int64 {
if r.res == nil {
return 0
}
id, err := r.res.LastInsertId()
if err != nil {
return 0
}
return id
}
func (r *execResult) RowsAffected() int64 {
if r.res == nil {
return 0
}
ra, err := r.res.RowsAffected()
if err != nil {
return 0
}
return ra
}
func (r *execResult) Then() ExecChain {
if r.err != nil {
return &execChain{r.db}
}
return r.db
}
//
// Chain
//
// ExecChain ...
type ExecChain interface {
executer
}
type execChain struct {
executer
}
type brokenChain struct {
err error
}
func (c *brokenChain) Exec(_ context.Context, _ string, _ ...interface{}) ExecResult {
return &execResult{err: c.err}
}

View File

@ -6,61 +6,29 @@ import (
"reflect"
)
// Result represents query result.
type Result interface {
// Load decodes rows into provided variable.
Load(dest interface{}) Result
// Error returns an error if one happened during query execution.
// QueryResult ...
type QueryResult interface {
Error() error
// Rows returns original database rows object.
Load(dest interface{}) error
Rows() *sql.Rows
// LastInsertID returns the last inserted record ID for results obtained
// from Exec calls.
LastInsertID() int64
// RowsAffected returns the number of rows affected for results obtained
// from Exec calls.
RowsAffected() int64
}
type result struct {
rows *sql.Rows
res sql.Result
type queryResult struct {
err error
rows *sql.Rows
}
func (r result) Rows() *sql.Rows {
func (r *queryResult) Rows() *sql.Rows {
return r.rows
}
func (r result) Error() error {
func (r *queryResult) Error() error {
return r.err
}
func (r result) LastInsertID() int64 {
if r.res == nil {
return 0
}
id, err := r.res.LastInsertId()
if err != nil {
return 0
}
return id
}
func (r result) RowsAffected() int64 {
if r.res == nil {
return 0
}
ra, err := r.res.RowsAffected()
if err != nil {
return 0
}
return ra
}
func (r result) Load(dest interface{}) Result {
func (r *queryResult) Load(dest interface{}) error {
if r.err != nil {
return r
return r.err
}
defer r.rows.Close()
@ -89,19 +57,19 @@ func (r result) Load(dest interface{}) Result {
}
if r.err == nil && r.rows.Err() != nil {
return r.withError(r.rows.Err())
return r.rows.Err()
}
return r
return nil
}
func (r *result) loadValue(dest interface{}) {
func (r *queryResult) loadValue(dest interface{}) {
if r.rows.Next() {
r.err = r.rows.Scan(dest)
}
}
func (r *result) loadSlice(typ reflect.Type, dest interface{}) {
func (r *queryResult) loadSlice(typ reflect.Type, dest interface{}) {
vSlice := reflect.MakeSlice(typ, 0, 0)
for r.rows.Next() {
val := reflect.New(typ.Elem())
@ -114,7 +82,7 @@ func (r *result) loadSlice(typ reflect.Type, dest interface{}) {
reflect.ValueOf(dest).Elem().Set(vSlice)
}
func (r *result) loadMap(dest *map[string]interface{}) {
func (r *queryResult) loadMap(dest *map[string]interface{}) {
if !r.rows.Next() {
return
}
@ -155,7 +123,7 @@ func (r *result) loadMap(dest *map[string]interface{}) {
}
}
func (r *result) loadSliceOfMaps(dest *[]map[string]interface{}) {
func (r *queryResult) loadSliceOfMaps(dest *[]map[string]interface{}) {
cols, err := r.rows.Columns()
if err != nil {
r.err = err
@ -196,7 +164,7 @@ func (r *result) loadSliceOfMaps(dest *[]map[string]interface{}) {
}
}
func (r *result) loadStruct(typ reflect.Type, dest interface{}) {
func (r *queryResult) loadStruct(typ reflect.Type, dest interface{}) {
if !r.rows.Next() {
return
}
@ -232,7 +200,7 @@ func (r *result) loadStruct(typ reflect.Type, dest interface{}) {
}
}
func (r *result) loadSliceOfStructs(typ reflect.Type, dest interface{}) {
func (r *queryResult) loadSliceOfStructs(typ reflect.Type, dest interface{}) {
cols, err := r.rows.Columns()
if err != nil {
r.err = err
@ -271,7 +239,7 @@ func (r *result) loadSliceOfStructs(typ reflect.Type, dest interface{}) {
}
}
func (r result) withError(err error) Result {
func (r *queryResult) withError(err error) *queryResult {
r.err = err
return r
}