test_spec_names.py 3.3 KB
Newer Older
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import unittest

17 18 19 20 21 22
import paddle
from paddle.nn import Layer


class Net(Layer):
    def __init__(self):
23
        super().__init__()
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
        self.fc = paddle.nn.Linear(16, 3)

    def forward(self, x, y, m, n):
        inputs = [x, y, m, n]
        outs = []
        for var in inputs:
            out = paddle.reshape(x, [-1, 16])
            out = self.fc(out)
            outs.append(out)

        out = paddle.stack(outs)
        return paddle.sum(out)


class TestArgsSpecName(unittest.TestCase):
    def read_from_dataset(self):
        self.x = paddle.randn([4, 2, 8])
        self.y = paddle.randn([4, 2, 8])
        self.m = paddle.randn([4, 2, 8])
        self.n = paddle.randn([4, 2, 8])

    def test_spec_name_hash(self):
        net = Net()
        net = paddle.jit.to_static(net)
        # Convert into program with four input
        self.read_from_dataset()
        self.run_test(net, [self.x, self.y, self.m, self.n], 1, [0, 1, 2, 3])

        # Convert into program with three input
        self.read_from_dataset()
        self.run_test(net, [self.x, self.x, self.m, self.n], 2, [0, 0, 1, 2])

        # Convert into program with two input
        self.read_from_dataset()
        self.run_test(net, [self.x, self.x, self.m, self.m], 3, [0, 0, 1, 1])

        # Use Cache Program
        self.read_from_dataset()
        self.run_test(net, [self.n, self.n, self.y, self.y], 3, [0, 0, 1, 1])

        # Convert into program with two input
        self.read_from_dataset()
        self.run_test(net, [self.x, self.y, self.x, self.y], 4, [0, 1, 0, 1])

        # Use Cache Program
        self.read_from_dataset()
        self.run_test(net, [self.m, self.n, self.m, self.n], 4, [0, 1, 0, 1])

        # Convert into program with one input
        self.read_from_dataset()
        self.run_test(net, [self.x, self.x, self.x, self.x], 5, [0, 0, 0, 0])

        # Use Cache Program
        self.read_from_dataset()
        self.run_test(net, [self.m, self.m, self.m, self.m], 5, [0, 0, 0, 0])

    def run_test(self, net, inputs, trace_count, mode):
        out = net(*inputs)
        self.assertEqual(net.forward.get_traced_count(), trace_count)
        self.assert_feed_mode(net.forward.inputs, mode)

    def assert_feed_mode(self, inputs, expect_mode):
        assert isinstance(inputs, list)
        assert isinstance(expect_mode, list)
        in_names = [var.name for var in inputs]

        i, name_ids = 0, {}

        def to_idx(name):
            nonlocal i
            if name not in name_ids:
                name_ids[name] = i
                i += 1
            return name_ids[name]

        mode = [to_idx(name) for name in in_names]
100
        self.assertEqual(mode, expect_mode)
101 102 103 104


if __name__ == '__main__':
    unittest.main()