字节开源Go协程池 gopool

2024-04-17 18:50:25 浏览数 (1)

字节开源Go协程池gopool

Java 中线程池,也支持自定义线程池,为啥 Golang 官方没有提供协程池的实现?Golang 官方偏向轻量级的并发, 希望通过 go func() 解决问题。

问题

  • 协程数量不可控,在代码并发处理过程中,一不小心 ,go 出了数万个协程, goruntine 虽然轻量级的执行流程,但是不限制的大量创建 goruntine ,对系统性能影响会很大,一个 goruntine 初始栈内存为 2KB,如果新建过多协程,过多 goruntine,内存会达到G级别,如何让协程数可控,是一个问题。
  • 协程泄漏问题,如果协程的bug,导致协程无法被回收,日积月累,可能导致程序崩溃,需要有工具避免协程泄漏问题。

先写一个协程池

一般来说,用 waitGroup 结合 channel ,可以实现一个协程池的功能。一个协程池,一般要具有如下三个功能:

  • 提交任务
  • 启动协程
  • 等待协程执行结束
代码语言:javascript复制
package main

import (
    "fmt"
    "sync"
    "testing"
)

// 任务结构体
type Task struct {
    ID int
    // 任务
    Job func()
}

// 协程池结构体
type Pool struct {
    // 任务通道
    taskChan chan Task
    // 工作协程数量
    workerCount int
    // 等待组
    wg sync.WaitGroup
}

// 创建协程池
func NewPool(workerCount int) *Pool {
    workChannel := make(chan Task, workerCount)
    return &Pool{
       taskChan:    workChannel,
       workerCount: workerCount,
       wg:          sync.WaitGroup{},
    }
}

// 向协程池提交任务
func (p *Pool) SubmitTask(task Task) {
    p.taskChan <- task
    p.wg.Add(1)
}

// 启动工作协程
func (p *Pool) StartWorkers() {
    for i := 0; i < p.workerCount; i   {
       go p.worker()
    }
}

// 工作协程
func (p *Pool) worker() {
    for task := range p.taskChan {
       defer p.wg.Done()
       fmt.Printf("Worker received task %dn", task.ID)
       task.Job()
       fmt.Printf("Worker completed task %dn", task.ID)
    }
}

func TestThreadPool(t *testing.T) {
    // 创建一个协程池,设置工作协程数量为 5
    pool := NewPool(5)

    // 提交任务到协程池
    for i := 1; i < 5; i   {
       task := Task{
          ID: i,
          Job: func() {
             fmt.Printf("Task %d is runningn", i)
          },
       }
       pool.SubmitTask(task)
    }

    // 启动工作协程
    pool.StartWorkers()

    // 等待所有任务完成
    pool.wg.Wait()
}

执行结果:

代码语言:javascript复制
=== RUN   TestThreadPool
Worker received task 1
Task 5 is running
Worker completed task 1
Worker received task 4
Task 5 is running
Worker completed task 4
Worker received task 2
Task 5 is running
Worker completed task 2
Worker received task 3
Task 5 is running
Worker completed task 3
--- PASS: TestThreadPool (0.00s)
PASS

优化一下上面的代码:

  • 将提交任务和协程池启动放一块
  • 引入 ctx, 其中某个协程错误,取消整个协程。
代码语言:javascript复制
package utils

import (
    "context"
    "sync"
)

// Semaphore 使用waitGroup和channel实现并发同时控制最大并发量
// 参考golang.org/x/sync.errgroup实现返回err功能
type Semaphore struct {
    c       chan struct{}
    wg      sync.WaitGroup
    cancel  func()
    errOnce sync.Once
    err     error
}

func NewSemaphore(maxSize int) *Semaphore {
    return &Semaphore{
       c: make(chan struct{}, maxSize),
    }
}

func NewSemaphoreWithContext(ctx context.Context, maxSize int) (*Semaphore, context.Context) {
    ctx, cancel := context.WithCancel(ctx)
    return &Semaphore{
       c:      make(chan struct{}, maxSize),
       cancel: cancel,
    }, ctx
}

func (s *Semaphore) Go(f func() error) {
    s.wg.Add(1)
    s.c <- struct{}{}
    go func() {
       defer func() {
          if err := recover(); err != nil {
          }
       }()
       defer func() {
          <-s.c
          s.wg.Done()
       }()
       if err := f(); err != nil {
          s.errOnce.Do(func() {
             s.err = err
             if s.cancel != nil {
                s.cancel()
             }
          })
       }
    }()
}

func (s *Semaphore) Wait() error {
    s.wg.Wait()
    if s.cancel != nil {
       s.cancel()
    }
    return s.err
}

测试代码:

代码语言:javascript复制
package utils

import (
    "math"
    "testing"
    "time"

    "github.com/bmizerany/assert"
)

func sleep1s() error {
    time.Sleep(time.Second)
    return nil
}

