test_imperative_basic.py 37.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

X
Xin Pan 已提交
15
import contextlib
16 17 18 19 20
import unittest
import numpy as np

import paddle.fluid as fluid
from paddle.fluid import core
21
from paddle.fluid import Linear
22
from paddle.fluid.layer_helper import LayerHelper
M
minqiyang 已提交
23
from test_imperative_base import new_program_scope
24
import paddle.fluid.dygraph_utils as dygraph_utils
25
from paddle.fluid.dygraph.layer_object_helper import LayerObjectHelper
26
import paddle
27
from paddle.fluid.framework import _test_eager_guard, _in_eager_mode
28 29


30
class MyLayer(fluid.Layer):
31 32
    def __init__(self):
        super(MyLayer, self).__init__()
33 34

    def forward(self, inputs):
M
minqiyang 已提交
35
        x = fluid.layers.relu(inputs)
36
        self._x_for_debug = x
X
Xin Pan 已提交
37 38 39
        x = fluid.layers.elementwise_mul(x, x)
        x = fluid.layers.reduce_sum(x)
        return [x]
40 41


42
class MLP(fluid.Layer):
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
    def __init__(self, input_size):
        super(MLP, self).__init__()
        self._linear1 = Linear(
            input_size,
            3,
            param_attr=fluid.ParamAttr(
                initializer=fluid.initializer.Constant(value=0.1)),
            bias_attr=fluid.ParamAttr(
                initializer=fluid.initializer.Constant(value=0.1)))
        self._linear2 = Linear(
            3,
            4,
            param_attr=fluid.ParamAttr(
                initializer=fluid.initializer.Constant(value=0.1)),
            bias_attr=fluid.ParamAttr(
                initializer=fluid.initializer.Constant(value=0.1)))
X
Xin Pan 已提交
59 60

    def forward(self, inputs):
61 62
        x = self._linear1(inputs)
        x = self._linear2(x)
X
Xin Pan 已提交
63 64 65 66
        x = fluid.layers.reduce_sum(x)
        return x


67
class SimpleRNNCell(fluid.Layer):
68 69
    def __init__(self, step_input_size, hidden_size, output_size, param_attr):
        super(SimpleRNNCell, self).__init__()
70 71 72
        self.step_input_size = step_input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
73 74
        self._dtype = core.VarDesc.VarType.FP32
        self.param_attr = param_attr
75 76 77 78

        i2h_param_shape = [self.step_input_size, self.hidden_size]
        h2h_param_shape = [self.hidden_size, self.hidden_size]
        h2o_param_shape = [self.output_size, self.hidden_size]
S
songyouwei 已提交
79
        self._i2h_w = None
80 81
        self._i2h_w = self.create_parameter(
            attr=self.param_attr,
82 83 84
            shape=i2h_param_shape,
            dtype=self._dtype,
            is_bias=False)
85 86
        self._h2h_w = self.create_parameter(
            attr=self.param_attr,
87 88 89
            shape=h2h_param_shape,
            dtype=self._dtype,
            is_bias=False)
90 91
        self._h2o_w = self.create_parameter(
            attr=self.param_attr,
92 93 94 95 96
            shape=h2o_param_shape,
            dtype=self._dtype,
            is_bias=False)

    def forward(self, input, pre_hidden):
97 98 99 100 101 102
        tmp_i2h = self.create_variable(dtype=self._dtype)
        tmp_h2h = self.create_variable(dtype=self._dtype)
        hidden = self.create_variable(dtype=self._dtype)
        out = self.create_variable(dtype=self._dtype)
        softmax_out = self.create_variable(dtype=self._dtype)
        reduce_out = self.create_variable(dtype=self._dtype)
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
        self._helper.append_op(
            type="mul",
            inputs={"X": input,
                    "Y": self._i2h_w},
            outputs={"Out": tmp_i2h},
            attrs={"x_num_col_dims": 1,
                   "y_num_col_dims": 1})

        self._helper.append_op(
            type="mul",
            inputs={"X": pre_hidden,
                    "Y": self._h2h_w},
            outputs={"Out": tmp_h2h},
            attrs={"x_num_col_dims": 1,
                   "y_num_col_dims": 1})

        self._helper.append_op(
            type="elementwise_add",
            inputs={'X': tmp_h2h,
                    'Y': tmp_i2h},
            outputs={'Out': hidden},
            attrs={'axis': -1,
                   'use_mkldnn': False})
126
        hidden = self._helper.append_activation(hidden, act='tanh')
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145

        self._helper.append_op(
            type="mul",
            inputs={"X": hidden,
                    "Y": self._h2o_w},
            outputs={"Out": out},
            attrs={"x_num_col_dims": 1,
                   "y_num_col_dims": 1})

        self._helper.append_op(
            type="softmax",
            inputs={"X": out},
            outputs={"Out": softmax_out},
            attrs={"use_cudnn": False})

        self._helper.append_op(
            type='reduce_sum',
            inputs={'X': softmax_out},
            outputs={'Out': reduce_out},
146
            attrs={'keep_dim': False,
147 148 149 150 151
                   'reduce_all': True})

        return reduce_out, hidden


152
class SimpleRNN(fluid.Layer):
153 154
    def __init__(self):
        super(SimpleRNN, self).__init__()
J
JiabinYang 已提交
155 156 157 158 159 160
        self.seq_len = 4
        self._cell = SimpleRNNCell(
            3,
            3,
            3,
            fluid.ParamAttr(initializer=fluid.initializer.Constant(value=0.1)))
J
JiabinYang 已提交
161 162

    def forward(self, inputs):
