用设计模式包装协程池

2023-08-10 18:47:46 浏览数 (1)

协程池是一种常见的并发编程模式,它可以在多个协程之间共享一组固定数量的协程,以避免创建过多的协程导致系统资源耗尽。在 Go 语言中,协程池通常使用 sync.WaitGroup 和 chan 类型来实现。

在本文中,我们将介绍一种用户设计模式,即封装协程池。该模式可以将协程池的实现细节隐藏在一个简单的接口后面,使用户可以轻松地使用协程池而不必了解其内部实现。

实现协程池

首先,我们定义一个 workerFunc 类型,它表示一个可以在协程池中运行的函数。然后,我们定义一个 Pool 类型,它包含一个 workers 通道和一个 limit 通道。workers 通道用于存储要运行的函数,limit 通道用于限制协程池中的协程数量。

代码语言:javascript复制
type workerFunc func() error
type Pool struct {
    workers chan workerFunc
    wg      sync.WaitGroup
    limit   chan struct{}
}

接下来,我们定义一个 NewPool 函数,它接受一个整数参数 size,表示协程池中的协程数量。NewPool 函数返回一个指向 Pool 类型的指针。

代码语言:javascript复制
func NewPool(size int) *Pool {
    return &Pool{
        workers: make(chan workerFunc),
        limit:   make(chan struct{}, size),
    }
}

然后,我们定义一个 Add 方法,它接受一个 context.Context 类型的参数 ctx 和一个 workerFunc 类型的参数 fn。Add 方法将 fn 函数添加到 workers 通道中,如果 ctx 被取消,则返回 ctx.Err()。

代码语言:javascript复制
func (p *Pool) Add(ctx context.Context, fn workerFunc) error {
    select {
    case <-ctx.Done():
        return ctx.Err()
    case p.workers <- fn:
        return nil
    }
}

接下来,我们定义一个 Run 方法,它接受一个 context.Context 类型的参数 ctx。Run 方法从 workers 通道中读取函数并运行它们。在运行函数之前,Run 方法会从 limit 通道中获取一个空结构体,以限制协程池中的协程数量。在函数运行完成后,Run 方法会将空结构体放回 limit 通道中,并使用 sync.WaitGroup 等待所有函数完成。

代码语言:javascript复制
func (p *Pool) Run(ctx context.Context) {
    for fn := range p.workers {
        p.limit <- struct{}{}
        p.wg.Add(1)
        go func(fn workerFunc) {
            defer func() {
                <-p.limit
                p.wg.Done()
            }()
            if err := fn(); err != nil {
                // handle error
                fmt.Printf("err: %vn", err)
            }
        }(fn)
    }
    p.wg.Wait()
}

最后,我们定义一个 Stop 方法,它关闭 workers 通道并等待所有函数完成。

代码语言:javascript复制
func (p *Pool) Stop() {
    close(p.workers)
    p.wg.Wait()
}

完整协程池代码:

代码语言:javascript复制
package gopool

import (
  "context"
  "fmt"
  "sync"
)

type workerFunc func() error

type Pool struct {
  workers chan workerFunc
  wg      sync.WaitGroup
  limit   chan struct{}
}

func NewPool(size int) *Pool {
  return &Pool{
    workers: make(chan workerFunc),
    limit:   make(chan struct{}, size),
  }
}

func (p *Pool) Add(ctx context.Context, fn workerFunc) error {
  select {
  case <-ctx.Done():
    return ctx.Err()
  case p.workers <- fn:
    return nil
  }
}

func (p *Pool) Run(ctx context.Context) {
  for fn := range p.workers {
    p.limit <- struct{}{}
    p.wg.Add(1)
    go func(fn workerFunc) {
      defer func() {
        <-p.limit
        p.wg.Done()
      }()
      if err := fn(); err != nil {
        // handle error
        fmt.Printf("err: %vn", err)
      }
    }(fn)
  }
  p.wg.Wait()
}

func (p *Pool) Stop() {
  close(p.workers)
  p.wg.Wait()
}

现在,我们已经完成了协程池的封装。用户可以使用以下代码来创建一个协程池并运行函数:

代码语言:javascript复制
pool := NewPool(10)
defer pool.Stop()

for i := 0; i < 100; i   {
    err := pool.Add(context.Background(), func() error {
        // do some work
        return nil
    })
    if err != nil {
        // handle error
    }
}

