未验证 提交 5429d145 编写于 作者: G Guanghua Yu 提交者: GitHub

update dygraph PTQ export_model api (#47284)

上级 b68c4a1e
...@@ -121,7 +121,7 @@ class ImperativePTQ(object): ...@@ -121,7 +121,7 @@ class ImperativePTQ(object):
InputSpec or example Tensor. If None, all input variables of InputSpec or example Tensor. If None, all input variables of
the original Layer's forward method would be the inputs of the original Layer's forward method would be the inputs of
the saved model. Default None. the saved model. Default None.
**configs (dict, optional): Other save configuration options for **config (dict, optional): Other save configuration options for
compatibility. We do not recommend using these configurations, compatibility. We do not recommend using these configurations,
they may be removed in the future. If not necessary, DO NOT use they may be removed in the future. If not necessary, DO NOT use
them. Default None. them. Default None.
...@@ -140,11 +140,15 @@ class ImperativePTQ(object): ...@@ -140,11 +140,15 @@ class ImperativePTQ(object):
assert isinstance( assert isinstance(
model, paddle.nn.Layer model, paddle.nn.Layer
), "The model must be the instance of paddle.nn.Layer." ), "The model must be the instance of paddle.nn.Layer."
is_postprocess = config.get('postprocess', False)
config.pop('postprocess', None)
# Convert and save dygraph quantized model # Convert and save dygraph quantized model
self._convert(model) self._convert(model)
paddle.jit.save(layer=model, path=path, input_spec=input_spec, **config) paddle.jit.save(layer=model, path=path, input_spec=input_spec, **config)
if not is_postprocess:
return
# Load inference program # Load inference program
is_dynamic_mode = False is_dynamic_mode = False
...@@ -272,10 +276,16 @@ class ImperativePTQ(object): ...@@ -272,10 +276,16 @@ class ImperativePTQ(object):
output_names = layer_info.output_names output_names = layer_info.output_names
output_thresholds = quant_config.out_act_quantizer.thresholds output_thresholds = quant_config.out_act_quantizer.thresholds
assert len(output_names) == 1 assert len(output_names) == 1
assert len(output_thresholds) == 1 if len(output_thresholds) == 1:
save_name = output_names[0] + str(0) + "_threshold" save_name = output_names[0] + str(0) + "_threshold"
sub_layer._set_op_attrs({save_name: output_thresholds[0]}) sub_layer._set_op_attrs({save_name: output_thresholds[0]})
sub_layer._set_op_attrs({"out_threshold": output_thresholds[0]}) sub_layer._set_op_attrs({"out_threshold": output_thresholds[0]})
else:
_logger.warning(
"output_thresholds shape of {} need to be 1, but received {}".format(
output_names[0], len(output_thresholds)
)
)
def _wrap_simulated_layers(self, model): def _wrap_simulated_layers(self, model):
""" """
...@@ -326,11 +336,13 @@ class ImperativePTQ(object): ...@@ -326,11 +336,13 @@ class ImperativePTQ(object):
# save the input thresholds # save the input thresholds
assert hasattr(quant_layer, "_fake_quant_input") assert hasattr(quant_layer, "_fake_quant_input")
assert hasattr(quant_layer._fake_quant_input, "_scale") assert hasattr(quant_layer._fake_quant_input, "_scale")
assert len(in_act_quantizer.thresholds) == 1 if len(in_act_quantizer.thresholds) == 1:
input_threshold = np.array( input_threshold = np.array(
[in_act_quantizer.thresholds[0]], dtype=np.float32 [in_act_quantizer.thresholds[0]], dtype=np.float32
) )
quant_layer._fake_quant_input._scale.set_value(input_threshold) quant_layer._fake_quant_input._scale.set_value(
input_threshold
)
assert hasattr(quant_layer, "_fake_quant_weight") assert hasattr(quant_layer, "_fake_quant_weight")
assert hasattr(quant_layer._fake_quant_weight, "_scale") assert hasattr(quant_layer._fake_quant_weight, "_scale")
......
...@@ -41,6 +41,7 @@ PTQ_LAYERS_INFO = [ ...@@ -41,6 +41,7 @@ PTQ_LAYERS_INFO = [
LayerInfo(paddle.nn.ReLU, ['X'], [], ['Out']), LayerInfo(paddle.nn.ReLU, ['X'], [], ['Out']),
LayerInfo(paddle.nn.ReLU6, ['X'], [], ['Out']), LayerInfo(paddle.nn.ReLU6, ['X'], [], ['Out']),
LayerInfo(paddle.nn.Hardswish, ['X'], [], ['Out']), LayerInfo(paddle.nn.Hardswish, ['X'], [], ['Out']),
LayerInfo(paddle.nn.Swish, ['X'], [], ['Out']),
LayerInfo(paddle.nn.Sigmoid, ['X'], [], ['Out']), LayerInfo(paddle.nn.Sigmoid, ['X'], [], ['Out']),
LayerInfo(paddle.nn.Softmax, ['X'], [], ['Out']), LayerInfo(paddle.nn.Softmax, ['X'], [], ['Out']),
LayerInfo(paddle.nn.Tanh, ['X'], [], ['Out']), LayerInfo(paddle.nn.Tanh, ['X'], [], ['Out']),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册