提交 ab6d12ca 编写于 作者: M Megvii Engine Team

feat(mge): add conv padding mode

GitOrigin-RevId: 147ced856e196437cfc5b371c91094ebcbba6ea8
上级 177001d5
...@@ -35,6 +35,7 @@ void PaddingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { ...@@ -35,6 +35,7 @@ void PaddingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
param().padding_val, stream); \ param().padding_val, stream); \
} }
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
#undef cb #undef cb
} }
......
...@@ -60,7 +60,8 @@ __global__ void paddingConst_kernel( ...@@ -60,7 +60,8 @@ __global__ void paddingConst_kernel(
params.src_stride[dim].divisor(); params.src_stride[dim].divisor();
*/ */
} }
dst[out_index] = in_src_valid_area ? src[in_index] : padding_val; dst[out_index] =
in_src_valid_area ? src[in_index] : static_cast<T>(padding_val);
} }
} }
...@@ -256,6 +257,7 @@ void padding_backward_proxy( ...@@ -256,6 +257,7 @@ void padding_backward_proxy(
const float_t padding_val, cudaStream_t stream); const float_t padding_val, cudaStream_t stream);
#define cb(DType) INST(typename DTypeTrait<DType>::ctype) #define cb(DType) INST(typename DTypeTrait<DType>::ctype)
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
#undef cb #undef cb
#undef INST #undef INST
......
...@@ -171,7 +171,7 @@ void PaddingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { ...@@ -171,7 +171,7 @@ void PaddingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
switch (param().padding_mode) { switch (param().padding_mode) {
case param::Padding::PaddingMode::CONSTANT: case param::Padding::PaddingMode::CONSTANT:
#define cb(DType) \ #define cb(DType) \
if (src.layout.dtype == DType()) { \ if (src.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \
using T = typename DTypeTrait<DType>::ctype; \ using T = typename DTypeTrait<DType>::ctype; \
MEGDNN_DISPATCH_CPU_KERN_OPR(exec_const_internal<T>( \ MEGDNN_DISPATCH_CPU_KERN_OPR(exec_const_internal<T>( \
src.layout.ndim, n, src.ptr<T>(), dst.ptr<T>(), params, \ src.layout.ndim, n, src.ptr<T>(), dst.ptr<T>(), params, \
...@@ -179,28 +179,31 @@ void PaddingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { ...@@ -179,28 +179,31 @@ void PaddingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
return; \ return; \
} }
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
#undef cb #undef cb
break; break;
case param::Padding::PaddingMode::REPLICATE: case param::Padding::PaddingMode::REPLICATE:
#define cb(DType) \ #define cb(DType) \
if (src.layout.dtype == DType()) { \ if (src.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \
using T = typename DTypeTrait<DType>::ctype; \ using T = typename DTypeTrait<DType>::ctype; \
MEGDNN_DISPATCH_CPU_KERN_OPR(exec_replicate_internal<T>( \ MEGDNN_DISPATCH_CPU_KERN_OPR(exec_replicate_internal<T>( \
src.layout.ndim, n, src.ptr<T>(), dst.ptr<T>(), params)); \ src.layout.ndim, n, src.ptr<T>(), dst.ptr<T>(), params)); \
return; \ return; \
} }
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
#undef cb #undef cb
break; break;
case param::Padding::PaddingMode::REFLECT: case param::Padding::PaddingMode::REFLECT:
#define cb(DType) \ #define cb(DType) \
if (src.layout.dtype == DType()) { \ if (src.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \
using T = typename DTypeTrait<DType>::ctype; \ using T = typename DTypeTrait<DType>::ctype; \
MEGDNN_DISPATCH_CPU_KERN_OPR(exec_reflect_internal<T>( \ MEGDNN_DISPATCH_CPU_KERN_OPR(exec_reflect_internal<T>( \
src.layout.ndim, n, src.ptr<T>(), dst.ptr<T>(), params)); \ src.layout.ndim, n, src.ptr<T>(), dst.ptr<T>(), params)); \
return; \ return; \
} }
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
#undef cb #undef cb
break; break;
default: default:
......
...@@ -101,6 +101,36 @@ TEST_F(CUDA, PADDING_REFLECT2) { ...@@ -101,6 +101,36 @@ TEST_F(CUDA, PADDING_REFLECT2) {
4, 1, 6, 3, 6, 1, 6, 3})}); 4, 1, 6, 3, 6, 1, 6, 3})});
} }
TEST_F(CUDA, PADDING_REFLECT2_QUANTIZED) {
Checker<Padding> checker(handle_cuda(), false);
param::Padding param;
param.padding_mode = param::Padding::PaddingMode::REFLECT;
param.front_offset_dim0 = 2;
param.front_offset_dim1 = 1;
param.front_offset_dim2 = 0;
param.front_offset_dim3 = 0;
param.front_offset_dim4 = 0;
param.front_offset_dim5 = 0;
param.front_offset_dim6 = 0;
param.back_offset_dim0 = 0;
param.back_offset_dim1 = 2;
param.back_offset_dim2 = 0;
param.back_offset_dim3 = 0;
param.back_offset_dim4 = 0;
param.back_offset_dim5 = 0;
param.back_offset_dim6 = 0;
checker.set_param(param).exect(
Testcase{
TensorValue(
{3, 3}, dtype::QuantizedS8(), {1, 2, 3, 4, 5, 6, 7, 8, 9}),
{}},
Testcase{{}, TensorValue({5, 6}, dtype::QuantizedS8(), {8, 7, 8, 9, 8, 7, 5,
4, 5, 6, 5, 4, 2, 1,
2, 3, 2, 1, 5, 4, 5,
6, 5, 4, 8, 7, 8, 9,
8, 7})});
}
TEST_F(CUDA, PADDING_REPLICATE) { TEST_F(CUDA, PADDING_REPLICATE) {
Checker<Padding> checker(handle_cuda(), false); Checker<Padding> checker(handle_cuda(), false);
param::Padding param; param::Padding param;
......
...@@ -83,6 +83,36 @@ TEST_F(NAIVE, PADDING_REFLECT) { ...@@ -83,6 +83,36 @@ TEST_F(NAIVE, PADDING_REFLECT) {
{10}, dtype::Float32(), {3, 2, 1, 2, 3, 4, 5, 4, 3, 2})}); {10}, dtype::Float32(), {3, 2, 1, 2, 3, 4, 5, 4, 3, 2})});
} }
TEST_F(NAIVE, PADDING_REFLECT2) {
Checker<Padding> checker(handle(), false);
param::Padding param;
param.padding_mode = param::Padding::PaddingMode::REFLECT;
param.front_offset_dim0 = 2;
param.front_offset_dim1 = 1;
param.front_offset_dim2 = 0;
param.front_offset_dim3 = 0;
param.front_offset_dim4 = 0;
param.front_offset_dim5 = 0;
param.front_offset_dim6 = 0;
param.back_offset_dim0 = 0;
param.back_offset_dim1 = 2;
param.back_offset_dim2 = 0;
param.back_offset_dim3 = 0;
param.back_offset_dim4 = 0;
param.back_offset_dim5 = 0;
param.back_offset_dim6 = 0;
checker.set_param(param).exect(
Testcase{
TensorValue(
{3, 3}, dtype::QuantizedS8(), {1, 2, 3, 4, 5, 6, 7, 8, 9}),
{}},
Testcase{{}, TensorValue({5, 6}, dtype::QuantizedS8(), {8, 7, 8, 9, 8, 7, 5,
4, 5, 6, 5, 4, 2, 1,
2, 3, 2, 1, 5, 4, 5,
6, 5, 4, 8, 7, 8, 9,
8, 7})});
}
TEST_F(NAIVE, PADDING_REPLICATE) { TEST_F(NAIVE, PADDING_REPLICATE) {
Checker<Padding> checker(handle(), false); Checker<Padding> checker(handle(), false);
param::Padding param; param::Padding param;
......
...@@ -18,6 +18,7 @@ from ..functional import ( ...@@ -18,6 +18,7 @@ from ..functional import (
conv_transpose3d, conv_transpose3d,
deformable_conv2d, deformable_conv2d,
local_conv2d, local_conv2d,
pad,
relu, relu,
) )
from ..tensor import Parameter from ..tensor import Parameter
...@@ -126,7 +127,7 @@ class Conv1d(_ConvNd): ...@@ -126,7 +127,7 @@ class Conv1d(_ConvNd):
kernel_size: size of weight on spatial dimensions. kernel_size: size of weight on spatial dimensions.
stride: stride of the 1D convolution operation. stride: stride of the 1D convolution operation.
padding: size of the paddings added to the input on both sides of its padding: size of the paddings added to the input on both sides of its
spatial dimensions. Only zero-padding is supported. Default: 0 spatial dimensions. Default: 0
dilation: dilation of the 1D convolution operation. Default: 1 dilation: dilation of the 1D convolution operation. Default: 1
groups: number of groups to divide input and output channels into, groups: number of groups to divide input and output channels into,
so as to perform a "grouped convolution". When ``groups`` is not 1, so as to perform a "grouped convolution". When ``groups`` is not 1,
...@@ -139,6 +140,8 @@ class Conv1d(_ConvNd): ...@@ -139,6 +140,8 @@ class Conv1d(_ConvNd):
placed on the precision of intermediate results. When set to "float32", placed on the precision of intermediate results. When set to "float32",
"float32" would be used for accumulator and intermediate result, but only "float32" would be used for accumulator and intermediate result, but only
effective when input and output are of float16 dtype. effective when input and output are of float16 dtype.
padding_mode: "zeros", "reflect" or "replicate". Default: "zeros".
Refer to :class:`~.module.padding.Pad` for more information.
Note: Note:
* ``weight`` usually has shape ``(out_channels, in_channels, kernel_size)`` , * ``weight`` usually has shape ``(out_channels, in_channels, kernel_size)`` ,
...@@ -177,6 +180,7 @@ class Conv1d(_ConvNd): ...@@ -177,6 +180,7 @@ class Conv1d(_ConvNd):
bias: bool = True, bias: bool = True,
conv_mode: str = "cross_correlation", conv_mode: str = "cross_correlation",
compute_mode: str = "default", compute_mode: str = "default",
padding_mode: str = "zeros",
**kwargs **kwargs
): ):
kernel_size = kernel_size kernel_size = kernel_size
...@@ -185,6 +189,7 @@ class Conv1d(_ConvNd): ...@@ -185,6 +189,7 @@ class Conv1d(_ConvNd):
dilation = dilation dilation = dilation
self.conv_mode = conv_mode self.conv_mode = conv_mode
self.compute_mode = compute_mode self.compute_mode = compute_mode
self.padding_mode = padding_mode
super().__init__( super().__init__(
in_channels, in_channels,
out_channels, out_channels,
...@@ -223,7 +228,27 @@ class Conv1d(_ConvNd): ...@@ -223,7 +228,27 @@ class Conv1d(_ConvNd):
# Assume format is NCH(W=1) # Assume format is NCH(W=1)
return (1, self.out_channels, 1) return (1, self.out_channels, 1)
def get_pad_witdth(self):
return ((0, 0), (0, 0), (self.padding, self.padding))
def calc_conv(self, inp, weight, bias): def calc_conv(self, inp, weight, bias):
assert self.padding_mode in [
"zeros",
"reflect",
"replicate",
]
if self.padding_mode != "zeros":
return conv1d(
pad(inp, self.get_pad_witdth(), self.padding_mode),
weight,
bias,
self.stride,
0,
self.dilation,
self.groups,
self.conv_mode,
self.compute_mode,
)
return conv1d( return conv1d(
inp, inp,
weight, weight,
...@@ -287,7 +312,7 @@ class Conv2d(_ConvNd): ...@@ -287,7 +312,7 @@ class Conv2d(_ConvNd):
``(kernel_size, kernel_size)``. ``(kernel_size, kernel_size)``.
stride: stride of the 2D convolution operation. Default: 1 stride: stride of the 2D convolution operation. Default: 1
padding: size of the paddings added to the input on both sides of its padding: size of the paddings added to the input on both sides of its
spatial dimensions. Only zero-padding is supported. Default: 0 spatial dimensions. Default: 0
dilation: dilation of the 2D convolution operation. Default: 1 dilation: dilation of the 2D convolution operation. Default: 1
groups: number of groups into which the input and output channels are divided, groups: number of groups into which the input and output channels are divided,
so as to perform a ``grouped convolution``. When ``groups`` is not 1, so as to perform a ``grouped convolution``. When ``groups`` is not 1,
...@@ -300,6 +325,8 @@ class Conv2d(_ConvNd): ...@@ -300,6 +325,8 @@ class Conv2d(_ConvNd):
placed on the precision of intermediate results. When set to "float32", placed on the precision of intermediate results. When set to "float32",
"float32" would be used for accumulator and intermediate result, but only "float32" would be used for accumulator and intermediate result, but only
effective when input and output are of float16 dtype. effective when input and output are of float16 dtype.
padding_mode: "zeros", "reflect" or "replicate". Default: "zeros".
Refer to :class:`~.module.padding.Pad` for more information.
Note: Note:
* ``weight`` usually has shape ``(out_channels, in_channels, height, width)`` , * ``weight`` usually has shape ``(out_channels, in_channels, height, width)`` ,
...@@ -338,6 +365,7 @@ class Conv2d(_ConvNd): ...@@ -338,6 +365,7 @@ class Conv2d(_ConvNd):
bias: bool = True, bias: bool = True,
conv_mode: str = "cross_correlation", conv_mode: str = "cross_correlation",
compute_mode: str = "default", compute_mode: str = "default",
padding_mode: str = "zeros",
**kwargs **kwargs
): ):
kernel_size = _pair_nonzero(kernel_size) kernel_size = _pair_nonzero(kernel_size)
...@@ -346,6 +374,7 @@ class Conv2d(_ConvNd): ...@@ -346,6 +374,7 @@ class Conv2d(_ConvNd):
dilation = _pair_nonzero(dilation) dilation = _pair_nonzero(dilation)
self.conv_mode = conv_mode self.conv_mode = conv_mode
self.compute_mode = compute_mode self.compute_mode = compute_mode
self.padding_mode = padding_mode
super().__init__( super().__init__(
in_channels, in_channels,
out_channels, out_channels,
...@@ -384,7 +413,32 @@ class Conv2d(_ConvNd): ...@@ -384,7 +413,32 @@ class Conv2d(_ConvNd):
# Assume format is NCHW # Assume format is NCHW
return (1, self.out_channels, 1, 1) return (1, self.out_channels, 1, 1)
def get_pad_witdth(self):
return (
(0, 0),
(0, 0),
(self.padding[0], self.padding[0]),
(self.padding[1], self.padding[1]),
)
def calc_conv(self, inp, weight, bias): def calc_conv(self, inp, weight, bias):
assert self.padding_mode in [
"zeros",
"reflect",
"replicate",
]
if self.padding_mode != "zeros":
return conv2d(
pad(inp, self.get_pad_witdth(), self.padding_mode),
weight,
bias,
self.stride,
0,
self.dilation,
self.groups,
self.conv_mode,
self.compute_mode,
)
return conv2d( return conv2d(
inp, inp,
weight, weight,
......
...@@ -30,6 +30,7 @@ class _ConvBnActivation2d(Module): ...@@ -30,6 +30,7 @@ class _ConvBnActivation2d(Module):
momentum=0.9, momentum=0.9,
affine=True, affine=True,
track_running_stats=True, track_running_stats=True,
padding_mode: str = "zeros",
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -44,6 +45,7 @@ class _ConvBnActivation2d(Module): ...@@ -44,6 +45,7 @@ class _ConvBnActivation2d(Module):
bias, bias,
conv_mode, conv_mode,
compute_mode, compute_mode,
padding_mode,
**kwargs, **kwargs,
) )
self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats)
......
...@@ -38,6 +38,7 @@ class Conv2d(Float.Conv2d, QATModule): ...@@ -38,6 +38,7 @@ class Conv2d(Float.Conv2d, QATModule):
float_module.bias is not None, float_module.bias is not None,
float_module.conv_mode, float_module.conv_mode,
float_module.compute_mode, float_module.compute_mode,
float_module.padding_mode,
name=float_module.name, name=float_module.name,
) )
qat_module.weight = float_module.weight qat_module.weight = float_module.weight
......
...@@ -147,6 +147,7 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule): ...@@ -147,6 +147,7 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule):
float_module.conv.bias is not None, float_module.conv.bias is not None,
float_module.conv.conv_mode, float_module.conv.conv_mode,
float_module.conv.compute_mode, float_module.conv.compute_mode,
padding_mode=float_module.conv.padding_mode,
name=float_module.name, name=float_module.name,
) )
qat_module.conv.weight = float_module.conv.weight qat_module.conv.weight = float_module.conv.weight
......
...@@ -11,7 +11,7 @@ import numpy as np ...@@ -11,7 +11,7 @@ import numpy as np
from ... import module as Float from ... import module as Float
from ...core.tensor import dtype from ...core.tensor import dtype
from ...functional.nn import conv_bias_activation from ...functional.nn import conv_bias_activation, pad
from ...functional.quantized import conv_transpose2d from ...functional.quantized import conv_transpose2d
from ...tensor import Parameter from ...tensor import Parameter
from ..qat import conv as QAT from ..qat import conv as QAT
...@@ -38,6 +38,7 @@ class Conv2d(Float.Conv2d, QuantizedModule): ...@@ -38,6 +38,7 @@ class Conv2d(Float.Conv2d, QuantizedModule):
conv_mode: str = "cross_correlation", conv_mode: str = "cross_correlation",
compute_mode: str = "default", compute_mode: str = "default",
dtype=None, dtype=None,
padding_mode: str = "zeros",
**kwargs **kwargs
): ):
super().__init__( super().__init__(
...@@ -51,13 +52,33 @@ class Conv2d(Float.Conv2d, QuantizedModule): ...@@ -51,13 +52,33 @@ class Conv2d(Float.Conv2d, QuantizedModule):
True, True,
conv_mode, conv_mode,
compute_mode, compute_mode,
padding_mode,
) )
self.output_dtype = dtype self.output_dtype = dtype
def calc_conv_quantized(self, inp, nonlinear_mode="identity"): def calc_conv_quantized(self, inp, nonlinear_mode="identity"):
assert self.padding_mode in [
"zeros",
"reflect",
"replicate",
]
inp_scale = dtype.get_scale(inp.dtype) inp_scale = dtype.get_scale(inp.dtype)
w_scale = dtype.get_scale(self.weight.dtype) w_scale = dtype.get_scale(self.weight.dtype)
bias_scale = inp_scale * w_scale bias_scale = inp_scale * w_scale
if self.padding_mode != "zeros":
return conv_bias_activation(
pad(inp, self.get_pad_witdth(), self.padding_mode),
self.weight,
self.bias.astype(dtype.qint32(bias_scale)),
self.output_dtype,
self.stride,
0,
self.dilation,
self.groups,
conv_mode=self.conv_mode,
compute_mode=self.compute_mode,
nonlinear_mode=nonlinear_mode,
)
return conv_bias_activation( return conv_bias_activation(
inp, inp,
self.weight, self.weight,
...@@ -88,6 +109,7 @@ class Conv2d(Float.Conv2d, QuantizedModule): ...@@ -88,6 +109,7 @@ class Conv2d(Float.Conv2d, QuantizedModule):
qat_module.dilation, qat_module.dilation,
qat_module.groups, qat_module.groups,
dtype=output_dtype, dtype=output_dtype,
padding_mode=qat_module.padding_mode,
name=qat_module.name, name=qat_module.name,
) )
weight = qat_module.weight.astype(qat_module.get_weight_dtype()) weight = qat_module.weight.astype(qat_module.get_weight_dtype())
......
...@@ -31,6 +31,7 @@ class _ConvBnActivation2d(Conv2d): ...@@ -31,6 +31,7 @@ class _ConvBnActivation2d(Conv2d):
qat_module.conv.groups, qat_module.conv.groups,
dtype=output_dtype, dtype=output_dtype,
name=qat_module.name, name=qat_module.name,
padding_mode=qat_module.conv.padding_mode,
) )
w_fold, b_fold = qat_module.fold_weight_bias( w_fold, b_fold = qat_module.fold_weight_bias(
qat_module.bn.running_mean, qat_module.bn.running_var qat_module.bn.running_mean, qat_module.bn.running_var
......
...@@ -126,6 +126,9 @@ def convbn2d_module_loader(expr): ...@@ -126,6 +126,9 @@ def convbn2d_module_loader(expr):
module = expr.inputs[0].owner module = expr.inputs[0].owner
if not hasattr(module.bn, "param_dim"): if not hasattr(module.bn, "param_dim"):
module.bn.param_dim = "dim_1c11" module.bn.param_dim = "dim_1c11"
module = expr.inputs[0].owner
if not hasattr(module.conv, "padding_mode"):
module.conv.padding_mode = "zeros"
@register_opdef_loader(BatchNorm) @register_opdef_loader(BatchNorm)
...@@ -170,3 +173,28 @@ def pad_func_loader(expr): ...@@ -170,3 +173,28 @@ def pad_func_loader(expr):
kwargs = expr.kwargs kwargs = expr.kwargs
kwargs["pad_width"] = kwargs.pop("pad_witdth") kwargs["pad_width"] = kwargs.pop("pad_witdth")
expr.set_args_kwargs(*expr.args, **kwargs) expr.set_args_kwargs(*expr.args, **kwargs)
@register_module_loader(
("megengine.module.conv", "Conv1d"),
("megengine.module.conv", "Conv2d"),
("megengine.module.conv", "ConvRelu2d"),
("megengine.module.qat.conv", "Conv2d"),
("megengine.module.qat.conv", "ConvRelu2d"),
("megengine.module.quantized.conv", "Conv2d"),
("megengine.module.quantized.conv", "ConvRelu2d"),
)
def conv2d_module_loader(expr):
module = expr.inputs[0].owner
if not hasattr(module, "padding_mode"):
module.padding_mode = "zeros"
@register_module_loader(
("megengine.module.quantized.conv_bn", "ConvBn2d"),
("megengine.module.quantized.conv_bn", "ConvBnRelu2d"),
)
def quantized_convbn2d_module_loader(expr):
module = expr.inputs[0].owner
if not hasattr(module, "padding_mode"):
module.padding_mode = "zeros"
...@@ -60,7 +60,18 @@ def test_qat_convbn2d(): ...@@ -60,7 +60,18 @@ def test_qat_convbn2d():
) )
def test_qat_conv(): @pytest.mark.parametrize(
"padding, padding_mode",
[
(0, "zeros"),
((1, 2), "zeros"),
(3, "reflect"),
((1, 2), "reflect"),
(4, "replicate"),
((1, 2), "replicate"),
],
)
def test_qat_conv(padding, padding_mode):
in_channels = 32 in_channels = 32
out_channels = 64 out_channels = 64
...@@ -72,7 +83,13 @@ def test_qat_conv(): ...@@ -72,7 +83,13 @@ def test_qat_conv():
self.quant = QuantStub() self.quant = QuantStub()
self.dequant = DequantStub() self.dequant = DequantStub()
self.conv = Conv2d( self.conv = Conv2d(
in_channels, out_channels, kernel_size, groups=groups, bias=bias in_channels,
out_channels,
kernel_size,
groups=groups,
bias=bias,
padding=padding,
padding_mode=padding_mode,
) )
self.conv_relu = ConvRelu2d( self.conv_relu = ConvRelu2d(
out_channels, in_channels, kernel_size, groups=groups, bias=bias out_channels, in_channels, kernel_size, groups=groups, bias=bias
......
...@@ -236,11 +236,16 @@ def test_linear(): ...@@ -236,11 +236,16 @@ def test_linear():
@pytest.mark.parametrize("module", ["Conv2d", "ConvBn2d", "ConvBnRelu2d"]) @pytest.mark.parametrize("module", ["Conv2d", "ConvBn2d", "ConvBnRelu2d"])
def test_conv(module): @pytest.mark.parametrize("padding_mode", ["zeros", "reflect", "replicate"])
normal_net = getattr(Float, module)(3, 3, 3, 1, 1, 1, bias=True) def test_conv(module, padding_mode):
normal_net = getattr(Float, module)(
3, 3, 3, 1, 1, 1, bias=True, padding_mode=padding_mode
)
normal_net.eval() normal_net.eval()
qat_net = getattr(QAT, module)(3, 3, 3, 1, 1, 1, bias=True) qat_net = getattr(QAT, module)(
3, 3, 3, 1, 1, 1, bias=True, padding_mode=padding_mode
)
qat_net.eval() qat_net.eval()
disable_observer(qat_net) disable_observer(qat_net)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册