From 1888d874b2cc62e10adc0d22b60cdce48f90fd65 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Mon, 4 Apr 2022 21:53:36 +0800 Subject: [PATCH] add cudnn flag in yaml (#41368) --- paddle/phi/core/kernel_factory.cc | 20 ++++++++++++++++++- paddle/phi/core/kernel_factory.h | 3 ++- python/paddle/utils/code_gen/api_base.py | 11 ++++++++-- python/paddle/utils/code_gen/api_gen.py | 2 ++ .../paddle/utils/code_gen/backward_api_gen.py | 2 ++ 5 files changed, 34 insertions(+), 4 deletions(-) diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index 81c43764fe..a1ce90c2c7 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -75,13 +75,31 @@ bool KernelFactory::IsSelectKernelValid(const std::string& kernel_name, } const Kernel& KernelFactory::SelectKernelOrThrowError( - const std::string& kernel_name, const KernelKey& kernel_key) const { + const std::string& kernel_name, + const KernelKey& kernel_key, + bool use_cudnn) const { auto iter = kernels_.find(kernel_name); PADDLE_ENFORCE_NE( iter, kernels_.end(), phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name)); +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (use_cudnn && kernel_key.backend() == Backend::GPU) { + auto kernel_iter = iter->second.find( + {Backend::GPUDNN, kernel_key.layout(), kernel_key.dtype()}); + if (kernel_iter == iter->second.end() && + kernel_key.layout() != phi::DataLayout::ALL_LAYOUT) { + kernel_iter = iter->second.find( + {Backend::GPUDNN, DataLayout::ALL_LAYOUT, kernel_key.dtype()}); + } + if (kernel_iter != iter->second.end()) { + return kernel_iter->second; + } + LOG(WARNING) << "The cudnn kernel for [" << kernel_name + << "] is not registered."; + } +#endif auto kernel_iter = iter->second.find(kernel_key); // TODO(chenweihang): polish refind impl here if (kernel_iter == iter->second.end() && diff --git a/paddle/phi/core/kernel_factory.h b/paddle/phi/core/kernel_factory.h index 6c098c75a0..8fd25b691b 100644 --- a/paddle/phi/core/kernel_factory.h +++ b/paddle/phi/core/kernel_factory.h @@ -238,7 +238,8 @@ class KernelFactory { } const Kernel& SelectKernelOrThrowError(const std::string& kernel_name, - const KernelKey& kernel_key) const; + const KernelKey& kernel_key, + bool use_cudnn = false) const; const Kernel& SelectKernelOrThrowError(const std::string& kernel_name, Backend backend, diff --git a/python/paddle/utils/code_gen/api_base.py b/python/paddle/utils/code_gen/api_base.py index c1a987d06b..c51e2b0acd 100644 --- a/python/paddle/utils/code_gen/api_base.py +++ b/python/paddle/utils/code_gen/api_base.py @@ -238,7 +238,8 @@ class BaseAPI(object): 'param': None, 'backend': None, 'layout': None, - 'data_type': None + 'data_type': None, + 'use_cudnn': 'false' } if 'backend' in kernel_config and len(kernel_config['backend']) > 0: kernel['backend'] = kernel_config['backend'] @@ -248,6 +249,10 @@ class BaseAPI(object): kernel['data_type'] = kernel_config['data_type'] if 'param' in kernel_config: kernel['param'] = kernel_config['param'] + if 'use_cudnn' in kernel_config: + kernel['use_cudnn'] = kernel_config['use_cudnn'] + if isinstance(kernel['use_cudnn'], bool): + kernel['use_cudnn'] = str(kernel['use_cudnn']).lower() kernel['func'] = [ kernel_fn.strip() for kernel_fn in kernel_config['func'].split(',') ] @@ -713,10 +718,12 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self outputs_args, kernel_output_names, output_create = self.gene_output( self.outputs['types'], 'SetKernelOutput', code_indent, inplace_flag) api_func_name = self.get_api_func_name() + ('_' if inplace_flag else '') + cudnn_args = '' if self.kernel[ + 'use_cudnn'] == 'false' else ', ' + self.kernel['use_cudnn'] return f""" {code_indent} VLOG(6) << "{self.api} API kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]"; {code_indent} const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( -{code_indent} "{self.kernel['func'][0]}", {{kernel_backend, kernel_layout, kernel_data_type}}); +{code_indent} "{self.kernel['func'][0]}", {{kernel_backend, kernel_layout, kernel_data_type}}{cudnn_args}); {code_indent} VLOG(6) << "{self.api} API kernel: " << kernel; {code_indent} auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); diff --git a/python/paddle/utils/code_gen/api_gen.py b/python/paddle/utils/code_gen/api_gen.py index f95edf6c59..4087b55b51 100644 --- a/python/paddle/utils/code_gen/api_gen.py +++ b/python/paddle/utils/code_gen/api_gen.py @@ -163,6 +163,8 @@ def source_include(header_file_path): #include "paddle/phi/infermeta/ternary.h" #include "paddle/fluid/platform/profiler/event_tracing.h" + +DECLARE_bool(conv2d_disable_cudnn); """ diff --git a/python/paddle/utils/code_gen/backward_api_gen.py b/python/paddle/utils/code_gen/backward_api_gen.py index e26f653878..970ac02247 100644 --- a/python/paddle/utils/code_gen/backward_api_gen.py +++ b/python/paddle/utils/code_gen/backward_api_gen.py @@ -179,6 +179,8 @@ def source_include(header_file_path): #include "paddle/fluid/eager/api/utils/global_utils.h" #include "paddle/fluid/platform/profiler/event_tracing.h" + +DECLARE_bool(conv2d_disable_cudnn); """ -- GitLab