未验证 提交 4b0f1b0c 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] fix the not ready kernel can not register. (#47758)

上级 e5eb3f55
...@@ -16,6 +16,9 @@ ...@@ -16,6 +16,9 @@
#include "glog/logging.h" #include "glog/logging.h"
static std::vector<std::string> gpu_exclusive_kernels({"sync_batch_norm",
"sync_batch_norm_grad"});
namespace phi { namespace phi {
void CustomKernelMap::RegisterCustomKernel(const std::string& name, void CustomKernelMap::RegisterCustomKernel(const std::string& name,
...@@ -41,12 +44,15 @@ void CustomKernelMap::RegisterCustomKernels() { ...@@ -41,12 +44,15 @@ void CustomKernelMap::RegisterCustomKernels() {
} }
auto& kernels = KernelFactory::Instance().kernels(); auto& kernels = KernelFactory::Instance().kernels();
for (auto& pair : kernels_) { for (auto& pair : kernels_) {
PADDLE_ENFORCE_NE( if (kernels.find(pair.first) == kernels.cend()) {
kernels.find(pair.first), if (std::find(gpu_exclusive_kernels.cbegin(),
kernels.end(), gpu_exclusive_kernels.cend(),
phi::errors::InvalidArgument( pair.first) == gpu_exclusive_kernels.cend()) {
PADDLE_THROW(phi::errors::InvalidArgument(
"The kernel %s is not ready for custom kernel registering.", "The kernel %s is not ready for custom kernel registering.",
pair.first)); pair.first));
}
}
for (auto& info_pair : pair.second) { for (auto& info_pair : pair.second) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册