diff --git a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc index fb8e557c88e35cb190e6795da6f261c61c7bee6a..055c9ff2383cb6eaa3248c83aef90322b93e6a88 100644 --- a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc @@ -185,8 +185,12 @@ FetchList NewIRInterpreter::Run(const std::vector& feed_names, if (!is_build_) { LOG_FIRST_N(INFO, 1) << "New Executor is Running."; - ::ir::BuildScope( - *ir_program_->block(), scope_, local_scope_, &value_2_var_name_map_); + ::ir::BuildScope(*ir_program_->block(), + InnerScope(), + &value_2_var_name_, + &variable_2_var_name_, + &var_name_2_id_, + &variable_list_); std::vector op_func_nodes; interpreter::BuildOpFuncList(place_, @@ -194,7 +198,7 @@ FetchList NewIRInterpreter::Run(const std::vector& feed_names, &op_func_nodes, scope_, local_scope_, - value_2_var_name_map_, + value_2_var_name_, execution_config_); // SetFeedVarsInplaceSkip(feed_names); // convert vec func_list to graph @@ -237,8 +241,12 @@ FetchList NewIRInterpreter::BetaRun(const std::vector& feed_names, SetDeviceId(place_); if (!is_build_) { LOG_FIRST_N(INFO, 1) << "New Executor is BetaRunning."; - ::ir::BuildScope( - *ir_program_->block(), scope_, local_scope_, &value_2_var_name_map_); + ::ir::BuildScope(*ir_program_->block(), + InnerScope(), + &value_2_var_name_, + &variable_2_var_name_, + &var_name_2_id_, + &variable_list_); BuildInstruction(); for (size_t instr_id = 0; instr_id < vec_instruction_base_.size(); ++instr_id) { @@ -1526,13 +1534,8 @@ void NewIRInterpreter::BuildInstruction() { ++it) { VLOG(0) << "Build Instruction for op: " << op_idx; if ((*it)->dialect()->name() == "pd_kernel") { - vec_instruction_base_.emplace_back( - std::make_unique(op_idx++, - place_, - (*it), - scope_, - local_scope_, - value_2_var_name_map_)); + vec_instruction_base_.emplace_back(std::make_unique( + op_idx++, place_, (*it), scope_, local_scope_, value_2_var_name_)); } else { PADDLE_THROW(platform::errors::Unimplemented( "Now only support pd_kernel dialect.")); diff --git a/paddle/fluid/framework/new_executor/new_ir_interpreter.h b/paddle/fluid/framework/new_executor/new_ir_interpreter.h index 0040e1bb946bc304c7e14562ea71a69104d7c8cd..7f84fdfcdb8806c1446767ee639a65336c5311a0 100644 --- a/paddle/fluid/framework/new_executor/new_ir_interpreter.h +++ b/paddle/fluid/framework/new_executor/new_ir_interpreter.h @@ -192,7 +192,11 @@ class NewIRInterpreter : public InterpreterBaseImpl { std::vector> vec_instruction_base_; - std::unordered_map<::ir::Value, std::string> value_2_var_name_map_; + std::unordered_map<::ir::Value, std::string> value_2_var_name_; + std::unordered_map + variable_2_var_name_; + std::map var_name_2_id_; + std::vector variable_list_; }; } // namespace framework diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h index 0c67837648dc58a49880d4257fcc9371b851a721..fd4aecbada17bffa5607042ebdd6a9bfc401e6a9 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h @@ -55,8 +55,18 @@ class PhiKernelAdaptor { void run_kernel_prog(ir::Program* program) { auto block = program->block(); - std::unordered_map name_map; - BuildScope(*block, scope_, nullptr, &name_map); + std::unordered_map value_2_var_name; + std::unordered_map + variable_2_var_name; + std::map var_name_2_id; + std::vector variable_list; + + BuildScope(*block, + scope_, + &value_2_var_name, + &variable_2_var_name, + &var_name_2_id, + &variable_list); ir::IrContext* ctx = ir::IrContext::Instance(); ctx->GetOrRegisterDialect(); @@ -88,7 +98,8 @@ class PhiKernelAdaptor { phi::MetaTensor, paddle::small_vector, paddle::small_vector, - false>((*it), name_map, scope_, nullptr, op_yaml_info_parser, &ctx); + false>( + (*it), value_2_var_name, scope_, nullptr, op_yaml_info_parser, &ctx); infer_meta_impl->infer_meta_(&ctx); @@ -108,12 +119,16 @@ class PhiKernelAdaptor { phi::TensorBase*, paddle::small_vector, paddle::small_vector, - true>( - (*it), name_map, scope_, nullptr, op_yaml_info_parser, &kernel_ctx); + true>((*it), + value_2_var_name, + scope_, + nullptr, + op_yaml_info_parser, + &kernel_ctx); kernel_fn(&kernel_ctx); auto out_value = (*it)->result(0); - out_name = name_map[out_value]; + out_name = value_2_var_name[out_value]; } } diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc index cb7a484c2a794bcaa17e4e7fc2dbf39f18de4dc2..c44a674275f444fcdbe6722e0d7168a6358b3a44 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc @@ -46,10 +46,15 @@ namespace ir { using VariableNameMap = std::unordered_map; -paddle::framework::Variable* CreateVar(ir::Value value, - const std::string& name, - paddle::framework::Scope* scope, - paddle::framework::Scope* local_scope) { +paddle::framework::Variable* CreateVar( + ir::Value value, + paddle::framework::Scope* inner_scope, + bool force_persisable, + std::unordered_map* value_2_var_name, + std::unordered_map* + variable_2_var_name, + std::map* var_name_2_id, + std::vector* variable_list) { Operation* def_op = value.GetDefiningOp(); bool is_persisable = false; if (def_op->attributes().count("is_persisable")) { @@ -58,27 +63,41 @@ paddle::framework::Variable* CreateVar(ir::Value value, .dyn_cast() .data(); } - if (is_persisable) { - VLOG(6) << "Create var: " << name << " in scope " << scope->root(); - return const_cast(scope->root())->Var(name); + + paddle::framework::Variable* var = nullptr; + std::string name = "inner_var_" + std::to_string(variable_2_var_name->size()); + if (force_persisable || is_persisable) { + VLOG(6) << "Create var: " << name << " in scope " << inner_scope->root(); + var = const_cast(inner_scope->root())->Var(name); } else { - VLOG(6) << "Create var: " << name << " in scope " << local_scope; - return local_scope->Var(name); + VLOG(6) << "Create var: " << name << " in scope " << inner_scope; + var = inner_scope->Var(name); } + value_2_var_name->emplace(value, name); + variable_2_var_name->emplace(var, name); + auto id = var_name_2_id->size(); + var_name_2_id->emplace(name, id); + variable_list->push_back(var); + PADDLE_ENFORCE_EQ( + variable_list->size(), + var_name_2_id->size(), + paddle::platform::errors::InvalidArgument( + "The size of variable_list and var_name_2_id map should be equal")); + return var; } void CheckInputVars( ir::Operation* op, const std::string& op_name, - const std::unordered_map& name_map) { + const std::unordered_map& value_2_var_name) { size_t input_num = op->num_operands(); if (input_num > 0) { for (size_t i = 0; i < input_num; ++i) { auto value = op->operand(i); if (value) { PADDLE_ENFORCE_NE( - name_map.find(value), - name_map.end(), + value_2_var_name.find(value), + value_2_var_name.end(), phi::errors::PreconditionNotMet( "input should in name map, [%d] 'th input of [%s] op", i, @@ -89,20 +108,25 @@ void CheckInputVars( } void BuildValue(ir::Value value, - paddle::framework::Scope* scope, - paddle::framework::Scope* local_scope, - std::unordered_map* name_map, - VariableNameMap* variable_name_map, - int& count) { // NOLINT - auto inner_local_scope = local_scope != nullptr ? local_scope : scope; - std::string name; - if (name_map->find(value) != name_map->end()) { - name = name_map->at(value); + paddle::framework::Scope* inner_scope, + std::unordered_map* value_2_var_name, + std::unordered_map* variable_2_var_name, + std::map* var_name_2_id, + std::vector* variable_list) { + paddle::framework::Variable* var = nullptr; + if (value_2_var_name->find(value) != value_2_var_name->end()) { + var = inner_scope->FindVar(value_2_var_name->at(value)); } else { - name = "inner_var_" + std::to_string(count++); - name_map->emplace(value, name); + var = CreateVar(value, + inner_scope, + false, + value_2_var_name, + variable_2_var_name, + var_name_2_id, + variable_list); } - auto var = CreateVar(value, name, scope, inner_local_scope); + // Only support DenseTensor or Vector if (!value.type()) { var->GetMutable(); @@ -120,11 +144,15 @@ void BuildValue(ir::Value value, paddle::platform::errors::Fatal( "Element of VectorType output only support " "DenseTensorType")); - std::string name_i = "inner_var_" + std::to_string(count++); - auto var_i = CreateVar(value, name_i, scope, inner_local_scope); + auto var_i = CreateVar(value, + inner_scope, + false, + value_2_var_name, + variable_2_var_name, + var_name_2_id, + variable_list); var_i->GetMutable(); tensor_array->emplace_back(var_i); - variable_name_map->emplace(var_i, name_i); } } else { PADDLE_THROW(phi::errors::PreconditionNotMet( @@ -132,24 +160,25 @@ void BuildValue(ir::Value value, } } -void HandleForSpecialOp(ir::Operation* op, - const VariableNameMap& variable_name_map, - paddle::framework::Scope* scope, - paddle::framework::Scope* local_scope, - std::unordered_map* name_map, - int& count) { // NOLINT +void HandleForSpecialOp( + ir::Operation* op, + paddle::framework::Scope* inner_scope, + std::unordered_map* value_2_var_name, + std::unordered_map* + variable_2_var_name, + std::map* var_name_2_id, + std::vector* variable_list) { std::string op_name = op->name(); if (op->attributes().count("op_name")) { op_name = op->attributes().at("op_name").dyn_cast().data(); } - size_t input_num = op->num_operands(); if (op_name == "pd.fetch") { // fetch is a very special op, with no output - VLOG(6) << "Handle for pd.fetch:"; - auto var = scope->Var("fetch"); - VLOG(6) << "Create var: fetch in scope " << scope; + auto var = const_cast(inner_scope->root()) + ->Var("fetch"); + VLOG(6) << "Create var: fetch in scope " << inner_scope->root(); auto fetch_list = var->GetMutable(); int index = op->attributes().at("col").dyn_cast().data(); @@ -157,16 +186,20 @@ void HandleForSpecialOp(ir::Operation* op, } if (op_name == "pd.feed") { - VLOG(6) << "Handle for pd.feed:"; auto value = op->result(0); - std::string name = "inner_var_" + std::to_string(count++); - name_map->emplace(value, name); - auto var = CreateVar(value, name, scope, local_scope); + auto var = CreateVar(value, + inner_scope, + false, + value_2_var_name, + variable_2_var_name, + var_name_2_id, + variable_list); // TODO(phlrain): need to update here, support StringTensor auto out_tensor = var->GetMutable(); - auto feed_var = scope->Var("feed"); - VLOG(6) << "Create var: feed in scope " << scope; + auto feed_var = + const_cast(inner_scope->root())->Var("feed"); + VLOG(6) << "Create var: feed in scope " << inner_scope->root(); int index = op->attributes().at("col").dyn_cast().data(); auto feed_list = feed_var->Get(); @@ -176,30 +209,33 @@ void HandleForSpecialOp(ir::Operation* op, } if (op_name == "builtin.combine") { - VLOG(6) << "Handle for builtin.combine:"; auto out_value = op->result(0); - std::string name; - if (name_map->find(out_value) != name_map->end()) { - name = name_map->at(out_value); + + paddle::framework::Variable* var = nullptr; + if (value_2_var_name->find(out_value) != value_2_var_name->end()) { + var = inner_scope->FindVar(value_2_var_name->at(out_value)); } else { - name = "inner_var_" + std::to_string(count++); - name_map->emplace(out_value, name); + var = CreateVar(out_value, + inner_scope, + false, + value_2_var_name, + variable_2_var_name, + var_name_2_id, + variable_list); } - auto var = CreateVar(out_value, name, scope, local_scope); auto tensor_array = var->GetMutable(); // clear tensor array tensor_array->clear(); - + size_t input_num = op->num_operands(); for (size_t i = 0; i < input_num; ++i) { auto value = op->operand(i); - PADDLE_ENFORCE_EQ( - name_map->count(value), + value_2_var_name->count(value), true, phi::errors::PreconditionNotMet("can not found input of combine op")); tensor_array->emplace_back( - CreateVar(value, name_map->at(value), scope, local_scope)); + inner_scope->FindVar(value_2_var_name->at(value))); } } @@ -210,14 +246,15 @@ void HandleForSpecialOp(ir::Operation* op, .dyn_cast() .data(); - auto in_ptr = op->operand(0); + auto value = op->operand(0); // change opreand name to param_name + auto orig_name = value_2_var_name->at(value); - auto orig_name = name_map->at(in_ptr); - if (scope->FindVar(param_name) == nullptr) { - scope->Rename(orig_name, param_name); + if (inner_scope->root()->FindVar(param_name) == nullptr) { + const_cast(inner_scope->root()) + ->Rename(orig_name, param_name); } - (*name_map)[in_ptr] = param_name; + (*value_2_var_name)[value] = param_name; } if (op_name == "builtin.get_parameter") { @@ -226,44 +263,44 @@ void HandleForSpecialOp(ir::Operation* op, .at("parameter_name") .dyn_cast() .data(); - auto out_ptr = op->result(0); - name_map->emplace(out_ptr, param_name); + auto value = op->result(0); + value_2_var_name->emplace(value, param_name); } if (op_name == "builtin.slice") { VLOG(6) << "Handle for builtin.slice"; auto out_value = op->result(0); - auto in_value = op->operand(0); - - PADDLE_ENFORCE_EQ(name_map->count(in_value), + PADDLE_ENFORCE_EQ(value_2_var_name->count(in_value), true, phi::errors::PreconditionNotMet( "input of buildin slice not in name map")); int index = op->attributes().at("index").dyn_cast().data(); - auto in_var = scope->FindVar(name_map->at(in_value)); + auto in_var = inner_scope->FindVar(value_2_var_name->at(in_value)); auto variable_array = in_var->Get(); PADDLE_ENFORCE_EQ( - variable_name_map.count(variable_array[index]), + variable_2_var_name->count(variable_array[index]), true, phi::errors::PreconditionNotMet("[%d] the variable in build slice " "input MUST in variable name map", index)); - std::string var_name = variable_name_map.at(variable_array[index]); - - name_map->emplace(out_value, var_name); + std::string var_name = variable_2_var_name->at(variable_array[index]); + value_2_var_name->emplace(out_value, var_name); } } -void HandleForInplaceOp(ir::Operation* op, - paddle::framework::Scope* scope, - paddle::framework::Scope* local_scope, - std::unordered_map* name_map, - int& count) { // NOLINT +void HandleForInplaceOp( + ir::Operation* op, + paddle::framework::Scope* inner_scope, + std::unordered_map* value_2_var_name, + std::unordered_map* + variable_2_var_name, + std::map* var_name_2_id, + std::vector* variable_list) { if (op->num_results() < 1) return; ir::IrContext* ctx = ir::IrContext::Instance(); std::string op_name = op->name(); @@ -271,12 +308,12 @@ void HandleForInplaceOp(ir::Operation* op, op_name = op->attributes().at("op_name").dyn_cast().data(); } - VLOG(4) << "HandleForInplaceOp: " << op_name; + ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name); paddle::dialect::OpYamlInfoParser yaml_parser( op_info.GetInterfaceImpl() ->get_op_info_()); - VariableNameMap variable_name_map; + for (size_t i = 0; i < op->num_results(); ++i) { ir::Value value = op->result(i); std::string value_name = yaml_parser.OutputNames()[i]; @@ -284,35 +321,36 @@ void HandleForInplaceOp(ir::Operation* op, std::string inplace_name = yaml_parser.InplaceName(value_name); ir::Value inplace_value = op->operand(yaml_parser.InputName2Id().at(inplace_name)); - std::string var_name = name_map->at(inplace_value); + std::string var_name = value_2_var_name->at(inplace_value); VLOG(4) << "inplace: " << value_name << " -> " << inplace_name << " (var: " << var_name << ")"; - name_map->emplace(value, var_name); + value_2_var_name->emplace(value, var_name); } else { - BuildValue( - value, scope, local_scope, name_map, &variable_name_map, count); + BuildValue(value, + inner_scope, + value_2_var_name, + variable_2_var_name, + var_name_2_id, + variable_list); } } } +// NOTE(zhiqiu): the persistable is created in inner_scope's root, and other is +// created in inner_scope. void BuildScope(const ir::Block& block, - paddle::framework::Scope* scope, - paddle::framework::Scope* local_scope, - std::unordered_map* name_map) { - VLOG(4) << "***** [before build] scope: ******\n" + paddle::framework::Scope* inner_scope, + std::unordered_map* value_2_var_name, + std::unordered_map* variable_2_var_name, + std::map* var_name_2_id, + std::vector* variable_list) { + VLOG(4) << "***** [before build] scope" + << "(" << inner_scope << ") ******\n" << paddle::framework::GenScopeTreeDebugInfo( - const_cast(scope->root())); - // NOTE(zhiqiu): if use local_scope (local_scope != nullptr), the persistable - // is created in scope , and other is created in local_scope. - auto inner_local_scope = local_scope != nullptr ? local_scope : scope; - VLOG(6) << "Build: scope [" << scope << "] inner_local_scope [" - << inner_local_scope << "]"; - - std::unordered_map - variable_name_map; - - // int count = name_map->size(); - int count = name_map->size(); + const_cast(inner_scope->root())); + + // int count = value_2_var_name->size(); for (auto it = block.begin(); it != block.end(); ++it) { ir::Operation* op = *it; @@ -321,19 +359,21 @@ void BuildScope(const ir::Block& block, op_name = op->attributes().at("op_name").dyn_cast().data(); } - - VLOG(4) << "BuildScope for :" << op_name; + VLOG(4) << "build op:" << op_name; if (op_name == "pd.feed" || op_name == "pd.fetch" || op_name == "builtin.combine" || op_name == "builtin.set_parameter" || op_name == "builtin.get_parameter" || op_name == "builtin.slice") { - VLOG(6) << "HandleForSpecialOp: " << op_name; - HandleForSpecialOp( - op, variable_name_map, scope, inner_local_scope, name_map, count); + HandleForSpecialOp(op, + inner_scope, + value_2_var_name, + variable_2_var_name, + var_name_2_id, + variable_list); continue; } - CheckInputVars(op, op_name, *name_map); + CheckInputVars(op, op_name, *value_2_var_name); if (op->num_results() < 1) continue; if (op->attributes().count("is_inplace") != 0 && @@ -341,23 +381,29 @@ void BuildScope(const ir::Block& block, .at("is_inplace") .dyn_cast() .data()) { - HandleForInplaceOp(op, scope, inner_local_scope, name_map, count); + HandleForInplaceOp(op, + inner_scope, + value_2_var_name, + variable_2_var_name, + var_name_2_id, + variable_list); continue; } else { for (size_t i = 0; i < op->num_results(); ++i) { BuildValue(op->result(i), - scope, - local_scope, - name_map, - &variable_name_map, - count); + inner_scope, + value_2_var_name, + variable_2_var_name, + var_name_2_id, + variable_list); } } } - VLOG(4) << "***** [after build] scope: ******\n" + VLOG(4) << "***** [after build] scope" + << "(" << inner_scope << ") ******\n" << paddle::framework::GenScopeTreeDebugInfo( - const_cast(scope->root())); + const_cast(inner_scope->root())); } } // namespace ir diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h index 136f1e3b930f27eb5cde007ed3d6c141738c6b23..7f6a804382921a9836523654342327486609464f 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h @@ -41,36 +41,13 @@ #include "glog/logging.h" namespace ir { -paddle::framework::Variable* CreateVar(ir::Value value, - const std::string& name, - paddle::framework::Scope* scope, - paddle::framework::Scope* local_scope); - -void BuildValue(ir::Value value, - paddle::framework::Scope* scope, - paddle::framework::Scope* local_scope, - std::unordered_map* name_map, - int& count); // NOLINT - -void HandleForSpecialOp(ir::Operation* op, - paddle::framework::Scope* scope, - paddle::framework::Scope* local_scope, - std::unordered_map* name_map, - int& count); // NOLINT - -void HandleForInplaceOp(ir::Operation* op, - paddle::framework::Scope* scope, - paddle::framework::Scope* local_scope, - std::unordered_map* name_map, - int& count); // NOLINT - -void CheckInputVars(ir::Operation* op, - const std::unordered_map& name_map); - void BuildScope(const ir::Block& block, - paddle::framework::Scope* scope, - paddle::framework::Scope* local_scope, - std::unordered_map* name_map); + paddle::framework::Scope* inner_scope, + std::unordered_map* value_2_var_name, + std::unordered_map* variable_2_var_name, + std::map* var_name_2_id, + std::vector* variable_list); template num_results(); ++i) { ir::Value out_ptr = op->result(i); auto name = name_map.at(out_ptr); + VLOG(6) << "ctx->EmplaceBackOutput: " << name; auto out_type = out_ptr.type(); if (!out_type) { phi::DenseTensor* ptr = nullptr; @@ -329,14 +307,14 @@ void BuildPhiContext( ctx->EmplaceBackOutput(out_ptr); } else if (out_type.isa()) { ctx->EmplaceBackOutput(OutType(const_cast( - &(inner_scope->Var(name)->Get())))); + &(inner_scope->FindVar(name)->Get())))); } else if (out_type.isa()) { ctx->EmplaceBackOutput(OutType(const_cast( - &(inner_scope->Var(name)->Get())))); + &(inner_scope->FindVar(name)->Get())))); } else if (out_type.isa()) { OutListType outputs; auto& variable_array = - scope->Var(name)->Get(); + scope->FindVar(name)->Get(); for (size_t i = 0; i < variable_array.size(); ++i) { outputs.emplace_back(OutType(const_cast( &(variable_array[i]->Get())))); @@ -360,6 +338,7 @@ void BuildPhiContext( } } } + VLOG(6) << "Done build phi context"; } } // namespace ir diff --git a/test/cpp/new_executor/standalone_executor_new_ir_test.cc b/test/cpp/new_executor/standalone_executor_new_ir_test.cc index 706f9f33c8853d99b521d9a622cb593daf01cbee..4c52621190227ec0d7325d554c987c36752cf992 100644 --- a/test/cpp/new_executor/standalone_executor_new_ir_test.cc +++ b/test/cpp/new_executor/standalone_executor_new_ir_test.cc @@ -89,5 +89,42 @@ TEST(StandaloneExecutor, run) { EXPECT_EQ(res3, true); } +TEST(StandaloneExecutor, run_inplace_sqrt) { + ir::IrContext* ctx = ir::IrContext::Instance(); + ir::Program program((ctx)); + ctx->GetOrRegisterDialect(); + ir::Builder builder = ir::Builder(ctx, program.block()); + + paddle::dialect::FullOp full = builder.Build( + std::vector{2, 2}, 4.0, phi::DataType::FLOAT32, phi::CPUPlace()); + + builder.Build(full->result(0)); + + auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); + + auto place = platform::CPUPlace(); + Scope scope; + InterpreterCore test_core(place, std::move(kernel_program), &scope); + test_core.BetaRun({}); + + auto out_tensor = test_core.local_scope() == nullptr + ? scope.FindVar("inner_var_0")->Get() + : test_core.local_scope() + ->FindVar("inner_var_0") + ->Get(); + + bool res0 = simple_cmp(out_tensor.data()[0], 2.0); + bool res1 = simple_cmp(out_tensor.data()[1], 2.0); + bool res2 = simple_cmp(out_tensor.data()[2], 2.0); + bool res3 = simple_cmp(out_tensor.data()[3], 2.0); + + EXPECT_EQ(scope.kids().size(), 1u); + EXPECT_EQ(scope.kids().front()->Size(), 1u); + EXPECT_EQ(res0, true); + EXPECT_EQ(res1, true); + EXPECT_EQ(res2, true); + EXPECT_EQ(res3, true); +} + } // namespace framework } // namespace paddle