未验证 提交 ece74c4c 编写于 作者: Z Zhen Wang 提交者: GitHub

Update the _get_fake_quant_type definition in imperative QAT. (#27222)

上级 f6be5989
......@@ -192,7 +192,6 @@ class ImperativeQuantAware(object):
assert len(input_dtype) == len(
feed), "The length of input_shape should be equal to feed's."
prog_trans = dygraph.ProgramTranslator()
with dygraph.guard():
model.eval()
input_vars = []
......
......@@ -209,15 +209,24 @@ class FakeQuantAbsMax(layers.Layer):
return quant_out
def _get_fake_quant_type(quant_type, name, moving_rate, quant_bits, dtype,
quant_on_weight):
def _get_fake_quant_type(quant_type, **kwargs):
call_args = {
"name": kwargs.get("name", None),
"quant_bits": kwargs.get("quant_bits", 8),
"dtype": kwargs.get("dtype", "float32")
}
if quant_type == 'abs_max':
call_args["quant_on_weight"] = kwargs.get("quant_on_weight", False)
elif quant_type == 'moving_average_abs_max':
call_args["moving_rate"] = kwargs.get("moving_rate", 0.9)
fake_quant_map = {
'abs_max':
lambda: FakeQuantAbsMax(name, quant_bits, dtype, quant_on_weight),
'moving_average_abs_max':
lambda: FakeQuantMovingAverage(name, moving_rate, quant_bits, dtype)
'abs_max': FakeQuantAbsMax,
'moving_average_abs_max': FakeQuantMovingAverage
}
return fake_quant_map[quant_type]()
return fake_quant_map[quant_type](**call_args)
class QuantizedConv2D(layers.Layer):
......@@ -247,11 +256,18 @@ class QuantizedConv2D(layers.Layer):
self.bias = getattr(layer, 'bias')
# For FakeQuant
self._fake_quant_weight = _get_fake_quant_type(
weight_quantize_type, self.weight.name, moving_rate, weight_bits,
self._dtype, True)
weight_quantize_type,
name=self.weight.name,
moving_rate=moving_rate,
quant_bits=weight_bits,
dtype=self._dtype,
quant_on_weight=True)
self._fake_quant_input = _get_fake_quant_type(
activation_quantize_type,
layer.full_name(), moving_rate, activation_bits, self._dtype, False)
name=layer.full_name(),
moving_rate=moving_rate,
quant_bits=activation_bits,
dtype=self._dtype)
def forward(self, input):
quant_input = self._fake_quant_input(input)
......@@ -326,11 +342,18 @@ class QuantizedLinear(layers.Layer):
self.bias = getattr(layer, 'bias')
# For FakeQuant
self._fake_quant_weight = _get_fake_quant_type(
weight_quantize_type, self.weight.name, moving_rate, weight_bits,
self._dtype, True)
weight_quantize_type,
name=self.weight.name,
moving_rate=moving_rate,
quant_bits=weight_bits,
dtype=self._dtype,
quant_on_weight=True)
self._fake_quant_input = _get_fake_quant_type(
activation_quantize_type,
layer.full_name(), moving_rate, activation_bits, self._dtype, False)
name=layer.full_name(),
moving_rate=moving_rate,
quant_bits=activation_bits,
dtype=self._dtype)
def forward(self, input):
quant_input = self._fake_quant_input(input)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册