test_break_continue.py 7.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

17
import numpy as np
18

19
import paddle
20
from paddle import fluid
H
hjyp 已提交
21
from paddle.jit.api import to_static
22
from paddle.jit.dy2static.utils import Dygraph2StaticException
23 24 25 26 27

SEED = 2020
np.random.seed(SEED)


28 29 30 31 32 33 34 35 36
class TestDy2staticException(unittest.TestCase):
    def setUp(self):
        self.x = np.random.random([10, 16]).astype('float32')
        self.dyfunc = None
        self.error = "Your if/else have different number of return value."

    def test_error(self):
        if self.dyfunc:
            with self.assertRaisesRegex(Dygraph2StaticException, self.error):
R
Ryan 已提交
37
                paddle.jit.enable_to_static(True)
H
hjyp 已提交
38
                self.assertTrue(to_static(self.dyfunc)(self.x))
39
        paddle.fluid.dygraph.base.global_var._in_declarative_mode_ = False
R
Ryan 已提交
40
        paddle.jit.enable_to_static(False)
41 42


43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
def test_continue_in_for(x):
    x = fluid.dygraph.to_variable(x)
    for i in range(10):
        x += 1
        if i > 5:
            continue
            x += 10086
        x += i
    return x


def test_continue_in_for_at_end(x):
    x = fluid.dygraph.to_variable(x)
    for i in range(10):
        x += 1
        if i > 5:
            continue
    return x


def test_continue_in_while(x):
    x = fluid.dygraph.to_variable(x)
65
    i = paddle.tensor.fill_constant(shape=[1], dtype='int32', value=0)
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
    while i < 10:
        i += 1
        if i > 5:
            continue
            x += 10086
        x += i
    return x


def test_break_in_for(x):
    x = fluid.dygraph.to_variable(x)
    for i in range(10):
        x += 1
        if i > 5:
            break
            x += 10086
        x += i
    return x


def test_break_in_for_at_end(x):
    x = fluid.dygraph.to_variable(x)
    for i in range(10):
        x += 1
        if i > 5:
            break
    return x


def test_break_in_while(x):
    x = fluid.dygraph.to_variable(x)
97
    i = paddle.tensor.fill_constant(shape=[1], dtype='int32', value=0)
98 99 100 101 102 103 104 105 106 107 108
    while i < 10:
        i += 1
        if i > 5:
            break
            x += 10086
        x += i
    return x


def test_break_continue_in_for(x):
    x = fluid.dygraph.to_variable(x)
109

110 111 112 113 114 115 116 117
    for i in range(1, 10, 1):
        if i <= 4:
            x += 1
            continue
        else:
            x += 10010
            break
        x += 10086
118

119 120
    a = paddle.tensor.fill_constant(shape=[1], dtype='int32', value=0)
    b = paddle.tensor.fill_constant(shape=[1], dtype='int32', value=3)
121 122 123 124
    # b = 10
    # TODO: add Raise Error and suggestion for usage:
    #   Py for contains break/continue depends on control-flow.
    for i in range(b):
125
        if a <= 4:
126
            x += 1
127
            a += 1
128 129 130 131 132
            continue
        else:
            x += 10010
            break
        x += 10086
133

134 135 136 137 138
    return x


def test_for_in_else(x):
    x = fluid.dygraph.to_variable(x)
139

140 141 142 143 144 145 146 147 148
    # Case 1:
    if False:
        pass
    else:
        for i in range(0, 10):
            if i > 5:
                x += 1
                break
            x += i
149 150

    # Case 2:
151 152 153 154 155 156 157 158 159 160
    if False:
        pass
    else:
        for i in range(0, 10):
            x += 1
            break
            x += i
    return x


161
def while_loop_class_var(x):
162
    class Foo:
163 164 165 166 167 168 169 170
        def __init__(self):
            self.a = 3
            self.b = 4
            self.c = 5

    foo = Foo()
    i = fluid.dygraph.to_variable(x)
    while i < 10:
