提交 bc9f9cd4 编写于 作者: M Megvii Engine Team

feat(imperative/module): add linear fuse bn and relu support

GitOrigin-RevId: c342f687974f9b47304c26ac39a59b74caaf7f70
上级 f0a3ab97
......@@ -24,7 +24,8 @@ from .dropout import Dropout
from .elemwise import Elemwise
from .embedding import Embedding
from .identity import Identity
from .linear import Linear
from .linear import Linear, LinearRelu
from .linear_bn import LinearBn1d, LinearBnRelu1d
from .lrn import LocalResponseNorm
from .module import Module
from .multiheadattn import MultiHeadAttention
......
import numpy as np
from ..functional.nn import linear
from ..functional.nn import linear, relu
from ..tensor import Parameter
from . import init
from .module import Module
......@@ -62,13 +62,22 @@ class Linear(Module):
if self.bias is not None:
init.zeros_(self.bias)
def _calc_linear(self, x, weight, bias):
def calc_linear(self, x, weight, bias):
return linear(x, weight, bias, compute_mode=self.compute_mode)
def forward(self, x):
return self._calc_linear(x, self.weight, self.bias)
return self.calc_linear(x, self.weight, self.bias)
def _module_info_string(self) -> str:
return "in_features={}, out_features={}, bias={}".format(
self.in_features, self.out_features, self.bias is not None
)
class LinearRelu(Linear):
r"""A fused :class:`~.Module` including :class:`~.module.Linear` and :func:`~.relu`.
Could be replaced with :class:`~.QATModule` version :class:`~.qat.LinearRelu` using :func:`~.quantize.quantize_qat`.
"""
def forward(self, inp):
return relu(self.calc_linear(inp, self.weight, self.bias))
import numpy as np
from ..functional import relu
from ..tensor import Parameter
from .batchnorm import BatchNorm1d
from .linear import Linear
from .module import Module
class _LinearBnActivation1d(Module):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
compute_mode: str = "default",
eps=1e-5,
momentum=0.9,
affine=True,
track_running_stats=True,
**kwargs
):
super().__init__(**kwargs)
self.out_features = out_features
self.in_features = in_features
self.bias = None
if bias:
b_shape = (out_features,)
self.bias = Parameter(np.zeros(b_shape, dtype=np.float32))
self.linear = Linear(in_features, out_features, bias, compute_mode, **kwargs,)
self.bn = BatchNorm1d(out_features, eps, momentum, affine, track_running_stats)
class LinearBn1d(_LinearBnActivation1d):
r"""A fused :class:`~.Module` including :class:`~.module.Linear` and :class:`~.module.BatchNorm1d`.
Could be replaced with :class:`~.QATModule` version :class:`~.qat.LinearBn1d` using
:func:`~.quantize.quantize_qat`.
"""
def forward(self, inp):
return self.bn(self.linear(inp))
class LinearBnRelu1d(_LinearBnActivation1d):
r"""A fused :class:`~.Module` including :class:`~.module.Linear`, :class:`~.module.BatchNorm1d` and :func:`~.relu`.
Could be replaced with :class:`~.QATModule` version :class:`~.qat.LinearBnRelu1d` using :func:`~.quantize.quantize_qat`.
"""
def forward(self, inp):
return relu(self.bn(self.linear(inp)))
......@@ -4,6 +4,7 @@ from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, ConvTransposeRelu2d
from .conv_bn import ConvBn2d, ConvBnRelu2d
from .conv_transpose_bn import ConvTransposeBn2d, ConvTransposeBnRelu2d
from .elemwise import Elemwise
from .linear import Linear
from .linear import Linear, LinearRelu
from .linear_bn import LinearBn1d, LinearBnRelu1d
from .module import QATModule
from .quant_dequant import DequantStub, QuantStub
from ... import functional as F
from .. import linear as Float
from .module import QATModule
......@@ -13,10 +14,16 @@ class Linear(Float.Linear, QATModule):
Default: True
"""
def calc_linear_qat(self, inp):
w_qat = self.apply_quant_weight(self.weight)
b_qat = self.apply_quant_bias(self.bias, inp, w_qat)
linear = self.calc_linear(inp, w_qat, b_qat)
return linear
def forward(self, inp):
w_qat = self.apply_quant_weight(self.weight)
b_qat = self.apply_quant_bias(self.bias, inp, w_qat)
return self.apply_quant_activation(self._calc_linear(inp, w_qat, b_qat))
return self.apply_quant_activation(self.calc_linear(inp, w_qat, b_qat))
@classmethod
def from_float_module(cls, float_module: Float.Linear):
......@@ -30,3 +37,12 @@ class Linear(Float.Linear, QATModule):
qmod.weight = float_module.weight
qmod.bias = float_module.bias
return qmod
class LinearRelu(Linear):
r"""A :class:`~.QATModule` include :class:`~.module.Linear` and :func:`~.relu` with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`.
"""
def forward(self, inp):
return self.apply_quant_activation(F.relu(self.calc_linear_qat(inp)))
from ...functional import linear, ones, relu, sqrt, sum, zeros
from .. import linear_bn as Float
from .module import QATModule
class _LinearBnActivation1d(Float._LinearBnActivation1d, QATModule):
def get_batch_mean_var(self, inp):
def _sum_channel(inp, axis=0, keepdims=True):
if isinstance(axis, int):
out = sum(inp, axis=axis, keepdims=keepdims)
elif isinstance(axis, tuple):
for idx, elem in enumerate(axis):
out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims)
return out
sum1 = _sum_channel(inp, (0, 2, 3))
sum2 = _sum_channel(inp ** 2, (0, 2, 3))
reduce_size = inp.size / inp.shape[1]
batch_mean = sum1 / reduce_size
batch_var = (sum2 - sum1 ** 2 / reduce_size) / reduce_size
return batch_mean, batch_var
def fold_weight_bias(self, bn_mean, bn_var):
weight_shape = [1] * len(self.linear.weight.shape)
weight_shape[0] = -1
bias_shape = [1] * len(self.linear.weight.shape)
bias_shape[1] = -1
# get fold bn linear param
gamma = self.bn.weight
if gamma is None:
gamma = ones((self.bn.num_features,), dtype="float32")
gamma = gamma.reshape(-1)
beta = self.bn.bias
if beta is None:
beta = zeros((self.bn.num_features,), dtype="float32")
beta = beta.reshape(-1)
if bn_mean is None:
bn_mean = zeros((self.bn.num_features,), dtype="float32")
bn_mean = bn_mean.reshape(-1)
if bn_var is None:
bn_var = ones((self.bn.num_features,), dtype="float32")
bn_var = bn_var.reshape(-1)
linear_bias = self.linear.bias
if linear_bias is None:
linear_bias = zeros(beta.shape(), dtype="float32")
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
scale_factor = gamma * bn_istd
w_fold = self.linear.weight * scale_factor.reshape(weight_shape)
w_fold = self.apply_quant_weight(w_fold)
b_fold = beta + gamma * (linear_bias - bn_mean) * bn_istd
return w_fold, b_fold
def update_running_mean_and_running_var(
self, bn_mean, bn_var, num_elements_per_channel
):
# update running mean and running var. no grad, use unbiased bn var
bn_mean = bn_mean.detach()
bn_var = (
bn_var.detach() * num_elements_per_channel / (num_elements_per_channel - 1)
)
exponential_average_factor = 1 - self.bn.momentum
self.bn.running_mean *= self.bn.momentum
self.bn.running_mean += exponential_average_factor * bn_mean
self.bn.running_var *= self.bn.momentum
self.bn.running_var += exponential_average_factor * bn_var
def calc_linear_bn_qat(self, inp, approx=True):
if self.training and not approx:
linear = self.linear(inp)
bn_mean, bn_var = self.get_batch_mean_var(linear)
num_elements_per_channel = linear.size / linear.shape[1]
self.update_running_mean_and_running_var(
bn_mean, bn_var, num_elements_per_channel
)
else:
bn_mean, bn_var = self.bn.running_mean, self.bn.running_var
bn_mean, bn_var = (
self.bn.running_mean.reshape(-1),
self.bn.running_var.reshape(-1),
)
weight_shape = [1] * len(self.linear.weight.shape)
weight_shape[0] = -1
bias_shape = [1] * len(self.linear.weight.shape)
bias_shape[1] = -1
# get gamma and beta in BatchNorm
gamma = self.bn.weight
if gamma is None:
gamma = ones((self.bn.num_features,), dtype="float32")
gamma = gamma.reshape(-1)
beta = self.bn.bias
if beta is None:
beta = zeros((self.bn.num_features,), dtype="float32")
beta = beta.reshape(-1)
# linear_bias
linear_bias = self.linear.bias
if linear_bias is None:
linear_bias = zeros(beta.shape, dtype="float32")
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
scale_factor = gamma * bn_istd
w_fold = self.linear.weight * scale_factor.reshape(weight_shape)
b_fold = None
if not (self.training and approx):
b_fold = beta + gamma * (linear_bias - bn_mean) * bn_istd
w_qat = self.apply_quant_weight(w_fold)
b_qat = self.apply_quant_bias(b_fold, inp, w_qat)
linear = self.linear.calc_linear(inp, w_qat, b_qat)
if not (self.training and approx):
return linear
# rescale linear to get original linear output
orig_linear = linear / scale_factor.reshape(bias_shape)
if self.linear.bias is not None:
orig_linear = orig_linear + self.linear.bias.reshape(bias_shape)
# calculate batch norm
linear = self.bn(orig_linear)
return linear
@classmethod
def from_float_module(cls, float_module: Float._LinearBnActivation1d):
qat_module = cls(
float_module.linear.in_features,
float_module.linear.out_features,
float_module.linear.bias is not None,
float_module.linear.compute_mode,
float_module.bn.eps,
float_module.bn.momentum,
float_module.bn.affine,
float_module.bn.track_running_stats,
name=float_module.name,
)
qat_module.linear.weight = float_module.linear.weight
qat_module.linear.bias = float_module.linear.bias
qat_module.bn = float_module.bn
return qat_module
class LinearBn1d(_LinearBnActivation1d):
r"""A fused :class:`~.QATModule` including :class:`~.module.Linear` and :class:`~.module.BatchNorm1d` with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`.
"""
def forward(self, inp):
return self.apply_quant_activation(self.calc_linear_bn_qat(inp))
class LinearBnRelu1d(_LinearBnActivation1d):
r"""A fused :class:`~.QATModule` including :class:`~.module.Linear`, :class:`~.module.BatchNorm1d` and :func:`~.relu` with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`.
"""
def forward(self, inp):
return self.apply_quant_activation(relu(self.calc_linear_bn_qat(inp)))
......@@ -4,6 +4,7 @@ from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, ConvTransposeRelu2d
from .conv_bn import ConvBn2d, ConvBnRelu2d
from .conv_transpose_bn import ConvTransposeBn2d, ConvTransposeBnRelu2d
from .elemwise import Elemwise
from .linear import Linear
from .linear import Linear, LinearRelu
from .linear_bn import LinearBn1d, LinearBnRelu1d
from .module import QuantizedModule
from .quant_dequant import DequantStub, QuantStub
import numpy as np
from ... import functional as F
from ... import module as Float
from ...core.tensor import dtype
from ...tensor import Parameter
from ..qat import linear as QAT
......@@ -16,20 +17,30 @@ class Linear(QuantizedModule):
self.bias = None
self.output_dtype = dtype
def forward(self, inp):
def calc_linear_quantized(self, inp, nonlinear_mode="identity"):
if self.training:
raise ValueError("quantized module only support inference.")
assert nonlinear_mode in ["identity", "relu"]
inp_scale = dtype.get_scale(inp.dtype)
w_scale = dtype.get_scale(self.weight.dtype)
bias_dtype = dtype.qint32(inp_scale * w_scale)
ret = F.nn.linear(
ret = F.linear(
inp,
self.weight,
None if self.bias is None else self.bias.astype(bias_dtype),
)
ret = ret if self.output_dtype is None else ret.astype(self.output_dtype)
if nonlinear_mode == "relu":
ret = F.relu(ret)
return ret
def forward(self, inp):
return self.calc_linear_quantized(inp)
@classmethod
def from_qat_module(cls, qat_module: QAT.Linear):
r"""
......@@ -38,8 +49,16 @@ class Linear(QuantizedModule):
"""
output_dtype = qat_module.get_activation_dtype()
qmod = cls(dtype=output_dtype, name=qat_module.name)
qmod.name = qat_module.name
weight = qat_module.weight.astype(qat_module.get_weight_dtype())
qmod.weight = Parameter(weight.numpy(), name=qat_module.weight.name)
if qat_module.bias is not None:
qmod.bias = Parameter(qat_module.bias.numpy(), name=qat_module.bias.name)
return qmod
class LinearRelu(Linear):
r"""Quantized version of :class:`~.qat.LinearRelu`."""
def forward(self, inp):
return self.calc_linear_quantized(inp, nonlinear_mode="relu")
from ...tensor import Parameter
from ..qat import linear_bn as QAT
from .linear import Linear
class _LinearBnActivation1d(Linear):
r"""Applies a Linear over a quantized input tensor, used for inference only.
"""
@classmethod
def from_qat_module(cls, qat_module: QAT._LinearBnActivation1d):
r"""
Return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
output_dtype = qat_module.get_activation_dtype()
qlinear = cls(dtype=output_dtype, name=qat_module.name,)
w_fold, b_fold = qat_module.fold_weight_bias(
qat_module.bn.running_mean, qat_module.bn.running_var
)
weight = w_fold.astype(qat_module.get_weight_dtype())
qlinear.weight = Parameter(weight.numpy(), name=qat_module.linear.weight.name)
qlinear.bias = Parameter(b_fold.numpy())
if qat_module.linear.bias is not None:
qlinear.bias.name = qat_module.linear.bias.name
return qlinear
class LinearBn1d(_LinearBnActivation1d):
r"""Quantized version of :class:`~.qat.LinearBn1d`."""
def forward(self, inp):
return self.calc_linear_quantized(inp, nonlinear_mode="identity")
class LinearBnRelu1d(_LinearBnActivation1d):
r"""Quantized version of :class:`~.qat.LinearBnRelu1d`."""
def forward(self, inp):
return self.calc_linear_quantized(inp, nonlinear_mode="relu")
......@@ -2,6 +2,7 @@ from copy import deepcopy
from ..functional import ones, sqrt, zeros
from ..module import (
BatchNorm1d,
BatchNorm2d,
Conv2d,
ConvBn2d,
......@@ -11,6 +12,10 @@ from ..module import (
ConvTransposeBn2d,
ConvTransposeBnRelu2d,
ConvTransposeRelu2d,
Linear,
LinearBn1d,
LinearBnRelu1d,
LinearRelu,
ReLU,
)
from ..tensor import Parameter
......@@ -26,10 +31,15 @@ _MAP_TO_FUSED_MODULE = {
(ConvTranspose2d, BatchNorm2d, False): ConvTranspose2d,
(ConvTranspose2d, BatchNorm2d, True): ConvTransposeBn2d,
(ConvTranspose2d, ReLU): ConvTransposeRelu2d,
(Linear, BatchNorm1d, ReLU, False): LinearRelu,
(Linear, BatchNorm1d, ReLU, True): LinearBnRelu1d,
(Linear, BatchNorm1d, False): Linear,
(Linear, BatchNorm1d, True): LinearBn1d,
(Linear, ReLU): LinearRelu,
}
def fold_weight_bias(
def _fold_conv_bn_weight_bias(
weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5, transpose=False
):
shape = (-1, 1, 1, 1)
......@@ -76,6 +86,57 @@ def fold_weight_bias(
return w_fold, b_fold
def _fold_linear_bn_weight_bias(weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5):
bn_mean, bn_var = bn_mean.reshape(-1), bn_var.reshape(-1)
weight_shape = [1] * len(weight.shape)
weight_shape[0] = -1
bias_shape = [1] * len(weight.shape)
bias_shape[1] = -1
out_features = weight.shape[0]
if gamma is None:
gamma = ones((out_features,), dtype="float32")
else:
gamma = gamma.reshape(-1)
if beta is None:
beta = zeros((out_features,), dtype="float32")
else:
beta = beta.reshape(-1)
if bn_mean is None:
bn_mean = zeros((out_features,), dtype="float32")
else:
bn_mean = bn_mean.reshape(-1)
if bn_var is None:
bn_var = ones((out_features,), dtype="float32")
else:
bn_var = bn_var.reshape(-1)
if bias is None:
bias = zeros((beta.shape), dtype="float32")
else:
bias = bias.reshape(-1)
bn_istd = 1.0 / sqrt(bn_var + eps)
scale_factor = gamma * bn_istd
w_fold = weight * scale_factor.reshape(*weight_shape)
b_fold = beta + gamma * (bias - bn_mean) * bn_istd
return w_fold, b_fold
def fold_weight_bias(
weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5, transpose=False
):
if weight.ndim != 2:
return _fold_conv_bn_weight_bias(
weight, bias, gamma, beta, bn_mean, bn_var, eps, transpose
)
return _fold_linear_bn_weight_bias(weight, bias, gamma, beta, bn_mean, bn_var, eps)
def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU):
module_key = tuple([type(m) for m in [conv, bn, relu] if m])
if bn:
......@@ -137,3 +198,37 @@ def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU):
module.bn = deepcopy(bn)
new_conv.training = conv.training
return module
def fuse_linear_bn_relu_module(linear: Linear, bn: BatchNorm1d, relu: ReLU):
module_key = tuple([type(m) for m in [linear, bn, relu] if m])
if bn:
assert (
linear.training == bn.training
), "Linear and BN both must be in the same mode (train or eval)."
assert (
bn.num_features == linear.out_features
), "Output channel of Linear must match num_features of BatchNorm1d"
module_key = module_key + (linear.training,)
module = _MAP_TO_FUSED_MODULE[module_key](
in_features=linear.in_features,
out_features=linear.out_features,
bias=linear.bias is not None,
compute_mode=linear.compute_mode,
name=linear.name,
)
new_linear = module if bn is None or not linear.training else module.linear
weight, bias = linear.weight, linear.bias
if not linear.training and bn is not None:
weight, bias = fold_weight_bias(
weight, bias, bn.weight, bn.bias, bn.running_mean, bn.running_var, bn.eps,
)
new_linear.weight = Parameter(weight)
if bias is not None:
new_linear.bias = Parameter(bias)
if bn is not None and linear.training:
module.bn = deepcopy(bn)
new_linear.training = linear.training
return module
......@@ -19,6 +19,10 @@ from megengine.module import (
ConvTransposeBn2d,
ConvTransposeRelu2d,
DequantStub,
Linear,
LinearBn1d,
LinearBnRelu1d,
LinearRelu,
Module,
QuantStub,
)
......@@ -330,3 +334,127 @@ def test_qat_conv_transpose2d():
np.testing.assert_allclose(
normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-6
)
def test_qat_linearbn1d():
in_features = 10
out_features = 5
class TestNet(Module):
def __init__(self, bias):
super().__init__()
self.quant = QuantStub()
self.dequant = DequantStub()
self.linear_bn = LinearBn1d(in_features, out_features, bias=bias,)
def forward(self, inp):
out = self.quant(inp)
out = self.linear_bn(out)
out = self.dequant(out)
return out
inputs = tensor(np.random.randn(4, in_features).astype(np.float32))
for bias in [True, False]:
net = TestNet(bias)
net.train()
qat_net = quantize_qat(net, inplace=False)
disable_fake_quant(qat_net)
normal_outputs = net(inputs)
qat_outputs = qat_net(inputs)
np.testing.assert_allclose(
normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-6,
)
np.testing.assert_allclose(
net.linear_bn.bn.running_mean.numpy(),
qat_net.linear_bn.bn.running_mean.numpy(),
atol=5e-8,
)
np.testing.assert_allclose(
net.linear_bn.bn.running_var.numpy(),
qat_net.linear_bn.bn.running_var.numpy(),
atol=5e-7,
)
net.eval()
normal_outputs = net(inputs)
qat_net.eval()
qat_outputs = qat_net(inputs)
np.testing.assert_allclose(
normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-6,
)
def test_qat_linear_relu():
in_features = 10
out_features = 5
class TestNet(Module):
def __init__(self, bias):
super().__init__()
self.quant = QuantStub()
self.dequant = DequantStub()
self.linear_relu = LinearRelu(in_features, out_features, bias=bias,)
def forward(self, inp):
out = self.quant(inp)
out = self.linear_relu(out)
out = self.dequant(out)
return out
inputs = tensor(np.random.randn(4, in_features).astype(np.float32))
for bias in [True, False]:
net = TestNet(bias)
net.train()
qat_net = quantize_qat(net, inplace=False)
disable_fake_quant(qat_net)
normal_outputs = net(inputs)
qat_outputs = qat_net(inputs)
np.testing.assert_allclose(
normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-6,
)
net.eval()
normal_outputs = net(inputs)
qat_net.eval()
qat_outputs = qat_net(inputs)
np.testing.assert_allclose(
normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-6,
)
def test_qat_linear_bn_relu():
in_features = 10
out_features = 5
class TestNet(Module):
def __init__(self, bias):
super().__init__()
self.quant = QuantStub()
self.dequant = DequantStub()
self.linear_bn_relu = LinearBnRelu1d(in_features, out_features, bias=bias,)
def forward(self, inp):
out = self.quant(inp)
out = self.linear_bn_relu(out)
out = self.dequant(out)
return out
inputs = tensor(np.random.randn(4, in_features).astype(np.float32))
for bias in [True, False]:
net = TestNet(bias)
net.train()
qat_net = quantize_qat(net, inplace=False)
disable_fake_quant(qat_net)
normal_outputs = net(inputs)
qat_outputs = qat_net(inputs)
np.testing.assert_allclose(
normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-6,
)
net.eval()
normal_outputs = net(inputs)
qat_net.eval()
qat_outputs = qat_net(inputs)
np.testing.assert_allclose(
normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-6,
)
......@@ -5,11 +5,14 @@ from megengine import Parameter, Tensor
from megengine import module as Float
from megengine.functional import ones, zeros
from megengine.module import (
BatchNorm1d,
BatchNorm2d,
Conv2d,
ConvBn2d,
ConvTranspose2d,
ConvTransposeBn2d,
Linear,
LinearBn1d,
ReLU,
)
from megengine.module import qat as QAT
......@@ -33,7 +36,10 @@ from megengine.quantization.quantize import (
quantize_qat,
reset_qconfig,
)
from megengine.utils.bn_fusion import fuse_conv_bn_relu_module
from megengine.utils.bn_fusion import (
fuse_conv_bn_relu_module,
fuse_linear_bn_relu_module,
)
class FloatNet(Float.Module):
......@@ -383,3 +389,43 @@ def test_ConvTransposeBn2d_fold_weight_bias():
np.testing.assert_allclose(
expected_result.numpy(), actual_result.numpy(), atol=1e-4
)
def test_LinearBn1d_fold_weight_bias():
in_features = 10
out_features = 5
linear = Linear(in_features, out_features)
bn = BatchNorm1d(out_features)
relu = ReLU()
fused_linear = fuse_linear_bn_relu_module(linear, bn, relu)
bn.eval()
fused_linear.eval()
inputs = Tensor(np.random.randn(4, in_features).astype(np.float32))
expected_result = relu(bn(linear(inputs)))
actual_result = fused_linear(inputs)
np.testing.assert_allclose(
expected_result.numpy(), actual_result.numpy(), atol=1e-4
)
linear.eval()
bn.eval()
relu.eval()
fused_linear = fuse_linear_bn_relu_module(linear, bn, relu)
fused_linear.eval()
expected_result = relu(linear(inputs))
actual_result = fused_linear(inputs)
np.testing.assert_allclose(
expected_result.numpy(), actual_result.numpy(), atol=1e-4
)
linear.train()
bn.train()
fused_linear = fuse_linear_bn_relu_module(linear, bn, None)
fused_linear.train()
expected_result = bn(linear(inputs))
actual_result = fused_linear(inputs)
np.testing.assert_allclose(
expected_result.numpy(), actual_result.numpy(), atol=1e-4
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册