test_assert.py 2.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

17 18
import numpy

A
Aurelius84 已提交
19
import paddle
20
from paddle import fluid
H
hjyp 已提交
21
from paddle.jit.api import to_static
22 23


A
Aurelius84 已提交
24
@paddle.jit.to_static
25 26 27 28 29
def dyfunc_assert_variable(x):
    x_v = fluid.dygraph.to_variable(x)
    assert x_v


H
hjyp 已提交
30
@to_static
31 32 33 34 35 36
def dyfunc_assert_non_variable(x=True):
    assert x


class TestAssertVariable(unittest.TestCase):
    def _run(self, func, x, with_exception, to_static):
R
Ryan 已提交
37
        paddle.jit.enable_to_static(to_static)
38 39 40 41 42 43 44 45 46 47 48 49 50
        if with_exception:
            with self.assertRaises(BaseException):
                with fluid.dygraph.guard():
                    func(x)
        else:
            with fluid.dygraph.guard():
                func(x)

    def _run_dy_static(self, func, x, with_exception):
        self._run(func, x, with_exception, True)
        self._run(func, x, with_exception, False)

    def test_non_variable(self):
51 52 53 54 55 56
        self._run_dy_static(
            dyfunc_assert_non_variable, x=False, with_exception=True
        )
        self._run_dy_static(
            dyfunc_assert_non_variable, x=True, with_exception=False
        )
57 58

    def test_bool_variable(self):
59 60 61 62 63 64
        self._run_dy_static(
            dyfunc_assert_variable, x=numpy.array([False]), with_exception=True
        )
        self._run_dy_static(
            dyfunc_assert_variable, x=numpy.array([True]), with_exception=False
        )
65 66

    def test_int_variable(self):
67 68 69 70 71 72
        self._run_dy_static(
            dyfunc_assert_variable, x=numpy.array([0]), with_exception=True
        )
        self._run_dy_static(
            dyfunc_assert_variable, x=numpy.array([1]), with_exception=False
        )
73 74 75 76


if __name__ == '__main__':
    unittest.main()