Go中sync.WaitGroup处理协程同步

2023-08-11 19:37:05 浏览数 (2)

简介

一个 sync.WaitGroup 对象可以等待一组协程结束。它很好地解决了 goroutine 同步的问题。

通常用于以下几种场景:

  • 需要等待 goroutine 多路任务完成
  • 主 goroutine 需要等待子 goroutine
  • 顺序任务需要等待前置任务

使用方法

  • main协程通过调用 wg.Add(delta int) 设置 worker 协程的个数,然后创建 worker 协程
  • worker协程执行结束以后,都要调用 wg.Done()
  • main协程调用 wg.Wait(),直到所有 worker 协程全部执行结束后返回

使用示例

  • WaitGroup内部使用一个计数器count
  • Add方法会增加计数器的值
  • Done方法会减少计数器的值
  • Wait方法会阻塞,直到计数器的值变为0
代码语言:go复制
// 初始化 WaitGroup
var wg sync.WaitGroup

// 告诉 WaitGroup 有 2 个 goroutine 需要等待
wg.Add(2)

// 启动第一个 goroutine
go func() {
    defer wg.Done()

    // do something
    time.Sleep(time.Second * 1)
    fmt.Println("goroutine 1 done")
}()

// 启动第二个 goroutine
go func() {  
    defer wg.Done()

    // do something
    time.Sleep(time.Second * 2)
    fmt.Println("goroutine 2 done")
}()

// 等待所有注册过的 goroutine 都执行完
wg.Wait()

// 主 goroutine 等待 wg.Wait() 完成
// go next

实现原理

  • 通过原子操作统一记录计数和等待变量。
  • 在计数操作与等待操作之间加入同步机制。
  • 使用信号量机制通知等待线程。
  • 通过可见性和竞争检测保证正确性。

具体一点:

  1. 使用一个64位的原子操作变量state来存储计数和等待线程数。高32位作为计数,低32位作为等待线程数。
  2. Add方法通过原子操作将计数调整,加入必要的同步操作保证顺序。
  3. Wait方法通过循环检测计数值,如果不为0则加1等待变量,否则返回。加等待变量表示有新的等待线程。
  4. 多次Add调用可能导致计数临界下降为0时有等待线程,这时需要额外同步检查避免错误。
  5. 32位系统需要检查变量对齐情况,可能需要交换变量存储位置保证原子方式有效。
  6. 内部使用runtime提供的信号量调用runtime_Semacquire/runtime_Semrelease来实现等待通知功能。
  7. 使用内存锁race.Enable完成可见性保证和竞争检测。

sync.WaitGroup 源码

代码语言:go复制
package sync

import (
	"internal/race"
	"sync/atomic"
	"unsafe"
)

// WaitGroup等待一组协程完成。
// 主协程调用Add来设置
// 等待的协程。然后是每个协程
// 运行并在完成时调用Done。同时,
// Wait可以用来阻塞,直到所有的协程都完成。

// WaitGroup首次使用后不能复制。
type WaitGroup struct {
	noCopy noCopy

	// 64位值:高32位为计数器,低32位为等待计数。
    // 64位原子操作需要64位对齐,但是32位编译器只保证64位字段是32位对齐的。
	// 出于这个原因,在32位体系结构上,我们需要检查state()中state1是否对齐,并在需要时动态地“交换”字段顺序。
    state1 uint64
	state2 uint32
}

// State返回指向存储在wg.state*中的State和sema字段的指针。
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
	if unsafe.Alignof(wg.state1) == 8 || uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
		// State1是64位对齐的:不做任何事情。
		return &wg.state1, &wg.state2
	} else {
		// State1是32位对齐,但不是64位对齐:这意味着(&state1) 4是64位对齐的。
		state := (*[3]uint32)(unsafe.Pointer(&wg.state1))
		return (*uint64)(unsafe.Pointer(&state[1])), &state[0]
	}
}

