test_break_continue.py 7.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

17
import numpy as np
X
xiongkun 已提交
18
from dygraph_to_static_util import ast_only_test, dy2static_unittest
19

20
import paddle
21
from paddle import base
H
hjyp 已提交
22
from paddle.jit.api import to_static
23
from paddle.jit.dy2static.utils import Dygraph2StaticException
24 25 26 27 28

SEED = 2020
np.random.seed(SEED)


X
xiongkun 已提交
29
@dy2static_unittest
30 31 32 33 34 35
class TestDy2staticException(unittest.TestCase):
    def setUp(self):
        self.x = np.random.random([10, 16]).astype('float32')
        self.dyfunc = None
        self.error = "Your if/else have different number of return value."

X
xiongkun 已提交
36
    @ast_only_test
37 38 39
    def test_error(self):
        if self.dyfunc:
            with self.assertRaisesRegex(Dygraph2StaticException, self.error):
R
Ryan 已提交
40
                paddle.jit.enable_to_static(True)
H
hjyp 已提交
41
                self.assertTrue(to_static(self.dyfunc)(self.x))
42
        paddle.base.dygraph.base.global_var._in_to_static_mode_ = False
R
Ryan 已提交
43
        paddle.jit.enable_to_static(False)
44 45


46
def test_continue_in_for(x):
47
    x = base.dygraph.to_variable(x)
48 49 50 51 52 53 54 55 56 57
    for i in range(10):
        x += 1
        if i > 5:
            continue
            x += 10086
        x += i
    return x


def test_continue_in_for_at_end(x):
58
    x = base.dygraph.to_variable(x)
59 60 61 62 63 64 65 66
    for i in range(10):
        x += 1
        if i > 5:
            continue
    return x


def test_continue_in_while(x):
67
    x = base.dygraph.to_variable(x)
68
    i = paddle.tensor.fill_constant(shape=[1], dtype='int32', value=0)
69 70 71 72 73 74 75 76 77 78
    while i < 10:
        i += 1
        if i > 5:
            continue
            x += 10086
        x += i
    return x


def test_break_in_for(x):
79
    x = base.dygraph.to_variable(x)
80 81 82 83 84 85 86 87 88 89
    for i in range(10):
        x += 1
        if i > 5:
            break
            x += 10086
        x += i
    return x


def test_break_in_for_at_end(x):
90
    x = base.dygraph.to_variable(x)
91 92 93 94 95 96 97 98
    for i in range(10):
        x += 1
        if i > 5:
            break
    return x


def test_break_in_while(x):
99
    x = base.dygraph.to_variable(x)
100
    i = paddle.tensor.fill_constant(shape=[1], dtype='int32', value=0)
101 102 103 104 105 106 107 108 109 110
    while i < 10:
        i += 1
        if i > 5:
            break
            x += 10086
        x += i
    return x


def test_break_continue_in_for(x):
111
    x = base.dygraph.to_variable(x)
112

113 114 115 116 117 118 119 120
    for i in range(1, 10, 1):
        if i <= 4:
            x += 1
            continue
        else:
            x += 10010
            break
        x += 10086
121

122 123
    a = paddle.tensor.fill_constant(shape=[1], dtype='int32', value=0)
    b = paddle.tensor.fill_constant(shape=[1], dtype='int32', value=3)
124 125 126 127
    # b = 10
    # TODO: add Raise Error and suggestion for usage:
    #   Py for contains break/continue depends on control-flow.
    for i in range(b):
128
        if a <= 4:
129
            x += 1
130
            a += 1
131 132 133 134 135
            continue
        else:
            x += 10010
            break
        x += 10086
136

137 138 139 140
    return x


def test_for_in_else(x):
141
    x = base.dygraph.to_variable(x)
142

143 144 145 146 147 148 149 150 151
    # Case 1:
    if False:
        pass
    else:
        for i in range(0, 10):
            if i > 5:
                x += 1
                break
            x += i
152 153

    # Case 2:
154 155 156 157 158 159 160 161 162 163
    if False:
        pass
    else:
        for i in range(0, 10):
            x += 1
            break
            x += i
    return x


