From 9c2a9afd0dd688f99d9ec8d22cafcd3f6ce0bb44 Mon Sep 17 00:00:00 2001 From: Aganlengzi Date: Fri, 1 Apr 2022 11:37:33 +0800 Subject: [PATCH] [custom kernel] support fallback (#41212) --- paddle/fluid/framework/operator.cc | 11 +++++++++++ paddle/fluid/framework/phi_utils.cc | 20 +++++++++++++++----- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index efb334ebbd..83380d1f26 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1600,6 +1600,17 @@ void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const { expected_kernel_key.place_ = platform::CPUPlace(); kernel_iter = kernels.find(expected_kernel_key); } +#endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + if (kernel_iter == kernels.end() && + platform::is_custom_place(expected_kernel_key.place_)) { + VLOG(3) << "missing " << expected_kernel_key.place_.GetDeviceType() + << " kernel: " << type_ + << ", expected_kernel_key:" << expected_kernel_key + << ", fallbacking to CPU one!"; + expected_kernel_key.place_ = platform::CPUPlace(); + kernel_iter = kernels.find(expected_kernel_key); + } #endif PADDLE_ENFORCE_NE(kernel_iter, kernels.end(), platform::errors::NotFound( diff --git a/paddle/fluid/framework/phi_utils.cc b/paddle/fluid/framework/phi_utils.cc index 14997dd961..82c2c33931 100644 --- a/paddle/fluid/framework/phi_utils.cc +++ b/paddle/fluid/framework/phi_utils.cc @@ -102,7 +102,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, if (platform::is_xpu_place(expected_kernel_key.place_) || paddle::platform::is_in_xpu_black_list(op.Type())) { VLOG(3) << "phi missing XPU kernel: " << op.Type() - << "phipected_kernel_key:" << expected_kernel_key + << ", phipected_kernel_key:" << expected_kernel_key << ", fallbacking to CPU one!"; return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype()); @@ -111,7 +111,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, #ifdef PADDLE_WITH_ASCEND_CL if (platform::is_npu_place(expected_kernel_key.place_)) { VLOG(3) << "phi missing NPU kernel: " << op.Type() - << "phipected_kernel_key:" << expected_kernel_key + << ", phipected_kernel_key:" << expected_kernel_key << ", fallbacking to CPU one!"; return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype()); @@ -120,7 +120,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, #ifdef PADDLE_WITH_MLU if (platform::is_mlu_place(expected_kernel_key.place_)) { VLOG(3) << "phi missing MLU kernel: " << op.Type() - << "phipected_kernel_key:" << expected_kernel_key + << ", phipected_kernel_key:" << expected_kernel_key << ", fallbacking to CPU one!"; return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype()); @@ -128,8 +128,18 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, #endif #ifdef PADDLE_WITH_IPU if (platform::is_ipu_place(expected_kernel_key.place_)) { - VLOG(3) << "pten missing IPU kernel: " << op.Type() - << ", expected_kernel_key:" << expected_kernel_key + VLOG(3) << "phi missing IPU kernel: " << op.Type() + << ", phipected_kernel_key:" << expected_kernel_key + << ", fallbacking to CPU one!"; + return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(), + kernel_key.dtype()); + } +#endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + if (platform::is_custom_place(expected_kernel_key.place_)) { + VLOG(3) << "phi missing " << expected_kernel_key.place_.GetDeviceType() + << " kernel: " << op.Type() + << ", phipected_kernel_key:" << expected_kernel_key << ", fallbacking to CPU one!"; return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype()); -- GitLab