未验证 提交 28bea706 编写于 作者: M minghaoBD 提交者: GitHub

[Unstructured_prune]add static masks for QAT (#917)

* add static masks for quant

* Update unstructured_pruner.rst

* Update unstructured_prune_api.rst

* add the UT

* fix typo

* API docs
Co-authored-by: Ngmm <38800877+mmglove@users.noreply.github.com>
上级 6410ed1a
...@@ -113,6 +113,35 @@ UnstructuredPruner ...@@ -113,6 +113,35 @@ UnstructuredPruner
.. ..
.. py:method:: paddleslim.UnstructuredPruner.set_static_masks()
这个API比较特殊,一般情况下不会用到。只有在【基于 FP32 稀疏化模型】进行量化训练时需要调用,因为需要固定住原本被置0的权重,保持0不变。具体来说,对于输入的 parameters=[0, 3, 0, 4, 5.5, 0],会生成对应的mask为:[0, 1, 0, 1, 1, 0]。而且在训练过程中,该 mask 数值不会随 parameters 更新(训练)而改变。在评估/保存模型之前,可以通过调用 pruner.update_params() mask应用到 parameters 上,从而达到『在训练过程中 parameters 中数值为0的参数不参与训练』的效果。
**示例代码:**
.. code-block:: python
import paddle
from paddleslim import UnstructuredPruner
from paddle.vision.models import LeNet as net
import numpy as np
place = paddle.set_device('cpu')
model = net(num_classes=10)
pruner = UnstructuredPruner(model, mode='threshold', threshold=0.5)
'''注释中为量化训练相关代码,以及参数导入
QAT configs and APIs
restore the sparse FP32 weights
'''
pruner.set_static_masks()
# quantization-aware training a batch
pruner.update_params()# 这一行代码需要在模型eval和保存前调用。
# eval or save pruned model
..
.. py:method:: paddleslim.UnstructuredPruner.total_sparse(model) .. py:method:: paddleslim.UnstructuredPruner.total_sparse(model)
UnstructuredPruner中的静态方法,用于计算给定的模型(model)的稀疏度并返回。该方法为静态方法,是考虑到在单单做模型评价的时候,我们就不需要初始化一个UnstructuredPruner示例了。 UnstructuredPruner中的静态方法,用于计算给定的模型(model)的稀疏度并返回。该方法为静态方法,是考虑到在单单做模型评价的时候,我们就不需要初始化一个UnstructuredPruner示例了。
......
...@@ -75,7 +75,7 @@ UnstrucuturedPruner ...@@ -75,7 +75,7 @@ UnstrucuturedPruner
pruner = UnstructuredPruner(paddle.static.default_main_program(), 'ratio', ratio=0.55, place=place) pruner = UnstructuredPruner(paddle.static.default_main_program(), 'ratio', ratio=0.55, place=place)
.. ..
.. py:method:: paddleslim.prune.unstructured_pruner.UnstructuredPruner.step() .. py:method:: paddleslim.prune.UnstructuredPruner.step()
更新稀疏化的阈值,如果是'threshold'模式,则维持设定的阈值,如果是'ratio'模式,则根据优化后的模型参数和设定的比例,重新计算阈值。该函数调用在训练过程中每个batch的optimizer.step()之后。 更新稀疏化的阈值,如果是'threshold'模式,则维持设定的阈值,如果是'ratio'模式,则根据优化后的模型参数和设定的比例,重新计算阈值。该函数调用在训练过程中每个batch的optimizer.step()之后。
...@@ -110,7 +110,7 @@ UnstrucuturedPruner ...@@ -110,7 +110,7 @@ UnstrucuturedPruner
print(pruner.threshold) # 可以看出,这里的threshold和上面打印的不同,这是因为step函数根据设定的ratio更新了threshold数值,便于剪裁操作。 print(pruner.threshold) # 可以看出,这里的threshold和上面打印的不同,这是因为step函数根据设定的ratio更新了threshold数值,便于剪裁操作。
.. ..
.. py:method:: paddleslim.prune.unstructured_pruner.UnstructuredPruner.update_params() .. py:method:: paddleslim.prune.UnstructuredPruner.update_params()
每一步优化后,重制模型中本来是0的权重。这一步通常用于模型evaluation和save之前,确保模型的稀疏率。但是,在训练过程中,由于前向过程中插入了稀疏化权重的op,故不需要开发者在训练过程中额外调用了。 每一步优化后,重制模型中本来是0的权重。这一步通常用于模型evaluation和save之前,确保模型的稀疏率。但是,在训练过程中,由于前向过程中插入了稀疏化权重的op,故不需要开发者在训练过程中额外调用了。
...@@ -149,7 +149,50 @@ UnstrucuturedPruner ...@@ -149,7 +149,50 @@ UnstrucuturedPruner
.. ..
.. py:method:: paddleslim.prune.unstructured_pruner.UnstructuredPruner.total_sparse(program) .. py:method:: paddleslim.prune.UnstructuredPruner.set_static_masks()
这个API比较特殊,一般情况下不会用到。只有在基于FP32稀疏化模型进行量化训练时需要调用,因为需要固定住原本被置0的权重,保持0不变。具体来说,对于输入的 parameters=[0, 3, 0, 4, 5.5, 0],会生成对应的mask为:[0, 1, 0, 1, 1, 0]。而且在训练过程中,该 mask 数值不会随 parameters 更新(训练)而改变。在评估/保存模型之前,可以通过调用 pruner.update_params() 将mask应用到 parameters 上,从而达到『在训练过程中 parameters 中数值为0的参数不参与训练』的效果。
**示例代码:**
.. code-block:: python
import paddle
import paddle.fluid as fluid
from paddleslim.prune import UnstructuredPruner
paddle.enable_static()
train_program = paddle.static.default_main_program()
startup_program = paddle.static.default_startup_program()
with fluid.program_guard(train_program, startup_program):
image = fluid.data(name='x', shape=[None, 1, 28, 28])
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
conv = fluid.layers.conv2d(image, 32, 1)
feature = fluid.layers.fc(conv, 10, act='softmax')
cost = fluid.layers.cross_entropy(input=feature, label=label)
avg_cost = fluid.layers.mean(x=cost)
place = paddle.static.cpu_places()[0]
exe = paddle.static.Executor(place)
exe.run(startup_program)
pruner = UnstructuredPruner(paddle.static.default_main_program(), 'ratio', ratio=0.55, place=place)
'''注释中为量化训练相关代码,以及参数导入
QAT configs and APIs
restore the sparse FP32 weights
'''
pruner.set_static_masks()
# quantization-aware training a batch
pruner.update_params()# 这一行代码需要在模型eval和保存前调用。
# eval or save pruned model
..
.. py:method:: paddleslim.prune.UnstructuredPruner.total_sparse(program)
UnstructuredPruner中的静态方法,用于计算给定的模型(program)的稀疏度并返回。该方法为静态方法,是考虑到在单单做模型评价的时候,我们就不需要初始化一个UnstructuredPruner示例了。 UnstructuredPruner中的静态方法,用于计算给定的模型(program)的稀疏度并返回。该方法为静态方法,是考虑到在单单做模型评价的时候,我们就不需要初始化一个UnstructuredPruner示例了。
...@@ -191,7 +234,7 @@ UnstrucuturedPruner ...@@ -191,7 +234,7 @@ UnstrucuturedPruner
.. ..
.. py:method:: paddleslim.prune.unstructured_pruner.UnstructuredPruner.total_sparse_conv1x1(program) .. py:method:: paddleslim.prune.UnstructuredPruner.total_sparse_conv1x1(program)
UnstructuredPruner中的静态方法,用于计算给定的模型(program)的1x1卷积稀疏度并返回。该方法为静态方法,是考虑到在单单做模型评价的时候,我们就不需要初始化一个UnstructuredPruner示例了。 UnstructuredPruner中的静态方法,用于计算给定的模型(program)的1x1卷积稀疏度并返回。该方法为静态方法,是考虑到在单单做模型评价的时候,我们就不需要初始化一个UnstructuredPruner示例了。
...@@ -234,7 +277,7 @@ UnstrucuturedPruner ...@@ -234,7 +277,7 @@ UnstrucuturedPruner
.. ..
.. py:method:: paddleslim.prune.unstructured_pruner.UnstructuredPruner.summarize_weights(program, ratio=0.1) .. py:method:: paddleslim.prune.UnstructuredPruner.summarize_weights(program, ratio=0.1)
该函数用于估计预训练模型中参数的分布情况,尤其是在不清楚如何设置threshold的数值时,尤为有用。例如,当输入为ratio=0.1时,函数会返回一个数值v,而绝对值小于v的权重的个数占所有权重个数的(100*ratio%)。 该函数用于估计预训练模型中参数的分布情况,尤其是在不清楚如何设置threshold的数值时,尤为有用。例如,当输入为ratio=0.1时,函数会返回一个数值v,而绝对值小于v的权重的个数占所有权重个数的(100*ratio%)。
...@@ -351,7 +394,7 @@ GMPUnstrucuturedPruner ...@@ -351,7 +394,7 @@ GMPUnstrucuturedPruner
print(pruner.ratio) # 可以看到ratio从0.15非线性的增加到0.55。 print(pruner.ratio) # 可以看到ratio从0.15非线性的增加到0.55。
.. ..
.. py:method:: paddleslim.prune.unstructured_pruner.GMPUnstructuredPruner.step() .. py:method:: paddleslim.prune.GMPUnstructuredPruner.step()
根据优化后的模型参数和设定的比例,重新计算阈值,并且更新mask。该函数调用在训练过程中每个batch的optimizer.step()之后。 根据优化后的模型参数和设定的比例,重新计算阈值,并且更新mask。该函数调用在训练过程中每个batch的optimizer.step()之后。
......
...@@ -126,6 +126,14 @@ class UnstructuredPruner(): ...@@ -126,6 +126,14 @@ class UnstructuredPruner():
bool_tmp = (paddle.abs(param) >= self.threshold) bool_tmp = (paddle.abs(param) >= self.threshold)
paddle.assign(bool_tmp, output=mask) paddle.assign(bool_tmp, output=mask)
def set_static_masks(self):
for name, sub_layer in self.model.named_sublayers():
if not self._should_prune_layer(sub_layer): continue
for param in sub_layer.parameters(include_sublayers=False):
mask = self.masks.get(param.name)
bool_tmp = (paddle.abs(param) != 0.0)
paddle.assign(bool_tmp, output=mask)
def summarize_weights(self, model, ratio=0.1): def summarize_weights(self, model, ratio=0.1):
""" """
The function is used to get the weights corresponding to a given ratio The function is used to get the weights corresponding to a given ratio
......
...@@ -195,6 +195,17 @@ class UnstructuredPruner(): ...@@ -195,6 +195,17 @@ class UnstructuredPruner():
v_mask = (v_param != 0).astype(v_param.dtype) v_mask = (v_param != 0).astype(v_param.dtype)
t_mask.set(v_mask, self.place) t_mask.set(v_mask, self.place)
def set_static_masks(self):
for param in self.masks:
if not self._should_prune_param(param):
continue
mask_name = self.masks[param]
t_param = self.scope.find_var(param).get_tensor()
t_mask = self.scope.find_var(mask_name).get_tensor()
v_param = np.array(t_param)
v_mask = (v_param != 0).astype(v_param.dtype)
t_mask.set(v_mask, self.place)
def step(self): def step(self):
""" """
Update the threshold and masks. Update the threshold and masks.
......
import sys
sys.path.append("../../")
import unittest
import paddle
import numpy as np
from paddleslim import UnstructuredPruner
from paddle.vision.models import mobilenet_v1
import paddle.vision.transforms as T
import paddle.fluid as fluid
from paddle.static import InputSpec as Input
import paddle.nn.functional as F
class TestStaticMasks(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestStaticMasks, self).__init__(*args, **kwargs)
paddle.disable_static()
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
self.train_dataset = paddle.vision.datasets.MNIST(
mode="train", backend="cv2", transform=transform)
self.train_loader = paddle.io.DataLoader(
self.train_dataset,
places=paddle.set_device('cpu'),
return_list=True)
def _reader():
for data in self.val_dataset:
yield data
self.val_reader = _reader
def _update_masks(self, pruner, t):
for name, sub_layer in pruner.model.named_sublayers():
for param in sub_layer.parameters(include_sublayers=False):
mask = pruner.masks.get(param.name)
bool_tmp = (paddle.abs(param) < t)
paddle.assign(bool_tmp, output=mask)
def runTest(self):
with fluid.unique_name.guard():
net = paddle.vision.models.LeNet()
optimizer = paddle.optimizer.Adam(
learning_rate=0.001, parameters=net.parameters())
inputs = [Input([None, 1, 28, 28], 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')]
pruner = UnstructuredPruner(net, mode='ratio', ratio=0.55)
net.train()
self._update_masks(pruner, 0.0)
pruner.update_params()
self._update_masks(pruner, 1.0)
pruner.set_static_masks()
sparsity_0 = UnstructuredPruner.total_sparse(net)
for i, data in enumerate(self.train_loader):
x_data = data[0]
y_data = paddle.to_tensor(data[1])
logits = net(x_data)
loss = F.cross_entropy(logits, y_data)
loss.backward()
optimizer.step()
optimizer.clear_grad()
if i == 10: break
sparsity_1 = UnstructuredPruner.total_sparse(net)
pruner.update_params()
sparsity_2 = UnstructuredPruner.total_sparse(net)
print(sparsity_0, sparsity_1, sparsity_2)
self.assertEqual(sparsity_0, 1.0)
self.assertEqual(sparsity_2, 1.0)
self.assertLess(sparsity_1, 1.0)
if __name__ == "__main__":
unittest.main()
import sys
sys.path.append("../")
import unittest
from static_case import StaticCase
import paddle.fluid as fluid
import paddle
from paddleslim.prune import UnstructuredPruner
from layers import conv_bn_layer
import numpy as np
class TestStaticMasks(StaticCase):
def _update_masks(self, pruner, t):
for param in pruner.masks:
mask_name = pruner.masks[param]
t_param = pruner.scope.find_var(param).get_tensor()
t_mask = pruner.scope.find_var(mask_name).get_tensor()
v_param = np.array(t_param)
v_mask = (np.abs(v_param) < t).astype(v_param.dtype)
t_mask.set(v_mask, pruner.place)
def test_set_static_masks(self):
main_program = paddle.static.default_main_program()
startup_program = paddle.static.default_startup_program()
with paddle.static.program_guard(main_program, startup_program):
input = paddle.static.data(name='image', shape=[None, 3, 16, 16])
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
conv1 = conv_bn_layer(input, 8, 1, "conv1")
conv2 = conv_bn_layer(conv1, 8, 1, "conv2")
conv3 = fluid.layers.conv2d_transpose(
input=conv2, num_filters=16, filter_size=2, stride=2)
predict = fluid.layers.fc(input=conv3, size=10, act='softmax')
cost = fluid.layers.cross_entropy(input=predict, label=label)
adam_optimizer = fluid.optimizer.AdamOptimizer(0.01)
avg_cost = fluid.layers.mean(cost)
adam_optimizer.minimize(avg_cost)
place = paddle.static.cpu_places()[0]
exe = paddle.static.Executor(place)
scope = paddle.static.global_scope()
exe.run(startup_program, scope=scope)
pruner = UnstructuredPruner(
main_program, 'ratio', scope=scope, place=place)
self._update_masks(pruner, 0.0)
pruner.update_params()
self._update_masks(pruner, 1.0)
pruner.set_static_masks()
sparsity_0 = pruner.total_sparse(main_program)
x = np.random.random(size=(10, 3, 16, 16)).astype('float32')
label = np.random.random(size=(10, 1)).astype('int64')
loss_data, = exe.run(main_program,
feed={"image": x,
"label": label},
fetch_list=[cost.name])
sparsity_1 = UnstructuredPruner.total_sparse(main_program)
pruner.update_params()
sparsity_2 = UnstructuredPruner.total_sparse(main_program)
print(sparsity_0, sparsity_1, sparsity_2)
self.assertEqual(sparsity_0, 1.0)
self.assertEqual(sparsity_2, 1.0)
self.assertLess(sparsity_1, 1.0)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册