未验证 提交 f2f1de7b 编写于 作者: R Ruibiao Chen 提交者: GitHub

Support cinn_launch op in standalone executor (#42046)

* Support cinn_launch OP in standalone executor

* Remove some redundant code
上级 3da8066a
...@@ -428,8 +428,17 @@ void InterpreterCore::BuildAndCacheInstructionCtx(Instruction* instr_node) { ...@@ -428,8 +428,17 @@ void InterpreterCore::BuildAndCacheInstructionCtx(Instruction* instr_node) {
} }
outs_map.emplace(var_name_item.first, std::move(out_vars)); outs_map.emplace(var_name_item.first, std::move(out_vars));
} }
// set runtime_ctx and infershape_ctx_ // set runtime_ctx and infershape_ctx_
instr_node->ResetContext(ins_map, outs_map); if (instr_node->OpBase()->Type() == "cinn_launch") { // OP use scope in
// kernel
Scope* local_scope = create_local_scope_
? global_scope_->GetMutableLocalScope()
: global_scope_->GetMutableScope();
instr_node->ResetContextWithScope(ins_map, outs_map, *local_scope);
} else {
instr_node->ResetContext(ins_map, outs_map);
}
} }
void InterpreterCore::BuildSkipShareLoDInfo() { void InterpreterCore::BuildSkipShareLoDInfo() {
......
...@@ -392,8 +392,19 @@ void build_op_func_list(const platform::Place& place, ...@@ -392,8 +392,19 @@ void build_op_func_list(const platform::Place& place,
platform::DeviceContextPool::Instance(); platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
Scope scope; Scope scope;
Scope* runtime_scope = &scope;
// NOTE(Ruibiao): We do not encourage directly using scope in OP kernel.
// But some OPs do have such behavior (e.g., cinn_launch OP). Here special
// treatment for them.
if (op_with_kernel->Type() == "cinn_launch") {
VLOG(6) << "OP(" << op_with_kernel->Type() << ") use scope in kernel, "
"so pass a real scope to "
"ExecutionContext";
runtime_scope = local_scope;
}
auto expected_kernel_key = op_with_kernel->GetExpectedKernelType( auto expected_kernel_key = op_with_kernel->GetExpectedKernelType(
ExecutionContext(*op, scope, *dev_ctx, runtime_context)); ExecutionContext(*op, *runtime_scope, *dev_ctx, runtime_context));
op_with_kernel->ResetKernelType(new OpKernelType(expected_kernel_key)); op_with_kernel->ResetKernelType(new OpKernelType(expected_kernel_key));
// change device by the device_guard() // change device by the device_guard()
...@@ -441,8 +452,8 @@ void build_op_func_list(const platform::Place& place, ...@@ -441,8 +452,8 @@ void build_op_func_list(const platform::Place& place,
op_with_kernel->Info().infer_shape_(&infer_shape_ctx); op_with_kernel->Info().infer_shape_(&infer_shape_ctx);
} }
auto exec_ctx = auto exec_ctx = ExecutionContext(*op_with_kernel, *runtime_scope,
ExecutionContext(*op_with_kernel, scope, *dev_ctx, runtime_context); *dev_ctx, runtime_context);
auto run_phi_kernel = false; auto run_phi_kernel = false;
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel( if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(
......
...@@ -755,6 +755,16 @@ void Instruction::ResetContext(const VariableValueMap& in_vars, ...@@ -755,6 +755,16 @@ void Instruction::ResetContext(const VariableValueMap& in_vars,
new ExecutionContext(*OpBase(), scope_, dev_ctx_, *runtime_ctx_.get())); new ExecutionContext(*OpBase(), scope_, dev_ctx_, *runtime_ctx_.get()));
} }
void Instruction::ResetContextWithScope(const VariableValueMap& in_vars,
const VariableValueMap& out_vars,
const framework::Scope& scope) {
runtime_ctx_.reset(new RuntimeContext(in_vars, out_vars));
infershape_ctx_.reset(
new InterpretercoreInferShapeContext(*OpBase(), *runtime_ctx_.get()));
execution_ctx_.reset(
new ExecutionContext(*OpBase(), scope, dev_ctx_, *runtime_ctx_.get()));
}
std::shared_ptr<RuntimeContext> Instruction::InnerRuntimeContext() const { std::shared_ptr<RuntimeContext> Instruction::InnerRuntimeContext() const {
return runtime_ctx_; return runtime_ctx_;
} }
......
...@@ -347,6 +347,10 @@ class Instruction { ...@@ -347,6 +347,10 @@ class Instruction {
void ResetContext(const VariableValueMap& in_vars, void ResetContext(const VariableValueMap& in_vars,
const VariableValueMap& out_vars); const VariableValueMap& out_vars);
void ResetContextWithScope(const VariableValueMap& in_vars,
const VariableValueMap& out_vars,
const framework::Scope& scope);
std::shared_ptr<RuntimeContext> InnerRuntimeContext() const; std::shared_ptr<RuntimeContext> InnerRuntimeContext() const;
std::shared_ptr<InterpretercoreInferShapeContext> InnerInferShapeContext() std::shared_ptr<InterpretercoreInferShapeContext> InnerInferShapeContext()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册