提交 d23fec06 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

fix(mge/functional): rm useless arguments

group and compute_mode are not used by local_conv2d

GitOrigin-RevId: 8e4f25bfd851d791f132ded0daddb3d636f65144
上级 3b08bd9e
...@@ -65,6 +65,7 @@ from .nn import ( ...@@ -65,6 +65,7 @@ from .nn import (
interpolate, interpolate,
leaky_relu, leaky_relu,
linear, linear,
local_conv2d,
matrix_mul, matrix_mul,
max_pool2d, max_pool2d,
one_hot, one_hot,
......
...@@ -170,6 +170,34 @@ def conv_transpose2d( ...@@ -170,6 +170,34 @@ def conv_transpose2d(
return res return res
@wrap_io_tensor
def local_conv2d(
inp: Tensor,
weight: Tensor,
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
conv_mode="CROSS_CORRELATION",
) -> Tensor:
"""Applies spatial 2D convolution over an image with untied kernels.
Refer to :class:`~.LocalConv2d` for more information.
"""
ret = mgb.opr.group_local(
inp,
weight,
pad_h=padding[0],
pad_w=padding[1],
stride_h=stride[0],
stride_w=stride[1],
dilate_h=dilation[0],
dilate_w=dilation[1],
format="NCHW",
mode=conv_mode,
)
return ret
@wrap_io_tensor @wrap_io_tensor
def max_pool2d( def max_pool2d(
inp: Tensor, inp: Tensor,
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax
from .batchnorm import BatchNorm1d, BatchNorm2d from .batchnorm import BatchNorm1d, BatchNorm2d
from .concat import Concat from .concat import Concat
from .conv import Conv2d, ConvTranspose2d from .conv import Conv2d, ConvTranspose2d, LocalConv2d
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d from .conv_bn_relu import ConvBn2d, ConvBnRelu2d
from .dropout import Dropout from .dropout import Dropout
from .elemwise import Elemwise from .elemwise import Elemwise
......
...@@ -14,7 +14,7 @@ import numpy as np ...@@ -14,7 +14,7 @@ import numpy as np
import megengine._internal as mgb import megengine._internal as mgb
from ..core import Parameter from ..core import Parameter
from ..functional import conv2d, conv_transpose2d from ..functional import conv2d, conv_transpose2d, local_conv2d
from ..utils.types import _pair, _pair_nonzero from ..utils.types import _pair, _pair_nonzero
from . import init from . import init
from .module import Module from .module import Module
...@@ -224,7 +224,7 @@ class ConvTranspose2d(_ConvNd): ...@@ -224,7 +224,7 @@ class ConvTranspose2d(_ConvNd):
``in_channels`` and ``out_channels`` must be divisible by ``groups``, ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and there would be an extra dimension at the beginning of the weight's and there would be an extra dimension at the beginning of the weight's
shape. Specifically, the shape of weight would be ``(groups, shape. Specifically, the shape of weight would be ``(groups,
out_channel // groups, in_channels // groups, *kernel_size)``. Default: 1 out_channels // groups, in_channels // groups, *kernel_size)``. Default: 1
:param bias: wether to add a bias onto the result of convolution. Default: :param bias: wether to add a bias onto the result of convolution. Default:
True True
:param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default: :param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default:
...@@ -306,3 +306,77 @@ class ConvTranspose2d(_ConvNd): ...@@ -306,3 +306,77 @@ class ConvTranspose2d(_ConvNd):
self.conv_mode, self.conv_mode,
self.compute_mode, self.compute_mode,
) )
class LocalConv2d(Conv2d):
r"""Applies a spatial convolution with untied kernels over an input 4D tensor.
It is also known as the locally connected layer.
:param in_channels: number of input channels.
:param out_channels: number of output channels.
:param input_height: the height of the input images.
:param input_width: the width of the input images.
:param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is
an :class:`int`, the actual kernel size would be
``(kernel_size, kernel_size)``. Default: 1
:param stride: stride of the 2D convolution operation. Default: 1
:param padding: size of the paddings added to the input on both sides of its
spatial dimensions. Only zero-padding is supported. Default: 0
:param groups: number of groups to divide input and output channels into,
so as to perform a "grouped convolution". When ``groups`` is not 1,
``in_channels`` and ``out_channels`` must be divisible by ``groups``.
The shape of weight is ``(groups, output_height, output_width,
in_channels // groups, *kernel_size, out_channels // groups)``.
"""
_conv_mode_type = mgb.opr_param_defs.Convolution.Mode
def __init__(
self,
in_channels: int,
out_channels: int,
input_height: int,
input_width: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
conv_mode: str = "CROSS_CORRELATION",
):
self.input_height = input_height
self.input_width = input_width
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias=False,
)
def _infer_weight_shape(self):
group = self.groups
output_height = (
self.input_height + self.padding[0] * 2 - self.kernel_size[0]
) // self.stride[0] + 1
output_width = (
self.input_width + self.padding[1] * 2 - self.kernel_size[1]
) // self.stride[1] + 1
# Assume format is NCHW
return (
group,
output_height,
output_width,
self.in_channels // group,
self.kernel_size[0],
self.kernel_size[1],
self.out_channels // group,
)
def forward(self, inp):
return local_conv2d(
inp, self.weight, self.stride, self.padding, self.dilation, self.conv_mode
)
...@@ -11,7 +11,7 @@ import itertools ...@@ -11,7 +11,7 @@ import itertools
import numpy as np import numpy as np
from megengine import Parameter, tensor from megengine import Parameter, tensor
from megengine.module import ConvTranspose2d from megengine.module import ConvTranspose2d, LocalConv2d
from megengine.test import assertTensorClose from megengine.test import assertTensorClose
...@@ -50,3 +50,61 @@ def test_conv_transpose2d(): ...@@ -50,3 +50,61 @@ def test_conv_transpose2d():
y = conv_transpose2d(tensor(inp)) y = conv_transpose2d(tensor(inp))
assertTensorClose(out, y.numpy(), max_err=2e-6) assertTensorClose(out, y.numpy(), max_err=2e-6)
def test_local_conv2d():
batch_size = 10
in_channels = 4
out_channels = 8
input_height = 8
input_width = 8
kernel_size = 3
stride = 1
padding = 1
dilation = 1
groups = 1
local_conv2d = LocalConv2d(
in_channels=in_channels,
out_channels=out_channels,
input_height=input_height,
input_width=input_width,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
inputs = np.random.normal(
size=(batch_size, in_channels, input_height, input_width)
).astype(np.float32)
output_height = (input_height + padding * 2 - kernel_size) // stride + 1
output_width = (input_width + padding * 2 - kernel_size) // stride + 1
weights = np.random.normal(
size=(
groups,
output_height,
output_width,
in_channels // groups,
kernel_size,
kernel_size,
out_channels // groups,
)
).astype(np.float32)
local_conv2d.weight = Parameter(weights)
outputs = local_conv2d(tensor(inputs))
# naive calculation use numpy
# only test output_height == input_height, output_width == input_width, group == 1
inputs = np.pad(inputs, ((0, 0), (0, 0), (1, 1), (1, 1)))
expected = np.zeros(
(batch_size, out_channels, output_height, output_width), dtype=np.float32,
)
for n, oc, oh, ow in itertools.product(
*map(range, [batch_size, out_channels, output_height, output_width])
):
ih, iw = oh * stride, ow * stride
expected[n, oc, ih, iw] = np.sum(
inputs[n, :, ih : ih + kernel_size, iw : iw + kernel_size]
* weights[0, oh, ow, :, :, :, oc]
)
assertTensorClose(outputs.numpy(), expected, max_err=1e-5)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册