diff --git a/demo/quant/quant_post/quant_post.py b/demo/quant/quant_post/quant_post.py
index 2e3f7b40f9a8afda86cf8bbd9c001816a8e2ea04..b9fed6cb1d65b87981de7c505e78f10e7b7d53c3 100755
--- a/demo/quant/quant_post/quant_post.py
+++ b/demo/quant/quant_post/quant_post.py
@@ -22,13 +22,13 @@ parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 32, "Minibatch size.")
-add_arg('batch_num', int, 1, "Batch number")
+add_arg('batch_num', int, 10, "Batch number")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('model_path', str, "./inference_model/MobileNetV1_infer/", "model dir")
add_arg('save_path', str, "./quant_model/MobileNet/", "model dir to save quanted model")
add_arg('model_filename', str, 'inference.pdmodel', "model file name")
add_arg('params_filename', str, 'inference.pdiparams', "params file name")
-add_arg('algo', str, 'hist', "calibration algorithm")
+add_arg('algo', str, 'avg', "calibration algorithm")
add_arg('round_type', str, 'round', "The method of converting the quantized weights.")
add_arg('hist_percent', float, 0.9999, "The percentile of algo:hist")
add_arg('is_full_quantize', bool, False, "Whether is full quantization or not.")
diff --git a/example/quantization/qat_with_constraints/how_to_qat_with_constraint.ipynb b/example/quantization/qat_with_constraints/how_to_qat_with_constraint.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..1d9ce5c1ca52b657adf5e4048e7b7514c8aaca9f
--- /dev/null
+++ b/example/quantization/qat_with_constraints/how_to_qat_with_constraint.ipynb
@@ -0,0 +1,220 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# QAT with convolution and batchnorm fused constraints"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "本教程以Conv2D和BatchNorm的融合为例,介绍如何使用PaddleSlim接口快速为量化训练添加训练约束(constraints)。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 1. 添加依赖"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import paddle\n",
+ "from paddle.vision.models import resnet18\n",
+ "from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver\n",
+ "from paddleslim.quant import SlimQuantConfig as QuantConfig\n",
+ "from paddleslim.quant import SlimQAT\n",
+ "from paddleslim.quant.constraints import FreezedConvBNConstraint\n",
+ "paddle.set_device(\"cpu\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 2. 构造模型"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model = resnet18()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 3. 量化训练\n",
+ "\n",
+ "**配置**\n",
+ "\n",
+ "构造量化配置实例,并将激活和权重的量化方式指定为基础的 AbsMax 量化策略。然后,调用 `add_constraints` 配置一个Constraints实例。\n",
+ "最后,使用量化配置构造SlimQAT对象。代码如下所示:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9)\n",
+ "q_config = QuantConfig(activation=quanter, weight=quanter)\n",
+ "q_config.add_constraints(FreezedConvBNConstraint())\n",
+ "qat = SlimQAT(q_config)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**量化**\n",
+ "\n",
+ "调用SlimQAT对象的quantize方法,将模型转为用于量化训练的模型。\n",
+ "用户需要指定一个inputs, 用于推断分析动态图的执行拓扑图,以便自动检测出所有Conv2D和BN的组合。\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x = paddle.rand([1, 3, 224, 224])\n",
+ "quant_model = qat.quantize(model, inplace=True, inputs=x)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "在该步骤,所有的Conv2D和BN的组合,都会被替换为QuantedConv2DBatchNorm layer。在QuantedConv2DBatchNorm中,参考 [Quantizing deep convolutional networks for efficient inference: A whitepaper](https://arxiv.org/abs/1806.08342) ,在量化训练过程中模拟Conv2D和BatchNorm的融合。原理如下图所示:\n",
+ "\n",
+ "
\n",
+ "Conv BN 量化训练矫正方案示意图
\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**训练**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "quant_model.train()\n",
+ "out = quant_model(x)\n",
+ "out.backward()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**查看量化训练模型结构**\n",
+ "\n",
+ "直接在终端输出量化训练模型的结构,或者将模型保存到文件系统,并使用[netron](https://netron.app/)查看模型结构。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(quant_model)\n",
+ "quant_model.eval()\n",
+ "paddle.jit.save(quant_model, \"./qat_model\", input_spec=[x])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "模拟量化训练模型结构如下图所示:\n",
+ "\n",
+ "\n",
+ "模拟量化模型结构示意图
\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**保存推理模型**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "infer_model = qat.convert(quant_model, inplace=True)\n",
+ "print(infer_model)\n",
+ "paddle.jit.save(infer_model, \"./infer_model\", input_spec=[x])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "推理模型结构如下图所示:\n",
+ "\n",
+ "\n",
+ "量化推理模型结构示意图
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.3"
+ },
+ "vscode": {
+ "interpreter": {
+ "hash": "2f394aca7ca06fed1e6064aef884364492d7cdda3614a461e02e6407fc40ba69"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/example/quantization/qat_with_constraints/what_is_constraints.md b/example/quantization/qat_with_constraints/what_is_constraints.md
new file mode 100644
index 0000000000000000000000000000000000000000..dfe1c9280a820b3259d1573a2b67f2e1640c917d
--- /dev/null
+++ b/example/quantization/qat_with_constraints/what_is_constraints.md
@@ -0,0 +1,135 @@
+# QAT Constraints 原理
+
+## 1. 概述
+约束(Constraints)是指在模拟量化训练阶段,施加到模型上的一些约束规则。约束的对象包括且不限于:模型参数、量化参数、Layer 的执行。
+需要添加约束的原因:模拟量化的前向数值计算需要和推理的前向数值计算对齐。
+实现独立且可扩展的约束模块的原因:不同的推理库和硬件,对量化推理的支持程度和方式不同,则在模拟量化训练阶段需要添加不同的约束。
+
+## 2. 各种约束介绍
+
+
+### 2.1 Conv/MatMul Bias约束
+给定图1所示的计算图,FP32 数值类型的 Weight 和 Input 经过矩阵乘,再加上一个同是 FP32 数值类型的 Bias,最终的 FP32 数值类型的输出交给后续的 Layer,这里以 SkipLayernorm 为例。
+
+
+图 1 - Matmul Add计算示意图
+
+对于上述计算,可以列举出三种模拟 Int8 量化方式,分别对应三种 Int8 推理实现。
+
+#### 第一种量化实现
+
+这种实现无需特殊约束,只用正常的模拟量化训练。对该实现介绍的目的是供其它量化方式对比。
+如下图所示,在模拟量化阶段,在 Matmul 的两个输入之后插入 Per-Tensor 的 QDQ Layer(量化反量化操作),来模拟对weight和input的量化。
+weigh 和 input 对应的量化 scale 分别为 $S_w$ 和 $S_i$。右图为推理实现,其中:
+
+- 删除 weight 的模拟量化 Layer,以 Int8 数值格式存储 weight
+- 将 Input 的模拟量化 Layer 替换为量化 Layer。量化 Layer 负责将 FP32 类型的 input 量化为 Int8 数值类型,所用量化 scale 为模拟量化阶段统计到的 $S_i$
+- Matmul 以 Int8 数值类型进行乘计算,并将结果收集累加为 Int32 数值
+- Matmul 之后的 Per-Tensor 反量化 Layer 将 Int32 数值反量化为FP32,所用量化 scale为$S_w * S_i$
+- Add 正常用 FP32 数值类型计算
+
+
+
+图2 第一种量化实现示意图
+
+计算公式如下:
+
+$$S_w = \frac{AbsMax_w }{2^{8-1}}$$
+$$S_i = \frac{AbsMax_i}{2^{8-1}}$$
+$$QWeight = round(\frac{weight} { S_w})$$
+$$QInput = round(\frac{input} { S_i})$$
+$$DqOutput = output * S_w * S_i $$
+
+#### 第二种量化实现
+
+这种实现需要对量化 Scale 进行约束:
+1. Bias 的量化 scale 必须与 Matmul 的反量化 scale 一致
+2. Matmul 的量化 scale 受 Bias 限制,需要保证将量化后的 Bias 限制在 16bits 以内(仅限 TI 芯片)
+如下图左侧模型结构所示,不仅模拟量化 Matmul 两个输入,还模拟量化了 Bias。对 Bias 的模拟量化与传统模拟量化不同,量化 scale 不是通过统计Bias收集得到的,而是直接用的 $S_w * S_i$,也就是 Matmul Layer 的反量化 scale。
+这里就是对 Matmul+Bias 的一个约束规则,在模拟量化训练阶段,要以 $S_w * S_i$ 为量化 scale对 Bias 进行模拟量化。该约束是为了与推理实现对齐,如下图右侧模型结构所示。
+在推理实现中,Matmul 以 Int8 数值类型进行乘计算,收集相乘后的数值,累加到 Int32 类型的 accumulator 中。为了能将 Bias 也累加到这个 accumulator,需要将 Bias 量化为 Int32。并且,量化Bias 所用的量化 scale 需要与 accumulator 的量化 scale 一样,即 $S_w * S_i$。
+
+
+图3 第二种量化实现示意图
+
+整个推理执行的公式为:
+
+$$FP32Output = (QWeight * QInput + QBias) * S_w * S_i $$
+
+其中:
+
+- $QWeight$: 量化后的 weight
+- $QInput$:计算方式为 $round(\frac{FP32Input}{S_i})$
+- $QBias$: 量化后的 bias,计算方式为 $round(\frac{FP32Bias} {S_w * S_i})$
+
+对于TI芯片,基本过程如上所述,但是在某些场景下,需要限制量化后的 bias 到 Int16 的范围内。用公式表示为:
+
+$$\frac{AbsMax(FP32Bias)}{S_w * S_i} <= 2^{15}$$
+
+在模拟量化训练过程中,需要满足上述公式的约束,即本小节开始处提到的第二个约束。
+
+#### 第三种量化实现
+
+这种量化方式,只需对 Add 单独加约束,即:保证 Add Layer 的两个输入的量化 scale 保持一致。
+如下图左侧模型结构所示,Matmul 和 Add 相互独立,都分别添加了模拟量化 Layer。Add 的两个输入的量化 scale 为与 Matmul 输入的量化 scale 无关。
+在推理时,无法直接将 bias 累加到 matmul 的 Int32 accumulator 上,因为 bias 的量化 scale 与 accumulator 不一样。而是,先用 $S_w * S_i$ 将 Int32 accumulator 反量化为 FP32,然后再用 $S_o$ 将 FP32 数值量化为 Int8,并与 Int8 数值类型的 Bias 相加。最终将 Add 结果用 $S_o$ 反量化为 FP32 数值。整个推理的计算公式为:
+
+$$FP32Output = (round(\frac{(QWeight * QInput) * S_w*S_i}{S_o}) + QBias) * S_o$$
+
+其中:
+
+$QBias$: 量化后的 bias,计算方式为 $round(\frac{FP32Bias} {S_o})$
+这种实现方式,理论上没有第二种合理,这里列出来是为了对比第二种方法中的 Bias 的量化方式。
+
+
+图4 - 第三种量化实现示意图
+
+
+### 2.2 Convolution/mul BatchNorm约束
+
+本小节介绍2种在模拟量化训练中处理 convolution batchnorm 的方式。
+Batch normalization 被普遍用于计算机视觉模型训练,它可以减少模型层与层之间的影响,是模型训练更稳定和快速。
+训练时,公式如下:
+
+$$x_{bn} = \gamma (\frac{x-\mu_{B}}{\sigma_B} ) + \beta$$
+
+推理时,公式如下:
+
+$$x_{bn} = \gamma (\frac{x-\mu}{\sigma} ) + \beta$$
+
+其中,$\mu_B$ 和 $\sigma_B$ 是当前当前单个 batch 的平均值和标准差。$\mu$ 和 $\sigma$ 是在训练阶段,从多个 batch 中统计得到平均值和标准差。
+在推理时,会将 batch normalizaiton 融合到 convolution 或 mul layer 中。公式如下:
+
+$$W_{inference} = \frac{\gamma * W}{\sigma}$$
+$$Bias_{inference} = \beta - \frac{\gamma \mu}{\sigma}$$
+
+在模拟量化训练阶段,需要模拟上述 convolution 和 batch normalization 的融合操作。
+
+可以有两个选择:
+- 策略1:Unfreeze BN, 使用单个 batch 的统计的平均值($\mu_B$)和标准差($\sigma_B$)
+ - 缺点:训练不稳定。batch 间 BN 的统计值 $\sigma_B$ 变化比较大,与 BN 融合之后,convolution 的 weight 也会随 $\sigma_B$ 频繁变化,最终导致 weight 量化 scales 的不稳定。而在推理时,convolution 融合的是从全局统计的 $\sigma$。所以,训练时 weight 量化 scales 的不稳定,会体现为推理精度的不稳定。
+ - 现象:如图5绿色曲线。
+- 策略2:Freeze BN,即使用训练阶段统计的平均值 $\mu$ 和标准差 $\sigma$
+ - 缺点:没有用到 batch norm 的特性,对激活做 nomaliztion 所用的 mean 和 variance 不符合当前 batch 的激活分布,导致量化训练不稳定
+ - 现象:
+
+在策略1中,训练不稳定的直接原因是『weight的量化scale频繁变化』,使用 moving average 来计算 weight 的 scale,会缓解问题,但不能完全避免该问题,如图5橙色曲线所示。
+
+
+图5 - 各种 Conv BN 训练方法效果对比
+
+
+最优的方案是综合策略1和策略2的优点,让 weight 使用全局的BN参数 $\sigma$,使 weight 更稳定。让激活使用当前 batch 统计到的 $\sigma_B$,使激活 BN 整准确。该方案如图6所示
+
+
+图6 - Conv BN 量化训练矫正方案示意图
+
+
+图6方案有如下3个关键点:
+1. 让 conv weight 更稳定
+在量化之前,将全局统计的 BN 参数 $\sigma$ 合并到 conv weight,确保 weight 量化的稳定性。
+2. 让激活BN更准确
+在训练前期,通过乘上 $\frac{\sigma}{\sigma_B}$ 来恢复 weight 融合对激活的影响,使激活是经过$\sigma_B$ 和 $\mu_B$ normalize 后的结果,行为更接近训练时标准的BN. 在上图中,对应 freeze Batch norm 开关为0。
+3. Freeze BN
+经过一段 unfreeze batch norm 的训练,开启 freeze batch norm 开关,使激活的 normalization 也使用全局的 BN 参数,与推理保持一致。相当于,消除了 BN 带来的训练与推理的差异,专心学习量化。
+上述方案,效果如图6中蓝色曲线所示。
diff --git a/paddleslim/core/__init__.py b/paddleslim/core/__init__.py
index e3076a7a646e5a85673944cf8fd2d7756a9ae354..91299ffd26f4887edbe86e17d3beb5e8d07ce70c 100644
--- a/paddleslim/core/__init__.py
+++ b/paddleslim/core/__init__.py
@@ -18,8 +18,11 @@ from ..core import registry
from .registry import *
from ..core import dygraph
from .dygraph import *
+from .graph_tracer import GraphTracer
+from .graph import Graph
__all__ = []
__all__ += graph_wrapper.__all__
__all__ += registry.__all__
__all__ += dygraph.__all__
+__all__ += ["GraphTracer", "Graph"]
diff --git a/paddleslim/core/graph.py b/paddleslim/core/graph.py
new file mode 100644
index 0000000000000000000000000000000000000000..6929250a027683579952fd9e40608ebf92c03fcb
--- /dev/null
+++ b/paddleslim/core/graph.py
@@ -0,0 +1,113 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" A structure used to describe the network of a model."""
+
+from typing import List, Tuple
+import paddle
+
+__all__ = ["Graph"]
+
+
+class Node():
+ """ It is the node of the model's executing DAG. A node is a layer's once-calling.
+ Args:
+ layer(paddle.nn.Layer): The layer executed in the current node.
+ """
+
+ def __init__(self, layer, call_count):
+ self._layer = layer
+ self._layer_name = layer.full_name()
+ self._call_count = call_count
+ self._hash_name = f"{self._layer_name}_{self._call_count}"
+ self._next_nodes = []
+ self._previous_nodes = []
+
+ @property
+ def next_nodes(self):
+ """ Get the next nodes of the current node.
+
+ Returns:
+ A list of nodes representing the next nodes of the current node.
+
+ """
+ return self._next_nodes
+
+ @property
+ def previous_nodes(self):
+ """ Get the previous nodes of the current node.
+
+ Returns:
+ A list of nodes representing the previous nodes of the current node.
+ """
+ return self._previous_nodes
+
+ @property
+ def name(self):
+ return self._hash_name
+
+ @property
+ def layer_name(self) -> str:
+ """ Get the full name of the layer executed in the current node."""
+ return self._layer_name
+
+ @property
+ def layer(self) -> paddle.nn.Layer:
+ """ Get the layer executed in the current node."""
+ return self._layer
+
+ def is_leaf(self):
+ """ Whether this node is a leaf node. It is a leaf when the layer called by this node has no sublayers."""
+ return isinstance(self._layer,
+ paddle.nn.Layer) and len(self._layer.sublayers()) == 0
+
+ def __hash__(self):
+ return self._hash_name
+
+ def __str__(self):
+ return self._hash_name
+
+
+class Graph():
+ """ Directed Acyclic Graph used to describe the executing of the model.
+ Usually, the graph is built with tracer tools. There is no need to call the constructor directly.
+ """
+
+ def __init__(self):
+ self._name2node = {}
+
+ def __str__(self):
+ return str(self._name2node.keys())
+
+ @property
+ def nodes(self) -> List[Node]:
+ """ Get all the nodes in the graph.
+ """
+ return self._name2node.values()
+
+ def find_conv_bn(self) -> List[Tuple[Node, Node]]:
+ """ Find all the convolution and batch normalization pairs.
+ Returns:
+ A list of convolution and batch normalization pairs. Each pair is a tuple with two
+ nodes, the first being the convolution node and the second being the batch normalization node.
+ """
+ conv2d_types = [paddle.nn.Conv2D]
+ bn_types = [paddle.nn.BatchNorm2D]
+ results = []
+ for node in self.nodes:
+ if type(node.layer) in conv2d_types:
+ for _next_node in node.next_nodes:
+ if type(_next_node.layer) in bn_types:
+ results.append((node, _next_node))
+ break
+ return results
diff --git a/paddleslim/core/graph_tracer.py b/paddleslim/core/graph_tracer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b202d2fa219ecf916809a63363ce90229168971
--- /dev/null
+++ b/paddleslim/core/graph_tracer.py
@@ -0,0 +1,163 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" A structure used to describe the network of a model."""
+
+import sys
+import types
+import paddle
+
+from paddleslim.core import GraphWrapper
+from .graph import Graph, Node
+from paddleslim.core.dygraph import dygraph2program
+
+__all__ = ["GraphTracer"]
+
+
+def _apply(layer, func):
+ for name, child in layer.named_children():
+ func(child)
+ _apply(child, func)
+
+
+def _add_call_hook(module,
+ function_new,
+ method_name='forward',
+ backup_name='__forward_orig__'):
+ def _call_hook_enable(op):
+ # do not patch the top level modules. makes it easy to invoke by self.module(x)
+ if op is not module:
+ assert not hasattr(
+ op, backup_name
+ ), f'in {op.__class__.__name__} detected an existing function {backup_name} : please double check'
+ # backup the original forward of op into backup_name
+ method_orig = getattr(op, method_name)
+ setattr(op, backup_name, method_orig)
+ # set new method
+ method_new = types.MethodType(function_new, op)
+ setattr(op, method_name, method_new)
+
+ _apply(module, _call_hook_enable)
+
+
+def _remove_call_hook(module,
+ method_name='forward',
+ backup_name='__forward_orig__'):
+ def _call_hook_disable(op):
+ if op is not module:
+ if hasattr(op, backup_name):
+ method_new = getattr(op, method_name)
+ method_orig = getattr(op, backup_name)
+ setattr(op, method_name, method_orig)
+ # delete the backup
+ setattr(op, backup_name, method_new)
+ delattr(op, backup_name)
+
+ _apply(module, _call_hook_disable)
+
+
+class GraphTracer(paddle.nn.Layer):
+ """ A tool used to trace the execution of the model.
+ Call the forward of the model decorated by this tracer
+ and it will create a graph.
+
+ Args:
+ model(paddle.nn.Layer): The model to be traced.
+
+ Examples:
+ .. code-block:: python
+ from paddeslim.core.graph_tracer import GraphTracer
+ from paddle.vision.models import resnet18
+
+ model = resnet18()
+ x = paddle.rand([1, 3, 224, 224])
+ tracer = GraphTracer(model)
+ tracer(x)
+ print(tracer.graph)
+
+ """
+
+ def __init__(self, model: paddle.nn.Layer):
+ super(GraphTracer, self).__init__()
+ self._model = model
+ self._graph = None
+ self._call_count = {}
+ self._tensor_previous = {}
+
+ @property
+ def graph(self) -> Graph:
+ assert self._graph is not None, "Please trace the graph by calling forward function of current tracer."
+ return self._graph
+
+ def forward(self, inputs, *args, **kwargs):
+ self._graph = Graph()
+ _add_call_hook(self._model, self._analyze_modules_op)
+ self._model(inputs, *args, **kwargs)
+ _remove_call_hook(self._model)
+
+ def _analyze_modules_op(self, op, inputs, *args, **kwargs):
+ node = self._trace_in(op, inputs)
+ #print(f"inputs: {inputs.name}")
+ outputs = op.__forward_orig__(inputs, *args, **kwargs)
+ #print(f"outputs: {outputs.name}")
+ self._trace_out(node, outputs)
+ return outputs
+
+ def _call_layer(self, layer):
+ layer_name = layer.full_name()
+ if layer_name not in self._call_count:
+ self._call_count[layer_name] = 0
+ self._call_count[layer_name] += 1
+ return self._call_count[layer_name]
+
+ def _trace_in(self, layer, inputs):
+ inputs = self._convert_to_list(inputs)
+ call_cout = self._call_layer(layer)
+ current_node = Node(layer, call_cout)
+
+ if current_node.name not in self._graph._name2node:
+ self._graph._name2node[current_node.name] = current_node
+ current_node = self._graph._name2node[current_node.name]
+ for inp in inputs:
+ last_node = self._tensor_previous.get(inp.name, None)
+ if last_node is not None:
+ assert isinstance(last_node, Node)
+ if last_node not in current_node.previous_nodes:
+ current_node.previous_nodes.append(last_node)
+ if current_node not in last_node.next_nodes:
+ last_node.next_nodes.append(current_node)
+
+ return current_node
+
+ def _trace_out(self, current_node, outputs):
+ assert current_node is not None, "The current node has not been visited."
+ if current_node.is_leaf():
+ outputs = self._convert_to_list(outputs)
+ for out in outputs:
+ self._tensor_previous[out.name] = current_node
+
+ def _convert_to_list(self, tensors):
+ """ Convert tensor to list.
+ It is important to convert the inputs to a list.
+ Because visiting the tensor by 'for ... in' will create new
+ temp variables and break the tracing process.
+ """
+ if isinstance(tensors, paddle.Tensor):
+ return [tensors]
+ elif isinstance(tensors, (list, tuple)):
+ for _t in tensors:
+ assert isinstance(_t, paddle.Tensor)
+ return tensors
+ raise TypeError(
+ f"Unsopported type: {type(tensors)}; The inputs type should be paddle.Tensor' or list of paddle.Tensor."
+ )
diff --git a/paddleslim/quant/__init__.py b/paddleslim/quant/__init__.py
index cdd60119533bfbcdf9160fc92a83941a14a9b8e1..8bd83a9cca893f735185eaba2a4a217e8497f2ff 100755
--- a/paddleslim/quant/__init__.py
+++ b/paddleslim/quant/__init__.py
@@ -39,3 +39,6 @@ except Exception as e:
"please use Paddle >= {} or develop version".format(min_paddle_version))
from .quant_embedding import quant_embedding
+from . import nn
+from .qat import SlimQAT
+from .config import SlimQuantConfig
diff --git a/paddleslim/quant/config.py b/paddleslim/quant/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c3a14a4532e39f33b93f47850f5c860fb89764f
--- /dev/null
+++ b/paddleslim/quant/config.py
@@ -0,0 +1,31 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from paddle.quantization.config import QuantConfig
+from paddle.quantization.factory import QuanterFactory
+
+
+class SlimQuantConfig(QuantConfig):
+ def __init__(self, activation: QuanterFactory, weight: QuanterFactory):
+ super(SlimQuantConfig, self).__init__(activation, weight)
+ self._constraints = []
+
+ @property
+ def constraints(self):
+ return self._constraints
+
+ def add_constraints(self, constraints):
+ if not isinstance(constraints, (list, tuple)):
+ constraints = [constraints]
+ self._constraints.extend(constraints)
\ No newline at end of file
diff --git a/paddleslim/quant/constraints/__init__.py b/paddleslim/quant/constraints/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..28a2a739bde3c72715479c565684d37f4f10bad5
--- /dev/null
+++ b/paddleslim/quant/constraints/__init__.py
@@ -0,0 +1,18 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .constraint import FusionConstraint, Constraint
+from .conv_bn_constraints import FreezedConvBNConstraint
+
+__all__ = ["Constraint", "FusionConstraint", "FreezedConvBNConstraint"]
diff --git a/paddleslim/quant/constraints/constraint.py b/paddleslim/quant/constraints/constraint.py
new file mode 100644
index 0000000000000000000000000000000000000000..43acad98416d89db41abf1e0d5b1ad2652b4b664
--- /dev/null
+++ b/paddleslim/quant/constraints/constraint.py
@@ -0,0 +1,91 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import abc
+from typing import List
+from paddleslim.core.graph import Graph
+from paddleslim.quant.config import SlimQuantConfig
+
+__all__ = ["FusionConstraint", "Constraint"]
+
+
+class Constraint(metaclass=abc.ABCMeta):
+ """ Constraints are the rules applied on layers during quantization-aware training.
+ The abstract class defined the interface of constraint. All the constraints should
+ extend this abstract class and implement the 'apply' function.
+ """
+
+ @abc.abstractmethod
+ def apply(self,
+ model: paddle.nn.Layer,
+ graph: Graph,
+ qconfig: SlimQuantConfig) -> None:
+ """ Apply the rules of the constraint on the model in place.
+ Usually, there are three steps to implement the apply function:
+ 1. Find all the targets to be constrained. Each target can be one layer or a group of layers.
+ 2. If the target is a group of layers, fuse it into a fusion layer.
+ 3. Mapping the fusion layer to a QAT strategy layer by the "add_qat_layer_mapping" function of "qconfig".
+
+ Args:
+ - model(paddle.nn.Layer): The model to be constrained.
+ - graph(Graph): A structure stored the DAG of executing model. It is used to
+ help find layers by some pattern.
+ - qconfig(SlimQuantConfig): The configuration of quantization.
+ """
+ pass
+
+
+class FusionConstraint(Constraint, metaclass=abc.ABCMeta):
+ """ Define some functions used to fuse operators.
+ """
+
+ def _find_parent_and_sublayer_name(self, model, layer):
+ for _name, _sub_layer in model.named_sublayers():
+ if layer.full_name() == _sub_layer.full_name():
+ return model, _name
+ else:
+ result = self._find_parent_and_sublayer_name(_sub_layer, layer)
+ if result is not None:
+ return result
+
+ def _replace_layer(self, model, source, target):
+ result = self._find_parent_and_sublayer_name(model, source)
+ assert result is not None, f"Can't find the parent layer of {source.full_name()} in model {model.full_name()}!!!"
+ parent_layer, sub_name = result
+ parent_layer._sub_layers[sub_name] = target
+ setattr(parent_layer, sub_name, target)
+
+ def fuse_ops(self,
+ model: paddle.nn.Layer,
+ fused_layer_type,
+ layers: List[paddle.nn.Layer],
+ config: SlimQuantConfig):
+ """ Fuse layers into fusion layer.
+ Args:
+ - model(paddle.nn.Layer): The model whose sublayers will be fused.
+ - fused_layer_type(type): Type of fusion layer.
+ - layers(list): List of layers to be fused. It will be the arguments to create 'fused_layer_type' instance
+ by calling fused_layer_type(*layers).
+ - configs(SlimQuantConfig): The configuration of quantization.
+ """
+ fused_layer = fused_layer_type(*layers)
+
+ for _layer in layers[1:]:
+ self._replace_layer(model, _layer, paddle.nn.Identity())
+ if _layer in config._layer2config:
+ config._layer2config.pop(_layer)
+
+ self._replace_layer(model, layers[0], fused_layer)
+ config._layer2config[fused_layer] = config._layer2config[layers[0]]
diff --git a/paddleslim/quant/constraints/conv_bn_constraints.py b/paddleslim/quant/constraints/conv_bn_constraints.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b89b782279c460e607282c700afc47017ce39b8
--- /dev/null
+++ b/paddleslim/quant/constraints/conv_bn_constraints.py
@@ -0,0 +1,61 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .constraint import FusionConstraint
+from ..nn import Conv2DBatchNormWrapper, QuantedConv2DBatchNorm
+
+FUSED_LAYER = Conv2DBatchNormWrapper
+QAT_FUSED_LAYER = QuantedConv2DBatchNorm
+
+
+class FreezedConvBNConstraint(FusionConstraint):
+ """ Simulate the folding of convolution and batch norm with a correction strategy.
+ Reference to: Quantizing deep convolutional networks for efficient inference: A whitepaper
+ Args:
+ - freeze_bn_delay(int): After the 'freeze_bn_delay' steps training, switch from using batch
+ statistics to long-term moving averages for batch normalization.
+
+ Examples:
+ .. code-block:: python
+
+ import paddle
+ from paddleslim.quant import SlimQuantConfig
+ from paddleslim.quant import SlimQAT
+ from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
+ from paddle.vision.models import resnet18
+
+ model = resnet18()
+ quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9)
+ q_config = SlimQuantConfig(activation=quanter, weight=quanter)
+ q_config.add_constraints(FreezedConvBNConstraint(freeze_bn_delay=1))
+ qat = SlimQAT(q_config)
+ quant_model = qat.quantize(model, inplace=True, inputs=x)
+
+ """
+
+ def __init__(self, freeze_bn_delay=0):
+ self._freeze_bn_delay = freeze_bn_delay
+
+ def apply(self, model, graph, config):
+ conv_bn_pairs = graph.find_conv_bn()
+ for pair in conv_bn_pairs:
+ pair = [node._layer for node in pair]
+ self.fuse_ops(model, FUSED_LAYER, pair, config)
+ config.add_qat_layer_mapping(FUSED_LAYER, QAT_FUSED_LAYER)
+
+ def _set_freeze_bn_delay(layer):
+ if isinstance(layer, FUSED_LAYER):
+ setattr(layer, "_freeze_bn_delay", self._freeze_bn_delay)
+
+ model.apply(_set_freeze_bn_delay)
diff --git a/paddleslim/quant/nn/__init__.py b/paddleslim/quant/nn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..abd8b725a0a69027a0d9eaff8398715f991f5e0b
--- /dev/null
+++ b/paddleslim/quant/nn/__init__.py
@@ -0,0 +1,17 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .conv_bn import QuantedConv2DBatchNorm, Conv2DBatchNormWrapper
+
+__all__ = ["QuantedConv2DBatchNorm", "Conv2DBatchNormWrapper"]
diff --git a/paddleslim/quant/nn/conv_bn.py b/paddleslim/quant/nn/conv_bn.py
new file mode 100644
index 0000000000000000000000000000000000000000..d63ab27b4df81499137cb546a4ed0fdcfb94eac8
--- /dev/null
+++ b/paddleslim/quant/nn/conv_bn.py
@@ -0,0 +1,259 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Wrapper layer to simulate folding batch norms during quantization-aware training.
+"""
+import paddle
+from paddle.nn import Layer
+from paddle.nn import functional as F
+from paddle.nn.layer.norm import _BatchNormBase
+from paddle.nn.quant.format import ConvertibleQuantedLayer
+
+
+class BatchNorm(Layer):
+ r"""Wrapper of batchnorm layer. It is used to get the mean and variance of cerrent mini-batch.
+ Args:
+ - bn(paddle.nn.layer.norm._BatchNormBase): The batchnorm layer to be wrapped.
+ """
+
+ def __init__(self, bn: _BatchNormBase):
+ super(BatchNorm, self).__init__()
+ self._bn = bn
+ self._runing_mean = self._bn._mean
+ self._runing_variance = self._bn._variance
+ self._weight = self._bn.weight
+ self._bias = self._bn.bias
+ self._batch_mean = None
+ self._batch_variance = None
+ self._training = True
+ self._batch_mean = None
+ self._batch_variance = None
+ assert self._bias is not None
+
+ def forward(self, input):
+ r"""
+ Args:
+ - input(Tensor): The input to be normalized.
+ Return:
+ A tuple with 3 elements that is normalized output, mean of current mini-natch and variance of
+ current mini-batch.
+ """
+ (
+ batch_norm_out, _, _, batch_mean, batch_variance, _,
+ ) = paddle._C_ops.batch_norm(
+ input,
+ self.
+ _runing_mean, # TODO: should freeze this runing mean to avoid instability in training
+ self.
+ _runing_variance, # TODO: should freeze this runing variance to avoid instability in training
+ self._weight,
+ self._bias,
+ not self._training,
+ self._bn._momentum,
+ self._bn._epsilon,
+ self._bn._data_format,
+ False, # use_global_stats
+ False, # trainable_statistics
+ )
+ batch_mean.stop_gradient = True
+ batch_variance.stop_gradient = True
+ self._batch_mean = batch_mean
+ self._batch_variance = batch_variance
+ return batch_norm_out, batch_mean, batch_variance
+
+
+class Conv2DBatchNormWrapper(paddle.nn.Layer):
+ def __init__(self, conv, bn):
+ super(Conv2DBatchNormWrapper, self).__init__()
+ self._conv = conv
+ self._bn = bn
+ self._freeze_bn_delay = None
+
+ def forward(self, inputs):
+ return self._bn(self._conv(inputs))
+
+
+class QuantedConv2DBatchNorm(ConvertibleQuantedLayer):
+ r"""Wrapper layer to simulate folding batch norms during quantization-aware training.
+ Fisrtly, it will execute once convolution and batch norms prior to quantizing weights
+ to get the long term statistics(i.e. runing mean and runing variance).
+ And it always scale the convolution's weights with a correction factor to the long term
+ statistics prior to quantization. This ensures that there is no jitter in the quantized
+ weights due to batch to batch variation.
+ Then the training will be divided into two phases:
+ 1. During the initial phase of training, with freeze_bn is false. It undo the scaling of
+ the weights so that outputs are identical to regular batch normalization.
+ 2. After sufficient training, with freeze_bn is true. Switch from using batch statistics
+ to long term moving averages for batch normalization.
+ Reference: Quantizing deep convolutional networks for efficient inference: A whitepaper
+ """
+
+ def __init__(self, conv_bn: Conv2DBatchNormWrapper, q_config):
+ super(QuantedConv2DBatchNorm, self).__init__()
+ conv = conv_bn._conv
+ bn = conv_bn._bn
+ self._freeze_bn_delay = conv_bn._freeze_bn_delay
+ # For Conv2D
+ self._groups = getattr(conv, '_groups')
+ self._stride = getattr(conv, '_stride')
+ self._padding = getattr(conv, '_padding')
+ self._padding_mode = getattr(conv, '_padding_mode')
+ if self._padding_mode != 'zeros':
+ self._reversed_padding_repeated_twice = getattr(
+ conv, '_reversed_padding_repeated_twice')
+ self._dilation = getattr(conv, '_dilation')
+ self._data_format = getattr(conv, '_data_format')
+ self.conv_weight = getattr(conv, 'weight')
+ self.conv_bias = getattr(conv, 'bias')
+
+ if self.conv_bias is None:
+ self.conv_bias = paddle.create_parameter(
+ bn.bias.shape, dtype=self.conv_weight.dtype, is_bias=True)
+
+ self.bn = BatchNorm(bn)
+
+ self.weight_quanter = None
+ self.activation_quanter = None
+ if q_config.weight is not None:
+ self.weight_quanter = q_config.weight._instance(conv)
+ if q_config.activation is not None:
+ self.activation_quanter = q_config.activation._instance(conv)
+
+ self._freeze_bn = False
+
+ self._steps = 0
+
+ def freeze_bn(self):
+ self._freeze_bn = True
+
+ def unfreeze_bn(self):
+ self._freeze_bn = False
+
+ def forward(self, input):
+ if self.training:
+ if self._freeze_bn_delay == self._steps:
+ self.freeze_bn()
+
+ if self._freeze_bn:
+ out = self._forward_with_bn_freezed(input)
+ else:
+ out = self._forward_with_bn_unfreezed(input)
+ self._steps += 1
+ return out
+ else:
+ return self._eval_forward(input)
+
+ def _forward_with_bn_unfreezed(self, input):
+ quant_input = input
+ if self.activation_quanter is not None:
+ quant_input = self.activation_quanter(quant_input)
+
+ # Step 1: Excute conv bn to get global runing mean and variance
+ _, _, batch_variance = self._conv_bn(quant_input, self.conv_weight)
+ # Step 2: Merge and quantize weights of conv and bn
+ merged_weight = self._merge_conv_bn_weight(self.conv_weight, self.bn)
+ quant_weight = merged_weight
+ if self.weight_quanter is not None:
+ quant_weight = self.weight_quanter(quant_weight)
+ # Step 3: Excute conv with merged weights
+ conv_out = self._conv_forward(quant_input, quant_weight)
+ # Step 4: Scale output of conv and merge bias
+ factor = paddle.reshape((self.bn._runing_variance / batch_variance),
+ [1, -1, 1, 1])
+ factor.stop_gradient = True
+ conv_out = conv_out * factor
+ merged_bias = self._merge_conv_unfreezed_bn_bias(
+ self.conv_bias, self.bn)
+ return conv_out + paddle.reshape(merged_bias, [1, -1, 1, 1])
+
+ def _eval_forward(self, input):
+ quant_input = input
+ if self.activation_quanter is not None:
+ quant_input = self.activation_quanter(quant_input)
+ if self.weight_quanter is not None:
+ quant_weight = self.weight_quanter(self.conv_weight)
+ conv_out = self._conv_forward(quant_input, quant_weight)
+ conv_out = conv_out if self.conv_bias is None else conv_out + paddle.reshape(
+ self.conv_bias, [1, -1, 1, 1])
+ return conv_out
+
+ def _forward_with_bn_freezed(self, input):
+ quant_input = input
+ if self.activation_quanter is not None:
+ quant_input = self.activation_quanter(quant_input)
+ # Step 1: Excute conv bn to get global runing mean and variance
+ self._conv_bn(quant_input, self.conv_weight)
+ # Step 2: Merge and quantize weights of conv and bn
+ merged_weight = self._merge_conv_bn_weight(self.conv_weight, self.bn)
+ quant_weight = merged_weight
+ if self.weight_quanter is not None:
+ quant_weight = self.weight_quanter(quant_weight)
+ # Step 3: Excute conv with merged weights
+ conv_out = self._conv_forward(quant_input, quant_weight)
+ # Step 4: Merge bias of conv and bn
+ merged_bias = self._merge_conv_freezed_bn_bias(self.conv_bias, self.bn)
+ return conv_out + paddle.reshape(merged_bias, [1, -1, 1, 1])
+
+ def _merge_conv_bn_weight(self, conv_weight, bn: BatchNorm):
+ merged_weight = paddle.reshape(
+ bn._weight, [-1, 1, 1, 1]) * conv_weight / paddle.reshape(
+ bn._runing_variance, [-1, 1, 1, 1])
+ conv_weight.set_value(merged_weight)
+ return conv_weight
+
+ def _merge_conv_freezed_bn_bias(self, conv_bias, bn: BatchNorm):
+ assert conv_bias is not None
+
+ merged_bias = (conv_bias + bn._bias -
+ (bn._weight * bn._runing_mean) / bn._runing_variance)
+ conv_bias.set_value(merged_bias)
+ return conv_bias
+
+ def _merge_conv_unfreezed_bn_bias(self, conv_bias, bn: BatchNorm):
+ assert conv_bias is not None
+ merged_bias = (conv_bias + bn._bias -
+ bn._weight * bn._batch_mean / bn._batch_variance)
+ conv_bias.set_value(merged_bias)
+ return conv_bias
+
+ def _conv_bn(self, input, conv_weight):
+ return self._bn_forward(self._conv_forward(input, conv_weight))
+
+ def _bn_forward(self, input):
+ return self.bn.forward(input)
+
+ def _conv_forward(self, inputs, weights):
+ if self._padding_mode != 'zeros':
+ inputs = F.pad(
+ inputs,
+ self._reversed_padding_repeated_twice,
+ mode=self._padding_mode,
+ data_format=self._data_format, )
+ self._padding = 0
+
+ return F.conv2d(
+ inputs,
+ weights,
+ bias=self.conv_bias,
+ padding=self._padding,
+ stride=self._stride,
+ dilation=self._dilation,
+ groups=self._groups,
+ data_format=self._data_format, )
+
+ def weights_to_quanters(self):
+ return [('conv_weight', 'weight_quanter')]
+
+ def activation_quanters(self):
+ return ['activation_quanter']
diff --git a/paddleslim/quant/qat.py b/paddleslim/quant/qat.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3e3d2d092db72ca3f87bcdadff233fe70fa9fa3
--- /dev/null
+++ b/paddleslim/quant/qat.py
@@ -0,0 +1,42 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import copy
+from paddle.quantization import QAT, QuantConfig
+from ..core import GraphTracer
+
+
+class SlimQAT(QAT):
+ def __init__(self, q_config):
+ super(SlimQAT, self).__init__(q_config)
+
+ def _apply_constraints(self, model, graph):
+ for _contraint in self._config.constraints:
+ _contraint.apply(model, graph, self._config)
+
+ def _analysis_model(self, _model, inputs):
+ assert inputs is not None
+ tracer = GraphTracer(_model)
+ tracer(inputs)
+ return tracer.graph
+
+ def quantize(self, model: paddle.nn.Layer, inplace=False, inputs=None):
+ _model = model if inplace else copy.deepcopy(model)
+ self._config._specify(_model)
+ graph = self._analysis_model(_model, inputs)
+ self._apply_constraints(_model, graph)
+ self._convert_to_quant_layers(_model, self._config)
+ self._insert_activation_observers(_model, self._config)
+ return _model
diff --git a/tests/quantization/test_conv_bn_constraints.py b/tests/quantization/test_conv_bn_constraints.py
new file mode 100644
index 0000000000000000000000000000000000000000..78da1ce1cf04878d951e8aef1d79785fb7730df6
--- /dev/null
+++ b/tests/quantization/test_conv_bn_constraints.py
@@ -0,0 +1,103 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import os
+import unittest
+import tempfile
+sys.path.append("../../")
+import paddle
+from paddle.vision.models import resnet18
+from paddleslim.quant import SlimQuantConfig as QuantConfig
+from paddleslim.quant import SlimQAT
+from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
+from paddleslim.quant.nn.conv_bn import QuantedConv2DBatchNorm
+from paddleslim.quant.constraints import FreezedConvBNConstraint
+from test_qat import TestQuantAwareTraining, load_model_and_count_layer
+
+
+class TestConvBNConstraintsBaseCase(TestQuantAwareTraining):
+ """ Common cases for testing 'quantize', 'convert' and 'jit.save' function."""
+
+ def extra_qconfig(self, qconfig):
+ qconfig.add_constraints(FreezedConvBNConstraint(freeze_bn_delay=1))
+
+
+class TestConvBNConstraints(unittest.TestCase):
+ """ More special cases on convolution and batch norm constraints."""
+
+ def setUp(self):
+ paddle.set_device("cpu")
+ self.temp_dir = tempfile.TemporaryDirectory(dir="./")
+ self.path = os.path.join(self.temp_dir.name, 'conv_bn_constraints')
+
+ def tearDown(self):
+ self.temp_dir.cleanup()
+
+ def _count_layers(self, model, layer_type):
+ count = 0
+ for _layer in model.sublayers(True):
+ if isinstance(_layer, layer_type):
+ count += 1
+ return count
+
+ def _get_one_layer(self, model, layer_type):
+ for _layer in model.sublayers(True):
+ if isinstance(_layer, layer_type):
+ return _layer
+ return None
+
+ def test_conv_bn(self):
+ model = resnet18()
+ conv_count = self._count_layers(model, paddle.nn.Conv2D)
+ quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9)
+ q_config = QuantConfig(activation=quanter, weight=quanter)
+ # It will freeze the batch normaliztion after 'freeze_bn_delay' steps
+ q_config.add_constraints(FreezedConvBNConstraint(freeze_bn_delay=1))
+
+ qat = SlimQAT(q_config)
+ x = paddle.rand([1, 3, 224, 224])
+ paddle.jit.save(model, "./test_model", input_spec=[x])
+ quant_model = qat.quantize(model, inplace=True, inputs=x)
+
+ # check freeze_bn_delay
+ qat_conv_bn_layer = self._get_one_layer(quant_model,
+ QuantedConv2DBatchNorm)
+ self.assertIsNotNone(qat_conv_bn_layer)
+ self.assertFalse(qat_conv_bn_layer._freeze_bn)
+ quant_model.train()
+ out = quant_model(x)
+ out.backward()
+ out = quant_model(x)
+ out.backward()
+ self.assertTrue(qat_conv_bn_layer._freeze_bn)
+
+ # check the count of QAT layers in QAT model
+ qat_layer_count = self._count_layers(quant_model,
+ QuantedConv2DBatchNorm)
+ self.assertEqual(qat_layer_count, conv_count)
+
+ # check the count of convolution and batch norm in saved static graph
+ quant_model.eval()
+ infer_model = qat.convert(quant_model, inplace=True)
+ save_path = os.path.join(self.path, 'infer_model')
+ paddle.jit.save(infer_model, save_path, input_spec=[x])
+ layer2count = load_model_and_count_layer(save_path,
+ ['conv2d', 'batch_norm'])
+ self.assertEqual(layer2count['conv2d'], conv_count)
+ self.assertEqual(layer2count['batch_norm'], 0)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/quantization/test_graph_tracer.py b/tests/quantization/test_graph_tracer.py
new file mode 100644
index 0000000000000000000000000000000000000000..059ad62ec35b681adc702378a02e219473192b46
--- /dev/null
+++ b/tests/quantization/test_graph_tracer.py
@@ -0,0 +1,58 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import os
+import unittest
+import tempfile
+sys.path.append("../../")
+import paddle
+from paddle.vision.models import resnet18
+from paddleslim.core.graph_tracer import GraphTracer
+from paddle.vision.models import resnet18
+
+
+class TestConvBNConstraints(unittest.TestCase):
+ def _count_layers(self, model, layer_type):
+ count = 0
+ for _layer in model.sublayers(True):
+ if isinstance(_layer, layer_type):
+ count += 1
+ return count
+
+ def test_conv_bn(self):
+ paddle.set_device("cpu")
+ model = resnet18()
+ conv_count = self._count_layers(model, paddle.nn.Conv2D)
+ x = paddle.rand([1, 3, 224, 224])
+ tracer = GraphTracer(model)
+ tracer(x)
+ conv_bn_pairs = tracer.graph.find_conv_bn()
+ self.assertEqual(len(conv_bn_pairs), conv_count)
+
+ conv_names = set()
+ bn_names = set()
+ for _layer in model.sublayers():
+ if isinstance(_layer, paddle.nn.Conv2D):
+ conv_names.add(_layer.full_name())
+ elif isinstance(_layer, paddle.nn.BatchNorm2D):
+ bn_names.add(_layer.full_name())
+
+ for _conv, _bn in conv_bn_pairs:
+ self.assertTrue(_bn.layer_name in bn_names)
+ self.assertTrue(_conv.layer_name in conv_names)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/quantization/test_qat.py b/tests/quantization/test_qat.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1d933a60b2e2e910ae570f2e033b2f467bd6df6
--- /dev/null
+++ b/tests/quantization/test_qat.py
@@ -0,0 +1,139 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import os
+import unittest
+sys.path.append("../../")
+import paddle
+import tempfile
+from paddle.vision.models import resnet18
+from paddleslim.quant import SlimQuantConfig as QuantConfig
+from paddleslim.quant import SlimQAT
+from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
+from paddle.quantization.quanters.abs_max import FakeQuanterWithAbsMaxObserverLayer
+from paddle.nn.quant.format import LinearDequanter, LinearQuanter
+
+
+def load_model_and_count_layer(model_path, layer_types):
+ layer2count = dict([(_layer, 0) for _layer in layer_types])
+ paddle.enable_static()
+ exe = paddle.static.Executor(paddle.CPUPlace())
+ main_program = paddle.static.Program()
+ startup_program = paddle.static.Program()
+ with paddle.static.program_guard(main_program, startup_program):
+ [
+ inference_program,
+ feed_target_names,
+ fetch_targets,
+ ] = paddle.static.load_inference_model(model_path, exe)
+ for _op in inference_program.global_block().ops:
+ if _op.type in layer2count:
+ layer2count[_op.type] += 1
+ paddle.disable_static()
+ return layer2count
+
+
+class TestQuantAwareTraining(unittest.TestCase):
+ def setUp(self):
+ paddle.set_device("cpu")
+ self.init_case()
+ self.dummy_input = paddle.rand([1, 3, 224, 224])
+ self.temp_dir = tempfile.TemporaryDirectory(dir="./")
+ self.path = os.path.join(self.temp_dir.name, 'qat')
+
+ def tearDown(self):
+ self.temp_dir.cleanup()
+
+ def extra_qconfig(self, q_config):
+ """The subclass of TestQuantAwareTraining can implement the function to add more special configuration."""
+ pass
+
+ def init_case(self):
+ quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9)
+ self.quantizer_type = FakeQuanterWithAbsMaxObserverLayer
+ self.quantizer_type_in_static = "fake_quantize_dequantize_moving_average_abs_max"
+ self.q_config = QuantConfig(activation=None, weight=None)
+ self.q_config.add_type_config(
+ paddle.nn.Conv2D, activation=quanter, weight=quanter)
+ self.extra_qconfig(self.q_config)
+
+ def _count_layers(self, model, layer_type):
+ count = 0
+ for _layer in model.sublayers(True):
+ if isinstance(_layer, layer_type):
+ count += 1
+ return count
+
+ def test_quantize(self):
+ model = resnet18()
+ conv_count = self._count_layers(model, paddle.nn.Conv2D)
+ qat = SlimQAT(self.q_config)
+ quant_model = qat.quantize(model, inplace=True, inputs=self.dummy_input)
+ quant_model.train()
+ out = quant_model(self.dummy_input)
+ out.backward()
+ quantizer_cnt = self._count_layers(quant_model, self.quantizer_type)
+ self.assertEqual(quantizer_cnt, conv_count * 2)
+
+ def test_convert(self):
+ model = resnet18()
+ conv_count = self._count_layers(model, paddle.nn.Conv2D)
+ qat = SlimQAT(self.q_config)
+ quant_model = qat.quantize(model, inplace=True, inputs=self.dummy_input)
+ converted_model = qat.convert(quant_model, inplace=True)
+
+ # check count of LinearQuanter and LinearDequanter in dygraph
+ quantizer_count_in_dygraph = self._count_layers(converted_model,
+ LinearQuanter)
+ dequantizer_count_in_dygraph = self._count_layers(
+ converted_model, LinearDequanter)
+
+ self.assertEqual(quantizer_count_in_dygraph, conv_count)
+ self.assertEqual(dequantizer_count_in_dygraph, conv_count * 2)
+
+ # check count of LinearQuanter and LinearDequanter in static model saved by jit.save
+ save_path = os.path.join(self.path, 'converted_model')
+ quant_model.eval()
+ paddle.jit.save(quant_model, save_path, input_spec=[self.dummy_input])
+
+ layer2count = load_model_and_count_layer(
+ save_path, ["quantize_linear", "dequantize_linear"])
+ quantizer_count_in_static_model = layer2count['quantize_linear']
+ dequantizer_count_in_static_model = layer2count['dequantize_linear']
+ self.assertEqual(quantizer_count_in_dygraph,
+ quantizer_count_in_static_model)
+ self.assertEqual(dequantizer_count_in_dygraph,
+ dequantizer_count_in_static_model)
+
+ def test_trace_qat_graph(self):
+ model = resnet18()
+ qat = SlimQAT(self.q_config)
+ quant_model = qat.quantize(model, inplace=True, inputs=self.dummy_input)
+ quantizer_count_indygraph = self._count_layers(quant_model,
+ self.quantizer_type)
+ save_path = os.path.join(self.path, 'qat_model')
+ quant_model.eval()
+ paddle.jit.save(quant_model, save_path, input_spec=[self.dummy_input])
+
+ layer2count = load_model_and_count_layer(
+ save_path, [self.quantizer_type_in_static])
+ quantizer_count_in_static_model = layer2count[
+ self.quantizer_type_in_static]
+ self.assertEqual(quantizer_count_indygraph,
+ quantizer_count_in_static_model)
+
+
+if __name__ == '__main__':
+ unittest.main()