提交 07e6a942 编写于 作者: I itminner 提交者: whs

paddleslim quantization skip pattern support list of string (#21141)

上级 d8e7d252
......@@ -99,7 +99,7 @@ class QuantizationTransformPass(object):
weight_quantize_type='abs_max',
window_size=10000,
moving_rate=0.9,
skip_pattern='skip_quant',
skip_pattern=['skip_quant'],
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']):
"""
Convert and rewrite the IrGraph according to weight and
......@@ -126,7 +126,7 @@ class QuantizationTransformPass(object):
model is well trained.
window_size(int): the window size for 'range_abs_max' quantization.
moving_rate(float): the param for 'moving_average_abs_max' quantization.
skip_pattern(str): The user-defined quantization skip pattern, which
skip_pattern(str or str list): The user-defined quantization skip pattern, which
will be presented in the name scope of an op. When the skip pattern is
detected in an op's name scope, the corresponding op will not be quantized.
quantizable_op_type(list[str]): List the type of ops that will be quantized.
......@@ -206,8 +206,12 @@ class QuantizationTransformPass(object):
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
def _quant_preprocess(op_node):
user_skipped = isinstance(self._skip_pattern, str) and \
op_node.op().has_attr("op_namescope") and \
user_skipped = False
if isinstance(self._skip_pattern, list):
user_skipped = op_node.op().has_attr("op_namescope") and \
any(pattern in op_node.op().attr("op_namescope") for pattern in self._skip_pattern)
elif isinstance(self._skip_pattern, str):
user_skipped = op_node.op().has_attr("op_namescope") and \
op_node.op().attr("op_namescope").find(self._skip_pattern) != -1
if user_skipped:
......@@ -1245,7 +1249,7 @@ class AddQuantDequantPass(object):
place=None,
moving_rate=0.9,
quant_bits=8,
skip_pattern='skip_quant',
skip_pattern=["skip_quant"],
quantizable_op_type=["elementwise_add", "pool2d", "concat"],
is_full_quantized=False):
"""
......@@ -1313,9 +1317,15 @@ class AddQuantDequantPass(object):
all_op_nodes = graph.all_op_nodes()
for op_node in all_op_nodes:
if op_node.name() in self._quantizable_op_type:
if isinstance(self._skip_pattern, str) and \
op_node.op().has_attr("op_namescope") and \
op_node.op().attr("op_namescope").find(self._skip_pattern) != -1:
user_skipped = False
if isinstance(self._skip_pattern, list):
user_skipped = op_node.op().has_attr("op_namescope") and \
any(pattern in op_node.op().attr("op_namescope") for pattern in self._skip_pattern)
elif isinstance(self._skip_pattern, str):
user_skipped = op_node.op().has_attr("op_namescope") and \
op_node.op().attr("op_namescope").find(self._skip_pattern) != -1
if user_skipped:
continue
if not self._is_input_all_not_persistable(graph, op_node):
......
......@@ -531,7 +531,7 @@ def quant_dequant_residual_block(num, quant_skip_pattern=None):
short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None)
hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu')
if quant_skip_pattern:
if isinstance(quant_skip_pattern, str):
with fluid.name_scope(quant_skip_pattern):
pool1 = fluid.layers.pool2d(
input=hidden, pool_size=2, pool_type='avg', pool_stride=2)
......@@ -539,6 +539,18 @@ def quant_dequant_residual_block(num, quant_skip_pattern=None):
input=hidden, pool_size=2, pool_type='max', pool_stride=2)
pool_add = fluid.layers.elementwise_add(
x=pool1, y=pool2, act='relu')
elif isinstance(quant_skip_pattern, list):
assert len(
quant_skip_pattern
) > 1, 'test config error: the len of quant_skip_pattern list should be greater than 1.'
with fluid.name_scope(quant_skip_pattern[0]):
pool1 = fluid.layers.pool2d(
input=hidden, pool_size=2, pool_type='avg', pool_stride=2)
pool2 = fluid.layers.pool2d(
input=hidden, pool_size=2, pool_type='max', pool_stride=2)
with fluid.name_scope(quant_skip_pattern[1]):
pool_add = fluid.layers.elementwise_add(
x=pool1, y=pool2, act='relu')
else:
pool1 = fluid.layers.pool2d(
input=hidden, pool_size=2, pool_type='avg', pool_stride=2)
......@@ -560,8 +572,15 @@ class TestAddQuantDequantPass(unittest.TestCase):
ops = graph.all_op_nodes()
for op_node in ops:
if op_node.name() in self._target_ops:
if skip_pattern and op_node.op().has_attr("op_namescope") and \
op_node.op().attr("op_namescope").find(skip_pattern) != -1:
user_skipped = False
if isinstance(skip_pattern, list):
user_skipped = op_node.op().has_attr("op_namescope") and \
any(pattern in op_node.op().attr("op_namescope") for pattern in skip_pattern)
elif isinstance(skip_pattern, str):
user_skipped = op_node.op().has_attr("op_namescope") and \
op_node.op().attr("op_namescope").find(skip_pattern) != -1
if user_skipped:
continue
in_nodes_all_not_persistable = True
......@@ -587,7 +606,7 @@ class TestAddQuantDequantPass(unittest.TestCase):
place = fluid.CPUPlace()
graph = IrGraph(core.Graph(main.desc), for_test=False)
add_quant_dequant_pass = AddQuantDequantPass(
scope=fluid.global_scope(), place=place)
scope=fluid.global_scope(), place=place, skip_pattern=skip_pattern)
add_quant_dequant_pass.apply(graph)
if not for_ci:
marked_nodes = set()
......@@ -611,6 +630,10 @@ class TestAddQuantDequantPass(unittest.TestCase):
def test_residual_block_skip_pattern(self):
self.residual_block_quant(skip_pattern='skip_quant', for_ci=True)
def test_residual_block_skip_pattern(self):
self.residual_block_quant(
skip_pattern=['skip_quant1', 'skip_quant2'], for_ci=True)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册