1
0
Fork 0
gobelt/sqldb/result.go

313 lines
5.7 KiB
Go

package sqldb
import (
"database/sql"
"fmt"
"reflect"
)
// Result represents query result.
type Result interface {
// Load decodes rows into provided variable.
Load(dest interface{}) Result
// Error returns an error if one happened during query execution.
Error() error
// Rows returns original database rows object.
Rows() *sql.Rows
// LastInsertID returns the last inserted record ID for results obtained
// from Exec calls.
LastInsertID() int64
// RowsAffected returns the number of rows affected for results obtained
// from Exec calls.
RowsAffected() int64
}
type result struct {
rows *sql.Rows
res sql.Result
err error
}
func (r result) Rows() *sql.Rows {
return r.rows
}
func (r result) Error() error {
return r.err
}
func (r result) LastInsertID() int64 {
if r.res == nil {
return 0
}
id, err := r.res.LastInsertId()
if err != nil {
return 0
}
return id
}
func (r result) RowsAffected() int64 {
if r.res == nil {
return 0
}
ra, err := r.res.RowsAffected()
if err != nil {
return 0
}
return ra
}
func (r result) Load(dest interface{}) Result {
if r.err != nil {
return r
}
defer r.rows.Close()
dtyp := reflect.TypeOf(dest)
if dtyp.Kind() != reflect.Ptr {
panic("Value must be a pointer")
}
dtyp = dtyp.Elem()
switch dtyp.Kind() {
case reflect.Struct:
r.loadStruct(dtyp, dest)
case reflect.Map:
r.loadMap(dest.(*map[string]interface{}))
case reflect.Slice:
switch dtyp.Elem().Kind() {
case reflect.Struct:
r.loadSliceOfStructs(dtyp, dest)
case reflect.Map:
r.loadSliceOfMaps(dest.(*[]map[string]interface{}))
default:
r.loadSlice(dtyp, dest)
}
default:
r.loadValue(dest)
}
if r.err == nil && r.rows.Err() != nil {
return r.withError(r.rows.Err())
}
return r
}
func (r *result) loadValue(dest interface{}) {
if r.rows.Next() {
r.err = r.rows.Scan(dest)
}
}
func (r *result) loadSlice(typ reflect.Type, dest interface{}) {
vSlice := reflect.MakeSlice(typ, 0, 0)
for r.rows.Next() {
val := reflect.New(typ.Elem())
r.err = r.rows.Scan(val.Interface())
if r.err != nil {
return
}
vSlice = reflect.Append(vSlice, val.Elem())
}
reflect.ValueOf(dest).Elem().Set(vSlice)
}
func (r *result) loadMap(dest *map[string]interface{}) {
if !r.rows.Next() {
return
}
cols, err := r.rows.Columns()
if err != nil {
r.err = err
return
}
colTypes, err := r.rows.ColumnTypes()
if err != nil {
r.err = err
return
}
vals := make([]interface{}, len(cols))
for i := range cols {
vals[i] = newValue(colTypes[i])
}
err = r.rows.Scan(vals...)
if err != nil {
r.err = err
return
}
if *dest == nil {
*dest = make(map[string]interface{}, len(cols))
}
for i, col := range cols {
switch tval := vals[i].(type) {
case *int64:
(*dest)[col] = *tval
case *string:
(*dest)[col] = *tval
case *bool:
(*dest)[col] = *tval
}
}
}
func (r *result) loadSliceOfMaps(dest *[]map[string]interface{}) {
cols, err := r.rows.Columns()
if err != nil {
r.err = err
return
}
colTypes, err := r.rows.ColumnTypes()
if err != nil {
r.err = err
return
}
if *dest == nil {
*dest = make([]map[string]interface{}, 0)
}
for r.rows.Next() {
vals := make([]interface{}, len(cols))
for i := range cols {
vals[i] = newValue(colTypes[i])
}
err = r.rows.Scan(vals...)
if err != nil {
r.err = err
return
}
row := make(map[string]interface{}, len(cols))
for i, col := range cols {
switch tval := vals[i].(type) {
case *int64:
row[col] = *tval
case *string:
row[col] = *tval
case *bool:
row[col] = *tval
}
}
*dest = append(*dest, row)
}
}
func (r *result) loadStruct(typ reflect.Type, dest interface{}) {
if !r.rows.Next() {
return
}
cols, err := r.rows.Columns()
if err != nil {
r.err = err
return
}
val := reflect.ValueOf(dest).Elem()
vals := make([]interface{}, len(cols))
tm := tagMap(cols, val.Type())
for i := range cols {
if fi, ok := tm[i]; ok {
fval := val.Field(fi)
vals[i] = reflect.New(fval.Type()).Interface()
} else {
var dummy interface{}
vals[i] = &dummy
}
}
err = r.rows.Scan(vals...)
if err != nil {
r.err = err
return
}
for i, fi := range tm {
fval := val.Field(fi)
fval.Set(reflect.ValueOf(vals[i]).Elem())
}
}
func (r *result) loadSliceOfStructs(typ reflect.Type, dest interface{}) {
cols, err := r.rows.Columns()
if err != nil {
r.err = err
return
}
vSlice := reflect.ValueOf(dest).Elem()
tSlice := vSlice.Type()
tElem := tSlice.Elem()
tm := tagMap(cols, tElem)
for r.rows.Next() {
vals := make([]interface{}, len(cols))
val := reflect.New(tElem).Elem()
for i := range cols {
if fi, ok := tm[i]; ok {
fval := val.Field(fi)
vals[i] = reflect.New(fval.Type()).Interface()
} else {
var dummy interface{}
vals[i] = &dummy
}
}
err = r.rows.Scan(vals...)
if err != nil {
r.err = err
return
}
for i, fi := range tm {
fval := val.Field(fi)
fval.Set(reflect.ValueOf(vals[i]).Elem())
}
vSlice.Set(reflect.Append(vSlice, val))
}
}
func (r result) withError(err error) Result {
r.err = err
return r
}
func newValue(typ *sql.ColumnType) interface{} {
switch typ.DatabaseTypeName() {
case "VARCHAR", "NVARCHAR", "TEXT":
var s string
return &s
case "INT", "BIGINT":
var i int64
return &i
case "BOOL":
var b bool
return &b
default:
panic(fmt.Errorf("Unsupported MySQL type: %s", typ.DatabaseTypeName()))
}
}
func tagMap(cols []string, typ reflect.Type) map[int]int {
fieldIndices := map[string]int{}
for i := 0; i < typ.NumField(); i++ {
tag := typ.Field(i).Tag.Get("db")
if tag != "" {
fieldIndices[tag] = i
}
}
colFields := map[int]int{}
for i, col := range cols {
if fi, ok := fieldIndices[col]; ok {
colFields[i] = fi
}
}
return colFields
}