提交 f28c2584 编写于 作者: Q Qiao Longfei

code clean test=develop

上级 e92ad8a2
...@@ -167,10 +167,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl( ...@@ -167,10 +167,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
bool is_forwarding = true; bool is_forwarding = true;
bool insert_collection_ops = NeedCollectiveOps(); bool insert_collection_ops = NeedCollectiveOps();
if (strategy_.async_mode_) {
// async mode did not need to merge gradient
insert_collection_ops = false;
}
for (ir::Node *node : sorted_ops) { for (ir::Node *node : sorted_ops) {
if (DealWithSpecialOp(&result, node)) { if (DealWithSpecialOp(&result, node)) {
...@@ -749,10 +745,6 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result, ...@@ -749,10 +745,6 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
ir::Node *node) const { ir::Node *node) const {
bool insert_op = false; bool insert_op = false;
if (OpHaveRole(*node, OpRole::kRPC)) { if (OpHaveRole(*node, OpRole::kRPC)) {
// in async_mode, each graph will send it's own gradient.
if (strategy_.async_mode_ && node->Op()->Type() == "send") {
return false;
}
int op_dev_id = CreateRPCOp(result, node); int op_dev_id = CreateRPCOp(result, node);
PADDLE_ENFORCE(op_dev_id != -1, PADDLE_ENFORCE(op_dev_id != -1,
"Can not schedule the RPC operator to the right place."); "Can not schedule the RPC operator to the right place.");
...@@ -768,11 +760,6 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result, ...@@ -768,11 +760,6 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
insert_op = true; insert_op = true;
need_broadcast_var_ = true; need_broadcast_var_ = true;
} else if (OpHaveRole(*node, OpRole::kDist)) { } else if (OpHaveRole(*node, OpRole::kDist)) {
// in async_mode, each graph will send it's own gradient, do not need to
// merge gradient.
if (strategy_.async_mode_ && node->Op()->Type() != "concat") {
return false;
}
int op_dev_id = CreateDistTrainOp(result, node); int op_dev_id = CreateDistTrainOp(result, node);
if (node->Op()->Type() == "concat") { if (node->Op()->Type() == "concat") {
// the input(block of parameter) of concat is on different device, // the input(block of parameter) of concat is on different device,
...@@ -844,7 +831,7 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const { ...@@ -844,7 +831,7 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const {
} }
auto recv_param_grad = boost::get<std::vector<std::string>>( auto recv_param_grad = boost::get<std::vector<std::string>>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
if (recv_param_grad.size() == 2U && !strategy_.async_mode_) { if (recv_param_grad.size() == 2U) {
op_dev_id = GetVarDeviceID(recv_param_grad[1]); op_dev_id = GetVarDeviceID(recv_param_grad[1]);
VLOG(10) << "recv param " << recv_param_grad[0] VLOG(10) << "recv param " << recv_param_grad[0]
<< " get grad place: " << recv_param_grad[1] << " get grad place: " << recv_param_grad[1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册