提交 5856c2f3 编写于 作者: Q Qiao Longfei

change Var to FindVar

上级 312b7786
...@@ -63,7 +63,7 @@ inline std::vector<int64_t> ToAbsoluteSection( ...@@ -63,7 +63,7 @@ inline std::vector<int64_t> ToAbsoluteSection(
inline std::vector<std::vector<int64_t>> SplitIds( inline std::vector<std::vector<int64_t>> SplitIds(
const std::string& id_name, const std::vector<int64_t>& height_section, const std::string& id_name, const std::vector<int64_t>& height_section,
framework::Scope* scope) { framework::Scope* scope) {
auto& id_tensor = scope->Var(id_name)->Get<framework::LoDTensor>(); auto& id_tensor = scope->FindVar(id_name)->Get<framework::LoDTensor>();
auto* id_data = id_tensor.data<int64_t>(); auto* id_data = id_tensor.data<int64_t>();
std::set<int64_t> all_ids; std::set<int64_t> all_ids;
for (size_t i = 0; i < id_tensor.numel(); ++i) { for (size_t i = 0; i < id_tensor.numel(); ++i) {
...@@ -111,14 +111,15 @@ inline void MergeMultipleVarsIntoOnBySection( ...@@ -111,14 +111,15 @@ inline void MergeMultipleVarsIntoOnBySection(
auto cpu_place = platform::CPUPlace(); auto cpu_place = platform::CPUPlace();
auto abs_sections = ToAbsoluteSection(height_section); auto abs_sections = ToAbsoluteSection(height_section);
auto& id_tensor = scope->Var(id_name)->Get<framework::LoDTensor>(); auto& id_tensor = scope->FindVar(id_name)->Get<framework::LoDTensor>();
auto* id_data = id_tensor.data<int64_t>(); auto* id_data = id_tensor.data<int64_t>();
std::unordered_map<int64_t, std::vector<size_t>> id_to_offset; std::unordered_map<int64_t, std::vector<size_t>> id_to_offset;
for (size_t i = 0; i < id_tensor.numel(); ++i) { for (size_t i = 0; i < id_tensor.numel(); ++i) {
id_to_offset[id_data[i]].push_back(i); id_to_offset[id_data[i]].push_back(i);
} }
auto* out_tensor = scope->Var(out_name)->GetMutable<framework::LoDTensor>(); auto* out_tensor =
scope->FindVar(out_name)->GetMutable<framework::LoDTensor>();
auto* out_tensor_data = out_tensor->mutable_data<float>(context.GetPlace()); auto* out_tensor_data = out_tensor->mutable_data<float>(context.GetPlace());
for (size_t section_idx = 0; section_idx < out_var_names.size(); for (size_t section_idx = 0; section_idx < out_var_names.size();
......
...@@ -56,7 +56,8 @@ class LookupTableKernel : public framework::OpKernel<T> { ...@@ -56,7 +56,8 @@ class LookupTableKernel : public framework::OpKernel<T> {
context.Attr<std::vector<int64_t>>("height_sections"); context.Attr<std::vector<int64_t>>("height_sections");
if (remote_prefetch) { if (remote_prefetch) {
// if emap is not empty, then the paramter will be fetched from remote parameter // if emap is not empty, then the parameter will be fetched from remote
// parameter
// server // server
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
operators::distributed::prefetch(id_name, out_name, table_name, epmap, operators::distributed::prefetch(id_name, out_name, table_name, epmap,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册