Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
4fd5ed43
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4fd5ed43
编写于
3月 29, 2021
作者:
R
ronnywang
提交者:
GitHub
3月 29, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[ROCM] added a cudnn switch of conv2d for rocm platform (#31836) (#31932)
上级
9b40cb87
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
61 addition
and
1 deletion
+61
-1
paddle/fluid/platform/flags.cc
paddle/fluid/platform/flags.cc
+12
-0
paddle/fluid/pybind/global_value_getter_setter.cc
paddle/fluid/pybind/global_value_getter_setter.cc
+3
-1
python/paddle/fluid/__init__.py
python/paddle/fluid/__init__.py
+1
-0
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+4
-0
python/paddle/fluid/tests/unittests/test_conv2d_op.py
python/paddle/fluid/tests/unittests/test_conv2d_op.py
+36
-0
python/paddle/nn/layer/conv.py
python/paddle/nn/layer/conv.py
+5
-0
未找到文件。
paddle/fluid/platform/flags.cc
浏览文件 @
4fd5ed43
...
...
@@ -564,3 +564,15 @@ DEFINE_string(tracer_mkldnn_ops_on, "",
*/
DEFINE_string
(
tracer_mkldnn_ops_off
,
""
,
"List of OneDNN operation types to be turned off"
);
/**
* CUDNN related FLAG
* Name: conv2d_disable_cudnn
* Since Version:
* Value Range: bool, default=false
* Example:
* Note: Disable cudnn in conv2d.
*/
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
DEFINE_bool
(
conv2d_disable_cudnn
,
false
,
"Disable cudnn in conv2d"
);
#endif
paddle/fluid/pybind/global_value_getter_setter.cc
浏览文件 @
4fd5ed43
...
...
@@ -72,6 +72,7 @@ DECLARE_uint64(conv_workspace_size_limit);
DECLARE_bool
(
cudnn_batchnorm_spatial_persistent
);
DECLARE_bool
(
cudnn_deterministic
);
DECLARE_bool
(
cudnn_exhaustive_search
);
DECLARE_bool
(
conv2d_disable_cudnn
);
// data processing
DECLARE_bool
(
enable_cublas_tensor_op_math
);
// device management
...
...
@@ -367,7 +368,8 @@ static void RegisterGlobalVarGetterSetter() {
FLAGS_fraction_of_cuda_pinned_memory_to_use
,
FLAGS_fraction_of_gpu_memory_to_use
,
FLAGS_initial_gpu_memory_in_mb
,
FLAGS_reallocate_gpu_memory_in_mb
,
FLAGS_enable_cublas_tensor_op_math
,
FLAGS_selected_gpus
,
FLAGS_sync_nccl_allreduce
);
FLAGS_selected_gpus
,
FLAGS_sync_nccl_allreduce
,
FLAGS_conv2d_disable_cudnn
);
#endif
#ifdef PADDLE_WITH_XPU
REGISTER_PUBLIC_GLOBAL_VAR
(
FLAGS_selected_xpus
);
...
...
python/paddle/fluid/__init__.py
浏览文件 @
4fd5ed43
...
...
@@ -230,6 +230,7 @@ def __bootstrap__():
'gpu_allocator_retry_time'
,
'local_exe_sub_scope_limit'
,
'gpu_memory_limit_mb'
,
'conv2d_disable_cudnn'
,
]
core
.
init_gflags
([
"--tryfromenv="
+
","
.
join
(
read_env_flags
)])
core
.
init_glog
(
sys
.
argv
[
0
])
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
4fd5ed43
...
...
@@ -1603,6 +1603,10 @@ def conv2d(input,
pre_bias = helper.create_variable_for_type_inference(dtype)
if (core.is_compiled_with_cuda() and paddle.fluid.get_flags(
"FLAGS_conv2d_disable_cudnn")["FLAGS_conv2d_disable_cudnn"]):
use_cudnn = False
helper.append_op(
type=l_type,
inputs={
...
...
python/paddle/fluid/tests/unittests/test_conv2d_op.py
浏览文件 @
4fd5ed43
...
...
@@ -1465,5 +1465,41 @@ class TestConv2DAPI_Error(unittest.TestCase):
self
.
assertRaises
(
ValueError
,
run_7
)
# --------- test environment variable ------
@
unittest
.
skipIf
(
not
(
core
.
is_compiled_with_cuda
()
or
core
.
is_compiled_with_rocm
()),
"core is not compiled with CUDA or ROCM"
)
class
TestConv2DEnviron
(
unittest
.
TestCase
):
def
run_conv2d_api
(
self
):
inputs
=
fluid
.
layers
.
data
(
shape
=
[
2
,
3
,
5
,
5
],
append_batch_size
=
False
,
name
=
"inputs"
,
dtype
=
"float32"
)
fluid
.
layers
.
conv2d
(
input
=
inputs
,
num_filters
=
4
,
filter_size
=
[
3
,
3
],
stride
=
[
1
,
1
],
padding
=
0
,
dilation
=
[
1
,
1
],
groups
=
1
,
data_format
=
"NCHW"
)
x_var
=
paddle
.
uniform
((
2
,
3
,
5
,
5
),
dtype
=
"float32"
,
min
=-
1.
,
max
=
1.
)
conv
=
paddle
.
nn
.
Conv2D
(
in_channels
=
3
,
out_channels
=
4
,
kernel_size
=
(
3
,
3
),
data_format
=
"NCHW"
)
y_var
=
conv
(
x_var
)
def
test_environ
(
self
):
fluid
.
set_flags
({
'FLAGS_conv2d_disable_cudnn'
:
False
})
self
.
run_conv2d_api
()
fluid
.
set_flags
({
'FLAGS_conv2d_disable_cudnn'
:
True
})
self
.
run_conv2d_api
()
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/nn/layer/conv.py
浏览文件 @
4fd5ed43
...
...
@@ -25,6 +25,7 @@ __all__ = [
import
numpy
as
np
from
...fluid
import
get_flags
from
...fluid
import
core
from
...device
import
get_cudnn_version
from
...fluid.dygraph
import
layers
...
...
@@ -644,6 +645,10 @@ class Conv2D(_ConvNd):
bias_attr
=
bias_attr
,
data_format
=
data_format
)
if
(
core
.
is_compiled_with_cuda
()
and
get_flags
(
"FLAGS_conv2d_disable_cudnn"
)[
"FLAGS_conv2d_disable_cudnn"
]):
self
.
_use_cudnn
=
False
def
forward
(
self
,
x
):
if
self
.
_padding_mode
!=
'zeros'
:
x
=
F
.
pad
(
x
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录