diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index cc5e87b22ed7a0967ad4a838170901101ea3fe5a..d0d69ae91a16b670494f1ec5310a029e52d8894b 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -244,7 +244,9 @@ class PostTrainingQuantization(object): drop_last=True, places=self._place) - # collect the variable names for sampling + # collect the variable names for sampling. + # TODO(juncaipeng), consider the name_scope of skip_quant and + # reduce the variables for sampling persistable_var_names = [] for var in self._program.list_vars(): if var.persistable: @@ -257,16 +259,18 @@ class PostTrainingQuantization(object): self._quantized_act_var_name.add(op.input("Input")[0]) self._quantized_weight_var_name.add(op.input("Filter")[0]) self._quantized_act_var_name.add(op.output("Output")[0]) - elif op_type == "mul": - if self._is_input_all_not_persistable( - op, persistable_var_names): - op._set_attr("skip_quant", True) - _logger.warning("Skip quant a mul op for two " - "input variables are not persistable") + elif op_type in ["mul", "matmul"]: + x_var_name = op.input("X")[0] + if x_var_name in persistable_var_names: + self._quantized_weight_var_name.add(x_var_name) + else: + self._quantized_act_var_name.add(x_var_name) + y_var_name = op.input("Y")[0] + if y_var_name in persistable_var_names: + self._quantized_weight_var_name.add(y_var_name) else: - self._quantized_act_var_name.add(op.input("X")[0]) - self._quantized_weight_var_name.add(op.input("Y")[0]) - self._quantized_act_var_name.add(op.output("Out")[0]) + self._quantized_act_var_name.add(y_var_name) + self._quantized_act_var_name.add(op.output("Out")[0]) else: # process other quantizable op type, the input must all not persistable if self._is_input_all_not_persistable( diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 062f1e6f0eb9cf08791196a7789e93bd24c28ac1..9edf473546f274d5b8862fe148b45b3c6ed71fa1 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -46,6 +46,7 @@ _op_real_in_out_name = { "conv2d": [["Input", "Filter"], ["Output"]], "depthwise_conv2d": [["Input"], ["Output"]], "mul": [["X", "Y"], ["Out"]], + "matmul": [["X", "Y"], ["Out"]], "pool2d": [["X"], ["Out"]], "elementwise_add": [["X", "Y"], ["Out"]], "concat": [["X"], ["Out"]], @@ -87,8 +88,25 @@ def _init_var_node(var_node, value, scope, place): tensor.set(value, place) +def _is_input_all_not_persistable(graph, op_node): + ''' + Analyse the real inputs of the op node are all not persistable. + ''' + is_input_all_not_persistable = True + op_node_name = op_node.name() + input_name_list = _op_real_in_out_name[op_node_name][0] + for input_name in input_name_list: + for arg_name in op_node.input(input_name): + in_node = graph._find_node_by_name(op_node.inputs, arg_name) + is_input_all_not_persistable = (is_input_all_not_persistable and \ + (not in_node.persistable())) + return is_input_all_not_persistable + + class QuantizationTransformPass(object): - _supported_quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul'] + _supported_quantizable_op_type = [ + 'conv2d', 'depthwise_conv2d', 'mul', 'matmul' + ] def __init__(self, scope=None, @@ -225,7 +243,7 @@ class QuantizationTransformPass(object): dequant_var_node = dequantized_vars[var_node.name()] else: quant_bits = self._weight_bits if var_node.name() in persistable_vars \ - else self._activation_bits + else self._activation_bits quant_type = self._weight_quantize_type if var_node.name() \ in persistable_vars else self._activation_quantize_type if quant_type == 'channel_wise_abs_max': @@ -252,17 +270,12 @@ class QuantizationTransformPass(object): graph.update_input_link(var_node, dequant_var_node, op) def _transform_backward(graph, op): - no_dequanted_input_vars = True for var_node in op.inputs: if var_node.name() not in op.input_arg_names(): continue if var_node.name() in dequantized_vars: dequant_var_node = dequantized_vars[var_node.name()] graph.update_input_link(var_node, dequant_var_node, op) - no_dequanted_input_vars = False - if no_dequanted_input_vars: - raise ValueError("There is no dequanted inputs for op %s." % - (op.name())) if not self._is_test: self._create_global_step(graph) @@ -277,18 +290,11 @@ class QuantizationTransformPass(object): # The loop for transforming the forward graph: for op in ops: if op.name() in self._quantizable_ops: - skipped = op.op().has_attr("skip_quant") and \ - op.op().attr("skip_quant") - if skipped: - continue - _transform_forward(graph, op) + if not QuantizationTransformPass._is_skip_quant(graph, op): + _transform_forward(graph, op) # The loop for renaming the inputs of backward op. for op in ops: if op.name() in self._quantizable_grad_ops: - skipped = op.op().has_attr("skip_quant") and \ - op.op().attr("skip_quant") - if skipped: - continue _transform_backward(graph, op) graph.resolve_hazard() return graph @@ -630,6 +636,22 @@ class QuantizationTransformPass(object): """ return "%s.scale" % (var_name) + @staticmethod + def _is_skip_quant(graph, op_node): + """ + Analyse whether the op node skips quantization. + """ + is_skip = False + if op_node.op().has_attr("skip_quant") and \ + op_node.op().attr("skip_quant"): + is_skip = True + # if the inputs of mul and matmul are not all persistable, use + # AddQuantDequantPass to quantize them. + if op_node.name() in ["mul", "matmul"] and \ + _is_input_all_not_persistable(graph, op_node): + is_skip = True + return is_skip + class QuantizationFreezePass(object): _supported_quantizable_op_type = \ @@ -733,14 +755,17 @@ class QuantizationFreezePass(object): for op_node in ops: op_name = op_node.name() if op_name in self._quantizable_ops: - skipped = op_node.op().has_attr("skip_quant") and \ - op_node.op().attr("skip_quant") - if skipped: - continue - if self._weight_quantize_type == 'channel_wise_abs_max' and op_name in self._conv_ops: - self._insert_post_channel_dequant_op(graph, op_node) - else: - self._insert_post_dequant_op(graph, op_node) + # only process the node that is quantized by QuantizationTransformPass + is_op_node_quantized = False + for var_node in op_node.inputs: + var_name = var_node.name() + if var_name.endswith('.dequantized'): + is_op_node_quantized = True + if is_op_node_quantized: + if self._weight_quantize_type == 'channel_wise_abs_max' and op_name in self._conv_ops: + self._insert_post_channel_dequant_op(graph, op_node) + else: + self._insert_post_dequant_op(graph, op_node) for op_node in ops: # insert dequant_op after fc/conv, need to rename inputs of the followed ops @@ -829,10 +854,6 @@ class QuantizationFreezePass(object): def _insert_post_dequant_op(self, graph, op_node): persistable_vars = [p.name() for p in graph.all_persistable_nodes()] - if len(op_node.input_arg_names()) >= 2 and len(persistable_vars) == 0: - raise ValueError("The op %s has more than one inputs " - "and all of them are not persistable. " - "Now, it is not supported!" % (op_node.name())) max_range = 1 param_range = (1 << (self._weight_bits - 1)) - 1 act_range = (1 << (self._activation_bits - 1)) - 1 @@ -987,9 +1008,7 @@ class ConvertToInt8Pass(object): for op_node in ops: op_name = op_node.name() if op_name in self._quantizable_ops: - skipped = op_node.op().has_attr("skip_quant") and \ - op_node.op().attr("skip_quant") - if skipped: + if QuantizationTransformPass._is_skip_quant(graph, op_node): continue for var_node in op_node.inputs: name = var_node.name() @@ -1240,7 +1259,7 @@ class AddQuantDequantPass(object): "equal", "gather", "greater_equal", "greater_than", "less_equal", "less_than", "mean", "not_equal", "reshape", "reshape2", "bilinear_interp", "nearest_interp", "trilinear_interp", "slice", - "squeeze", "elementwise_sub" + "squeeze", "elementwise_sub", "mul", "matmul" ] _activation_type = ["relu", "relu6", "leaky_relu", "tanh", "swish"] @@ -1317,34 +1336,38 @@ class AddQuantDequantPass(object): all_op_nodes = graph.all_op_nodes() for op_node in all_op_nodes: if op_node.name() in self._quantizable_op_type: - user_skipped = False + is_skip = False if isinstance(self._skip_pattern, list): - user_skipped = op_node.op().has_attr("op_namescope") and \ + is_skip = op_node.op().has_attr("op_namescope") and \ any(pattern in op_node.op().attr("op_namescope") for pattern in self._skip_pattern) elif isinstance(self._skip_pattern, str): - user_skipped = op_node.op().has_attr("op_namescope") and \ + is_skip = op_node.op().has_attr("op_namescope") and \ op_node.op().attr("op_namescope").find(self._skip_pattern) != -1 - if user_skipped: - continue + is_op_node_quantized = False + for var_node in op_node.inputs: + var_name = var_node.name() + if var_name.endswith('.dequantized'): + is_op_node_quantized = True - if not self._is_input_all_not_persistable(graph, op_node): + if is_skip or is_op_node_quantized or \ + (not _is_input_all_not_persistable(graph, op_node)): continue input_name_list = _op_real_in_out_name[op_node.name()][0] + arg_names = [] for input_name in input_name_list: - for arg_name in op_node.input(input_name): - in_node = graph._find_node_by_name(op_node.inputs, - arg_name) - if arg_name in dequantized_vars_map: - quant_var_node = dequantized_vars_map[arg_name] - else: - quant_var_node, _ = \ - self._inser_quant_dequant_moving_average_abs_max_op( - graph, in_node, self._quant_bits) - dequantized_vars_map[arg_name] = quant_var_node - graph.update_input_link(in_node, quant_var_node, - op_node) + arg_names.extend(op_node.input(input_name)) + for arg_name in arg_names: + in_node = graph._find_node_by_name(op_node.inputs, arg_name) + if arg_name in dequantized_vars_map: + quant_var_node = dequantized_vars_map[arg_name] + else: + quant_var_node, _ = \ + self._inser_quant_dequant_moving_average_abs_max_op( + graph, in_node, self._quant_bits) + dequantized_vars_map[arg_name] = quant_var_node + graph.update_input_link(in_node, quant_var_node, op_node) # Backward stage, update input link for op_node in all_op_nodes: @@ -1360,21 +1383,6 @@ class AddQuantDequantPass(object): graph.resolve_hazard() return graph - def _is_input_all_not_persistable(self, graph, op_node): - ''' - Analyse the real inputs of the op node are all not persistable. - ''' - is_input_all_not_persistable = True - op_node_name = op_node.name() - - input_name_list = _op_real_in_out_name[op_node_name][0] - for input_name in input_name_list: - for arg_name in op_node.input(input_name): - in_node = graph._find_node_by_name(op_node.inputs, arg_name) - is_input_all_not_persistable = (is_input_all_not_persistable and \ - (not in_node.persistable())) - return is_input_all_not_persistable - def _inser_quant_dequant_moving_average_abs_max_op(self, graph, var_node, quant_bits): """Insert fake_quantize_dequantize_moving_average_abs_max op. diff --git a/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py b/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py index 0141cc0f8ad847506f5e657e40ce0946fecf8144..eb86b667c0a3c3aaa372ed2ac83517f5de5a7b83 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py @@ -60,14 +60,21 @@ def residual_block(num, quant_skip_pattern=None): bias_attr=bias_attr) return fluid.layers.batch_norm(input=tmp, act=act) - data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32') - label = fluid.layers.data(name='label', shape=[1], dtype='int64') + data = fluid.layers.data( + name='image', + shape=[1, 1, 32, 32], + dtype='float32', + append_batch_size=False) + label = fluid.layers.data( + name='label', shape=[1, 1], dtype='int64', append_batch_size=False) hidden = data for _ in six.moves.xrange(num): conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True) short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None) hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu') - + matmul_weight = fluid.layers.create_parameter( + shape=[1, 16, 32, 32], dtype='float32') + hidden = fluid.layers.matmul(hidden, matmul_weight, True, True) if quant_skip_pattern: with fluid.name_scope(quant_skip_pattern): pool = fluid.layers.pool2d( @@ -189,6 +196,7 @@ class TestQuantizationTransformPass(unittest.TestCase): def residual_block_quant(self, activation_quant_type, weight_quantize_type, + quantizable_op_type, for_ci=True): main = fluid.Program() startup = fluid.Program() @@ -202,7 +210,8 @@ class TestQuantizationTransformPass(unittest.TestCase): scope=fluid.global_scope(), place=place, activation_quantize_type=activation_quant_type, - weight_quantize_type=weight_quantize_type) + weight_quantize_type=weight_quantize_type, + quantizable_op_type=quantizable_op_type) transform_pass.apply(graph) if not for_ci: marked_nodes = set() @@ -223,14 +232,22 @@ class TestQuantizationTransformPass(unittest.TestCase): val_marked_nodes) def test_residual_block_abs_max(self): - self.residual_block_quant('abs_max', 'abs_max', for_ci=True) + quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul', 'matmul'] + self.residual_block_quant( + 'abs_max', 'abs_max', quantizable_op_type, for_ci=True) def test_residual_block_range_abs_max(self): - self.residual_block_quant('range_abs_max', 'abs_max', for_ci=True) + quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul', 'matmul'] + self.residual_block_quant( + 'range_abs_max', 'abs_max', quantizable_op_type, for_ci=True) def test_residual_block_moving_average_abs_max(self): + quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul', 'matmul'] self.residual_block_quant( - 'moving_average_abs_max', 'channel_wise_abs_max', for_ci=True) + 'moving_average_abs_max', + 'channel_wise_abs_max', + quantizable_op_type, + for_ci=True) class TestQuantizationFreezePass(unittest.TestCase): @@ -523,14 +540,16 @@ def quant_dequant_residual_block(num, quant_skip_pattern=None): bias_attr=bias_attr) return fluid.layers.batch_norm(input=tmp, act=act) - data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32') + data1 = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32') + data2 = fluid.layers.data( + name='matmul_input', shape=[16, 32, 32], dtype='float32') label = fluid.layers.data(name='label', shape=[1], dtype='int64') - hidden = data + hidden = data1 for _ in six.moves.xrange(num): conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True) short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None) hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu') - + hidden = fluid.layers.matmul(hidden, data2, True, True) if isinstance(quant_skip_pattern, str): with fluid.name_scope(quant_skip_pattern): pool1 = fluid.layers.pool2d( @@ -596,7 +615,10 @@ class TestAddQuantDequantPass(unittest.TestCase): for input_name in input_names: self.assertTrue(input_name.endswith('.quant_dequant')) - def residual_block_quant(self, skip_pattern=None, for_ci=True): + def residual_block_quant(self, + quantizable_op_type, + skip_pattern=None, + for_ci=True): main = fluid.Program() startup = fluid.Program() with fluid.program_guard(main, startup): @@ -606,7 +628,10 @@ class TestAddQuantDequantPass(unittest.TestCase): place = fluid.CPUPlace() graph = IrGraph(core.Graph(main.desc), for_test=False) add_quant_dequant_pass = AddQuantDequantPass( - scope=fluid.global_scope(), place=place, skip_pattern=skip_pattern) + scope=fluid.global_scope(), + place=place, + skip_pattern=skip_pattern, + quantizable_op_type=quantizable_op_type) add_quant_dequant_pass.apply(graph) if not for_ci: marked_nodes = set() @@ -625,14 +650,21 @@ class TestAddQuantDequantPass(unittest.TestCase): val_graph.draw('.', 'val_add_quant_dequant_graph', val_marked_nodes) def test_residual_block(self): - self.residual_block_quant(skip_pattern=None, for_ci=True) + quantizable_op_type = ['elementwise_add', 'pool2d', 'mul', 'matmul'] + self.residual_block_quant( + quantizable_op_type, skip_pattern=None, for_ci=True) def test_residual_block_skip_pattern(self): - self.residual_block_quant(skip_pattern='skip_quant', for_ci=True) + quantizable_op_type = ['elementwise_add', 'pool2d', 'mul', 'matmul'] + self.residual_block_quant( + quantizable_op_type, skip_pattern='skip_quant', for_ci=True) def test_residual_block_skip_pattern(self): + quantizable_op_type = ['elementwise_add', 'pool2d', 'mul', 'matmul'] self.residual_block_quant( - skip_pattern=['skip_quant1', 'skip_quant2'], for_ci=True) + quantizable_op_type, + skip_pattern=['skip_quant1', 'skip_quant2'], + for_ci=True) if __name__ == '__main__':