未验证 提交 ece74c4c 编写于 作者: Z Zhen Wang 提交者: GitHub

Update the _get_fake_quant_type definition in imperative QAT. (#27222)

上级 f6be5989
...@@ -192,7 +192,6 @@ class ImperativeQuantAware(object): ...@@ -192,7 +192,6 @@ class ImperativeQuantAware(object):
assert len(input_dtype) == len( assert len(input_dtype) == len(
feed), "The length of input_shape should be equal to feed's." feed), "The length of input_shape should be equal to feed's."
prog_trans = dygraph.ProgramTranslator()
with dygraph.guard(): with dygraph.guard():
model.eval() model.eval()
input_vars = [] input_vars = []
......
...@@ -209,15 +209,24 @@ class FakeQuantAbsMax(layers.Layer): ...@@ -209,15 +209,24 @@ class FakeQuantAbsMax(layers.Layer):
return quant_out return quant_out
def _get_fake_quant_type(quant_type, name, moving_rate, quant_bits, dtype, def _get_fake_quant_type(quant_type, **kwargs):
quant_on_weight): call_args = {
"name": kwargs.get("name", None),
"quant_bits": kwargs.get("quant_bits", 8),
"dtype": kwargs.get("dtype", "float32")
}
if quant_type == 'abs_max':
call_args["quant_on_weight"] = kwargs.get("quant_on_weight", False)
elif quant_type == 'moving_average_abs_max':
call_args["moving_rate"] = kwargs.get("moving_rate", 0.9)
fake_quant_map = { fake_quant_map = {
'abs_max': 'abs_max': FakeQuantAbsMax,
lambda: FakeQuantAbsMax(name, quant_bits, dtype, quant_on_weight), 'moving_average_abs_max': FakeQuantMovingAverage
'moving_average_abs_max':
lambda: FakeQuantMovingAverage(name, moving_rate, quant_bits, dtype)
} }
return fake_quant_map[quant_type]()
return fake_quant_map[quant_type](**call_args)
class QuantizedConv2D(layers.Layer): class QuantizedConv2D(layers.Layer):
...@@ -247,11 +256,18 @@ class QuantizedConv2D(layers.Layer): ...@@ -247,11 +256,18 @@ class QuantizedConv2D(layers.Layer):
self.bias = getattr(layer, 'bias') self.bias = getattr(layer, 'bias')
# For FakeQuant # For FakeQuant
self._fake_quant_weight = _get_fake_quant_type( self._fake_quant_weight = _get_fake_quant_type(
weight_quantize_type, self.weight.name, moving_rate, weight_bits, weight_quantize_type,
self._dtype, True) name=self.weight.name,
moving_rate=moving_rate,
quant_bits=weight_bits,
dtype=self._dtype,
quant_on_weight=True)
self._fake_quant_input = _get_fake_quant_type( self._fake_quant_input = _get_fake_quant_type(
activation_quantize_type, activation_quantize_type,
layer.full_name(), moving_rate, activation_bits, self._dtype, False) name=layer.full_name(),
moving_rate=moving_rate,
quant_bits=activation_bits,
dtype=self._dtype)
def forward(self, input): def forward(self, input):
quant_input = self._fake_quant_input(input) quant_input = self._fake_quant_input(input)
...@@ -326,11 +342,18 @@ class QuantizedLinear(layers.Layer): ...@@ -326,11 +342,18 @@ class QuantizedLinear(layers.Layer):
self.bias = getattr(layer, 'bias') self.bias = getattr(layer, 'bias')
# For FakeQuant # For FakeQuant
self._fake_quant_weight = _get_fake_quant_type( self._fake_quant_weight = _get_fake_quant_type(
weight_quantize_type, self.weight.name, moving_rate, weight_bits, weight_quantize_type,
self._dtype, True) name=self.weight.name,
moving_rate=moving_rate,
quant_bits=weight_bits,
dtype=self._dtype,
quant_on_weight=True)
self._fake_quant_input = _get_fake_quant_type( self._fake_quant_input = _get_fake_quant_type(
activation_quantize_type, activation_quantize_type,
layer.full_name(), moving_rate, activation_bits, self._dtype, False) name=layer.full_name(),
moving_rate=moving_rate,
quant_bits=activation_bits,
dtype=self._dtype)
def forward(self, input): def forward(self, input):
quant_input = self._fake_quant_input(input) quant_input = self._fake_quant_input(input)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册