未验证 提交 6566b8f5 编写于 作者: G Guanghua Yu 提交者: GitHub

fix dygraph new format problem export in QAT (#47023)

上级 f4ea771d
......@@ -72,7 +72,8 @@ class ImperativeQuantAware(object):
weight_preprocess_layer=None,
act_preprocess_layer=None,
weight_quantize_layer=None,
act_quantize_layer=None):
act_quantize_layer=None,
onnx_format=False):
"""
The constructor for ImperativeQuantAware.
......@@ -124,6 +125,8 @@ class ImperativeQuantAware(object):
activation and returns dequantized activation.
If None, will use quantization op defined by 'activation_quantize_type'.
Default is None.
onnx_format (bool, optional): Whether to export the quantized model
with format of ONNX. Default is False.
Note:
If user sets attribute 'skip_quant' to a Layer that support dynamic
......@@ -224,7 +227,7 @@ class ImperativeQuantAware(object):
self._quantize_inputs = ImperativeQuantizeInputs(**kwargs)
self._quantize_outputs = ImperativeQuantizeOutputs(
moving_rate, activation_bits)
moving_rate, activation_bits, onnx_format)
def quantize(self, model):
"""
......@@ -415,7 +418,7 @@ class ImperativeQuantizeOutputs(object):
Calculate the output scales for target layers.
"""
def __init__(self, moving_rate=0.9, activation_bits=8):
def __init__(self, moving_rate=0.9, activation_bits=8, onnx_format=False):
"""
The constructor for ImperativeQuantizeOutputs.
......@@ -427,6 +430,7 @@ class ImperativeQuantizeOutputs(object):
super(ImperativeQuantizeOutputs, self).__init__()
self._moving_rate = moving_rate
self._activation_bits = activation_bits
self._onnx_format = onnx_format
def apply(self, model):
"""
......@@ -463,12 +467,7 @@ class ImperativeQuantizeOutputs(object):
setattr(parent_layer, sub_name, cur_quant_layer)
def save_quantized_model(self,
model,
path,
input_spec=None,
onnx_format=False,
**config):
def save_quantized_model(self, model, path, input_spec=None, **config):
"""
Save the quantized model for the inference.
......@@ -481,8 +480,6 @@ class ImperativeQuantizeOutputs(object):
InputSpec or example Tensor. If None, all input variables of
the original Layer's forward method would be the inputs of
the saved model. Default None.
onnx_format (bool, optional): Whether to export the quantized model
with format of ONNX. Default is False.
**config (dict, optional): Other save configuration options for
compatibility. We do not recommend using these configurations,
they may be removed in the future. If not necessary, DO NOT use
......@@ -523,7 +520,7 @@ class ImperativeQuantizeOutputs(object):
model_filename=model_filename,
params_filename=params_filename))
if not onnx_format:
if not self._onnx_format:
self._gather_scales(infer_program, scope, fetch_targets)
# Remove `moving_average_abs_max_scale` node in sub graphs.
......@@ -542,10 +539,14 @@ class ImperativeQuantizeOutputs(object):
graph = IrGraph(core.Graph(infer_program.desc), for_test=False)
transform_pass = ReplaceFakeQuantDequantPass(
scope, place, quant_bits=self._activation_bits)
transform_pass.apply(graph)
for sub_graph in graph.all_sub_graphs():
sub_graph._for_test = True
transform_pass.apply(sub_graph)
quant_weight_pass = QuantWeightPass(scope, place)
quant_weight_pass.apply(graph)
for sub_graph in graph.all_sub_graphs():
sub_graph._for_test = True
quant_weight_pass.apply(sub_graph)
infer_program = graph.to_program()
......@@ -567,18 +568,24 @@ class ImperativeQuantizeOutputs(object):
"""
Whether the layer needs to calculate output scales.
"""
# exclude fake_quant ops in quant_layers file
if not isinstance(layer, dygraph.Layer):
return False
if self._onnx_format:
return True if isinstance(layer, tuple(
utils.fake_quant_wrap_layers)) else False
flag = False
if isinstance(layer, dygraph.Layer):
# exclude fake_quant ops in quant_layers file
if utils.is_leaf_layer(layer) and \
not isinstance(layer, tuple(utils.fake_quant_leaf_layers)):
flag = True
if utils.is_leaf_layer(layer) and \
not isinstance(layer, tuple(utils.fake_quant_leaf_layers)):
flag = True
if isinstance(layer, tuple(utils.fake_quant_wrap_layers)):
flag = True
if isinstance(layer, tuple(utils.fake_quant_wrap_layers)):
flag = True
if isinstance(layer, paddle.nn.quant.FloatFunctionalLayer):
flag = True
if isinstance(layer, paddle.nn.quant.FloatFunctionalLayer):
flag = True
return flag
......
......@@ -59,6 +59,7 @@ _fake_dequant_op_list = [
_fake_quant_dequant_op_list = [
'fake_quantize_dequantize_moving_average_abs_max',
"fake_channel_wise_quantize_dequantize_abs_max",
"fake_quantize_dequantize_abs_max",
]
_conv_ops = ['conv2d', 'depthwise_conv2d', 'conv2d_transpose']
......
......@@ -334,9 +334,11 @@ def quant_tensor(x, scale, quant_axis=0, weight_bits=8, onnx_format=False):
x[x < -scale] = -scale
return x
assert quant_axis in [0, 1], 'quant_axis should be 0 or 1 for now.'
bnt = (1 << (weight_bits - 1)) - 1
if isinstance(scale, list) and len(scale) == 1:
scale = scale[0]
if isinstance(scale, list):
assert quant_axis in [0, 1], 'quant_axis should be 0 or 1 for now.'
for i, s in enumerate(scale):
if s == 0.0:
s = 1e-8
......
......@@ -66,7 +66,8 @@ class TestImperativeQat(unittest.TestCase):
imperative_qat = ImperativeQuantAware(
weight_quantize_type=self.weight_quantize_type,
activation_quantize_type=self.activation_quantize_type,
fuse_conv_bn=self.fuse_conv_bn)
fuse_conv_bn=self.fuse_conv_bn,
onnx_format=self.onnx_format)
with fluid.dygraph.guard():
# For CI coverage
......@@ -185,8 +186,7 @@ class TestImperativeQat(unittest.TestCase):
input_spec=[
paddle.static.InputSpec(shape=[None, 1, 28, 28],
dtype='float32')
],
onnx_format=self.onnx_format)
])
print('Quantized model saved in %s' % tmpdir)
if core.is_compiled_with_cuda():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册