未验证 提交 7858d332 编写于 作者: C cc 提交者: GitHub

[dygraph qat] change default config and fix bug (#34047)

上级 1f28968b
......@@ -46,8 +46,8 @@ class ImperativePTQ(object):
Args:
quant_config(PTQConfig): the config of post training quantization.
The config has weight_quantizer and activation_quantizer.
In default, the weight_quantizer and activation_quantizer are
AbsmaxQuantizer.
In default, the weight_quantizer is PerChannelAbsmaxQuantizer
and the activation_quantizer is KLQuantizer.
"""
super(ImperativePTQ, self).__init__()
......@@ -70,9 +70,9 @@ class ImperativePTQ(object):
"The model must be the instance of paddle.nn.Layer."
if not inplace:
new_model = copy.deepcopy(model)
model = copy.deepcopy(model)
for name, layer in new_model.named_sublayers():
for name, layer in model.named_sublayers():
if PTQRegistry.is_supported_layer(layer) \
and utils.is_leaf_layer(layer) \
and not self._is_skip_layer(layer):
......@@ -90,13 +90,13 @@ class ImperativePTQ(object):
layer._forward_post_hooks.move_to_end(
quant_hook_handle._hook_id, last=False)
return new_model
return model
def save_quantized_model(self, model, path, input_spec=None, **config):
"""
1. Convert the quantized model
2. Call jit.save to save the inference model
3. Load and postprocess the inference model.
3. Post process the inference model.
Args:
model (Layer): The model to be saved.
......@@ -207,8 +207,19 @@ class ImperativePTQ(object):
assert isinstance(model, paddle.nn.Layer), \
"The input model must be the instance of paddle.nn.Layer."
total_num = 0
cur_num = 0
for name, sub_layer in model.named_sublayers():
if self._is_quant_layer(sub_layer):
total_num += 1
for name, sub_layer in model.named_sublayers():
if self._is_quant_layer(sub_layer):
cur_num += 1
if cur_num % 5 == 0:
_logger.info("Process the %s / %s layer" %
(cur_num, total_num))
quant_config = sub_layer._quant_config
if quant_config.enable_in_act_quantizer:
......
......@@ -53,4 +53,4 @@ class PTQConfig(object):
self.enable_in_act_quantizer = False
default_ptq_config = PTQConfig(AbsmaxQuantizer(), AbsmaxQuantizer())
default_ptq_config = PTQConfig(KLQuantizer(), PerChannelAbsmaxQuantizer())
......@@ -82,7 +82,8 @@ class TestImperativePTQ(unittest.TestCase):
return data_cache_folder
def set_vars(self):
self.ptq = ImperativePTQ(default_ptq_config)
config = PTQConfig(AbsmaxQuantizer(), AbsmaxQuantizer())
self.ptq = ImperativePTQ(config)
self.batch_num = 10
self.batch_size = 10
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册