未验证 提交 b94df7f4 编写于 作者: J Juncheng 提交者: GitHub

Explicitly specify the SBP in NonDistributedOptimizerPass (#3937)

上级 6e71b007
......@@ -146,6 +146,17 @@ Maybe<void> NonDistributedOptimizerPass::Apply(const OpGraph& op_graph, JobBuild
for (const OpNode* node : last_node7node_seqs.second) {
builder->MutParallelConfOnlyOnce(node->op().op_name(), parallel_conf);
}
const LogicalBlobId last_node_out_lbi = last_node->op().BnInOp2Lbi(last_node->op().SoleObn());
last_node->ForEachNodeOnOutEdge([&](const OpNode* dst_node) {
for (const std::string& ibn : dst_node->op().input_bns()) {
if (dst_node->op().BnInOp2Lbi(ibn) == last_node_out_lbi) {
OpBlobArg op_blob_arg;
op_blob_arg.set_op_name(dst_node->op().op_name());
op_blob_arg.set_bn_in_op(ibn);
builder->MutSbpParallel4Oba(op_blob_arg)->mutable_broadcast_parallel();
}
}
});
ParallelDesc new_pd(parallel_conf);
pd2last_nodes[new_pd].push_back(last_node);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册