From 406888cdd515c5738777ddedd8d76a2842417da1 Mon Sep 17 00:00:00 2001 From: Gregory Eremin Date: Tue, 3 Jul 2018 19:47:23 +0200 Subject: [PATCH] Extract column indexing into reflect2 package --- reflect2/fields.go | 27 +++++++++++++++++++++++++++ sqldb/query_result.go | 27 ++++++--------------------- 2 files changed, 33 insertions(+), 21 deletions(-) create mode 100644 reflect2/fields.go diff --git a/reflect2/fields.go b/reflect2/fields.go new file mode 100644 index 0000000..ccbdc5f --- /dev/null +++ b/reflect2/fields.go @@ -0,0 +1,27 @@ +package reflect2 + +import "reflect" + +// TagIndex returns a map that associates tag values with field indices. +func TagIndex(typ reflect.Type, tag string) map[string]int { + tagIndex := map[string]int{} + for i := 0; i < typ.NumField(); i++ { + tag := typ.Field(i).Tag.Get(tag) + if tag != "" { + tagIndex[tag] = i + } + } + return tagIndex +} + +// AssociateColumns returns a map that associates column indices with fields. +func AssociateColumns(typ reflect.Type, tag string, cols []string) map[int]int { + tagIndex := TagIndex(typ, tag) + colFields := map[int]int{} + for i, col := range cols { + if fi, ok := tagIndex[col]; ok { + colFields[i] = fi + } + } + return colFields +} diff --git a/sqldb/query_result.go b/sqldb/query_result.go index c33dff3..6340cac 100644 --- a/sqldb/query_result.go +++ b/sqldb/query_result.go @@ -4,6 +4,8 @@ import ( "database/sql" "fmt" "reflect" + + "github.com/localhots/gobelt/reflect2" ) // QueryResult ... @@ -13,6 +15,8 @@ type QueryResult interface { Rows() *sql.Rows } +const tagName = "db" + type queryResult struct { err error rows *sql.Rows @@ -177,7 +181,7 @@ func (r *queryResult) loadStruct(typ reflect.Type, dest interface{}) { val := reflect.ValueOf(dest).Elem() vals := make([]interface{}, len(cols)) - tm := tagMap(cols, val.Type()) + tm := reflect2.AssociateColumns(val.Type(), tagName, cols) for i := range cols { if fi, ok := tm[i]; ok { fval := val.Field(fi) @@ -210,7 +214,7 @@ func (r *queryResult) loadSliceOfStructs(typ reflect.Type, dest interface{}) { vSlice := reflect.ValueOf(dest).Elem() tSlice := vSlice.Type() tElem := tSlice.Elem() - tm := tagMap(cols, tElem) + tm := reflect2.AssociateColumns(tElem, tagName, cols) for r.rows.Next() { vals := make([]interface{}, len(cols)) @@ -259,22 +263,3 @@ func newValue(typ *sql.ColumnType) interface{} { panic(fmt.Errorf("Unsupported MySQL type: %s", typ.DatabaseTypeName())) } } - -func tagMap(cols []string, typ reflect.Type) map[int]int { - fieldIndices := map[string]int{} - for i := 0; i < typ.NumField(); i++ { - tag := typ.Field(i).Tag.Get("db") - if tag != "" { - fieldIndices[tag] = i - } - } - - colFields := map[int]int{} - for i, col := range cols { - if fi, ok := fieldIndices[col]; ok { - colFields[i] = fi - } - } - - return colFields -}