未验证 提交 8f8a02fd 编写于 作者: L LielinJiang 提交者: GitHub

Optimize conv performance (#28766)

* optimize conv performance
上级 00e55ded
......@@ -95,6 +95,68 @@ def _update_padding_nd(padding, channel_last, num_dims):
return padding, padding_algorithm
def _conv_nd(x,
weight,
bias=None,
stride=1,
padding=0,
padding_algorithm=None,
dilation=1,
groups=1,
data_format="NCHW",
channel_dim=1,
op_type="conv2d",
use_cudnn=True,
use_mkldnn=False,
name=None):
if in_dygraph_mode():
attrs = ('strides', stride, 'paddings', padding, 'dilations', dilation,
'groups', groups, 'use_cudnn', use_cudnn, 'use_mkldnn',
use_mkldnn, 'fuse_relu_before_depthwise_conv', False,
"padding_algorithm", padding_algorithm, "data_format",
data_format)
pre_bias = getattr(core.ops, op_type)(x, weight, *attrs)
if bias is not None:
out = nn.elementwise_add(pre_bias, bias, axis=channel_dim)
else:
out = pre_bias
else:
inputs = {'Input': [x], 'Filter': [weight]}
attrs = {
'strides': stride,
'paddings': padding,
'dilations': dilation,
'groups': groups,
'use_cudnn': use_cudnn,
'use_mkldnn': use_mkldnn,
'fuse_relu_before_depthwise_conv': False,
"padding_algorithm": padding_algorithm,
"data_format": data_format
}
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
op_type)
helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype(input_param_name='x')
pre_bias = helper.create_variable_for_type_inference(dtype)
outputs = {"Output": [pre_bias]}
helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=attrs)
if bias is not None:
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='elementwise_add',
inputs={'X': [pre_bias],
'Y': [bias]},
outputs={'Out': [out]},
attrs={'axis': channel_dim,
'use_mkldnn': use_mkldnn})
else:
out = pre_bias
return out
def conv1d(x,
weight,
bias=None,
......@@ -472,12 +534,13 @@ def conv2d(x,
"received: the number of filters is {}, the shape of weight is {}"
", the groups is {}".format(num_filters, weight.shape, groups))
# use_cudnn = True if core.is_compiled_with_cuda() else False
cudnn_version = get_cudnn_version()
use_cudnn = True if (core.is_compiled_with_cuda() and
cudnn_version is not None) else False
use_mkldnn = core.globals()["FLAGS_use_mkldnn"]
# update attrs
padding, padding_algorithm = _update_padding_nd(padding, channel_last, 2)
stride = utils.convert_to_list(stride, 2, 'stride')
......@@ -489,56 +552,9 @@ def conv2d(x,
l_type = 'depthwise_conv2d'
use_cudnn = False
inputs = {'Input': [x], 'Filter': [weight]}
attrs = {
'strides': stride,
'paddings': padding,
'dilations': dilation,
'groups': groups,
'use_cudnn': use_cudnn,
'use_mkldnn': False,
'fuse_relu_before_depthwise_conv': False,
"padding_algorithm": padding_algorithm,
"data_format": data_format
}
if in_dygraph_mode():
attrs = ('strides', stride, 'paddings', padding, 'dilations', dilation,
'groups', groups, 'use_cudnn', use_cudnn, 'use_mkldnn', False,
'fuse_relu_before_depthwise_conv', False, "padding_algorithm",
padding_algorithm, "data_format", data_format)
pre_bias = getattr(core.ops, l_type)(x, weight, *attrs)
if bias is not None:
out = nn.elementwise_add(pre_bias, bias, axis=channel_dim)
else:
out = pre_bias
else:
inputs = {'Input': [x], 'Filter': [weight]}
attrs = {
'strides': stride,
'paddings': padding,
'dilations': dilation,
'groups': groups,
'use_cudnn': use_cudnn,
'use_mkldnn': False,
'fuse_relu_before_depthwise_conv': False,
"padding_algorithm": padding_algorithm,
"data_format": data_format
}
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'conv2d')
helper = LayerHelper(l_type, **locals())
dtype = helper.input_dtype(input_param_name='x')
pre_bias = helper.create_variable_for_type_inference(dtype)
outputs = {"Output": [pre_bias]}
helper.append_op(
type=l_type, inputs=inputs, outputs=outputs, attrs=attrs)
if bias is not None:
out = nn.elementwise_add(pre_bias, bias, axis=channel_dim)
else:
out = pre_bias
return out
return _conv_nd(x, weight, bias, stride, padding, padding_algorithm,
dilation, groups, data_format, channel_dim, l_type,
use_cudnn, use_mkldnn, name)
def conv1d_transpose(x,
......@@ -1201,44 +1217,9 @@ def conv3d(x,
dilation = utils.convert_to_list(dilation, 3, 'dilation')
op_type = "conv3d"
if in_dygraph_mode():
attrs = ('strides', stride, 'paddings', padding, 'dilations', dilation,
'groups', groups, 'use_cudnn', use_cudnn, 'use_mkldnn', False,
"padding_algorithm", padding_algorithm, "data_format",
data_format)
pre_bias = getattr(core.ops, op_type)(x, weight, *attrs)
if bias is not None:
out = nn.elementwise_add(pre_bias, bias, axis=channel_dim)
else:
out = pre_bias
else:
inputs = {'Input': [x], 'Filter': [weight]}
attrs = {
'strides': stride,
'paddings': padding,
'dilations': dilation,
'groups': groups,
'use_cudnn': use_cudnn,
'use_mkldnn': False,
"padding_algorithm": padding_algorithm,
"data_format": data_format
}
helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype(input_param_name='x')
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'conv3d')
pre_bias = helper.create_variable_for_type_inference(dtype)
outputs = {"Output": [pre_bias]}
helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=attrs)
if bias is not None:
out = nn.elementwise_add(pre_bias, bias, axis=channel_dim)
else:
out = pre_bias
return out
return _conv_nd(x, weight, bias, stride, padding, padding_algorithm,
dilation, groups, data_format, channel_dim, op_type,
use_cudnn, False, name)
def conv3d_transpose(x,
......
......@@ -25,6 +25,8 @@ __all__ = [
import numpy as np
from ...fluid import core
from ...device import get_cudnn_version
from ...fluid.dygraph import layers
from ...fluid.initializer import Normal
from .. import functional as F
......@@ -83,6 +85,13 @@ class _ConvNd(layers.Layer):
"when padding_mode in ['reflect', 'replicate', 'circular'], type of padding must be int"
)
channel_last = (data_format == "NHWC") or (data_format == "NDHWC") or (
data_format == "NLC")
if channel_last:
self._channel_dim = len(data_format) - 1
else:
self._channel_dim = 1
self._stride = utils.convert_to_list(stride, dims, 'stride')
self._dilation = utils.convert_to_list(dilation, dims, 'dilation')
self._kernel_size = utils.convert_to_list(kernel_size, dims,
......@@ -90,10 +99,15 @@ class _ConvNd(layers.Layer):
self._padding = padding
self._padding_mode = padding_mode
self.output_padding = output_padding
if dims != 1:
self._padding, self._padding_algorithm = _update_padding_nd(
padding, channel_last, dims)
if transposed:
filter_shape = [self._in_channels, out_channels // groups
] + self._kernel_size
self._padding, self._padding_algorithm = _update_padding_nd(
padding, channel_last, dims)
else:
if in_channels % groups != 0:
raise ValueError("in_channels must be divisible by groups.")
......@@ -104,6 +118,8 @@ class _ConvNd(layers.Layer):
self._reversed_padding_repeated_twice = _reverse_repeat_list(
_paired_padding, 2)
self._padding, _ = _update_padding_nd(0, channel_last, dims)
filter_shape = [out_channels, in_channels // groups
] + self._kernel_size
......@@ -112,6 +128,17 @@ class _ConvNd(layers.Layer):
self.bias = self.create_parameter(
attr=self._bias_attr, shape=[self._out_channels], is_bias=True)
cudnn_version = get_cudnn_version()
self._use_cudnn = True if (core.is_compiled_with_cuda() and
cudnn_version is not None) else False
self._op_type = "conv" + str(dims) + 'd'
if dims == 2 and (in_channels == groups and in_channels != 1 and
out_channels % in_channels == 0):
self.op_type = 'depthwise_conv2d'
self._use_cudnn = False
class Conv1D(_ConvNd):
"""
......@@ -581,24 +608,20 @@ class Conv2D(_ConvNd):
self._reversed_padding_repeated_twice,
mode=self._padding_mode,
data_format=self._data_format)
return F.conv2d(
x,
self.weight,
bias=self.bias,
stride=self._stride,
dilation=self._dilation,
groups=self._groups,
data_format=self._data_format)
out = F.conv2d(
out = F.conv._conv_nd(
x,
self.weight,
bias=self.bias,
padding=self._padding,
stride=self._stride,
padding=self._padding,
padding_algorithm=self._padding_algorithm,
dilation=self._dilation,
groups=self._groups,
data_format=self._data_format)
data_format=self._data_format,
channel_dim=self._channel_dim,
op_type=self._op_type,
use_cudnn=self._use_cudnn)
return out
......@@ -902,24 +925,20 @@ class Conv3D(_ConvNd):
self._reversed_padding_repeated_twice,
mode=self._padding_mode,
data_format=self._data_format)
return F.conv3d(
x,
self.weight,
bias=self.bias,
stride=self._stride,
dilation=self._dilation,
groups=self._groups,
data_format=self._data_format)
out = F.conv3d(
out = F.conv._conv_nd(
x,
self.weight,
bias=self.bias,
padding=self._padding,
stride=self._stride,
padding=self._padding,
padding_algorithm=self._padding_algorithm,
dilation=self._dilation,
groups=self._groups,
data_format=self._data_format)
data_format=self._data_format,
channel_dim=self._channel_dim,
op_type=self._op_type,
use_cudnn=self._use_cudnn)
return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册