test_pylayer_op.py 26.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import numpy as np

import paddle
W
wanghuancoder 已提交
21 22
from paddle.autograd import PyLayer, EagerPyLayer
from paddle.fluid.framework import _test_eager_guard, in_dygraph_mode
23 24


25 26 27 28 29
class FakeTensor(paddle.fluid.core.VarBase):
    def __init__(self):
        pass


30
class TestPyLayer(unittest.TestCase):
W
wanghuancoder 已提交
31 32
    def func_test_simple_pylayer_multiple_output(self):
        class tanh(EagerPyLayer if in_dygraph_mode() else PyLayer):
33 34 35 36 37 38
            @staticmethod
            def forward(ctx, x1, x2, func1, func2=paddle.square):
                ctx.func = func2
                y1 = func1(x1)
                y2 = func1(x2)
                ctx.save_for_backward(y1, y2)
W
WeiXin 已提交
39
                return y1, 1, y2, None
40 41 42 43 44 45 46 47 48 49 50 51 52

            @staticmethod
            def backward(ctx, dy1, dy2):
                y1, y2 = ctx.saved_tensor()
                re1 = dy1 * (1 - ctx.func(y1))
                re2 = dy2 * (1 - paddle.square(y2))
                return re1, re2

        input1 = paddle.randn([2, 3]).astype("float64")
        input2 = input1.detach().clone()
        input1.stop_gradient = False
        input2.stop_gradient = False
        z = tanh.apply(input1, input1, paddle.tanh, paddle.square)
W
WeiXin 已提交
53
        z = z[0] + z[2]
54 55 56 57 58
        z.mean().backward()

        z2 = paddle.tanh(input2) + paddle.tanh(input2)
        z2.mean().backward()

59 60
        self.assertTrue(
            np.max(np.abs((input1.grad.numpy() - input2.grad.numpy()))) < 1e-10)
61

W
wanghuancoder 已提交
62 63 64 65 66 67 68
    def test_simple_pylayer_multiple_output(self):
        with _test_eager_guard():
            self.func_test_simple_pylayer_multiple_output()
        self.func_test_simple_pylayer_multiple_output()

    def func_test_simple_pylayer_return_none_with_no_grad(self):
        class tanh(EagerPyLayer if in_dygraph_mode() else PyLayer):
69 70 71 72 73 74
            @staticmethod
            def forward(ctx, x1, x2, func1, func2=paddle.square):
                ctx.func = func2
                y1 = func1(x1)
                y2 = func1(x2)
                ctx.save_for_backward(y1, y2)
W
WeiXin 已提交
75
                return 1, None, y1, y2, ''
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92

            @staticmethod
            def backward(ctx, dy1, dy2):
                y1, y2 = ctx.saved_tensor()
                re1 = dy1 * (1 - ctx.func(y1))
                re2 = dy2 * (1 - paddle.square(y2))
                return re1, None

        input1 = paddle.randn([2, 3]).astype("float64")
        input2 = input1.detach().clone()
        input3 = input1.detach().clone()
        input4 = input1.detach().clone()
        input1.stop_gradient = False
        input2.stop_gradient = False
        input3.stop_gradient = True
        input4.stop_gradient = True
        z = tanh.apply(input1, input3, paddle.tanh, paddle.square)
W
WeiXin 已提交
93
        z = z[2] + z[3]
94 95 96 97 98
        z.mean().backward()

        z2 = paddle.tanh(input2) + paddle.tanh(input4)
        z2.mean().backward()

W
WeiXin 已提交
99 100
        self.assertTrue(
            np.max(np.abs((input1.grad.numpy() - input2.grad.numpy()))) < 1e-10)
101

