diff --git a/csv2/struct.go b/csv2/struct.go new file mode 100644 index 0000000..b2bd5fb --- /dev/null +++ b/csv2/struct.go @@ -0,0 +1,154 @@ +package csv2 + +import ( + "encoding/csv" + "io" + "reflect" + "strconv" + + "github.com/juju/errors" + "github.com/localhots/gobelt/reflect2" +) + +// Reader is a wrapper for standard library CSV Reader that allows unmarshalling +// into a slice of structs. +type Reader struct { + TagName string + ColumnNamesInFirstRow bool + Reader *csv.Reader + cols []string +} + +const defaultTagName = "csv" + +// NewReader creates a new reader from a standard CSV Reader. +func NewReader(r *csv.Reader) *Reader { + return &Reader{ + TagName: defaultTagName, + ColumnNamesInFirstRow: true, + Reader: r, + } +} + +// SetColumnNames assigns column names. Use this function if column names are +// not provided in the first row. +func (r *Reader) SetColumnNames(cols []string) { + if r.ColumnNamesInFirstRow { + panic("Should not assign column names when they are expected in the first row") + } + r.cols = cols +} + +// Load reads CSV contents and unmarshals them into given destination. +// Destination value must be a pointer to a slice of structs. +func (r *Reader) Load(dest interface{}) error { + destT := reflect.TypeOf(dest) + if destT.Kind() != reflect.Ptr || + destT.Elem().Kind() != reflect.Slice || + destT.Elem().Elem().Kind() != reflect.Struct { + return errors.New("Destination must be a pointer to a slice of structs") + } + + if r.cols == nil { + if r.ColumnNamesInFirstRow { + cols, err := r.Reader.Read() + if err != nil { + return errors.Annotate(err, "Failed to read column names from first row") + } + r.cols = cols + } else { + return errors.New("Column names are not defined") + } + } + + destV := reflect.ValueOf(dest).Elem() + valT := destT.Elem().Elem() + colIndex := reflect2.AssociateColumns(valT, r.TagName, r.cols) + + for { + row, err := r.Reader.Read() + if err != nil { + if err == io.EOF { + break + } + return errors.Annotate(err, "Failed to read CSV row") + } + + val := reflect.New(valT).Elem() + for iCol, iField := range colIndex { + err := unmarshal(row[iCol], val.Field(iField)) + if err != nil { + return errors.Annotate(err, "Failed to process CSV row") + } + } + + destV.Set(reflect.Append(destV, val)) + } + return nil +} + +func unmarshal(v string, dest reflect.Value) error { + switch dest.Kind() { + case reflect.String: + dest.SetString(v) + case reflect.Bool: + b, err := strconv.ParseBool(v) + if err != nil { + return unmarshalError(err, v, dest.Kind()) + } + dest.SetBool(b) + case reflect.Int, reflect.Int64: + return unmarshalInt(v, dest, 64) + case reflect.Int8: + return unmarshalInt(v, dest, 8) + case reflect.Int16: + return unmarshalInt(v, dest, 16) + case reflect.Int32: + return unmarshalInt(v, dest, 32) + case reflect.Uint, reflect.Uint64: + return unmarshalUint(v, dest, 64) + case reflect.Uint8: + return unmarshalUint(v, dest, 8) + case reflect.Uint16: + return unmarshalUint(v, dest, 16) + case reflect.Uint32: + return unmarshalUint(v, dest, 32) + case reflect.Float32: + return unmarshalFloat(v, dest, 32) + case reflect.Float64: + return unmarshalFloat(v, dest, 64) + } + + return nil +} + +func unmarshalInt(v string, dest reflect.Value, bitSize int) error { + i, err := strconv.ParseInt(v, 10, bitSize) + if err != nil { + return unmarshalError(err, v, dest.Kind()) + } + dest.SetInt(i) + return nil +} + +func unmarshalUint(v string, dest reflect.Value, bitSize int) error { + i, err := strconv.ParseUint(v, 10, bitSize) + if err != nil { + return unmarshalError(err, v, dest.Kind()) + } + dest.SetUint(i) + return nil +} + +func unmarshalFloat(v string, dest reflect.Value, bitSize int) error { + f, err := strconv.ParseFloat(v, bitSize) + if err != nil { + return unmarshalError(err, v, dest.Kind()) + } + dest.SetFloat(f) + return nil +} + +func unmarshalError(err error, v string, k reflect.Kind) error { + return errors.Annotatef(err, "Can't unmarshal %q into value of type %s", v, k) +} diff --git a/csv2/struct_test.go b/csv2/struct_test.go new file mode 100644 index 0000000..cd2dfe8 --- /dev/null +++ b/csv2/struct_test.go @@ -0,0 +1,42 @@ +package csv2 + +import ( + "encoding/csv" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestLoad(t *testing.T) { + type specie struct { + Name string `csv:"name"` + FavFood string `csv:"fav_food"` + Age uint16 `csv:"age"` + Weight float32 `csv:"weight"` + Available bool `csv:"available"` + } + body := `name,available,fav_food,weight,age +Alice,false,Bananas,19.22,5 +Frank,true,Burrito,14,9 +Joel,true,Pesto,32.5,21` + exp := []specie{ + {Name: "Alice", FavFood: "Bananas", Age: 5, Weight: 19.22, Available: false}, + {Name: "Frank", FavFood: "Burrito", Age: 9, Weight: 14, Available: true}, + {Name: "Joel", FavFood: "Pesto", Age: 21, Weight: 32.5, Available: true}, + } + + csvReader := csv.NewReader(strings.NewReader(body)) + r := NewReader(csvReader) + r.ColumnNamesInFirstRow = true + + var out []specie + err := r.Load(&out) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if !cmp.Equal(exp, out) { + t.Errorf("Result value is different: %s", cmp.Diff(exp, out)) + } +}