1
0
Fork 0

Add support for multi-column tables to tests

This commit is contained in:
Gregory Eremin 2018-11-11 20:57:41 +01:00
parent 07fed4b4eb
commit cb10d7cecf
2 changed files with 51 additions and 38 deletions

View File

@ -47,7 +47,7 @@ func TestSet(t *testing.T) {
for in, exp := range inputs { for in, exp := range inputs {
t.Run("input "+in, func(t *testing.T) { t.Run("input "+in, func(t *testing.T) {
suite.insertAndCompareExp(t, tbl, in, exp) suite.insertAndCompareExp(t, tbl, iSlice(in), iSlice(exp))
}) })
} }
} }
@ -64,7 +64,11 @@ func TestEnum(t *testing.T) {
} }
for in, exp := range inputs { for in, exp := range inputs {
t.Run("input "+in, func(t *testing.T) { t.Run("input "+in, func(t *testing.T) {
suite.insertAndCompareExp(t, tbl, in, exp) suite.insertAndCompareExp(t, tbl, iSlice(in), iSlice(exp))
}) })
} }
} }
func iSlice(i interface{}) []interface{} {
return []interface{}{i}
}

View File

@ -29,10 +29,15 @@ const (
) )
type table struct { type table struct {
name string name string
colTyp mysql.ColumnType cols []column
conn *sql.DB
}
type column struct {
typ mysql.ColumnType
length string
attrs byte attrs byte
conn *sql.DB
} }
// //
@ -40,8 +45,16 @@ type table struct {
// //
func (s *testSuite) createTable(typ mysql.ColumnType, length string, attrs byte) *table { func (s *testSuite) createTable(typ mysql.ColumnType, length string, attrs byte) *table {
name := strings.ToLower(typ.String()) + fmt.Sprintf("_test_%d", time.Now().UnixNano()) return s.createTableMulti(column{typ, length, attrs})
cols := colDef(typ, length, attrs) }
func (s *testSuite) createTableMulti(cols ...column) *table {
name := fmt.Sprintf("test_table_%d", time.Now().UnixNano())
colDefs := make([]string, len(cols))
for i, col := range cols {
colDefs[i] = colDef(col.typ, col.length, col.attrs)
}
colsDefStr := strings.Join(colDefs, ",\n\t")
_, err := s.conn.Exec(fmt.Sprintf(`DROP TABLE IF EXISTS %s`, name)) _, err := s.conn.Exec(fmt.Sprintf(`DROP TABLE IF EXISTS %s`, name))
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
@ -49,10 +62,10 @@ func (s *testSuite) createTable(typ mysql.ColumnType, length string, attrs byte)
tableQuery := fmt.Sprintf(`CREATE TABLE %s ( tableQuery := fmt.Sprintf(`CREATE TABLE %s (
%s %s
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4`, name, cols) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4`, name, colsDefStr)
fmt.Println("--------") fmt.Println("--------")
fmt.Printf("-- Creating test table: type %s\n", typ.String()) fmt.Println("-- Creating test table")
fmt.Println(tableQuery) fmt.Println(tableQuery)
fmt.Println("--------") fmt.Println("--------")
@ -61,21 +74,18 @@ func (s *testSuite) createTable(typ mysql.ColumnType, length string, attrs byte)
log.Fatal(err) log.Fatal(err)
} }
return &table{ return &table{
name: name, name: name,
colTyp: typ, cols: cols,
attrs: attrs, conn: s.conn,
conn: s.conn,
} }
} }
func (tbl *table) insert(t *testing.T, val interface{}) { func (tbl *table) insert(t *testing.T, vals ...interface{}) {
t.Helper() t.Helper()
if val == nil { ph := strings.Repeat("?,", len(vals))
val = "NULL" ph = ph[:len(ph)-1]
}
// log.Printf("Table: %s Value: %v", tbl.name, val) _, err := tbl.conn.Exec(fmt.Sprintf(`INSERT INTO %s VALUES (%s)`, tbl.name, ph), vals...)
_, err := tbl.conn.Exec(fmt.Sprintf(`INSERT INTO %s VALUES (?)`, tbl.name), val)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -176,21 +186,21 @@ func colTypeSyntax(ct mysql.ColumnType) (typName, attrs string) {
// Expectations // Expectations
// //
func (s *testSuite) insertAndCompare(t *testing.T, tbl *table, val interface{}) { func (s *testSuite) insertAndCompare(t *testing.T, tbl *table, vals ...interface{}) {
t.Helper() t.Helper()
tbl.insert(t, val) tbl.insert(t, vals...)
suite.expectValue(t, tbl, val) suite.expectValue(t, tbl, vals)
} }
func (s *testSuite) insertAndCompareExp(t *testing.T, tbl *table, val, exp interface{}) { func (s *testSuite) insertAndCompareExp(t *testing.T, tbl *table, vals, exps []interface{}) {
t.Helper() t.Helper()
tbl.insert(t, val) tbl.insert(t, vals...)
suite.expectValue(t, tbl, exp) suite.expectValue(t, tbl, exps)
} }
func (s *testSuite) expectValue(t *testing.T, tbl *table, exp interface{}) { func (s *testSuite) expectValue(t *testing.T, tbl *table, exp []interface{}) {
t.Helper() t.Helper()
out := make(chan interface{}) out := make(chan []interface{})
go func() { go func() {
for { for {
evt, err := suite.reader.ReadEvent() evt, err := suite.reader.ReadEvent()
@ -203,11 +213,11 @@ func (s *testSuite) expectValue(t *testing.T, tbl *table, exp interface{}) {
if err != nil { if err != nil {
t.Fatalf("Failed to decode rows event: %v", err) t.Fatalf("Failed to decode rows event: %v", err)
} }
if len(re.Rows) != 1 && len(re.Rows[0]) != 1 { if len(re.Rows) != 1 {
t.Fatal("Expected 1 row with 1 value") t.Fatal("Expected 1 row")
} }
out <- re.Rows[0][0] out <- re.Rows[0]
return return
} }
} }
@ -215,18 +225,18 @@ func (s *testSuite) expectValue(t *testing.T, tbl *table, exp interface{}) {
select { select {
case res := <-out: case res := <-out:
s.compare(t, tbl, exp, res) for i := range res {
s.compare(t, tbl.cols[i], exp[i], res[i])
}
case <-time.After(3 * time.Second): case <-time.After(3 * time.Second):
t.Fatalf("Value was not received") t.Fatalf("Value was not received")
} }
} }
func (s *testSuite) compare(t *testing.T, tbl *table, exp, res interface{}) { func (s *testSuite) compare(t *testing.T, col column, exp, res interface{}) {
// Sign integer if necessary // Sign integer if necessary
if attrUnsigned&tbl.attrs == 0 { if attrUnsigned&col.attrs == 0 {
// old := res res = signNumber(res, col.typ)
res = signNumber(res, tbl.colTyp)
// t.Logf("Converted unsigned %d into signed %d", old, res)
} }
// Expectations would be pointers for null types, dereference them because // Expectations would be pointers for null types, dereference them because
@ -240,10 +250,9 @@ func (s *testSuite) compare(t *testing.T, tbl *table, exp, res interface{}) {
} }
} }
// fmt.Printf("VALUE RECEIVED: %T(%+v), EXPECTED: %T(%+v)\n", res, res, exp, exp)
switch texp := exp.(type) { switch texp := exp.(type) {
case []byte: case []byte:
switch tbl.colTyp { switch col.typ {
case mysql.ColumnTypeJSON: case mysql.ColumnTypeJSON:
var jExp, jRes interface{} var jExp, jRes interface{}
if err := json.Unmarshal(texp, &jExp); err != nil { if err := json.Unmarshal(texp, &jExp); err != nil {