W
wanghuancoder 已提交
102 103 104 105 106 107 108
    def test_simple_pylayer_return_none_with_no_grad(self):
        with _test_eager_guard():
            self.func_test_simple_pylayer_return_none_with_no_grad()
        self.func_test_simple_pylayer_return_none_with_no_grad()

    def func_test_simple_pylayer_single_output(self):
        class tanh(EagerPyLayer if in_dygraph_mode() else PyLayer):
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
            @staticmethod
            def forward(ctx, x1, func1, func2=paddle.square):
                ctx.func = func2
                y1 = func1(x1)
                ctx.save_for_backward(y1)
                return y1

            @staticmethod
            def backward(ctx, dy1):
                y1, = ctx.saved_tensor()
                re1 = dy1 * (1 - ctx.func(y1))
                return re1

        input1 = paddle.randn([2, 3]).astype("float64")
        input2 = input1.detach().clone()
        input1.stop_gradient = False
        input2.stop_gradient = False
        z = tanh.apply(x1=input1, func1=paddle.tanh)
        z.mean().backward()
        z2 = paddle.tanh(input2)
        z2.mean().backward()

131 132
        self.assertTrue(
            np.max(np.abs((input1.grad.numpy() - input2.grad.numpy()))) < 1e-10)
133

W
wanghuancoder 已提交
134 135 136 137 138 139 140
    def test_simple_pylayer_single_output(self):
        with _test_eager_guard():
            self.func_test_simple_pylayer_single_output()
        self.func_test_simple_pylayer_single_output()

    def func_test_pylayer_num_output_match(self):
        class tanh(EagerPyLayer if in_dygraph_mode() else PyLayer):
W
WeiXin 已提交
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
            @staticmethod
            def forward(
                    ctx,
                    x1,
                    x2, ):
                return x1 + x2

            @staticmethod
            def backward(ctx, dy1):
                return dy1 + 1

        input1 = paddle.randn([2, 3]).astype("float64")
        input2 = input1.detach().clone()
        input1.stop_gradient = False
        input2.stop_gradient = False
        z = tanh.apply(input1, input2)
        with self.assertRaises(ValueError):
            z.mean().backward()

W
wanghuancoder 已提交
160 161 162 163 164 165 166
    def test_pylayer_num_output_match(self):
        with _test_eager_guard():
            self.func_test_pylayer_num_output_match()
        self.func_test_pylayer_num_output_match()

    def func_test_pylayer_dtype(self):
        class tanh(EagerPyLayer if in_dygraph_mode() else PyLayer):
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
            @staticmethod
            def forward(ctx, x, dtype):
                y = paddle.cast(x, dtype)
                return y

            @staticmethod
            def backward(ctx, dy1):
                return dy1

        dtypes = [
            'bool', 'float16', 'float32', 'float64', 'uint8', 'int32', 'int64'
        ]
        for dtype in dtypes:
            input1 = (paddle.randn([2, 3]))
            input1.stop_gradient = False
            self.assertTrue(input1.grad is None)

            z = tanh.apply(input1, dtype)
            z = paddle.cast(z, "float32")
            z.sum().backward()
            self.assertTrue(input1.grad is not None)

W
wanghuancoder 已提交
189 190 191 192 193 194 195
    def test_pylayer_dtype(self):
        with _test_eager_guard():
            self.func_test_pylayer_dtype()
        self.func_test_pylayer_dtype()

    def func_test_pylayer_Exception_forward(self):
        class Layer_None1(EagerPyLayer if in_dygraph_mode() else PyLayer):
196 197 198 199 200 201 202 203 204
            @staticmethod
            def forward(ctx, *args):
                return None

            @staticmethod
            def backward(ctx, *args):
                return args

        input1 = paddle.randn([2, 3]).astype("float64")
W
WeiXin 已提交
205
        with self.assertRaises(ValueError):
206 207
            z = Layer_None1.apply(input1)

W
wanghuancoder 已提交
208
        class Layer_None2(EagerPyLayer if in_dygraph_mode() else PyLayer):
209 210
            @staticmethod
            def forward(ctx, *args):
W
WeiXin 已提交
211
                return [None, args[0]]
