未验证 提交 7ffde4bc 编写于 作者: H HongyuJia 提交者: GitHub

simplify FallbackToCpu (#49124)

上级 ff79c144
......@@ -651,8 +651,8 @@ void BuildOpFuncList(const platform::Place& place,
} else {
if (!op_with_kernel->SupportsKernelType(expected_kernel_key,
exec_ctx)) {
auto phi_cpu_kernel_key = FallBackToCpu(
expected_kernel_key, phi_kernel_key, *op_with_kernel);
auto phi_cpu_kernel_key =
FallBackToCpu(phi_kernel_key, *op_with_kernel);
op_with_kernel->ResetPhiKernel(
new phi::Kernel(phi::KernelFactory::Instance().SelectKernel(
phi_kernel_name, phi_cpu_kernel_key)));
......
......@@ -1808,8 +1808,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
#endif
) {
fallback_to_cpu = true;
auto phi_cpu_kernel_key =
FallBackToCpu(*kernel_type_.get(), phi_kernel_key, *this);
auto phi_cpu_kernel_key = FallBackToCpu(phi_kernel_key, *this);
phi_kernel_.reset(
new phi::Kernel(phi::KernelFactory::Instance().SelectKernel(
phi_kernel_name, phi_cpu_kernel_key)));
......
......@@ -100,58 +100,59 @@ phi::KernelKey TransOpKernelTypeToPhiKernelKey(
framework::TransToPhiDataType(kernel_type.data_type_));
}
phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
const phi::KernelKey& kernel_key,
phi::KernelKey FallBackToCpu(const phi::KernelKey& kernel_key,
const framework::OperatorBase& op) {
#ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(expected_kernel_key.place_) ||
if (kernel_key.backend() == phi::Backend::XPU ||
paddle::platform::is_in_xpu_black_list(op.Type())) {
VLOG(3) << "phi missing XPU kernel: " << op.Type()
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
<< ", expected_kernel_key:" << kernel_key
<< ", fallback to CPU one!";
return phi::KernelKey(
phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype());
}
#endif
#ifdef PADDLE_WITH_ASCEND_CL
if (platform::is_npu_place(expected_kernel_key.place_)) {
if (kernel_key.backend() == phi::Backend::NPU) {
VLOG(3) << "phi missing NPU kernel: " << op.Type()
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
<< ", expected_kernel_key:" << kernel_key
<< ", fallback to CPU one!";
return phi::KernelKey(
phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype());
}
#endif
#ifdef PADDLE_WITH_MLU
if (platform::is_mlu_place(expected_kernel_key.place_)) {
if (kernel_key.backend() == phi::Backend::MLU) {
VLOG(3) << "phi missing MLU kernel: " << op.Type()
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
<< ", expected_kernel_key:" << kernel_key
<< ", fallback to CPU one!";
return phi::KernelKey(
phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype());
}
#endif
#ifdef PADDLE_WITH_IPU
if (platform::is_ipu_place(expected_kernel_key.place_)) {
if (kernel_key.backend() == phi::Backend::IPU) {
VLOG(3) << "phi missing IPU kernel: " << op.Type()
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
<< ", expected_kernel_key:" << kernel_key
<< ", fallback to CPU one!";
return phi::KernelKey(
phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype());
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (platform::is_custom_place(expected_kernel_key.place_)) {
VLOG(3) << "phi missing " << expected_kernel_key.place_.GetDeviceType()
auto place = phi::TransToPhiPlace(kernel_key.backend());
if (platform::is_custom_place(place)) {
VLOG(3) << "phi missing " << place.GetDeviceType()
<< " kernel: " << op.Type()
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
<< ", expected_kernel_key:" << kernel_key
<< ", fallback to CPU one!";
return phi::KernelKey(
phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype());
}
#endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(expected_kernel_key.place_)) {
if (kernel_key.backend() == phi::Backend::GPU ||
kernel_key.backend() == phi::Backend::GPUDNN) {
PADDLE_THROW(platform::errors::Unavailable(
"For GPU kernel, they must not fallback into CPU kernel."));
}
......
......@@ -43,8 +43,7 @@ namespace framework {
OpKernelType TransPhiKernelKeyToOpKernelType(const phi::KernelKey& kernel_key);
phi::KernelKey TransOpKernelTypeToPhiKernelKey(const OpKernelType& kernel_type);
phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
const phi::KernelKey& kernel_key,
phi::KernelKey FallBackToCpu(const phi::KernelKey& kernel_key,
const framework::OperatorBase& op);
/* Kernel Args parse */
......
......@@ -393,8 +393,7 @@ PreparedOp PrepareImpl(
#endif
) {
if (has_phi_kernel) {
auto phi_cpu_kernel_key =
FallBackToCpu(expected_kernel_key, phi_kernel_key, op);
auto phi_cpu_kernel_key = FallBackToCpu(phi_kernel_key, op);
auto& phi_cpu_kernel =
phi_kernel_factory.SelectKernel(phi_kernel_name, phi_cpu_kernel_key);
if (phi_cpu_kernel.IsValid()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册