提交 07e6a942 编写于 作者: I itminner 提交者: whs

paddleslim quantization skip pattern support list of string (#21141)

上级 d8e7d252
...@@ -99,7 +99,7 @@ class QuantizationTransformPass(object): ...@@ -99,7 +99,7 @@ class QuantizationTransformPass(object):
weight_quantize_type='abs_max', weight_quantize_type='abs_max',
window_size=10000, window_size=10000,
moving_rate=0.9, moving_rate=0.9,
skip_pattern='skip_quant', skip_pattern=['skip_quant'],
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']): quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']):
""" """
Convert and rewrite the IrGraph according to weight and Convert and rewrite the IrGraph according to weight and
...@@ -126,9 +126,9 @@ class QuantizationTransformPass(object): ...@@ -126,9 +126,9 @@ class QuantizationTransformPass(object):
model is well trained. model is well trained.
window_size(int): the window size for 'range_abs_max' quantization. window_size(int): the window size for 'range_abs_max' quantization.
moving_rate(float): the param for 'moving_average_abs_max' quantization. moving_rate(float): the param for 'moving_average_abs_max' quantization.
skip_pattern(str): The user-defined quantization skip pattern, which skip_pattern(str or str list): The user-defined quantization skip pattern, which
will be presented in the name scope of an op. When the skip pattern is will be presented in the name scope of an op. When the skip pattern is
detected in an op's name scope, the corresponding op will not be quantized. detected in an op's name scope, the corresponding op will not be quantized.
quantizable_op_type(list[str]): List the type of ops that will be quantized. quantizable_op_type(list[str]): List the type of ops that will be quantized.
Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in
QuantizationFreezePass and ConvertToInt8Pass must be the same as this. QuantizationFreezePass and ConvertToInt8Pass must be the same as this.
...@@ -206,9 +206,13 @@ class QuantizationTransformPass(object): ...@@ -206,9 +206,13 @@ class QuantizationTransformPass(object):
persistable_vars = [p.name() for p in graph.all_persistable_nodes()] persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
def _quant_preprocess(op_node): def _quant_preprocess(op_node):
user_skipped = isinstance(self._skip_pattern, str) and \ user_skipped = False
op_node.op().has_attr("op_namescope") and \ if isinstance(self._skip_pattern, list):
op_node.op().attr("op_namescope").find(self._skip_pattern) != -1 user_skipped = op_node.op().has_attr("op_namescope") and \
any(pattern in op_node.op().attr("op_namescope") for pattern in self._skip_pattern)
elif isinstance(self._skip_pattern, str):
user_skipped = op_node.op().has_attr("op_namescope") and \
op_node.op().attr("op_namescope").find(self._skip_pattern) != -1
if user_skipped: if user_skipped:
op_node.op()._set_attr("skip_quant", True) op_node.op()._set_attr("skip_quant", True)
...@@ -1245,7 +1249,7 @@ class AddQuantDequantPass(object): ...@@ -1245,7 +1249,7 @@ class AddQuantDequantPass(object):
place=None, place=None,
moving_rate=0.9, moving_rate=0.9,
quant_bits=8, quant_bits=8,
skip_pattern='skip_quant', skip_pattern=["skip_quant"],
quantizable_op_type=["elementwise_add", "pool2d", "concat"], quantizable_op_type=["elementwise_add", "pool2d", "concat"],
is_full_quantized=False): is_full_quantized=False):
""" """
...@@ -1313,9 +1317,15 @@ class AddQuantDequantPass(object): ...@@ -1313,9 +1317,15 @@ class AddQuantDequantPass(object):
all_op_nodes = graph.all_op_nodes() all_op_nodes = graph.all_op_nodes()
for op_node in all_op_nodes: for op_node in all_op_nodes:
if op_node.name() in self._quantizable_op_type: if op_node.name() in self._quantizable_op_type:
if isinstance(self._skip_pattern, str) and \ user_skipped = False
op_node.op().has_attr("op_namescope") and \ if isinstance(self._skip_pattern, list):
op_node.op().attr("op_namescope").find(self._skip_pattern) != -1: user_skipped = op_node.op().has_attr("op_namescope") and \
any(pattern in op_node.op().attr("op_namescope") for pattern in self._skip_pattern)
elif isinstance(self._skip_pattern, str):
user_skipped = op_node.op().has_attr("op_namescope") and \
op_node.op().attr("op_namescope").find(self._skip_pattern) != -1
if user_skipped:
continue continue
if not self._is_input_all_not_persistable(graph, op_node): if not self._is_input_all_not_persistable(graph, op_node):
......
...@@ -531,7 +531,7 @@ def quant_dequant_residual_block(num, quant_skip_pattern=None): ...@@ -531,7 +531,7 @@ def quant_dequant_residual_block(num, quant_skip_pattern=None):
short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None) short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None)
hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu') hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu')
if quant_skip_pattern: if isinstance(quant_skip_pattern, str):
with fluid.name_scope(quant_skip_pattern): with fluid.name_scope(quant_skip_pattern):
pool1 = fluid.layers.pool2d( pool1 = fluid.layers.pool2d(
input=hidden, pool_size=2, pool_type='avg', pool_stride=2) input=hidden, pool_size=2, pool_type='avg', pool_stride=2)
...@@ -539,6 +539,18 @@ def quant_dequant_residual_block(num, quant_skip_pattern=None): ...@@ -539,6 +539,18 @@ def quant_dequant_residual_block(num, quant_skip_pattern=None):
input=hidden, pool_size=2, pool_type='max', pool_stride=2) input=hidden, pool_size=2, pool_type='max', pool_stride=2)
pool_add = fluid.layers.elementwise_add( pool_add = fluid.layers.elementwise_add(
x=pool1, y=pool2, act='relu') x=pool1, y=pool2, act='relu')
elif isinstance(quant_skip_pattern, list):
assert len(
quant_skip_pattern
) > 1, 'test config error: the len of quant_skip_pattern list should be greater than 1.'
with fluid.name_scope(quant_skip_pattern[0]):
pool1 = fluid.layers.pool2d(
input=hidden, pool_size=2, pool_type='avg', pool_stride=2)
pool2 = fluid.layers.pool2d(
input=hidden, pool_size=2, pool_type='max', pool_stride=2)
with fluid.name_scope(quant_skip_pattern[1]):
pool_add = fluid.layers.elementwise_add(
x=pool1, y=pool2, act='relu')
else: else:
pool1 = fluid.layers.pool2d( pool1 = fluid.layers.pool2d(
input=hidden, pool_size=2, pool_type='avg', pool_stride=2) input=hidden, pool_size=2, pool_type='avg', pool_stride=2)
...@@ -560,8 +572,15 @@ class TestAddQuantDequantPass(unittest.TestCase): ...@@ -560,8 +572,15 @@ class TestAddQuantDequantPass(unittest.TestCase):
ops = graph.all_op_nodes() ops = graph.all_op_nodes()
for op_node in ops: for op_node in ops:
if op_node.name() in self._target_ops: if op_node.name() in self._target_ops:
if skip_pattern and op_node.op().has_attr("op_namescope") and \ user_skipped = False
op_node.op().attr("op_namescope").find(skip_pattern) != -1: if isinstance(skip_pattern, list):
user_skipped = op_node.op().has_attr("op_namescope") and \
any(pattern in op_node.op().attr("op_namescope") for pattern in skip_pattern)
elif isinstance(skip_pattern, str):
user_skipped = op_node.op().has_attr("op_namescope") and \
op_node.op().attr("op_namescope").find(skip_pattern) != -1
if user_skipped:
continue continue
in_nodes_all_not_persistable = True in_nodes_all_not_persistable = True
...@@ -587,7 +606,7 @@ class TestAddQuantDequantPass(unittest.TestCase): ...@@ -587,7 +606,7 @@ class TestAddQuantDequantPass(unittest.TestCase):
place = fluid.CPUPlace() place = fluid.CPUPlace()
graph = IrGraph(core.Graph(main.desc), for_test=False) graph = IrGraph(core.Graph(main.desc), for_test=False)
add_quant_dequant_pass = AddQuantDequantPass( add_quant_dequant_pass = AddQuantDequantPass(
scope=fluid.global_scope(), place=place) scope=fluid.global_scope(), place=place, skip_pattern=skip_pattern)
add_quant_dequant_pass.apply(graph) add_quant_dequant_pass.apply(graph)
if not for_ci: if not for_ci:
marked_nodes = set() marked_nodes = set()
...@@ -611,6 +630,10 @@ class TestAddQuantDequantPass(unittest.TestCase): ...@@ -611,6 +630,10 @@ class TestAddQuantDequantPass(unittest.TestCase):
def test_residual_block_skip_pattern(self): def test_residual_block_skip_pattern(self):
self.residual_block_quant(skip_pattern='skip_quant', for_ci=True) self.residual_block_quant(skip_pattern='skip_quant', for_ci=True)
def test_residual_block_skip_pattern(self):
self.residual_block_quant(
skip_pattern=['skip_quant1', 'skip_quant2'], for_ci=True)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册