提交 8fad00a1 编写于 作者: M Megvii Engine Team

feat(mge/functional/nn): add conv1d padding

GitOrigin-RevId: 1bbfd36b96f5757d459ddc75d3483da40867577d
上级 4aa277a2
......@@ -62,6 +62,7 @@ __all__ = [
"softplus",
"svd",
"warp_perspective",
"conv1d",
]
......@@ -121,7 +122,7 @@ def conv2d(
and the shape of weight should be `(groups, out_channel // groups,
in_channels // groups, height, width)`.
:type conv_mode: string or :class:`P.Convolution.Mode`
:param conv_mode: supports "CROSS_CORRELATION" or "CONVOLUTION". Default:
:param conv_mode: supports "CROSS_CORRELATION". Default:
"CROSS_CORRELATION"
:type compute_mode: string or
:class:`P.Convolution.ComputeMode`
......@@ -187,7 +188,7 @@ def conv_transpose2d(
and the shape of weight should be `(groups, out_channel // groups,
in_channels // groups, height, width)`. Default: 1
:type conv_mode: string or :class:`P.Convolution.Mode`
:param conv_mode: supports "CROSS_CORRELATION" or "CONVOLUTION". Default:
:param conv_mode: supports "CROSS_CORRELATION". Default:
"CROSS_CORRELATION"
:type compute_mode: string or
:class:`P.Convolution.ComputeMode`
......@@ -232,9 +233,7 @@ def local_conv2d(
dilation: Union[int, Tuple[int, int]] = 1,
conv_mode="CROSS_CORRELATION",
):
"""
Applies spatial 2D convolution over an groupped channeled image with untied kernels.
"""
"""Applies spatial 2D convolution over an groupped channeled image with untied kernels."""
assert conv_mode == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION"
stride_h, stride_w = expand_hw(stride)
......@@ -1585,6 +1584,82 @@ def indexing_one_hot(
return result
def conv1d(
inp: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
conv_mode="CROSS_CORRELATION",
compute_mode="DEFAULT",
) -> Tensor:
"""1D convolution operation.
Refer to :class:`~.Conv1d` for more information.
:param inp: The feature map of the convolution operation
:param weight: The convolution kernel
:param bias: The bias added to the result of convolution (if given)
:param stride: Stride of the 1D convolution operation. Default: 1
:param padding: Size of the paddings added to the input on both sides of its
spatial dimensions. Only zero-padding is supported. Default: 0
:param dilation: Dilation of the 1D convolution operation. Default: 1
:param groups: number of groups to divide input and output channels into,
so as to perform a "grouped convolution". When ``groups`` is not 1,
``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and the shape of weight should be ``(groups, out_channel // groups,
in_channels // groups, height, width)``.
:type conv_mode: string or :class:`mgb.opr_param_defs.Convolution.Mode`
:param conv_mode: Supports 'CROSS_CORRELATION'. Default:
'CROSS_CORRELATION'.
:type compute_mode: string or
:class:`mgb.opr_param_defs.Convolution.ComputeMode`
:param compute_mode: When set to 'DEFAULT', no special requirements will be
placed on the precision of intermediate results. When set to 'FLOAT32',
Float32 would be used for accumulator and intermediate result, but only
effective when input and output are of Float16 dtype.
"""
assert conv_mode == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION"
assert compute_mode == "DEFAULT" or compute_mode.name == "DEFAULT"
assert inp.ndim == 3, "the input dimension of conv1d should be 3"
assert weight.ndim == 3, "the weight dimension of conv1d should be 3"
inp = expand_dims(inp, 3)
weight = expand_dims(weight, 3)
if bias is not None:
assert bias.ndim == 3, "the bias dimension of conv1d should be 3"
bias = expand_dims(bias, 3)
stride_h = stride
pad_h = padding
dilate_h = dilation
Sparse = P.Convolution.Sparse
sparse_type = Sparse.DENSE if groups == 1 else Sparse.GROUP
op = builtin.Convolution(
stride_h=stride_h,
stride_w=1,
pad_h=pad_h,
pad_w=0,
dilate_h=dilate_h,
dilate_w=1,
strategy=get_conv_execution_strategy(),
mode=conv_mode,
compute_mode=compute_mode,
sparse=sparse_type,
)
inp, weight = utils.convert_inputs(inp, weight)
(output,) = apply(op, inp, weight)
if bias is not None:
output += bias
output = squeeze(output, 3)
return output
def nms(
boxes: Tensor, scores: Tensor, iou_thresh: float, max_output: Optional[int] = None
) -> Tensor:
......
......@@ -11,7 +11,7 @@ from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax
from .adaptive_pooling import AdaptiveAvgPool2d, AdaptiveMaxPool2d
from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm
from .concat import Concat
from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, LocalConv2d
from .conv import Conv1d, Conv2d, ConvRelu2d, ConvTranspose2d, LocalConv2d
from .conv_bn import ConvBn2d, ConvBnRelu2d
from .dropout import Dropout
from .elemwise import Elemwise
......
......@@ -11,7 +11,7 @@ from typing import Tuple, Union
import numpy as np
from ..core.ops._internal import param_defs as P
from ..functional import conv2d, conv_transpose2d, local_conv2d, relu
from ..functional import conv1d, conv2d, conv_transpose2d, local_conv2d, relu
from ..functional.types import _pair, _pair_nonzero
from ..tensor import Parameter
from . import init
......@@ -86,6 +86,152 @@ class _ConvNd(Module):
return s.format(**self.__dict__)
class Conv1d(_ConvNd):
r"""
Applies a 1D convolution over an input tensor.
For instance, given an input of the size :math:`(N, C_{\text{in}}, H)`,
this layer generates an output of the size
:math:`(N, C_{\text{out}}, H_{\text{out}}})` through the
process described as below:
.. math::
\text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
\sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)
where :math:`\star` is the valid 1D cross-correlation operator,
:math:`N` is batch size, :math:`C` denotes number of channels, and
:math:`H` is length of 1D data element.
When `groups == in_channels` and `out_channels == K * in_channels`,
where K is a positive integer, this operation is also known as depthwise
convolution.
In other words, for an input of size :math:`(N, C_{in}, H_{in})`,
a depthwise convolution with a depthwise multiplier `K`, can be constructed
by arguments :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`.
:param in_channels: number of input channels.
:param out_channels: number of output channels.
:param kernel_size: size of weight on spatial dimensions. If kernel_size is
an :class:`int`, the actual kernel size would be
`(kernel_size, kernel_size)`. Default: 1
:param stride: stride of the 1D convolution operation. Default: 1
:param padding: size of the paddings added to the input on both sides of its
spatial dimensions. Only zero-padding is supported. Default: 0
:param dilation: dilation of the 1D convolution operation. Default: 1
:param groups: number of groups into which the input and output channels are divided, so as to perform a "grouped convolution". When ``groups`` is not 1,
``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and there would be an extra dimension at the beginning of the weight's
shape. Specifically, the shape of weight would be `(groups,
out_channel // groups, in_channels // groups, *kernel_size)`.
:param bias: whether to add a bias onto the result of convolution. Default:
True
:param conv_mode: Supports `CROSS_CORRELATION`. Default:
`CROSS_CORRELATION`
:param compute_mode: When set to "DEFAULT", no special requirements will be
placed on the precision of intermediate results. When set to "FLOAT32",
"Float32" would be used for accumulator and intermediate result, but only
effective when input and output are of float16 dtype.
Examples:
.. testcode::
import numpy as np
import megengine as mge
import megengine.module as M
m = M.Conv1d(in_channels=3, out_channels=1, kernel_size=3)
inp = mge.tensor(np.arange(0, 24).astype("float32").reshape(2, 3, 4))
oup = m(inp)
print(oup.numpy().shape)
Outputs:
.. testoutput::
(2, 1, 2)
"""
_conv_mode_type = P.Convolution.Mode
_compute_mode_type = P.Convolution.ComputeMode
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
conv_mode: str = "CROSS_CORRELATION",
compute_mode: str = "DEFAULT",
):
kernel_size = kernel_size
stride = stride
padding = padding
dilation = dilation
self.conv_mode = self._conv_mode_type.convert(conv_mode)
self.compute_mode = self._compute_mode_type.convert(compute_mode)
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
)
def _get_fanin(self):
kh = self.kernel_size
ic = self.in_channels
return kh * ic
def _infer_weight_shape(self):
group = self.groups
ichl = self.in_channels
ochl = self.out_channels
kh = self.kernel_size
if group == 1:
# Assume format is NCH(W=1)
return (ochl, ichl, kh)
assert (
ichl % group == 0 and ochl % group == 0
), "invalid config: input_channels={} output_channels={} group={}".format(
ichl, ochl, group
)
# Assume format is NCH(W=1)
return (group, ochl // group, ichl // group, kh)
def _infer_bias_shape(self):
# Assume format is NCH(W=1)
return (1, self.out_channels, 1)
def calc_conv(self, inp, weight, bias):
return conv1d(
inp,
weight,
bias,
self.stride,
self.padding,
self.dilation,
self.groups,
self.conv_mode,
self.compute_mode,
)
def forward(self, inp):
return self.calc_conv(inp, self.weight, self.bias)
class Conv2d(_ConvNd):
r"""
Applies a 2D convolution over an input tensor.
......@@ -128,7 +274,7 @@ class Conv2d(_ConvNd):
out_channel // groups, in_channels // groups, *kernel_size)`.
:param bias: whether to add a bias onto the result of convolution. Default:
True
:param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default:
:param conv_mode: Supports `CROSS_CORRELATION`. Default:
`CROSS_CORRELATION`
:param compute_mode: When set to "DEFAULT", no special requirements will be
placed on the precision of intermediate results. When set to "FLOAT32",
......@@ -260,7 +406,7 @@ class ConvTranspose2d(_ConvNd):
out_channels // groups, in_channels // groups, *kernel_size)``. Default: 1
:param bias: wether to add a bias onto the result of convolution. Default:
True
:param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default:
:param conv_mode: Supports `CROSS_CORRELATION`. Default:
`CROSS_CORRELATION`
:param compute_mode: When set to "DEFAULT", no special requirements will be
placed on the precision of intermediate results. When set to "FLOAT32",
......
......@@ -531,6 +531,18 @@ def test_zero_stride_numpy_array():
out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1)
def test_conv1d():
inp = tensor(np.ones((16,), dtype=np.float32).reshape(2, 2, 4))
weight = tensor(np.ones((12,), dtype=np.float32).reshape(3, 2, 2))
out = F.conv1d(inp, weight, None, 2, 0, 1, 1)
np.testing.assert_equal(
out.numpy(),
np.array(
[[[4, 4], [4, 4], [4, 4]], [[4, 4], [4, 4], [4, 4]]], dtype=np.float32
),
)
def test_condtake():
x = np.array([[1, 2, 3], [4, 5, 6]])
y = np.array([[True, False, True], [False, True, True]])
......
......@@ -20,6 +20,7 @@ from megengine import Parameter, Tensor, tensor
from megengine.module import (
BatchNorm1d,
BatchNorm2d,
Conv1d,
Conv2d,
Dropout,
Linear,
......@@ -541,6 +542,43 @@ def test_shared_param():
np.testing.assert_allclose(conv0(data).numpy(), conv1(data).numpy())
class Simple2(Module):
def __init__(self):
super().__init__()
self.conv1 = Conv1d(1, 1, kernel_size=3, bias=False)
self.conv0 = Conv1d(1, 1, kernel_size=3, bias=False)
self.conv1.weight = self.conv0.weight
def forward(self, inputs):
pass
def test_shared_param_1d():
net = Simple2()
assert net.conv0.weight is net.conv1.weight
data = tensor(np.random.random((1, 1, 8)).astype(np.float32))
np.testing.assert_allclose(net.conv0(data).numpy(), net.conv1(data).numpy())
with BytesIO() as f:
mge.save(net, f)
f.seek(0)
net1 = mge.load(f)
assert net1.conv0.weight is net1.conv1.weight
np.testing.assert_allclose(net1.conv0(data).numpy(), net1.conv1(data).numpy())
with BytesIO() as f:
mge.save(net.conv0, f)
f.seek(0)
conv0 = mge.load(f)
with BytesIO() as f:
mge.save(net.conv1, f)
f.seek(0)
conv1 = mge.load(f)
assert conv0.weight is not conv1.weight
np.testing.assert_allclose(conv0(data).numpy(), conv1(data).numpy())
def test_pickle_module():
data_shape = (2, 28)
data = tensor(np.random.random(data_shape))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册