未验证 提交 965b45e8 编写于 作者: J juncaipeng 提交者: GitHub

Move pool2d to add_quant_dequant_pass, test=develop (#20586) (#20675)

* move pool2d to add_quant_dequant_pass, test=develop
上级 e083f149
...@@ -26,7 +26,7 @@ __all__ = [ ...@@ -26,7 +26,7 @@ __all__ = [
'AddQuantDequantPass' 'AddQuantDequantPass'
] ]
_quantizable_op_list = ['conv2d', 'depthwise_conv2d', 'mul', 'pool2d'] _quantizable_op_list = ['conv2d', 'depthwise_conv2d', 'mul']
_fake_quant_op_list = [ _fake_quant_op_list = [
'fake_quantize_abs_max', 'fake_quantize_range_abs_max', 'fake_quantize_abs_max', 'fake_quantize_range_abs_max',
...@@ -161,13 +161,11 @@ class QuantizationTransformPass(object): ...@@ -161,13 +161,11 @@ class QuantizationTransformPass(object):
persistable_vars = [p.name() for p in graph.all_persistable_nodes()] persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
def _quant_preprocess(op_node): def _quant_preprocess(op_node):
pool_skipped = op_node.op().has_attr("pooling_type") and \
op_node.op().attr("pooling_type") == 'avg'
user_skipped = isinstance(self._skip_pattern, str) and \ user_skipped = isinstance(self._skip_pattern, str) and \
op_node.op().has_attr("op_namescope") and \ op_node.op().has_attr("op_namescope") and \
op_node.op().attr("op_namescope").find(self._skip_pattern) != -1 op_node.op().attr("op_namescope").find(self._skip_pattern) != -1
if pool_skipped or user_skipped: if user_skipped:
op_node.op()._set_attr("skip_quant", True) op_node.op()._set_attr("skip_quant", True)
def _transform_forward(graph, op): def _transform_forward(graph, op):
...@@ -1163,10 +1161,15 @@ class ScaleForInferencePass(object): ...@@ -1163,10 +1161,15 @@ class ScaleForInferencePass(object):
class AddQuantDequantPass(object): class AddQuantDequantPass(object):
def __init__(self, scope=None, place=None, moving_rate=0.9, quant_bits=8): def __init__(self,
scope=None,
place=None,
moving_rate=0.9,
quant_bits=8,
skip_pattern='skip_quant'):
""" """
This pass is used to add quant_dequant op for some ops, such as the This pass is used to add quant_dequant op for some ops, such as the
'elementwise_add' and 'average pool2d' op. 'elementwise_add' and 'pool2d' op.
""" """
self._scope = scope self._scope = scope
self._place = place self._place = place
...@@ -1175,11 +1178,12 @@ class AddQuantDequantPass(object): ...@@ -1175,11 +1178,12 @@ class AddQuantDequantPass(object):
self._is_test = None self._is_test = None
self._target_ops = ["elementwise_add", "pool2d"] self._target_ops = ["elementwise_add", "pool2d"]
self._target_grad_ops = ['%s_grad' % (op) for op in self._target_ops] self._target_grad_ops = ['%s_grad' % (op) for op in self._target_ops]
self._skip_pattern = skip_pattern
def apply(self, graph): def apply(self, graph):
""" """
Add quant_dequant before some ops, such as the 'elementwise_add' Add quant_dequant before some ops, such as the 'elementwise_add'
and 'average pool2d' op. and 'pool2d' op.
Args: Args:
graph(IrGraph): the target graph. graph(IrGraph): the target graph.
""" """
...@@ -1191,6 +1195,11 @@ class AddQuantDequantPass(object): ...@@ -1191,6 +1195,11 @@ class AddQuantDequantPass(object):
for op_node in ops: for op_node in ops:
if op_node.name() in self._target_ops: if op_node.name() in self._target_ops:
if isinstance(self._skip_pattern, str) and \
op_node.op().has_attr("op_namescope") and \
op_node.op().attr("op_namescope").find(self._skip_pattern) != -1:
continue
in_nodes_all_not_persistable = True in_nodes_all_not_persistable = True
for input_name in op_node.input_arg_names(): for input_name in op_node.input_arg_names():
in_node = graph._find_node_by_name(op_node.inputs, in_node = graph._find_node_by_name(op_node.inputs,
...@@ -1201,10 +1210,6 @@ class AddQuantDequantPass(object): ...@@ -1201,10 +1210,6 @@ class AddQuantDequantPass(object):
if not in_nodes_all_not_persistable: if not in_nodes_all_not_persistable:
continue continue
if op_node.op().has_attr("pooling_type") and \
op_node.op().attr("pooling_type") == 'max':
continue
input_names = op_node.input_arg_names() input_names = op_node.input_arg_names()
for input_name in input_names: for input_name in input_names:
in_node = graph._find_node_by_name(op_node.inputs, in_node = graph._find_node_by_name(op_node.inputs,
......
...@@ -42,7 +42,7 @@ def linear_fc(num): ...@@ -42,7 +42,7 @@ def linear_fc(num):
return loss return loss
def residual_block(num): def residual_block(num, quant_skip_pattern=None):
def conv_bn_layer(input, def conv_bn_layer(input,
ch_out, ch_out,
filter_size, filter_size,
...@@ -67,8 +67,14 @@ def residual_block(num): ...@@ -67,8 +67,14 @@ def residual_block(num):
conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True) conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True)
short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None) short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None)
hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu') hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu')
pool = fluid.layers.pool2d(
input=hidden, pool_size=2, pool_type='avg', pool_stride=2) if quant_skip_pattern:
with fluid.name_scope(quant_skip_pattern):
pool = fluid.layers.pool2d(
input=hidden, pool_size=2, pool_type='avg', pool_stride=2)
else:
pool = fluid.layers.pool2d(
input=hidden, pool_size=2, pool_type='avg', pool_stride=2)
fc = fluid.layers.fc(input=pool, size=10) fc = fluid.layers.fc(input=pool, size=10)
loss = fluid.layers.cross_entropy(input=fc, label=label) loss = fluid.layers.cross_entropy(input=fc, label=label)
loss = fluid.layers.mean(loss) loss = fluid.layers.mean(loss)
...@@ -134,7 +140,10 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -134,7 +140,10 @@ class TestQuantizationTransformPass(unittest.TestCase):
arg_name.endswith('.quantized.dequantized')) arg_name.endswith('.quantized.dequantized'))
self.assertTrue(arg_name in quantized_ops) self.assertTrue(arg_name in quantized_ops)
def linear_fc_quant(self, activation_quant_type, for_ci=True): def linear_fc_quant(self,
activation_quant_type,
weight_quantize_type,
for_ci=True):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
...@@ -146,7 +155,8 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -146,7 +155,8 @@ class TestQuantizationTransformPass(unittest.TestCase):
transform_pass = QuantizationTransformPass( transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(), scope=fluid.global_scope(),
place=place, place=place,
activation_quantize_type=activation_quant_type) activation_quantize_type=activation_quant_type,
weight_quantize_type=weight_quantize_type)
transform_pass.apply(graph) transform_pass.apply(graph)
if not for_ci: if not for_ci:
marked_nodes = set() marked_nodes = set()
...@@ -167,15 +177,19 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -167,15 +177,19 @@ class TestQuantizationTransformPass(unittest.TestCase):
val_marked_nodes) val_marked_nodes)
def test_linear_fc_quant_abs_max(self): def test_linear_fc_quant_abs_max(self):
self.linear_fc_quant('abs_max', for_ci=True) self.linear_fc_quant('abs_max', 'abs_max', for_ci=True)
def test_linear_fc_quant_range_abs_max(self): def test_linear_fc_quant_range_abs_max(self):
self.linear_fc_quant('range_abs_max', for_ci=True) self.linear_fc_quant('range_abs_max', 'abs_max', for_ci=True)
def test_linear_fc_quant_moving_average_abs_max(self): def test_linear_fc_quant_moving_average_abs_max(self):
self.linear_fc_quant('moving_average_abs_max', for_ci=True) self.linear_fc_quant(
'moving_average_abs_max', 'channel_wise_abs_max', for_ci=True)
def residual_block_quant(self, activation_quant_type, for_ci=True): def residual_block_quant(self,
activation_quant_type,
weight_quantize_type,
for_ci=True):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
...@@ -187,7 +201,8 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -187,7 +201,8 @@ class TestQuantizationTransformPass(unittest.TestCase):
transform_pass = QuantizationTransformPass( transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(), scope=fluid.global_scope(),
place=place, place=place,
activation_quantize_type=activation_quant_type) activation_quantize_type=activation_quant_type,
weight_quantize_type=weight_quantize_type)
transform_pass.apply(graph) transform_pass.apply(graph)
if not for_ci: if not for_ci:
marked_nodes = set() marked_nodes = set()
...@@ -208,13 +223,14 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -208,13 +223,14 @@ class TestQuantizationTransformPass(unittest.TestCase):
val_marked_nodes) val_marked_nodes)
def test_residual_block_abs_max(self): def test_residual_block_abs_max(self):
self.residual_block_quant('abs_max', for_ci=True) self.residual_block_quant('abs_max', 'abs_max', for_ci=True)
def test_residual_block_range_abs_max(self): def test_residual_block_range_abs_max(self):
self.residual_block_quant('range_abs_max', for_ci=True) self.residual_block_quant('range_abs_max', 'abs_max', for_ci=True)
def test_residual_block_moving_average_abs_max(self): def test_residual_block_moving_average_abs_max(self):
self.residual_block_quant('moving_average_abs_max', for_ci=True) self.residual_block_quant(
'moving_average_abs_max', 'channel_wise_abs_max', for_ci=True)
class TestQuantizationFreezePass(unittest.TestCase): class TestQuantizationFreezePass(unittest.TestCase):
...@@ -494,11 +510,14 @@ class TestAddQuantDequantPass(unittest.TestCase): ...@@ -494,11 +510,14 @@ class TestAddQuantDequantPass(unittest.TestCase):
self._target_ops = {'elementwise_add', 'pool2d'} self._target_ops = {'elementwise_add', 'pool2d'}
self._target_grad_ops = {'elementwise_add_grad', 'pool2d_grad'} self._target_grad_ops = {'elementwise_add_grad', 'pool2d_grad'}
def check_graph(self, graph): def check_graph(self, graph, skip_pattern=None):
ops = graph.all_op_nodes() ops = graph.all_op_nodes()
for op_node in ops: for op_node in ops:
if op_node.name() in self._target_ops: if op_node.name() in self._target_ops:
if skip_pattern and op_node.op().has_attr("op_namescope") and \
op_node.op().attr("op_namescope").find(skip_pattern) != -1:
continue
in_nodes_all_not_persistable = True in_nodes_all_not_persistable = True
for input_name in op_node.input_arg_names(): for input_name in op_node.input_arg_names():
in_node = graph._find_node_by_name(op_node.inputs, in_node = graph._find_node_by_name(op_node.inputs,
...@@ -508,20 +527,15 @@ class TestAddQuantDequantPass(unittest.TestCase): ...@@ -508,20 +527,15 @@ class TestAddQuantDequantPass(unittest.TestCase):
not in_node.persistable()) not in_node.persistable())
if not in_nodes_all_not_persistable: if not in_nodes_all_not_persistable:
continue continue
if op_node.op().has_attr("pooling_type") and \
op_node.op().attr("pooling_type") == 'max':
continue
input_names = op_node.input_arg_names() input_names = op_node.input_arg_names()
for input_name in input_names: for input_name in input_names:
self.assertTrue(input_name.endswith('.quant_dequant')) self.assertTrue(input_name.endswith('.quant_dequant'))
def residual_block_quant(self, for_ci=True): def residual_block_quant(self, skip_pattern=None, for_ci=True):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
loss = residual_block(1) loss = residual_block(2, skip_pattern)
opt = fluid.optimizer.Adam(learning_rate=0.001) opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss) opt.minimize(loss)
place = fluid.CPUPlace() place = fluid.CPUPlace()
...@@ -535,7 +549,7 @@ class TestAddQuantDequantPass(unittest.TestCase): ...@@ -535,7 +549,7 @@ class TestAddQuantDequantPass(unittest.TestCase):
if op.name().find('quant') > -1: if op.name().find('quant') > -1:
marked_nodes.add(op) marked_nodes.add(op)
graph.draw('.', 'add_quant_dequant_graph', marked_nodes) graph.draw('.', 'add_quant_dequant_graph', marked_nodes)
self.check_graph(graph) self.check_graph(graph, skip_pattern)
program = graph.to_program() program = graph.to_program()
val_graph = IrGraph(core.Graph(program.desc), for_test=False) val_graph = IrGraph(core.Graph(program.desc), for_test=False)
if not for_ci: if not for_ci:
...@@ -546,7 +560,10 @@ class TestAddQuantDequantPass(unittest.TestCase): ...@@ -546,7 +560,10 @@ class TestAddQuantDequantPass(unittest.TestCase):
val_graph.draw('.', 'val_add_quant_dequant_graph', val_marked_nodes) val_graph.draw('.', 'val_add_quant_dequant_graph', val_marked_nodes)
def test_residual_block(self): def test_residual_block(self):
self.residual_block_quant(for_ci=True) self.residual_block_quant(skip_pattern=None, for_ci=True)
def test_residual_block_skip_pattern(self):
self.residual_block_quant(skip_pattern='skip_quant', for_ci=True)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册