diff --git a/docs/zh_cn/api_cn/dygraph/pruners/unstructured_pruner.rst b/docs/zh_cn/api_cn/dygraph/pruners/unstructured_pruner.rst index 9beb48cd8b9213fff7e9a3d59f33e76eee54b272..19086aab4abd9d9c6614490cf65984e07c4a3ec4 100644 --- a/docs/zh_cn/api_cn/dygraph/pruners/unstructured_pruner.rst +++ b/docs/zh_cn/api_cn/dygraph/pruners/unstructured_pruner.rst @@ -113,6 +113,35 @@ UnstructuredPruner .. + .. py:method:: paddleslim.UnstructuredPruner.set_static_masks() + + 这个API比较特殊,一般情况下不会用到。只有在【基于 FP32 稀疏化模型】进行量化训练时需要调用,因为需要固定住原本被置0的权重,保持0不变。具体来说,对于输入的 parameters=[0, 3, 0, 4, 5.5, 0],会生成对应的mask为:[0, 1, 0, 1, 1, 0]。而且在训练过程中,该 mask 数值不会随 parameters 更新(训练)而改变。在评估/保存模型之前,可以通过调用 pruner.update_params() 将mask应用到 parameters 上,从而达到『在训练过程中 parameters 中数值为0的参数不参与训练』的效果。 + + **示例代码:** + + .. code-block:: python + + import paddle + from paddleslim import UnstructuredPruner + from paddle.vision.models import LeNet as net + import numpy as np + + place = paddle.set_device('cpu') + model = net(num_classes=10) + pruner = UnstructuredPruner(model, mode='threshold', threshold=0.5) + + '''注释中为量化训练相关代码,以及参数导入 + QAT configs and APIs + restore the sparse FP32 weights + ''' + + pruner.set_static_masks() + # quantization-aware training a batch + pruner.update_params()# 这一行代码需要在模型eval和保存前调用。 + # eval or save pruned model + + .. + .. py:method:: paddleslim.UnstructuredPruner.total_sparse(model) UnstructuredPruner中的静态方法,用于计算给定的模型(model)的稀疏度并返回。该方法为静态方法,是考虑到在单单做模型评价的时候,我们就不需要初始化一个UnstructuredPruner示例了。 diff --git a/docs/zh_cn/api_cn/static/prune/unstructured_prune_api.rst b/docs/zh_cn/api_cn/static/prune/unstructured_prune_api.rst index 7a19b8db43033522651540c85362ae684b2db26f..298175faf0a986c5ffd9cb589a512574588ba701 100644 --- a/docs/zh_cn/api_cn/static/prune/unstructured_prune_api.rst +++ b/docs/zh_cn/api_cn/static/prune/unstructured_prune_api.rst @@ -75,7 +75,7 @@ UnstrucuturedPruner pruner = UnstructuredPruner(paddle.static.default_main_program(), 'ratio', ratio=0.55, place=place) .. - .. py:method:: paddleslim.prune.unstructured_pruner.UnstructuredPruner.step() + .. py:method:: paddleslim.prune.UnstructuredPruner.step() 更新稀疏化的阈值,如果是'threshold'模式,则维持设定的阈值,如果是'ratio'模式,则根据优化后的模型参数和设定的比例,重新计算阈值。该函数调用在训练过程中每个batch的optimizer.step()之后。 @@ -110,7 +110,7 @@ UnstrucuturedPruner print(pruner.threshold) # 可以看出,这里的threshold和上面打印的不同,这是因为step函数根据设定的ratio更新了threshold数值,便于剪裁操作。 .. - .. py:method:: paddleslim.prune.unstructured_pruner.UnstructuredPruner.update_params() + .. py:method:: paddleslim.prune.UnstructuredPruner.update_params() 每一步优化后,重制模型中本来是0的权重。这一步通常用于模型evaluation和save之前,确保模型的稀疏率。但是,在训练过程中,由于前向过程中插入了稀疏化权重的op,故不需要开发者在训练过程中额外调用了。 @@ -149,7 +149,50 @@ UnstrucuturedPruner .. - .. py:method:: paddleslim.prune.unstructured_pruner.UnstructuredPruner.total_sparse(program) + .. py:method:: paddleslim.prune.UnstructuredPruner.set_static_masks() + + 这个API比较特殊,一般情况下不会用到。只有在基于FP32稀疏化模型进行量化训练时需要调用,因为需要固定住原本被置0的权重,保持0不变。具体来说,对于输入的 parameters=[0, 3, 0, 4, 5.5, 0],会生成对应的mask为:[0, 1, 0, 1, 1, 0]。而且在训练过程中,该 mask 数值不会随 parameters 更新(训练)而改变。在评估/保存模型之前,可以通过调用 pruner.update_params() 将mask应用到 parameters 上,从而达到『在训练过程中 parameters 中数值为0的参数不参与训练』的效果。 + + **示例代码:** + + .. code-block:: python + + import paddle + import paddle.fluid as fluid + from paddleslim.prune import UnstructuredPruner + + paddle.enable_static() + + train_program = paddle.static.default_main_program() + startup_program = paddle.static.default_startup_program() + + with fluid.program_guard(train_program, startup_program): + image = fluid.data(name='x', shape=[None, 1, 28, 28]) + label = fluid.data(name='label', shape=[None, 1], dtype='int64') + conv = fluid.layers.conv2d(image, 32, 1) + feature = fluid.layers.fc(conv, 10, act='softmax') + cost = fluid.layers.cross_entropy(input=feature, label=label) + avg_cost = fluid.layers.mean(x=cost) + + place = paddle.static.cpu_places()[0] + exe = paddle.static.Executor(place) + exe.run(startup_program) + + pruner = UnstructuredPruner(paddle.static.default_main_program(), 'ratio', ratio=0.55, place=place) + + '''注释中为量化训练相关代码,以及参数导入 + QAT configs and APIs + restore the sparse FP32 weights + ''' + + pruner.set_static_masks() + # quantization-aware training a batch + pruner.update_params()# 这一行代码需要在模型eval和保存前调用。 + # eval or save pruned model + + .. + + .. py:method:: paddleslim.prune.UnstructuredPruner.total_sparse(program) UnstructuredPruner中的静态方法,用于计算给定的模型(program)的稀疏度并返回。该方法为静态方法,是考虑到在单单做模型评价的时候,我们就不需要初始化一个UnstructuredPruner示例了。 @@ -191,7 +234,7 @@ UnstrucuturedPruner .. - .. py:method:: paddleslim.prune.unstructured_pruner.UnstructuredPruner.total_sparse_conv1x1(program) + .. py:method:: paddleslim.prune.UnstructuredPruner.total_sparse_conv1x1(program) UnstructuredPruner中的静态方法,用于计算给定的模型(program)的1x1卷积稀疏度并返回。该方法为静态方法,是考虑到在单单做模型评价的时候,我们就不需要初始化一个UnstructuredPruner示例了。 @@ -234,7 +277,7 @@ UnstrucuturedPruner .. - .. py:method:: paddleslim.prune.unstructured_pruner.UnstructuredPruner.summarize_weights(program, ratio=0.1) + .. py:method:: paddleslim.prune.UnstructuredPruner.summarize_weights(program, ratio=0.1) 该函数用于估计预训练模型中参数的分布情况,尤其是在不清楚如何设置threshold的数值时,尤为有用。例如,当输入为ratio=0.1时,函数会返回一个数值v,而绝对值小于v的权重的个数占所有权重个数的(100*ratio%)。 @@ -351,7 +394,7 @@ GMPUnstrucuturedPruner print(pruner.ratio) # 可以看到ratio从0.15非线性的增加到0.55。 .. - .. py:method:: paddleslim.prune.unstructured_pruner.GMPUnstructuredPruner.step() + .. py:method:: paddleslim.prune.GMPUnstructuredPruner.step() 根据优化后的模型参数和设定的比例,重新计算阈值,并且更新mask。该函数调用在训练过程中每个batch的optimizer.step()之后。 diff --git a/paddleslim/dygraph/prune/unstructured_pruner.py b/paddleslim/dygraph/prune/unstructured_pruner.py index deaa7e3ae969e23a945b8bf17aafed54a5bd9f1e..4e7663352b782aabeac321836e54fa394f382391 100644 --- a/paddleslim/dygraph/prune/unstructured_pruner.py +++ b/paddleslim/dygraph/prune/unstructured_pruner.py @@ -126,6 +126,14 @@ class UnstructuredPruner(): bool_tmp = (paddle.abs(param) >= self.threshold) paddle.assign(bool_tmp, output=mask) + def set_static_masks(self): + for name, sub_layer in self.model.named_sublayers(): + if not self._should_prune_layer(sub_layer): continue + for param in sub_layer.parameters(include_sublayers=False): + mask = self.masks.get(param.name) + bool_tmp = (paddle.abs(param) != 0.0) + paddle.assign(bool_tmp, output=mask) + def summarize_weights(self, model, ratio=0.1): """ The function is used to get the weights corresponding to a given ratio diff --git a/paddleslim/prune/unstructured_pruner.py b/paddleslim/prune/unstructured_pruner.py index ca2fb2fd87e6be44373bcfd5dd93c19377303d61..e6787395ac8e5261cdb98cf7b959801f51236756 100644 --- a/paddleslim/prune/unstructured_pruner.py +++ b/paddleslim/prune/unstructured_pruner.py @@ -195,6 +195,17 @@ class UnstructuredPruner(): v_mask = (v_param != 0).astype(v_param.dtype) t_mask.set(v_mask, self.place) + def set_static_masks(self): + for param in self.masks: + if not self._should_prune_param(param): + continue + mask_name = self.masks[param] + t_param = self.scope.find_var(param).get_tensor() + t_mask = self.scope.find_var(mask_name).get_tensor() + v_param = np.array(t_param) + v_mask = (v_param != 0).astype(v_param.dtype) + t_mask.set(v_mask, self.place) + def step(self): """ Update the threshold and masks. diff --git a/tests/dygraph/test_unstructured_prune_quant.py b/tests/dygraph/test_unstructured_prune_quant.py new file mode 100644 index 0000000000000000000000000000000000000000..96408580d1cb9ce26943baac944f5658311a81e2 --- /dev/null +++ b/tests/dygraph/test_unstructured_prune_quant.py @@ -0,0 +1,72 @@ +import sys +sys.path.append("../../") +import unittest +import paddle +import numpy as np +from paddleslim import UnstructuredPruner +from paddle.vision.models import mobilenet_v1 +import paddle.vision.transforms as T +import paddle.fluid as fluid +from paddle.static import InputSpec as Input +import paddle.nn.functional as F + + +class TestStaticMasks(unittest.TestCase): + def __init__(self, *args, **kwargs): + super(TestStaticMasks, self).__init__(*args, **kwargs) + paddle.disable_static() + transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])]) + self.train_dataset = paddle.vision.datasets.MNIST( + mode="train", backend="cv2", transform=transform) + self.train_loader = paddle.io.DataLoader( + self.train_dataset, + places=paddle.set_device('cpu'), + return_list=True) + + def _reader(): + for data in self.val_dataset: + yield data + + self.val_reader = _reader + + def _update_masks(self, pruner, t): + for name, sub_layer in pruner.model.named_sublayers(): + for param in sub_layer.parameters(include_sublayers=False): + mask = pruner.masks.get(param.name) + bool_tmp = (paddle.abs(param) < t) + paddle.assign(bool_tmp, output=mask) + + def runTest(self): + with fluid.unique_name.guard(): + net = paddle.vision.models.LeNet() + optimizer = paddle.optimizer.Adam( + learning_rate=0.001, parameters=net.parameters()) + inputs = [Input([None, 1, 28, 28], 'float32', name='image')] + labels = [Input([None, 1], 'int64', name='label')] + pruner = UnstructuredPruner(net, mode='ratio', ratio=0.55) + net.train() + self._update_masks(pruner, 0.0) + pruner.update_params() + self._update_masks(pruner, 1.0) + pruner.set_static_masks() + sparsity_0 = UnstructuredPruner.total_sparse(net) + for i, data in enumerate(self.train_loader): + x_data = data[0] + y_data = paddle.to_tensor(data[1]) + logits = net(x_data) + loss = F.cross_entropy(logits, y_data) + loss.backward() + optimizer.step() + optimizer.clear_grad() + if i == 10: break + sparsity_1 = UnstructuredPruner.total_sparse(net) + pruner.update_params() + sparsity_2 = UnstructuredPruner.total_sparse(net) + print(sparsity_0, sparsity_1, sparsity_2) + self.assertEqual(sparsity_0, 1.0) + self.assertEqual(sparsity_2, 1.0) + self.assertLess(sparsity_1, 1.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_unstructured_pruner_quant.py b/tests/test_unstructured_pruner_quant.py new file mode 100644 index 0000000000000000000000000000000000000000..95c877ac5f38b6d66f7dcbb562cb095397057c62 --- /dev/null +++ b/tests/test_unstructured_pruner_quant.py @@ -0,0 +1,67 @@ +import sys +sys.path.append("../") +import unittest +from static_case import StaticCase +import paddle.fluid as fluid +import paddle +from paddleslim.prune import UnstructuredPruner +from layers import conv_bn_layer +import numpy as np + + +class TestStaticMasks(StaticCase): + def _update_masks(self, pruner, t): + for param in pruner.masks: + mask_name = pruner.masks[param] + t_param = pruner.scope.find_var(param).get_tensor() + t_mask = pruner.scope.find_var(mask_name).get_tensor() + v_param = np.array(t_param) + v_mask = (np.abs(v_param) < t).astype(v_param.dtype) + t_mask.set(v_mask, pruner.place) + + def test_set_static_masks(self): + main_program = paddle.static.default_main_program() + startup_program = paddle.static.default_startup_program() + with paddle.static.program_guard(main_program, startup_program): + input = paddle.static.data(name='image', shape=[None, 3, 16, 16]) + label = fluid.data(name='label', shape=[None, 1], dtype='int64') + conv1 = conv_bn_layer(input, 8, 1, "conv1") + conv2 = conv_bn_layer(conv1, 8, 1, "conv2") + conv3 = fluid.layers.conv2d_transpose( + input=conv2, num_filters=16, filter_size=2, stride=2) + predict = fluid.layers.fc(input=conv3, size=10, act='softmax') + cost = fluid.layers.cross_entropy(input=predict, label=label) + adam_optimizer = fluid.optimizer.AdamOptimizer(0.01) + avg_cost = fluid.layers.mean(cost) + adam_optimizer.minimize(avg_cost) + + place = paddle.static.cpu_places()[0] + exe = paddle.static.Executor(place) + scope = paddle.static.global_scope() + exe.run(startup_program, scope=scope) + + pruner = UnstructuredPruner( + main_program, 'ratio', scope=scope, place=place) + + self._update_masks(pruner, 0.0) + pruner.update_params() + self._update_masks(pruner, 1.0) + pruner.set_static_masks() + sparsity_0 = pruner.total_sparse(main_program) + x = np.random.random(size=(10, 3, 16, 16)).astype('float32') + label = np.random.random(size=(10, 1)).astype('int64') + loss_data, = exe.run(main_program, + feed={"image": x, + "label": label}, + fetch_list=[cost.name]) + sparsity_1 = UnstructuredPruner.total_sparse(main_program) + pruner.update_params() + sparsity_2 = UnstructuredPruner.total_sparse(main_program) + print(sparsity_0, sparsity_1, sparsity_2) + self.assertEqual(sparsity_0, 1.0) + self.assertEqual(sparsity_2, 1.0) + self.assertLess(sparsity_1, 1.0) + + +if __name__ == '__main__': + unittest.main()