提交 fbadd4b6 编写于 作者: Q Qiao Longfei

follow comment test=develop

上级 abf17226
...@@ -133,15 +133,15 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -133,15 +133,15 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
void AppendMultiDevPass(const BuildStrategy &strategy) { void AppendMultiDevPass(const BuildStrategy &strategy) {
ir::Pass *multi_devices_pass; ir::Pass *multi_devices_pass;
if (strategy_.is_distribution_) { if (strategy_.is_distribution_) {
VLOG(3) << "dist train mode"; VLOG(3) << "multi device dist train mode";
multi_devices_pass = AppendPass("dist_multi_devices_pass").get(); multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
} else { } else {
if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
VLOG(3) << "allreduce mode"; VLOG(3) << "multi device allreduce mode";
multi_devices_pass = multi_devices_pass =
AppendPass("allreduce_mode_multi_devices_pass").get(); AppendPass("allreduce_mode_multi_devices_pass").get();
} else if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kReduce) { } else if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
VLOG(3) << "reduce mode"; VLOG(3) << "multi device reduce mode";
multi_devices_pass = AppendPass("reduce_mode_multi_devices_pass").get(); multi_devices_pass = AppendPass("reduce_mode_multi_devices_pass").get();
} else { } else {
PADDLE_THROW("Unknown reduce strategy."); PADDLE_THROW("Unknown reduce strategy.");
......
...@@ -731,7 +731,6 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result, ...@@ -731,7 +731,6 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
} }
} }
insert_op = true; insert_op = true;
need_broadcast_var_ = true;
} else if (OpHaveRole(*node, OpRole::kDist)) { } else if (OpHaveRole(*node, OpRole::kDist)) {
int op_dev_id = CreateDistTrainOp(result, node); int op_dev_id = CreateDistTrainOp(result, node);
if (node->Op()->Type() == "concat") { if (node->Op()->Type() == "concat") {
...@@ -925,6 +924,7 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result, ...@@ -925,6 +924,7 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
} }
void DistSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const { void DistSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const {
// only GPU reduce mode need to broadcast parameters to each device.
if (UseGPU() && strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) { if (UseGPU() && strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
if (strategy_.fuse_broadcast_op_) { if (strategy_.fuse_broadcast_op_) {
CreateFusedBroadcastOp(result, bcast_var_name_set_); CreateFusedBroadcastOp(result, bcast_var_name_set_);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册