提交 f1a3fb04 编写于 作者: Q Qiao Longfei

Merge branch 'fix_lookuptable_in_reduce' of...

Merge branch 'fix_lookuptable_in_reduce' of https://github.com/seiriosPlus/Paddle into cpu-for-1.1-merge
......@@ -680,7 +680,8 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
}
if (node->Op()->Type() == "split_byref" ||
node->Op()->Type() == "split_selected_rows") {
node->Op()->Type() == "split_selected_rows" ||
node->Op()->Type() == "split_ids") {
// TODO(paddle-dev): getting the first var is not safe.
op_dev_id = GetVarDeviceID(*result, input_var_names[0]);
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
......
......@@ -69,6 +69,12 @@ bool IsDistTrainOp(ir::Node *node, const std::vector<std::string> &send_vars,
std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) {
return true;
}
if (!(var.find(".block") == std::string::npos &&
var.find(".pserver") == std::string::npos) &&
std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) {
return true;
}
}
return false;
};
......
......@@ -1126,7 +1126,8 @@ to transpile() call.")
inputs={
'Ids': [program.global_block().vars[table_grad_name]]
},
outputs={"Out": self.trainer_side_table_grad_list})
outputs={"Out": self.trainer_side_table_grad_list},
attrs={RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE})
program.global_block()._insert_op(
index=op_index + 2,
type="send",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册