Extract column indexing into reflect2 package
This commit is contained in:
parent
4aae29e4bf
commit
406888cdd5
|
@ -0,0 +1,27 @@
|
|||
package reflect2
|
||||
|
||||
import "reflect"
|
||||
|
||||
// TagIndex returns a map that associates tag values with field indices.
|
||||
func TagIndex(typ reflect.Type, tag string) map[string]int {
|
||||
tagIndex := map[string]int{}
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
tag := typ.Field(i).Tag.Get(tag)
|
||||
if tag != "" {
|
||||
tagIndex[tag] = i
|
||||
}
|
||||
}
|
||||
return tagIndex
|
||||
}
|
||||
|
||||
// AssociateColumns returns a map that associates column indices with fields.
|
||||
func AssociateColumns(typ reflect.Type, tag string, cols []string) map[int]int {
|
||||
tagIndex := TagIndex(typ, tag)
|
||||
colFields := map[int]int{}
|
||||
for i, col := range cols {
|
||||
if fi, ok := tagIndex[col]; ok {
|
||||
colFields[i] = fi
|
||||
}
|
||||
}
|
||||
return colFields
|
||||
}
|
|
@ -4,6 +4,8 @@ import (
|
|||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/localhots/gobelt/reflect2"
|
||||
)
|
||||
|
||||
// QueryResult ...
|
||||
|
@ -13,6 +15,8 @@ type QueryResult interface {
|
|||
Rows() *sql.Rows
|
||||
}
|
||||
|
||||
const tagName = "db"
|
||||
|
||||
type queryResult struct {
|
||||
err error
|
||||
rows *sql.Rows
|
||||
|
@ -177,7 +181,7 @@ func (r *queryResult) loadStruct(typ reflect.Type, dest interface{}) {
|
|||
|
||||
val := reflect.ValueOf(dest).Elem()
|
||||
vals := make([]interface{}, len(cols))
|
||||
tm := tagMap(cols, val.Type())
|
||||
tm := reflect2.AssociateColumns(val.Type(), tagName, cols)
|
||||
for i := range cols {
|
||||
if fi, ok := tm[i]; ok {
|
||||
fval := val.Field(fi)
|
||||
|
@ -210,7 +214,7 @@ func (r *queryResult) loadSliceOfStructs(typ reflect.Type, dest interface{}) {
|
|||
vSlice := reflect.ValueOf(dest).Elem()
|
||||
tSlice := vSlice.Type()
|
||||
tElem := tSlice.Elem()
|
||||
tm := tagMap(cols, tElem)
|
||||
tm := reflect2.AssociateColumns(tElem, tagName, cols)
|
||||
|
||||
for r.rows.Next() {
|
||||
vals := make([]interface{}, len(cols))
|
||||
|
@ -259,22 +263,3 @@ func newValue(typ *sql.ColumnType) interface{} {
|
|||
panic(fmt.Errorf("Unsupported MySQL type: %s", typ.DatabaseTypeName()))
|
||||
}
|
||||
}
|
||||
|
||||
func tagMap(cols []string, typ reflect.Type) map[int]int {
|
||||
fieldIndices := map[string]int{}
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
tag := typ.Field(i).Tag.Get("db")
|
||||
if tag != "" {
|
||||
fieldIndices[tag] = i
|
||||
}
|
||||
}
|
||||
|
||||
colFields := map[int]int{}
|
||||
for i, col := range cols {
|
||||
if fi, ok := fieldIndices[col]; ok {
|
||||
colFields[i] = fi
|
||||
}
|
||||
}
|
||||
|
||||
return colFields
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue