// Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. package driver import ( "context" "database/sql/driver" "net" "strconv" "strings" "time" ) type mysqlConn struct { buf buffer netConn net.Conn affectedRows uint64 insertID uint64 cfg *Config maxAllowedPacket int maxWriteSize int writeTimeout time.Duration flags clientFlag status statusFlag sequence uint8 parseTime bool // for context support (Go 1.8+) watching bool watcher chan<- context.Context closech chan struct{} finished chan<- struct{} canceled atomicError // set non-nil if conn is canceled closed atomicBool // set when conn is closed, before closech is closed } func newConnection(dsn string) (*mysqlConn, error) { var err error // New mysqlConn mc := &mysqlConn{ maxAllowedPacket: maxPacketSize, maxWriteSize: maxPacketSize - 1, closech: make(chan struct{}), } mc.cfg, err = ParseDSN(dsn) if err != nil { return nil, err } mc.parseTime = mc.cfg.ParseTime // Connect to Server nd := net.Dialer{Timeout: mc.cfg.Timeout} mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr) if err != nil { return nil, err } // Enable TCP Keepalives on TCP connections if tc, ok := mc.netConn.(*net.TCPConn); ok { if err := tc.SetKeepAlive(true); err != nil { // Don't send COM_QUIT before handshake. mc.netConn.Close() mc.netConn = nil return nil, err } } // Call startWatcher for context support (From Go 1.8) mc.startWatcher() mc.buf = newBuffer(mc.netConn) // Set I/O timeouts mc.buf.timeout = mc.cfg.ReadTimeout mc.writeTimeout = mc.cfg.WriteTimeout // Reading Handshake Initialization Packet authData, plugin, err := mc.readHandshakePacket() if err != nil { mc.cleanup() return nil, err } // Send Client Authentication Packet authResp, addNUL, err := mc.auth(authData, plugin) if err != nil { // try the default auth plugin, if using the requested plugin failed errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error()) plugin = defaultAuthPlugin authResp, addNUL, err = mc.auth(authData, plugin) if err != nil { mc.cleanup() return nil, err } } if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil { mc.cleanup() return nil, err } // Handle response to auth packet, switch methods if possible if err = mc.handleAuthResult(authData, plugin); err != nil { // Authentication failed and MySQL has already closed the connection // (https://dev.mysql.com/doc/internals/en/authentication-fails.html). // Do not send COM_QUIT, just cleanup and return the error. mc.cleanup() return nil, err } if mc.cfg.MaxAllowedPacket > 0 { mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket } else { // Get max allowed packet size maxap, err := mc.getSystemVar("max_allowed_packet") if err != nil { mc.Close() return nil, err } mc.maxAllowedPacket = stringToInt(maxap) - 1 } if mc.maxAllowedPacket < maxPacketSize { mc.maxWriteSize = mc.maxAllowedPacket } // Handle DSN Params err = mc.handleParams() if err != nil { mc.Close() return nil, err } return mc, nil } // Handles parameters set in DSN after the connection is established func (mc *mysqlConn) handleParams() (err error) { for param, val := range mc.cfg.Params { switch param { // Charset case "charset": charsets := strings.Split(val, ",") for i := range charsets { // ignore errors here - a charset may not exist err = mc.exec("SET NAMES " + charsets[i]) if err == nil { break } } if err != nil { return } // System Vars default: err = mc.exec("SET " + param + "=" + val) if err != nil { return } } } return } func (mc *mysqlConn) markBadConn(err error) error { if mc == nil { return err } if err != errBadConnNoWrite { return err } return driver.ErrBadConn } func (mc *mysqlConn) Close() (err error) { // Makes Close idempotent if !mc.closed.IsSet() { err = mc.writeCommandPacket(comQuit) } mc.cleanup() return } // Closes the network connection and unsets internal variables. Do not call this // function after successfully authentication, call Close instead. This function // is called before auth or on auth failure because MySQL will have already // closed the network connection. func (mc *mysqlConn) cleanup() { if !mc.closed.TrySet(true) { return } // Makes cleanup idempotent close(mc.closech) if mc.netConn == nil { return } if err := mc.netConn.Close(); err != nil { errLog.Print(err) } } func (mc *mysqlConn) error() error { if mc.closed.IsSet() { if err := mc.canceled.Value(); err != nil { return err } return ErrInvalidConn } return nil } func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) { // Number of ? should be same to len(args) if strings.Count(query, "?") != len(args) { return "", driver.ErrSkip } buf := mc.buf.takeCompleteBuffer() if buf == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) return "", ErrInvalidConn } buf = buf[:0] argPos := 0 for i := 0; i < len(query); i++ { q := strings.IndexByte(query[i:], '?') if q == -1 { buf = append(buf, query[i:]...) break } buf = append(buf, query[i:i+q]...) i += q arg := args[argPos] argPos++ if arg == nil { buf = append(buf, "NULL"...) continue } switch v := arg.(type) { case int64: buf = strconv.AppendInt(buf, v, 10) case float64: buf = strconv.AppendFloat(buf, v, 'g', -1, 64) case bool: if v { buf = append(buf, '1') } else { buf = append(buf, '0') } case time.Time: if v.IsZero() { buf = append(buf, "'0000-00-00'"...) } else { v := v.In(mc.cfg.Loc) v = v.Add(time.Nanosecond * 500) // To round under microsecond year := v.Year() year100 := year / 100 year1 := year % 100 month := v.Month() day := v.Day() hour := v.Hour() minute := v.Minute() second := v.Second() micro := v.Nanosecond() / 1000 buf = append(buf, []byte{ '\'', digits10[year100], digits01[year100], digits10[year1], digits01[year1], '-', digits10[month], digits01[month], '-', digits10[day], digits01[day], ' ', digits10[hour], digits01[hour], ':', digits10[minute], digits01[minute], ':', digits10[second], digits01[second], }...) if micro != 0 { micro10000 := micro / 10000 micro100 := micro / 100 % 100 micro1 := micro % 100 buf = append(buf, []byte{ '.', digits10[micro10000], digits01[micro10000], digits10[micro100], digits01[micro100], digits10[micro1], digits01[micro1], }...) } buf = append(buf, '\'') } case []byte: if v == nil { buf = append(buf, "NULL"...) } else { buf = append(buf, "_binary'"...) if mc.status&statusNoBackslashEscapes == 0 { buf = escapeBytesBackslash(buf, v) } else { buf = escapeBytesQuotes(buf, v) } buf = append(buf, '\'') } case string: buf = append(buf, '\'') if mc.status&statusNoBackslashEscapes == 0 { buf = escapeStringBackslash(buf, v) } else { buf = escapeStringQuotes(buf, v) } buf = append(buf, '\'') default: return "", driver.ErrSkip } if len(buf)+4 > mc.maxAllowedPacket { return "", driver.ErrSkip } } if argPos != len(args) { return "", driver.ErrSkip } return string(buf), nil } func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { if mc.closed.IsSet() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } if len(args) != 0 { if !mc.cfg.InterpolateParams { return nil, driver.ErrSkip } // try to interpolate the parameters to save extra roundtrips for preparing and closing a statement prepared, err := mc.interpolateParams(query, args) if err != nil { return nil, err } query = prepared } mc.affectedRows = 0 mc.insertID = 0 err := mc.exec(query) if err == nil { return &mysqlResult{ affectedRows: int64(mc.affectedRows), insertID: int64(mc.insertID), }, err } return nil, mc.markBadConn(err) } // Internal function to execute commands func (mc *mysqlConn) exec(query string) error { // Send command if err := mc.writeCommandPacketStr(comQuery, query); err != nil { return mc.markBadConn(err) } // Read Result resLen, err := mc.readResultSetHeaderPacket() if err != nil { return err } if resLen > 0 { // columns if err := mc.readUntilEOF(); err != nil { return err } // rows if err := mc.readUntilEOF(); err != nil { return err } } return mc.discardResults() } // Gets the value of the given MySQL System Variable // The returned byte slice is only valid until the next read func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { // Send command if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil { return nil, err } // Read Result resLen, err := mc.readResultSetHeaderPacket() if err == nil { rows := new(textRows) rows.mc = mc rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}} if resLen > 0 { // Columns if err := mc.readUntilEOF(); err != nil { return nil, err } } dest := make([]driver.Value, resLen) if err = rows.readRow(dest); err == nil { return dest[0].([]byte), mc.readUntilEOF() } } return nil, err } // finish is called when the query has canceled. func (mc *mysqlConn) cancel(err error) { mc.canceled.Set(err) mc.cleanup() } // finish is called when the query has succeeded. func (mc *mysqlConn) finish() { if !mc.watching || mc.finished == nil { return } select { case mc.finished <- struct{}{}: mc.watching = false case <-mc.closech: } } // Ping implements driver.Pinger interface func (mc *mysqlConn) Ping(ctx context.Context) (err error) { if mc.closed.IsSet() { errLog.Print(ErrInvalidConn) return driver.ErrBadConn } if err = mc.watchCancel(ctx); err != nil { return } defer mc.finish() if err = mc.writeCommandPacket(comPing); err != nil { return } return mc.readResultOK() } func (mc *mysqlConn) watchCancel(ctx context.Context) error { if mc.watching { // Reach here if canceled, // so the connection is already invalid mc.cleanup() return nil } if ctx.Done() == nil { return nil } mc.watching = true select { default: case <-ctx.Done(): return ctx.Err() } if mc.watcher == nil { return nil } mc.watcher <- ctx return nil } func (mc *mysqlConn) startWatcher() { watcher := make(chan context.Context, 1) mc.watcher = watcher finished := make(chan struct{}) mc.finished = finished go func() { for { var ctx context.Context select { case ctx = <-watcher: case <-mc.closech: return } select { case <-ctx.Done(): mc.cancel(ctx.Err()) case <-finished: case <-mc.closech: return } } }() } // ResetSession implements driver.SessionResetter. // (From Go 1.10) func (mc *mysqlConn) ResetSession(ctx context.Context) error { if mc.closed.IsSet() { return driver.ErrBadConn } return nil }