diff --git a/paddle/fluid/operators/distributed/parameter_prefetch.cc b/paddle/fluid/operators/distributed/parameter_prefetch.cc index f6a2d5bbe52732dc7ce7c6b86461c65375b15c72..4cdeae81a1021c156ea9a2c2f97ec0fd982b80ad 100644 --- a/paddle/fluid/operators/distributed/parameter_prefetch.cc +++ b/paddle/fluid/operators/distributed/parameter_prefetch.cc @@ -102,8 +102,7 @@ static void MergeMultipleVarsIntoOneBySection( const std::string& out_name, const std::vector& out_var_names, const std::vector& height_section, const std::vector>& splited_ids, - const framework::ExecutionContext& context, - const framework::Scope& actual_scope, framework::Scope* scope, + const framework::ExecutionContext& context, framework::Scope* scope, platform::DeviceContext* actual_ctx) { PADDLE_ENFORCE_EQ(out_var_names.size(), height_section.size(), ""); @@ -115,9 +114,9 @@ static void MergeMultipleVarsIntoOneBySection( id_to_offset[ids_vector[i]].push_back(i); } - auto& id_tensor = actual_scope.FindVar(id_name)->Get(); + auto& id_tensor = scope.FindVar(id_name)->Get(); auto* out_tensor = - actual_scope.FindVar(out_name)->GetMutable(); + scope.FindVar(out_name)->GetMutable(); auto* out_tensor_data = out_tensor->mutable_data(id_tensor.place()); bool is_on_cpu_place = true; @@ -175,7 +174,7 @@ void prefetch(const std::string& id_name, const std::string& out_name, const std::vector& height_sections, const framework::ExecutionContext& context, const framework::Scope& scope) { - auto& local_scope = context.scope().NewScope(); + auto& local_scope = scope.NewScope(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& cpu_ctx = *pool.Get(platform::CPUPlace()); @@ -247,8 +246,8 @@ void prefetch(const std::string& id_name, const std::string& out_name, MergeMultipleVarsIntoOneBySection(id_name, ids_vector, out_name, out_var_names, height_sections, splited_ids, - context, scope, &local_scope, &actual_ctx); - context.scope().DeleteScope(&local_scope); + context, &local_scope, &actual_ctx); + scope.DeleteScope(&local_scope); } }; // namespace distributed diff --git a/paddle/fluid/operators/nce_op.h b/paddle/fluid/operators/nce_op.h index 2e51c67401fe4b6e042443cff31974bdd75e1f6a..862064be182f8906fce8bc4c2370344c2df80bed 100644 --- a/paddle/fluid/operators/nce_op.h +++ b/paddle/fluid/operators/nce_op.h @@ -170,7 +170,7 @@ class NCEKernel : public framework::OpKernel { auto height_sections = context.Attr>("height_sections"); auto table_names = context.Attr>("table_names"); - auto *ids = local_scope.Var("Ids"); + auto *ids = local_scope.Var("Ids@Local"); auto *x_tensor = ids->GetMutable(); x_tensor->mutable_data( framework::make_ddim({static_cast(labels.size()), 1}), @@ -179,12 +179,10 @@ class NCEKernel : public framework::OpKernel { std::memcpy(x_tensor->data(), labels.data(), labels.size() * sizeof(int64_t)); - local_scope.Var("Weight@Local") - ->GetMutable() - ->mutable_data(context.GetPlace()); + local_scope.Var("Weight@Local"); #ifdef PADDLE_WITH_DISTRIBUTE - operators::distributed::prefetch("Ids", "Weight@Local", table_names, + operators::distributed::prefetch("Ids@Local", "Weight@Local", table_names, epmap, height_sections, context, &local_scope); #else @@ -207,10 +205,7 @@ class NCEKernel : public framework::OpKernel { sample_out_data[i] += result(0); sample_out_data[i] = (1. / (1. + exp(-sample_out_data[i]))); } - - if (context.scope().HasKid(&local_scope)) { - context.scope().DeleteScope(&local_scope); - } + context.scope().DeleteScope(&local_scope); } else { auto weight_mat = EigenMatrix::From(*(context.Input("Weight")));