313 lines
5.7 KiB
Go
313 lines
5.7 KiB
Go
package sqldb
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"reflect"
|
|
)
|
|
|
|
// Result represents query result.
|
|
type Result interface {
|
|
// Load decodes rows into provided variable.
|
|
Load(dest interface{}) Result
|
|
// Error returns an error if one happened during query execution.
|
|
Error() error
|
|
// Rows returns original database rows object.
|
|
Rows() *sql.Rows
|
|
// LastInsertID returns the last inserted record ID for results obtained
|
|
// from Exec calls.
|
|
LastInsertID() int64
|
|
// RowsAffected returns the number of rows affected for results obtained
|
|
// from Exec calls.
|
|
RowsAffected() int64
|
|
}
|
|
|
|
type result struct {
|
|
rows *sql.Rows
|
|
res sql.Result
|
|
err error
|
|
}
|
|
|
|
func (r result) Rows() *sql.Rows {
|
|
return r.rows
|
|
}
|
|
|
|
func (r result) Error() error {
|
|
return r.err
|
|
}
|
|
|
|
func (r result) LastInsertID() int64 {
|
|
if r.res == nil {
|
|
return 0
|
|
}
|
|
id, err := r.res.LastInsertId()
|
|
if err != nil {
|
|
return 0
|
|
}
|
|
return id
|
|
}
|
|
|
|
func (r result) RowsAffected() int64 {
|
|
if r.res == nil {
|
|
return 0
|
|
}
|
|
ra, err := r.res.RowsAffected()
|
|
if err != nil {
|
|
return 0
|
|
}
|
|
return ra
|
|
}
|
|
|
|
func (r result) Load(dest interface{}) Result {
|
|
if r.err != nil {
|
|
return r
|
|
}
|
|
defer r.rows.Close()
|
|
|
|
dtyp := reflect.TypeOf(dest)
|
|
if dtyp.Kind() != reflect.Ptr {
|
|
panic("Value must be a pointer")
|
|
}
|
|
dtyp = dtyp.Elem()
|
|
|
|
switch dtyp.Kind() {
|
|
case reflect.Struct:
|
|
r.loadStruct(dtyp, dest)
|
|
case reflect.Map:
|
|
r.loadMap(dest.(*map[string]interface{}))
|
|
case reflect.Slice:
|
|
switch dtyp.Elem().Kind() {
|
|
case reflect.Struct:
|
|
r.loadSliceOfStructs(dtyp, dest)
|
|
case reflect.Map:
|
|
r.loadSliceOfMaps(dest.(*[]map[string]interface{}))
|
|
default:
|
|
r.loadSlice(dtyp, dest)
|
|
}
|
|
default:
|
|
r.loadValue(dest)
|
|
}
|
|
|
|
if r.err == nil && r.rows.Err() != nil {
|
|
return r.withError(r.rows.Err())
|
|
}
|
|
|
|
return r
|
|
}
|
|
|
|
func (r *result) loadValue(dest interface{}) {
|
|
if r.rows.Next() {
|
|
r.err = r.rows.Scan(dest)
|
|
}
|
|
}
|
|
|
|
func (r *result) loadSlice(typ reflect.Type, dest interface{}) {
|
|
vSlice := reflect.MakeSlice(typ, 0, 0)
|
|
for r.rows.Next() {
|
|
val := reflect.New(typ.Elem())
|
|
r.err = r.rows.Scan(val.Interface())
|
|
if r.err != nil {
|
|
return
|
|
}
|
|
vSlice = reflect.Append(vSlice, val.Elem())
|
|
}
|
|
reflect.ValueOf(dest).Elem().Set(vSlice)
|
|
}
|
|
|
|
func (r *result) loadMap(dest *map[string]interface{}) {
|
|
if !r.rows.Next() {
|
|
return
|
|
}
|
|
|
|
cols, err := r.rows.Columns()
|
|
if err != nil {
|
|
r.err = err
|
|
return
|
|
}
|
|
colTypes, err := r.rows.ColumnTypes()
|
|
if err != nil {
|
|
r.err = err
|
|
return
|
|
}
|
|
|
|
vals := make([]interface{}, len(cols))
|
|
for i := range cols {
|
|
vals[i] = newValue(colTypes[i])
|
|
}
|
|
err = r.rows.Scan(vals...)
|
|
if err != nil {
|
|
r.err = err
|
|
return
|
|
}
|
|
|
|
if *dest == nil {
|
|
*dest = make(map[string]interface{}, len(cols))
|
|
}
|
|
for i, col := range cols {
|
|
switch tval := vals[i].(type) {
|
|
case *int64:
|
|
(*dest)[col] = *tval
|
|
case *string:
|
|
(*dest)[col] = *tval
|
|
case *bool:
|
|
(*dest)[col] = *tval
|
|
}
|
|
}
|
|
}
|
|
|
|
func (r *result) loadSliceOfMaps(dest *[]map[string]interface{}) {
|
|
cols, err := r.rows.Columns()
|
|
if err != nil {
|
|
r.err = err
|
|
return
|
|
}
|
|
colTypes, err := r.rows.ColumnTypes()
|
|
if err != nil {
|
|
r.err = err
|
|
return
|
|
}
|
|
|
|
if *dest == nil {
|
|
*dest = make([]map[string]interface{}, 0)
|
|
}
|
|
for r.rows.Next() {
|
|
vals := make([]interface{}, len(cols))
|
|
for i := range cols {
|
|
vals[i] = newValue(colTypes[i])
|
|
}
|
|
err = r.rows.Scan(vals...)
|
|
if err != nil {
|
|
r.err = err
|
|
return
|
|
}
|
|
|
|
row := make(map[string]interface{}, len(cols))
|
|
for i, col := range cols {
|
|
switch tval := vals[i].(type) {
|
|
case *int64:
|
|
row[col] = *tval
|
|
case *string:
|
|
row[col] = *tval
|
|
case *bool:
|
|
row[col] = *tval
|
|
}
|
|
}
|
|
*dest = append(*dest, row)
|
|
}
|
|
}
|
|
|
|
func (r *result) loadStruct(typ reflect.Type, dest interface{}) {
|
|
if !r.rows.Next() {
|
|
return
|
|
}
|
|
|
|
cols, err := r.rows.Columns()
|
|
if err != nil {
|
|
r.err = err
|
|
return
|
|
}
|
|
|
|
val := reflect.ValueOf(dest).Elem()
|
|
vals := make([]interface{}, len(cols))
|
|
tm := tagMap(cols, val.Type())
|
|
for i := range cols {
|
|
if fi, ok := tm[i]; ok {
|
|
fval := val.Field(fi)
|
|
vals[i] = reflect.New(fval.Type()).Interface()
|
|
} else {
|
|
var dummy interface{}
|
|
vals[i] = &dummy
|
|
}
|
|
}
|
|
|
|
err = r.rows.Scan(vals...)
|
|
if err != nil {
|
|
r.err = err
|
|
return
|
|
}
|
|
|
|
for i, fi := range tm {
|
|
fval := val.Field(fi)
|
|
fval.Set(reflect.ValueOf(vals[i]).Elem())
|
|
}
|
|
}
|
|
|
|
func (r *result) loadSliceOfStructs(typ reflect.Type, dest interface{}) {
|
|
cols, err := r.rows.Columns()
|
|
if err != nil {
|
|
r.err = err
|
|
return
|
|
}
|
|
|
|
vSlice := reflect.ValueOf(dest).Elem()
|
|
tSlice := vSlice.Type()
|
|
tElem := tSlice.Elem()
|
|
tm := tagMap(cols, tElem)
|
|
|
|
for r.rows.Next() {
|
|
vals := make([]interface{}, len(cols))
|
|
val := reflect.New(tElem).Elem()
|
|
for i := range cols {
|
|
if fi, ok := tm[i]; ok {
|
|
fval := val.Field(fi)
|
|
vals[i] = reflect.New(fval.Type()).Interface()
|
|
} else {
|
|
var dummy interface{}
|
|
vals[i] = &dummy
|
|
}
|
|
}
|
|
|
|
err = r.rows.Scan(vals...)
|
|
if err != nil {
|
|
r.err = err
|
|
return
|
|
}
|
|
|
|
for i, fi := range tm {
|
|
fval := val.Field(fi)
|
|
fval.Set(reflect.ValueOf(vals[i]).Elem())
|
|
}
|
|
vSlice.Set(reflect.Append(vSlice, val))
|
|
}
|
|
}
|
|
|
|
func (r result) withError(err error) Result {
|
|
r.err = err
|
|
return r
|
|
}
|
|
|
|
func newValue(typ *sql.ColumnType) interface{} {
|
|
switch typ.DatabaseTypeName() {
|
|
case "VARCHAR", "NVARCHAR", "TEXT":
|
|
var s string
|
|
return &s
|
|
case "INT", "BIGINT":
|
|
var i int64
|
|
return &i
|
|
case "BOOL":
|
|
var b bool
|
|
return &b
|
|
default:
|
|
panic(fmt.Errorf("Unsupported MySQL type: %s", typ.DatabaseTypeName()))
|
|
}
|
|
}
|
|
|
|
func tagMap(cols []string, typ reflect.Type) map[int]int {
|
|
fieldIndices := map[string]int{}
|
|
for i := 0; i < typ.NumField(); i++ {
|
|
tag := typ.Field(i).Tag.Get("db")
|
|
if tag != "" {
|
|
fieldIndices[tag] = i
|
|
}
|
|
}
|
|
|
|
colFields := map[int]int{}
|
|
for i, col := range cols {
|
|
if fi, ok := fieldIndices[col]; ok {
|
|
colFields[i] = fi
|
|
}
|
|
}
|
|
|
|
return colFields
|
|
}
|