diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index f262ace3dc6527be3358b74e7fc2736817e3a302..6335c80e0839bc3bb33ebfa78c99a7037c6799bf 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -25,7 +25,9 @@ from ....log_helper import get_logger from .quantization_pass import QuantizationTransformPass from .quantization_pass import QuantizationFreezePass from .quantization_pass import AddQuantDequantPass -from .quantization_pass import _op_real_in_out_name +from .quantization_pass import _out_scale_op_list +from .quantization_pass import _get_op_input_var_names +from .quantization_pass import _get_op_output_var_names __all__ = ['PostTrainingQuantization', 'WeightQuantization'] @@ -68,14 +70,17 @@ class PostTrainingQuantization(object): model_dir=None, model_filename=None, params_filename=None, + batch_generator=None, sample_generator=None, batch_size=10, batch_nums=None, algo="KL", quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], is_full_quantize=False, - weight_bits=8, activation_bits=8, + weight_bits=8, + activation_quantize_type='range_abs_max', + weight_quantize_type='channel_wise_abs_max', is_use_cache_file=False, cache_dir="./temp_post_training"): ''' @@ -95,9 +100,14 @@ class PostTrainingQuantization(object): When all parameters were saved in a single binary file, set it as the real filename. If parameters were saved in separate files, set it as 'None'. Default is 'None'. - sample_generator(Python Generator): The sample generator provides - calibrate data for DataLoader, and it only returns a sample every - time. + batch_generator(Python Generator): The batch generator provides + calibrate data for DataLoader, and it returns a batch every + time. Note that, sample_generator and batch_generator, only one + should be set. Beisdes, batch_generator supports lod tensor. + sample_generator(Python Generator): The sample generator provides + calibrate data for DataLoader, and it only returns a sample every + time. Note that, sample_generator and batch_generator, only one + should be set. Beisdes, sample_generator dose not support lod tensor. batch_size(int, optional): The batch size of DataLoader. Default is 10. batch_nums(int, optional): If batch_nums is not None, the number of calibrate data is batch_size*batch_nums. If batch_nums is None, use @@ -114,8 +124,19 @@ class PostTrainingQuantization(object): apply quantization to all supported quantizable op type. If set is_full_quantized as False, only apply quantization to the op type according to the input quantizable_op_type. - weight_bits(int, optional): quantization bit number for weights. activation_bits(int): quantization bit number for activation. + weight_bits(int, optional): quantization bit number for weights. + activation_quantize_type(str): quantization type for activation, + now support 'range_abs_max', 'moving_average_abs_max' and 'abs_max'. + This param only specifies the fake ops in saving quantized model. + If it is 'range_abs_max' or 'moving_average_abs_max', we save the scale + obtained by post training quantization in fake ops. Note that, if it + is 'abs_max', the scale will not be saved in fake ops. + weight_quantize_type(str): quantization type for weights, + support 'abs_max' and 'channel_wise_abs_max'. This param only specifies + the fake ops in saving quantized model, and we save the scale obtained + by post training quantization in fake ops. Compared to 'abs_max', + the model accuracy is usually higher when it is 'channel_wise_abs_max'. is_use_cache_file(bool, optional): If set is_use_cache_file as False, all temp data will be saved in memory. If set is_use_cache_file as True, it will save temp data to disk. When the fp32 model is complex or @@ -163,46 +184,67 @@ class PostTrainingQuantization(object): ptq.save_quantized_model(save_model_path) ''' + self._support_activation_quantize_type = [ + 'range_abs_max', 'moving_average_abs_max', 'abs_max' + ] + self._support_weight_quantize_type = ['abs_max', 'channel_wise_abs_max'] + self._support_algo_type = ['KL', 'abs_max', 'min_max'] + self._support_quantize_op_type = \ + list(set(QuantizationTransformPass._supported_quantizable_op_type + + AddQuantDequantPass._supported_quantizable_op_type)) + + # Check inputs assert executor is not None, "The executor cannot be None." assert model_dir is not None, "The model_dir cannot be None." - assert sample_generator is not None, \ - "The sample_generator cannot be None." - assert algo in ['KL', 'abs_max', 'min_max'], \ + assert any([gen is not None] for gen in [sample_generator, + batch_generator]), "The sample_generator and batch_generator " \ + "cannot be None in the same time." + assert batch_size > 0, "The batch_size should be greater than 0." + assert algo in self._support_algo_type, \ "The algo should be KL, abs_max or min_max." - + assert activation_quantize_type in self._support_activation_quantize_type, \ + "The activation_quantize_type ({}) should in ({}).".format( + activation_quantize_type, self._support_activation_quantize_type) + assert weight_quantize_type in self._support_weight_quantize_type, \ + "The weight_quantize_type ({}) shoud in ({}).".format( + weight_quantize_type, self._support_weight_quantize_type) + + # Save input params self._executor = executor self._scope = global_scope() if scope == None else scope self._model_dir = model_dir self._model_filename = model_filename self._params_filename = params_filename self._sample_generator = sample_generator + self._batch_generator = batch_generator self._batch_size = batch_size self._batch_nums = batch_nums self._algo = algo - self._is_use_cache_file = is_use_cache_file - self._cache_dir = cache_dir - if self._is_use_cache_file and not os.path.exists(self._cache_dir): - os.mkdir(self._cache_dir) - - supported_quantizable_op_type = \ - QuantizationTransformPass._supported_quantizable_op_type + \ - AddQuantDequantPass._supported_quantizable_op_type + self._activation_bits = activation_bits + self._weight_bits = weight_bits + self._activation_quantize_type = activation_quantize_type + self._weight_quantize_type = weight_quantize_type + self._is_full_quantize = is_full_quantize if is_full_quantize: - self._quantizable_op_type = supported_quantizable_op_type + self._quantizable_op_type = self._support_quantize_op_type else: self._quantizable_op_type = quantizable_op_type for op_type in self._quantizable_op_type: - assert op_type in supported_quantizable_op_type, \ + assert op_type in self._support_quantize_op_type, \ op_type + " is not supported for quantization." + self._is_use_cache_file = is_use_cache_file + self._cache_dir = cache_dir + if self._is_use_cache_file and not os.path.exists(self._cache_dir): + os.mkdir(self._cache_dir) + # Define variables self._place = self._executor.place self._program = None self._feed_list = None self._fetch_list = None self._data_loader = None - self._op_real_in_out_name = _op_real_in_out_name - self._bit_length = 8 + self._out_scale_op_list = _out_scale_op_list self._quantized_weight_var_name = set() self._quantized_act_var_name = set() self._sampling_data = {} @@ -223,7 +265,7 @@ class PostTrainingQuantization(object): the program of quantized model. ''' self._load_model_data() - self._collect_quantized_varnames() + self._collect_target_varnames() self._set_activation_persistable() batch_id = 0 @@ -257,17 +299,28 @@ class PostTrainingQuantization(object): self._save_output_threshold() return self._program - def save_quantized_model(self, save_model_path): + def save_quantized_model(self, + save_model_path, + model_filename=None, + params_filename=None): ''' Save the quantized model to the disk. Args: - save_model_path(str): The path to save the quantized model + save_model_path(str): The path to save the quantized model. + model_filename(str, optional): If the model_filename is None, + save the model to '__model__'. Otherwise, save the model + to the specified filename. Default: None. + params_filename(str, optional): If the params_filename is None, + save params to separted files. Otherwise, save all params + to the specified filename. Returns: None ''' io.save_inference_model( dirname=save_model_path, + model_filename=model_filename, + params_filename=params_filename, feeded_var_names=self._feed_list, target_vars=self._fetch_list, executor=self._executor, @@ -287,20 +340,31 @@ class PostTrainingQuantization(object): for var_name in self._feed_list] self._data_loader = io.DataLoader.from_generator( feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True) - self._data_loader.set_sample_generator( - self._sample_generator, - batch_size=self._batch_size, - drop_last=True, - places=self._place) - - def _collect_quantized_varnames(self): + if self._sample_generator is not None: + self._data_loader.set_sample_generator( + self._sample_generator, + batch_size=self._batch_size, + drop_last=True, + places=self._place) + elif self._batch_generator is not None: + self._data_loader.set_batch_generator( + self._batch_generator, places=self._place) + + def _collect_target_varnames(self): ''' Collect the variable names for sampling, and set activation variables to be persistable. ''' + # TODO(juncaipeng), consider the name_scope of skip_quant _logger.info("Collect quantized variable names ...") - # TODO(juncaipeng), consider the name_scope of skip_quant and - # reduce the variables for sampling + + def collect_var_name(var_name_list, persistable_var_names): + for var_name in var_name_list: + if var_name in persistable_var_names: + self._quantized_weight_var_name.add(var_name) + else: + self._quantized_act_var_name.add(var_name) + persistable_var_names = [] for var in self._program.list_vars(): if var.persistable: @@ -308,30 +372,22 @@ class PostTrainingQuantization(object): for op in self._program.global_block().ops: op_type = op.type + # For quantized ops, sample inputs and outputs if op_type in self._quantizable_op_type: - name_list = self._op_real_in_out_name[op_type] - for input_name in name_list[0]: - for var_name in op.input(input_name): - if var_name in persistable_var_names: - self._quantized_weight_var_name.add(var_name) - else: - self._quantized_act_var_name.add(var_name) - for output_name in name_list[1]: - for var_name in op.output(output_name): - if var_name in persistable_var_names: - self._quantized_weight_var_name.add(var_name) - else: - self._quantized_act_var_name.add(var_name) + collect_var_name( + _get_op_input_var_names(op), persistable_var_names) + collect_var_name( + _get_op_output_var_names(op), persistable_var_names) + # For other op, only sample output scale + elif op_type in self._out_scale_op_list: + collect_var_name( + _get_op_output_var_names(op), persistable_var_names) def _set_activation_persistable(self): ''' Set activation variables to be persistable, so can obtain the tensor data in sample_data ''' - persistable_var_names = [] - for var in self._program.list_vars(): - if var.persistable: - persistable_var_names.append(var.name) for var in self._program.list_vars(): if var.name in self._quantized_act_var_name: var.persistable = True @@ -350,6 +406,7 @@ class PostTrainingQuantization(object): ''' assert self._algo in ["abs_max", "min_max"], \ "The algo should be abs_max or min_max to sample min max value." + if self._algo == "abs_max": # Only calculate abs_max value for weight for once if self._quantized_var_abs_max == {}: @@ -396,15 +453,13 @@ class PostTrainingQuantization(object): "The algo should be min_max to save input threshold." for op in self._program.global_block().ops: if op.type in self._quantizable_op_type: - input_name_list = self._op_real_in_out_name[op.type][0] - for input_name in input_name_list: - for var_name in op.input(input_name): - assert var_name in self._quantized_var_min - assert var_name in self._quantized_var_max - op._set_attr(var_name + ".min", - self._quantized_var_min[var_name]) - op._set_attr(var_name + ".max", - self._quantized_var_max[var_name]) + for var_name in _get_op_input_var_names(op): + assert var_name in self._quantized_var_min + assert var_name in self._quantized_var_max + op._set_attr(var_name + ".min", + self._quantized_var_min[var_name]) + op._set_attr(var_name + ".max", + self._quantized_var_max[var_name]) def _sample_data(self, iter): ''' @@ -438,16 +493,21 @@ class PostTrainingQuantization(object): ''' _logger.info("Calculate KL threshold ...") assert self._algo == "KL", "The algo should be KL to calculate kl threshold." - # apply channel_wise_abs_max quantization for weights + + # Abs_max threshold for weights for var_name in self._quantized_weight_var_name: - data = self._sampling_data[var_name] - threshold_per_channel = [] - for i in range(data.shape[0]): - abs_max_value = np.max(np.abs(data[i])) - threshold_per_channel.append(abs_max_value) - self._quantized_var_kl_threshold[var_name] = threshold_per_channel - - # apply kl quantization for activation + weight_data = self._sampling_data[var_name] + weight_threshold = None + if self._weight_quantize_type == "abs_max": + weight_threshold = np.max(np.abs(weight_data)) + elif self._weight_quantize_type == "channel_wise_abs_max": + weight_threshold = [] + for i in range(weight_data.shape[0]): + abs_max_value = np.max(np.abs(weight_data[i])) + weight_threshold.append(abs_max_value) + self._quantized_var_kl_threshold[var_name] = weight_threshold + + # KL threshold for activations if self._is_use_cache_file: for var_name in self._quantized_act_var_name: sampling_data = [] @@ -484,10 +544,10 @@ class PostTrainingQuantization(object): transform_pass = QuantizationTransformPass( scope=self._scope, place=self._place, - weight_bits=self._bit_length, - activation_bits=self._bit_length, - activation_quantize_type='moving_average_abs_max', - weight_quantize_type='channel_wise_abs_max', + weight_bits=self._weight_bits, + activation_bits=self._activation_bits, + activation_quantize_type=self._activation_quantize_type, + weight_quantize_type=self._weight_quantize_type, quantizable_op_type=major_quantizable_op_types) transform_pass.apply(graph) @@ -525,9 +585,9 @@ class PostTrainingQuantization(object): freeze_pass = QuantizationFreezePass( scope=self._scope, place=self._place, - weight_bits=self._bit_length, - activation_bits=self._bit_length, - weight_quantize_type='channel_wise_abs_max', + weight_bits=self._weight_bits, + activation_bits=self._activation_bits, + weight_quantize_type=self._weight_quantize_type, quantizable_op_type=major_quantizable_op_types) freeze_pass.apply(graph) self._program = graph.to_program() @@ -536,30 +596,37 @@ class PostTrainingQuantization(object): ''' Save output threshold to the quantized op. ''' + + def save_info(op_node, out_var_name, threshold_map, out_info_name, + quantized_type): + assert out_var_name in threshold_map, \ + "The output ({}) of {} node does not have threshold.".format( + out_var_name, op_node.type) + op_node._set_attr(out_info_name, threshold_map[var_name]) + if op_node.type in self._quantizable_op_type: + op._set_attr("quantization_type", quantized_type) + + def analysis_and_save_info(op_node, out_var_name): + if self._algo == "KL": + save_info(op_node, out_var_name, + self._quantized_var_kl_threshold, "out_threshold", + "post_kl") + elif self._algo == "abs_max": + save_info(op_node, out_var_name, self._quantized_var_abs_max, + "out_threshold", "post_abs_max") + elif self._algo == "min_max": + save_info(op_node, out_var_name, self._quantized_var_min, + "out_min", "post_min_max") + save_info(op_node, out_var_name, self._quantized_var_max, + "out_max", "post_min_max") + for op in self._program.global_block().ops: - if op.type in self._quantizable_op_type: - output_name_list = self._op_real_in_out_name[op.type][1] - for output_name in output_name_list: - for var_name in op.output(output_name): - if self._algo == "KL": - assert var_name in self._quantized_var_kl_threshold - op._set_attr( - var_name + ".threshold", - self._quantized_var_kl_threshold[var_name]) - op._set_attr("quantization_type", "post_kl") - elif self._algo == "abs_max": - assert var_name in self._quantized_var_abs_max - op._set_attr(var_name + ".threshold", - self._quantized_var_abs_max[var_name]) - op._set_attr("quantization_type", "post_abs_max") - elif self._algo == "min_max": - assert var_name in self._quantized_var_min - assert var_name in self._quantized_var_max - op._set_attr(var_name + ".min", - self._quantized_var_min[var_name]) - op._set_attr(var_name + ".max", - self._quantized_var_max[var_name]) - op._set_attr("quantization_type", "post_min_max") + if op.type in (self._quantizable_op_type + self._out_scale_op_list): + out_var_names = _get_op_output_var_names(op) + assert len(out_var_names) == 1, "Post training " + \ + "quantization only support one output for " + op.type + for var_name in out_var_names: + analysis_and_save_info(op, var_name) def _get_kl_scaling_factor(self, activation_blob, num_quantized_bins=255): ''' diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 201ffc35dd397c9fc883f1ff258bac00ccc38cbf..cde41e687fa90f56e715fb34ba68eec13a3b9369 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -18,12 +18,13 @@ from ..... import compat as cpt from .... import core from ....framework import IrGraph from ....framework import IrNode +from ....framework import Operator from .... import unique_name __all__ = [ 'QuantizationTransformPass', 'QuantizationFreezePass', 'ConvertToInt8Pass', - 'TransformForMobilePass', 'ScaleForTrainingPass', 'ScaleForInferencePass', - 'AddQuantDequantPass' + 'TransformForMobilePass', 'OutScaleForTrainingPass', + 'OutScaleForInferencePass', 'AddQuantDequantPass' ] _fake_quant_op_list = [ @@ -40,9 +41,9 @@ _fake_quant_dequant_op_list = [ ] _out_scale_op_list = [ - "mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid", "depthwise_conv2d", - "batch_norm", "concat", "tanh", "pad", "elementwise_add", "elementwise_mul", - "dropout", "split", "prelu", "conv2d_transpose", "leaky_relu" + "conv2d", "depthwise_conv2d", "mul", "matmul", "relu", "leaky_relu", + "relu6", "sigmoid", "tanh", "prelu", "swish", "softmax", "batch_norm", + "elementwise_add", "pool2d", "reshape2", "transpose2" ] # list op real input and output names, to avoid processing input such as AxisTensor. @@ -67,6 +68,7 @@ _op_real_in_out_name = { "not_equal": [["X", "Y"], ["Out"]], "reshape": [["X"], ["Out"]], "reshape2": [["X"], ["Out"]], + "transpose2": [["X"], ["Out"]], "bilinear_interp": [["X"], ["Out"]], "nearest_interp": [["X"], ["Out"]], "trilinear_interp": [["X"], ["Out"]], @@ -76,11 +78,49 @@ _op_real_in_out_name = { "relu": [["X"], ["Out"]], "relu6": [["X"], ["Out"]], "leaky_relu": [["X"], ["Out"]], + "prelu": [["X"], ["Out"]], "tanh": [["X"], ["Out"]], "swish": [["X"], ["Out"]], + "dropout": [["X"], ["Out"]], + "batch_norm": [["X"], ["Y"]], + "sigmoid": [["X"], ["Y"]], } +def _get_op_input_var_names(op): + """ """ + assert isinstance(op, (IrNode, Operator)), \ + "The input op should be IrNode or Operator." + var_names = [] + op_name = op.name() if isinstance(op, IrNode) \ + else op.type + name_list = _op_real_in_out_name[op_name][0] + for name in name_list: + var_name = op.input(name) + if isinstance(var_name, list): + var_names.extend(var_name) + else: + var_names.append(var_name) + return var_names + + +def _get_op_output_var_names(op): + """ """ + assert isinstance(op, (IrNode, Operator)), \ + "The input op should be IrNode or Operator." + var_names = [] + op_name = op.name() if isinstance(op, IrNode) \ + else op.type + name_list = _op_real_in_out_name[op_name][1] + for name in name_list: + var_name = op.output(name) + if isinstance(var_name, list): + var_names.extend(var_name) + else: + var_names.append(var_name) + return var_names + + def _init_var_node(var_node, value, scope, place): assert isinstance(value, np.ndarray), 'The type of value should be numpy array.' @@ -97,17 +137,18 @@ def _is_input_all_not_persistable(graph, op_node): Analyse the real inputs of the op node are all not persistable. ''' is_input_all_not_persistable = True - op_node_name = op_node.name() - input_name_list = _op_real_in_out_name[op_node_name][0] - for input_name in input_name_list: - for arg_name in op_node.input(input_name): - in_node = graph._find_node_by_name(op_node.inputs, arg_name) - is_input_all_not_persistable = (is_input_all_not_persistable and \ - (not in_node.persistable())) + for var_name in _get_op_input_var_names(op_node): + in_node = graph._find_node_by_name(op_node.inputs, var_name) + is_input_all_not_persistable = (is_input_all_not_persistable and \ + (not in_node.persistable())) return is_input_all_not_persistable class QuantizationTransformPass(object): + """ + Quantize the ops that have weights. Add quant and dequant ops for the quantized + ops's inputs. + """ _supported_quantizable_op_type = [ 'conv2d', 'depthwise_conv2d', 'mul', 'matmul' ] @@ -124,8 +165,7 @@ class QuantizationTransformPass(object): skip_pattern=['skip_quant'], quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']): """ - Convert and rewrite the IrGraph according to weight and - activation quantization type. + Constructor. Args: scope(fluid.Scope): When activation use 'range_abs_max' as the quantize @@ -1088,7 +1128,7 @@ class TransformForMobilePass(object): return graph -class ScaleForTrainingPass(object): +class OutScaleForTrainingPass(object): def __init__(self, scope=None, place=None, moving_rate=0.9): """ This pass is used for calculating output scales of some operators. @@ -1195,7 +1235,7 @@ class ScaleForTrainingPass(object): return "%s@scale" % (var_name) -class ScaleForInferencePass(object): +class OutScaleForInferencePass(object): def __init__(self, scope=None): """ This pass is used for setting output scales of some operators. @@ -1226,7 +1266,7 @@ class ScaleForInferencePass(object): scale_name = self._scale_name(op_node.output_arg_names()[0]) scale_v = np.array( self._scope.find_var(scale_name).get_tensor())[0] - op_node.op()._set_attr("out_scale", float(scale_v)) + op_node.op()._set_attr("out_threshold", float(scale_v)) graph.resolve_hazard() return graph @@ -1238,6 +1278,10 @@ class ScaleForInferencePass(object): class AddQuantDequantPass(object): + """ + Quantize the ops that do not have weights, and add quant_dequant op for the + quantized ops's inputs. + """ _supported_quantizable_op_type = [ "pool2d", "elementwise_add", "concat", "softmax", "argmax", "transpose", "equal", "gather", "greater_equal", "greater_than", "less_equal", @@ -1259,9 +1303,7 @@ class AddQuantDequantPass(object): quantizable_op_type=["elementwise_add", "pool2d"], is_full_quantized=False): """ - This pass add quant_dequant op for some ops, of which all the inputs must be - not persistable. - The input scales can be obtained from the quant_dequant op. + Constructor. Args: scope(fluid.Scope): The scope is used to initialize these new parameters. @@ -1338,10 +1380,7 @@ class AddQuantDequantPass(object): op_node.op()._set_attr("quantization_type", "qat_without_weight") op_node.op()._set_attr("activation_bits", self._quant_bits) - input_name_list = _op_real_in_out_name[op_node.name()][0] - arg_names = [] - for input_name in input_name_list: - arg_names.extend(op_node.input(input_name)) + arg_names = _get_op_input_var_names(op_node) for arg_name in arg_names: in_node = graph._find_node_by_name(op_node.inputs, arg_name) if arg_name in dequantized_vars_map: diff --git a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt index 4ac4db6b4aca2f26743ec6db3528fb9978032463..58a3827ce6e2f9ee8de5dd25dc0ba4510f6cf8a4 100644 --- a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt @@ -107,15 +107,14 @@ function(save_qat_model_test target qat_model_dir fp32_model_save_path int8_mode --quantized_ops ${quantized_ops}) endfunction() -# Disable the unittest temporary -list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1) if(WIN32) - list(REMOVE_ITEM TEST_OPS test_light_nas) + list(REMOVE_ITEM TEST_OPS test_light_nas) + list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1) list(REMOVE_ITEM TEST_OPS test_post_training_quantization_resnet50) list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1) endif() -# Disable unittest for random error +# Disable unittest for random error temporary list(REMOVE_ITEM TEST_OPS test_quantization_scale_pass) if(LINUX AND WITH_MKLDNN) diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py index 45140aec4e5f6159a16b52a22bda3e79dd3e3c60..50085ed4a5b7aff66a9581e0d7f2415c9b46f631 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py @@ -140,9 +140,9 @@ class TestPostTrainingQuantization(unittest.TestCase): self.batch_size = 1 if os.environ.get('DATASET') == 'full' else 50 self.sample_iterations = 50 if os.environ.get( - 'DATASET') == 'full' else 1 + 'DATASET') == 'full' else 2 self.infer_iterations = 50000 if os.environ.get( - 'DATASET') == 'full' else 1 + 'DATASET') == 'full' else 2 self.timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) self.int8_model = os.path.join(os.getcwd(), @@ -287,11 +287,12 @@ class TestPostTrainingQuantization(unittest.TestCase): (int8_throughput, int8_latency, int8_acc1) = self.run_program( self.int8_model, batch_size, infer_iterations) + print("---Post training quantization of {} method---".format(algo)) print( - "FP32 {0}: batch_size {1}, throughput {2} images/second, latency {3} second, accuracy {4}". + "FP32 {0}: batch_size {1}, throughput {2} images/second, latency {3} second, accuracy {4}.". format(model, batch_size, fp32_throughput, fp32_latency, fp32_acc1)) print( - "INT8 {0}: batch_size {1}, throughput {2} images/second, latency {3} second, accuracy {4}". + "INT8 {0}: batch_size {1}, throughput {2} images/second, latency {3} second, accuracy {4}.\n". format(model, batch_size, int8_throughput, int8_latency, int8_acc1)) sys.stdout.flush() @@ -308,7 +309,10 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization): ] data_md5s = ['13892b0716d26443a8cdea15b3c6438b'] quantizable_op_type = [ - "conv2d", "depthwise_conv2d", "mul", "pool2d", "elementwise_add" + "conv2d", + "depthwise_conv2d", + "mul", + "pool2d", ] is_full_quantize = False is_use_cache_file = False @@ -326,10 +330,12 @@ class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization): ] data_md5s = ['13892b0716d26443a8cdea15b3c6438b'] quantizable_op_type = [ - "conv2d", "depthwise_conv2d", "mul", "pool2d", "elementwise_add" + "conv2d", + "mul", ] is_full_quantize = False is_use_cache_file = False + # The accuracy diff of post-traing quantization (abs_max) maybe bigger diff_threshold = 0.05 self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type, is_full_quantize, is_use_cache_file, diff_threshold) diff --git a/python/paddle/fluid/contrib/slim/tests/test_quantization_scale_pass.py b/python/paddle/fluid/contrib/slim/tests/test_quantization_scale_pass.py index 8aa461c438e904e440abc2626de55dd0abd582d5..8a502bc9378724a22c678807bc3df9c2f4b016ec 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_quantization_scale_pass.py +++ b/python/paddle/fluid/contrib/slim/tests/test_quantization_scale_pass.py @@ -22,8 +22,8 @@ import paddle from paddle.fluid.framework import IrGraph from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass -from paddle.fluid.contrib.slim.quantization import ScaleForTrainingPass -from paddle.fluid.contrib.slim.quantization import ScaleForInferencePass +from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass +from paddle.fluid.contrib.slim.quantization import OutScaleForInferencePass from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass from paddle.fluid import core @@ -112,7 +112,7 @@ class TestQuantizationScalePass(unittest.TestCase): add_quant_dequant_pass.apply(main_graph) add_quant_dequant_pass.apply(test_graph) - scale_training_pass = ScaleForTrainingPass(scope=scope, place=place) + scale_training_pass = OutScaleForTrainingPass(scope=scope, place=place) scale_training_pass.apply(main_graph) dev_name = '_gpu' if use_cuda else '_cpu' @@ -151,7 +151,7 @@ class TestQuantizationScalePass(unittest.TestCase): if not for_ci: print('{}: {}'.format('loss' + dev_name, loss_v)) - scale_inference_pass = ScaleForInferencePass(scope=scope) + scale_inference_pass = OutScaleForInferencePass(scope=scope) scale_inference_pass.apply(test_graph) # Freeze graph for inference, but the weight of fc/conv is still float type.