未验证 提交 8f7372ca 编写于 作者: J juncaipeng 提交者: GitHub

add mul and matmul quantization, test=develop (#22054)

* add mul and matmul quantization, test=develop
* add test for matmul, test=develop
上级 73733498
...@@ -244,7 +244,9 @@ class PostTrainingQuantization(object): ...@@ -244,7 +244,9 @@ class PostTrainingQuantization(object):
drop_last=True, drop_last=True,
places=self._place) places=self._place)
# collect the variable names for sampling # collect the variable names for sampling.
# TODO(juncaipeng), consider the name_scope of skip_quant and
# reduce the variables for sampling
persistable_var_names = [] persistable_var_names = []
for var in self._program.list_vars(): for var in self._program.list_vars():
if var.persistable: if var.persistable:
...@@ -257,15 +259,17 @@ class PostTrainingQuantization(object): ...@@ -257,15 +259,17 @@ class PostTrainingQuantization(object):
self._quantized_act_var_name.add(op.input("Input")[0]) self._quantized_act_var_name.add(op.input("Input")[0])
self._quantized_weight_var_name.add(op.input("Filter")[0]) self._quantized_weight_var_name.add(op.input("Filter")[0])
self._quantized_act_var_name.add(op.output("Output")[0]) self._quantized_act_var_name.add(op.output("Output")[0])
elif op_type == "mul": elif op_type in ["mul", "matmul"]:
if self._is_input_all_not_persistable( x_var_name = op.input("X")[0]
op, persistable_var_names): if x_var_name in persistable_var_names:
op._set_attr("skip_quant", True) self._quantized_weight_var_name.add(x_var_name)
_logger.warning("Skip quant a mul op for two " else:
"input variables are not persistable") self._quantized_act_var_name.add(x_var_name)
y_var_name = op.input("Y")[0]
if y_var_name in persistable_var_names:
self._quantized_weight_var_name.add(y_var_name)
else: else:
self._quantized_act_var_name.add(op.input("X")[0]) self._quantized_act_var_name.add(y_var_name)
self._quantized_weight_var_name.add(op.input("Y")[0])
self._quantized_act_var_name.add(op.output("Out")[0]) self._quantized_act_var_name.add(op.output("Out")[0])
else: else:
# process other quantizable op type, the input must all not persistable # process other quantizable op type, the input must all not persistable
......
...@@ -46,6 +46,7 @@ _op_real_in_out_name = { ...@@ -46,6 +46,7 @@ _op_real_in_out_name = {
"conv2d": [["Input", "Filter"], ["Output"]], "conv2d": [["Input", "Filter"], ["Output"]],
"depthwise_conv2d": [["Input"], ["Output"]], "depthwise_conv2d": [["Input"], ["Output"]],
"mul": [["X", "Y"], ["Out"]], "mul": [["X", "Y"], ["Out"]],
"matmul": [["X", "Y"], ["Out"]],
"pool2d": [["X"], ["Out"]], "pool2d": [["X"], ["Out"]],
"elementwise_add": [["X", "Y"], ["Out"]], "elementwise_add": [["X", "Y"], ["Out"]],
"concat": [["X"], ["Out"]], "concat": [["X"], ["Out"]],
...@@ -87,8 +88,25 @@ def _init_var_node(var_node, value, scope, place): ...@@ -87,8 +88,25 @@ def _init_var_node(var_node, value, scope, place):
tensor.set(value, place) tensor.set(value, place)
def _is_input_all_not_persistable(graph, op_node):
'''
Analyse the real inputs of the op node are all not persistable.
'''
is_input_all_not_persistable = True
op_node_name = op_node.name()
input_name_list = _op_real_in_out_name[op_node_name][0]
for input_name in input_name_list:
for arg_name in op_node.input(input_name):
in_node = graph._find_node_by_name(op_node.inputs, arg_name)
is_input_all_not_persistable = (is_input_all_not_persistable and \
(not in_node.persistable()))
return is_input_all_not_persistable
class QuantizationTransformPass(object): class QuantizationTransformPass(object):
_supported_quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul'] _supported_quantizable_op_type = [
'conv2d', 'depthwise_conv2d', 'mul', 'matmul'
]
def __init__(self, def __init__(self,
scope=None, scope=None,
...@@ -252,17 +270,12 @@ class QuantizationTransformPass(object): ...@@ -252,17 +270,12 @@ class QuantizationTransformPass(object):
graph.update_input_link(var_node, dequant_var_node, op) graph.update_input_link(var_node, dequant_var_node, op)
def _transform_backward(graph, op): def _transform_backward(graph, op):
no_dequanted_input_vars = True
for var_node in op.inputs: for var_node in op.inputs:
if var_node.name() not in op.input_arg_names(): if var_node.name() not in op.input_arg_names():
continue continue
if var_node.name() in dequantized_vars: if var_node.name() in dequantized_vars:
dequant_var_node = dequantized_vars[var_node.name()] dequant_var_node = dequantized_vars[var_node.name()]
graph.update_input_link(var_node, dequant_var_node, op) graph.update_input_link(var_node, dequant_var_node, op)
no_dequanted_input_vars = False
if no_dequanted_input_vars:
raise ValueError("There is no dequanted inputs for op %s." %
(op.name()))
if not self._is_test: if not self._is_test:
self._create_global_step(graph) self._create_global_step(graph)
...@@ -277,18 +290,11 @@ class QuantizationTransformPass(object): ...@@ -277,18 +290,11 @@ class QuantizationTransformPass(object):
# The loop for transforming the forward graph: # The loop for transforming the forward graph:
for op in ops: for op in ops:
if op.name() in self._quantizable_ops: if op.name() in self._quantizable_ops:
skipped = op.op().has_attr("skip_quant") and \ if not QuantizationTransformPass._is_skip_quant(graph, op):
op.op().attr("skip_quant")
if skipped:
continue
_transform_forward(graph, op) _transform_forward(graph, op)
# The loop for renaming the inputs of backward op. # The loop for renaming the inputs of backward op.
for op in ops: for op in ops:
if op.name() in self._quantizable_grad_ops: if op.name() in self._quantizable_grad_ops:
skipped = op.op().has_attr("skip_quant") and \
op.op().attr("skip_quant")
if skipped:
continue
_transform_backward(graph, op) _transform_backward(graph, op)
graph.resolve_hazard() graph.resolve_hazard()
return graph return graph
...@@ -630,6 +636,22 @@ class QuantizationTransformPass(object): ...@@ -630,6 +636,22 @@ class QuantizationTransformPass(object):
""" """
return "%s.scale" % (var_name) return "%s.scale" % (var_name)
@staticmethod
def _is_skip_quant(graph, op_node):
"""
Analyse whether the op node skips quantization.
"""
is_skip = False
if op_node.op().has_attr("skip_quant") and \
op_node.op().attr("skip_quant"):
is_skip = True
# if the inputs of mul and matmul are not all persistable, use
# AddQuantDequantPass to quantize them.
if op_node.name() in ["mul", "matmul"] and \
_is_input_all_not_persistable(graph, op_node):
is_skip = True
return is_skip
class QuantizationFreezePass(object): class QuantizationFreezePass(object):
_supported_quantizable_op_type = \ _supported_quantizable_op_type = \
...@@ -733,10 +755,13 @@ class QuantizationFreezePass(object): ...@@ -733,10 +755,13 @@ class QuantizationFreezePass(object):
for op_node in ops: for op_node in ops:
op_name = op_node.name() op_name = op_node.name()
if op_name in self._quantizable_ops: if op_name in self._quantizable_ops:
skipped = op_node.op().has_attr("skip_quant") and \ # only process the node that is quantized by QuantizationTransformPass
op_node.op().attr("skip_quant") is_op_node_quantized = False
if skipped: for var_node in op_node.inputs:
continue var_name = var_node.name()
if var_name.endswith('.dequantized'):
is_op_node_quantized = True
if is_op_node_quantized:
if self._weight_quantize_type == 'channel_wise_abs_max' and op_name in self._conv_ops: if self._weight_quantize_type == 'channel_wise_abs_max' and op_name in self._conv_ops:
self._insert_post_channel_dequant_op(graph, op_node) self._insert_post_channel_dequant_op(graph, op_node)
else: else:
...@@ -829,10 +854,6 @@ class QuantizationFreezePass(object): ...@@ -829,10 +854,6 @@ class QuantizationFreezePass(object):
def _insert_post_dequant_op(self, graph, op_node): def _insert_post_dequant_op(self, graph, op_node):
persistable_vars = [p.name() for p in graph.all_persistable_nodes()] persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
if len(op_node.input_arg_names()) >= 2 and len(persistable_vars) == 0:
raise ValueError("The op %s has more than one inputs "
"and all of them are not persistable. "
"Now, it is not supported!" % (op_node.name()))
max_range = 1 max_range = 1
param_range = (1 << (self._weight_bits - 1)) - 1 param_range = (1 << (self._weight_bits - 1)) - 1
act_range = (1 << (self._activation_bits - 1)) - 1 act_range = (1 << (self._activation_bits - 1)) - 1
...@@ -987,9 +1008,7 @@ class ConvertToInt8Pass(object): ...@@ -987,9 +1008,7 @@ class ConvertToInt8Pass(object):
for op_node in ops: for op_node in ops:
op_name = op_node.name() op_name = op_node.name()
if op_name in self._quantizable_ops: if op_name in self._quantizable_ops:
skipped = op_node.op().has_attr("skip_quant") and \ if QuantizationTransformPass._is_skip_quant(graph, op_node):
op_node.op().attr("skip_quant")
if skipped:
continue continue
for var_node in op_node.inputs: for var_node in op_node.inputs:
name = var_node.name() name = var_node.name()
...@@ -1240,7 +1259,7 @@ class AddQuantDequantPass(object): ...@@ -1240,7 +1259,7 @@ class AddQuantDequantPass(object):
"equal", "gather", "greater_equal", "greater_than", "less_equal", "equal", "gather", "greater_equal", "greater_than", "less_equal",
"less_than", "mean", "not_equal", "reshape", "reshape2", "less_than", "mean", "not_equal", "reshape", "reshape2",
"bilinear_interp", "nearest_interp", "trilinear_interp", "slice", "bilinear_interp", "nearest_interp", "trilinear_interp", "slice",
"squeeze", "elementwise_sub" "squeeze", "elementwise_sub", "mul", "matmul"
] ]
_activation_type = ["relu", "relu6", "leaky_relu", "tanh", "swish"] _activation_type = ["relu", "relu6", "leaky_relu", "tanh", "swish"]
...@@ -1317,25 +1336,30 @@ class AddQuantDequantPass(object): ...@@ -1317,25 +1336,30 @@ class AddQuantDequantPass(object):
all_op_nodes = graph.all_op_nodes() all_op_nodes = graph.all_op_nodes()
for op_node in all_op_nodes: for op_node in all_op_nodes:
if op_node.name() in self._quantizable_op_type: if op_node.name() in self._quantizable_op_type:
user_skipped = False is_skip = False
if isinstance(self._skip_pattern, list): if isinstance(self._skip_pattern, list):
user_skipped = op_node.op().has_attr("op_namescope") and \ is_skip = op_node.op().has_attr("op_namescope") and \
any(pattern in op_node.op().attr("op_namescope") for pattern in self._skip_pattern) any(pattern in op_node.op().attr("op_namescope") for pattern in self._skip_pattern)
elif isinstance(self._skip_pattern, str): elif isinstance(self._skip_pattern, str):
user_skipped = op_node.op().has_attr("op_namescope") and \ is_skip = op_node.op().has_attr("op_namescope") and \
op_node.op().attr("op_namescope").find(self._skip_pattern) != -1 op_node.op().attr("op_namescope").find(self._skip_pattern) != -1
if user_skipped: is_op_node_quantized = False
continue for var_node in op_node.inputs:
var_name = var_node.name()
if var_name.endswith('.dequantized'):
is_op_node_quantized = True
if not self._is_input_all_not_persistable(graph, op_node): if is_skip or is_op_node_quantized or \
(not _is_input_all_not_persistable(graph, op_node)):
continue continue
input_name_list = _op_real_in_out_name[op_node.name()][0] input_name_list = _op_real_in_out_name[op_node.name()][0]
arg_names = []
for input_name in input_name_list: for input_name in input_name_list:
for arg_name in op_node.input(input_name): arg_names.extend(op_node.input(input_name))
in_node = graph._find_node_by_name(op_node.inputs, for arg_name in arg_names:
arg_name) in_node = graph._find_node_by_name(op_node.inputs, arg_name)
if arg_name in dequantized_vars_map: if arg_name in dequantized_vars_map:
quant_var_node = dequantized_vars_map[arg_name] quant_var_node = dequantized_vars_map[arg_name]
else: else:
...@@ -1343,8 +1367,7 @@ class AddQuantDequantPass(object): ...@@ -1343,8 +1367,7 @@ class AddQuantDequantPass(object):
self._inser_quant_dequant_moving_average_abs_max_op( self._inser_quant_dequant_moving_average_abs_max_op(
graph, in_node, self._quant_bits) graph, in_node, self._quant_bits)
dequantized_vars_map[arg_name] = quant_var_node dequantized_vars_map[arg_name] = quant_var_node
graph.update_input_link(in_node, quant_var_node, graph.update_input_link(in_node, quant_var_node, op_node)
op_node)
# Backward stage, update input link # Backward stage, update input link
for op_node in all_op_nodes: for op_node in all_op_nodes:
...@@ -1360,21 +1383,6 @@ class AddQuantDequantPass(object): ...@@ -1360,21 +1383,6 @@ class AddQuantDequantPass(object):
graph.resolve_hazard() graph.resolve_hazard()
return graph return graph
def _is_input_all_not_persistable(self, graph, op_node):
'''
Analyse the real inputs of the op node are all not persistable.
'''
is_input_all_not_persistable = True
op_node_name = op_node.name()
input_name_list = _op_real_in_out_name[op_node_name][0]
for input_name in input_name_list:
for arg_name in op_node.input(input_name):
in_node = graph._find_node_by_name(op_node.inputs, arg_name)
is_input_all_not_persistable = (is_input_all_not_persistable and \
(not in_node.persistable()))
return is_input_all_not_persistable
def _inser_quant_dequant_moving_average_abs_max_op(self, graph, var_node, def _inser_quant_dequant_moving_average_abs_max_op(self, graph, var_node,
quant_bits): quant_bits):
"""Insert fake_quantize_dequantize_moving_average_abs_max op. """Insert fake_quantize_dequantize_moving_average_abs_max op.
......
...@@ -60,14 +60,21 @@ def residual_block(num, quant_skip_pattern=None): ...@@ -60,14 +60,21 @@ def residual_block(num, quant_skip_pattern=None):
bias_attr=bias_attr) bias_attr=bias_attr)
return fluid.layers.batch_norm(input=tmp, act=act) return fluid.layers.batch_norm(input=tmp, act=act)
data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32') data = fluid.layers.data(
label = fluid.layers.data(name='label', shape=[1], dtype='int64') name='image',
shape=[1, 1, 32, 32],
dtype='float32',
append_batch_size=False)
label = fluid.layers.data(
name='label', shape=[1, 1], dtype='int64', append_batch_size=False)
hidden = data hidden = data
for _ in six.moves.xrange(num): for _ in six.moves.xrange(num):
conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True) conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True)
short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None) short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None)
hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu') hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu')
matmul_weight = fluid.layers.create_parameter(
shape=[1, 16, 32, 32], dtype='float32')
hidden = fluid.layers.matmul(hidden, matmul_weight, True, True)
if quant_skip_pattern: if quant_skip_pattern:
with fluid.name_scope(quant_skip_pattern): with fluid.name_scope(quant_skip_pattern):
pool = fluid.layers.pool2d( pool = fluid.layers.pool2d(
...@@ -189,6 +196,7 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -189,6 +196,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
def residual_block_quant(self, def residual_block_quant(self,
activation_quant_type, activation_quant_type,
weight_quantize_type, weight_quantize_type,
quantizable_op_type,
for_ci=True): for_ci=True):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
...@@ -202,7 +210,8 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -202,7 +210,8 @@ class TestQuantizationTransformPass(unittest.TestCase):
scope=fluid.global_scope(), scope=fluid.global_scope(),
place=place, place=place,
activation_quantize_type=activation_quant_type, activation_quantize_type=activation_quant_type,
weight_quantize_type=weight_quantize_type) weight_quantize_type=weight_quantize_type,
quantizable_op_type=quantizable_op_type)
transform_pass.apply(graph) transform_pass.apply(graph)
if not for_ci: if not for_ci:
marked_nodes = set() marked_nodes = set()
...@@ -223,14 +232,22 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -223,14 +232,22 @@ class TestQuantizationTransformPass(unittest.TestCase):
val_marked_nodes) val_marked_nodes)
def test_residual_block_abs_max(self): def test_residual_block_abs_max(self):
self.residual_block_quant('abs_max', 'abs_max', for_ci=True) quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul', 'matmul']
self.residual_block_quant(
'abs_max', 'abs_max', quantizable_op_type, for_ci=True)
def test_residual_block_range_abs_max(self): def test_residual_block_range_abs_max(self):
self.residual_block_quant('range_abs_max', 'abs_max', for_ci=True) quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul', 'matmul']
self.residual_block_quant(
'range_abs_max', 'abs_max', quantizable_op_type, for_ci=True)
def test_residual_block_moving_average_abs_max(self): def test_residual_block_moving_average_abs_max(self):
quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul', 'matmul']
self.residual_block_quant( self.residual_block_quant(
'moving_average_abs_max', 'channel_wise_abs_max', for_ci=True) 'moving_average_abs_max',
'channel_wise_abs_max',
quantizable_op_type,
for_ci=True)
class TestQuantizationFreezePass(unittest.TestCase): class TestQuantizationFreezePass(unittest.TestCase):
...@@ -523,14 +540,16 @@ def quant_dequant_residual_block(num, quant_skip_pattern=None): ...@@ -523,14 +540,16 @@ def quant_dequant_residual_block(num, quant_skip_pattern=None):
bias_attr=bias_attr) bias_attr=bias_attr)
return fluid.layers.batch_norm(input=tmp, act=act) return fluid.layers.batch_norm(input=tmp, act=act)
data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32') data1 = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32')
data2 = fluid.layers.data(
name='matmul_input', shape=[16, 32, 32], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64') label = fluid.layers.data(name='label', shape=[1], dtype='int64')
hidden = data hidden = data1
for _ in six.moves.xrange(num): for _ in six.moves.xrange(num):
conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True) conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True)
short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None) short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None)
hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu') hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu')
hidden = fluid.layers.matmul(hidden, data2, True, True)
if isinstance(quant_skip_pattern, str): if isinstance(quant_skip_pattern, str):
with fluid.name_scope(quant_skip_pattern): with fluid.name_scope(quant_skip_pattern):
pool1 = fluid.layers.pool2d( pool1 = fluid.layers.pool2d(
...@@ -596,7 +615,10 @@ class TestAddQuantDequantPass(unittest.TestCase): ...@@ -596,7 +615,10 @@ class TestAddQuantDequantPass(unittest.TestCase):
for input_name in input_names: for input_name in input_names:
self.assertTrue(input_name.endswith('.quant_dequant')) self.assertTrue(input_name.endswith('.quant_dequant'))
def residual_block_quant(self, skip_pattern=None, for_ci=True): def residual_block_quant(self,
quantizable_op_type,
skip_pattern=None,
for_ci=True):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
...@@ -606,7 +628,10 @@ class TestAddQuantDequantPass(unittest.TestCase): ...@@ -606,7 +628,10 @@ class TestAddQuantDequantPass(unittest.TestCase):
place = fluid.CPUPlace() place = fluid.CPUPlace()
graph = IrGraph(core.Graph(main.desc), for_test=False) graph = IrGraph(core.Graph(main.desc), for_test=False)
add_quant_dequant_pass = AddQuantDequantPass( add_quant_dequant_pass = AddQuantDequantPass(
scope=fluid.global_scope(), place=place, skip_pattern=skip_pattern) scope=fluid.global_scope(),
place=place,
skip_pattern=skip_pattern,
quantizable_op_type=quantizable_op_type)
add_quant_dequant_pass.apply(graph) add_quant_dequant_pass.apply(graph)
if not for_ci: if not for_ci:
marked_nodes = set() marked_nodes = set()
...@@ -625,14 +650,21 @@ class TestAddQuantDequantPass(unittest.TestCase): ...@@ -625,14 +650,21 @@ class TestAddQuantDequantPass(unittest.TestCase):
val_graph.draw('.', 'val_add_quant_dequant_graph', val_marked_nodes) val_graph.draw('.', 'val_add_quant_dequant_graph', val_marked_nodes)
def test_residual_block(self): def test_residual_block(self):
self.residual_block_quant(skip_pattern=None, for_ci=True) quantizable_op_type = ['elementwise_add', 'pool2d', 'mul', 'matmul']
self.residual_block_quant(
quantizable_op_type, skip_pattern=None, for_ci=True)
def test_residual_block_skip_pattern(self): def test_residual_block_skip_pattern(self):
self.residual_block_quant(skip_pattern='skip_quant', for_ci=True) quantizable_op_type = ['elementwise_add', 'pool2d', 'mul', 'matmul']
self.residual_block_quant(
quantizable_op_type, skip_pattern='skip_quant', for_ci=True)
def test_residual_block_skip_pattern(self): def test_residual_block_skip_pattern(self):
quantizable_op_type = ['elementwise_add', 'pool2d', 'mul', 'matmul']
self.residual_block_quant( self.residual_block_quant(
skip_pattern=['skip_quant1', 'skip_quant2'], for_ci=True) quantizable_op_type,
skip_pattern=['skip_quant1', 'skip_quant2'],
for_ci=True)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册