未验证 提交 84a55138 编写于 作者: C cc 提交者: GitHub

[dygraph qat] Refine saving output scale to infer program (#31784)

* Refine saving output scale to infer program
上级 68497e7b
...@@ -251,8 +251,8 @@ class ImperativeQuantizeInputs(object): ...@@ -251,8 +251,8 @@ class ImperativeQuantizeInputs(object):
super(ImperativeQuantizeInputs, self).__init__() super(ImperativeQuantizeInputs, self).__init__()
self._quantizable_layer_type = tuple( self._quantizable_layer_type = tuple(
utils.supported_quant_layers_map[layer] utils.quant_input_layers_map[layer]
if layer in utils.supported_quant_layers_map else layer if layer in utils.quant_input_layers_map else layer
for layer in quantizable_layer_type) for layer in quantizable_layer_type)
for layer in self._quantizable_layer_type: for layer in self._quantizable_layer_type:
assert not isinstance(layer, str), \ assert not isinstance(layer, str), \
...@@ -324,12 +324,11 @@ class ImperativeQuantizeInputs(object): ...@@ -324,12 +324,11 @@ class ImperativeQuantizeInputs(object):
target = name[last_idx:idx] target = name[last_idx:idx]
quant_layer = self._get_quantized_layer(layer) quant_layer = self._get_quantized_layer(layer)
setattr(quant_layer, "layer_name", layer.full_name())
setattr(obj, target, quant_layer) setattr(obj, target, quant_layer)
def _get_quantized_layer(self, layer): def _get_quantized_layer(self, layer):
quant_layer_name = None quant_layer_name = None
for key, value in utils.supported_quant_layers_map.items(): for key, value in utils.quant_input_layers_map.items():
if isinstance(layer, value): if isinstance(layer, value):
quant_layer_name = 'Quantized' + key quant_layer_name = 'Quantized' + key
break break
...@@ -372,6 +371,9 @@ class ImperativeCalcOutputScale(object): ...@@ -372,6 +371,9 @@ class ImperativeCalcOutputScale(object):
""" """
assert isinstance(model, dygraph.Layer), \ assert isinstance(model, dygraph.Layer), \
"The model must be the instance of dygraph.Layer." "The model must be the instance of dygraph.Layer."
# Calculate the target ops's output scale, and don't consider
# the skip_quant attr
for _, layer in model.named_sublayers(): for _, layer in model.named_sublayers():
if self._is_target_layer(layer): if self._is_target_layer(layer):
self._init_scale_params(layer) self._init_scale_params(layer)
...@@ -411,24 +413,21 @@ class ImperativeCalcOutputScale(object): ...@@ -411,24 +413,21 @@ class ImperativeCalcOutputScale(object):
assert isinstance(layer, dygraph.Layer), \ assert isinstance(layer, dygraph.Layer), \
"The model must be the instance of dygraph.Layer." "The model must be the instance of dygraph.Layer."
# remove handles and collect output scales self._gather_output_scale(layer)
with dygraph.guard(): with dygraph.guard():
layer.eval() layer.eval()
for handle in self._register_hook_handle_list: for handle in self._register_hook_handle_list:
handle.remove() handle.remove()
for _, sub_layer in layer.named_sublayers():
if self._is_target_layer(sub_layer):
if hasattr(sub_layer, "layer_name"):
layer_name = sub_layer.layer_name
else:
layer_name = sub_layer.full_name()
if hasattr(sub_layer, "_quant_out_scale"):
self._out_scale_dict[layer_name] = float(
sub_layer._quant_out_scale)
# save the quantized model that doesn't have output scales
paddle.jit.save(layer=layer, path=path, input_spec=input_spec, **config) paddle.jit.save(layer=layer, path=path, input_spec=input_spec, **config)
if len(self._out_scale_dict) == 0:
warnings.warn("Warning: No Layer of the model while to be " \
"saved contains the out_threshold attribute, so the " \
"generated inference model would not contain the " \
"out_threshold.")
return
# load static model # load static model
is_dynamic_mode = False is_dynamic_mode = False
if paddle.in_dynamic_mode(): if paddle.in_dynamic_mode():
...@@ -443,79 +442,26 @@ class ImperativeCalcOutputScale(object): ...@@ -443,79 +442,26 @@ class ImperativeCalcOutputScale(object):
basename = os.path.basename(path) basename = os.path.basename(path)
model_filename = basename + INFER_MODEL_SUFFIX model_filename = basename + INFER_MODEL_SUFFIX
params_filename = basename + INFER_PARAMS_SUFFIX params_filename = basename + INFER_PARAMS_SUFFIX
[inference_program, feed_target_names, fetch_targets] = (
[infer_program, feed_target_names, fetch_targets] = (
load_inference_model( load_inference_model(
dirname=dirname, dirname=dirname,
executor=exe, executor=exe,
model_filename=model_filename, model_filename=model_filename,
params_filename=params_filename)) params_filename=params_filename))
# TODO(jc): analyse whether the dygraph model has
# several blocks before applying qat
assert infer_program.num_blocks == 1, \
"Quantization aware training (QAT) requires the program " \
"only has a block for now. When the model has if-else or " \
"while, the program will have several blocks."
# set output scales to the static model # set output scales to the static model
check_behind_op = False self._save_output_scale(infer_program)
op_count = 0
ops_list = [key for key, _ in self._out_scale_dict.items()] # process skip quant
if len(ops_list) == 0: self._set_skip_quant_attr(infer_program)
warnings.warn(
"Warning: No Layer of the model while to be saved contains "
"the out_threshold attribute, so the generated inference "
"model would not contain the out_threshold.")
else:
# Because the Layer in dygraph may correspond to multiple ops
# in static program after being saved. To ensure correctness,
# the outscale collected for output of dygraph Layer can only
# be set to the last op in the corresponding ops in static program.
#
# We can judge the execution order of the ops which corresponding
# to dygraph Layer by check_behind_op
forward_op = None
for block in inference_program.blocks:
for op in block.ops:
if op.type in utils.op_real_in_out_name:
if op_count > len(ops_list):
warnings.warn(
"The number of Layer which has "
"out_threshold attribute should be bigger than "
"the op in inference model")
break
if check_behind_op:
check_behind_op = False
if op.type == "elementwise_add":
if self._is_op_matched(ops_list[op_count], op,
block):
op._set_attr("out_threshold",
self._out_scale_dict[ops_list[
op_count]])
op_count += 1
forward_op = None
continue
else:
if forward_op is None:
raise ValueError(
"forward_op should not be None")
if self._is_op_matched(ops_list[op_count],
forward_op, block):
forward_op._set_attr(
"out_threshold", self._out_scale_dict[
ops_list[op_count]])
op_count += 1
forward_op = None
if op.type in ["conv2d", "depthwise_conv2d", "matmul"]:
check_behind_op = True
forward_op = op
continue
if op_count >= len(ops_list):
warnings.warn(
"The number of Layer which has out_threshold attribute should be bigger than the op in inference model"
)
break
if self._is_op_matched(ops_list[op_count], op, block):
op._set_attr(
"out_threshold",
self._out_scale_dict[ops_list[op_count]])
op_count += 1
self._set_skip_quant_attr(inference_program)
# save the final quantized model that has output scales # save the final quantized model that has output scales
save_inference_model( save_inference_model(
...@@ -523,16 +469,75 @@ class ImperativeCalcOutputScale(object): ...@@ -523,16 +469,75 @@ class ImperativeCalcOutputScale(object):
feeded_var_names=feed_target_names, feeded_var_names=feed_target_names,
target_vars=fetch_targets, target_vars=fetch_targets,
executor=exe, executor=exe,
main_program=inference_program.clone(), main_program=infer_program.clone(),
model_filename=model_filename, model_filename=model_filename,
params_filename=params_filename) params_filename=params_filename)
if is_dynamic_mode: if is_dynamic_mode:
paddle.disable_static() paddle.disable_static()
def _gather_output_scale(self, layer):
"""
Gather all output scales to self._out_scale_dict
"""
with dygraph.guard():
layer.eval()
for _, sub_layer in layer.named_sublayers():
if self._is_target_layer(sub_layer):
layer_name = sub_layer.full_name()
if hasattr(sub_layer, "_quant_out_scale"):
self._out_scale_dict[layer_name] = float(
sub_layer._quant_out_scale)
def _save_output_scale(self, infer_program):
"""
Save all output scales to the corresponding ops in static
inference program.
Because the Layer in dygraph may correspond to multiple ops
in static program after being saved. To ensure correctness,
the outscale collected for output of dygraph Layer can only
be set to the last op in the corresponding ops in static program.
"""
assert infer_program.num_blocks == 1, \
"The inference program should only have a block."
global_block = infer_program.global_block()
target_ops = global_block.ops
scale_idx = 0
op_idx = 0
attr_name = "out_threshold"
for scale_name, scale_value in self._out_scale_dict.items():
while True:
if op_idx >= len(target_ops):
break
op = target_ops[op_idx]
if not self._is_scale_op_matched(scale_name, op, global_block):
op_idx += 1
else:
if op.type in utils.weight_op_types \
and op_idx + 1 < len(target_ops) \
and target_ops[op_idx+1].type == "elementwise_add":
target_ops[op_idx + 1]._set_attr(attr_name, scale_value)
op_idx += 2
else:
op._set_attr(attr_name, scale_value)
op_idx += 1
scale_idx += 1
break
if scale_idx != len(self._out_scale_dict):
_logger.warning("Warning: the model have %s output scales, "\
"but it only saves %s output scales." \
% (len(self._out_scale_dict), scale_idx))
def _is_target_layer(self, layer): def _is_target_layer(self, layer):
return isinstance(layer, utils.out_scale_layers_list) \ return isinstance(layer, tuple(utils.quant_output_layers_map.values())) \
or 'quantized_' in layer.full_name() or ('quantized_' in layer.full_name() and \
'quantized_noweight' not in layer.full_name())
def _init_scale_params(self, layer, name=None): def _init_scale_params(self, layer, name=None):
""" """
...@@ -570,27 +575,39 @@ class ImperativeCalcOutputScale(object): ...@@ -570,27 +575,39 @@ class ImperativeCalcOutputScale(object):
layer._quant_out_accum = _create_param(layer, name, "accum", dtype) layer._quant_out_accum = _create_param(layer, name, "accum", dtype)
layer._quant_out_accum.stop_gradient = True layer._quant_out_accum.stop_gradient = True
# Judge whether the op in program matches the Layer in dynamic model def _is_scale_op_matched(self, scale_name, op, block):
def _is_op_matched(self, layer_name, op, block): """
output_var_names = quantization_pass._get_op_output_var_names(op) Based on the op name and attrs to judge whether the op in
for output_var_name in output_var_names: program matches the scale_name. We must know the corresponding
output_var_tensor = block.var(output_var_name) name between dgraph and static model.
if output_var_tensor.dtype not in [ """
core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32 fp_type = [core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]
]: if op.type in quantization_pass._op_real_in_out_name.keys():
return False output_var_names = quantization_pass._get_op_output_var_names(op)
for output_var_name in output_var_names:
# Because the naming styles of static and dynamic graph are different, output_var_tensor = block.var(output_var_name)
# in order to avoid mistakes, we unify the name here. if output_var_tensor.dtype not in fp_type:
op_type = output_var_names[0].split(".")[0] return False
op_type = op_type.rsplit("_", 1)[0]
if op_type == 'depthwise_conv2d': # corresponding_map: [name, op_types, function]
op_type = 'conv2d' # Note that, the items have priority in corresponding_map
if 'prelu' in op_type: corresponding_map = [
op_type = op_type.replace('prelu', 'p_re_lu') ['conv2d_tranpose', ['conv2d_transpose', \
if 'relu' in op_type: 'depthwise_conv2d_transpose'], None],
op_type = op_type.replace('relu', 're_lu') ['conv2d', ['conv2d', 'depthwise_conv2d'], None],
return op_type in layer_name ['linear', ['matmul'], None],
['re_lu6', ['relu6'], None],
['p_re_lu', ['prelu'], None],
['leaky_re_lu', ['leaky_relu'], None],
['re_lu', ['relu'], None],
]
for item in corresponding_map:
if item[0] in scale_name:
return (op.type in item[1]) and \
(len(item) == 2 or item[2] is None or item[2](op))
return op.type in scale_name
def _set_skip_quant_attr(self, program): def _set_skip_quant_attr(self, program):
block = program.global_block() block = program.global_block()
......
...@@ -30,7 +30,7 @@ op_real_in_out_name = { ...@@ -30,7 +30,7 @@ op_real_in_out_name = {
"swish": [["X"], ["Out"]], "swish": [["X"], ["Out"]],
} }
supported_quant_layers_map = { quant_input_layers_map = {
'Conv2D': paddle.nn.Conv2D, 'Conv2D': paddle.nn.Conv2D,
'Linear': paddle.nn.Linear, 'Linear': paddle.nn.Linear,
'AdaptiveAvgPool2D': paddle.nn.AdaptiveAvgPool2D, 'AdaptiveAvgPool2D': paddle.nn.AdaptiveAvgPool2D,
...@@ -58,8 +58,30 @@ fake_quantize_dequantize_types = [ ...@@ -58,8 +58,30 @@ fake_quantize_dequantize_types = [
"fake_quantize_dequantize_moving_average_abs_max" "fake_quantize_dequantize_moving_average_abs_max"
] ]
out_scale_layers_list = ( quant_output_layers_map = {
paddle.nn.Conv2D, paddle.nn.Linear, paddle.nn.MaxPool2D, 'Conv2D': paddle.nn.Conv2D,
paddle.nn.BatchNorm, paddle.nn.BatchNorm2D, paddle.nn.SyncBatchNorm, 'Conv2DTranspose': paddle.nn.Conv2DTranspose,
paddle.nn.LeakyReLU, paddle.nn.PReLU, paddle.nn.ReLU, paddle.nn.ReLU6, 'Linear': paddle.nn.Linear,
paddle.nn.Sigmoid, paddle.nn.Softmax, paddle.nn.Tanh, paddle.nn.Swish) 'AdaptiveAvgPool2D': paddle.nn.AdaptiveAvgPool2D,
'AdaptiveMaxPool2D': paddle.nn.AdaptiveMaxPool2D,
'AvgPool2D': paddle.nn.AvgPool2D,
'MaxPool2D': paddle.nn.MaxPool2D,
'BatchNorm': paddle.nn.BatchNorm,
'BatchNorm2D': paddle.nn.BatchNorm2D,
'SyncBatchNorm': paddle.nn.SyncBatchNorm,
'ELU': paddle.nn.ELU,
'GELU': paddle.nn.GELU,
'LeakyReLU': paddle.nn.LeakyReLU,
'PReLU': paddle.nn.PReLU,
'ReLU': paddle.nn.ReLU,
'ReLU6': paddle.nn.ReLU6,
'Sigmoid': paddle.nn.Sigmoid,
'Softmax': paddle.nn.Softmax,
'Tanh': paddle.nn.Tanh,
'Swish': paddle.nn.Swish,
}
weight_op_types = [
"conv2d", "depthwise_conv2d", "matmul", "conv2d_transpose",
"depthwise_conv2d_transpose"
]
...@@ -33,7 +33,6 @@ from paddle.fluid.dygraph.container import Sequential ...@@ -33,7 +33,6 @@ from paddle.fluid.dygraph.container import Sequential
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from paddle.nn.layer import ReLU, LeakyReLU, Sigmoid, Softmax, PReLU from paddle.nn.layer import ReLU, LeakyReLU, Sigmoid, Softmax, PReLU
from paddle.nn import Linear, Conv2D, Softmax, BatchNorm2D, MaxPool2D from paddle.nn import Linear, Conv2D, Softmax, BatchNorm2D, MaxPool2D
from paddle.fluid.dygraph.nn import Pool2D
from paddle.fluid.log_helper import get_logger from paddle.fluid.log_helper import get_logger
from paddle.fluid.dygraph import nn from paddle.fluid.dygraph import nn
...@@ -131,8 +130,8 @@ class ImperativeLenet(fluid.dygraph.Layer): ...@@ -131,8 +130,8 @@ class ImperativeLenet(fluid.dygraph.Layer):
bias_attr=False), bias_attr=False),
BatchNorm2D(6), BatchNorm2D(6),
ReLU(), ReLU(),
Pool2D( MaxPool2D(
pool_size=2, pool_type='max', pool_stride=2), kernel_size=2, stride=2),
Conv2D( Conv2D(
in_channels=6, in_channels=6,
out_channels=16, out_channels=16,
...@@ -357,7 +356,6 @@ class TestImperativeOutSclae(unittest.TestCase): ...@@ -357,7 +356,6 @@ class TestImperativeOutSclae(unittest.TestCase):
"diff({}) at {}, dynamic loss = {}, static loss = {}". "diff({}) at {}, dynamic loss = {}, static loss = {}".
format(diff, i, loss_d, loss_s)) format(diff, i, loss_d, loss_s))
break break
self.assertTrue( self.assertTrue(
np.allclose( np.allclose(
np.array(dynamic_loss_rec), np.array(dynamic_loss_rec),
...@@ -398,10 +396,15 @@ class TestImperativeOutSclae(unittest.TestCase): ...@@ -398,10 +396,15 @@ class TestImperativeOutSclae(unittest.TestCase):
if dynamic_ops[i].has_attr("out_threshold"): if dynamic_ops[i].has_attr("out_threshold"):
op_count += 1 op_count += 1
self.assertTrue(dynamic_ops[i].type == static_ops[i].type) self.assertTrue(dynamic_ops[i].type == static_ops[i].type)
if dynamic_ops[i].attr("out_threshold") != static_ops[i].attr(
"out_threshold"):
_logger.info(dynamic_ops[i].attr("out_threshold"))
_logger.info(static_ops[i].attr("out_threshold"))
self.assertTrue(dynamic_ops[i].attr("out_threshold") == self.assertTrue(dynamic_ops[i].attr("out_threshold") ==
static_ops[i].attr("out_threshold")) static_ops[i].attr("out_threshold"))
self.assertTrue(op_count == 13) _logger.info("op_cout: {}".format(op_count))
self.assertTrue(op_count == 14)
class TestSaveQuanztizedModelFromCheckPoint(unittest.TestCase): class TestSaveQuanztizedModelFromCheckPoint(unittest.TestCase):
...@@ -470,7 +473,9 @@ class TestSaveQuanztizedModelFromCheckPoint(unittest.TestCase): ...@@ -470,7 +473,9 @@ class TestSaveQuanztizedModelFromCheckPoint(unittest.TestCase):
self.assertTrue(dynamic_ops[i].type == static_ops[i].type) self.assertTrue(dynamic_ops[i].type == static_ops[i].type)
self.assertTrue(dynamic_ops[i].attr("out_threshold") == self.assertTrue(dynamic_ops[i].attr("out_threshold") ==
static_ops[i].attr("out_threshold")) static_ops[i].attr("out_threshold"))
self.assertTrue(op_count == 13)
_logger.info("op_cout: {}".format(op_count))
self.assertTrue(op_count == 14)
class TestSaveQuantizedModel_Warning(unittest.TestCase): class TestSaveQuantizedModel_Warning(unittest.TestCase):
...@@ -490,8 +495,10 @@ class TestSaveQuantizedModel_Warning(unittest.TestCase): ...@@ -490,8 +495,10 @@ class TestSaveQuantizedModel_Warning(unittest.TestCase):
shape=[None, 1, 28, 28], dtype='float32') shape=[None, 1, 28, 28], dtype='float32')
]) ])
warning_message = "Warning: No Layer of the model while to be saved contains the out_threshold attribute, " \ warning_message = "Warning: No Layer of the model while to be " \
"so the generated inference model would not contain the out_threshold." "saved contains the out_threshold attribute, so the " \
"generated inference model would not contain the " \
"out_threshold."
num = get_vaild_warning_num(warning_message, w) num = get_vaild_warning_num(warning_message, w)
assert num == 1 assert num == 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册