test_unstructured_prune.py 1.5 KB
Newer Older
M
minghaoBD 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
import sys
sys.path.append("../../")
import unittest
import paddle
import numpy as np
from paddleslim.dygraph.prune.unstructured_pruner import UnstructuredPruner
from paddle.vision.models import mobilenet_v1


class TestUnstructuredPruner(unittest.TestCase):
    def __init__(self, *args, **kwargs):
        super(TestUnstructuredPruner, self).__init__(*args, **kwargs)
        paddle.disable_static()
        self._gen_model()

    def _gen_model(self):
        self.net = mobilenet_v1(num_classes=10, pretrained=False)
        self.pruner = UnstructuredPruner(
            self.net, mode='ratio', ratio=0.98, threshold=0.0)

    def test_prune(self):
        ori_density = UnstructuredPruner.total_sparse(self.net)
        ori_threshold = self.pruner.threshold
        self.pruner.step()
        self.net(
            paddle.to_tensor(
                np.random.uniform(0, 1, [16, 3, 32, 32]), dtype='float32'))
        cur_density = UnstructuredPruner.total_sparse(self.net)
        cur_threshold = self.pruner.threshold
        print("Original threshold: {}".format(ori_threshold))
        print("Current threshold: {}".format(cur_threshold))
        print("Original density: {}".format(ori_density))
        print("Current density: {}".format(cur_density))
        self.assertLessEqual(ori_threshold, cur_threshold)
        self.assertLessEqual(cur_density, ori_density)

        self.pruner.update_params()
        self.assertEqual(cur_density, UnstructuredPruner.total_sparse(self.net))


if __name__ == "__main__":
    unittest.main()