diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 74310a6046c7ddaaaa57846560ea7273816baf9c..a4fcf0773f6239b9044b194eea369f8482a3a894 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -428,8 +428,17 @@ void InterpreterCore::BuildAndCacheInstructionCtx(Instruction* instr_node) { } outs_map.emplace(var_name_item.first, std::move(out_vars)); } + // set runtime_ctx and infershape_ctx_ - instr_node->ResetContext(ins_map, outs_map); + if (instr_node->OpBase()->Type() == "cinn_launch") { // OP use scope in + // kernel + Scope* local_scope = create_local_scope_ + ? global_scope_->GetMutableLocalScope() + : global_scope_->GetMutableScope(); + instr_node->ResetContextWithScope(ins_map, outs_map, *local_scope); + } else { + instr_node->ResetContext(ins_map, outs_map); + } } void InterpreterCore::BuildSkipShareLoDInfo() { diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.cc b/paddle/fluid/framework/new_executor/interpretercore_util.cc index ed813c78bc3689cf85b0495c523ba29d787a708c..afddcb580b9d878693ea45fb85eceaec0d806081 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.cc +++ b/paddle/fluid/framework/new_executor/interpretercore_util.cc @@ -392,8 +392,19 @@ void build_op_func_list(const platform::Place& place, platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); Scope scope; + Scope* runtime_scope = &scope; + // NOTE(Ruibiao): We do not encourage directly using scope in OP kernel. + // But some OPs do have such behavior (e.g., cinn_launch OP). Here special + // treatment for them. + if (op_with_kernel->Type() == "cinn_launch") { + VLOG(6) << "OP(" << op_with_kernel->Type() << ") use scope in kernel, " + "so pass a real scope to " + "ExecutionContext"; + runtime_scope = local_scope; + } + auto expected_kernel_key = op_with_kernel->GetExpectedKernelType( - ExecutionContext(*op, scope, *dev_ctx, runtime_context)); + ExecutionContext(*op, *runtime_scope, *dev_ctx, runtime_context)); op_with_kernel->ResetKernelType(new OpKernelType(expected_kernel_key)); // change device by the device_guard() @@ -441,8 +452,8 @@ void build_op_func_list(const platform::Place& place, op_with_kernel->Info().infer_shape_(&infer_shape_ctx); } - auto exec_ctx = - ExecutionContext(*op_with_kernel, scope, *dev_ctx, runtime_context); + auto exec_ctx = ExecutionContext(*op_with_kernel, *runtime_scope, + *dev_ctx, runtime_context); auto run_phi_kernel = false; if (phi::KernelFactory::Instance().HasCompatiblePhiKernel( diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.cc b/paddle/fluid/framework/new_executor/new_executor_defs.cc index 86d534b0b4edd641675ad6e125133d404f05528e..3c2395d4320a17edb80fc2308f0bb3e554d470ed 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.cc +++ b/paddle/fluid/framework/new_executor/new_executor_defs.cc @@ -755,6 +755,16 @@ void Instruction::ResetContext(const VariableValueMap& in_vars, new ExecutionContext(*OpBase(), scope_, dev_ctx_, *runtime_ctx_.get())); } +void Instruction::ResetContextWithScope(const VariableValueMap& in_vars, + const VariableValueMap& out_vars, + const framework::Scope& scope) { + runtime_ctx_.reset(new RuntimeContext(in_vars, out_vars)); + infershape_ctx_.reset( + new InterpretercoreInferShapeContext(*OpBase(), *runtime_ctx_.get())); + execution_ctx_.reset( + new ExecutionContext(*OpBase(), scope, dev_ctx_, *runtime_ctx_.get())); +} + std::shared_ptr Instruction::InnerRuntimeContext() const { return runtime_ctx_; } diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index 6a1e46e3592421e35d1b8b5b04f6a09916e03e6a..28b9f6f0130f5b2fbd209d6a46ee95be544a5877 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -347,6 +347,10 @@ class Instruction { void ResetContext(const VariableValueMap& in_vars, const VariableValueMap& out_vars); + void ResetContextWithScope(const VariableValueMap& in_vars, + const VariableValueMap& out_vars, + const framework::Scope& scope); + std::shared_ptr InnerRuntimeContext() const; std::shared_ptr InnerInferShapeContext()