未验证 提交 6e7e4501 编写于 作者: R Ruibiao Chen 提交者: GitHub

Cherry pick for standalone executor (#42281)

* [cherry-pick] Support cinn_launch op in standalone executor (#42046)

* Support cinn_launch OP in standalone executor

* Remove some redundant code

* [cherry-pick] Do not reset default stream for StreamSafeCUDAAllocator (#42149)
上级 9bc423b1
...@@ -429,8 +429,17 @@ void InterpreterCore::BuildAndCacheInstructionCtx(Instruction* instr_node) { ...@@ -429,8 +429,17 @@ void InterpreterCore::BuildAndCacheInstructionCtx(Instruction* instr_node) {
} }
outs_map.emplace(var_name_item.first, std::move(out_vars)); outs_map.emplace(var_name_item.first, std::move(out_vars));
} }
// set runtime_ctx and infershape_ctx_ // set runtime_ctx and infershape_ctx_
instr_node->ResetContext(ins_map, outs_map); if (instr_node->OpBase()->Type() == "cinn_launch") { // OP use scope in
// kernel
Scope* local_scope = create_local_scope_
? global_scope_->GetMutableLocalScope()
: global_scope_->GetMutableScope();
instr_node->ResetContextWithScope(ins_map, outs_map, *local_scope);
} else {
instr_node->ResetContext(ins_map, outs_map);
}
} }
void InterpreterCore::BuildSkipShareLoDInfo() { void InterpreterCore::BuildSkipShareLoDInfo() {
......
...@@ -393,8 +393,19 @@ void build_op_func_list(const platform::Place& place, ...@@ -393,8 +393,19 @@ void build_op_func_list(const platform::Place& place,
platform::DeviceContextPool::Instance(); platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
Scope scope; Scope scope;
Scope* runtime_scope = &scope;
// NOTE(Ruibiao): We do not encourage directly using scope in OP kernel.
// But some OPs do have such behavior (e.g., cinn_launch OP). Here special
// treatment for them.
if (op_with_kernel->Type() == "cinn_launch") {
VLOG(6) << "OP(" << op_with_kernel->Type() << ") use scope in kernel, "
"so pass a real scope to "
"ExecutionContext";
runtime_scope = local_scope;
}
auto expected_kernel_key = op_with_kernel->GetExpectedKernelType( auto expected_kernel_key = op_with_kernel->GetExpectedKernelType(
ExecutionContext(*op, scope, *dev_ctx, runtime_context)); ExecutionContext(*op, *runtime_scope, *dev_ctx, runtime_context));
op_with_kernel->ResetKernelType(new OpKernelType(expected_kernel_key)); op_with_kernel->ResetKernelType(new OpKernelType(expected_kernel_key));
// change device by the device_guard() // change device by the device_guard()
...@@ -442,8 +453,8 @@ void build_op_func_list(const platform::Place& place, ...@@ -442,8 +453,8 @@ void build_op_func_list(const platform::Place& place,
op_with_kernel->Info().infer_shape_(&infer_shape_ctx); op_with_kernel->Info().infer_shape_(&infer_shape_ctx);
} }
auto exec_ctx = auto exec_ctx = ExecutionContext(*op_with_kernel, *runtime_scope,
ExecutionContext(*op_with_kernel, scope, *dev_ctx, runtime_context); *dev_ctx, runtime_context);
auto run_phi_kernel = false; auto run_phi_kernel = false;
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel( if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(
......
...@@ -765,6 +765,16 @@ void Instruction::ResetContext(const VariableValueMap& in_vars, ...@@ -765,6 +765,16 @@ void Instruction::ResetContext(const VariableValueMap& in_vars,
new ExecutionContext(*OpBase(), scope_, dev_ctx_, *runtime_ctx_.get())); new ExecutionContext(*OpBase(), scope_, dev_ctx_, *runtime_ctx_.get()));
} }
void Instruction::ResetContextWithScope(const VariableValueMap& in_vars,
const VariableValueMap& out_vars,
const framework::Scope& scope) {
runtime_ctx_.reset(new RuntimeContext(in_vars, out_vars));
infershape_ctx_.reset(
new InterpretercoreInferShapeContext(*OpBase(), *runtime_ctx_.get()));
execution_ctx_.reset(
new ExecutionContext(*OpBase(), scope, dev_ctx_, *runtime_ctx_.get()));
}
std::shared_ptr<RuntimeContext> Instruction::InnerRuntimeContext() const { std::shared_ptr<RuntimeContext> Instruction::InnerRuntimeContext() const {
return runtime_ctx_; return runtime_ctx_;
} }
......
...@@ -351,6 +351,10 @@ class Instruction { ...@@ -351,6 +351,10 @@ class Instruction {
void ResetContext(const VariableValueMap& in_vars, void ResetContext(const VariableValueMap& in_vars,
const VariableValueMap& out_vars); const VariableValueMap& out_vars);
void ResetContextWithScope(const VariableValueMap& in_vars,
const VariableValueMap& out_vars,
const framework::Scope& scope);
std::shared_ptr<RuntimeContext> InnerRuntimeContext() const; std::shared_ptr<RuntimeContext> InnerRuntimeContext() const;
std::shared_ptr<InterpretercoreInferShapeContext> InnerInferShapeContext() std::shared_ptr<InterpretercoreInferShapeContext> InnerInferShapeContext()
......
...@@ -415,6 +415,23 @@ class AllocatorFacadePrivate { ...@@ -415,6 +415,23 @@ class AllocatorFacadePrivate {
void SetDefaultStream(const platform::CUDAPlace& place, gpuStream_t stream) { void SetDefaultStream(const platform::CUDAPlace& place, gpuStream_t stream) {
const std::shared_ptr<StreamSafeCUDAAllocator>& allocator = const std::shared_ptr<StreamSafeCUDAAllocator>& allocator =
GetDefaultStreamSafeCUDAAllocator(place); GetDefaultStreamSafeCUDAAllocator(place);
// NOTE(Ruibiao): The default stream will be set when the CUDADeviceContext
// created. Normally, the DeviceContextPool is a global singleton and one
// Place only correspond to one DeviceContext. However, to support
// multi-stream scheduling, standalone executor creates two extra
// DeviceContextPools for H2D and D2H stream in StreamAnalyzer, which make
// one Place correspond to multiple DeviceContext and unexpectedly reset the
// default stream in runtime. To avoid this behavior, we do not allow
// changing default stream after initially setting.
if (allocator->GetDefaultStream() != nullptr) {
VLOG(5) << "The default stream for StreamSafeCUDAAllocator("
<< allocator.get() << ") in " << place << " has been set to "
<< allocator->GetDefaultStream()
<< " before, not allow to change now.";
return;
}
allocator->SetDefaultStream(stream); allocator->SetDefaultStream(stream);
VLOG(8) << "Set default stream to " << stream VLOG(8) << "Set default stream to " << stream
<< " for StreamSafeCUDAAllocator(" << allocator.get() << ") in " << " for StreamSafeCUDAAllocator(" << allocator.get() << ") in "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册