未验证 提交 0f6412c0 编写于 作者: L Leo Chen 提交者: GitHub

do not use scope in op kernel (#41316)

上级 1b58ce14
...@@ -26,17 +26,13 @@ template <typename DeviceContext, typename T> ...@@ -26,17 +26,13 @@ template <typename DeviceContext, typename T>
class DistributedLookupTableKernel : public framework::OpKernel<T> { class DistributedLookupTableKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto &scope = context.scope();
auto padding_idx = context.Attr<int64_t>("padding_idx"); auto padding_idx = context.Attr<int64_t>("padding_idx");
auto table_id = context.Attr<int>("table_id"); auto table_id = context.Attr<int>("table_id");
bool is_test = context.Attr<bool>("is_test"); bool is_test = context.Attr<bool>("is_test");
auto embedding_name = context.InputNames("W").front(); auto *var = context.InputVar("W");
int64_t emb_dim = 0; int64_t emb_dim = 0;
auto *var = scope.FindVar(embedding_name);
if (var->IsType<framework::LoDTensor>()) { if (var->IsType<framework::LoDTensor>()) {
emb_dim = var->Get<framework::LoDTensor>().dims()[1]; emb_dim = var->Get<framework::LoDTensor>().dims()[1];
} else if (var->IsType<phi::SelectedRows>()) { } else if (var->IsType<phi::SelectedRows>()) {
...@@ -61,35 +57,31 @@ class DistributedLookupTableKernel : public framework::OpKernel<T> { ...@@ -61,35 +57,31 @@ class DistributedLookupTableKernel : public framework::OpKernel<T> {
} else { } else {
auto inputs_variable = context.MultiInputVar("Ids"); auto inputs_variable = context.MultiInputVar("Ids");
auto outputs_variable = context.MultiOutputVar("Outputs"); auto outputs_variable = context.MultiOutputVar("Outputs");
auto inputs_name = context.InputNames("Ids");
auto outputs_name = context.OutputNames("Outputs");
auto cpu_place = platform::CPUPlace(); auto cpu_place = platform::CPUPlace();
framework::Scope *tmp_scope = scope.NewTmpScope().release();
std::vector<const framework::LoDTensor *> tmp_input_vec; std::vector<const framework::LoDTensor *> tmp_input_vec;
auto input_var_size = inputs_variable.size(); auto input_var_size = inputs_variable.size();
std::vector<framework::LoDTensor *> tmp_output_vec; std::vector<framework::LoDTensor *> tmp_output_vec;
auto output_var_size = outputs_variable.size(); auto output_var_size = outputs_variable.size();
std::vector<std::shared_ptr<framework::LoDTensor>> tmp_tensors;
// create temp input // create temp input
for (size_t idx = 0; idx < input_var_size; ++idx) { for (size_t idx = 0; idx < input_var_size; ++idx) {
framework::Variable *tmp_input_var = tmp_scope->Var(inputs_name[idx]); tmp_tensors.emplace_back(std::make_shared<framework::LoDTensor>());
framework::LoDTensor *tmp_input_tensor = auto *p = tmp_tensors.back().get();
tmp_input_var->GetMutable<framework::LoDTensor>();
framework::TensorCopy(inputs_variable[idx]->Get<framework::LoDTensor>(), framework::TensorCopy(inputs_variable[idx]->Get<framework::LoDTensor>(),
cpu_place, context.device_context(), cpu_place, context.device_context(), p);
tmp_input_tensor); tmp_input_vec.push_back(p);
tmp_input_vec.push_back(tmp_input_tensor);
} }
// create temp output // create temp output
for (size_t idx = 0; idx < output_var_size; ++idx) { for (size_t idx = 0; idx < output_var_size; ++idx) {
framework::Variable *tmp_output_var = tmp_scope->Var(outputs_name[idx]); tmp_tensors.emplace_back(std::make_shared<framework::LoDTensor>());
framework::LoDTensor *tmp_output_tensor = auto *p = tmp_tensors.back().get();
tmp_output_var->GetMutable<framework::LoDTensor>(); p->Resize(outputs[idx]->dims());
tmp_output_tensor->Resize(outputs[idx]->dims()); tmp_output_vec.push_back(p);
tmp_output_vec.push_back(tmp_output_tensor);
} }
// use fleet->PullSparse // use fleet->PullSparse
...@@ -100,27 +92,21 @@ class DistributedLookupTableKernel : public framework::OpKernel<T> { ...@@ -100,27 +92,21 @@ class DistributedLookupTableKernel : public framework::OpKernel<T> {
// cp temp to origin // cp temp to origin
for (size_t idx = 0; idx < output_var_size; ++idx) { for (size_t idx = 0; idx < output_var_size; ++idx) {
framework::Variable *tmp_output_var = tmp_scope->Var(outputs_name[idx]);
framework::LoDTensor *tmp_output_tensor =
tmp_output_var->GetMutable<framework::LoDTensor>();
framework::TensorCopy( framework::TensorCopy(
*tmp_output_tensor, context.GetPlace(), context.device_context(), *tmp_output_vec[idx], context.GetPlace(), context.device_context(),
outputs_variable[idx]->GetMutable<framework::LoDTensor>()); outputs_variable[idx]->GetMutable<framework::LoDTensor>());
} }
delete tmp_scope;
} }
auto id_names = context.InputNames("Ids");
auto out_names = context.OutputNames("Outputs");
auto lookup_table_version = auto lookup_table_version =
context.Attr<std::string>("lookup_table_version"); context.Attr<std::string>("lookup_table_version");
auto id_vars = context.MultiInputVar("Ids");
auto out_vars = context.MultiOutputVar("Outputs");
if (lookup_table_version == "lookup_table_v2") { if (lookup_table_version == "lookup_table_v2") {
for (size_t i = 0; i < id_names.size(); ++i) { for (size_t i = 0; i < id_vars.size(); ++i) {
auto *id_var = scope.FindVar(id_names[i]); auto *id_tensor = id_vars[i]->GetMutable<framework::LoDTensor>();
auto *out_var = scope.FindVar(out_names[i]); auto *out_tensor = out_vars[i]->GetMutable<framework::LoDTensor>();
auto *id_tensor = id_var->GetMutable<framework::LoDTensor>();
auto *out_tensor = out_var->GetMutable<framework::LoDTensor>();
auto id_dims = id_tensor->dims(); auto id_dims = id_tensor->dims();
out_tensor->Resize(phi::make_ddim({static_cast<int64_t>(id_dims[0]), out_tensor->Resize(phi::make_ddim({static_cast<int64_t>(id_dims[0]),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册