212 213 214 215 216 217

            @staticmethod
            def backward(ctx, *args):
                return args

        input1 = paddle.randn([2, 3]).astype("float64")
W
WeiXin 已提交
218 219
        # return None
        z = Layer_None2.apply(input1)
220

W
wanghuancoder 已提交
221
        class Layer_one1(EagerPyLayer if in_dygraph_mode() else PyLayer):
222 223 224 225 226 227 228 229 230
            @staticmethod
            def forward(ctx, *args):
                return 1

            @staticmethod
            def backward(ctx, *args):
                return args

        input1 = paddle.randn([2, 3]).astype("float64")
W
WeiXin 已提交
231 232
        # At least one output of `PyLayer.backward` is a `Tensor`
        with self.assertRaises(ValueError):
233 234
            z = Layer_one1.apply(input1)

W
wanghuancoder 已提交
235
        class Layer_one2(EagerPyLayer if in_dygraph_mode() else PyLayer):
236 237
            @staticmethod
            def forward(ctx, *args):
W
WeiXin 已提交
238
                return [1, 2, args[0]]
239 240 241 242 243 244

            @staticmethod
            def backward(ctx, *args):
                return args

        input1 = paddle.randn([2, 3]).astype("float64")
W
WeiXin 已提交
245 246
        # return int 
        z = Layer_one2.apply(input1)
247

W
wanghuancoder 已提交
248
        class Layer_no_fw(EagerPyLayer if in_dygraph_mode() else PyLayer):
249 250 251 252 253 254 255 256
            @staticmethod
            def backward(ctx, *args):
                return args

        input1 = paddle.randn([2, 3]).astype("float64")
        with self.assertRaises(NotImplementedError):
            z = Layer_no_fw.apply(input1)

W
wanghuancoder 已提交
257 258 259 260 261 262 263
    def test_pylayer_Exception_forward(self):
        with _test_eager_guard():
            self.func_test_pylayer_Exception_forward()
        self.func_test_pylayer_Exception_forward()

    def func_test_pylayer_nograd(self):
        class tanh(EagerPyLayer if in_dygraph_mode() else PyLayer):
264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
            @staticmethod
            def forward(ctx, x1, func1, func2=paddle.square, xx=None):
                ctx.func = func2
                y1 = func1(x1)
                return y1

            @staticmethod
            def backward(ctx, x1, y1, dy1):
                re1 = dy1 * (1 - ctx.func(y1))
                return re1

        input1 = paddle.randn([2, 3]).astype("float64")
        z = tanh.apply(input1, paddle.tanh, paddle.square)
        z.mean().backward()
        self.assertTrue(z.grad is None)

W
wanghuancoder 已提交
280 281 282 283 284 285 286
    def test_pylayer_nograd(self):
        with _test_eager_guard():
            self.func_test_pylayer_nograd()
        self.func_test_pylayer_nograd()

    def func_test_pylayer_Exception_bk(self):
        class Layer_bk_none1(EagerPyLayer if in_dygraph_mode() else PyLayer):
287 288 289 290 291 292 293 294 295 296 297 298
            @staticmethod
            def forward(ctx, x):
                return x * 2

            @staticmethod
            def backward(ctx, dy1):
                return None

        input2 = paddle.randn([2, 3]).astype("float64")
        input2.stop_gradient = False
        z = Layer_bk_none1.apply(input2)

299
        with self.assertRaises(ValueError):
300
            z.sum().backward()
301

W
wanghuancoder 已提交
302
        class Layer_bk_none2(EagerPyLayer if in_dygraph_mode() else PyLayer):
303 304 305 306 307 308 309 310 311 312 313
            @staticmethod
            def forward(ctx, x1, x2):
                return x1 + x2

            @staticmethod
            def backward(ctx, dy1):
                return None, dy1

        input1 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = False
        z = Layer_bk_none2.apply(input1, input1)
314

315
        with self.assertRaises(ValueError):
316
            z.mean().backward()
317

