test_function_spec.py 4.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import unittest
16 17 18

from test_declarative import foo_func

19
import paddle
20
from paddle.jit.dy2static.function_spec import FunctionSpec
21
from paddle.static import InputSpec
22

23 24
paddle.enable_static()

25 26 27 28 29 30 31

class TestFunctionSpec(unittest.TestCase):
    def test_constructor(self):
        foo_spec = FunctionSpec(foo_func)
        args_name = foo_spec.args_name
        self.assertListEqual(args_name, ['a', 'b', 'c', 'd'])
        self.assertTrue(foo_spec.dygraph_function == foo_func)
32
        self.assertIsNone(foo_spec.input_spec)
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52

    def test_verify_input_spec(self):
        a_spec = InputSpec([None, 10], name='a')
        b_spec = InputSpec([10], name='b')

        # type(input_spec) should be list or tuple
        with self.assertRaises(TypeError):
            foo_spec = FunctionSpec(foo_func, input_spec=a_spec)

        foo_spec = FunctionSpec(foo_func, input_spec=[a_spec, b_spec])
        self.assertTrue(len(foo_spec.flat_input_spec) == 2)

    def test_unified_args_and_kwargs(self):
        foo_spec = FunctionSpec(foo_func)
        # case 1: foo(10, 20, c=4)
        args, kwargs = foo_spec.unified_args_and_kwargs([10, 20], {'c': 4})
        self.assertTupleEqual(args, (10, 20, 4, 2))
        self.assertTrue(len(kwargs) == 0)

        # case 2: foo(a=10, b=20, d=4)
53 54 55
        args, kwargs = foo_spec.unified_args_and_kwargs(
            [], {'a': 10, 'b': 20, 'd': 4}
        )
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
        self.assertTupleEqual(args, (10, 20, 1, 4))
        self.assertTrue(len(kwargs) == 0)

        # case 3: foo(10, b=20)
        args, kwargs = foo_spec.unified_args_and_kwargs([10], {'b': 20})
        self.assertTupleEqual(args, (10, 20, 1, 2))
        self.assertTrue(len(kwargs) == 0)

        # assert len(self._arg_names) >= len(args)
        with self.assertRaises(ValueError):
            foo_spec.unified_args_and_kwargs([10, 20, 30, 40, 50], {'c': 4})

        # assert arg_name should be in kwargs
        with self.assertRaises(ValueError):
            foo_spec.unified_args_and_kwargs([10], {'c': 4})

    def test_args_to_input_spec(self):
73 74
        a_spec = InputSpec([None, 10], name='a', stop_gradient=True)
        b_spec = InputSpec([10], name='b', stop_gradient=True)
75 76 77 78 79 80 81

        a_tensor = paddle.static.data(name='a_var', shape=[4, 10])
        b_tensor = paddle.static.data(name='b_var', shape=[4, 10])
        kwargs = {'c': 1, 'd': 2}

        # case 1
        foo_spec = FunctionSpec(foo_func, input_spec=[a_spec, b_spec])
82
        input_with_spec, _ = foo_spec.args_to_input_spec(
83 84
            (a_tensor, b_tensor, 1, 2), {}
        )
85

86 87
        self.assertTrue(len(input_with_spec) == 4)
        self.assertTrue(input_with_spec[0] == a_spec)  # a
88 89
        ans_b_spec = InputSpec([4, 10], name='b', stop_gradient=True)
        self.assertTrue(input_with_spec[1] == ans_b_spec)  # b
90 91 92 93 94
        self.assertTrue(input_with_spec[2] == 1)  # c
        self.assertTrue(input_with_spec[3] == 2)  # d

        # case 2
        foo_spec = FunctionSpec(foo_func, input_spec=[a_spec])
95 96 97
        input_with_spec, _ = foo_spec.args_to_input_spec(
            (a_tensor, b_tensor), {}
        )
98 99 100 101 102 103 104 105 106
        self.assertTrue(len(input_with_spec) == 2)
        self.assertTrue(input_with_spec[0] == a_spec)  # a
        self.assertTupleEqual(input_with_spec[1].shape, (4, 10))  # b.shape
        self.assertEqual(input_with_spec[1].name, 'b_var')  # b.name

        # case 3
        # assert kwargs is None if set `input_spec`
        foo_spec = FunctionSpec(foo_func, input_spec=[a_spec])
        with self.assertRaises(ValueError):
107 108 109
            input_with_spec = foo_spec.args_to_input_spec(
                (a_tensor, b_tensor), {'c': 4}
            )
110 111 112 113 114

        # case 4
        # assert len(args) >= len(self._input_spec)
        foo_spec = FunctionSpec(foo_func, input_spec=[a_spec, b_spec])
        with self.assertRaises(ValueError):
115
            input_with_spec = foo_spec.args_to_input_spec((a_tensor,), {})
116 117 118 119


if __name__ == '__main__':
    unittest.main()