From e439d735638ca89ad0b614a5061185aaa7718cd3 Mon Sep 17 00:00:00 2001 From: Aganlengzi Date: Fri, 29 Jul 2022 15:40:52 +0800 Subject: [PATCH] add FLAGS_enable_api_kernel_fallback (#44706) * add FLAGS_enable_api_kernel_fallback * deal with more cases * add ut for coverage --- paddle/fluid/platform/flags.cc | 13 +++++++ paddle/phi/core/kernel_factory.cc | 36 ++++++++++++++++--- .../custom_runtime/test_custom_cpu_plugin.py | 10 ++++++ 3 files changed, 54 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index d5a9381735..b3b16356b9 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -184,6 +184,19 @@ PADDLE_DEFINE_EXPORTED_string( "please refer to the documents"); #endif +/* + * Kernel related FLAG + * Name: FLAGS_enable_api_kernel_fallback + * Since Version: 2.4 + * Value Range: bool, default=true + * Example: FLAGS_enable_api_kernel_fallback=true would allow kernel of current + * backend fallback to CPU one when not found + */ +PADDLE_DEFINE_EXPORTED_bool( + enable_api_kernel_fallback, + true, + "Whether enable api kernel fallback to CPU one when not found"); + #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) /** * CUDNN related FLAG diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index 3bee07f8a3..6e16029ee4 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -17,6 +17,8 @@ #include "glog/logging.h" #include "paddle/phi/core/enforce.h" +DECLARE_bool(enable_api_kernel_fallback); + namespace phi { const static Kernel empty_kernel; // NOLINT @@ -120,8 +122,15 @@ KernelResult KernelFactory::SelectKernelOrThrowError( kernel_iter = iter->second.find(any_layout_kernel_key); } - bool has_fallback_cpu = false; - if (kernel_iter == iter->second.end()) { + PADDLE_ENFORCE_NE( + kernel_iter == iter->second.end() && kernel_key.backend() == Backend::CPU, + true, + phi::errors::NotFound( + "The kernel with key %s of kernel `%s` is not registered.", + kernel_key, + kernel_name)); + + if (FLAGS_enable_api_kernel_fallback && kernel_iter == iter->second.end()) { // Fallback CPU backend phi::KernelKey cpu_kernel_key( phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype()); @@ -132,18 +141,35 @@ KernelResult KernelFactory::SelectKernelOrThrowError( phi::Backend::CPU, phi::DataLayout::ALL_LAYOUT, kernel_key.dtype()); kernel_iter = iter->second.find(any_layout_kernel_key); } - has_fallback_cpu = true; + + PADDLE_ENFORCE_NE( + kernel_iter, + iter->second.end(), + phi::errors::NotFound( + "The kernel with key %s of kernel `%s` is not registered and" + " fail to fallback to CPU one.", + kernel_key, + kernel_name)); + + VLOG(3) << "missing " << kernel_key.backend() << " kernel: " << kernel_name + << ", expected_kernel_key:" << kernel_key + << ", fallbacking to CPU one!"; + + return {kernel_iter->second, true}; } PADDLE_ENFORCE_NE( kernel_iter, iter->second.end(), phi::errors::NotFound( - "The kernel with key %s of kernel `%s` is not registered.", + "The kernel with key %s of kernel `%s` is not registered and" + " the current value of FLAGS_enable_api_kernel_fallback(bool," + " default true) is false. If you want to fallback this kernel" + " to CPU one, please set the flag true before run again.", kernel_key, kernel_name)); - return {kernel_iter->second, has_fallback_cpu}; + return {kernel_iter->second, false}; } const KernelArgsDef& KernelFactory::GetFirstKernelArgsDef( diff --git a/python/paddle/fluid/tests/custom_runtime/test_custom_cpu_plugin.py b/python/paddle/fluid/tests/custom_runtime/test_custom_cpu_plugin.py index 02dabf899c..dcdb7d2d12 100644 --- a/python/paddle/fluid/tests/custom_runtime/test_custom_cpu_plugin.py +++ b/python/paddle/fluid/tests/custom_runtime/test_custom_cpu_plugin.py @@ -40,6 +40,7 @@ class TestCustomCPUPlugin(unittest.TestCase): self._test_custom_device_mnist() self._test_eager_backward_api() self._test_eager_copy_to() + self._test_fallback_kernel() self._test_custom_device_dataloader() self._test_custom_device_mnist() @@ -160,6 +161,15 @@ class TestCustomCPUPlugin(unittest.TestCase): self.assertTrue(np.array_equal(another_custom_cpu_tensor, x)) self.assertTrue(another_custom_cpu_tensor.place.is_custom_place()) + def _test_fallback_kernel(self): + # using (custom_cpu, add, int16) which is not registered + import paddle + r = np.array([6, 6, 6], 'int16') + x = paddle.to_tensor([5, 4, 3], 'int16') + y = paddle.to_tensor([1, 2, 3], 'int16') + z = paddle.add(x, y) + self.assertTrue(np.array_equal(z, r)) + def tearDown(self): del os.environ['CUSTOM_DEVICE_ROOT'] -- GitLab