W
wanghuancoder 已提交
318
        class Layer_bk_one1(EagerPyLayer if in_dygraph_mode() else PyLayer):
319 320 321 322 323 324 325 326 327 328 329
            @staticmethod
            def forward(ctx, x):
                return x + x

            @staticmethod
            def backward(ctx, dy):
                return 1

        input1 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = False
        z = Layer_bk_one1.apply(input1)
330

331
        with self.assertRaises(ValueError):
332
            z.mean().backward()
333

W
wanghuancoder 已提交
334
        class Layer_bk_one2(EagerPyLayer if in_dygraph_mode() else PyLayer):
335
            @staticmethod
336 337
            def forward(ctx, x1, x2):
                return x1 * 2, x2 * 5
338 339 340 341 342 343 344

            @staticmethod
            def backward(ctx, *args):
                return 1, 1

        input1 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = False
345

346 347 348
        y = Layer_bk_one2.apply(input1, input1)
        z = y[0] + y[1]
        with self.assertRaises(ValueError):
349
            z.mean().backward()
350

W
wanghuancoder 已提交
351
        class Layer_no_bk(EagerPyLayer if in_dygraph_mode() else PyLayer):
352 353 354 355 356 357 358 359
            @staticmethod
            def forward(ctx, x):
                return x * 2, x * 5

        input1 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = False
        z = Layer_no_bk.apply(input1)

360 361 362
        with self.assertRaises(OSError):
            z = z[0] + z[1]
            z.mean().backward()
363

W
wanghuancoder 已提交
364
        class Layer_bk_match(EagerPyLayer if in_dygraph_mode() else PyLayer):
365 366 367 368 369 370 371 372 373 374 375 376
            @staticmethod
            def forward(ctx, x):
                return x * 2, x * 5

            @staticmethod
            def backward(ctx, dy1, dy2):
                return dy2 * 2, dy1 * 2

        input1 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = False
        z = Layer_bk_match.apply(input1)
        with self.assertRaises(ValueError):
377 378
            z = z[0] + z[1]
            z.mean().backward()
379

W
wanghuancoder 已提交
380 381 382 383 384 385 386
    def test_pylayer_Exception_bk(self):
        with _test_eager_guard():
            self.func_test_pylayer_Exception_bk()
        self.func_test_pylayer_Exception_bk()

    def func_test_pylayer_bk_return_none(self):
        class Layer_bk_none1(EagerPyLayer if in_dygraph_mode() else PyLayer):
387 388 389 390 391 392 393 394 395 396 397 398 399 400 401
            @staticmethod
            def forward(ctx, x1, x2):
                return x1 + x2

            @staticmethod
            def backward(ctx, dy):
                return 1

        input1 = paddle.randn([2, 3]).astype("float64")
        input2 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = True
        input2.stop_gradient = False
        z = Layer_bk_none1.apply(input1, input2)

        with self.assertRaises(ValueError):
402
            z.mean().backward()
403

W
wanghuancoder 已提交
404
        class Layer_bk_none2(EagerPyLayer if in_dygraph_mode() else PyLayer):
405 406 407 408 409 410 411 412 413 414 415 416 417 418 419
            @staticmethod
            def forward(ctx, x1, x2):
                return x1 * 2, x2 * 5

            @staticmethod
            def backward(ctx, *args):
                return 1, 1

        input1 = paddle.randn([2, 3]).astype("float64")
        input2 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = True
        input2.stop_gradient = False
        z = Layer_bk_none2.apply(input1, input2)
        z = z[0] + z[1]
        with self.assertRaises(ValueError):
420
            z.mean().backward()
421

W
wanghuancoder 已提交
422 423 424 425 426
    def test_pylayer_bk_return_none(self):
        with _test_eager_guard():
            self.func_test_pylayer_bk_return_none()
        self.func_test_pylayer_bk_return_none()

427
    def func_test_pylayer_inplace(self):
W
wanghuancoder 已提交
428
        class cus_tanh(EagerPyLayer if in_dygraph_mode() else PyLayer):
