提交 521f1e59 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3457 add LeakReLUQuant OP for bug fix.

Merge pull request !3457 from chenzhongming/master
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Quantization aware."""
"""Quantization aware training."""
from functools import partial
import numpy as np
......@@ -43,6 +43,7 @@ __all__ = [
'Conv2dQuant',
'DenseQuant',
'ActQuant',
'LeakyReLUQuant',
'HSwishQuant',
'HSigmoidQuant',
'TensorAddQuant',
......@@ -349,7 +350,7 @@ class FakeQuantWithMinMax(Cell):
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
# init fake quant relative op
if per_channel:
if self.per_channel:
quant_fun = partial(Q.FakeQuantPerChannel, channel_axis=self.channel_axis)
ema_fun = partial(Q.MinMaxUpdatePerChannel, channel_axis=self.channel_axis)
else:
......@@ -369,7 +370,7 @@ class FakeQuantWithMinMax(Cell):
num_bits=self.num_bits,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
quant_delay=quant_delay)
quant_delay=self.quant_delay)
self.fake_quant_train = quant_fun(training=True)
self.fake_quant_infer = quant_fun(training=False)
......@@ -832,7 +833,7 @@ class ActQuant(_QuantActivation):
Tensor, with the same type and shape as the `x`.
Examples:
>>> act_quant = nn.ActQuant(4, 1)
>>> act_quant = nn.ActQuant(nn.ReLU)
>>> input_x = Tensor(np.array([[1, 2, -1], [-2, 0, -1]]), mindspore.float32)
>>> result = act_quant(input_x)
"""
......@@ -855,7 +856,7 @@ class ActQuant(_QuantActivation):
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
self.act = activation()
self.act = activation
def construct(self, x):
x = self.act(x)
......@@ -865,6 +866,75 @@ class ActQuant(_QuantActivation):
def get_origin(self):
return self.act
class LeakyReLUQuant(_QuantActivation):
r"""
LeakyReLUQuant activation function. Add Fake Quant OP after HSwish OP.
For a more Detailed overview of HSwish op.
Args:
activation (Cell): Activation cell class.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
Inputs:
- **x** (Tensor) - The input of HSwishQuant.
Outputs:
Tensor, with the same type and shape as the `x`.
Examples:
>>> activation = nn.LeakyReLUQuant(nn.LeakyReLU())
>>> input = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
>>> result = activation(input)
"""
def __init__(self,
activation,
ema_decay=0.999,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False,
quant_delay=0):
super(LeakyReLUQuant, self).__init__()
self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=True,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=True,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
if issubclass(activation.__class__, nn.LeakyReLU):
self.act = activation
else:
raise ValueError("Activation should be `nn.LeakyReLU`")
def construct(self, x):
x = self.fake_quant_act_before(x)
x = self.act(x)
x = self.fake_quant_act_after(x)
return x
def get_origin(self):
return self.act
class HSwishQuant(_QuantActivation):
r"""
......@@ -888,9 +958,9 @@ class HSwishQuant(_QuantActivation):
Tensor, with the same type and shape as the `x`.
Examples:
>>> hswish_quant = nn.HSwishQuant(4, 1)
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
>>> result = hswish_quant(input_x)
>>> activation = nn.HSwishQuant(nn.HSwish())
>>> input = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
>>> result = activation(input)
"""
def __init__(self,
......@@ -920,8 +990,8 @@ class HSwishQuant(_QuantActivation):
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
if issubclass(activation, nn.HSwish):
self.act = activation()
if issubclass(activation.__class__, nn.HSwish):
self.act = activation
else:
raise ValueError("Activation should be `nn.HSwish`")
......@@ -957,9 +1027,9 @@ class HSigmoidQuant(_QuantActivation):
Tensor, with the same type and shape as the `x`.
Examples:
>>> hsigmoid_quant = nn.HSigmoidQuant(4, 1)
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
>>> result = hsigmoid_quant(input_x)
>>> activation = nn.HSigmoidQuant(nn.HSigmoid())
>>> input = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
>>> result = activation(input)
"""
def __init__(self,
......@@ -989,8 +1059,8 @@ class HSigmoidQuant(_QuantActivation):
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
if issubclass(activation, nn.HSigmoid):
self.act = activation()
if issubclass(activation.__class__, nn.HSigmoid):
self.act = activation
else:
raise ValueError("Activation should be `nn.HSigmoid`")
......
......@@ -386,6 +386,8 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'")
if not self.is_ascend:
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
if len(x_shape) == 1:
self.channel_axis = 0
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
validator.check_integer(
"min shape", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
......
......@@ -35,8 +35,8 @@ from . import quant_utils
_ACTIVATION_MAP = {nn.ReLU: quant.ActQuant,
nn.ReLU6: quant.ActQuant,
nn.LeakyReLU: quant.ActQuant,
nn.Sigmoid: quant.ActQuant,
nn.LeakyReLU: quant.LeakyReLUQuant,
nn.HSigmoid: quant.HSigmoidQuant,
nn.HSwish: quant.HSwishQuant}
......@@ -167,32 +167,35 @@ class ConvertToQuantNetwork:
convert Conv2d cell to quant cell
"""
conv_inner = subcell.conv
if subcell.has_bn and self.bn_fold:
bn_inner = subcell.batchnorm
conv_inner = quant.Conv2dBatchNormQuant(conv_inner.in_channels,
conv_inner.out_channels,
kernel_size=conv_inner.kernel_size,
stride=conv_inner.stride,
pad_mode=conv_inner.pad_mode,
padding=conv_inner.padding,
dilation=conv_inner.dilation,
group=conv_inner.group,
eps=bn_inner.eps,
quant_delay=self.weight_qdelay,
freeze_bn=self.freeze_bn,
per_channel=self.weight_channel,
num_bits=self.weight_bits,
fake=True,
symmetric=self.weight_symmetric,
narrow_range=self.weight_range)
# change original network BatchNormal OP parameters to quant network
conv_inner.gamma = subcell.batchnorm.gamma
conv_inner.beta = subcell.batchnorm.beta
conv_inner.moving_mean = subcell.batchnorm.moving_mean
conv_inner.moving_variance = subcell.batchnorm.moving_variance
del subcell.batchnorm
subcell.batchnorm = None
subcell.has_bn = False
if subcell.has_bn:
if self.bn_fold:
bn_inner = subcell.batchnorm
conv_inner = quant.Conv2dBatchNormQuant(conv_inner.in_channels,
conv_inner.out_channels,
kernel_size=conv_inner.kernel_size,
stride=conv_inner.stride,
pad_mode=conv_inner.pad_mode,
padding=conv_inner.padding,
dilation=conv_inner.dilation,
group=conv_inner.group,
eps=bn_inner.eps,
quant_delay=self.weight_qdelay,
freeze_bn=self.freeze_bn,
per_channel=self.weight_channel,
num_bits=self.weight_bits,
fake=True,
symmetric=self.weight_symmetric,
narrow_range=self.weight_range)
# change original network BatchNormal OP parameters to quant network
conv_inner.gamma = subcell.batchnorm.gamma
conv_inner.beta = subcell.batchnorm.beta
conv_inner.moving_mean = subcell.batchnorm.moving_mean
conv_inner.moving_variance = subcell.batchnorm.moving_variance
del subcell.batchnorm
subcell.batchnorm = None
subcell.has_bn = False
else:
raise ValueError("Only support Batchnorm fold mode.")
else:
conv_inner = quant.Conv2dQuant(conv_inner.in_channels,
conv_inner.out_channels,
......@@ -259,7 +262,7 @@ class ConvertToQuantNetwork:
act_class = activation.__class__
if act_class not in _ACTIVATION_MAP:
raise ValueError("Unsupported activation in auto quant: ", act_class)
return _ACTIVATION_MAP[act_class](activation=act_class,
return _ACTIVATION_MAP[act_class](activation=activation,
num_bits=self.act_bits,
quant_delay=self.act_qdelay,
per_channel=self.act_channel,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册