未验证 提交 6566b8f5 编写于 作者: G Guanghua Yu 提交者: GitHub

fix dygraph new format problem export in QAT (#47023)

上级 f4ea771d
...@@ -72,7 +72,8 @@ class ImperativeQuantAware(object): ...@@ -72,7 +72,8 @@ class ImperativeQuantAware(object):
weight_preprocess_layer=None, weight_preprocess_layer=None,
act_preprocess_layer=None, act_preprocess_layer=None,
weight_quantize_layer=None, weight_quantize_layer=None,
act_quantize_layer=None): act_quantize_layer=None,
onnx_format=False):
""" """
The constructor for ImperativeQuantAware. The constructor for ImperativeQuantAware.
...@@ -124,6 +125,8 @@ class ImperativeQuantAware(object): ...@@ -124,6 +125,8 @@ class ImperativeQuantAware(object):
activation and returns dequantized activation. activation and returns dequantized activation.
If None, will use quantization op defined by 'activation_quantize_type'. If None, will use quantization op defined by 'activation_quantize_type'.
Default is None. Default is None.
onnx_format (bool, optional): Whether to export the quantized model
with format of ONNX. Default is False.
Note: Note:
If user sets attribute 'skip_quant' to a Layer that support dynamic If user sets attribute 'skip_quant' to a Layer that support dynamic
...@@ -224,7 +227,7 @@ class ImperativeQuantAware(object): ...@@ -224,7 +227,7 @@ class ImperativeQuantAware(object):
self._quantize_inputs = ImperativeQuantizeInputs(**kwargs) self._quantize_inputs = ImperativeQuantizeInputs(**kwargs)
self._quantize_outputs = ImperativeQuantizeOutputs( self._quantize_outputs = ImperativeQuantizeOutputs(
moving_rate, activation_bits) moving_rate, activation_bits, onnx_format)
def quantize(self, model): def quantize(self, model):
""" """
...@@ -415,7 +418,7 @@ class ImperativeQuantizeOutputs(object): ...@@ -415,7 +418,7 @@ class ImperativeQuantizeOutputs(object):
Calculate the output scales for target layers. Calculate the output scales for target layers.
""" """
def __init__(self, moving_rate=0.9, activation_bits=8): def __init__(self, moving_rate=0.9, activation_bits=8, onnx_format=False):
""" """
The constructor for ImperativeQuantizeOutputs. The constructor for ImperativeQuantizeOutputs.
...@@ -427,6 +430,7 @@ class ImperativeQuantizeOutputs(object): ...@@ -427,6 +430,7 @@ class ImperativeQuantizeOutputs(object):
super(ImperativeQuantizeOutputs, self).__init__() super(ImperativeQuantizeOutputs, self).__init__()
self._moving_rate = moving_rate self._moving_rate = moving_rate
self._activation_bits = activation_bits self._activation_bits = activation_bits
self._onnx_format = onnx_format
def apply(self, model): def apply(self, model):
""" """
...@@ -463,12 +467,7 @@ class ImperativeQuantizeOutputs(object): ...@@ -463,12 +467,7 @@ class ImperativeQuantizeOutputs(object):
setattr(parent_layer, sub_name, cur_quant_layer) setattr(parent_layer, sub_name, cur_quant_layer)
def save_quantized_model(self, def save_quantized_model(self, model, path, input_spec=None, **config):
model,
path,
input_spec=None,
onnx_format=False,
**config):
""" """
Save the quantized model for the inference. Save the quantized model for the inference.
...@@ -481,8 +480,6 @@ class ImperativeQuantizeOutputs(object): ...@@ -481,8 +480,6 @@ class ImperativeQuantizeOutputs(object):
InputSpec or example Tensor. If None, all input variables of InputSpec or example Tensor. If None, all input variables of
the original Layer's forward method would be the inputs of the original Layer's forward method would be the inputs of
the saved model. Default None. the saved model. Default None.
onnx_format (bool, optional): Whether to export the quantized model
with format of ONNX. Default is False.
**config (dict, optional): Other save configuration options for **config (dict, optional): Other save configuration options for
compatibility. We do not recommend using these configurations, compatibility. We do not recommend using these configurations,
they may be removed in the future. If not necessary, DO NOT use they may be removed in the future. If not necessary, DO NOT use
...@@ -523,7 +520,7 @@ class ImperativeQuantizeOutputs(object): ...@@ -523,7 +520,7 @@ class ImperativeQuantizeOutputs(object):
model_filename=model_filename, model_filename=model_filename,
params_filename=params_filename)) params_filename=params_filename))
if not onnx_format: if not self._onnx_format:
self._gather_scales(infer_program, scope, fetch_targets) self._gather_scales(infer_program, scope, fetch_targets)
# Remove `moving_average_abs_max_scale` node in sub graphs. # Remove `moving_average_abs_max_scale` node in sub graphs.
...@@ -542,10 +539,14 @@ class ImperativeQuantizeOutputs(object): ...@@ -542,10 +539,14 @@ class ImperativeQuantizeOutputs(object):
graph = IrGraph(core.Graph(infer_program.desc), for_test=False) graph = IrGraph(core.Graph(infer_program.desc), for_test=False)
transform_pass = ReplaceFakeQuantDequantPass( transform_pass = ReplaceFakeQuantDequantPass(
scope, place, quant_bits=self._activation_bits) scope, place, quant_bits=self._activation_bits)
transform_pass.apply(graph) for sub_graph in graph.all_sub_graphs():
sub_graph._for_test = True
transform_pass.apply(sub_graph)
quant_weight_pass = QuantWeightPass(scope, place) quant_weight_pass = QuantWeightPass(scope, place)
quant_weight_pass.apply(graph) for sub_graph in graph.all_sub_graphs():
sub_graph._for_test = True
quant_weight_pass.apply(sub_graph)
infer_program = graph.to_program() infer_program = graph.to_program()
...@@ -567,18 +568,24 @@ class ImperativeQuantizeOutputs(object): ...@@ -567,18 +568,24 @@ class ImperativeQuantizeOutputs(object):
""" """
Whether the layer needs to calculate output scales. Whether the layer needs to calculate output scales.
""" """
# exclude fake_quant ops in quant_layers file
if not isinstance(layer, dygraph.Layer):
return False
if self._onnx_format:
return True if isinstance(layer, tuple(
utils.fake_quant_wrap_layers)) else False
flag = False flag = False
if isinstance(layer, dygraph.Layer): if utils.is_leaf_layer(layer) and \
# exclude fake_quant ops in quant_layers file not isinstance(layer, tuple(utils.fake_quant_leaf_layers)):
if utils.is_leaf_layer(layer) and \ flag = True
not isinstance(layer, tuple(utils.fake_quant_leaf_layers)):
flag = True
if isinstance(layer, tuple(utils.fake_quant_wrap_layers)): if isinstance(layer, tuple(utils.fake_quant_wrap_layers)):
flag = True flag = True
if isinstance(layer, paddle.nn.quant.FloatFunctionalLayer): if isinstance(layer, paddle.nn.quant.FloatFunctionalLayer):
flag = True flag = True
return flag return flag
......
...@@ -59,6 +59,7 @@ _fake_dequant_op_list = [ ...@@ -59,6 +59,7 @@ _fake_dequant_op_list = [
_fake_quant_dequant_op_list = [ _fake_quant_dequant_op_list = [
'fake_quantize_dequantize_moving_average_abs_max', 'fake_quantize_dequantize_moving_average_abs_max',
"fake_channel_wise_quantize_dequantize_abs_max", "fake_channel_wise_quantize_dequantize_abs_max",
"fake_quantize_dequantize_abs_max",
] ]
_conv_ops = ['conv2d', 'depthwise_conv2d', 'conv2d_transpose'] _conv_ops = ['conv2d', 'depthwise_conv2d', 'conv2d_transpose']
......
...@@ -334,9 +334,11 @@ def quant_tensor(x, scale, quant_axis=0, weight_bits=8, onnx_format=False): ...@@ -334,9 +334,11 @@ def quant_tensor(x, scale, quant_axis=0, weight_bits=8, onnx_format=False):
x[x < -scale] = -scale x[x < -scale] = -scale
return x return x
assert quant_axis in [0, 1], 'quant_axis should be 0 or 1 for now.'
bnt = (1 << (weight_bits - 1)) - 1 bnt = (1 << (weight_bits - 1)) - 1
if isinstance(scale, list) and len(scale) == 1:
scale = scale[0]
if isinstance(scale, list): if isinstance(scale, list):
assert quant_axis in [0, 1], 'quant_axis should be 0 or 1 for now.'
for i, s in enumerate(scale): for i, s in enumerate(scale):
if s == 0.0: if s == 0.0:
s = 1e-8 s = 1e-8
......
...@@ -66,7 +66,8 @@ class TestImperativeQat(unittest.TestCase): ...@@ -66,7 +66,8 @@ class TestImperativeQat(unittest.TestCase):
imperative_qat = ImperativeQuantAware( imperative_qat = ImperativeQuantAware(
weight_quantize_type=self.weight_quantize_type, weight_quantize_type=self.weight_quantize_type,
activation_quantize_type=self.activation_quantize_type, activation_quantize_type=self.activation_quantize_type,
fuse_conv_bn=self.fuse_conv_bn) fuse_conv_bn=self.fuse_conv_bn,
onnx_format=self.onnx_format)
with fluid.dygraph.guard(): with fluid.dygraph.guard():
# For CI coverage # For CI coverage
...@@ -185,8 +186,7 @@ class TestImperativeQat(unittest.TestCase): ...@@ -185,8 +186,7 @@ class TestImperativeQat(unittest.TestCase):
input_spec=[ input_spec=[
paddle.static.InputSpec(shape=[None, 1, 28, 28], paddle.static.InputSpec(shape=[None, 1, 28, 28],
dtype='float32') dtype='float32')
], ])
onnx_format=self.onnx_format)
print('Quantized model saved in %s' % tmpdir) print('Quantized model saved in %s' % tmpdir)
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册