diff --git a/paddle/fluid/operators/distributed/parameter_prefetch.cc b/paddle/fluid/operators/distributed/parameter_prefetch.cc index cf14538b1c284d297242197088a66cc156b1762c..67b56bd218079bd55b49d70622007931580df811 100644 --- a/paddle/fluid/operators/distributed/parameter_prefetch.cc +++ b/paddle/fluid/operators/distributed/parameter_prefetch.cc @@ -102,8 +102,9 @@ static void MergeMultipleVarsIntoOneBySection( const std::string& out_name, const std::vector& out_var_names, const std::vector& height_section, const std::vector>& splited_ids, - const framework::ExecutionContext& context, framework::Scope* scope, - platform::DeviceContext* actual_ctx) { + const framework::ExecutionContext& context, + const framework::Scope& actual_scope, framework::Scope* scope, + platform::DeviceContext* actual_ctx, ) { PADDLE_ENFORCE_EQ(out_var_names.size(), height_section.size(), ""); auto cpu_place = platform::CPUPlace(); @@ -114,9 +115,9 @@ static void MergeMultipleVarsIntoOneBySection( id_to_offset[ids_vector[i]].push_back(i); } - auto& id_tensor = scope->FindVar(id_name)->Get(); + auto& id_tensor = actual_scope.FindVar(id_name)->Get(); auto* out_tensor = - scope->FindVar(out_name)->GetMutable(); + actual_scope.FindVar(out_name)->GetMutable(); auto* out_tensor_data = out_tensor->mutable_data(id_tensor.place()); bool is_on_cpu_place = true; @@ -172,8 +173,9 @@ void prefetch(const std::string& id_name, const std::string& out_name, const std::vector& table_names, const std::vector& epmap, const std::vector& height_sections, - const framework::ExecutionContext& context) { - auto& local_scope = context.scope().NewScope(); + const framework::ExecutionContext& context, + const framework::Scope& scope) { + auto& local_scope = scope.NewScope(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& cpu_ctx = *pool.Get(platform::CPUPlace()); @@ -245,9 +247,8 @@ void prefetch(const std::string& id_name, const std::string& out_name, MergeMultipleVarsIntoOneBySection(id_name, ids_vector, out_name, out_var_names, height_sections, splited_ids, - context, &local_scope, &actual_ctx); - - context.scope().DeleteScope(&local_scope); + context, scope, &local_scope, &actual_ctx); + scope.DeleteScope(&local_scope); } }; // namespace distributed diff --git a/paddle/fluid/operators/lookup_table_op.h b/paddle/fluid/operators/lookup_table_op.h index 3a73a7637c6d7d3eff7443802a4a52be9149e0ef..a7d0fd4856edc74237151c64f286d468ad86e7ca 100644 --- a/paddle/fluid/operators/lookup_table_op.h +++ b/paddle/fluid/operators/lookup_table_op.h @@ -59,7 +59,8 @@ class LookupTableKernel : public framework::OpKernel { // server #ifdef PADDLE_WITH_DISTRIBUTE operators::distributed::prefetch(id_name, out_name, table_names, epmap, - height_sections, context); + height_sections, context, + context.scope()); #else PADDLE_THROW( "paddle is not compiled with distribute support, can not do " diff --git a/paddle/fluid/operators/nce_op.h b/paddle/fluid/operators/nce_op.h index 6567b6534a42470cec7341aee54b7e391a33dc1b..9789e3038893a9b6c693d0f12c80fa05c1b33de0 100644 --- a/paddle/fluid/operators/nce_op.h +++ b/paddle/fluid/operators/nce_op.h @@ -170,18 +170,31 @@ class NCEKernel : public framework::OpKernel { auto height_sections = context.Attr>("height_sections"); auto table_names = context.Attr>("table_names"); - local_scope.Var("Ids"); - local_scope.Var("Weight"); + auto *ids = local_scope.Var("Ids"); + auto *x_tensor = ids->GetMutable(); + x_tensor->mutable_data( + framework::make_ddim({static_cast(labels.size()), 1}), + context.GetPlace()); + // copy. + std::memcpy(x_tensor->data(), labels.data(), + labels.size() * sizeof(int64_t)); + + local_scope.Var("Weight@Local") + ->GetMutable() + ->mutable_data(context.GetPlace()); #ifdef PADDLE_WITH_DISTRIBUTE - operators::distributed::prefetch("Ids", "Weight", table_names, epmap, - height_sections, context); + operators::distributed::prefetch("Ids", "Weight@Local", table_names, + epmap, height_sections, context, + &local_scope); #else PADDLE_THROW( "paddle is not compiled with distribute support, can not do " "parameter prefetch!"); +#endif - auto weight_mat = EigenMatrix::From(*(weight->Get())); + auto weight_mat = EigenMatrix::From( + (local_scope.Var("Weight@Local")->Get())); for (int64_t i = 0; i < sample_labels->numel(); ++i) { std::vector::iterator it = std::find(labels.begin(), labels.end(), sample_labels_data[i]); @@ -196,7 +209,7 @@ class NCEKernel : public framework::OpKernel { } context.scope().DeleteScope(&local_scope); -#endif + } else { auto weight_mat = EigenMatrix::From(*(context.Input("Weight")));