1
0
Fork 0

Use simple exec

This commit is contained in:
Gregory Eremin 2018-11-20 22:54:50 +01:00
parent fb382876ef
commit 57d2b30132
3 changed files with 17 additions and 165 deletions

View File

@ -12,7 +12,6 @@ import (
"context"
"database/sql/driver"
"net"
"strconv"
"strings"
"time"
)
@ -219,157 +218,6 @@ func (mc *mysqlConn) error() error {
return nil
}
func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
// Number of ? should be same to len(args)
if strings.Count(query, "?") != len(args) {
return "", driver.ErrSkip
}
buf := mc.buf.takeCompleteBuffer()
if buf == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return "", ErrInvalidConn
}
buf = buf[:0]
argPos := 0
for i := 0; i < len(query); i++ {
q := strings.IndexByte(query[i:], '?')
if q == -1 {
buf = append(buf, query[i:]...)
break
}
buf = append(buf, query[i:i+q]...)
i += q
arg := args[argPos]
argPos++
if arg == nil {
buf = append(buf, "NULL"...)
continue
}
switch v := arg.(type) {
case int64:
buf = strconv.AppendInt(buf, v, 10)
case float64:
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
case bool:
if v {
buf = append(buf, '1')
} else {
buf = append(buf, '0')
}
case time.Time:
if v.IsZero() {
buf = append(buf, "'0000-00-00'"...)
} else {
v := v.In(mc.cfg.Loc)
v = v.Add(time.Nanosecond * 500) // To round under microsecond
year := v.Year()
year100 := year / 100
year1 := year % 100
month := v.Month()
day := v.Day()
hour := v.Hour()
minute := v.Minute()
second := v.Second()
micro := v.Nanosecond() / 1000
buf = append(buf, []byte{
'\'',
digits10[year100], digits01[year100],
digits10[year1], digits01[year1],
'-',
digits10[month], digits01[month],
'-',
digits10[day], digits01[day],
' ',
digits10[hour], digits01[hour],
':',
digits10[minute], digits01[minute],
':',
digits10[second], digits01[second],
}...)
if micro != 0 {
micro10000 := micro / 10000
micro100 := micro / 100 % 100
micro1 := micro % 100
buf = append(buf, []byte{
'.',
digits10[micro10000], digits01[micro10000],
digits10[micro100], digits01[micro100],
digits10[micro1], digits01[micro1],
}...)
}
buf = append(buf, '\'')
}
case []byte:
if v == nil {
buf = append(buf, "NULL"...)
} else {
buf = append(buf, "_binary'"...)
if mc.status&statusNoBackslashEscapes == 0 {
buf = escapeBytesBackslash(buf, v)
} else {
buf = escapeBytesQuotes(buf, v)
}
buf = append(buf, '\'')
}
case string:
buf = append(buf, '\'')
if mc.status&statusNoBackslashEscapes == 0 {
buf = escapeStringBackslash(buf, v)
} else {
buf = escapeStringQuotes(buf, v)
}
buf = append(buf, '\'')
default:
return "", driver.ErrSkip
}
if len(buf)+4 > mc.maxAllowedPacket {
return "", driver.ErrSkip
}
}
if argPos != len(args) {
return "", driver.ErrSkip
}
return string(buf), nil
}
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
if mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
if len(args) != 0 {
if !mc.cfg.InterpolateParams {
return nil, driver.ErrSkip
}
// try to interpolate the parameters to save extra roundtrips for preparing and closing a statement
prepared, err := mc.interpolateParams(query, args)
if err != nil {
return nil, err
}
query = prepared
}
mc.affectedRows = 0
mc.insertID = 0
err := mc.exec(query)
if err == nil {
return &mysqlResult{
affectedRows: int64(mc.affectedRows),
insertID: int64(mc.insertID),
}, err
}
return nil, mc.markBadConn(err)
}
// Internal function to execute commands
func (mc *mysqlConn) exec(query string) error {
// Send command

View File

@ -5,6 +5,11 @@ import (
"time"
)
// ExtendedConn provides access to internal packet functions.
type ExtendedConn struct {
*mysqlConn
}
// NewExtendedConnection creates a new connection extended with packet access
// methods.
func NewExtendedConnection(dsn string) (*ExtendedConn, error) {
@ -15,17 +20,9 @@ func NewExtendedConnection(dsn string) (*ExtendedConn, error) {
return &ExtendedConn{conn}, nil
}
// ExtendedConn provides access to internal packet functions.
type ExtendedConn struct {
*mysqlConn
}
// Close the connection.
func (c *ExtendedConn) Close() error {
// Reset buffer length parameter
// If it's not zero bad stuff happens
c.buf.length = 0
return c.mysqlConn.Close()
// Exec executes a query.
func (c *ExtendedConn) Exec(query string) error {
return c.exec(query)
}
// ReadPacket reads a packet from a given connection. If given context has a
@ -65,3 +62,11 @@ func (c *ExtendedConn) HandleErrorPacket(data []byte) error {
func (c *ExtendedConn) ResetSequence() {
c.sequence = 0
}
// Close the connection.
func (c *ExtendedConn) Close() error {
// Reset buffer length parameter
// If it's not zero bad stuff happens
c.buf.length = 0
return c.mysqlConn.Close()
}

View File

@ -121,8 +121,7 @@ func (c *Conn) DisableChecksum() error {
// SetVar assigns a new value to the given variable.
func (c *Conn) SetVar(name, val string) error {
_, err := c.conn.Exec(fmt.Sprintf("SET %s=%q", name, val), nil)
return err
return c.conn.Exec(fmt.Sprintf("SET %s=%q", name, val))
}
// Close the connection.