From 0f6412c0c645e9a3c901cbcf4fa83c314ab85a37 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Sat, 2 Apr 2022 19:08:56 +0800 Subject: [PATCH] do not use scope in op kernel (#41316) --- .../pscore/distributed_lookup_table_op.h | 48 +++++++------------ 1 file changed, 17 insertions(+), 31 deletions(-) diff --git a/paddle/fluid/operators/pscore/distributed_lookup_table_op.h b/paddle/fluid/operators/pscore/distributed_lookup_table_op.h index da439407a42..c2717c19b2d 100644 --- a/paddle/fluid/operators/pscore/distributed_lookup_table_op.h +++ b/paddle/fluid/operators/pscore/distributed_lookup_table_op.h @@ -26,17 +26,13 @@ template class DistributedLookupTableKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { - auto &scope = context.scope(); - auto padding_idx = context.Attr("padding_idx"); auto table_id = context.Attr("table_id"); bool is_test = context.Attr("is_test"); - auto embedding_name = context.InputNames("W").front(); + auto *var = context.InputVar("W"); int64_t emb_dim = 0; - auto *var = scope.FindVar(embedding_name); - if (var->IsType()) { emb_dim = var->Get().dims()[1]; } else if (var->IsType()) { @@ -61,35 +57,31 @@ class DistributedLookupTableKernel : public framework::OpKernel { } else { auto inputs_variable = context.MultiInputVar("Ids"); auto outputs_variable = context.MultiOutputVar("Outputs"); - auto inputs_name = context.InputNames("Ids"); - auto outputs_name = context.OutputNames("Outputs"); auto cpu_place = platform::CPUPlace(); - framework::Scope *tmp_scope = scope.NewTmpScope().release(); std::vector tmp_input_vec; auto input_var_size = inputs_variable.size(); std::vector tmp_output_vec; auto output_var_size = outputs_variable.size(); + std::vector> tmp_tensors; + // create temp input for (size_t idx = 0; idx < input_var_size; ++idx) { - framework::Variable *tmp_input_var = tmp_scope->Var(inputs_name[idx]); - framework::LoDTensor *tmp_input_tensor = - tmp_input_var->GetMutable(); + tmp_tensors.emplace_back(std::make_shared()); + auto *p = tmp_tensors.back().get(); framework::TensorCopy(inputs_variable[idx]->Get(), - cpu_place, context.device_context(), - tmp_input_tensor); - tmp_input_vec.push_back(tmp_input_tensor); + cpu_place, context.device_context(), p); + tmp_input_vec.push_back(p); } // create temp output for (size_t idx = 0; idx < output_var_size; ++idx) { - framework::Variable *tmp_output_var = tmp_scope->Var(outputs_name[idx]); - framework::LoDTensor *tmp_output_tensor = - tmp_output_var->GetMutable(); - tmp_output_tensor->Resize(outputs[idx]->dims()); - tmp_output_vec.push_back(tmp_output_tensor); + tmp_tensors.emplace_back(std::make_shared()); + auto *p = tmp_tensors.back().get(); + p->Resize(outputs[idx]->dims()); + tmp_output_vec.push_back(p); } // use fleet->PullSparse @@ -100,27 +92,21 @@ class DistributedLookupTableKernel : public framework::OpKernel { // cp temp to origin for (size_t idx = 0; idx < output_var_size; ++idx) { - framework::Variable *tmp_output_var = tmp_scope->Var(outputs_name[idx]); - framework::LoDTensor *tmp_output_tensor = - tmp_output_var->GetMutable(); framework::TensorCopy( - *tmp_output_tensor, context.GetPlace(), context.device_context(), + *tmp_output_vec[idx], context.GetPlace(), context.device_context(), outputs_variable[idx]->GetMutable()); } - delete tmp_scope; } - auto id_names = context.InputNames("Ids"); - auto out_names = context.OutputNames("Outputs"); auto lookup_table_version = context.Attr("lookup_table_version"); + auto id_vars = context.MultiInputVar("Ids"); + auto out_vars = context.MultiOutputVar("Outputs"); if (lookup_table_version == "lookup_table_v2") { - for (size_t i = 0; i < id_names.size(); ++i) { - auto *id_var = scope.FindVar(id_names[i]); - auto *out_var = scope.FindVar(out_names[i]); - auto *id_tensor = id_var->GetMutable(); - auto *out_tensor = out_var->GetMutable(); + for (size_t i = 0; i < id_vars.size(); ++i) { + auto *id_tensor = id_vars[i]->GetMutable(); + auto *out_tensor = out_vars[i]->GetMutable(); auto id_dims = id_tensor->dims(); out_tensor->Resize(phi::make_ddim({static_cast(id_dims[0]), -- GitLab