提交 7ed98dd5 编写于 作者: M Megvii Engine Team

feat(mge/functional): add conv_transpose2d support

GitOrigin-RevId: 9ad87a4ea9f1c067418e46f98867445b593c0464
上级 20332705
......@@ -53,6 +53,7 @@ from .nn import (
batch_norm2d,
batched_matrix_mul,
conv2d,
conv_transpose2d,
dropout,
embedding,
eye,
......
......@@ -100,6 +100,69 @@ def conv2d(
return res
@wrap_io_tensor
def conv_transpose2d(
inp: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
conv_mode="CROSS_CORRELATION",
compute_mode="DEFAULT",
) -> Tensor:
"""2D transposed convolution operation.
:param inp: The feature map of the convolution operation
:param weight: The convolution kernel
:param bias: The bias added to the result of convolution (if given)
:param stride: Stride of the 2D convolution operation. Default: 1
:param padding: Size of the paddings added to the input on both sides of its
spatial dimensions. Only zero-padding is supported. Default: 0
:param dilation: Dilation of the 2D convolution operation. Default: 1
:param groups: number of groups to divide input and output channels into,
so as to perform a "grouped convolution". When ``groups`` is not 1,
``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and the shape of weight should be ``(groups, out_channel // groups,
in_channels // groups, height, width)``. Default: 1
:type conv_mode: string or :class:`mgb.opr_param_defs.Convolution.Mode`
:param conv_mode: Supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default:
'CROSS_CORRELATION'.
:type compute_mode: string or
:class:`mgb.opr_param_defs.Convolution.ComputeMode`
:param compute_mode: When set to 'DEFAULT', no special requirements will be
placed on the precision of intermediate results. When set to 'FLOAT32',
Float32 would be used for accumulator and intermediate result, but only
effective when input and output are of Float16 dtype.
Refer to :class:`~.ConvTranspose2d` for more information.
"""
ph, pw = _pair(padding)
sh, sw = _pair_nonzero(stride)
dh, dw = _pair_nonzero(dilation)
Sparse = mgb.opr_param_defs.Convolution.Sparse
sparse_type = Sparse.DENSE if groups == 1 else Sparse.GROUP
res = mgb.opr.deconvolution(
inp,
weight,
pad_h=ph,
pad_w=pw,
stride_h=sh,
stride_w=sw,
dilate_h=dh,
dilate_w=dw,
format="NCHW",
strategy=get_conv_execution_strategy(),
mode=conv_mode,
compute_mode=compute_mode,
sparse=sparse_type,
)
if bias is not None:
res += bias
return res
@wrap_io_tensor
def max_pool2d(
inp: Tensor,
......
......@@ -8,7 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax
from .batchnorm import BatchNorm1d, BatchNorm2d
from .conv import Conv2d
from .conv import Conv2d, ConvTranspose2d
from .dropout import Dropout
from .embedding import Embedding
from .identity import Identity
......
......@@ -14,7 +14,7 @@ import numpy as np
import megengine._internal as mgb
from ..core import Parameter
from ..functional import conv2d
from ..functional import conv2d, conv_transpose2d
from ..utils.types import _pair, _pair_nonzero
from . import init
from .module import Module
......@@ -31,7 +31,6 @@ class _ConvNd(Module):
stride: Union[int, Tuple[int, int]],
padding: Union[int, Tuple[int, int]],
dilation: Union[int, Tuple[int, int]],
output_padding: Union[int, Tuple[int, int]],
groups: int,
bias: bool = True,
):
......@@ -46,7 +45,6 @@ class _ConvNd(Module):
self.stride = stride
self.padding = padding
self.dilation = dilation
self.output_padding = output_padding
self.groups = groups
self.weight = Parameter(np.zeros(self._infer_weight_shape(), dtype=np.float32))
......@@ -154,7 +152,6 @@ class Conv2d(_ConvNd):
stride,
padding,
dilation,
(0, 0),
groups,
bias,
)
......@@ -197,3 +194,112 @@ class Conv2d(_ConvNd):
self.conv_mode,
self.compute_mode,
)
class ConvTranspose2d(_ConvNd):
r"""Applies a 2D transposed convolution over an input tensor.
This module is also known as a deconvolution or a fractionally-strided convolution.
:class:`ConvTranspose2d` can ben seen as the gradient of :class:`Conv2d` operation
with respect to its input.
Convolution usually reduces the size of input, while transposed convolution works
the other way, transforming a smaller input to a larger output while preserving the
connectivity pattern.
:param in_channels: number of input channels.
:param out_channels: number of output channels.
:param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is
an :class:`int`, the actual kernel size would be
``(kernel_size, kernel_size)``. Default: 1
:param stride: stride of the 2D convolution operation. Default: 1
:param padding: size of the paddings added to the input on both sides of its
spatial dimensions. Only zero-padding is supported. Default: 0
:param dilation: dilation of the 2D convolution operation. Default: 1
:param groups: number of groups to divide input and output channels into,
so as to perform a "grouped convolution". When ``groups`` is not 1,
``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and there would be an extra dimension at the beginning of the weight's
shape. Specifically, the shape of weight would be ``(groups,
out_channel // groups, in_channels // groups, *kernel_size)``. Default: 1
:param bias: wether to add a bias onto the result of convolution. Default:
True
:param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default:
`CROSS_CORRELATION`.
:param compute_mode: When set to `DEFAULT`, no special requirements will be
placed on the precision of intermediate results. When set to `FLOAT32`,
float32 would be used for accumulator and intermediate result, but only
effective when input and output are of float16 dtype.
"""
_conv_mode_type = mgb.opr_param_defs.Convolution.Mode
_compute_mode_type = mgb.opr_param_defs.Convolution.ComputeMode
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
bias: bool = True,
conv_mode: str = "CROSS_CORRELATION",
compute_mode: str = "DEFAULT",
):
kernel_size = _pair_nonzero(kernel_size)
stride = _pair_nonzero(stride)
padding = _pair(padding)
dilation = _pair_nonzero(dilation)
self.conv_mode = self._conv_mode_type.convert(conv_mode)
self.compute_mode = self._compute_mode_type.convert(compute_mode)
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
)
def _get_fanin(self):
kh, kw = self.kernel_size
oc = self.out_channels
return kh * kw * oc
def _infer_weight_shape(self):
group = self.groups
ichl = self.in_channels
ochl = self.out_channels
kh, kw = self.kernel_size
if group == 1:
# Assume format is NCHW
return (ichl, ochl, kh, kw)
assert (
ichl % group == 0 and ochl % group == 0
), "invalid config: input_channels={} output_channels={} group={}".format(
ichl, ochl, group
)
# Assume format is NCHW
return (group, ichl // group, ochl // group, kh, kw)
def _infer_bias_shape(self):
# Assume format is NCHW
return (1, self.out_channels, 1, 1)
def forward(self, inp):
return conv_transpose2d(
inp,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups,
self.conv_mode,
self.compute_mode,
)
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import itertools
import numpy as np
import pytest
import torch
import megengine as mge
from megengine import Parameter, tensor
from megengine.module import Conv2d, ConvTranspose2d
from megengine.test import assertTensorClose
def test_conv_transpose2d():
SH, SW = 3, 1
PH, PW = 2, 0
N, IC, IH, IW = 4, 5, 8, 6
KH, KW = 3, 4
OC = 3
BIAS = True
def getsize(inp, kern, stride):
return (inp - 1) * stride + kern
OH = getsize(IH, KH, SH)
OW = getsize(IW, KW, SW)
inp = np.random.normal(size=(N, IC, IH, IW)).astype(np.float32)
out = np.zeros((N, OC, OH, OW), dtype=np.float32)
weight = np.random.normal(size=(IC, OC, KH, KW)).astype(np.float32)
bias = np.random.normal(size=(1, OC, 1, 1)).astype(np.float32)
for n, ic, ih, iw in itertools.product(*map(range, [N, IC, IH, IW])):
oh, ow = ih * SH, iw * SW
out[n, :, oh : oh + KH, ow : ow + KW] += inp[n, ic, ih, iw] * weight[ic]
out = out[:, :, PH : OH - PH, PW : OW - PW]
if BIAS:
out += bias
conv_transpose2d = ConvTranspose2d(IC, OC, (KH, KW), (SH, SW), (PH, PW), bias=BIAS)
conv_transpose2d.weight = Parameter(weight, dtype=np.float32)
if BIAS:
conv_transpose2d.bias = Parameter(bias, dtype=np.float32)
y = conv_transpose2d(tensor(inp))
assertTensorClose(out, y.numpy(), max_err=2e-6)
torch_conv_transpose2d = torch.nn.ConvTranspose2d(
IC, OC, (KH, KW), stride=(SH, SW), padding=(PH, PW), bias=BIAS
)
torch_conv_transpose2d.weight = torch.nn.parameter.Parameter(torch.Tensor(weight))
if BIAS:
torch_conv_transpose2d.bias = torch.nn.parameter.Parameter(
torch.Tensor(bias).reshape(OC)
)
torch_y = torch_conv_transpose2d(torch.Tensor(inp))
assertTensorClose(torch_y.detach().numpy(), y.numpy(), max_err=2e-6)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册