J
JiabinYang 已提交
163
        outs = list()
J
JiabinYang 已提交
164 165
        pre_hiddens = list()

166
        init_hidden = self.create_parameter(
J
JiabinYang 已提交
167 168 169 170 171 172
            attr=fluid.ParamAttr(
                initializer=fluid.initializer.Constant(value=0.1)),
            shape=[1, 3],
            dtype='float32',
            is_bias=False)
        pre_hidden = init_hidden
J
JiabinYang 已提交
173
        for i in range(self.seq_len):
J
JiabinYang 已提交
174 175 176
            input = fluid.layers.slice(
                inputs, axes=[1], starts=[i], ends=[i + 1])
            input = fluid.layers.reshape(input, shape=[1, 3])
J
JiabinYang 已提交
177 178
            out_softmax, pre_hidden = self._cell(input, pre_hidden)
            outs.append(out_softmax)
J
JiabinYang 已提交
179

J
JiabinYang 已提交
180
        return outs, pre_hiddens
J
JiabinYang 已提交
181 182


M
minqiyang 已提交
183
class TestImperative(unittest.TestCase):
184
    def functional_dygraph_context(self):
185 186 187 188
        self.assertFalse(fluid.dygraph.enabled())
        fluid.enable_dygraph()
        self.assertTrue(fluid.dygraph.enabled())
        np_inp = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
189
        var_inp = paddle.to_tensor(np_inp)
190 191 192 193 194 195 196 197 198
        mlp = MLP(input_size=2)
        out = mlp(var_inp)
        dy_out1 = out.numpy()
        out.backward()
        dy_grad1 = mlp._linear1.weight.gradient()
        fluid.disable_dygraph()
        self.assertFalse(fluid.dygraph.enabled())
        with fluid.dygraph.guard():
            self.assertTrue(fluid.dygraph.enabled())
199
            var_inp = paddle.to_tensor(np_inp)
200 201 202 203 204 205 206
            mlp = MLP(input_size=2)
            out = mlp(var_inp)
            dy_out2 = out.numpy()
            out.backward()
            dy_grad2 = mlp._linear1.weight.gradient()
        self.assertFalse(fluid.dygraph.enabled())
        self.assertTrue(np.array_equal(dy_out1, dy_out2))
207 208
        self.assertTrue(np.array_equal(dy_grad1, dy_grad2))

209 210 211 212 213 214
    def test_functional_dygraph_context(self):
        with _test_eager_guard():
            self.functional_dygraph_context()
        self.functional_dygraph_context()

    def functional_paddle_imperative_dygraph_context(self):
215 216 217
        self.assertFalse(paddle.in_dynamic_mode())
        paddle.disable_static()
        self.assertTrue(paddle.in_dynamic_mode())
218
        np_inp = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
Z
Zhou Wei 已提交
219
        var_inp = paddle.to_tensor(np_inp)
220 221 222 223 224
        mlp = MLP(input_size=2)
        out = mlp(var_inp)
        dy_out1 = out.numpy()
        out.backward()
        dy_grad1 = mlp._linear1.weight.gradient()
225 226 227 228
        paddle.enable_static()
        self.assertFalse(paddle.in_dynamic_mode())
        paddle.disable_static()
        self.assertTrue(paddle.in_dynamic_mode())
Z
Zhou Wei 已提交
229
        var_inp = paddle.to_tensor(np_inp)
230 231 232 233 234 235 236
        mlp = MLP(input_size=2)
        out = mlp(var_inp)
        dy_out2 = out.numpy()
        out.backward()
        dy_grad2 = mlp._linear1.weight.gradient()
        paddle.enable_static()
        self.assertFalse(paddle.in_dynamic_mode())
237
        self.assertTrue(np.array_equal(dy_out1, dy_out2))
238 239
        self.assertTrue(np.array_equal(dy_grad1, dy_grad2))

240 241 242 243 244 245
    def test_functional_paddle_imperative_dygraph_context(self):
        with _test_eager_guard():
            self.functional_paddle_imperative_dygraph_context()
        self.functional_paddle_imperative_dygraph_context()

    def func_isinstance(self):
246 247 248
        var = fluid.layers.data(shape=[1], name='x', dtype='float32')
        self.assertTrue(isinstance(var, fluid.Variable))
        with fluid.dygraph.guard():
249 250 251 252 253 254 255 256 257 258 259 260
            if fluid.framework._in_eager_mode():
                var_base = paddle.to_tensor(np.array([3, 4, 5]))
                self.assertTrue(isinstance(var_base, core.eager.EagerTensor))
            else:
                var_base = paddle.to_tensor(np.array([3, 4, 5]))
                self.assertTrue(isinstance(var_base, core.VarBase))
                self.assertTrue(isinstance(var_base, fluid.Variable))

    def test_isinstance(self):
        with _test_eager_guard():
            self.func_isinstance()
        self.func_isinstance()
261

262
    def func_create_varbase(self):
263 264
        x = np.ones([2, 2], np.float32)
        y = np.zeros([3, 3], np.float32)
265 266
        t = fluid.Tensor()
        t.set(x, fluid.CPUPlace())
