diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index ebd1d644bcce0554904ee0931827ef43a6d52b11..4f3889a71e865ce6f9799d976c6e6ec5c365429b 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -680,7 +680,8 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, } if (node->Op()->Type() == "split_byref" || - node->Op()->Type() == "split_selected_rows") { + node->Op()->Type() == "split_selected_rows" || + node->Op()->Type() == "split_ids") { // TODO(paddle-dev): getting the first var is not safe. op_dev_id = GetVarDeviceID(*result, input_var_names[0]); if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index 398f7095968e62f92d610f560d7574b27706d13e..cb22403de5e05a2f83438c854590376497fda918 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -69,6 +69,12 @@ bool IsDistTrainOp(ir::Node *node, const std::vector &send_vars, std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) { return true; } + + if (!(var.find(".block") == std::string::npos && + var.find(".pserver") == std::string::npos) && + std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) { + return true; + } } return false; }; diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index eab01de0acb4fb8df4d200a384fd220182e7e62d..443f0a95f3e3c0ed535333aa26ecc99dc99d64c1 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -1126,7 +1126,8 @@ to transpile() call.") inputs={ 'Ids': [program.global_block().vars[table_grad_name]] }, - outputs={"Out": self.trainer_side_table_grad_list}) + outputs={"Out": self.trainer_side_table_grad_list}, + attrs={RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE}) program.global_block()._insert_op( index=op_index + 2, type="send",