test_ifelse.py 14.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

17
import numpy as np
X
xiongkun 已提交
18
from dygraph_to_static_util import ast_only_test, dy2static_unittest
19 20 21
from ifelse_simple_func import (
    NetWithControlFlowIf,
    add_fn,
22
    base,
23 24 25 26 27 28 29 30
    dyfunc_empty_nonlocal,
    dyfunc_ifelse_ret_int1,
    dyfunc_ifelse_ret_int2,
    dyfunc_ifelse_ret_int3,
    dyfunc_ifelse_ret_int4,
    dyfunc_with_if_else,
    dyfunc_with_if_else2,
    dyfunc_with_if_else3,
C
co63oc 已提交
31
    dyfunc_with_if_else_with_list_generator,
32 33 34 35 36 37 38 39 40 41 42 43
    if_tensor_case,
    if_with_and_or,
    if_with_and_or_1,
    if_with_and_or_2,
    if_with_and_or_3,
    if_with_and_or_4,
    if_with_class_var,
    loss_fn,
    nested_if_else,
    nested_if_else_2,
    nested_if_else_3,
)
44

45
import paddle
46
import paddle.nn.functional as F
47
from paddle.base import core
48
from paddle.jit.dy2static.utils import Dygraph2StaticException
49

50 51
np.random.seed(1)

52 53
if base.is_compiled_with_cuda():
    place = base.CUDAPlace(0)
54
else:
55
    place = base.CPUPlace()
56

57

X
xiongkun 已提交
58
@dy2static_unittest
59 60 61 62 63 64
class TestDy2staticException(unittest.TestCase):
    def setUp(self):
        self.x = np.random.random([10, 16]).astype('float32')
        self.dyfunc = None
        self.error = "Your if/else have different number of return value."

X
xiongkun 已提交
65
    @ast_only_test
66 67 68
    def test_error(self):
        if self.dyfunc:
            with self.assertRaisesRegex(Dygraph2StaticException, self.error):
R
Ryan 已提交
69
                paddle.jit.enable_to_static(True)
H
hjyp 已提交
70
                self.assertTrue(paddle.jit.to_static(self.dyfunc)(self.x))
71
        paddle.base.dygraph.base.global_var._in_to_static_mode_ = False
R
Ryan 已提交
72
        paddle.jit.enable_to_static(False)
73 74


75 76 77
class TestDygraphIfElse(unittest.TestCase):
    """
    TestCase for the transformation from control flow `if/else`
78
    dependent on tensor in Dygraph into Static `base.layers.cond`.
79 80 81 82 83 84 85
    """

    def setUp(self):
        self.x = np.random.random([10, 16]).astype('float32')
        self.dyfunc = dyfunc_with_if_else

    def _run_static(self):
86 87 88
        return self._run_dygraph(to_static=True)

    def _run_dygraph(self, to_static=False):
89 90
        with base.dygraph.guard(place):
            x_v = base.dygraph.to_variable(self.x)
91
            if to_static:
H
hjyp 已提交
92
                ret = paddle.jit.to_static(self.dyfunc)(x_v)
93 94
            else:
                ret = self.dyfunc(x_v)
95 96 97 98 99 100 101 102 103 104 105 106 107
            return ret.numpy()

    def test_ast_to_func(self):
        self.assertTrue((self._run_dygraph() == self._run_static()).all())


class TestDygraphIfElse2(TestDygraphIfElse):
    def setUp(self):
        self.x = np.random.random([10, 16]).astype('float32')
        self.dyfunc = dyfunc_with_if_else2


class TestDygraphIfElse3(TestDygraphIfElse):
108 109 110 111 112
    def setUp(self):
        self.x = np.random.random([10, 16]).astype('float32')
        self.dyfunc = dyfunc_with_if_else3


113 114 115 116 117 118
class TestDygraphIfElse4(TestDygraphIfElse):
    def setUp(self):
        self.x = np.random.random([10, 16]).astype('float32')
        self.dyfunc = dyfunc_empty_nonlocal


119 120 121
class TestDygraphIfElseWithListGenerator(TestDygraphIfElse):
    def setUp(self):
        self.x = np.random.random([10, 16]).astype('float32')
C
co63oc 已提交
122
        self.dyfunc = dyfunc_with_if_else_with_list_generator
123 124


125
class TestDygraphNestedIfElse(TestDygraphIfElse):
126 127 128
    def setUp(self):
        self.x = np.random.random([10, 16]).astype('float32')
        self.dyfunc = nested_if_else
129 130


