提交 da7f250c 编写于 作者: M Megvii Engine Team

feat(mge/module): deconv fuse bn and relu

GitOrigin-RevId: 5619b397a4686edec3f98f02c66cf3e70b197092
上级 dbd94839
......@@ -12,11 +12,13 @@ from .conv import (
ConvRelu2d,
ConvTranspose2d,
ConvTranspose3d,
ConvTransposeRelu2d,
DeformableConv2d,
LocalConv2d,
RegionRestrictedConv,
)
from .conv_bn import ConvBn2d, ConvBnRelu2d
from .conv_transpose_bn import ConvTransposeBn2d, ConvTransposeBnRelu2d
from .deformable_psroi_pooling import DeformablePSROIPooling
from .dropout import Dropout
from .elemwise import Elemwise
......
......@@ -773,6 +773,15 @@ class ConvRelu2d(Conv2d):
return relu(self.calc_conv(inp, self.weight, self.bias))
class ConvTransposeRelu2d(ConvTranspose2d):
r"""A fused :class:`~.Module` including :class:`~.module.ConvTranspose2d` and :func:`~.relu`.
Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvTransposeRelu2d` using :func:`~.quantize.quantize_qat`.
"""
def forward(self, inp):
return relu(self.calc_conv_transpose2d(inp, self.weight, self.bias))
class DeformableConv2d(_ConvNd):
r"""Deformable Convolution.
......
from typing import Tuple, Union
from ..functional import relu
from .batchnorm import BatchNorm2d
from .conv import ConvTranspose2d
from .module import Module
class _ConvTransposeBnActivation2d(Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
output_padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
bias: bool = True,
conv_mode: str = "cross_correlation",
compute_mode: str = "default",
eps=1e-5,
momentum=0.9,
affine=True,
track_running_stats=True,
**kwargs
):
super().__init__(**kwargs)
self.conv_transpose2d = ConvTranspose2d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
output_padding,
dilation,
groups,
bias,
conv_mode,
compute_mode,
**kwargs,
)
self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats)
class ConvTransposeBn2d(_ConvTransposeBnActivation2d):
r"""A fused :class:`~.Module` including :class:`~.module.ConvTranspose2d` and :class:`~.module.BatchNorm2d`.
Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvTransposeBn2d` using:func:`~.quantize.quantize_qat`.
"""
def forward(self, inp):
return self.bn(self.conv_transpose2d(inp))
class ConvTransposeBnRelu2d(_ConvTransposeBnActivation2d):
r"""A fused :class:`~.Module` including :class:`~.module.ConvTranspose2d`, :class:`~.module.BatchNorm2d` and :func:`~.relu`.
Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvTransposeBnRelu2d` using :func:`~.quantize.quantize_qat`.
"""
def forward(self, inp):
return relu(self.bn(self.conv_transpose2d(inp)))
from .batch_matmul_activation import BatchMatMulActivation
from .concat import Concat
from .conv import Conv2d, ConvRelu2d, ConvTranspose2d
from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, ConvTransposeRelu2d
from .conv_bn import ConvBn2d, ConvBnRelu2d
from .conv_transpose_bn import ConvTransposeBn2d, ConvTransposeBnRelu2d
from .elemwise import Elemwise
from .linear import Linear
from .module import QATModule
......
......@@ -59,8 +59,8 @@ class ConvTranspose2d(Float.ConvTranspose2d, QATModule):
def calc_conv_transpose2d_qat(self, inp):
w_qat = self.apply_quant_weight(self.weight)
b_qat = self.apply_quant_bias(self.bias, inp, w_qat)
conv = self.calc_conv_transpose2d(inp, w_qat, b_qat)
return conv
conv_transpose2d = self.calc_conv_transpose2d(inp, w_qat, b_qat)
return conv_transpose2d
@classmethod
def from_float_module(cls, float_module: Float.ConvTranspose2d):
......@@ -88,3 +88,12 @@ class ConvTranspose2d(Float.ConvTranspose2d, QATModule):
def forward(self, inp):
return self.apply_quant_activation(self.calc_conv_transpose2d_qat(inp))
class ConvTransposeRelu2d(ConvTranspose2d):
r"""A :class:`~.QATModule` include :class:`~.module.ConvTranspose2d` and :func:`~.relu` with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`.
"""
def forward(self, inp):
return self.apply_quant_activation(F.relu(self.calc_conv_transpose2d_qat(inp)))
from ...functional import ones, relu, sqrt, sum, zeros
from .. import conv_transpose_bn as Float
from .module import QATModule
class _ConvTransposeBnActivation2d(Float._ConvTransposeBnActivation2d, QATModule):
def get_batch_mean_var(self, inp):
def _sum_channel(inp, axis=0, keepdims=True):
if isinstance(axis, int):
out = sum(inp, axis=axis, keepdims=keepdims)
elif isinstance(axis, tuple):
for idx, elem in enumerate(axis):
out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims)
return out
sum1 = _sum_channel(inp, (0, 2, 3))
sum2 = _sum_channel(inp ** 2, (0, 2, 3))
reduce_size = inp.size / inp.shape[1]
batch_mean = sum1 / reduce_size
batch_var = (sum2 - sum1 ** 2 / reduce_size) / reduce_size
return batch_mean, batch_var
def fold_weight_bias(self, bn_mean, bn_var):
# get fold bn conv_transpose2d param
gamma = self.bn.weight
if gamma is None:
gamma = ones((self.bn.num_features), dtype="float32")
gamma = gamma.reshape(1, -1, 1, 1)
beta = self.bn.bias
if beta is None:
beta = zeros((1, self.bn.num_features, 1, 1), dtype="float32")
if bn_mean is None:
bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32")
if bn_var is None:
bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32")
conv_transpose2d_bias = self.conv_transpose2d.bias
if conv_transpose2d_bias is None:
conv_transpose2d_bias = zeros(
self.conv_transpose2d._infer_bias_shape(), dtype="float32"
)
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
scale_factor = gamma * bn_istd
if self.conv_transpose2d.groups == 1:
w_fold = self.conv_transpose2d.weight * scale_factor.reshape(-1, 1, 1, 1)
else:
w_fold = self.conv_transpose2d.weight * scale_factor.reshape(
self.conv_transpose2d.groups, -1, 1, 1, 1
)
w_fold = self.apply_quant_weight(w_fold)
b_fold = beta + gamma * (conv_transpose2d_bias - bn_mean) * bn_istd
return w_fold, b_fold
def update_running_mean_and_running_var(
self, bn_mean, bn_var, num_elements_per_channel
):
# update running mean and running var. no grad, use unbiased bn var
bn_mean = bn_mean.detach()
bn_var = (
bn_var.detach() * num_elements_per_channel / (num_elements_per_channel - 1)
)
exponential_average_factor = 1 - self.bn.momentum
self.bn.running_mean *= self.bn.momentum
self.bn.running_mean += exponential_average_factor * bn_mean
self.bn.running_var *= self.bn.momentum
self.bn.running_var += exponential_average_factor * bn_var
def calc_conv_transpose2d_bn_qat(self, inp, approx=True):
if self.training and not approx:
conv_transpose2d = self.conv_transpose2d(inp)
bn_mean, bn_var = self.get_batch_mean_var(conv_transpose2d)
num_elements_per_channel = conv_transpose2d.size / conv_transpose2d.shape[1]
self.update_running_mean_and_running_var(
bn_mean, bn_var, num_elements_per_channel
)
else:
bn_mean, bn_var = self.bn.running_mean, self.bn.running_var
# get gamma and beta in BatchNorm
gamma = self.bn.weight
if gamma is None:
gamma = ones((self.bn.num_features), dtype="float32")
gamma = gamma.reshape(1, -1, 1, 1)
beta = self.bn.bias
if beta is None:
beta = zeros((self.bn.num_features), dtype="float32")
beta = beta.reshape(1, -1, 1, 1)
# conv_transpose2d_bias
conv_transpose2d_bias = self.conv_transpose2d.bias
if conv_transpose2d_bias is None:
conv_transpose2d_bias = zeros(
self.conv_transpose2d._infer_bias_shape(), dtype="float32"
)
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
scale_factor = gamma * bn_istd
if self.conv_transpose2d.groups == 1:
w_fold = self.conv_transpose2d.weight * scale_factor.reshape(1, -1, 1, 1)
else:
w_fold = self.conv_transpose2d.weight * scale_factor.reshape(
self.conv_transpose2d.groups, 1, -1, 1, 1
)
b_fold = None
if not (self.training and approx):
b_fold = beta + gamma * (conv_transpose2d_bias - bn_mean) * bn_istd
w_qat = self.apply_quant_weight(w_fold)
b_qat = self.apply_quant_bias(b_fold, inp, w_qat)
conv_transpose2d = self.conv_transpose2d.calc_conv_transpose2d(
inp, w_qat, b_qat
)
if not (self.training and approx):
return conv_transpose2d
# rescale conv_transpose2d to get original conv_transpose2d output
orig_conv_transpose2d = conv_transpose2d / scale_factor.reshape(1, -1, 1, 1)
if self.conv_transpose2d.bias is not None:
orig_conv_transpose2d = orig_conv_transpose2d + self.conv_transpose2d.bias
# calculate batch norm
conv_transpose2d = self.bn(orig_conv_transpose2d)
return conv_transpose2d
@classmethod
def from_float_module(cls, float_module: Float._ConvTransposeBnActivation2d):
qat_module = cls(
float_module.conv_transpose2d.in_channels,
float_module.conv_transpose2d.out_channels,
float_module.conv_transpose2d.kernel_size,
float_module.conv_transpose2d.stride,
float_module.conv_transpose2d.padding,
float_module.conv_transpose2d.output_padding,
float_module.conv_transpose2d.dilation,
float_module.conv_transpose2d.groups,
float_module.conv_transpose2d.bias is not None,
float_module.conv_transpose2d.conv_mode,
float_module.conv_transpose2d.compute_mode,
name=float_module.name,
)
qat_module.conv_transpose2d.weight = float_module.conv_transpose2d.weight
qat_module.conv_transpose2d.bias = float_module.conv_transpose2d.bias
qat_module.bn = float_module.bn
return qat_module
class ConvTransposeBn2d(_ConvTransposeBnActivation2d):
r"""A fused :class:`~.QATModule` including :class:`~.module.ConvTranspose2d` and :class:`~.module.BatchNorm2d` with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`.
"""
def forward(self, inp):
return self.apply_quant_activation(self.calc_conv_transpose2d_bn_qat(inp))
class ConvTransposeBnRelu2d(_ConvTransposeBnActivation2d):
r"""A fused :class:`~.QATModule` including :class:`~.module.ConvTranspose2d`, :class:`~.module.BatchNorm2d` and :func:`~.relu` with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`.
"""
def forward(self, inp):
return self.apply_quant_activation(relu(self.calc_conv_transpose2d_bn_qat(inp)))
from .batch_matmul_activation import BatchMatMulActivation
from .concat import Concat
from .conv import Conv2d, ConvRelu2d, ConvTranspose2d
from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, ConvTransposeRelu2d
from .conv_bn import ConvBn2d, ConvBnRelu2d
from .conv_transpose_bn import ConvTransposeBn2d, ConvTransposeBnRelu2d
from .elemwise import Elemwise
from .linear import Linear
from .module import QuantizedModule
......
......@@ -178,7 +178,7 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule):
:class:`~.QATModule` instance.
"""
output_dtype = qat_module.get_activation_dtype()
qconv = cls(
qconv_transpose2d = cls(
qat_module.in_channels,
qat_module.out_channels,
qat_module.kernel_size,
......@@ -194,15 +194,19 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule):
name=qat_module.name,
)
weight = qat_module.weight.astype(qat_module.get_weight_dtype())
qconv.weight = Parameter(weight.numpy(), name=qat_module.weight.name)
qconv.bias = (
qconv_transpose2d.weight = Parameter(
weight.numpy(), name=qat_module.weight.name
)
qconv_transpose2d.bias = (
Parameter(qat_module.bias.numpy(), name=qat_module.bias.name)
if qat_module.bias is not None
else None
)
return qconv
return qconv_transpose2d
def calc_conv_transpose2d_quantized(self, inp, nonlinear_mode):
assert nonlinear_mode == "identity", "nonlinear_mode shoule be 'identity'"
def calc_conv_transpose2d_quantized(self, inp):
if self.bias is not None:
inp_scale = dtype.get_scale(inp.dtype)
w_scale = dtype.get_scale(self.weight.dtype)
......@@ -225,4 +229,11 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule):
)
def forward(self, inp):
return self.calc_conv_transpose2d_quantized(inp)
return self.calc_conv_transpose2d_quantized(inp, nonlinear_mode="identity")
class ConvTransposeRelu2d(ConvTranspose2d):
r"""Quantized version of :class:`~.qat.ConvTransposeRelu2d`."""
def forward(self, inp):
return self.calc_conv_transpose2d_quantized(inp, nonlinear_mode="relu")
from ...tensor import Parameter
from ..qat import conv_transpose_bn as QAT
from .conv import ConvTranspose2d
class _ConvTransposeBnActivation2d(ConvTranspose2d):
r"""Applies a 2D deconvolution over a quantized input tensor, used for inference only.
"""
@classmethod
def from_qat_module(cls, qat_module: QAT._ConvTransposeBnActivation2d):
r"""
Return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
output_dtype = qat_module.get_activation_dtype()
qconv_transpose2d = cls(
qat_module.conv_transpose2d.in_channels,
qat_module.conv_transpose2d.out_channels,
qat_module.conv_transpose2d.kernel_size,
qat_module.conv_transpose2d.stride,
qat_module.conv_transpose2d.padding,
qat_module.conv_transpose2d.output_padding,
qat_module.conv_transpose2d.dilation,
qat_module.conv_transpose2d.groups,
dtype=output_dtype,
name=qat_module.name,
)
w_fold, b_fold = qat_module.fold_weight_bias(
qat_module.bn.running_mean, qat_module.bn.running_var
)
weight = w_fold.astype(qat_module.get_weight_dtype())
qconv_transpose2d.weight = Parameter(
weight.numpy(), name=qat_module.conv_transpose2d.weight.name
)
qconv_transpose2d.bias = Parameter(b_fold.numpy())
if qat_module.conv_transpose2d.bias is not None:
qconv_transpose2d.bias.name = qat_module.conv_transpose2d.bias.name
return qconv_transpose2d
class ConvTransposeBn2d(_ConvTransposeBnActivation2d):
r"""Quantized version of :class:`~.qat.ConvTransposeBn2d`."""
def forward(self, inp):
return self.calc_conv_transpose2d_quantized(inp, nonlinear_mode="identity")
class ConvTransposeBnRelu2d(_ConvTransposeBnActivation2d):
r"""Quantized version of :class:`~.qat.ConvTransposeBnRelu2d`."""
def forward(self, inp):
return self.calc_conv_transpose2d_quantized(inp, nonlinear_mode="relu")
from copy import deepcopy
from ..functional import ones, sqrt, zeros
from ..module import BatchNorm2d, Conv2d, ConvBn2d, ConvBnRelu2d, ConvRelu2d, ReLU
from ..module import (
BatchNorm2d,
Conv2d,
ConvBn2d,
ConvBnRelu2d,
ConvRelu2d,
ConvTranspose2d,
ConvTransposeBn2d,
ConvTransposeBnRelu2d,
ConvTransposeRelu2d,
ReLU,
)
from ..tensor import Parameter
_MAP_TO_FUSED_MODULE = {
(Conv2d, BatchNorm2d, ReLU, False): ConvRelu2d,
(Conv2d, BatchNorm2d, ReLU, True): ConvBnRelu2d,
(ConvTranspose2d, BatchNorm2d, ReLU, False): ConvTransposeRelu2d,
(ConvTranspose2d, BatchNorm2d, ReLU, True): ConvTransposeBnRelu2d,
(Conv2d, BatchNorm2d, False): Conv2d,
(Conv2d, BatchNorm2d, True): ConvBn2d,
(Conv2d, ReLU): ConvRelu2d,
(ConvTranspose2d, BatchNorm2d, False): ConvTranspose2d,
(ConvTranspose2d, BatchNorm2d, True): ConvTransposeBn2d,
(ConvTranspose2d, ReLU): ConvTransposeRelu2d,
}
def fold_weight_bias(weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5):
# get fold bn conv param
def fold_weight_bias(
weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5, transpose=False
):
shape = (1, -1, 1, 1)
if transpose:
shape = (-1, 1, 1, 1)
kernel_shape = weight.shape
if len(kernel_shape) == 5:
groups, num_features = kernel_shape[0], kernel_shape[1]
else:
groups, num_features = 1, kernel_shape[0]
out_channels = groups * num_features
if gamma is None:
gamma = ones((num_features), dtype="float32")
gamma = ones((out_channels,), dtype="float32")
gamma = gamma.reshape(1, -1, 1, 1)
if beta is None:
beta = zeros((num_features), dtype="float32")
beta = zeros((out_channels,), dtype="float32")
beta = beta.reshape(1, -1, 1, 1)
if bn_mean is None:
bn_mean = zeros((1, num_features, 1, 1), dtype="float32")
bn_mean = zeros((1, out_channels, 1, 1), dtype="float32")
if bn_var is None:
bn_var = ones((1, num_features, 1, 1), dtype="float32")
bn_var = ones((1, out_channels, 1, 1), dtype="float32")
if bias is None:
bias = zeros((1, num_features, 1, 1), dtype="float32")
bias = zeros((1, out_channels, 1, 1), dtype="float32")
bn_istd = 1.0 / sqrt(bn_var + eps)
scale_factor = gamma * bn_istd
if groups == 1:
w_fold = weight * scale_factor.reshape(-1, 1, 1, 1)
w_fold = weight * scale_factor.reshape(*shape)
else:
w_fold = weight * scale_factor.reshape(groups, -1, 1, 1, 1)
w_fold = weight * scale_factor.reshape(groups, *shape)
b_fold = beta + gamma * (bias - bn_mean) * bn_istd
return w_fold, b_fold
......@@ -84,3 +106,55 @@ def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU):
module.bn = deepcopy(bn)
new_conv.training = conv.training
return module
def fuse_conv_transpose2d_bn_relu_module(
conv_transpose2d: ConvTranspose2d, bn: BatchNorm2d, relu: ReLU
):
module_key = tuple([type(m) for m in [conv_transpose2d, bn, relu] if m])
if bn:
assert (
conv_transpose2d.training == bn.training
), "ConvTranspose2d and BN both must be in the same mode (train or eval)."
assert (
bn.num_features == conv_transpose2d.out_channels
), "Output channel of ConvTranspose2d must match num_features of BatchNorm2d"
module_key = module_key + (conv_transpose2d.training,)
module = _MAP_TO_FUSED_MODULE[module_key](
in_channels=conv_transpose2d.in_channels,
out_channels=conv_transpose2d.out_channels,
kernel_size=conv_transpose2d.kernel_size,
stride=conv_transpose2d.stride,
padding=conv_transpose2d.padding,
output_padding=conv_transpose2d.output_padding,
dilation=conv_transpose2d.dilation,
groups=conv_transpose2d.groups,
bias=conv_transpose2d.bias is not None,
conv_mode=conv_transpose2d.conv_mode,
compute_mode=conv_transpose2d.compute_mode,
name=conv_transpose2d.name,
)
new_conv_transpose2d = (
module
if bn is None or not conv_transpose2d.training
else module.conv_transpose2d
)
weight, bias = conv_transpose2d.weight, conv_transpose2d.bias
if not conv_transpose2d.training and bn is not None:
weight, bias = fold_weight_bias(
weight,
bias,
bn.weight,
bn.bias,
bn.running_mean,
bn.running_var,
bn.eps,
transpose=False,
)
new_conv_transpose2d.weight = Parameter(weight)
if bias is not None:
new_conv_transpose2d.bias = Parameter(bias)
if bn is not None and conv_transpose2d.training:
module.bn = deepcopy(bn)
new_conv_transpose2d.training = conv_transpose2d.training
return module
......@@ -5,7 +5,9 @@ import numpy as np
import pytest
import megengine.utils.comp_graph_tools as cgtools
from megengine import jit, tensor
from megengine import jit
from megengine import module as M
from megengine import tensor
from megengine.device import get_device_count
from megengine.functional import expand_dims
from megengine.module import (
......@@ -14,6 +16,8 @@ from megengine.module import (
ConvBn2d,
ConvRelu2d,
ConvTranspose2d,
ConvTransposeBn2d,
ConvTransposeRelu2d,
DequantStub,
Module,
QuantStub,
......@@ -34,6 +38,49 @@ def test_qat_convbn2d():
module = ConvBn2d(
in_channels, out_channels, kernel_size, groups=groups, bias=bias
)
M.init.normal_(module.bn.weight)
M.init.normal_(module.bn.bias)
module.train()
qat_module = quantize_qat(module, inplace=False)
disable_fake_quant(qat_module)
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
normal_outputs = module(inputs)
qat_outputs = qat_module(inputs)
np.testing.assert_allclose(
normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6
)
np.testing.assert_allclose(
module.bn.running_mean.numpy(),
qat_module.bn.running_mean.numpy(),
atol=5e-8,
)
np.testing.assert_allclose(
module.bn.running_var.numpy(), qat_module.bn.running_var.numpy(), atol=5e-7,
)
module.eval()
normal_outputs = module(inputs)
qat_module.eval()
qat_outputs = qat_module(inputs)
np.testing.assert_allclose(
normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6
)
def test_qat_convtransposebn2d():
in_channels = 32
out_channels = 64
kernel_size = 3
for groups, bias in product([1, 4], [True, False]):
module = ConvTransposeBn2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
output_padding=0,
groups=groups,
bias=bias,
)
M.init.normal_(module.bn.weight)
M.init.normal_(module.bn.bias)
module.train()
qat_module = quantize_qat(module, inplace=False)
disable_fake_quant(qat_module)
......@@ -235,10 +282,14 @@ def test_qat_conv_transpose2d():
self.conv = ConvTranspose2d(
in_channels, out_channels, kernel_size, bias=bias
)
self.conv_transpose2d_relu = ConvTransposeRelu2d(
out_channels, in_channels, kernel_size, bias=bias
)
def forward(self, inp):
out = self.quant(inp)
out = self.conv(out)
out = self.conv_transpose2d_relu(out)
out = self.dequant(out)
return out
......@@ -250,10 +301,14 @@ def test_qat_conv_transpose2d():
disable_fake_quant(qat_net)
normal_outputs = net(inputs)
qat_outputs = qat_net(inputs)
np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
np.testing.assert_allclose(
normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-6
)
net.eval()
normal_outputs = net(inputs)
qat_net.eval()
qat_outputs = qat_net(inputs)
np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
np.testing.assert_allclose(
normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-6
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册