未验证 提交 bc9b6e26 编写于 作者: S shentanyue 提交者: GitHub

[Phi] Fix phi kernel error when kernel fallback cpu (#53879)

上级 ad49e0fb
......@@ -1677,6 +1677,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
RuntimeContext* runtime_ctx) const {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
bool fallback_to_cpu = false;
phi::KernelKey phi_cpu_kernel_key;
auto* dev_ctx = pool.Get(place);
// using cache
if (kernel_type_.get()) {
......@@ -1895,7 +1896,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
if (in_custom_back_list) {
VLOG(3) << "fluid in black list: " << phi_kernel_name;
}
auto phi_cpu_kernel_key = FallBackToCpu(phi_kernel_key, *this);
phi_cpu_kernel_key = FallBackToCpu(phi_kernel_key, *this);
phi_kernel_.reset(
new phi::Kernel(phi::KernelFactory::Instance().SelectKernel(
phi_kernel_name, phi_cpu_kernel_key)));
......@@ -1926,12 +1927,20 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
1,
platform::EventRole::kInnerOp);
if (need_prepare_data_) {
transfer_scope =
PrepareData(scope,
framework::TransOpKernelTypeToPhiKernelKey(*kernel_type_),
&transfered_inplace_vars,
runtime_ctx,
dev_ctx->GetPlace());
if (fallback_to_cpu) {
transfer_scope = PrepareData(scope,
phi_cpu_kernel_key,
&transfered_inplace_vars,
runtime_ctx,
dev_ctx->GetPlace());
} else {
transfer_scope = PrepareData(
scope,
framework::TransOpKernelTypeToPhiKernelKey(*kernel_type_),
&transfered_inplace_vars,
runtime_ctx,
dev_ctx->GetPlace());
}
}
}
// exec scope is the scope that kernel actually executed on.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册