提交 de5e56be 编写于 作者: C chengduoZH

add og has been broadcasted

上级 95658767
...@@ -55,6 +55,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -55,6 +55,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
const ProgramDesc &program) const { const ProgramDesc &program) const {
auto graph = new SSAGraph(); auto graph = new SSAGraph();
SSAGraph &result = *graph; SSAGraph &result = *graph;
std::unordered_set<std::string> og_has_bc;
result.vars_.resize(places_.size()); result.vars_.resize(places_.size());
bool is_forwarding = true; bool is_forwarding = true;
...@@ -123,8 +124,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -123,8 +124,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
if (!is_forwarding) { if (!is_forwarding) {
auto var_names = op->OutputArgumentNames(); auto var_names = op->OutputArgumentNames();
for (auto &og : var_names) { for (auto &og : var_names) {
if (grad_names_.count(og) != 0) { // is param grad if (grad_names_.count(og) != 0 &&
// Insert NCCL AllReduce Op og_has_bc.count(og) == 0) { // is param grad
// Insert NCCL AllReduce Op
og_has_bc.insert(og);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
result.ops_.emplace_back( result.ops_.emplace_back(
new NCCLAllReduceOpHandle(local_scopes_, places_, *nccl_ctxs_)); new NCCLAllReduceOpHandle(local_scopes_, places_, *nccl_ctxs_));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册