提交 99309fa3 编写于 作者: M Megvii Engine Team

feat(mge/functional): add param output_padding for deconv ops

GitOrigin-RevId: 8a69608953a69b40db4d23b489435a7ae03c9523
上级 116781ba
...@@ -335,6 +335,7 @@ def conv_transpose2d( ...@@ -335,6 +335,7 @@ def conv_transpose2d(
bias: Optional[Tensor] = None, bias: Optional[Tensor] = None,
stride: Union[int, Tuple[int, int]] = 1, stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0, padding: Union[int, Tuple[int, int]] = 0,
output_padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1, dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1, groups: int = 1,
conv_mode="cross_correlation", conv_mode="cross_correlation",
...@@ -352,6 +353,7 @@ def conv_transpose2d( ...@@ -352,6 +353,7 @@ def conv_transpose2d(
stride: stride of the 2D convolution operation. Default: 1 stride: stride of the 2D convolution operation. Default: 1
padding: size of the paddings added to the input on both sides of its padding: size of the paddings added to the input on both sides of its
spatial dimensions. Only zero-padding is supported. Default: 0 spatial dimensions. Only zero-padding is supported. Default: 0
output_padding: size of paddings appended to output. Default: 0
dilation: dilation of the 2D convolution operation. Default: 1 dilation: dilation of the 2D convolution operation. Default: 1
groups: number of groups into which the input and output channels are divided, groups: number of groups into which the input and output channels are divided,
so as to perform a ``grouped convolution``. When ``groups`` is not 1, so as to perform a ``grouped convolution``. When ``groups`` is not 1,
...@@ -374,6 +376,7 @@ def conv_transpose2d( ...@@ -374,6 +376,7 @@ def conv_transpose2d(
stride_h, stride_w = expand_hw(stride) stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding) pad_h, pad_w = expand_hw(padding)
output_pad_h, output_pad_w = expand_hw(output_padding)
dilate_h, dilate_w = expand_hw(dilation) dilate_h, dilate_w = expand_hw(dilation)
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
...@@ -389,6 +392,31 @@ def conv_transpose2d( ...@@ -389,6 +392,31 @@ def conv_transpose2d(
compute_mode=compute_mode, compute_mode=compute_mode,
sparse=sparse_type, sparse=sparse_type,
) )
if output_pad_h != 0 or output_pad_h != 0:
assert (
output_pad_h < stride[0]
), "output_padding[0] shoule be less than stride[0]"
assert (
output_pad_w < stride[1]
), "output_padding[1] shoule be less than stride[1]"
Hout = (
(inp.shape[2] - 1) * stride[0]
- 2 * padding[0]
+ dilation[0] * (weight.shape[2] - 1)
+ output_pad_h
+ 1
)
Wout = (
(inp.shape[3] - 1) * stride[1]
- 2 * padding[1]
+ dilation[1] * (weight.shape[3] - 1)
+ output_pad_w
+ 1
)
output_shape = [inp.shape[0], weight.shape[1], Hout, Wout]
output_shape = astensor1d(output_shape)
(output,) = apply(op, weight, inp, output_shape)
else:
(output,) = apply(op, weight, inp) (output,) = apply(op, weight, inp)
if bias is not None: if bias is not None:
if amp._enabled: if amp._enabled:
...@@ -528,6 +556,7 @@ def conv_transpose3d( ...@@ -528,6 +556,7 @@ def conv_transpose3d(
bias: Optional[Tensor] = None, bias: Optional[Tensor] = None,
stride: Union[int, Tuple[int, int, int]] = 1, stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0, padding: Union[int, Tuple[int, int, int]] = 0,
output_padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1, dilation: Union[int, Tuple[int, int, int]] = 1,
groups: int = 1, groups: int = 1,
) -> Tensor: ) -> Tensor:
...@@ -544,6 +573,7 @@ def conv_transpose3d( ...@@ -544,6 +573,7 @@ def conv_transpose3d(
stride: stride of the 3D convolution operation. Default: 1 stride: stride of the 3D convolution operation. Default: 1
padding: size of the paddings added to the input on all sides of its padding: size of the paddings added to the input on all sides of its
spatial dimensions. Only zero-padding is supported. Default: 0 spatial dimensions. Only zero-padding is supported. Default: 0
output_padding: size of paddings appended to output. Default: 0
dilation: dilation of the 3D convolution operation. Default: 1 dilation: dilation of the 3D convolution operation. Default: 1
groups: number of groups into which the input and output channels are divided, groups: number of groups into which the input and output channels are divided,
so as to perform a ``grouped convolution``. When ``groups`` is not 1, so as to perform a ``grouped convolution``. When ``groups`` is not 1,
...@@ -558,6 +588,7 @@ def conv_transpose3d( ...@@ -558,6 +588,7 @@ def conv_transpose3d(
pad = expand_dhw(padding) pad = expand_dhw(padding)
stride = expand_dhw(stride) stride = expand_dhw(stride)
dilate = expand_dhw(dilation) dilate = expand_dhw(dilation)
output_padding = expand_dhw(output_padding)
sparse_type = "dense" if groups == 1 else "group" sparse_type = "dense" if groups == 1 else "group"
op = builtin.Convolution3DBackwardData( op = builtin.Convolution3DBackwardData(
...@@ -573,6 +604,41 @@ def conv_transpose3d( ...@@ -573,6 +604,41 @@ def conv_transpose3d(
strategy=get_execution_strategy(), strategy=get_execution_strategy(),
sparse=sparse_type, sparse=sparse_type,
) )
if output_padding[0] != 0 or output_padding[1] != 0 or output_padding[2] != 0:
assert (
output_padding[0] < stride[0]
), "output_padding[0] shoule be less than stride[0]"
assert (
output_padding[1] < stride[1]
), "output_padding[1] shoule be less than stride[1]"
assert (
output_padding[2] < stride[2]
), "output_padding[2] shoule be less than stride[2]"
Dout = (
(inp.shape[2] - 1) * stride[0]
- 2 * padding[0]
+ dilation[0] * (weight.shape[2] - 1)
+ output_padding[0]
+ 1
)
Hout = (
(inp.shape[3] - 1) * stride[1]
- 2 * padding[1]
+ dilation[1] * (weight.shape[3] - 1)
+ output_padding[1]
+ 1
)
Wout = (
(inp.shape[4] - 1) * stride[2]
- 2 * padding[2]
+ dilation[2] * (weight.shape[4] - 1)
+ output_padding[2]
+ 1
)
output_shape = [inp.shape[0], weight.shape[1], Dout, Hout, Wout]
output_shape = astensor1d(output_shape)
(output,) = apply(op, weight, inp, output_shape)
else:
(output,) = apply(op, weight, inp) (output,) = apply(op, weight, inp)
if bias is not None: if bias is not None:
output += bias output += bias
......
...@@ -134,6 +134,7 @@ def conv_transpose2d( ...@@ -134,6 +134,7 @@ def conv_transpose2d(
dtype=None, dtype=None,
stride: Union[int, Tuple[int, int]] = 1, stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0, padding: Union[int, Tuple[int, int]] = 0,
output_padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1, dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1, groups: int = 1,
conv_mode="cross_correlation", conv_mode="cross_correlation",
...@@ -156,6 +157,7 @@ def conv_transpose2d( ...@@ -156,6 +157,7 @@ def conv_transpose2d(
) )
pad_h, pad_w = _pair(padding) pad_h, pad_w = _pair(padding)
output_pad_h, output_pad_w = _pair(output_padding)
stride_h, stride_w = _pair_nonzero(stride) stride_h, stride_w = _pair_nonzero(stride)
dilate_h, dilate_w = _pair_nonzero(dilation) dilate_h, dilate_w = _pair_nonzero(dilation)
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
...@@ -173,5 +175,30 @@ def conv_transpose2d( ...@@ -173,5 +175,30 @@ def conv_transpose2d(
compute_mode=compute_mode, compute_mode=compute_mode,
mode=conv_mode, mode=conv_mode,
) )
if output_pad_h != 0 or output_pad_h != 0:
assert (
output_pad_h < stride[0]
), "output_padding[0] shoule be less than stride[0]"
assert (
output_pad_w < stride[1]
), "output_padding[1] shoule be less than stride[1]"
Hout = (
(inp.shape[2] - 1) * stride[0]
- 2 * padding[0]
+ dilation[0] * (weight.shape[2] - 1)
+ output_pad_h
+ 1
)
Wout = (
(inp.shape[3] - 1) * stride[1]
- 2 * padding[1]
+ dilation[1] * (weight.shape[3] - 1)
+ output_pad_w
+ 1
)
output_shape = [inp.shape[0], weight.shape[1], Hout, Wout]
output_shape = Tensor(output_shape)
(output,) = apply(op, weight, inp, output_shape)
else:
(output,) = apply(op, weight, inp) (output,) = apply(op, weight, inp)
return output return output
...@@ -30,6 +30,7 @@ class _ConvNd(Module): ...@@ -30,6 +30,7 @@ class _ConvNd(Module):
kernel_size: Union[int, Tuple[int, int]], kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]],
padding: Union[int, Tuple[int, int]], padding: Union[int, Tuple[int, int]],
output_padding: Union[int, Tuple[int, int]],
dilation: Union[int, Tuple[int, int]], dilation: Union[int, Tuple[int, int]],
groups: int, groups: int,
bias: bool = True, bias: bool = True,
...@@ -45,6 +46,7 @@ class _ConvNd(Module): ...@@ -45,6 +46,7 @@ class _ConvNd(Module):
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.stride = stride self.stride = stride
self.padding = padding self.padding = padding
self.output_padding = output_padding
self.dilation = dilation self.dilation = dilation
self.groups = groups self.groups = groups
...@@ -178,6 +180,7 @@ class Conv1d(_ConvNd): ...@@ -178,6 +180,7 @@ class Conv1d(_ConvNd):
kernel_size, kernel_size,
stride, stride,
padding, padding,
0,
dilation, dilation,
groups, groups,
bias, bias,
...@@ -352,6 +355,7 @@ class Conv2d(_ConvNd): ...@@ -352,6 +355,7 @@ class Conv2d(_ConvNd):
kernel_size, kernel_size,
stride, stride,
padding, padding,
0,
dilation, dilation,
groups, groups,
bias, bias,
...@@ -505,6 +509,7 @@ class Conv3d(_ConvNd): ...@@ -505,6 +509,7 @@ class Conv3d(_ConvNd):
kernel_size, kernel_size,
stride, stride,
padding, padding,
0,
dilation, dilation,
groups, groups,
bias, bias,
...@@ -572,6 +577,7 @@ class ConvTranspose2d(_ConvNd): ...@@ -572,6 +577,7 @@ class ConvTranspose2d(_ConvNd):
stride: stride of the 2D convolution operation. Default: 1 stride: stride of the 2D convolution operation. Default: 1
padding: size of the paddings added to the input on both sides of its padding: size of the paddings added to the input on both sides of its
spatial dimensions. Only zero-padding is supported. Default: 0 spatial dimensions. Only zero-padding is supported. Default: 0
output_padding: size of paddings appended to output. Default: 0
dilation: dilation of the 2D convolution operation. Default: 1 dilation: dilation of the 2D convolution operation. Default: 1
groups: number of groups into which the input and output channels are divided, groups: number of groups into which the input and output channels are divided,
so as to perform a ``grouped convolution``. When ``groups`` is not 1, so as to perform a ``grouped convolution``. When ``groups`` is not 1,
...@@ -591,6 +597,8 @@ class ConvTranspose2d(_ConvNd): ...@@ -591,6 +597,8 @@ class ConvTranspose2d(_ConvNd):
* ``bias`` usually has shape ``(1, out_channels, *1)`` * ``bias`` usually has shape ``(1, out_channels, *1)``
""" """
output_padding = 0
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
...@@ -598,6 +606,7 @@ class ConvTranspose2d(_ConvNd): ...@@ -598,6 +606,7 @@ class ConvTranspose2d(_ConvNd):
kernel_size: Union[int, Tuple[int, int]], kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1, stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0, padding: Union[int, Tuple[int, int]] = 0,
output_padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1, dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1, groups: int = 1,
bias: bool = True, bias: bool = True,
...@@ -608,6 +617,7 @@ class ConvTranspose2d(_ConvNd): ...@@ -608,6 +617,7 @@ class ConvTranspose2d(_ConvNd):
kernel_size = _pair_nonzero(kernel_size) kernel_size = _pair_nonzero(kernel_size)
stride = _pair_nonzero(stride) stride = _pair_nonzero(stride)
padding = _pair(padding) padding = _pair(padding)
output_padding = _pair(output_padding)
dilation = _pair_nonzero(dilation) dilation = _pair_nonzero(dilation)
self.conv_mode = conv_mode self.conv_mode = conv_mode
self.compute_mode = compute_mode self.compute_mode = compute_mode
...@@ -617,6 +627,7 @@ class ConvTranspose2d(_ConvNd): ...@@ -617,6 +627,7 @@ class ConvTranspose2d(_ConvNd):
kernel_size, kernel_size,
stride, stride,
padding, padding,
output_padding,
dilation, dilation,
groups, groups,
bias, bias,
...@@ -656,6 +667,7 @@ class ConvTranspose2d(_ConvNd): ...@@ -656,6 +667,7 @@ class ConvTranspose2d(_ConvNd):
bias, bias,
self.stride, self.stride,
self.padding, self.padding,
self.output_padding,
self.dilation, self.dilation,
self.groups, self.groups,
self.conv_mode, self.conv_mode,
...@@ -817,6 +829,7 @@ class DeformableConv2d(_ConvNd): ...@@ -817,6 +829,7 @@ class DeformableConv2d(_ConvNd):
kernel_size, kernel_size,
stride, stride,
padding, padding,
0,
dilation, dilation,
groups, groups,
bias, bias,
...@@ -889,6 +902,7 @@ class ConvTranspose3d(_ConvNd): ...@@ -889,6 +902,7 @@ class ConvTranspose3d(_ConvNd):
stride: stride of the 3D convolution operation. Default: 1 stride: stride of the 3D convolution operation. Default: 1
padding: size of the paddings added to the input on all sides of its padding: size of the paddings added to the input on all sides of its
spatial dimensions. Only zero-padding is supported. Default: 0 spatial dimensions. Only zero-padding is supported. Default: 0
output_padding: size of paddings appended to output. Default: 0
dilation: dilation of the 3D convolution operation. Default: 1 dilation: dilation of the 3D convolution operation. Default: 1
groups: number of groups into which the input and output channels are divided, groups: number of groups into which the input and output channels are divided,
so as to perform a ``grouped convolution``. When ``groups`` is not 1, so as to perform a ``grouped convolution``. When ``groups`` is not 1,
...@@ -902,6 +916,8 @@ class ConvTranspose3d(_ConvNd): ...@@ -902,6 +916,8 @@ class ConvTranspose3d(_ConvNd):
* ``bias`` usually has shape ``(1, out_channels, *1)`` * ``bias`` usually has shape ``(1, out_channels, *1)``
""" """
output_padding = 0
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
...@@ -909,6 +925,7 @@ class ConvTranspose3d(_ConvNd): ...@@ -909,6 +925,7 @@ class ConvTranspose3d(_ConvNd):
kernel_size: Union[int, Tuple[int, int, int]], kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]] = 1, stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0, padding: Union[int, Tuple[int, int, int]] = 0,
output_padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1, dilation: Union[int, Tuple[int, int, int]] = 1,
groups: int = 1, groups: int = 1,
bias: bool = True, bias: bool = True,
...@@ -923,6 +940,7 @@ class ConvTranspose3d(_ConvNd): ...@@ -923,6 +940,7 @@ class ConvTranspose3d(_ConvNd):
kernel_size=kernel_size, kernel_size=kernel_size,
stride=stride, stride=stride,
padding=padding, padding=padding,
output_padding=output_padding,
dilation=dilation, dilation=dilation,
groups=groups, groups=groups,
bias=bias, bias=bias,
...@@ -956,5 +974,11 @@ class ConvTranspose3d(_ConvNd): ...@@ -956,5 +974,11 @@ class ConvTranspose3d(_ConvNd):
def forward(self, inp): def forward(self, inp):
return conv_transpose3d( return conv_transpose3d(
inp, self.weight, self.bias, self.stride, self.padding, self.dilation, inp,
self.weight,
self.bias,
self.stride,
self.padding,
self.output_padding,
self.dilation,
) )
...@@ -74,6 +74,7 @@ class ConvTranspose2d(Float.ConvTranspose2d, QATModule): ...@@ -74,6 +74,7 @@ class ConvTranspose2d(Float.ConvTranspose2d, QATModule):
float_module.kernel_size, float_module.kernel_size,
float_module.stride, float_module.stride,
float_module.padding, float_module.padding,
float_module.output_padding,
float_module.dilation, float_module.dilation,
float_module.groups, float_module.groups,
float_module.bias is not None, float_module.bias is not None,
......
...@@ -138,6 +138,8 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule): ...@@ -138,6 +138,8 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule):
dtype: data type of the output, should be qint8. dtype: data type of the output, should be qint8.
""" """
output_padding = 0
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
...@@ -145,6 +147,7 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule): ...@@ -145,6 +147,7 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule):
kernel_size: Union[int, Tuple[int, int]], kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1, stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0, padding: Union[int, Tuple[int, int]] = 0,
output_padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1, dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1, groups: int = 1,
bias: bool = True, bias: bool = True,
...@@ -159,6 +162,7 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule): ...@@ -159,6 +162,7 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule):
kernel_size=kernel_size, kernel_size=kernel_size,
stride=stride, stride=stride,
padding=padding, padding=padding,
output_padding=output_padding,
dilation=dilation, dilation=dilation,
groups=groups, groups=groups,
bias=bias, bias=bias,
...@@ -180,6 +184,7 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule): ...@@ -180,6 +184,7 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule):
qat_module.kernel_size, qat_module.kernel_size,
qat_module.stride, qat_module.stride,
qat_module.padding, qat_module.padding,
qat_module.output_padding,
qat_module.dilation, qat_module.dilation,
qat_module.groups, qat_module.groups,
qat_module.bias is not None, qat_module.bias is not None,
...@@ -212,6 +217,7 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule): ...@@ -212,6 +217,7 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule):
dtype=self.output_dtype, dtype=self.output_dtype,
stride=self.stride, stride=self.stride,
padding=self.padding, padding=self.padding,
output_padding=self.output_padding,
dilation=self.dilation, dilation=self.dilation,
groups=self.groups, groups=self.groups,
conv_mode=self.conv_mode, conv_mode=self.conv_mode,
......
...@@ -18,7 +18,8 @@ from megengine.core._trace_option import use_symbolic_shape ...@@ -18,7 +18,8 @@ from megengine.core._trace_option import use_symbolic_shape
from megengine.core.autodiff.grad import Grad from megengine.core.autodiff.grad import Grad
from megengine.core.tensor.utils import make_shape_tuple from megengine.core.tensor.utils import make_shape_tuple
from megengine.device import get_device_count from megengine.device import get_device_count
from megengine.module import LayerNorm from megengine.jit.tracing import trace
from megengine.module import ConvTranspose2d, ConvTranspose3d, LayerNorm
_assert_allclose = partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6) _assert_allclose = partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6)
...@@ -1374,3 +1375,37 @@ def test_local_conv2d(stride, padding, dilation, ksize, groups): ...@@ -1374,3 +1375,37 @@ def test_local_conv2d(stride, padding, dilation, ksize, groups):
) )
ref = local_conv2d_np(data, weight, stride, padding, dilation) ref = local_conv2d_np(data, weight, stride, padding, dilation)
np.testing.assert_almost_equal(output.numpy(), ref, 5) np.testing.assert_almost_equal(output.numpy(), ref, 5)
def test_conv_transpose2d():
m = ConvTranspose2d(
16, 33, (3, 5), output_padding=(1, 2), stride=(2, 3), padding=(4, 2)
)
@trace(symbolic=True)
def fwd(inp: Tensor):
return m(inp)
input = Tensor(np.random.rand(20, 16, 50, 100))
output = fwd(input)
output_shape = Tensor(output.shape)
np.testing.assert_equal(
output_shape.numpy(), np.array([20, 33, 94, 300], dtype=np.int32)
)
def test_conv_transpose3d():
m = ConvTranspose3d(
16, 33, (3, 5, 2), output_padding=(2, 1, 1), stride=(3, 2, 2), padding=(0, 4, 2)
)
@trace(symbolic=True)
def fwd(inp: Tensor):
return m(inp)
input = Tensor(np.random.rand(20, 16, 10, 50, 100))
output = fwd(input)
output_shape = Tensor(output.shape)
np.testing.assert_equal(
output_shape.numpy(), np.array([20, 33, 32, 96, 197], dtype=np.int32)
)
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "../op_trait.h" #include "../op_trait.h"
#include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" #include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megbrain/opr/tensor_gen.h"
namespace mgb { namespace mgb {
namespace imperative { namespace imperative {
...@@ -152,8 +153,11 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { ...@@ -152,8 +153,11 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
inputs[0], inputs[1], conv.param(), conv.policy(), config); inputs[0], inputs[1], conv.param(), conv.policy(), config);
} else { } else {
mgb_assert(inputs.size() == 3); mgb_assert(inputs.size() == 3);
auto* src_for_shape =
opr::Alloc::make(inputs[2], inputs[0]->dtype(), {}).node();
return opr::ConvolutionBackwardData::make( return opr::ConvolutionBackwardData::make(
inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config); inputs[0], inputs[1], src_for_shape, conv.param(), conv.policy(),
config);
} }
} }
...@@ -168,6 +172,14 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -168,6 +172,14 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
if (filter.ndim && diff.ndim) { if (filter.ndim && diff.ndim) {
// deduce_layout won't override existing dtype // deduce_layout won't override existing dtype
dnn_opr.opr().deduce_layout(filter, diff, output_layout); dnn_opr.opr().deduce_layout(filter, diff, output_layout);
if (inputs.size() == 3) {
if (!inputs[2].value.empty()) {
cg::copy_tensor_value_to_shape(output_layout, inputs[2].value);
output_layout.init_contiguous_stride();
} else {
output_layout.ndim = 0;
}
}
} }
return {{{output_layout, inputs[0].comp_node}}, output_layout.ndim != 0}; return {{{output_layout, inputs[0].comp_node}}, output_layout.ndim != 0};
} }
...@@ -185,8 +197,11 @@ SmallVector<TensorPtr> apply_on_physical_tensor( ...@@ -185,8 +197,11 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
return output_descs[0].layout; return output_descs[0].layout;
} else { } else {
TensorLayout out_layout{inputs[0]->dtype()}; TensorLayout out_layout{inputs[0]->dtype()};
dnn_opr.op()->deduce_layout( if (inputs.size() == 3) {
inputs[0]->layout(), inputs[1]->layout(), out_layout); cg::copy_tensor_value_to_shape(
out_layout, inputs[2]->get_value().proxy_to_default_cpu());
out_layout.init_contiguous_stride();
}
return out_layout; return out_layout;
} }
}(); }();
...@@ -263,50 +278,74 @@ namespace convolution3d_backward_data { ...@@ -263,50 +278,74 @@ namespace convolution3d_backward_data {
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
mgb_assert( mgb_assert(
inputs.size() == 2, inputs.size() == 2 || inputs.size() == 3,
"inputs num of conv_transpose3d should be 2 but you give %zu", "inputs num of conv_transpose3d should be 2 or 3 but you give %zu",
inputs.size()); inputs.size());
auto&& op_def = def.cast_final_safe<Convolution3DBackwardData>(); auto&& conv3dbwd = def.cast_final_safe<Convolution3DBackwardData>();
auto&& weight = inputs[0]; DnnOprHelper<megdnn::Convolution3DBackwardData> dnn_opr(conv3dbwd.param());
auto&& filter = inputs[0];
auto&& diff = inputs[1]; auto&& diff = inputs[1];
if (!(weight.layout.ndim && diff.layout.ndim)) {
return {{{TensorLayout{weight.layout.dtype}, weight.comp_node}}, false}; if (!(filter.layout.ndim && diff.layout.ndim)) {
return {{{TensorLayout{filter.layout.dtype}, filter.comp_node}}, false};
} }
DnnOprHelper<megdnn::Convolution3DBackwardData> dnn_opr(op_def.param());
auto oup_layout = dnn_opr.deduce_layout(weight.layout, diff.layout); TensorLayout output_layout = dnn_opr.deduce_layout(filter.layout, diff.layout);
return {{{oup_layout, weight.comp_node}}, true}; if (filter.layout.ndim && diff.layout.ndim) {
if (inputs.size() == 3) {
if (!inputs[2].value.empty()) {
cg::copy_tensor_value_to_shape(output_layout, inputs[2].value);
output_layout.init_contiguous_stride();
} else {
output_layout.ndim = 0;
}
}
}
return {{{output_layout, inputs[0].comp_node}}, output_layout.ndim != 0};
} }
SmallVector<TensorPtr> apply_on_physical_tensor( SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs, const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto&& conv = def.cast_final_safe<Convolution3DBackwardData>(); auto&& conv3dbwd = def.cast_final_safe<Convolution3DBackwardData>();
auto cn = inputs[0]->comp_node(); CompNode cn = inputs[0]->comp_node();
DnnOprCaller<megdnn::Convolution3DBackwardData> dnn_opr(
auto&& wlayout = inputs[0]->layout(); cn, conv3dbwd.param(), conv3dbwd.policy());
auto&& dlayout = inputs[1]->layout(); auto out_layout = [&] {
DnnOprCaller<megdnn::Convolution3DBackwardData> dnn_op(
cn, conv.param(), conv.policy());
auto oup_layout = [&] {
if (validated) { if (validated) {
return output_descs[0].layout; return output_descs[0].layout;
} else { } else {
return dnn_op.deduce_layout(wlayout, dlayout); TensorLayout out_layout{inputs[0]->dtype()};
dnn_opr.op()->deduce_layout(
inputs[0]->layout(), inputs[1]->layout(), out_layout);
if (inputs.size() == 3) {
cg::copy_tensor_value_to_shape(
out_layout, inputs[2]->get_value().proxy_to_default_cpu());
out_layout.init_contiguous_stride();
}
return out_layout;
} }
}(); }();
auto oup = Tensor::make(oup_layout, cn); auto out = Tensor::make(out_layout, cn);
dnn_op.exec_fastrun(inputs[0], inputs[1], oup); dnn_opr.exec_fastrun(inputs[0], inputs[1], out);
return {oup}; return {out};
} }
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& conv = static_cast<const Convolution3DBackwardData&>(def); auto&& conv = static_cast<const Convolution3DBackwardData&>(def);
OperatorNodeConfig config{conv.make_name()}; OperatorNodeConfig config{conv.make_name()};
mgb_assert(inputs.size() == 2); if (inputs.size() == 2) {
return opr::Convolution3DBackwardData::make( return opr::Convolution3DBackwardData::make(
inputs[0], inputs[1], conv.param(), conv.policy(), config); inputs[0], inputs[1], conv.param(), conv.policy(), config);
} else {
mgb_assert(inputs.size() == 3);
// The output shape is calculated in advance and given as input
auto* src_for_shape =
opr::Alloc::make(inputs[2], inputs[0]->dtype(), {}).node();
return opr::Convolution3DBackwardData::make(
inputs[0], inputs[1], src_for_shape, conv.param(), conv.policy(),
config);
}
} }
OP_TRAIT_REG(Convolution3DBackwardData, Convolution3DBackwardData) OP_TRAIT_REG(Convolution3DBackwardData, Convolution3DBackwardData)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册