diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index ae01cab882206c2f33a86bba65ff6bc106414206..b843986b0e7d6de8b2c15b7c62b0f55b8acb52f3 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -178,23 +178,19 @@ class FakeQuantWithMinMax(Cell): if self.is_ascend: self.fake_quant = quant_fun(num_bits=self.num_bits, symmetric=self.symmetric, - narrow_range=self.narrow_range, - training=self.training) + narrow_range=self.narrow_range) else: self.fake_quant = quant_fun(num_bits=self.num_bits, ema=self.ema, ema_decay=ema_decay, quant_delay=quant_delay, symmetric=self.symmetric, - narrow_range=self.narrow_range, - training=self.training) - if self.training: - self.ema_update = ema_fun(num_bits=self.num_bits, - ema=self.ema, - ema_decay=self.ema_decay, - symmetric=self.symmetric, - narrow_range=self.narrow_range, - training=self.training) + narrow_range=self.narrow_range) + self.ema_update = ema_fun(num_bits=self.num_bits, + ema=self.ema, + ema_decay=self.ema_decay, + symmetric=self.symmetric, + narrow_range=self.narrow_range) def extend_repr(self): s = 'num_bits={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \