test_itpruner.py 576 字节
Newer Older
R
rehulisw 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
import sys
sys.path.append("../")
import unittest
import paddle
from paddleslim.nas.itpruner import ITPruner
from paddleslim.nas.itpruner.Cifar.nets.resnet_cifar import ResNetCifar


class TestITPruner(unittest.TestCase):
    def test_itpruner(self):
        net = ResNetCifar(depth=20, num_classes=10, cfg=None)
        data = paddle.normal(shape=[100, 3, 32, 32])

        itpruner = ITPruner(net, data)
        target_flops = 20800000
        beta = 243

        itpruner.prune(target_flops, beta)


if __name__ == '__main__':
    unittest.main()