From 5c94db16f2c19f9693b8a2cc0d80e9e0a7076700 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 19 Nov 2021 19:11:13 +0800 Subject: [PATCH] feat(mge/functional): add groups support for conv_transpose2d & 3d GitOrigin-RevId: b75b792fb4c911e1671a0db40bea1c96c6d832fb --- imperative/python/megengine/functional/nn.py | 12 +- imperative/python/megengine/module/conv.py | 16 +- .../test/unit/functional/test_functional.py | 59 +++++++ .../python/test/unit/module/test_conv.py | 159 ------------------ 4 files changed, 81 insertions(+), 165 deletions(-) diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 5400f8a48..ce2e16558 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -372,6 +372,7 @@ def conv_transpose2d( Args: inp: feature map of the convolution operation. weight: convolution kernel. + weight usually has shape ``(in_channels, out_channels, height, width)``. bias: bias added to the result of convolution (if given). stride: stride of the 2D convolution operation. Default: 1 padding: size of the paddings added to the input on both sides of its @@ -405,14 +406,12 @@ def conv_transpose2d( if weight.dtype != dtype: weight = weight.astype(dtype) - if groups != 1: - raise NotImplementedError("group transposed conv2d is not supported yet.") - stride_h, stride_w = expand_hw(stride) pad_h, pad_w = expand_hw(padding) dilate_h, dilate_w = expand_hw(dilation) - compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) + compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) + sparse_type = "dense" if groups == 1 else "group" op = builtin.ConvolutionBackwardData( stride_h=stride_h, stride_w=stride_w, @@ -422,6 +421,7 @@ def conv_transpose2d( dilate_w=dilate_w, strategy=get_execution_strategy(), compute_mode=compute_mode, + sparse=sparse_type, ) (output,) = apply(op, weight, inp) if bias is not None: @@ -447,6 +447,7 @@ def deformable_conv2d( Args: inp: input feature map. weight: convolution kernel. + weight usually has shape ``(out_channels, in_channels, height, width)``. offset: input offset to kernel, channel of this tensor should match the deformable settings. mask: input mask to kernel, channel of this tensor should match the deformable settings. bias: bias added to the result of convolution (if given). @@ -551,6 +552,7 @@ def conv_transpose3d( stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, + groups: int = 1, ) -> Tensor: r"""3D transposed convolution operation. Only support the case that groups = 1 and conv_mode = "cross_correlation". @@ -581,6 +583,7 @@ def conv_transpose3d( if weight.dtype != dtype: weight = weight.astype(dtype) + sparse_type = "dense" if groups == 1 else "group" op = builtin.Convolution3DBackwardData( pad_d=pad[D], pad_h=pad[H], @@ -592,6 +595,7 @@ def conv_transpose3d( dilate_h=dilate[H], dilate_w=dilate[W], strategy=get_execution_strategy(), + sparse=sparse_type, ) (output,) = apply(op, weight, inp) if bias is not None: diff --git a/imperative/python/megengine/module/conv.py b/imperative/python/megengine/module/conv.py index bd5612818..a4117d46b 100644 --- a/imperative/python/megengine/module/conv.py +++ b/imperative/python/megengine/module/conv.py @@ -891,6 +891,7 @@ class ConvTranspose3d(_ConvNd): padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, bias: bool = True, + groups: int = 1, ): kernel_size = _triple_nonzero(kernel_size) stride = _triple_nonzero(stride) @@ -903,7 +904,7 @@ class ConvTranspose3d(_ConvNd): stride=stride, padding=padding, dilation=dilation, - groups=1, + groups=groups, bias=bias, ) @@ -913,10 +914,21 @@ class ConvTranspose3d(_ConvNd): return kt * kh * kw * ic def _infer_weight_shape(self): + group = self.groups ichl = self.in_channels ochl = self.out_channels kt, kh, kw = self.kernel_size - return (ichl, ochl, kt, kh, kw) + if group == 1: + # Assume format is NCHW + return (ichl, ochl, kt, kh, kw) + + assert ( + ichl % group == 0 and ochl % group == 0 + ), "invalid config: in_channels={} out_channels={} group={}".format( + ichl, ochl, group + ) + # Assume format is NCHW + return (group, ichl // group, ochl // group, kt, kh, kw) def _infer_bias_shape(self): # Assume format is NCTHW diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 119a75d6e..1a82b30db 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -1290,3 +1290,62 @@ def test_set_warp_perspective_config(): expected = F.vision.warp_perspective(inp, M, (2, 2), format="NHWC") np.testing.assert_allclose(config_out.numpy(), expected.numpy()) np.testing.assert_allclose(context_out.numpy(), expected.numpy()) + + +@pytest.mark.parametrize("stride", [(1, 1)]) +@pytest.mark.parametrize("padding", [(1, 1)]) +@pytest.mark.parametrize("dilation", [(1, 1)]) +@pytest.mark.parametrize("ksize", [(3, 3)]) +@pytest.mark.parametrize("groups", [1, 2]) +def test_local_conv2d(stride, padding, dilation, ksize, groups): + batch_size, in_channels, out_channels = 2, 4, 8 + input_height, input_width = 10, 10 + output_height = (input_height + padding[0] * 2 - ksize[0]) // stride[0] + 1 + output_width = (input_width + padding[1] * 2 - ksize[1]) // stride[1] + 1 + + def local_conv2d_np(data, weight, stride, padding, dialtion): + # naive calculation use numpy + # only test output_height == input_height, output_width == input_width + data = np.pad(data, ((0, 0), (0, 0), (1, 1), (1, 1))) + expected = np.zeros( + (batch_size, out_channels, output_height, output_width), dtype=np.float32, + ) + ic_group_size = in_channels // groups + oc_group_size = out_channels // groups + for n, oc, oh, ow in itertools.product( + *map(range, [batch_size, out_channels, output_height, output_width]) + ): + ih, iw = oh * stride[0], ow * stride[1] + g_id = oc // oc_group_size + expected[n, oc, ih, iw] = np.sum( + data[ + n, + g_id * ic_group_size : (g_id + 1) * ic_group_size, + ih : ih + ksize[0], + iw : iw + ksize[1], + ] + * weight[g_id, oh, ow, :, :, :, oc % oc_group_size] + ) + return expected + + data = np.random.rand(batch_size, in_channels, input_height, input_width).astype( + "float32" + ) + weight = np.random.rand( + groups, + output_height, + output_width, + in_channels // groups, + *ksize, + out_channels // groups, + ).astype("float32") + output = F.local_conv2d( + tensor(data), + tensor(weight), + None, + stride=stride, + padding=padding, + dilation=dilation, + ) + ref = local_conv2d_np(data, weight, stride, padding, dilation) + np.testing.assert_almost_equal(output.numpy(), ref, 5) diff --git a/imperative/python/test/unit/module/test_conv.py b/imperative/python/test/unit/module/test_conv.py index f2143ae18..e782a077a 100644 --- a/imperative/python/test/unit/module/test_conv.py +++ b/imperative/python/test/unit/module/test_conv.py @@ -42,162 +42,3 @@ def test_conv_dtype_promotion(name, reproducible): m = getattr(M, name)(Ci, Co, K) x = tensor(np.random.random(size=(N, Ci) + S).astype("float16")) np.testing.assert_equal(m(x).numpy(), m(x.astype("float32")).numpy()) - - -def test_conv_transpose2d(): - SH, SW = 3, 1 - PH, PW = 2, 0 - N, IC, IH, IW = 4, 5, 8, 6 - KH, KW = 3, 4 - OC = 3 - BIAS = False - - def getsize(inp, kern, stride): - return (inp - 1) * stride + kern - - OH = getsize(IH, KH, SH) - OW = getsize(IW, KW, SW) - - inp = np.random.normal(size=(N, IC, IH, IW)).astype(np.float32) - out = np.zeros((N, OC, OH, OW), dtype=np.float32) - weight = np.random.normal(size=(IC, OC, KH, KW)).astype(np.float32) - bias = np.random.normal(size=(1, OC, 1, 1)).astype(np.float32) - - # naive calculation use numpy - for n, ic, ih, iw in itertools.product(*map(range, [N, IC, IH, IW])): - oh, ow = ih * SH, iw * SW - out[n, :, oh : oh + KH, ow : ow + KW] += inp[n, ic, ih, iw] * weight[ic] - out = out[:, :, PH : OH - PH, PW : OW - PW] - if BIAS: - out += bias - - # megengine conv_transpose2d calculation - conv_transpose2d = ConvTranspose2d(IC, OC, (KH, KW), (SH, SW), (PH, PW), bias=BIAS) - conv_transpose2d.weight = Parameter(weight, dtype=np.float32) - if BIAS: - conv_transpose2d.bias = Parameter(bias, dtype=np.float32) - y = conv_transpose2d(tensor(inp)) - - np.testing.assert_almost_equal(out, y.numpy(), 2e-6) - - -def test_local_conv2d(): - def test_func( - batch_size, - in_channels, - out_channels, - input_height, - input_width, - kernel_size, - stride, - padding, - dilation, - groups, - ): - local_conv2d = LocalConv2d( - in_channels=in_channels, - out_channels=out_channels, - input_height=input_height, - input_width=input_width, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - ) - inputs = np.random.normal( - size=(batch_size, in_channels, input_height, input_width) - ).astype(np.float32) - output_height = (input_height + padding * 2 - kernel_size) // stride + 1 - output_width = (input_width + padding * 2 - kernel_size) // stride + 1 - weights = local_conv2d.weight.numpy() - outputs = local_conv2d(tensor(inputs)) - # naive calculation use numpy - # only test output_height == input_height, output_width == input_width - inputs = np.pad(inputs, ((0, 0), (0, 0), (1, 1), (1, 1))) - expected = np.zeros( - (batch_size, out_channels, output_height, output_width), dtype=np.float32, - ) - ic_group_size = in_channels // groups - oc_group_size = out_channels // groups - for n, oc, oh, ow in itertools.product( - *map(range, [batch_size, out_channels, output_height, output_width]) - ): - ih, iw = oh * stride, ow * stride - g_id = oc // oc_group_size - expected[n, oc, ih, iw] = np.sum( - inputs[ - n, - g_id * ic_group_size : (g_id + 1) * ic_group_size, - ih : ih + kernel_size, - iw : iw + kernel_size, - ] - * weights[g_id, oh, ow, :, :, :, oc % oc_group_size] - ) - np.testing.assert_almost_equal(outputs.numpy(), expected, 1e-5) - - test_func(10, 4, 4, 5, 5, 3, 1, 1, 1, 1) - test_func(10, 32, 32, 8, 8, 3, 1, 1, 1, 2) - test_func(10, 32, 32, 8, 8, 3, 1, 1, 1, 4) - - -def test_conv_transpose3d(): - def getsize(inp, kernel, stride, dilate): - return (inp - 1) * stride + kernel * dilate - dilate + 1 - - def test_func( - N, - IC, - ID, - IH, - IW, - OC, - KD, - KH, - KW, - SD, - SH, - SW, - PD, - PH, - PW, - DD, - DH, - DW, - bias=True, - ): - conv_transpose3d = ConvTranspose3d( - in_channels=IC, - out_channels=OC, - kernel_size=(KD, KH, KW), - stride=(SD, SH, SW), - padding=(PD, PH, PW), - dilation=(DD, DH, DW), - bias=bias, - ) - - OD = getsize(ID, KD, SD, DD) - OH = getsize(IH, KH, SH, DH) - OW = getsize(IW, KW, SW, DW) - - inp = np.random.normal(size=(N, IC, ID, IH, IW)) - weight = np.random.normal(size=(IC, OC, KD, KH, KW)) - out_np = np.zeros((N, OC, OD, OH, OW), dtype=np.float32) - - for n, ic, idepth, ih, iw in itertools.product( - *map(range, [N, IC, ID, IH, IW]) - ): - od, oh, ow = idepth * SD, ih * SH, iw * SW - out_np[n, :, od : od + KD, oh : oh + KH, ow : ow + KW] += ( - inp[n, ic, idepth, ih, iw] * weight[ic] - ) - out_np = out_np[:, :, PD : OD - PD, PH : OH - PH, PW : OW - PW] - - assert conv_transpose3d.weight.numpy().shape == weight.shape - conv_transpose3d.weight = Parameter(weight) - out_meg = conv_transpose3d.forward(tensor(inp)) - - np.testing.assert_almost_equal(out_meg.numpy(), out_np, 1e-5) - - test_func(4, 3, 8, 16, 16, 8, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1) - test_func(4, 8, 16, 32, 32, 16, 1, 3, 1, 2, 1, 2, 0, 1, 0, 1, 1, 1) -- GitLab