diff --git a/mysql/driver/connection.go b/mysql/driver/connection.go index 02ad2ce..9c0a94a 100644 --- a/mysql/driver/connection.go +++ b/mysql/driver/connection.go @@ -12,7 +12,6 @@ import ( "context" "database/sql/driver" "net" - "strconv" "strings" "time" ) @@ -219,157 +218,6 @@ func (mc *mysqlConn) error() error { return nil } -func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) { - // Number of ? should be same to len(args) - if strings.Count(query, "?") != len(args) { - return "", driver.ErrSkip - } - - buf := mc.buf.takeCompleteBuffer() - if buf == nil { - // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) - return "", ErrInvalidConn - } - buf = buf[:0] - argPos := 0 - - for i := 0; i < len(query); i++ { - q := strings.IndexByte(query[i:], '?') - if q == -1 { - buf = append(buf, query[i:]...) - break - } - buf = append(buf, query[i:i+q]...) - i += q - - arg := args[argPos] - argPos++ - - if arg == nil { - buf = append(buf, "NULL"...) - continue - } - - switch v := arg.(type) { - case int64: - buf = strconv.AppendInt(buf, v, 10) - case float64: - buf = strconv.AppendFloat(buf, v, 'g', -1, 64) - case bool: - if v { - buf = append(buf, '1') - } else { - buf = append(buf, '0') - } - case time.Time: - if v.IsZero() { - buf = append(buf, "'0000-00-00'"...) - } else { - v := v.In(mc.cfg.Loc) - v = v.Add(time.Nanosecond * 500) // To round under microsecond - year := v.Year() - year100 := year / 100 - year1 := year % 100 - month := v.Month() - day := v.Day() - hour := v.Hour() - minute := v.Minute() - second := v.Second() - micro := v.Nanosecond() / 1000 - - buf = append(buf, []byte{ - '\'', - digits10[year100], digits01[year100], - digits10[year1], digits01[year1], - '-', - digits10[month], digits01[month], - '-', - digits10[day], digits01[day], - ' ', - digits10[hour], digits01[hour], - ':', - digits10[minute], digits01[minute], - ':', - digits10[second], digits01[second], - }...) - - if micro != 0 { - micro10000 := micro / 10000 - micro100 := micro / 100 % 100 - micro1 := micro % 100 - buf = append(buf, []byte{ - '.', - digits10[micro10000], digits01[micro10000], - digits10[micro100], digits01[micro100], - digits10[micro1], digits01[micro1], - }...) - } - buf = append(buf, '\'') - } - case []byte: - if v == nil { - buf = append(buf, "NULL"...) - } else { - buf = append(buf, "_binary'"...) - if mc.status&statusNoBackslashEscapes == 0 { - buf = escapeBytesBackslash(buf, v) - } else { - buf = escapeBytesQuotes(buf, v) - } - buf = append(buf, '\'') - } - case string: - buf = append(buf, '\'') - if mc.status&statusNoBackslashEscapes == 0 { - buf = escapeStringBackslash(buf, v) - } else { - buf = escapeStringQuotes(buf, v) - } - buf = append(buf, '\'') - default: - return "", driver.ErrSkip - } - - if len(buf)+4 > mc.maxAllowedPacket { - return "", driver.ErrSkip - } - } - if argPos != len(args) { - return "", driver.ErrSkip - } - return string(buf), nil -} - -func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { - if mc.closed.IsSet() { - errLog.Print(ErrInvalidConn) - return nil, driver.ErrBadConn - } - if len(args) != 0 { - if !mc.cfg.InterpolateParams { - return nil, driver.ErrSkip - } - // try to interpolate the parameters to save extra roundtrips for preparing and closing a statement - prepared, err := mc.interpolateParams(query, args) - if err != nil { - return nil, err - } - query = prepared - } - mc.affectedRows = 0 - mc.insertID = 0 - - err := mc.exec(query) - if err == nil { - return &mysqlResult{ - affectedRows: int64(mc.affectedRows), - insertID: int64(mc.insertID), - }, err - } - return nil, mc.markBadConn(err) -} - // Internal function to execute commands func (mc *mysqlConn) exec(query string) error { // Send command diff --git a/mysql/driver/connection_extended.go b/mysql/driver/connection_extended.go index a5064a6..62c690b 100644 --- a/mysql/driver/connection_extended.go +++ b/mysql/driver/connection_extended.go @@ -5,6 +5,11 @@ import ( "time" ) +// ExtendedConn provides access to internal packet functions. +type ExtendedConn struct { + *mysqlConn +} + // NewExtendedConnection creates a new connection extended with packet access // methods. func NewExtendedConnection(dsn string) (*ExtendedConn, error) { @@ -15,17 +20,9 @@ func NewExtendedConnection(dsn string) (*ExtendedConn, error) { return &ExtendedConn{conn}, nil } -// ExtendedConn provides access to internal packet functions. -type ExtendedConn struct { - *mysqlConn -} - -// Close the connection. -func (c *ExtendedConn) Close() error { - // Reset buffer length parameter - // If it's not zero bad stuff happens - c.buf.length = 0 - return c.mysqlConn.Close() +// Exec executes a query. +func (c *ExtendedConn) Exec(query string) error { + return c.exec(query) } // ReadPacket reads a packet from a given connection. If given context has a @@ -65,3 +62,11 @@ func (c *ExtendedConn) HandleErrorPacket(data []byte) error { func (c *ExtendedConn) ResetSequence() { c.sequence = 0 } + +// Close the connection. +func (c *ExtendedConn) Close() error { + // Reset buffer length parameter + // If it's not zero bad stuff happens + c.buf.length = 0 + return c.mysqlConn.Close() +} diff --git a/mysql/slave/slave_conn.go b/mysql/slave/slave_conn.go index c06ba65..3ba946a 100644 --- a/mysql/slave/slave_conn.go +++ b/mysql/slave/slave_conn.go @@ -121,8 +121,7 @@ func (c *Conn) DisableChecksum() error { // SetVar assigns a new value to the given variable. func (c *Conn) SetVar(name, val string) error { - _, err := c.conn.Exec(fmt.Sprintf("SET %s=%q", name, val), nil) - return err + return c.conn.Exec(fmt.Sprintf("SET %s=%q", name, val)) } // Close the connection.