diff --git a/python/paddle/nn/functional/conv.py b/python/paddle/nn/functional/conv.py index 5dfd66feda509f4597dbc7a4f7a665caa9a25f27..c4410346ca17d1bed626cf176403dcb43b9668c5 100644 --- a/python/paddle/nn/functional/conv.py +++ b/python/paddle/nn/functional/conv.py @@ -95,6 +95,68 @@ def _update_padding_nd(padding, channel_last, num_dims): return padding, padding_algorithm +def _conv_nd(x, + weight, + bias=None, + stride=1, + padding=0, + padding_algorithm=None, + dilation=1, + groups=1, + data_format="NCHW", + channel_dim=1, + op_type="conv2d", + use_cudnn=True, + use_mkldnn=False, + name=None): + + if in_dygraph_mode(): + attrs = ('strides', stride, 'paddings', padding, 'dilations', dilation, + 'groups', groups, 'use_cudnn', use_cudnn, 'use_mkldnn', + use_mkldnn, 'fuse_relu_before_depthwise_conv', False, + "padding_algorithm", padding_algorithm, "data_format", + data_format) + pre_bias = getattr(core.ops, op_type)(x, weight, *attrs) + if bias is not None: + out = nn.elementwise_add(pre_bias, bias, axis=channel_dim) + else: + out = pre_bias + else: + inputs = {'Input': [x], 'Filter': [weight]} + attrs = { + 'strides': stride, + 'paddings': padding, + 'dilations': dilation, + 'groups': groups, + 'use_cudnn': use_cudnn, + 'use_mkldnn': use_mkldnn, + 'fuse_relu_before_depthwise_conv': False, + "padding_algorithm": padding_algorithm, + "data_format": data_format + } + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], + op_type) + helper = LayerHelper(op_type, **locals()) + dtype = helper.input_dtype(input_param_name='x') + pre_bias = helper.create_variable_for_type_inference(dtype) + outputs = {"Output": [pre_bias]} + helper.append_op( + type=op_type, inputs=inputs, outputs=outputs, attrs=attrs) + if bias is not None: + out = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type='elementwise_add', + inputs={'X': [pre_bias], + 'Y': [bias]}, + outputs={'Out': [out]}, + attrs={'axis': channel_dim, + 'use_mkldnn': use_mkldnn}) + else: + out = pre_bias + + return out + + def conv1d(x, weight, bias=None, @@ -472,12 +534,13 @@ def conv2d(x, "received: the number of filters is {}, the shape of weight is {}" ", the groups is {}".format(num_filters, weight.shape, groups)) - # use_cudnn = True if core.is_compiled_with_cuda() else False cudnn_version = get_cudnn_version() use_cudnn = True if (core.is_compiled_with_cuda() and cudnn_version is not None) else False + use_mkldnn = core.globals()["FLAGS_use_mkldnn"] + # update attrs padding, padding_algorithm = _update_padding_nd(padding, channel_last, 2) stride = utils.convert_to_list(stride, 2, 'stride') @@ -489,56 +552,9 @@ def conv2d(x, l_type = 'depthwise_conv2d' use_cudnn = False - inputs = {'Input': [x], 'Filter': [weight]} - attrs = { - 'strides': stride, - 'paddings': padding, - 'dilations': dilation, - 'groups': groups, - 'use_cudnn': use_cudnn, - 'use_mkldnn': False, - 'fuse_relu_before_depthwise_conv': False, - "padding_algorithm": padding_algorithm, - "data_format": data_format - } - - if in_dygraph_mode(): - attrs = ('strides', stride, 'paddings', padding, 'dilations', dilation, - 'groups', groups, 'use_cudnn', use_cudnn, 'use_mkldnn', False, - 'fuse_relu_before_depthwise_conv', False, "padding_algorithm", - padding_algorithm, "data_format", data_format) - pre_bias = getattr(core.ops, l_type)(x, weight, *attrs) - if bias is not None: - out = nn.elementwise_add(pre_bias, bias, axis=channel_dim) - else: - out = pre_bias - else: - inputs = {'Input': [x], 'Filter': [weight]} - attrs = { - 'strides': stride, - 'paddings': padding, - 'dilations': dilation, - 'groups': groups, - 'use_cudnn': use_cudnn, - 'use_mkldnn': False, - 'fuse_relu_before_depthwise_conv': False, - "padding_algorithm": padding_algorithm, - "data_format": data_format - } - check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], - 'conv2d') - helper = LayerHelper(l_type, **locals()) - dtype = helper.input_dtype(input_param_name='x') - pre_bias = helper.create_variable_for_type_inference(dtype) - outputs = {"Output": [pre_bias]} - helper.append_op( - type=l_type, inputs=inputs, outputs=outputs, attrs=attrs) - if bias is not None: - out = nn.elementwise_add(pre_bias, bias, axis=channel_dim) - else: - out = pre_bias - - return out + return _conv_nd(x, weight, bias, stride, padding, padding_algorithm, + dilation, groups, data_format, channel_dim, l_type, + use_cudnn, use_mkldnn, name) def conv1d_transpose(x, @@ -1201,44 +1217,9 @@ def conv3d(x, dilation = utils.convert_to_list(dilation, 3, 'dilation') op_type = "conv3d" - if in_dygraph_mode(): - attrs = ('strides', stride, 'paddings', padding, 'dilations', dilation, - 'groups', groups, 'use_cudnn', use_cudnn, 'use_mkldnn', False, - "padding_algorithm", padding_algorithm, "data_format", - data_format) - pre_bias = getattr(core.ops, op_type)(x, weight, *attrs) - if bias is not None: - out = nn.elementwise_add(pre_bias, bias, axis=channel_dim) - else: - out = pre_bias - else: - inputs = {'Input': [x], 'Filter': [weight]} - attrs = { - 'strides': stride, - 'paddings': padding, - 'dilations': dilation, - 'groups': groups, - 'use_cudnn': use_cudnn, - 'use_mkldnn': False, - "padding_algorithm": padding_algorithm, - "data_format": data_format - } - helper = LayerHelper(op_type, **locals()) - dtype = helper.input_dtype(input_param_name='x') - check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], - 'conv3d') - - pre_bias = helper.create_variable_for_type_inference(dtype) - outputs = {"Output": [pre_bias]} - - helper.append_op( - type=op_type, inputs=inputs, outputs=outputs, attrs=attrs) - if bias is not None: - out = nn.elementwise_add(pre_bias, bias, axis=channel_dim) - else: - out = pre_bias - - return out + return _conv_nd(x, weight, bias, stride, padding, padding_algorithm, + dilation, groups, data_format, channel_dim, op_type, + use_cudnn, False, name) def conv3d_transpose(x, diff --git a/python/paddle/nn/layer/conv.py b/python/paddle/nn/layer/conv.py index f97e549464738234b6f6bce529557750deb9fc2c..0b0d0e302b841c525d2059daa4ad45cb609159ca 100644 --- a/python/paddle/nn/layer/conv.py +++ b/python/paddle/nn/layer/conv.py @@ -25,6 +25,8 @@ __all__ = [ import numpy as np +from ...fluid import core +from ...device import get_cudnn_version from ...fluid.dygraph import layers from ...fluid.initializer import Normal from .. import functional as F @@ -83,6 +85,13 @@ class _ConvNd(layers.Layer): "when padding_mode in ['reflect', 'replicate', 'circular'], type of padding must be int" ) + channel_last = (data_format == "NHWC") or (data_format == "NDHWC") or ( + data_format == "NLC") + if channel_last: + self._channel_dim = len(data_format) - 1 + else: + self._channel_dim = 1 + self._stride = utils.convert_to_list(stride, dims, 'stride') self._dilation = utils.convert_to_list(dilation, dims, 'dilation') self._kernel_size = utils.convert_to_list(kernel_size, dims, @@ -90,10 +99,15 @@ class _ConvNd(layers.Layer): self._padding = padding self._padding_mode = padding_mode self.output_padding = output_padding + if dims != 1: + self._padding, self._padding_algorithm = _update_padding_nd( + padding, channel_last, dims) if transposed: filter_shape = [self._in_channels, out_channels // groups ] + self._kernel_size + self._padding, self._padding_algorithm = _update_padding_nd( + padding, channel_last, dims) else: if in_channels % groups != 0: raise ValueError("in_channels must be divisible by groups.") @@ -104,6 +118,8 @@ class _ConvNd(layers.Layer): self._reversed_padding_repeated_twice = _reverse_repeat_list( _paired_padding, 2) + self._padding, _ = _update_padding_nd(0, channel_last, dims) + filter_shape = [out_channels, in_channels // groups ] + self._kernel_size @@ -112,6 +128,17 @@ class _ConvNd(layers.Layer): self.bias = self.create_parameter( attr=self._bias_attr, shape=[self._out_channels], is_bias=True) + cudnn_version = get_cudnn_version() + + self._use_cudnn = True if (core.is_compiled_with_cuda() and + cudnn_version is not None) else False + + self._op_type = "conv" + str(dims) + 'd' + if dims == 2 and (in_channels == groups and in_channels != 1 and + out_channels % in_channels == 0): + self.op_type = 'depthwise_conv2d' + self._use_cudnn = False + class Conv1D(_ConvNd): """ @@ -581,24 +608,20 @@ class Conv2D(_ConvNd): self._reversed_padding_repeated_twice, mode=self._padding_mode, data_format=self._data_format) - return F.conv2d( - x, - self.weight, - bias=self.bias, - stride=self._stride, - dilation=self._dilation, - groups=self._groups, - data_format=self._data_format) - - out = F.conv2d( + + out = F.conv._conv_nd( x, self.weight, bias=self.bias, - padding=self._padding, stride=self._stride, + padding=self._padding, + padding_algorithm=self._padding_algorithm, dilation=self._dilation, groups=self._groups, - data_format=self._data_format) + data_format=self._data_format, + channel_dim=self._channel_dim, + op_type=self._op_type, + use_cudnn=self._use_cudnn) return out @@ -902,24 +925,20 @@ class Conv3D(_ConvNd): self._reversed_padding_repeated_twice, mode=self._padding_mode, data_format=self._data_format) - return F.conv3d( - x, - self.weight, - bias=self.bias, - stride=self._stride, - dilation=self._dilation, - groups=self._groups, - data_format=self._data_format) - - out = F.conv3d( + + out = F.conv._conv_nd( x, self.weight, bias=self.bias, - padding=self._padding, stride=self._stride, + padding=self._padding, + padding_algorithm=self._padding_algorithm, dilation=self._dilation, groups=self._groups, - data_format=self._data_format) + data_format=self._data_format, + channel_dim=self._channel_dim, + op_type=self._op_type, + use_cudnn=self._use_cudnn) return out