Add thread pool implementation
This commit is contained in:
parent
c70d409479
commit
611b508425
|
@ -0,0 +1,19 @@
|
||||||
|
Copyright (c) 2018 Gregory Eremin
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
|
@ -0,0 +1,63 @@
|
||||||
|
package threadpool
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ThreadPool implements a thread pool model. It allocates a pool of threads
|
||||||
|
// ready to perform tasks concurrently.
|
||||||
|
type ThreadPool struct {
|
||||||
|
Logger interface {
|
||||||
|
Printf(f string, args ...interface{})
|
||||||
|
}
|
||||||
|
queue chan func()
|
||||||
|
wg sync.WaitGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a thread pool with a given number of workers.
|
||||||
|
func New(size int) *ThreadPool {
|
||||||
|
tp := &ThreadPool{
|
||||||
|
Logger: &log.Logger{},
|
||||||
|
queue: make(chan func(), size),
|
||||||
|
}
|
||||||
|
|
||||||
|
tp.wg.Add(size)
|
||||||
|
for i := 0; i < size; i++ {
|
||||||
|
go tp.worker()
|
||||||
|
}
|
||||||
|
|
||||||
|
return tp
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enqueue adds a task to queue.
|
||||||
|
func (tp *ThreadPool) Enqueue(ctx context.Context, task func()) {
|
||||||
|
select {
|
||||||
|
case tp.queue <- task:
|
||||||
|
case <-ctx.Done():
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close waits for all currently accepted tasks to be processed and returns.
|
||||||
|
// Attempts to enqueue a task after calling Close would result in a panic.
|
||||||
|
func (tp *ThreadPool) Close() {
|
||||||
|
close(tp.queue)
|
||||||
|
tp.wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tp *ThreadPool) worker() {
|
||||||
|
defer tp.wg.Done()
|
||||||
|
for task := range tp.queue {
|
||||||
|
tp.perform(task)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tp *ThreadPool) perform(task func()) {
|
||||||
|
defer func() {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
tp.Logger.Printf("Thread pool task recovered from panic: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
task()
|
||||||
|
}
|
|
@ -0,0 +1,49 @@
|
||||||
|
package threadpool
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestThreadPool(t *testing.T) {
|
||||||
|
const n = 100
|
||||||
|
var s int64
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
pool := New(n / 10)
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
pool.Enqueue(ctx, func() { atomic.AddInt64(&s, 1) })
|
||||||
|
}
|
||||||
|
pool.Close()
|
||||||
|
|
||||||
|
if s != n {
|
||||||
|
t.Errorf("Thread pool result doesn't match: expected %d, got %d", n, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestThreadPoolPanicHandling(t *testing.T) {
|
||||||
|
logger := &bufLogger{buf: bytes.NewBuffer(nil)}
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
pool := New(1)
|
||||||
|
pool.Logger = logger
|
||||||
|
pool.Enqueue(ctx, func() { panic("oh no!") })
|
||||||
|
pool.Close()
|
||||||
|
|
||||||
|
out := logger.buf.String()
|
||||||
|
exp := "Thread pool task recovered from panic: oh no!"
|
||||||
|
if out != exp {
|
||||||
|
t.Errorf("Expected logger to receive message %q, got %q", exp, out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type bufLogger struct {
|
||||||
|
buf *bytes.Buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *bufLogger) Printf(f string, args ...interface{}) {
|
||||||
|
b.buf.WriteString(fmt.Sprintf(f, args...))
|
||||||
|
}
|
Loading…
Reference in New Issue