提交 f1a3fb04 编写于 作者: Q Qiao Longfei

Merge branch 'fix_lookuptable_in_reduce' of...

Merge branch 'fix_lookuptable_in_reduce' of https://github.com/seiriosPlus/Paddle into cpu-for-1.1-merge
...@@ -680,7 +680,8 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, ...@@ -680,7 +680,8 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
} }
if (node->Op()->Type() == "split_byref" || if (node->Op()->Type() == "split_byref" ||
node->Op()->Type() == "split_selected_rows") { node->Op()->Type() == "split_selected_rows" ||
node->Op()->Type() == "split_ids") {
// TODO(paddle-dev): getting the first var is not safe. // TODO(paddle-dev): getting the first var is not safe.
op_dev_id = GetVarDeviceID(*result, input_var_names[0]); op_dev_id = GetVarDeviceID(*result, input_var_names[0]);
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
......
...@@ -69,6 +69,12 @@ bool IsDistTrainOp(ir::Node *node, const std::vector<std::string> &send_vars, ...@@ -69,6 +69,12 @@ bool IsDistTrainOp(ir::Node *node, const std::vector<std::string> &send_vars,
std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) { std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) {
return true; return true;
} }
if (!(var.find(".block") == std::string::npos &&
var.find(".pserver") == std::string::npos) &&
std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) {
return true;
}
} }
return false; return false;
}; };
......
...@@ -1126,7 +1126,8 @@ to transpile() call.") ...@@ -1126,7 +1126,8 @@ to transpile() call.")
inputs={ inputs={
'Ids': [program.global_block().vars[table_grad_name]] 'Ids': [program.global_block().vars[table_grad_name]]
}, },
outputs={"Out": self.trainer_side_table_grad_list}) outputs={"Out": self.trainer_side_table_grad_list},
attrs={RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE})
program.global_block()._insert_op( program.global_block()._insert_op(
index=op_index + 2, index=op_index + 2,
type="send", type="send",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册