267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283
        if _in_eager_mode():
            # TODO(jiabin): Support Kwargs and uncomment these tests
            # egr_tmp = fluid.core.eager.EagerTensor(value=x, place=fluid.core.CPUPlace())
            egr_tmp2 = fluid.core.eager.EagerTensor(y, fluid.core.CPUPlace())
            egr_tmp3 = paddle.to_tensor(x)
            egr_tmp4 = fluid.core.eager.EagerTensor(y)
            # egr_tmp5 = fluid.core.eager.EagerTensor(value=x)
            # TODO(jiabin): Support it when we merge LoDTensor with DenseTensor
            egr_tmp6 = fluid.core.eager.EagerTensor(t)

            # self.assertTrue(np.array_equal(x, egr_tmp.numpy()))
            self.assertTrue(np.array_equal(y, egr_tmp2.numpy()))
            self.assertTrue(np.array_equal(x, egr_tmp3.numpy()))
            self.assertTrue(np.array_equal(y, egr_tmp4.numpy()))
            # self.assertTrue(np.array_equal(x, egr_tmp5.numpy()))
            self.assertTrue(np.array_equal(x, egr_tmp6.numpy()))
        else:
284 285
            tmp = fluid.core.VarBase(value=x, place=fluid.core.CPUPlace())
            tmp2 = fluid.core.VarBase(y, fluid.core.CPUPlace())
286
            tmp3 = paddle.to_tensor(x)
287 288
            tmp4 = fluid.core.VarBase(y)
            tmp5 = fluid.core.VarBase(value=x)
289
            tmp6 = fluid.core.VarBase(t)
290 291 292 293 294 295

            self.assertTrue(np.array_equal(x, tmp.numpy()))
            self.assertTrue(np.array_equal(y, tmp2.numpy()))
            self.assertTrue(np.array_equal(x, tmp3.numpy()))
            self.assertTrue(np.array_equal(y, tmp4.numpy()))
            self.assertTrue(np.array_equal(x, tmp5.numpy()))
296
            self.assertTrue(np.array_equal(x, tmp6.numpy()))
297

298 299 300 301 302 303
    def test_create_varbase(self):
        with fluid.dygraph.guard():
            with _test_eager_guard():
                self.func_create_varbase()
            self.func_create_varbase()

304 305 306 307 308 309 310 311 312 313
    def test_no_grad_guard(self):
        data = np.array([[2, 3], [4, 5]]).astype('float32')
        with fluid.dygraph.guard():
            l0 = fluid.Linear(2, 2)
            self.assertTrue(l0.weight._grad_ivar() is None)
            l1 = fluid.Linear(2, 2)
            with fluid.dygraph.no_grad():
                self.assertTrue(l1.weight.stop_gradient is False)
                tmp = l1.weight * 2
                self.assertTrue(tmp.stop_gradient)
314
            x = paddle.to_tensor(data)
315 316 317 318 319 320 321
            y = l0(x) + tmp
            o = l1(y)
            o.backward()

            self.assertTrue(tmp._grad_ivar() is None)
            self.assertTrue(l0.weight._grad_ivar() is not None)

322 323 324 325 326 327
    def test_paddle_imperative_no_grad_guard(self):
        data = np.array([[2, 3], [4, 5]]).astype('float32')
        with fluid.dygraph.guard():
            l0 = fluid.Linear(2, 2)
            self.assertTrue(l0.weight._grad_ivar() is None)
            l1 = fluid.Linear(2, 2)
328
            with paddle.no_grad():
329 330
                self.assertTrue(l1.weight.stop_gradient is False)
                tmp = l1.weight * 2
331
                self.assertTrue(tmp.stop_gradient)
332
            x = paddle.to_tensor(data)
333 334 335 336
            y = l0(x) + tmp
            o = l1(y)
            o.backward()

337
            self.assertTrue(tmp._grad_ivar() is None)
338 339
            self.assertTrue(l0.weight._grad_ivar() is not None)

340 341 342 343 344 345 346 347 348 349 350 351 352
    def test_paddle_imperative_set_grad_enabled(self):
        data = np.array([[2, 3], [4, 5]]).astype('float32')
        with fluid.dygraph.guard():
            l0 = fluid.Linear(2, 2)
            self.assertTrue(l0.weight._grad_ivar() is None)
            l1 = fluid.Linear(2, 2)
            with paddle.set_grad_enabled(False):
                self.assertTrue(l1.weight.stop_gradient is False)
                tmp = l1.weight * 2
                with paddle.set_grad_enabled(True):
                    tmp2 = l1.weight * 2
                self.assertTrue(tmp.stop_gradient)
                self.assertTrue(tmp2.stop_gradient is False)
353
            x = paddle.to_tensor(data)
354 355 356 357 358 359 360 361
            y = l0(x) + tmp2
            o = l1(y)
            o.backward()

            self.assertTrue(tmp._grad_ivar() is None)
            self.assertTrue(tmp2._grad_ivar() is not None)
            self.assertTrue(l0.weight._grad_ivar() is not None)

W
wuhuanzhou 已提交
362 363 364 365 366 367 368
    def test_paddle_imperative_is_grad_enabled(self):
        with fluid.dygraph.guard():
            with paddle.set_grad_enabled(False):
                self.assertTrue(paddle.is_grad_enabled() is False)
                with paddle.set_grad_enabled(True):
                    self.assertTrue(paddle.is_grad_enabled())

M
minqiyang 已提交
369 370
    def test_sum_op(self):
        x = np.ones([2, 2], np.float32)
L
lujun 已提交
371
        with fluid.dygraph.guard():
M
minqiyang 已提交
372 373
            inputs = []
            for _ in range(10):
374
                tmp = paddle.to_tensor(x)
375 376
                tmp.stop_gradient = False
                inputs.append(tmp)
M
minqiyang 已提交
377 378
            ret = fluid.layers.sums(inputs)
            loss = fluid.layers.reduce_sum(ret)
L
lujun 已提交
379
            loss.backward()
380 381 382
        with fluid.dygraph.guard():
            inputs2 = []
            for _ in range(10):
