diff --git a/paddle/phi/core/custom_kernel.cc b/paddle/phi/core/custom_kernel.cc index 356dd1482c9d8330e97a0395e3ead25edf5f75a8..75c651c5402c7cf39ac9d08c643b7b68d58b9df5 100644 --- a/paddle/phi/core/custom_kernel.cc +++ b/paddle/phi/core/custom_kernel.cc @@ -16,6 +16,9 @@ #include "glog/logging.h" +static std::vector gpu_exclusive_kernels({"sync_batch_norm", + "sync_batch_norm_grad"}); + namespace phi { void CustomKernelMap::RegisterCustomKernel(const std::string& name, @@ -41,12 +44,15 @@ void CustomKernelMap::RegisterCustomKernels() { } auto& kernels = KernelFactory::Instance().kernels(); for (auto& pair : kernels_) { - PADDLE_ENFORCE_NE( - kernels.find(pair.first), - kernels.end(), - phi::errors::InvalidArgument( + if (kernels.find(pair.first) == kernels.cend()) { + if (std::find(gpu_exclusive_kernels.cbegin(), + gpu_exclusive_kernels.cend(), + pair.first) == gpu_exclusive_kernels.cend()) { + PADDLE_THROW(phi::errors::InvalidArgument( "The kernel %s is not ready for custom kernel registering.", pair.first)); + } + } for (auto& info_pair : pair.second) { PADDLE_ENFORCE_EQ(