429 430
            @staticmethod
            def forward(ctx, x):
431
                return x
432 433 434 435 436

            @staticmethod
            def backward(ctx, dy):
                return dy

437 438 439 440 441 442 443 444 445 446
        class Layer(paddle.nn.Layer):
            def __init__(self):
                super(Layer, self).__init__()

            def forward(self, data):
                data = paddle.nn.functional.relu(data)
                z = paddle.tanh(data)
                z = cus_tanh.apply(data)
                return z.mean()

447 448 449
        for i in range(2):
            data = paddle.ones([2, 3], dtype="float64") / (i + 1)
            data.stop_gradient = False
450 451 452 453 454
            layer = Layer()
            z = layer(data)
            z.backward()
            self.assertTrue(data.grad is not None)

455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489
    def test_pylayer_inplace(self):
        with _test_eager_guard():
            self.func_test_pylayer_inplace()
        self.func_test_pylayer_inplace()

    def test_pylayer_inplace_backward_error(self):
        with _test_eager_guard():

            class cus_tanh(EagerPyLayer if in_dygraph_mode() else PyLayer):
                @staticmethod
                def forward(ctx, x):
                    ctx.mark_dirty(x)
                    return x

                @staticmethod
                def backward(ctx, dy):
                    return dy

            class Layer(paddle.nn.Layer):
                def __init__(self):
                    super(Layer, self).__init__()

                def forward(self, data):
                    var_b = data**2
                    var_c = var_b**2
                    z = cus_tanh.apply(var_b)
                    loss = paddle.nn.functional.relu(var_c)
                    return loss

            data = paddle.ones([2, 3], dtype="float64")
            data.stop_gradient = False
            layer = Layer()
            z = layer(data)
            with self.assertRaisesRegexp(
                    RuntimeError,
490
                    "received tensor_version:{} != wrapper_version_snapshot:{}".
491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558
                    format(1, 0)):
                z.backward()

    def test_pylayer_inplace_backward_success_1(self):
        with _test_eager_guard():

            class cus_tanh(EagerPyLayer if in_dygraph_mode() else PyLayer):
                @staticmethod
                def forward(ctx, x):
                    ctx.mark_dirty(x)
                    return x

                @staticmethod
                def backward(ctx, dy):
                    return dy

            class Layer(paddle.nn.Layer):
                def __init__(self):
                    super(Layer, self).__init__()

                def forward(self, data):
                    var_b = data**2
                    var_c = cus_tanh.apply(var_b)
                    var_d = var_c**2
                    loss = var_d.sum()
                    return loss

            for i in range(2):
                data = paddle.ones([2, 3], dtype="float64") / (i + 1)
                data.stop_gradient = False
                layer = Layer()
                z = layer(data)
                z.backward()
                self.assertTrue(data.grad is not None)

    def test_pylayer_inplace_backward_success_2(self):
        with _test_eager_guard():

            class cus_tanh(EagerPyLayer if in_dygraph_mode() else PyLayer):
                @staticmethod
                def forward(ctx, x):
                    ctx.mark_dirty(x)
                    return x

                @staticmethod
                def backward(ctx, dy):
                    return dy

            class Layer(paddle.nn.Layer):
                def __init__(self):
                    super(Layer, self).__init__()

                def forward(self, data):
                    var_b = data**2
                    var_c = cus_tanh.apply(var_b)
                    var_d = var_c + var_c
                    loss = var_d.sum()
                    return loss

            for i in range(2):
                data = paddle.ones([2, 3], dtype="float64") / (i + 1)
                data.stop_gradient = False
                layer = Layer()
                z = layer(data)
                z.backward()
                self.assertTrue(data.grad is not None)

    def func_test_pylayer_inplace_and_leaf_exception(self):
W
wanghuancoder 已提交
559
        class cus_pylayer_op(EagerPyLayer if in_dygraph_mode() else PyLayer):
560 561
            @staticmethod
            def forward(ctx, x):
