From 6390b175e6b205e2e3fdf9df7ed4af0e51b686b2 Mon Sep 17 00:00:00 2001 From: XGZhang <46363693+XGZhang11@users.noreply.github.com> Date: Thu, 28 Oct 2021 15:08:10 +0800 Subject: [PATCH] support quantization of bert (#36593) --- .../slim/quantization/quantization_pass.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index dc355fec0d3..90caee6c7a9 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -272,7 +272,8 @@ class QuantizationTransformPass(object): the quantized ops's inputs. """ _supported_quantizable_op_type = [ - 'conv2d', 'depthwise_conv2d', 'conv2d_transpose', 'mul', 'matmul' + 'conv2d', 'depthwise_conv2d', 'conv2d_transpose', 'mul', 'matmul', + 'matmul_v2' ] def __init__(self, @@ -520,6 +521,16 @@ class QuantizationTransformPass(object): dequant_var_node = dequantized_vars[var_node.name()] graph.update_input_link(var_node, dequant_var_node, op) + def _has_weight(op): + has_weight = False + for var_node in op.inputs: + if var_node.name() not in op.input_arg_names(): + continue + name = var_node.name() + if var_node.name() in persistable_vars: + has_weight = True + return has_weight + if not self._is_test: self._create_global_step(graph) ops = graph.all_op_nodes() @@ -535,11 +546,11 @@ class QuantizationTransformPass(object): # The loop for transforming the forward graph: for op in ops: if op.name() in self._quantizable_ops: - if not self._is_skip_quant(graph, op): + if not self._is_skip_quant(graph, op) and _has_weight(op): _transform_forward(graph, op) # The loop for renaming the inputs of backward op. for op in ops: - if op.name() in self._quantizable_grad_ops: + if op.name() in self._quantizable_grad_ops and _has_weight(op): _transform_backward(graph, op) graph.resolve_hazard() return graph -- GitLab