未验证 提交 bc9b6e26 编写于 作者: S shentanyue 提交者: GitHub

[Phi] Fix phi kernel error when kernel fallback cpu (#53879)

上级 ad49e0fb
...@@ -1677,6 +1677,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1677,6 +1677,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
RuntimeContext* runtime_ctx) const { RuntimeContext* runtime_ctx) const {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
bool fallback_to_cpu = false; bool fallback_to_cpu = false;
phi::KernelKey phi_cpu_kernel_key;
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
// using cache // using cache
if (kernel_type_.get()) { if (kernel_type_.get()) {
...@@ -1895,7 +1896,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1895,7 +1896,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
if (in_custom_back_list) { if (in_custom_back_list) {
VLOG(3) << "fluid in black list: " << phi_kernel_name; VLOG(3) << "fluid in black list: " << phi_kernel_name;
} }
auto phi_cpu_kernel_key = FallBackToCpu(phi_kernel_key, *this); phi_cpu_kernel_key = FallBackToCpu(phi_kernel_key, *this);
phi_kernel_.reset( phi_kernel_.reset(
new phi::Kernel(phi::KernelFactory::Instance().SelectKernel( new phi::Kernel(phi::KernelFactory::Instance().SelectKernel(
phi_kernel_name, phi_cpu_kernel_key))); phi_kernel_name, phi_cpu_kernel_key)));
...@@ -1926,14 +1927,22 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1926,14 +1927,22 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
1, 1,
platform::EventRole::kInnerOp); platform::EventRole::kInnerOp);
if (need_prepare_data_) { if (need_prepare_data_) {
transfer_scope = if (fallback_to_cpu) {
PrepareData(scope, transfer_scope = PrepareData(scope,
phi_cpu_kernel_key,
&transfered_inplace_vars,
runtime_ctx,
dev_ctx->GetPlace());
} else {
transfer_scope = PrepareData(
scope,
framework::TransOpKernelTypeToPhiKernelKey(*kernel_type_), framework::TransOpKernelTypeToPhiKernelKey(*kernel_type_),
&transfered_inplace_vars, &transfered_inplace_vars,
runtime_ctx, runtime_ctx,
dev_ctx->GetPlace()); dev_ctx->GetPlace());
} }
} }
}
// exec scope is the scope that kernel actually executed on. // exec scope is the scope that kernel actually executed on.
const Scope& exec_scope = const Scope& exec_scope =
(transfer_scope == nullptr ? scope : *transfer_scope); (transfer_scope == nullptr ? scope : *transfer_scope);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册