test_dy2prog.py 2.3 KB
Newer Older
1 2 3
import os
import sys
sys.path.append("../")
4
os.environ['FLAGS_enable_eager_mode'] = "1"
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
import paddle
import unittest
from paddleslim.core import dygraph2program


class Model(paddle.nn.Layer):
    def __init__(self):
        super(Model, self).__init__()
        self.conv = paddle.nn.Conv2D(
            in_channels=1, out_channels=256, kernel_size=3, stride=1, padding=1)
        self.pool2d_avg = paddle.nn.AdaptiveAvgPool2D([1, 1])
        self.out = paddle.nn.Linear(256, 10)

    def forward(self, inputs):
        inputs = paddle.reshape(inputs, shape=[0, 1, 28, 28])
        y = self.conv(inputs)
        y = self.pool2d_avg(y)
        y = paddle.reshape(y, shape=[-1, 256])
        y = self.out(y)
        return y


class TestEagerDygraph2Program(unittest.TestCase):
    def setUp(self):
        self.prepare_inputs()
        self.prepare_layer()

    def prepare_inputs(self):
        self.inputs = [3, 28, 28]
W
whs 已提交
34
        self.ops = [
35 36
            'assign_value', 'reshape2', 'conv2d', 'reshape2', 'elementwise_add',
            'pool2d', 'reshape2', 'matmul_v2', 'elementwise_add'
W
whs 已提交
37
        ]
38 39 40 41 42 43 44 45 46

    def prepare_layer(self):
        self.layer = Model()

    def test_dy2prog(self):
        program = dygraph2program(self.layer, self.inputs)
        self.assert_program(program)

    def assert_program(self, program):
W
whs 已提交
47
        self.assertListEqual([op.type for op in program.block(0).ops], self.ops)
48 49 50 51 52


class TestEagerDygraph2Program2(TestEagerDygraph2Program):
    def prepare_inputs(self):
        self.inputs = [[3, 28, 28]]
W
whs 已提交
53
        self.ops = [
54 55
            'assign_value', 'reshape2', 'conv2d', 'reshape2', 'elementwise_add',
            'pool2d', 'reshape2', 'matmul_v2', 'elementwise_add'
W
whs 已提交
56
        ]
57 58 59 60 61


class TestEagerDygraph2Program3(TestEagerDygraph2Program):
    def prepare_inputs(self):
        self.inputs = paddle.randn([3, 28, 28])
W
whs 已提交
62
        self.ops = [
63 64
            'reshape2', 'conv2d', 'reshape2', 'elementwise_add', 'pool2d',
            'reshape2', 'matmul_v2', 'elementwise_add'
W
whs 已提交
65
        ]
66 67 68 69 70


class TestEagerDygraph2Program4(TestEagerDygraph2Program):
    def prepare_inputs(self):
        self.inputs = [paddle.randn([3, 28, 28])]
W
whs 已提交
71
        self.ops = [
72 73
            'reshape2', 'conv2d', 'reshape2', 'elementwise_add', 'pool2d',
            'reshape2', 'matmul_v2', 'elementwise_add'
W
whs 已提交
74
        ]
75 76 77 78


if __name__ == "__main__":
    unittest.main()