diff --git a/cmd/main.go b/cmd/main.go index f838676..69118ef 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -1,14 +1,19 @@ package main import ( + "context" "flag" "fmt" "log" + "net" "os" + "os/signal" + "syscall" "time" "github.com/localhots/bocadillo/reader" "github.com/localhots/bocadillo/reader/slave" + "github.com/pkg/errors" ) func main() { @@ -31,21 +36,40 @@ func main() { log.Fatalf("Failed to create reader: %v", err) } - // for i := 0; i < 100; i++ { + done := handleShutdown() + ctx := context.Background() for { - evt, err := reader.ReadEvent() - if err != nil { - log.Fatalf("Failed to read event: %v", err) - } - ts := time.Unix(int64(evt.Header.Timestamp), 0).Format(time.RFC3339) - log.Printf("Event received: %s %s, %d\n", evt.Header.Type.String(), ts, evt.Header.NextOffset) - - if evt.Table != nil { - _, err := evt.DecodeRows() + ctx, cancel := context.WithTimeout(ctx, 1*time.Second) + defer cancel() + select { + case <-done: + log.Println("Closing reader") + err := reader.Close() if err != nil { - log.Fatalf("Failed to parse rows event: %v", err) + log.Fatalf("Failed to close reader: %v", err) + } + return + default: + evt, err := reader.ReadEvent(ctx) + if err != nil { + if isTimeout(err) { + log.Println("Event read timeout") + continue + } + log.Fatalf("Failed to read event: %v", err) + } + + ts := time.Unix(int64(evt.Header.Timestamp), 0).Format(time.RFC3339) + log.Printf("Event received: %s %s, %d\n", evt.Header.Type.String(), ts, evt.Header.NextOffset) + + if evt.Table != nil { + _, err := evt.DecodeRows() + if err != nil { + log.Fatalf("Failed to parse rows event: %v", err) + } } } + cancel() } } @@ -56,3 +80,27 @@ func validate(cond bool, msg string) { os.Exit(2) } } + +func handleShutdown() <-chan struct{} { + sig := make(chan os.Signal, 1) + done := make(chan struct{}) + signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM) + go func() { + <-sig + log.Println("Shutdown requested") + close(done) + }() + return done +} + +func isTimeout(err error) bool { + if err == nil { + return false + } + err = errors.Cause(err) + if err == context.DeadlineExceeded || err == context.Canceled { + return true + } + ne, ok := err.(*net.OpError) + return ok && ne.Timeout() +} diff --git a/reader/enhanced_reader.go b/reader/enhanced_reader.go index fc3d754..3b377ad 100644 --- a/reader/enhanced_reader.go +++ b/reader/enhanced_reader.go @@ -59,8 +59,8 @@ func (r *EnhancedReader) WhitelistTables(database string, tables ...string) erro } // ReadEvent reads next event from the binary log. -func (r *EnhancedReader) ReadEvent() (*Event, error) { - evt, err := r.reader.ReadEvent() +func (r *EnhancedReader) ReadEvent(ctx context.Context) (*Event, error) { + evt, err := r.reader.ReadEvent(ctx) if err != nil { return nil, err } @@ -78,30 +78,8 @@ func (r *EnhancedReader) ReadEvent() (*Event, error) { // NextRowsEvent returns the next rows event for a whitelisted table. It blocks // until next event is received or context is cancelled. func (r *EnhancedReader) NextRowsEvent(ctx context.Context) (*EnhancedRowsEvent, error) { - evtch := make(chan *EnhancedRowsEvent) - errch := make(chan error) - go func() { - evt, err := r.nextRowsEvent() - if err != nil { - errch <- err - } else { - evtch <- evt - } - }() - - select { - case evt := <-evtch: - return evt, nil - case err := <-errch: - return nil, err - case <-ctx.Done(): - return nil, nil - } -} - -func (r *EnhancedReader) nextRowsEvent() (*EnhancedRowsEvent, error) { for { - evt, err := r.reader.ReadEvent() + evt, err := r.reader.ReadEvent(ctx) if err != nil { return nil, err } diff --git a/reader/reader.go b/reader/reader.go index 3e6eac1..0815237 100644 --- a/reader/reader.go +++ b/reader/reader.go @@ -1,6 +1,8 @@ package reader import ( + "context" + "github.com/juju/errors" "github.com/localhots/bocadillo/binlog" "github.com/localhots/bocadillo/reader/slave" @@ -61,8 +63,8 @@ func New(dsn string, sc slave.Config) (*Reader, error) { } // ReadEvent reads next event from the binary log. -func (r *Reader) ReadEvent() (*Event, error) { - connBuff, err := r.conn.ReadPacket() +func (r *Reader) ReadEvent(ctx context.Context) (*Event, error) { + connBuff, err := r.conn.ReadPacket(ctx) if err != nil { return nil, errors.Annotate(err, "read next event") } diff --git a/reader/slave/slave_conn.go b/reader/slave/slave_conn.go index 04dfa52..ffbfea6 100644 --- a/reader/slave/slave_conn.go +++ b/reader/slave/slave_conn.go @@ -1,6 +1,7 @@ package slave import ( + "context" "database/sql/driver" "fmt" "io" @@ -65,8 +66,8 @@ func Connect(dsn string, conf Config) (*Conn, error) { // ReadPacket reads next packet from the server and processes the first status // byte. -func (c *Conn) ReadPacket() ([]byte, error) { - data, err := c.conn.ReadPacket() +func (c *Conn) ReadPacket(ctx context.Context) ([]byte, error) { + data, err := c.conn.ReadPacket(ctx) if err != nil { return nil, err }