未验证 提交 123949eb 编写于 作者: R ronnywang 提交者: GitHub

[ROCM] added a cudnn switch of conv2d for rocm platform (#31836)

上级 61805d8f
......@@ -564,3 +564,15 @@ DEFINE_string(tracer_mkldnn_ops_on, "",
*/
DEFINE_string(tracer_mkldnn_ops_off, "",
"List of OneDNN operation types to be turned off");
/**
* CUDNN related FLAG
* Name: conv2d_disable_cudnn
* Since Version:
* Value Range: bool, default=false
* Example:
* Note: Disable cudnn in conv2d.
*/
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
DEFINE_bool(conv2d_disable_cudnn, false, "Disable cudnn in conv2d");
#endif
......@@ -72,6 +72,7 @@ DECLARE_uint64(conv_workspace_size_limit);
DECLARE_bool(cudnn_batchnorm_spatial_persistent);
DECLARE_bool(cudnn_deterministic);
DECLARE_bool(cudnn_exhaustive_search);
DECLARE_bool(conv2d_disable_cudnn);
// data processing
DECLARE_bool(enable_cublas_tensor_op_math);
// device management
......@@ -367,7 +368,8 @@ static void RegisterGlobalVarGetterSetter() {
FLAGS_fraction_of_cuda_pinned_memory_to_use,
FLAGS_fraction_of_gpu_memory_to_use, FLAGS_initial_gpu_memory_in_mb,
FLAGS_reallocate_gpu_memory_in_mb, FLAGS_enable_cublas_tensor_op_math,
FLAGS_selected_gpus, FLAGS_sync_nccl_allreduce);
FLAGS_selected_gpus, FLAGS_sync_nccl_allreduce,
FLAGS_conv2d_disable_cudnn);
#endif
#ifdef PADDLE_WITH_XPU
REGISTER_PUBLIC_GLOBAL_VAR(FLAGS_selected_xpus);
......
......@@ -230,6 +230,7 @@ def __bootstrap__():
'gpu_allocator_retry_time',
'local_exe_sub_scope_limit',
'gpu_memory_limit_mb',
'conv2d_disable_cudnn',
]
core.init_gflags(["--tryfromenv=" + ",".join(read_env_flags)])
core.init_glog(sys.argv[0])
......
......@@ -1603,6 +1603,10 @@ def conv2d(input,
pre_bias = helper.create_variable_for_type_inference(dtype)
if (core.is_compiled_with_cuda() and paddle.fluid.get_flags(
"FLAGS_conv2d_disable_cudnn")["FLAGS_conv2d_disable_cudnn"]):
use_cudnn = False
helper.append_op(
type=l_type,
inputs={
......
......@@ -1465,5 +1465,41 @@ class TestConv2DAPI_Error(unittest.TestCase):
self.assertRaises(ValueError, run_7)
# --------- test environment variable ------
@unittest.skipIf(
not (core.is_compiled_with_cuda() or core.is_compiled_with_rocm()),
"core is not compiled with CUDA or ROCM")
class TestConv2DEnviron(unittest.TestCase):
def run_conv2d_api(self):
inputs = fluid.layers.data(
shape=[2, 3, 5, 5],
append_batch_size=False,
name="inputs",
dtype="float32")
fluid.layers.conv2d(
input=inputs,
num_filters=4,
filter_size=[3, 3],
stride=[1, 1],
padding=0,
dilation=[1, 1],
groups=1,
data_format="NCHW")
x_var = paddle.uniform((2, 3, 5, 5), dtype="float32", min=-1., max=1.)
conv = paddle.nn.Conv2D(
in_channels=3,
out_channels=4,
kernel_size=(3, 3),
data_format="NCHW")
y_var = conv(x_var)
def test_environ(self):
fluid.set_flags({'FLAGS_conv2d_disable_cudnn': False})
self.run_conv2d_api()
fluid.set_flags({'FLAGS_conv2d_disable_cudnn': True})
self.run_conv2d_api()
if __name__ == '__main__':
unittest.main()
......@@ -25,6 +25,7 @@ __all__ = [
import numpy as np
from ...fluid import get_flags
from ...fluid import core
from ...device import get_cudnn_version
from ...fluid.dygraph import layers
......@@ -644,6 +645,10 @@ class Conv2D(_ConvNd):
bias_attr=bias_attr,
data_format=data_format)
if (core.is_compiled_with_cuda() and get_flags(
"FLAGS_conv2d_disable_cudnn")["FLAGS_conv2d_disable_cudnn"]):
self._use_cudnn = False
def forward(self, x):
if self._padding_mode != 'zeros':
x = F.pad(x,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册