提交 e18afa0b 编写于 作者: M Megvii Engine Team

feat(mge/module): python wrapper for conv_transpose3d

GitOrigin-RevId: 61097b871338d09e882cd7c5a453efeb2e0b1d7f
上级 fdf7006b
......@@ -48,6 +48,7 @@ __all__ = [
"conv2d",
"conv3d",
"conv_transpose2d",
"conv_transpose3d",
"deformable_conv2d",
"deformable_psroi_pooling",
"dropout",
......@@ -488,6 +489,54 @@ def local_conv2d(
return output
def conv_transpose3d(
inp: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1,
) -> Tensor:
"""
3D transposed convolution operation. Only support the case that group = 1
and conv_mode = "cross_correlation".
Refer to :class:`~.ConvTranspose3d` for more information.
:param inp: feature map of the convolution operation.
:param weight: convolution kernel.
:param bias: bias added to the result of convolution (if given).
:param stride: stride of the 3D convolution operation. Default: 1
:param padding: size of the paddings added to the input on all sides of its
spatial dimensions. Only zero-padding is supported. Default: 0
:param dilation: dilation of the 3D convolution operation. Default: 1
:return: output tensor.
"""
D, H, W = 0, 1, 2
pad = _triple(padding)
stride = _triple_nonzero(stride)
dilate = _triple_nonzero(dilation)
op = builtin.Convolution3DBackwardData(
pad_d=pad[D],
pad_h=pad[H],
pad_w=pad[W],
stride_d=stride[D],
stride_h=stride[H],
stride_w=stride[W],
dilate_d=dilate[D],
dilate_h=dilate[H],
dilate_w=dilate[W],
strategy=get_execution_strategy(),
)
weight, inp = utils.convert_inputs(weight, inp)
(output,) = apply(op, weight, inp)
if bias is not None:
output += bias
return output
def max_pool2d(
inp: Tensor,
kernel_size: Union[int, Tuple[int, int]],
......
......@@ -18,6 +18,7 @@ from .conv import (
Conv3d,
ConvRelu2d,
ConvTranspose2d,
ConvTranspose3d,
DeformableConv2d,
LocalConv2d,
)
......
......@@ -15,6 +15,7 @@ from ..functional import (
conv2d,
conv3d,
conv_transpose2d,
conv_transpose3d,
deformable_conv2d,
local_conv2d,
relu,
......@@ -842,3 +843,75 @@ class DeformableConv2d(_ConvNd):
def forward(self, inp, offset, mask):
return self.calc_conv(inp, self.weight, offset, mask, self.bias)
class ConvTranspose3d(_ConvNd):
r"""
Applies a 3D transposed convolution over an input tensor.
Only support the case that group = 1 and conv_mode = "cross_correlation".
:class:`ConvTranspose3d` can be seen as the gradient of :class:`Conv3d` operation
with respect to its input.
Convolution3D usually reduces the size of input, while transposed convolution3d
works the opposite way, transforming a smaller input to a larger output while
preserving the connectivity pattern.
:param in_channels: number of input channels.
:param out_channels: number of output channels.
:param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is
an :class:`int`, the actual kernel size would be
``(kernel_size, kernel_size, kernel_size)``. Default: 1
:param stride: stride of the 3D convolution operation. Default: 1
:param padding: size of the paddings added to the input on all sides of its
spatial dimensions. Only zero-padding is supported. Default: 0
:param dilation: dilation of the 3D convolution operation. Default: 1
:param bias: wether to add a bias onto the result of convolution. Default:
True
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1,
bias: bool = True,
):
kernel_size = _triple_nonzero(kernel_size)
stride = _triple_nonzero(stride)
padding = _triple(padding)
dilation = _triple_nonzero(dilation)
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=1,
bias=bias,
)
def _get_fanin(self):
kt, kh, kw = self.kernel_size
ic = self.in_channels
return kt * kh * kw * ic
def _infer_weight_shape(self):
ichl = self.in_channels
ochl = self.out_channels
kt, kh, kw = self.kernel_size
return (ochl, ichl, kt, kh, kw)
def _infer_bias_shape(self):
# Assume format is NCTHW
return (1, self.out_channels, 1, 1, 1)
def forward(self, inp):
return conv_transpose3d(
inp, self.weight, self.bias, self.stride, self.padding, self.dilation,
)
......@@ -11,7 +11,7 @@ import itertools
import numpy as np
from megengine import Parameter, tensor
from megengine.module import ConvTranspose2d, LocalConv2d
from megengine.module import ConvTranspose2d, ConvTranspose3d, LocalConv2d
def test_conv_transpose2d():
......@@ -120,3 +120,64 @@ def test_local_conv2d():
test_func(10, 4, 4, 5, 5, 3, 1, 1, 1, 1)
test_func(10, 32, 32, 8, 8, 3, 1, 1, 1, 2)
test_func(10, 32, 32, 8, 8, 3, 1, 1, 1, 4)
def test_conv_transpose3d():
def getsize(inp, kernel, stride, dilate):
return (inp - 1) * stride + kernel * dilate - dilate + 1
def test_func(
N,
IC,
ID,
IH,
IW,
OC,
KD,
KH,
KW,
SD,
SH,
SW,
PD,
PH,
PW,
DD,
DH,
DW,
bias=True,
):
conv_transpose3d = ConvTranspose3d(
in_channels=IC,
out_channels=OC,
kernel_size=(KD, KH, KW),
stride=(SD, SH, SW),
padding=(PD, PH, PW),
dilation=(DD, DH, DW),
bias=bias,
)
OD = getsize(ID, KD, SD, DD)
OH = getsize(IH, KH, SH, DH)
OW = getsize(IW, KW, SW, DW)
inp = np.random.normal(size=(N, IC, ID, IH, IW))
weight = np.random.normal(size=(IC, OC, KD, KH, KW))
out_np = np.zeros((N, OC, OD, OH, OW), dtype=np.float32)
for n, ic, idepth, ih, iw in itertools.product(
*map(range, [N, IC, ID, IH, IW])
):
od, oh, ow = idepth * SD, ih * SH, iw * SW
out_np[n, :, od : od + KD, oh : oh + KH, ow : ow + KW] += (
inp[n, ic, idepth, ih, iw] * weight[ic]
)
out_np = out_np[:, :, PD : OD - PD, PH : OH - PH, PW : OW - PW]
conv_transpose3d.weight = Parameter(weight)
out_meg = conv_transpose3d.forward(tensor(inp))
np.testing.assert_almost_equal(out_meg.numpy(), out_np, 1e-5)
test_func(4, 3, 8, 16, 16, 8, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1)
test_func(4, 8, 16, 32, 32, 16, 1, 3, 1, 2, 1, 2, 0, 1, 0, 1, 1, 1)
......@@ -75,5 +75,20 @@ OP_TRAIT_REG(Convolution3D, Convolution3D, opr::Convolution3D)
.fallback();
}} // convolution3d
namespace { namespace convolution3d_backward_data {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& conv = static_cast<const Convolution3DBackwardData&>(def);
OperatorNodeConfig config{conv.make_name()};
mgb_assert(inputs.size() == 2);
return opr::Convolution3DBackwardData::make(inputs[0], inputs[1], conv.param(), conv.policy(), config);
}
OP_TRAIT_REG(Convolution3DBackwardData, Convolution3DBackwardData)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // convolution3d_backward_data
}
}
......@@ -53,6 +53,8 @@ def ConvolutionBackwardData: MgbHashableOp<"ConvolutionBackwardData", [Convoluti
def Convolution3D: MgbHashableOp<"Convolution3D", [Convolution3DParam, ExecutionPolicyParamBase<"policy">]>;
def Convolution3DBackwardData: MgbHashableOp<"Convolution3DBackwardData", [Convolution3DParam, ExecutionPolicyParamBase<"policy">]>;
def DeformableConv : MgbHashableOp<"DeformableConv", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>;
def GroupLocal: MgbHashableOp<"GroupLocal", [ConvolutionParam]>;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册