Add thread pool implementation
This commit is contained in:
@@ -0,0 +1,63 @@
|
||||
package threadpool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ThreadPool implements a thread pool model. It allocates a pool of threads
|
||||
// ready to perform tasks concurrently.
|
||||
type ThreadPool struct {
|
||||
Logger interface {
|
||||
Printf(f string, args ...interface{})
|
||||
}
|
||||
queue chan func()
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// New creates a thread pool with a given number of workers.
|
||||
func New(size int) *ThreadPool {
|
||||
tp := &ThreadPool{
|
||||
Logger: &log.Logger{},
|
||||
queue: make(chan func(), size),
|
||||
}
|
||||
|
||||
tp.wg.Add(size)
|
||||
for i := 0; i < size; i++ {
|
||||
go tp.worker()
|
||||
}
|
||||
|
||||
return tp
|
||||
}
|
||||
|
||||
// Enqueue adds a task to queue.
|
||||
func (tp *ThreadPool) Enqueue(ctx context.Context, task func()) {
|
||||
select {
|
||||
case tp.queue <- task:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
|
||||
// Close waits for all currently accepted tasks to be processed and returns.
|
||||
// Attempts to enqueue a task after calling Close would result in a panic.
|
||||
func (tp *ThreadPool) Close() {
|
||||
close(tp.queue)
|
||||
tp.wg.Wait()
|
||||
}
|
||||
|
||||
func (tp *ThreadPool) worker() {
|
||||
defer tp.wg.Done()
|
||||
for task := range tp.queue {
|
||||
tp.perform(task)
|
||||
}
|
||||
}
|
||||
|
||||
func (tp *ThreadPool) perform(task func()) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
tp.Logger.Printf("Thread pool task recovered from panic: %v", err)
|
||||
}
|
||||
}()
|
||||
task()
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
package threadpool
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestThreadPool(t *testing.T) {
|
||||
const n = 100
|
||||
var s int64
|
||||
ctx := context.Background()
|
||||
|
||||
pool := New(n / 10)
|
||||
for i := 0; i < n; i++ {
|
||||
pool.Enqueue(ctx, func() { atomic.AddInt64(&s, 1) })
|
||||
}
|
||||
pool.Close()
|
||||
|
||||
if s != n {
|
||||
t.Errorf("Thread pool result doesn't match: expected %d, got %d", n, s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestThreadPoolPanicHandling(t *testing.T) {
|
||||
logger := &bufLogger{buf: bytes.NewBuffer(nil)}
|
||||
ctx := context.Background()
|
||||
|
||||
pool := New(1)
|
||||
pool.Logger = logger
|
||||
pool.Enqueue(ctx, func() { panic("oh no!") })
|
||||
pool.Close()
|
||||
|
||||
out := logger.buf.String()
|
||||
exp := "Thread pool task recovered from panic: oh no!"
|
||||
if out != exp {
|
||||
t.Errorf("Expected logger to receive message %q, got %q", exp, out)
|
||||
}
|
||||
}
|
||||
|
||||
type bufLogger struct {
|
||||
buf *bytes.Buffer
|
||||
}
|
||||
|
||||
func (b *bufLogger) Printf(f string, args ...interface{}) {
|
||||
b.buf.WriteString(fmt.Sprintf(f, args...))
|
||||
}
|
||||
Reference in New Issue
Block a user