383
                tmp = paddle.to_tensor(x)
384 385
                tmp.stop_gradient = False
                inputs2.append(tmp)
386 387
            ret2 = fluid.layers.sums(inputs2)
            loss2 = fluid.layers.reduce_sum(ret2)
388 389
            fluid.set_flags({'FLAGS_sort_sum_gradient': True})
            loss2.backward()
390

391 392
            self.assertTrue(np.allclose(ret.numpy(), x * 10))
            self.assertTrue(np.allclose(inputs[0].gradient(), x))
393 394 395
            self.assertTrue(np.allclose(ret2.numpy(), x * 10))
            a = inputs2[0].gradient()
            self.assertTrue(np.allclose(inputs2[0].gradient(), x))
M
minqiyang 已提交
396

397 398 399 400 401 402 403 404 405
    def test_empty_var(self):
        with fluid.dygraph.guard():
            cur_program = fluid.Program()
            cur_block = cur_program.current_block()
            new_variable = cur_block.create_var(
                name="X", shape=[-1, 23, 48], dtype='float32')
            try:
                new_variable.numpy()
            except Exception as e:
406
                assert type(e) == ValueError
407 408 409 410

            try:
                new_variable.backward()
            except Exception as e:
411
                assert type(e) == core.EnforceNotMet
412 413 414 415

            try:
                new_variable.clear_gradient()
            except Exception as e:
416
                assert type(e) == core.EnforceNotMet
417 418 419 420

    def test_empty_grad(self):
        with fluid.dygraph.guard():
            x = np.ones([2, 2], np.float32)
421
            new_var = paddle.to_tensor(x)
422 423 424 425 426 427 428 429
            try:
                new_var.gradient()
            except Exception as e:
                assert type(e) == ValueError

            try:
                new_var.clear_gradient()
            except Exception as e:
430
                assert type(e) == core.EnforceNotMet
431 432 433 434 435 436 437 438 439 440 441 442 443 444

        with fluid.dygraph.guard():
            cur_program = fluid.Program()
            cur_block = cur_program.current_block()
            new_variable = cur_block.create_var(
                name="X", shape=[-1, 23, 48], dtype='float32')
            try:
                new_variable.gradient()
            except Exception as e:
                assert type(e) == ValueError

    def test_set_persistable(self):
        with fluid.dygraph.guard():
            x = np.ones([2, 2], np.float32)
445
            new_var = paddle.to_tensor(x)
446 447
            self.assertFalse(new_var.persistable)
            new_var.persistable = True
448
            self.assertTrue(new_var.persistable)
449

M
minqiyang 已提交
450
    def test_layer(self):
L
lujun 已提交
451
        with fluid.dygraph.guard():
452
            l = fluid.Layer("l")
M
minqiyang 已提交
453 454 455 456
            self.assertRaises(NotImplementedError, l.forward, [])

    def test_layer_in_out(self):
        np_inp = np.array([1.0, 2.0, -1.0], dtype=np.float32)
L
lujun 已提交
457
        with fluid.dygraph.guard():
458
            var_inp = paddle.to_tensor(np_inp)
459
            var_inp.stop_gradient = False
460
            l = MyLayer()
M
minqiyang 已提交
461 462
            x = l(var_inp)[0]
            self.assertIsNotNone(x)
463
            dy_out = x.numpy()
L
lujun 已提交
464
            x.backward()
465
            dy_grad = l._x_for_debug.gradient()
M
minqiyang 已提交
466

467
        with fluid.dygraph.guard():
468
            var_inp2 = paddle.to_tensor(np_inp)
469
            var_inp2.stop_gradient = False
470
            l2 = MyLayer()
471 472 473
            x2 = l2(var_inp2)[0]
            self.assertIsNotNone(x2)
            dy_out2 = x2.numpy()
474 475
            fluid.set_flags({'FLAGS_sort_sum_gradient': True})
            x2.backward()
476 477
            dy_grad2 = l2._x_for_debug.gradient()

M
minqiyang 已提交
478 479 480
        with new_program_scope():
            inp = fluid.layers.data(
                name="inp", shape=[3], append_batch_size=False)
481
            l = MyLayer()
M
minqiyang 已提交
482 483 484 485 486 487 488 489 490 491 492 493
            x = l(inp)[0]
            param_grads = fluid.backward.append_backward(
                x, parameter_list=[l._x_for_debug.name])[0]
            exe = fluid.Executor(fluid.CPUPlace(
            ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))

            static_out, static_grad = exe.run(
                feed={inp.name: np_inp},
                fetch_list=[x.name, param_grads[1].name])

        self.assertTrue(np.allclose(dy_out, static_out))
        self.assertTrue(np.allclose(dy_grad, static_grad))
494 495
        self.assertTrue(np.allclose(dy_out2, static_out))
        self.assertTrue(np.allclose(dy_grad2, static_grad))
M
minqiyang 已提交
496 497 498

    def test_mlp(self):
        np_inp = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
L
lujun 已提交
499
        with fluid.dygraph.guard():
500
            var_inp = paddle.to_tensor(np_inp)
501
            mlp = MLP(input_size=2)
M
minqiyang 已提交
502
            out = mlp(var_inp)
503
            dy_out = out.numpy()
L
lujun 已提交
504
            out.backward()
505
            dy_grad = mlp._linear1.weight.gradient()
M
minqiyang 已提交
506

507
        with fluid.dygraph.guard():
508
            var_inp2 = paddle.to_tensor(np_inp)
509
            mlp2 = MLP(input_size=2)
