未验证 提交 9c2a9afd 编写于 作者: A Aganlengzi 提交者: GitHub

[custom kernel] support fallback (#41212)

上级 0b0c2768
...@@ -1600,6 +1600,17 @@ void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const { ...@@ -1600,6 +1600,17 @@ void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const {
expected_kernel_key.place_ = platform::CPUPlace(); expected_kernel_key.place_ = platform::CPUPlace();
kernel_iter = kernels.find(expected_kernel_key); kernel_iter = kernels.find(expected_kernel_key);
} }
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (kernel_iter == kernels.end() &&
platform::is_custom_place(expected_kernel_key.place_)) {
VLOG(3) << "missing " << expected_kernel_key.place_.GetDeviceType()
<< " kernel: " << type_
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
expected_kernel_key.place_ = platform::CPUPlace();
kernel_iter = kernels.find(expected_kernel_key);
}
#endif #endif
PADDLE_ENFORCE_NE(kernel_iter, kernels.end(), PADDLE_ENFORCE_NE(kernel_iter, kernels.end(),
platform::errors::NotFound( platform::errors::NotFound(
......
...@@ -102,7 +102,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, ...@@ -102,7 +102,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
if (platform::is_xpu_place(expected_kernel_key.place_) || if (platform::is_xpu_place(expected_kernel_key.place_) ||
paddle::platform::is_in_xpu_black_list(op.Type())) { paddle::platform::is_in_xpu_black_list(op.Type())) {
VLOG(3) << "phi missing XPU kernel: " << op.Type() VLOG(3) << "phi missing XPU kernel: " << op.Type()
<< "phipected_kernel_key:" << expected_kernel_key << ", phipected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!"; << ", fallbacking to CPU one!";
return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(), return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
kernel_key.dtype()); kernel_key.dtype());
...@@ -111,7 +111,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, ...@@ -111,7 +111,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
#ifdef PADDLE_WITH_ASCEND_CL #ifdef PADDLE_WITH_ASCEND_CL
if (platform::is_npu_place(expected_kernel_key.place_)) { if (platform::is_npu_place(expected_kernel_key.place_)) {
VLOG(3) << "phi missing NPU kernel: " << op.Type() VLOG(3) << "phi missing NPU kernel: " << op.Type()
<< "phipected_kernel_key:" << expected_kernel_key << ", phipected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!"; << ", fallbacking to CPU one!";
return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(), return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
kernel_key.dtype()); kernel_key.dtype());
...@@ -120,7 +120,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, ...@@ -120,7 +120,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
#ifdef PADDLE_WITH_MLU #ifdef PADDLE_WITH_MLU
if (platform::is_mlu_place(expected_kernel_key.place_)) { if (platform::is_mlu_place(expected_kernel_key.place_)) {
VLOG(3) << "phi missing MLU kernel: " << op.Type() VLOG(3) << "phi missing MLU kernel: " << op.Type()
<< "phipected_kernel_key:" << expected_kernel_key << ", phipected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!"; << ", fallbacking to CPU one!";
return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(), return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
kernel_key.dtype()); kernel_key.dtype());
...@@ -128,8 +128,18 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, ...@@ -128,8 +128,18 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
#endif #endif
#ifdef PADDLE_WITH_IPU #ifdef PADDLE_WITH_IPU
if (platform::is_ipu_place(expected_kernel_key.place_)) { if (platform::is_ipu_place(expected_kernel_key.place_)) {
VLOG(3) << "pten missing IPU kernel: " << op.Type() VLOG(3) << "phi missing IPU kernel: " << op.Type()
<< ", expected_kernel_key:" << expected_kernel_key << ", phipected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
kernel_key.dtype());
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (platform::is_custom_place(expected_kernel_key.place_)) {
VLOG(3) << "phi missing " << expected_kernel_key.place_.GetDeviceType()
<< " kernel: " << op.Type()
<< ", phipected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!"; << ", fallbacking to CPU one!";
return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(), return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
kernel_key.dtype()); kernel_key.dtype());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册