From adb69ed660d2b0605e3eb8d3514be5586f2517cf Mon Sep 17 00:00:00 2001 From: whs Date: Tue, 28 Mar 2023 11:57:42 +0800 Subject: [PATCH] Add constraints module used in QAT (#1667) * Add module to analysis dygraph * Support fusing conv bn * Add tutorial for qat with conv bn fused constraint * Add what_is_constraints --- demo/quant/quant_post/quant_post.py | 4 +- .../how_to_qat_with_constraint.ipynb | 220 +++++++++++++++ .../what_is_constraints.md | 135 +++++++++ paddleslim/core/__init__.py | 3 + paddleslim/core/graph.py | 113 ++++++++ paddleslim/core/graph_tracer.py | 163 +++++++++++ paddleslim/quant/__init__.py | 3 + paddleslim/quant/config.py | 31 +++ paddleslim/quant/constraints/__init__.py | 18 ++ paddleslim/quant/constraints/constraint.py | 91 ++++++ .../quant/constraints/conv_bn_constraints.py | 61 +++++ paddleslim/quant/nn/__init__.py | 17 ++ paddleslim/quant/nn/conv_bn.py | 259 ++++++++++++++++++ paddleslim/quant/qat.py | 42 +++ .../quantization/test_conv_bn_constraints.py | 103 +++++++ tests/quantization/test_graph_tracer.py | 58 ++++ tests/quantization/test_qat.py | 139 ++++++++++ 17 files changed, 1458 insertions(+), 2 deletions(-) create mode 100644 example/quantization/qat_with_constraints/how_to_qat_with_constraint.ipynb create mode 100644 example/quantization/qat_with_constraints/what_is_constraints.md create mode 100644 paddleslim/core/graph.py create mode 100644 paddleslim/core/graph_tracer.py create mode 100644 paddleslim/quant/config.py create mode 100644 paddleslim/quant/constraints/__init__.py create mode 100644 paddleslim/quant/constraints/constraint.py create mode 100644 paddleslim/quant/constraints/conv_bn_constraints.py create mode 100644 paddleslim/quant/nn/__init__.py create mode 100644 paddleslim/quant/nn/conv_bn.py create mode 100644 paddleslim/quant/qat.py create mode 100644 tests/quantization/test_conv_bn_constraints.py create mode 100644 tests/quantization/test_graph_tracer.py create mode 100644 tests/quantization/test_qat.py diff --git a/demo/quant/quant_post/quant_post.py b/demo/quant/quant_post/quant_post.py index 2e3f7b40..b9fed6cb 100755 --- a/demo/quant/quant_post/quant_post.py +++ b/demo/quant/quant_post/quant_post.py @@ -22,13 +22,13 @@ parser = argparse.ArgumentParser(description=__doc__) add_arg = functools.partial(add_arguments, argparser=parser) # yapf: disable add_arg('batch_size', int, 32, "Minibatch size.") -add_arg('batch_num', int, 1, "Batch number") +add_arg('batch_num', int, 10, "Batch number") add_arg('use_gpu', bool, True, "Whether to use GPU or not.") add_arg('model_path', str, "./inference_model/MobileNetV1_infer/", "model dir") add_arg('save_path', str, "./quant_model/MobileNet/", "model dir to save quanted model") add_arg('model_filename', str, 'inference.pdmodel', "model file name") add_arg('params_filename', str, 'inference.pdiparams', "params file name") -add_arg('algo', str, 'hist', "calibration algorithm") +add_arg('algo', str, 'avg', "calibration algorithm") add_arg('round_type', str, 'round', "The method of converting the quantized weights.") add_arg('hist_percent', float, 0.9999, "The percentile of algo:hist") add_arg('is_full_quantize', bool, False, "Whether is full quantization or not.") diff --git a/example/quantization/qat_with_constraints/how_to_qat_with_constraint.ipynb b/example/quantization/qat_with_constraints/how_to_qat_with_constraint.ipynb new file mode 100644 index 00000000..1d9ce5c1 --- /dev/null +++ b/example/quantization/qat_with_constraints/how_to_qat_with_constraint.ipynb @@ -0,0 +1,220 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# QAT with convolution and batchnorm fused constraints" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "本教程以Conv2D和BatchNorm的融合为例,介绍如何使用PaddleSlim接口快速为量化训练添加训练约束(constraints)。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. 添加依赖" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import paddle\n", + "from paddle.vision.models import resnet18\n", + "from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver\n", + "from paddleslim.quant import SlimQuantConfig as QuantConfig\n", + "from paddleslim.quant import SlimQAT\n", + "from paddleslim.quant.constraints import FreezedConvBNConstraint\n", + "paddle.set_device(\"cpu\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. 构造模型" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "model = resnet18()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. 量化训练\n", + "\n", + "**配置**\n", + "\n", + "构造量化配置实例,并将激活和权重的量化方式指定为基础的 AbsMax 量化策略。然后,调用 `add_constraints` 配置一个Constraints实例。\n", + "最后,使用量化配置构造SlimQAT对象。代码如下所示:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9)\n", + "q_config = QuantConfig(activation=quanter, weight=quanter)\n", + "q_config.add_constraints(FreezedConvBNConstraint())\n", + "qat = SlimQAT(q_config)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**量化**\n", + "\n", + "调用SlimQAT对象的quantize方法,将模型转为用于量化训练的模型。\n", + "用户需要指定一个inputs, 用于推断分析动态图的执行拓扑图,以便自动检测出所有Conv2D和BN的组合。\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "x = paddle.rand([1, 3, 224, 224])\n", + "quant_model = qat.quantize(model, inplace=True, inputs=x)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "在该步骤,所有的Conv2D和BN的组合,都会被替换为QuantedConv2DBatchNorm layer。在QuantedConv2DBatchNorm中,参考 [Quantizing deep convolutional networks for efficient inference: A whitepaper](https://arxiv.org/abs/1806.08342) ,在量化训练过程中模拟Conv2D和BatchNorm的融合。原理如下图所示:\n", + "\n", + "
\n", + "
Conv BN 量化训练矫正方案示意图
\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**训练**" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "quant_model.train()\n", + "out = quant_model(x)\n", + "out.backward()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**查看量化训练模型结构**\n", + "\n", + "直接在终端输出量化训练模型的结构,或者将模型保存到文件系统,并使用[netron](https://netron.app/)查看模型结构。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(quant_model)\n", + "quant_model.eval()\n", + "paddle.jit.save(quant_model, \"./qat_model\", input_spec=[x])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "模拟量化训练模型结构如下图所示:\n", + "\n", + "
\"image\"
\n", + "
模拟量化模型结构示意图
\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**保存推理模型**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "infer_model = qat.convert(quant_model, inplace=True)\n", + "print(infer_model)\n", + "paddle.jit.save(infer_model, \"./infer_model\", input_spec=[x])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "推理模型结构如下图所示:\n", + "\n", + "
\"image\"
\n", + "
量化推理模型结构示意图
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + }, + "vscode": { + "interpreter": { + "hash": "2f394aca7ca06fed1e6064aef884364492d7cdda3614a461e02e6407fc40ba69" + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/example/quantization/qat_with_constraints/what_is_constraints.md b/example/quantization/qat_with_constraints/what_is_constraints.md new file mode 100644 index 00000000..dfe1c928 --- /dev/null +++ b/example/quantization/qat_with_constraints/what_is_constraints.md @@ -0,0 +1,135 @@ +# QAT Constraints 原理 + +## 1. 概述 +约束(Constraints)是指在模拟量化训练阶段,施加到模型上的一些约束规则。约束的对象包括且不限于:模型参数、量化参数、Layer 的执行。 +需要添加约束的原因:模拟量化的前向数值计算需要和推理的前向数值计算对齐。 +实现独立且可扩展的约束模块的原因:不同的推理库和硬件,对量化推理的支持程度和方式不同,则在模拟量化训练阶段需要添加不同的约束。 + +## 2. 各种约束介绍 + + +### 2.1 Conv/MatMul Bias约束 +给定图1所示的计算图,FP32 数值类型的 Weight 和 Input 经过矩阵乘,再加上一个同是 FP32 数值类型的 Bias,最终的 FP32 数值类型的输出交给后续的 Layer,这里以 SkipLayernorm 为例。 + +
+
图 1 - Matmul Add计算示意图
+ +对于上述计算,可以列举出三种模拟 Int8 量化方式,分别对应三种 Int8 推理实现。 + +#### 第一种量化实现 + +这种实现无需特殊约束,只用正常的模拟量化训练。对该实现介绍的目的是供其它量化方式对比。 +如下图所示,在模拟量化阶段,在 Matmul 的两个输入之后插入 Per-Tensor 的 QDQ Layer(量化反量化操作),来模拟对weight和input的量化。 +weigh 和 input 对应的量化 scale 分别为 $S_w$ 和 $S_i$。右图为推理实现,其中: + +- 删除 weight 的模拟量化 Layer,以 Int8 数值格式存储 weight +- 将 Input 的模拟量化 Layer 替换为量化 Layer。量化 Layer 负责将 FP32 类型的 input 量化为 Int8 数值类型,所用量化 scale 为模拟量化阶段统计到的 $S_i$ +- Matmul 以 Int8 数值类型进行乘计算,并将结果收集累加为 Int32 数值 +- Matmul 之后的 Per-Tensor 反量化 Layer 将 Int32 数值反量化为FP32,所用量化 scale为$S_w * S_i$ +- Add 正常用 FP32 数值类型计算 + +
+ +
图2 第一种量化实现示意图
+ +计算公式如下: + +$$S_w = \frac{AbsMax_w }{2^{8-1}}$$ +$$S_i = \frac{AbsMax_i}{2^{8-1}}$$ +$$QWeight = round(\frac{weight} { S_w})$$ +$$QInput = round(\frac{input} { S_i})$$ +$$DqOutput = output * S_w * S_i $$ + +#### 第二种量化实现 + +这种实现需要对量化 Scale 进行约束: +1. Bias 的量化 scale 必须与 Matmul 的反量化 scale 一致 +2. Matmul 的量化 scale 受 Bias 限制,需要保证将量化后的 Bias 限制在 16bits 以内(仅限 TI 芯片) +如下图左侧模型结构所示,不仅模拟量化 Matmul 两个输入,还模拟量化了 Bias。对 Bias 的模拟量化与传统模拟量化不同,量化 scale 不是通过统计Bias收集得到的,而是直接用的 $S_w * S_i$,也就是 Matmul Layer 的反量化 scale。 +这里就是对 Matmul+Bias 的一个约束规则,在模拟量化训练阶段,要以 $S_w * S_i$ 为量化 scale对 Bias 进行模拟量化。该约束是为了与推理实现对齐,如下图右侧模型结构所示。 +在推理实现中,Matmul 以 Int8 数值类型进行乘计算,收集相乘后的数值,累加到 Int32 类型的 accumulator 中。为了能将 Bias 也累加到这个 accumulator,需要将 Bias 量化为 Int32。并且,量化Bias 所用的量化 scale 需要与 accumulator 的量化 scale 一样,即 $S_w * S_i$。 + +
+
图3 第二种量化实现示意图
+ +整个推理执行的公式为: + +$$FP32Output = (QWeight * QInput + QBias) * S_w * S_i $$ + +其中: + +- $QWeight$: 量化后的 weight +- $QInput$:计算方式为 $round(\frac{FP32Input}{S_i})$ +- $QBias$: 量化后的 bias,计算方式为 $round(\frac{FP32Bias} {S_w * S_i})$ + +对于TI芯片,基本过程如上所述,但是在某些场景下,需要限制量化后的 bias 到 Int16 的范围内。用公式表示为: + +$$\frac{AbsMax(FP32Bias)}{S_w * S_i} <= 2^{15}$$ + +在模拟量化训练过程中,需要满足上述公式的约束,即本小节开始处提到的第二个约束。 + +#### 第三种量化实现 + +这种量化方式,只需对 Add 单独加约束,即:保证 Add Layer 的两个输入的量化 scale 保持一致。 +如下图左侧模型结构所示,Matmul 和 Add 相互独立,都分别添加了模拟量化 Layer。Add 的两个输入的量化 scale 为与 Matmul 输入的量化 scale 无关。 +在推理时,无法直接将 bias 累加到 matmul 的 Int32 accumulator 上,因为 bias 的量化 scale 与 accumulator 不一样。而是,先用 $S_w * S_i$ 将 Int32 accumulator 反量化为 FP32,然后再用 $S_o$ 将 FP32 数值量化为 Int8,并与 Int8 数值类型的 Bias 相加。最终将 Add 结果用 $S_o$ 反量化为 FP32 数值。整个推理的计算公式为: + +$$FP32Output = (round(\frac{(QWeight * QInput) * S_w*S_i}{S_o}) + QBias) * S_o$$ + +其中: + +$QBias$: 量化后的 bias,计算方式为 $round(\frac{FP32Bias} {S_o})$ +这种实现方式,理论上没有第二种合理,这里列出来是为了对比第二种方法中的 Bias 的量化方式。 + +
+
图4 - 第三种量化实现示意图
+ + +### 2.2 Convolution/mul BatchNorm约束 + +本小节介绍2种在模拟量化训练中处理 convolution batchnorm 的方式。 +Batch normalization 被普遍用于计算机视觉模型训练,它可以减少模型层与层之间的影响,是模型训练更稳定和快速。 +训练时,公式如下: + +$$x_{bn} = \gamma (\frac{x-\mu_{B}}{\sigma_B} ) + \beta$$ + +推理时,公式如下: + +$$x_{bn} = \gamma (\frac{x-\mu}{\sigma} ) + \beta$$ + +其中,$\mu_B$ 和 $\sigma_B$ 是当前当前单个 batch 的平均值和标准差。$\mu$ 和 $\sigma$ 是在训练阶段,从多个 batch 中统计得到平均值和标准差。 +在推理时,会将 batch normalizaiton 融合到 convolution 或 mul layer 中。公式如下: + +$$W_{inference} = \frac{\gamma * W}{\sigma}$$ +$$Bias_{inference} = \beta - \frac{\gamma \mu}{\sigma}$$ + +在模拟量化训练阶段,需要模拟上述 convolution 和 batch normalization 的融合操作。 + +可以有两个选择: +- 策略1:Unfreeze BN, 使用单个 batch 的统计的平均值($\mu_B$)和标准差($\sigma_B$) + - 缺点:训练不稳定。batch 间 BN 的统计值 $\sigma_B$ 变化比较大,与 BN 融合之后,convolution 的 weight 也会随 $\sigma_B$ 频繁变化,最终导致 weight 量化 scales 的不稳定。而在推理时,convolution 融合的是从全局统计的 $\sigma$。所以,训练时 weight 量化 scales 的不稳定,会体现为推理精度的不稳定。 + - 现象:如图5绿色曲线。 +- 策略2:Freeze BN,即使用训练阶段统计的平均值 $\mu$ 和标准差 $\sigma$ + - 缺点:没有用到 batch norm 的特性,对激活做 nomaliztion 所用的 mean 和 variance 不符合当前 batch 的激活分布,导致量化训练不稳定 + - 现象: + +在策略1中,训练不稳定的直接原因是『weight的量化scale频繁变化』,使用 moving average 来计算 weight 的 scale,会缓解问题,但不能完全避免该问题,如图5橙色曲线所示。 + +
+
图5 - 各种 Conv BN 训练方法效果对比
+ + +最优的方案是综合策略1和策略2的优点,让 weight 使用全局的BN参数 $\sigma$,使 weight 更稳定。让激活使用当前 batch 统计到的 $\sigma_B$,使激活 BN 整准确。该方案如图6所示 + +
+
图6 - Conv BN 量化训练矫正方案示意图
+ + +图6方案有如下3个关键点: +1. 让 conv weight 更稳定 +在量化之前,将全局统计的 BN 参数 $\sigma$ 合并到 conv weight,确保 weight 量化的稳定性。 +2. 让激活BN更准确 +在训练前期,通过乘上 $\frac{\sigma}{\sigma_B}$ 来恢复 weight 融合对激活的影响,使激活是经过$\sigma_B$ 和 $\mu_B$ normalize 后的结果,行为更接近训练时标准的BN. 在上图中,对应 freeze Batch norm 开关为0。 +3. Freeze BN +经过一段 unfreeze batch norm 的训练,开启 freeze batch norm 开关,使激活的 normalization 也使用全局的 BN 参数,与推理保持一致。相当于,消除了 BN 带来的训练与推理的差异,专心学习量化。 +上述方案,效果如图6中蓝色曲线所示。 diff --git a/paddleslim/core/__init__.py b/paddleslim/core/__init__.py index e3076a7a..91299ffd 100644 --- a/paddleslim/core/__init__.py +++ b/paddleslim/core/__init__.py @@ -18,8 +18,11 @@ from ..core import registry from .registry import * from ..core import dygraph from .dygraph import * +from .graph_tracer import GraphTracer +from .graph import Graph __all__ = [] __all__ += graph_wrapper.__all__ __all__ += registry.__all__ __all__ += dygraph.__all__ +__all__ += ["GraphTracer", "Graph"] diff --git a/paddleslim/core/graph.py b/paddleslim/core/graph.py new file mode 100644 index 00000000..6929250a --- /dev/null +++ b/paddleslim/core/graph.py @@ -0,0 +1,113 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" A structure used to describe the network of a model.""" + +from typing import List, Tuple +import paddle + +__all__ = ["Graph"] + + +class Node(): + """ It is the node of the model's executing DAG. A node is a layer's once-calling. + Args: + layer(paddle.nn.Layer): The layer executed in the current node. + """ + + def __init__(self, layer, call_count): + self._layer = layer + self._layer_name = layer.full_name() + self._call_count = call_count + self._hash_name = f"{self._layer_name}_{self._call_count}" + self._next_nodes = [] + self._previous_nodes = [] + + @property + def next_nodes(self): + """ Get the next nodes of the current node. + + Returns: + A list of nodes representing the next nodes of the current node. + + """ + return self._next_nodes + + @property + def previous_nodes(self): + """ Get the previous nodes of the current node. + + Returns: + A list of nodes representing the previous nodes of the current node. + """ + return self._previous_nodes + + @property + def name(self): + return self._hash_name + + @property + def layer_name(self) -> str: + """ Get the full name of the layer executed in the current node.""" + return self._layer_name + + @property + def layer(self) -> paddle.nn.Layer: + """ Get the layer executed in the current node.""" + return self._layer + + def is_leaf(self): + """ Whether this node is a leaf node. It is a leaf when the layer called by this node has no sublayers.""" + return isinstance(self._layer, + paddle.nn.Layer) and len(self._layer.sublayers()) == 0 + + def __hash__(self): + return self._hash_name + + def __str__(self): + return self._hash_name + + +class Graph(): + """ Directed Acyclic Graph used to describe the executing of the model. + Usually, the graph is built with tracer tools. There is no need to call the constructor directly. + """ + + def __init__(self): + self._name2node = {} + + def __str__(self): + return str(self._name2node.keys()) + + @property + def nodes(self) -> List[Node]: + """ Get all the nodes in the graph. + """ + return self._name2node.values() + + def find_conv_bn(self) -> List[Tuple[Node, Node]]: + """ Find all the convolution and batch normalization pairs. + Returns: + A list of convolution and batch normalization pairs. Each pair is a tuple with two + nodes, the first being the convolution node and the second being the batch normalization node. + """ + conv2d_types = [paddle.nn.Conv2D] + bn_types = [paddle.nn.BatchNorm2D] + results = [] + for node in self.nodes: + if type(node.layer) in conv2d_types: + for _next_node in node.next_nodes: + if type(_next_node.layer) in bn_types: + results.append((node, _next_node)) + break + return results diff --git a/paddleslim/core/graph_tracer.py b/paddleslim/core/graph_tracer.py new file mode 100644 index 00000000..5b202d2f --- /dev/null +++ b/paddleslim/core/graph_tracer.py @@ -0,0 +1,163 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" A structure used to describe the network of a model.""" + +import sys +import types +import paddle + +from paddleslim.core import GraphWrapper +from .graph import Graph, Node +from paddleslim.core.dygraph import dygraph2program + +__all__ = ["GraphTracer"] + + +def _apply(layer, func): + for name, child in layer.named_children(): + func(child) + _apply(child, func) + + +def _add_call_hook(module, + function_new, + method_name='forward', + backup_name='__forward_orig__'): + def _call_hook_enable(op): + # do not patch the top level modules. makes it easy to invoke by self.module(x) + if op is not module: + assert not hasattr( + op, backup_name + ), f'in {op.__class__.__name__} detected an existing function {backup_name} : please double check' + # backup the original forward of op into backup_name + method_orig = getattr(op, method_name) + setattr(op, backup_name, method_orig) + # set new method + method_new = types.MethodType(function_new, op) + setattr(op, method_name, method_new) + + _apply(module, _call_hook_enable) + + +def _remove_call_hook(module, + method_name='forward', + backup_name='__forward_orig__'): + def _call_hook_disable(op): + if op is not module: + if hasattr(op, backup_name): + method_new = getattr(op, method_name) + method_orig = getattr(op, backup_name) + setattr(op, method_name, method_orig) + # delete the backup + setattr(op, backup_name, method_new) + delattr(op, backup_name) + + _apply(module, _call_hook_disable) + + +class GraphTracer(paddle.nn.Layer): + """ A tool used to trace the execution of the model. + Call the forward of the model decorated by this tracer + and it will create a graph. + + Args: + model(paddle.nn.Layer): The model to be traced. + + Examples: + .. code-block:: python + from paddeslim.core.graph_tracer import GraphTracer + from paddle.vision.models import resnet18 + + model = resnet18() + x = paddle.rand([1, 3, 224, 224]) + tracer = GraphTracer(model) + tracer(x) + print(tracer.graph) + + """ + + def __init__(self, model: paddle.nn.Layer): + super(GraphTracer, self).__init__() + self._model = model + self._graph = None + self._call_count = {} + self._tensor_previous = {} + + @property + def graph(self) -> Graph: + assert self._graph is not None, "Please trace the graph by calling forward function of current tracer." + return self._graph + + def forward(self, inputs, *args, **kwargs): + self._graph = Graph() + _add_call_hook(self._model, self._analyze_modules_op) + self._model(inputs, *args, **kwargs) + _remove_call_hook(self._model) + + def _analyze_modules_op(self, op, inputs, *args, **kwargs): + node = self._trace_in(op, inputs) + #print(f"inputs: {inputs.name}") + outputs = op.__forward_orig__(inputs, *args, **kwargs) + #print(f"outputs: {outputs.name}") + self._trace_out(node, outputs) + return outputs + + def _call_layer(self, layer): + layer_name = layer.full_name() + if layer_name not in self._call_count: + self._call_count[layer_name] = 0 + self._call_count[layer_name] += 1 + return self._call_count[layer_name] + + def _trace_in(self, layer, inputs): + inputs = self._convert_to_list(inputs) + call_cout = self._call_layer(layer) + current_node = Node(layer, call_cout) + + if current_node.name not in self._graph._name2node: + self._graph._name2node[current_node.name] = current_node + current_node = self._graph._name2node[current_node.name] + for inp in inputs: + last_node = self._tensor_previous.get(inp.name, None) + if last_node is not None: + assert isinstance(last_node, Node) + if last_node not in current_node.previous_nodes: + current_node.previous_nodes.append(last_node) + if current_node not in last_node.next_nodes: + last_node.next_nodes.append(current_node) + + return current_node + + def _trace_out(self, current_node, outputs): + assert current_node is not None, "The current node has not been visited." + if current_node.is_leaf(): + outputs = self._convert_to_list(outputs) + for out in outputs: + self._tensor_previous[out.name] = current_node + + def _convert_to_list(self, tensors): + """ Convert tensor to list. + It is important to convert the inputs to a list. + Because visiting the tensor by 'for ... in' will create new + temp variables and break the tracing process. + """ + if isinstance(tensors, paddle.Tensor): + return [tensors] + elif isinstance(tensors, (list, tuple)): + for _t in tensors: + assert isinstance(_t, paddle.Tensor) + return tensors + raise TypeError( + f"Unsopported type: {type(tensors)}; The inputs type should be paddle.Tensor' or list of paddle.Tensor." + ) diff --git a/paddleslim/quant/__init__.py b/paddleslim/quant/__init__.py index cdd60119..8bd83a9c 100755 --- a/paddleslim/quant/__init__.py +++ b/paddleslim/quant/__init__.py @@ -39,3 +39,6 @@ except Exception as e: "please use Paddle >= {} or develop version".format(min_paddle_version)) from .quant_embedding import quant_embedding +from . import nn +from .qat import SlimQAT +from .config import SlimQuantConfig diff --git a/paddleslim/quant/config.py b/paddleslim/quant/config.py new file mode 100644 index 00000000..3c3a14a4 --- /dev/null +++ b/paddleslim/quant/config.py @@ -0,0 +1,31 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.quantization.config import QuantConfig +from paddle.quantization.factory import QuanterFactory + + +class SlimQuantConfig(QuantConfig): + def __init__(self, activation: QuanterFactory, weight: QuanterFactory): + super(SlimQuantConfig, self).__init__(activation, weight) + self._constraints = [] + + @property + def constraints(self): + return self._constraints + + def add_constraints(self, constraints): + if not isinstance(constraints, (list, tuple)): + constraints = [constraints] + self._constraints.extend(constraints) \ No newline at end of file diff --git a/paddleslim/quant/constraints/__init__.py b/paddleslim/quant/constraints/__init__.py new file mode 100644 index 00000000..28a2a739 --- /dev/null +++ b/paddleslim/quant/constraints/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .constraint import FusionConstraint, Constraint +from .conv_bn_constraints import FreezedConvBNConstraint + +__all__ = ["Constraint", "FusionConstraint", "FreezedConvBNConstraint"] diff --git a/paddleslim/quant/constraints/constraint.py b/paddleslim/quant/constraints/constraint.py new file mode 100644 index 00000000..43acad98 --- /dev/null +++ b/paddleslim/quant/constraints/constraint.py @@ -0,0 +1,91 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import abc +from typing import List +from paddleslim.core.graph import Graph +from paddleslim.quant.config import SlimQuantConfig + +__all__ = ["FusionConstraint", "Constraint"] + + +class Constraint(metaclass=abc.ABCMeta): + """ Constraints are the rules applied on layers during quantization-aware training. + The abstract class defined the interface of constraint. All the constraints should + extend this abstract class and implement the 'apply' function. + """ + + @abc.abstractmethod + def apply(self, + model: paddle.nn.Layer, + graph: Graph, + qconfig: SlimQuantConfig) -> None: + """ Apply the rules of the constraint on the model in place. + Usually, there are three steps to implement the apply function: + 1. Find all the targets to be constrained. Each target can be one layer or a group of layers. + 2. If the target is a group of layers, fuse it into a fusion layer. + 3. Mapping the fusion layer to a QAT strategy layer by the "add_qat_layer_mapping" function of "qconfig". + + Args: + - model(paddle.nn.Layer): The model to be constrained. + - graph(Graph): A structure stored the DAG of executing model. It is used to + help find layers by some pattern. + - qconfig(SlimQuantConfig): The configuration of quantization. + """ + pass + + +class FusionConstraint(Constraint, metaclass=abc.ABCMeta): + """ Define some functions used to fuse operators. + """ + + def _find_parent_and_sublayer_name(self, model, layer): + for _name, _sub_layer in model.named_sublayers(): + if layer.full_name() == _sub_layer.full_name(): + return model, _name + else: + result = self._find_parent_and_sublayer_name(_sub_layer, layer) + if result is not None: + return result + + def _replace_layer(self, model, source, target): + result = self._find_parent_and_sublayer_name(model, source) + assert result is not None, f"Can't find the parent layer of {source.full_name()} in model {model.full_name()}!!!" + parent_layer, sub_name = result + parent_layer._sub_layers[sub_name] = target + setattr(parent_layer, sub_name, target) + + def fuse_ops(self, + model: paddle.nn.Layer, + fused_layer_type, + layers: List[paddle.nn.Layer], + config: SlimQuantConfig): + """ Fuse layers into fusion layer. + Args: + - model(paddle.nn.Layer): The model whose sublayers will be fused. + - fused_layer_type(type): Type of fusion layer. + - layers(list): List of layers to be fused. It will be the arguments to create 'fused_layer_type' instance + by calling fused_layer_type(*layers). + - configs(SlimQuantConfig): The configuration of quantization. + """ + fused_layer = fused_layer_type(*layers) + + for _layer in layers[1:]: + self._replace_layer(model, _layer, paddle.nn.Identity()) + if _layer in config._layer2config: + config._layer2config.pop(_layer) + + self._replace_layer(model, layers[0], fused_layer) + config._layer2config[fused_layer] = config._layer2config[layers[0]] diff --git a/paddleslim/quant/constraints/conv_bn_constraints.py b/paddleslim/quant/constraints/conv_bn_constraints.py new file mode 100644 index 00000000..2b89b782 --- /dev/null +++ b/paddleslim/quant/constraints/conv_bn_constraints.py @@ -0,0 +1,61 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .constraint import FusionConstraint +from ..nn import Conv2DBatchNormWrapper, QuantedConv2DBatchNorm + +FUSED_LAYER = Conv2DBatchNormWrapper +QAT_FUSED_LAYER = QuantedConv2DBatchNorm + + +class FreezedConvBNConstraint(FusionConstraint): + """ Simulate the folding of convolution and batch norm with a correction strategy. + Reference to: Quantizing deep convolutional networks for efficient inference: A whitepaper + Args: + - freeze_bn_delay(int): After the 'freeze_bn_delay' steps training, switch from using batch + statistics to long-term moving averages for batch normalization. + + Examples: + .. code-block:: python + + import paddle + from paddleslim.quant import SlimQuantConfig + from paddleslim.quant import SlimQAT + from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver + from paddle.vision.models import resnet18 + + model = resnet18() + quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9) + q_config = SlimQuantConfig(activation=quanter, weight=quanter) + q_config.add_constraints(FreezedConvBNConstraint(freeze_bn_delay=1)) + qat = SlimQAT(q_config) + quant_model = qat.quantize(model, inplace=True, inputs=x) + + """ + + def __init__(self, freeze_bn_delay=0): + self._freeze_bn_delay = freeze_bn_delay + + def apply(self, model, graph, config): + conv_bn_pairs = graph.find_conv_bn() + for pair in conv_bn_pairs: + pair = [node._layer for node in pair] + self.fuse_ops(model, FUSED_LAYER, pair, config) + config.add_qat_layer_mapping(FUSED_LAYER, QAT_FUSED_LAYER) + + def _set_freeze_bn_delay(layer): + if isinstance(layer, FUSED_LAYER): + setattr(layer, "_freeze_bn_delay", self._freeze_bn_delay) + + model.apply(_set_freeze_bn_delay) diff --git a/paddleslim/quant/nn/__init__.py b/paddleslim/quant/nn/__init__.py new file mode 100644 index 00000000..abd8b725 --- /dev/null +++ b/paddleslim/quant/nn/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .conv_bn import QuantedConv2DBatchNorm, Conv2DBatchNormWrapper + +__all__ = ["QuantedConv2DBatchNorm", "Conv2DBatchNormWrapper"] diff --git a/paddleslim/quant/nn/conv_bn.py b/paddleslim/quant/nn/conv_bn.py new file mode 100644 index 00000000..d63ab27b --- /dev/null +++ b/paddleslim/quant/nn/conv_bn.py @@ -0,0 +1,259 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Wrapper layer to simulate folding batch norms during quantization-aware training. +""" +import paddle +from paddle.nn import Layer +from paddle.nn import functional as F +from paddle.nn.layer.norm import _BatchNormBase +from paddle.nn.quant.format import ConvertibleQuantedLayer + + +class BatchNorm(Layer): + r"""Wrapper of batchnorm layer. It is used to get the mean and variance of cerrent mini-batch. + Args: + - bn(paddle.nn.layer.norm._BatchNormBase): The batchnorm layer to be wrapped. + """ + + def __init__(self, bn: _BatchNormBase): + super(BatchNorm, self).__init__() + self._bn = bn + self._runing_mean = self._bn._mean + self._runing_variance = self._bn._variance + self._weight = self._bn.weight + self._bias = self._bn.bias + self._batch_mean = None + self._batch_variance = None + self._training = True + self._batch_mean = None + self._batch_variance = None + assert self._bias is not None + + def forward(self, input): + r""" + Args: + - input(Tensor): The input to be normalized. + Return: + A tuple with 3 elements that is normalized output, mean of current mini-natch and variance of + current mini-batch. + """ + ( + batch_norm_out, _, _, batch_mean, batch_variance, _, + ) = paddle._C_ops.batch_norm( + input, + self. + _runing_mean, # TODO: should freeze this runing mean to avoid instability in training + self. + _runing_variance, # TODO: should freeze this runing variance to avoid instability in training + self._weight, + self._bias, + not self._training, + self._bn._momentum, + self._bn._epsilon, + self._bn._data_format, + False, # use_global_stats + False, # trainable_statistics + ) + batch_mean.stop_gradient = True + batch_variance.stop_gradient = True + self._batch_mean = batch_mean + self._batch_variance = batch_variance + return batch_norm_out, batch_mean, batch_variance + + +class Conv2DBatchNormWrapper(paddle.nn.Layer): + def __init__(self, conv, bn): + super(Conv2DBatchNormWrapper, self).__init__() + self._conv = conv + self._bn = bn + self._freeze_bn_delay = None + + def forward(self, inputs): + return self._bn(self._conv(inputs)) + + +class QuantedConv2DBatchNorm(ConvertibleQuantedLayer): + r"""Wrapper layer to simulate folding batch norms during quantization-aware training. + Fisrtly, it will execute once convolution and batch norms prior to quantizing weights + to get the long term statistics(i.e. runing mean and runing variance). + And it always scale the convolution's weights with a correction factor to the long term + statistics prior to quantization. This ensures that there is no jitter in the quantized + weights due to batch to batch variation. + Then the training will be divided into two phases: + 1. During the initial phase of training, with freeze_bn is false. It undo the scaling of + the weights so that outputs are identical to regular batch normalization. + 2. After sufficient training, with freeze_bn is true. Switch from using batch statistics + to long term moving averages for batch normalization. + Reference: Quantizing deep convolutional networks for efficient inference: A whitepaper + """ + + def __init__(self, conv_bn: Conv2DBatchNormWrapper, q_config): + super(QuantedConv2DBatchNorm, self).__init__() + conv = conv_bn._conv + bn = conv_bn._bn + self._freeze_bn_delay = conv_bn._freeze_bn_delay + # For Conv2D + self._groups = getattr(conv, '_groups') + self._stride = getattr(conv, '_stride') + self._padding = getattr(conv, '_padding') + self._padding_mode = getattr(conv, '_padding_mode') + if self._padding_mode != 'zeros': + self._reversed_padding_repeated_twice = getattr( + conv, '_reversed_padding_repeated_twice') + self._dilation = getattr(conv, '_dilation') + self._data_format = getattr(conv, '_data_format') + self.conv_weight = getattr(conv, 'weight') + self.conv_bias = getattr(conv, 'bias') + + if self.conv_bias is None: + self.conv_bias = paddle.create_parameter( + bn.bias.shape, dtype=self.conv_weight.dtype, is_bias=True) + + self.bn = BatchNorm(bn) + + self.weight_quanter = None + self.activation_quanter = None + if q_config.weight is not None: + self.weight_quanter = q_config.weight._instance(conv) + if q_config.activation is not None: + self.activation_quanter = q_config.activation._instance(conv) + + self._freeze_bn = False + + self._steps = 0 + + def freeze_bn(self): + self._freeze_bn = True + + def unfreeze_bn(self): + self._freeze_bn = False + + def forward(self, input): + if self.training: + if self._freeze_bn_delay == self._steps: + self.freeze_bn() + + if self._freeze_bn: + out = self._forward_with_bn_freezed(input) + else: + out = self._forward_with_bn_unfreezed(input) + self._steps += 1 + return out + else: + return self._eval_forward(input) + + def _forward_with_bn_unfreezed(self, input): + quant_input = input + if self.activation_quanter is not None: + quant_input = self.activation_quanter(quant_input) + + # Step 1: Excute conv bn to get global runing mean and variance + _, _, batch_variance = self._conv_bn(quant_input, self.conv_weight) + # Step 2: Merge and quantize weights of conv and bn + merged_weight = self._merge_conv_bn_weight(self.conv_weight, self.bn) + quant_weight = merged_weight + if self.weight_quanter is not None: + quant_weight = self.weight_quanter(quant_weight) + # Step 3: Excute conv with merged weights + conv_out = self._conv_forward(quant_input, quant_weight) + # Step 4: Scale output of conv and merge bias + factor = paddle.reshape((self.bn._runing_variance / batch_variance), + [1, -1, 1, 1]) + factor.stop_gradient = True + conv_out = conv_out * factor + merged_bias = self._merge_conv_unfreezed_bn_bias( + self.conv_bias, self.bn) + return conv_out + paddle.reshape(merged_bias, [1, -1, 1, 1]) + + def _eval_forward(self, input): + quant_input = input + if self.activation_quanter is not None: + quant_input = self.activation_quanter(quant_input) + if self.weight_quanter is not None: + quant_weight = self.weight_quanter(self.conv_weight) + conv_out = self._conv_forward(quant_input, quant_weight) + conv_out = conv_out if self.conv_bias is None else conv_out + paddle.reshape( + self.conv_bias, [1, -1, 1, 1]) + return conv_out + + def _forward_with_bn_freezed(self, input): + quant_input = input + if self.activation_quanter is not None: + quant_input = self.activation_quanter(quant_input) + # Step 1: Excute conv bn to get global runing mean and variance + self._conv_bn(quant_input, self.conv_weight) + # Step 2: Merge and quantize weights of conv and bn + merged_weight = self._merge_conv_bn_weight(self.conv_weight, self.bn) + quant_weight = merged_weight + if self.weight_quanter is not None: + quant_weight = self.weight_quanter(quant_weight) + # Step 3: Excute conv with merged weights + conv_out = self._conv_forward(quant_input, quant_weight) + # Step 4: Merge bias of conv and bn + merged_bias = self._merge_conv_freezed_bn_bias(self.conv_bias, self.bn) + return conv_out + paddle.reshape(merged_bias, [1, -1, 1, 1]) + + def _merge_conv_bn_weight(self, conv_weight, bn: BatchNorm): + merged_weight = paddle.reshape( + bn._weight, [-1, 1, 1, 1]) * conv_weight / paddle.reshape( + bn._runing_variance, [-1, 1, 1, 1]) + conv_weight.set_value(merged_weight) + return conv_weight + + def _merge_conv_freezed_bn_bias(self, conv_bias, bn: BatchNorm): + assert conv_bias is not None + + merged_bias = (conv_bias + bn._bias - + (bn._weight * bn._runing_mean) / bn._runing_variance) + conv_bias.set_value(merged_bias) + return conv_bias + + def _merge_conv_unfreezed_bn_bias(self, conv_bias, bn: BatchNorm): + assert conv_bias is not None + merged_bias = (conv_bias + bn._bias - + bn._weight * bn._batch_mean / bn._batch_variance) + conv_bias.set_value(merged_bias) + return conv_bias + + def _conv_bn(self, input, conv_weight): + return self._bn_forward(self._conv_forward(input, conv_weight)) + + def _bn_forward(self, input): + return self.bn.forward(input) + + def _conv_forward(self, inputs, weights): + if self._padding_mode != 'zeros': + inputs = F.pad( + inputs, + self._reversed_padding_repeated_twice, + mode=self._padding_mode, + data_format=self._data_format, ) + self._padding = 0 + + return F.conv2d( + inputs, + weights, + bias=self.conv_bias, + padding=self._padding, + stride=self._stride, + dilation=self._dilation, + groups=self._groups, + data_format=self._data_format, ) + + def weights_to_quanters(self): + return [('conv_weight', 'weight_quanter')] + + def activation_quanters(self): + return ['activation_quanter'] diff --git a/paddleslim/quant/qat.py b/paddleslim/quant/qat.py new file mode 100644 index 00000000..e3e3d2d0 --- /dev/null +++ b/paddleslim/quant/qat.py @@ -0,0 +1,42 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import copy +from paddle.quantization import QAT, QuantConfig +from ..core import GraphTracer + + +class SlimQAT(QAT): + def __init__(self, q_config): + super(SlimQAT, self).__init__(q_config) + + def _apply_constraints(self, model, graph): + for _contraint in self._config.constraints: + _contraint.apply(model, graph, self._config) + + def _analysis_model(self, _model, inputs): + assert inputs is not None + tracer = GraphTracer(_model) + tracer(inputs) + return tracer.graph + + def quantize(self, model: paddle.nn.Layer, inplace=False, inputs=None): + _model = model if inplace else copy.deepcopy(model) + self._config._specify(_model) + graph = self._analysis_model(_model, inputs) + self._apply_constraints(_model, graph) + self._convert_to_quant_layers(_model, self._config) + self._insert_activation_observers(_model, self._config) + return _model diff --git a/tests/quantization/test_conv_bn_constraints.py b/tests/quantization/test_conv_bn_constraints.py new file mode 100644 index 00000000..78da1ce1 --- /dev/null +++ b/tests/quantization/test_conv_bn_constraints.py @@ -0,0 +1,103 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +import unittest +import tempfile +sys.path.append("../../") +import paddle +from paddle.vision.models import resnet18 +from paddleslim.quant import SlimQuantConfig as QuantConfig +from paddleslim.quant import SlimQAT +from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver +from paddleslim.quant.nn.conv_bn import QuantedConv2DBatchNorm +from paddleslim.quant.constraints import FreezedConvBNConstraint +from test_qat import TestQuantAwareTraining, load_model_and_count_layer + + +class TestConvBNConstraintsBaseCase(TestQuantAwareTraining): + """ Common cases for testing 'quantize', 'convert' and 'jit.save' function.""" + + def extra_qconfig(self, qconfig): + qconfig.add_constraints(FreezedConvBNConstraint(freeze_bn_delay=1)) + + +class TestConvBNConstraints(unittest.TestCase): + """ More special cases on convolution and batch norm constraints.""" + + def setUp(self): + paddle.set_device("cpu") + self.temp_dir = tempfile.TemporaryDirectory(dir="./") + self.path = os.path.join(self.temp_dir.name, 'conv_bn_constraints') + + def tearDown(self): + self.temp_dir.cleanup() + + def _count_layers(self, model, layer_type): + count = 0 + for _layer in model.sublayers(True): + if isinstance(_layer, layer_type): + count += 1 + return count + + def _get_one_layer(self, model, layer_type): + for _layer in model.sublayers(True): + if isinstance(_layer, layer_type): + return _layer + return None + + def test_conv_bn(self): + model = resnet18() + conv_count = self._count_layers(model, paddle.nn.Conv2D) + quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9) + q_config = QuantConfig(activation=quanter, weight=quanter) + # It will freeze the batch normaliztion after 'freeze_bn_delay' steps + q_config.add_constraints(FreezedConvBNConstraint(freeze_bn_delay=1)) + + qat = SlimQAT(q_config) + x = paddle.rand([1, 3, 224, 224]) + paddle.jit.save(model, "./test_model", input_spec=[x]) + quant_model = qat.quantize(model, inplace=True, inputs=x) + + # check freeze_bn_delay + qat_conv_bn_layer = self._get_one_layer(quant_model, + QuantedConv2DBatchNorm) + self.assertIsNotNone(qat_conv_bn_layer) + self.assertFalse(qat_conv_bn_layer._freeze_bn) + quant_model.train() + out = quant_model(x) + out.backward() + out = quant_model(x) + out.backward() + self.assertTrue(qat_conv_bn_layer._freeze_bn) + + # check the count of QAT layers in QAT model + qat_layer_count = self._count_layers(quant_model, + QuantedConv2DBatchNorm) + self.assertEqual(qat_layer_count, conv_count) + + # check the count of convolution and batch norm in saved static graph + quant_model.eval() + infer_model = qat.convert(quant_model, inplace=True) + save_path = os.path.join(self.path, 'infer_model') + paddle.jit.save(infer_model, save_path, input_spec=[x]) + layer2count = load_model_and_count_layer(save_path, + ['conv2d', 'batch_norm']) + self.assertEqual(layer2count['conv2d'], conv_count) + self.assertEqual(layer2count['batch_norm'], 0) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/quantization/test_graph_tracer.py b/tests/quantization/test_graph_tracer.py new file mode 100644 index 00000000..059ad62e --- /dev/null +++ b/tests/quantization/test_graph_tracer.py @@ -0,0 +1,58 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +import unittest +import tempfile +sys.path.append("../../") +import paddle +from paddle.vision.models import resnet18 +from paddleslim.core.graph_tracer import GraphTracer +from paddle.vision.models import resnet18 + + +class TestConvBNConstraints(unittest.TestCase): + def _count_layers(self, model, layer_type): + count = 0 + for _layer in model.sublayers(True): + if isinstance(_layer, layer_type): + count += 1 + return count + + def test_conv_bn(self): + paddle.set_device("cpu") + model = resnet18() + conv_count = self._count_layers(model, paddle.nn.Conv2D) + x = paddle.rand([1, 3, 224, 224]) + tracer = GraphTracer(model) + tracer(x) + conv_bn_pairs = tracer.graph.find_conv_bn() + self.assertEqual(len(conv_bn_pairs), conv_count) + + conv_names = set() + bn_names = set() + for _layer in model.sublayers(): + if isinstance(_layer, paddle.nn.Conv2D): + conv_names.add(_layer.full_name()) + elif isinstance(_layer, paddle.nn.BatchNorm2D): + bn_names.add(_layer.full_name()) + + for _conv, _bn in conv_bn_pairs: + self.assertTrue(_bn.layer_name in bn_names) + self.assertTrue(_conv.layer_name in conv_names) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/quantization/test_qat.py b/tests/quantization/test_qat.py new file mode 100644 index 00000000..e1d933a6 --- /dev/null +++ b/tests/quantization/test_qat.py @@ -0,0 +1,139 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +import unittest +sys.path.append("../../") +import paddle +import tempfile +from paddle.vision.models import resnet18 +from paddleslim.quant import SlimQuantConfig as QuantConfig +from paddleslim.quant import SlimQAT +from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver +from paddle.quantization.quanters.abs_max import FakeQuanterWithAbsMaxObserverLayer +from paddle.nn.quant.format import LinearDequanter, LinearQuanter + + +def load_model_and_count_layer(model_path, layer_types): + layer2count = dict([(_layer, 0) for _layer in layer_types]) + paddle.enable_static() + exe = paddle.static.Executor(paddle.CPUPlace()) + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + [ + inference_program, + feed_target_names, + fetch_targets, + ] = paddle.static.load_inference_model(model_path, exe) + for _op in inference_program.global_block().ops: + if _op.type in layer2count: + layer2count[_op.type] += 1 + paddle.disable_static() + return layer2count + + +class TestQuantAwareTraining(unittest.TestCase): + def setUp(self): + paddle.set_device("cpu") + self.init_case() + self.dummy_input = paddle.rand([1, 3, 224, 224]) + self.temp_dir = tempfile.TemporaryDirectory(dir="./") + self.path = os.path.join(self.temp_dir.name, 'qat') + + def tearDown(self): + self.temp_dir.cleanup() + + def extra_qconfig(self, q_config): + """The subclass of TestQuantAwareTraining can implement the function to add more special configuration.""" + pass + + def init_case(self): + quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9) + self.quantizer_type = FakeQuanterWithAbsMaxObserverLayer + self.quantizer_type_in_static = "fake_quantize_dequantize_moving_average_abs_max" + self.q_config = QuantConfig(activation=None, weight=None) + self.q_config.add_type_config( + paddle.nn.Conv2D, activation=quanter, weight=quanter) + self.extra_qconfig(self.q_config) + + def _count_layers(self, model, layer_type): + count = 0 + for _layer in model.sublayers(True): + if isinstance(_layer, layer_type): + count += 1 + return count + + def test_quantize(self): + model = resnet18() + conv_count = self._count_layers(model, paddle.nn.Conv2D) + qat = SlimQAT(self.q_config) + quant_model = qat.quantize(model, inplace=True, inputs=self.dummy_input) + quant_model.train() + out = quant_model(self.dummy_input) + out.backward() + quantizer_cnt = self._count_layers(quant_model, self.quantizer_type) + self.assertEqual(quantizer_cnt, conv_count * 2) + + def test_convert(self): + model = resnet18() + conv_count = self._count_layers(model, paddle.nn.Conv2D) + qat = SlimQAT(self.q_config) + quant_model = qat.quantize(model, inplace=True, inputs=self.dummy_input) + converted_model = qat.convert(quant_model, inplace=True) + + # check count of LinearQuanter and LinearDequanter in dygraph + quantizer_count_in_dygraph = self._count_layers(converted_model, + LinearQuanter) + dequantizer_count_in_dygraph = self._count_layers( + converted_model, LinearDequanter) + + self.assertEqual(quantizer_count_in_dygraph, conv_count) + self.assertEqual(dequantizer_count_in_dygraph, conv_count * 2) + + # check count of LinearQuanter and LinearDequanter in static model saved by jit.save + save_path = os.path.join(self.path, 'converted_model') + quant_model.eval() + paddle.jit.save(quant_model, save_path, input_spec=[self.dummy_input]) + + layer2count = load_model_and_count_layer( + save_path, ["quantize_linear", "dequantize_linear"]) + quantizer_count_in_static_model = layer2count['quantize_linear'] + dequantizer_count_in_static_model = layer2count['dequantize_linear'] + self.assertEqual(quantizer_count_in_dygraph, + quantizer_count_in_static_model) + self.assertEqual(dequantizer_count_in_dygraph, + dequantizer_count_in_static_model) + + def test_trace_qat_graph(self): + model = resnet18() + qat = SlimQAT(self.q_config) + quant_model = qat.quantize(model, inplace=True, inputs=self.dummy_input) + quantizer_count_indygraph = self._count_layers(quant_model, + self.quantizer_type) + save_path = os.path.join(self.path, 'qat_model') + quant_model.eval() + paddle.jit.save(quant_model, save_path, input_spec=[self.dummy_input]) + + layer2count = load_model_and_count_layer( + save_path, [self.quantizer_type_in_static]) + quantizer_count_in_static_model = layer2count[ + self.quantizer_type_in_static] + self.assertEqual(quantizer_count_indygraph, + quantizer_count_in_static_model) + + +if __name__ == '__main__': + unittest.main() -- GitLab