提交 af2f5fc8 编写于 作者: Q Qiao Longfei

fix some bugs

上级 5d5e0656
...@@ -862,7 +862,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp( ...@@ -862,7 +862,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
if (node->Op()->Type() == "fetch_barrier") { if (node->Op()->Type() == "fetch_barrier") {
outvar_dev_id = outvar_dev_id =
GetVarDeviceID(*result, output->Name(), *sharded_var_device); GetVarDeviceID(*result, output->Name(), *sharded_var_device);
PADDLE_ENFORCE_NE(outvar_dev_id, -1); PADDLE_ENFORCE_NE(outvar_dev_id, -1, "output name %s", output->Name());
} }
p = places_[outvar_dev_id]; p = places_[outvar_dev_id];
ir::Node *new_node = nullptr; ir::Node *new_node = nullptr;
......
...@@ -100,7 +100,7 @@ inline void SplitIdsIntoMultipleVarsBySection( ...@@ -100,7 +100,7 @@ inline void SplitIdsIntoMultipleVarsBySection(
} }
} }
inline void MergeMultipleVarsIntoOnBySection( inline void MergeMultipleVarsIntoOneBySection(
const std::string& id_name, const std::string& out_name, const std::string& id_name, const std::string& out_name,
const std::vector<std::string>& out_var_names, const std::vector<std::string>& out_var_names,
const std::vector<int64_t>& height_section, const std::vector<int64_t>& height_section,
...@@ -125,6 +125,7 @@ inline void MergeMultipleVarsIntoOnBySection( ...@@ -125,6 +125,7 @@ inline void MergeMultipleVarsIntoOnBySection(
for (size_t section_idx = 0; section_idx < out_var_names.size(); for (size_t section_idx = 0; section_idx < out_var_names.size();
++section_idx) { ++section_idx) {
auto& ids_in_this_section = splited_ids[section_idx]; auto& ids_in_this_section = splited_ids[section_idx];
if (!ids_in_this_section.empty()) {
auto& prefetch_out_var = auto& prefetch_out_var =
scope->Var(out_var_names[section_idx])->Get<framework::LoDTensor>(); scope->Var(out_var_names[section_idx])->Get<framework::LoDTensor>();
const auto* out_var_data = prefetch_out_var.data<float>(); const auto* out_var_data = prefetch_out_var.data<float>();
...@@ -141,10 +142,14 @@ inline void MergeMultipleVarsIntoOnBySection( ...@@ -141,10 +142,14 @@ inline void MergeMultipleVarsIntoOnBySection(
auto& offsets = id_to_offset[origin_id]; auto& offsets = id_to_offset[origin_id];
for (auto& offset : offsets) { for (auto& offset : offsets) {
// should support GPU tensor // should support GPU tensor
memory::Copy(cpu_place, out_tensor_data + offset * row_numel, cpu_place, memory::Copy(cpu_place, out_tensor_data + offset * row_numel,
out_var_data + i * row_numel, sizeof(float) * row_numel); cpu_place, out_var_data + i * row_numel,
sizeof(float) * row_numel);
} }
} }
} else {
VLOG(30) << "ids in this section is empty";
}
} }
} }
...@@ -190,11 +195,12 @@ void prefetch(const std::string& id_name, const std::string& out_name, ...@@ -190,11 +195,12 @@ void prefetch(const std::string& id_name, const std::string& out_name,
VLOG(30) << "don't send no-initialied variable: " << out_var_names[i]; VLOG(30) << "don't send no-initialied variable: " << out_var_names[i];
} }
} }
for (size_t i = 0; i < rets.size(); i++) { for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
} }
MergeMultipleVarsIntoOnBySection(id_name, out_name, out_var_names, MergeMultipleVarsIntoOneBySection(id_name, out_name, out_var_names,
height_sections, splited_ids, context, height_sections, splited_ids, context,
&local_scope); &local_scope);
......
...@@ -444,7 +444,7 @@ class DistributeTranspiler(object): ...@@ -444,7 +444,7 @@ class DistributeTranspiler(object):
# connect deps to send op in async mode # connect deps to send op in async mode
recv_dep_in = self.grad_name_to_send_dummy_out[ recv_dep_in = self.grad_name_to_send_dummy_out[
self.param_name_to_grad_name[param_varname]] self.param_name_to_grad_name[param_varname]]
all_recv_outputs.extend(splited_var)
# get recv op_role_var, if not splited, the grad should have .trainer suffix # get recv op_role_var, if not splited, the grad should have .trainer suffix
# if splited, grad should be the original grad var name. ParallelExecutor # if splited, grad should be the original grad var name. ParallelExecutor
# will use op_role_var to get expected device place to run this op. # will use op_role_var to get expected device place to run this op.
...@@ -460,6 +460,7 @@ class DistributeTranspiler(object): ...@@ -460,6 +460,7 @@ class DistributeTranspiler(object):
self._update_remote_sparse_update_op(param_varname, self._update_remote_sparse_update_op(param_varname,
height_sections, eps) height_sections, eps)
else: else:
all_recv_outputs.extend(splited_var)
program.global_block().append_op( program.global_block().append_op(
type="recv", type="recv",
inputs={"X": [recv_dep_in]}, inputs={"X": [recv_dep_in]},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册