未验证 提交 3398f996 编写于 作者: Z Zhen Wang 提交者: GitHub

Adding AddQuantDequantPass for TensorRT int8 (#17529)

* add quant_dequant_pass, test=develop

* Add quant_dequant before some ops, such as the elementwise_add op. This is required by TensorRT. test=develop
上级 f9796b12
...@@ -22,7 +22,8 @@ from .... import unique_name ...@@ -22,7 +22,8 @@ from .... import unique_name
__all__ = [ __all__ = [
'QuantizationTransformPass', 'QuantizationFreezePass', 'ConvertToInt8Pass', 'QuantizationTransformPass', 'QuantizationFreezePass', 'ConvertToInt8Pass',
'TransformForMobilePass', 'ScaleForTrainingPass', 'ScaleForInferencePass' 'TransformForMobilePass', 'ScaleForTrainingPass', 'ScaleForInferencePass',
'AddQuantDequantPass'
] ]
...@@ -994,6 +995,8 @@ class ScaleForTrainingPass(object): ...@@ -994,6 +995,8 @@ class ScaleForTrainingPass(object):
Args: Args:
graph(IrGraph): the target graph. graph(IrGraph): the target graph.
""" """
assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.'
self._is_test = graph.is_test() self._is_test = graph.is_test()
ops = graph.all_op_nodes() ops = graph.all_op_nodes()
for op_node in ops: for op_node in ops:
...@@ -1099,6 +1102,8 @@ class ScaleForInferencePass(object): ...@@ -1099,6 +1102,8 @@ class ScaleForInferencePass(object):
Args: Args:
graph(IrGraph): the target graph. graph(IrGraph): the target graph.
""" """
assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.'
ops = graph.all_op_nodes() ops = graph.all_op_nodes()
for op_node in ops: for op_node in ops:
name = op_node.name() name = op_node.name()
...@@ -1117,3 +1122,137 @@ class ScaleForInferencePass(object): ...@@ -1117,3 +1122,137 @@ class ScaleForInferencePass(object):
Return the scale name for the var named `var_name`. Return the scale name for the var named `var_name`.
""" """
return "%s@scale" % (var_name) return "%s@scale" % (var_name)
class AddQuantDequantPass(object):
def __init__(self, scope=None, place=None, moving_rate=0.9, quant_bits=8):
"""
This pass is used to add quant_dequant op for some ops, such as the
`elementwise_add` op.
"""
self._scope = scope
self._place = place
self._moving_rate = moving_rate
self._quant_bits = quant_bits
self._is_test = None
self._target_ops = ["elementwise_add", "pool2d"]
def apply(self, graph):
"""
Add quant_dequant before some ops, such as the `elementwise_add` op. This
is required by TensorRT.
Args:
graph(IrGraph): the target graph.
"""
assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.'
self._is_test = graph.is_test()
ops = graph.all_op_nodes()
for op_node in ops:
name = op_node.name()
if name in self._target_ops:
in_nodes_all_not_persistable = True
for input_name in op_node.input_arg_names():
in_node = graph._find_node_by_name(op_node.inputs,
input_name)
in_nodes_all_not_persistable = (
in_nodes_all_not_persistable and
not in_node.persistable())
if not in_nodes_all_not_persistable:
continue
input_names = op_node.input_arg_names()
for input_name in input_names:
in_node = graph._find_node_by_name(op_node.inputs,
input_name)
quant_var_node, scale_var_node = self._inser_quant_dequant_moving_average_abs_max_op(
graph, in_node, self._quant_bits)
graph.update_input_link(in_node, quant_var_node, op_node)
graph.resolve_hazard()
return graph
def _inser_quant_dequant_moving_average_abs_max_op(self, graph, var_node,
quant_bits):
"""Insert fake_quantize_dequantize_moving_average_abs_max op.
"""
quant_var_node = graph.create_var_node(
name="{}.quant_dequant".format(var_node.name()),
var_type=var_node.type(),
shape=var_node.shape(),
var_dtype=var_node.dtype())
scale_in_node = graph.create_persistable_node(
name="{}.quant_dequant.scale".format(var_node.name()),
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=var_node.dtype())
data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
_init_var_node(
scale_in_node,
np.array(
[0.001], dtype=data_type),
self._scope,
self._place)
scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
ins = {'X': var_node, 'InScale': scale_in_node}
outs = {'Out': quant_var_node, 'OutScale': scale_out_node}
if not self._is_test:
state_in_node = graph.create_persistable_node(
name=unique_name.generate('quant_dequant.state'),
var_type=core.VarDesc.VarType.LOD_TENSOR,
var_dtype=var_node.dtype(),
shape=[1])
data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
_init_var_node(
state_in_node,
np.ones(
[1], dtype=data_type),
self._scope,
self._place)
accum_in_node = graph.create_persistable_node(
name=unique_name.generate('quant_dequant.accum'),
var_type=core.VarDesc.VarType.LOD_TENSOR,
var_dtype=var_node.dtype(),
shape=[1])
_init_var_node(
accum_in_node,
np.ones(
[1], dtype=data_type),
self._scope,
self._place)
state_out_node = graph.create_var_node_from_desc(state_in_node.var(
))
accum_out_node = graph.create_var_node_from_desc(accum_in_node.var(
))
ins['InState'] = state_in_node
ins['InAccum'] = accum_in_node
outs['OutState'] = state_out_node
outs['OutAccum'] = accum_out_node
attrs = {
'bit_length': quant_bits,
'moving_rate': self._moving_rate,
'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
}
quant_op_node = graph.create_op_node(
op_type='fake_quantize_dequantize_moving_average_abs_max',
attrs=attrs,
inputs=ins,
outputs=outs)
graph.link_to(var_node, quant_op_node)
graph.link_to(scale_in_node, quant_op_node)
graph.link_to(quant_op_node, quant_var_node)
graph.link_to(quant_op_node, scale_out_node)
if not self._is_test:
graph.link_to(state_in_node, quant_op_node)
graph.link_to(accum_in_node, quant_op_node)
graph.link_to(quant_op_node, state_out_node)
graph.link_to(quant_op_node, accum_out_node)
return quant_var_node, scale_out_node
...@@ -24,6 +24,7 @@ from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass ...@@ -24,6 +24,7 @@ from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
from paddle.fluid.contrib.slim.quantization import ScaleForTrainingPass from paddle.fluid.contrib.slim.quantization import ScaleForTrainingPass
from paddle.fluid.contrib.slim.quantization import ScaleForInferencePass from paddle.fluid.contrib.slim.quantization import ScaleForInferencePass
from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass
from paddle.fluid import core from paddle.fluid import core
os.environ["CUDA_VISIBLE_DEVICES"] = "0" os.environ["CUDA_VISIBLE_DEVICES"] = "0"
...@@ -98,6 +99,7 @@ class TestQuantizationScalePass(unittest.TestCase): ...@@ -98,6 +99,7 @@ class TestQuantizationScalePass(unittest.TestCase):
scope = fluid.Scope() scope = fluid.Scope()
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
exe.run(startup) exe.run(startup)
transform_pass = QuantizationTransformPass( transform_pass = QuantizationTransformPass(
scope=scope, scope=scope,
place=place, place=place,
...@@ -105,8 +107,14 @@ class TestQuantizationScalePass(unittest.TestCase): ...@@ -105,8 +107,14 @@ class TestQuantizationScalePass(unittest.TestCase):
weight_quantize_type=weight_quant_type) weight_quantize_type=weight_quant_type)
transform_pass.apply(main_graph) transform_pass.apply(main_graph)
transform_pass.apply(test_graph) transform_pass.apply(test_graph)
add_quant_dequant_pass = AddQuantDequantPass(scope=scope, place=place)
add_quant_dequant_pass.apply(main_graph)
add_quant_dequant_pass.apply(test_graph)
scale_training_pass = ScaleForTrainingPass(scope=scope, place=place) scale_training_pass = ScaleForTrainingPass(scope=scope, place=place)
scale_training_pass.apply(main_graph) scale_training_pass.apply(main_graph)
dev_name = '_gpu' if use_cuda else '_cpu' dev_name = '_gpu' if use_cuda else '_cpu'
if not for_ci: if not for_ci:
marked_nodes = set() marked_nodes = set()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册