提交 9944abe9 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1963 bug fix in fake quant training in r0.3

Merge pull request !1963 from chenzhongming/r0.3
...@@ -118,7 +118,6 @@ class FakeQuantWithMinMax(Cell): ...@@ -118,7 +118,6 @@ class FakeQuantWithMinMax(Cell):
quant_delay (int): Quantization delay parameters according by global step. Default: 0. quant_delay (int): Quantization delay parameters according by global step. Default: 0.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False. symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
training (bool): Quantization algorithm training or not. Default: True.
Inputs: Inputs:
- **x** (Tensor) - The input of FakeQuantWithMinMax. - **x** (Tensor) - The input of FakeQuantWithMinMax.
...@@ -143,8 +142,7 @@ class FakeQuantWithMinMax(Cell): ...@@ -143,8 +142,7 @@ class FakeQuantWithMinMax(Cell):
out_channels=1, out_channels=1,
quant_delay=0, quant_delay=0,
symmetric=False, symmetric=False,
narrow_range=False, narrow_range=False):
training=True):
"""init FakeQuantWithMinMax layer""" """init FakeQuantWithMinMax layer"""
super(FakeQuantWithMinMax, self).__init__() super(FakeQuantWithMinMax, self).__init__()
self.min_init = min_init self.min_init = min_init
...@@ -158,7 +156,6 @@ class FakeQuantWithMinMax(Cell): ...@@ -158,7 +156,6 @@ class FakeQuantWithMinMax(Cell):
self.quant_delay = quant_delay self.quant_delay = quant_delay
self.symmetric = symmetric self.symmetric = symmetric
self.narrow_range = narrow_range self.narrow_range = narrow_range
self.training = training
self.is_ascend = context.get_context('device_target') == "Ascend" self.is_ascend = context.get_context('device_target') == "Ascend"
# init tensor min and max for fake quant op # init tensor min and max for fake quant op
...@@ -208,7 +205,7 @@ class FakeQuantWithMinMax(Cell): ...@@ -208,7 +205,7 @@ class FakeQuantWithMinMax(Cell):
return s return s
def construct(self, x): def construct(self, x):
if self.ema and self.is_ascend: if self.is_ascend and self.training:
min_up, max_up = self.ema_update(x, self.minq, self.maxq) min_up, max_up = self.ema_update(x, self.minq, self.maxq)
out = self.fake_quant(x, min_up, max_up) out = self.fake_quant(x, min_up, max_up)
P.Assign()(self.minq, min_up) P.Assign()(self.minq, min_up)
......
...@@ -247,7 +247,7 @@ def convert_quant_network(network, ...@@ -247,7 +247,7 @@ def convert_quant_network(network,
network (Cell): Obtain a pipeline through network for saving graph summary. network (Cell): Obtain a pipeline through network for saving graph summary.
quant_delay (int): Number of steps after which weights and activations are quantized during eval. Default: 0. quant_delay (int): Number of steps after which weights and activations are quantized during eval. Default: 0.
bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: False. bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: False.
freeze_bn (bool): Number of steps after which BN parameters used total mean and variance. Default: 0. freeze_bn (int): Number of steps after which BN parameters used total mean and variance. Default: 0.
weight_bits (int): Number of bits to use for quantizing weights. Default: 8. weight_bits (int): Number of bits to use for quantizing weights. Default: 8.
act_bits (int): Number of bits to use for quantizing activations. Default: 8. act_bits (int): Number of bits to use for quantizing activations. Default: 8.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False. per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册