未验证 提交 0f6412c0 编写于 作者: L Leo Chen 提交者: GitHub

do not use scope in op kernel (#41316)

上级 1b58ce14
......@@ -26,17 +26,13 @@ template <typename DeviceContext, typename T>
class DistributedLookupTableKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto &scope = context.scope();
auto padding_idx = context.Attr<int64_t>("padding_idx");
auto table_id = context.Attr<int>("table_id");
bool is_test = context.Attr<bool>("is_test");
auto embedding_name = context.InputNames("W").front();
auto *var = context.InputVar("W");
int64_t emb_dim = 0;
auto *var = scope.FindVar(embedding_name);
if (var->IsType<framework::LoDTensor>()) {
emb_dim = var->Get<framework::LoDTensor>().dims()[1];
} else if (var->IsType<phi::SelectedRows>()) {
......@@ -61,35 +57,31 @@ class DistributedLookupTableKernel : public framework::OpKernel<T> {
} else {
auto inputs_variable = context.MultiInputVar("Ids");
auto outputs_variable = context.MultiOutputVar("Outputs");
auto inputs_name = context.InputNames("Ids");
auto outputs_name = context.OutputNames("Outputs");
auto cpu_place = platform::CPUPlace();
framework::Scope *tmp_scope = scope.NewTmpScope().release();
std::vector<const framework::LoDTensor *> tmp_input_vec;
auto input_var_size = inputs_variable.size();
std::vector<framework::LoDTensor *> tmp_output_vec;
auto output_var_size = outputs_variable.size();
std::vector<std::shared_ptr<framework::LoDTensor>> tmp_tensors;
// create temp input
for (size_t idx = 0; idx < input_var_size; ++idx) {
framework::Variable *tmp_input_var = tmp_scope->Var(inputs_name[idx]);
framework::LoDTensor *tmp_input_tensor =
tmp_input_var->GetMutable<framework::LoDTensor>();
tmp_tensors.emplace_back(std::make_shared<framework::LoDTensor>());
auto *p = tmp_tensors.back().get();
framework::TensorCopy(inputs_variable[idx]->Get<framework::LoDTensor>(),
cpu_place, context.device_context(),
tmp_input_tensor);
tmp_input_vec.push_back(tmp_input_tensor);
cpu_place, context.device_context(), p);
tmp_input_vec.push_back(p);
}
// create temp output
for (size_t idx = 0; idx < output_var_size; ++idx) {
framework::Variable *tmp_output_var = tmp_scope->Var(outputs_name[idx]);
framework::LoDTensor *tmp_output_tensor =
tmp_output_var->GetMutable<framework::LoDTensor>();
tmp_output_tensor->Resize(outputs[idx]->dims());
tmp_output_vec.push_back(tmp_output_tensor);
tmp_tensors.emplace_back(std::make_shared<framework::LoDTensor>());
auto *p = tmp_tensors.back().get();
p->Resize(outputs[idx]->dims());
tmp_output_vec.push_back(p);
}
// use fleet->PullSparse
......@@ -100,27 +92,21 @@ class DistributedLookupTableKernel : public framework::OpKernel<T> {
// cp temp to origin
for (size_t idx = 0; idx < output_var_size; ++idx) {
framework::Variable *tmp_output_var = tmp_scope->Var(outputs_name[idx]);
framework::LoDTensor *tmp_output_tensor =
tmp_output_var->GetMutable<framework::LoDTensor>();
framework::TensorCopy(
*tmp_output_tensor, context.GetPlace(), context.device_context(),
*tmp_output_vec[idx], context.GetPlace(), context.device_context(),
outputs_variable[idx]->GetMutable<framework::LoDTensor>());
}
delete tmp_scope;
}
auto id_names = context.InputNames("Ids");
auto out_names = context.OutputNames("Outputs");
auto lookup_table_version =
context.Attr<std::string>("lookup_table_version");
auto id_vars = context.MultiInputVar("Ids");
auto out_vars = context.MultiOutputVar("Outputs");
if (lookup_table_version == "lookup_table_v2") {
for (size_t i = 0; i < id_names.size(); ++i) {
auto *id_var = scope.FindVar(id_names[i]);
auto *out_var = scope.FindVar(out_names[i]);
auto *id_tensor = id_var->GetMutable<framework::LoDTensor>();
auto *out_tensor = out_var->GetMutable<framework::LoDTensor>();
for (size_t i = 0; i < id_vars.size(); ++i) {
auto *id_tensor = id_vars[i]->GetMutable<framework::LoDTensor>();
auto *out_tensor = out_vars[i]->GetMutable<framework::LoDTensor>();
auto id_dims = id_tensor->dims();
out_tensor->Resize(phi::make_ddim({static_cast<int64_t>(id_dims[0]),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册