diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index dc355fec0d362a4208ebb048a7b29925dabb6ead..90caee6c7a947023317f76874e8eac83c8b249f2 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -272,7 +272,8 @@ class QuantizationTransformPass(object): the quantized ops's inputs. """ _supported_quantizable_op_type = [ - 'conv2d', 'depthwise_conv2d', 'conv2d_transpose', 'mul', 'matmul' + 'conv2d', 'depthwise_conv2d', 'conv2d_transpose', 'mul', 'matmul', + 'matmul_v2' ] def __init__(self, @@ -520,6 +521,16 @@ class QuantizationTransformPass(object): dequant_var_node = dequantized_vars[var_node.name()] graph.update_input_link(var_node, dequant_var_node, op) + def _has_weight(op): + has_weight = False + for var_node in op.inputs: + if var_node.name() not in op.input_arg_names(): + continue + name = var_node.name() + if var_node.name() in persistable_vars: + has_weight = True + return has_weight + if not self._is_test: self._create_global_step(graph) ops = graph.all_op_nodes() @@ -535,11 +546,11 @@ class QuantizationTransformPass(object): # The loop for transforming the forward graph: for op in ops: if op.name() in self._quantizable_ops: - if not self._is_skip_quant(graph, op): + if not self._is_skip_quant(graph, op) and _has_weight(op): _transform_forward(graph, op) # The loop for renaming the inputs of backward op. for op in ops: - if op.name() in self._quantizable_grad_ops: + if op.name() in self._quantizable_grad_ops and _has_weight(op): _transform_backward(graph, op) graph.resolve_hazard() return graph