未验证 提交 e439d735 编写于 作者: A Aganlengzi 提交者: GitHub

add FLAGS_enable_api_kernel_fallback (#44706)

* add FLAGS_enable_api_kernel_fallback

* deal with more cases

* add ut for coverage
上级 653885a5
...@@ -184,6 +184,19 @@ PADDLE_DEFINE_EXPORTED_string( ...@@ -184,6 +184,19 @@ PADDLE_DEFINE_EXPORTED_string(
"please refer to the documents"); "please refer to the documents");
#endif #endif
/*
* Kernel related FLAG
* Name: FLAGS_enable_api_kernel_fallback
* Since Version: 2.4
* Value Range: bool, default=true
* Example: FLAGS_enable_api_kernel_fallback=true would allow kernel of current
* backend fallback to CPU one when not found
*/
PADDLE_DEFINE_EXPORTED_bool(
enable_api_kernel_fallback,
true,
"Whether enable api kernel fallback to CPU one when not found");
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
/** /**
* CUDNN related FLAG * CUDNN related FLAG
......
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
DECLARE_bool(enable_api_kernel_fallback);
namespace phi { namespace phi {
const static Kernel empty_kernel; // NOLINT const static Kernel empty_kernel; // NOLINT
...@@ -120,8 +122,15 @@ KernelResult KernelFactory::SelectKernelOrThrowError( ...@@ -120,8 +122,15 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
kernel_iter = iter->second.find(any_layout_kernel_key); kernel_iter = iter->second.find(any_layout_kernel_key);
} }
bool has_fallback_cpu = false; PADDLE_ENFORCE_NE(
if (kernel_iter == iter->second.end()) { kernel_iter == iter->second.end() && kernel_key.backend() == Backend::CPU,
true,
phi::errors::NotFound(
"The kernel with key %s of kernel `%s` is not registered.",
kernel_key,
kernel_name));
if (FLAGS_enable_api_kernel_fallback && kernel_iter == iter->second.end()) {
// Fallback CPU backend // Fallback CPU backend
phi::KernelKey cpu_kernel_key( phi::KernelKey cpu_kernel_key(
phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype()); phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype());
...@@ -132,18 +141,35 @@ KernelResult KernelFactory::SelectKernelOrThrowError( ...@@ -132,18 +141,35 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
phi::Backend::CPU, phi::DataLayout::ALL_LAYOUT, kernel_key.dtype()); phi::Backend::CPU, phi::DataLayout::ALL_LAYOUT, kernel_key.dtype());
kernel_iter = iter->second.find(any_layout_kernel_key); kernel_iter = iter->second.find(any_layout_kernel_key);
} }
has_fallback_cpu = true;
PADDLE_ENFORCE_NE(
kernel_iter,
iter->second.end(),
phi::errors::NotFound(
"The kernel with key %s of kernel `%s` is not registered and"
" fail to fallback to CPU one.",
kernel_key,
kernel_name));
VLOG(3) << "missing " << kernel_key.backend() << " kernel: " << kernel_name
<< ", expected_kernel_key:" << kernel_key
<< ", fallbacking to CPU one!";
return {kernel_iter->second, true};
} }
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
kernel_iter, kernel_iter,
iter->second.end(), iter->second.end(),
phi::errors::NotFound( phi::errors::NotFound(
"The kernel with key %s of kernel `%s` is not registered.", "The kernel with key %s of kernel `%s` is not registered and"
" the current value of FLAGS_enable_api_kernel_fallback(bool,"
" default true) is false. If you want to fallback this kernel"
" to CPU one, please set the flag true before run again.",
kernel_key, kernel_key,
kernel_name)); kernel_name));
return {kernel_iter->second, has_fallback_cpu}; return {kernel_iter->second, false};
} }
const KernelArgsDef& KernelFactory::GetFirstKernelArgsDef( const KernelArgsDef& KernelFactory::GetFirstKernelArgsDef(
......
...@@ -40,6 +40,7 @@ class TestCustomCPUPlugin(unittest.TestCase): ...@@ -40,6 +40,7 @@ class TestCustomCPUPlugin(unittest.TestCase):
self._test_custom_device_mnist() self._test_custom_device_mnist()
self._test_eager_backward_api() self._test_eager_backward_api()
self._test_eager_copy_to() self._test_eager_copy_to()
self._test_fallback_kernel()
self._test_custom_device_dataloader() self._test_custom_device_dataloader()
self._test_custom_device_mnist() self._test_custom_device_mnist()
...@@ -160,6 +161,15 @@ class TestCustomCPUPlugin(unittest.TestCase): ...@@ -160,6 +161,15 @@ class TestCustomCPUPlugin(unittest.TestCase):
self.assertTrue(np.array_equal(another_custom_cpu_tensor, x)) self.assertTrue(np.array_equal(another_custom_cpu_tensor, x))
self.assertTrue(another_custom_cpu_tensor.place.is_custom_place()) self.assertTrue(another_custom_cpu_tensor.place.is_custom_place())
def _test_fallback_kernel(self):
# using (custom_cpu, add, int16) which is not registered
import paddle
r = np.array([6, 6, 6], 'int16')
x = paddle.to_tensor([5, 4, 3], 'int16')
y = paddle.to_tensor([1, 2, 3], 'int16')
z = paddle.add(x, y)
self.assertTrue(np.array_equal(z, r))
def tearDown(self): def tearDown(self):
del os.environ['CUSTOM_DEVICE_ROOT'] del os.environ['CUSTOM_DEVICE_ROOT']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册