diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 00a499350044903c0c987240ae4cedda5c385bc2..a03d75e3fe79af53a6c4fb4b114ecf72a575e29f 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1447,6 +1447,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, const platform::Place& place, RuntimeContext* runtime_ctx) const { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + bool fallback_to_cpu = false; auto* dev_ctx = pool.Get(place); #ifdef PADDLE_WITH_ASCEND_CL @@ -1637,6 +1638,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, || (is_xpu_unsupport && !is_xpu_kp_support) #endif ) { + fallback_to_cpu = true; auto phi_cpu_kernel_key = FallBackToCpu(*kernel_type_.get(), phi_kernel_key, *this); phi_kernel_.reset( @@ -1720,6 +1722,9 @@ void OperatorWithKernel::RunImpl(const Scope& scope, (*kernel_func_)( ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx)); } + if (fallback_to_cpu) { + phi_kernel_.release(); + } } if (!transfered_inplace_vars.empty()) {