test_pylayer_op.py 26.5 KB
Newer Older
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15 16 17 18
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np

import paddle
19
from paddle.autograd.py_layer import LegacyPyLayer, EagerPyLayer
W
wanghuancoder 已提交
20
from paddle.fluid.framework import _test_eager_guard, in_dygraph_mode
21 22


23
class FakeTensor(paddle.fluid.core.VarBase):
24

25 26 27 28
    def __init__(self):
        pass


29
class TestPyLayer(unittest.TestCase):
30

W
wanghuancoder 已提交
31
    def func_test_simple_pylayer_multiple_output(self):
32

33
        class tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
34

35 36 37 38 39 40
            @staticmethod
            def forward(ctx, x1, x2, func1, func2=paddle.square):
                ctx.func = func2
                y1 = func1(x1)
                y2 = func1(x2)
                ctx.save_for_backward(y1, y2)
W
WeiXin 已提交
41
                return y1, 1, y2, None
42 43 44 45 46 47 48 49 50 51 52 53 54

            @staticmethod
            def backward(ctx, dy1, dy2):
                y1, y2 = ctx.saved_tensor()
                re1 = dy1 * (1 - ctx.func(y1))
                re2 = dy2 * (1 - paddle.square(y2))
                return re1, re2

        input1 = paddle.randn([2, 3]).astype("float64")
        input2 = input1.detach().clone()
        input1.stop_gradient = False
        input2.stop_gradient = False
        z = tanh.apply(input1, input1, paddle.tanh, paddle.square)
W
WeiXin 已提交
55
        z = z[0] + z[2]
56 57 58 59 60
        z.mean().backward()

        z2 = paddle.tanh(input2) + paddle.tanh(input2)
        z2.mean().backward()

61 62
        self.assertTrue(
            np.max(np.abs((input1.grad.numpy() - input2.grad.numpy()))) < 1e-10)
63

W
wanghuancoder 已提交
64 65 66 67 68 69
    def test_simple_pylayer_multiple_output(self):
        with _test_eager_guard():
            self.func_test_simple_pylayer_multiple_output()
        self.func_test_simple_pylayer_multiple_output()

    def func_test_simple_pylayer_return_none_with_no_grad(self):
70

71
        class tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
72

73 74 75 76 77 78
            @staticmethod
            def forward(ctx, x1, x2, func1, func2=paddle.square):
                ctx.func = func2
                y1 = func1(x1)
                y2 = func1(x2)
                ctx.save_for_backward(y1, y2)
W
WeiXin 已提交
79
                return 1, None, y1, y2, ''
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96

            @staticmethod
            def backward(ctx, dy1, dy2):
                y1, y2 = ctx.saved_tensor()
                re1 = dy1 * (1 - ctx.func(y1))
                re2 = dy2 * (1 - paddle.square(y2))
                return re1, None

        input1 = paddle.randn([2, 3]).astype("float64")
        input2 = input1.detach().clone()
        input3 = input1.detach().clone()
        input4 = input1.detach().clone()
        input1.stop_gradient = False
        input2.stop_gradient = False
        input3.stop_gradient = True
        input4.stop_gradient = True
        z = tanh.apply(input1, input3, paddle.tanh, paddle.square)
W
WeiXin 已提交
97
        z = z[2] + z[3]
98 99 100 101 102
        z.mean().backward()

        z2 = paddle.tanh(input2) + paddle.tanh(input4)
        z2.mean().backward()

W
WeiXin 已提交
103 104
        self.assertTrue(
            np.max(np.abs((input1.grad.numpy() - input2.grad.numpy()))) < 1e-10)
105

W
wanghuancoder 已提交
106 107 108 109 110 111
    def test_simple_pylayer_return_none_with_no_grad(self):
        with _test_eager_guard():
            self.func_test_simple_pylayer_return_none_with_no_grad()
        self.func_test_simple_pylayer_return_none_with_no_grad()

    def func_test_simple_pylayer_single_output(self):
