在多卡训练模型时,遇到一些multiprocessing中spawn模块相关的错误,本文记录相关内容及解决方案。
问题复现
使用 mmdetection
训练时出现过一些莫名其妙的错误:
- 多卡训练时无法 pickle lambda 函数
AttributeError: Can't pickle local object 'Dataset.load_data.<locals>.<lambda>'
- 训练过程中修改代码,改动会引入到已经在运行的程序当中(细思恐极)
单卡时没有相关错误出现
问题原因
在使用 multiprocessing ,Start()
方法被Process
类调用的时候 ,有三种不同的启用子进程的方法,这个方法可以通过multiprocessing.set_start_method
来声明。
该方法有三种字符型的参数:
spawn
父进程会开启一个新的python解释器进程。子进程只会继承需要用来跑run
方法的资源。更具体的,不需要的文件描述以及handles将不会被继承。使用这个方法来启动进程是三种方法里最慢的。(Windows上的默认方法)fork
父进程使用os.fork()
方法来开启子进程。通过这个方式开启的子进程与父进程一毛一样,父进程所有的资源都会被子进程继承。这个只限于Unix类的系统上,Unix,Linux,MacOS的默认方法。forkserver
这个参数会开启一个服务进程,之后,一旦有新的进程需求,父进程就会向server请求一个新的进程。这个fork server是单线程的所以它使用os.fork()
方法是安全的。这个方法不会继承非必须的系统资源。这个参数支持Unix系统。
其中
os.fork()
会避免上述错误内容的出现。
解决方案
- 强制multiprocessing模块使用
fork
方法开启进程
import multiprocessing as mp
import torch.multiprocessing as t_mp
mp.set_start_method('fork', force = True)
start_method = mp.get_start_method()
print(f"mp: {start_method}")
t_mp.set_start_method('fork', force = True)
start_method = t_mp.get_start_method()
print(f"t_mp: {start_method}")
需要在进程开启前设置相关参数方可生效。
参考资料
- https://zhuanlan.zhihu.com/p/136995403