test_assert.py 2.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

17
import numpy
H
hong 已提交
18
from dygraph_to_static_util import test_and_compare_with_new_ir
19

A
Aurelius84 已提交
20
import paddle
21
from paddle import fluid
H
hjyp 已提交
22
from paddle.jit.api import to_static
23 24


A
Aurelius84 已提交
25
@paddle.jit.to_static
26 27 28 29 30
def dyfunc_assert_variable(x):
    x_v = fluid.dygraph.to_variable(x)
    assert x_v


H
hjyp 已提交
31
@to_static
32 33 34 35 36 37
def dyfunc_assert_non_variable(x=True):
    assert x


class TestAssertVariable(unittest.TestCase):
    def _run(self, func, x, with_exception, to_static):
R
Ryan 已提交
38
        paddle.jit.enable_to_static(to_static)
39 40 41 42 43 44 45 46 47 48 49 50
        if with_exception:
            with self.assertRaises(BaseException):
                with fluid.dygraph.guard():
                    func(x)
        else:
            with fluid.dygraph.guard():
                func(x)

    def _run_dy_static(self, func, x, with_exception):
        self._run(func, x, with_exception, True)
        self._run(func, x, with_exception, False)

H
hong 已提交
51
    @test_and_compare_with_new_ir(False)
52
    def test_non_variable(self):
53 54 55 56 57 58
        self._run_dy_static(
            dyfunc_assert_non_variable, x=False, with_exception=True
        )
        self._run_dy_static(
            dyfunc_assert_non_variable, x=True, with_exception=False
        )
59

H
hong 已提交
60
    @test_and_compare_with_new_ir(False)
61
    def test_bool_variable(self):
62 63 64 65 66 67
        self._run_dy_static(
            dyfunc_assert_variable, x=numpy.array([False]), with_exception=True
        )
        self._run_dy_static(
            dyfunc_assert_variable, x=numpy.array([True]), with_exception=False
        )
68 69

    def test_int_variable(self):
70 71 72 73 74 75
        self._run_dy_static(
            dyfunc_assert_variable, x=numpy.array([0]), with_exception=True
        )
        self._run_dy_static(
            dyfunc_assert_variable, x=numpy.array([1]), with_exception=False
        )
76 77 78 79


if __name__ == '__main__':
    unittest.main()