提交 bb2e7f0b 编写于 作者: T tangwei12

add scope in prefetch

上级 527946df
...@@ -104,7 +104,7 @@ static void MergeMultipleVarsIntoOneBySection( ...@@ -104,7 +104,7 @@ static void MergeMultipleVarsIntoOneBySection(
const std::vector<std::vector<int64_t>>& splited_ids, const std::vector<std::vector<int64_t>>& splited_ids,
const framework::ExecutionContext& context, const framework::ExecutionContext& context,
const framework::Scope& actual_scope, framework::Scope* scope, const framework::Scope& actual_scope, framework::Scope* scope,
platform::DeviceContext* actual_ctx, ) { platform::DeviceContext* actual_ctx) {
PADDLE_ENFORCE_EQ(out_var_names.size(), height_section.size(), ""); PADDLE_ENFORCE_EQ(out_var_names.size(), height_section.size(), "");
auto cpu_place = platform::CPUPlace(); auto cpu_place = platform::CPUPlace();
...@@ -175,7 +175,7 @@ void prefetch(const std::string& id_name, const std::string& out_name, ...@@ -175,7 +175,7 @@ void prefetch(const std::string& id_name, const std::string& out_name,
const std::vector<int>& height_sections, const std::vector<int>& height_sections,
const framework::ExecutionContext& context, const framework::ExecutionContext& context,
const framework::Scope& scope) { const framework::Scope& scope) {
auto& local_scope = scope.NewScope(); auto& local_scope = context.scope().NewScope();
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& cpu_ctx = *pool.Get(platform::CPUPlace()); auto& cpu_ctx = *pool.Get(platform::CPUPlace());
...@@ -192,7 +192,7 @@ void prefetch(const std::string& id_name, const std::string& out_name, ...@@ -192,7 +192,7 @@ void prefetch(const std::string& id_name, const std::string& out_name,
out_var_names.push_back(out_name + "@" + epmap[i]); out_var_names.push_back(out_name + "@" + epmap[i]);
} }
auto& id_tensor = local_scope.FindVar(id_name)->Get<framework::LoDTensor>(); auto& id_tensor = scope.FindVar(id_name)->Get<framework::LoDTensor>();
std::vector<int64_t> ids_vector; std::vector<int64_t> ids_vector;
if (platform::is_cpu_place(id_tensor.place())) { if (platform::is_cpu_place(id_tensor.place())) {
auto* id_data = id_tensor.data<int64_t>(); auto* id_data = id_tensor.data<int64_t>();
...@@ -248,7 +248,7 @@ void prefetch(const std::string& id_name, const std::string& out_name, ...@@ -248,7 +248,7 @@ void prefetch(const std::string& id_name, const std::string& out_name,
MergeMultipleVarsIntoOneBySection(id_name, ids_vector, out_name, MergeMultipleVarsIntoOneBySection(id_name, ids_vector, out_name,
out_var_names, height_sections, splited_ids, out_var_names, height_sections, splited_ids,
context, scope, &local_scope, &actual_ctx); context, scope, &local_scope, &actual_ctx);
scope.DeleteScope(&local_scope); context.scope().DeleteScope(&local_scope);
} }
}; // namespace distributed }; // namespace distributed
......
...@@ -27,7 +27,8 @@ void prefetch(const std::string& id_name, const std::string& out_name, ...@@ -27,7 +27,8 @@ void prefetch(const std::string& id_name, const std::string& out_name,
const std::vector<std::string>& table_names, const std::vector<std::string>& table_names,
const std::vector<std::string>& epmap, const std::vector<std::string>& epmap,
const std::vector<int>& height_sections, const std::vector<int>& height_sections,
const framework::ExecutionContext& context); const framework::ExecutionContext& context,
const framework::Scope& scope);
}; // namespace distributed }; // namespace distributed
}; // namespace operators }; // namespace operators
......
...@@ -180,7 +180,7 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -180,7 +180,7 @@ class NCEKernel : public framework::OpKernel<T> {
labels.size() * sizeof(int64_t)); labels.size() * sizeof(int64_t));
local_scope.Var("Weight@Local") local_scope.Var("Weight@Local")
->GetMutable<framework::Tensor>() ->GetMutable<framework::LoDTensor>()
->mutable_data<T>(context.GetPlace()); ->mutable_data<T>(context.GetPlace());
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
...@@ -194,7 +194,7 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -194,7 +194,7 @@ class NCEKernel : public framework::OpKernel<T> {
#endif #endif
auto weight_mat = EigenMatrix<T>::From( auto weight_mat = EigenMatrix<T>::From(
(local_scope.Var("Weight@Local")->Get<framework::Tensor>())); (local_scope.Var("Weight@Local")->Get<framework::LoDTensor>()));
for (int64_t i = 0; i < sample_labels->numel(); ++i) { for (int64_t i = 0; i < sample_labels->numel(); ++i) {
std::vector<int64_t>::iterator it = std::vector<int64_t>::iterator it =
std::find(labels.begin(), labels.end(), sample_labels_data[i]); std::find(labels.begin(), labels.end(), sample_labels_data[i]);
...@@ -208,8 +208,9 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -208,8 +208,9 @@ class NCEKernel : public framework::OpKernel<T> {
sample_out_data[i] = (1. / (1. + exp(-sample_out_data[i]))); sample_out_data[i] = (1. / (1. + exp(-sample_out_data[i])));
} }
context.scope().DeleteScope(&local_scope); if (context.scope().HasKid(&local_scope)) {
context.scope().DeleteScope(&local_scope);
}
} else { } else {
auto weight_mat = auto weight_mat =
EigenMatrix<T>::From(*(context.Input<Tensor>("Weight"))); EigenMatrix<T>::From(*(context.Input<Tensor>("Weight")));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册