diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index 49cb4ea4311eea98f5a97e2e74c75f0fc751f591..7c48e29ebc94b80ecaf725844805db6baf48511f 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -11,12 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math + import os import re +import math +import shutil import logging import numpy as np -import shutil + try: from tqdm import tqdm except: @@ -34,7 +36,10 @@ from .cal_kl_threshold import cal_kl_threshold from .adaround import run_adaround from . import utils -__all__ = ['PostTrainingQuantization', 'WeightQuantization'] +__all__ = [ + 'PostTrainingQuantization', 'WeightQuantization', + 'PostTrainingQuantizationProgram' +] _logger = get_logger(__name__, logging.INFO, @@ -108,9 +113,9 @@ class PostTrainingQuantization(object): """ def __init__(self, - executor=None, + executor, + model_dir, scope=None, - model_dir=None, model_filename=None, params_filename=None, batch_generator=None, @@ -130,10 +135,15 @@ class PostTrainingQuantization(object): activation_quantize_type='range_abs_max', weight_quantize_type='channel_wise_abs_max', onnx_format=False, + freeze_model=True, optimize_model=False, is_use_cache_file=False, skip_tensor_list=None, - cache_dir=None): + same_scale_tensor_list=None, + scale_trainable=False, + cache_dir=None, + scale_dict=None, + return_graph=False): ''' Constructor. @@ -206,7 +216,12 @@ class PostTrainingQuantization(object): the model accuracy is usually higher when it is 'channel_wise_abs_max'. onnx_format(bool): Whether to export the quantized model with format of ONNX. Default is False. - skip_tensor_list(list): List of skip quant tensor name. + freeze_model(bool): Whether to convert quantized and trained ``program`` to final + quantized ``program``. Default: True. + skip_tensor_list(list): List of skip quant tensor name. Default: None. + same_scale_tensor_list(list(list)): The list of tensor keep same scale in the outermost + list, the final scale about every list is the max of the scale in the list + of tensor. Default: None. optimize_model(bool, optional): If set optimize_model as True, it applies some passes to the model before quantization, and it supports `conv2d/depthwise_conv2d + bn` pass so far. Some targets require the @@ -215,6 +230,7 @@ class PostTrainingQuantization(object): `conv2d/depthwise_conv2d + bn`, the weights scale for all channel will be different. In address this problem, fuse the pattern before quantization. Default False. + scale_trainable(bool, optional): whether scale can be train. is_use_cache_file(bool, optional): This param is deprecated. cache_dir(str, optional): This param is deprecated. Returns: @@ -275,7 +291,6 @@ class PostTrainingQuantization(object): # Check inputs assert executor is not None, "The executor cannot be None." - assert model_dir is not None, "The model_dir cannot be None." assert any([gen is not None] for gen in [sample_generator, batch_generator, data_loader]), "The sample_generator, batch_generator " \ "and data_loader cannot be None in the same time." @@ -347,6 +362,11 @@ class PostTrainingQuantization(object): self._best_calibration_loss = {} # The threshold for algo = abs_max, mse or avg self._quantized_threshold = {} + self._same_scale_tensor_list = same_scale_tensor_list + self._freeze_model = freeze_model + self._scale_trainable = scale_trainable + self._scale_dict = scale_dict + self._return_graph = return_graph def quantize(self): ''' @@ -441,7 +461,11 @@ class PostTrainingQuantization(object): persistables.extend(_op.input('X')) _op.desc.set_input("X", persistables) - return self._program + if not self._return_graph: + return self._program + else: + main_graph = IrGraph(core.Graph(self._program.desc), for_test=True) + return main_graph def _adaround_apply(self): assert self._algo != "min_max", "The algo should not be min_max." @@ -495,12 +519,13 @@ class PostTrainingQuantization(object): ''' Load model and set data loader. ''' - _logger.info("Load model and set data loader ...") - [self._program, self._feed_list, self._fetch_list] = \ - io.load_inference_model(dirname=self._model_dir, - executor=self._executor, - model_filename=self._model_filename, - params_filename=self._params_filename) + if self._program is None: + _logger.info("Load model and set data loader ...") + [self._program, self._feed_list, self._fetch_list] = \ + io.load_inference_model(dirname=self._model_dir, + executor=self._executor, + model_filename=self._model_filename, + params_filename=self._params_filename) if self._optimize_model: self._optimize_fp32_model() @@ -972,7 +997,8 @@ class PostTrainingQuantization(object): activation_bits=self._activation_bits, activation_quantize_type=self._activation_quantize_type, weight_quantize_type=self._weight_quantize_type, - quantizable_op_type=major_quantizable_op_types) + quantizable_op_type=major_quantizable_op_types, + is_test=not self._scale_trainable) else: transform_pass = QuantizationTransformPassV2( scope=self._scope, @@ -981,7 +1007,8 @@ class PostTrainingQuantization(object): activation_bits=self._activation_bits, activation_quantize_type=self._activation_quantize_type, weight_quantize_type=self._weight_quantize_type, - quantizable_op_type=major_quantizable_op_types) + quantizable_op_type=major_quantizable_op_types, + is_test=not self._scale_trainable) for sub_graph in graph.all_sub_graphs(): # Insert fake_quant/fake_dequantize op must in test graph, so @@ -998,24 +1025,68 @@ class PostTrainingQuantization(object): add_quant_dequant_pass = AddQuantDequantPass( scope=self._scope, place=self._place, - quantizable_op_type=minor_quantizable_op_types) + quantizable_op_type=minor_quantizable_op_types, + is_test=not self._scale_trainable) else: add_quant_dequant_pass = AddQuantDequantPassV2( scope=self._scope, place=self._place, quantizable_op_type=minor_quantizable_op_types, - is_full_quantized=self._is_full_quantize) + is_full_quantized=self._is_full_quantize, + is_test=not self._scale_trainable) for sub_graph in graph.all_sub_graphs(): sub_graph._for_test = True add_quant_dequant_pass.apply(sub_graph) # save threshold to scale var node - if self._algo in ["KL", "hist"]: - scale_dict = self._quantized_var_threshold - else: - scale_dict = self._quantized_threshold - for key, val in scale_dict.items(): + if self._scale_dict is None: + if self._algo in ["KL", "hist"]: + scale_dict = self._quantized_var_threshold + else: + scale_dict = self._quantized_threshold + + if self._same_scale_tensor_list is not None: + for tensor_list in self._same_scale_tensor_list: + max_scale = None + tmp_tensor_list = [] + for tensor_name in tensor_list: + if '#' in tensor_name: + real_tensor_name, opera, scalar = tensor_name.split( + '#') + if opera == '*': + scale_dict[real_tensor_name] = float( + scale_dict[real_tensor_name]) * float( + scalar) + elif opera == '/': + scale_dict[real_tensor_name] = float( + scale_dict[real_tensor_name]) / float( + scalar) + max_scale = scale_dict[ + real_tensor_name] if max_scale is None else max( + max_scale, scale_dict[real_tensor_name]) + else: + max_scale = scale_dict[ + tensor_name] if max_scale is None else max( + max_scale, scale_dict[tensor_name]) + + for tensor_name in tensor_list: + if '#' in tensor_name: + real_tensor_name, opera, scalar = tensor_name.split( + '#') + if opera == '*': + scale_dict[ + real_tensor_name] = max_scale / float( + scalar) + elif opera == '/': + scale_dict[ + real_tensor_name] = max_scale * float( + scalar) + else: + scale_dict[tensor_name] = max_scale + self._scale_dict = scale_dict + + for key, val in self._scale_dict.items(): utils.set_variable_data(self._scope, self._place, key + "@scale", np.array([val], dtype=np.float32)) utils.set_variable_data(self._scope, self._place, @@ -1024,19 +1095,20 @@ class PostTrainingQuantization(object): if not self._onnx_format: # apply QuantizationFreezePass, and obtain the final quant model - freeze_pass = QuantizationFreezePass( - scope=self._scope, - place=self._place, - bias_correction=self._bias_correction, - weight_bits=self._weight_bits, - round_type=self._round_type, - activation_bits=self._activation_bits, - weight_quantize_type=self._weight_quantize_type, - quantizable_op_type=major_quantizable_op_types) - - for sub_graph in graph.all_sub_graphs(): - sub_graph._for_test = True - freeze_pass.apply(sub_graph) + if self._freeze_model: + freeze_pass = QuantizationFreezePass( + scope=self._scope, + place=self._place, + bias_correction=self._bias_correction, + weight_bits=self._weight_bits, + round_type=self._round_type, + activation_bits=self._activation_bits, + weight_quantize_type=self._weight_quantize_type, + quantizable_op_type=major_quantizable_op_types) + + for sub_graph in graph.all_sub_graphs(): + sub_graph._for_test = True + freeze_pass.apply(sub_graph) else: quant_weight_pass = QuantWeightPass(self._scope, self._place) for sub_graph in graph.all_sub_graphs(): @@ -1155,6 +1227,58 @@ class PostTrainingQuantization(object): return (hist_index - 0.5) * bin_width +class PostTrainingQuantizationProgram(PostTrainingQuantization): + + def __init__(self, + executor, + program, + feed_list=None, + fetch_list=None, + scope=None, + batch_generator=None, + sample_generator=None, + data_loader=None, + batch_size=10, + batch_nums=None, + algo="KL", + hist_percent=0.99999, + quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], + round_type='round', + learning_rate=0.001, + is_full_quantize=False, + bias_correction=False, + activation_bits=8, + weight_bits=8, + activation_quantize_type='range_abs_max', + weight_quantize_type='channel_wise_abs_max', + onnx_format=False, + freeze_model=True, + optimize_model=False, + is_use_cache_file=False, + skip_tensor_list=None, + same_scale_tensor_list=None, + scale_trainable=False, + cache_dir=None, + scale_dict=None, + return_graph=True): + super().__init__(executor, scope, None, None, None, batch_generator, + sample_generator, data_loader, batch_size, batch_nums, + algo, hist_percent, quantizable_op_type, round_type, + learning_rate, is_full_quantize, bias_correction, + activation_bits, weight_bits, activation_quantize_type, + weight_quantize_type, onnx_format, freeze_model, + optimize_model, is_use_cache_file, skip_tensor_list, + same_scale_tensor_list, scale_trainable, cache_dir, + scale_dict, return_graph) + self._program = program + assert feed_list is not None, \ + "Feed list should not be None." + assert fetch_list is not None, \ + "Fetch list should not be None." + self._feed_list = feed_list + self._fetch_list = fetch_list + + class WeightQuantization(object): _supported_quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul'] _supported_weight_quantize_type = ['channel_wise_abs_max', 'abs_max'] diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 8213b779a6a46b611059b5a9ed2ebe011af03f46..6fdd84e3491ca6e055156c1d3baf9484c964e554 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -124,7 +124,8 @@ class QuantizationTransformPass(object): weight_preprocess_func=None, act_preprocess_func=None, optimizer_func=None, - executor=None): + executor=None, + is_test=None): r""" Constructor. @@ -241,7 +242,7 @@ class QuantizationTransformPass(object): self._quantizable_grad_ops = [ '%s_grad' % (op) for op in self._quantizable_ops ] - self._is_test = None + self._is_test = is_test self._global_step = None self.create_var_map = {} @@ -260,7 +261,8 @@ class QuantizationTransformPass(object): """ assert isinstance(graph, IrGraph), 'graph must be the instance of IrGraph.' - self._is_test = graph.is_test() + if self._is_test is None: + self._is_test = graph.is_test() # marked the variable which has been dequantized. dequantized_vars = collections.OrderedDict() persistable_vars = [p.name() for p in graph.all_persistable_nodes()] @@ -449,16 +451,21 @@ class QuantizationTransformPass(object): var_type=var_node.type(), shape=var_node.shape(), var_dtype=var_node.dtype()) + scale_name = self._quantized_scale_name(name) + data_type = 'float64' if var_node.dtype( + ) == core.VarDesc.VarType.FP64 else 'float32' + try: + scale_value = np.array( + self._scope.find_var(scale_name).get_tensor()) + except: + scale_value = np.zeros([1], dtype=data_type) scale_var_node = graph.create_persistable_node( - name=self._quantized_scale_name(name), + name=scale_name, var_type=var_node.type(), shape=[1], var_dtype=var_node.dtype()) - data_type = 'float64' if var_node.dtype( - ) == core.VarDesc.VarType.FP64 else 'float32' - _init_var_node(scale_var_node, - np.zeros(scale_var_node.shape(), dtype=data_type), - self._scope, self._place) + _init_var_node(scale_var_node, scale_value, self._scope, self._place) + quant_op_node = graph.create_op_node( op_type='fake_quantize_abs_max', attrs={ @@ -487,16 +494,20 @@ class QuantizationTransformPass(object): shape=var_node.shape(), var_dtype=var_node.dtype()) + scale_name = self._quantized_scale_name(name) + data_type = 'float64' if var_node.dtype( + ) == core.VarDesc.VarType.FP64 else 'float32' + try: + scale_value = np.array( + self._scope.find_var(scale_name).get_tensor()) + except: + scale_value = np.array([_SCALE_DEFAULT_VALUE], dtype=data_type) scale_in_node = graph.create_persistable_node( - name=self._quantized_scale_name(name), + name=scale_name, var_type=core.VarDesc.VarType.LOD_TENSOR, shape=[1], var_dtype=var_node.dtype()) - data_type = 'float64' if var_node.dtype( - ) == core.VarDesc.VarType.FP64 else 'float32' - _init_var_node(scale_in_node, - np.array([_SCALE_DEFAULT_VALUE], dtype=data_type), - self._scope, self._place) + _init_var_node(scale_in_node, scale_value, self._scope, self._place) scale_out_node = graph.create_var_node_from_desc(scale_in_node.var()) inputs = {'X': var_node, 'InScale': scale_in_node} @@ -549,16 +560,20 @@ class QuantizationTransformPass(object): var_type=var_node.type(), shape=var_node.shape(), var_dtype=var_node.dtype()) + scale_name = self._quantized_scale_name(name) + data_type = 'float64' if var_node.dtype( + ) == core.VarDesc.VarType.FP64 else 'float32' + try: + scale_value = np.array( + self._scope.find_var(scale_name).get_tensor()) + except: + scale_value = np.array([_SCALE_DEFAULT_VALUE], dtype=data_type) scale_in_node = graph.create_persistable_node( - name=self._quantized_scale_name(name), + name=scale_name, var_type=core.VarDesc.VarType.LOD_TENSOR, shape=[1], var_dtype=var_node.dtype()) - data_type = 'float64' if var_node.dtype( - ) == core.VarDesc.VarType.FP64 else 'float32' - _init_var_node(scale_in_node, - np.array([_SCALE_DEFAULT_VALUE], dtype=data_type), - self._scope, self._place) + _init_var_node(scale_in_node, scale_value, self._scope, self._place) scale_out_node = graph.create_var_node_from_desc(scale_in_node.var()) ins = {'X': var_node, 'InScale': scale_in_node} @@ -628,16 +643,21 @@ class QuantizationTransformPass(object): var_type=var_node.type(), shape=var_node.shape(), var_dtype=var_node.dtype()) + scale_name = self._quantized_scale_name(name) + data_type = 'float64' if var_node.dtype( + ) == core.VarDesc.VarType.FP64 else 'float32' + try: + scale_value = np.array( + self._scope.find_var(scale_name).get_tensor()) + except: + scale_value = np.zeros([var_node.shape()[quant_axis]], + dtype=data_type) scale_var_node = graph.create_persistable_node( name=self._quantized_scale_name(name), var_type=var_node.type(), shape=[var_node.shape()[quant_axis]], var_dtype=var_node.dtype()) - data_type = 'float64' if var_node.dtype( - ) == core.VarDesc.VarType.FP64 else 'float32' - _init_var_node(scale_var_node, - np.zeros(scale_var_node.shape(), dtype=data_type), - self._scope, self._place) + _init_var_node(scale_var_node, scale_value, self._scope, self._place) quant_op_node = graph.create_op_node( op_type='fake_channel_wise_quantize_abs_max', attrs={ @@ -1396,7 +1416,12 @@ class TransformForMobilePass(object): class OutScaleForTrainingPass(object): - def __init__(self, scope=None, place=None, moving_rate=0.9): + def __init__(self, + scope=None, + place=None, + moving_rate=0.9, + is_test=None, + scale_dict=None): """ This pass is used for calculating output scales of some operators. These output scales may be used by tensorRT or some other inference engines. @@ -1411,8 +1436,9 @@ class OutScaleForTrainingPass(object): self._scope = scope self._place = _get_paddle_place(place) self._moving_rate = moving_rate - self._is_test = None + self._is_test = is_test self._teller_set = utils._out_scale_op_list + self._scale_dict = scale_dict def apply(self, graph): """ @@ -1424,7 +1450,8 @@ class OutScaleForTrainingPass(object): """ assert isinstance(graph, IrGraph), 'graph must be the instance of IrGraph.' - self._is_test = graph.is_test() + if self._is_test is None: + self._is_test = graph.is_test() target_ops = [] for op in graph.all_op_nodes(): if op.name() in self._teller_set: @@ -1440,22 +1467,29 @@ class OutScaleForTrainingPass(object): [core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]: continue + data_type = 'float64' if in_node.dtype() \ + == core.VarDesc.VarType.FP64 else 'float32' try: - graph._find_node_by_name( + scale_node = graph._find_node_by_name( graph.all_var_nodes(), self._scale_name(in_node.name())) - continue except: scale_node = graph.create_persistable_node( name=self._scale_name(in_node.name()), var_type=core.VarDesc.VarType.LOD_TENSOR, shape=[1], var_dtype=in_node.dtype()) + if self._scale_dict is not None: + try: + scale_value = np.array( + [self._scale_dict[in_node.name()]]) + except: + scale_value = np.ones([1], dtype=data_type) + else: + scale_value = np.ones([1], dtype=data_type) + _init_var_node(scale_node, scale_value, self._scope, + self._place) - data_type = 'float64' if in_node.dtype() \ - == core.VarDesc.VarType.FP64 else 'float32' - _init_var_node(scale_node, np.ones([1], dtype=data_type), - self._scope, self._place) ins = {'X': in_node} outs = {'OutScale': scale_node} if not self._is_test: @@ -1589,7 +1623,9 @@ class AddQuantDequantPass(object): quant_bits=8, skip_pattern=["skip_quant"], quantizable_op_type=["elementwise_add", "pool2d"], - is_full_quantized=False): + is_full_quantized=False, + is_test=None, + scale_dict=None): """ Constructor. @@ -1616,8 +1652,9 @@ class AddQuantDequantPass(object): self._place = _get_paddle_place(place) self._moving_rate = moving_rate self._quant_bits = quant_bits - self._is_test = None + self._is_test = is_test self._skip_pattern = skip_pattern + self._scale_dict = scale_dict if is_full_quantized: self._quantizable_op_type = utils._act_supported_quantizable_op_type @@ -1645,7 +1682,8 @@ class AddQuantDequantPass(object): """ assert isinstance(graph, IrGraph), 'graph must be the instance of IrGraph.' - self._is_test = graph.is_test() + if self._is_test is None: + self._is_test = graph.is_test() dequantized_vars_map = collections.OrderedDict() # Forward stage, insert quant_dequant op @@ -1711,17 +1749,28 @@ class AddQuantDequantPass(object): var_type=var_node.type(), shape=var_node.shape(), var_dtype=var_node.dtype()) + scale_name = "{}.quant_dequant@scale".format(var_node.name()) + data_type = 'float64' if var_node.dtype( + ) == core.VarDesc.VarType.FP64 else 'float32' + try: + if self._scale_dict is not None and var_node.name( + ) in self._scale_dict.keys(): + scale_value = np.array([self._scale_dict[var_node.name()]], + dtype=data_type) + else: + scale_value = np.array( + self._scope.find_var(scale_name).get_tensor(), + dtype=data_type) + except: + scale_value = np.array([_SCALE_DEFAULT_VALUE], dtype=data_type) + scale_in_node = graph.create_persistable_node( name="{}.quant_dequant@scale".format(var_node.name()), var_type=core.VarDesc.VarType.LOD_TENSOR, shape=[1], var_dtype=var_node.dtype()) - data_type = 'float64' if var_node.dtype( - ) == core.VarDesc.VarType.FP64 else 'float32' - _init_var_node(scale_in_node, - np.array([_SCALE_DEFAULT_VALUE], dtype=data_type), - self._scope, self._place) + _init_var_node(scale_in_node, scale_value, self._scope, self._place) scale_out_node = graph.create_var_node_from_desc(scale_in_node.var()) ins = {'X': var_node, 'InScale': scale_in_node} outs = {'Out': quant_var_node, 'OutScale': scale_out_node} @@ -1992,7 +2041,8 @@ class QuantizationTransformPassV2(QuantizationTransformPass): weight_preprocess_func=None, act_preprocess_func=None, optimizer_func=None, - executor=None): + executor=None, + is_test=None): r""" Args: scope(paddle.Scope): When activation use 'range_abs_max' as the quantize @@ -2106,7 +2156,7 @@ class QuantizationTransformPassV2(QuantizationTransformPass): self._quantizable_grad_ops = [ '%s_grad' % (op) for op in self._quantizable_ops ] - self._is_test = None + self._is_test = is_test self._global_step = None self.create_var_map = {} @@ -2235,7 +2285,8 @@ class QuantizationTransformPassV2(QuantizationTransformPass): """ assert isinstance(graph, IrGraph), 'graph must be the instance of IrGraph.' - self._is_test = graph.is_test() + if self._is_test is None: + self._is_test = graph.is_test() self.persistable_vars = [ p.name() for p in graph.all_persistable_nodes() @@ -2285,7 +2336,8 @@ class AddQuantDequantPassV2(object): quant_bits=8, skip_pattern=["skip_quant"], quantizable_op_type=["elementwise_add", "pool2d"], - is_full_quantized=False): + is_full_quantized=False, + is_test=None): """ Args: scope(paddle.Scope): The scope is used to initialize these new parameters. @@ -2325,7 +2377,7 @@ class AddQuantDequantPassV2(object): self._place = _get_paddle_place(place) self._moving_rate = moving_rate self._quant_bits = quant_bits - self._is_test = None + self._is_test = is_test self._skip_pattern = skip_pattern if is_full_quantized: @@ -2355,7 +2407,8 @@ class AddQuantDequantPassV2(object): """ assert isinstance(graph, IrGraph), 'graph must be the instance of IrGraph.' - self._is_test = graph.is_test() + if self._is_test is None: + self._is_test = graph.is_test() dequantized_vars_map = collections.OrderedDict() self.persistable_vars = [ diff --git a/python/paddle/fluid/contrib/slim/quantization/utils.py b/python/paddle/fluid/contrib/slim/quantization/utils.py index e7187018c8d5bffb8c0b5bb3872093159e1cdbe7..c2c24348f5b76c5296bcc130b797ee59d625de7a 100644 --- a/python/paddle/fluid/contrib/slim/quantization/utils.py +++ b/python/paddle/fluid/contrib/slim/quantization/utils.py @@ -38,7 +38,6 @@ _act_supported_quantizable_op_type = [ "mean", "not_equal", "reshape", - "reshape2", "dropout", "bilinear_interp", "nearest_interp", diff --git a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt index de373716d8b131bf9f7ea6b828471b589f391c2e..7e38e407336e5fab6832cdb7179554b87223b193 100644 --- a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt @@ -246,6 +246,7 @@ if(WIN32) list(REMOVE_ITEM TEST_OPS test_post_training_quantization_while) list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1) list(REMOVE_ITEM TEST_OPS test_post_training_quantization_resnet50) + list(REMOVE_ITEM TEST_OPS test_post_training_quantization_program_resnet50) list(REMOVE_ITEM TEST_OPS test_post_training_quantization_lstm_model) list(REMOVE_ITEM TEST_OPS test_imperative_ptq) list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1) @@ -520,6 +521,8 @@ endforeach() if(NOT WIN32) set_tests_properties(test_post_training_quantization_lstm_model PROPERTIES TIMEOUT 120) + set_tests_properties(test_post_training_quantization_program_resnet50 + PROPERTIES TIMEOUT 240) set_tests_properties(test_post_training_quantization_mobilenetv1 PROPERTIES TIMEOUT 600 LABELS "RUN_TYPE=NIGHTLY") set_tests_properties(test_post_training_quantization_resnet50 diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py index 70b04ebf5ef5e931514e7362d00958960a43ec25..cb6d685f721d604f89c9b5c3eb00589838ce390f 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py @@ -292,13 +292,13 @@ class TestPostTrainingQuantization(unittest.TestCase): print("Start FP32 inference for {0} on {1} images ...".format( model, infer_iterations * batch_size)) - (fp32_throughput, fp32_latency, - fp32_acc1) = self.run_program(model_cache_folder + "/model", - batch_size, infer_iterations) + (fp32_throughput, fp32_latency, fp32_acc1) = self.run_program( + os.path.join(model_cache_folder, "model"), batch_size, + infer_iterations) print("Start INT8 post training quantization for {0} on {1} images ...". format(model, sample_iterations * batch_size)) - self.generate_quantized_model(model_cache_folder + "/model", + self.generate_quantized_model(os.path.join(model_cache_folder, "model"), quantizable_op_type, algo, round_type, is_full_quantize, is_use_cache_file, is_optimize_model, onnx_format) @@ -454,29 +454,5 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization): onnx_format=onnx_format) -class TestPostTrainingPtfForMobilenetv1(TestPostTrainingQuantization): - - def test_post_training_ptf_mobilenetv1(self): - model = "MobileNet-V1" - algo = "ptf" - round_type = "round" - data_urls = [ - 'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' - ] - data_md5s = ['13892b0716d26443a8cdea15b3c6438b'] - quantizable_op_type = [ - "conv2d", - "mul", - ] - is_full_quantize = False - is_use_cache_file = False - is_optimize_model = False - # The accuracy diff of post-training quantization (abs_max) maybe bigger - diff_threshold = 0.05 - self.run_test(model, algo, round_type, data_urls, data_md5s, - quantizable_op_type, is_full_quantize, is_use_cache_file, - is_optimize_model, diff_threshold) - - if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_program_resnet50.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_program_resnet50.py new file mode 100644 index 0000000000000000000000000000000000000000..3709d497a634d46b93a2bd61531d0477d8fd1fae --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_program_resnet50.py @@ -0,0 +1,279 @@ +# copyright (c) 2018 paddlepaddle authors. all rights reserved. +# +# licensed under the apache license, version 2.0 (the "license"); +# you may not use this file except in compliance with the license. +# you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. + +import os +import sys +import time +import paddle +import random +import unittest +import functools +import contextlib +import numpy as np +import paddle.fluid as fluid +from PIL import Image, ImageEnhance +from paddle.fluid.contrib.slim.quantization import PostTrainingQuantizationProgram +from test_post_training_quantization_mobilenetv1 import TestPostTrainingQuantization + +paddle.enable_static() + +random.seed(0) +np.random.seed(0) + +THREAD = 1 +DATA_DIM = 224 +BUF_SIZE = 102400 +DATA_DIR = 'data/ILSVRC2012' + +img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) +img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) + + +def resize_short(img, target_size): + percent = float(target_size) / min(img.size[0], img.size[1]) + resized_width = int(round(img.size[0] * percent)) + resized_height = int(round(img.size[1] * percent)) + img = img.resize((resized_width, resized_height), Image.LANCZOS) + return img + + +def crop_image(img, target_size, center): + width, height = img.size + size = target_size + if center == True: + w_start = (width - size) / 2 + h_start = (height - size) / 2 + else: + w_start = np.random.randint(0, width - size + 1) + h_start = np.random.randint(0, height - size + 1) + w_end = w_start + size + h_end = h_start + size + img = img.crop((w_start, h_start, w_end, h_end)) + return img + + +def process_image(sample, mode, color_jitter, rotate): + img_path = sample[0] + img = Image.open(img_path) + img = resize_short(img, target_size=256) + img = crop_image(img, target_size=DATA_DIM, center=True) + if img.mode != 'RGB': + img = img.convert('RGB') + img = np.array(img).astype('float32').transpose((2, 0, 1)) / 255 + img -= img_mean + img /= img_std + return img, sample[1] + + +def _reader_creator(file_list, + mode, + shuffle=False, + color_jitter=False, + rotate=False, + data_dir=DATA_DIR): + + def reader(): + with open(file_list) as flist: + full_lines = [line.strip() for line in flist] + if shuffle: + np.random.shuffle(full_lines) + lines = full_lines + + for line in lines: + img_path, label = line.split() + img_path = os.path.join(data_dir, img_path) + if not os.path.exists(img_path): + continue + yield img_path, int(label) + + mapper = functools.partial(process_image, + mode=mode, + color_jitter=color_jitter, + rotate=rotate) + + return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE) + + +def val(data_dir=DATA_DIR): + file_list = os.path.join(data_dir, 'val_list.txt') + return _reader_creator(file_list, 'val', shuffle=False, data_dir=data_dir) + + +class TestPostTrainingQuantizationProgram(TestPostTrainingQuantization): + + def run_program(self, model_path, batch_size, infer_iterations): + image_shape = [3, 224, 224] + place = fluid.CPUPlace() + exe = fluid.Executor(place) + [infer_program, feed_dict, fetch_targets] = \ + fluid.io.load_inference_model(model_path, exe) + val_reader = paddle.batch(val(), batch_size) + iterations = infer_iterations + test_info = [] + cnt = 0 + periods = [] + for batch_id, data in enumerate(val_reader()): + image = np.array([x[0].reshape(image_shape) + for x in data]).astype("float32") + label = np.array([x[1] for x in data]).astype("int64") + label = label.reshape([-1, 1]) + + t1 = time.time() + _, acc1, _ = exe.run(infer_program, + feed={ + feed_dict[0]: image, + feed_dict[1]: label + }, + fetch_list=fetch_targets) + t2 = time.time() + period = t2 - t1 + periods.append(period) + + test_info.append(np.mean(acc1) * len(data)) + cnt += len(data) + + if (batch_id + 1) % 100 == 0: + print("{0} images,".format(batch_id + 1)) + sys.stdout.flush() + if (batch_id + 1) == iterations: + break + + throughput = cnt / np.sum(periods) + latency = np.average(periods) + acc1 = np.sum(test_info) / cnt + [infer_program, feed_dict, fetch_targets] = \ + fluid.io.load_inference_model(model_path, exe) + return (throughput, latency, acc1, infer_program, feed_dict, + fetch_targets) + + def generate_quantized_model( + self, + program, + quantizable_op_type, + feed_list, + fetch_list, + algo="KL", + round_type="round", + is_full_quantize=False, + is_use_cache_file=False, + is_optimize_model=False, + onnx_format=False, + ): + try: + os.system("mkdir " + self.int8_model) + except Exception as e: + print("Failed to create {} due to {}".format( + self.int8_model, str(e))) + sys.exit(-1) + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + scope = fluid.global_scope() + val_reader = val() + same_scale_tensor_list = [[ + 'batch_norm_3.tmp_2#/#1', 'batch_norm_4.tmp_2#*#1' + ], ['batch_norm_27.tmp_2', 'batch_norm_26.tmp_2']] + ptq = PostTrainingQuantizationProgram( + executor=exe, + program=program, + sample_generator=val_reader, + batch_nums=10, + algo=algo, + quantizable_op_type=quantizable_op_type, + round_type=round_type, + is_full_quantize=is_full_quantize, + optimize_model=is_optimize_model, + onnx_format=onnx_format, + is_use_cache_file=is_use_cache_file, + feed_list=feed_list, + fetch_list=fetch_list, + same_scale_tensor_list=same_scale_tensor_list) + ptq.quantize() + ptq.save_quantized_model(self.int8_model) + + def run_test(self, + model, + algo, + round_type, + data_urls, + data_md5s, + quantizable_op_type, + is_full_quantize, + is_use_cache_file, + is_optimize_model, + diff_threshold, + onnx_format=False): + infer_iterations = self.infer_iterations + batch_size = self.batch_size + sample_iterations = self.sample_iterations + + model_cache_folder = self.download_data(data_urls, data_md5s, model) + + print("Start FP32 inference for {0} on {1} images ...".format( + model, infer_iterations * batch_size)) + (fp32_throughput, fp32_latency, fp32_acc1, infer_program, feed_dict, + fetch_targets) = self.run_program( + os.path.join(model_cache_folder, "model"), batch_size, + infer_iterations) + print("Start INT8 post training quantization for {0} on {1} images ...". + format(model, sample_iterations * batch_size)) + self.generate_quantized_model(infer_program, quantizable_op_type, + feed_dict, fetch_targets, algo, + round_type, is_full_quantize, + is_use_cache_file, is_optimize_model, + onnx_format) + + print("Start INT8 inference for {0} on {1} images ...".format( + model, infer_iterations * batch_size)) + (int8_throughput, int8_latency, int8_acc1, _, _, + _) = self.run_program(self.int8_model, batch_size, infer_iterations) + + print("---Post training quantization of {} method---".format(algo)) + print( + "FP32 {0}: batch_size {1}, throughput {2} images/second, latency {3} second, accuracy {4}." + .format(model, batch_size, fp32_throughput, fp32_latency, + fp32_acc1)) + print( + "INT8 {0}: batch_size {1}, throughput {2} images/second, latency {3} second, accuracy {4}.\n" + .format(model, batch_size, int8_throughput, int8_latency, + int8_acc1)) + sys.stdout.flush() + + delta_value = fp32_acc1 - int8_acc1 + self.assertLess(delta_value, diff_threshold) + + +class TestPostTrainingProgramAbsMaxForResnet50( + TestPostTrainingQuantizationProgram): + + def test_post_training_abs_max_resnet50(self): + model = "ResNet-50" + algo = "abs_max" + round_type = "round" + data_urls = [ + 'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz' + ] + data_md5s = ['4a5194524823d9b76da6e738e1367881'] + quantizable_op_type = ["conv2d", "mul"] + is_full_quantize = False + is_use_cache_file = False + is_optimize_model = False + diff_threshold = 0.025 + self.run_test(model, algo, round_type, data_urls, data_md5s, + quantizable_op_type, is_full_quantize, is_use_cache_file, + is_optimize_model, diff_threshold) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/contrib/slim/tests/test_quantization_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/tests/test_quantization_mkldnn_pass.py index fbb1adefa11111622533b4b01688af1198dbb8b1..f6c6c4f20afff07bf406e95876996fb2cff8849f 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_quantization_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/tests/test_quantization_mkldnn_pass.py @@ -118,6 +118,11 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase): activation_quantize_type=activation_quant_type, weight_quantize_type=weight_quant_type) transform_pass.apply(main_graph) + transform_pass = QuantizationTransformPass( + scope=scope, + place=place, + activation_quantize_type=activation_quant_type, + weight_quantize_type=weight_quant_type) transform_pass.apply(test_graph) build_strategy = fluid.BuildStrategy() diff --git a/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py b/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py index ce06bd63a86289a6e44310433938a602f38b266f..9d61e67092e383360ffe247c1ce359cf110271f8 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py @@ -313,6 +313,12 @@ class TestQuantizationFreezePass(unittest.TestCase): weight_quantize_type=weight_quant_type, skip_pattern=quant_skip_pattern) transform_pass.apply(main_graph) + transform_pass = QuantizationTransformPass( + scope=scope, + place=place, + activation_quantize_type=activation_quant_type, + weight_quantize_type=weight_quant_type, + skip_pattern=quant_skip_pattern) transform_pass.apply(test_graph) dev_name = '_gpu_' if use_cuda else '_cpu_' if not for_ci: