未验证 提交 b0ceed6f 编写于 作者: J juncaipeng 提交者: GitHub

add fake_quant_dequant_op for average pool2d, test=develop (#19880)

* add fake_quant_dequant_op for average pool2d
* add test
上级 cb8f3c03
......@@ -90,6 +90,9 @@ class QuantizationTransformPass(object):
usually is not used for weight, since weights are fixed once the
model is well trained.
window_size (int): the window size for 'range_abs_max' quantization.
skip_pattern(str): The user-defined quantization skip pattern, which
will be presented in the name scope of an op. When the skip pattern is
detected in an op's name scope, the corresponding op will not be quantized.
Examples:
.. code-block:: python
......@@ -1163,29 +1166,31 @@ class AddQuantDequantPass(object):
def __init__(self, scope=None, place=None, moving_rate=0.9, quant_bits=8):
"""
This pass is used to add quant_dequant op for some ops, such as the
`elementwise_add` op.
'elementwise_add' and 'average pool2d' op.
"""
self._scope = scope
self._place = place
self._moving_rate = moving_rate
self._quant_bits = quant_bits
self._is_test = None
self._target_ops = ["elementwise_add"]
self._target_ops = ["elementwise_add", "pool2d"]
self._target_grad_ops = ['%s_grad' % (op) for op in self._target_ops]
def apply(self, graph):
"""
Add quant_dequant before some ops, such as the `elementwise_add` op. This
is required by TensorRT.
Add quant_dequant before some ops, such as the 'elementwise_add'
and 'average pool2d' op.
Args:
graph(IrGraph): the target graph.
"""
assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.'
self._is_test = graph.is_test()
dequantized_vars_map = collections.OrderedDict()
ops = graph.all_op_nodes()
for op_node in ops:
name = op_node.name()
if name in self._target_ops:
if op_node.name() in self._target_ops:
in_nodes_all_not_persistable = True
for input_name in op_node.input_arg_names():
in_node = graph._find_node_by_name(op_node.inputs,
......@@ -1195,13 +1200,31 @@ class AddQuantDequantPass(object):
not in_node.persistable())
if not in_nodes_all_not_persistable:
continue
if op_node.op().has_attr("pooling_type") and \
op_node.op().attr("pooling_type") == 'max':
continue
input_names = op_node.input_arg_names()
for input_name in input_names:
in_node = graph._find_node_by_name(op_node.inputs,
input_name)
quant_var_node, scale_var_node = self._inser_quant_dequant_moving_average_abs_max_op(
quant_var_node, scale_var_node = \
self._inser_quant_dequant_moving_average_abs_max_op(
graph, in_node, self._quant_bits)
dequantized_vars_map[input_name] = quant_var_node
graph.update_input_link(in_node, quant_var_node, op_node)
for op_node in ops:
if op_node.name() in self._target_grad_ops:
for input_name in op_node.input_arg_names():
if input_name in dequantized_vars_map:
in_node = graph._find_node_by_name(op_node.inputs,
input_name)
dequant_var_node = dequantized_vars_map[input_name]
graph.update_input_link(in_node, dequant_var_node,
op_node)
graph.resolve_hazard()
return graph
......
......@@ -24,6 +24,7 @@ from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass
from paddle.fluid.contrib.slim.quantization import TransformForMobilePass
from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass
from paddle.fluid import core
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
......@@ -66,7 +67,9 @@ def residual_block(num):
conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True)
short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None)
hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu')
fc = fluid.layers.fc(input=hidden, size=10)
pool = fluid.layers.pool2d(
input=hidden, pool_size=2, pool_type='avg', pool_stride=2)
fc = fluid.layers.fc(input=pool, size=10)
loss = fluid.layers.cross_entropy(input=fc, label=label)
loss = fluid.layers.mean(loss)
return loss
......@@ -486,5 +489,65 @@ class TestQuantizationFreezePass(unittest.TestCase):
for_ci=True)
class TestAddQuantDequantPass(unittest.TestCase):
def setUp(self):
self._target_ops = {'elementwise_add', 'pool2d'}
self._target_grad_ops = {'elementwise_add_grad', 'pool2d_grad'}
def check_graph(self, graph):
ops = graph.all_op_nodes()
for op_node in ops:
if op_node.name() in self._target_ops:
in_nodes_all_not_persistable = True
for input_name in op_node.input_arg_names():
in_node = graph._find_node_by_name(op_node.inputs,
input_name)
in_nodes_all_not_persistable = (
in_nodes_all_not_persistable and
not in_node.persistable())
if not in_nodes_all_not_persistable:
continue
if op_node.op().has_attr("pooling_type") and \
op_node.op().attr("pooling_type") == 'max':
continue
input_names = op_node.input_arg_names()
for input_name in input_names:
self.assertTrue(input_name.endswith('.quant_dequant'))
def residual_block_quant(self, for_ci=True):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
loss = residual_block(1)
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
place = fluid.CPUPlace()
graph = IrGraph(core.Graph(main.desc), for_test=False)
add_quant_dequant_pass = AddQuantDequantPass(
scope=fluid.global_scope(), place=place)
add_quant_dequant_pass.apply(graph)
if not for_ci:
marked_nodes = set()
for op in graph.all_op_nodes():
if op.name().find('quant') > -1:
marked_nodes.add(op)
graph.draw('.', 'add_quant_dequant_graph', marked_nodes)
self.check_graph(graph)
program = graph.to_program()
val_graph = IrGraph(core.Graph(program.desc), for_test=False)
if not for_ci:
val_marked_nodes = set()
for op in val_graph.all_op_nodes():
if op.name().find('quant') > -1:
val_marked_nodes.add(op)
val_graph.draw('.', 'val_add_quant_dequant_graph', val_marked_nodes)
def test_residual_block(self):
self.residual_block_quant(for_ci=True)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册