1
0
Fork 0

Don't mess with driver code, include it as-is

It doesn't seem worth the trouble to reduce amount of code in the driver. Removing automatic registration (along with distinguishing between network timeouts and context timeouts) seem sufficient.
This commit is contained in:
Gregory Eremin 2018-12-11 20:52:44 +01:00
parent 2f12207dbe
commit 05009b19d7
34 changed files with 7311 additions and 425 deletions

View File

@ -1,374 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package driver
import (
"context"
"database/sql/driver"
"net"
"strings"
"time"
)
type mysqlConn struct {
buf buffer
netConn net.Conn
affectedRows uint64
insertID uint64
cfg *Config
maxAllowedPacket int
maxWriteSize int
writeTimeout time.Duration
flags clientFlag
status statusFlag
sequence uint8
parseTime bool
// for context support (Go 1.8+)
watching bool
watcher chan<- context.Context
closech chan struct{}
finished chan<- struct{}
canceled atomicError // set non-nil if conn is canceled
closed atomicBool // set when conn is closed, before closech is closed
}
func newConnection(dsn string) (*mysqlConn, error) {
var err error
// New mysqlConn
mc := &mysqlConn{
maxAllowedPacket: maxPacketSize,
maxWriteSize: maxPacketSize - 1,
closech: make(chan struct{}),
}
mc.cfg, err = ParseDSN(dsn)
if err != nil {
return nil, err
}
mc.parseTime = mc.cfg.ParseTime
// Connect to Server
nd := net.Dialer{Timeout: mc.cfg.Timeout}
mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr)
if err != nil {
return nil, err
}
// Enable TCP Keepalives on TCP connections
if tc, ok := mc.netConn.(*net.TCPConn); ok {
if err := tc.SetKeepAlive(true); err != nil {
// Don't send COM_QUIT before handshake.
mc.netConn.Close()
mc.netConn = nil
return nil, err
}
}
// Call startWatcher for context support (From Go 1.8)
mc.startWatcher()
mc.buf = newBuffer(mc.netConn)
// Set I/O timeouts
mc.buf.timeout = mc.cfg.ReadTimeout
mc.writeTimeout = mc.cfg.WriteTimeout
// Reading Handshake Initialization Packet
authData, plugin, err := mc.readHandshakePacket()
if err != nil {
mc.cleanup()
return nil, err
}
// Send Client Authentication Packet
authResp, addNUL, err := mc.auth(authData, plugin)
if err != nil {
// try the default auth plugin, if using the requested plugin failed
errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
plugin = defaultAuthPlugin
authResp, addNUL, err = mc.auth(authData, plugin)
if err != nil {
mc.cleanup()
return nil, err
}
}
if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil {
mc.cleanup()
return nil, err
}
// Handle response to auth packet, switch methods if possible
if err = mc.handleAuthResult(authData, plugin); err != nil {
// Authentication failed and MySQL has already closed the connection
// (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
// Do not send COM_QUIT, just cleanup and return the error.
mc.cleanup()
return nil, err
}
if mc.cfg.MaxAllowedPacket > 0 {
mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket
} else {
// Get max allowed packet size
maxap, err := mc.getSystemVar("max_allowed_packet")
if err != nil {
mc.Close()
return nil, err
}
mc.maxAllowedPacket = stringToInt(maxap) - 1
}
if mc.maxAllowedPacket < maxPacketSize {
mc.maxWriteSize = mc.maxAllowedPacket
}
// Handle DSN Params
err = mc.handleParams()
if err != nil {
mc.Close()
return nil, err
}
return mc, nil
}
// Handles parameters set in DSN after the connection is established
func (mc *mysqlConn) handleParams() (err error) {
for param, val := range mc.cfg.Params {
switch param {
// Charset
case "charset":
charsets := strings.Split(val, ",")
for i := range charsets {
// ignore errors here - a charset may not exist
err = mc.exec("SET NAMES " + charsets[i])
if err == nil {
break
}
}
if err != nil {
return
}
// System Vars
default:
err = mc.exec("SET " + param + "=" + val)
if err != nil {
return
}
}
}
return
}
func (mc *mysqlConn) markBadConn(err error) error {
if mc == nil {
return err
}
if err != errBadConnNoWrite {
return err
}
return driver.ErrBadConn
}
func (mc *mysqlConn) Close() (err error) {
// Makes Close idempotent
if !mc.closed.IsSet() {
err = mc.writeCommandPacket(comQuit)
}
mc.cleanup()
return
}
// Closes the network connection and unsets internal variables. Do not call this
// function after successfully authentication, call Close instead. This function
// is called before auth or on auth failure because MySQL will have already
// closed the network connection.
func (mc *mysqlConn) cleanup() {
if !mc.closed.TrySet(true) {
return
}
// Makes cleanup idempotent
close(mc.closech)
if mc.netConn == nil {
return
}
if err := mc.netConn.Close(); err != nil {
errLog.Print(err)
}
}
func (mc *mysqlConn) error() error {
if mc.closed.IsSet() {
if err := mc.canceled.Value(); err != nil {
return err
}
return ErrInvalidConn
}
return nil
}
// Internal function to execute commands
func (mc *mysqlConn) exec(query string) error {
// Send command
if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
return mc.markBadConn(err)
}
// Read Result
resLen, err := mc.readResultSetHeaderPacket()
if err != nil {
return err
}
if resLen > 0 {
// columns
if err := mc.readUntilEOF(); err != nil {
return err
}
// rows
if err := mc.readUntilEOF(); err != nil {
return err
}
}
return mc.discardResults()
}
// Gets the value of the given MySQL System Variable
// The returned byte slice is only valid until the next read
func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
// Send command
if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
return nil, err
}
// Read Result
resLen, err := mc.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc
rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
if resLen > 0 {
// Columns
if err := mc.readUntilEOF(); err != nil {
return nil, err
}
}
dest := make([]driver.Value, resLen)
if err = rows.readRow(dest); err == nil {
return dest[0].([]byte), mc.readUntilEOF()
}
}
return nil, err
}
// finish is called when the query has canceled.
func (mc *mysqlConn) cancel(err error) {
mc.canceled.Set(err)
mc.cleanup()
}
// finish is called when the query has succeeded.
func (mc *mysqlConn) finish() {
if !mc.watching || mc.finished == nil {
return
}
select {
case mc.finished <- struct{}{}:
mc.watching = false
case <-mc.closech:
}
}
// Ping implements driver.Pinger interface
func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
if mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
return driver.ErrBadConn
}
if err = mc.watchCancel(ctx); err != nil {
return
}
defer mc.finish()
if err = mc.writeCommandPacket(comPing); err != nil {
return
}
return mc.readResultOK()
}
func (mc *mysqlConn) watchCancel(ctx context.Context) error {
if mc.watching {
// Reach here if canceled,
// so the connection is already invalid
mc.cleanup()
return nil
}
if ctx.Done() == nil {
return nil
}
mc.watching = true
select {
default:
case <-ctx.Done():
return ctx.Err()
}
if mc.watcher == nil {
return nil
}
mc.watcher <- ctx
return nil
}
func (mc *mysqlConn) startWatcher() {
watcher := make(chan context.Context, 1)
mc.watcher = watcher
finished := make(chan struct{})
mc.finished = finished
go func() {
for {
var ctx context.Context
select {
case ctx = <-watcher:
case <-mc.closech:
return
}
select {
case <-ctx.Done():
mc.cancel(ctx.Err())
case <-finished:
case <-mc.closech:
return
}
}
}()
}
// ResetSession implements driver.SessionResetter.
// (From Go 1.10)
func (mc *mysqlConn) ResetSession(ctx context.Context) error {
if mc.closed.IsSet() {
return driver.ErrBadConn
}
return nil
}

View File

