Add thread pool implementation

This commit is contained in:
2018-06-17 12:57:47 +02:00
parent c70d409479
commit 611b508425
4 changed files with 132 additions and 0 deletions
+63
View File
@@ -0,0 +1,63 @@
package threadpool
import (
"context"
"log"
"sync"
)
// ThreadPool implements a thread pool model. It allocates a pool of threads
// ready to perform tasks concurrently.
type ThreadPool struct {
Logger interface {
Printf(f string, args ...interface{})
}
queue chan func()
wg sync.WaitGroup
}
// New creates a thread pool with a given number of workers.
func New(size int) *ThreadPool {
tp := &ThreadPool{
Logger: &log.Logger{},
queue: make(chan func(), size),
}
tp.wg.Add(size)
for i := 0; i < size; i++ {
go tp.worker()
}
return tp
}
// Enqueue adds a task to queue.
func (tp *ThreadPool) Enqueue(ctx context.Context, task func()) {
select {
case tp.queue <- task:
case <-ctx.Done():
}
}
// Close waits for all currently accepted tasks to be processed and returns.
// Attempts to enqueue a task after calling Close would result in a panic.
func (tp *ThreadPool) Close() {
close(tp.queue)
tp.wg.Wait()
}
func (tp *ThreadPool) worker() {
defer tp.wg.Done()
for task := range tp.queue {
tp.perform(task)
}
}
func (tp *ThreadPool) perform(task func()) {
defer func() {
if err := recover(); err != nil {
tp.Logger.Printf("Thread pool task recovered from panic: %v", err)
}
}()
task()
}
+49
View File
@@ -0,0 +1,49 @@
package threadpool
import (
"bytes"
"context"
"fmt"
"sync/atomic"
"testing"
)
func TestThreadPool(t *testing.T) {
const n = 100
var s int64
ctx := context.Background()
pool := New(n / 10)
for i := 0; i < n; i++ {
pool.Enqueue(ctx, func() { atomic.AddInt64(&s, 1) })
}
pool.Close()
if s != n {
t.Errorf("Thread pool result doesn't match: expected %d, got %d", n, s)
}
}
func TestThreadPoolPanicHandling(t *testing.T) {
logger := &bufLogger{buf: bytes.NewBuffer(nil)}
ctx := context.Background()
pool := New(1)
pool.Logger = logger
pool.Enqueue(ctx, func() { panic("oh no!") })
pool.Close()
out := logger.buf.String()
exp := "Thread pool task recovered from panic: oh no!"
if out != exp {
t.Errorf("Expected logger to receive message %q, got %q", exp, out)
}
}
type bufLogger struct {
buf *bytes.Buffer
}
func (b *bufLogger) Printf(f string, args ...interface{}) {
b.buf.WriteString(fmt.Sprintf(f, args...))
}