未验证 提交 09f7796c 编写于 作者: D Double_V 提交者: GitHub

fix fconv in paddle 1.8 (#4705)

上级 5c244bf1
......@@ -3,6 +3,7 @@ from paddle.fluid import dygraph
from paddle.fluid.dygraph import nn
from pytracking.libs.Fconv2d import Conv2D
from pytracking.libs.Fconv2d import FConv2D
class SiamFCEstimator(dygraph.layers.Layer):
......@@ -40,8 +41,7 @@ class SiamFCEstimator(dygraph.layers.Layer):
instance = fluid.layers.reshape(
instance, shape=[1, -1, shape[2], shape[3]])
cross_conv = Conv2D(stride=1, padding=0, dilation=1, groups=shape[0])
score_map = cross_conv(instance, exemplar)
score_map = FConv2D(instance, exemplar, stride=1, padding=0, dilation=1, groups=shape[0])
score_map = fluid.layers.transpose(score_map, [1, 0, 2, 3])
score_map = self.adjust_conv(score_map)
return score_map
......@@ -6,7 +6,7 @@ from paddle.fluid import layers
import cv2 as cv
from pytracking.features.preprocessing import numpy_to_paddle, paddle_to_numpy
from pytracking.libs.Fconv2d import Fconv2d
from pytracking.libs.Fconv2d import FConv2D
from pytracking.libs.paddle_utils import PTensor, _padding, n2p
......@@ -192,13 +192,13 @@ class Blur(Transform):
if isinstance(image, PTensor):
sz = image.shape[2:]
filter = [n2p(f) for f in self.filter_np]
im1 = Fconv2d(
im1 = FConv2D(
layers.reshape(image, [-1, 1, sz[0], sz[1]]),
filter[0],
padding=(self.filter_size[0], 0))
return self.crop_to_output(
layers.reshape(
Fconv2d(
FConv2D(
im1, filter[1], padding=(0, self.filter_size[1])),
[1, -1, sz[0], sz[1]]))
else:
......
from __future__ import print_function
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.layer_object_helper import LayerObjectHelper
from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.framework import Variable, OpProtoHolder, in_dygraph_mode
from paddle.fluid.layers import utils
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer
from paddle.fluid.dygraph import dygraph_utils
from paddle.fluid.framework import Variable, OpProtoHolder, in_dygraph_mode
from paddle.fluid.layers import utils
from paddle.fluid.framework import Variable, in_dygraph_mode
from paddle.fluid import core, dygraph_utils
from paddle.fluid.layers import nn, utils
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper
def Fconv2d(
input,
filter,
stride=1,
padding=0,
dilation=1,
groups=1,
use_cudnn=True, ):
"""
Similar with conv2d, this is a convolution2D layers. Difference
is filter can be token as input directly instead of setting filter size
and number of fliters. Filter is a 4-D tensor with shape
[num_filter, num_channel, filter_size_h, filter_size_w].
Args:
input (Variable): The input image with [N, C, H, W] format.
filter(Variable): The input filter with [out_channels, in_channels, H, W] format.
stride (int|tuple): The stride size. If stride is a tuple, it must
contain two integers, (stride_H, stride_W). Otherwise, the
stride_H = stride_W = stride. Default: stride = 1.
padding (int|tuple): The padding size. If padding is a tuple, it must
contain two integers, (padding_H, padding_W). Otherwise, the
padding_H = padding_W = padding. Default: padding = 0.
dilation (int|tuple): The dilation size. If dilation is a tuple, it must
contain two integers, (dilation_H, dilation_W). Otherwise, the
dilation_H = dilation_W = dilation. Default: dilation = 1.
bias_attr (ParamAttr|bool|None): The parameter attribute for the bias of conv2d.
If it is set to False, no bias will be added to the output units.
If it is set to None or one attribute of ParamAttr, conv2d
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
use_cudnn (bool): Use cudnn kernel or not, it is valid only when the cudnn
library is installed. Default: True
act (str): Activation type, if it is set to None, activation is not appended.
Default: None
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. Default: None
Returns:
Variable: The tensor variable storing the convolution and \
non-linearity activation result.
Raises:
ValueError: If the shapes of input, filter_size, stride, padding and
groups mismatch.
Examples:
.. code-block:: python
data = fluid.data(name='data', shape=[None, 3, 32, 32], \
dtype='float32')
filter = fluid.data(name='filter',shape=[10,3,3,3], \
dtype='float32')
conv2d = fluid.layers.conv2d(input=data,
filter=filter,
act="relu")
"""
conv_with_filter = Conv2D(
stride=stride, padding=padding, dilation=dilation, groups=groups)
return conv_with_filter(input, filter)
def _is_list_or_tuple(input):
return isinstance(input, (list, tuple))
class Conv2D(fluid.dygraph.layers.Layer):
"""
This interface is used to construct a callable object of the ``Conv2D`` class.
For more details, refer to code examples.
The convolution2D layer calculates the output based on the input, filter
and strides, paddings, dilations, groups parameters. Input and
Output are in NCHW format, where N is batch size, C is the number of
the feature map, H is the height of the feature map, and W is the width of the feature map.
Filter's shape is [MCHW] , where M is the number of output feature map,
C is the number of input feature map, H is the height of the filter,
and W is the width of the filter. If the groups is greater than 1,
C will equal the number of input feature map divided by the groups.
Please refer to UFLDL's `convolution
<http://ufldl.stanford.edu/tutorial/supervised/FeatureExtractionUsingConvolution/>`_
for more detials.
If bias attribution and activation type are provided, bias is added to the
output of the convolution, and the corresponding activation function is
applied to the final result.
For each input :math:`X`, the equation is:
.. math::
Out = \\sigma (W \\ast X + b)
Where:
* :math:`X`: Input value, a ``Tensor`` with NCHW format.
* :math:`W`: Filter value, a ``Tensor`` with shape [MCHW] .
* :math:`\\ast`: Convolution operation.
* :math:`b`: Bias value, a 2-D ``Tensor`` with shape [M, 1].
* :math:`\\sigma`: Activation function.
* :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different.
Example:
- Input:
Input shape: :math:`(N, C_{in}, H_{in}, W_{in})`
Filter shape: :math:`(C_{out}, C_{in}, H_f, W_f)`
- Output:
Output shape: :math:`(N, C_{out}, H_{out}, W_{out})`
Where
.. math::
H_{out}&= \\frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]} + 1 \\\\
W_{out}&= \\frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]} + 1
Parameters:
num_channels(int): The number of channels in the input image.
num_filters(int): The number of filter. It is as same as the output
feature map.
filter_size (int or tuple): The filter size. If filter_size is a tuple,
it must contain two integers, (filter_size_H, filter_size_W).
Otherwise, the filter will be a square.
stride (int or tuple, optional): The stride size. If stride is a tuple, it must
contain two integers, (stride_H, stride_W). Otherwise, the
stride_H = stride_W = stride. Default: 1.
padding (int or tuple, optional): The padding size. If padding is a tuple, it must
contain two integers, (padding_H, padding_W). Otherwise, the
padding_H = padding_W = padding. Default: 0.
dilation (int or tuple, optional): The dilation size. If dilation is a tuple, it must
contain two integers, (dilation_H, dilation_W). Otherwise, the
dilation_H = dilation_W = dilation. Default: 1.
groups (int, optional): The groups number of the Conv2d Layer. According to grouped
convolution in Alex Krizhevsky's Deep CNN paper: when group=2,
the first half of the filters is only connected to the first half
of the input channels, while the second half of the filters is only
connected to the second half of the input channels. Default: 1.
param_attr (ParamAttr, optional): The parameter attribute for learnable weights(Parameter)
of conv2d. If it is set to None or one attribute of ParamAttr, conv2d
will create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with :math:`Normal(0.0, std)`,
and the :math:`std` is :math:`(\\frac{2.0 }{filter\_elem\_num})^{0.5}`. Default: None.
bias_attr (ParamAttr or bool, optional): The attribute for the bias of conv2d.
If it is set to False, no bias will be added to the output units.
If it is set to None or one attribute of ParamAttr, conv2d
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
use_cudnn (bool, optional): Use cudnn kernel or not, it is valid only when the cudnn
library is installed. Default: True.
act (str, optional): Activation type, if it is set to None, activation is not appended.
Default: None.
dtype (str, optional): Data type, it can be "float32" or "float64". Default: "float32".
Attribute:
**weight** (Parameter): the learnable weights of filter of this layer.
**bias** (Parameter or None): the learnable bias of this layer.
Returns:
None
def _zero_padding_in_batch_and_channel(padding, channel_last):
if channel_last:
return list(padding[0]) == [0, 0] and list(padding[-1]) == [0, 0]
else:
return list(padding[0]) == [0, 0] and list(padding[1]) == [0, 0]
Raises:
ValueError: if ``use_cudnn`` is not a bool value.
Examples:
.. code-block:: python
from paddle.fluid.dygraph.base import to_variable
import paddle.fluid as fluid
from paddle.fluid.dygraph import Conv2D
import numpy as np
data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
with fluid.dygraph.guard():
conv2d = Conv2D(3, 2, 3)
data = to_variable(data)
conv = conv2d(data)
"""
def __init__(self,
stride=1,
padding=0,
dilation=1,
groups=None,
use_cudnn=True,
act=None,
dtype='float32'):
super(Conv2D, self).__init__()
self._groups = groups
self._stride = utils.convert_to_list(stride, 2, 'stride')
self._padding = utils.convert_to_list(padding, 2, 'padding')
self._dilation = utils.convert_to_list(dilation, 2, 'dilation')
self._act = act
if not isinstance(use_cudnn, bool):
raise ValueError("use_cudnn should be True or False")
self._use_cudnn = use_cudnn
self._dtype = dtype
def _exclude_padding_in_batch_and_channel(padding, channel_last):
padding_ = padding[1:-1] if channel_last else padding[2:]
padding_ = [elem for pad_a_dim in padding_ for elem in pad_a_dim]
return padding_
# TODO: recover the usage of depthwise_conv2d when it's
# kernel fixed https://github.com/PaddlePaddle/Paddle/issues/17098
# if (self._num_channels == self._groups and
# num_filters % self._num_channels == 0 and not self._use_cudnn):
# self._l_type = 'depthwise_conv2d'
# else:
# self._l_type = 'conv2d'
self._l_type = 'conv2d'
def forward(self, input, weight, bias=None):
inputs = {
'Input': [input],
'Filter': [weight],
}
def _update_padding_nd(padding, channel_last, num_dims):
if isinstance(padding, str):
padding = padding.upper()
if padding not in ["SAME", "VALID"]:
raise ValueError(
"Unknown padding: '{}'. It can only be 'SAME' or 'VALID'.".
format(padding))
if padding == "VALID":
padding_algorithm = "VALID"
padding = [0] * num_dims
else:
padding_algorithm = "SAME"
padding = [0] * num_dims
elif _is_list_or_tuple(padding):
# for padding like
# [(pad_before, pad_after), (pad_before, pad_after), ...]
# padding for batch_dim and channel_dim included
if len(padding) == 2 + num_dims and _is_list_or_tuple(padding[0]):
if not _zero_padding_in_batch_and_channel(padding, channel_last):
raise ValueError(
"Non-zero padding({}) in the batch or channel dimensions "
"is not supported.".format(padding))
padding_algorithm = "EXPLICIT"
padding = _exclude_padding_in_batch_and_channel(padding,
channel_last)
if utils._is_symmetric_padding(padding, num_dims):
padding = padding[0::2]
# for padding like [pad_before, pad_after, pad_before, pad_after, ...]
elif len(padding) == 2 * num_dims and isinstance(padding[0], int):
padding_algorithm = "EXPLICIT"
padding = utils.convert_to_list(padding, 2 * num_dims, 'padding')
if utils._is_symmetric_padding(padding, num_dims):
padding = padding[0::2]
# for padding like [pad_d1, pad_d2, ...]
elif len(padding) == num_dims and isinstance(padding[0], int):
padding_algorithm = "EXPLICIT"
padding = utils.convert_to_list(padding, num_dims, 'padding')
else:
raise ValueError("In valid padding: {}".format(padding))
# for integer padding
else:
padding_algorithm = "EXPLICIT"
padding = utils.convert_to_list(padding, num_dims, 'padding')
return padding, padding_algorithm
def FConv2D(input,
weight,
bias=None,
padding=0,
stride=1,
dilation=1,
groups=1,
use_cudnn=True,
act=None,
data_format="NCHW",
name=None):
# entry checks
if not isinstance(use_cudnn, bool):
raise ValueError("Attr(use_cudnn) should be True or False. "
"Received Attr(use_cudnn): {}.".format(use_cudnn))
if data_format not in ["NCHW", "NHWC"]:
raise ValueError("Attr(data_format) should be 'NCHW' or 'NHWC'. "
"Received Attr(data_format): {}.".format(data_format))
channel_last = (data_format == "NHWC")
channel_dim = -1 if channel_last else 1
num_channels = input.shape[channel_dim]
num_filters = weight.shape[0]
if num_channels < 0:
raise ValueError("The channel dimmention of the input({}) "
"should be defined. Received: {}.".format(
input.shape, num_channels))
if num_channels % groups != 0:
raise ValueError(
"the channel of input must be divisible by groups,"
"received: the channel of input is {}, the shape of input is {}"
", the groups is {}".format(num_channels, input.shape, groups))
if num_filters % groups != 0:
raise ValueError(
"the number of filters must be divisible by groups,"
"received: the number of filters is {}, the shape of weight is {}"
", the groups is {}".format(num_filters, weight.shape, groups))
# update attrs
padding, padding_algorithm = _update_padding_nd(padding, channel_last, 2)
stride = utils.convert_to_list(stride, 2, 'stride')
dilation = utils.convert_to_list(dilation, 2, 'dilation')
l_type = "conv2d"
if (num_channels == groups and num_filters % num_channels == 0 and
not use_cudnn):
l_type = 'depthwise_conv2d'
inputs = {'Input': [input], 'Filter': [weight]}
attrs = {
'strides': stride,
'paddings': padding,
'dilations': dilation,
'groups': groups,
'use_cudnn': use_cudnn,
'use_mkldnn': False,
'fuse_relu_before_depthwise_conv': False,
"padding_algorithm": padding_algorithm,
"data_format": data_format
}
if in_dygraph_mode():
attrs = ('strides', stride, 'paddings', padding, 'dilations', dilation,
'groups', groups, 'use_cudnn', use_cudnn, 'use_mkldnn', False,
'fuse_relu_before_depthwise_conv', False, "padding_algorithm",
padding_algorithm, "data_format", data_format)
pre_bias = getattr(core.ops, l_type)(input, weight, *attrs)
if bias is not None:
pre_act = nn.elementwise_add(pre_bias, bias, axis=channel_dim)
else:
pre_act = pre_bias
out = dygraph_utils._append_activation_in_dygraph(
pre_act, act, use_cudnn=use_cudnn)
else:
inputs = {'Input': [input], 'Filter': [weight]}
attrs = {
'strides': self._stride,
'paddings': self._padding,
'dilations': self._dilation,
'groups': self._groups if self._groups else 1,
'use_cudnn': self._use_cudnn,
'strides': stride,
'paddings': padding,
'dilations': dilation,
'groups': groups,
'use_cudnn': use_cudnn,
'use_mkldnn': False,
'fuse_relu_before_depthwise_conv': False,
"padding_algorithm": padding_algorithm,
"data_format": data_format
}
check_variable_and_dtype(input, 'input',
['float16', 'float32', 'float64'], 'conv2d')
helper = LayerHelper(l_type, **locals())
dtype = helper.input_dtype()
pre_bias = helper.create_variable_for_type_inference(dtype)
outputs = {"Output": [pre_bias]}
helper.append_op(
type=l_type, inputs=inputs, outputs=outputs, attrs=attrs)
if bias is not None:
pre_act = nn.elementwise_add(pre_bias, bias, axis=channel_dim)
else:
pre_act = pre_bias
out = helper.append_activation(pre_act)
return out
if in_dygraph_mode():
outs = core.ops.conv2d(inputs, attrs)
pre_bias = outs['Output'][0]
pre_act = dygraph_utils._append_bias_in_dygraph(pre_bias, bias, 1)
return dygraph_utils._append_activation_in_dygraph(pre_act,
self._act)
pre_bias = self._helper.create_variable_for_type_inference(
dtype=self._dtype)
def test_conv2d_with_filter():
self._helper.append_op(
type=self._l_type,
inputs={
'Input': input,
'Filter': weight,
},
outputs={"Output": pre_bias},
attrs=attrs)
import paddle.fluid.dygraph as dygraph
import numpy as np
if bias is not None:
pre_act = self._helper.create_variable_for_type_inference(
dtype=self._dtype)
self._helper.append_op(
type='elementwise_add',
inputs={'X': [pre_bias],
'Y': [bias]},
outputs={'Out': [pre_act]},
attrs={'axis': 1})
else:
pre_act = pre_bias
exemplar = np.random.random((8, 4, 6, 6)).astype(np.float32)
instance = np.random.random((8, 4, 22, 22)).astype(np.float32)
# Currently, we don't support inplace in dygraph mode
return self._helper.append_activation(pre_act, act=self._act)
with dygraph.guard():
exem = dygraph.to_variable(exemplar)
inst = dygraph.to_variable(instance)
res = FConv2D(inst, exem, groups=1)
print(res.shape)
\ No newline at end of file
from paddle import fluid
from paddle.fluid import layers
from pytracking.libs.Fconv2d import Fconv2d
from pytracking.libs.Fconv2d import FConv2D
from pytracking.libs.tensorlist import tensor_operation, TensorList
from paddle.fluid.framework import Variable as PTensor
......@@ -37,7 +37,7 @@ def conv2d(input: PTensor,
raise ValueError('Unknown mode for padding.')
assert bias is None
out = Fconv2d(
out = FConv2D(
input,
weight,
stride=stride,
......@@ -56,4 +56,4 @@ def conv1x1(input: PTensor, weight: PTensor):
if weight is None:
return input
return Fconv2d(input, weight)
return FConv2D(input, weight)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册