未验证 提交 12b9b03e 编写于 作者: G Guanghua Yu 提交者: GitHub

[cherry-pick] update dygraph PTQ export_model api (#47415)

* update dygraph PTQ export_model api

* remove postprocess
上级 df64e790
...@@ -31,9 +31,9 @@ from .ptq_registry import PTQRegistry ...@@ -31,9 +31,9 @@ from .ptq_registry import PTQRegistry
__all__ = ['ImperativePTQ'] __all__ = ['ImperativePTQ']
_logger = get_logger(__name__, _logger = get_logger(
logging.INFO, __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
fmt='%(asctime)s-%(levelname)s: %(message)s') )
class ImperativePTQ(object): class ImperativePTQ(object):
...@@ -75,17 +75,20 @@ class ImperativePTQ(object): ...@@ -75,17 +75,20 @@ class ImperativePTQ(object):
Return Return
quantized_model(paddle.nn.Layer): The quantized model. quantized_model(paddle.nn.Layer): The quantized model.
""" """
assert isinstance(model, paddle.nn.Layer), \ assert isinstance(
"The model must be the instance of paddle.nn.Layer." model, paddle.nn.Layer
), "The model must be the instance of paddle.nn.Layer."
if not inplace: if not inplace:
model = copy.deepcopy(model) model = copy.deepcopy(model)
if fuse: if fuse:
model.eval() model.eval()
model = fuse_utils.fuse_layers(model, fuse_list) model = fuse_utils.fuse_layers(model, fuse_list)
for name, layer in model.named_sublayers(): for name, layer in model.named_sublayers():
if PTQRegistry.is_supported_layer(layer) \ if (
and utils.is_leaf_layer(layer) \ PTQRegistry.is_supported_layer(layer)
and not self._is_skip_layer(layer): and utils.is_leaf_layer(layer)
and not self._is_skip_layer(layer)
):
# Add quant config # Add quant config
quant_config = copy.deepcopy(self._quant_config) quant_config = copy.deepcopy(self._quant_config)
...@@ -98,7 +101,8 @@ class ImperativePTQ(object): ...@@ -98,7 +101,8 @@ class ImperativePTQ(object):
quant_hook_handle = layer.register_forward_post_hook(hook) quant_hook_handle = layer.register_forward_post_hook(hook)
quant_config.quant_hook_handle = quant_hook_handle quant_config.quant_hook_handle = quant_hook_handle
layer._forward_post_hooks.move_to_end( layer._forward_post_hooks.move_to_end(
quant_hook_handle._hook_id, last=False) quant_hook_handle._hook_id, last=False
)
return model return model
...@@ -110,14 +114,14 @@ class ImperativePTQ(object): ...@@ -110,14 +114,14 @@ class ImperativePTQ(object):
Args: Args:
model (Layer): The model to be saved. model (Layer): The model to be saved.
path (str): The path prefix to save model. The format is path (str): The path prefix to save model. The format is
``dirname/file_prefix`` or ``file_prefix``. ``dirname/file_prefix`` or ``file_prefix``.
input_spec (list[InputSpec|Tensor], optional): Describes the input input_spec (list[InputSpec|Tensor], optional): Describes the input
of the saved model's forward method, which can be described by of the saved model's forward method, which can be described by
InputSpec or example Tensor. If None, all input variables of InputSpec or example Tensor. If None, all input variables of
the original Layer's forward method would be the inputs of the original Layer's forward method would be the inputs of
the saved model. Default None. the saved model. Default None.
**configs (dict, optional): Other save configuration options for **config (dict, optional): Other save configuration options for
compatibility. We do not recommend using these configurations, compatibility. We do not recommend using these configurations,
they may be removed in the future. If not necessary, DO NOT use they may be removed in the future. If not necessary, DO NOT use
them. Default None. them. Default None.
...@@ -125,16 +129,17 @@ class ImperativePTQ(object): ...@@ -125,16 +129,17 @@ class ImperativePTQ(object):
(1) output_spec (list[Tensor]): Selects the output targets of (1) output_spec (list[Tensor]): Selects the output targets of
the saved model. By default, all return variables of original the saved model. By default, all return variables of original
Layer's forward method are kept as the output of the saved model. Layer's forward method are kept as the output of the saved model.
If the provided ``output_spec`` list is not all output variables, If the provided ``output_spec`` list is not all output variables,
the saved model will be pruned according to the given the saved model will be pruned according to the given
``output_spec`` list. ``output_spec`` list.
Returns: Returns:
None None
""" """
assert isinstance(model, paddle.nn.Layer), \ assert isinstance(
"The model must be the instance of paddle.nn.Layer." model, paddle.nn.Layer
), "The model must be the instance of paddle.nn.Layer."
# Convert and save dygraph quantized model # Convert and save dygraph quantized model
self._convert(model) self._convert(model)
...@@ -156,12 +161,16 @@ class ImperativePTQ(object): ...@@ -156,12 +161,16 @@ class ImperativePTQ(object):
model_filename = basename + INFER_MODEL_SUFFIX model_filename = basename + INFER_MODEL_SUFFIX
params_filename = basename + INFER_PARAMS_SUFFIX params_filename = basename + INFER_PARAMS_SUFFIX
[infer_program, feed_target_names, [
fetch_targets] = (paddle.fluid.io.load_inference_model( infer_program,
dirname=dirname, feed_target_names,
executor=exe, fetch_targets,
model_filename=model_filename, ] = paddle.fluid.io.load_inference_model(
params_filename=params_filename)) dirname=dirname,
executor=exe,
model_filename=model_filename,
params_filename=params_filename,
)
# Process inference program # Process inference program
self._clean_up(infer_program) self._clean_up(infer_program)
...@@ -169,13 +178,15 @@ class ImperativePTQ(object): ...@@ -169,13 +178,15 @@ class ImperativePTQ(object):
self._remove_scale_op(infer_program) self._remove_scale_op(infer_program)
# Save final program # Save final program
paddle.fluid.io.save_inference_model(dirname=dirname, paddle.fluid.io.save_inference_model(
feeded_var_names=feed_target_names, dirname=dirname,
target_vars=fetch_targets, feeded_var_names=feed_target_names,
executor=exe, target_vars=fetch_targets,
main_program=infer_program.clone(), executor=exe,
model_filename=model_filename, main_program=infer_program.clone(),
params_filename=params_filename) model_filename=model_filename,
params_filename=params_filename,
)
if is_dynamic_mode: if is_dynamic_mode:
paddle.disable_static() paddle.disable_static()
...@@ -213,8 +224,9 @@ class ImperativePTQ(object): ...@@ -213,8 +224,9 @@ class ImperativePTQ(object):
Returns: Returns:
None None
""" """
assert isinstance(model, paddle.nn.Layer), \ assert isinstance(
"The input model must be the instance of paddle.nn.Layer." model, paddle.nn.Layer
), "The input model must be the instance of paddle.nn.Layer."
total_num = 0 total_num = 0
cur_num = 0 cur_num = 0
...@@ -226,8 +238,9 @@ class ImperativePTQ(object): ...@@ -226,8 +238,9 @@ class ImperativePTQ(object):
if self._is_quant_layer(sub_layer): if self._is_quant_layer(sub_layer):
cur_num += 1 cur_num += 1
if cur_num % 5 == 0: if cur_num % 5 == 0:
_logger.info("Process the %s / %s layer" % _logger.info(
(cur_num, total_num)) "Process the %s / %s layer" % (cur_num, total_num)
)
quant_config = sub_layer._quant_config quant_config = sub_layer._quant_config
...@@ -236,7 +249,7 @@ class ImperativePTQ(object): ...@@ -236,7 +249,7 @@ class ImperativePTQ(object):
quant_config.out_act_quantizer.cal_thresholds() quant_config.out_act_quantizer.cal_thresholds()
if PTQRegistry.is_simulated_quant_layer(sub_layer): if PTQRegistry.is_simulated_quant_layer(sub_layer):
weights = (sub_layer.weight, ) weights = (sub_layer.weight,)
quant_config.wt_quantizer.sample_data(sub_layer, weights) quant_config.wt_quantizer.sample_data(sub_layer, weights)
quant_config.wt_quantizer.cal_thresholds() quant_config.wt_quantizer.cal_thresholds()
...@@ -250,18 +263,25 @@ class ImperativePTQ(object): ...@@ -250,18 +263,25 @@ class ImperativePTQ(object):
Returns: Returns:
None None
""" """
assert isinstance(sub_layer, paddle.nn.Layer), \ assert isinstance(
"The input model must be the instance of paddle.nn.Layer." sub_layer, paddle.nn.Layer
), "The input model must be the instance of paddle.nn.Layer."
layer_info = PTQRegistry.layer_info(sub_layer) layer_info = PTQRegistry.layer_info(sub_layer)
output_names = layer_info.output_names output_names = layer_info.output_names
output_thresholds = quant_config.out_act_quantizer.thresholds output_thresholds = quant_config.out_act_quantizer.thresholds
assert len(output_names) == 1 assert len(output_names) == 1
assert len(output_thresholds) == 1 if len(output_thresholds) == 1:
save_name = output_names[0] + str(0) + "_threshold" save_name = output_names[0] + str(0) + "_threshold"
sub_layer._set_op_attrs({save_name: output_thresholds[0]}) sub_layer._set_op_attrs({save_name: output_thresholds[0]})
sub_layer._set_op_attrs({"out_threshold": output_thresholds[0]}) sub_layer._set_op_attrs({"out_threshold": output_thresholds[0]})
else:
_logger.warning(
"output_thresholds shape of {} need to be 1, but received {}".format(
output_names[0], len(output_thresholds)
)
)
def _wrap_simulated_layers(self, model): def _wrap_simulated_layers(self, model):
""" """
...@@ -272,12 +292,14 @@ class ImperativePTQ(object): ...@@ -272,12 +292,14 @@ class ImperativePTQ(object):
Returns: Returns:
None None
""" """
assert isinstance(model, paddle.nn.Layer), \ assert isinstance(
"The input model must be the instance of paddle.nn.Layer." model, paddle.nn.Layer
), "The input model must be the instance of paddle.nn.Layer."
for name, sub_layer in model.named_sublayers(): for name, sub_layer in model.named_sublayers():
if self._is_quant_layer(sub_layer) \ if self._is_quant_layer(
and PTQRegistry.is_simulated_quant_layer(sub_layer): sub_layer
) and PTQRegistry.is_simulated_quant_layer(sub_layer):
quant_config = sub_layer._quant_config quant_config = sub_layer._quant_config
assert quant_config.enable_in_act_quantizer == True assert quant_config.enable_in_act_quantizer == True
...@@ -303,36 +325,44 @@ class ImperativePTQ(object): ...@@ -303,36 +325,44 @@ class ImperativePTQ(object):
"activation_bits": in_act_quantizer.quant_bits, "activation_bits": in_act_quantizer.quant_bits,
} }
quant_layer = quant_layers.__dict__[quant_layer_name](sub_layer, quant_layer = quant_layers.__dict__[quant_layer_name](
**kwargs) sub_layer, **kwargs
)
# save the input thresholds # save the input thresholds
assert hasattr(quant_layer, "_fake_quant_input") assert hasattr(quant_layer, "_fake_quant_input")
assert hasattr(quant_layer._fake_quant_input, "_scale") assert hasattr(quant_layer._fake_quant_input, "_scale")
assert len(in_act_quantizer.thresholds) == 1 if len(in_act_quantizer.thresholds) == 1:
input_threshold = np.array([in_act_quantizer.thresholds[0]], input_threshold = np.array(
dtype=np.float32) [in_act_quantizer.thresholds[0]], dtype=np.float32
quant_layer._fake_quant_input._scale.set_value(input_threshold) )
quant_layer._fake_quant_input._scale.set_value(
input_threshold
)
assert hasattr(quant_layer, "_fake_quant_weight") assert hasattr(quant_layer, "_fake_quant_weight")
assert hasattr(quant_layer._fake_quant_weight, "_scale") assert hasattr(quant_layer._fake_quant_weight, "_scale")
assert len(wt_quantizer.thresholds) == 1 assert len(wt_quantizer.thresholds) == 1
weight_threshold = wt_quantizer.thresholds[0] weight_threshold = wt_quantizer.thresholds[0]
if isinstance(weight_threshold, list): if isinstance(weight_threshold, list):
weight_threshold = np.array(weight_threshold, weight_threshold = np.array(
dtype=np.float32) weight_threshold, dtype=np.float32
)
else: else:
weight_threshold = np.array([weight_threshold], weight_threshold = np.array(
dtype=np.float32) [weight_threshold], dtype=np.float32
)
quant_layer._fake_quant_weight._scale.set_value( quant_layer._fake_quant_weight._scale.set_value(
weight_threshold) weight_threshold
)
# save the output thresholds # save the output thresholds
self._save_output_thresholds(quant_layer, quant_config) self._save_output_thresholds(quant_layer, quant_config)
# replace the layer # replace the layer
parent_layer, sub_name = \ parent_layer, sub_name = utils.find_parent_layer_and_sub_name(
utils.find_parent_layer_and_sub_name(model, name) model, name
)
setattr(parent_layer, sub_name, quant_layer) setattr(parent_layer, sub_name, quant_layer)
def _gather_input_thresholds(self, program, scope): def _gather_input_thresholds(self, program, scope):
...@@ -351,30 +381,37 @@ class ImperativePTQ(object): ...@@ -351,30 +381,37 @@ class ImperativePTQ(object):
if previous_op is None: if previous_op is None:
continue continue
if "quantize_dequantize" in previous_op.type or \ if (
previous_op.type == "moving_average_abs_max_scale": "quantize_dequantize" in previous_op.type
or previous_op.type == "moving_average_abs_max_scale"
):
attr_name = previous_op.output('OutScale')[0] attr_name = previous_op.output('OutScale')[0]
in_threshold = utils.load_variable_data(scope, attr_name) in_threshold = utils.load_variable_data(scope, attr_name)
in_threshold = utils.fp_numpy_to_naive(in_threshold) in_threshold = utils.fp_numpy_to_naive(in_threshold)
argname, index = utils._get_input_name_index( argname, index = utils._get_input_name_index(
op, in_var_name) op, in_var_name
op._set_attr(argname + str(index) + "_threshold", )
in_threshold) op._set_attr(
argname + str(index) + "_threshold", in_threshold
)
op._set_attr("with_quant_attr", True) op._set_attr("with_quant_attr", True)
else: else:
for out_var_name in utils._get_op_output_var_names( for out_var_name in utils._get_op_output_var_names(
previous_op): previous_op
):
if out_var_name != in_var_name: if out_var_name != in_var_name:
continue continue
argname, index = utils._get_output_name_index( argname, index = utils._get_output_name_index(
previous_op, out_var_name) previous_op, out_var_name
)
attr_name = argname + str(index) + "_threshold" attr_name = argname + str(index) + "_threshold"
if not previous_op.has_attr(attr_name): if not previous_op.has_attr(attr_name):
continue continue
threshold = previous_op.attr(attr_name) threshold = previous_op.attr(attr_name)
argname, index = utils._get_input_name_index( argname, index = utils._get_input_name_index(
op, in_var_name) op, in_var_name
)
attr_name = argname + str(index) + "_threshold" attr_name = argname + str(index) + "_threshold"
op._set_attr(attr_name, threshold) op._set_attr(attr_name, threshold)
op._set_attr("with_quant_attr", True) op._set_attr("with_quant_attr", True)
...@@ -390,8 +427,11 @@ class ImperativePTQ(object): ...@@ -390,8 +427,11 @@ class ImperativePTQ(object):
""" """
def _helper(op, next_op, old_attr_name, new_attr_name): def _helper(op, next_op, old_attr_name, new_attr_name):
if op.has_attr(old_attr_name) and next_op.has_attr(old_attr_name) \ if (
and op.attr(old_attr_name) == next_op.attr(old_attr_name): op.has_attr(old_attr_name)
and next_op.has_attr(old_attr_name)
and op.attr(old_attr_name) == next_op.attr(old_attr_name)
):
threshold = op.attr(old_attr_name) threshold = op.attr(old_attr_name)
op._remove_attr(old_attr_name) op._remove_attr(old_attr_name)
next_op._remove_attr(old_attr_name) next_op._remove_attr(old_attr_name)
...@@ -417,8 +457,8 @@ class ImperativePTQ(object): ...@@ -417,8 +457,8 @@ class ImperativePTQ(object):
old_attr_name = argname + str(index) + "_threshold" old_attr_name = argname + str(index) + "_threshold"
argname, index = utils._get_output_name_index( argname, index = utils._get_output_name_index(
next_op, next_op, next_op.output("Out")[0]
next_op.output("Out")[0]) )
new_attr_name = argname + str(index) + "_threshold" new_attr_name = argname + str(index) + "_threshold"
_helper(op, next_op, old_attr_name, new_attr_name) _helper(op, next_op, old_attr_name, new_attr_name)
......
...@@ -41,6 +41,7 @@ PTQ_LAYERS_INFO = [ ...@@ -41,6 +41,7 @@ PTQ_LAYERS_INFO = [
LayerInfo(paddle.nn.ReLU, ['X'], [], ['Out']), LayerInfo(paddle.nn.ReLU, ['X'], [], ['Out']),
LayerInfo(paddle.nn.ReLU6, ['X'], [], ['Out']), LayerInfo(paddle.nn.ReLU6, ['X'], [], ['Out']),
LayerInfo(paddle.nn.Hardswish, ['X'], [], ['Out']), LayerInfo(paddle.nn.Hardswish, ['X'], [], ['Out']),
LayerInfo(paddle.nn.Swish, ['X'], [], ['Out']),
LayerInfo(paddle.nn.Sigmoid, ['X'], [], ['Out']), LayerInfo(paddle.nn.Sigmoid, ['X'], [], ['Out']),
LayerInfo(paddle.nn.Softmax, ['X'], [], ['Out']), LayerInfo(paddle.nn.Softmax, ['X'], [], ['Out']),
LayerInfo(paddle.nn.Tanh, ['X'], [], ['Out']), LayerInfo(paddle.nn.Tanh, ['X'], [], ['Out']),
...@@ -48,10 +49,15 @@ PTQ_LAYERS_INFO = [ ...@@ -48,10 +49,15 @@ PTQ_LAYERS_INFO = [
] ]
QUANT_LAYERS_INFO = [ QUANT_LAYERS_INFO = [
LayerInfo(paddle.nn.quant.quant_layers.QuantizedConv2D, ['Input'], LayerInfo(
['Filter'], ['Output']), paddle.nn.quant.quant_layers.QuantizedConv2D,
LayerInfo(paddle.nn.quant.quant_layers.QuantizedLinear, ['X'], ['Y'], ['Input'],
['Out']), ['Filter'],
['Output'],
),
LayerInfo(
paddle.nn.quant.quant_layers.QuantizedLinear, ['X'], ['Y'], ['Out']
),
] ]
SIMULATED_LAYERS = [paddle.nn.Conv2D, paddle.nn.Linear] SIMULATED_LAYERS = [paddle.nn.Conv2D, paddle.nn.Linear]
...@@ -61,6 +67,7 @@ class PTQRegistry(object): ...@@ -61,6 +67,7 @@ class PTQRegistry(object):
""" """
Register the supported layers for PTQ and provide layers info. Register the supported layers for PTQ and provide layers info.
""" """
supported_layers_map = {} supported_layers_map = {}
registered_layers_map = {} registered_layers_map = {}
is_inited = False is_inited = False
...@@ -89,8 +96,9 @@ class PTQRegistry(object): ...@@ -89,8 +96,9 @@ class PTQRegistry(object):
flag(bool): Whther the layer is supported. flag(bool): Whther the layer is supported.
""" """
cls._init() cls._init()
return layer in cls.supported_layers_map or \ return layer in cls.supported_layers_map or isinstance(
isinstance(layer, tuple(cls.supported_layers_map.keys())) layer, tuple(cls.supported_layers_map.keys())
)
@classmethod @classmethod
def is_registered_layer(cls, layer): def is_registered_layer(cls, layer):
...@@ -102,8 +110,9 @@ class PTQRegistry(object): ...@@ -102,8 +110,9 @@ class PTQRegistry(object):
flag(bool): Wether the layer is register layer_info. flag(bool): Wether the layer is register layer_info.
""" """
cls._init() cls._init()
return layer in cls.registered_layers_map or \ return layer in cls.registered_layers_map or isinstance(
isinstance(layer, tuple(cls.registered_layers_map.keys())) layer, tuple(cls.registered_layers_map.keys())
)
@classmethod @classmethod
def is_simulated_quant_layer(cls, layer): def is_simulated_quant_layer(cls, layer):
...@@ -114,8 +123,9 @@ class PTQRegistry(object): ...@@ -114,8 +123,9 @@ class PTQRegistry(object):
Returns: Returns:
flag(bool): Whther the layer is supported. flag(bool): Whther the layer is supported.
""" """
return layer in SIMULATED_LAYERS or \ return layer in SIMULATED_LAYERS or isinstance(
isinstance(layer, tuple(SIMULATED_LAYERS)) layer, tuple(SIMULATED_LAYERS)
)
@classmethod @classmethod
def layer_info(cls, layer): def layer_info(cls, layer):
...@@ -126,8 +136,9 @@ class PTQRegistry(object): ...@@ -126,8 +136,9 @@ class PTQRegistry(object):
Returns: Returns:
layer_info(LayerInfo): The layer info of the input layer. layer_info(LayerInfo): The layer info of the input layer.
""" """
assert cls.is_registered_layer(layer), \ assert cls.is_registered_layer(
"The input layer is not register." layer
), "The input layer is not register."
for layer_key, layer_info in cls.registered_layers_map.items(): for layer_key, layer_info in cls.registered_layers_map.items():
if layer == layer_key or isinstance(layer, layer_key): if layer == layer_key or isinstance(layer, layer_key):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册