From 61a1f68845f65a12723ac2a667d083a1ab27399e Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Mon, 5 Dec 2022 15:14:18 +0800 Subject: [PATCH] Support matmul in QAT and loading quantized models in PTQ (#47892) --- .../slim/quantization/imperative/utils.py | 1 + .../slim/quantization/quantization_pass.py | 18 ++++++++++++++++++ python/paddle/nn/quant/__init__.py | 1 + python/paddle/nn/quant/functional_layers.py | 10 +++++++++- 4 files changed, 29 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py b/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py index d771b51e09..e5ed14cb9f 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py @@ -63,6 +63,7 @@ fake_quant_output_layers = [ paddle.nn.quant.subtract, paddle.nn.quant.multiply, paddle.nn.quant.divide, + paddle.nn.quant.matmul, ] fake_quant_leaf_layers = [ diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 6d99f0949d..705b0e5e69 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -1939,6 +1939,15 @@ class AddQuantDequantPass: op_node.op()._set_attr("activation_bits", self._quant_bits) op_node.op()._set_attr("with_quant_attr", True) arg_names = utils._get_op_input_var_names(op_node) + # If already quanted, skip it. + skip_quant = False + for arg_name in arg_names: + if "quantized.dequantized" in arg_name: + skip_quant = True + break + if skip_quant: + continue + for arg_name in arg_names: in_node = graph._find_node_by_name( op_node.inputs, arg_name @@ -2797,6 +2806,15 @@ class AddQuantDequantPassV2: continue arg_names = utils._get_op_input_var_names(op_node) + # If already quanted, skip it. + skip_quant = False + for arg_name in arg_names: + if "quantized.dequantized" in arg_name: + skip_quant = True + break + if skip_quant: + continue + for arg_name in arg_names: in_node = graph._find_node_by_name( op_node.inputs, arg_name diff --git a/python/paddle/nn/quant/__init__.py b/python/paddle/nn/quant/__init__.py index 8973761ab6..f96558bfbe 100644 --- a/python/paddle/nn/quant/__init__.py +++ b/python/paddle/nn/quant/__init__.py @@ -21,6 +21,7 @@ from .functional_layers import reshape # noqa: F401 from .functional_layers import transpose # noqa: F401 from .functional_layers import concat # noqa: F401 from .functional_layers import flatten # noqa: F401 +from .functional_layers import matmul # noqa: F401 from .quant_layers import QuantStub # noqa: F401 __all__ = [] diff --git a/python/paddle/nn/quant/functional_layers.py b/python/paddle/nn/quant/functional_layers.py index 2986e3e050..3a0fafe6b6 100644 --- a/python/paddle/nn/quant/functional_layers.py +++ b/python/paddle/nn/quant/functional_layers.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ...tensor import manipulation, math +from ...tensor import linalg, manipulation, math from .. import Layer __all__ = [] @@ -85,3 +85,11 @@ class flatten(FloatFunctionalLayer): def forward(self, x, start_axis=0, stop_axis=-1, name=None): return manipulation.flatten(x, start_axis, stop_axis, name) + + +class matmul(FloatFunctionalLayer): + def __init__(self): + super().__init__() + + def forward(self, x, y, transpose_x=False, transpose_y=False, name=None): + return linalg.matmul(x, y, transpose_x, transpose_y, name) -- GitLab