提交 631aa3d1 编写于 作者: Y Yu Yang

Wait all inputs ready

上级 9b1f4d5d
...@@ -375,6 +375,12 @@ struct NCCLAllReduceOpHandle : public OpHandle { ...@@ -375,6 +375,12 @@ struct NCCLAllReduceOpHandle : public OpHandle {
if (this->inputs_.size() == 1) { if (this->inputs_.size() == 1) {
return; // No need to all reduce when GPU count = 1; return; // No need to all reduce when GPU count = 1;
} else { } else {
// Wait input done
for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_;
in->generated_op_->Wait(dev_ctx_[p]);
}
auto &var_name = static_cast<VarHandle *>(this->inputs_[0])->name_; auto &var_name = static_cast<VarHandle *>(this->inputs_[0])->name_;
VLOG(3) << "Invoke NCCL AllReduce"; VLOG(3) << "Invoke NCCL AllReduce";
int dtype = -1; int dtype = -1;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册