164
def while_loop_class_var(x):
165
    class Foo:
166 167 168 169 170 171
        def __init__(self):
            self.a = 3
            self.b = 4
            self.c = 5

    foo = Foo()
172
    i = base.dygraph.to_variable(x)
173
    while i < 10:
174
        foo.b = paddle.zeros(shape=[1], dtype='float32')
175 176 177 178 179 180 181 182 183
        foo.c = foo.b + foo.a
        i += 1
        if foo.c < 0:
            continue
        if foo.c > 6:
            break
    return foo.c


184 185 186 187 188 189 190 191 192 193 194 195 196 197
def test_optim_break_in_for(x):
    x = paddle.to_tensor(x)
    for i in range(10):
        if x.sum() > 5:
            break
            x += 10086
        x += i
        if i < 3:
            x = x * 2
    return x


def test_optim_break_in_while(x):
    x = paddle.to_tensor(x)
198
    i = paddle.tensor.fill_constant(shape=[1], dtype='int32', value=0)
199 200 201 202 203 204 205 206 207
    while i < 10:
        if i > 5:
            break
            x += 10086
        x += i
        i += 1
    return x


208 209
class TestContinueInFor(unittest.TestCase):
    def setUp(self):
210
        self.input = np.zeros(1).astype('int64')
211
        self.place = (
212 213 214
            base.CUDAPlace(0)
            if base.is_compiled_with_cuda()
            else base.CPUPlace()
215
        )
216 217 218 219 220 221
        self.init_dygraph_func()

    def init_dygraph_func(self):
        self.dygraph_func = test_continue_in_for

    def run_dygraph_mode(self):
222
        with base.dygraph.guard():
223 224 225 226
            res = self.dygraph_func(self.input)
            return res.numpy()

    def run_static_mode(self):
227
        with base.dygraph.guard():
H
hjyp 已提交
228
            res = to_static(self.dygraph_func)(self.input)
229
            return res.numpy()
230 231 232 233

    def test_transformed_static_result(self):
        static_res = self.run_static_mode()
        dygraph_res = self.run_dygraph_mode()
234 235 236 237 238
        np.testing.assert_allclose(
            dygraph_res,
            static_res,
            rtol=1e-05,
            err_msg='dygraph res is {}\nstatic_res is {}'.format(
239 240 241
                dygraph_res, static_res
            ),
        )
242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277


class TestContinueInForAtEnd(TestContinueInFor):
    def init_dygraph_func(self):
        self.dygraph_func = test_continue_in_for_at_end


class TestBreakInFor(TestContinueInFor):
    def init_dygraph_func(self):
        self.dygraph_func = test_break_in_for


class TestBreakInForAtEnd(TestContinueInFor):
    def init_dygraph_func(self):
        self.dygraph_func = test_break_in_for_at_end


class TestBreakContinueInFor(TestContinueInFor):
    def init_dygraph_func(self):
        self.dygraph_func = test_break_continue_in_for


class TestForInElse(TestContinueInFor):
    def init_dygraph_func(self):
        self.dygraph_func = test_for_in_else


class TestContinueInWhile(TestContinueInFor):
    def init_dygraph_func(self):
        self.dygraph_func = test_continue_in_while


class TestBreakInWhile(TestContinueInWhile):
    def init_dygraph_func(self):
        self.dygraph_func = test_break_in_while

278 279 280 281

class TestWhileLoopClassVar(TestContinueInWhile):
    def init_dygraph_func(self):
        self.dygraph_func = while_loop_class_var
282 283


284 285 286 287 288
class TestOptimBreakInFor(TestDy2staticException):
    def setUp(self):
        self.x = np.random.random([10, 16]).astype('float32')
        self.dyfunc = test_optim_break_in_for
        self.error = "python while pred change from bool to variable."
289 290 291 292 293 294 295


class TestOptimBreakInWhile(TestContinueInWhile):
    def init_dygraph_func(self):
        self.dygraph_func = test_optim_break_in_while


296
if __name__ == '__main__':
297
    unittest.main()