提交 3bccc1e6 编写于 作者: Q Qiao Longfei

optimize broadcast logic test=develop

上级 62f1248f
...@@ -925,10 +925,13 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result, ...@@ -925,10 +925,13 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
} }
void DistSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const { void DistSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const {
// only GPU reduce mode need to broadcast parameters to each device. // broad cast received parameters when training in parameter server mode.
if (UseGPU()) { if (need_broadcast_var_) {
if (need_broadcast_var_ || // cpu reduce mode did not need to broadcast received parameters.
if (!UseGPU() &&
strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) { strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
return;
}
if (strategy_.fuse_broadcast_op_) { if (strategy_.fuse_broadcast_op_) {
CreateFusedBroadcastOp(result, bcast_var_name_set_); CreateFusedBroadcastOp(result, bcast_var_name_set_);
} else { } else {
...@@ -940,7 +943,6 @@ void DistSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const { ...@@ -940,7 +943,6 @@ void DistSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const {
} }
} }
} }
}
} }
std::unordered_set<std::string> &MultiDevSSAGraphBuilder() { std::unordered_set<std::string> &MultiDevSSAGraphBuilder() {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册