131
class TestDygraphNestedIfElse2(TestDygraphIfElse):
132 133 134 135 136
    def setUp(self):
        self.x = np.random.random([10, 16]).astype('float32')
        self.dyfunc = nested_if_else_2


137
class TestDygraphNestedIfElse3(TestDygraphIfElse):
138 139 140 141 142
    def setUp(self):
        self.x = np.random.random([10, 16]).astype('float32')
        self.dyfunc = nested_if_else_3


143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
def dyfunc_ifExp_with_while(x):
    y = [x]

    def add_fn(x):
        x = x + 1
        return x

    def cond(i, ten, y):
        return i < ten

    def map_func(func, tensor_list):
        return [func(x) for x in tensor_list]

    def body(i, ten, y):
        # It will be converted into `layers.cond` as followed.
158
        # map_func(lambda x: paddle.static.nn.cond(i==0, lambda: x, lambda: add_fn(x), y)
159
        y = map_func(lambda x: x if (i == 0) is not None else add_fn(x), y)
160 161 162
        i += 1
        return [i, ten, y]

163 164
    i = paddle.tensor.fill_constant(shape=[1], dtype='int64', value=0)
    ten = paddle.tensor.fill_constant(shape=[1], dtype='int64', value=10)
165
    i, ten, y = paddle.static.nn.while_loop(cond, body, [i, ten, y])
166 167 168 169 170 171 172 173 174
    return y[0]


class TestDygraphIfElse6(TestDygraphIfElse):
    def setUp(self):
        self.x = np.random.random([10, 16]).astype('float32')
        self.dyfunc = dyfunc_ifExp_with_while


175
def dyfunc_ifExp(x):
176 177 178 179 180 181 182 183 184
    y = [x]

    def add_fn(x):
        x = x + 1
        return x

    def map_func(func, tensor_list):
        return [func(x) for x in tensor_list]

185
    i = paddle.tensor.fill_constant(shape=[1], dtype='int64', value=0)
186
    # It will be converted into `layers.cond` as followed.
187
    # map_func(lambda x: paddle.static.nn.cond(i==1, lambda: x, lambda: add_fn(x), y)
188 189
    # `if (Tensor) == 1` is supported in dygraph.
    y = map_func(lambda x: x if i == 1 else add_fn(x), y)
190 191 192 193 194 195
    return y[0]


class TestDygraphIfElse7(TestDygraphIfElse):
    def setUp(self):
        self.x = np.random.random([10, 16]).astype('float32')
196
        self.dyfunc = dyfunc_ifExp
197 198


199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
class TestDygraphIfElseWithAndOr(TestDygraphIfElse):
    def setUp(self):
        self.x = np.random.random([10, 16]).astype('float32')
        self.dyfunc = if_with_and_or


class TestDygraphIfElseWithAndOr1(TestDygraphIfElse):
    def setUp(self):
        self.x = np.random.random([10, 16]).astype('float32')
        self.dyfunc = if_with_and_or_1


class TestDygraphIfElseWithAndOr2(TestDygraphIfElse):
    def setUp(self):
        self.x = np.random.random([10, 16]).astype('float32')
        self.dyfunc = if_with_and_or_2


class TestDygraphIfElseWithAndOr3(TestDygraphIfElse):
    def setUp(self):
        self.x = np.random.random([10, 16]).astype('float32')
        self.dyfunc = if_with_and_or_3


class TestDygraphIfElseWithAndOr4(TestDygraphIfElse):
    def setUp(self):
        self.x = np.random.random([10, 16]).astype('float32')
        self.dyfunc = if_with_and_or_4
227 228


229 230 231 232 233 234
class TestDygraphIfElseWithClassVar(TestDygraphIfElse):
    def setUp(self):
        self.x = np.random.random([10, 16]).astype('float32')
        self.dyfunc = if_with_class_var


235 236 237 238 239 240
class TestDygraphIfTensor(TestDygraphIfElse):
    def setUp(self):
        self.x = np.random.random([10, 16]).astype('float32')
        self.dyfunc = if_tensor_case


241 242 243
class TestDygraphIfElseNet(unittest.TestCase):
    """
    TestCase for the transformation from control flow `if/else`
244
    dependent on tensor in Dygraph into Static `base.layers.cond`.
245 246 247 248 249 250 251
    """

    def setUp(self):
        self.x = np.random.random([10, 16]).astype('float32')
        self.Net = NetWithControlFlowIf

    def _run_static(self):
252
        return self._run(to_static=True)
253 254

    def _run_dygraph(self):
