From bb2e7f0bbed1cfcf47b5b8e90bc9e35b46c13b50 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Sat, 8 Dec 2018 12:31:33 +0800 Subject: [PATCH] add scope in prefetch --- paddle/fluid/operators/distributed/parameter_prefetch.cc | 8 ++++---- paddle/fluid/operators/distributed/parameter_prefetch.h | 3 ++- paddle/fluid/operators/nce_op.h | 9 +++++---- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/distributed/parameter_prefetch.cc b/paddle/fluid/operators/distributed/parameter_prefetch.cc index 67b56bd21..f6a2d5bbe 100644 --- a/paddle/fluid/operators/distributed/parameter_prefetch.cc +++ b/paddle/fluid/operators/distributed/parameter_prefetch.cc @@ -104,7 +104,7 @@ static void MergeMultipleVarsIntoOneBySection( const std::vector>& splited_ids, const framework::ExecutionContext& context, const framework::Scope& actual_scope, framework::Scope* scope, - platform::DeviceContext* actual_ctx, ) { + platform::DeviceContext* actual_ctx) { PADDLE_ENFORCE_EQ(out_var_names.size(), height_section.size(), ""); auto cpu_place = platform::CPUPlace(); @@ -175,7 +175,7 @@ void prefetch(const std::string& id_name, const std::string& out_name, const std::vector& height_sections, const framework::ExecutionContext& context, const framework::Scope& scope) { - auto& local_scope = scope.NewScope(); + auto& local_scope = context.scope().NewScope(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& cpu_ctx = *pool.Get(platform::CPUPlace()); @@ -192,7 +192,7 @@ void prefetch(const std::string& id_name, const std::string& out_name, out_var_names.push_back(out_name + "@" + epmap[i]); } - auto& id_tensor = local_scope.FindVar(id_name)->Get(); + auto& id_tensor = scope.FindVar(id_name)->Get(); std::vector ids_vector; if (platform::is_cpu_place(id_tensor.place())) { auto* id_data = id_tensor.data(); @@ -248,7 +248,7 @@ void prefetch(const std::string& id_name, const std::string& out_name, MergeMultipleVarsIntoOneBySection(id_name, ids_vector, out_name, out_var_names, height_sections, splited_ids, context, scope, &local_scope, &actual_ctx); - scope.DeleteScope(&local_scope); + context.scope().DeleteScope(&local_scope); } }; // namespace distributed diff --git a/paddle/fluid/operators/distributed/parameter_prefetch.h b/paddle/fluid/operators/distributed/parameter_prefetch.h index 53b0fbfb5..53482c4c4 100644 --- a/paddle/fluid/operators/distributed/parameter_prefetch.h +++ b/paddle/fluid/operators/distributed/parameter_prefetch.h @@ -27,7 +27,8 @@ void prefetch(const std::string& id_name, const std::string& out_name, const std::vector& table_names, const std::vector& epmap, const std::vector& height_sections, - const framework::ExecutionContext& context); + const framework::ExecutionContext& context, + const framework::Scope& scope); }; // namespace distributed }; // namespace operators diff --git a/paddle/fluid/operators/nce_op.h b/paddle/fluid/operators/nce_op.h index 9789e3038..2e51c6740 100644 --- a/paddle/fluid/operators/nce_op.h +++ b/paddle/fluid/operators/nce_op.h @@ -180,7 +180,7 @@ class NCEKernel : public framework::OpKernel { labels.size() * sizeof(int64_t)); local_scope.Var("Weight@Local") - ->GetMutable() + ->GetMutable() ->mutable_data(context.GetPlace()); #ifdef PADDLE_WITH_DISTRIBUTE @@ -194,7 +194,7 @@ class NCEKernel : public framework::OpKernel { #endif auto weight_mat = EigenMatrix::From( - (local_scope.Var("Weight@Local")->Get())); + (local_scope.Var("Weight@Local")->Get())); for (int64_t i = 0; i < sample_labels->numel(); ++i) { std::vector::iterator it = std::find(labels.begin(), labels.end(), sample_labels_data[i]); @@ -208,8 +208,9 @@ class NCEKernel : public framework::OpKernel { sample_out_data[i] = (1. / (1. + exp(-sample_out_data[i]))); } - context.scope().DeleteScope(&local_scope); - + if (context.scope().HasKid(&local_scope)) { + context.scope().DeleteScope(&local_scope); + } } else { auto weight_mat = EigenMatrix::From(*(context.Input("Weight"))); -- GitLab