未验证 提交 8f7372ca 编写于 作者: J juncaipeng 提交者: GitHub

add mul and matmul quantization, test=develop (#22054)

* add mul and matmul quantization, test=develop
* add test for matmul, test=develop
上级 73733498
......@@ -244,7 +244,9 @@ class PostTrainingQuantization(object):
drop_last=True,
places=self._place)
# collect the variable names for sampling
# collect the variable names for sampling.
# TODO(juncaipeng), consider the name_scope of skip_quant and
# reduce the variables for sampling
persistable_var_names = []
for var in self._program.list_vars():
if var.persistable:
......@@ -257,15 +259,17 @@ class PostTrainingQuantization(object):
self._quantized_act_var_name.add(op.input("Input")[0])
self._quantized_weight_var_name.add(op.input("Filter")[0])
self._quantized_act_var_name.add(op.output("Output")[0])
elif op_type == "mul":
if self._is_input_all_not_persistable(
op, persistable_var_names):
op._set_attr("skip_quant", True)
_logger.warning("Skip quant a mul op for two "
"input variables are not persistable")
elif op_type in ["mul", "matmul"]:
x_var_name = op.input("X")[0]
if x_var_name in persistable_var_names:
self._quantized_weight_var_name.add(x_var_name)
else:
self._quantized_act_var_name.add(x_var_name)
y_var_name = op.input("Y")[0]
if y_var_name in persistable_var_names:
self._quantized_weight_var_name.add(y_var_name)
else:
self._quantized_act_var_name.add(op.input("X")[0])
self._quantized_weight_var_name.add(op.input("Y")[0])
self._quantized_act_var_name.add(y_var_name)
self._quantized_act_var_name.add(op.output("Out")[0])
else:
# process other quantizable op type, the input must all not persistable
......
......@@ -46,6 +46,7 @@ _op_real_in_out_name = {
"conv2d": [["Input", "Filter"], ["Output"]],
"depthwise_conv2d": [["Input"], ["Output"]],
"mul": [["X", "Y"], ["Out"]],
"matmul": [["X", "Y"], ["Out"]],
"pool2d": [["X"], ["Out"]],
"elementwise_add": [["X", "Y"], ["Out"]],
"concat": [["X"], ["Out"]],
......@@ -87,8 +88,25 @@ def _init_var_node(var_node, value, scope, place):
tensor.set(value, place)
def _is_input_all_not_persistable(graph, op_node):
'''
Analyse the real inputs of the op node are all not persistable.
'''
is_input_all_not_persistable = True
op_node_name = op_node.name()
input_name_list = _op_real_in_out_name[op_node_name][0]
for input_name in input_name_list:
for arg_name in op_node.input(input_name):
in_node = graph._find_node_by_name(op_node.inputs, arg_name)
is_input_all_not_persistable = (is_input_all_not_persistable and \
(not in_node.persistable()))
return is_input_all_not_persistable
class QuantizationTransformPass(object):
_supported_quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
_supported_quantizable_op_type = [
'conv2d', 'depthwise_conv2d', 'mul', 'matmul'
]
def __init__(self,
scope=None,
......@@ -252,17 +270,12 @@ class QuantizationTransformPass(object):
graph.update_input_link(var_node, dequant_var_node, op)
def _transform_backward(graph, op):
no_dequanted_input_vars = True
for var_node in op.inputs:
if var_node.name() not in op.input_arg_names():
continue
if var_node.name() in dequantized_vars:
dequant_var_node = dequantized_vars[var_node.name()]
graph.update_input_link(var_node, dequant_var_node, op)
no_dequanted_input_vars = False
if no_dequanted_input_vars:
raise ValueError("There is no dequanted inputs for op %s." %
(op.name()))
if not self._is_test:
self._create_global_step(graph)
......@@ -277,18 +290,11 @@ class QuantizationTransformPass(object):
# The loop for transforming the forward graph:
for op in ops:
if op.name() in self._quantizable_ops:
skipped = op.op().has_attr("skip_quant") and \
op.op().attr("skip_quant")
if skipped:
continue
if not QuantizationTransformPass._is_skip_quant(graph, op):
_transform_forward(graph, op)
# The loop for renaming the inputs of backward op.
for op in ops:
if op.name() in self._quantizable_grad_ops:
skipped = op.op().has_attr("skip_quant") and \
op.op().attr("skip_quant")
if skipped:
continue
_transform_backward(graph, op)
graph.resolve_hazard()
return graph
......@@ -630,6 +636,22 @@ class QuantizationTransformPass(object):
"""
return "%s.scale" % (var_name)
@staticmethod
def _is_skip_quant(graph, op_node):
"""
Analyse whether the op node skips quantization.
"""
is_skip = False
if op_node.op().has_attr("skip_quant") and \
op_node.op().attr("skip_quant"):
is_skip = True
# if the inputs of mul and matmul are not all persistable, use
# AddQuantDequantPass to quantize them.
if op_node.name() in ["mul", "matmul"] and \
_is_input_all_not_persistable(graph, op_node):
is_skip = True
return is_skip
class QuantizationFreezePass(object):
_supported_quantizable_op_type = \
......@@ -733,10 +755,13 @@ class QuantizationFreezePass(object):
for op_node in ops:
op_name = op_node.name()
if op_name in self._quantizable_ops:
skipped = op_node.op().has_attr("skip_quant") and \
op_node.op().attr("skip_quant")
if skipped:
continue
# only process the node that is quantized by QuantizationTransformPass
is_op_node_quantized = False
for var_node in op_node.inputs:
var_name = var_node.name()
if var_name.endswith('.dequantized'):
is_op_node_quantized = True
if is_op_node_quantized:
if self._weight_quantize_type == 'channel_wise_abs_max' and op_name in self._conv_ops:
self._insert_post_channel_dequant_op(graph, op_node)
else:
......@@ -829,10 +854,6 @@ class QuantizationFreezePass(object):
def _insert_post_dequant_op(self, graph, op_node):
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
if len(op_node.input_arg_names()) >= 2 and len(persistable_vars) == 0:
raise ValueError("The op %s has more than one inputs "
"and all of them are not persistable. "
"Now, it is not supported!" % (op_node.name()))
max_range = 1
param_range = (1 << (self._weight_bits - 1)) - 1
act_range = (1 << (self._activation_bits - 1)) - 1
......@@ -987,9 +1008,7 @@ class ConvertToInt8Pass(object):
for op_node in ops:
op_name = op_node.name()
if op_name in self._quantizable_ops:
skipped = op_node.op().has_attr("skip_quant") and \
op_node.op().attr("skip_quant")
if skipped:
if QuantizationTransformPass._is_skip_quant(graph, op_node):
continue
for var_node in op_node.inputs:
name = var_node.name()
......@@ -1240,7 +1259,7 @@ class AddQuantDequantPass(object):
"equal", "gather", "greater_equal", "greater_than", "less_equal",
"less_than", "mean", "not_equal", "reshape", "reshape2",
"bilinear_interp", "nearest_interp", "trilinear_interp", "slice",
"squeeze", "elementwise_sub"
"squeeze", "elementwise_sub", "mul", "matmul"
]
_activation_type = ["relu", "relu6", "leaky_relu", "tanh", "swish"]
......@@ -1317,25 +1336,30 @@ class AddQuantDequantPass(object):
all_op_nodes = graph.all_op_nodes()
for op_node in all_op_nodes:
if op_node.name() in self._quantizable_op_type:
user_skipped = False
is_skip = False
if isinstance(self._skip_pattern, list):
user_skipped = op_node.op().has_attr("op_namescope") and \
is_skip = op_node.op().has_attr("op_namescope") and \
any(pattern in op_node.op().attr("op_namescope") for pattern in self._skip_pattern)
elif isinstance(self._skip_pattern, str):
user_skipped = op_node.op().has_attr("op_namescope") and \
is_skip = op_node.op().has_attr("op_namescope") and \
op_node.op().attr("op_namescope").find(self._skip_pattern) != -1
if user_skipped:
continue
is_op_node_quantized = False
for var_node in op_node.inputs:
var_name = var_node.name()
if var_name.endswith('.dequantized'):
is_op_node_quantized = True
if not self._is_input_all_not_persistable(graph, op_node):
if is_skip or is_op_node_quantized or \
(not _is_input_all_not_persistable(graph, op_node)):
continue
input_name_list = _op_real_in_out_name[op_node.name()][0]
arg_names = []
for input_name in input_name_list:
for arg_name in op_node.input(input_name):
in_node = graph._find_node_by_name(op_node.inputs,
arg_name)
arg_names.extend(op_node.input(input_name))
for arg_name in arg_names:
in_node = graph._find_node_by_name(op_node.inputs, arg_name)
if arg_name in dequantized_vars_map:
quant_var_node = dequantized_vars_map[arg_name]
else:
......@@ -1343,8 +1367,7 @@ class AddQuantDequantPass(object):
self._inser_quant_dequant_moving_average_abs_max_op(
graph, in_node, self._quant_bits)
dequantized_vars_map[arg_name] = quant_var_node
graph.update_input_link(in_node, quant_var_node,
op_node)
graph.update_input_link(in_node, quant_var_node, op_node)
# Backward stage, update input link
for op_node in all_op_nodes:
......@@ -1360,21 +1383,6 @@ class AddQuantDequantPass(object):
graph.resolve_hazard()
return graph
def _is_input_all_not_persistable(self, graph, op_node):
'''
Analyse the real inputs of the op node are all not persistable.
'''
is_input_all_not_persistable = True
op_node_name = op_node.name()
input_name_list = _op_real_in_out_name[op_node_name][0]
for input_name in input_name_list:
for arg_name in op_node.input(input_name):
in_node = graph._find_node_by_name(op_node.inputs, arg_name)
is_input_all_not_persistable = (is_input_all_not_persistable and \
(not in_node.persistable()))
return is_input_all_not_persistable
def _inser_quant_dequant_moving_average_abs_max_op(self, graph, var_node,
quant_bits):
"""Insert fake_quantize_dequantize_moving_average_abs_max op.
......
......@@ -60,14 +60,21 @@ def residual_block(num, quant_skip_pattern=None):
bias_attr=bias_attr)
return fluid.layers.batch_norm(input=tmp, act=act)
data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
data = fluid.layers.data(
name='image',
shape=[1, 1, 32, 32],
dtype='float32',
append_batch_size=False)
label = fluid.layers.data(
name='label', shape=[1, 1], dtype='int64', append_batch_size=False)
hidden = data
for _ in six.moves.xrange(num):
conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True)
short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None)
hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu')
matmul_weight = fluid.layers.create_parameter(
shape=[1, 16, 32, 32], dtype='float32')
hidden = fluid.layers.matmul(hidden, matmul_weight, True, True)
if quant_skip_pattern:
with fluid.name_scope(quant_skip_pattern):
pool = fluid.layers.pool2d(
......@@ -189,6 +196,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
def residual_block_quant(self,
activation_quant_type,
weight_quantize_type,
quantizable_op_type,
for_ci=True):
main = fluid.Program()
startup = fluid.Program()
......@@ -202,7 +210,8 @@ class TestQuantizationTransformPass(unittest.TestCase):
scope=fluid.global_scope(),
place=place,
activation_quantize_type=activation_quant_type,
weight_quantize_type=weight_quantize_type)
weight_quantize_type=weight_quantize_type,
quantizable_op_type=quantizable_op_type)
transform_pass.apply(graph)
if not for_ci:
marked_nodes = set()
......@@ -223,14 +232,22 @@ class TestQuantizationTransformPass(unittest.TestCase):
val_marked_nodes)
def test_residual_block_abs_max(self):
self.residual_block_quant('abs_max', 'abs_max', for_ci=True)
quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul', 'matmul']
self.residual_block_quant(
'abs_max', 'abs_max', quantizable_op_type, for_ci=True)
def test_residual_block_range_abs_max(self):
self.residual_block_quant('range_abs_max', 'abs_max', for_ci=True)
quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul', 'matmul']
self.residual_block_quant(
'range_abs_max', 'abs_max', quantizable_op_type, for_ci=True)
def test_residual_block_moving_average_abs_max(self):
quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul', 'matmul']
self.residual_block_quant(
'moving_average_abs_max', 'channel_wise_abs_max', for_ci=True)
'moving_average_abs_max',
'channel_wise_abs_max',
quantizable_op_type,
for_ci=True)
class TestQuantizationFreezePass(unittest.TestCase):
......@@ -523,14 +540,16 @@ def quant_dequant_residual_block(num, quant_skip_pattern=None):
bias_attr=bias_attr)
return fluid.layers.batch_norm(input=tmp, act=act)
data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32')
data1 = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32')
data2 = fluid.layers.data(
name='matmul_input', shape=[16, 32, 32], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
hidden = data
hidden = data1
for _ in six.moves.xrange(num):
conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True)
short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None)
hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu')
hidden = fluid.layers.matmul(hidden, data2, True, True)
if isinstance(quant_skip_pattern, str):
with fluid.name_scope(quant_skip_pattern):
pool1 = fluid.layers.pool2d(
......@@ -596,7 +615,10 @@ class TestAddQuantDequantPass(unittest.TestCase):
for input_name in input_names:
self.assertTrue(input_name.endswith('.quant_dequant'))
def residual_block_quant(self, skip_pattern=None, for_ci=True):
def residual_block_quant(self,
quantizable_op_type,
skip_pattern=None,
for_ci=True):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
......@@ -606,7 +628,10 @@ class TestAddQuantDequantPass(unittest.TestCase):
place = fluid.CPUPlace()
graph = IrGraph(core.Graph(main.desc), for_test=False)
add_quant_dequant_pass = AddQuantDequantPass(
scope=fluid.global_scope(), place=place, skip_pattern=skip_pattern)
scope=fluid.global_scope(),
place=place,
skip_pattern=skip_pattern,
quantizable_op_type=quantizable_op_type)
add_quant_dequant_pass.apply(graph)
if not for_ci:
marked_nodes = set()
......@@ -625,14 +650,21 @@ class TestAddQuantDequantPass(unittest.TestCase):
val_graph.draw('.', 'val_add_quant_dequant_graph', val_marked_nodes)
def test_residual_block(self):
self.residual_block_quant(skip_pattern=None, for_ci=True)
quantizable_op_type = ['elementwise_add', 'pool2d', 'mul', 'matmul']
self.residual_block_quant(
quantizable_op_type, skip_pattern=None, for_ci=True)
def test_residual_block_skip_pattern(self):
self.residual_block_quant(skip_pattern='skip_quant', for_ci=True)
quantizable_op_type = ['elementwise_add', 'pool2d', 'mul', 'matmul']
self.residual_block_quant(
quantizable_op_type, skip_pattern='skip_quant', for_ci=True)
def test_residual_block_skip_pattern(self):
quantizable_op_type = ['elementwise_add', 'pool2d', 'mul', 'matmul']
self.residual_block_quant(
skip_pattern=['skip_quant1', 'skip_quant2'], for_ci=True)
quantizable_op_type,
skip_pattern=['skip_quant1', 'skip_quant2'],
for_ci=True)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册