diff --git a/config/config.go b/config/config.go index 1233502..8325413 100644 --- a/config/config.go +++ b/config/config.go @@ -11,21 +11,34 @@ func Require(key string, dest interface{}) { configs[key] = dest } -// Load reads the config file and distributes provided configuration to +// Load ... +func Load(src string) error { + return load(src, toml.Decode) +} + +// LoadFile reads the config file and distributes provided configuration to // requested destinations. -func Load(path string) error { - var conf map[string]toml.Primitive - meta, err := toml.DecodeFile(path, &conf) +func LoadFile(path string) error { + return load(path, toml.DecodeFile) +} + +func load(from string, withFn func(string, interface{}) (toml.MetaData, error)) error { + var sections map[string]toml.Primitive + meta, err := withFn(from, §ions) if err != nil { return err } + return decode(meta, sections) +} - for key, prim := range conf { - dest, ok := configs[key] +func decode(meta toml.MetaData, sections map[string]toml.Primitive) error { + for section, conf := range sections { + dest, ok := configs[section] if !ok { continue } - err := meta.PrimitiveDecode(prim, dest) + + err := meta.PrimitiveDecode(conf, dest) if err != nil { return err } diff --git a/sqldb/callback.go b/sqldb/callback.go new file mode 100644 index 0000000..78cb2d8 --- /dev/null +++ b/sqldb/callback.go @@ -0,0 +1,40 @@ +package sqldb + +import ( + "context" + "time" +) + +type ( + // BeforeCallback is a kind of function that can be called before a query is + // executed. + BeforeCallback func(ctx context.Context, query string) + // AfterCallback is a kind of function that can be called after a query was + // executed. + AfterCallback func(ctx context.Context, query string, took time.Duration, err error) +) + +type callbacks struct { + before []BeforeCallback + after []AfterCallback +} + +func (c *callbacks) addBefore(cb BeforeCallback) { + c.before = append(c.before, cb) +} + +func (c *callbacks) addAfter(cb AfterCallback) { + c.after = append(c.after, cb) +} + +func (c *callbacks) callBefore(ctx context.Context, query string) { + for _, cb := range c.before { + cb(ctx, query) + } +} + +func (c *callbacks) callAfter(ctx context.Context, query string, took time.Duration, err error) { + for _, cb := range c.after { + cb(ctx, query, took, err) + } +} diff --git a/sqldb/caller.go b/sqldb/caller.go new file mode 100644 index 0000000..4211e41 --- /dev/null +++ b/sqldb/caller.go @@ -0,0 +1,67 @@ +package sqldb + +import ( + "context" + "database/sql" + "time" +) + +type connOrTx interface { + executer + queryPerformer +} + +type executer interface { + // Exec executes a query and does not expect any result. + Exec(ctx context.Context, query string, args ...interface{}) ExecResult + ExecNamed(ctx context.Context, query string, arg interface{}) ExecResult +} + +type queryPerformer interface { + // Query executes a query and returns a result object that can later be used + // to retrieve values. + Query(ctx context.Context, query string, args ...interface{}) QueryResult + QueryNamed(ctx context.Context, query string, arg interface{}) QueryResult +} + +type stdConnOrTx interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) +} + +type caller struct { + db stdConnOrTx + cb *callbacks +} + +func (c *caller) Exec(ctx context.Context, query string, args ...interface{}) ExecResult { + c.cb.callBefore(ctx, query) + startedAt := time.Now() + res, err := c.db.ExecContext(ctx, query, args...) + c.cb.callAfter(ctx, query, time.Since(startedAt), err) + + return &execResult{ + db: c, + err: err, + res: res, + } +} + +func (c *caller) ExecNamed(ctx context.Context, query string, arg interface{}) ExecResult { + return nil +} + +func (c *caller) Query(ctx context.Context, query string, args ...interface{}) QueryResult { + c.cb.callBefore(ctx, query) + startedAt := time.Now() + rows, err := c.db.QueryContext(ctx, query, args...) + c.cb.callAfter(ctx, query, time.Since(startedAt), err) + return &queryResult{ + err: err, + rows: rows, + } +} + +func (c *caller) QueryNamed(ctx context.Context, query string, arg interface{}) QueryResult { + return nil +} diff --git a/sqldb/chain_test.go b/sqldb/chain_test.go new file mode 100644 index 0000000..38408aa --- /dev/null +++ b/sqldb/chain_test.go @@ -0,0 +1,15 @@ +package sqldb + +import ( + "context" + "testing" +) + +func TestCallChain(t *testing.T) { + ctx := context.Background() + mustT(t, conn. + Exec(ctx, "INSERT INTO sqldb_test (id, name) VALUES (3, 'Fred')").Then(). + Exec(ctx, "UPDATE sqldb_test SET name = 'Wilson' WHERE id = 3").Then(). + Exec(ctx, "DELETE FROM sqldb_test WHERE id = 3"), + ) +} diff --git a/sqldb/conn.go b/sqldb/conn.go index 50cc13f..8f66607 100644 --- a/sqldb/conn.go +++ b/sqldb/conn.go @@ -3,27 +3,31 @@ package sqldb import ( "context" "database/sql" - "time" "github.com/juju/errors" ) // Conn represents database connection. -type Conn struct { - db *sql.DB - - beforeCallbacks []BeforeCallback - afterCallbacks []AfterCallback +type Conn interface { + connOrTx + Begin(context.Context, func(Tx) error) error + Close() error + DB() *sql.DB + Before(BeforeCallback) + After(AfterCallback) } -type ( - // BeforeCallback is a kind of function that can be called before a query is - // executed. - BeforeCallback func(ctx context.Context, query string) - // AfterCallback is a kind of function that can be called after a query was - // executed. - AfterCallback func(ctx context.Context, query string, took time.Duration, err error) -) +// Tx represents database transacation. +type Tx interface { + connOrTx + Commit() error + Rollback() error +} + +type dbWrapper struct { + conn *sql.DB + *caller +} // Flavor defines a kind of SQL database. type Flavor string @@ -36,7 +40,7 @@ const ( ) // Connect establishes a new database connection. -func Connect(ctx context.Context, f Flavor, dsn string) (*Conn, error) { +func Connect(ctx context.Context, f Flavor, dsn string) (Conn, error) { conn, err := sql.Open(string(f), dsn) if err != nil { return nil, errors.Annotate(err, "Failed to establish connection") @@ -45,62 +49,69 @@ func Connect(ctx context.Context, f Flavor, dsn string) (*Conn, error) { if err != nil { return nil, errors.Annotate(err, "Connection is not responding") } - return &Conn{db: conn}, nil + return &dbWrapper{ + conn: conn, + caller: &caller{ + db: conn, + cb: &callbacks{}, + }, + }, nil } -// Exec executes a query and does not expect any result. -func (c *Conn) Exec(ctx context.Context, query string, args ...interface{}) Result { - c.callBefore(ctx, query) - startedAt := time.Now() - res, err := c.db.ExecContext(ctx, query, args...) - c.callAfter(ctx, query, time.Since(startedAt), err) +// Begin executes a transaction. +func (c *dbWrapper) Begin(ctx context.Context, fn func(tx Tx) error) error { + tx, err := c.conn.BeginTx(ctx, &sql.TxOptions{}) if err != nil { - return result{err: err} + return err } - return result{res: res} -} - -// Query executes a query and returns a result object that can later be used to -// retrieve values. -func (c *Conn) Query(ctx context.Context, query string, args ...interface{}) Result { - var r result - c.callBefore(ctx, query) - startedAt := time.Now() - r.rows, r.err = c.db.QueryContext(ctx, query, args...) - c.callAfter(ctx, query, time.Since(startedAt), r.err) - return r + err = fn(c.wrapTx(tx)) + if err != nil { + tx.Rollback() + } + return err } // DB returns the underlying DB object. -func (c *Conn) DB() *sql.DB { - return c.db +func (c *dbWrapper) DB() *sql.DB { + return c.conn } // Close closes the connection. -func (c *Conn) Close() error { - return c.db.Close() +func (c *dbWrapper) Close() error { + return c.conn.Close() } // Before adds a callback function that would be called before a query is // executed. -func (c *Conn) Before(cb BeforeCallback) { - c.beforeCallbacks = append(c.beforeCallbacks, cb) +func (c *dbWrapper) Before(cb BeforeCallback) { + c.cb.addBefore(cb) } // After adds a callback function that would be called after a query was // executed. -func (c *Conn) After(cb AfterCallback) { - c.afterCallbacks = append(c.afterCallbacks, cb) +func (c *dbWrapper) After(cb AfterCallback) { + c.cb.addAfter(cb) } -func (c *Conn) callBefore(ctx context.Context, query string) { - for _, cb := range c.beforeCallbacks { - cb(ctx, query) +func (c *dbWrapper) wrapTx(tx *sql.Tx) Tx { + return &txWrapper{ + tx: tx, + connOrTx: &caller{ + db: tx, + cb: c.cb, + }, } } -func (c *Conn) callAfter(ctx context.Context, query string, took time.Duration, err error) { - for _, cb := range c.afterCallbacks { - cb(ctx, query, took, err) - } +type txWrapper struct { + tx *sql.Tx + connOrTx +} + +func (w *txWrapper) Commit() error { + return w.tx.Commit() +} + +func (w *txWrapper) Rollback() error { + return w.tx.Rollback() } diff --git a/sqldb/exec_result.go b/sqldb/exec_result.go new file mode 100644 index 0000000..9deebd5 --- /dev/null +++ b/sqldb/exec_result.go @@ -0,0 +1,79 @@ +package sqldb + +import ( + "context" + "database/sql" +) + +// ExecResult ... +type ExecResult interface { + Error() error + LastInsertID() int64 + RowsAffected() int64 + Result() sql.Result + Then() ExecChain +} + +type execResult struct { + db executer + err error + res sql.Result +} + +func (r *execResult) Result() sql.Result { + return r.res +} + +func (r *execResult) Error() error { + return r.err +} + +func (r *execResult) LastInsertID() int64 { + if r.res == nil { + return 0 + } + id, err := r.res.LastInsertId() + if err != nil { + return 0 + } + return id +} + +func (r *execResult) RowsAffected() int64 { + if r.res == nil { + return 0 + } + ra, err := r.res.RowsAffected() + if err != nil { + return 0 + } + return ra +} + +func (r *execResult) Then() ExecChain { + if r.err != nil { + return &execChain{r.db} + } + return r.db +} + +// +// Chain +// + +// ExecChain ... +type ExecChain interface { + executer +} + +type execChain struct { + executer +} + +type brokenChain struct { + err error +} + +func (c *brokenChain) Exec(_ context.Context, _ string, _ ...interface{}) ExecResult { + return &execResult{err: c.err} +} diff --git a/sqldb/result.go b/sqldb/query_result.go similarity index 75% rename from sqldb/result.go rename to sqldb/query_result.go index b222cf9..c33dff3 100644 --- a/sqldb/result.go +++ b/sqldb/query_result.go @@ -6,61 +6,29 @@ import ( "reflect" ) -// Result represents query result. -type Result interface { - // Load decodes rows into provided variable. - Load(dest interface{}) Result - // Error returns an error if one happened during query execution. +// QueryResult ... +type QueryResult interface { Error() error - // Rows returns original database rows object. + Load(dest interface{}) error Rows() *sql.Rows - // LastInsertID returns the last inserted record ID for results obtained - // from Exec calls. - LastInsertID() int64 - // RowsAffected returns the number of rows affected for results obtained - // from Exec calls. - RowsAffected() int64 } -type result struct { - rows *sql.Rows - res sql.Result +type queryResult struct { err error + rows *sql.Rows } -func (r result) Rows() *sql.Rows { +func (r *queryResult) Rows() *sql.Rows { return r.rows } -func (r result) Error() error { +func (r *queryResult) Error() error { return r.err } -func (r result) LastInsertID() int64 { - if r.res == nil { - return 0 - } - id, err := r.res.LastInsertId() - if err != nil { - return 0 - } - return id -} - -func (r result) RowsAffected() int64 { - if r.res == nil { - return 0 - } - ra, err := r.res.RowsAffected() - if err != nil { - return 0 - } - return ra -} - -func (r result) Load(dest interface{}) Result { +func (r *queryResult) Load(dest interface{}) error { if r.err != nil { - return r + return r.err } defer r.rows.Close() @@ -89,19 +57,19 @@ func (r result) Load(dest interface{}) Result { } if r.err == nil && r.rows.Err() != nil { - return r.withError(r.rows.Err()) + return r.rows.Err() } - return r + return nil } -func (r *result) loadValue(dest interface{}) { +func (r *queryResult) loadValue(dest interface{}) { if r.rows.Next() { r.err = r.rows.Scan(dest) } } -func (r *result) loadSlice(typ reflect.Type, dest interface{}) { +func (r *queryResult) loadSlice(typ reflect.Type, dest interface{}) { vSlice := reflect.MakeSlice(typ, 0, 0) for r.rows.Next() { val := reflect.New(typ.Elem()) @@ -114,7 +82,7 @@ func (r *result) loadSlice(typ reflect.Type, dest interface{}) { reflect.ValueOf(dest).Elem().Set(vSlice) } -func (r *result) loadMap(dest *map[string]interface{}) { +func (r *queryResult) loadMap(dest *map[string]interface{}) { if !r.rows.Next() { return } @@ -155,7 +123,7 @@ func (r *result) loadMap(dest *map[string]interface{}) { } } -func (r *result) loadSliceOfMaps(dest *[]map[string]interface{}) { +func (r *queryResult) loadSliceOfMaps(dest *[]map[string]interface{}) { cols, err := r.rows.Columns() if err != nil { r.err = err @@ -196,7 +164,7 @@ func (r *result) loadSliceOfMaps(dest *[]map[string]interface{}) { } } -func (r *result) loadStruct(typ reflect.Type, dest interface{}) { +func (r *queryResult) loadStruct(typ reflect.Type, dest interface{}) { if !r.rows.Next() { return } @@ -232,7 +200,7 @@ func (r *result) loadStruct(typ reflect.Type, dest interface{}) { } } -func (r *result) loadSliceOfStructs(typ reflect.Type, dest interface{}) { +func (r *queryResult) loadSliceOfStructs(typ reflect.Type, dest interface{}) { cols, err := r.rows.Columns() if err != nil { r.err = err @@ -271,7 +239,7 @@ func (r *result) loadSliceOfStructs(typ reflect.Type, dest interface{}) { } } -func (r result) withError(err error) Result { +func (r *queryResult) withError(err error) *queryResult { r.err = err return r }