562 563
                if in_dygraph_mode():
                    ctx.mark_dirty(x)
564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585
                return x

            @staticmethod
            def backward(ctx, dy):
                return dy

        class Layer(paddle.nn.Layer):
            def __init__(self):
                super(Layer, self).__init__()

            def forward(self, data):
                z = cus_pylayer_op.apply(data)
                return z.mean()

        for i in range(2):
            data = paddle.ones([2, 3], dtype="float64") / (i + 1)
            data.stop_gradient = False
            layer = Layer()

            with self.assertRaises(ValueError):
                z = layer(data)

586 587 588 589 590
    def test_pylayer_inplace_and_leaf_exception(self):
        with _test_eager_guard():
            self.func_test_pylayer_inplace_and_leaf_exception()
        self.func_test_pylayer_inplace_and_leaf_exception()

W
wanghuancoder 已提交
591 592
    def func_test_backward_in_backward(self):
        class cus_tanh(EagerPyLayer if in_dygraph_mode() else PyLayer):
593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611
            @staticmethod
            def forward(ctx, x):
                temp = x.detach()
                ctx.inputs = temp
                return x.mean()

            @staticmethod
            def backward(ctx, dy):
                with paddle.set_grad_enabled(True):
                    temp = ctx.inputs
                    temp.stop_gradient = False
                    z = paddle.tanh(temp)
                    z.backward()
                    self.assertTrue(temp.grad is not None)
                    return paddle.to_tensor(temp.grad)

        for i in range(2):
            data = paddle.ones([2, 3], dtype="float32") / (i + 1)
            data.stop_gradient = False
612 613 614 615
            data = paddle.nn.functional.relu(data)
            z = paddle.tanh(data)
            z = cus_tanh.apply(data)

W
wanghuancoder 已提交
616 617 618 619 620 621 622
    def test_backward_in_backward(self):
        with _test_eager_guard():
            self.func_test_backward_in_backward()
        self.func_test_backward_in_backward()

    def func_test_return_to_tensor(self):
        class Tanh(EagerPyLayer if in_dygraph_mode() else PyLayer):
623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642
            @staticmethod
            def forward(ctx, x1):
                y1 = paddle.tanh(x1)
                ctx.save_for_backward(y1)
                tensor_1 = paddle.to_tensor([1, 2], dtype='float32')
                return y1, 5, None, "helloworld", tensor_1

            @staticmethod
            def backward(ctx, dy1, dy2):
                y1, = ctx.saved_tensor()
                re1 = dy1 * (1 - paddle.square(y1))
                return dy1

        input1 = paddle.randn([2, 3]).astype("float32")
        input2 = input1.detach().clone()
        input1.stop_gradient = False
        input2.stop_gradient = False
        z, number, none_item, string_item, tensor1 = Tanh.apply(x1=input1)
        z.mean().backward()

