未验证 提交 b0ceed6f 编写于 作者: J juncaipeng 提交者: GitHub

add fake_quant_dequant_op for average pool2d, test=develop (#19880)

* add fake_quant_dequant_op for average pool2d
* add test
上级 cb8f3c03
...@@ -90,6 +90,9 @@ class QuantizationTransformPass(object): ...@@ -90,6 +90,9 @@ class QuantizationTransformPass(object):
usually is not used for weight, since weights are fixed once the usually is not used for weight, since weights are fixed once the
model is well trained. model is well trained.
window_size (int): the window size for 'range_abs_max' quantization. window_size (int): the window size for 'range_abs_max' quantization.
skip_pattern(str): The user-defined quantization skip pattern, which
will be presented in the name scope of an op. When the skip pattern is
detected in an op's name scope, the corresponding op will not be quantized.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -1163,29 +1166,31 @@ class AddQuantDequantPass(object): ...@@ -1163,29 +1166,31 @@ class AddQuantDequantPass(object):
def __init__(self, scope=None, place=None, moving_rate=0.9, quant_bits=8): def __init__(self, scope=None, place=None, moving_rate=0.9, quant_bits=8):
""" """
This pass is used to add quant_dequant op for some ops, such as the This pass is used to add quant_dequant op for some ops, such as the
`elementwise_add` op. 'elementwise_add' and 'average pool2d' op.
""" """
self._scope = scope self._scope = scope
self._place = place self._place = place
self._moving_rate = moving_rate self._moving_rate = moving_rate
self._quant_bits = quant_bits self._quant_bits = quant_bits
self._is_test = None self._is_test = None
self._target_ops = ["elementwise_add"] self._target_ops = ["elementwise_add", "pool2d"]
self._target_grad_ops = ['%s_grad' % (op) for op in self._target_ops]
def apply(self, graph): def apply(self, graph):
""" """
Add quant_dequant before some ops, such as the `elementwise_add` op. This Add quant_dequant before some ops, such as the 'elementwise_add'
is required by TensorRT. and 'average pool2d' op.
Args: Args:
graph(IrGraph): the target graph. graph(IrGraph): the target graph.
""" """
assert isinstance(graph, assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.' IrGraph), 'graph must be the instance of IrGraph.'
self._is_test = graph.is_test() self._is_test = graph.is_test()
dequantized_vars_map = collections.OrderedDict()
ops = graph.all_op_nodes() ops = graph.all_op_nodes()
for op_node in ops: for op_node in ops:
name = op_node.name() if op_node.name() in self._target_ops:
if name in self._target_ops:
in_nodes_all_not_persistable = True in_nodes_all_not_persistable = True
for input_name in op_node.input_arg_names(): for input_name in op_node.input_arg_names():
in_node = graph._find_node_by_name(op_node.inputs, in_node = graph._find_node_by_name(op_node.inputs,
...@@ -1195,13 +1200,31 @@ class AddQuantDequantPass(object): ...@@ -1195,13 +1200,31 @@ class AddQuantDequantPass(object):
not in_node.persistable()) not in_node.persistable())
if not in_nodes_all_not_persistable: if not in_nodes_all_not_persistable:
continue continue
if op_node.op().has_attr("pooling_type") and \
op_node.op().attr("pooling_type") == 'max':
continue
input_names = op_node.input_arg_names() input_names = op_node.input_arg_names()
for input_name in input_names: for input_name in input_names:
in_node = graph._find_node_by_name(op_node.inputs, in_node = graph._find_node_by_name(op_node.inputs,
input_name) input_name)
quant_var_node, scale_var_node = self._inser_quant_dequant_moving_average_abs_max_op( quant_var_node, scale_var_node = \
self._inser_quant_dequant_moving_average_abs_max_op(
graph, in_node, self._quant_bits) graph, in_node, self._quant_bits)
dequantized_vars_map[input_name] = quant_var_node
graph.update_input_link(in_node, quant_var_node, op_node) graph.update_input_link(in_node, quant_var_node, op_node)
for op_node in ops:
if op_node.name() in self._target_grad_ops:
for input_name in op_node.input_arg_names():
if input_name in dequantized_vars_map:
in_node = graph._find_node_by_name(op_node.inputs,
input_name)
dequant_var_node = dequantized_vars_map[input_name]
graph.update_input_link(in_node, dequant_var_node,
op_node)
graph.resolve_hazard() graph.resolve_hazard()
return graph return graph
......
...@@ -24,6 +24,7 @@ from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass ...@@ -24,6 +24,7 @@ from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass
from paddle.fluid.contrib.slim.quantization import TransformForMobilePass from paddle.fluid.contrib.slim.quantization import TransformForMobilePass
from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass
from paddle.fluid import core from paddle.fluid import core
os.environ["CUDA_VISIBLE_DEVICES"] = "0" os.environ["CUDA_VISIBLE_DEVICES"] = "0"
...@@ -66,7 +67,9 @@ def residual_block(num): ...@@ -66,7 +67,9 @@ def residual_block(num):
conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True) conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True)
short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None) short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None)
hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu') hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu')
fc = fluid.layers.fc(input=hidden, size=10) pool = fluid.layers.pool2d(
input=hidden, pool_size=2, pool_type='avg', pool_stride=2)
fc = fluid.layers.fc(input=pool, size=10)
loss = fluid.layers.cross_entropy(input=fc, label=label) loss = fluid.layers.cross_entropy(input=fc, label=label)
loss = fluid.layers.mean(loss) loss = fluid.layers.mean(loss)
return loss return loss
...@@ -486,5 +489,65 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -486,5 +489,65 @@ class TestQuantizationFreezePass(unittest.TestCase):
for_ci=True) for_ci=True)
class TestAddQuantDequantPass(unittest.TestCase):
def setUp(self):
self._target_ops = {'elementwise_add', 'pool2d'}
self._target_grad_ops = {'elementwise_add_grad', 'pool2d_grad'}
def check_graph(self, graph):
ops = graph.all_op_nodes()
for op_node in ops:
if op_node.name() in self._target_ops:
in_nodes_all_not_persistable = True
for input_name in op_node.input_arg_names():
in_node = graph._find_node_by_name(op_node.inputs,
input_name)
in_nodes_all_not_persistable = (
in_nodes_all_not_persistable and
not in_node.persistable())
if not in_nodes_all_not_persistable:
continue
if op_node.op().has_attr("pooling_type") and \
op_node.op().attr("pooling_type") == 'max':
continue
input_names = op_node.input_arg_names()
for input_name in input_names:
self.assertTrue(input_name.endswith('.quant_dequant'))
def residual_block_quant(self, for_ci=True):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
loss = residual_block(1)
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
place = fluid.CPUPlace()
graph = IrGraph(core.Graph(main.desc), for_test=False)
add_quant_dequant_pass = AddQuantDequantPass(
scope=fluid.global_scope(), place=place)
add_quant_dequant_pass.apply(graph)
if not for_ci:
marked_nodes = set()
for op in graph.all_op_nodes():
if op.name().find('quant') > -1:
marked_nodes.add(op)
graph.draw('.', 'add_quant_dequant_graph', marked_nodes)
self.check_graph(graph)
program = graph.to_program()
val_graph = IrGraph(core.Graph(program.desc), for_test=False)
if not for_ci:
val_marked_nodes = set()
for op in val_graph.all_op_nodes():
if op.name().find('quant') > -1:
val_marked_nodes.add(op)
val_graph.draw('.', 'val_add_quant_dequant_graph', val_marked_nodes)
def test_residual_block(self):
self.residual_block_quant(for_ci=True)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册