@ -0,0 +1,167 @@
## Version 1.4 (2018-06-03)
Changes:
- Documentation fixes (#530, #535, #567)
- Refactoring (#575, #579, #580, #581, #603, #615, #704)
- Cache column names (#444)
- Sort the DSN parameters in DSNs generated from a config (#637)
- Allow native password authentication by default (#644)
- Use the default port if it is missing in the DSN (#668)
- Removed the `strict` mode (#676)
- Do not query `max_allowed_packet` by default (#680)
- Dropped support Go 1.6 and lower (#696)
- Updated `ConvertValue()` to match the database/sql/driver implementation (#760)
- Document the usage of `0000-00-00T00:00:00` as the time.Time zero value (#783)
- Improved the compatibility of the authentication system (#807)
New Features:
- Multi-Results support (#537)
- `rejectReadOnly` DSN option (#604)
- `context.Context` support (#608, #612, #627, #761)
- Transaction isolation level support (#619, #744)
- Read-Only transactions support (#618, #634)
- `NewConfig` function which initializes a config with default values (#679)
- Implemented the `ColumnType` interfaces (#667, #724)
- Support for custom string types in `ConvertValue` (#623)
- Implemented `NamedValueChecker`, improving support for uint64 with high bit set (#690, #709, #710)
- `caching_sha2_password` authentication plugin support (#794, #800, #801, #802)
- Implemented `driver.SessionResetter` (#779)
- `sha256_password` authentication plugin support (#808)
Bugfixes:
- Use the DSN hostname as TLS default ServerName if `tls=true` (#564, #718)
- Fixed LOAD LOCAL DATA INFILE for empty files (#590)
- Removed columns definition cache since it sometimes cached invalid data (#592)
- Don't mutate registered TLS configs (#600)
- Make RegisterTLSConfig concurrency-safe (#613)
- Handle missing auth data in the handshake packet correctly (#646)
- Do not retry queries when data was written to avoid data corruption (#302, #736)
- Cache the connection pointer for error handling before invalidating it (#678)
- Fixed imports for appengine/cloudsql (#700)
- Fix sending STMT_LONG_DATA for 0 byte data (#734)
- Set correct capacity for []bytes read from length-encoded strings (#766)
- Make RegisterDial concurrency-safe (#773)
## Version 1.3 (2016-12-01)
Changes:
- Go 1.1 is no longer supported
- Use decimals fields in MySQL to format time types (#249)
- Buffer optimizations (#269)
- TLS ServerName defaults to the host (#283)
- Refactoring (#400, #410, #437)
- Adjusted documentation for second generation CloudSQL (#485)
- Documented DSN system var quoting rules (#502)
- Made statement.Close() calls idempotent to avoid errors in Go 1.6+ (#512)
New Features:
- Enable microsecond resolution on TIME, DATETIME and TIMESTAMP (#249)
- Support for returning table alias on Columns() (#289, #359, #382)
- Placeholder interpolation, can be actived with the DSN parameter `interpolateParams=true` (#309, #318, #490)
- Support for uint64 parameters with high bit set (#332, #345)
- Cleartext authentication plugin support (#327)
- Exported ParseDSN function and the Config struct (#403, #419, #429)
- Read / Write timeouts (#401)
- Support for JSON field type (#414)
- Support for multi-statements and multi-results (#411, #431)
- DSN parameter to set the driver-side max_allowed_packet value manually (#489)
- Native password authentication plugin support (#494, #524)
Bugfixes:
- Fixed handling of queries without columns and rows (#255)
- Fixed a panic when SetKeepAlive() failed (#298)
- Handle ERR packets while reading rows (#321)
- Fixed reading NULL length-encoded integers in MySQL 5.6+ (#349)
- Fixed absolute paths support in LOAD LOCAL DATA INFILE (#356)
- Actually zero out bytes in handshake response (#378)
- Fixed race condition in registering LOAD DATA INFILE handler (#383)
- Fixed tests with MySQL 5.7.9+ (#380)
- QueryUnescape TLS config names (#397)
- Fixed "broken pipe" error by writing to closed socket (#390)
- Fixed LOAD LOCAL DATA INFILE buffering (#424)
- Fixed parsing of floats into float64 when placeholders are used (#434)
- Fixed DSN tests with Go 1.7+ (#459)
- Handle ERR packets while waiting for EOF (#473)
- Invalidate connection on error while discarding additional results (#513)
- Allow terminating packets of length 0 (#516)
## Version 1.2 (2014-06-03)
Changes:
- We switched back to a "rolling release". `go get` installs the current master branch again
- Version v1 of the driver will not be maintained anymore. Go 1.0 is no longer supported by this driver
- Exported errors to allow easy checking from application code
- Enabled TCP Keepalives on TCP connections
- Optimized INFILE handling (better buffer size calculation, lazy init, ...)
- The DSN parser also checks for a missing separating slash
- Faster binary date / datetime to string formatting
- Also exported the MySQLWarning type
- mysqlConn.Close returns the first error encountered instead of ignoring all errors
- writePacket() automatically writes the packet size to the header
- readPacket() uses an iterative approach instead of the recursive approach to merge splitted packets
New Features:
- `RegisterDial` allows the usage of a custom dial function to establish the network connection
- Setting the connection collation is possible with the `collation` DSN parameter. This parameter should be preferred over the `charset` parameter
- Logging of critical errors is configurable with `SetLogger`
- Google CloudSQL support
Bugfixes:
- Allow more than 32 parameters in prepared statements
- Various old_password fixes
- Fixed TestConcurrent test to pass Go's race detection
- Fixed appendLengthEncodedInteger for large numbers
- Renamed readLengthEnodedString to readLengthEncodedString and skipLengthEnodedString to skipLengthEncodedString (fixed typo)
## Version 1.1 (2013-11-02)
Changes:
- Go-MySQL-Driver now requires Go 1.1
- Connections now use the collation `utf8_general_ci` by default. Adding `&charset=UTF8` to the DSN should not be necessary anymore
- Made closing rows and connections error tolerant. This allows for example deferring rows.Close() without checking for errors
- `[]byte(nil)` is now treated as a NULL value. Before, it was treated like an empty string / `[]byte("")`
- DSN parameter values must now be url.QueryEscape'ed. This allows text values to contain special characters, such as '&'.
- Use the IO buffer also for writing. This results in zero allocations (by the driver) for most queries
- Optimized the buffer for reading
- stmt.Query now caches column metadata
- New Logo
- Changed the copyright header to include all contributors
- Improved the LOAD INFILE documentation
- The driver struct is now exported to make the driver directly accessible
- Refactored the driver tests
- Added more benchmarks and moved all to a separate file
- Other small refactoring
New Features:
- Added *old_passwords* support: Required in some cases, but must be enabled by adding `allowOldPasswords=true` to the DSN since it is insecure
- Added a `clientFoundRows` parameter: Return the number of matching rows instead of the number of rows changed on UPDATEs
- Added TLS/SSL support: Use a TLS/SSL encrypted connection to the server. Custom TLS configs can be registered and used
Bugfixes:
- Fixed MySQL 4.1 support: MySQL 4.1 sends packets with lengths which differ from the specification
- Convert to DB timezone when inserting `time.Time`
- Splitted packets (more than 16MB) are now merged correctly
- Fixed false positive `io.EOF` errors when the data was fully read
- Avoid panics on reuse of closed connections
- Fixed empty string producing false nil values
- Fixed sign byte for positive TIME fields
## Version 1.0 (2013-05-14)
Initial Release

View File

@ -0,0 +1,23 @@
# Contributing Guidelines
## Reporting Issues
Before creating a new Issue, please check first if a similar Issue [already exists](https://github.com/go-sql-driver/mysql/issues?state=open) or was [recently closed](https://github.com/go-sql-driver/mysql/issues?direction=desc&page=1&sort=updated&state=closed).
## Contributing Code
By contributing to this project, you share your code under the Mozilla Public License 2, as specified in the LICENSE file.
Don't forget to add yourself to the AUTHORS file.
### Code Review
Everyone is invited to review and comment on pull requests.
If it looks fine to you, comment with "LGTM" (Looks good to me).
If changes are required, notice the reviewers with "PTAL" (Please take another look) after committing the fixes.
Before merging the Pull Request, at least one [team member](https://github.com/go-sql-driver?tab=members) must have commented with "LGTM".
## Development Ideas
If you are looking for ideas for code contributions, please check our [Development Ideas](https://github.com/go-sql-driver/mysql/wiki/Development-Ideas) Wiki page.

View File

@ -0,0 +1,19 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
// +build appengine
package mysql
import (
"google.golang.org/appengine/cloudsql"
)
func init() {
RegisterDial("cloudsql", cloudsql.Dial)
}

View File

@ -6,7 +6,7 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package driver
package mysql
import (
"crypto/rand"

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,319 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"bytes"
"context"
"database/sql"
"database/sql/driver"
"fmt"
"math"
"runtime"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
)
type TB testing.B
func (tb *TB) check(err error) {
if err != nil {
tb.Fatal(err)
}
}
func (tb *TB) checkDB(db *sql.DB, err error) *sql.DB {
tb.check(err)
return db
}
func (tb *TB) checkRows(rows *sql.Rows, err error) *sql.Rows {
tb.check(err)
return rows
}
func (tb *TB) checkStmt(stmt *sql.Stmt, err error) *sql.Stmt {
tb.check(err)
return stmt
}
func initDB(b *testing.B, queries ...string) *sql.DB {
tb := (*TB)(b)
db := tb.checkDB(sql.Open("mysql", dsn))
for _, query := range queries {
if _, err := db.Exec(query); err != nil {
b.Fatalf("error on %q: %v", query, err)
}
}
return db
}
const concurrencyLevel = 10
func BenchmarkQuery(b *testing.B) {
tb := (*TB)(b)
b.StopTimer()
b.ReportAllocs()
db := initDB(b,
"DROP TABLE IF EXISTS foo",
"CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))",
`INSERT INTO foo VALUES (1, "one")`,
`INSERT INTO foo VALUES (2, "two")`,
)
db.SetMaxIdleConns(concurrencyLevel)
defer db.Close()
stmt := tb.checkStmt(db.Prepare("SELECT val FROM foo WHERE id=?"))
defer stmt.Close()
remain := int64(b.N)
var wg sync.WaitGroup
wg.Add(concurrencyLevel)
defer wg.Wait()
b.StartTimer()
for i := 0; i < concurrencyLevel; i++ {
go func() {
for {
if atomic.AddInt64(&remain, -1) < 0 {
wg.Done()
return
}
var got string
tb.check(stmt.QueryRow(1).Scan(&got))
if got != "one" {
b.Errorf("query = %q; want one", got)
wg.Done()
return
}
}
}()
}
}
func BenchmarkExec(b *testing.B) {
tb := (*TB)(b)
b.StopTimer()
b.ReportAllocs()
db := tb.checkDB(sql.Open("mysql", dsn))
db.SetMaxIdleConns(concurrencyLevel)
defer db.Close()
stmt := tb.checkStmt(db.Prepare("DO 1"))
defer stmt.Close()
remain := int64(b.N)
var wg sync.WaitGroup
wg.Add(concurrencyLevel)
defer wg.Wait()
b.StartTimer()
for i := 0; i < concurrencyLevel; i++ {
go func() {
for {
if atomic.AddInt64(&remain, -1) < 0 {
wg.Done()
return
}
if _, err := stmt.Exec(); err != nil {
b.Fatal(err.Error())
}
}
}()
}
}
// data, but no db writes
var roundtripSample []byte
func initRoundtripBenchmarks() ([]byte, int, int) {
if roundtripSample == nil {
roundtripSample = []byte(strings.Repeat("0123456789abcdef", 1024*1024))
}
return roundtripSample, 16, len(roundtripSample)
}
func BenchmarkRoundtripTxt(b *testing.B) {
b.StopTimer()
sample, min, max := initRoundtripBenchmarks()
sampleString := string(sample)
b.ReportAllocs()
tb := (*TB)(b)
db := tb.checkDB(sql.Open("mysql", dsn))
defer db.Close()
b.StartTimer()
var result string
for i := 0; i < b.N; i++ {
length := min + i
if length > max {
length = max
}
test := sampleString[0:length]
rows := tb.checkRows(db.Query(`SELECT "` + test + `"`))
if !rows.Next() {
rows.Close()
b.Fatalf("crashed")
}
err := rows.Scan(&result)
if err != nil {
rows.Close()
b.Fatalf("crashed")
}
if result != test {
rows.Close()
b.Errorf("mismatch")
}
rows.Close()
}
}
func BenchmarkRoundtripBin(b *testing.B) {
b.StopTimer()
sample, min, max := initRoundtripBenchmarks()
b.ReportAllocs()
tb := (*TB)(b)
db := tb.checkDB(sql.Open("mysql", dsn))
defer db.Close()
stmt := tb.checkStmt(db.Prepare("SELECT ?"))
defer stmt.Close()
b.StartTimer()
var result sql.RawBytes
for i := 0; i < b.N; i++ {
length := min + i
if length > max {
length = max
}
test := sample[0:length]
rows := tb.checkRows(stmt.Query(test))
if !rows.Next() {
rows.Close()
b.Fatalf("crashed")
}
err := rows.Scan(&result)
if err != nil {
rows.Close()
b.Fatalf("crashed")
}
if !bytes.Equal(result, test) {
rows.Close()
b.Errorf("mismatch")
}
rows.Close()
}
}
func BenchmarkInterpolation(b *testing.B) {
mc := &mysqlConn{
cfg: &Config{
InterpolateParams: true,
Loc: time.UTC,
},
maxAllowedPacket: maxPacketSize,
maxWriteSize: maxPacketSize - 1,
buf: newBuffer(nil),
}
args := []driver.Value{
int64(42424242),
float64(math.Pi),
false,
time.Unix(1423411542, 807015000),
[]byte("bytes containing special chars ' \" \a \x00"),
"string containing special chars ' \" \a \x00",
}
q := "SELECT ?, ?, ?, ?, ?, ?"
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := mc.interpolateParams(q, args)
if err != nil {
b.Fatal(err)
}
}
}
func benchmarkQueryContext(b *testing.B, db *sql.DB, p int) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db.SetMaxIdleConns(p * runtime.GOMAXPROCS(0))
tb := (*TB)(b)
stmt := tb.checkStmt(db.PrepareContext(ctx, "SELECT val FROM foo WHERE id=?"))
defer stmt.Close()
b.SetParallelism(p)
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
var got string
for pb.Next() {
tb.check(stmt.QueryRow(1).Scan(&got))
if got != "one" {
b.Fatalf("query = %q; want one", got)
}
}
})
}
func BenchmarkQueryContext(b *testing.B) {
db := initDB(b,
"DROP TABLE IF EXISTS foo",
"CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))",
`INSERT INTO foo VALUES (1, "one")`,
`INSERT INTO foo VALUES (2, "two")`,
)
defer db.Close()
for _, p := range []int{1, 2, 3, 4} {
b.Run(fmt.Sprintf("%d", p), func(b *testing.B) {
benchmarkQueryContext(b, db, p)
})
}
}
func benchmarkExecContext(b *testing.B, db *sql.DB, p int) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db.SetMaxIdleConns(p * runtime.GOMAXPROCS(0))
tb := (*TB)(b)
stmt := tb.checkStmt(db.PrepareContext(ctx, "DO 1"))
defer stmt.Close()
b.SetParallelism(p)
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
if _, err := stmt.ExecContext(ctx); err != nil {
b.Fatal(err)
}
}
})
}
func BenchmarkExecContext(b *testing.B) {
db := initDB(b,
"DROP TABLE IF EXISTS foo",
"CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))",
`INSERT INTO foo VALUES (1, "one")`,
`INSERT INTO foo VALUES (2, "two")`,
)
defer db.Close()
for _, p := range []int{1, 2, 3, 4} {
b.Run(fmt.Sprintf("%d", p), func(b *testing.B) {
benchmarkQueryContext(b, db, p)
})
}
}

View File

@ -1,35 +1,45 @@
package driver
package mysql
import (
"context"
"database/sql/driver"
"errors"
"time"
)
// ExtendConn creates an extended connection.
func ExtendConn(conn driver.Conn) (*ExtendedConn, error) {
if conn == nil {
return nil, errors.New("Connection is nil")
}
mc, ok := conn.(*mysqlConn)
if !ok || mc == nil {
return nil, errors.New("Invalid connection")
}
return &ExtendedConn{mc}, nil
}
// ExtendedConn provides access to internal packet functions.
type ExtendedConn struct {
*mysqlConn
}
// NewExtendedConnection creates a new connection extended with packet access
// methods.
func NewExtendedConnection(dsn string) (*ExtendedConn, error) {
conn, err := newConnection(dsn)
if err != nil {
return nil, err
}
return &ExtendedConn{conn}, nil
// Close ...
func (c *ExtendedConn) Close() error {
c.buf.length = 0
return c.mysqlConn.Close()
}
// Exec executes a query.
// Exec ...
func (c *ExtendedConn) Exec(query string) error {
return c.exec(query)
}
// ReadPacket reads a packet from a given connection. If given context has a
// deadline it would be used as read timeout.
// ReadPacket reads a packet from a given connection.
func (c *ExtendedConn) ReadPacket(ctx context.Context) ([]byte, error) {
if dl, ok := ctx.Deadline(); ok {
dur := time.Until(dl)
dur := dl.Sub(time.Now())
if dur < 0 {
return nil, context.DeadlineExceeded
}
@ -46,27 +56,17 @@ func (c *ExtendedConn) WritePacket(p []byte) error {
return c.writePacket(p)
}
// ReadResultOK returns an error if packet is not an OK_Packet.
// Spec: https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html
// ReadResultOK ...
func (c *ExtendedConn) ReadResultOK() error {
return c.readResultOK()
}
// HandleErrorPacket reads error message from ERR_Packet.
// Spec: https://dev.mysql.com/doc/internals/en/packet-ERR_Packet.html
// HandleErrorPacket ...
func (c *ExtendedConn) HandleErrorPacket(data []byte) error {
return c.handleErrorPacket(data)
}
// ResetSequence resets command sequence counter.
// ResetSequence ...
func (c *ExtendedConn) ResetSequence() {
c.sequence = 0
}
// Close the connection.
func (c *ExtendedConn) Close() error {
// Reset buffer length parameter
// If it's not zero bad stuff happens
c.buf.length = 0
return c.mysqlConn.Close()
}

View File

@ -6,7 +6,7 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package driver
package mysql
import (
"io"

View File

@ -6,7 +6,7 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package driver
package mysql
const defaultCollation = "utf8_general_ci"
const binaryCollation = "binary"

View File

@ -0,0 +1,654 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"context"
"database/sql"
"database/sql/driver"
"io"
"net"
"strconv"
"strings"
"time"
)
// a copy of context.Context for Go 1.7 and earlier
type mysqlContext interface {
Done() <-chan struct{}
Err() error
// defined in context.Context, but not used in this driver:
// Deadline() (deadline time.Time, ok bool)
// Value(key interface{}) interface{}
}
type mysqlConn struct {
buf buffer
netConn net.Conn
affectedRows uint64
insertId uint64
cfg *Config
maxAllowedPacket int
maxWriteSize int
writeTimeout time.Duration
flags clientFlag
status statusFlag
sequence uint8
parseTime bool
// for context support (Go 1.8+)
watching bool
watcher chan<- mysqlContext
closech chan struct{}
finished chan<- struct{}
canceled atomicError // set non-nil if conn is canceled
closed atomicBool // set when conn is closed, before closech is closed
}
// Handles parameters set in DSN after the connection is established
func (mc *mysqlConn) handleParams() (err error) {
for param, val := range mc.cfg.Params {
switch param {
// Charset
case "charset":
charsets := strings.Split(val, ",")
for i := range charsets {
// ignore errors here - a charset may not exist
err = mc.exec("SET NAMES " + charsets[i])
if err == nil {
break
}
}
if err != nil {
return
}
// System Vars
default:
err = mc.exec("SET " + param + "=" + val + "")
if err != nil {
return
}
}
}
return
}
func (mc *mysqlConn) markBadConn(err error) error {
if mc == nil {
return err
}
if err != errBadConnNoWrite {
return err
}
return driver.ErrBadConn
}
func (mc *mysqlConn) Begin() (driver.Tx, error) {
return mc.begin(false)
}
func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
if mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
var q string
if readOnly {
q = "START TRANSACTION READ ONLY"
} else {
q = "START TRANSACTION"
}
err := mc.exec(q)
if err == nil {
return &mysqlTx{mc}, err
}
return nil, mc.markBadConn(err)
}
func (mc *mysqlConn) Close() (err error) {
// Makes Close idempotent
if !mc.closed.IsSet() {
err = mc.writeCommandPacket(comQuit)
}
mc.cleanup()
return
}
// Closes the network connection and unsets internal variables. Do not call this
// function after successfully authentication, call Close instead. This function
// is called before auth or on auth failure because MySQL will have already
// closed the network connection.
func (mc *mysqlConn) cleanup() {
if !mc.closed.TrySet(true) {
return
}
// Makes cleanup idempotent
close(mc.closech)
if mc.netConn == nil {
return
}
if err := mc.netConn.Close(); err != nil {
errLog.Print(err)
}
}
func (mc *mysqlConn) error() error {
if mc.closed.IsSet() {
if err := mc.canceled.Value(); err != nil {
return err
}
return ErrInvalidConn
}
return nil
}
func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
if mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
// Send command
err := mc.writeCommandPacketStr(comStmtPrepare, query)
if err != nil {
return nil, mc.markBadConn(err)
}
stmt := &mysqlStmt{
mc: mc,
}
// Read Result
columnCount, err := stmt.readPrepareResultPacket()
if err == nil {
if stmt.paramCount > 0 {
if err = mc.readUntilEOF(); err != nil {
return nil, err
}
}
if columnCount > 0 {
err = mc.readUntilEOF()
}
}
return stmt, err
}
func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
// Number of ? should be same to len(args)
if strings.Count(query, "?") != len(args) {
return "", driver.ErrSkip
}
buf := mc.buf.takeCompleteBuffer()
if buf == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return "", ErrInvalidConn
}
buf = buf[:0]
argPos := 0
for i := 0; i < len(query); i++ {
q := strings.IndexByte(query[i:], '?')
if q == -1 {
buf = append(buf, query[i:]...)
break
}
buf = append(buf, query[i:i+q]...)
i += q
arg := args[argPos]
argPos++
if arg == nil {
buf = append(buf, "NULL"...)
continue
}
switch v := arg.(type) {
case int64:
buf = strconv.AppendInt(buf, v, 10)
case float64:
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
case bool:
if v {
buf = append(buf, '1')
} else {
buf = append(buf, '0')
}
case time.Time:
if v.IsZero() {
buf = append(buf, "'0000-00-00'"...)
} else {
v := v.In(mc.cfg.Loc)
v = v.Add(time.Nanosecond * 500) // To round under microsecond
year := v.Year()
year100 := year / 100
year1 := year % 100
month := v.Month()
day := v.Day()
hour := v.Hour()
minute := v.Minute()
second := v.Second()
micro := v.Nanosecond() / 1000
buf = append(buf, []byte{
'\'',
digits10[year100], digits01[year100],
digits10[year1], digits01[year1],
'-',
digits10[month], digits01[month],
'-',
digits10[day], digits01[day],
' ',
digits10[hour], digits01[hour],
':',
digits10[minute], digits01[minute],
':',
digits10[second], digits01[second],
}...)
if micro != 0 {
micro10000 := micro / 10000
micro100 := micro / 100 % 100
micro1 := micro % 100
buf = append(buf, []byte{
'.',
digits10[micro10000], digits01[micro10000],
digits10[micro100], digits01[micro100],
digits10[micro1], digits01[micro1],
}...)
}
buf = append(buf, '\'')
}
case []byte:
if v == nil {
buf = append(buf, "NULL"...)
} else {
buf = append(buf, "_binary'"...)
if mc.status&statusNoBackslashEscapes == 0 {
buf = escapeBytesBackslash(buf, v)
} else {
buf = escapeBytesQuotes(buf, v)
}
buf = append(buf, '\'')
}
case string:
buf = append(buf, '\'')
if mc.status&statusNoBackslashEscapes == 0 {
buf = escapeStringBackslash(buf, v)
} else {
buf = escapeStringQuotes(buf, v)
}
buf = append(buf, '\'')
default:
return "", driver.ErrSkip
}
if len(buf)+4 > mc.maxAllowedPacket {
return "", driver.ErrSkip
}
}
if argPos != len(args) {
return "", driver.ErrSkip
}
return string(buf), nil
}
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
if mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
if len(args) != 0 {
if !mc.cfg.InterpolateParams {
return nil, driver.ErrSkip
}
// try to interpolate the parameters to save extra roundtrips for preparing and closing a statement
prepared, err := mc.interpolateParams(query, args)
if err != nil {
return nil, err
}
query = prepared
}
mc.affectedRows = 0
mc.insertId = 0
err := mc.exec(query)
if err == nil {
return &mysqlResult{
affectedRows: int64(mc.affectedRows),
insertId: int64(mc.insertId),
}, err
}
return nil, mc.markBadConn(err)
}
// Internal function to execute commands
func (mc *mysqlConn) exec(query string) error {
// Send command
if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
return mc.markBadConn(err)
}
// Read Result
resLen, err := mc.readResultSetHeaderPacket()
if err != nil {
return err
}
if resLen > 0 {
// columns
if err := mc.readUntilEOF(); err != nil {
return err
}
// rows
if err := mc.readUntilEOF(); err != nil {
return err
}
}
return mc.discardResults()
}
func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
return mc.query(query, args)
}
func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
if mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
if len(args) != 0 {
if !mc.cfg.InterpolateParams {
return nil, driver.ErrSkip
}
// try client-side prepare to reduce roundtrip
prepared, err := mc.interpolateParams(query, args)
if err != nil {
return nil, err
}
query = prepared
}
// Send command
err := mc.writeCommandPacketStr(comQuery, query)
if err == nil {
// Read Result
var resLen int
resLen, err = mc.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc
if resLen == 0 {
rows.rs.done = true
switch err := rows.NextResultSet(); err {
case nil, io.EOF:
return rows, nil
default:
return nil, err
}
}
// Columns
rows.rs.columns, err = mc.readColumns(resLen)
return rows, err
}
}
return nil, mc.markBadConn(err)
}
// Gets the value of the given MySQL System Variable
// The returned byte slice is only valid until the next read
func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
// Send command
if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
return nil, err
}
// Read Result
resLen, err := mc.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc
rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
if resLen > 0 {
// Columns
if err := mc.readUntilEOF(); err != nil {
return nil, err
}
}
dest := make([]driver.Value, resLen)
if err = rows.readRow(dest); err == nil {
return dest[0].([]byte), mc.readUntilEOF()
}
}
return nil, err
}
// finish is called when the query has canceled.
func (mc *mysqlConn) cancel(err error) {
mc.canceled.Set(err)
mc.cleanup()
}
// finish is called when the query has succeeded.
func (mc *mysqlConn) finish() {
if !mc.watching || mc.finished == nil {
return
}
select {
case mc.finished <- struct{}{}:
mc.watching = false
case <-mc.closech:
}
}
// Ping implements driver.Pinger interface
func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
if mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
return driver.ErrBadConn
}
if err = mc.watchCancel(ctx); err != nil {
return
}
defer mc.finish()
if err = mc.writeCommandPacket(comPing); err != nil {
return
}
return mc.readResultOK()
}
// BeginTx implements driver.ConnBeginTx interface
func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
if err := mc.watchCancel(ctx); err != nil {
return nil, err
}
defer mc.finish()
if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault {
level, err := mapIsolationLevel(opts.Isolation)
if err != nil {
return nil, err
}
err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level)
if err != nil {
return nil, err
}
}
return mc.begin(opts.ReadOnly)
}
func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
if err := mc.watchCancel(ctx); err != nil {
return nil, err
}
rows, err := mc.query(query, dargs)
if err != nil {
mc.finish()
return nil, err
}
rows.finish = mc.finish
return rows, err
}
func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
if err := mc.watchCancel(ctx); err != nil {
return nil, err
}
defer mc.finish()
return mc.Exec(query, dargs)
}
func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
if err := mc.watchCancel(ctx); err != nil {
return nil, err
}
stmt, err := mc.Prepare(query)
mc.finish()
if err != nil {
return nil, err
}
select {
default:
case <-ctx.Done():
stmt.Close()
return nil, ctx.Err()
}
return stmt, nil
}
func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
if err := stmt.mc.watchCancel(ctx); err != nil {
return nil, err
}
rows, err := stmt.query(dargs)
if err != nil {
stmt.mc.finish()
return nil, err
}
rows.finish = stmt.mc.finish
return rows, err
}
func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
if err := stmt.mc.watchCancel(ctx); err != nil {
return nil, err
}
defer stmt.mc.finish()
return stmt.Exec(dargs)
}
func (mc *mysqlConn) watchCancel(ctx context.Context) error {
if mc.watching {
// Reach here if canceled,
// so the connection is already invalid
mc.cleanup()
return nil
}
if ctx.Done() == nil {
return nil
}
mc.watching = true
select {
default:
case <-ctx.Done():
return ctx.Err()
}
if mc.watcher == nil {
return nil
}
mc.watcher <- ctx
return nil
}
func (mc *mysqlConn) startWatcher() {
watcher := make(chan mysqlContext, 1)
mc.watcher = watcher
finished := make(chan struct{})
mc.finished = finished
go func() {
for {
var ctx mysqlContext
select {
case ctx = <-watcher:
case <-mc.closech:
return
}
select {
case <-ctx.Done():
mc.cancel(ctx.Err())
case <-finished:
case <-mc.closech:
return
}
}
}()
}
func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
nv.Value, err = converter{}.ConvertValue(nv.Value)
return
}
// ResetSession implements driver.SessionResetter.
// (From Go 1.10)
func (mc *mysqlConn) ResetSession(ctx context.Context) error {
if mc.closed.IsSet() {
return driver.ErrBadConn
}
return nil
}

View File

@ -0,0 +1,81 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"database/sql/driver"
"testing"
)
func TestInterpolateParams(t *testing.T) {
mc := &mysqlConn{
buf: newBuffer(nil),
maxAllowedPacket: maxPacketSize,
cfg: &Config{
InterpolateParams: true,
},
}
q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42), "gopher"})
if err != nil {
t.Errorf("Expected err=nil, got %#v", err)
return
}
expected := `SELECT 42+'gopher'`
if q != expected {
t.Errorf("Expected: %q\nGot: %q", expected, q)
}
}
func TestInterpolateParamsTooManyPlaceholders(t *testing.T) {
mc := &mysqlConn{
buf: newBuffer(nil),
maxAllowedPacket: maxPacketSize,
cfg: &Config{
InterpolateParams: true,
},
}
q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42)})
if err != driver.ErrSkip {
t.Errorf("Expected err=driver.ErrSkip, got err=%#v, q=%#v", err, q)
}
}
// We don't support placeholder in string literal for now.
// https://github.com/go-sql-driver/mysql/pull/490
func TestInterpolateParamsPlaceholderInString(t *testing.T) {
mc := &mysqlConn{
buf: newBuffer(nil),
maxAllowedPacket: maxPacketSize,
cfg: &Config{
InterpolateParams: true,
},
}
q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)})
// When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42`
if err != driver.ErrSkip {
t.Errorf("Expected err=driver.ErrSkip, got err=%#v, q=%#v", err, q)
}
}
func TestCheckNamedValue(t *testing.T) {
value := driver.NamedValue{Value: ^uint64(0)}
x := &mysqlConn{}
err := x.CheckNamedValue(&value)
if err != nil {
t.Fatal("uint64 high-bit not convertible", err)
}
if value.Value != "18446744073709551615" {
t.Fatalf("uint64 high-bit not converted, got %#v %T", value.Value, value.Value)
}
}

View File

@ -6,7 +6,7 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package driver
package mysql
const (
defaultAuthPlugin = "mysql_native_password"

View File

@ -0,0 +1,157 @@
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
// Package mysql provides a MySQL driver for Go's database/sql package.
//
// The driver should be used via the database/sql package:
//
// import "database/sql"
// import _ "github.com/go-sql-driver/mysql"
//
// db, err := sql.Open("mysql", "user:password@/dbname")
//
// See https://github.com/go-sql-driver/mysql#usage for details
package mysql
import (
"database/sql/driver"
"net"
"sync"
)
// MySQLDriver is exported to make the driver directly accessible.
// In general the driver is used via the database/sql package.
type MySQLDriver struct{}
// DialFunc is a function which can be used to establish the network connection.
// Custom dial functions must be registered with RegisterDial
type DialFunc func(addr string) (net.Conn, error)
var (
dialsLock sync.RWMutex
dials map[string]DialFunc
)
// RegisterDial registers a custom dial function. It can then be used by the
// network address mynet(addr), where mynet is the registered new network.
// addr is passed as a parameter to the dial function.
func RegisterDial(net string, dial DialFunc) {
dialsLock.Lock()
defer dialsLock.Unlock()
if dials == nil {
dials = make(map[string]DialFunc)
}
dials[net] = dial
}
// Open new Connection.
// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how
// the DSN string is formated
func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
var err error
// New mysqlConn
mc := &mysqlConn{
maxAllowedPacket: maxPacketSize,
maxWriteSize: maxPacketSize - 1,
closech: make(chan struct{}),
}
mc.cfg, err = ParseDSN(dsn)
if err != nil {
return nil, err
}
mc.parseTime = mc.cfg.ParseTime
// Connect to Server
dialsLock.RLock()
dial, ok := dials[mc.cfg.Net]
dialsLock.RUnlock()
if ok {
mc.netConn, err = dial(mc.cfg.Addr)
} else {
nd := net.Dialer{Timeout: mc.cfg.Timeout}
mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr)
}
if err != nil {
return nil, err
}
// Enable TCP Keepalives on TCP connections
if tc, ok := mc.netConn.(*net.TCPConn); ok {
if err := tc.SetKeepAlive(true); err != nil {
// Don't send COM_QUIT before handshake.
mc.netConn.Close()
mc.netConn = nil
return nil, err
}
}
// Call startWatcher for context support (From Go 1.8)
mc.startWatcher()
mc.buf = newBuffer(mc.netConn)
// Set I/O timeouts
mc.buf.timeout = mc.cfg.ReadTimeout
mc.writeTimeout = mc.cfg.WriteTimeout
// Reading Handshake Initialization Packet
authData, plugin, err := mc.readHandshakePacket()
if err != nil {
mc.cleanup()
return nil, err
}
// Send Client Authentication Packet
authResp, addNUL, err := mc.auth(authData, plugin)
if err != nil {
// try the default auth plugin, if using the requested plugin failed
errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
plugin = defaultAuthPlugin
authResp, addNUL, err = mc.auth(authData, plugin)
if err != nil {
mc.cleanup()
return nil, err
}
}
if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil {
mc.cleanup()
return nil, err
}
// Handle response to auth packet, switch methods if possible
if err = mc.handleAuthResult(authData, plugin); err != nil {
// Authentication failed and MySQL has already closed the connection
// (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
// Do not send COM_QUIT, just cleanup and return the error.
mc.cleanup()
return nil, err
}
if mc.cfg.MaxAllowedPacket > 0 {
mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket
} else {
// Get max allowed packet size
maxap, err := mc.getSystemVar("max_allowed_packet")
if err != nil {
mc.Close()
return nil, err
}
mc.maxAllowedPacket = stringToInt(maxap) - 1
}
if mc.maxAllowedPacket < maxPacketSize {
mc.maxWriteSize = mc.maxAllowedPacket
}
// Handle DSN Params
err = mc.handleParams()
if err != nil {
mc.Close()
return nil, err
}
return mc, nil
}

File diff suppressed because it is too large Load Diff

View File

@ -6,7 +6,7 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package driver
package mysql
import (
"bytes"
@ -463,6 +463,7 @@ func parseDSNParams(cfg *Config, params string) (err error) {
// Collation
case "collation":
cfg.Collation = value
break
case "columnsWithAlias":
var isBool bool

View File

@ -0,0 +1,331 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"crypto/tls"
"fmt"
"net/url"
"reflect"
"testing"
"time"
)
var testDSNs = []struct {
in string
out *Config
}{{
"username:password@protocol(address)/dbname?param=value",
&Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
}, {
"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true",
&Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, ColumnsWithAlias: true},
}, {
"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true&multiStatements=true",
&Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, ColumnsWithAlias: true, MultiStatements: true},
}, {
"user@unix(/path/to/socket)/dbname?charset=utf8",
&Config{User: "user", Net: "unix", Addr: "/path/to/socket", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
}, {
"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true",
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, TLSConfig: "true"},
}, {
"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify",
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8mb4,utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, TLSConfig: "skip-verify"},
}, {
"user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci&maxAllowedPacket=16777216",
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, AllowNativePasswords: true, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, AllowAllFiles: true, AllowOldPasswords: true, ClientFoundRows: true, MaxAllowedPacket: 16777216},
}, {
"user:password@/dbname?allowNativePasswords=false&maxAllowedPacket=0",
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, AllowNativePasswords: false},
}, {
"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local",
&Config{User: "user", Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.Local, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
}, {
"/dbname",
&Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
}, {
"@/",
&Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
}, {
"/",
&Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
}, {
"",
&Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
}, {
"user:p@/ssword@/",
&Config{User: "user", Passwd: "p@/ssword", Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
}, {
"unix/?arg=%2Fsome%2Fpath.ext",
&Config{Net: "unix", Addr: "/tmp/mysql.sock", Params: map[string]string{"arg": "/some/path.ext"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
}, {
"tcp(127.0.0.1)/dbname",
&Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
}, {
"tcp(de:ad:be:ef::ca:fe)/dbname",
&Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
},
}
func TestDSNParser(t *testing.T) {
for i, tst := range testDSNs {
cfg, err := ParseDSN(tst.in)
if err != nil {
t.Error(err.Error())
}
// pointer not static
cfg.tls = nil
if !reflect.DeepEqual(cfg, tst.out) {
t.Errorf("%d. ParseDSN(%q) mismatch:\ngot %+v\nwant %+v", i, tst.in, cfg, tst.out)
}
}
}
func TestDSNParserInvalid(t *testing.T) {
var invalidDSNs = []string{
"@net(addr/", // no closing brace
"@tcp(/", // no closing brace
"tcp(/", // no closing brace
"(/", // no closing brace
"net(addr)//", // unescaped
"User:pass@tcp(1.2.3.4:3306)", // no trailing slash
"net()/", // unknown default addr
//"/dbname?arg=/some/unescaped/path",
}
for i, tst := range invalidDSNs {
if _, err := ParseDSN(tst); err == nil {
t.Errorf("invalid DSN #%d. (%s) didn't error!", i, tst)
}
}
}
func TestDSNReformat(t *testing.T) {
for i, tst := range testDSNs {
dsn1 := tst.in
cfg1, err := ParseDSN(dsn1)
if err != nil {
t.Error(err.Error())
continue
}
cfg1.tls = nil // pointer not static
res1 := fmt.Sprintf("%+v", cfg1)
dsn2 := cfg1.FormatDSN()
cfg2, err := ParseDSN(dsn2)
if err != nil {
t.Error(err.Error())
continue
}
cfg2.tls = nil // pointer not static
res2 := fmt.Sprintf("%+v", cfg2)
if res1 != res2 {
t.Errorf("%d. %q does not match %q", i, res2, res1)
}
}
}
func TestDSNServerPubKey(t *testing.T) {
baseDSN := "User:password@tcp(localhost:5555)/dbname?serverPubKey="
RegisterServerPubKey("testKey", testPubKeyRSA)
defer DeregisterServerPubKey("testKey")
tst := baseDSN + "testKey"
cfg, err := ParseDSN(tst)
if err != nil {
t.Error(err.Error())
}
if cfg.ServerPubKey != "testKey" {
t.Errorf("unexpected cfg.ServerPubKey value: %v", cfg.ServerPubKey)
}
if cfg.pubKey != testPubKeyRSA {
t.Error("pub key pointer doesn't match")
}
// Key is missing
tst = baseDSN + "invalid_name"
cfg, err = ParseDSN(tst)
if err == nil {
t.Errorf("invalid name in DSN (%s) but did not error. Got config: %#v", tst, cfg)
}
}
func TestDSNServerPubKeyQueryEscape(t *testing.T) {
const name = "&%!:"
dsn := "User:password@tcp(localhost:5555)/dbname?serverPubKey=" + url.QueryEscape(name)
RegisterServerPubKey(name, testPubKeyRSA)
defer DeregisterServerPubKey(name)
cfg, err := ParseDSN(dsn)
if err != nil {
t.Error(err.Error())
}
if cfg.pubKey != testPubKeyRSA {
t.Error("pub key pointer doesn't match")
}
}
func TestDSNWithCustomTLS(t *testing.T) {
baseDSN := "User:password@tcp(localhost:5555)/dbname?tls="
tlsCfg := tls.Config{}
RegisterTLSConfig("utils_test", &tlsCfg)
defer DeregisterTLSConfig("utils_test")
// Custom TLS is missing
tst := baseDSN + "invalid_tls"
cfg, err := ParseDSN(tst)
if err == nil {
t.Errorf("invalid custom TLS in DSN (%s) but did not error. Got config: %#v", tst, cfg)
}
tst = baseDSN + "utils_test"
// Custom TLS with a server name
name := "foohost"
tlsCfg.ServerName = name
cfg, err = ParseDSN(tst)
if err != nil {
t.Error(err.Error())
} else if cfg.tls.ServerName != name {
t.Errorf("did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, tst)
}
// Custom TLS without a server name
name = "localhost"
tlsCfg.ServerName = ""
cfg, err = ParseDSN(tst)
if err != nil {
t.Error(err.Error())
} else if cfg.tls.ServerName != name {
t.Errorf("did not get the correct ServerName (%s) parsing DSN (%s).", name, tst)
} else if tlsCfg.ServerName != "" {
t.Errorf("tlsCfg was mutated ServerName (%s) should be empty parsing DSN (%s).", name, tst)
}
}
func TestDSNTLSConfig(t *testing.T) {
expectedServerName := "example.com"
dsn := "tcp(example.com:1234)/?tls=true"
cfg, err := ParseDSN(dsn)
if err != nil {
t.Error(err.Error())
}
if cfg.tls == nil {
t.Error("cfg.tls should not be nil")
}
if cfg.tls.ServerName != expectedServerName {
t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.tls.ServerName)
}
dsn = "tcp(example.com)/?tls=true"
cfg, err = ParseDSN(dsn)
if err != nil {
t.Error(err.Error())
}
if cfg.tls == nil {
t.Error("cfg.tls should not be nil")
}
if cfg.tls.ServerName != expectedServerName {
t.Errorf("cfg.tls.ServerName should be %q, got %q (host without port)", expectedServerName, cfg.tls.ServerName)
}
}
func TestDSNWithCustomTLSQueryEscape(t *testing.T) {
const configKey = "&%!:"
dsn := "User:password@tcp(localhost:5555)/dbname?tls=" + url.QueryEscape(configKey)
name := "foohost"
tlsCfg := tls.Config{ServerName: name}
RegisterTLSConfig(configKey, &tlsCfg)
defer DeregisterTLSConfig(configKey)
cfg, err := ParseDSN(dsn)
if err != nil {
t.Error(err.Error())
} else if cfg.tls.ServerName != name {
t.Errorf("did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, dsn)
}
}
func TestDSNUnsafeCollation(t *testing.T) {
_, err := ParseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=true")
if err != errInvalidDSNUnsafeCollation {
t.Errorf("expected %v, got %v", errInvalidDSNUnsafeCollation, err)
}
_, err = ParseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=false")
if err != nil {
t.Errorf("expected %v, got %v", nil, err)
}
_, err = ParseDSN("/dbname?collation=gbk_chinese_ci")
if err != nil {
t.Errorf("expected %v, got %v", nil, err)
}
_, err = ParseDSN("/dbname?collation=ascii_bin&interpolateParams=true")
if err != nil {
t.Errorf("expected %v, got %v", nil, err)
}
_, err = ParseDSN("/dbname?collation=latin1_german1_ci&interpolateParams=true")
if err != nil {
t.Errorf("expected %v, got %v", nil, err)
}
_, err = ParseDSN("/dbname?collation=utf8_general_ci&interpolateParams=true")
if err != nil {
t.Errorf("expected %v, got %v", nil, err)
}
_, err = ParseDSN("/dbname?collation=utf8mb4_general_ci&interpolateParams=true")
if err != nil {
t.Errorf("expected %v, got %v", nil, err)
}
}
func TestParamsAreSorted(t *testing.T) {
expected := "/dbname?interpolateParams=true&foobar=baz&quux=loo"
cfg := NewConfig()
cfg.DBName = "dbname"
cfg.InterpolateParams = true
cfg.Params = map[string]string{
"quux": "loo",
"foobar": "baz",
}
actual := cfg.FormatDSN()
if actual != expected {
t.Errorf("generic Config.Params were not sorted: want %#v, got %#v", expected, actual)
}
}
func BenchmarkParseDSN(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
for _, tst := range testDSNs {
if _, err := ParseDSN(tst.in); err != nil {
b.Error(err.Error())
}
}
}
}

View File

@ -6,7 +6,7 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package driver
package mysql
import (
"errors"
@ -22,7 +22,7 @@ var (
ErrMalformPkt = errors.New("malformed packet")
ErrNoTLS = errors.New("TLS requested but server does not support TLS")
ErrCleartextPassword = errors.New("this user requires clear text authentication. If you still want to use it, please add 'allowCleartextPasswords=1' to your DSN")
ErrNativePassword = errors.New("this user requires mysql native password authentication")
ErrNativePassword = errors.New("this user requires mysql native password authentication.")
ErrOldPassword = errors.New("this user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords")
ErrUnknownPlugin = errors.New("this authentication plugin is not supported")
ErrOldProtocol = errors.New("MySQL server does not support required protocol 41+")
@ -55,14 +55,14 @@ func SetLogger(logger Logger) error {
return nil
}
// Error is an error type which represents a single MySQL error
type Error struct {
// MySQLError is an error type which represents a single MySQL error
type MySQLError struct {
Number uint16
Message string
}
func (e *Error) Error() string {
return fmt.Sprintf("Error %d: %s", e.Number, e.Message)
func (me *MySQLError) Error() string {
return fmt.Sprintf("Error %d: %s", me.Number, me.Message)
}
func timeoutError(err error) bool {

View File

@ -0,0 +1,42 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"bytes"
"log"
"testing"
)
func TestErrorsSetLogger(t *testing.T) {
previous := errLog
defer func() {
errLog = previous
}()
// set up logger
const expected = "prefix: test\n"
buffer := bytes.NewBuffer(make([]byte, 0, 64))
logger := log.New(buffer, "prefix: ", 0)
// print
SetLogger(logger)
errLog.Print("test")
// check result
if actual := buffer.String(); actual != expected {
t.Errorf("expected %q, got %q", expected, actual)
}
}
func TestErrorsStrictIgnoreNotes(t *testing.T) {
runTests(t, dsn+"&sql_notes=false", func(dbt *DBTest) {
dbt.mustExec("DROP TABLE IF EXISTS does_not_exist")
})
}

View File

@ -6,7 +6,7 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package driver
package mysql
import (
"database/sql"

View File

@ -6,7 +6,7 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package driver
package mysql
import (
"fmt"

View File

@ -6,7 +6,7 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package driver
package mysql
import (
"bytes"
@ -583,7 +583,7 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error {
}
// Error Message [string]
return &Error{
return &MySQLError{
Number: errno,
Message: string(data[pos:]),
}
@ -604,7 +604,7 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
mc.affectedRows, _, n = readLengthEncodedInteger(data[1:])
// Insert id [Length Coded Binary]
mc.insertID, _, m = readLengthEncodedInteger(data[1+n:])
mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
// server_status [2 bytes]
mc.status = readStatus(data[1+n+m : 1+n+m+2])
@ -806,6 +806,307 @@ func (mc *mysqlConn) readUntilEOF() error {
}
}
/******************************************************************************
* Prepared Statements *
******************************************************************************/
// Prepare Result Packets
// http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html
func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
data, err := stmt.mc.readPacket()
if err == nil {
// packet indicator [1 byte]
if data[0] != iOK {
return 0, stmt.mc.handleErrorPacket(data)
}
// statement id [4 bytes]
stmt.id = binary.LittleEndian.Uint32(data[1:5])
// Column count [16 bit uint]
columnCount := binary.LittleEndian.Uint16(data[5:7])
// Param count [16 bit uint]
stmt.paramCount = int(binary.LittleEndian.Uint16(data[7:9]))
// Reserved [8 bit]
// Warning count [16 bit uint]
return columnCount, nil
}
return 0, err
}
// http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html
func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
maxLen := stmt.mc.maxAllowedPacket - 1
pktLen := maxLen
// After the header (bytes 0-3) follows before the data:
// 1 byte command
// 4 bytes stmtID
// 2 bytes paramID
const dataOffset = 1 + 4 + 2
// Cannot use the write buffer since
// a) the buffer is too small
// b) it is in use
data := make([]byte, 4+1+4+2+len(arg))
copy(data[4+dataOffset:], arg)
for argLen := len(arg); argLen > 0; argLen -= pktLen - dataOffset {
if dataOffset+argLen < maxLen {
pktLen = dataOffset + argLen
}
stmt.mc.sequence = 0
// Add command byte [1 byte]
data[4] = comStmtSendLongData
// Add stmtID [32 bit]
data[5] = byte(stmt.id)
data[6] = byte(stmt.id >> 8)
data[7] = byte(stmt.id >> 16)
data[8] = byte(stmt.id >> 24)
// Add paramID [16 bit]
data[9] = byte(paramID)
data[10] = byte(paramID >> 8)
// Send CMD packet
err := stmt.mc.writePacket(data[:4+pktLen])
if err == nil {
data = data[pktLen-dataOffset:]
continue
}
return err
}
// Reset Packet Sequence
stmt.mc.sequence = 0
return nil
}
// Execute Prepared Statement
// http://dev.mysql.com/doc/internals/en/com-stmt-execute.html
func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
if len(args) != stmt.paramCount {
return fmt.Errorf(
"argument count mismatch (got: %d; has: %d)",
len(args),
stmt.paramCount,
)
}
const minPktLen = 4 + 1 + 4 + 1 + 4
mc := stmt.mc
// Determine threshould dynamically to avoid packet size shortage.
longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1)
if longDataSize < 64 {
longDataSize = 64
}
// Reset packet-sequence
mc.sequence = 0
var data []byte
if len(args) == 0 {
data = mc.buf.takeBuffer(minPktLen)
} else {
data = mc.buf.takeCompleteBuffer()
}
if data == nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return errBadConnNoWrite
}
// command [1 byte]
data[4] = comStmtExecute
// statement_id [4 bytes]
data[5] = byte(stmt.id)
data[6] = byte(stmt.id >> 8)
data[7] = byte(stmt.id >> 16)
data[8] = byte(stmt.id >> 24)
// flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte]
data[9] = 0x00
// iteration_count (uint32(1)) [4 bytes]
data[10] = 0x01
data[11] = 0x00
data[12] = 0x00
data[13] = 0x00
if len(args) > 0 {
pos := minPktLen
var nullMask []byte
if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) {
// buffer has to be extended but we don't know by how much so
// we depend on append after all data with known sizes fit.
// We stop at that because we deal with a lot of columns here
// which makes the required allocation size hard to guess.
tmp := make([]byte, pos+maskLen+typesLen)
copy(tmp[:pos], data[:pos])
data = tmp
nullMask = data[pos : pos+maskLen]
pos += maskLen
} else {
nullMask = data[pos : pos+maskLen]
for i := 0; i < maskLen; i++ {
nullMask[i] = 0
}
pos += maskLen
}
// newParameterBoundFlag 1 [1 byte]
data[pos] = 0x01
pos++
// type of each parameter [len(args)*2 bytes]
paramTypes := data[pos:]
pos += len(args) * 2
// value of each parameter [n bytes]
paramValues := data[pos:pos]
valuesCap := cap(paramValues)
for i, arg := range args {
// build NULL-bitmap
if arg == nil {
nullMask[i/8] |= 1 << (uint(i) & 7)
paramTypes[i+i] = byte(fieldTypeNULL)
paramTypes[i+i+1] = 0x00
continue
}
// cache types and values
switch v := arg.(type) {
case int64:
paramTypes[i+i] = byte(fieldTypeLongLong)
paramTypes[i+i+1] = 0x00
if cap(paramValues)-len(paramValues)-8 >= 0 {
paramValues = paramValues[:len(paramValues)+8]
binary.LittleEndian.PutUint64(
paramValues[len(paramValues)-8:],
uint64(v),
)
} else {
paramValues = append(paramValues,
uint64ToBytes(uint64(v))...,
)
}
case float64:
paramTypes[i+i] = byte(fieldTypeDouble)
paramTypes[i+i+1] = 0x00
if cap(paramValues)-len(paramValues)-8 >= 0 {
paramValues = paramValues[:len(paramValues)+8]
binary.LittleEndian.PutUint64(
paramValues[len(paramValues)-8:],
math.Float64bits(v),
)
} else {
paramValues = append(paramValues,
uint64ToBytes(math.Float64bits(v))...,
)
}
case bool:
paramTypes[i+i] = byte(fieldTypeTiny)
paramTypes[i+i+1] = 0x00
if v {
paramValues = append(paramValues, 0x01)
} else {
paramValues = append(paramValues, 0x00)
}
case []byte:
// Common case (non-nil value) first
if v != nil {
paramTypes[i+i] = byte(fieldTypeString)
paramTypes[i+i+1] = 0x00
if len(v) < longDataSize {
paramValues = appendLengthEncodedInteger(paramValues,
uint64(len(v)),
)
paramValues = append(paramValues, v...)
} else {
if err := stmt.writeCommandLongData(i, v); err != nil {
return err
}
}
continue
}
// Handle []byte(nil) as a NULL value
nullMask[i/8] |= 1 << (uint(i) & 7)
paramTypes[i+i] = byte(fieldTypeNULL)
paramTypes[i+i+1] = 0x00
case string:
paramTypes[i+i] = byte(fieldTypeString)
paramTypes[i+i+1] = 0x00
if len(v) < longDataSize {
paramValues = appendLengthEncodedInteger(paramValues,
uint64(len(v)),
)
paramValues = append(paramValues, v...)
} else {
if err := stmt.writeCommandLongData(i, []byte(v)); err != nil {
return err
}
}
case time.Time:
paramTypes[i+i] = byte(fieldTypeString)
paramTypes[i+i+1] = 0x00
var a [64]byte
var b = a[:0]
if v.IsZero() {
b = append(b, "0000-00-00"...)
} else {
b = v.In(mc.cfg.Loc).AppendFormat(b, timeFormat)
}
paramValues = appendLengthEncodedInteger(paramValues,
uint64(len(b)),
)
paramValues = append(paramValues, b...)
default:
return fmt.Errorf("cannot convert type: %T", arg)
}
}
// Check if param values exceeded the available buffer
// In that case we must build the data packet with the new values buffer
if valuesCap != cap(paramValues) {
data = append(data[:pos], paramValues...)
mc.buf.buf = data
}
pos += len(paramValues)
data = data[:pos]
}
return mc.writePacket(data)
}
func (mc *mysqlConn) discardResults() error {
for mc.status&statusMoreResultsExists != 0 {
resLen, err := mc.readResultSetHeaderPacket()

View File

@ -0,0 +1,336 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"bytes"
"errors"
"net"
"testing"
"time"
)
var (
errConnClosed = errors.New("connection is closed")
errConnTooManyReads = errors.New("too many reads")
errConnTooManyWrites = errors.New("too many writes")
)
// struct to mock a net.Conn for testing purposes
type mockConn struct {
laddr net.Addr
raddr net.Addr
data []byte
written []byte
queuedReplies [][]byte
closed bool
read int
reads int
writes int
maxReads int
maxWrites int
}
func (m *mockConn) Read(b []byte) (n int, err error) {
if m.closed {
return 0, errConnClosed
}
m.reads++
if m.maxReads > 0 && m.reads > m.maxReads {
return 0, errConnTooManyReads
}
n = copy(b, m.data)
m.read += n
m.data = m.data[n:]
return
}
func (m *mockConn) Write(b []byte) (n int, err error) {
if m.closed {
return 0, errConnClosed
}
m.writes++
if m.maxWrites > 0 && m.writes > m.maxWrites {
return 0, errConnTooManyWrites
}
n = len(b)
m.written = append(m.written, b...)
if n > 0 && len(m.queuedReplies) > 0 {
m.data = m.queuedReplies[0]
m.queuedReplies = m.queuedReplies[1:]
}
return
}
func (m *mockConn) Close() error {
m.closed = true
return nil
}
func (m *mockConn) LocalAddr() net.Addr {
return m.laddr
}
func (m *mockConn) RemoteAddr() net.Addr {
return m.raddr
}
func (m *mockConn) SetDeadline(t time.Time) error {
return nil
}
func (m *mockConn) SetReadDeadline(t time.Time) error {
return nil
}
func (m *mockConn) SetWriteDeadline(t time.Time) error {
return nil
}
// make sure mockConn implements the net.Conn interface
var _ net.Conn = new(mockConn)
func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) {
conn := new(mockConn)
mc := &mysqlConn{
buf: newBuffer(conn),
cfg: NewConfig(),
netConn: conn,
closech: make(chan struct{}),
maxAllowedPacket: defaultMaxAllowedPacket,
sequence: sequence,
}
return conn, mc
}
func TestReadPacketSingleByte(t *testing.T) {
conn := new(mockConn)
mc := &mysqlConn{
buf: newBuffer(conn),
}
conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff}
conn.maxReads = 1
packet, err := mc.readPacket()
if err != nil {
t.Fatal(err)
}
if len(packet) != 1 {
t.Fatalf("unexpected packet length: expected %d, got %d", 1, len(packet))
}
if packet[0] != 0xff {
t.Fatalf("unexpected packet content: expected %x, got %x", 0xff, packet[0])
}
}
func TestReadPacketWrongSequenceID(t *testing.T) {
conn := new(mockConn)
mc := &mysqlConn{
buf: newBuffer(conn),
}
// too low sequence id
conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff}
conn.maxReads = 1
mc.sequence = 1
_, err := mc.readPacket()
if err != ErrPktSync {
t.Errorf("expected ErrPktSync, got %v", err)
}
// reset
conn.reads = 0
mc.sequence = 0
mc.buf = newBuffer(conn)
// too high sequence id
conn.data = []byte{0x01, 0x00, 0x00, 0x42, 0xff}
_, err = mc.readPacket()
if err != ErrPktSyncMul {
t.Errorf("expected ErrPktSyncMul, got %v", err)
}
}
func TestReadPacketSplit(t *testing.T) {
conn := new(mockConn)
mc := &mysqlConn{
buf: newBuffer(conn),
}
data := make([]byte, maxPacketSize*2+4*3)
const pkt2ofs = maxPacketSize + 4
const pkt3ofs = 2 * (maxPacketSize + 4)
// case 1: payload has length maxPacketSize
data = data[:pkt2ofs+4]
// 1st packet has maxPacketSize length and sequence id 0
// ff ff ff 00 ...
data[0] = 0xff
data[1] = 0xff
data[2] = 0xff
// mark the payload start and end of 1st packet so that we can check if the
// content was correctly appended
data[4] = 0x11
data[maxPacketSize+3] = 0x22
// 2nd packet has payload length 0 and squence id 1
// 00 00 00 01
data[pkt2ofs+3] = 0x01
conn.data = data
conn.maxReads = 3
packet, err := mc.readPacket()
if err != nil {
t.Fatal(err)
}
if len(packet) != maxPacketSize {
t.Fatalf("unexpected packet length: expected %d, got %d", maxPacketSize, len(packet))
}
if packet[0] != 0x11 {
t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0])
}
if packet[maxPacketSize-1] != 0x22 {
t.Fatalf("unexpected payload end: expected %x, got %x", 0x22, packet[maxPacketSize-1])
}
// case 2: payload has length which is a multiple of maxPacketSize
data = data[:cap(data)]
// 2nd packet now has maxPacketSize length
data[pkt2ofs] = 0xff
data[pkt2ofs+1] = 0xff
data[pkt2ofs+2] = 0xff
// mark the payload start and end of the 2nd packet
data[pkt2ofs+4] = 0x33
data[pkt2ofs+maxPacketSize+3] = 0x44
// 3rd packet has payload length 0 and squence id 2
// 00 00 00 02
data[pkt3ofs+3] = 0x02
conn.data = data
conn.reads = 0
conn.maxReads = 5
mc.sequence = 0
packet, err = mc.readPacket()
if err != nil {
t.Fatal(err)
}
if len(packet) != 2*maxPacketSize {
t.Fatalf("unexpected packet length: expected %d, got %d", 2*maxPacketSize, len(packet))
}
if packet[0] != 0x11 {
t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0])
}
if packet[2*maxPacketSize-1] != 0x44 {
t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet[2*maxPacketSize-1])
}
// case 3: payload has a length larger maxPacketSize, which is not an exact
// multiple of it
data = data[:pkt2ofs+4+42]
data[pkt2ofs] = 0x2a
data[pkt2ofs+1] = 0x00
data[pkt2ofs+2] = 0x00
data[pkt2ofs+4+41] = 0x44
conn.data = data
conn.reads = 0
conn.maxReads = 4
mc.sequence = 0
packet, err = mc.readPacket()
if err != nil {
t.Fatal(err)
}
if len(packet) != maxPacketSize+42 {
t.Fatalf("unexpected packet length: expected %d, got %d", maxPacketSize+42, len(packet))
}
if packet[0] != 0x11 {
t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0])
}
if packet[maxPacketSize+41] != 0x44 {
t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet[maxPacketSize+41])
}
}
func TestReadPacketFail(t *testing.T) {
conn := new(mockConn)
mc := &mysqlConn{
buf: newBuffer(conn),
closech: make(chan struct{}),
}
// illegal empty (stand-alone) packet
conn.data = []byte{0x00, 0x00, 0x00, 0x00}
conn.maxReads = 1
_, err := mc.readPacket()
if err != ErrInvalidConn {
t.Errorf("expected ErrInvalidConn, got %v", err)
}
// reset
conn.reads = 0
mc.sequence = 0
mc.buf = newBuffer(conn)
// fail to read header
conn.closed = true
_, err = mc.readPacket()
if err != ErrInvalidConn {
t.Errorf("expected ErrInvalidConn, got %v", err)
}
// reset
conn.closed = false
conn.reads = 0
mc.sequence = 0
mc.buf = newBuffer(conn)
// fail to read body
conn.maxReads = 1
_, err = mc.readPacket()
if err != ErrInvalidConn {
t.Errorf("expected ErrInvalidConn, got %v", err)
}
}
// https://github.com/go-sql-driver/mysql/pull/801
// not-NUL terminated plugin_name in init packet
func TestRegression801(t *testing.T) {
conn := new(mockConn)
mc := &mysqlConn{
buf: newBuffer(conn),
cfg: new(Config),
sequence: 42,
closech: make(chan struct{}),
}
conn.data = []byte{72, 0, 0, 42, 10, 53, 46, 53, 46, 56, 0, 165, 0, 0, 0,
60, 70, 63, 58, 68, 104, 34, 97, 0, 223, 247, 33, 2, 0, 15, 128, 21, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 98, 120, 114, 47, 85, 75, 109, 99, 51, 77,
50, 64, 0, 109, 121, 115, 113, 108, 95, 110, 97, 116, 105, 118, 101, 95,
112, 97, 115, 115, 119, 111, 114, 100}
conn.maxReads = 1
authData, pluginName, err := mc.readHandshakePacket()
if err != nil {
t.Fatalf("got error: %v", err)
}
if pluginName != "mysql_native_password" {
t.Errorf("expected plugin name 'mysql_native_password', got '%s'", pluginName)
}
expectedAuthData := []byte{60, 70, 63, 58, 68, 104, 34, 97, 98, 120, 114,
47, 85, 75, 109, 99, 51, 77, 50, 64}
if !bytes.Equal(authData, expectedAuthData) {
t.Errorf("expected authData '%v', got '%v'", expectedAuthData, authData)
}
}

View File

@ -6,15 +6,15 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package driver
package mysql
type mysqlResult struct {
affectedRows int64
insertID int64
insertId int64
}
func (res *mysqlResult) LastInsertId() (int64, error) {
return res.insertID, nil
return res.insertId, nil
}
func (res *mysqlResult) RowsAffected() (int64, error) {

View File

@ -6,7 +6,7 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package driver
package mysql
import (
"database/sql/driver"

View File

@ -0,0 +1,211 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"database/sql/driver"
"fmt"
"io"
"reflect"
"strconv"
)
type mysqlStmt struct {
mc *mysqlConn
id uint32
paramCount int
}
func (stmt *mysqlStmt) Close() error {
if stmt.mc == nil || stmt.mc.closed.IsSet() {
// driver.Stmt.Close can be called more than once, thus this function
// has to be idempotent.
// See also Issue #450 and golang/go#16019.
//errLog.Print(ErrInvalidConn)
return driver.ErrBadConn
}
err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id)
stmt.mc = nil
return err
}
func (stmt *mysqlStmt) NumInput() int {
return stmt.paramCount
}
func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter {
return converter{}
}
func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
if stmt.mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
// Send command
err := stmt.writeExecutePacket(args)
if err != nil {
return nil, stmt.mc.markBadConn(err)
}
mc := stmt.mc
mc.affectedRows = 0
mc.insertId = 0
// Read Result
resLen, err := mc.readResultSetHeaderPacket()
if err != nil {
return nil, err
}
if resLen > 0 {
// Columns
if err = mc.readUntilEOF(); err != nil {
return nil, err
}
// Rows
if err := mc.readUntilEOF(); err != nil {
return nil, err
}
}
if err := mc.discardResults(); err != nil {
return nil, err
}
return &mysqlResult{
affectedRows: int64(mc.affectedRows),
insertId: int64(mc.insertId),
}, nil
}
func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
return stmt.query(args)
}
func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
if stmt.mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
// Send command
err := stmt.writeExecutePacket(args)
if err != nil {
return nil, stmt.mc.markBadConn(err)
}
mc := stmt.mc
// Read Result
resLen, err := mc.readResultSetHeaderPacket()
if err != nil {
return nil, err
}
rows := new(binaryRows)
if resLen > 0 {
rows.mc = mc
rows.rs.columns, err = mc.readColumns(resLen)
} else {
rows.rs.done = true
switch err := rows.NextResultSet(); err {
case nil, io.EOF:
return rows, nil
default:
return nil, err
}
}
return rows, err
}
type converter struct{}
// ConvertValue mirrors the reference/default converter in database/sql/driver
// with _one_ exception. We support uint64 with their high bit and the default
// implementation does not. This function should be kept in sync with
// database/sql/driver defaultConverter.ConvertValue() except for that
// deliberate difference.
func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
if driver.IsValue(v) {
return v, nil
}
if vr, ok := v.(driver.Valuer); ok {
sv, err := callValuerValue(vr)
if err != nil {
return nil, err
}
if !driver.IsValue(sv) {
return nil, fmt.Errorf("non-Value type %T returned from Value", sv)
}
return sv, nil
}
rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Ptr:
// indirect pointers
if rv.IsNil() {
return nil, nil
} else {
return c.ConvertValue(rv.Elem().Interface())
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return rv.Int(), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
return int64(rv.Uint()), nil
case reflect.Uint64:
u64 := rv.Uint()
if u64 >= 1<<63 {
return strconv.FormatUint(u64, 10), nil
}
return int64(u64), nil
case reflect.Float32, reflect.Float64:
return rv.Float(), nil
case reflect.Bool:
return rv.Bool(), nil
case reflect.Slice:
ek := rv.Type().Elem().Kind()
if ek == reflect.Uint8 {
return rv.Bytes(), nil
}
return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek)
case reflect.String:
return rv.String(), nil
}
return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind())
}
var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
// callValuerValue returns vr.Value(), with one exception:
// If vr.Value is an auto-generated method on a pointer type and the
// pointer is nil, it would panic at runtime in the panicwrap
// method. Treat it like nil instead.
//
// This is so people can implement driver.Value on value types and
// still use nil pointers to those types to mean nil/NULL, just like
// string/*string.
//
// This is an exact copy of the same-named unexported function from the
// database/sql package.
func callValuerValue(vr driver.Valuer) (v driver.Value, err error) {
if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr &&
rv.IsNil() &&
rv.Type().Elem().Implements(valuerReflectType) {
return nil, nil
}
return vr.Value()
}

View File

@ -0,0 +1,126 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"bytes"
"testing"
)
func TestConvertDerivedString(t *testing.T) {
type derived string
output, err := converter{}.ConvertValue(derived("value"))
if err != nil {
t.Fatal("Derived string type not convertible", err)
}
if output != "value" {
t.Fatalf("Derived string type not converted, got %#v %T", output, output)
}
}
func TestConvertDerivedByteSlice(t *testing.T) {
type derived []uint8
output, err := converter{}.ConvertValue(derived("value"))
if err != nil {
t.Fatal("Byte slice not convertible", err)
}
if bytes.Compare(output.([]byte), []byte("value")) != 0 {
t.Fatalf("Byte slice not converted, got %#v %T", output, output)
}
}
func TestConvertDerivedUnsupportedSlice(t *testing.T) {
type derived []int
_, err := converter{}.ConvertValue(derived{1})
if err == nil || err.Error() != "unsupported type mysql.derived, a slice of int" {
t.Fatal("Unexpected error", err)
}
}
func TestConvertDerivedBool(t *testing.T) {
type derived bool
output, err := converter{}.ConvertValue(derived(true))
if err != nil {
t.Fatal("Derived bool type not convertible", err)
}
if output != true {
t.Fatalf("Derived bool type not converted, got %#v %T", output, output)
}
}
func TestConvertPointer(t *testing.T) {
str := "value"
output, err := converter{}.ConvertValue(&str)
if err != nil {
t.Fatal("Pointer type not convertible", err)
}
if output != "value" {
t.Fatalf("Pointer type not converted, got %#v %T", output, output)
}
}
func TestConvertSignedIntegers(t *testing.T) {
values := []interface{}{
int8(-42),
int16(-42),
int32(-42),
int64(-42),
int(-42),
}
for _, value := range values {
output, err := converter{}.ConvertValue(value)
if err != nil {
t.Fatalf("%T type not convertible %s", value, err)
}
if output != int64(-42) {
t.Fatalf("%T type not converted, got %#v %T", value, output, output)
}
}
}
func TestConvertUnsignedIntegers(t *testing.T) {
values := []interface{}{
uint8(42),
uint16(42),
uint32(42),
uint64(42),
uint(42),
}
for _, value := range values {
output, err := converter{}.ConvertValue(value)
if err != nil {
t.Fatalf("%T type not convertible %s", value, err)
}
if output != int64(42) {
t.Fatalf("%T type not converted, got %#v %T", value, output, output)
}
}
output, err := converter{}.ConvertValue(^uint64(0))
if err != nil {
t.Fatal("uint64 high-bit not convertible", err)
}
if output != "18446744073709551615" {
t.Fatalf("uint64 high-bit not converted, got %#v %T", output, output)
}
}

View File

@ -0,0 +1,31 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
type mysqlTx struct {
mc *mysqlConn
}
func (tx *mysqlTx) Commit() (err error) {
if tx.mc == nil || tx.mc.closed.IsSet() {
return ErrInvalidConn
}
err = tx.mc.exec("COMMIT")
tx.mc = nil
return
}
func (tx *mysqlTx) Rollback() (err error) {
if tx.mc == nil || tx.mc.closed.IsSet() {
return ErrInvalidConn
}
err = tx.mc.exec("ROLLBACK")
tx.mc = nil
return
}

View File

@ -6,7 +6,7 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package driver
package mysql
import (
"crypto/tls"

View File

@ -0,0 +1,334 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"bytes"
"database/sql"
"database/sql/driver"
"encoding/binary"
"testing"
"time"
)
func TestScanNullTime(t *testing.T) {
var scanTests = []struct {
in interface{}
error bool
valid bool
time time.Time
}{
{tDate, false, true, tDate},
{sDate, false, true, tDate},
{[]byte(sDate), false, true, tDate},
{tDateTime, false, true, tDateTime},
{sDateTime, false, true, tDateTime},
{[]byte(sDateTime), false, true, tDateTime},
{tDate0, false, true, tDate0},
{sDate0, false, true, tDate0},
{[]byte(sDate0), false, true, tDate0},
{sDateTime0, false, true, tDate0},
{[]byte(sDateTime0), false, true, tDate0},
{"", true, false, tDate0},
{"1234", true, false, tDate0},
{0, true, false, tDate0},
}
var nt = NullTime{}
var err error
for _, tst := range scanTests {
err = nt.Scan(tst.in)
if (err != nil) != tst.error {
t.Errorf("%v: expected error status %t, got %t", tst.in, tst.error, (err != nil))
}
if nt.Valid != tst.valid {
t.Errorf("%v: expected valid status %t, got %t", tst.in, tst.valid, nt.Valid)
}
if nt.Time != tst.time {
t.Errorf("%v: expected time %v, got %v", tst.in, tst.time, nt.Time)
}
}
}
func TestLengthEncodedInteger(t *testing.T) {
var integerTests = []struct {
num uint64
encoded []byte
}{
{0x0000000000000000, []byte{0x00}},
{0x0000000000000012, []byte{0x12}},
{0x00000000000000fa, []byte{0xfa}},
{0x0000000000000100, []byte{0xfc, 0x00, 0x01}},
{0x0000000000001234, []byte{0xfc, 0x34, 0x12}},
{0x000000000000ffff, []byte{0xfc, 0xff, 0xff}},
{0x0000000000010000, []byte{0xfd, 0x00, 0x00, 0x01}},
{0x0000000000123456, []byte{0xfd, 0x56, 0x34, 0x12}},
{0x0000000000ffffff, []byte{0xfd, 0xff, 0xff, 0xff}},
{0x0000000001000000, []byte{0xfe, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}},
{0x123456789abcdef0, []byte{0xfe, 0xf0, 0xde, 0xbc, 0x9a, 0x78, 0x56, 0x34, 0x12}},
{0xffffffffffffffff, []byte{0xfe, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}},
}
for _, tst := range integerTests {
num, isNull, numLen := readLengthEncodedInteger(tst.encoded)
if isNull {
t.Errorf("%x: expected %d, got NULL", tst.encoded, tst.num)
}
if num != tst.num {
t.Errorf("%x: expected %d, got %d", tst.encoded, tst.num, num)
}
if numLen != len(tst.encoded) {
t.Errorf("%x: expected size %d, got %d", tst.encoded, len(tst.encoded), numLen)
}
encoded := appendLengthEncodedInteger(nil, num)
if !bytes.Equal(encoded, tst.encoded) {
t.Errorf("%v: expected %x, got %x", num, tst.encoded, encoded)
}
}
}
func TestFormatBinaryDateTime(t *testing.T) {
rawDate := [11]byte{}
binary.LittleEndian.PutUint16(rawDate[:2], 1978) // years
rawDate[2] = 12 // months
rawDate[3] = 30 // days
rawDate[4] = 15 // hours
rawDate[5] = 46 // minutes
rawDate[6] = 23 // seconds
binary.LittleEndian.PutUint32(rawDate[7:], 987654) // microseconds
expect := func(expected string, inlen, outlen uint8) {
actual, _ := formatBinaryDateTime(rawDate[:inlen], outlen)
bytes, ok := actual.([]byte)
if !ok {
t.Errorf("formatBinaryDateTime must return []byte, was %T", actual)
}
if string(bytes) != expected {
t.Errorf(
"expected %q, got %q for length in %d, out %d",
expected, actual, inlen, outlen,
)
}
}
expect("0000-00-00", 0, 10)
expect("0000-00-00 00:00:00", 0, 19)
expect("1978-12-30", 4, 10)
expect("1978-12-30 15:46:23", 7, 19)
expect("1978-12-30 15:46:23.987654", 11, 26)
}
func TestFormatBinaryTime(t *testing.T) {
expect := func(expected string, src []byte, outlen uint8) {
actual, _ := formatBinaryTime(src, outlen)
bytes, ok := actual.([]byte)
if !ok {
t.Errorf("formatBinaryDateTime must return []byte, was %T", actual)
}
if string(bytes) != expected {
t.Errorf(
"expected %q, got %q for src=%q and outlen=%d",
expected, actual, src, outlen)
}
}
// binary format:
// sign (0: positive, 1: negative), days(4), hours, minutes, seconds, micro(4)
// Zeros
expect("00:00:00", []byte{}, 8)
expect("00:00:00.0", []byte{}, 10)
expect("00:00:00.000000", []byte{}, 15)
// Without micro(4)
expect("12:34:56", []byte{0, 0, 0, 0, 0, 12, 34, 56}, 8)
expect("-12:34:56", []byte{1, 0, 0, 0, 0, 12, 34, 56}, 8)
expect("12:34:56.00", []byte{0, 0, 0, 0, 0, 12, 34, 56}, 11)
expect("24:34:56", []byte{0, 1, 0, 0, 0, 0, 34, 56}, 8)
expect("-99:34:56", []byte{1, 4, 0, 0, 0, 3, 34, 56}, 8)
expect("103079215103:34:56", []byte{0, 255, 255, 255, 255, 23, 34, 56}, 8)
// With micro(4)
expect("12:34:56.00", []byte{0, 0, 0, 0, 0, 12, 34, 56, 99, 0, 0, 0}, 11)
expect("12:34:56.000099", []byte{0, 0, 0, 0, 0, 12, 34, 56, 99, 0, 0, 0}, 15)
}
func TestEscapeBackslash(t *testing.T) {
expect := func(expected, value string) {
actual := string(escapeBytesBackslash([]byte{}, []byte(value)))
if actual != expected {
t.Errorf(
"expected %s, got %s",
expected, actual,
)
}
actual = string(escapeStringBackslash([]byte{}, value))
if actual != expected {
t.Errorf(
"expected %s, got %s",
expected, actual,
)
}
}
expect("foo\\0bar", "foo\x00bar")
expect("foo\\nbar", "foo\nbar")
expect("foo\\rbar", "foo\rbar")
expect("foo\\Zbar", "foo\x1abar")
expect("foo\\\"bar", "foo\"bar")
expect("foo\\\\bar", "foo\\bar")
expect("foo\\'bar", "foo'bar")
}
func TestEscapeQuotes(t *testing.T) {
expect := func(expected, value string) {
actual := string(escapeBytesQuotes([]byte{}, []byte(value)))
if actual != expected {
t.Errorf(
"expected %s, got %s",
expected, actual,
)
}
actual = string(escapeStringQuotes([]byte{}, value))
if actual != expected {
t.Errorf(
"expected %s, got %s",
expected, actual,
)
}
}
expect("foo\x00bar", "foo\x00bar") // not affected
expect("foo\nbar", "foo\nbar") // not affected
expect("foo\rbar", "foo\rbar") // not affected
expect("foo\x1abar", "foo\x1abar") // not affected
expect("foo''bar", "foo'bar") // affected
expect("foo\"bar", "foo\"bar") // not affected
}
func TestAtomicBool(t *testing.T) {
var ab atomicBool
if ab.IsSet() {
t.Fatal("Expected value to be false")
}
ab.Set(true)
if ab.value != 1 {
t.Fatal("Set(true) did not set value to 1")
}
if !ab.IsSet() {
t.Fatal("Expected value to be true")
}
ab.Set(true)
if !ab.IsSet() {
t.Fatal("Expected value to be true")
}
ab.Set(false)
if ab.value != 0 {
t.Fatal("Set(false) did not set value to 0")
}
if ab.IsSet() {
t.Fatal("Expected value to be false")
}
ab.Set(false)
if ab.IsSet() {
t.Fatal("Expected value to be false")
}
if ab.TrySet(false) {
t.Fatal("Expected TrySet(false) to fail")
}
if !ab.TrySet(true) {
t.Fatal("Expected TrySet(true) to succeed")
}
if !ab.IsSet() {
t.Fatal("Expected value to be true")
}
ab.Set(true)
if !ab.IsSet() {
t.Fatal("Expected value to be true")
}
if ab.TrySet(true) {
t.Fatal("Expected TrySet(true) to fail")
}
if !ab.TrySet(false) {
t.Fatal("Expected TrySet(false) to succeed")
}
if ab.IsSet() {
t.Fatal("Expected value to be false")
}
ab._noCopy.Lock() // we've "tested" it ¯\_(ツ)_/¯
}
func TestAtomicError(t *testing.T) {
var ae atomicError
if ae.Value() != nil {
t.Fatal("Expected value to be nil")
}
ae.Set(ErrMalformPkt)
if v := ae.Value(); v != ErrMalformPkt {
if v == nil {
t.Fatal("Value is still nil")
}
t.Fatal("Error did not match")
}
ae.Set(ErrPktSync)
if ae.Value() == ErrMalformPkt {
t.Fatal("Error still matches old error")
}
if v := ae.Value(); v != ErrPktSync {
t.Fatal("Error did not match")
}
}
func TestIsolationLevelMapping(t *testing.T) {
data := []struct {
level driver.IsolationLevel
expected string
}{
{
level: driver.IsolationLevel(sql.LevelReadCommitted),
expected: "READ COMMITTED",
},
{
level: driver.IsolationLevel(sql.LevelRepeatableRead),
expected: "REPEATABLE READ",
},
{
level: driver.IsolationLevel(sql.LevelReadUncommitted),
expected: "READ UNCOMMITTED",
},
{
level: driver.IsolationLevel(sql.LevelSerializable),
expected: "SERIALIZABLE",
},
}
for i, td := range data {
if actual, err := mapIsolationLevel(td.level); actual != td.expected || err != nil {
t.Fatal(i, td.expected, actual, err)
}
}
// check unsupported mapping
expectedErr := "mysql: unsupported isolation level: 7"
actual, err := mapIsolationLevel(driver.IsolationLevel(sql.LevelLinearizable))
if actual != "" || err == nil {
t.Fatal("Expected error on unsupported isolation level")
}
if err.Error() != expectedErr {
t.Fatalf("Expected error to be %q, got %q", expectedErr, err)
}
}

View File

@ -6,12 +6,12 @@ import (
"os"
"github.com/localhots/bocadillo/buffer"
"github.com/localhots/bocadillo/mysql/driver"
"github.com/localhots/bocadillo/mysql/slave/internal/mysql"
)
// Conn is a slave connection used to issue a binlog dump command.
type Conn struct {
conn *driver.ExtendedConn
conn *mysql.ExtendedConn
conf Config
}
@ -49,12 +49,17 @@ func Connect(dsn string, conf Config) (*Conn, error) {
conf.Offset = 4
}
conn, err := driver.NewExtendedConnection(dsn)
conn, err := (mysql.MySQLDriver{}).Open(dsn)
if err != nil {
return nil, err
}
return &Conn{conn: conn, conf: conf}, nil
extconn, err := mysql.ExtendConn(conn)
if err != nil {
return nil, err
}
return &Conn{conn: extconn, conf: conf}, nil
}
// ReadPacket reads next packet from the server and processes the first status