255 256 257
        return self._run(to_static=False)

    def _run(self, to_static=False):
R
Ryan 已提交
258
        paddle.jit.enable_to_static(to_static)
259

260
        with base.dygraph.guard(place):
261
            net = self.Net()
262
            x_v = base.dygraph.to_variable(self.x)
263 264 265 266 267 268 269
            ret = net(x_v)
            return ret.numpy()

    def test_ast_to_func(self):
        self.assertTrue((self._run_dygraph() == self._run_static()).all())


270 271
# Test to call function ahead caller.
def relu(x):
272
    return F.relu(x)
273 274


275
def call_external_func(x, label=None):
276
    if paddle.mean(x) < 0:
277 278 279 280
        x_v = x - 1
    else:
        x_v = add_fn(x)

281
    x_v = relu(x_v)
282 283 284 285 286 287 288 289 290 291 292 293
    if label is not None:
        loss = loss_fn(x_v, label)
        return loss
    return x_v


class TestAst2FuncWithExternalFunc(TestDygraphIfElse):
    def setUp(self):
        self.x = np.random.random([10, 16]).astype('float32')
        self.dyfunc = call_external_func


294
class NetWithExternalFunc(paddle.nn.Layer):
H
hjyp 已提交
295
    @paddle.jit.to_static
296
    def forward(self, x, label=None):
297
        if paddle.mean(x) < 0:
298 299 300 301
            x_v = x - 1
        else:
            x_v = add_fn(x)

302
        x_v = softmax(x_v)
303 304 305 306 307 308
        if label is not None:
            loss = loss_fn(x_v, label)
            return loss
        return x_v


309 310
# Test to call function behind caller.
def softmax(x):
311
    return paddle.nn.functional.softmax(x)
312 313


314 315 316 317 318 319
class TestNetWithExternalFunc(TestDygraphIfElseNet):
    def setUp(self):
        self.x = np.random.random([10, 16]).astype('float32')
        self.Net = NetWithExternalFunc


320 321
class DiffModeNet1(paddle.nn.Layer):
    def __init__(self, mode):
322
        super().__init__()
323 324 325 326 327 328 329 330 331 332 333 334 335 336 337
        self.mode = mode

    @paddle.jit.to_static
    def forward(self, x, y):
        if self.mode == 'train':
            out = x + y
        elif self.mode == 'infer':
            out = x - y
        else:
            raise ValueError('Illegal mode')
        return out


class DiffModeNet2(paddle.nn.Layer):
    def __init__(self, mode):
338
        super().__init__()
339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366
        self.mode = mode

    @paddle.jit.to_static
    def forward(self, x, y):
        if self.mode == 'train':
            out = x + y
            return out
        elif self.mode == 'infer':
            out = x - y
            return out
        else:
            raise ValueError('Illegal mode')


class TestDiffModeNet(unittest.TestCase):
    """
    TestCase for the net with different modes
    """

    def setUp(self):
        self.x = paddle.randn([10, 16], 'float32')
        self.y = paddle.randn([10, 16], 'float32')
        self.init_net()

    def init_net(self):
        self.Net = DiffModeNet1

    def _run(self, mode, to_static):
R
Ryan 已提交
367
        paddle.jit.enable_to_static(to_static)
368 369 370 371 372 373

        net = self.Net(mode)
        ret = net(self.x, self.y)
        return ret.numpy()

    def test_train_mode(self):
374
        self.assertTrue(
375 376 377 378 379
            (
                self._run(mode='train', to_static=True)
                == self._run(mode='train', to_static=False)
            ).all()
        )
380 381

    def test_infer_mode(self):
382
        self.assertTrue(
383 384 385 386 387
            (
                self._run(mode='infer', to_static=True)
                == self._run(mode='infer', to_static=False)
            ).all()
        )
388 389 390 391 392 393 394


class TestDiffModeNet2(TestDiffModeNet):
    def init_net(self):
        self.Net = DiffModeNet2


395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417
class TestNewVarCreateInOneBranch(unittest.TestCase):
    def test_var_used_in_another_for(self):
        def case_func(training):
            # targets and targets_list is dynamically defined by training
            if training:
                targets = [1, 2, 3]
                targets_list = [targets]

            num_step = 3
            for i in range(num_step):
                if i > 0:
                    rois, rosi_num = 1, 2
                    # targets is in loop_vars.
                    if training:
                        ros, rosi_num, targets = -1, -2, [-1, -2, -3]
                        targets_list.append(targets)

            return rosi_num

        self.assertEqual(paddle.jit.to_static(case_func)(False), 2)
        self.assertEqual(paddle.jit.to_static(case_func)(True), -2)


