From ece74c4cd49189480bb196904e2577eba7653990 Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Thu, 10 Sep 2020 19:11:18 +0800 Subject: [PATCH] Update the _get_fake_quant_type definition in imperative QAT. (#27222) --- .../slim/quantization/imperative/qat.py | 1 - .../slim/quantization/imperative/quant_nn.py | 49 ++++++++++++++----- 2 files changed, 36 insertions(+), 14 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py index 5662284483..8d399c9290 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py @@ -192,7 +192,6 @@ class ImperativeQuantAware(object): assert len(input_dtype) == len( feed), "The length of input_shape should be equal to feed's." - prog_trans = dygraph.ProgramTranslator() with dygraph.guard(): model.eval() input_vars = [] diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/quant_nn.py b/python/paddle/fluid/contrib/slim/quantization/imperative/quant_nn.py index 59dd9867ab..e22c980b0a 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/quant_nn.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/quant_nn.py @@ -209,15 +209,24 @@ class FakeQuantAbsMax(layers.Layer): return quant_out -def _get_fake_quant_type(quant_type, name, moving_rate, quant_bits, dtype, - quant_on_weight): +def _get_fake_quant_type(quant_type, **kwargs): + call_args = { + "name": kwargs.get("name", None), + "quant_bits": kwargs.get("quant_bits", 8), + "dtype": kwargs.get("dtype", "float32") + } + + if quant_type == 'abs_max': + call_args["quant_on_weight"] = kwargs.get("quant_on_weight", False) + elif quant_type == 'moving_average_abs_max': + call_args["moving_rate"] = kwargs.get("moving_rate", 0.9) + fake_quant_map = { - 'abs_max': - lambda: FakeQuantAbsMax(name, quant_bits, dtype, quant_on_weight), - 'moving_average_abs_max': - lambda: FakeQuantMovingAverage(name, moving_rate, quant_bits, dtype) + 'abs_max': FakeQuantAbsMax, + 'moving_average_abs_max': FakeQuantMovingAverage } - return fake_quant_map[quant_type]() + + return fake_quant_map[quant_type](**call_args) class QuantizedConv2D(layers.Layer): @@ -247,11 +256,18 @@ class QuantizedConv2D(layers.Layer): self.bias = getattr(layer, 'bias') # For FakeQuant self._fake_quant_weight = _get_fake_quant_type( - weight_quantize_type, self.weight.name, moving_rate, weight_bits, - self._dtype, True) + weight_quantize_type, + name=self.weight.name, + moving_rate=moving_rate, + quant_bits=weight_bits, + dtype=self._dtype, + quant_on_weight=True) self._fake_quant_input = _get_fake_quant_type( activation_quantize_type, - layer.full_name(), moving_rate, activation_bits, self._dtype, False) + name=layer.full_name(), + moving_rate=moving_rate, + quant_bits=activation_bits, + dtype=self._dtype) def forward(self, input): quant_input = self._fake_quant_input(input) @@ -326,11 +342,18 @@ class QuantizedLinear(layers.Layer): self.bias = getattr(layer, 'bias') # For FakeQuant self._fake_quant_weight = _get_fake_quant_type( - weight_quantize_type, self.weight.name, moving_rate, weight_bits, - self._dtype, True) + weight_quantize_type, + name=self.weight.name, + moving_rate=moving_rate, + quant_bits=weight_bits, + dtype=self._dtype, + quant_on_weight=True) self._fake_quant_input = _get_fake_quant_type( activation_quantize_type, - layer.full_name(), moving_rate, activation_bits, self._dtype, False) + name=layer.full_name(), + moving_rate=moving_rate, + quant_bits=activation_bits, + dtype=self._dtype) def forward(self, input): quant_input = self._fake_quant_input(input) -- GitLab