Don't mess with driver code, include it as-is
It doesn't seem worth the trouble to reduce amount of code in the driver. Removing automatic registration (along with distinguishing between network timeouts and context timeouts) seem sufficient.
This commit is contained in:
parent
2f12207dbe
commit
05009b19d7
|
@ -1,374 +0,0 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package driver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type mysqlConn struct {
|
||||
buf buffer
|
||||
netConn net.Conn
|
||||
affectedRows uint64
|
||||
insertID uint64
|
||||
cfg *Config
|
||||
maxAllowedPacket int
|
||||
maxWriteSize int
|
||||
writeTimeout time.Duration
|
||||
flags clientFlag
|
||||
status statusFlag
|
||||
sequence uint8
|
||||
parseTime bool
|
||||
|
||||
// for context support (Go 1.8+)
|
||||
watching bool
|
||||
watcher chan<- context.Context
|
||||
closech chan struct{}
|
||||
finished chan<- struct{}
|
||||
canceled atomicError // set non-nil if conn is canceled
|
||||
closed atomicBool // set when conn is closed, before closech is closed
|
||||
}
|
||||
|
||||
func newConnection(dsn string) (*mysqlConn, error) {
|
||||
var err error
|
||||
|
||||
// New mysqlConn
|
||||
mc := &mysqlConn{
|
||||
maxAllowedPacket: maxPacketSize,
|
||||
maxWriteSize: maxPacketSize - 1,
|
||||
closech: make(chan struct{}),
|
||||
}
|
||||
mc.cfg, err = ParseDSN(dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mc.parseTime = mc.cfg.ParseTime
|
||||
|
||||
// Connect to Server
|
||||
nd := net.Dialer{Timeout: mc.cfg.Timeout}
|
||||
mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Enable TCP Keepalives on TCP connections
|
||||
if tc, ok := mc.netConn.(*net.TCPConn); ok {
|
||||
if err := tc.SetKeepAlive(true); err != nil {
|
||||
// Don't send COM_QUIT before handshake.
|
||||
mc.netConn.Close()
|
||||
mc.netConn = nil
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Call startWatcher for context support (From Go 1.8)
|
||||
mc.startWatcher()
|
||||
|
||||
mc.buf = newBuffer(mc.netConn)
|
||||
|
||||
// Set I/O timeouts
|
||||
mc.buf.timeout = mc.cfg.ReadTimeout
|
||||
mc.writeTimeout = mc.cfg.WriteTimeout
|
||||
|
||||
// Reading Handshake Initialization Packet
|
||||
authData, plugin, err := mc.readHandshakePacket()
|
||||
if err != nil {
|
||||
mc.cleanup()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Send Client Authentication Packet
|
||||
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
// try the default auth plugin, if using the requested plugin failed
|
||||
errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
|
||||
plugin = defaultAuthPlugin
|
||||
authResp, addNUL, err = mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
mc.cleanup()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil {
|
||||
mc.cleanup()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Handle response to auth packet, switch methods if possible
|
||||
if err = mc.handleAuthResult(authData, plugin); err != nil {
|
||||
// Authentication failed and MySQL has already closed the connection
|
||||
// (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
|
||||
// Do not send COM_QUIT, just cleanup and return the error.
|
||||
mc.cleanup()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if mc.cfg.MaxAllowedPacket > 0 {
|
||||
mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket
|
||||
} else {
|
||||
// Get max allowed packet size
|
||||
maxap, err := mc.getSystemVar("max_allowed_packet")
|
||||
if err != nil {
|
||||
mc.Close()
|
||||
return nil, err
|
||||
}
|
||||
mc.maxAllowedPacket = stringToInt(maxap) - 1
|
||||
}
|
||||
if mc.maxAllowedPacket < maxPacketSize {
|
||||
mc.maxWriteSize = mc.maxAllowedPacket
|
||||
}
|
||||
|
||||
// Handle DSN Params
|
||||
err = mc.handleParams()
|
||||
if err != nil {
|
||||
mc.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return mc, nil
|
||||
}
|
||||
|
||||
// Handles parameters set in DSN after the connection is established
|
||||
func (mc *mysqlConn) handleParams() (err error) {
|
||||
for param, val := range mc.cfg.Params {
|
||||
switch param {
|
||||
// Charset
|
||||
case "charset":
|
||||
charsets := strings.Split(val, ",")
|
||||
for i := range charsets {
|
||||
// ignore errors here - a charset may not exist
|
||||
err = mc.exec("SET NAMES " + charsets[i])
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// System Vars
|
||||
default:
|
||||
err = mc.exec("SET " + param + "=" + val)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) markBadConn(err error) error {
|
||||
if mc == nil {
|
||||
return err
|
||||
}
|
||||
if err != errBadConnNoWrite {
|
||||
return err
|
||||
}
|
||||
return driver.ErrBadConn
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) Close() (err error) {
|
||||
// Makes Close idempotent
|
||||
if !mc.closed.IsSet() {
|
||||
err = mc.writeCommandPacket(comQuit)
|
||||
}
|
||||
|
||||
mc.cleanup()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Closes the network connection and unsets internal variables. Do not call this
|
||||
// function after successfully authentication, call Close instead. This function
|
||||
// is called before auth or on auth failure because MySQL will have already
|
||||
// closed the network connection.
|
||||
func (mc *mysqlConn) cleanup() {
|
||||
if !mc.closed.TrySet(true) {
|
||||
return
|
||||
}
|
||||
|
||||
// Makes cleanup idempotent
|
||||
close(mc.closech)
|
||||
if mc.netConn == nil {
|
||||
return
|
||||
}
|
||||
if err := mc.netConn.Close(); err != nil {
|
||||
errLog.Print(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) error() error {
|
||||
if mc.closed.IsSet() {
|
||||
if err := mc.canceled.Value(); err != nil {
|
||||
return err
|
||||
}
|
||||
return ErrInvalidConn
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Internal function to execute commands
|
||||
func (mc *mysqlConn) exec(query string) error {
|
||||
// Send command
|
||||
if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
|
||||
return mc.markBadConn(err)
|
||||
}
|
||||
|
||||
// Read Result
|
||||
resLen, err := mc.readResultSetHeaderPacket()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if resLen > 0 {
|
||||
// columns
|
||||
if err := mc.readUntilEOF(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// rows
|
||||
if err := mc.readUntilEOF(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return mc.discardResults()
|
||||
}
|
||||
|
||||
// Gets the value of the given MySQL System Variable
|
||||
// The returned byte slice is only valid until the next read
|
||||
func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
|
||||
// Send command
|
||||
if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Read Result
|
||||
resLen, err := mc.readResultSetHeaderPacket()
|
||||
if err == nil {
|
||||
rows := new(textRows)
|
||||
rows.mc = mc
|
||||
rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
|
||||
|
||||
if resLen > 0 {
|
||||
// Columns
|
||||
if err := mc.readUntilEOF(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
dest := make([]driver.Value, resLen)
|
||||
if err = rows.readRow(dest); err == nil {
|
||||
return dest[0].([]byte), mc.readUntilEOF()
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// finish is called when the query has canceled.
|
||||
func (mc *mysqlConn) cancel(err error) {
|
||||
mc.canceled.Set(err)
|
||||
mc.cleanup()
|
||||
}
|
||||
|
||||
// finish is called when the query has succeeded.
|
||||
func (mc *mysqlConn) finish() {
|
||||
if !mc.watching || mc.finished == nil {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case mc.finished <- struct{}{}:
|
||||
mc.watching = false
|
||||
case <-mc.closech:
|
||||
}
|
||||
}
|
||||
|
||||
// Ping implements driver.Pinger interface
|
||||
func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
|
||||
if mc.closed.IsSet() {
|
||||
errLog.Print(ErrInvalidConn)
|
||||
return driver.ErrBadConn
|
||||
}
|
||||
|
||||
if err = mc.watchCancel(ctx); err != nil {
|
||||
return
|
||||
}
|
||||
defer mc.finish()
|
||||
|
||||
if err = mc.writeCommandPacket(comPing); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return mc.readResultOK()
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) watchCancel(ctx context.Context) error {
|
||||
if mc.watching {
|
||||
// Reach here if canceled,
|
||||
// so the connection is already invalid
|
||||
mc.cleanup()
|
||||
return nil
|
||||
}
|
||||
if ctx.Done() == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
mc.watching = true
|
||||
select {
|
||||
default:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
if mc.watcher == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
mc.watcher <- ctx
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) startWatcher() {
|
||||
watcher := make(chan context.Context, 1)
|
||||
mc.watcher = watcher
|
||||
finished := make(chan struct{})
|
||||
mc.finished = finished
|
||||
go func() {
|
||||
for {
|
||||
var ctx context.Context
|
||||
select {
|
||||
case ctx = <-watcher:
|
||||
case <-mc.closech:
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
mc.cancel(ctx.Err())
|
||||
case <-finished:
|
||||
case <-mc.closech:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// ResetSession implements driver.SessionResetter.
|
||||
// (From Go 1.10)
|
||||
func (mc *mysqlConn) ResetSession(ctx context.Context) error {
|
||||
if mc.closed.IsSet() {
|
||||
return driver.ErrBadConn
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,167 @@
|
|||
## Version 1.4 (2018-06-03)
|
||||
|
||||
Changes:
|
||||
|
||||
- Documentation fixes (#530, #535, #567)
|
||||
- Refactoring (#575, #579, #580, #581, #603, #615, #704)
|
||||
- Cache column names (#444)
|
||||
- Sort the DSN parameters in DSNs generated from a config (#637)
|
||||
- Allow native password authentication by default (#644)
|
||||
- Use the default port if it is missing in the DSN (#668)
|
||||
- Removed the `strict` mode (#676)
|
||||
- Do not query `max_allowed_packet` by default (#680)
|
||||
- Dropped support Go 1.6 and lower (#696)
|
||||
- Updated `ConvertValue()` to match the database/sql/driver implementation (#760)
|
||||
- Document the usage of `0000-00-00T00:00:00` as the time.Time zero value (#783)
|
||||
- Improved the compatibility of the authentication system (#807)
|
||||
|
||||
New Features:
|
||||
|
||||
- Multi-Results support (#537)
|
||||
- `rejectReadOnly` DSN option (#604)
|
||||
- `context.Context` support (#608, #612, #627, #761)
|
||||
- Transaction isolation level support (#619, #744)
|
||||
- Read-Only transactions support (#618, #634)
|
||||
- `NewConfig` function which initializes a config with default values (#679)
|
||||
- Implemented the `ColumnType` interfaces (#667, #724)
|
||||
- Support for custom string types in `ConvertValue` (#623)
|
||||
- Implemented `NamedValueChecker`, improving support for uint64 with high bit set (#690, #709, #710)
|
||||
- `caching_sha2_password` authentication plugin support (#794, #800, #801, #802)
|
||||
- Implemented `driver.SessionResetter` (#779)
|
||||
- `sha256_password` authentication plugin support (#808)
|
||||
|
||||
Bugfixes:
|
||||
|
||||
- Use the DSN hostname as TLS default ServerName if `tls=true` (#564, #718)
|
||||
- Fixed LOAD LOCAL DATA INFILE for empty files (#590)
|
||||
- Removed columns definition cache since it sometimes cached invalid data (#592)
|
||||
- Don't mutate registered TLS configs (#600)
|
||||
- Make RegisterTLSConfig concurrency-safe (#613)
|
||||
- Handle missing auth data in the handshake packet correctly (#646)
|
||||
- Do not retry queries when data was written to avoid data corruption (#302, #736)
|
||||
- Cache the connection pointer for error handling before invalidating it (#678)
|
||||
- Fixed imports for appengine/cloudsql (#700)
|
||||
- Fix sending STMT_LONG_DATA for 0 byte data (#734)
|
||||
- Set correct capacity for []bytes read from length-encoded strings (#766)
|
||||
- Make RegisterDial concurrency-safe (#773)
|
||||
|
||||
|
||||
## Version 1.3 (2016-12-01)
|
||||
|
||||
Changes:
|
||||
|
||||
- Go 1.1 is no longer supported
|
||||
- Use decimals fields in MySQL to format time types (#249)
|
||||
- Buffer optimizations (#269)
|
||||
- TLS ServerName defaults to the host (#283)
|
||||
- Refactoring (#400, #410, #437)
|
||||
- Adjusted documentation for second generation CloudSQL (#485)
|
||||
- Documented DSN system var quoting rules (#502)
|
||||
- Made statement.Close() calls idempotent to avoid errors in Go 1.6+ (#512)
|
||||
|
||||
New Features:
|
||||
|
||||
- Enable microsecond resolution on TIME, DATETIME and TIMESTAMP (#249)
|
||||
- Support for returning table alias on Columns() (#289, #359, #382)
|
||||
- Placeholder interpolation, can be actived with the DSN parameter `interpolateParams=true` (#309, #318, #490)
|
||||
- Support for uint64 parameters with high bit set (#332, #345)
|
||||
- Cleartext authentication plugin support (#327)
|
||||
- Exported ParseDSN function and the Config struct (#403, #419, #429)
|
||||
- Read / Write timeouts (#401)
|
||||
- Support for JSON field type (#414)
|
||||
- Support for multi-statements and multi-results (#411, #431)
|
||||
- DSN parameter to set the driver-side max_allowed_packet value manually (#489)
|
||||
- Native password authentication plugin support (#494, #524)
|
||||
|
||||
Bugfixes:
|
||||
|
||||
- Fixed handling of queries without columns and rows (#255)
|
||||
- Fixed a panic when SetKeepAlive() failed (#298)
|
||||
- Handle ERR packets while reading rows (#321)
|
||||
- Fixed reading NULL length-encoded integers in MySQL 5.6+ (#349)
|
||||
- Fixed absolute paths support in LOAD LOCAL DATA INFILE (#356)
|
||||
- Actually zero out bytes in handshake response (#378)
|
||||
- Fixed race condition in registering LOAD DATA INFILE handler (#383)
|
||||
- Fixed tests with MySQL 5.7.9+ (#380)
|
||||
- QueryUnescape TLS config names (#397)
|
||||
- Fixed "broken pipe" error by writing to closed socket (#390)
|
||||
- Fixed LOAD LOCAL DATA INFILE buffering (#424)
|
||||
- Fixed parsing of floats into float64 when placeholders are used (#434)
|
||||
- Fixed DSN tests with Go 1.7+ (#459)
|
||||
- Handle ERR packets while waiting for EOF (#473)
|
||||
- Invalidate connection on error while discarding additional results (#513)
|
||||
- Allow terminating packets of length 0 (#516)
|
||||
|
||||
|
||||
## Version 1.2 (2014-06-03)
|
||||
|
||||
Changes:
|
||||
|
||||
- We switched back to a "rolling release". `go get` installs the current master branch again
|
||||
- Version v1 of the driver will not be maintained anymore. Go 1.0 is no longer supported by this driver
|
||||
- Exported errors to allow easy checking from application code
|
||||
- Enabled TCP Keepalives on TCP connections
|
||||
- Optimized INFILE handling (better buffer size calculation, lazy init, ...)
|
||||
- The DSN parser also checks for a missing separating slash
|
||||
- Faster binary date / datetime to string formatting
|
||||
- Also exported the MySQLWarning type
|
||||
- mysqlConn.Close returns the first error encountered instead of ignoring all errors
|
||||
- writePacket() automatically writes the packet size to the header
|
||||
- readPacket() uses an iterative approach instead of the recursive approach to merge splitted packets
|
||||
|
||||
New Features:
|
||||
|
||||
- `RegisterDial` allows the usage of a custom dial function to establish the network connection
|
||||
- Setting the connection collation is possible with the `collation` DSN parameter. This parameter should be preferred over the `charset` parameter
|
||||
- Logging of critical errors is configurable with `SetLogger`
|
||||
- Google CloudSQL support
|
||||
|
||||
Bugfixes:
|
||||
|
||||
- Allow more than 32 parameters in prepared statements
|
||||
- Various old_password fixes
|
||||
- Fixed TestConcurrent test to pass Go's race detection
|
||||
- Fixed appendLengthEncodedInteger for large numbers
|
||||
- Renamed readLengthEnodedString to readLengthEncodedString and skipLengthEnodedString to skipLengthEncodedString (fixed typo)
|
||||
|
||||
|
||||
## Version 1.1 (2013-11-02)
|
||||
|
||||
Changes:
|
||||
|
||||
- Go-MySQL-Driver now requires Go 1.1
|
||||
- Connections now use the collation `utf8_general_ci` by default. Adding `&charset=UTF8` to the DSN should not be necessary anymore
|
||||
- Made closing rows and connections error tolerant. This allows for example deferring rows.Close() without checking for errors
|
||||
- `[]byte(nil)` is now treated as a NULL value. Before, it was treated like an empty string / `[]byte("")`
|
||||
- DSN parameter values must now be url.QueryEscape'ed. This allows text values to contain special characters, such as '&'.
|
||||
- Use the IO buffer also for writing. This results in zero allocations (by the driver) for most queries
|
||||
- Optimized the buffer for reading
|
||||
- stmt.Query now caches column metadata
|
||||
- New Logo
|
||||
- Changed the copyright header to include all contributors
|
||||
- Improved the LOAD INFILE documentation
|
||||
- The driver struct is now exported to make the driver directly accessible
|
||||
- Refactored the driver tests
|
||||
- Added more benchmarks and moved all to a separate file
|
||||
- Other small refactoring
|
||||
|
||||
New Features:
|
||||
|
||||
- Added *old_passwords* support: Required in some cases, but must be enabled by adding `allowOldPasswords=true` to the DSN since it is insecure
|
||||
- Added a `clientFoundRows` parameter: Return the number of matching rows instead of the number of rows changed on UPDATEs
|
||||
- Added TLS/SSL support: Use a TLS/SSL encrypted connection to the server. Custom TLS configs can be registered and used
|
||||
|
||||
Bugfixes:
|
||||
|
||||
- Fixed MySQL 4.1 support: MySQL 4.1 sends packets with lengths which differ from the specification
|
||||
- Convert to DB timezone when inserting `time.Time`
|
||||
- Splitted packets (more than 16MB) are now merged correctly
|
||||
- Fixed false positive `io.EOF` errors when the data was fully read
|
||||
- Avoid panics on reuse of closed connections
|
||||
- Fixed empty string producing false nil values
|
||||
- Fixed sign byte for positive TIME fields
|
||||
|
||||
|
||||
## Version 1.0 (2013-05-14)
|
||||
|
||||
Initial Release
|
|
@ -0,0 +1,23 @@
|
|||
# Contributing Guidelines
|
||||
|
||||
## Reporting Issues
|
||||
|
||||
Before creating a new Issue, please check first if a similar Issue [already exists](https://github.com/go-sql-driver/mysql/issues?state=open) or was [recently closed](https://github.com/go-sql-driver/mysql/issues?direction=desc&page=1&sort=updated&state=closed).
|
||||
|
||||
## Contributing Code
|
||||
|
||||
By contributing to this project, you share your code under the Mozilla Public License 2, as specified in the LICENSE file.
|
||||
Don't forget to add yourself to the AUTHORS file.
|
||||
|
||||
### Code Review
|
||||
|
||||
Everyone is invited to review and comment on pull requests.
|
||||
If it looks fine to you, comment with "LGTM" (Looks good to me).
|
||||
|
||||
If changes are required, notice the reviewers with "PTAL" (Please take another look) after committing the fixes.
|
||||
|
||||
Before merging the Pull Request, at least one [team member](https://github.com/go-sql-driver?tab=members) must have commented with "LGTM".
|
||||
|
||||
## Development Ideas
|
||||
|
||||
If you are looking for ideas for code contributions, please check our [Development Ideas](https://github.com/go-sql-driver/mysql/wiki/Development-Ideas) Wiki page.
|
|
@ -0,0 +1,19 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
// +build appengine
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"google.golang.org/appengine/cloudsql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterDial("cloudsql", cloudsql.Dial)
|
||||
}
|
|
@ -6,7 +6,7 @@
|
|||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package driver
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,319 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"math"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type TB testing.B
|
||||
|
||||
func (tb *TB) check(err error) {
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (tb *TB) checkDB(db *sql.DB, err error) *sql.DB {
|
||||
tb.check(err)
|
||||
return db
|
||||
}
|
||||
|
||||
func (tb *TB) checkRows(rows *sql.Rows, err error) *sql.Rows {
|
||||
tb.check(err)
|
||||
return rows
|
||||
}
|
||||
|
||||
func (tb *TB) checkStmt(stmt *sql.Stmt, err error) *sql.Stmt {
|
||||
tb.check(err)
|
||||
return stmt
|
||||
}
|
||||
|
||||
func initDB(b *testing.B, queries ...string) *sql.DB {
|
||||
tb := (*TB)(b)
|
||||
db := tb.checkDB(sql.Open("mysql", dsn))
|
||||
for _, query := range queries {
|
||||
if _, err := db.Exec(query); err != nil {
|
||||
b.Fatalf("error on %q: %v", query, err)
|
||||
}
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
const concurrencyLevel = 10
|
||||
|
||||
func BenchmarkQuery(b *testing.B) {
|
||||
tb := (*TB)(b)
|
||||
b.StopTimer()
|
||||
b.ReportAllocs()
|
||||
db := initDB(b,
|
||||
"DROP TABLE IF EXISTS foo",
|
||||
"CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))",
|
||||
`INSERT INTO foo VALUES (1, "one")`,
|
||||
`INSERT INTO foo VALUES (2, "two")`,
|
||||
)
|
||||
db.SetMaxIdleConns(concurrencyLevel)
|
||||
defer db.Close()
|
||||
|
||||
stmt := tb.checkStmt(db.Prepare("SELECT val FROM foo WHERE id=?"))
|
||||
defer stmt.Close()
|
||||
|
||||
remain := int64(b.N)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(concurrencyLevel)
|
||||
defer wg.Wait()
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < concurrencyLevel; i++ {
|
||||
go func() {
|
||||
for {
|
||||
if atomic.AddInt64(&remain, -1) < 0 {
|
||||
wg.Done()
|
||||
return
|
||||
}
|
||||
|
||||
var got string
|
||||
tb.check(stmt.QueryRow(1).Scan(&got))
|
||||
if got != "one" {
|
||||
b.Errorf("query = %q; want one", got)
|
||||
wg.Done()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkExec(b *testing.B) {
|
||||
tb := (*TB)(b)
|
||||
b.StopTimer()
|
||||
b.ReportAllocs()
|
||||
db := tb.checkDB(sql.Open("mysql", dsn))
|
||||
db.SetMaxIdleConns(concurrencyLevel)
|
||||
defer db.Close()
|
||||
|
||||
stmt := tb.checkStmt(db.Prepare("DO 1"))
|
||||
defer stmt.Close()
|
||||
|
||||
remain := int64(b.N)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(concurrencyLevel)
|
||||
defer wg.Wait()
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < concurrencyLevel; i++ {
|
||||
go func() {
|
||||
for {
|
||||
if atomic.AddInt64(&remain, -1) < 0 {
|
||||
wg.Done()
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
b.Fatal(err.Error())
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// data, but no db writes
|
||||
var roundtripSample []byte
|
||||
|
||||
func initRoundtripBenchmarks() ([]byte, int, int) {
|
||||
if roundtripSample == nil {
|
||||
roundtripSample = []byte(strings.Repeat("0123456789abcdef", 1024*1024))
|
||||
}
|
||||
return roundtripSample, 16, len(roundtripSample)
|
||||
}
|
||||
|
||||
func BenchmarkRoundtripTxt(b *testing.B) {
|
||||
b.StopTimer()
|
||||
sample, min, max := initRoundtripBenchmarks()
|
||||
sampleString := string(sample)
|
||||
b.ReportAllocs()
|
||||
tb := (*TB)(b)
|
||||
db := tb.checkDB(sql.Open("mysql", dsn))
|
||||
defer db.Close()
|
||||
b.StartTimer()
|
||||
var result string
|
||||
for i := 0; i < b.N; i++ {
|
||||
length := min + i
|
||||
if length > max {
|
||||
length = max
|
||||
}
|
||||
test := sampleString[0:length]
|
||||
rows := tb.checkRows(db.Query(`SELECT "` + test + `"`))
|
||||
if !rows.Next() {
|
||||
rows.Close()
|
||||
b.Fatalf("crashed")
|
||||
}
|
||||
err := rows.Scan(&result)
|
||||
if err != nil {
|
||||
rows.Close()
|
||||
b.Fatalf("crashed")
|
||||
}
|
||||
if result != test {
|
||||
rows.Close()
|
||||
b.Errorf("mismatch")
|
||||
}
|
||||
rows.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRoundtripBin(b *testing.B) {
|
||||
b.StopTimer()
|
||||
sample, min, max := initRoundtripBenchmarks()
|
||||
b.ReportAllocs()
|
||||
tb := (*TB)(b)
|
||||
db := tb.checkDB(sql.Open("mysql", dsn))
|
||||
defer db.Close()
|
||||
stmt := tb.checkStmt(db.Prepare("SELECT ?"))
|
||||
defer stmt.Close()
|
||||
b.StartTimer()
|
||||
var result sql.RawBytes
|
||||
for i := 0; i < b.N; i++ {
|
||||
length := min + i
|
||||
if length > max {
|
||||
length = max
|
||||
}
|
||||
test := sample[0:length]
|
||||
rows := tb.checkRows(stmt.Query(test))
|
||||
if !rows.Next() {
|
||||
rows.Close()
|
||||
b.Fatalf("crashed")
|
||||
}
|
||||
err := rows.Scan(&result)
|
||||
if err != nil {
|
||||
rows.Close()
|
||||
b.Fatalf("crashed")
|
||||
}
|
||||
if !bytes.Equal(result, test) {
|
||||
rows.Close()
|
||||
b.Errorf("mismatch")
|
||||
}
|
||||
rows.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkInterpolation(b *testing.B) {
|
||||
mc := &mysqlConn{
|
||||
cfg: &Config{
|
||||
InterpolateParams: true,
|
||||
Loc: time.UTC,
|
||||
},
|
||||
maxAllowedPacket: maxPacketSize,
|
||||
maxWriteSize: maxPacketSize - 1,
|
||||
buf: newBuffer(nil),
|
||||
}
|
||||
|
||||
args := []driver.Value{
|
||||
int64(42424242),
|
||||
float64(math.Pi),
|
||||
false,
|
||||
time.Unix(1423411542, 807015000),
|
||||
[]byte("bytes containing special chars ' \" \a \x00"),
|
||||
"string containing special chars ' \" \a \x00",
|
||||
}
|
||||
q := "SELECT ?, ?, ?, ?, ?, ?"
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := mc.interpolateParams(q, args)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func benchmarkQueryContext(b *testing.B, db *sql.DB, p int) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
db.SetMaxIdleConns(p * runtime.GOMAXPROCS(0))
|
||||
|
||||
tb := (*TB)(b)
|
||||
stmt := tb.checkStmt(db.PrepareContext(ctx, "SELECT val FROM foo WHERE id=?"))
|
||||
defer stmt.Close()
|
||||
|
||||
b.SetParallelism(p)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
var got string
|
||||
for pb.Next() {
|
||||
tb.check(stmt.QueryRow(1).Scan(&got))
|
||||
if got != "one" {
|
||||
b.Fatalf("query = %q; want one", got)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkQueryContext(b *testing.B) {
|
||||
db := initDB(b,
|
||||
"DROP TABLE IF EXISTS foo",
|
||||
"CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))",
|
||||
`INSERT INTO foo VALUES (1, "one")`,
|
||||
`INSERT INTO foo VALUES (2, "two")`,
|
||||
)
|
||||
defer db.Close()
|
||||
for _, p := range []int{1, 2, 3, 4} {
|
||||
b.Run(fmt.Sprintf("%d", p), func(b *testing.B) {
|
||||
benchmarkQueryContext(b, db, p)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func benchmarkExecContext(b *testing.B, db *sql.DB, p int) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
db.SetMaxIdleConns(p * runtime.GOMAXPROCS(0))
|
||||
|
||||
tb := (*TB)(b)
|
||||
stmt := tb.checkStmt(db.PrepareContext(ctx, "DO 1"))
|
||||
defer stmt.Close()
|
||||
|
||||
b.SetParallelism(p)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkExecContext(b *testing.B) {
|
||||
db := initDB(b,
|
||||
"DROP TABLE IF EXISTS foo",
|
||||
"CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))",
|
||||
`INSERT INTO foo VALUES (1, "one")`,
|
||||
`INSERT INTO foo VALUES (2, "two")`,
|
||||
)
|
||||
defer db.Close()
|
||||
for _, p := range []int{1, 2, 3, 4} {
|
||||
b.Run(fmt.Sprintf("%d", p), func(b *testing.B) {
|
||||
benchmarkQueryContext(b, db, p)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,35 +1,45 @@
|
|||
package driver
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ExtendConn creates an extended connection.
|
||||
func ExtendConn(conn driver.Conn) (*ExtendedConn, error) {
|
||||
if conn == nil {
|
||||
return nil, errors.New("Connection is nil")
|
||||
}
|
||||
mc, ok := conn.(*mysqlConn)
|
||||
if !ok || mc == nil {
|
||||
return nil, errors.New("Invalid connection")
|
||||
}
|
||||
|
||||
return &ExtendedConn{mc}, nil
|
||||
}
|
||||
|
||||
// ExtendedConn provides access to internal packet functions.
|
||||
type ExtendedConn struct {
|
||||
*mysqlConn
|
||||
}
|
||||
|
||||
// NewExtendedConnection creates a new connection extended with packet access
|
||||
// methods.
|
||||
func NewExtendedConnection(dsn string) (*ExtendedConn, error) {
|
||||
conn, err := newConnection(dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ExtendedConn{conn}, nil
|
||||
// Close ...
|
||||
func (c *ExtendedConn) Close() error {
|
||||
c.buf.length = 0
|
||||
return c.mysqlConn.Close()
|
||||
}
|
||||
|
||||
// Exec executes a query.
|
||||
// Exec ...
|
||||
func (c *ExtendedConn) Exec(query string) error {
|
||||
return c.exec(query)
|
||||
}
|
||||
|
||||
// ReadPacket reads a packet from a given connection. If given context has a
|
||||
// deadline it would be used as read timeout.
|
||||
// ReadPacket reads a packet from a given connection.
|
||||
func (c *ExtendedConn) ReadPacket(ctx context.Context) ([]byte, error) {
|
||||
if dl, ok := ctx.Deadline(); ok {
|
||||
dur := time.Until(dl)
|
||||
dur := dl.Sub(time.Now())
|
||||
if dur < 0 {
|
||||
return nil, context.DeadlineExceeded
|
||||
}
|
||||
|
@ -46,27 +56,17 @@ func (c *ExtendedConn) WritePacket(p []byte) error {
|
|||
return c.writePacket(p)
|
||||
}
|
||||
|
||||
// ReadResultOK returns an error if packet is not an OK_Packet.
|
||||
// Spec: https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html
|
||||
// ReadResultOK ...
|
||||
func (c *ExtendedConn) ReadResultOK() error {
|
||||
return c.readResultOK()
|
||||
}
|
||||
|
||||
// HandleErrorPacket reads error message from ERR_Packet.
|
||||
// Spec: https://dev.mysql.com/doc/internals/en/packet-ERR_Packet.html
|
||||
// HandleErrorPacket ...
|
||||
func (c *ExtendedConn) HandleErrorPacket(data []byte) error {
|
||||
return c.handleErrorPacket(data)
|
||||
}
|
||||
|
||||
// ResetSequence resets command sequence counter.
|
||||
// ResetSequence ...
|
||||
func (c *ExtendedConn) ResetSequence() {
|
||||
c.sequence = 0
|
||||
}
|
||||
|
||||
// Close the connection.
|
||||
func (c *ExtendedConn) Close() error {
|
||||
// Reset buffer length parameter
|
||||
// If it's not zero bad stuff happens
|
||||
c.buf.length = 0
|
||||
return c.mysqlConn.Close()
|
||||
}
|
|
@ -6,7 +6,7 @@
|
|||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package driver
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"io"
|
|
@ -6,7 +6,7 @@
|
|||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package driver
|
||||
package mysql
|
||||
|
||||
const defaultCollation = "utf8_general_ci"
|
||||
const binaryCollation = "binary"
|
|
@ -0,0 +1,654 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// a copy of context.Context for Go 1.7 and earlier
|
||||
type mysqlContext interface {
|
||||
Done() <-chan struct{}
|
||||
Err() error
|
||||
|
||||
// defined in context.Context, but not used in this driver:
|
||||
// Deadline() (deadline time.Time, ok bool)
|
||||
// Value(key interface{}) interface{}
|
||||
}
|
||||
|
||||
type mysqlConn struct {
|
||||
buf buffer
|
||||
netConn net.Conn
|
||||
affectedRows uint64
|
||||
insertId uint64
|
||||
cfg *Config
|
||||
maxAllowedPacket int
|
||||
maxWriteSize int
|
||||
writeTimeout time.Duration
|
||||
flags clientFlag
|
||||
status statusFlag
|
||||
sequence uint8
|
||||
parseTime bool
|
||||
|
||||
// for context support (Go 1.8+)
|
||||
watching bool
|
||||
watcher chan<- mysqlContext
|
||||
closech chan struct{}
|
||||
finished chan<- struct{}
|
||||
canceled atomicError // set non-nil if conn is canceled
|
||||
closed atomicBool // set when conn is closed, before closech is closed
|
||||
}
|
||||
|
||||
// Handles parameters set in DSN after the connection is established
|
||||
func (mc *mysqlConn) handleParams() (err error) {
|
||||
for param, val := range mc.cfg.Params {
|
||||
switch param {
|
||||
// Charset
|
||||
case "charset":
|
||||
charsets := strings.Split(val, ",")
|
||||
for i := range charsets {
|
||||
// ignore errors here - a charset may not exist
|
||||
err = mc.exec("SET NAMES " + charsets[i])
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// System Vars
|
||||
default:
|
||||
err = mc.exec("SET " + param + "=" + val + "")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) markBadConn(err error) error {
|
||||
if mc == nil {
|
||||
return err
|
||||
}
|
||||
if err != errBadConnNoWrite {
|
||||
return err
|
||||
}
|
||||
return driver.ErrBadConn
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) Begin() (driver.Tx, error) {
|
||||
return mc.begin(false)
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
|
||||
if mc.closed.IsSet() {
|
||||
errLog.Print(ErrInvalidConn)
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
var q string
|
||||
if readOnly {
|
||||
q = "START TRANSACTION READ ONLY"
|
||||
} else {
|
||||
q = "START TRANSACTION"
|
||||
}
|
||||
err := mc.exec(q)
|
||||
if err == nil {
|
||||
return &mysqlTx{mc}, err
|
||||
}
|
||||
return nil, mc.markBadConn(err)
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) Close() (err error) {
|
||||
// Makes Close idempotent
|
||||
if !mc.closed.IsSet() {
|
||||
err = mc.writeCommandPacket(comQuit)
|
||||
}
|
||||
|
||||
mc.cleanup()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Closes the network connection and unsets internal variables. Do not call this
|
||||
// function after successfully authentication, call Close instead. This function
|
||||
// is called before auth or on auth failure because MySQL will have already
|
||||
// closed the network connection.
|
||||
func (mc *mysqlConn) cleanup() {
|
||||
if !mc.closed.TrySet(true) {
|
||||
return
|
||||
}
|
||||
|
||||
// Makes cleanup idempotent
|
||||
close(mc.closech)
|
||||
if mc.netConn == nil {
|
||||
return
|
||||
}
|
||||
if err := mc.netConn.Close(); err != nil {
|
||||
errLog.Print(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) error() error {
|
||||
if mc.closed.IsSet() {
|
||||
if err := mc.canceled.Value(); err != nil {
|
||||
return err
|
||||
}
|
||||
return ErrInvalidConn
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
|
||||
if mc.closed.IsSet() {
|
||||
errLog.Print(ErrInvalidConn)
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
// Send command
|
||||
err := mc.writeCommandPacketStr(comStmtPrepare, query)
|
||||
if err != nil {
|
||||
return nil, mc.markBadConn(err)
|
||||
}
|
||||
|
||||
stmt := &mysqlStmt{
|
||||
mc: mc,
|
||||
}
|
||||
|
||||
// Read Result
|
||||
columnCount, err := stmt.readPrepareResultPacket()
|
||||
if err == nil {
|
||||
if stmt.paramCount > 0 {
|
||||
if err = mc.readUntilEOF(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if columnCount > 0 {
|
||||
err = mc.readUntilEOF()
|
||||
}
|
||||
}
|
||||
|
||||
return stmt, err
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
|
||||
// Number of ? should be same to len(args)
|
||||
if strings.Count(query, "?") != len(args) {
|
||||
return "", driver.ErrSkip
|
||||
}
|
||||
|
||||
buf := mc.buf.takeCompleteBuffer()
|
||||
if buf == nil {
|
||||
// can not take the buffer. Something must be wrong with the connection
|
||||
errLog.Print(ErrBusyBuffer)
|
||||
return "", ErrInvalidConn
|
||||
}
|
||||
buf = buf[:0]
|
||||
argPos := 0
|
||||
|
||||
for i := 0; i < len(query); i++ {
|
||||
q := strings.IndexByte(query[i:], '?')
|
||||
if q == -1 {
|
||||
buf = append(buf, query[i:]...)
|
||||
break
|
||||
}
|
||||
buf = append(buf, query[i:i+q]...)
|
||||
i += q
|
||||
|
||||
arg := args[argPos]
|
||||
argPos++
|
||||
|
||||
if arg == nil {
|
||||
buf = append(buf, "NULL"...)
|
||||
continue
|
||||
}
|
||||
|
||||
switch v := arg.(type) {
|
||||
case int64:
|
||||
buf = strconv.AppendInt(buf, v, 10)
|
||||
case float64:
|
||||
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
|
||||
case bool:
|
||||
if v {
|
||||
buf = append(buf, '1')
|
||||
} else {
|
||||
buf = append(buf, '0')
|
||||
}
|
||||
case time.Time:
|
||||
if v.IsZero() {
|
||||
buf = append(buf, "'0000-00-00'"...)
|
||||
} else {
|
||||
v := v.In(mc.cfg.Loc)
|
||||
v = v.Add(time.Nanosecond * 500) // To round under microsecond
|
||||
year := v.Year()
|
||||
year100 := year / 100
|
||||
year1 := year % 100
|
||||
month := v.Month()
|
||||
day := v.Day()
|
||||
hour := v.Hour()
|
||||
minute := v.Minute()
|
||||
second := v.Second()
|
||||
micro := v.Nanosecond() / 1000
|
||||
|
||||
buf = append(buf, []byte{
|
||||
'\'',
|
||||
digits10[year100], digits01[year100],
|
||||
digits10[year1], digits01[year1],
|
||||
'-',
|
||||
digits10[month], digits01[month],
|
||||
'-',
|
||||
digits10[day], digits01[day],
|
||||
' ',
|
||||
digits10[hour], digits01[hour],
|
||||
':',
|
||||
digits10[minute], digits01[minute],
|
||||
':',
|
||||
digits10[second], digits01[second],
|
||||
}...)
|
||||
|
||||
if micro != 0 {
|
||||
micro10000 := micro / 10000
|
||||
micro100 := micro / 100 % 100
|
||||
micro1 := micro % 100
|
||||
buf = append(buf, []byte{
|
||||
'.',
|
||||
digits10[micro10000], digits01[micro10000],
|
||||
digits10[micro100], digits01[micro100],
|
||||
digits10[micro1], digits01[micro1],
|
||||
}...)
|
||||
}
|
||||
buf = append(buf, '\'')
|
||||
}
|
||||
case []byte:
|
||||
if v == nil {
|
||||
buf = append(buf, "NULL"...)
|
||||
} else {
|
||||
buf = append(buf, "_binary'"...)
|
||||
if mc.status&statusNoBackslashEscapes == 0 {
|
||||
buf = escapeBytesBackslash(buf, v)
|
||||
} else {
|
||||
buf = escapeBytesQuotes(buf, v)
|
||||
}
|
||||
buf = append(buf, '\'')
|
||||
}
|
||||
case string:
|
||||
buf = append(buf, '\'')
|
||||
if mc.status&statusNoBackslashEscapes == 0 {
|
||||
buf = escapeStringBackslash(buf, v)
|
||||
} else {
|
||||
buf = escapeStringQuotes(buf, v)
|
||||
}
|
||||
buf = append(buf, '\'')
|
||||
default:
|
||||
return "", driver.ErrSkip
|
||||
}
|
||||
|
||||
if len(buf)+4 > mc.maxAllowedPacket {
|
||||
return "", driver.ErrSkip
|
||||
}
|
||||
}
|
||||
if argPos != len(args) {
|
||||
return "", driver.ErrSkip
|
||||
}
|
||||
return string(buf), nil
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
|
||||
if mc.closed.IsSet() {
|
||||
errLog.Print(ErrInvalidConn)
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
if len(args) != 0 {
|
||||
if !mc.cfg.InterpolateParams {
|
||||
return nil, driver.ErrSkip
|
||||
}
|
||||
// try to interpolate the parameters to save extra roundtrips for preparing and closing a statement
|
||||
prepared, err := mc.interpolateParams(query, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
query = prepared
|
||||
}
|
||||
mc.affectedRows = 0
|
||||
mc.insertId = 0
|
||||
|
||||
err := mc.exec(query)
|
||||
if err == nil {
|
||||
return &mysqlResult{
|
||||
affectedRows: int64(mc.affectedRows),
|
||||
insertId: int64(mc.insertId),
|
||||
}, err
|
||||
}
|
||||
return nil, mc.markBadConn(err)
|
||||
}
|
||||
|
||||
// Internal function to execute commands
|
||||
func (mc *mysqlConn) exec(query string) error {
|
||||
// Send command
|
||||
if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
|
||||
return mc.markBadConn(err)
|
||||
}
|
||||
|
||||
// Read Result
|
||||
resLen, err := mc.readResultSetHeaderPacket()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if resLen > 0 {
|
||||
// columns
|
||||
if err := mc.readUntilEOF(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// rows
|
||||
if err := mc.readUntilEOF(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return mc.discardResults()
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
|
||||
return mc.query(query, args)
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
|
||||
if mc.closed.IsSet() {
|
||||
errLog.Print(ErrInvalidConn)
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
if len(args) != 0 {
|
||||
if !mc.cfg.InterpolateParams {
|
||||
return nil, driver.ErrSkip
|
||||
}
|
||||
// try client-side prepare to reduce roundtrip
|
||||
prepared, err := mc.interpolateParams(query, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
query = prepared
|
||||
}
|
||||
// Send command
|
||||
err := mc.writeCommandPacketStr(comQuery, query)
|
||||
if err == nil {
|
||||
// Read Result
|
||||
var resLen int
|
||||
resLen, err = mc.readResultSetHeaderPacket()
|
||||
if err == nil {
|
||||
rows := new(textRows)
|
||||
rows.mc = mc
|
||||
|
||||
if resLen == 0 {
|
||||
rows.rs.done = true
|
||||
|
||||
switch err := rows.NextResultSet(); err {
|
||||
case nil, io.EOF:
|
||||
return rows, nil
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Columns
|
||||
rows.rs.columns, err = mc.readColumns(resLen)
|
||||
return rows, err
|
||||
}
|
||||
}
|
||||
return nil, mc.markBadConn(err)
|
||||
}
|
||||
|
||||
// Gets the value of the given MySQL System Variable
|
||||
// The returned byte slice is only valid until the next read
|
||||
func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
|
||||
// Send command
|
||||
if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Read Result
|
||||
resLen, err := mc.readResultSetHeaderPacket()
|
||||
if err == nil {
|
||||
rows := new(textRows)
|
||||
rows.mc = mc
|
||||
rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
|
||||
|
||||
if resLen > 0 {
|
||||
// Columns
|
||||
if err := mc.readUntilEOF(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
dest := make([]driver.Value, resLen)
|
||||
if err = rows.readRow(dest); err == nil {
|
||||
return dest[0].([]byte), mc.readUntilEOF()
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// finish is called when the query has canceled.
|
||||
func (mc *mysqlConn) cancel(err error) {
|
||||
mc.canceled.Set(err)
|
||||
mc.cleanup()
|
||||
}
|
||||
|
||||
// finish is called when the query has succeeded.
|
||||
func (mc *mysqlConn) finish() {
|
||||
if !mc.watching || mc.finished == nil {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case mc.finished <- struct{}{}:
|
||||
mc.watching = false
|
||||
case <-mc.closech:
|
||||
}
|
||||
}
|
||||
|
||||
// Ping implements driver.Pinger interface
|
||||
func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
|
||||
if mc.closed.IsSet() {
|
||||
errLog.Print(ErrInvalidConn)
|
||||
return driver.ErrBadConn
|
||||
}
|
||||
|
||||
if err = mc.watchCancel(ctx); err != nil {
|
||||
return
|
||||
}
|
||||
defer mc.finish()
|
||||
|
||||
if err = mc.writeCommandPacket(comPing); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return mc.readResultOK()
|
||||
}
|
||||
|
||||
// BeginTx implements driver.ConnBeginTx interface
|
||||
func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
|
||||
if err := mc.watchCancel(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer mc.finish()
|
||||
|
||||
if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault {
|
||||
level, err := mapIsolationLevel(opts.Isolation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return mc.begin(opts.ReadOnly)
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
|
||||
dargs, err := namedValueToValue(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := mc.watchCancel(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := mc.query(query, dargs)
|
||||
if err != nil {
|
||||
mc.finish()
|
||||
return nil, err
|
||||
}
|
||||
rows.finish = mc.finish
|
||||
return rows, err
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
||||
dargs, err := namedValueToValue(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := mc.watchCancel(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer mc.finish()
|
||||
|
||||
return mc.Exec(query, dargs)
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
|
||||
if err := mc.watchCancel(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stmt, err := mc.Prepare(query)
|
||||
mc.finish()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
select {
|
||||
default:
|
||||
case <-ctx.Done():
|
||||
stmt.Close()
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
return stmt, nil
|
||||
}
|
||||
|
||||
func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
dargs, err := namedValueToValue(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := stmt.mc.watchCancel(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := stmt.query(dargs)
|
||||
if err != nil {
|
||||
stmt.mc.finish()
|
||||
return nil, err
|
||||
}
|
||||
rows.finish = stmt.mc.finish
|
||||
return rows, err
|
||||
}
|
||||
|
||||
func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
|
||||
dargs, err := namedValueToValue(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := stmt.mc.watchCancel(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer stmt.mc.finish()
|
||||
|
||||
return stmt.Exec(dargs)
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) watchCancel(ctx context.Context) error {
|
||||
if mc.watching {
|
||||
// Reach here if canceled,
|
||||
// so the connection is already invalid
|
||||
mc.cleanup()
|
||||
return nil
|
||||
}
|
||||
if ctx.Done() == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
mc.watching = true
|
||||
select {
|
||||
default:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
if mc.watcher == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
mc.watcher <- ctx
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) startWatcher() {
|
||||
watcher := make(chan mysqlContext, 1)
|
||||
mc.watcher = watcher
|
||||
finished := make(chan struct{})
|
||||
mc.finished = finished
|
||||
go func() {
|
||||
for {
|
||||
var ctx mysqlContext
|
||||
select {
|
||||
case ctx = <-watcher:
|
||||
case <-mc.closech:
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
mc.cancel(ctx.Err())
|
||||
case <-finished:
|
||||
case <-mc.closech:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
|
||||
nv.Value, err = converter{}.ConvertValue(nv.Value)
|
||||
return
|
||||
}
|
||||
|
||||
// ResetSession implements driver.SessionResetter.
|
||||
// (From Go 1.10)
|
||||
func (mc *mysqlConn) ResetSession(ctx context.Context) error {
|
||||
if mc.closed.IsSet() {
|
||||
return driver.ErrBadConn
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,81 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestInterpolateParams(t *testing.T) {
|
||||
mc := &mysqlConn{
|
||||
buf: newBuffer(nil),
|
||||
maxAllowedPacket: maxPacketSize,
|
||||
cfg: &Config{
|
||||
InterpolateParams: true,
|
||||
},
|
||||
}
|
||||
|
||||
q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42), "gopher"})
|
||||
if err != nil {
|
||||
t.Errorf("Expected err=nil, got %#v", err)
|
||||
return
|
||||
}
|
||||
expected := `SELECT 42+'gopher'`
|
||||
if q != expected {
|
||||
t.Errorf("Expected: %q\nGot: %q", expected, q)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInterpolateParamsTooManyPlaceholders(t *testing.T) {
|
||||
mc := &mysqlConn{
|
||||
buf: newBuffer(nil),
|
||||
maxAllowedPacket: maxPacketSize,
|
||||
cfg: &Config{
|
||||
InterpolateParams: true,
|
||||
},
|
||||
}
|
||||
|
||||
q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42)})
|
||||
if err != driver.ErrSkip {
|
||||
t.Errorf("Expected err=driver.ErrSkip, got err=%#v, q=%#v", err, q)
|
||||
}
|
||||
}
|
||||
|
||||
// We don't support placeholder in string literal for now.
|
||||
// https://github.com/go-sql-driver/mysql/pull/490
|
||||
func TestInterpolateParamsPlaceholderInString(t *testing.T) {
|
||||
mc := &mysqlConn{
|
||||
buf: newBuffer(nil),
|
||||
maxAllowedPacket: maxPacketSize,
|
||||
cfg: &Config{
|
||||
InterpolateParams: true,
|
||||
},
|
||||
}
|
||||
|
||||
q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)})
|
||||
// When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42`
|
||||
if err != driver.ErrSkip {
|
||||
t.Errorf("Expected err=driver.ErrSkip, got err=%#v, q=%#v", err, q)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckNamedValue(t *testing.T) {
|
||||
value := driver.NamedValue{Value: ^uint64(0)}
|
||||
x := &mysqlConn{}
|
||||
err := x.CheckNamedValue(&value)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal("uint64 high-bit not convertible", err)
|
||||
}
|
||||
|
||||
if value.Value != "18446744073709551615" {
|
||||
t.Fatalf("uint64 high-bit not converted, got %#v %T", value.Value, value.Value)
|
||||
}
|
||||
}
|
|
@ -6,7 +6,7 @@
|
|||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package driver
|
||||
package mysql
|
||||
|
||||
const (
|
||||
defaultAuthPlugin = "mysql_native_password"
|
|
@ -0,0 +1,157 @@
|
|||
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
// Package mysql provides a MySQL driver for Go's database/sql package.
|
||||
//
|
||||
// The driver should be used via the database/sql package:
|
||||
//
|
||||
// import "database/sql"
|
||||
// import _ "github.com/go-sql-driver/mysql"
|
||||
//
|
||||
// db, err := sql.Open("mysql", "user:password@/dbname")
|
||||
//
|
||||
// See https://github.com/go-sql-driver/mysql#usage for details
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// MySQLDriver is exported to make the driver directly accessible.
|
||||
// In general the driver is used via the database/sql package.
|
||||
type MySQLDriver struct{}
|
||||
|
||||
// DialFunc is a function which can be used to establish the network connection.
|
||||
// Custom dial functions must be registered with RegisterDial
|
||||
type DialFunc func(addr string) (net.Conn, error)
|
||||
|
||||
var (
|
||||
dialsLock sync.RWMutex
|
||||
dials map[string]DialFunc
|
||||
)
|
||||
|
||||
// RegisterDial registers a custom dial function. It can then be used by the
|
||||
// network address mynet(addr), where mynet is the registered new network.
|
||||
// addr is passed as a parameter to the dial function.
|
||||
func RegisterDial(net string, dial DialFunc) {
|
||||
dialsLock.Lock()
|
||||
defer dialsLock.Unlock()
|
||||
if dials == nil {
|
||||
dials = make(map[string]DialFunc)
|
||||
}
|
||||
dials[net] = dial
|
||||
}
|
||||
|
||||
// Open new Connection.
|
||||
// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how
|
||||
// the DSN string is formated
|
||||
func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
|
||||
var err error
|
||||
|
||||
// New mysqlConn
|
||||
mc := &mysqlConn{
|
||||
maxAllowedPacket: maxPacketSize,
|
||||
maxWriteSize: maxPacketSize - 1,
|
||||
closech: make(chan struct{}),
|
||||
}
|
||||
mc.cfg, err = ParseDSN(dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mc.parseTime = mc.cfg.ParseTime
|
||||
|
||||
// Connect to Server
|
||||
dialsLock.RLock()
|
||||
dial, ok := dials[mc.cfg.Net]
|
||||
dialsLock.RUnlock()
|
||||
if ok {
|
||||
mc.netConn, err = dial(mc.cfg.Addr)
|
||||
} else {
|
||||
nd := net.Dialer{Timeout: mc.cfg.Timeout}
|
||||
mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Enable TCP Keepalives on TCP connections
|
||||
if tc, ok := mc.netConn.(*net.TCPConn); ok {
|
||||
if err := tc.SetKeepAlive(true); err != nil {
|
||||
// Don't send COM_QUIT before handshake.
|
||||
mc.netConn.Close()
|
||||
mc.netConn = nil
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Call startWatcher for context support (From Go 1.8)
|
||||
mc.startWatcher()
|
||||
|
||||
mc.buf = newBuffer(mc.netConn)
|
||||
|
||||
// Set I/O timeouts
|
||||
mc.buf.timeout = mc.cfg.ReadTimeout
|
||||
mc.writeTimeout = mc.cfg.WriteTimeout
|
||||
|
||||
// Reading Handshake Initialization Packet
|
||||
authData, plugin, err := mc.readHandshakePacket()
|
||||
if err != nil {
|
||||
mc.cleanup()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Send Client Authentication Packet
|
||||
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
// try the default auth plugin, if using the requested plugin failed
|
||||
errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
|
||||
plugin = defaultAuthPlugin
|
||||
authResp, addNUL, err = mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
mc.cleanup()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil {
|
||||
mc.cleanup()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Handle response to auth packet, switch methods if possible
|
||||
if err = mc.handleAuthResult(authData, plugin); err != nil {
|
||||
// Authentication failed and MySQL has already closed the connection
|
||||
// (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
|
||||
// Do not send COM_QUIT, just cleanup and return the error.
|
||||
mc.cleanup()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if mc.cfg.MaxAllowedPacket > 0 {
|
||||
mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket
|
||||
} else {
|
||||
// Get max allowed packet size
|
||||
maxap, err := mc.getSystemVar("max_allowed_packet")
|
||||
if err != nil {
|
||||
mc.Close()
|
||||
return nil, err
|
||||
}
|
||||
mc.maxAllowedPacket = stringToInt(maxap) - 1
|
||||
}
|
||||
if mc.maxAllowedPacket < maxPacketSize {
|
||||
mc.maxWriteSize = mc.maxAllowedPacket
|
||||
}
|
||||
|
||||
// Handle DSN Params
|
||||
err = mc.handleParams()
|
||||
if err != nil {
|
||||
mc.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return mc, nil
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -6,7 +6,7 @@
|
|||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package driver
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
@ -463,6 +463,7 @@ func parseDSNParams(cfg *Config, params string) (err error) {
|
|||
// Collation
|
||||
case "collation":
|
||||
cfg.Collation = value
|
||||
break
|
||||
|
||||
case "columnsWithAlias":
|
||||
var isBool bool
|
|
@ -0,0 +1,331 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var testDSNs = []struct {
|
||||
in string
|
||||
out *Config
|
||||
}{{
|
||||
"username:password@protocol(address)/dbname?param=value",
|
||||
&Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
|
||||
}, {
|
||||
"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true",
|
||||
&Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, ColumnsWithAlias: true},
|
||||
}, {
|
||||
"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true&multiStatements=true",
|
||||
&Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, ColumnsWithAlias: true, MultiStatements: true},
|
||||
}, {
|
||||
"user@unix(/path/to/socket)/dbname?charset=utf8",
|
||||
&Config{User: "user", Net: "unix", Addr: "/path/to/socket", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
|
||||
}, {
|
||||
"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true",
|
||||
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, TLSConfig: "true"},
|
||||
}, {
|
||||
"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify",
|
||||
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8mb4,utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, TLSConfig: "skip-verify"},
|
||||
}, {
|
||||
"user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci&maxAllowedPacket=16777216",
|
||||
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, AllowNativePasswords: true, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, AllowAllFiles: true, AllowOldPasswords: true, ClientFoundRows: true, MaxAllowedPacket: 16777216},
|
||||
}, {
|
||||
"user:password@/dbname?allowNativePasswords=false&maxAllowedPacket=0",
|
||||
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, AllowNativePasswords: false},
|
||||
}, {
|
||||
"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local",
|
||||
&Config{User: "user", Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.Local, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
|
||||
}, {
|
||||
"/dbname",
|
||||
&Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
|
||||
}, {
|
||||
"@/",
|
||||
&Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
|
||||
}, {
|
||||
"/",
|
||||
&Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
|
||||
}, {
|
||||
"",
|
||||
&Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
|
||||
}, {
|
||||
"user:p@/ssword@/",
|
||||
&Config{User: "user", Passwd: "p@/ssword", Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
|
||||
}, {
|
||||
"unix/?arg=%2Fsome%2Fpath.ext",
|
||||
&Config{Net: "unix", Addr: "/tmp/mysql.sock", Params: map[string]string{"arg": "/some/path.ext"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
|
||||
}, {
|
||||
"tcp(127.0.0.1)/dbname",
|
||||
&Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
|
||||
}, {
|
||||
"tcp(de:ad:be:ef::ca:fe)/dbname",
|
||||
&Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
|
||||
},
|
||||
}
|
||||
|
||||
func TestDSNParser(t *testing.T) {
|
||||
for i, tst := range testDSNs {
|
||||
cfg, err := ParseDSN(tst.in)
|
||||
if err != nil {
|
||||
t.Error(err.Error())
|
||||
}
|
||||
|
||||
// pointer not static
|
||||
cfg.tls = nil
|
||||
|
||||
if !reflect.DeepEqual(cfg, tst.out) {
|
||||
t.Errorf("%d. ParseDSN(%q) mismatch:\ngot %+v\nwant %+v", i, tst.in, cfg, tst.out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDSNParserInvalid(t *testing.T) {
|
||||
var invalidDSNs = []string{
|
||||
"@net(addr/", // no closing brace
|
||||
"@tcp(/", // no closing brace
|
||||
"tcp(/", // no closing brace
|
||||
"(/", // no closing brace
|
||||
"net(addr)//", // unescaped
|
||||
"User:pass@tcp(1.2.3.4:3306)", // no trailing slash
|
||||
"net()/", // unknown default addr
|
||||
//"/dbname?arg=/some/unescaped/path",
|
||||
}
|
||||
|
||||
for i, tst := range invalidDSNs {
|
||||
if _, err := ParseDSN(tst); err == nil {
|
||||
t.Errorf("invalid DSN #%d. (%s) didn't error!", i, tst)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDSNReformat(t *testing.T) {
|
||||
for i, tst := range testDSNs {
|
||||
dsn1 := tst.in
|
||||
cfg1, err := ParseDSN(dsn1)
|
||||
if err != nil {
|
||||
t.Error(err.Error())
|
||||
continue
|
||||
}
|
||||
cfg1.tls = nil // pointer not static
|
||||
res1 := fmt.Sprintf("%+v", cfg1)
|
||||
|
||||
dsn2 := cfg1.FormatDSN()
|
||||
cfg2, err := ParseDSN(dsn2)
|
||||
if err != nil {
|
||||
t.Error(err.Error())
|
||||
continue
|
||||
}
|
||||
cfg2.tls = nil // pointer not static
|
||||
res2 := fmt.Sprintf("%+v", cfg2)
|
||||
|
||||
if res1 != res2 {
|
||||
t.Errorf("%d. %q does not match %q", i, res2, res1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDSNServerPubKey(t *testing.T) {
|
||||
baseDSN := "User:password@tcp(localhost:5555)/dbname?serverPubKey="
|
||||
|
||||
RegisterServerPubKey("testKey", testPubKeyRSA)
|
||||
defer DeregisterServerPubKey("testKey")
|
||||
|
||||
tst := baseDSN + "testKey"
|
||||
cfg, err := ParseDSN(tst)
|
||||
if err != nil {
|
||||
t.Error(err.Error())
|
||||
}
|
||||
|
||||
if cfg.ServerPubKey != "testKey" {
|
||||
t.Errorf("unexpected cfg.ServerPubKey value: %v", cfg.ServerPubKey)
|
||||
}
|
||||
if cfg.pubKey != testPubKeyRSA {
|
||||
t.Error("pub key pointer doesn't match")
|
||||
}
|
||||
|
||||
// Key is missing
|
||||
tst = baseDSN + "invalid_name"
|
||||
cfg, err = ParseDSN(tst)
|
||||
if err == nil {
|
||||
t.Errorf("invalid name in DSN (%s) but did not error. Got config: %#v", tst, cfg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDSNServerPubKeyQueryEscape(t *testing.T) {
|
||||
const name = "&%!:"
|
||||
dsn := "User:password@tcp(localhost:5555)/dbname?serverPubKey=" + url.QueryEscape(name)
|
||||
|
||||
RegisterServerPubKey(name, testPubKeyRSA)
|
||||
defer DeregisterServerPubKey(name)
|
||||
|
||||
cfg, err := ParseDSN(dsn)
|
||||
if err != nil {
|
||||
t.Error(err.Error())
|
||||
}
|
||||
|
||||
if cfg.pubKey != testPubKeyRSA {
|
||||
t.Error("pub key pointer doesn't match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDSNWithCustomTLS(t *testing.T) {
|
||||
baseDSN := "User:password@tcp(localhost:5555)/dbname?tls="
|
||||
tlsCfg := tls.Config{}
|
||||
|
||||
RegisterTLSConfig("utils_test", &tlsCfg)
|
||||
defer DeregisterTLSConfig("utils_test")
|
||||
|
||||
// Custom TLS is missing
|
||||
tst := baseDSN + "invalid_tls"
|
||||
cfg, err := ParseDSN(tst)
|
||||
if err == nil {
|
||||
t.Errorf("invalid custom TLS in DSN (%s) but did not error. Got config: %#v", tst, cfg)
|
||||
}
|
||||
|
||||
tst = baseDSN + "utils_test"
|
||||
|
||||
// Custom TLS with a server name
|
||||
name := "foohost"
|
||||
tlsCfg.ServerName = name
|
||||
cfg, err = ParseDSN(tst)
|
||||
|
||||
if err != nil {
|
||||
t.Error(err.Error())
|
||||
} else if cfg.tls.ServerName != name {
|
||||
t.Errorf("did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, tst)
|
||||
}
|
||||
|
||||
// Custom TLS without a server name
|
||||
name = "localhost"
|
||||
tlsCfg.ServerName = ""
|
||||
cfg, err = ParseDSN(tst)
|
||||
|
||||
if err != nil {
|
||||
t.Error(err.Error())
|
||||
} else if cfg.tls.ServerName != name {
|
||||
t.Errorf("did not get the correct ServerName (%s) parsing DSN (%s).", name, tst)
|
||||
} else if tlsCfg.ServerName != "" {
|
||||
t.Errorf("tlsCfg was mutated ServerName (%s) should be empty parsing DSN (%s).", name, tst)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDSNTLSConfig(t *testing.T) {
|
||||
expectedServerName := "example.com"
|
||||
dsn := "tcp(example.com:1234)/?tls=true"
|
||||
|
||||
cfg, err := ParseDSN(dsn)
|
||||
if err != nil {
|
||||
t.Error(err.Error())
|
||||
}
|
||||
if cfg.tls == nil {
|
||||
t.Error("cfg.tls should not be nil")
|
||||
}
|
||||
if cfg.tls.ServerName != expectedServerName {
|
||||
t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.tls.ServerName)
|
||||
}
|
||||
|
||||
dsn = "tcp(example.com)/?tls=true"
|
||||
cfg, err = ParseDSN(dsn)
|
||||
if err != nil {
|
||||
t.Error(err.Error())
|
||||
}
|
||||
if cfg.tls == nil {
|
||||
t.Error("cfg.tls should not be nil")
|
||||
}
|
||||
if cfg.tls.ServerName != expectedServerName {
|
||||
t.Errorf("cfg.tls.ServerName should be %q, got %q (host without port)", expectedServerName, cfg.tls.ServerName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDSNWithCustomTLSQueryEscape(t *testing.T) {
|
||||
const configKey = "&%!:"
|
||||
dsn := "User:password@tcp(localhost:5555)/dbname?tls=" + url.QueryEscape(configKey)
|
||||
name := "foohost"
|
||||
tlsCfg := tls.Config{ServerName: name}
|
||||
|
||||
RegisterTLSConfig(configKey, &tlsCfg)
|
||||
defer DeregisterTLSConfig(configKey)
|
||||
|
||||
cfg, err := ParseDSN(dsn)
|
||||
|
||||
if err != nil {
|
||||
t.Error(err.Error())
|
||||
} else if cfg.tls.ServerName != name {
|
||||
t.Errorf("did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, dsn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDSNUnsafeCollation(t *testing.T) {
|
||||
_, err := ParseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=true")
|
||||
if err != errInvalidDSNUnsafeCollation {
|
||||
t.Errorf("expected %v, got %v", errInvalidDSNUnsafeCollation, err)
|
||||
}
|
||||
|
||||
_, err = ParseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=false")
|
||||
if err != nil {
|
||||
t.Errorf("expected %v, got %v", nil, err)
|
||||
}
|
||||
|
||||
_, err = ParseDSN("/dbname?collation=gbk_chinese_ci")
|
||||
if err != nil {
|
||||
t.Errorf("expected %v, got %v", nil, err)
|
||||
}
|
||||
|
||||
_, err = ParseDSN("/dbname?collation=ascii_bin&interpolateParams=true")
|
||||
if err != nil {
|
||||
t.Errorf("expected %v, got %v", nil, err)
|
||||
}
|
||||
|
||||
_, err = ParseDSN("/dbname?collation=latin1_german1_ci&interpolateParams=true")
|
||||
if err != nil {
|
||||
t.Errorf("expected %v, got %v", nil, err)
|
||||
}
|
||||
|
||||
_, err = ParseDSN("/dbname?collation=utf8_general_ci&interpolateParams=true")
|
||||
if err != nil {
|
||||
t.Errorf("expected %v, got %v", nil, err)
|
||||
}
|
||||
|
||||
_, err = ParseDSN("/dbname?collation=utf8mb4_general_ci&interpolateParams=true")
|
||||
if err != nil {
|
||||
t.Errorf("expected %v, got %v", nil, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParamsAreSorted(t *testing.T) {
|
||||
expected := "/dbname?interpolateParams=true&foobar=baz&quux=loo"
|
||||
cfg := NewConfig()
|
||||
cfg.DBName = "dbname"
|
||||
cfg.InterpolateParams = true
|
||||
cfg.Params = map[string]string{
|
||||
"quux": "loo",
|
||||
"foobar": "baz",
|
||||
}
|
||||
actual := cfg.FormatDSN()
|
||||
if actual != expected {
|
||||
t.Errorf("generic Config.Params were not sorted: want %#v, got %#v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkParseDSN(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, tst := range testDSNs {
|
||||
if _, err := ParseDSN(tst.in); err != nil {
|
||||
b.Error(err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -6,7 +6,7 @@
|
|||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package driver
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
@ -22,7 +22,7 @@ var (
|
|||
ErrMalformPkt = errors.New("malformed packet")
|
||||
ErrNoTLS = errors.New("TLS requested but server does not support TLS")
|
||||
ErrCleartextPassword = errors.New("this user requires clear text authentication. If you still want to use it, please add 'allowCleartextPasswords=1' to your DSN")
|
||||
ErrNativePassword = errors.New("this user requires mysql native password authentication")
|
||||
ErrNativePassword = errors.New("this user requires mysql native password authentication.")
|
||||
ErrOldPassword = errors.New("this user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords")
|
||||
ErrUnknownPlugin = errors.New("this authentication plugin is not supported")
|
||||
ErrOldProtocol = errors.New("MySQL server does not support required protocol 41+")
|
||||
|
@ -55,14 +55,14 @@ func SetLogger(logger Logger) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Error is an error type which represents a single MySQL error
|
||||
type Error struct {
|
||||
// MySQLError is an error type which represents a single MySQL error
|
||||
type MySQLError struct {
|
||||
Number uint16
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
return fmt.Sprintf("Error %d: %s", e.Number, e.Message)
|
||||
func (me *MySQLError) Error() string {
|
||||
return fmt.Sprintf("Error %d: %s", me.Number, me.Message)
|
||||
}
|
||||
|
||||
func timeoutError(err error) bool {
|
|
@ -0,0 +1,42 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestErrorsSetLogger(t *testing.T) {
|
||||
previous := errLog
|
||||
defer func() {
|
||||
errLog = previous
|
||||
}()
|
||||
|
||||
// set up logger
|
||||
const expected = "prefix: test\n"
|
||||
buffer := bytes.NewBuffer(make([]byte, 0, 64))
|
||||
logger := log.New(buffer, "prefix: ", 0)
|
||||
|
||||
// print
|
||||
SetLogger(logger)
|
||||
errLog.Print("test")
|
||||
|
||||
// check result
|
||||
if actual := buffer.String(); actual != expected {
|
||||
t.Errorf("expected %q, got %q", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorsStrictIgnoreNotes(t *testing.T) {
|
||||
runTests(t, dsn+"&sql_notes=false", func(dbt *DBTest) {
|
||||
dbt.mustExec("DROP TABLE IF EXISTS does_not_exist")
|
||||
})
|
||||
}
|
|
@ -6,7 +6,7 @@
|
|||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package driver
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql"
|
|
@ -6,7 +6,7 @@
|
|||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package driver
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"fmt"
|
|
@ -6,7 +6,7 @@
|
|||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package driver
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
@ -583,7 +583,7 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error {
|
|||
}
|
||||
|
||||
// Error Message [string]
|
||||
return &Error{
|
||||
return &MySQLError{
|
||||
Number: errno,
|
||||
Message: string(data[pos:]),
|
||||
}
|
||||
|
@ -604,7 +604,7 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
|
|||
mc.affectedRows, _, n = readLengthEncodedInteger(data[1:])
|
||||
|
||||
// Insert id [Length Coded Binary]
|
||||
mc.insertID, _, m = readLengthEncodedInteger(data[1+n:])
|
||||
mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
|
||||
|
||||
// server_status [2 bytes]
|
||||
mc.status = readStatus(data[1+n+m : 1+n+m+2])
|
||||
|
@ -806,6 +806,307 @@ func (mc *mysqlConn) readUntilEOF() error {
|
|||
}
|
||||
}
|
||||
|
||||
/******************************************************************************
|
||||
* Prepared Statements *
|
||||
******************************************************************************/
|
||||
|
||||
// Prepare Result Packets
|
||||
// http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html
|
||||
func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
|
||||
data, err := stmt.mc.readPacket()
|
||||
if err == nil {
|
||||
// packet indicator [1 byte]
|
||||
if data[0] != iOK {
|
||||
return 0, stmt.mc.handleErrorPacket(data)
|
||||
}
|
||||
|
||||
// statement id [4 bytes]
|
||||
stmt.id = binary.LittleEndian.Uint32(data[1:5])
|
||||
|
||||
// Column count [16 bit uint]
|
||||
columnCount := binary.LittleEndian.Uint16(data[5:7])
|
||||
|
||||
// Param count [16 bit uint]
|
||||
stmt.paramCount = int(binary.LittleEndian.Uint16(data[7:9]))
|
||||
|
||||
// Reserved [8 bit]
|
||||
|
||||
// Warning count [16 bit uint]
|
||||
|
||||
return columnCount, nil
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html
|
||||
func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
|
||||
maxLen := stmt.mc.maxAllowedPacket - 1
|
||||
pktLen := maxLen
|
||||
|
||||
// After the header (bytes 0-3) follows before the data:
|
||||
// 1 byte command
|
||||
// 4 bytes stmtID
|
||||
// 2 bytes paramID
|
||||
const dataOffset = 1 + 4 + 2
|
||||
|
||||
// Cannot use the write buffer since
|
||||
// a) the buffer is too small
|
||||
// b) it is in use
|
||||
data := make([]byte, 4+1+4+2+len(arg))
|
||||
|
||||
copy(data[4+dataOffset:], arg)
|
||||
|
||||
for argLen := len(arg); argLen > 0; argLen -= pktLen - dataOffset {
|
||||
if dataOffset+argLen < maxLen {
|
||||
pktLen = dataOffset + argLen
|
||||
}
|
||||
|
||||
stmt.mc.sequence = 0
|
||||
// Add command byte [1 byte]
|
||||
data[4] = comStmtSendLongData
|
||||
|
||||
// Add stmtID [32 bit]
|
||||
data[5] = byte(stmt.id)
|
||||
data[6] = byte(stmt.id >> 8)
|
||||
data[7] = byte(stmt.id >> 16)
|
||||
data[8] = byte(stmt.id >> 24)
|
||||
|
||||
// Add paramID [16 bit]
|
||||
data[9] = byte(paramID)
|
||||
data[10] = byte(paramID >> 8)
|
||||
|
||||
// Send CMD packet
|
||||
err := stmt.mc.writePacket(data[:4+pktLen])
|
||||
if err == nil {
|
||||
data = data[pktLen-dataOffset:]
|
||||
continue
|
||||
}
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
// Reset Packet Sequence
|
||||
stmt.mc.sequence = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
// Execute Prepared Statement
|
||||
// http://dev.mysql.com/doc/internals/en/com-stmt-execute.html
|
||||
func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
||||
if len(args) != stmt.paramCount {
|
||||
return fmt.Errorf(
|
||||
"argument count mismatch (got: %d; has: %d)",
|
||||
len(args),
|
||||
stmt.paramCount,
|
||||
)
|
||||
}
|
||||
|
||||
const minPktLen = 4 + 1 + 4 + 1 + 4
|
||||
mc := stmt.mc
|
||||
|
||||
// Determine threshould dynamically to avoid packet size shortage.
|
||||
longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1)
|
||||
if longDataSize < 64 {
|
||||
longDataSize = 64
|
||||
}
|
||||
|
||||
// Reset packet-sequence
|
||||
mc.sequence = 0
|
||||
|
||||
var data []byte
|
||||
|
||||
if len(args) == 0 {
|
||||
data = mc.buf.takeBuffer(minPktLen)
|
||||
} else {
|
||||
data = mc.buf.takeCompleteBuffer()
|
||||
}
|
||||
if data == nil {
|
||||
// cannot take the buffer. Something must be wrong with the connection
|
||||
errLog.Print(ErrBusyBuffer)
|
||||
return errBadConnNoWrite
|
||||
}
|
||||
|
||||
// command [1 byte]
|
||||
data[4] = comStmtExecute
|
||||
|
||||
// statement_id [4 bytes]
|
||||
data[5] = byte(stmt.id)
|
||||
data[6] = byte(stmt.id >> 8)
|
||||
data[7] = byte(stmt.id >> 16)
|
||||
data[8] = byte(stmt.id >> 24)
|
||||
|
||||
// flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte]
|
||||
data[9] = 0x00
|
||||
|
||||
// iteration_count (uint32(1)) [4 bytes]
|
||||
data[10] = 0x01
|
||||
data[11] = 0x00
|
||||
data[12] = 0x00
|
||||
data[13] = 0x00
|
||||
|
||||
if len(args) > 0 {
|
||||
pos := minPktLen
|
||||
|
||||
var nullMask []byte
|
||||
if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) {
|
||||
// buffer has to be extended but we don't know by how much so
|
||||
// we depend on append after all data with known sizes fit.
|
||||
// We stop at that because we deal with a lot of columns here
|
||||
// which makes the required allocation size hard to guess.
|
||||
tmp := make([]byte, pos+maskLen+typesLen)
|
||||
copy(tmp[:pos], data[:pos])
|
||||
data = tmp
|
||||
nullMask = data[pos : pos+maskLen]
|
||||
pos += maskLen
|
||||
} else {
|
||||
nullMask = data[pos : pos+maskLen]
|
||||
for i := 0; i < maskLen; i++ {
|
||||
nullMask[i] = 0
|
||||
}
|
||||
pos += maskLen
|
||||
}
|
||||
|
||||
// newParameterBoundFlag 1 [1 byte]
|
||||
data[pos] = 0x01
|
||||
pos++
|
||||
|
||||
// type of each parameter [len(args)*2 bytes]
|
||||
paramTypes := data[pos:]
|
||||
pos += len(args) * 2
|
||||
|
||||
// value of each parameter [n bytes]
|
||||
paramValues := data[pos:pos]
|
||||
valuesCap := cap(paramValues)
|
||||
|
||||
for i, arg := range args {
|
||||
// build NULL-bitmap
|
||||
if arg == nil {
|
||||
nullMask[i/8] |= 1 << (uint(i) & 7)
|
||||
paramTypes[i+i] = byte(fieldTypeNULL)
|
||||
paramTypes[i+i+1] = 0x00
|
||||
continue
|
||||
}
|
||||
|
||||
// cache types and values
|
||||
switch v := arg.(type) {
|
||||
case int64:
|
||||
paramTypes[i+i] = byte(fieldTypeLongLong)
|
||||
paramTypes[i+i+1] = 0x00
|
||||
|
||||
if cap(paramValues)-len(paramValues)-8 >= 0 {
|
||||
paramValues = paramValues[:len(paramValues)+8]
|
||||
binary.LittleEndian.PutUint64(
|
||||
paramValues[len(paramValues)-8:],
|
||||
uint64(v),
|
||||
)
|
||||
} else {
|
||||
paramValues = append(paramValues,
|
||||
uint64ToBytes(uint64(v))...,
|
||||
)
|
||||
}
|
||||
|
||||
case float64:
|
||||
paramTypes[i+i] = byte(fieldTypeDouble)
|
||||
paramTypes[i+i+1] = 0x00
|
||||
|
||||
if cap(paramValues)-len(paramValues)-8 >= 0 {
|
||||
paramValues = paramValues[:len(paramValues)+8]
|
||||
binary.LittleEndian.PutUint64(
|
||||
paramValues[len(paramValues)-8:],
|
||||
math.Float64bits(v),
|
||||
)
|
||||
} else {
|
||||
paramValues = append(paramValues,
|
||||
uint64ToBytes(math.Float64bits(v))...,
|
||||
)
|
||||
}
|
||||
|
||||
case bool:
|
||||
paramTypes[i+i] = byte(fieldTypeTiny)
|
||||
paramTypes[i+i+1] = 0x00
|
||||
|
||||
if v {
|
||||
paramValues = append(paramValues, 0x01)
|
||||
} else {
|
||||
paramValues = append(paramValues, 0x00)
|
||||
}
|
||||
|
||||
case []byte:
|
||||
// Common case (non-nil value) first
|
||||
if v != nil {
|
||||
paramTypes[i+i] = byte(fieldTypeString)
|
||||
paramTypes[i+i+1] = 0x00
|
||||
|
||||
if len(v) < longDataSize {
|
||||
paramValues = appendLengthEncodedInteger(paramValues,
|
||||
uint64(len(v)),
|
||||
)
|
||||
paramValues = append(paramValues, v...)
|
||||
} else {
|
||||
if err := stmt.writeCommandLongData(i, v); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle []byte(nil) as a NULL value
|
||||
nullMask[i/8] |= 1 << (uint(i) & 7)
|
||||
paramTypes[i+i] = byte(fieldTypeNULL)
|
||||
paramTypes[i+i+1] = 0x00
|
||||
|
||||
case string:
|
||||
paramTypes[i+i] = byte(fieldTypeString)
|
||||
paramTypes[i+i+1] = 0x00
|
||||
|
||||
if len(v) < longDataSize {
|
||||
paramValues = appendLengthEncodedInteger(paramValues,
|
||||
uint64(len(v)),
|
||||
)
|
||||
paramValues = append(paramValues, v...)
|
||||
} else {
|
||||
if err := stmt.writeCommandLongData(i, []byte(v)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
case time.Time:
|
||||
paramTypes[i+i] = byte(fieldTypeString)
|
||||
paramTypes[i+i+1] = 0x00
|
||||
|
||||
var a [64]byte
|
||||
var b = a[:0]
|
||||
|
||||
if v.IsZero() {
|
||||
b = append(b, "0000-00-00"...)
|
||||
} else {
|
||||
b = v.In(mc.cfg.Loc).AppendFormat(b, timeFormat)
|
||||
}
|
||||
|
||||
paramValues = appendLengthEncodedInteger(paramValues,
|
||||
uint64(len(b)),
|
||||
)
|
||||
paramValues = append(paramValues, b...)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("cannot convert type: %T", arg)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if param values exceeded the available buffer
|
||||
// In that case we must build the data packet with the new values buffer
|
||||
if valuesCap != cap(paramValues) {
|
||||
data = append(data[:pos], paramValues...)
|
||||
mc.buf.buf = data
|
||||
}
|
||||
|
||||
pos += len(paramValues)
|
||||
data = data[:pos]
|
||||
}
|
||||
|
||||
return mc.writePacket(data)
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) discardResults() error {
|
||||
for mc.status&statusMoreResultsExists != 0 {
|
||||
resLen, err := mc.readResultSetHeaderPacket()
|
|
@ -0,0 +1,336 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
errConnClosed = errors.New("connection is closed")
|
||||
errConnTooManyReads = errors.New("too many reads")
|
||||
errConnTooManyWrites = errors.New("too many writes")
|
||||
)
|
||||
|
||||
// struct to mock a net.Conn for testing purposes
|
||||
type mockConn struct {
|
||||
laddr net.Addr
|
||||
raddr net.Addr
|
||||
data []byte
|
||||
written []byte
|
||||
queuedReplies [][]byte
|
||||
closed bool
|
||||
read int
|
||||
reads int
|
||||
writes int
|
||||
maxReads int
|
||||
maxWrites int
|
||||
}
|
||||
|
||||
func (m *mockConn) Read(b []byte) (n int, err error) {
|
||||
if m.closed {
|
||||
return 0, errConnClosed
|
||||
}
|
||||
|
||||
m.reads++
|
||||
if m.maxReads > 0 && m.reads > m.maxReads {
|
||||
return 0, errConnTooManyReads
|
||||
}
|
||||
|
||||
n = copy(b, m.data)
|
||||
m.read += n
|
||||
m.data = m.data[n:]
|
||||
return
|
||||
}
|
||||
func (m *mockConn) Write(b []byte) (n int, err error) {
|
||||
if m.closed {
|
||||
return 0, errConnClosed
|
||||
}
|
||||
|
||||
m.writes++
|
||||
if m.maxWrites > 0 && m.writes > m.maxWrites {
|
||||
return 0, errConnTooManyWrites
|
||||
}
|
||||
|
||||
n = len(b)
|
||||
m.written = append(m.written, b...)
|
||||
|
||||
if n > 0 && len(m.queuedReplies) > 0 {
|
||||
m.data = m.queuedReplies[0]
|
||||
m.queuedReplies = m.queuedReplies[1:]
|
||||
}
|
||||
return
|
||||
}
|
||||
func (m *mockConn) Close() error {
|
||||
m.closed = true
|
||||
return nil
|
||||
}
|
||||
func (m *mockConn) LocalAddr() net.Addr {
|
||||
return m.laddr
|
||||
}
|
||||
func (m *mockConn) RemoteAddr() net.Addr {
|
||||
return m.raddr
|
||||
}
|
||||
func (m *mockConn) SetDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockConn) SetReadDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockConn) SetWriteDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// make sure mockConn implements the net.Conn interface
|
||||
var _ net.Conn = new(mockConn)
|
||||
|
||||
func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) {
|
||||
conn := new(mockConn)
|
||||
mc := &mysqlConn{
|
||||
buf: newBuffer(conn),
|
||||
cfg: NewConfig(),
|
||||
netConn: conn,
|
||||
closech: make(chan struct{}),
|
||||
maxAllowedPacket: defaultMaxAllowedPacket,
|
||||
sequence: sequence,
|
||||
}
|
||||
return conn, mc
|
||||
}
|
||||
|
||||
func TestReadPacketSingleByte(t *testing.T) {
|
||||
conn := new(mockConn)
|
||||
mc := &mysqlConn{
|
||||
buf: newBuffer(conn),
|
||||
}
|
||||
|
||||
conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff}
|
||||
conn.maxReads = 1
|
||||
packet, err := mc.readPacket()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(packet) != 1 {
|
||||
t.Fatalf("unexpected packet length: expected %d, got %d", 1, len(packet))
|
||||
}
|
||||
if packet[0] != 0xff {
|
||||
t.Fatalf("unexpected packet content: expected %x, got %x", 0xff, packet[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadPacketWrongSequenceID(t *testing.T) {
|
||||
conn := new(mockConn)
|
||||
mc := &mysqlConn{
|
||||
buf: newBuffer(conn),
|
||||
}
|
||||
|
||||
// too low sequence id
|
||||
conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff}
|
||||
conn.maxReads = 1
|
||||
mc.sequence = 1
|
||||
_, err := mc.readPacket()
|
||||
if err != ErrPktSync {
|
||||
t.Errorf("expected ErrPktSync, got %v", err)
|
||||
}
|
||||
|
||||
// reset
|
||||
conn.reads = 0
|
||||
mc.sequence = 0
|
||||
mc.buf = newBuffer(conn)
|
||||
|
||||
// too high sequence id
|
||||
conn.data = []byte{0x01, 0x00, 0x00, 0x42, 0xff}
|
||||
_, err = mc.readPacket()
|
||||
if err != ErrPktSyncMul {
|
||||
t.Errorf("expected ErrPktSyncMul, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadPacketSplit(t *testing.T) {
|
||||
conn := new(mockConn)
|
||||
mc := &mysqlConn{
|
||||
buf: newBuffer(conn),
|
||||
}
|
||||
|
||||
data := make([]byte, maxPacketSize*2+4*3)
|
||||
const pkt2ofs = maxPacketSize + 4
|
||||
const pkt3ofs = 2 * (maxPacketSize + 4)
|
||||
|
||||
// case 1: payload has length maxPacketSize
|
||||
data = data[:pkt2ofs+4]
|
||||
|
||||
// 1st packet has maxPacketSize length and sequence id 0
|
||||
// ff ff ff 00 ...
|
||||
data[0] = 0xff
|
||||
data[1] = 0xff
|
||||
data[2] = 0xff
|
||||
|
||||
// mark the payload start and end of 1st packet so that we can check if the
|
||||
// content was correctly appended
|
||||
data[4] = 0x11
|
||||
data[maxPacketSize+3] = 0x22
|
||||
|
||||
// 2nd packet has payload length 0 and squence id 1
|
||||
// 00 00 00 01
|
||||
data[pkt2ofs+3] = 0x01
|
||||
|
||||
conn.data = data
|
||||
conn.maxReads = 3
|
||||
packet, err := mc.readPacket()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(packet) != maxPacketSize {
|
||||
t.Fatalf("unexpected packet length: expected %d, got %d", maxPacketSize, len(packet))
|
||||
}
|
||||
if packet[0] != 0x11 {
|
||||
t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0])
|
||||
}
|
||||
if packet[maxPacketSize-1] != 0x22 {
|
||||
t.Fatalf("unexpected payload end: expected %x, got %x", 0x22, packet[maxPacketSize-1])
|
||||
}
|
||||
|
||||
// case 2: payload has length which is a multiple of maxPacketSize
|
||||
data = data[:cap(data)]
|
||||
|
||||
// 2nd packet now has maxPacketSize length
|
||||
data[pkt2ofs] = 0xff
|
||||
data[pkt2ofs+1] = 0xff
|
||||
data[pkt2ofs+2] = 0xff
|
||||
|
||||
// mark the payload start and end of the 2nd packet
|
||||
data[pkt2ofs+4] = 0x33
|
||||
data[pkt2ofs+maxPacketSize+3] = 0x44
|
||||
|
||||
// 3rd packet has payload length 0 and squence id 2
|
||||
// 00 00 00 02
|
||||
data[pkt3ofs+3] = 0x02
|
||||
|
||||
conn.data = data
|
||||
conn.reads = 0
|
||||
conn.maxReads = 5
|
||||
mc.sequence = 0
|
||||
packet, err = mc.readPacket()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(packet) != 2*maxPacketSize {
|
||||
t.Fatalf("unexpected packet length: expected %d, got %d", 2*maxPacketSize, len(packet))
|
||||
}
|
||||
if packet[0] != 0x11 {
|
||||
t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0])
|
||||
}
|
||||
if packet[2*maxPacketSize-1] != 0x44 {
|
||||
t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet[2*maxPacketSize-1])
|
||||
}
|
||||
|
||||
// case 3: payload has a length larger maxPacketSize, which is not an exact
|
||||
// multiple of it
|
||||
data = data[:pkt2ofs+4+42]
|
||||
data[pkt2ofs] = 0x2a
|
||||
data[pkt2ofs+1] = 0x00
|
||||
data[pkt2ofs+2] = 0x00
|
||||
data[pkt2ofs+4+41] = 0x44
|
||||
|
||||
conn.data = data
|
||||
conn.reads = 0
|
||||
conn.maxReads = 4
|
||||
mc.sequence = 0
|
||||
packet, err = mc.readPacket()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(packet) != maxPacketSize+42 {
|
||||
t.Fatalf("unexpected packet length: expected %d, got %d", maxPacketSize+42, len(packet))
|
||||
}
|
||||
if packet[0] != 0x11 {
|
||||
t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0])
|
||||
}
|
||||
if packet[maxPacketSize+41] != 0x44 {
|
||||
t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet[maxPacketSize+41])
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadPacketFail(t *testing.T) {
|
||||
conn := new(mockConn)
|
||||
mc := &mysqlConn{
|
||||
buf: newBuffer(conn),
|
||||
closech: make(chan struct{}),
|
||||
}
|
||||
|
||||
// illegal empty (stand-alone) packet
|
||||
conn.data = []byte{0x00, 0x00, 0x00, 0x00}
|
||||
conn.maxReads = 1
|
||||
_, err := mc.readPacket()
|
||||
if err != ErrInvalidConn {
|
||||
t.Errorf("expected ErrInvalidConn, got %v", err)
|
||||
}
|
||||
|
||||
// reset
|
||||
conn.reads = 0
|
||||
mc.sequence = 0
|
||||
mc.buf = newBuffer(conn)
|
||||
|
||||
// fail to read header
|
||||
conn.closed = true
|
||||
_, err = mc.readPacket()
|
||||
if err != ErrInvalidConn {
|
||||
t.Errorf("expected ErrInvalidConn, got %v", err)
|
||||
}
|
||||
|
||||
// reset
|
||||
conn.closed = false
|
||||
conn.reads = 0
|
||||
mc.sequence = 0
|
||||
mc.buf = newBuffer(conn)
|
||||
|
||||
// fail to read body
|
||||
conn.maxReads = 1
|
||||
_, err = mc.readPacket()
|
||||
if err != ErrInvalidConn {
|
||||
t.Errorf("expected ErrInvalidConn, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/go-sql-driver/mysql/pull/801
|
||||
// not-NUL terminated plugin_name in init packet
|
||||
func TestRegression801(t *testing.T) {
|
||||
conn := new(mockConn)
|
||||
mc := &mysqlConn{
|
||||
buf: newBuffer(conn),
|
||||
cfg: new(Config),
|
||||
sequence: 42,
|
||||
closech: make(chan struct{}),
|
||||
}
|
||||
|
||||
conn.data = []byte{72, 0, 0, 42, 10, 53, 46, 53, 46, 56, 0, 165, 0, 0, 0,
|
||||
60, 70, 63, 58, 68, 104, 34, 97, 0, 223, 247, 33, 2, 0, 15, 128, 21, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 98, 120, 114, 47, 85, 75, 109, 99, 51, 77,
|
||||
50, 64, 0, 109, 121, 115, 113, 108, 95, 110, 97, 116, 105, 118, 101, 95,
|
||||
112, 97, 115, 115, 119, 111, 114, 100}
|
||||
conn.maxReads = 1
|
||||
|
||||
authData, pluginName, err := mc.readHandshakePacket()
|
||||
if err != nil {
|
||||
t.Fatalf("got error: %v", err)
|
||||
}
|
||||
|
||||
if pluginName != "mysql_native_password" {
|
||||
t.Errorf("expected plugin name 'mysql_native_password', got '%s'", pluginName)
|
||||
}
|
||||
|
||||
expectedAuthData := []byte{60, 70, 63, 58, 68, 104, 34, 97, 98, 120, 114,
|
||||
47, 85, 75, 109, 99, 51, 77, 50, 64}
|
||||
if !bytes.Equal(authData, expectedAuthData) {
|
||||
t.Errorf("expected authData '%v', got '%v'", expectedAuthData, authData)
|
||||
}
|
||||
}
|
|
@ -6,15 +6,15 @@
|
|||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package driver
|
||||
package mysql
|
||||
|
||||
type mysqlResult struct {
|
||||
affectedRows int64
|
||||
insertID int64
|
||||
insertId int64
|
||||
}
|
||||
|
||||
func (res *mysqlResult) LastInsertId() (int64, error) {
|
||||
return res.insertID, nil
|
||||
return res.insertId, nil
|
||||
}
|
||||
|
||||
func (res *mysqlResult) RowsAffected() (int64, error) {
|
|
@ -6,7 +6,7 @@
|
|||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package driver
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
|
@ -0,0 +1,211 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type mysqlStmt struct {
|
||||
mc *mysqlConn
|
||||
id uint32
|
||||
paramCount int
|
||||
}
|
||||
|
||||
func (stmt *mysqlStmt) Close() error {
|
||||
if stmt.mc == nil || stmt.mc.closed.IsSet() {
|
||||
// driver.Stmt.Close can be called more than once, thus this function
|
||||
// has to be idempotent.
|
||||
// See also Issue #450 and golang/go#16019.
|
||||
//errLog.Print(ErrInvalidConn)
|
||||
return driver.ErrBadConn
|
||||
}
|
||||
|
||||
err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id)
|
||||
stmt.mc = nil
|
||||
return err
|
||||
}
|
||||
|
||||
func (stmt *mysqlStmt) NumInput() int {
|
||||
return stmt.paramCount
|
||||
}
|
||||
|
||||
func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter {
|
||||
return converter{}
|
||||
}
|
||||
|
||||
func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
if stmt.mc.closed.IsSet() {
|
||||
errLog.Print(ErrInvalidConn)
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
// Send command
|
||||
err := stmt.writeExecutePacket(args)
|
||||
if err != nil {
|
||||
return nil, stmt.mc.markBadConn(err)
|
||||
}
|
||||
|
||||
mc := stmt.mc
|
||||
|
||||
mc.affectedRows = 0
|
||||
mc.insertId = 0
|
||||
|
||||
// Read Result
|
||||
resLen, err := mc.readResultSetHeaderPacket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resLen > 0 {
|
||||
// Columns
|
||||
if err = mc.readUntilEOF(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Rows
|
||||
if err := mc.readUntilEOF(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := mc.discardResults(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &mysqlResult{
|
||||
affectedRows: int64(mc.affectedRows),
|
||||
insertId: int64(mc.insertId),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
return stmt.query(args)
|
||||
}
|
||||
|
||||
func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
|
||||
if stmt.mc.closed.IsSet() {
|
||||
errLog.Print(ErrInvalidConn)
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
// Send command
|
||||
err := stmt.writeExecutePacket(args)
|
||||
if err != nil {
|
||||
return nil, stmt.mc.markBadConn(err)
|
||||
}
|
||||
|
||||
mc := stmt.mc
|
||||
|
||||
// Read Result
|
||||
resLen, err := mc.readResultSetHeaderPacket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows := new(binaryRows)
|
||||
|
||||
if resLen > 0 {
|
||||
rows.mc = mc
|
||||
rows.rs.columns, err = mc.readColumns(resLen)
|
||||
} else {
|
||||
rows.rs.done = true
|
||||
|
||||
switch err := rows.NextResultSet(); err {
|
||||
case nil, io.EOF:
|
||||
return rows, nil
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return rows, err
|
||||
}
|
||||
|
||||
type converter struct{}
|
||||
|
||||
// ConvertValue mirrors the reference/default converter in database/sql/driver
|
||||
// with _one_ exception. We support uint64 with their high bit and the default
|
||||
// implementation does not. This function should be kept in sync with
|
||||
// database/sql/driver defaultConverter.ConvertValue() except for that
|
||||
// deliberate difference.
|
||||
func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
|
||||
if driver.IsValue(v) {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
if vr, ok := v.(driver.Valuer); ok {
|
||||
sv, err := callValuerValue(vr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !driver.IsValue(sv) {
|
||||
return nil, fmt.Errorf("non-Value type %T returned from Value", sv)
|
||||
}
|
||||
return sv, nil
|
||||
}
|
||||
|
||||
rv := reflect.ValueOf(v)
|
||||
switch rv.Kind() {
|
||||
case reflect.Ptr:
|
||||
// indirect pointers
|
||||
if rv.IsNil() {
|
||||
return nil, nil
|
||||
} else {
|
||||
return c.ConvertValue(rv.Elem().Interface())
|
||||
}
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return rv.Int(), nil
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
|
||||
return int64(rv.Uint()), nil
|
||||
case reflect.Uint64:
|
||||
u64 := rv.Uint()
|
||||
if u64 >= 1<<63 {
|
||||
return strconv.FormatUint(u64, 10), nil
|
||||
}
|
||||
return int64(u64), nil
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return rv.Float(), nil
|
||||
case reflect.Bool:
|
||||
return rv.Bool(), nil
|
||||
case reflect.Slice:
|
||||
ek := rv.Type().Elem().Kind()
|
||||
if ek == reflect.Uint8 {
|
||||
return rv.Bytes(), nil
|
||||
}
|
||||
return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek)
|
||||
case reflect.String:
|
||||
return rv.String(), nil
|
||||
}
|
||||
return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind())
|
||||
}
|
||||
|
||||
var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
|
||||
|
||||
// callValuerValue returns vr.Value(), with one exception:
|
||||
// If vr.Value is an auto-generated method on a pointer type and the
|
||||
// pointer is nil, it would panic at runtime in the panicwrap
|
||||
// method. Treat it like nil instead.
|
||||
//
|
||||
// This is so people can implement driver.Value on value types and
|
||||
// still use nil pointers to those types to mean nil/NULL, just like
|
||||
// string/*string.
|
||||
//
|
||||
// This is an exact copy of the same-named unexported function from the
|
||||
// database/sql package.
|
||||
func callValuerValue(vr driver.Valuer) (v driver.Value, err error) {
|
||||
if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr &&
|
||||
rv.IsNil() &&
|
||||
rv.Type().Elem().Implements(valuerReflectType) {
|
||||
return nil, nil
|
||||
}
|
||||
return vr.Value()
|
||||
}
|
|
@ -0,0 +1,126 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConvertDerivedString(t *testing.T) {
|
||||
type derived string
|
||||
|
||||
output, err := converter{}.ConvertValue(derived("value"))
|
||||
if err != nil {
|
||||
t.Fatal("Derived string type not convertible", err)
|
||||
}
|
||||
|
||||
if output != "value" {
|
||||
t.Fatalf("Derived string type not converted, got %#v %T", output, output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertDerivedByteSlice(t *testing.T) {
|
||||
type derived []uint8
|
||||
|
||||
output, err := converter{}.ConvertValue(derived("value"))
|
||||
if err != nil {
|
||||
t.Fatal("Byte slice not convertible", err)
|
||||
}
|
||||
|
||||
if bytes.Compare(output.([]byte), []byte("value")) != 0 {
|
||||
t.Fatalf("Byte slice not converted, got %#v %T", output, output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertDerivedUnsupportedSlice(t *testing.T) {
|
||||
type derived []int
|
||||
|
||||
_, err := converter{}.ConvertValue(derived{1})
|
||||
if err == nil || err.Error() != "unsupported type mysql.derived, a slice of int" {
|
||||
t.Fatal("Unexpected error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertDerivedBool(t *testing.T) {
|
||||
type derived bool
|
||||
|
||||
output, err := converter{}.ConvertValue(derived(true))
|
||||
if err != nil {
|
||||
t.Fatal("Derived bool type not convertible", err)
|
||||
}
|
||||
|
||||
if output != true {
|
||||
t.Fatalf("Derived bool type not converted, got %#v %T", output, output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertPointer(t *testing.T) {
|
||||
str := "value"
|
||||
|
||||
output, err := converter{}.ConvertValue(&str)
|
||||
if err != nil {
|
||||
t.Fatal("Pointer type not convertible", err)
|
||||
}
|
||||
|
||||
if output != "value" {
|
||||
t.Fatalf("Pointer type not converted, got %#v %T", output, output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertSignedIntegers(t *testing.T) {
|
||||
values := []interface{}{
|
||||
int8(-42),
|
||||
int16(-42),
|
||||
int32(-42),
|
||||
int64(-42),
|
||||
int(-42),
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
output, err := converter{}.ConvertValue(value)
|
||||
if err != nil {
|
||||
t.Fatalf("%T type not convertible %s", value, err)
|
||||
}
|
||||
|
||||
if output != int64(-42) {
|
||||
t.Fatalf("%T type not converted, got %#v %T", value, output, output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertUnsignedIntegers(t *testing.T) {
|
||||
values := []interface{}{
|
||||
uint8(42),
|
||||
uint16(42),
|
||||
uint32(42),
|
||||
uint64(42),
|
||||
uint(42),
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
output, err := converter{}.ConvertValue(value)
|
||||
if err != nil {
|
||||
t.Fatalf("%T type not convertible %s", value, err)
|
||||
}
|
||||
|
||||
if output != int64(42) {
|
||||
t.Fatalf("%T type not converted, got %#v %T", value, output, output)
|
||||
}
|
||||
}
|
||||
|
||||
output, err := converter{}.ConvertValue(^uint64(0))
|
||||
if err != nil {
|
||||
t.Fatal("uint64 high-bit not convertible", err)
|
||||
}
|
||||
|
||||
if output != "18446744073709551615" {
|
||||
t.Fatalf("uint64 high-bit not converted, got %#v %T", output, output)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
type mysqlTx struct {
|
||||
mc *mysqlConn
|
||||
}
|
||||
|
||||
func (tx *mysqlTx) Commit() (err error) {
|
||||
if tx.mc == nil || tx.mc.closed.IsSet() {
|
||||
return ErrInvalidConn
|
||||
}
|
||||
err = tx.mc.exec("COMMIT")
|
||||
tx.mc = nil
|
||||
return
|
||||
}
|
||||
|
||||
func (tx *mysqlTx) Rollback() (err error) {
|
||||
if tx.mc == nil || tx.mc.closed.IsSet() {
|
||||
return ErrInvalidConn
|
||||
}
|
||||
err = tx.mc.exec("ROLLBACK")
|
||||
tx.mc = nil
|
||||
return
|
||||
}
|
|
@ -6,7 +6,7 @@
|
|||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package driver
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
|
@ -0,0 +1,334 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestScanNullTime(t *testing.T) {
|
||||
var scanTests = []struct {
|
||||
in interface{}
|
||||
error bool
|
||||
valid bool
|
||||
time time.Time
|
||||
}{
|
||||
{tDate, false, true, tDate},
|
||||
{sDate, false, true, tDate},
|
||||
{[]byte(sDate), false, true, tDate},
|
||||
{tDateTime, false, true, tDateTime},
|
||||
{sDateTime, false, true, tDateTime},
|
||||
{[]byte(sDateTime), false, true, tDateTime},
|
||||
{tDate0, false, true, tDate0},
|
||||
{sDate0, false, true, tDate0},
|
||||
{[]byte(sDate0), false, true, tDate0},
|
||||
{sDateTime0, false, true, tDate0},
|
||||
{[]byte(sDateTime0), false, true, tDate0},
|
||||
{"", true, false, tDate0},
|
||||
{"1234", true, false, tDate0},
|
||||
{0, true, false, tDate0},
|
||||
}
|
||||
|
||||
var nt = NullTime{}
|
||||
var err error
|
||||
|
||||
for _, tst := range scanTests {
|
||||
err = nt.Scan(tst.in)
|
||||
if (err != nil) != tst.error {
|
||||
t.Errorf("%v: expected error status %t, got %t", tst.in, tst.error, (err != nil))
|
||||
}
|
||||
if nt.Valid != tst.valid {
|
||||
t.Errorf("%v: expected valid status %t, got %t", tst.in, tst.valid, nt.Valid)
|
||||
}
|
||||
if nt.Time != tst.time {
|
||||
t.Errorf("%v: expected time %v, got %v", tst.in, tst.time, nt.Time)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLengthEncodedInteger(t *testing.T) {
|
||||
var integerTests = []struct {
|
||||
num uint64
|
||||
encoded []byte
|
||||
}{
|
||||
{0x0000000000000000, []byte{0x00}},
|
||||
{0x0000000000000012, []byte{0x12}},
|
||||
{0x00000000000000fa, []byte{0xfa}},
|
||||
{0x0000000000000100, []byte{0xfc, 0x00, 0x01}},
|
||||
{0x0000000000001234, []byte{0xfc, 0x34, 0x12}},
|
||||
{0x000000000000ffff, []byte{0xfc, 0xff, 0xff}},
|
||||
{0x0000000000010000, []byte{0xfd, 0x00, 0x00, 0x01}},
|
||||
{0x0000000000123456, []byte{0xfd, 0x56, 0x34, 0x12}},
|
||||
{0x0000000000ffffff, []byte{0xfd, 0xff, 0xff, 0xff}},
|
||||
{0x0000000001000000, []byte{0xfe, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}},
|
||||
{0x123456789abcdef0, []byte{0xfe, 0xf0, 0xde, 0xbc, 0x9a, 0x78, 0x56, 0x34, 0x12}},
|
||||
{0xffffffffffffffff, []byte{0xfe, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}},
|
||||
}
|
||||
|
||||
for _, tst := range integerTests {
|
||||
num, isNull, numLen := readLengthEncodedInteger(tst.encoded)
|
||||
if isNull {
|
||||
t.Errorf("%x: expected %d, got NULL", tst.encoded, tst.num)
|
||||
}
|
||||
if num != tst.num {
|
||||
t.Errorf("%x: expected %d, got %d", tst.encoded, tst.num, num)
|
||||
}
|
||||
if numLen != len(tst.encoded) {
|
||||
t.Errorf("%x: expected size %d, got %d", tst.encoded, len(tst.encoded), numLen)
|
||||
}
|
||||
encoded := appendLengthEncodedInteger(nil, num)
|
||||
if !bytes.Equal(encoded, tst.encoded) {
|
||||
t.Errorf("%v: expected %x, got %x", num, tst.encoded, encoded)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatBinaryDateTime(t *testing.T) {
|
||||
rawDate := [11]byte{}
|
||||
binary.LittleEndian.PutUint16(rawDate[:2], 1978) // years
|
||||
rawDate[2] = 12 // months
|
||||
rawDate[3] = 30 // days
|
||||
rawDate[4] = 15 // hours
|
||||
rawDate[5] = 46 // minutes
|
||||
rawDate[6] = 23 // seconds
|
||||
binary.LittleEndian.PutUint32(rawDate[7:], 987654) // microseconds
|
||||
expect := func(expected string, inlen, outlen uint8) {
|
||||
actual, _ := formatBinaryDateTime(rawDate[:inlen], outlen)
|
||||
bytes, ok := actual.([]byte)
|
||||
if !ok {
|
||||
t.Errorf("formatBinaryDateTime must return []byte, was %T", actual)
|
||||
}
|
||||
if string(bytes) != expected {
|
||||
t.Errorf(
|
||||
"expected %q, got %q for length in %d, out %d",
|
||||
expected, actual, inlen, outlen,
|
||||
)
|
||||
}
|
||||
}
|
||||
expect("0000-00-00", 0, 10)
|
||||
expect("0000-00-00 00:00:00", 0, 19)
|
||||
expect("1978-12-30", 4, 10)
|
||||
expect("1978-12-30 15:46:23", 7, 19)
|
||||
expect("1978-12-30 15:46:23.987654", 11, 26)
|
||||
}
|
||||
|
||||
func TestFormatBinaryTime(t *testing.T) {
|
||||
expect := func(expected string, src []byte, outlen uint8) {
|
||||
actual, _ := formatBinaryTime(src, outlen)
|
||||
bytes, ok := actual.([]byte)
|
||||
if !ok {
|
||||
t.Errorf("formatBinaryDateTime must return []byte, was %T", actual)
|
||||
}
|
||||
if string(bytes) != expected {
|
||||
t.Errorf(
|
||||
"expected %q, got %q for src=%q and outlen=%d",
|
||||
expected, actual, src, outlen)
|
||||
}
|
||||
}
|
||||
|
||||
// binary format:
|
||||
// sign (0: positive, 1: negative), days(4), hours, minutes, seconds, micro(4)
|
||||
|
||||
// Zeros
|
||||
expect("00:00:00", []byte{}, 8)
|
||||
expect("00:00:00.0", []byte{}, 10)
|
||||
expect("00:00:00.000000", []byte{}, 15)
|
||||
|
||||
// Without micro(4)
|
||||
expect("12:34:56", []byte{0, 0, 0, 0, 0, 12, 34, 56}, 8)
|
||||
expect("-12:34:56", []byte{1, 0, 0, 0, 0, 12, 34, 56}, 8)
|
||||
expect("12:34:56.00", []byte{0, 0, 0, 0, 0, 12, 34, 56}, 11)
|
||||
expect("24:34:56", []byte{0, 1, 0, 0, 0, 0, 34, 56}, 8)
|
||||
expect("-99:34:56", []byte{1, 4, 0, 0, 0, 3, 34, 56}, 8)
|
||||
expect("103079215103:34:56", []byte{0, 255, 255, 255, 255, 23, 34, 56}, 8)
|
||||
|
||||
// With micro(4)
|
||||
expect("12:34:56.00", []byte{0, 0, 0, 0, 0, 12, 34, 56, 99, 0, 0, 0}, 11)
|
||||
expect("12:34:56.000099", []byte{0, 0, 0, 0, 0, 12, 34, 56, 99, 0, 0, 0}, 15)
|
||||
}
|
||||
|
||||
func TestEscapeBackslash(t *testing.T) {
|
||||
expect := func(expected, value string) {
|
||||
actual := string(escapeBytesBackslash([]byte{}, []byte(value)))
|
||||
if actual != expected {
|
||||
t.Errorf(
|
||||
"expected %s, got %s",
|
||||
expected, actual,
|
||||
)
|
||||
}
|
||||
|
||||
actual = string(escapeStringBackslash([]byte{}, value))
|
||||
if actual != expected {
|
||||
t.Errorf(
|
||||
"expected %s, got %s",
|
||||
expected, actual,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
expect("foo\\0bar", "foo\x00bar")
|
||||
expect("foo\\nbar", "foo\nbar")
|
||||
expect("foo\\rbar", "foo\rbar")
|
||||
expect("foo\\Zbar", "foo\x1abar")
|
||||
expect("foo\\\"bar", "foo\"bar")
|
||||
expect("foo\\\\bar", "foo\\bar")
|
||||
expect("foo\\'bar", "foo'bar")
|
||||
}
|
||||
|
||||
func TestEscapeQuotes(t *testing.T) {
|
||||
expect := func(expected, value string) {
|
||||
actual := string(escapeBytesQuotes([]byte{}, []byte(value)))
|
||||
if actual != expected {
|
||||
t.Errorf(
|
||||
"expected %s, got %s",
|
||||
expected, actual,
|
||||
)
|
||||
}
|
||||
|
||||
actual = string(escapeStringQuotes([]byte{}, value))
|
||||
if actual != expected {
|
||||
t.Errorf(
|
||||
"expected %s, got %s",
|
||||
expected, actual,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
expect("foo\x00bar", "foo\x00bar") // not affected
|
||||
expect("foo\nbar", "foo\nbar") // not affected
|
||||
expect("foo\rbar", "foo\rbar") // not affected
|
||||
expect("foo\x1abar", "foo\x1abar") // not affected
|
||||
expect("foo''bar", "foo'bar") // affected
|
||||
expect("foo\"bar", "foo\"bar") // not affected
|
||||
}
|
||||
|
||||
func TestAtomicBool(t *testing.T) {
|
||||
var ab atomicBool
|
||||
if ab.IsSet() {
|
||||
t.Fatal("Expected value to be false")
|
||||
}
|
||||
|
||||
ab.Set(true)
|
||||
if ab.value != 1 {
|
||||
t.Fatal("Set(true) did not set value to 1")
|
||||
}
|
||||
if !ab.IsSet() {
|
||||
t.Fatal("Expected value to be true")
|
||||
}
|
||||
|
||||
ab.Set(true)
|
||||
if !ab.IsSet() {
|
||||
t.Fatal("Expected value to be true")
|
||||
}
|
||||
|
||||
ab.Set(false)
|
||||
if ab.value != 0 {
|
||||
t.Fatal("Set(false) did not set value to 0")
|
||||
}
|
||||
if ab.IsSet() {
|
||||
t.Fatal("Expected value to be false")
|
||||
}
|
||||
|
||||
ab.Set(false)
|
||||
if ab.IsSet() {
|
||||
t.Fatal("Expected value to be false")
|
||||
}
|
||||
if ab.TrySet(false) {
|
||||
t.Fatal("Expected TrySet(false) to fail")
|
||||
}
|
||||
if !ab.TrySet(true) {
|
||||
t.Fatal("Expected TrySet(true) to succeed")
|
||||
}
|
||||
if !ab.IsSet() {
|
||||
t.Fatal("Expected value to be true")
|
||||
}
|
||||
|
||||
ab.Set(true)
|
||||
if !ab.IsSet() {
|
||||
t.Fatal("Expected value to be true")
|
||||
}
|
||||
if ab.TrySet(true) {
|
||||
t.Fatal("Expected TrySet(true) to fail")
|
||||
}
|
||||
if !ab.TrySet(false) {
|
||||
t.Fatal("Expected TrySet(false) to succeed")
|
||||
}
|
||||
if ab.IsSet() {
|
||||
t.Fatal("Expected value to be false")
|
||||
}
|
||||
|
||||
ab._noCopy.Lock() // we've "tested" it ¯\_(ツ)_/¯
|
||||
}
|
||||
|
||||
func TestAtomicError(t *testing.T) {
|
||||
var ae atomicError
|
||||
if ae.Value() != nil {
|
||||
t.Fatal("Expected value to be nil")
|
||||
}
|
||||
|
||||
ae.Set(ErrMalformPkt)
|
||||
if v := ae.Value(); v != ErrMalformPkt {
|
||||
if v == nil {
|
||||
t.Fatal("Value is still nil")
|
||||
}
|
||||
t.Fatal("Error did not match")
|
||||
}
|
||||
ae.Set(ErrPktSync)
|
||||
if ae.Value() == ErrMalformPkt {
|
||||
t.Fatal("Error still matches old error")
|
||||
}
|
||||
if v := ae.Value(); v != ErrPktSync {
|
||||
t.Fatal("Error did not match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsolationLevelMapping(t *testing.T) {
|
||||
data := []struct {
|
||||
level driver.IsolationLevel
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
level: driver.IsolationLevel(sql.LevelReadCommitted),
|
||||
expected: "READ COMMITTED",
|
||||
},
|
||||
{
|
||||
level: driver.IsolationLevel(sql.LevelRepeatableRead),
|
||||
expected: "REPEATABLE READ",
|
||||
},
|
||||
{
|
||||
level: driver.IsolationLevel(sql.LevelReadUncommitted),
|
||||
expected: "READ UNCOMMITTED",
|
||||
},
|
||||
{
|
||||
level: driver.IsolationLevel(sql.LevelSerializable),
|
||||
expected: "SERIALIZABLE",
|
||||
},
|
||||
}
|
||||
|
||||
for i, td := range data {
|
||||
if actual, err := mapIsolationLevel(td.level); actual != td.expected || err != nil {
|
||||
t.Fatal(i, td.expected, actual, err)
|
||||
}
|
||||
}
|
||||
|
||||
// check unsupported mapping
|
||||
expectedErr := "mysql: unsupported isolation level: 7"
|
||||
actual, err := mapIsolationLevel(driver.IsolationLevel(sql.LevelLinearizable))
|
||||
if actual != "" || err == nil {
|
||||
t.Fatal("Expected error on unsupported isolation level")
|
||||
}
|
||||
if err.Error() != expectedErr {
|
||||
t.Fatalf("Expected error to be %q, got %q", expectedErr, err)
|
||||
}
|
||||
}
|
|
@ -6,12 +6,12 @@ import (
|
|||
"os"
|
||||
|
||||
"github.com/localhots/bocadillo/buffer"
|
||||
"github.com/localhots/bocadillo/mysql/driver"
|
||||
"github.com/localhots/bocadillo/mysql/slave/internal/mysql"
|
||||
)
|
||||
|
||||
// Conn is a slave connection used to issue a binlog dump command.
|
||||
type Conn struct {
|
||||
conn *driver.ExtendedConn
|
||||
conn *mysql.ExtendedConn
|
||||
conf Config
|
||||
}
|
||||
|
||||
|
@ -49,12 +49,17 @@ func Connect(dsn string, conf Config) (*Conn, error) {
|
|||
conf.Offset = 4
|
||||
}
|
||||
|
||||
conn, err := driver.NewExtendedConnection(dsn)
|
||||
conn, err := (mysql.MySQLDriver{}).Open(dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Conn{conn: conn, conf: conf}, nil
|
||||
extconn, err := mysql.ExtendConn(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Conn{conn: extconn, conf: conf}, nil
|
||||
}
|
||||
|
||||
// ReadPacket reads next packet from the server and processes the first status
|
||||
|
|
Loading…
Reference in New Issue