From bb58ea35b992151fb66023ce8ae3aa78af0d67be Mon Sep 17 00:00:00 2001 From: chenzomi Date: Wed, 10 Jun 2020 14:30:08 +0800 Subject: [PATCH] bug fix in fake quant training in r0.3 --- mindspore/nn/layer/quant.py | 7 ++----- mindspore/train/quant/quant.py | 2 +- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index 21e45f517..689f1c495 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -118,7 +118,6 @@ class FakeQuantWithMinMax(Cell): quant_delay (int): Quantization delay parameters according by global step. Default: 0. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. - training (bool): Quantization algorithm training or not. Default: True. Inputs: - **x** (Tensor) - The input of FakeQuantWithMinMax. @@ -143,8 +142,7 @@ class FakeQuantWithMinMax(Cell): out_channels=1, quant_delay=0, symmetric=False, - narrow_range=False, - training=True): + narrow_range=False): """init FakeQuantWithMinMax layer""" super(FakeQuantWithMinMax, self).__init__() self.min_init = min_init @@ -158,7 +156,6 @@ class FakeQuantWithMinMax(Cell): self.quant_delay = quant_delay self.symmetric = symmetric self.narrow_range = narrow_range - self.training = training self.is_ascend = context.get_context('device_target') == "Ascend" # init tensor min and max for fake quant op @@ -208,7 +205,7 @@ class FakeQuantWithMinMax(Cell): return s def construct(self, x): - if self.ema and self.is_ascend: + if self.is_ascend and self.training: min_up, max_up = self.ema_update(x, self.minq, self.maxq) out = self.fake_quant(x, min_up, max_up) P.Assign()(self.minq, min_up) diff --git a/mindspore/train/quant/quant.py b/mindspore/train/quant/quant.py index e2a035bc7..ff4042693 100644 --- a/mindspore/train/quant/quant.py +++ b/mindspore/train/quant/quant.py @@ -247,7 +247,7 @@ def convert_quant_network(network, network (Cell): Obtain a pipeline through network for saving graph summary. quant_delay (int): Number of steps after which weights and activations are quantized during eval. Default: 0. bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: False. - freeze_bn (bool): Number of steps after which BN parameters used total mean and variance. Default: 0. + freeze_bn (int): Number of steps after which BN parameters used total mean and variance. Default: 0. weight_bits (int): Number of bits to use for quantizing weights. Default: 8. act_bits (int): Number of bits to use for quantizing activations. Default: 8. per_channel (bool): Quantization granularity based on layer or on channel. Default: False. -- GitLab