diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index 6c850e6c708342584dfe25e2ca9536c2d6bd7026..cde81a764414aa64eea6c8b92872a264f7a3da69 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""Quantization aware.""" +"""Quantization aware training.""" from functools import partial import numpy as np @@ -43,6 +43,7 @@ __all__ = [ 'Conv2dQuant', 'DenseQuant', 'ActQuant', + 'LeakyReLUQuant', 'HSwishQuant', 'HSigmoidQuant', 'TensorAddQuant', @@ -349,7 +350,7 @@ class FakeQuantWithMinMax(Cell): self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False) # init fake quant relative op - if per_channel: + if self.per_channel: quant_fun = partial(Q.FakeQuantPerChannel, channel_axis=self.channel_axis) ema_fun = partial(Q.MinMaxUpdatePerChannel, channel_axis=self.channel_axis) else: @@ -369,7 +370,7 @@ class FakeQuantWithMinMax(Cell): num_bits=self.num_bits, symmetric=self.symmetric, narrow_range=self.narrow_range, - quant_delay=quant_delay) + quant_delay=self.quant_delay) self.fake_quant_train = quant_fun(training=True) self.fake_quant_infer = quant_fun(training=False) @@ -832,7 +833,7 @@ class ActQuant(_QuantActivation): Tensor, with the same type and shape as the `x`. Examples: - >>> act_quant = nn.ActQuant(4, 1) + >>> act_quant = nn.ActQuant(nn.ReLU) >>> input_x = Tensor(np.array([[1, 2, -1], [-2, 0, -1]]), mindspore.float32) >>> result = act_quant(input_x) """ @@ -855,7 +856,7 @@ class ActQuant(_QuantActivation): symmetric=symmetric, narrow_range=narrow_range, quant_delay=quant_delay) - self.act = activation() + self.act = activation def construct(self, x): x = self.act(x) @@ -865,6 +866,75 @@ class ActQuant(_QuantActivation): def get_origin(self): return self.act +class LeakyReLUQuant(_QuantActivation): + r""" + LeakyReLUQuant activation function. Add Fake Quant OP after HSwish OP. + + For a more Detailed overview of HSwish op. + + Args: + activation (Cell): Activation cell class. + ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. + per_channel (bool): Quantization granularity based on layer or on channel. Default: False. + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. + symmetric (bool): Quantization algorithm use symmetric or not. Default: False. + narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. + + Inputs: + - **x** (Tensor) - The input of HSwishQuant. + + Outputs: + Tensor, with the same type and shape as the `x`. + + Examples: + >>> activation = nn.LeakyReLUQuant(nn.LeakyReLU()) + >>> input = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32) + >>> result = activation(input) + """ + + def __init__(self, + activation, + ema_decay=0.999, + per_channel=False, + num_bits=8, + symmetric=False, + narrow_range=False, + quant_delay=0): + super(LeakyReLUQuant, self).__init__() + self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6, + max_init=6, + ema=True, + ema_decay=ema_decay, + per_channel=per_channel, + num_bits=num_bits, + symmetric=symmetric, + narrow_range=narrow_range, + quant_delay=quant_delay) + self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6, + max_init=6, + ema=True, + ema_decay=ema_decay, + per_channel=per_channel, + num_bits=num_bits, + symmetric=symmetric, + narrow_range=narrow_range, + quant_delay=quant_delay) + if issubclass(activation.__class__, nn.LeakyReLU): + self.act = activation + else: + raise ValueError("Activation should be `nn.LeakyReLU`") + + def construct(self, x): + x = self.fake_quant_act_before(x) + x = self.act(x) + x = self.fake_quant_act_after(x) + return x + + def get_origin(self): + return self.act + + class HSwishQuant(_QuantActivation): r""" @@ -888,9 +958,9 @@ class HSwishQuant(_QuantActivation): Tensor, with the same type and shape as the `x`. Examples: - >>> hswish_quant = nn.HSwishQuant(4, 1) - >>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32) - >>> result = hswish_quant(input_x) + >>> activation = nn.HSwishQuant(nn.HSwish()) + >>> input = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32) + >>> result = activation(input) """ def __init__(self, @@ -920,8 +990,8 @@ class HSwishQuant(_QuantActivation): symmetric=symmetric, narrow_range=narrow_range, quant_delay=quant_delay) - if issubclass(activation, nn.HSwish): - self.act = activation() + if issubclass(activation.__class__, nn.HSwish): + self.act = activation else: raise ValueError("Activation should be `nn.HSwish`") @@ -957,9 +1027,9 @@ class HSigmoidQuant(_QuantActivation): Tensor, with the same type and shape as the `x`. Examples: - >>> hsigmoid_quant = nn.HSigmoidQuant(4, 1) - >>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32) - >>> result = hsigmoid_quant(input_x) + >>> activation = nn.HSigmoidQuant(nn.HSigmoid()) + >>> input = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32) + >>> result = activation(input) """ def __init__(self, @@ -989,8 +1059,8 @@ class HSigmoidQuant(_QuantActivation): symmetric=symmetric, narrow_range=narrow_range, quant_delay=quant_delay) - if issubclass(activation, nn.HSigmoid): - self.act = activation() + if issubclass(activation.__class__, nn.HSigmoid): + self.act = activation else: raise ValueError("Activation should be `nn.HSigmoid`") diff --git a/mindspore/ops/operations/_quant_ops.py b/mindspore/ops/operations/_quant_ops.py index 1f4de03d3cf29ff14f3bc20fea119fefc8043fff..d34e322bb5d99f03272dc44239c91a29b5f9c317 100644 --- a/mindspore/ops/operations/_quant_ops.py +++ b/mindspore/ops/operations/_quant_ops.py @@ -386,6 +386,8 @@ class FakeQuantPerChannel(PrimitiveWithInfer): raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'") if not self.is_ascend: validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) + if len(x_shape) == 1: + self.channel_axis = 0 validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) validator.check_integer( "min shape", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name) diff --git a/mindspore/train/quant/quant.py b/mindspore/train/quant/quant.py index 4048525029ef913d4a0de799f192d9347cd2586b..d73dce25ae457156d1988ad2892dab7c2f3b0abe 100644 --- a/mindspore/train/quant/quant.py +++ b/mindspore/train/quant/quant.py @@ -35,8 +35,8 @@ from . import quant_utils _ACTIVATION_MAP = {nn.ReLU: quant.ActQuant, nn.ReLU6: quant.ActQuant, - nn.LeakyReLU: quant.ActQuant, nn.Sigmoid: quant.ActQuant, + nn.LeakyReLU: quant.LeakyReLUQuant, nn.HSigmoid: quant.HSigmoidQuant, nn.HSwish: quant.HSwishQuant} @@ -167,32 +167,35 @@ class ConvertToQuantNetwork: convert Conv2d cell to quant cell """ conv_inner = subcell.conv - if subcell.has_bn and self.bn_fold: - bn_inner = subcell.batchnorm - conv_inner = quant.Conv2dBatchNormQuant(conv_inner.in_channels, - conv_inner.out_channels, - kernel_size=conv_inner.kernel_size, - stride=conv_inner.stride, - pad_mode=conv_inner.pad_mode, - padding=conv_inner.padding, - dilation=conv_inner.dilation, - group=conv_inner.group, - eps=bn_inner.eps, - quant_delay=self.weight_qdelay, - freeze_bn=self.freeze_bn, - per_channel=self.weight_channel, - num_bits=self.weight_bits, - fake=True, - symmetric=self.weight_symmetric, - narrow_range=self.weight_range) - # change original network BatchNormal OP parameters to quant network - conv_inner.gamma = subcell.batchnorm.gamma - conv_inner.beta = subcell.batchnorm.beta - conv_inner.moving_mean = subcell.batchnorm.moving_mean - conv_inner.moving_variance = subcell.batchnorm.moving_variance - del subcell.batchnorm - subcell.batchnorm = None - subcell.has_bn = False + if subcell.has_bn: + if self.bn_fold: + bn_inner = subcell.batchnorm + conv_inner = quant.Conv2dBatchNormQuant(conv_inner.in_channels, + conv_inner.out_channels, + kernel_size=conv_inner.kernel_size, + stride=conv_inner.stride, + pad_mode=conv_inner.pad_mode, + padding=conv_inner.padding, + dilation=conv_inner.dilation, + group=conv_inner.group, + eps=bn_inner.eps, + quant_delay=self.weight_qdelay, + freeze_bn=self.freeze_bn, + per_channel=self.weight_channel, + num_bits=self.weight_bits, + fake=True, + symmetric=self.weight_symmetric, + narrow_range=self.weight_range) + # change original network BatchNormal OP parameters to quant network + conv_inner.gamma = subcell.batchnorm.gamma + conv_inner.beta = subcell.batchnorm.beta + conv_inner.moving_mean = subcell.batchnorm.moving_mean + conv_inner.moving_variance = subcell.batchnorm.moving_variance + del subcell.batchnorm + subcell.batchnorm = None + subcell.has_bn = False + else: + raise ValueError("Only support Batchnorm fold mode.") else: conv_inner = quant.Conv2dQuant(conv_inner.in_channels, conv_inner.out_channels, @@ -259,7 +262,7 @@ class ConvertToQuantNetwork: act_class = activation.__class__ if act_class not in _ACTIVATION_MAP: raise ValueError("Unsupported activation in auto quant: ", act_class) - return _ACTIVATION_MAP[act_class](activation=act_class, + return _ACTIVATION_MAP[act_class](activation=activation, num_bits=self.act_bits, quant_delay=self.act_qdelay, per_channel=self.act_channel,