Add thread pool implementation
This commit is contained in:
parent
c70d409479
commit
611b508425
|
@ -0,0 +1,19 @@
|
|||
Copyright (c) 2018 Gregory Eremin
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
|
@ -0,0 +1,63 @@
|
|||
package threadpool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ThreadPool implements a thread pool model. It allocates a pool of threads
|
||||
// ready to perform tasks concurrently.
|
||||
type ThreadPool struct {
|
||||
Logger interface {
|
||||
Printf(f string, args ...interface{})
|
||||
}
|
||||
queue chan func()
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// New creates a thread pool with a given number of workers.
|
||||
func New(size int) *ThreadPool {
|
||||
tp := &ThreadPool{
|
||||
Logger: &log.Logger{},
|
||||
queue: make(chan func(), size),
|
||||
}
|
||||
|
||||
tp.wg.Add(size)
|
||||
for i := 0; i < size; i++ {
|
||||
go tp.worker()
|
||||
}
|
||||
|
||||
return tp
|
||||
}
|
||||
|
||||
// Enqueue adds a task to queue.
|
||||
func (tp *ThreadPool) Enqueue(ctx context.Context, task func()) {
|
||||
select {
|
||||
case tp.queue <- task:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
|
||||
// Close waits for all currently accepted tasks to be processed and returns.
|
||||
// Attempts to enqueue a task after calling Close would result in a panic.
|
||||
func (tp *ThreadPool) Close() {
|
||||
close(tp.queue)
|
||||
tp.wg.Wait()
|
||||
}
|
||||
|
||||
func (tp *ThreadPool) worker() {
|
||||
defer tp.wg.Done()
|
||||
for task := range tp.queue {
|
||||
tp.perform(task)
|
||||
}
|
||||
}
|
||||
|
||||
func (tp *ThreadPool) perform(task func()) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
tp.Logger.Printf("Thread pool task recovered from panic: %v", err)
|
||||
}
|
||||
}()
|
||||
task()
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
package threadpool
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestThreadPool(t *testing.T) {
|
||||
const n = 100
|
||||
var s int64
|
||||
ctx := context.Background()
|
||||
|
||||
pool := New(n / 10)
|
||||
for i := 0; i < n; i++ {
|
||||
pool.Enqueue(ctx, func() { atomic.AddInt64(&s, 1) })
|
||||
}
|
||||
pool.Close()
|
||||
|
||||
if s != n {
|
||||
t.Errorf("Thread pool result doesn't match: expected %d, got %d", n, s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestThreadPoolPanicHandling(t *testing.T) {
|
||||
logger := &bufLogger{buf: bytes.NewBuffer(nil)}
|
||||
ctx := context.Background()
|
||||
|
||||
pool := New(1)
|
||||
pool.Logger = logger
|
||||
pool.Enqueue(ctx, func() { panic("oh no!") })
|
||||
pool.Close()
|
||||
|
||||
out := logger.buf.String()
|
||||
exp := "Thread pool task recovered from panic: oh no!"
|
||||
if out != exp {
|
||||
t.Errorf("Expected logger to receive message %q, got %q", exp, out)
|
||||
}
|
||||
}
|
||||
|
||||
type bufLogger struct {
|
||||
buf *bytes.Buffer
|
||||
}
|
||||
|
||||
func (b *bufLogger) Printf(f string, args ...interface{}) {
|
||||
b.buf.WriteString(fmt.Sprintf(f, args...))
|
||||
}
|
Loading…
Reference in New Issue