From e18afa0b05d5ebe057c33d0677eff528ab6ce5b3 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 22 Apr 2021 16:09:16 +0800 Subject: [PATCH] feat(mge/module): python wrapper for conv_transpose3d GitOrigin-RevId: 61097b871338d09e882cd7c5a453efeb2e0b1d7f --- imperative/python/megengine/functional/nn.py | 49 +++++++++++++ .../python/megengine/module/__init__.py | 1 + imperative/python/megengine/module/conv.py | 73 +++++++++++++++++++ .../python/test/unit/module/test_conv.py | 63 +++++++++++++++- imperative/src/impl/ops/convolution.cpp | 15 ++++ src/core/include/megbrain/ir/ops.td | 2 + 6 files changed, 202 insertions(+), 1 deletion(-) diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 4d9e27a6b..9eb34758f 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -48,6 +48,7 @@ __all__ = [ "conv2d", "conv3d", "conv_transpose2d", + "conv_transpose3d", "deformable_conv2d", "deformable_psroi_pooling", "dropout", @@ -488,6 +489,54 @@ def local_conv2d( return output +def conv_transpose3d( + inp: Tensor, + weight: Tensor, + bias: Optional[Tensor] = None, + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + dilation: Union[int, Tuple[int, int, int]] = 1, +) -> Tensor: + """ + 3D transposed convolution operation. Only support the case that group = 1 + and conv_mode = "cross_correlation". + + Refer to :class:`~.ConvTranspose3d` for more information. + + :param inp: feature map of the convolution operation. + :param weight: convolution kernel. + :param bias: bias added to the result of convolution (if given). + :param stride: stride of the 3D convolution operation. Default: 1 + :param padding: size of the paddings added to the input on all sides of its + spatial dimensions. Only zero-padding is supported. Default: 0 + :param dilation: dilation of the 3D convolution operation. Default: 1 + :return: output tensor. + """ + D, H, W = 0, 1, 2 + + pad = _triple(padding) + stride = _triple_nonzero(stride) + dilate = _triple_nonzero(dilation) + + op = builtin.Convolution3DBackwardData( + pad_d=pad[D], + pad_h=pad[H], + pad_w=pad[W], + stride_d=stride[D], + stride_h=stride[H], + stride_w=stride[W], + dilate_d=dilate[D], + dilate_h=dilate[H], + dilate_w=dilate[W], + strategy=get_execution_strategy(), + ) + weight, inp = utils.convert_inputs(weight, inp) + (output,) = apply(op, weight, inp) + if bias is not None: + output += bias + return output + + def max_pool2d( inp: Tensor, kernel_size: Union[int, Tuple[int, int]], diff --git a/imperative/python/megengine/module/__init__.py b/imperative/python/megengine/module/__init__.py index 162d2cc64..bdffc89dd 100644 --- a/imperative/python/megengine/module/__init__.py +++ b/imperative/python/megengine/module/__init__.py @@ -18,6 +18,7 @@ from .conv import ( Conv3d, ConvRelu2d, ConvTranspose2d, + ConvTranspose3d, DeformableConv2d, LocalConv2d, ) diff --git a/imperative/python/megengine/module/conv.py b/imperative/python/megengine/module/conv.py index 1732f2d3e..cd156eece 100644 --- a/imperative/python/megengine/module/conv.py +++ b/imperative/python/megengine/module/conv.py @@ -15,6 +15,7 @@ from ..functional import ( conv2d, conv3d, conv_transpose2d, + conv_transpose3d, deformable_conv2d, local_conv2d, relu, @@ -842,3 +843,75 @@ class DeformableConv2d(_ConvNd): def forward(self, inp, offset, mask): return self.calc_conv(inp, self.weight, offset, mask, self.bias) + + +class ConvTranspose3d(_ConvNd): + r""" + Applies a 3D transposed convolution over an input tensor. + + Only support the case that group = 1 and conv_mode = "cross_correlation". + + :class:`ConvTranspose3d` can be seen as the gradient of :class:`Conv3d` operation + with respect to its input. + + Convolution3D usually reduces the size of input, while transposed convolution3d + works the opposite way, transforming a smaller input to a larger output while + preserving the connectivity pattern. + + :param in_channels: number of input channels. + :param out_channels: number of output channels. + :param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is + an :class:`int`, the actual kernel size would be + ``(kernel_size, kernel_size, kernel_size)``. Default: 1 + :param stride: stride of the 3D convolution operation. Default: 1 + :param padding: size of the paddings added to the input on all sides of its + spatial dimensions. Only zero-padding is supported. Default: 0 + :param dilation: dilation of the 3D convolution operation. Default: 1 + :param bias: wether to add a bias onto the result of convolution. Default: + True + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + dilation: Union[int, Tuple[int, int, int]] = 1, + bias: bool = True, + ): + kernel_size = _triple_nonzero(kernel_size) + stride = _triple_nonzero(stride) + padding = _triple(padding) + dilation = _triple_nonzero(dilation) + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=1, + bias=bias, + ) + + def _get_fanin(self): + kt, kh, kw = self.kernel_size + ic = self.in_channels + return kt * kh * kw * ic + + def _infer_weight_shape(self): + ichl = self.in_channels + ochl = self.out_channels + kt, kh, kw = self.kernel_size + return (ochl, ichl, kt, kh, kw) + + def _infer_bias_shape(self): + # Assume format is NCTHW + return (1, self.out_channels, 1, 1, 1) + + def forward(self, inp): + return conv_transpose3d( + inp, self.weight, self.bias, self.stride, self.padding, self.dilation, + ) diff --git a/imperative/python/test/unit/module/test_conv.py b/imperative/python/test/unit/module/test_conv.py index 6c51deb62..0f99608c3 100644 --- a/imperative/python/test/unit/module/test_conv.py +++ b/imperative/python/test/unit/module/test_conv.py @@ -11,7 +11,7 @@ import itertools import numpy as np from megengine import Parameter, tensor -from megengine.module import ConvTranspose2d, LocalConv2d +from megengine.module import ConvTranspose2d, ConvTranspose3d, LocalConv2d def test_conv_transpose2d(): @@ -120,3 +120,64 @@ def test_local_conv2d(): test_func(10, 4, 4, 5, 5, 3, 1, 1, 1, 1) test_func(10, 32, 32, 8, 8, 3, 1, 1, 1, 2) test_func(10, 32, 32, 8, 8, 3, 1, 1, 1, 4) + + +def test_conv_transpose3d(): + def getsize(inp, kernel, stride, dilate): + return (inp - 1) * stride + kernel * dilate - dilate + 1 + + def test_func( + N, + IC, + ID, + IH, + IW, + OC, + KD, + KH, + KW, + SD, + SH, + SW, + PD, + PH, + PW, + DD, + DH, + DW, + bias=True, + ): + conv_transpose3d = ConvTranspose3d( + in_channels=IC, + out_channels=OC, + kernel_size=(KD, KH, KW), + stride=(SD, SH, SW), + padding=(PD, PH, PW), + dilation=(DD, DH, DW), + bias=bias, + ) + + OD = getsize(ID, KD, SD, DD) + OH = getsize(IH, KH, SH, DH) + OW = getsize(IW, KW, SW, DW) + + inp = np.random.normal(size=(N, IC, ID, IH, IW)) + weight = np.random.normal(size=(IC, OC, KD, KH, KW)) + out_np = np.zeros((N, OC, OD, OH, OW), dtype=np.float32) + + for n, ic, idepth, ih, iw in itertools.product( + *map(range, [N, IC, ID, IH, IW]) + ): + od, oh, ow = idepth * SD, ih * SH, iw * SW + out_np[n, :, od : od + KD, oh : oh + KH, ow : ow + KW] += ( + inp[n, ic, idepth, ih, iw] * weight[ic] + ) + out_np = out_np[:, :, PD : OD - PD, PH : OH - PH, PW : OW - PW] + + conv_transpose3d.weight = Parameter(weight) + out_meg = conv_transpose3d.forward(tensor(inp)) + + np.testing.assert_almost_equal(out_meg.numpy(), out_np, 1e-5) + + test_func(4, 3, 8, 16, 16, 8, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1) + test_func(4, 8, 16, 32, 32, 16, 1, 3, 1, 2, 1, 2, 0, 1, 0, 1, 1, 1) diff --git a/imperative/src/impl/ops/convolution.cpp b/imperative/src/impl/ops/convolution.cpp index afc057508..8a517712b 100644 --- a/imperative/src/impl/ops/convolution.cpp +++ b/imperative/src/impl/ops/convolution.cpp @@ -75,5 +75,20 @@ OP_TRAIT_REG(Convolution3D, Convolution3D, opr::Convolution3D) .fallback(); }} // convolution3d +namespace { namespace convolution3d_backward_data { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& conv = static_cast(def); + OperatorNodeConfig config{conv.make_name()}; + mgb_assert(inputs.size() == 2); + return opr::Convolution3DBackwardData::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); +} + +OP_TRAIT_REG(Convolution3DBackwardData, Convolution3DBackwardData) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // convolution3d_backward_data + } } diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index 01c91da3d..e13bb859a 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -53,6 +53,8 @@ def ConvolutionBackwardData: MgbHashableOp<"ConvolutionBackwardData", [Convoluti def Convolution3D: MgbHashableOp<"Convolution3D", [Convolution3DParam, ExecutionPolicyParamBase<"policy">]>; +def Convolution3DBackwardData: MgbHashableOp<"Convolution3DBackwardData", [Convolution3DParam, ExecutionPolicyParamBase<"policy">]>; + def DeformableConv : MgbHashableOp<"DeformableConv", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>; def GroupLocal: MgbHashableOp<"GroupLocal", [ConvolutionParam]>; -- GitLab