Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
1888d874
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1888d874
编写于
4月 04, 2022
作者:
Z
zyfncg
提交者:
GitHub
4月 04, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add cudnn flag in yaml (#41368)
上级
77cf305f
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
34 addition
and
4 deletion
+34
-4
paddle/phi/core/kernel_factory.cc
paddle/phi/core/kernel_factory.cc
+19
-1
paddle/phi/core/kernel_factory.h
paddle/phi/core/kernel_factory.h
+2
-1
python/paddle/utils/code_gen/api_base.py
python/paddle/utils/code_gen/api_base.py
+9
-2
python/paddle/utils/code_gen/api_gen.py
python/paddle/utils/code_gen/api_gen.py
+2
-0
python/paddle/utils/code_gen/backward_api_gen.py
python/paddle/utils/code_gen/backward_api_gen.py
+2
-0
未找到文件。
paddle/phi/core/kernel_factory.cc
浏览文件 @
1888d874
...
...
@@ -75,13 +75,31 @@ bool KernelFactory::IsSelectKernelValid(const std::string& kernel_name,
}
const
Kernel
&
KernelFactory
::
SelectKernelOrThrowError
(
const
std
::
string
&
kernel_name
,
const
KernelKey
&
kernel_key
)
const
{
const
std
::
string
&
kernel_name
,
const
KernelKey
&
kernel_key
,
bool
use_cudnn
)
const
{
auto
iter
=
kernels_
.
find
(
kernel_name
);
PADDLE_ENFORCE_NE
(
iter
,
kernels_
.
end
(),
phi
::
errors
::
NotFound
(
"The kernel `%s` is not registered."
,
kernel_name
));
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if
(
use_cudnn
&&
kernel_key
.
backend
()
==
Backend
::
GPU
)
{
auto
kernel_iter
=
iter
->
second
.
find
(
{
Backend
::
GPUDNN
,
kernel_key
.
layout
(),
kernel_key
.
dtype
()});
if
(
kernel_iter
==
iter
->
second
.
end
()
&&
kernel_key
.
layout
()
!=
phi
::
DataLayout
::
ALL_LAYOUT
)
{
kernel_iter
=
iter
->
second
.
find
(
{
Backend
::
GPUDNN
,
DataLayout
::
ALL_LAYOUT
,
kernel_key
.
dtype
()});
}
if
(
kernel_iter
!=
iter
->
second
.
end
())
{
return
kernel_iter
->
second
;
}
LOG
(
WARNING
)
<<
"The cudnn kernel for ["
<<
kernel_name
<<
"] is not registered."
;
}
#endif
auto
kernel_iter
=
iter
->
second
.
find
(
kernel_key
);
// TODO(chenweihang): polish refind impl here
if
(
kernel_iter
==
iter
->
second
.
end
()
&&
...
...
paddle/phi/core/kernel_factory.h
浏览文件 @
1888d874
...
...
@@ -238,7 +238,8 @@ class KernelFactory {
}
const
Kernel
&
SelectKernelOrThrowError
(
const
std
::
string
&
kernel_name
,
const
KernelKey
&
kernel_key
)
const
;
const
KernelKey
&
kernel_key
,
bool
use_cudnn
=
false
)
const
;
const
Kernel
&
SelectKernelOrThrowError
(
const
std
::
string
&
kernel_name
,
Backend
backend
,
...
...
python/paddle/utils/code_gen/api_base.py
浏览文件 @
1888d874
...
...
@@ -238,7 +238,8 @@ class BaseAPI(object):
'param'
:
None
,
'backend'
:
None
,
'layout'
:
None
,
'data_type'
:
None
'data_type'
:
None
,
'use_cudnn'
:
'false'
}
if
'backend'
in
kernel_config
and
len
(
kernel_config
[
'backend'
])
>
0
:
kernel
[
'backend'
]
=
kernel_config
[
'backend'
]
...
...
@@ -248,6 +249,10 @@ class BaseAPI(object):
kernel
[
'data_type'
]
=
kernel_config
[
'data_type'
]
if
'param'
in
kernel_config
:
kernel
[
'param'
]
=
kernel_config
[
'param'
]
if
'use_cudnn'
in
kernel_config
:
kernel
[
'use_cudnn'
]
=
kernel_config
[
'use_cudnn'
]
if
isinstance
(
kernel
[
'use_cudnn'
],
bool
):
kernel
[
'use_cudnn'
]
=
str
(
kernel
[
'use_cudnn'
]).
lower
()
kernel
[
'func'
]
=
[
kernel_fn
.
strip
()
for
kernel_fn
in
kernel_config
[
'func'
].
split
(
','
)
]
...
...
@@ -713,10 +718,12 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self
outputs_args
,
kernel_output_names
,
output_create
=
self
.
gene_output
(
self
.
outputs
[
'types'
],
'SetKernelOutput'
,
code_indent
,
inplace_flag
)
api_func_name
=
self
.
get_api_func_name
()
+
(
'_'
if
inplace_flag
else
''
)
cudnn_args
=
''
if
self
.
kernel
[
'use_cudnn'
]
==
'false'
else
', '
+
self
.
kernel
[
'use_cudnn'
]
return
f
"""
{
code_indent
}
VLOG(6) << "
{
self
.
api
}
API kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]";
{
code_indent
}
const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
{
code_indent
}
"
{
self
.
kernel
[
'func'
][
0
]
}
", {{kernel_backend, kernel_layout, kernel_data_type}});
{
code_indent
}
"
{
self
.
kernel
[
'func'
][
0
]
}
", {{kernel_backend, kernel_layout, kernel_data_type}}
{
cudnn_args
}
);
{
code_indent
}
VLOG(6) << "
{
self
.
api
}
API kernel: " << kernel;
{
code_indent
}
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
...
...
python/paddle/utils/code_gen/api_gen.py
浏览文件 @
1888d874
...
...
@@ -163,6 +163,8 @@ def source_include(header_file_path):
#include "paddle/phi/infermeta/ternary.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
DECLARE_bool(conv2d_disable_cudnn);
"""
...
...
python/paddle/utils/code_gen/backward_api_gen.py
浏览文件 @
1888d874
...
...
@@ -179,6 +179,8 @@ def source_include(header_file_path):
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
DECLARE_bool(conv2d_disable_cudnn);
"""
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录