func TestSemaphore(t *testing.T) {
    // 最大并发 >= 执行任务数量
    sema := NewSemaphore(4)
    now := time.Now()
    for i := 0; i < 4; i   {
       sema.Go(sleep1s)
    }
    err := sema.Wait()
    assert.Equal(t, nil, err)
    sec := math.Round(time.Since(now).Seconds())
    assert.Equal(t, 1, int(sec))

    // 设置最大并发为2
    sema = NewSemaphore(2)
    now = time.Now()
    for i := 0; i < 4; i   {
       sema.Go(sleep1s)
    }
    err = sema.Wait()
    assert.Equal(t, nil, err)
    sec = math.Round(time.Since(now).Seconds())
    assert.Equal(t, 2, int(sec))
}

sync.pool

https://github.com/bytedance/gopkg/tree/develop/util/gopool

原理简介

原理和 Java 线程池原理有点类似

工作流程

  • 工作协程(workerPool):可以设置协程池中的工作协程数(cap)。
代码语言:javascript复制
// 如果没使用 NewPool方法创建协程池 会默认 init 建一个 default pool
func init() {
    initMetrics()
    defaultPool = NewPool("gopool.DefaultPool", 10000, NewConfig())
}

func NewPool(name string, cap int32, config *Config) Pool {
    p := &pool{
       name:   name,
       cap:    cap,
       config: config,
    }
    return p
}
  • 任务队列(taskPool):用于存放待执行任务的队列,当核心线程都在执行任务时,新的任务会被放入任务队列中等待。
代码语言:javascript复制
var taskPool sync.Pool

func init() {
    taskPool.New = newTask
}

func newTask() interface{} {
    return &task{}
}

工作流程如下:

  1. 当任务到达时,会将任务加入到工作队列中队尾。
  2. 如果 task 任务数量大于阈值,阈值默认是1且目前的 worker(工作协程)数量小于上限 p.cap 或者没有工作协程,会立即执行任务。
代码语言:javascript复制
func (p *pool) CtxGo(ctx context.Context, f func()) {
    t := taskPool.Get().(*task)
    t.ctx = ctx
    t.f = f
    p.taskLock.Lock()
    if p.taskHead == nil {
       p.taskHead = t
       p.taskTail = t
    } else {
       p.taskTail.next = t
       p.taskTail = t
    }
    p.taskLock.Unlock()
    atomic.AddInt32(&p.taskCount, 1)
    // 如果 pool 已经被关闭了,就 panic
    if atomic.LoadInt32(&p.closed) == 1 {
       panic("use closed pool")
    }
    // 满足以下两个条件:
    // 1. task 数量大于阈值
    // 2. 目前的 worker 数量小于上限 p.cap(工作协程数)
    // 或者目前没有 worker
    if (atomic.LoadInt32(&p.taskCount) >= p.config.ScaleThreshold && p.WorkerCount() < atomic.LoadInt32(&p.cap)) || p.WorkerCount() == 0 {
       p.incWorkerCount()
       w := workerPool.Get().(*worker)
       w.pool = p
       w.run()
    }
}
  1. 通过 for 循环是从工作队列中取队头任务,然后移动队头指向链表下一节点,执行任务,任务完成后做清理,直至任务队列中没有任务需要执行,协程 return
代码语言:javascript复制
func (w *worker) run() {
    go func() {
       for {
          //select {
          //case <-w.stopChan:
          // w.close()
          // return
          //default:
          var t *task
          w.pool.taskLock.Lock()
          if w.pool.taskHead != nil {
             t = w.pool.taskHead
             w.pool.taskHead = w.pool.taskHead.next
             atomic.AddInt32(&w.pool.taskCount, -1)
          }
          if t == nil {
             // 如果没有任务要做了,就释放资源,退出
             w.close()
             w.pool.taskLock.Unlock()
             w.Recycle()
             return
          }
          w.pool.taskLock.Unlock()
          func() {
             defer func() {
                if r := recover(); r != nil {
                   logs.CtxFatal(t.ctx, "GOPOOL: panic in pool: %s: %v: %s", w.pool.name, r, debug.Stack())
                   if w.pool.config.EnablePanicMetrics {
                      panicMetricsClient.EmitCounter(panicKey, 1, metrics.T{Name: "pool", Value: w.pool.name})
                   }
                   w.pool.panicHandler(t.ctx, r)
                }
             }()
             t.f()
          }()
          t.Recycle()
          //}
       }
    }()
}

可能会问,为啥要写个死循环去遍历,假设不写 for 循环, 如果一个任务,run 一次,就创建一个工作协程,这个开销成本比较高,通过循环变了任务队列的方式,不断去取,可以避免创建一些不必要的工作协程。

举个例子,假设有 4个任务,任务1 执行,开启了一个工作协程1, 任务2 执行,开启了一个工作协程2,任务3执行,开启了一个工作协程3, 任务4来了,此时工作协程1执行完毕,去取任务4执行。这样的话,4个任务,只需要3个工作协程,如果工作协程执行足够快,工作协程数会更少。

实践

场景:捞取2个月的数据,然后导出 捞取一个月的动账明细数据,然后进行导出,原流程是一个开始时间,一个结束时间,每次捞取10分钟的数据,每次加10分钟,循环处理。改为并发流程后,先将时间按10分钟分段,每一段做为一个任务,交给协程池去跑。最后再对结果进行汇总。项目实测,导出效率提升10倍以上。

参考资料

https://github.com/bytedance/gopkg/tree/develop/util/gopool

0 人点赞