未验证 提交 6c737c67 编写于 作者: R ronnywang 提交者: GitHub

fix cached kernel bug when fallback to cpu (#45676)

上级 6813f41e
...@@ -1447,6 +1447,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1447,6 +1447,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place, const platform::Place& place,
RuntimeContext* runtime_ctx) const { RuntimeContext* runtime_ctx) const {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
bool fallback_to_cpu = false;
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
#ifdef PADDLE_WITH_ASCEND_CL #ifdef PADDLE_WITH_ASCEND_CL
...@@ -1637,6 +1638,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1637,6 +1638,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
|| (is_xpu_unsupport && !is_xpu_kp_support) || (is_xpu_unsupport && !is_xpu_kp_support)
#endif #endif
) { ) {
fallback_to_cpu = true;
auto phi_cpu_kernel_key = auto phi_cpu_kernel_key =
FallBackToCpu(*kernel_type_.get(), phi_kernel_key, *this); FallBackToCpu(*kernel_type_.get(), phi_kernel_key, *this);
phi_kernel_.reset( phi_kernel_.reset(
...@@ -1720,6 +1722,9 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1720,6 +1722,9 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
(*kernel_func_)( (*kernel_func_)(
ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx)); ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx));
} }
if (fallback_to_cpu) {
phi_kernel_.release();
}
} }
if (!transfered_inplace_vars.empty()) { if (!transfered_inplace_vars.empty()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册