510 511
            out2 = mlp2(var_inp2)
            dy_out2 = out2.numpy()
512 513
            fluid.set_flags({'FLAGS_sort_sum_gradient': True})
            out2.backward()
514
            dy_grad2 = mlp2._linear1.weight.gradient()
515

M
minqiyang 已提交
516 517 518
        with new_program_scope():
            inp = fluid.layers.data(
                name="inp", shape=[2, 2], append_batch_size=False)
519
            mlp = MLP(input_size=2)
M
minqiyang 已提交
520 521
            out = mlp(inp)
            param_grads = fluid.backward.append_backward(
522
                out, parameter_list=[mlp._linear1.weight.name])[0]
M
minqiyang 已提交
523 524 525 526 527 528 529 530 531 532
            exe = fluid.Executor(fluid.CPUPlace(
            ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
            exe.run(fluid.default_startup_program())

            static_out, static_grad = exe.run(
                feed={inp.name: np_inp},
                fetch_list=[out.name, param_grads[1].name])

        self.assertTrue(np.allclose(dy_out, static_out))
        self.assertTrue(np.allclose(dy_grad, static_grad))
533 534
        self.assertTrue(np.allclose(dy_out2, static_out))
        self.assertTrue(np.allclose(dy_grad2, static_grad))
M
minqiyang 已提交
535 536

        params = mlp.parameters(True)
537 538 539 540
        self.assertEqual("linear_0.w_0", params[0].name)
        self.assertEqual("linear_0.b_0", params[1].name)
        self.assertEqual("linear_1.w_0", params[2].name)
        self.assertEqual("linear_1.b_0", params[3].name)
M
minqiyang 已提交
541 542
        self.assertEqual(len(params), 4)

J
Jiabin Yang 已提交
543
        sublayers = mlp.sublayers()
544 545
        self.assertEqual(mlp._linear1, sublayers[0])
        self.assertEqual(mlp._linear2, sublayers[1])
M
minqiyang 已提交
546 547
        self.assertEqual(len(sublayers), 2)

548 549 550 551 552 553 554
    def test_gradient_accumulation(self):
        def test_single_api(sort_sum_gradient):
            fluid.set_flags({'FLAGS_sort_sum_gradient': sort_sum_gradient})
            x = paddle.to_tensor(5., stop_gradient=False)
            for i in range(10):
                y = paddle.pow(x, 4.0)
                y.backward()
555
                self.assertEqual(x.grad.numpy(), (i + 1) * 500)
556
            x.clear_gradient()
557
            self.assertEqual(x.grad.numpy(), 0.)
558
            for i in range(10):
559 560
                y = paddle.pow(x, 4.0)
                y.backward()
561
                self.assertEqual(x.grad.numpy(), (i + 1) * 500)
562
            x.clear_grad()
563
            self.assertEqual(x.grad.numpy(), 0.)
564 565 566 567 568 569 570 571 572 573

        def test_simple_net(sort_sum_gradient):
            fluid.set_flags({'FLAGS_sort_sum_gradient': sort_sum_gradient})
            x = paddle.to_tensor(5., stop_gradient=False)
            y = paddle.to_tensor(2., stop_gradient=False)
            z = paddle.to_tensor(3., stop_gradient=False)

            def fun(x, y, z):
                loss1 = x * x * y
                loss2 = x * z
574 575
                loss1.backward(retain_graph=True)
                loss2.backward(retain_graph=True)
576 577 578
                self.assertTrue(np.array_equal(x.grad.numpy(), [23.]))
                self.assertTrue(np.array_equal(y.grad.numpy(), [25.]))
                self.assertTrue(np.array_equal(z.grad.numpy(), [5.]))
579 580 581 582
                x.clear_grad()
                y.clear_grad()
                z.clear_grad()

583 584
                dx = paddle.grad([loss1], x, create_graph=True)[0]
                loss = loss1 + loss2 + dx
585
                # loss = x*x*y + x*z + 2*x*y
586 587 588 589 590
                return loss

            loss = fun(x, y, z)
            loss.backward(retain_graph=True)
            # x.grad = 2*x*y + z + 2*y = 27 
591
            self.assertTrue(np.array_equal(x.grad.numpy(), [27]))
592 593

            loss.backward(retain_graph=True)
594
            self.assertTrue(np.array_equal(x.grad.numpy(), [54]))
595 596

            loss.backward()
597
            self.assertTrue(np.array_equal(x.grad.numpy(), [81]))
598 599 600 601 602 603 604 605 606

            with self.assertRaises(RuntimeError):
                loss.backward()

            loss1 = x * x * y
            loss2 = x * z
            dx = paddle.grad([loss1], x, create_graph=True)[0]
            loss = loss1 + loss2 + dx
            loss.backward()
607 608
            self.assertTrue(np.array_equal(dx.grad.numpy(), [1]))
            self.assertTrue(np.array_equal(x.grad.numpy(), [108]))
609 610 611 612 613 614 615 616 617

        def test_mlp(sort_sum_gradient):
            fluid.set_flags({'FLAGS_sort_sum_gradient': sort_sum_gradient})
            input_size = 5
            paddle.seed(1)
            mlp1 = MLP(input_size=input_size)
            # generate the gradient of each step
            mlp2 = MLP(input_size=input_size)

618 619 620 621
            expected_weight1_grad = 0.
            expected_bias1_grad = 0.
            expected_weight2_grad = 0.
            expected_bias2_grad = 0.
622

623
            for batch_id in range(100):
624 625 626 627
                x = paddle.uniform([10, input_size])
                detach_x = x.detach()
                clear_loss = mlp2(detach_x)
                clear_loss.backward()
628 629 630 631 632 633 634 635
                expected_weight1_grad = (
                    expected_weight1_grad + mlp2._linear1.weight.grad.numpy())
                expected_bias1_grad = (
                    expected_bias1_grad + mlp2._linear1.bias.grad.numpy())
                expected_weight2_grad = (
                    expected_weight2_grad + mlp2._linear2.weight.grad.numpy())
                expected_bias2_grad = (
                    expected_bias2_grad + mlp2._linear2.bias.grad.numpy())
636 637 638 639

                loss = mlp1(x)
                loss.backward()

640
                self.assertTrue(np.array_equal(loss.grad.numpy(), [1]))
641
                self.assertTrue(
642
                    np.allclose(mlp1._linear1.weight.grad.numpy(),
643 644
                                expected_weight1_grad))
                self.assertTrue(
645 646
                    np.allclose(mlp1._linear1.bias.grad.numpy(),
                                expected_bias1_grad))
647
                self.assertTrue(
648
                    np.allclose(mlp1._linear2.weight.grad.numpy(),
649 650
                                expected_weight2_grad))
                self.assertTrue(
651 652
                    np.allclose(mlp1._linear2.bias.grad.numpy(),
                                expected_bias2_grad))
653 654

                mlp2.clear_gradients()
655
                self.assertTrue(np.array_equal(clear_loss.grad.numpy(), [1]))
656
                if ((batch_id + 1) % 10) % 2 == 0:
657
                    mlp1.clear_gradients()
658 659 660 661
                    expected_weight1_grad = 0.
                    expected_bias1_grad = 0.
                    expected_weight2_grad = 0.
                    expected_bias2_grad = 0.
662 663 664 665 666 667 668 669 670
                elif ((batch_id + 1) % 10) % 2 == 1:
                    mlp1.clear_gradients()
                    mlp1._linear1.weight._set_grad_ivar(
                        paddle.ones([input_size, 3]))
                    mlp1._linear2.weight._set_grad_ivar(paddle.ones([3, 4]))
                    expected_weight1_grad = 1.
                    expected_bias1_grad = 0.
                    expected_weight2_grad = 1.
                    expected_bias2_grad = 0.
671 672 673 674 675 676 677 678 679

        with fluid.dygraph.guard():
            test_single_api(False)
            test_single_api(True)
            test_simple_net(False)
            test_simple_net(True)
            test_mlp(False)
            test_mlp(True)

X
Xin Pan 已提交
680
    def test_dygraph_vs_static(self):
681 682
        np_inp1 = np.random.rand(4, 3, 3)
        np_inp2 = np.random.rand(4, 3, 3)
X
Xin Pan 已提交
683 684 685

        # dynamic graph
        with fluid.dygraph.guard():
686 687
            inp1 = paddle.to_tensor(np_inp1)
            inp2 = paddle.to_tensor(np_inp2)
688
            if np.sum(np_inp1) < np.sum(np_inp2):
X
Xin Pan 已提交
689 690 691
                x = fluid.layers.elementwise_add(inp1, inp2)
            else:
                x = fluid.layers.elementwise_sub(inp1, inp2)
L
lujun 已提交
692
            dygraph_result = x.numpy()
X
Xin Pan 已提交
693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725

        # static graph
        with new_program_scope():
            inp_data1 = fluid.layers.data(
                name='inp1', shape=[3, 3], dtype=np.float32)
            inp_data2 = fluid.layers.data(
                name='inp2', shape=[3, 3], dtype=np.float32)

            a = fluid.layers.expand(
                fluid.layers.reshape(
                    fluid.layers.reduce_sum(inp_data1), [1, 1]), [4, 1])
            b = fluid.layers.expand(
                fluid.layers.reshape(
                    fluid.layers.reduce_sum(inp_data2), [1, 1]), [4, 1])
            cond = fluid.layers.less_than(x=a, y=b)

            ie = fluid.layers.IfElse(cond)
            with ie.true_block():
                d1 = ie.input(inp_data1)
                d2 = ie.input(inp_data2)
                d3 = fluid.layers.elementwise_add(d1, d2)
                ie.output(d3)

            with ie.false_block():
                d1 = ie.input(inp_data1)
                d2 = ie.input(inp_data2)
                d3 = fluid.layers.elementwise_sub(d1, d2)
                ie.output(d3)
            out = ie()

            exe = fluid.Executor(fluid.CPUPlace(
            ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
            static_result = exe.run(fluid.default_main_program(),
726 727
                                    feed={'inp1': np_inp1,
                                          'inp2': np_inp2},
X
Xin Pan 已提交
728 729 730
                                    fetch_list=out)[0]
        self.assertTrue(np.allclose(dygraph_result, static_result))

M
minqiyang 已提交
731 732 733 734 735
    def test_rnn(self):
        np_inp = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0],
                           [10.0, 11.0, 12.0]])
        np_inp = np_inp.reshape((1, 4, 3))
        np_inp = np_inp.astype(np.float32)
L
lujun 已提交
736
        with fluid.dygraph.guard():
737
            var_inp = paddle.to_tensor(np_inp)
M
minqiyang 已提交
738
            var_inp = fluid.layers.reshape(var_inp, shape=[1, 4, 3])
739
            simple_rnn = SimpleRNN()
M
minqiyang 已提交
740
            outs, pre_hiddens = simple_rnn.forward(var_inp)
741
            dy_out = outs[3].numpy()
L
lujun 已提交
742
            outs[3].backward()
743 744 745
            dy_grad_h2o = simple_rnn._cell._h2o_w.gradient()
            dy_grad_h2h = simple_rnn._cell._h2h_w.gradient()
            dy_grad_i2h = simple_rnn._cell._i2h_w.gradient()
M
minqiyang 已提交
746

747
        with fluid.dygraph.guard():
748
            var_inp2 = paddle.to_tensor(np_inp)
749
            var_inp2 = fluid.layers.reshape(var_inp2, shape=[1, 4, 3])
750
            simple_rnn2 = SimpleRNN()
751 752
            outs2, pre_hiddens2 = simple_rnn2.forward(var_inp2)
            dy_out2 = outs2[3].numpy()
753 754
            fluid.set_flags({'FLAGS_sort_sum_gradient': True})
            outs2[3].backward()
755 756 757 758
            dy_grad_h2o2 = simple_rnn2._cell._h2o_w.gradient()
            dy_grad_h2h2 = simple_rnn2._cell._h2h_w.gradient()
            dy_grad_i2h2 = simple_rnn2._cell._i2h_w.gradient()

M
minqiyang 已提交
759 760 761
        with new_program_scope():
            inp = fluid.layers.data(
                name="inp", shape=[1, 4, 3], append_batch_size=False)
762
            simple_rnn = SimpleRNN()
M
minqiyang 已提交
763 764 765 766 767 768 769 770 771 772
            outs, pre_hiddens = simple_rnn(inp)
            param_grads = fluid.backward.append_backward(outs[3])
            exe = fluid.Executor(fluid.CPUPlace())
            exe.run(fluid.default_startup_program())
            static_out, static_grad_h2o, static_grad_h2h, static_grad_i2h = exe.run(
                feed={inp.name: np_inp},
                fetch_list=[
                    outs[3].name, param_grads[0][1].name,
                    param_grads[1][1].name, param_grads[2][1].name
                ])
773

M
minqiyang 已提交
774 775 776 777
        self.assertTrue(np.allclose(dy_out, static_out))
        self.assertTrue(np.allclose(dy_grad_h2o, static_grad_h2o))
        self.assertTrue(np.allclose(dy_grad_h2h, static_grad_h2h))
        self.assertTrue(np.allclose(dy_grad_i2h, static_grad_i2h))
778 779 780 781
        self.assertTrue(np.allclose(dy_out2, static_out))
        self.assertTrue(np.allclose(dy_grad_h2o2, static_grad_h2o))
        self.assertTrue(np.allclose(dy_grad_h2h2, static_grad_h2h))
        self.assertTrue(np.allclose(dy_grad_i2h2, static_grad_i2h))
M
minqiyang 已提交
782

783
    def func_layer_attrs(self):
784 785 786 787 788 789
        layer = fluid.dygraph.Layer("test")
        layer.test_attr = 1
        self.assertFalse(hasattr(layer, "whatever"))
        self.assertTrue(hasattr(layer, "test_attr"))
        self.assertEqual(layer.test_attr, 1)

790 791 792 793 794 795 796 797 798 799 800 801 802
        my_layer = MyLayer()
        my_layer.w1 = my_layer.create_parameter([3, 3])
        my_layer.add_parameter('w2', None)
        self.assertEqual(len(my_layer.parameters()), 1)
        self.assertRaises(TypeError, my_layer.__setattr__, 'w1', 'str')
        my_layer.w1 = None
        self.assertEqual(len(my_layer.parameters()), 0)
        my_layer.l1 = fluid.dygraph.Linear(3, 3)
        self.assertEqual(len(my_layer.sublayers()), 1)
        self.assertRaises(TypeError, my_layer.__setattr__, 'l1', 'str')
        my_layer.l1 = None
        self.assertEqual(len(my_layer.sublayers()), 0)

803 804 805 806 807
    def test_layer_attrs(self):
        with _test_eager_guard():
            self.func_layer_attrs()
        self.func_layer_attrs()

808

809
class TestDygraphUtils(unittest.TestCase):
810
    def func_append_activation_in_dygraph_exception(self):
811 812 813 814 815 816
        with new_program_scope():
            np_inp = np.random.random(size=(10, 20, 30)).astype(np.float32)
            a = fluid.layers.data("a", [10, 20])
            func = dygraph_utils._append_activation_in_dygraph
            self.assertRaises(AssertionError, func, a, act="sigmoid")

817 818 819 820 821 822
    def test_append_activation_in_dygraph_exception(self):
        with _test_eager_guard():
            self.func_append_activation_in_dygraph_exception()
        self.func_append_activation_in_dygraph_exception()

    def func_append_activation_in_dygraph1(self):
823 824 825
        a_np = np.random.random(size=(10, 20, 30)).astype(np.float32)
        func = dygraph_utils._append_activation_in_dygraph
        with fluid.dygraph.guard():
826
            a = paddle.to_tensor(a_np)
827 828 829 830
            res1 = func(a, act="hard_sigmoid")
            res2 = fluid.layers.hard_sigmoid(a)
            self.assertTrue(np.array_equal(res1.numpy(), res2.numpy()))

831 832 833 834 835 836
    def test_append_activation_in_dygraph1(self):
        with _test_eager_guard():
            self.func_append_activation_in_dygraph1()
        self.func_append_activation_in_dygraph1()

    def func_append_activation_in_dygraph2(self):
837 838 839
        a_np = np.random.random(size=(10, 20, 30)).astype(np.float32)
        func = dygraph_utils._append_activation_in_dygraph
        with fluid.dygraph.guard():
840
            a = paddle.to_tensor(a_np)
841 842
            res1 = func(a, act="sigmoid", use_mkldnn=True, use_cudnn=True)
            res2 = fluid.layers.sigmoid(a)
843
            self.assertTrue(np.allclose(res1.numpy(), res2.numpy()))
844

845 846 847 848 849 850
    def test_append_activation_in_dygraph2(self):
        with _test_eager_guard():
            self.func_append_activation_in_dygraph2()
        self.func_append_activation_in_dygraph2()

    def func_append_activation_in_dygraph3(self):
851 852 853 854
        a_np = np.random.random(size=(10, 20, 30)).astype(np.float32)
        helper = LayerObjectHelper(fluid.unique_name.generate("test"))
        func = helper.append_activation
        with fluid.dygraph.guard():
855
            a = paddle.to_tensor(a_np)
856 857 858 859
            res1 = func(a, act="sigmoid", use_cudnn=True)
            res2 = fluid.layers.sigmoid(a)
            self.assertTrue(np.array_equal(res1.numpy(), res2.numpy()))

860 861 862 863 864 865
    def test_append_activation_in_dygraph3(self):
        with _test_eager_guard():
            self.func_append_activation_in_dygraph3()
        self.func_append_activation_in_dygraph3()

    def func_append_activation_in_dygraph_use_mkldnn(self):
866 867 868 869 870
        a_np = np.random.uniform(-2, 2, (10, 20, 30)).astype(np.float32)
        helper = LayerHelper(
            fluid.unique_name.generate("test"), act="relu", use_mkldnn=True)
        func = helper.append_activation
        with fluid.dygraph.guard():
871
            a = paddle.to_tensor(a_np)
872 873 874 875
            res1 = func(a)
            res2 = fluid.layers.relu(a)
            self.assertTrue(np.array_equal(res1.numpy(), res2.numpy()))

876 877 878 879 880 881
    def test_append_activation_in_dygraph_use_mkldnn(self):
        with _test_eager_guard():
            self.func_append_activation_in_dygraph_use_mkldnn()
        self.func_append_activation_in_dygraph_use_mkldnn()

    def func_append_activation_in_dygraph_global_use_mkldnn(self):
882 883 884
        a_np = np.random.uniform(-2, 2, (10, 20, 30)).astype(np.float32)
        helper = LayerHelper(fluid.unique_name.generate("test"), act="relu")
        func = helper.append_activation
885
        with fluid.dygraph.guard(fluid.core.CPUPlace()):
886
            a = paddle.to_tensor(a_np)
887 888 889 890 891 892 893 894
            fluid.set_flags({'FLAGS_use_mkldnn': True})
            try:
                res1 = func(a)
            finally:
                fluid.set_flags({'FLAGS_use_mkldnn': False})
            res2 = fluid.layers.relu(a)
        self.assertTrue(np.array_equal(res1.numpy(), res2.numpy()))

895 896 897 898 899 900
    def test_append_activation_in_dygraph_global_use_mkldnn(self):
        with _test_eager_guard():
            self.func_append_activation_in_dygraph_global_use_mkldnn()
        self.func_append_activation_in_dygraph_global_use_mkldnn()

    def func_append_bias_in_dygraph_exception(self):
901 902 903 904 905 906
        with new_program_scope():
            np_inp = np.random.random(size=(10, 20, 30)).astype(np.float32)
            a = fluid.layers.data("a", [10, 20])
            func = dygraph_utils._append_bias_in_dygraph
            self.assertRaises(AssertionError, func, a)

907 908 909 910 911 912
    def test_append_bias_in_dygraph_exception(self):
        with _test_eager_guard():
            self.func_append_bias_in_dygraph_exception()
        self.func_append_bias_in_dygraph_exception()

    def func_append_bias_in_dygraph(self):
913 914 915
        a_np = np.random.random(size=(10, 20, 30)).astype(np.float32)
        func = dygraph_utils._append_bias_in_dygraph
        with fluid.dygraph.guard():
916
            a = paddle.to_tensor(a_np)
917
            res1 = func(a, bias=a)
918
            res2 = paddle.add(a, a)
919 920
            self.assertTrue(np.array_equal(res1.numpy(), res2.numpy()))

921 922 923 924 925
    def test_append_bias_in_dygraph(self):
        with _test_eager_guard():
            self.func_append_bias_in_dygraph()
        self.func_append_bias_in_dygraph()

926

927
class TestDygraphGuardWithError(unittest.TestCase):
928
    def func_without_guard(self):
929
        with fluid.dygraph.guard():
930
            x = paddle.to_tensor(np.zeros([10, 10]))
931 932 933 934
        with self.assertRaisesRegexp(TypeError,
                                     "Please use `with fluid.dygraph.guard()"):
            y = fluid.layers.matmul(x, x)

935 936 937 938 939
    def test_without_guard(self):
        with _test_eager_guard():
            self.func_without_guard()
        self.func_without_guard()

940

941
class TestMetaclass(unittest.TestCase):
942
    def func_metaclass(self):
943 944
        self.assertEqual(type(MyLayer).__name__, 'type')
        self.assertNotEqual(type(MyLayer).__name__, 'pybind11_type')
945 946 947 948 949 950 951 952 953 954 955
        if core._in_eager_mode():
            self.assertEqual(
                type(paddle.fluid.core.eager.EagerTensor).__name__, 'type')
        else:
            self.assertEqual(
                type(paddle.fluid.core.VarBase).__name__, 'pybind11_type')

    def test_metaclass(self):
        with _test_eager_guard():
            self.func_metaclass()
        self.func_metaclass()
956 957


958
if __name__ == '__main__':
959
    paddle.enable_static()
960
    unittest.main()