112

113
        class tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
114

115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
            @staticmethod
            def forward(ctx, x1, func1, func2=paddle.square):
                ctx.func = func2
                y1 = func1(x1)
                ctx.save_for_backward(y1)
                return y1

            @staticmethod
            def backward(ctx, dy1):
                y1, = ctx.saved_tensor()
                re1 = dy1 * (1 - ctx.func(y1))
                return re1

        input1 = paddle.randn([2, 3]).astype("float64")
        input2 = input1.detach().clone()
        input1.stop_gradient = False
        input2.stop_gradient = False
        z = tanh.apply(x1=input1, func1=paddle.tanh)
        z.mean().backward()
        z2 = paddle.tanh(input2)
        z2.mean().backward()

137 138
        self.assertTrue(
            np.max(np.abs((input1.grad.numpy() - input2.grad.numpy()))) < 1e-10)
139

W
wanghuancoder 已提交
140 141 142 143 144 145
    def test_simple_pylayer_single_output(self):
        with _test_eager_guard():
            self.func_test_simple_pylayer_single_output()
        self.func_test_simple_pylayer_single_output()

    def func_test_pylayer_num_output_match(self):
146

147
        class tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
148

W
WeiXin 已提交
149 150
            @staticmethod
            def forward(
151 152 153 154
                ctx,
                x1,
                x2,
            ):
W
WeiXin 已提交
155 156 157 158 159 160 161 162 163 164 165 166 167 168
                return x1 + x2

            @staticmethod
            def backward(ctx, dy1):
                return dy1 + 1

        input1 = paddle.randn([2, 3]).astype("float64")
        input2 = input1.detach().clone()
        input1.stop_gradient = False
        input2.stop_gradient = False
        z = tanh.apply(input1, input2)
        with self.assertRaises(ValueError):
            z.mean().backward()

W
wanghuancoder 已提交
169 170 171 172 173 174
    def test_pylayer_num_output_match(self):
        with _test_eager_guard():
            self.func_test_pylayer_num_output_match()
        self.func_test_pylayer_num_output_match()

    def func_test_pylayer_dtype(self):
175

176
        class tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
177

178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
            @staticmethod
            def forward(ctx, x, dtype):
                y = paddle.cast(x, dtype)
                return y

            @staticmethod
            def backward(ctx, dy1):
                return dy1

        dtypes = [
            'bool', 'float16', 'float32', 'float64', 'uint8', 'int32', 'int64'
        ]
        for dtype in dtypes:
            input1 = (paddle.randn([2, 3]))
            input1.stop_gradient = False
            self.assertTrue(input1.grad is None)

            z = tanh.apply(input1, dtype)
            z = paddle.cast(z, "float32")
            z.sum().backward()
            self.assertTrue(input1.grad is not None)

W
wanghuancoder 已提交
200 201 202 203 204 205
    def test_pylayer_dtype(self):
        with _test_eager_guard():
            self.func_test_pylayer_dtype()
        self.func_test_pylayer_dtype()

    def func_test_pylayer_Exception_forward(self):
206