171
        foo.b = paddle.zeros(shape=[1], dtype='float32')
172 173 174 175 176 177 178 179 180
        foo.c = foo.b + foo.a
        i += 1
        if foo.c < 0:
            continue
        if foo.c > 6:
            break
    return foo.c


181 182 183 184 185 186 187 188 189 190 191 192 193 194
def test_optim_break_in_for(x):
    x = paddle.to_tensor(x)
    for i in range(10):
        if x.sum() > 5:
            break
            x += 10086
        x += i
        if i < 3:
            x = x * 2
    return x


def test_optim_break_in_while(x):
    x = paddle.to_tensor(x)
195
    i = paddle.tensor.fill_constant(shape=[1], dtype='int32', value=0)
196 197 198 199 200 201 202 203 204
    while i < 10:
        if i > 5:
            break
            x += 10086
        x += i
        i += 1
    return x


205 206
class TestContinueInFor(unittest.TestCase):
    def setUp(self):
207
        self.input = np.zeros(1).astype('int64')
208 209 210 211 212
        self.place = (
            fluid.CUDAPlace(0)
            if fluid.is_compiled_with_cuda()
            else fluid.CPUPlace()
        )
213 214 215 216 217 218 219 220 221 222 223
        self.init_dygraph_func()

    def init_dygraph_func(self):
        self.dygraph_func = test_continue_in_for

    def run_dygraph_mode(self):
        with fluid.dygraph.guard():
            res = self.dygraph_func(self.input)
            return res.numpy()

    def run_static_mode(self):
224
        with fluid.dygraph.guard():
H
hjyp 已提交
225
            res = to_static(self.dygraph_func)(self.input)
226
            return res.numpy()
227 228 229 230

    def test_transformed_static_result(self):
        static_res = self.run_static_mode()
        dygraph_res = self.run_dygraph_mode()
231 232 233 234 235
        np.testing.assert_allclose(
            dygraph_res,
            static_res,
            rtol=1e-05,
            err_msg='dygraph res is {}\nstatic_res is {}'.format(
236 237 238
                dygraph_res, static_res
            ),
        )
239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274


class TestContinueInForAtEnd(TestContinueInFor):
    def init_dygraph_func(self):
        self.dygraph_func = test_continue_in_for_at_end


class TestBreakInFor(TestContinueInFor):
    def init_dygraph_func(self):
        self.dygraph_func = test_break_in_for


class TestBreakInForAtEnd(TestContinueInFor):
    def init_dygraph_func(self):
        self.dygraph_func = test_break_in_for_at_end


class TestBreakContinueInFor(TestContinueInFor):
    def init_dygraph_func(self):
        self.dygraph_func = test_break_continue_in_for


class TestForInElse(TestContinueInFor):
    def init_dygraph_func(self):
        self.dygraph_func = test_for_in_else


class TestContinueInWhile(TestContinueInFor):
    def init_dygraph_func(self):
        self.dygraph_func = test_continue_in_while


class TestBreakInWhile(TestContinueInWhile):
    def init_dygraph_func(self):
        self.dygraph_func = test_break_in_while

275 276 277 278

class TestWhileLoopClassVar(TestContinueInWhile):
    def init_dygraph_func(self):
        self.dygraph_func = while_loop_class_var
279 280


281 282 283 284 285
class TestOptimBreakInFor(TestDy2staticException):
    def setUp(self):
        self.x = np.random.random([10, 16]).astype('float32')
        self.dyfunc = test_optim_break_in_for
        self.error = "python while pred change from bool to variable."
286 287 288 289 290 291 292


class TestOptimBreakInWhile(TestContinueInWhile):
    def init_dygraph_func(self):
        self.dygraph_func = test_optim_break_in_while


293
if __name__ == '__main__':
294
    unittest.main()