未验证 提交 e439d735 编写于 作者: A Aganlengzi 提交者: GitHub

add FLAGS_enable_api_kernel_fallback (#44706)

* add FLAGS_enable_api_kernel_fallback

* deal with more cases

* add ut for coverage
上级 653885a5
......@@ -184,6 +184,19 @@ PADDLE_DEFINE_EXPORTED_string(
"please refer to the documents");
#endif
/*
* Kernel related FLAG
* Name: FLAGS_enable_api_kernel_fallback
* Since Version: 2.4
* Value Range: bool, default=true
* Example: FLAGS_enable_api_kernel_fallback=true would allow kernel of current
* backend fallback to CPU one when not found
*/
PADDLE_DEFINE_EXPORTED_bool(
enable_api_kernel_fallback,
true,
"Whether enable api kernel fallback to CPU one when not found");
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
/**
* CUDNN related FLAG
......
......@@ -17,6 +17,8 @@
#include "glog/logging.h"
#include "paddle/phi/core/enforce.h"
DECLARE_bool(enable_api_kernel_fallback);
namespace phi {
const static Kernel empty_kernel; // NOLINT
......@@ -120,8 +122,15 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
kernel_iter = iter->second.find(any_layout_kernel_key);
}
bool has_fallback_cpu = false;
if (kernel_iter == iter->second.end()) {
PADDLE_ENFORCE_NE(
kernel_iter == iter->second.end() && kernel_key.backend() == Backend::CPU,
true,
phi::errors::NotFound(
"The kernel with key %s of kernel `%s` is not registered.",
kernel_key,
kernel_name));
if (FLAGS_enable_api_kernel_fallback && kernel_iter == iter->second.end()) {
// Fallback CPU backend
phi::KernelKey cpu_kernel_key(
phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype());
......@@ -132,18 +141,35 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
phi::Backend::CPU, phi::DataLayout::ALL_LAYOUT, kernel_key.dtype());
kernel_iter = iter->second.find(any_layout_kernel_key);
}
has_fallback_cpu = true;
PADDLE_ENFORCE_NE(
kernel_iter,
iter->second.end(),
phi::errors::NotFound(
"The kernel with key %s of kernel `%s` is not registered and"
" fail to fallback to CPU one.",
kernel_key,
kernel_name));
VLOG(3) << "missing " << kernel_key.backend() << " kernel: " << kernel_name
<< ", expected_kernel_key:" << kernel_key
<< ", fallbacking to CPU one!";
return {kernel_iter->second, true};
}
PADDLE_ENFORCE_NE(
kernel_iter,
iter->second.end(),
phi::errors::NotFound(
"The kernel with key %s of kernel `%s` is not registered.",
"The kernel with key %s of kernel `%s` is not registered and"
" the current value of FLAGS_enable_api_kernel_fallback(bool,"
" default true) is false. If you want to fallback this kernel"
" to CPU one, please set the flag true before run again.",
kernel_key,
kernel_name));
return {kernel_iter->second, has_fallback_cpu};
return {kernel_iter->second, false};
}
const KernelArgsDef& KernelFactory::GetFirstKernelArgsDef(
......
......@@ -40,6 +40,7 @@ class TestCustomCPUPlugin(unittest.TestCase):
self._test_custom_device_mnist()
self._test_eager_backward_api()
self._test_eager_copy_to()
self._test_fallback_kernel()
self._test_custom_device_dataloader()
self._test_custom_device_mnist()
......@@ -160,6 +161,15 @@ class TestCustomCPUPlugin(unittest.TestCase):
self.assertTrue(np.array_equal(another_custom_cpu_tensor, x))
self.assertTrue(another_custom_cpu_tensor.place.is_custom_place())
def _test_fallback_kernel(self):
# using (custom_cpu, add, int16) which is not registered
import paddle
r = np.array([6, 6, 6], 'int16')
x = paddle.to_tensor([5, 4, 3], 'int16')
y = paddle.to_tensor([1, 2, 3], 'int16')
z = paddle.add(x, y)
self.assertTrue(np.array_equal(z, r))
def tearDown(self):
del os.environ['CUSTOM_DEVICE_ROOT']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册