207
        class Layer_None1(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
208

209 210 211 212 213 214 215 216 217
            @staticmethod
            def forward(ctx, *args):
                return None

            @staticmethod
            def backward(ctx, *args):
                return args

        input1 = paddle.randn([2, 3]).astype("float64")
W
WeiXin 已提交
218
        with self.assertRaises(ValueError):
219 220
            z = Layer_None1.apply(input1)

221
        class Layer_None2(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
222

223 224
            @staticmethod
            def forward(ctx, *args):
W
WeiXin 已提交
225
                return [None, args[0]]
226 227 228 229 230 231

            @staticmethod
            def backward(ctx, *args):
                return args

        input1 = paddle.randn([2, 3]).astype("float64")
W
WeiXin 已提交
232 233
        # return None
        z = Layer_None2.apply(input1)
234

235
        class Layer_one1(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
236

237 238 239 240 241 242 243 244 245
            @staticmethod
            def forward(ctx, *args):
                return 1

            @staticmethod
            def backward(ctx, *args):
                return args

        input1 = paddle.randn([2, 3]).astype("float64")
W
WeiXin 已提交
246 247
        # At least one output of `PyLayer.backward` is a `Tensor`
        with self.assertRaises(ValueError):
248 249
            z = Layer_one1.apply(input1)

250
        class Layer_one2(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
251

252 253
            @staticmethod
            def forward(ctx, *args):
W
WeiXin 已提交
254
                return [1, 2, args[0]]
255 256 257 258 259 260

            @staticmethod
            def backward(ctx, *args):
                return args

        input1 = paddle.randn([2, 3]).astype("float64")
261
        # return int
W
WeiXin 已提交
262
        z = Layer_one2.apply(input1)
263

264
        class Layer_no_fw(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
265

266 267 268 269 270 271 272 273
            @staticmethod
            def backward(ctx, *args):
                return args

        input1 = paddle.randn([2, 3]).astype("float64")
        with self.assertRaises(NotImplementedError):
            z = Layer_no_fw.apply(input1)

W
wanghuancoder 已提交
274 275 276 277 278 279
    def test_pylayer_Exception_forward(self):
        with _test_eager_guard():
            self.func_test_pylayer_Exception_forward()
        self.func_test_pylayer_Exception_forward()

    def func_test_pylayer_nograd(self):
280

281
        class tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
282

283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298
            @staticmethod
            def forward(ctx, x1, func1, func2=paddle.square, xx=None):
                ctx.func = func2
                y1 = func1(x1)
                return y1

            @staticmethod
            def backward(ctx, x1, y1, dy1):
                re1 = dy1 * (1 - ctx.func(y1))
                return re1

        input1 = paddle.randn([2, 3]).astype("float64")
        z = tanh.apply(input1, paddle.tanh, paddle.square)
        z.mean().backward()
        self.assertTrue(z.grad is None)

W
wanghuancoder 已提交
299 300 301 302 303 304
    def test_pylayer_nograd(self):
        with _test_eager_guard():
            self.func_test_pylayer_nograd()
        self.func_test_pylayer_nograd()

    def func_test_pylayer_Exception_bk(self):
305

306 307
        class Layer_bk_none1(
                EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
308

309 310 311 312 313 314 315 316 317 318 319 320
            @staticmethod
            def forward(ctx, x):
                return x * 2

            @staticmethod
            def backward(ctx, dy1):
                return None

        input2 = paddle.randn([2, 3]).astype("float64")
        input2.stop_gradient = False
        z = Layer_bk_none1.apply(input2)

321
        with self.assertRaises(ValueError):
322
            z.sum().backward()
323

324 325
        class Layer_bk_none2(
                EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
326

327 328 329 330 331 332 333 334 335 336 337
            @staticmethod
            def forward(ctx, x1, x2):
                return x1 + x2

            @staticmethod
            def backward(ctx, dy1):
                return None, dy1

        input1 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = False
        z = Layer_bk_none2.apply(input1, input1)
338

339
        with self.assertRaises(ValueError):
340
            z.mean().backward()
341

342 343
        class Layer_bk_one1(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer
                            ):
344

345 346 347 348 349 350 351 352 353 354 355
            @staticmethod
            def forward(ctx, x):
                return x + x

            @staticmethod
            def backward(ctx, dy):
                return 1

        input1 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = False
        z = Layer_bk_one1.apply(input1)
356

357
        with self.assertRaises(ValueError):
358
            z.mean().backward()
359

360 361
        class Layer_bk_one2(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer
                            ):
362

363
            @staticmethod
364 365
            def forward(ctx, x1, x2):
                return x1 * 2, x2 * 5
366 367 368 369 370 371 372

            @staticmethod
            def backward(ctx, *args):
                return 1, 1

        input1 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = False
373

374 375 376
        y = Layer_bk_one2.apply(input1, input1)
        z = y[0] + y[1]
        with self.assertRaises(ValueError):
377
            z.mean().backward()
378

379
        class Layer_no_bk(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
380

381 382 383 384 385 386 387 388
            @staticmethod
            def forward(ctx, x):
                return x * 2, x * 5

        input1 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = False
        z = Layer_no_bk.apply(input1)

389 390 391
        with self.assertRaises(OSError):
            z = z[0] + z[1]
            z.mean().backward()
392

393 394
        class Layer_bk_match(
                EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
395

396 397 398 399 400 401 402 403 404 405 406 407
            @staticmethod
            def forward(ctx, x):
                return x * 2, x * 5

            @staticmethod
            def backward(ctx, dy1, dy2):
                return dy2 * 2, dy1 * 2

        input1 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = False
        z = Layer_bk_match.apply(input1)
        with self.assertRaises(ValueError):
408 409
            z = z[0] + z[1]
            z.mean().backward()
410

W
wanghuancoder 已提交
411 412 413 414 415 416
    def test_pylayer_Exception_bk(self):
        with _test_eager_guard():
            self.func_test_pylayer_Exception_bk()
        self.func_test_pylayer_Exception_bk()

    def func_test_pylayer_bk_return_none(self):
417

418 419
        class Layer_bk_none1(
                EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
420

421 422 423 424 425 426 427 428 429 430 431 432 433 434 435
            @staticmethod
            def forward(ctx, x1, x2):
                return x1 + x2

            @staticmethod
            def backward(ctx, dy):
                return 1

        input1 = paddle.randn([2, 3]).astype("float64")
        input2 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = True
        input2.stop_gradient = False
        z = Layer_bk_none1.apply(input1, input2)

        with self.assertRaises(ValueError):
436
            z.mean().backward()
437

438 439
        class Layer_bk_none2(
                EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
440

441 442 443 444 445 446 447 448 449 450 451 452 453 454 455
            @staticmethod
            def forward(ctx, x1, x2):
                return x1 * 2, x2 * 5

            @staticmethod
            def backward(ctx, *args):
                return 1, 1

        input1 = paddle.randn([2, 3]).astype("float64")
        input2 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = True
        input2.stop_gradient = False
        z = Layer_bk_none2.apply(input1, input2)
        z = z[0] + z[1]
        with self.assertRaises(ValueError):
456
            z.mean().backward()
457

W
wanghuancoder 已提交
458 459 460 461 462
    def test_pylayer_bk_return_none(self):
        with _test_eager_guard():
            self.func_test_pylayer_bk_return_none()
        self.func_test_pylayer_bk_return_none()

463
    def func_test_pylayer_inplace(self):
464

465
        class cus_tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
466

467 468
            @staticmethod
            def forward(ctx, x):
469
                return x
470 471 472 473 474

            @staticmethod
            def backward(ctx, dy):
                return dy

475
        class Layer(paddle.nn.Layer):
476

477 478 479 480
            def __init__(self):
                super(Layer, self).__init__()

            def forward(self, data):
481
                data = data**2
482 483 484 485
                z = paddle.tanh(data)
                z = cus_tanh.apply(data)
                return z.mean()

486 487 488
        for i in range(2):
            data = paddle.ones([2, 3], dtype="float64") / (i + 1)
            data.stop_gradient = False
489 490 491 492 493
            layer = Layer()
            z = layer(data)
            z.backward()
            self.assertTrue(data.grad is not None)

494 495 496 497 498 499 500 501
    def test_pylayer_inplace(self):
        with _test_eager_guard():
            self.func_test_pylayer_inplace()
        self.func_test_pylayer_inplace()

    def test_pylayer_inplace_backward_error(self):
        with _test_eager_guard():

502 503
            class cus_tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer
                           ):
504

505 506 507 508 509 510 511 512 513
                @staticmethod
                def forward(ctx, x):
                    return x

                @staticmethod
                def backward(ctx, dy):
                    return dy

            class Layer(paddle.nn.Layer):
514

515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530
                def __init__(self):
                    super(Layer, self).__init__()

                def forward(self, data):
                    var_b = data**2
                    var_c = var_b**2
                    z = cus_tanh.apply(var_b)
                    loss = paddle.nn.functional.relu(var_c)
                    return loss

            data = paddle.ones([2, 3], dtype="float64")
            data.stop_gradient = False
            layer = Layer()
            z = layer(data)
            with self.assertRaisesRegexp(
                    RuntimeError,
531
                    "received tensor_version:{} != wrapper_version_snapshot:{}".
532 533 534 535 536 537
                    format(1, 0)):
                z.backward()

    def test_pylayer_inplace_backward_success_1(self):
        with _test_eager_guard():

538 539
            class cus_tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer
                           ):
540

541 542 543 544 545 546 547 548 549
                @staticmethod
                def forward(ctx, x):
                    return x

                @staticmethod
                def backward(ctx, dy):
                    return dy

            class Layer(paddle.nn.Layer):
550

551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571
                def __init__(self):
                    super(Layer, self).__init__()

                def forward(self, data):
                    var_b = data**2
                    var_c = cus_tanh.apply(var_b)
                    var_d = var_c**2
                    loss = var_d.sum()
                    return loss

            for i in range(2):
                data = paddle.ones([2, 3], dtype="float64") / (i + 1)
                data.stop_gradient = False
                layer = Layer()
                z = layer(data)
                z.backward()
                self.assertTrue(data.grad is not None)

    def test_pylayer_inplace_backward_success_2(self):
        with _test_eager_guard():

572 573
            class cus_tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer
                           ):
574

575 576 577 578 579 580 581 582 583
                @staticmethod
                def forward(ctx, x):
                    return x

                @staticmethod
                def backward(ctx, dy):
                    return dy

            class Layer(paddle.nn.Layer):
584

585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603
                def __init__(self):
                    super(Layer, self).__init__()

                def forward(self, data):
                    var_b = data**2
                    var_c = cus_tanh.apply(var_b)
                    var_d = var_c + var_c
                    loss = var_d.sum()
                    return loss

            for i in range(2):
                data = paddle.ones([2, 3], dtype="float64") / (i + 1)
                data.stop_gradient = False
                layer = Layer()
                z = layer(data)
                z.backward()
                self.assertTrue(data.grad is not None)

    def func_test_pylayer_inplace_and_leaf_exception(self):
604

605 606
        class cus_pylayer_op(
                EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
607

608 609 610 611 612 613 614 615 616
            @staticmethod
            def forward(ctx, x):
                return x

            @staticmethod
            def backward(ctx, dy):
                return dy

        class Layer(paddle.nn.Layer):
617

618 619 620 621 622 623 624 625 626 627 628 629 630 631 632
            def __init__(self):
                super(Layer, self).__init__()

            def forward(self, data):
                z = cus_pylayer_op.apply(data)
                return z.mean()

        for i in range(2):
            data = paddle.ones([2, 3], dtype="float64") / (i + 1)
            data.stop_gradient = False
            layer = Layer()

            with self.assertRaises(ValueError):
                z = layer(data)

633 634 635 636 637
    def test_pylayer_inplace_and_leaf_exception(self):
        with _test_eager_guard():
            self.func_test_pylayer_inplace_and_leaf_exception()
        self.func_test_pylayer_inplace_and_leaf_exception()

W
wanghuancoder 已提交
638
    def func_test_backward_in_backward(self):
639

640
        class cus_tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
641

642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660
            @staticmethod
            def forward(ctx, x):
                temp = x.detach()
                ctx.inputs = temp
                return x.mean()

            @staticmethod
            def backward(ctx, dy):
                with paddle.set_grad_enabled(True):
                    temp = ctx.inputs
                    temp.stop_gradient = False
                    z = paddle.tanh(temp)
                    z.backward()
                    self.assertTrue(temp.grad is not None)
                    return paddle.to_tensor(temp.grad)

        for i in range(2):
            data = paddle.ones([2, 3], dtype="float32") / (i + 1)
            data.stop_gradient = False
661 662 663 664
            data = paddle.nn.functional.relu(data)
            z = paddle.tanh(data)
            z = cus_tanh.apply(data)

W
wanghuancoder 已提交
665 666 667 668 669 670
    def test_backward_in_backward(self):
        with _test_eager_guard():
            self.func_test_backward_in_backward()
        self.func_test_backward_in_backward()

    def func_test_return_to_tensor(self):
671

672
        class Tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
673

674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693
            @staticmethod
            def forward(ctx, x1):
                y1 = paddle.tanh(x1)
                ctx.save_for_backward(y1)
                tensor_1 = paddle.to_tensor([1, 2], dtype='float32')
                return y1, 5, None, "helloworld", tensor_1

            @staticmethod
            def backward(ctx, dy1, dy2):
                y1, = ctx.saved_tensor()
                re1 = dy1 * (1 - paddle.square(y1))
                return dy1

        input1 = paddle.randn([2, 3]).astype("float32")
        input2 = input1.detach().clone()
        input1.stop_gradient = False
        input2.stop_gradient = False
        z, number, none_item, string_item, tensor1 = Tanh.apply(x1=input1)
        z.mean().backward()

W
wanghuancoder 已提交
694 695 696 697 698 699 700 701 702
    def test_return_to_tensor(self):
        with _test_eager_guard():
            self.func_test_return_to_tensor()
        self.func_test_return_to_tensor()

    def test_materialize_grads(self):
        with _test_eager_guard():

            class Tanh(EagerPyLayer):
703

W
wanghuancoder 已提交
704 705
                @staticmethod
                def forward(ctx, x):
706
                    ctx.mark_not_inplace(x)
W
wanghuancoder 已提交
707 708 709 710 711 712 713 714 715 716 717 718 719 720 721
                    return x, x + x

                @staticmethod
                def backward(ctx, grad, grad2):
                    self.assertEqual(grad2, paddle.zeros([1]))
                    return grad

            x = paddle.ones([1], dtype="float64")
            x.stop_gradient = False
            Tanh.apply(x)[0].backward()

    def test_dont_materialize_grads(self):
        with _test_eager_guard():

            class Tanh(EagerPyLayer):
722

W
wanghuancoder 已提交
723 724
                @staticmethod
                def forward(ctx, x):
725
                    ctx.mark_not_inplace(x)
W
wanghuancoder 已提交
726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741
                    ctx.set_materialize_grads(False)
                    return x, x + x

                @staticmethod
                def backward(ctx, grad, grad2):
                    self.assertIsNone(grad2)
                    return grad

            x = paddle.ones([1], dtype="float64")
            x.stop_gradient = False
            Tanh.apply(x)[0].backward()

    def test_mark_non_differentiable(self):
        with _test_eager_guard():

            class Tanh(EagerPyLayer):
742

W
wanghuancoder 已提交
743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762
                @staticmethod
                def forward(ctx, x):
                    a = x + x
                    ctx.mark_non_differentiable(a)
                    return a

                @staticmethod
                def backward(ctx, grad):
                    self.assertTrue(False)  # should not be call
                    return paddle.ones([1], dtype="float64")

            x = paddle.ones([1], dtype="float64")
            x.stop_gradient = False
            y = Tanh.apply(x)
            y.sum().backward()

    def test_mark_non_differentiable2(self):
        with _test_eager_guard():

            class Tanh(EagerPyLayer):
763

W
wanghuancoder 已提交
764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782
                @staticmethod
                def forward(ctx, x):
                    a = x + x
                    b = x + x + x
                    ctx.mark_non_differentiable(a)
                    return a, b

                @staticmethod
                def backward(ctx, grad_a, grad_b):
                    self.assertEqual(grad_a, paddle.zeros([1]))
                    self.assertEqual(grad_b, paddle.ones([1], dtype="float64"))
                    return grad_b

            x = paddle.ones([1], dtype="float64")
            x.stop_gradient = False
            a, b = Tanh.apply(x)
            b.sum().backward()
            self.assertEqual(x.grad, paddle.ones([1], dtype="float64"))

783 784

class TestPyLayerReturnType(unittest.TestCase):
785

786
    def test_forward_args_fake_tensor(self):
787

788
        class Tanh(LegacyPyLayer):
789

790 791 792 793 794 795 796 797 798 799 800 801 802 803 804
            @staticmethod
            def forward(ctx, x1):
                y1 = FakeTensor()
                return y1, x1

            @staticmethod
            def backward(ctx, dy1, dy2):
                return dy1

        input1 = FakeTensor()

        with self.assertRaises(ValueError):
            y1, y2 = Tanh.apply(input1)

    def test_forward_kwargs_fake_tensor(self):
805

806
        class Tanh(LegacyPyLayer):
807

808 809 810 811 812 813 814 815 816 817 818 819 820 821 822
            @staticmethod
            def forward(ctx, x1):

                return x1

            @staticmethod
            def backward(ctx, dy1, dy2):
                return dy1

        input1 = FakeTensor()

        with self.assertRaises(ValueError):
            y = Tanh.apply(x1=input1)

    def test_forward_return_fake_tensor(self):
823

824
        class Tanh(LegacyPyLayer):
825

826 827 828 829 830 831 832 833 834 835 836 837 838 839 840
            @staticmethod
            def forward(ctx, x1):

                return FakeTensor()

            @staticmethod
            def backward(ctx, dy1, dy2):
                return dy1

        input1 = paddle.randn([3, 2])

        with self.assertRaises(ValueError):
            y = Tanh.apply(x1=input1)

    def test_forward_return_fake_tensor_tuple(self):
841

842
        class Tanh(LegacyPyLayer):
843

844 845 846 847 848 849 850 851 852 853 854 855 856 857 858
            @staticmethod
            def forward(ctx, x1):

                return FakeTensor(), FakeTensor()

            @staticmethod
            def backward(ctx, dy1, dy2):
                return dy1

        input1 = paddle.randn([3, 2])

        with self.assertRaises(ValueError):
            y = Tanh.apply(x1=input1)

    def test_backward_return_fake_tensor_tuple(self):
859

860
        class Tanh(LegacyPyLayer):
861

862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878
            @staticmethod
            def forward(ctx, x1, x2):
                return x1 + 1, x1 + 2

            @staticmethod
            def backward(ctx, dy1, dy2):

                return FakeTensor(), 2

        input1 = paddle.randn([3, 2])
        input1.stop_gradient = False
        y, _ = Tanh.apply(input1, 1 + input1)

        with self.assertRaises(ValueError):
            y.mean().backward()

    def test_backward_return_fake_tensor(self):
879

880
        class Tanh(LegacyPyLayer):
881

882 883 884 885 886 887 888 889 890 891 892 893 894 895 896
            @staticmethod
            def forward(ctx, x1):
                return x1 + 1, x1 + 2

            @staticmethod
            def backward(ctx, dy1, dy2):
                return FakeTensor()

        input1 = paddle.randn([3, 2])
        input1.stop_gradient = False
        y, _ = Tanh.apply(input1)

        with self.assertRaises(ValueError):
            y.mean().backward()

897 898 899

if __name__ == '__main__':
    unittest.main()