未验证 提交 1888d874 编写于 作者: Z zyfncg 提交者: GitHub

add cudnn flag in yaml (#41368)

上级 77cf305f
......@@ -75,13 +75,31 @@ bool KernelFactory::IsSelectKernelValid(const std::string& kernel_name,
}
const Kernel& KernelFactory::SelectKernelOrThrowError(
const std::string& kernel_name, const KernelKey& kernel_key) const {
const std::string& kernel_name,
const KernelKey& kernel_key,
bool use_cudnn) const {
auto iter = kernels_.find(kernel_name);
PADDLE_ENFORCE_NE(
iter,
kernels_.end(),
phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name));
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (use_cudnn && kernel_key.backend() == Backend::GPU) {
auto kernel_iter = iter->second.find(
{Backend::GPUDNN, kernel_key.layout(), kernel_key.dtype()});
if (kernel_iter == iter->second.end() &&
kernel_key.layout() != phi::DataLayout::ALL_LAYOUT) {
kernel_iter = iter->second.find(
{Backend::GPUDNN, DataLayout::ALL_LAYOUT, kernel_key.dtype()});
}
if (kernel_iter != iter->second.end()) {
return kernel_iter->second;
}
LOG(WARNING) << "The cudnn kernel for [" << kernel_name
<< "] is not registered.";
}
#endif
auto kernel_iter = iter->second.find(kernel_key);
// TODO(chenweihang): polish refind impl here
if (kernel_iter == iter->second.end() &&
......
......@@ -238,7 +238,8 @@ class KernelFactory {
}
const Kernel& SelectKernelOrThrowError(const std::string& kernel_name,
const KernelKey& kernel_key) const;
const KernelKey& kernel_key,
bool use_cudnn = false) const;
const Kernel& SelectKernelOrThrowError(const std::string& kernel_name,
Backend backend,
......
......@@ -238,7 +238,8 @@ class BaseAPI(object):
'param': None,
'backend': None,
'layout': None,
'data_type': None
'data_type': None,
'use_cudnn': 'false'
}
if 'backend' in kernel_config and len(kernel_config['backend']) > 0:
kernel['backend'] = kernel_config['backend']
......@@ -248,6 +249,10 @@ class BaseAPI(object):
kernel['data_type'] = kernel_config['data_type']
if 'param' in kernel_config:
kernel['param'] = kernel_config['param']
if 'use_cudnn' in kernel_config:
kernel['use_cudnn'] = kernel_config['use_cudnn']
if isinstance(kernel['use_cudnn'], bool):
kernel['use_cudnn'] = str(kernel['use_cudnn']).lower()
kernel['func'] = [
kernel_fn.strip() for kernel_fn in kernel_config['func'].split(',')
]
......@@ -713,10 +718,12 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self
outputs_args, kernel_output_names, output_create = self.gene_output(
self.outputs['types'], 'SetKernelOutput', code_indent, inplace_flag)
api_func_name = self.get_api_func_name() + ('_' if inplace_flag else '')
cudnn_args = '' if self.kernel[
'use_cudnn'] == 'false' else ', ' + self.kernel['use_cudnn']
return f"""
{code_indent} VLOG(6) << "{self.api} API kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]";
{code_indent} const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
{code_indent} "{self.kernel['func'][0]}", {{kernel_backend, kernel_layout, kernel_data_type}});
{code_indent} "{self.kernel['func'][0]}", {{kernel_backend, kernel_layout, kernel_data_type}}{cudnn_args});
{code_indent} VLOG(6) << "{self.api} API kernel: " << kernel;
{code_indent} auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
......
......@@ -163,6 +163,8 @@ def source_include(header_file_path):
#include "paddle/phi/infermeta/ternary.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
DECLARE_bool(conv2d_disable_cudnn);
"""
......
......@@ -179,6 +179,8 @@ def source_include(header_file_path):
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
DECLARE_bool(conv2d_disable_cudnn);
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册