文章目录
- 0 Ray深度强化学习框架概述
- 1 Ray使用场景—多进程(通过ray.remote装饰器实现)
- 2 Ray使用场景—进程间通信
- 3 Tune使用场景—调参
- 4 RLLib使用场景—RL算法
- 5 Ray、Tune和RLLib关系
- 6 Ray系统架构(实现多进程和跨节点通信)
- 6.1 Ray系统架构—概述
- 6.2 Ray系统架构—内存管理
- 7 Tune系统架构(实验资源分配 调参)
- 8 RLLib系统架构(Trainer、Policy和Agent)
- 8.1 Policy实现功能
- 8.2 Trainer实现功能
- 参考文献
0 Ray深度强化学习框架概述
Ray
——分布式框架的基础设施,提供多进程通信和集群维护等功能Tune
——基于Ray的中间库,主要功能是调参(如使用BPT算法异步调参)和多实验排队管理RLLib
——基于Ray的分布式和Tune的调参,实现抽象RL算法,可支持层次RL和Multi Agent学习等
1 Ray使用场景—多进程(通过ray.remote装饰器实现)
2 Ray使用场景—进程间通信
通过拿到远程函数的ID,可以在集群的任何地方,通过get(ID)
获取该函数返回值
3 Tune使用场景—调参
对于PPO通过5个学习率参数,每组实验做两遍,总共10个实验,目前共有8个CPU,每个实验需要1个CPU,Tune可以把这些实验放入到队列中。若目前CPU已满,则等待,下图所示为8个CPU正在作业,剩余2个实验正在等待中
4 RLLib使用场景—RL算法
RLLib基于Tune和Ray实现强化学习算法,下图基于IMPALA框架,图中Trainer维护一个model,每个Rollout Worker会创建1个进程,根据model去异步Trajectory Sampling,然后将多个采样结果反馈给Trainer,Trainer根据采样更新model网络权重,再更新Rollout worker
- Trainer 包含采样、训练、优化、数据处理和传输功能
RLLib对于算法的抽象,一切算法本质都是以下两者的交互
- Workers
- Learner
while True:
trainer.train():
# 1.通过worker去采样data
data = trainer.worker.sample();
# 2.通过data和相应loss反向传播计算更新weight
trainer.optimizer.undate(data);
# 3.将新weight同步到worker中
trainer.worker.sync_weight();
5 Ray、Tune和RLLib关系
6 Ray系统架构(实现多进程和跨节点通信)
6.1 Ray系统架构—概述
- Object Store是跨进程的数据库,类似全局数据库,不同进程可以通过Obj Store数据库获取对应函数Obj ID从而获取数据
在Slurm集群上的脚本案例
6.2 Ray系统架构—内存管理
7 Tune系统架构(实验资源分配 调参)
Tune同时维护多个实验,合理为每个实验的不同请求分配资源,每个实验被抽象成1个Trainable
,TrialExecutor
会根据每个Trainable
需要的CPU/GPU分配合理资源,本质就是优先队列
while (true) {
trainable.train(); // 需设定终止条件
...
}
使用PBT异步调参算法,借鉴遗传算法思想,不同于传统随机算法调参,传统的不同参数是并行且独立调整,因为是固定变量法,有些参数越调越好(比如学习率),有些参数在较差的参数组合下始终无法调好(比如折扣因子),使得浪费计算资源。PBT使得参数调整之间并非独立,会将好的其他参数(如学习率)拿到其他较差的参数(折扣因子)中进行试探,使得在不额外增加计算资源的情况下快速调优
8 RLLib系统架构(Trainer、Policy和Agent)
8.1 Policy实现功能
RLLib有一套完善的build model
系统,只要给定Env
,比如图像,它就会自动创建CNN
Model
Policy主要实现功能
Loss Fun
——用来优化Postprocess Function
——用于数据处理Build Model
——根据Env自动创建适配Model
8.2 Trainer实现功能
- 指定
Policy
——如上1步的PPOTFPolicy
- 选择
Optimizer
——此处为更抽象的optimizer(比Adam更抽象),包含模型 数据的输入,loss的计算和GPU多卡训练等
参考文献
[1] 强化学习系统怎么实现?Ray是啥?Tune和RLLib又是什么? [2] Ray v1.2 官方文档