未验证 提交 12b9b03e 编写于 作者: G Guanghua Yu 提交者: GitHub

[cherry-pick] update dygraph PTQ export_model api (#47415)

* update dygraph PTQ export_model api

* remove postprocess
上级 df64e790
......@@ -31,9 +31,9 @@ from .ptq_registry import PTQRegistry
__all__ = ['ImperativePTQ']
_logger = get_logger(__name__,
logging.INFO,
fmt='%(asctime)s-%(levelname)s: %(message)s')
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)
class ImperativePTQ(object):
......@@ -75,17 +75,20 @@ class ImperativePTQ(object):
Return
quantized_model(paddle.nn.Layer): The quantized model.
"""
assert isinstance(model, paddle.nn.Layer), \
"The model must be the instance of paddle.nn.Layer."
assert isinstance(
model, paddle.nn.Layer
), "The model must be the instance of paddle.nn.Layer."
if not inplace:
model = copy.deepcopy(model)
if fuse:
model.eval()
model = fuse_utils.fuse_layers(model, fuse_list)
for name, layer in model.named_sublayers():
if PTQRegistry.is_supported_layer(layer) \
and utils.is_leaf_layer(layer) \
and not self._is_skip_layer(layer):
if (
PTQRegistry.is_supported_layer(layer)
and utils.is_leaf_layer(layer)
and not self._is_skip_layer(layer)
):
# Add quant config
quant_config = copy.deepcopy(self._quant_config)
......@@ -98,7 +101,8 @@ class ImperativePTQ(object):
quant_hook_handle = layer.register_forward_post_hook(hook)
quant_config.quant_hook_handle = quant_hook_handle
layer._forward_post_hooks.move_to_end(
quant_hook_handle._hook_id, last=False)
quant_hook_handle._hook_id, last=False
)
return model
......@@ -110,14 +114,14 @@ class ImperativePTQ(object):
Args:
model (Layer): The model to be saved.
path (str): The path prefix to save model. The format is
path (str): The path prefix to save model. The format is
``dirname/file_prefix`` or ``file_prefix``.
input_spec (list[InputSpec|Tensor], optional): Describes the input
of the saved model's forward method, which can be described by
InputSpec or example Tensor. If None, all input variables of
InputSpec or example Tensor. If None, all input variables of
the original Layer's forward method would be the inputs of
the saved model. Default None.
**configs (dict, optional): Other save configuration options for
**config (dict, optional): Other save configuration options for
compatibility. We do not recommend using these configurations,
they may be removed in the future. If not necessary, DO NOT use
them. Default None.
......@@ -125,16 +129,17 @@ class ImperativePTQ(object):
(1) output_spec (list[Tensor]): Selects the output targets of
the saved model. By default, all return variables of original
Layer's forward method are kept as the output of the saved model.
If the provided ``output_spec`` list is not all output variables,
If the provided ``output_spec`` list is not all output variables,
the saved model will be pruned according to the given
``output_spec`` list.
``output_spec`` list.
Returns:
None
"""
assert isinstance(model, paddle.nn.Layer), \
"The model must be the instance of paddle.nn.Layer."
assert isinstance(
model, paddle.nn.Layer
), "The model must be the instance of paddle.nn.Layer."
# Convert and save dygraph quantized model
self._convert(model)
......@@ -156,12 +161,16 @@ class ImperativePTQ(object):
model_filename = basename + INFER_MODEL_SUFFIX
params_filename = basename + INFER_PARAMS_SUFFIX
[infer_program, feed_target_names,
fetch_targets] = (paddle.fluid.io.load_inference_model(
dirname=dirname,
executor=exe,
model_filename=model_filename,
params_filename=params_filename))
[
infer_program,
feed_target_names,
fetch_targets,
] = paddle.fluid.io.load_inference_model(
dirname=dirname,
executor=exe,
model_filename=model_filename,
params_filename=params_filename,
)
# Process inference program
self._clean_up(infer_program)
......@@ -169,13 +178,15 @@ class ImperativePTQ(object):
self._remove_scale_op(infer_program)
# Save final program
paddle.fluid.io.save_inference_model(dirname=dirname,
feeded_var_names=feed_target_names,
target_vars=fetch_targets,
executor=exe,
main_program=infer_program.clone(),
model_filename=model_filename,
params_filename=params_filename)
paddle.fluid.io.save_inference_model(
dirname=dirname,
feeded_var_names=feed_target_names,
target_vars=fetch_targets,
executor=exe,
main_program=infer_program.clone(),
model_filename=model_filename,
params_filename=params_filename,
)
if is_dynamic_mode:
paddle.disable_static()
......@@ -213,8 +224,9 @@ class ImperativePTQ(object):
Returns:
None
"""
assert isinstance(model, paddle.nn.Layer), \
"The input model must be the instance of paddle.nn.Layer."
assert isinstance(
model, paddle.nn.Layer
), "The input model must be the instance of paddle.nn.Layer."
total_num = 0
cur_num = 0
......@@ -226,8 +238,9 @@ class ImperativePTQ(object):
if self._is_quant_layer(sub_layer):
cur_num += 1
if cur_num % 5 == 0:
_logger.info("Process the %s / %s layer" %
(cur_num, total_num))
_logger.info(
"Process the %s / %s layer" % (cur_num, total_num)
)
quant_config = sub_layer._quant_config
......@@ -236,7 +249,7 @@ class ImperativePTQ(object):
quant_config.out_act_quantizer.cal_thresholds()
if PTQRegistry.is_simulated_quant_layer(sub_layer):
weights = (sub_layer.weight, )
weights = (sub_layer.weight,)
quant_config.wt_quantizer.sample_data(sub_layer, weights)
quant_config.wt_quantizer.cal_thresholds()
......@@ -250,18 +263,25 @@ class ImperativePTQ(object):
Returns:
None
"""
assert isinstance(sub_layer, paddle.nn.Layer), \
"The input model must be the instance of paddle.nn.Layer."
assert isinstance(
sub_layer, paddle.nn.Layer
), "The input model must be the instance of paddle.nn.Layer."
layer_info = PTQRegistry.layer_info(sub_layer)
output_names = layer_info.output_names
output_thresholds = quant_config.out_act_quantizer.thresholds
assert len(output_names) == 1
assert len(output_thresholds) == 1
save_name = output_names[0] + str(0) + "_threshold"
sub_layer._set_op_attrs({save_name: output_thresholds[0]})
sub_layer._set_op_attrs({"out_threshold": output_thresholds[0]})
if len(output_thresholds) == 1:
save_name = output_names[0] + str(0) + "_threshold"
sub_layer._set_op_attrs({save_name: output_thresholds[0]})
sub_layer._set_op_attrs({"out_threshold": output_thresholds[0]})
else:
_logger.warning(
"output_thresholds shape of {} need to be 1, but received {}".format(
output_names[0], len(output_thresholds)
)
)
def _wrap_simulated_layers(self, model):
"""
......@@ -272,12 +292,14 @@ class ImperativePTQ(object):
Returns:
None
"""
assert isinstance(model, paddle.nn.Layer), \
"The input model must be the instance of paddle.nn.Layer."
assert isinstance(
model, paddle.nn.Layer
), "The input model must be the instance of paddle.nn.Layer."
for name, sub_layer in model.named_sublayers():
if self._is_quant_layer(sub_layer) \
and PTQRegistry.is_simulated_quant_layer(sub_layer):
if self._is_quant_layer(
sub_layer
) and PTQRegistry.is_simulated_quant_layer(sub_layer):
quant_config = sub_layer._quant_config
assert quant_config.enable_in_act_quantizer == True
......@@ -303,36 +325,44 @@ class ImperativePTQ(object):
"activation_bits": in_act_quantizer.quant_bits,
}
quant_layer = quant_layers.__dict__[quant_layer_name](sub_layer,
**kwargs)
quant_layer = quant_layers.__dict__[quant_layer_name](
sub_layer, **kwargs
)
# save the input thresholds
assert hasattr(quant_layer, "_fake_quant_input")
assert hasattr(quant_layer._fake_quant_input, "_scale")
assert len(in_act_quantizer.thresholds) == 1
input_threshold = np.array([in_act_quantizer.thresholds[0]],
dtype=np.float32)
quant_layer._fake_quant_input._scale.set_value(input_threshold)
if len(in_act_quantizer.thresholds) == 1:
input_threshold = np.array(
[in_act_quantizer.thresholds[0]], dtype=np.float32
)
quant_layer._fake_quant_input._scale.set_value(
input_threshold
)
assert hasattr(quant_layer, "_fake_quant_weight")
assert hasattr(quant_layer._fake_quant_weight, "_scale")
assert len(wt_quantizer.thresholds) == 1
weight_threshold = wt_quantizer.thresholds[0]
if isinstance(weight_threshold, list):
weight_threshold = np.array(weight_threshold,
dtype=np.float32)
weight_threshold = np.array(
weight_threshold, dtype=np.float32
)
else:
weight_threshold = np.array([weight_threshold],
dtype=np.float32)
weight_threshold = np.array(
[weight_threshold], dtype=np.float32
)
quant_layer._fake_quant_weight._scale.set_value(
weight_threshold)
weight_threshold
)
# save the output thresholds
self._save_output_thresholds(quant_layer, quant_config)
# replace the layer
parent_layer, sub_name = \
utils.find_parent_layer_and_sub_name(model, name)
parent_layer, sub_name = utils.find_parent_layer_and_sub_name(
model, name
)
setattr(parent_layer, sub_name, quant_layer)
def _gather_input_thresholds(self, program, scope):
......@@ -351,30 +381,37 @@ class ImperativePTQ(object):
if previous_op is None:
continue
if "quantize_dequantize" in previous_op.type or \
previous_op.type == "moving_average_abs_max_scale":
if (
"quantize_dequantize" in previous_op.type
or previous_op.type == "moving_average_abs_max_scale"
):
attr_name = previous_op.output('OutScale')[0]
in_threshold = utils.load_variable_data(scope, attr_name)
in_threshold = utils.fp_numpy_to_naive(in_threshold)
argname, index = utils._get_input_name_index(
op, in_var_name)
op._set_attr(argname + str(index) + "_threshold",
in_threshold)
op, in_var_name
)
op._set_attr(
argname + str(index) + "_threshold", in_threshold
)
op._set_attr("with_quant_attr", True)
else:
for out_var_name in utils._get_op_output_var_names(
previous_op):
previous_op
):
if out_var_name != in_var_name:
continue
argname, index = utils._get_output_name_index(
previous_op, out_var_name)
previous_op, out_var_name
)
attr_name = argname + str(index) + "_threshold"
if not previous_op.has_attr(attr_name):
continue
threshold = previous_op.attr(attr_name)
argname, index = utils._get_input_name_index(
op, in_var_name)
op, in_var_name
)
attr_name = argname + str(index) + "_threshold"
op._set_attr(attr_name, threshold)
op._set_attr("with_quant_attr", True)
......@@ -390,8 +427,11 @@ class ImperativePTQ(object):
"""
def _helper(op, next_op, old_attr_name, new_attr_name):
if op.has_attr(old_attr_name) and next_op.has_attr(old_attr_name) \
and op.attr(old_attr_name) == next_op.attr(old_attr_name):
if (
op.has_attr(old_attr_name)
and next_op.has_attr(old_attr_name)
and op.attr(old_attr_name) == next_op.attr(old_attr_name)
):
threshold = op.attr(old_attr_name)
op._remove_attr(old_attr_name)
next_op._remove_attr(old_attr_name)
......@@ -417,8 +457,8 @@ class ImperativePTQ(object):
old_attr_name = argname + str(index) + "_threshold"
argname, index = utils._get_output_name_index(
next_op,
next_op.output("Out")[0])
next_op, next_op.output("Out")[0]
)
new_attr_name = argname + str(index) + "_threshold"
_helper(op, next_op, old_attr_name, new_attr_name)
......
......@@ -41,6 +41,7 @@ PTQ_LAYERS_INFO = [
LayerInfo(paddle.nn.ReLU, ['X'], [], ['Out']),
LayerInfo(paddle.nn.ReLU6, ['X'], [], ['Out']),
LayerInfo(paddle.nn.Hardswish, ['X'], [], ['Out']),
LayerInfo(paddle.nn.Swish, ['X'], [], ['Out']),
LayerInfo(paddle.nn.Sigmoid, ['X'], [], ['Out']),
LayerInfo(paddle.nn.Softmax, ['X'], [], ['Out']),
LayerInfo(paddle.nn.Tanh, ['X'], [], ['Out']),
......@@ -48,10 +49,15 @@ PTQ_LAYERS_INFO = [
]
QUANT_LAYERS_INFO = [
LayerInfo(paddle.nn.quant.quant_layers.QuantizedConv2D, ['Input'],
['Filter'], ['Output']),
LayerInfo(paddle.nn.quant.quant_layers.QuantizedLinear, ['X'], ['Y'],
['Out']),
LayerInfo(
paddle.nn.quant.quant_layers.QuantizedConv2D,
['Input'],
['Filter'],
['Output'],
),
LayerInfo(
paddle.nn.quant.quant_layers.QuantizedLinear, ['X'], ['Y'], ['Out']
),
]
SIMULATED_LAYERS = [paddle.nn.Conv2D, paddle.nn.Linear]
......@@ -61,6 +67,7 @@ class PTQRegistry(object):
"""
Register the supported layers for PTQ and provide layers info.
"""
supported_layers_map = {}
registered_layers_map = {}
is_inited = False
......@@ -89,8 +96,9 @@ class PTQRegistry(object):
flag(bool): Whther the layer is supported.
"""
cls._init()
return layer in cls.supported_layers_map or \
isinstance(layer, tuple(cls.supported_layers_map.keys()))
return layer in cls.supported_layers_map or isinstance(
layer, tuple(cls.supported_layers_map.keys())
)
@classmethod
def is_registered_layer(cls, layer):
......@@ -102,8 +110,9 @@ class PTQRegistry(object):
flag(bool): Wether the layer is register layer_info.
"""
cls._init()
return layer in cls.registered_layers_map or \
isinstance(layer, tuple(cls.registered_layers_map.keys()))
return layer in cls.registered_layers_map or isinstance(
layer, tuple(cls.registered_layers_map.keys())
)
@classmethod
def is_simulated_quant_layer(cls, layer):
......@@ -114,8 +123,9 @@ class PTQRegistry(object):
Returns:
flag(bool): Whther the layer is supported.
"""
return layer in SIMULATED_LAYERS or \
isinstance(layer, tuple(SIMULATED_LAYERS))
return layer in SIMULATED_LAYERS or isinstance(
layer, tuple(SIMULATED_LAYERS)
)
@classmethod
def layer_info(cls, layer):
......@@ -126,8 +136,9 @@ class PTQRegistry(object):
Returns:
layer_info(LayerInfo): The layer info of the input layer.
"""
assert cls.is_registered_layer(layer), \
"The input layer is not register."
assert cls.is_registered_layer(
layer
), "The input layer is not register."
for layer_key, layer_info in cls.registered_layers_map.items():
if layer == layer_key or isinstance(layer, layer_key):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册