X
xiongkun 已提交
418
@dy2static_unittest
419 420 421
class TestDy2StIfElseRetInt1(unittest.TestCase):
    def setUp(self):
        self.x = np.random.random([5]).astype('float32')
X
xiongkun 已提交
422
        self.dyfunc = paddle.jit.to_static(dyfunc_ifelse_ret_int1)
423 424 425
        self.out = self.get_dy2stat_out()

    def get_dy2stat_out(self):
R
Ryan 已提交
426
        paddle.jit.enable_to_static(True)
427 428
        static_func = paddle.jit.to_static(self.dyfunc)
        out = static_func(self.x)
R
Ryan 已提交
429
        paddle.jit.enable_to_static(False)
430 431
        return out

X
xiongkun 已提交
432
    @ast_only_test
433
    def test_ast_to_func(self):
X
xiongkun 已提交
434
        self.setUp()
435
        self.assertIsInstance(self.out[0], (paddle.Tensor, core.eager.Tensor))
436 437 438
        self.assertIsInstance(self.out[1], int)


439
class TestDy2StIfElseRetInt2(TestDy2staticException):
440 441
    def setUp(self):
        self.x = np.random.random([5]).astype('float32')
442
        self.error = "Your if/else have different number of return value."
443 444 445
        self.dyfunc = dyfunc_ifelse_ret_int2


X
xiongkun 已提交
446
@dy2static_unittest
447 448 449
class TestDy2StIfElseRetInt3(TestDy2StIfElseRetInt1):
    def setUp(self):
        self.x = np.random.random([5]).astype('float32')
X
xiongkun 已提交
450
        self.dyfunc = paddle.jit.to_static(dyfunc_ifelse_ret_int3)
451 452
        self.out = self.get_dy2stat_out()

X
xiongkun 已提交
453
    @ast_only_test
454
    def test_ast_to_func(self):
X
xiongkun 已提交
455
        self.setUp()
456
        self.assertIsInstance(self.out, (paddle.Tensor, core.eager.Tensor))
457 458


X
xiongkun 已提交
459
@dy2static_unittest
460 461 462
class TestDy2StIfElseRetInt4(TestDy2StIfElseRetInt1):
    def setUp(self):
        self.x = np.random.random([5]).astype('float32')
X
xiongkun 已提交
463
        self.dyfunc = paddle.jit.to_static(dyfunc_ifelse_ret_int4)
464

X
xiongkun 已提交
465
    @ast_only_test
466
    def test_ast_to_func(self):
R
Ryan 已提交
467
        paddle.jit.enable_to_static(True)
468
        with self.assertRaises(Dygraph2StaticException):
469 470
            static_func = paddle.jit.to_static(self.dyfunc)
            out = static_func(self.x)
471 472
        # Why need set `_in_to_static_mode_` here?
        # In Dy2St we use `with _to_static_mode_guard_()` to indicate
473
        # that the code block is under @to_static, but in this UT
474 475 476
        # an exception is thrown during Dy2St, making the `_in_to_static_mode_`
        # a wrong value. So We need set `_in_to_static_mode_` to False manually.
        paddle.base.dygraph.base.global_var._in_to_static_mode_ = False
R
Ryan 已提交
477
        paddle.jit.enable_to_static(False)
478 479


480 481
class IfElseNet(paddle.nn.Layer):
    def __init__(self):
482
        super().__init__()
483 484 485
        self.param = self.create_parameter(
            shape=[3, 2], dtype='float32', is_bias=False
        )
486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512

    @paddle.jit.to_static
    def forward(self, a, b, c):
        a = paddle.matmul(a, self.param)
        a = paddle.reshape(a, (2, 4))
        cond = paddle.to_tensor([10])
        if cond == 10:
            a_argmax = a.argmax(axis=-1)
            b = b + self.param
        else:
            print(c)
        return b


class TestDy2StIfElseBackward(unittest.TestCase):
    def test_run_backward(self):
        a = paddle.randn((4, 3), dtype='float32')
        a.stop_gradient = False
        b = paddle.to_tensor([10]).astype('float32')
        b.stop_gradient = False
        c = paddle.to_tensor([2])
        c.stop_gradient = False

        net = IfElseNet()
        net.train()
        out = net(a, b, c)
        out.backward()
513 514 515
        np.testing.assert_allclose(
            (b + net.param).numpy(), out.numpy(), rtol=1e-05
        )
516 517


518
if __name__ == '__main__':
519
    unittest.main()