From f28c25845330cf47250f7f6cba67f6f4cdaae97d Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Tue, 5 Mar 2019 17:10:17 +0800 Subject: [PATCH] code clean test=develop --- .../framework/details/multi_devices_graph_pass.cc | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index 109037c3e..c8e9c5d68 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -167,10 +167,6 @@ std::unique_ptr MultiDevSSAGraphBuilderBase::ApplyImpl( bool is_forwarding = true; bool insert_collection_ops = NeedCollectiveOps(); - if (strategy_.async_mode_) { - // async mode did not need to merge gradient - insert_collection_ops = false; - } for (ir::Node *node : sorted_ops) { if (DealWithSpecialOp(&result, node)) { @@ -749,10 +745,6 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result, ir::Node *node) const { bool insert_op = false; if (OpHaveRole(*node, OpRole::kRPC)) { - // in async_mode, each graph will send it's own gradient. - if (strategy_.async_mode_ && node->Op()->Type() == "send") { - return false; - } int op_dev_id = CreateRPCOp(result, node); PADDLE_ENFORCE(op_dev_id != -1, "Can not schedule the RPC operator to the right place."); @@ -768,11 +760,6 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result, insert_op = true; need_broadcast_var_ = true; } else if (OpHaveRole(*node, OpRole::kDist)) { - // in async_mode, each graph will send it's own gradient, do not need to - // merge gradient. - if (strategy_.async_mode_ && node->Op()->Type() != "concat") { - return false; - } int op_dev_id = CreateDistTrainOp(result, node); if (node->Op()->Type() == "concat") { // the input(block of parameter) of concat is on different device, @@ -844,7 +831,7 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const { } auto recv_param_grad = boost::get>( node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); - if (recv_param_grad.size() == 2U && !strategy_.async_mode_) { + if (recv_param_grad.size() == 2U) { op_dev_id = GetVarDeviceID(recv_param_grad[1]); VLOG(10) << "recv param " << recv_param_grad[0] << " get grad place: " << recv_param_grad[1] -- GitLab