提交 af2f5fc8 编写于 作者: Q Qiao Longfei

fix some bugs

上级 5d5e0656
...@@ -862,7 +862,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp( ...@@ -862,7 +862,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
if (node->Op()->Type() == "fetch_barrier") { if (node->Op()->Type() == "fetch_barrier") {
outvar_dev_id = outvar_dev_id =
GetVarDeviceID(*result, output->Name(), *sharded_var_device); GetVarDeviceID(*result, output->Name(), *sharded_var_device);
PADDLE_ENFORCE_NE(outvar_dev_id, -1); PADDLE_ENFORCE_NE(outvar_dev_id, -1, "output name %s", output->Name());
} }
p = places_[outvar_dev_id]; p = places_[outvar_dev_id];
ir::Node *new_node = nullptr; ir::Node *new_node = nullptr;
......
...@@ -100,7 +100,7 @@ inline void SplitIdsIntoMultipleVarsBySection( ...@@ -100,7 +100,7 @@ inline void SplitIdsIntoMultipleVarsBySection(
} }
} }
inline void MergeMultipleVarsIntoOnBySection( inline void MergeMultipleVarsIntoOneBySection(
const std::string& id_name, const std::string& out_name, const std::string& id_name, const std::string& out_name,
const std::vector<std::string>& out_var_names, const std::vector<std::string>& out_var_names,
const std::vector<int64_t>& height_section, const std::vector<int64_t>& height_section,
...@@ -125,25 +125,30 @@ inline void MergeMultipleVarsIntoOnBySection( ...@@ -125,25 +125,30 @@ inline void MergeMultipleVarsIntoOnBySection(
for (size_t section_idx = 0; section_idx < out_var_names.size(); for (size_t section_idx = 0; section_idx < out_var_names.size();
++section_idx) { ++section_idx) {
auto& ids_in_this_section = splited_ids[section_idx]; auto& ids_in_this_section = splited_ids[section_idx];
auto& prefetch_out_var = if (!ids_in_this_section.empty()) {
scope->Var(out_var_names[section_idx])->Get<framework::LoDTensor>(); auto& prefetch_out_var =
const auto* out_var_data = prefetch_out_var.data<float>(); scope->Var(out_var_names[section_idx])->Get<framework::LoDTensor>();
auto& dims = prefetch_out_var.dims(); const auto* out_var_data = prefetch_out_var.data<float>();
auto& dims = prefetch_out_var.dims();
PADDLE_ENFORCE_EQ(dims.size(), 2, "");
PADDLE_ENFORCE_EQ(ids_in_this_section.size(), dims[0]); PADDLE_ENFORCE_EQ(dims.size(), 2, "");
PADDLE_ENFORCE_EQ(ids_in_this_section.size(), dims[0]);
auto row_numel = dims[1];
auto row_numel = dims[1];
for (size_t i = 0; i < dims[0]; ++i) {
auto id = ids_in_this_section[i]; for (size_t i = 0; i < dims[0]; ++i) {
auto origin_id = id + abs_sections[section_idx]; auto id = ids_in_this_section[i];
auto& offsets = id_to_offset[origin_id]; auto origin_id = id + abs_sections[section_idx];
for (auto& offset : offsets) { auto& offsets = id_to_offset[origin_id];
// should support GPU tensor for (auto& offset : offsets) {
memory::Copy(cpu_place, out_tensor_data + offset * row_numel, cpu_place, // should support GPU tensor
out_var_data + i * row_numel, sizeof(float) * row_numel); memory::Copy(cpu_place, out_tensor_data + offset * row_numel,
cpu_place, out_var_data + i * row_numel,
sizeof(float) * row_numel);
}
} }
} else {
VLOG(30) << "ids in this section is empty";
} }
} }
} }
...@@ -190,13 +195,14 @@ void prefetch(const std::string& id_name, const std::string& out_name, ...@@ -190,13 +195,14 @@ void prefetch(const std::string& id_name, const std::string& out_name,
VLOG(30) << "don't send no-initialied variable: " << out_var_names[i]; VLOG(30) << "don't send no-initialied variable: " << out_var_names[i];
} }
} }
for (size_t i = 0; i < rets.size(); i++) { for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
} }
MergeMultipleVarsIntoOnBySection(id_name, out_name, out_var_names, MergeMultipleVarsIntoOneBySection(id_name, out_name, out_var_names,
height_sections, splited_ids, context, height_sections, splited_ids, context,
&local_scope); &local_scope);
context.scope().DeleteScope(&local_scope); context.scope().DeleteScope(&local_scope);
} }
......
...@@ -444,7 +444,7 @@ class DistributeTranspiler(object): ...@@ -444,7 +444,7 @@ class DistributeTranspiler(object):
# connect deps to send op in async mode # connect deps to send op in async mode
recv_dep_in = self.grad_name_to_send_dummy_out[ recv_dep_in = self.grad_name_to_send_dummy_out[
self.param_name_to_grad_name[param_varname]] self.param_name_to_grad_name[param_varname]]
all_recv_outputs.extend(splited_var)
# get recv op_role_var, if not splited, the grad should have .trainer suffix # get recv op_role_var, if not splited, the grad should have .trainer suffix
# if splited, grad should be the original grad var name. ParallelExecutor # if splited, grad should be the original grad var name. ParallelExecutor
# will use op_role_var to get expected device place to run this op. # will use op_role_var to get expected device place to run this op.
...@@ -460,6 +460,7 @@ class DistributeTranspiler(object): ...@@ -460,6 +460,7 @@ class DistributeTranspiler(object):
self._update_remote_sparse_update_op(param_varname, self._update_remote_sparse_update_op(param_varname,
height_sections, eps) height_sections, eps)
else: else:
all_recv_outputs.extend(splited_var)
program.global_block().append_op( program.global_block().append_op(
type="recv", type="recv",
inputs={"X": [recv_dep_in]}, inputs={"X": [recv_dep_in]},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册