未验证 提交 6e0e7061 编写于 作者: 乔龙飞 Qiao Longfei 提交者: GitHub

Revert "cpu reduce mode did not need to broadcast params test=develop"

上级 0a3fd011
......@@ -133,15 +133,12 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
void AppendMultiDevPass(const BuildStrategy &strategy) {
ir::Pass *multi_devices_pass;
if (strategy_.is_distribution_) {
VLOG(3) << "multi device dist train mode";
multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
} else {
if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
VLOG(3) << "multi device allreduce mode";
multi_devices_pass =
AppendPass("allreduce_mode_multi_devices_pass").get();
} else if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
VLOG(3) << "multi device reduce mode";
multi_devices_pass = AppendPass("reduce_mode_multi_devices_pass").get();
} else {
PADDLE_THROW("Unknown reduce strategy.");
......
......@@ -731,6 +731,7 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
}
}
insert_op = true;
need_broadcast_var_ = true;
} else if (OpHaveRole(*node, OpRole::kDist)) {
int op_dev_id = CreateDistTrainOp(result, node);
if (node->Op()->Type() == "concat") {
......@@ -924,8 +925,9 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
}
void DistSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const {
// only GPU reduce mode need to broadcast parameters to each device.
if (UseGPU() && strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
if (need_broadcast_var_ ||
(UseGPU() &&
strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce)) {
if (strategy_.fuse_broadcast_op_) {
CreateFusedBroadcastOp(result, bcast_var_name_set_);
} else {
......
......@@ -174,6 +174,7 @@ class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
int CreateDistTrainOp(ir::Graph *result, ir::Node *node) const;
mutable std::vector<std::unordered_set<std::string>> bcast_var_name_set_;
mutable bool need_broadcast_var_{false};
};
std::unordered_set<std::string> &MultiDevSSAGraphBuilder();
......
......@@ -19,7 +19,6 @@ import sys
from .. import compat as cpt
from . import core
from . import framework
__all__ = ['CompiledProgram', 'ExecutionStrategy', 'BuildStrategy']
......@@ -35,15 +34,6 @@ def _place_obj(place):
return p
def _is_pserver_mode(main_program):
main = main_program if main_program \
else framework.default_main_program()
for op in main.global_block().ops:
if op.type in ["send", "recv"]:
return True
return False
class CompiledProgram(object):
"""
Compiles a Program for execution.
......@@ -120,7 +110,6 @@ class CompiledProgram(object):
self._exec_strategy = ExecutionStrategy()
if self._build_strategy is None:
self._build_strategy = BuildStrategy()
self._build_strategy.is_distribution = _is_pserver_mode(self._program)
return self
def with_inference_optimize(self, config):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册