test_unstructured_prune_gmp.py 1.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
import sys
sys.path.append("../../")
import unittest
import paddle
import numpy as np
from paddleslim import UnstructuredPruner, GMPUnstructuredPruner
from paddle.vision.models import mobilenet_v1


class TestUnstructuredPruner(unittest.TestCase):
    def __init__(self, *args, **kwargs):
        super(TestUnstructuredPruner, self).__init__(*args, **kwargs)
        paddle.disable_static()
        self._gen_model()

    def _gen_model(self):
        self.net = mobilenet_v1(num_classes=10, pretrained=False)
        configs = {
            'stable_iterations': 0,
            'pruning_iterations': 1000,
            'tunning_iterations': 1000,
            'resume_iteration': 500,
            'pruning_steps': 20,
            'initial_ratio': 0.05,
        }
        self.pruner = GMPUnstructuredPruner(
            self.net, ratio=0.55, configs=configs)

        self.assertGreater(self.pruner.ratio, 0.3)

    def test_unstructured_prune_gmp(self):
        last_ratio = 0.0
        ratio = 0.0
        while len(self.pruner.ratios_stack) > 0:
            self.pruner.step()
            last_ratio = ratio
            ratio = self.pruner.ratio
            self.assertGreaterEqual(ratio, last_ratio)
        self.assertEqual(ratio, 0.55)


if __name__ == "__main__":
    unittest.main()