diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index ab3ca4cc1025911a5fbc12ffa6a91dd1d2fd0aa6..134529d7797251348a6d7eaceeaab8a3bd05c605 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -99,7 +99,7 @@ class QuantizationTransformPass(object): weight_quantize_type='abs_max', window_size=10000, moving_rate=0.9, - skip_pattern='skip_quant', + skip_pattern=['skip_quant'], quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']): """ Convert and rewrite the IrGraph according to weight and @@ -126,9 +126,9 @@ class QuantizationTransformPass(object): model is well trained. window_size(int): the window size for 'range_abs_max' quantization. moving_rate(float): the param for 'moving_average_abs_max' quantization. - skip_pattern(str): The user-defined quantization skip pattern, which + skip_pattern(str or str list): The user-defined quantization skip pattern, which will be presented in the name scope of an op. When the skip pattern is - detected in an op's name scope, the corresponding op will not be quantized. + detected in an op's name scope, the corresponding op will not be quantized. quantizable_op_type(list[str]): List the type of ops that will be quantized. Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in QuantizationFreezePass and ConvertToInt8Pass must be the same as this. @@ -206,9 +206,13 @@ class QuantizationTransformPass(object): persistable_vars = [p.name() for p in graph.all_persistable_nodes()] def _quant_preprocess(op_node): - user_skipped = isinstance(self._skip_pattern, str) and \ - op_node.op().has_attr("op_namescope") and \ - op_node.op().attr("op_namescope").find(self._skip_pattern) != -1 + user_skipped = False + if isinstance(self._skip_pattern, list): + user_skipped = op_node.op().has_attr("op_namescope") and \ + any(pattern in op_node.op().attr("op_namescope") for pattern in self._skip_pattern) + elif isinstance(self._skip_pattern, str): + user_skipped = op_node.op().has_attr("op_namescope") and \ + op_node.op().attr("op_namescope").find(self._skip_pattern) != -1 if user_skipped: op_node.op()._set_attr("skip_quant", True) @@ -1245,7 +1249,7 @@ class AddQuantDequantPass(object): place=None, moving_rate=0.9, quant_bits=8, - skip_pattern='skip_quant', + skip_pattern=["skip_quant"], quantizable_op_type=["elementwise_add", "pool2d", "concat"], is_full_quantized=False): """ @@ -1313,9 +1317,15 @@ class AddQuantDequantPass(object): all_op_nodes = graph.all_op_nodes() for op_node in all_op_nodes: if op_node.name() in self._quantizable_op_type: - if isinstance(self._skip_pattern, str) and \ - op_node.op().has_attr("op_namescope") and \ - op_node.op().attr("op_namescope").find(self._skip_pattern) != -1: + user_skipped = False + if isinstance(self._skip_pattern, list): + user_skipped = op_node.op().has_attr("op_namescope") and \ + any(pattern in op_node.op().attr("op_namescope") for pattern in self._skip_pattern) + elif isinstance(self._skip_pattern, str): + user_skipped = op_node.op().has_attr("op_namescope") and \ + op_node.op().attr("op_namescope").find(self._skip_pattern) != -1 + + if user_skipped: continue if not self._is_input_all_not_persistable(graph, op_node): diff --git a/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py b/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py index 3080a6e60d2cdab7c90c091c95b1b6952b1b980f..0141cc0f8ad847506f5e657e40ce0946fecf8144 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py @@ -531,7 +531,7 @@ def quant_dequant_residual_block(num, quant_skip_pattern=None): short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None) hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu') - if quant_skip_pattern: + if isinstance(quant_skip_pattern, str): with fluid.name_scope(quant_skip_pattern): pool1 = fluid.layers.pool2d( input=hidden, pool_size=2, pool_type='avg', pool_stride=2) @@ -539,6 +539,18 @@ def quant_dequant_residual_block(num, quant_skip_pattern=None): input=hidden, pool_size=2, pool_type='max', pool_stride=2) pool_add = fluid.layers.elementwise_add( x=pool1, y=pool2, act='relu') + elif isinstance(quant_skip_pattern, list): + assert len( + quant_skip_pattern + ) > 1, 'test config error: the len of quant_skip_pattern list should be greater than 1.' + with fluid.name_scope(quant_skip_pattern[0]): + pool1 = fluid.layers.pool2d( + input=hidden, pool_size=2, pool_type='avg', pool_stride=2) + pool2 = fluid.layers.pool2d( + input=hidden, pool_size=2, pool_type='max', pool_stride=2) + with fluid.name_scope(quant_skip_pattern[1]): + pool_add = fluid.layers.elementwise_add( + x=pool1, y=pool2, act='relu') else: pool1 = fluid.layers.pool2d( input=hidden, pool_size=2, pool_type='avg', pool_stride=2) @@ -560,8 +572,15 @@ class TestAddQuantDequantPass(unittest.TestCase): ops = graph.all_op_nodes() for op_node in ops: if op_node.name() in self._target_ops: - if skip_pattern and op_node.op().has_attr("op_namescope") and \ - op_node.op().attr("op_namescope").find(skip_pattern) != -1: + user_skipped = False + if isinstance(skip_pattern, list): + user_skipped = op_node.op().has_attr("op_namescope") and \ + any(pattern in op_node.op().attr("op_namescope") for pattern in skip_pattern) + elif isinstance(skip_pattern, str): + user_skipped = op_node.op().has_attr("op_namescope") and \ + op_node.op().attr("op_namescope").find(skip_pattern) != -1 + + if user_skipped: continue in_nodes_all_not_persistable = True @@ -587,7 +606,7 @@ class TestAddQuantDequantPass(unittest.TestCase): place = fluid.CPUPlace() graph = IrGraph(core.Graph(main.desc), for_test=False) add_quant_dequant_pass = AddQuantDequantPass( - scope=fluid.global_scope(), place=place) + scope=fluid.global_scope(), place=place, skip_pattern=skip_pattern) add_quant_dequant_pass.apply(graph) if not for_ci: marked_nodes = set() @@ -611,6 +630,10 @@ class TestAddQuantDequantPass(unittest.TestCase): def test_residual_block_skip_pattern(self): self.residual_block_quant(skip_pattern='skip_quant', for_ci=True) + def test_residual_block_skip_pattern(self): + self.residual_block_quant( + skip_pattern=['skip_quant1', 'skip_quant2'], for_ci=True) + if __name__ == '__main__': unittest.main()