[源码解析] PyTorch 如何实现后向传播 (4)---- 具体算法
0x00 摘要
前文中我们介绍了反向传播引擎的动态逻辑,因为具体反向传播算法是在设备线程中完成的,所以我们单独用一章来讲解。
0x01 工作线程主体
thread_main是工作线程的主体函数,主要逻辑就是围绕着 ReadyQueue 执行一个 while 循环,工作线程阻塞在 ReadyQueue -> pop 这里,如果主线程或者其他线程插入了一个 NodeTask,则 pop 会返回取出一个 NodeTask,工作线程处理这个 NodeTask,完成后向计算的一个环节,如果有需要就继续往某一ReadyQueue插入新的 NodeTask,驱动引擎继续执行后向计算其他环节。
thread_main 从如下途径被调用:
- CUDA, XLA 设备的 autograd threads 会调用。
- CPU 之上的反向传播主线程会调用。
- 前两个case 进行可重入反向传播,也会调用。
1.1 线程主体代码
工作线程的计算始于动态图的GraphRoot函数,反向传播就以 Node 的edge为纽带,层层从前向后计算,直到来到了leaf节点,最终完成了反向计算,具体如下:
- local_graph_task表示我们从队列中检索的graph_task。外部graph_ 任务表示我们需要执行的可重入执行的总体 graph_任务。
- 从自己的ReadyQueue之中取出NodeTask实例,使用 local_graph_task 为参数来执行evaluate_function(反向传播函数)。
- outstanding_tasks 自减 1。
- 如果本 local_graph_task 已经结束(可重入反向传播会运行多个 GraphTask),即:
- 执行后续操作 exec_post_processing,然后使用 future_result_->markCompleted。
- 如果这个task是来自其它worker thread,即 worker_device != base_owner,则向那个worker thread的queue发送一个dummy function task,让那个工作线程也执行起来。
具体代码如下:
代码语言:javascript复制// thread_main is used by:
// 1). autograd threads for devices (i.e. CUDA, XLA)
// 2). the caller/owning thread of the backward call on CPU (sync mode)
// 3). Renetrant backward that invoked by either 1) or 2)
// The exit conditions are different for the above three cases.
// For 1), we are spinning on running the thread_main on device autograd
// threads throughout the Engine lifetime, thread_main will get
// terminated during Engine destruction by pushing shutdown tasks
// For 2), the owning thread of the backward call drives the thread_main
// synchronously until the graph_task of that owning thread is
// completed and exit the thread_main to continue executing the
// result of caller's code.
// For 3), the reentrant backward that invokes
// thread_main, either from 1) or 2), will not spin and will exit as
// long as graph_task is completed and notify the owning thread as
// needed.
auto Engine::thread_main(const std::shared_ptr<GraphTask>& graph_task) -> void {
// When graph_task is nullptr, this is a long running thread that processes
// tasks (ex: device threads). When graph_task is non-null (ex: reentrant
// backwards, user thread), this function is expected to exit once that
// graph_task complete.
// local_ready_queue should already been initialized when we get into thread_main
while (graph_task == nullptr || !graph_task->future_result_->completed()) {
// local_graph_task represents the graph_task we retrieve from the queue.
// The outer graph_task represents the overall graph_task we need to execute
// for reentrant execution.
std::shared_ptr<GraphTask> local_graph_task;
{
// Scope this block of execution since NodeTask is not needed after this
// block and can be deallocated (release any references to grad tensors
// as part of inputs_).
NodeTask task = local_ready_queue->pop(); // 阻塞等待
// This will only work if the worker is running a non backward task
// TODO Needs to be fixed this to work in all cases
if (task.isShutdownTask_) {
break;
}
if (!(local_graph_task = task.base_.lock())) {
// GraphTask for function is no longer valid, skipping further
// execution.
continue;
}
if (task.fn_ && !local_graph_task->has_error_.load()) {
// 利用grad_mode_来配置AutoGradMode,整个反向计算期间的代码都靠GradMode::is_enabled()来判断当前是否是要计算grad
AutoGradMode grad_mode(local_graph_task->grad_mode_);
try {
// The guard sets the thread_local current_graph_task on construction
// and restores it on exit. The current_graph_task variable helps
// queue_callback() to find the target GraphTask to append final
// callbacks.
GraphTaskGuard guard(local_graph_task);
NodeGuard ndguard(task.fn_);
// 执行后向计算
evaluate_function(local_graph_task, task.fn_.get(), task.inputs_, local_graph_task->cpu_ready_queue_);
} catch (std::exception& e) {
thread_on_exception(local_graph_task, task.fn_, e);
}
}
}
// Decrement the outstanding tasks.
--local_graph_task->outstanding_tasks_;
// Check if we've completed execution.
if (local_graph_task->completed()) { // 已经结束了,进行后续处理
local_graph_task->mark_as_completed_and_run_post_processing();
auto base_owner = local_graph_task->owner_; // 后续是需要在 GraphTask 的 owner_ 处理
// The current worker thread finish the graph_task, but the owning thread
// of the graph_task might be sleeping on pop() if it does not have work.
// So we need to send a dummy function task to the owning thread just to
// ensure that it's not sleeping, so that we can exit the thread_main.
// If it has work, it might see that graph_task->outstanding_tasks_ == 0
// before it gets to the task, but it's a no-op anyway.
//
// NB: This is not necessary if the current thread is the owning thread.
if (worker_device != base_owner) {
// Synchronize outstanding_tasks_ with queue mutex
std::atomic_thread_fence(std::memory_order_release);
// 获取后续工作的queue
ready_queue_by_index(local_graph_task->cpu_ready_queue_, base_owner)
->push(NodeTask(local_graph_task, nullptr, InputBuffer(0)));
}
}
}
}
1.2 使用 Ready Queue
上述代码之中,最后使用 ready_queue_by_index 获取到后续工作对应的queue。
代码语言:javascript复制ready_queue_by_index(local_graph_task->cpu_ready_queue_, base_owner)
->push(NodeTask(local_graph_task, nullptr, InputBuffer(0)));
如何获取Ready Queue?具体策略是:
- 如果下一个 需要执行的设备是 CPU,则选用cpu_ready_queue。
- 否则从device_ready_queues_选取一个GPU对应的 ReadyQueue。
代码如下:
代码语言:javascript复制auto Engine::ready_queue_by_index(std::shared_ptr<ReadyQueue> cpu_ready_queue, int device_index) -> std::shared_ptr<ReadyQueue> {
if (device_index == CPU_DEVICE) {
// return the cpu ready queue passed in
TORCH_INTERNAL_ASSERT(cpu_ready_queue);
return cpu_ready_queue;
} else {
// Static cast is ok here as the number of device should never overflow an int.
TORCH_INTERNAL_ASSERT(0 <= device_index && device_index < static_cast<int>(device_ready_queues_.size()));
// See Note [Allocating GPUs to autograd threads]
// NB: This function would become obsolete if we truly allocated a CPU thread
// per device, rather than colocate.
return device_ready_queues_.at(device_index);
}
}
逻辑如下:
代码语言:javascript复制 ---------------------------------------------------------------------
| Main Thread |
| |
| push(NodeTask) -------------- |
| | |
---------------------------------------------------------------------
|
|
v
------ -----
| |
| ReadyQueue |
| |
------ -----
|
|
|
---------------------------------------------------------------------
| Worker Thread 1 | |
| | |
| thread_main{ | |
| v |
| NodeTask task = local_ready_queue->pop() |
| |
| evaluate_function(task.fn_.get(),task.inputs_) |
| } |
---------------------------------------------------------------------
0x02 反向计算总体逻辑
evaluate_function 方法完成了反向计算的逻辑,总体逻辑如下:
- 准备工作:如果exec_info需要处理,则处理 captured_vars_。
- 反向计算:调用 call_function(graph_task, func, inputs),这是反向传播中计算相关的核心逻辑:
- 调用pre hooks。
- 调用fn进行计算。
- 调用post hooks。
- 扫尾工作:
- 如果不需要keep graph,则fn.release_variables();
- 依据 call_function的输出 outputs,进行计算 num_outputs = outputs.size(),得到 num_outputs的元素数量(该数量等同于当前fn的next_edge()返回的list中的元素数量)。
- 准备下一步工作,具体就是查找后续需要计算的NodeTask,num_outputs 就是在这里被用到。这部分比较复杂。
总体代码如下:
代码语言:javascript复制void Engine::evaluate_function(
std::shared_ptr<GraphTask>& graph_task,
Node* func, // 导数计算方法
InputBuffer& inputs, // 当前Node的输入梯度
const std::shared_ptr<ReadyQueue>& cpu_ready_queue) {
// 进行准备工作
// If exec_info_ is not empty, we have to instrument the execution
auto& exec_info_ = graph_task->exec_info_;
if (!exec_info_.empty()) {
auto& fn_info = exec_info_.at(func); // 取出当前的进行处理
if (auto* capture_vec = fn_info.captures_.get()) {
// Lock mutex for writing to graph_task->captured_vars_.
std::lock_guard<std::mutex> lock(graph_task->mutex_);
for (const auto& capture : *capture_vec) {
// captured_grad 就是临时存储下,每次node计算都会更新,最终输出给调用者,相当于引用
// 1. captured_grad 引用了captured_vars_[capture.output_idx_],
auto& captured_grad = graph_task->captured_vars_[capture.output_idx_];
// 2. 给 captured_vars_[capture.output_idx_] 赋值 inputs[capture.input_idx_]
captured_grad = inputs[capture.input_idx_];
// 遍历hooks,链式调用hook进行计算,captured_grad 不停的作为输入和输出在流水线中流淌
// 就是针对 captured_vars_[capture.output_idx_]不停的计算,最终结果还是在 captured_vars_[capture.output_idx_] 之中。
for (auto& hook : capture.hooks_) {
captured_grad = (*hook)(captured_grad);
}
}
}
if (!fn_info.needed_) {
// Skip execution if we don't need to execute the function.
return;
}
}
// Set the ThreadLocalState before calling the function.
// NB: The ThreadLocalStateGuard doesn't set the grad_mode because GraphTask
// always saves ThreadLocalState without grad_mode.
at::ThreadLocalStateGuard tls_guard(graph_task->thread_locals_);
// Switches to a function's CUDA stream (if applicable) before calling it
const auto opt_parent_stream = (*func).stream(c10::DeviceType::CUDA);
c10::OptionalStreamGuard parent_stream_guard{opt_parent_stream};
// 进行反向计算
auto outputs = call_function(graph_task, func, inputs);
// 如果不需要保持计算图,则本节点释放变量
auto& fn = *func;
if (!graph_task->keep_graph_) {
fn.release_variables();
}
// 得到 num_outputs的元素数量(该数量等同于当前fn的next_edge()返回的list中的元素数量),后续遍历本节点输出时候会用到
int num_outputs = outputs.size();
if (num_outputs == 0) { // Note: doesn't acquire the mutex
// Records leaf stream (if applicable)
// See note "Streaming backwards"
if (opt_parent_stream) {
std::lock_guard<std::mutex> lock(graph_task->mutex_);
graph_task->leaf_streams.emplace(*opt_parent_stream);
}
return;
}
if (AnomalyMode::is_enabled()) {
AutoGradMode grad_mode(false);
for (int i = 0; i < num_outputs; i) {
auto& output = outputs[i];
at::OptionalDeviceGuard guard(device_of(output));
if (output.defined() && isnan(output).any().item<uint8_t>()) {
std::stringstream ss;
}
}
}
// 准备下一步工作
// Lock mutex for the accesses to GraphTask dependencies_, not_ready_ and cpu_ready_queue_ below
std::lock_guard<std::mutex> lock(graph_task->mutex_);
for (int i = 0; i < num_outputs; i) {
auto& output = outputs[i];
const auto& next = fn.next_edge(i); // next_edge是该node在前向传播图中的输入,在反向传播时候就是本节点的输出,所以next就是下一个可能运算的节点
if (!next.is_valid()) continue;
// Check if the next function is ready to be computed
bool is_ready = false;
auto& dependencies = graph_task->dependencies_;
auto it = dependencies.find(next.function.get()); // 找到下一个节点的依赖
if (it == dependencies.end()) {
auto name = next.function->name();
throw std::runtime_error(std::string("dependency not found for ") name);
} else if (--it->second == 0) {
dependencies.erase(it);
is_ready = true; // 下一个节点没有入度了,那么说明计算该节点梯度依赖的其他节点梯度都已经计算完成
}
// 要去 not_ready里面看看,是否已经存储了
auto& not_ready = graph_task->not_ready_;
auto not_ready_it = not_ready.find(next.function.get());
if (not_ready_it == not_ready.end()) {
// 下一个节点的梯度还没有进行计算
// Skip functions that aren't supposed to be executed
// 跳过不需要计算的节点
if (!exec_info_.empty()) {
auto it = exec_info_.find(next.function.get());
if (it == exec_info_.end() || !it->second.should_execute()) {
continue;
}
}
// No buffers have been allocated for the function
InputBuffer input_buffer(next.function->num_inputs()); // 下一个节点前置梯度的buffer,就是下一个节点的输入梯度
// Accumulates into buffer
// 下一个节点的输入梯度就是当前节点的输出,所以要拷贝过去
const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
input_buffer.add(next.input_nr,
std::move(output),
opt_parent_stream,
opt_next_stream);
if (is_ready) {
auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
// 既然依赖全部完成,就插入到ReadyQueue 之中
queue->push(
NodeTask(graph_task, next.function, std::move(input_buffer)));
} else {
// 下一个节点的输入依赖还没有完成,就放到not_ready之中。
not_ready.emplace(next.function.get(), std::move(input_buffer));
}
} else {
// 如果下一个节点已经开始计算,但是没有完成(就是依赖梯度还有),此时应该在not_ready之中
// The function already has a buffer
auto &input_buffer = not_ready_it->second;
// Accumulates into buffer
const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
input_buffer.add(next.input_nr,
std::move(output),
opt_parent_stream,
opt_next_stream);
// Graph中每一个node(fn)的输出是下一个node(fn)的输入,下面4句代码来将前一个fn的输出转化为下一个fn的输入
if (is_ready) {
// 如果此时已经没有输入依赖,就放入新的NodeTask,就是下一个需要计算梯度的NodeTask
auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
queue->push(
NodeTask(graph_task, next.function, std::move(input_buffer)));
//已经完成下一个节点前置梯度计算,从not_ready中移除相应的buffer
not_ready.erase(not_ready_it);
}
}
}
}
因为这部分代码十分复杂,我们逐一进行分析。
0x03 准备工作
首先我们看看准备工作,具体如下:
- 取出当前 Node 的 ExecInfo。
- 取出其 captures_,遍历其中每一个 Capture。
- 遍历Capture 的 hooks,链式调用hook进行计算。
- captured_grad 不停的作为输入和输出在流水线中流淌,针对
captured_vars_[capture.output_idx_]
陆续计算。 - 最终结果保存在
captured_vars_[capture.output_idx_]
之中。
- captured_grad 不停的作为输入和输出在流水线中流淌,针对
代码中有一个细节,就是captured_grad 只是临时存储,每次node计算都会更新,最终输出给调用者,相当于引用。
代码语言:javascript复制void Engine::evaluate_function(
std::shared_ptr<GraphTask>& graph_task,
Node* func, // 导数计算方法
InputBuffer& inputs, // 当前Node的输入梯度
const std::shared_ptr<ReadyQueue>& cpu_ready_queue) {
// 进行准备工作
// If exec_info_ is not empty, we have to instrument the execution
auto& exec_info_ = graph_task->exec_info_;
if (!exec_info_.empty()) {
auto& fn_info = exec_info_.at(func); // 取出当前的进行处理
if (auto* capture_vec = fn_info.captures_.get()) {
// Lock mutex for writing to graph_task->captured_vars_.
std::lock_guard<std::mutex> lock(graph_task->mutex_);
for (const auto& capture : *capture_vec) {
// captured_grad 就是临时存储下,每次node计算都会更新,最终输出给调用者,相当于引用
// 1. captured_grad 引用了captured_vars_[capture.output_idx_],
auto& captured_grad = graph_task->captured_vars_[capture.output_idx_];
// 2. 给 captured_vars_[capture.output_idx_] 赋值 inputs[capture.input_idx_]
captured_grad = inputs[capture.input_idx_];
// 遍历hooks,链式调用hook进行计算,captured_grad 不停的作为输入和输出在流水线中流淌
// 就是针对 captured_vars_[capture.output_idx_]不停的计算,最终结果还是在 captured_vars_[capture.output_idx_] 之中。
for (auto& hook : capture.hooks_) {
captured_grad = (*hook)(captured_grad);
}
}
}
if (!fn_info.needed_) {
// Skip execution if we don't need to execute the function.
return;
}
}
0x04 核心逻辑
call_function是反向传播中计算相关的核心逻辑。
- 调用注册在本 node上的pre_hooks;
- 调用node本身,比如MeanBackward0、MulBackward0等。
- 输入是
InputBuffer::variables(std::move(inputBuffer))
,一组Variable的实例。当动态图刚开始进行反向计算时,引擎首先执行的是图的根节点——graph_root,它的输入是task.inputs——InputBuffer(0)。 - 调用的是fn的apply(),apply是多态实现,针对不同的operation会dispatch到operation对应的apply实现上。
- 输出也是一组Variable的实例 outputs = fn(std::move(inputs_copy)),outputs 要作为下一个fn的输入。
- 输入是
- 调用注册在node上的post hooks。
- 返回当前节点对应的导数,这是一个variable_list。
具体代码如下:
代码语言:javascript复制static variable_list call_function(
std::shared_ptr<GraphTask>& graph_task,
Node* func,
InputBuffer& inputBuffer) {
CheckpointValidGuard cpvguard(graph_task);
auto& fn = *func;
auto inputs =
call_pre_hooks(fn, InputBuffer::variables(std::move(inputBuffer)));
if (!graph_task->keep_graph_) {
fn.will_release_variables();
}
const auto has_post_hooks = !fn.post_hooks().empty();
variable_list outputs;
if (has_post_hooks) {
// In functions/accumulate_grad.cpp, there is some logic to check the
// conditions under which the incoming gradient can be stolen directly
// (which elides a deep copy) instead of cloned. One of these conditions
// is that the incoming gradient's refcount must be 1 (nothing else is
// referencing the same data). Stashing inputs_copy here bumps the
// refcount, so if post hooks are employed, it's actually still ok for
// accumulate_grad.cpp to steal the gradient if the refcount is 2.
//
// "new_grad.use_count() <= 1 !post_hooks().empty()" in
// accumulate_grad.cpp accounts for this, but also creates a silent
// dependency between engine.cpp (ie, this particular engine
// implementation) and accumulate_grad.cpp.
//
// If you change the logic here, make sure it's compatible with
// accumulate_grad.cpp.
auto inputs_copy = inputs;
outputs = fn(std::move(inputs_copy));
} else {
outputs = fn(std::move(inputs));
}
validate_outputs(fn.next_edges(), outputs, [&](const std::string& msg) {
std::ostringstream ss;
return ss.str();
});
if(has_post_hooks){
return call_post_hooks(fn, std::move(outputs), inputs);
}
return outputs;
}
0x05 准备下一步工作
这部分是反向传播的复杂之处。
现在调用 call_function,得到了后向传播的输出,记录到了 outputs 之中。
代码语言:javascript复制auto outputs = call_function(graph_task, func, inputs);
所以,后半部分就是从 outputs 之中寻找后续可以计算的Node。
总体思路就是:遍历后向传播的输出节点(就是该节点在前向计算图中的入边连接的节点),逐一衡量输出节点。遍历循环中分为两段代码,对于每一个输出节点做如下操作:
- 第一段是依据依赖排查这个节点,得到这个节点是否就绪。核心就是看看这个输出节点在GraphTask的dependencies的计数是否降为0。
- 如果是0,就说明这个节点就绪了,说明这个node不会被未来的计算所依赖了。
- 如果非0,就说明这个节点有多个输入,即,被多个node连接,而且有的输入还没有计算完成梯度。
- 第二段是依据是否就绪来处理这个节点,比如放入哪一个queue。
5.1 依据依赖排查节点
第一段代码功能是依据依赖关系来 排查节点,得到这个节点是否就绪,具体如下:
代码语言:javascript复制 for (int i = 0; i < num_outputs; i) { // 遍历输出节点,逐一衡量
auto& output = outputs[i];
const auto& next = fn.next_edge(i); // 获得一个输出节点
if (!next.is_valid()) continue;
// Check if the next function is ready to be computed
bool is_ready = false;
auto& dependencies = graph_task->dependencies_; // 拿到GraphTask的依赖关系
auto it = dependencies.find(next.function.get()); // 找到输出节点的依赖项
if (it == dependencies.end()) {
auto name = next.function->name(); // 没找到
throw std::runtime_error(std::string("dependency not found for ") name);
} else if (--it->second == 0) {
dependencies.erase(it); // 找到了,并且已经计算完毕
is_ready = true;
}
auto& not_ready = graph_task->not_ready_;
auto not_ready_it = not_ready.find(next.function.get()); // 找到输入buffer
现在已经找到了某一个输出节点,也知道其是否计算完毕(依据有没有依赖项),也拿到了其存在"未就绪队列"的输入buffer(如果存在的话)。
5.2 处理这个节点
第二段是依据是否就绪来处理这个节点,比如放入哪一个queue,是就绪队列?还是未就绪队列?核心是:
- 如果就绪,就放到该节点对应的 ReadyQueue 去处理。
- 如果没有就绪,就新建立一个NodeTask放到 GraphTask的 not_ready 等待后续处理。需要注意的是,这个新的NodeTask 是在 worker thread 之中创建的。
- 如何找到 ReadyQueue?需要看这个 Node 节点的 input_buffer.device() ,即,这个新 NodeTask 应该发送到 input_buffer.device() 那个 device 对应的 ReadyQueue。
我们具体看看如何依据 is_ready 的数值来对 not_ready 进行操作。
- 如果在 未就绪队列 not_ready 之中 没有找到 next_edge 对应的元素,则:
- 如果 exec_info_ 不为空,则在 exec_info_ 之中查找 next_edge 对应的元素,如果有元素且注明了不需要执行,就跳到for循环的下一个。
- 用 next_edge 的流,inut_nr 等信息构建一个 input_buffer。
- 如果 is_ready 是 True,就用 本 GraphTask,next.function,input_buffer构建一个NodeTask,放入 ReadyQueue(利用 input_buffer.device() 来得到对应的 queue)。这就要唤醒下一个 worker 线程。
- 如果 is_ready 是 False,这通常表明这个node有多个输入(被更多的node连接,使用num_inputs()可以获得数量),也说明此次处理的是这个node的第一个输入,后续还需要使用这个 next_edge,所以这个 next_edge 需要被放到 not_ready 之中。则把 next.function,input_buffer 放入到 not_ready 之中,这个input_buffer 就是 next_edge 后续执行时候需要的各种输入。
- 如果在 未就绪队列 not_ready 之中找到了 next_edge 对应的元素,则:
- 拿出来该元素对应的 input_buffer,把信息累积到 input_buffer 之中。此次累积的是该节点的其他输入。 input_buffer.add(next.input_nr, std::move(output), opt_parent_stream, opt_next_stream) 完成了累积操作,next.input_nr 就表明当前的node是反向传播中要流向的node(next)的第几个输入。
- 如果is_ready 是 True,就用 本 GraphTask,next.function,input_buffer构建一个NodeTask,放入 ReadyQueue。这就要唤醒下一个 worker 线程。
- 从 not_ready 之中移除此元素,就是从 GraphTask 的依赖关系之中去除。
代码如下:
代码语言:javascript复制 if (not_ready_it == not_ready.end()) {
// Skip functions that aren't supposed to be executed
if (!exec_info_.empty()) {
auto it = exec_info_.find(next.function.get());
if (it == exec_info_.end() || !it->second.should_execute()) {
continue;
}
}
// No buffers have been allocated for the function
InputBuffer input_buffer(next.function->num_inputs());
// Accumulates into buffer
const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
input_buffer.add(next.input_nr,
std::move(output),
opt_parent_stream,
opt_next_stream);
if (is_ready) {
// 找出了下一个Node的queue
auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
queue->push( //
NodeTask(graph_task, next.function, std::move(input_buffer)));
} else {
not_ready.emplace(next.function.get(), std::move(input_buffer));
}
} else {
// The function already has a buffer
auto &input_buffer = not_ready_it->second;
// Accumulates into buffer
const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
input_buffer.add(next.input_nr,
std::move(output),
opt_parent_stream,
opt_next_stream);
if (is_ready) {
// 找出了下一个Node的queue
auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
queue->push(
NodeTask(graph_task, next.function, std::move(input_buffer)));
not_ready.erase(not_ready_it);
}
}
具体逻辑图如下:
- func 指向了目前正在进行反向计算的 Node。
- func 调用自己的 apply 方法进行计算,得出了 outputs,假设有3个输出,遍历,我们选择第三个为 output。
- func 的边是 next_edges_ 成员变量,遍历,我们选择第三个边为next。
- 用 next 和 GraphTask 的 dependencies_ 来判断 next 是不是就绪。
- 如果就绪,把 output 构建一个 input_buffer,然后生成一个 NodeTask,插入到对应的 ReadyQuieue。
- 如果没就绪,把 output 构建一个 input_buffer,和 next 一起放入 GraphTask 的 not_ready_,后续会使用。
1 ---------------
func --> | Node | ---> ...
| | |
| | |
| apply() ------> outputs ------> ... 2
| | |
| | |
| | | --------------
| | ---> output --> | input_buffer --
| | -------------- |
| | |
| | |
| | | 5
| | |
| | |
| | ----> ... |
| | | ---------
| | | | |
| next_edges_ ---> ----> ... 3 | |
| | | | |
| | | | |
| | | 5 v |
| | ----> next ------> YES | ------------
--------------- | ---> push(NodeTask) -----> | ReadyQueue |
| 4 | | ------------
| | |
--------------- --> Ready? - |
| GraphTask | | | 6 |
| | | | NO | 6
| | | ----> next.function |
| dependencies_ --> map<Node*, int> --> |
| | | |
| | | |
| | 6 v v
| not_ready_ ---------------------------------------------> map<Node*, InputBuffer>
| |
---------------
手机如下:
0x06 扫尾操作
在 thread_main 之中,如果本task已经结束,即做后续操作,具体代码如下。
代码语言:javascript复制auto Engine::thread_main(const std::shared_ptr<GraphTask>& graph_task) -> void {
// 忽略前面代码
// Check if we've completed execution.
if (local_graph_task->completed()) { // 判断是否结束
// 如果结束了,就进行后续操作
local_graph_task->mark_as_completed_and_run_post_processing();
auto base_owner = local_graph_task->owner_;
// The current worker thread finish the graph_task, but the owning thread
// of the graph_task might be sleeping on pop() if it does not have work.
// So we need to send a dummy function task to the owning thread just to
// ensure that it's not sleeping, so that we can exit the thread_main.
// If it has work, it might see that graph_task->outstanding_tasks_ == 0
// before it gets to the task, but it's a no-op anyway.
//
// NB: This is not necessary if the current thread is the owning thread.
if (worker_device != base_owner) {
// Synchronize outstanding_tasks_ with queue mutex
std::atomic_thread_fence(std::memory_order_release);
ready_queue_by_index(local_graph_task->cpu_ready_queue_, base_owner)
->push(NodeTask(local_graph_task, nullptr, InputBuffer(0)));
}
}
我们接下来分析这些扫尾工作。注意,这里是 thread_main 之中的扫尾工作。
6.1 判断结束
以下代码用来判断本 GraphTask是否结束,其实就是 ReadyQueue 之中是否还有待运行的 NodeTask。
outstanding_tasks_ 是待处理 NodeTask的数量,用来判断该GrapTask是否还需要执行,其数值总是先加再减,如果数目为0,则说明任务结束了。
- 当 GraphTask 被创建出来时候,此数值为0。
- 如果有一个NodeTask被送入到 ReadyQueue,则outstanding_tasks_ 增加 1。
- 如果在工作线程作执行一次 evaluate_function(task)后,outstanding_tasks的值减 1。
- 如果这个数量不为0,则此GraphTask依然需要运行。
bool GraphTask::completed() {
// outstanding_tasks在evaluate_function中可能会被改变
return outstanding_tasks_.load() == 0 ||
(exit_on_error_ && has_error_.load());
}
6.2 后续&通知
mark_as_completed_and_run_post_processing 就是进行后续处理。
执行后续操作 exec_post_processing,然后使用 future_result_->markCompleted 通知主线程。
代码语言:javascript复制void GraphTask::mark_as_completed_and_run_post_processing() {
// Allow only one thread one attempt to process this logic.
if (future_completed_.exchange(true)) {
// Future is already marked complete, or being marked as such.
// In case the marking complete is only in progress, we add a
// wait() to guarantee the future is marked complete on exit.
future_result_->wait();
return;
}
try {
// Run post processing, before marking the future as complete.
// Drop lock prior to completing, to avoid holding across callbacks.
std::unique_lock<std::mutex> lock(mutex_);
exec_post_processing(); // 进行后续操作
std::vector<Variable> vars = std::move(captured_vars_);
// Need to unlock before we call markCompleted to avoid holding locks
// when the callbacks are called.
lock.unlock();
future_result_->markCompleted(std::move(vars)); // 通知主线程
} catch (std::exception& e) {
future_result_->setErrorIfNeeded(std::current_exception());
}
}
6.2.1 后续操作
后续操作,如果之前有注册了 callback,则进行调用。也会进行流同步。
代码语言:javascript复制void GraphTask::exec_post_processing() {
if (!not_ready_.empty()) {
throw std::runtime_error("could not compute gradients for some functions");
}
// set the thread_local current_graph_task_ as more callbacks can be installed
// by existing final callbacks.
GraphTaskGuard guard(shared_from_this());
// Lock mutex during each iteration for accessing final_callbacks.size()
// Unlocking is necessary, because the callback can register
// more callbacks (or they can be registered from other threads
// while it's waiting.
std::unique_lock<std::mutex> cb_lock(final_callbacks_lock_);
// WARNING: Don't use a range-for loop here because more callbacks may be
// added in between callback calls, so iterators may become invalidated.
for (size_t i = 0; i < final_callbacks_.size(); i) {
cb_lock.unlock();
final_callbacks_[i]();
cb_lock.lock();
}
// Syncs leaf streams with default streams (if necessary)
// See note "Streaming backwards"
for (const auto& leaf_stream : leaf_streams) {
const auto guard = c10::impl::VirtualGuardImpl{c10::DeviceType::CUDA};
const auto default_stream = guard.getDefaultStream(leaf_stream.device());
if (leaf_stream != default_stream) {
auto event = c10::Event{c10::DeviceType::CUDA};
event.record(leaf_stream);
default_stream.wait(event);
}
}
}
6.2.2 通知主线程
之前在 execute 之中会用 fut->wait() 来等待任务完成。下面我们省略了部分代码。
代码语言:javascript复制auto Engine::execute(const edge_list& roots,
const variable_list& inputs,
bool keep_graph,
bool create_graph,
bool accumulate_grad,
const edge_list& outputs) -> variable_list {
// Queue the root
if (skip_dummy_node) {
execute_with_graph_task(graph_task, graph_root, std::move(input_buffer));
} else {
execute_with_graph_task(graph_task, graph_root, InputBuffer(variable_list()));
}
auto& fut = graph_task->future_result_;
fut->wait();
return fut->value().toTensorVector();
}
在 mark_as_completed_and_run_post_processing 会用如下代码来通知主线程。
代码语言:javascript复制future_result_->markCompleted(std::move(vars)); // 通知主线程
6.3 通知其他线程
如果这个task是来自其它work thread,即 worker_device != base_owner,则向那个worker thread的queue发送一个dummy function task,让那个工作线程也执行起来。
local_graph_task
表示我们从队列中检索的 graph_task
。外部graph_
任务表示我们需要执行的可重入执行的总体graph_任务。
在 thread_main 之中,有一个 work around。就是:当前工作线程完成 graph_task,但此时,拥有graph_task的线程可能正在pop()上等待休眠。因此,我们需要向所属线程发送一个仿造的函数任务,以唤醒它,这样我们可以退出thread_main。
这种情况发生在可重入反向传播的情形。
代码语言:javascript复制// If worker_device is any devices (i.e. CPU, CUDA): this is a re-entrant
// backward call from that device.
graph_task->owner_ = worker_device;
具体代码如下:
代码语言:javascript复制 // Check if we've completed execution.
if (local_graph_task->completed()) {
local_graph_task->mark_as_completed_and_run_post_processing();
auto base_owner = local_graph_task->owner_; // 当前设备
if (worker_device != base_owner) {
// 不是同一个设备
// Synchronize outstanding_tasks_ with queue mutex
std::atomic_thread_fence(std::memory_order_release);
ready_queue_by_index(local_graph_task->cpu_ready_queue_, base_owner)
->push(NodeTask(local_graph_task, nullptr, InputBuffer(0))); // dummy task
}
}
其他线程当收到了 dummy task 之后,不会处理,因为 function 是 nullptr,然后就调用 local_ready_queue->pop() 继续从自己的queue 中读取下一个 task。
具体如下:
- 主线程等待。
- 如果工作线程发现GraphTask 已经结束,就通知主线程。
- 如果需要唤醒其他线程,就向该线程对应的 queue 插入 NodeTask。
- 对应线程取出 NodeTask 进行执行。
------------------------------------------------
| Worker Thread 1 |
| |
| thread_main{ |
| |
| mark_as_completed_and_run_post_processing |
2 markCompleted() | { |
------------------- |
| | } |
| | |
--------------- | | push(NodeTask) ----- |
| Main Thread | | | | |
| | | | } | |
| | | | | |
| | | ------------------------------------------------
| | | |
| | | 3 |
| | v v
| | ------- -------
| | 1 ---------------- | |
| | wait() | | | ReadyQueue |
| ------------> | future_result_ | | |
| | | | ------- -------
| | ---------------- |
| | |
| | 4 | pop(NodeTask)
| | |
| | v
| | -------- ---------------------
| | | Worker Thread 2 |
| | | |
| | | |
--------------- | |
| |
| |
------------------------------
至此,后向传播已经分析完毕,从下一篇开始,我们正式进入 PyTorch 分布式训练。
0xFF 参考
https://www.zhihu.com/column/gemfield
【PyTorch】聊聊 backward 背后的代码
pytorch笔记(计算图 autograd)-Node(1)
详解Pytorch中的网络构造
PyTorch的优化器
PyTorch的分布式
PyTorch的Tensor(下)
PyTorch的Tensor(中)
PyTorch的Tensor(上)
PyTorch的动态图(下)
PyTorch的动态图(上)
PyTorch Internals 5:Autograd的实现
A GENTLE INTRODUCTION TO TORCH.AUTOGRAD
PyTorch学习笔记(12)——PyTorch中的Autograd机制介绍
PyTorch 的 Autograd