Add sqldb and set packages

This commit is contained in:
2018-06-23 23:46:35 +02:00
parent 0df8b1ec84
commit b06910f223
31 changed files with 2983 additions and 0 deletions
+106
View File
@@ -0,0 +1,106 @@
package sqldb
import (
"context"
"database/sql"
"time"
"github.com/juju/errors"
)
// Conn represents database connection.
type Conn struct {
db *sql.DB
beforeCallbacks []BeforeCallback
afterCallbacks []AfterCallback
}
type (
// BeforeCallback is a kind of function that can be called before a query is
// executed.
BeforeCallback func(ctx context.Context, query string)
// AfterCallback is a kind of function that can be called after a query was
// executed.
AfterCallback func(ctx context.Context, query string, took time.Duration, err error)
)
// Flavor defines a kind of SQL database.
type Flavor string
const (
// MySQL is the MySQL SQL flavor.
MySQL Flavor = "mysql"
// PostgreSQL is the PostgreSQL SQL flavor.
PostgreSQL Flavor = "postgresql"
)
// Connect establishes a new database connection.
func Connect(ctx context.Context, f Flavor, dsn string) (*Conn, error) {
conn, err := sql.Open(string(f), dsn)
if err != nil {
return nil, errors.Annotate(err, "Failed to establish connection")
}
err = conn.PingContext(ctx)
if err != nil {
return nil, errors.Annotate(err, "Connection is not responding")
}
return &Conn{db: conn}, nil
}
// Exec executes a query and does not expect any result.
func (c *Conn) Exec(ctx context.Context, query string, args ...interface{}) Result {
c.callBefore(ctx, query)
startedAt := time.Now()
res, err := c.db.ExecContext(ctx, query, args...)
c.callAfter(ctx, query, time.Since(startedAt), err)
if err != nil {
return result{err: err}
}
return result{res: res}
}
// Query executes a query and returns a result object that can later be used to
// retrieve values.
func (c *Conn) Query(ctx context.Context, query string, args ...interface{}) Result {
var r result
c.callBefore(ctx, query)
startedAt := time.Now()
r.rows, r.err = c.db.QueryContext(ctx, query, args...)
c.callAfter(ctx, query, time.Since(startedAt), r.err)
return r
}
// DB returns the underlying DB object.
func (c *Conn) DB() *sql.DB {
return c.db
}
// Close closes the connection.
func (c *Conn) Close() error {
return c.db.Close()
}
// Before adds a callback function that would be called before a query is
// executed.
func (c *Conn) Before(cb BeforeCallback) {
c.beforeCallbacks = append(c.beforeCallbacks, cb)
}
// After adds a callback function that would be called after a query was
// executed.
func (c *Conn) After(cb AfterCallback) {
c.afterCallbacks = append(c.afterCallbacks, cb)
}
func (c *Conn) callBefore(ctx context.Context, query string) {
for _, cb := range c.beforeCallbacks {
cb(ctx, query)
}
}
func (c *Conn) callAfter(ctx context.Context, query string, took time.Duration, err error) {
for _, cb := range c.afterCallbacks {
cb(ctx, query, took, err)
}
}
+312
View File
@@ -0,0 +1,312 @@
package sqldb
import (
"database/sql"
"fmt"
"reflect"
)
// Result represents query result.
type Result interface {
// Load decodes rows into provided variable.
Load(dest interface{}) Result
// Error returns an error if one happened during query execution.
Error() error
// Rows returns original database rows object.
Rows() *sql.Rows
// LastInsertID returns the last inserted record ID for results obtained
// from Exec calls.
LastInsertID() int64
// RowsAffected returns the number of rows affected for results obtained
// from Exec calls.
RowsAffected() int64
}
type result struct {
rows *sql.Rows
res sql.Result
err error
}
func (r result) Rows() *sql.Rows {
return r.rows
}
func (r result) Error() error {
return r.err
}
func (r result) LastInsertID() int64 {
if r.res == nil {
return 0
}
id, err := r.res.LastInsertId()
if err != nil {
return 0
}
return id
}
func (r result) RowsAffected() int64 {
if r.res == nil {
return 0
}
ra, err := r.res.RowsAffected()
if err != nil {
return 0
}
return ra
}
func (r result) Load(dest interface{}) Result {
if r.err != nil {
return r
}
defer r.rows.Close()
dtyp := reflect.TypeOf(dest)
if dtyp.Kind() != reflect.Ptr {
panic("Value must be a pointer")
}
dtyp = dtyp.Elem()
switch dtyp.Kind() {
case reflect.Struct:
r.loadStruct(dtyp, dest)
case reflect.Map:
r.loadMap(dest.(*map[string]interface{}))
case reflect.Slice:
switch dtyp.Elem().Kind() {
case reflect.Struct:
r.loadSliceOfStructs(dtyp, dest)
case reflect.Map:
r.loadSliceOfMaps(dest.(*[]map[string]interface{}))
default:
r.loadSlice(dtyp, dest)
}
default:
r.loadValue(dest)
}
if r.err == nil && r.rows.Err() != nil {
return r.withError(r.rows.Err())
}
return r
}
func (r *result) loadValue(dest interface{}) {
if r.rows.Next() {
r.err = r.rows.Scan(dest)
}
}
func (r *result) loadSlice(typ reflect.Type, dest interface{}) {
vSlice := reflect.MakeSlice(typ, 0, 0)
for r.rows.Next() {
val := reflect.New(typ.Elem())
r.err = r.rows.Scan(val.Interface())
if r.err != nil {
return
}
vSlice = reflect.Append(vSlice, val.Elem())
}
reflect.ValueOf(dest).Elem().Set(vSlice)
}
func (r *result) loadMap(dest *map[string]interface{}) {
if !r.rows.Next() {
return
}
cols, err := r.rows.Columns()
if err != nil {
r.err = err
return
}
colTypes, err := r.rows.ColumnTypes()
if err != nil {
r.err = err
return
}
vals := make([]interface{}, len(cols))
for i := range cols {
vals[i] = newValue(colTypes[i])
}
err = r.rows.Scan(vals...)
if err != nil {
r.err = err
return
}
if *dest == nil {
*dest = make(map[string]interface{}, len(cols))
}
for i, col := range cols {
switch tval := vals[i].(type) {
case *int64:
(*dest)[col] = *tval
case *string:
(*dest)[col] = *tval
case *bool:
(*dest)[col] = *tval
}
}
}
func (r *result) loadSliceOfMaps(dest *[]map[string]interface{}) {
cols, err := r.rows.Columns()
if err != nil {
r.err = err
return
}
colTypes, err := r.rows.ColumnTypes()
if err != nil {
r.err = err
return
}
if *dest == nil {
*dest = make([]map[string]interface{}, 0)
}
for r.rows.Next() {
vals := make([]interface{}, len(cols))
for i := range cols {
vals[i] = newValue(colTypes[i])
}
err = r.rows.Scan(vals...)
if err != nil {
r.err = err
return
}
row := make(map[string]interface{}, len(cols))
for i, col := range cols {
switch tval := vals[i].(type) {
case *int64:
row[col] = *tval
case *string:
row[col] = *tval
case *bool:
row[col] = *tval
}
}
*dest = append(*dest, row)
}
}
func (r *result) loadStruct(typ reflect.Type, dest interface{}) {
if !r.rows.Next() {
return
}
cols, err := r.rows.Columns()
if err != nil {
r.err = err
return
}
val := reflect.ValueOf(dest).Elem()
vals := make([]interface{}, len(cols))
tm := tagMap(cols, val.Type())
for i := range cols {
if fi, ok := tm[i]; ok {
fval := val.Field(fi)
vals[i] = reflect.New(fval.Type()).Interface()
} else {
var dummy interface{}
vals[i] = &dummy
}
}
err = r.rows.Scan(vals...)
if err != nil {
r.err = err
return
}
for i, fi := range tm {
fval := val.Field(fi)
fval.Set(reflect.ValueOf(vals[i]).Elem())
}
}
func (r *result) loadSliceOfStructs(typ reflect.Type, dest interface{}) {
cols, err := r.rows.Columns()
if err != nil {
r.err = err
return
}
vSlice := reflect.ValueOf(dest).Elem()
tSlice := vSlice.Type()
tElem := tSlice.Elem()
tm := tagMap(cols, tElem)
for r.rows.Next() {
vals := make([]interface{}, len(cols))
val := reflect.New(tElem).Elem()
for i := range cols {
if fi, ok := tm[i]; ok {
fval := val.Field(fi)
vals[i] = reflect.New(fval.Type()).Interface()
} else {
var dummy interface{}
vals[i] = &dummy
}
}
err = r.rows.Scan(vals...)
if err != nil {
r.err = err
return
}
for i, fi := range tm {
fval := val.Field(fi)
fval.Set(reflect.ValueOf(vals[i]).Elem())
}
vSlice.Set(reflect.Append(vSlice, val))
}
}
func (r result) withError(err error) Result {
r.err = err
return r
}
func newValue(typ *sql.ColumnType) interface{} {
switch typ.DatabaseTypeName() {
case "VARCHAR", "NVARCHAR", "TEXT":
var s string
return &s
case "INT", "BIGINT":
var i int64
return &i
case "BOOL":
var b bool
return &b
default:
panic(fmt.Errorf("Unsupported MySQL type: %s", typ.DatabaseTypeName()))
}
}
func tagMap(cols []string, typ reflect.Type) map[int]int {
fieldIndices := map[string]int{}
for i := 0; i < typ.NumField(); i++ {
tag := typ.Field(i).Tag.Get("db")
if tag != "" {
fieldIndices[tag] = i
}
}
colFields := map[int]int{}
for i, col := range cols {
if fi, ok := fieldIndices[col]; ok {
colFields[i] = fi
}
}
return colFields
}
+74
View File
@@ -0,0 +1,74 @@
package sqldb
import (
"context"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestLoadSingleValue(t *testing.T) {
ctx := context.Background()
exp := int(1)
var out int
mustT(t, conn.Query(ctx, "SELECT 1").Load(&out))
if exp != out {
t.Errorf("Value doesn't match: expected %d, got %d", exp, out)
}
}
func TestLoadSlice(t *testing.T) {
ctx := context.Background()
exp := []int{1, 2}
var out []int
mustT(t, conn.Query(ctx, "SELECT id FROM sqldb_test").Load(&out))
if !cmp.Equal(exp, out) {
t.Errorf("Values dont't match: %s", cmp.Diff(exp, out))
}
}
func TestLoadMap(t *testing.T) {
ctx := context.Background()
exp := map[string]interface{}{"id": int64(1), "name": "Alice"}
var out map[string]interface{}
mustT(t, conn.Query(ctx, "SELECT * FROM sqldb_test WHERE id = 1").Load(&out))
if !cmp.Equal(exp, out) {
t.Errorf("Record doesn't match: %s", cmp.Diff(exp, out))
}
}
func TestLoadSliceOfMaps(t *testing.T) {
ctx := context.Background()
exp := []map[string]interface{}{
{"id": int64(1), "name": "Alice"},
{"id": int64(2), "name": "Bob"},
}
var out []map[string]interface{}
mustT(t, conn.Query(ctx, "SELECT * FROM sqldb_test ORDER BY id ASC").Load(&out))
if !cmp.Equal(exp, out) {
t.Errorf("Records don't match: %s", cmp.Diff(exp, out))
}
}
func TestLoadStruct(t *testing.T) {
ctx := context.Background()
exp := record{ID: 1, Name: "Alice"}
var out record
mustT(t, conn.Query(ctx, "SELECT * FROM sqldb_test WHERE id = 1").Load(&out))
if !cmp.Equal(exp, out) {
t.Errorf("Record doesn't match: %s", cmp.Diff(exp, out))
}
}
func TestLoadSliceOfStructs(t *testing.T) {
ctx := context.Background()
exp := []record{
{ID: 1, Name: "Alice"},
{ID: 2, Name: "Bob"},
}
var out []record
mustT(t, conn.Query(ctx, "SELECT * FROM sqldb_test").Load(&out))
if !cmp.Equal(exp, out) {
t.Errorf("Records don't match: %s", cmp.Diff(exp, out))
}
}
+77
View File
@@ -0,0 +1,77 @@
package sqldb
import (
"context"
"flag"
"fmt"
"log"
"os"
"testing"
_ "github.com/go-sql-driver/mysql" // MySQL driver
)
type record struct {
ID uint `db:"id"`
Name string `db:"name"`
}
var conn *Conn
func TestMain(m *testing.M) {
dsn := flag.String("dsn", "", "Database source name")
flag.Parse()
if *dsn == "" {
log.Println("Database source name is not provided, skipping package tests")
os.Exit(0)
}
log.Println("Establishing connection to the test database")
ctx := context.Background()
var err error
conn, err = Connect(ctx, MySQL, *dsn)
if err != nil {
log.Fatalf("Failed to connect: %v\n", err)
}
log.Println("Seeding database")
must(conn.Exec(ctx, `
DROP TABLE IF EXISTS sqldb_test
`))
must(conn.Exec(ctx, `
CREATE TABLE sqldb_test (
id int(11) UNSIGNED NOT NULL,
name VARCHAR(10) DEFAULT '',
PRIMARY KEY (id)
) ENGINE=InnoDB DEFAULT CHARSET=ascii
`))
must(conn.Exec(ctx, `
INSERT INTO sqldb_test (id, name)
VALUES
(1, "Alice"),
(2, "Bob")
`))
fmt.Println("Starting test suite")
exitCode := m.Run()
log.Println("Test suite finished")
if err := conn.Close(); err != nil {
log.Printf("Failed to close connection: %v\n", err)
}
os.Exit(exitCode)
}
func mustT(t *testing.T, r Result) Result {
t.Helper()
if r.Error() != nil {
t.Fatalf("Query failed: %v", r.Error())
}
return r
}
func must(r Result) Result {
if r.Error() != nil {
log.Fatalf("Query failed: %v\n", r.Error())
}
return r
}