diff --git a/python/paddle/nn/functional/conv.py b/python/paddle/nn/functional/conv.py index 3c1482e69c3c36232ee5d70f2156a8d16c2d212a..5cf4953933242292c6a732513dbee2164811dd35 100644 --- a/python/paddle/nn/functional/conv.py +++ b/python/paddle/nn/functional/conv.py @@ -267,8 +267,8 @@ def conv1d(x, dilation = utils.convert_to_list(dilation, 1, 'dilation') + [1] l_type = "conv2d" - if (num_channels == groups and num_filters % num_channels == 0 and - not use_cudnn): + if (num_channels == groups and num_channels != 1 and + num_filters % num_channels == 0 and not use_cudnn): l_type = 'depthwise_conv2d' use_cudnn = False @@ -491,7 +491,8 @@ def conv2d(x, dilation = utils.convert_to_list(dilation, 2, 'dilation') l_type = "conv2d" - if (num_channels == groups and num_filters % num_channels == 0): + if (num_channels == groups and num_channels != 1 and + num_filters % num_channels == 0): l_type = 'depthwise_conv2d' use_cudnn = False @@ -761,7 +762,8 @@ def conv_transpose1d(x, op_type = 'conv2d_transpose' num_filters = weight.shape[1] - if (num_channels == groups and num_filters == 1 and not use_cudnn): + if (num_channels == groups and num_channels != 1 and num_filters == 1 and + not use_cudnn): op_type = 'depthwise_conv2d_transpose' use_cudnn = False @@ -1010,7 +1012,7 @@ def conv_transpose2d(x, op_type = 'conv2d_transpose' num_filters = weight.shape[1] - if (num_channels == groups and num_filters == 1): + if (num_channels == groups and num_channels != 1 and num_filters == 1): op_type = 'depthwise_conv2d_transpose' use_cudnn = False