提交 5c94db16 编写于 作者: M Megvii Engine Team

feat(mge/functional): add groups support for conv_transpose2d & 3d

GitOrigin-RevId: b75b792fb4c911e1671a0db40bea1c96c6d832fb
上级 f2f33565
......@@ -372,6 +372,7 @@ def conv_transpose2d(
Args:
inp: feature map of the convolution operation.
weight: convolution kernel.
weight usually has shape ``(in_channels, out_channels, height, width)``.
bias: bias added to the result of convolution (if given).
stride: stride of the 2D convolution operation. Default: 1
padding: size of the paddings added to the input on both sides of its
......@@ -405,14 +406,12 @@ def conv_transpose2d(
if weight.dtype != dtype:
weight = weight.astype(dtype)
if groups != 1:
raise NotImplementedError("group transposed conv2d is not supported yet.")
stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding)
dilate_h, dilate_w = expand_hw(dilation)
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
sparse_type = "dense" if groups == 1 else "group"
op = builtin.ConvolutionBackwardData(
stride_h=stride_h,
stride_w=stride_w,
......@@ -422,6 +421,7 @@ def conv_transpose2d(
dilate_w=dilate_w,
strategy=get_execution_strategy(),
compute_mode=compute_mode,
sparse=sparse_type,
)
(output,) = apply(op, weight, inp)
if bias is not None:
......@@ -447,6 +447,7 @@ def deformable_conv2d(
Args:
inp: input feature map.
weight: convolution kernel.
weight usually has shape ``(out_channels, in_channels, height, width)``.
offset: input offset to kernel, channel of this tensor should match the deformable settings.
mask: input mask to kernel, channel of this tensor should match the deformable settings.
bias: bias added to the result of convolution (if given).
......@@ -551,6 +552,7 @@ def conv_transpose3d(
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1,
groups: int = 1,
) -> Tensor:
r"""3D transposed convolution operation. Only support the case that groups = 1
and conv_mode = "cross_correlation".
......@@ -581,6 +583,7 @@ def conv_transpose3d(
if weight.dtype != dtype:
weight = weight.astype(dtype)
sparse_type = "dense" if groups == 1 else "group"
op = builtin.Convolution3DBackwardData(
pad_d=pad[D],
pad_h=pad[H],
......@@ -592,6 +595,7 @@ def conv_transpose3d(
dilate_h=dilate[H],
dilate_w=dilate[W],
strategy=get_execution_strategy(),
sparse=sparse_type,
)
(output,) = apply(op, weight, inp)
if bias is not None:
......
......@@ -891,6 +891,7 @@ class ConvTranspose3d(_ConvNd):
padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1,
bias: bool = True,
groups: int = 1,
):
kernel_size = _triple_nonzero(kernel_size)
stride = _triple_nonzero(stride)
......@@ -903,7 +904,7 @@ class ConvTranspose3d(_ConvNd):
stride=stride,
padding=padding,
dilation=dilation,
groups=1,
groups=groups,
bias=bias,
)
......@@ -913,10 +914,21 @@ class ConvTranspose3d(_ConvNd):
return kt * kh * kw * ic
def _infer_weight_shape(self):
group = self.groups
ichl = self.in_channels
ochl = self.out_channels
kt, kh, kw = self.kernel_size
return (ichl, ochl, kt, kh, kw)
if group == 1:
# Assume format is NCHW
return (ichl, ochl, kt, kh, kw)
assert (
ichl % group == 0 and ochl % group == 0
), "invalid config: in_channels={} out_channels={} group={}".format(
ichl, ochl, group
)
# Assume format is NCHW
return (group, ichl // group, ochl // group, kt, kh, kw)
def _infer_bias_shape(self):
# Assume format is NCTHW
......
......@@ -1290,3 +1290,62 @@ def test_set_warp_perspective_config():
expected = F.vision.warp_perspective(inp, M, (2, 2), format="NHWC")
np.testing.assert_allclose(config_out.numpy(), expected.numpy())
np.testing.assert_allclose(context_out.numpy(), expected.numpy())
@pytest.mark.parametrize("stride", [(1, 1)])
@pytest.mark.parametrize("padding", [(1, 1)])
@pytest.mark.parametrize("dilation", [(1, 1)])
@pytest.mark.parametrize("ksize", [(3, 3)])
@pytest.mark.parametrize("groups", [1, 2])
def test_local_conv2d(stride, padding, dilation, ksize, groups):
batch_size, in_channels, out_channels = 2, 4, 8
input_height, input_width = 10, 10
output_height = (input_height + padding[0] * 2 - ksize[0]) // stride[0] + 1
output_width = (input_width + padding[1] * 2 - ksize[1]) // stride[1] + 1
def local_conv2d_np(data, weight, stride, padding, dialtion):
# naive calculation use numpy
# only test output_height == input_height, output_width == input_width
data = np.pad(data, ((0, 0), (0, 0), (1, 1), (1, 1)))
expected = np.zeros(
(batch_size, out_channels, output_height, output_width), dtype=np.float32,
)
ic_group_size = in_channels // groups
oc_group_size = out_channels // groups
for n, oc, oh, ow in itertools.product(
*map(range, [batch_size, out_channels, output_height, output_width])
):
ih, iw = oh * stride[0], ow * stride[1]
g_id = oc // oc_group_size
expected[n, oc, ih, iw] = np.sum(
data[
n,
g_id * ic_group_size : (g_id + 1) * ic_group_size,
ih : ih + ksize[0],
iw : iw + ksize[1],
]
* weight[g_id, oh, ow, :, :, :, oc % oc_group_size]
)
return expected
data = np.random.rand(batch_size, in_channels, input_height, input_width).astype(
"float32"
)
weight = np.random.rand(
groups,
output_height,
output_width,
in_channels // groups,
*ksize,
out_channels // groups,
).astype("float32")
output = F.local_conv2d(
tensor(data),
tensor(weight),
None,
stride=stride,
padding=padding,
dilation=dilation,
)
ref = local_conv2d_np(data, weight, stride, padding, dilation)
np.testing.assert_almost_equal(output.numpy(), ref, 5)
......@@ -42,162 +42,3 @@ def test_conv_dtype_promotion(name, reproducible):
m = getattr(M, name)(Ci, Co, K)
x = tensor(np.random.random(size=(N, Ci) + S).astype("float16"))
np.testing.assert_equal(m(x).numpy(), m(x.astype("float32")).numpy())
def test_conv_transpose2d():
SH, SW = 3, 1
PH, PW = 2, 0
N, IC, IH, IW = 4, 5, 8, 6
KH, KW = 3, 4
OC = 3
BIAS = False
def getsize(inp, kern, stride):
return (inp - 1) * stride + kern
OH = getsize(IH, KH, SH)
OW = getsize(IW, KW, SW)
inp = np.random.normal(size=(N, IC, IH, IW)).astype(np.float32)
out = np.zeros((N, OC, OH, OW), dtype=np.float32)
weight = np.random.normal(size=(IC, OC, KH, KW)).astype(np.float32)
bias = np.random.normal(size=(1, OC, 1, 1)).astype(np.float32)
# naive calculation use numpy
for n, ic, ih, iw in itertools.product(*map(range, [N, IC, IH, IW])):
oh, ow = ih * SH, iw * SW
out[n, :, oh : oh + KH, ow : ow + KW] += inp[n, ic, ih, iw] * weight[ic]
out = out[:, :, PH : OH - PH, PW : OW - PW]
if BIAS:
out += bias
# megengine conv_transpose2d calculation
conv_transpose2d = ConvTranspose2d(IC, OC, (KH, KW), (SH, SW), (PH, PW), bias=BIAS)
conv_transpose2d.weight = Parameter(weight, dtype=np.float32)
if BIAS:
conv_transpose2d.bias = Parameter(bias, dtype=np.float32)
y = conv_transpose2d(tensor(inp))
np.testing.assert_almost_equal(out, y.numpy(), 2e-6)
def test_local_conv2d():
def test_func(
batch_size,
in_channels,
out_channels,
input_height,
input_width,
kernel_size,
stride,
padding,
dilation,
groups,
):
local_conv2d = LocalConv2d(
in_channels=in_channels,
out_channels=out_channels,
input_height=input_height,
input_width=input_width,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
inputs = np.random.normal(
size=(batch_size, in_channels, input_height, input_width)
).astype(np.float32)
output_height = (input_height + padding * 2 - kernel_size) // stride + 1
output_width = (input_width + padding * 2 - kernel_size) // stride + 1
weights = local_conv2d.weight.numpy()
outputs = local_conv2d(tensor(inputs))
# naive calculation use numpy
# only test output_height == input_height, output_width == input_width
inputs = np.pad(inputs, ((0, 0), (0, 0), (1, 1), (1, 1)))
expected = np.zeros(
(batch_size, out_channels, output_height, output_width), dtype=np.float32,
)
ic_group_size = in_channels // groups
oc_group_size = out_channels // groups
for n, oc, oh, ow in itertools.product(
*map(range, [batch_size, out_channels, output_height, output_width])
):
ih, iw = oh * stride, ow * stride
g_id = oc // oc_group_size
expected[n, oc, ih, iw] = np.sum(
inputs[
n,
g_id * ic_group_size : (g_id + 1) * ic_group_size,
ih : ih + kernel_size,
iw : iw + kernel_size,
]
* weights[g_id, oh, ow, :, :, :, oc % oc_group_size]
)
np.testing.assert_almost_equal(outputs.numpy(), expected, 1e-5)
test_func(10, 4, 4, 5, 5, 3, 1, 1, 1, 1)
test_func(10, 32, 32, 8, 8, 3, 1, 1, 1, 2)
test_func(10, 32, 32, 8, 8, 3, 1, 1, 1, 4)
def test_conv_transpose3d():
def getsize(inp, kernel, stride, dilate):
return (inp - 1) * stride + kernel * dilate - dilate + 1
def test_func(
N,
IC,
ID,
IH,
IW,
OC,
KD,
KH,
KW,
SD,
SH,
SW,
PD,
PH,
PW,
DD,
DH,
DW,
bias=True,
):
conv_transpose3d = ConvTranspose3d(
in_channels=IC,
out_channels=OC,
kernel_size=(KD, KH, KW),
stride=(SD, SH, SW),
padding=(PD, PH, PW),
dilation=(DD, DH, DW),
bias=bias,
)
OD = getsize(ID, KD, SD, DD)
OH = getsize(IH, KH, SH, DH)
OW = getsize(IW, KW, SW, DW)
inp = np.random.normal(size=(N, IC, ID, IH, IW))
weight = np.random.normal(size=(IC, OC, KD, KH, KW))
out_np = np.zeros((N, OC, OD, OH, OW), dtype=np.float32)
for n, ic, idepth, ih, iw in itertools.product(
*map(range, [N, IC, ID, IH, IW])
):
od, oh, ow = idepth * SD, ih * SH, iw * SW
out_np[n, :, od : od + KD, oh : oh + KH, ow : ow + KW] += (
inp[n, ic, idepth, ih, iw] * weight[ic]
)
out_np = out_np[:, :, PD : OD - PD, PH : OH - PH, PW : OW - PW]
assert conv_transpose3d.weight.numpy().shape == weight.shape
conv_transpose3d.weight = Parameter(weight)
out_meg = conv_transpose3d.forward(tensor(inp))
np.testing.assert_almost_equal(out_meg.numpy(), out_np, 1e-5)
test_func(4, 3, 8, 16, 16, 8, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1)
test_func(4, 8, 16, 32, 32, 16, 1, 3, 1, 2, 1, 2, 0, 1, 0, 1, 1, 1)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册