提交 527946df 编写于 作者: T tangwei12

add scope in prefetch

上级 b653ed05
...@@ -102,8 +102,9 @@ static void MergeMultipleVarsIntoOneBySection( ...@@ -102,8 +102,9 @@ static void MergeMultipleVarsIntoOneBySection(
const std::string& out_name, const std::vector<std::string>& out_var_names, const std::string& out_name, const std::vector<std::string>& out_var_names,
const std::vector<int>& height_section, const std::vector<int>& height_section,
const std::vector<std::vector<int64_t>>& splited_ids, const std::vector<std::vector<int64_t>>& splited_ids,
const framework::ExecutionContext& context, framework::Scope* scope, const framework::ExecutionContext& context,
platform::DeviceContext* actual_ctx) { const framework::Scope& actual_scope, framework::Scope* scope,
platform::DeviceContext* actual_ctx, ) {
PADDLE_ENFORCE_EQ(out_var_names.size(), height_section.size(), ""); PADDLE_ENFORCE_EQ(out_var_names.size(), height_section.size(), "");
auto cpu_place = platform::CPUPlace(); auto cpu_place = platform::CPUPlace();
...@@ -114,9 +115,9 @@ static void MergeMultipleVarsIntoOneBySection( ...@@ -114,9 +115,9 @@ static void MergeMultipleVarsIntoOneBySection(
id_to_offset[ids_vector[i]].push_back(i); id_to_offset[ids_vector[i]].push_back(i);
} }
auto& id_tensor = scope->FindVar(id_name)->Get<framework::LoDTensor>(); auto& id_tensor = actual_scope.FindVar(id_name)->Get<framework::LoDTensor>();
auto* out_tensor = auto* out_tensor =
scope->FindVar(out_name)->GetMutable<framework::LoDTensor>(); actual_scope.FindVar(out_name)->GetMutable<framework::LoDTensor>();
auto* out_tensor_data = out_tensor->mutable_data<float>(id_tensor.place()); auto* out_tensor_data = out_tensor->mutable_data<float>(id_tensor.place());
bool is_on_cpu_place = true; bool is_on_cpu_place = true;
...@@ -172,8 +173,9 @@ void prefetch(const std::string& id_name, const std::string& out_name, ...@@ -172,8 +173,9 @@ void prefetch(const std::string& id_name, const std::string& out_name,
const std::vector<std::string>& table_names, const std::vector<std::string>& table_names,
const std::vector<std::string>& epmap, const std::vector<std::string>& epmap,
const std::vector<int>& height_sections, const std::vector<int>& height_sections,
const framework::ExecutionContext& context) { const framework::ExecutionContext& context,
auto& local_scope = context.scope().NewScope(); const framework::Scope& scope) {
auto& local_scope = scope.NewScope();
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& cpu_ctx = *pool.Get(platform::CPUPlace()); auto& cpu_ctx = *pool.Get(platform::CPUPlace());
...@@ -245,9 +247,8 @@ void prefetch(const std::string& id_name, const std::string& out_name, ...@@ -245,9 +247,8 @@ void prefetch(const std::string& id_name, const std::string& out_name,
MergeMultipleVarsIntoOneBySection(id_name, ids_vector, out_name, MergeMultipleVarsIntoOneBySection(id_name, ids_vector, out_name,
out_var_names, height_sections, splited_ids, out_var_names, height_sections, splited_ids,
context, &local_scope, &actual_ctx); context, scope, &local_scope, &actual_ctx);
scope.DeleteScope(&local_scope);
context.scope().DeleteScope(&local_scope);
} }
}; // namespace distributed }; // namespace distributed
......
...@@ -59,7 +59,8 @@ class LookupTableKernel : public framework::OpKernel<T> { ...@@ -59,7 +59,8 @@ class LookupTableKernel : public framework::OpKernel<T> {
// server // server
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
operators::distributed::prefetch(id_name, out_name, table_names, epmap, operators::distributed::prefetch(id_name, out_name, table_names, epmap,
height_sections, context); height_sections, context,
context.scope());
#else #else
PADDLE_THROW( PADDLE_THROW(
"paddle is not compiled with distribute support, can not do " "paddle is not compiled with distribute support, can not do "
......
...@@ -170,18 +170,31 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -170,18 +170,31 @@ class NCEKernel : public framework::OpKernel<T> {
auto height_sections = context.Attr<std::vector<int>>("height_sections"); auto height_sections = context.Attr<std::vector<int>>("height_sections");
auto table_names = context.Attr<std::vector<std::string>>("table_names"); auto table_names = context.Attr<std::vector<std::string>>("table_names");
local_scope.Var("Ids"); auto *ids = local_scope.Var("Ids");
local_scope.Var("Weight"); auto *x_tensor = ids->GetMutable<framework::LoDTensor>();
x_tensor->mutable_data<int64_t>(
framework::make_ddim({static_cast<int64_t>(labels.size()), 1}),
context.GetPlace());
// copy.
std::memcpy(x_tensor->data<int64_t>(), labels.data(),
labels.size() * sizeof(int64_t));
local_scope.Var("Weight@Local")
->GetMutable<framework::Tensor>()
->mutable_data<T>(context.GetPlace());
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
operators::distributed::prefetch("Ids", "Weight", table_names, epmap, operators::distributed::prefetch("Ids", "Weight@Local", table_names,
height_sections, context); epmap, height_sections, context,
&local_scope);
#else #else
PADDLE_THROW( PADDLE_THROW(
"paddle is not compiled with distribute support, can not do " "paddle is not compiled with distribute support, can not do "
"parameter prefetch!"); "parameter prefetch!");
#endif
auto weight_mat = EigenMatrix<T>::From(*(weight->Get<T>())); auto weight_mat = EigenMatrix<T>::From(
(local_scope.Var("Weight@Local")->Get<framework::Tensor>()));
for (int64_t i = 0; i < sample_labels->numel(); ++i) { for (int64_t i = 0; i < sample_labels->numel(); ++i) {
std::vector<int64_t>::iterator it = std::vector<int64_t>::iterator it =
std::find(labels.begin(), labels.end(), sample_labels_data[i]); std::find(labels.begin(), labels.end(), sample_labels_data[i]);
...@@ -196,7 +209,7 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -196,7 +209,7 @@ class NCEKernel : public framework::OpKernel<T> {
} }
context.scope().DeleteScope(&local_scope); context.scope().DeleteScope(&local_scope);
#endif
} else { } else {
auto weight_mat = auto weight_mat =
EigenMatrix<T>::From(*(context.Input<Tensor>("Weight"))); EigenMatrix<T>::From(*(context.Input<Tensor>("Weight")));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册