// Add向WaitGroup计数器添加增量,增量可能为负。
// 如果计数器变为零,则释放被Wait阻塞的所有协程。
// 如果计数器为负,则添加panics。

// 请注意,当计数器为零时,具有正增量的调用必须在Wait之前发生。
// 具有负增量的调用,或者在计数器大于零时开始的具有正增量的调用,可能在任何时候发生。
// 通常,这意味着对Add的调用应该在语句创建要等待的程序或其他事件之前执行。
// 如果重用WaitGroup来等待几个独立的事件集,则必须在所有先前的wait调用返回之后发生新的Add调用。
// 参见WaitGroup示例。
func (wg *WaitGroup) Add(delta int) {
	statep, semap := wg.state()
	if race.Enabled {
		_ = *statep // 提前触发nil延迟
		if delta < 0 {
			// 与Wait同步减量。
			race.ReleaseMerge(unsafe.Pointer(wg))
		}
		race.Disable()
		defer race.Enable()
	}
	state := atomic.AddUint64(statep, uint64(delta)<<32)
	v := int32(state >> 32)
	w := uint32(state)
	if race.Enabled && delta > 0 && v == int32(delta) {
		// 第一个增量必须与Wait同步。
        // 需要将其建模为读取,因为可能有多个并发的wg。计数器从0转换。
		race.Read(unsafe.Pointer(semap))
	}
	if v < 0 {
		panic("sync: negative WaitGroup counter")
	}
	if w != 0 && delta > 0 && v == int32(delta) {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
	if v > 0 || w == 0 {
		return
	}
	// 当 waiters > 0时,这个协程将counter设置为0。
    // 状态不可能同时发生突变
    // -添加不能与等待同时发生,
    // - Wait如果看到counter == 0,则不会增加waiters。
    // 仍然要做一个便宜的完整性检查来检测WaitGroup的误用。
	if *statep != state {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
	// 将waiters计数重置为0。 
	*statep = 0
	for ; w != 0; w-- {
		runtime_Semrelease(semap, false, 0)
	}
}

// Done将WaitGroup counter减1。
func (wg *WaitGroup) Done() {
	wg.Add(-1)
}

// 等待阻塞直到WaitGroup counter为0。
func (wg *WaitGroup) Wait() {
	statep, semap := wg.state()
	if race.Enabled {
		_ = *statep // 提前触发nil延迟
		race.Disable()
	}
	for {
		state := atomic.LoadUint64(statep)
		v := int32(state >> 32)
		w := uint32(state)
		if v == 0 {
			// Counter is 0, no need to wait.
			if race.Enabled {
				race.Enable()
				race.Acquire(unsafe.Pointer(wg))
			}
			return
		}
		// Increment waiters count.
		if atomic.CompareAndSwapUint64(statep, state, state 1) {
			if race.Enabled && w == 0 {
				// Wait必须与第一个Add同步。
                // 需要将其建模为写操作与Add中的读操作竞争。
                // 因此,只能给第一个 waiter 写入,
                // 否则并发等待将相互竞争。
				race.Write(unsafe.Pointer(semap))
			}
			runtime_Semacquire(semap)
			if *statep != 0 {
				panic("sync: WaitGroup is reused before previous Wait has returned")
			}
			if race.Enabled {
				race.Enable()
				race.Acquire(unsafe.Pointer(wg))
			}
			return
		}
	}
}

internal/race

主要用于静态编译时的并发数据竞争检测,可以更便捷地检查并发程序是否安全。

race.Enabled表示是否开启竞争检测功能。

race.Enable()开启竞争检测。

race.Disable()关闭竞争检测。

race.Acquire()模拟对共享资源获取锁。

race.ReleaseMerge()模拟对共享资源解锁并合并锁定计数。

race.Write()模拟对共享资源的写操作。

race.Read()模拟对共享资源的读操作。

信号量 semaphore

在系统中,会给每一个进程一个信号量,代表每个进程目前的状态。未得到控制权的进程,会在特定的地方被迫停下来,等待可以继续进行的信号到来。

0 人点赞