test_prune_walker.py 874 字节
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
import sys
sys.path.append("../../")
import unittest
import numpy as np
import paddle
from paddleslim.dygraph import L1NormFilterPruner
from paddle.nn import Conv2D, Linear, Layer


class Net(Layer):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = Conv2D(3, 8, 3)
        self.linear = Linear(8 * 30 * 30, 5)

    def forward(self, x):
        tmp = self.conv1(x)
        tmp = paddle.flatten(tmp, 1)
        return self.linear(tmp)


class TestWalker(unittest.TestCase):
    def runTest(self):
        x_shape = (1, 3, 32, 32)
        net = Net()
        x = np.random.uniform(-1, 1, x_shape).astype('float32')
        pruner = L1NormFilterPruner(net, [paddle.to_tensor(x)])
        pruner.prune_vars({"conv2d_0.w_0": 0.2}, [0])
        self.assertTrue(net.linear.weight.shape == [5400, 5])


if __name__ == '__main__':
    unittest.main()