W
wanghuancoder 已提交
643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725
    def test_return_to_tensor(self):
        with _test_eager_guard():
            self.func_test_return_to_tensor()
        self.func_test_return_to_tensor()

    def test_materialize_grads(self):
        with _test_eager_guard():

            class Tanh(EagerPyLayer):
                @staticmethod
                def forward(ctx, x):
                    return x, x + x

                @staticmethod
                def backward(ctx, grad, grad2):
                    self.assertEqual(grad2, paddle.zeros([1]))
                    return grad

            x = paddle.ones([1], dtype="float64")
            x.stop_gradient = False
            Tanh.apply(x)[0].backward()

    def test_dont_materialize_grads(self):
        with _test_eager_guard():

            class Tanh(EagerPyLayer):
                @staticmethod
                def forward(ctx, x):
                    ctx.set_materialize_grads(False)
                    return x, x + x

                @staticmethod
                def backward(ctx, grad, grad2):
                    self.assertIsNone(grad2)
                    return grad

            x = paddle.ones([1], dtype="float64")
            x.stop_gradient = False
            Tanh.apply(x)[0].backward()

    def test_mark_non_differentiable(self):
        with _test_eager_guard():

            class Tanh(EagerPyLayer):
                @staticmethod
                def forward(ctx, x):
                    a = x + x
                    ctx.mark_non_differentiable(a)
                    return a

                @staticmethod
                def backward(ctx, grad):
                    self.assertTrue(False)  # should not be call
                    return paddle.ones([1], dtype="float64")

            x = paddle.ones([1], dtype="float64")
            x.stop_gradient = False
            y = Tanh.apply(x)
            y.sum().backward()

    def test_mark_non_differentiable2(self):
        with _test_eager_guard():

            class Tanh(EagerPyLayer):
                @staticmethod
                def forward(ctx, x):
                    a = x + x
                    b = x + x + x
                    ctx.mark_non_differentiable(a)
                    return a, b

                @staticmethod
                def backward(ctx, grad_a, grad_b):
                    self.assertEqual(grad_a, paddle.zeros([1]))
                    self.assertEqual(grad_b, paddle.ones([1], dtype="float64"))
                    return grad_b

            x = paddle.ones([1], dtype="float64")
            x.stop_gradient = False
            a, b = Tanh.apply(x)
            b.sum().backward()
            self.assertEqual(x.grad, paddle.ones([1], dtype="float64"))

726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826

class TestPyLayerReturnType(unittest.TestCase):
    def test_forward_args_fake_tensor(self):
        class Tanh(PyLayer):
            @staticmethod
            def forward(ctx, x1):
                y1 = FakeTensor()
                return y1, x1

            @staticmethod
            def backward(ctx, dy1, dy2):
                return dy1

        input1 = FakeTensor()

        with self.assertRaises(ValueError):
            y1, y2 = Tanh.apply(input1)

    def test_forward_kwargs_fake_tensor(self):
        class Tanh(PyLayer):
            @staticmethod
            def forward(ctx, x1):

                return x1

            @staticmethod
            def backward(ctx, dy1, dy2):
                return dy1

        input1 = FakeTensor()

        with self.assertRaises(ValueError):
            y = Tanh.apply(x1=input1)

    def test_forward_return_fake_tensor(self):
        class Tanh(PyLayer):
            @staticmethod
            def forward(ctx, x1):

                return FakeTensor()

            @staticmethod
            def backward(ctx, dy1, dy2):
                return dy1

        input1 = paddle.randn([3, 2])

        with self.assertRaises(ValueError):
            y = Tanh.apply(x1=input1)

    def test_forward_return_fake_tensor_tuple(self):
        class Tanh(PyLayer):
            @staticmethod
            def forward(ctx, x1):

                return FakeTensor(), FakeTensor()

            @staticmethod
            def backward(ctx, dy1, dy2):
                return dy1

        input1 = paddle.randn([3, 2])

        with self.assertRaises(ValueError):
            y = Tanh.apply(x1=input1)

    def test_backward_return_fake_tensor_tuple(self):
        class Tanh(PyLayer):
            @staticmethod
            def forward(ctx, x1, x2):
                return x1 + 1, x1 + 2

            @staticmethod
            def backward(ctx, dy1, dy2):

                return FakeTensor(), 2

        input1 = paddle.randn([3, 2])
        input1.stop_gradient = False
        y, _ = Tanh.apply(input1, 1 + input1)

        with self.assertRaises(ValueError):
            y.mean().backward()

    def test_backward_return_fake_tensor(self):
        class Tanh(PyLayer):
            @staticmethod
            def forward(ctx, x1):
                return x1 + 1, x1 + 2

            @staticmethod
            def backward(ctx, dy1, dy2):
                return FakeTensor()

        input1 = paddle.randn([3, 2])
        input1.stop_gradient = False
        y, _ = Tanh.apply(input1)

        with self.assertRaises(ValueError):
            y.mean().backward()

827 828 829

if __name__ == '__main__':
    unittest.main()