未验证 提交 f201b465 编写于 作者: J juncaipeng 提交者: GitHub

Move pool2d to add_quant_dequant_pass, test=develop (#20586)

* move pool2d to add_quant_dequant_pass, test=develop
上级 efa10937
......@@ -26,7 +26,7 @@ __all__ = [
'AddQuantDequantPass'
]
_quantizable_op_list = ['conv2d', 'depthwise_conv2d', 'mul', 'pool2d']
_quantizable_op_list = ['conv2d', 'depthwise_conv2d', 'mul']
_fake_quant_op_list = [
'fake_quantize_abs_max', 'fake_quantize_range_abs_max',
......@@ -161,13 +161,11 @@ class QuantizationTransformPass(object):
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
def _quant_preprocess(op_node):
pool_skipped = op_node.op().has_attr("pooling_type") and \
op_node.op().attr("pooling_type") == 'avg'
user_skipped = isinstance(self._skip_pattern, str) and \
op_node.op().has_attr("op_namescope") and \
op_node.op().attr("op_namescope").find(self._skip_pattern) != -1
if pool_skipped or user_skipped:
if user_skipped:
op_node.op()._set_attr("skip_quant", True)
def _transform_forward(graph, op):
......@@ -1163,10 +1161,15 @@ class ScaleForInferencePass(object):
class AddQuantDequantPass(object):
def __init__(self, scope=None, place=None, moving_rate=0.9, quant_bits=8):
def __init__(self,
scope=None,
place=None,
moving_rate=0.9,
quant_bits=8,
skip_pattern='skip_quant'):
"""
This pass is used to add quant_dequant op for some ops, such as the
'elementwise_add' and 'average pool2d' op.
'elementwise_add' and 'pool2d' op.
"""
self._scope = scope
self._place = place
......@@ -1175,11 +1178,12 @@ class AddQuantDequantPass(object):
self._is_test = None
self._target_ops = ["elementwise_add", "pool2d"]
self._target_grad_ops = ['%s_grad' % (op) for op in self._target_ops]
self._skip_pattern = skip_pattern
def apply(self, graph):
"""
Add quant_dequant before some ops, such as the 'elementwise_add'
and 'average pool2d' op.
and 'pool2d' op.
Args:
graph(IrGraph): the target graph.
"""
......@@ -1191,6 +1195,11 @@ class AddQuantDequantPass(object):
for op_node in ops:
if op_node.name() in self._target_ops:
if isinstance(self._skip_pattern, str) and \
op_node.op().has_attr("op_namescope") and \
op_node.op().attr("op_namescope").find(self._skip_pattern) != -1:
continue
in_nodes_all_not_persistable = True
for input_name in op_node.input_arg_names():
in_node = graph._find_node_by_name(op_node.inputs,
......@@ -1201,10 +1210,6 @@ class AddQuantDequantPass(object):
if not in_nodes_all_not_persistable:
continue
if op_node.op().has_attr("pooling_type") and \
op_node.op().attr("pooling_type") == 'max':
continue
input_names = op_node.input_arg_names()
for input_name in input_names:
in_node = graph._find_node_by_name(op_node.inputs,
......
......@@ -42,7 +42,7 @@ def linear_fc(num):
return loss
def residual_block(num):
def residual_block(num, quant_skip_pattern=None):
def conv_bn_layer(input,
ch_out,
filter_size,
......@@ -67,8 +67,14 @@ def residual_block(num):
conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True)
short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None)
hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu')
pool = fluid.layers.pool2d(
input=hidden, pool_size=2, pool_type='avg', pool_stride=2)
if quant_skip_pattern:
with fluid.name_scope(quant_skip_pattern):
pool = fluid.layers.pool2d(
input=hidden, pool_size=2, pool_type='avg', pool_stride=2)
else:
pool = fluid.layers.pool2d(
input=hidden, pool_size=2, pool_type='avg', pool_stride=2)
fc = fluid.layers.fc(input=pool, size=10)
loss = fluid.layers.cross_entropy(input=fc, label=label)
loss = fluid.layers.mean(loss)
......@@ -134,7 +140,10 @@ class TestQuantizationTransformPass(unittest.TestCase):
arg_name.endswith('.quantized.dequantized'))
self.assertTrue(arg_name in quantized_ops)
def linear_fc_quant(self, activation_quant_type, for_ci=True):
def linear_fc_quant(self,
activation_quant_type,
weight_quantize_type,
for_ci=True):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
......@@ -146,7 +155,8 @@ class TestQuantizationTransformPass(unittest.TestCase):
transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(),
place=place,
activation_quantize_type=activation_quant_type)
activation_quantize_type=activation_quant_type,
weight_quantize_type=weight_quantize_type)
transform_pass.apply(graph)
if not for_ci:
marked_nodes = set()
......@@ -167,15 +177,19 @@ class TestQuantizationTransformPass(unittest.TestCase):
val_marked_nodes)
def test_linear_fc_quant_abs_max(self):
self.linear_fc_quant('abs_max', for_ci=True)
self.linear_fc_quant('abs_max', 'abs_max', for_ci=True)
def test_linear_fc_quant_range_abs_max(self):
self.linear_fc_quant('range_abs_max', for_ci=True)
self.linear_fc_quant('range_abs_max', 'abs_max', for_ci=True)
def test_linear_fc_quant_moving_average_abs_max(self):
self.linear_fc_quant('moving_average_abs_max', for_ci=True)
self.linear_fc_quant(
'moving_average_abs_max', 'channel_wise_abs_max', for_ci=True)
def residual_block_quant(self, activation_quant_type, for_ci=True):
def residual_block_quant(self,
activation_quant_type,
weight_quantize_type,
for_ci=True):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
......@@ -187,7 +201,8 @@ class TestQuantizationTransformPass(unittest.TestCase):
transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(),
place=place,
activation_quantize_type=activation_quant_type)
activation_quantize_type=activation_quant_type,
weight_quantize_type=weight_quantize_type)
transform_pass.apply(graph)
if not for_ci:
marked_nodes = set()
......@@ -208,13 +223,14 @@ class TestQuantizationTransformPass(unittest.TestCase):
val_marked_nodes)
def test_residual_block_abs_max(self):
self.residual_block_quant('abs_max', for_ci=True)
self.residual_block_quant('abs_max', 'abs_max', for_ci=True)
def test_residual_block_range_abs_max(self):
self.residual_block_quant('range_abs_max', for_ci=True)
self.residual_block_quant('range_abs_max', 'abs_max', for_ci=True)
def test_residual_block_moving_average_abs_max(self):
self.residual_block_quant('moving_average_abs_max', for_ci=True)
self.residual_block_quant(
'moving_average_abs_max', 'channel_wise_abs_max', for_ci=True)
class TestQuantizationFreezePass(unittest.TestCase):
......@@ -494,11 +510,14 @@ class TestAddQuantDequantPass(unittest.TestCase):
self._target_ops = {'elementwise_add', 'pool2d'}
self._target_grad_ops = {'elementwise_add_grad', 'pool2d_grad'}
def check_graph(self, graph):
def check_graph(self, graph, skip_pattern=None):
ops = graph.all_op_nodes()
for op_node in ops:
if op_node.name() in self._target_ops:
if skip_pattern and op_node.op().has_attr("op_namescope") and \
op_node.op().attr("op_namescope").find(skip_pattern) != -1:
continue
in_nodes_all_not_persistable = True
for input_name in op_node.input_arg_names():
in_node = graph._find_node_by_name(op_node.inputs,
......@@ -508,20 +527,15 @@ class TestAddQuantDequantPass(unittest.TestCase):
not in_node.persistable())
if not in_nodes_all_not_persistable:
continue
if op_node.op().has_attr("pooling_type") and \
op_node.op().attr("pooling_type") == 'max':
continue
input_names = op_node.input_arg_names()
for input_name in input_names:
self.assertTrue(input_name.endswith('.quant_dequant'))
def residual_block_quant(self, for_ci=True):
def residual_block_quant(self, skip_pattern=None, for_ci=True):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
loss = residual_block(1)
loss = residual_block(2, skip_pattern)
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
place = fluid.CPUPlace()
......@@ -535,7 +549,7 @@ class TestAddQuantDequantPass(unittest.TestCase):
if op.name().find('quant') > -1:
marked_nodes.add(op)
graph.draw('.', 'add_quant_dequant_graph', marked_nodes)
self.check_graph(graph)
self.check_graph(graph, skip_pattern)
program = graph.to_program()
val_graph = IrGraph(core.Graph(program.desc), for_test=False)
if not for_ci:
......@@ -546,7 +560,10 @@ class TestAddQuantDequantPass(unittest.TestCase):
val_graph.draw('.', 'val_add_quant_dequant_graph', val_marked_nodes)
def test_residual_block(self):
self.residual_block_quant(for_ci=True)
self.residual_block_quant(skip_pattern=None, for_ci=True)
def test_residual_block_skip_pattern(self):
self.residual_block_quant(skip_pattern='skip_quant', for_ci=True)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册