From 2293385e93b9aa8ff2045f108436eee53cfb10c6 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 26 Jan 2022 13:27:42 +0800 Subject: [PATCH] feat(mge): add conv padding mode GitOrigin-RevId: 147ced856e196437cfc5b371c91094ebcbba6ea8 --- dnn/src/cuda/padding/opr_impl.cpp | 1 + dnn/src/cuda/padding/padding.cu | 4 +- dnn/src/naive/padding/opr_impl.cpp | 9 ++- dnn/test/cuda/padding.cpp | 30 ++++++++++ dnn/test/naive/padding.cpp | 30 ++++++++++ imperative/python/megengine/module/conv.py | 58 ++++++++++++++++++- imperative/python/megengine/module/conv_bn.py | 2 + .../python/megengine/module/qat/conv.py | 1 + .../python/megengine/module/qat/conv_bn.py | 1 + .../python/megengine/module/quantized/conv.py | 24 +++++++- .../megengine/module/quantized/conv_bn.py | 1 + .../python/megengine/traced_module/compat.py | 36 ++++++++++++ .../python/test/unit/module/test_qat.py | 21 ++++++- .../test/unit/quantization/test_module.py | 11 +++- 14 files changed, 217 insertions(+), 12 deletions(-) diff --git a/dnn/src/cuda/padding/opr_impl.cpp b/dnn/src/cuda/padding/opr_impl.cpp index 5da1c744f..3102ca432 100644 --- a/dnn/src/cuda/padding/opr_impl.cpp +++ b/dnn/src/cuda/padding/opr_impl.cpp @@ -35,6 +35,7 @@ void PaddingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { param().padding_val, stream); \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) #undef cb } diff --git a/dnn/src/cuda/padding/padding.cu b/dnn/src/cuda/padding/padding.cu index 4bd91b789..5b4678cdf 100644 --- a/dnn/src/cuda/padding/padding.cu +++ b/dnn/src/cuda/padding/padding.cu @@ -60,7 +60,8 @@ __global__ void paddingConst_kernel( params.src_stride[dim].divisor(); */ } - dst[out_index] = in_src_valid_area ? src[in_index] : padding_val; + dst[out_index] = + in_src_valid_area ? src[in_index] : static_cast(padding_val); } } @@ -256,6 +257,7 @@ void padding_backward_proxy( const float_t padding_val, cudaStream_t stream); #define cb(DType) INST(typename DTypeTrait::ctype) MEGDNN_FOREACH_COMPUTING_DTYPE(cb) +MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) #undef cb #undef INST diff --git a/dnn/src/naive/padding/opr_impl.cpp b/dnn/src/naive/padding/opr_impl.cpp index 1de574431..0d45ddc7e 100644 --- a/dnn/src/naive/padding/opr_impl.cpp +++ b/dnn/src/naive/padding/opr_impl.cpp @@ -171,7 +171,7 @@ void PaddingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { switch (param().padding_mode) { case param::Padding::PaddingMode::CONSTANT: #define cb(DType) \ - if (src.layout.dtype == DType()) { \ + if (src.layout.dtype.enumv() == DTypeTrait::enumv) { \ using T = typename DTypeTrait::ctype; \ MEGDNN_DISPATCH_CPU_KERN_OPR(exec_const_internal( \ src.layout.ndim, n, src.ptr(), dst.ptr(), params, \ @@ -179,28 +179,31 @@ void PaddingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { return; \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) #undef cb break; case param::Padding::PaddingMode::REPLICATE: #define cb(DType) \ - if (src.layout.dtype == DType()) { \ + if (src.layout.dtype.enumv() == DTypeTrait::enumv) { \ using T = typename DTypeTrait::ctype; \ MEGDNN_DISPATCH_CPU_KERN_OPR(exec_replicate_internal( \ src.layout.ndim, n, src.ptr(), dst.ptr(), params)); \ return; \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) #undef cb break; case param::Padding::PaddingMode::REFLECT: #define cb(DType) \ - if (src.layout.dtype == DType()) { \ + if (src.layout.dtype.enumv() == DTypeTrait::enumv) { \ using T = typename DTypeTrait::ctype; \ MEGDNN_DISPATCH_CPU_KERN_OPR(exec_reflect_internal( \ src.layout.ndim, n, src.ptr(), dst.ptr(), params)); \ return; \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) #undef cb break; default: diff --git a/dnn/test/cuda/padding.cpp b/dnn/test/cuda/padding.cpp index 2a1808bde..1ae5ea223 100644 --- a/dnn/test/cuda/padding.cpp +++ b/dnn/test/cuda/padding.cpp @@ -101,6 +101,36 @@ TEST_F(CUDA, PADDING_REFLECT2) { 4, 1, 6, 3, 6, 1, 6, 3})}); } +TEST_F(CUDA, PADDING_REFLECT2_QUANTIZED) { + Checker checker(handle_cuda(), false); + param::Padding param; + param.padding_mode = param::Padding::PaddingMode::REFLECT; + param.front_offset_dim0 = 2; + param.front_offset_dim1 = 1; + param.front_offset_dim2 = 0; + param.front_offset_dim3 = 0; + param.front_offset_dim4 = 0; + param.front_offset_dim5 = 0; + param.front_offset_dim6 = 0; + param.back_offset_dim0 = 0; + param.back_offset_dim1 = 2; + param.back_offset_dim2 = 0; + param.back_offset_dim3 = 0; + param.back_offset_dim4 = 0; + param.back_offset_dim5 = 0; + param.back_offset_dim6 = 0; + checker.set_param(param).exect( + Testcase{ + TensorValue( + {3, 3}, dtype::QuantizedS8(), {1, 2, 3, 4, 5, 6, 7, 8, 9}), + {}}, + Testcase{{}, TensorValue({5, 6}, dtype::QuantizedS8(), {8, 7, 8, 9, 8, 7, 5, + 4, 5, 6, 5, 4, 2, 1, + 2, 3, 2, 1, 5, 4, 5, + 6, 5, 4, 8, 7, 8, 9, + 8, 7})}); +} + TEST_F(CUDA, PADDING_REPLICATE) { Checker checker(handle_cuda(), false); param::Padding param; diff --git a/dnn/test/naive/padding.cpp b/dnn/test/naive/padding.cpp index e0c2380c2..161c1a652 100644 --- a/dnn/test/naive/padding.cpp +++ b/dnn/test/naive/padding.cpp @@ -83,6 +83,36 @@ TEST_F(NAIVE, PADDING_REFLECT) { {10}, dtype::Float32(), {3, 2, 1, 2, 3, 4, 5, 4, 3, 2})}); } +TEST_F(NAIVE, PADDING_REFLECT2) { + Checker checker(handle(), false); + param::Padding param; + param.padding_mode = param::Padding::PaddingMode::REFLECT; + param.front_offset_dim0 = 2; + param.front_offset_dim1 = 1; + param.front_offset_dim2 = 0; + param.front_offset_dim3 = 0; + param.front_offset_dim4 = 0; + param.front_offset_dim5 = 0; + param.front_offset_dim6 = 0; + param.back_offset_dim0 = 0; + param.back_offset_dim1 = 2; + param.back_offset_dim2 = 0; + param.back_offset_dim3 = 0; + param.back_offset_dim4 = 0; + param.back_offset_dim5 = 0; + param.back_offset_dim6 = 0; + checker.set_param(param).exect( + Testcase{ + TensorValue( + {3, 3}, dtype::QuantizedS8(), {1, 2, 3, 4, 5, 6, 7, 8, 9}), + {}}, + Testcase{{}, TensorValue({5, 6}, dtype::QuantizedS8(), {8, 7, 8, 9, 8, 7, 5, + 4, 5, 6, 5, 4, 2, 1, + 2, 3, 2, 1, 5, 4, 5, + 6, 5, 4, 8, 7, 8, 9, + 8, 7})}); +} + TEST_F(NAIVE, PADDING_REPLICATE) { Checker checker(handle(), false); param::Padding param; diff --git a/imperative/python/megengine/module/conv.py b/imperative/python/megengine/module/conv.py index 5baa22ad0..e0ee3f7b3 100644 --- a/imperative/python/megengine/module/conv.py +++ b/imperative/python/megengine/module/conv.py @@ -18,6 +18,7 @@ from ..functional import ( conv_transpose3d, deformable_conv2d, local_conv2d, + pad, relu, ) from ..tensor import Parameter @@ -126,7 +127,7 @@ class Conv1d(_ConvNd): kernel_size: size of weight on spatial dimensions. stride: stride of the 1D convolution operation. padding: size of the paddings added to the input on both sides of its - spatial dimensions. Only zero-padding is supported. Default: 0 + spatial dimensions. Default: 0 dilation: dilation of the 1D convolution operation. Default: 1 groups: number of groups to divide input and output channels into, so as to perform a "grouped convolution". When ``groups`` is not 1, @@ -139,6 +140,8 @@ class Conv1d(_ConvNd): placed on the precision of intermediate results. When set to "float32", "float32" would be used for accumulator and intermediate result, but only effective when input and output are of float16 dtype. + padding_mode: "zeros", "reflect" or "replicate". Default: "zeros". + Refer to :class:`~.module.padding.Pad` for more information. Note: * ``weight`` usually has shape ``(out_channels, in_channels, kernel_size)`` , @@ -177,6 +180,7 @@ class Conv1d(_ConvNd): bias: bool = True, conv_mode: str = "cross_correlation", compute_mode: str = "default", + padding_mode: str = "zeros", **kwargs ): kernel_size = kernel_size @@ -185,6 +189,7 @@ class Conv1d(_ConvNd): dilation = dilation self.conv_mode = conv_mode self.compute_mode = compute_mode + self.padding_mode = padding_mode super().__init__( in_channels, out_channels, @@ -223,7 +228,27 @@ class Conv1d(_ConvNd): # Assume format is NCH(W=1) return (1, self.out_channels, 1) + def get_pad_witdth(self): + return ((0, 0), (0, 0), (self.padding, self.padding)) + def calc_conv(self, inp, weight, bias): + assert self.padding_mode in [ + "zeros", + "reflect", + "replicate", + ] + if self.padding_mode != "zeros": + return conv1d( + pad(inp, self.get_pad_witdth(), self.padding_mode), + weight, + bias, + self.stride, + 0, + self.dilation, + self.groups, + self.conv_mode, + self.compute_mode, + ) return conv1d( inp, weight, @@ -287,7 +312,7 @@ class Conv2d(_ConvNd): ``(kernel_size, kernel_size)``. stride: stride of the 2D convolution operation. Default: 1 padding: size of the paddings added to the input on both sides of its - spatial dimensions. Only zero-padding is supported. Default: 0 + spatial dimensions. Default: 0 dilation: dilation of the 2D convolution operation. Default: 1 groups: number of groups into which the input and output channels are divided, so as to perform a ``grouped convolution``. When ``groups`` is not 1, @@ -300,6 +325,8 @@ class Conv2d(_ConvNd): placed on the precision of intermediate results. When set to "float32", "float32" would be used for accumulator and intermediate result, but only effective when input and output are of float16 dtype. + padding_mode: "zeros", "reflect" or "replicate". Default: "zeros". + Refer to :class:`~.module.padding.Pad` for more information. Note: * ``weight`` usually has shape ``(out_channels, in_channels, height, width)`` , @@ -338,6 +365,7 @@ class Conv2d(_ConvNd): bias: bool = True, conv_mode: str = "cross_correlation", compute_mode: str = "default", + padding_mode: str = "zeros", **kwargs ): kernel_size = _pair_nonzero(kernel_size) @@ -346,6 +374,7 @@ class Conv2d(_ConvNd): dilation = _pair_nonzero(dilation) self.conv_mode = conv_mode self.compute_mode = compute_mode + self.padding_mode = padding_mode super().__init__( in_channels, out_channels, @@ -384,7 +413,32 @@ class Conv2d(_ConvNd): # Assume format is NCHW return (1, self.out_channels, 1, 1) + def get_pad_witdth(self): + return ( + (0, 0), + (0, 0), + (self.padding[0], self.padding[0]), + (self.padding[1], self.padding[1]), + ) + def calc_conv(self, inp, weight, bias): + assert self.padding_mode in [ + "zeros", + "reflect", + "replicate", + ] + if self.padding_mode != "zeros": + return conv2d( + pad(inp, self.get_pad_witdth(), self.padding_mode), + weight, + bias, + self.stride, + 0, + self.dilation, + self.groups, + self.conv_mode, + self.compute_mode, + ) return conv2d( inp, weight, diff --git a/imperative/python/megengine/module/conv_bn.py b/imperative/python/megengine/module/conv_bn.py index 5d87a688f..d69a955a0 100644 --- a/imperative/python/megengine/module/conv_bn.py +++ b/imperative/python/megengine/module/conv_bn.py @@ -30,6 +30,7 @@ class _ConvBnActivation2d(Module): momentum=0.9, affine=True, track_running_stats=True, + padding_mode: str = "zeros", **kwargs ): super().__init__(**kwargs) @@ -44,6 +45,7 @@ class _ConvBnActivation2d(Module): bias, conv_mode, compute_mode, + padding_mode, **kwargs, ) self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) diff --git a/imperative/python/megengine/module/qat/conv.py b/imperative/python/megengine/module/qat/conv.py index 04201f1fe..61a5ee83d 100644 --- a/imperative/python/megengine/module/qat/conv.py +++ b/imperative/python/megengine/module/qat/conv.py @@ -38,6 +38,7 @@ class Conv2d(Float.Conv2d, QATModule): float_module.bias is not None, float_module.conv_mode, float_module.compute_mode, + float_module.padding_mode, name=float_module.name, ) qat_module.weight = float_module.weight diff --git a/imperative/python/megengine/module/qat/conv_bn.py b/imperative/python/megengine/module/qat/conv_bn.py index 2b891ed61..77af6394e 100644 --- a/imperative/python/megengine/module/qat/conv_bn.py +++ b/imperative/python/megengine/module/qat/conv_bn.py @@ -147,6 +147,7 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule): float_module.conv.bias is not None, float_module.conv.conv_mode, float_module.conv.compute_mode, + padding_mode=float_module.conv.padding_mode, name=float_module.name, ) qat_module.conv.weight = float_module.conv.weight diff --git a/imperative/python/megengine/module/quantized/conv.py b/imperative/python/megengine/module/quantized/conv.py index b7bc08473..aff71bc3f 100644 --- a/imperative/python/megengine/module/quantized/conv.py +++ b/imperative/python/megengine/module/quantized/conv.py @@ -11,7 +11,7 @@ import numpy as np from ... import module as Float from ...core.tensor import dtype -from ...functional.nn import conv_bias_activation +from ...functional.nn import conv_bias_activation, pad from ...functional.quantized import conv_transpose2d from ...tensor import Parameter from ..qat import conv as QAT @@ -38,6 +38,7 @@ class Conv2d(Float.Conv2d, QuantizedModule): conv_mode: str = "cross_correlation", compute_mode: str = "default", dtype=None, + padding_mode: str = "zeros", **kwargs ): super().__init__( @@ -51,13 +52,33 @@ class Conv2d(Float.Conv2d, QuantizedModule): True, conv_mode, compute_mode, + padding_mode, ) self.output_dtype = dtype def calc_conv_quantized(self, inp, nonlinear_mode="identity"): + assert self.padding_mode in [ + "zeros", + "reflect", + "replicate", + ] inp_scale = dtype.get_scale(inp.dtype) w_scale = dtype.get_scale(self.weight.dtype) bias_scale = inp_scale * w_scale + if self.padding_mode != "zeros": + return conv_bias_activation( + pad(inp, self.get_pad_witdth(), self.padding_mode), + self.weight, + self.bias.astype(dtype.qint32(bias_scale)), + self.output_dtype, + self.stride, + 0, + self.dilation, + self.groups, + conv_mode=self.conv_mode, + compute_mode=self.compute_mode, + nonlinear_mode=nonlinear_mode, + ) return conv_bias_activation( inp, self.weight, @@ -88,6 +109,7 @@ class Conv2d(Float.Conv2d, QuantizedModule): qat_module.dilation, qat_module.groups, dtype=output_dtype, + padding_mode=qat_module.padding_mode, name=qat_module.name, ) weight = qat_module.weight.astype(qat_module.get_weight_dtype()) diff --git a/imperative/python/megengine/module/quantized/conv_bn.py b/imperative/python/megengine/module/quantized/conv_bn.py index cef6b1375..4230ed49e 100644 --- a/imperative/python/megengine/module/quantized/conv_bn.py +++ b/imperative/python/megengine/module/quantized/conv_bn.py @@ -31,6 +31,7 @@ class _ConvBnActivation2d(Conv2d): qat_module.conv.groups, dtype=output_dtype, name=qat_module.name, + padding_mode=qat_module.conv.padding_mode, ) w_fold, b_fold = qat_module.fold_weight_bias( qat_module.bn.running_mean, qat_module.bn.running_var diff --git a/imperative/python/megengine/traced_module/compat.py b/imperative/python/megengine/traced_module/compat.py index 9350b1bda..48978e8be 100644 --- a/imperative/python/megengine/traced_module/compat.py +++ b/imperative/python/megengine/traced_module/compat.py @@ -126,6 +126,9 @@ def convbn2d_module_loader(expr): module = expr.inputs[0].owner if not hasattr(module.bn, "param_dim"): module.bn.param_dim = "dim_1c11" + module = expr.inputs[0].owner + if not hasattr(module.conv, "padding_mode"): + module.conv.padding_mode = "zeros" @register_opdef_loader(BatchNorm) @@ -162,3 +165,36 @@ def tensor_gen_func_loader(expr): else: device = None expr.set_args_kwargs(shape, dtype=dtype, device=device) + + +@register_functional_loader(("megengine.functional.nn", "pad")) +def pad_func_loader(expr): + if "pad_witdth" in expr.kwargs: + kwargs = expr.kwargs + kwargs["pad_width"] = kwargs.pop("pad_witdth") + expr.set_args_kwargs(*expr.args, **kwargs) + + +@register_module_loader( + ("megengine.module.conv", "Conv1d"), + ("megengine.module.conv", "Conv2d"), + ("megengine.module.conv", "ConvRelu2d"), + ("megengine.module.qat.conv", "Conv2d"), + ("megengine.module.qat.conv", "ConvRelu2d"), + ("megengine.module.quantized.conv", "Conv2d"), + ("megengine.module.quantized.conv", "ConvRelu2d"), +) +def conv2d_module_loader(expr): + module = expr.inputs[0].owner + if not hasattr(module, "padding_mode"): + module.padding_mode = "zeros" + + +@register_module_loader( + ("megengine.module.quantized.conv_bn", "ConvBn2d"), + ("megengine.module.quantized.conv_bn", "ConvBnRelu2d"), +) +def quantized_convbn2d_module_loader(expr): + module = expr.inputs[0].owner + if not hasattr(module, "padding_mode"): + module.padding_mode = "zeros" diff --git a/imperative/python/test/unit/module/test_qat.py b/imperative/python/test/unit/module/test_qat.py index 2cb3f5a87..1d81a17dd 100644 --- a/imperative/python/test/unit/module/test_qat.py +++ b/imperative/python/test/unit/module/test_qat.py @@ -60,7 +60,18 @@ def test_qat_convbn2d(): ) -def test_qat_conv(): +@pytest.mark.parametrize( + "padding, padding_mode", + [ + (0, "zeros"), + ((1, 2), "zeros"), + (3, "reflect"), + ((1, 2), "reflect"), + (4, "replicate"), + ((1, 2), "replicate"), + ], +) +def test_qat_conv(padding, padding_mode): in_channels = 32 out_channels = 64 @@ -72,7 +83,13 @@ def test_qat_conv(): self.quant = QuantStub() self.dequant = DequantStub() self.conv = Conv2d( - in_channels, out_channels, kernel_size, groups=groups, bias=bias + in_channels, + out_channels, + kernel_size, + groups=groups, + bias=bias, + padding=padding, + padding_mode=padding_mode, ) self.conv_relu = ConvRelu2d( out_channels, in_channels, kernel_size, groups=groups, bias=bias diff --git a/imperative/python/test/unit/quantization/test_module.py b/imperative/python/test/unit/quantization/test_module.py index c30d53f99..ec23ef6a9 100644 --- a/imperative/python/test/unit/quantization/test_module.py +++ b/imperative/python/test/unit/quantization/test_module.py @@ -236,11 +236,16 @@ def test_linear(): @pytest.mark.parametrize("module", ["Conv2d", "ConvBn2d", "ConvBnRelu2d"]) -def test_conv(module): - normal_net = getattr(Float, module)(3, 3, 3, 1, 1, 1, bias=True) +@pytest.mark.parametrize("padding_mode", ["zeros", "reflect", "replicate"]) +def test_conv(module, padding_mode): + normal_net = getattr(Float, module)( + 3, 3, 3, 1, 1, 1, bias=True, padding_mode=padding_mode + ) normal_net.eval() - qat_net = getattr(QAT, module)(3, 3, 3, 1, 1, 1, bias=True) + qat_net = getattr(QAT, module)( + 3, 3, 3, 1, 1, 1, bias=True, padding_mode=padding_mode + ) qat_net.eval() disable_observer(qat_net) -- GitLab