pool.Run(context.Background())

在上面的代码中,我们创建了一个大小为 10 的协程池,并向其中添加了 100 个函数。然后,我们调用 Run 方法来运行这些函数。在函数运行完成后,我们调用 Stop 方法来关闭协程池。

通过封装协程池,我们可以将协程池的实现细节隐藏在一个简单的接口后面,使用户可以轻松地使用协程池而不必了解其内部实现。这种用户设计模式可以提高代码的可读性和可维护性,并使代码更易于重用。

使用普通工厂模式封装

代码语言:javascript复制
package gopool

import "context"

type PoolFactory struct {
  size int
  pool *Pool // 存储 *Pool
}

func NewPoolFactory(size int) *PoolFactory {
  return &PoolFactory{
    size: size,
    pool: &Pool{ // 初始化 *Pool
      workers: make(chan workerFunc),
      limit:   make(chan struct{}, size),
    },
  }
}

func (f *PoolFactory) getPool() *Pool {
  return f.pool // 直接返回存储的 *Pool
}

func (f *PoolFactory) Add(ctx context.Context, fn workerFunc) error {
  pool := f.getPool()
  select {
  case <-ctx.Done():
    return ctx.Err()
  case pool.workers <- fn:
    return nil
  }
}

func (f *PoolFactory) Stop(ctx context.Context) error {
  pool := f.getPool()
  close(pool.workers)
  for i := 0; i < f.size; i   {
    select {
    case <-ctx.Done():
      return ctx.Err()
    case pool.limit <- struct{}{}:
    }
  }
  return nil
}

func (f *PoolFactory) Run(ctx context.Context) {
  pool := f.getPool()
  pool.Run(ctx)
}

使用工厂方法模式封装

代码语言:javascript复制
package gopool

import "context"

type PoolFactory interface {
  Add(ctx context.Context, fn workerFunc) error
  Stop(ctx context.Context) error
  Run(ctx context.Context)
}

type poolFactory struct {
  size int
  pool *Pool // 存储 *Pool
}

func NewPoolFactory(size int) PoolFactory {
  return &poolFactory{
    size: size,
    pool: &Pool{ // 初始化 *Pool
      workers: make(chan workerFunc),
      limit:   make(chan struct{}, size),
    },
  }
}

func (f *poolFactory) Add(ctx context.Context, fn workerFunc) error {
  pool := f.pool
  select {
  case <-ctx.Done():
    return ctx.Err()
  case pool.workers <- fn:
    return nil
  }
}

func (f *poolFactory) Stop(ctx context.Context) error {
  pool := f.pool
  close(pool.workers)
  for i := 0; i < f.size; i   {
    select {
    case <-ctx.Done():
      return ctx.Err()
    case pool.limit <- struct{}{}:
    }
  }
  return nil
}

func (f *poolFactory) Run(ctx context.Context) {
  pool := f.pool
  pool.Run(ctx)
}

抽象工厂模式封装

代码语言:javascript复制
package gopool

import (
  "fmt"
  "sync"
)
type PoolFactory interface {
  CreatePool(size int) Pool
}

type poolFactory struct{}

func NewPoolFactory() PoolFactory {
  return &poolFactory{}
}

func (pf *poolFactory) CreatePool(size int) Pool {
  return NewPool(size)
}
type workerFunc func() error

type Pool interface {
  Add(fn workerFunc)
  Run()
  Stop()
}

type pool struct {
  workers chan workerFunc
  wg      sync.WaitGroup
  limit   chan struct{}
}

func NewPool(size int) Pool {
  return &pool{
    workers: make(chan workerFunc, size),
    limit:   make(chan struct{}, size),
  }
}

func (p *pool) Add(fn workerFunc) {
  p.workers <- fn
}

func (p *pool) Run() {
  for fn := range p.workers {
    p.limit <- struct{}{}
    p.wg.Add(1)
    go func(fn workerFunc) {
      defer func() {
        <-p.limit
        p.wg.Done()
      }()
      if err := fn(); err != nil {
        // handle error
        fmt.Printf("err: %vn", err)
      }
    }(fn)
  }
  p.wg.Wait()
}

func (p *pool) Stop() {
  close(p.workers)
  p.wg.Wait()
}

更多请看仓库:https://github.com/xilu0/gopool.git

0 人点赞