test_pylayer_op.py 20.3 KB
Newer Older
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

17 18 19
import numpy as np

import paddle
20
from paddle.autograd.py_layer import PyLayer
21 22


23 24 25 26 27
class FakeTensor(paddle.fluid.core.VarBase):
    def __init__(self):
        pass


28
class TestPyLayer(unittest.TestCase):
29
    def test_simple_pylayer_multiple_output(self):
30
        class tanh(PyLayer):
31 32 33 34 35 36
            @staticmethod
            def forward(ctx, x1, x2, func1, func2=paddle.square):
                ctx.func = func2
                y1 = func1(x1)
                y2 = func1(x2)
                ctx.save_for_backward(y1, y2)
W
WeiXin 已提交
37
                return y1, 1, y2, None
38 39 40 41 42 43 44 45 46 47 48 49 50

            @staticmethod
            def backward(ctx, dy1, dy2):
                y1, y2 = ctx.saved_tensor()
                re1 = dy1 * (1 - ctx.func(y1))
                re2 = dy2 * (1 - paddle.square(y2))
                return re1, re2

        input1 = paddle.randn([2, 3]).astype("float64")
        input2 = input1.detach().clone()
        input1.stop_gradient = False
        input2.stop_gradient = False
        z = tanh.apply(input1, input1, paddle.tanh, paddle.square)
W
WeiXin 已提交
51
        z = z[0] + z[2]
52 53 54 55 56
        z.mean().backward()

        z2 = paddle.tanh(input2) + paddle.tanh(input2)
        z2.mean().backward()

57
        self.assertTrue(
58 59
            np.max(np.abs((input1.grad.numpy() - input2.grad.numpy()))) < 1e-10
        )
60

61
    def test_simple_pylayer_return_none_with_no_grad(self):
62
        class tanh(PyLayer):
63 64 65 66 67 68
            @staticmethod
            def forward(ctx, x1, x2, func1, func2=paddle.square):
                ctx.func = func2
                y1 = func1(x1)
                y2 = func1(x2)
                ctx.save_for_backward(y1, y2)
W
WeiXin 已提交
69
                return 1, None, y1, y2, ''
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86

            @staticmethod
            def backward(ctx, dy1, dy2):
                y1, y2 = ctx.saved_tensor()
                re1 = dy1 * (1 - ctx.func(y1))
                re2 = dy2 * (1 - paddle.square(y2))
                return re1, None

        input1 = paddle.randn([2, 3]).astype("float64")
        input2 = input1.detach().clone()
        input3 = input1.detach().clone()
        input4 = input1.detach().clone()
        input1.stop_gradient = False
        input2.stop_gradient = False
        input3.stop_gradient = True
        input4.stop_gradient = True
        z = tanh.apply(input1, input3, paddle.tanh, paddle.square)
W
WeiXin 已提交
87
        z = z[2] + z[3]
88 89 90 91 92
        z.mean().backward()

        z2 = paddle.tanh(input2) + paddle.tanh(input4)
        z2.mean().backward()

W
WeiXin 已提交
93
        self.assertTrue(
94 95
            np.max(np.abs((input1.grad.numpy() - input2.grad.numpy()))) < 1e-10
        )
96

97
    def test_simple_pylayer_single_output(self):
98
        class tanh(PyLayer):
99 100 101 102 103 104 105 106 107
            @staticmethod
            def forward(ctx, x1, func1, func2=paddle.square):
                ctx.func = func2
                y1 = func1(x1)
                ctx.save_for_backward(y1)
                return y1

            @staticmethod
            def backward(ctx, dy1):
108
                (y1,) = ctx.saved_tensor()
109 110 111 112 113 114 115 116 117 118 119 120
                re1 = dy1 * (1 - ctx.func(y1))
                return re1

        input1 = paddle.randn([2, 3]).astype("float64")
        input2 = input1.detach().clone()
        input1.stop_gradient = False
        input2.stop_gradient = False
        z = tanh.apply(x1=input1, func1=paddle.tanh)
        z.mean().backward()
        z2 = paddle.tanh(input2)
        z2.mean().backward()

121
        self.assertTrue(
122 123
            np.max(np.abs((input1.grad.numpy() - input2.grad.numpy()))) < 1e-10
        )
124

125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
    def test_simple_pylayer_multi_output(self):
        class tanh(PyLayer):
            @staticmethod
            def forward(ctx, x1, func1, func2=paddle.split):
                ctx.func = func2
                y1 = func1(x1)
                ctx.save_for_backward(y1)
                return y1

            @staticmethod
            def backward(ctx, dy1):
                (y1,) = ctx.saved_tensor()
                re1 = ctx.func(dy1, 3)
                return re1

        input1 = paddle.randn([2, 3]).astype("float64")
        input2 = paddle.randn([2, 3]).astype("float64")
        input3 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = False
        input2.stop_gradient = False
        input3.stop_gradient = False
        z = tanh.apply(x1=[input1, input2, input3], func1=paddle.concat)
        z.mean().backward()
        z2 = paddle.concat([input1, input2, input3])
        z2.mean().backward()

        self.assertTrue(
            np.max(np.abs((input1.grad.numpy() - input2.grad.numpy()))) < 1e-10
        )

155
    def test_pylayer_num_output_match(self):
156
        class tanh(PyLayer):
W
WeiXin 已提交
157 158
            @staticmethod
            def forward(
159 160 161 162
                ctx,
                x1,
                x2,
            ):
W
WeiXin 已提交
163 164 165 166 167 168 169 170 171 172 173 174 175 176
                return x1 + x2

            @staticmethod
            def backward(ctx, dy1):
                return dy1 + 1

        input1 = paddle.randn([2, 3]).astype("float64")
        input2 = input1.detach().clone()
        input1.stop_gradient = False
        input2.stop_gradient = False
        z = tanh.apply(input1, input2)
        with self.assertRaises(ValueError):
            z.mean().backward()

177
    def test_pylayer_dtype(self):
178
        class tanh(PyLayer):
179 180 181 182 183 184 185 186 187 188
            @staticmethod
            def forward(ctx, x, dtype):
                y = paddle.cast(x, dtype)
                return y

            @staticmethod
            def backward(ctx, dy1):
                return dy1

        dtypes = [
189 190 191 192 193 194 195
            'bool',
            'float16',
            'float32',
            'float64',
            'uint8',
            'int32',
            'int64',
196 197
        ]
        for dtype in dtypes:
198
            input1 = paddle.randn([2, 3])
199
            input1.stop_gradient = False
200
            self.assertIsNone(input1.grad)
201 202 203 204

            z = tanh.apply(input1, dtype)
            z = paddle.cast(z, "float32")
            z.sum().backward()
205
            self.assertIsNotNone(input1.grad)
206

207
    def test_pylayer_Exception_forward(self):
208
        class Layer_None1(PyLayer):
209 210 211 212 213 214 215 216 217
            @staticmethod
            def forward(ctx, *args):
                return None

            @staticmethod
            def backward(ctx, *args):
                return args

        input1 = paddle.randn([2, 3]).astype("float64")
W
WeiXin 已提交
218
        with self.assertRaises(ValueError):
219 220
            z = Layer_None1.apply(input1)

221
        class Layer_None2(PyLayer):
222 223
            @staticmethod
            def forward(ctx, *args):
W
WeiXin 已提交
224
                return [None, args[0]]
225 226 227 228 229 230

            @staticmethod
            def backward(ctx, *args):
                return args

        input1 = paddle.randn([2, 3]).astype("float64")
W
WeiXin 已提交
231 232
        # return None
        z = Layer_None2.apply(input1)
233

234
        class Layer_one1(PyLayer):
235 236 237 238 239 240 241 242 243
            @staticmethod
            def forward(ctx, *args):
                return 1

            @staticmethod
            def backward(ctx, *args):
                return args

        input1 = paddle.randn([2, 3]).astype("float64")
W
WeiXin 已提交
244 245
        # At least one output of `PyLayer.backward` is a `Tensor`
        with self.assertRaises(ValueError):
246 247
            z = Layer_one1.apply(input1)

248
        class Layer_one2(PyLayer):
249 250
            @staticmethod
            def forward(ctx, *args):
W
WeiXin 已提交
251
                return [1, 2, args[0]]
252 253 254 255 256 257

            @staticmethod
            def backward(ctx, *args):
                return args

        input1 = paddle.randn([2, 3]).astype("float64")
258
        # return int
W
WeiXin 已提交
259
        z = Layer_one2.apply(input1)
260

261
        class Layer_no_fw(PyLayer):
262 263 264 265 266 267 268 269
            @staticmethod
            def backward(ctx, *args):
                return args

        input1 = paddle.randn([2, 3]).astype("float64")
        with self.assertRaises(NotImplementedError):
            z = Layer_no_fw.apply(input1)

270
    def test_pylayer_nograd(self):
271
        class tanh(PyLayer):
272 273 274 275 276 277 278 279 280 281 282 283 284 285
            @staticmethod
            def forward(ctx, x1, func1, func2=paddle.square, xx=None):
                ctx.func = func2
                y1 = func1(x1)
                return y1

            @staticmethod
            def backward(ctx, x1, y1, dy1):
                re1 = dy1 * (1 - ctx.func(y1))
                return re1

        input1 = paddle.randn([2, 3]).astype("float64")
        z = tanh.apply(input1, paddle.tanh, paddle.square)
        z.mean().backward()
286
        self.assertIsNone(z.grad)
287

288
    def test_pylayer_Exception_bk(self):
289
        class Layer_bk_none1(PyLayer):
290 291 292 293 294 295 296 297 298 299 300 301
            @staticmethod
            def forward(ctx, x):
                return x * 2

            @staticmethod
            def backward(ctx, dy1):
                return None

        input2 = paddle.randn([2, 3]).astype("float64")
        input2.stop_gradient = False
        z = Layer_bk_none1.apply(input2)

302 303
        z.sum().backward()
        self.assertEqual(input2.grad, None)
304

305
        class Layer_bk_none2(PyLayer):
306 307 308 309 310 311 312 313 314 315 316
            @staticmethod
            def forward(ctx, x1, x2):
                return x1 + x2

            @staticmethod
            def backward(ctx, dy1):
                return None, dy1

        input1 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = False
        z = Layer_bk_none2.apply(input1, input1)
317

318 319
        z.mean().backward()
        self.assertIsNone(z.grad)
320

321
        class Layer_bk_one1(PyLayer):
322 323 324 325 326 327 328 329 330 331 332
            @staticmethod
            def forward(ctx, x):
                return x + x

            @staticmethod
            def backward(ctx, dy):
                return 1

        input1 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = False
        z = Layer_bk_one1.apply(input1)
333

334
        with self.assertRaises(ValueError):
335
            z.mean().backward()
336

337
        class Layer_bk_one2(PyLayer):
338
            @staticmethod
339 340
            def forward(ctx, x1, x2):
                return x1 * 2, x2 * 5
341 342 343 344 345 346 347

            @staticmethod
            def backward(ctx, *args):
                return 1, 1

        input1 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = False
348

349 350 351
        y = Layer_bk_one2.apply(input1, input1)
        z = y[0] + y[1]
        with self.assertRaises(ValueError):
352
            z.mean().backward()
353

354
        class Layer_no_bk(PyLayer):
355 356 357 358 359 360 361 362
            @staticmethod
            def forward(ctx, x):
                return x * 2, x * 5

        input1 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = False
        z = Layer_no_bk.apply(input1)

363 364 365
        with self.assertRaises(OSError):
            z = z[0] + z[1]
            z.mean().backward()
366

367
        class Layer_bk_match(PyLayer):
368 369 370 371 372 373 374 375 376 377 378 379
            @staticmethod
            def forward(ctx, x):
                return x * 2, x * 5

            @staticmethod
            def backward(ctx, dy1, dy2):
                return dy2 * 2, dy1 * 2

        input1 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = False
        z = Layer_bk_match.apply(input1)
        with self.assertRaises(ValueError):
380 381
            z = z[0] + z[1]
            z.mean().backward()
382

383
    def test_pylayer_bk_return_none(self):
384
        class Layer_bk_none1(PyLayer):
385 386 387 388 389 390 391 392 393 394 395 396 397 398 399
            @staticmethod
            def forward(ctx, x1, x2):
                return x1 + x2

            @staticmethod
            def backward(ctx, dy):
                return 1

        input1 = paddle.randn([2, 3]).astype("float64")
        input2 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = True
        input2.stop_gradient = False
        z = Layer_bk_none1.apply(input1, input2)

        with self.assertRaises(ValueError):
400
            z.mean().backward()
401

402
        class Layer_bk_none2(PyLayer):
403 404 405 406 407 408 409 410 411 412 413 414 415 416 417
            @staticmethod
            def forward(ctx, x1, x2):
                return x1 * 2, x2 * 5

            @staticmethod
            def backward(ctx, *args):
                return 1, 1

        input1 = paddle.randn([2, 3]).astype("float64")
        input2 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = True
        input2.stop_gradient = False
        z = Layer_bk_none2.apply(input1, input2)
        z = z[0] + z[1]
        with self.assertRaises(ValueError):
418
            z.mean().backward()
419

420
    def test_pylayer_inplace(self):
421
        class cus_tanh(PyLayer):
422 423
            @staticmethod
            def forward(ctx, x):
424
                return x
425 426 427 428 429

            @staticmethod
            def backward(ctx, dy):
                return dy

430 431
        class Layer(paddle.nn.Layer):
            def __init__(self):
432
                super().__init__()
433 434

            def forward(self, data):
435
                data = data**2
436 437 438 439
                z = paddle.tanh(data)
                z = cus_tanh.apply(data)
                return z.mean()

440 441 442
        for i in range(2):
            data = paddle.ones([2, 3], dtype="float64") / (i + 1)
            data.stop_gradient = False
443 444 445
            layer = Layer()
            z = layer(data)
            z.backward()
446
            self.assertIsNotNone(data.grad)
447

448
    def test_pylayer_inplace_backward_error(self):
449
        class cus_tanh(PyLayer):
450 451 452
            @staticmethod
            def forward(ctx, x):
                return x
453

454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472
            @staticmethod
            def backward(ctx, dy):
                return dy

        class Layer(paddle.nn.Layer):
            def __init__(self):
                super().__init__()

            def forward(self, data):
                var_b = data**2
                var_c = var_b**2
                z = cus_tanh.apply(var_b)
                loss = paddle.nn.functional.relu(var_c)
                return loss

        data = paddle.ones([2, 3], dtype="float64")
        data.stop_gradient = False
        layer = Layer()
        z = layer(data)
473
        with self.assertRaisesRegex(
474 475 476 477 478 479 480 481
            RuntimeError,
            "received tensor_version:{} != wrapper_version_snapshot:{}".format(
                1, 0
            ),
        ):
            z.backward()

    def test_pylayer_inplace_backward_success_1(self):
482
        class cus_tanh(PyLayer):
483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503
            @staticmethod
            def forward(ctx, x):
                return x

            @staticmethod
            def backward(ctx, dy):
                return dy

        class Layer(paddle.nn.Layer):
            def __init__(self):
                super().__init__()

            def forward(self, data):
                var_b = data**2
                var_c = cus_tanh.apply(var_b)
                var_d = var_c**2
                loss = var_d.sum()
                return loss

        for i in range(2):
            data = paddle.ones([2, 3], dtype="float64") / (i + 1)
504 505 506
            data.stop_gradient = False
            layer = Layer()
            z = layer(data)
507 508
            z.backward()
            self.assertIsNotNone(data.grad)
509

510
    def test_pylayer_inplace_backward_success_2(self):
511
        class cus_tanh(PyLayer):
512 513 514
            @staticmethod
            def forward(ctx, x):
                return x
515

516 517 518
            @staticmethod
            def backward(ctx, dy):
                return dy
519

520 521 522
        class Layer(paddle.nn.Layer):
            def __init__(self):
                super().__init__()
523

524 525 526 527 528 529
            def forward(self, data):
                var_b = data**2
                var_c = cus_tanh.apply(var_b)
                var_d = var_c + var_c
                loss = var_d.sum()
                return loss
530

531 532 533 534 535 536 537 538 539
        for i in range(2):
            data = paddle.ones([2, 3], dtype="float64") / (i + 1)
            data.stop_gradient = False
            layer = Layer()
            z = layer(data)
            z.backward()
            self.assertIsNotNone(data.grad)

    def test_pylayer_inplace_and_leaf_exception(self):
540
        class cus_pylayer_op(PyLayer):
541 542 543 544 545 546 547 548 549 550
            @staticmethod
            def forward(ctx, x):
                return x

            @staticmethod
            def backward(ctx, dy):
                return dy

        class Layer(paddle.nn.Layer):
            def __init__(self):
551
                super().__init__()
552 553 554 555 556 557 558 559 560 561 562 563 564

            def forward(self, data):
                z = cus_pylayer_op.apply(data)
                return z.mean()

        for i in range(2):
            data = paddle.ones([2, 3], dtype="float64") / (i + 1)
            data.stop_gradient = False
            layer = Layer()

            with self.assertRaises(ValueError):
                z = layer(data)

565
    def test_backward_in_backward(self):
566
        class cus_tanh(PyLayer):
567 568 569 570 571 572 573 574 575 576 577 578 579
            @staticmethod
            def forward(ctx, x):
                temp = x.detach()
                ctx.inputs = temp
                return x.mean()

            @staticmethod
            def backward(ctx, dy):
                with paddle.set_grad_enabled(True):
                    temp = ctx.inputs
                    temp.stop_gradient = False
                    z = paddle.tanh(temp)
                    z.backward()
580
                    self.assertIsNotNone(temp.grad)
581 582 583 584 585
                    return paddle.to_tensor(temp.grad)

        for i in range(2):
            data = paddle.ones([2, 3], dtype="float32") / (i + 1)
            data.stop_gradient = False
586 587 588 589
            data = paddle.nn.functional.relu(data)
            z = paddle.tanh(data)
            z = cus_tanh.apply(data)

590
    def test_return_to_tensor(self):
591
        class Tanh(PyLayer):
592 593 594 595 596 597 598 599 600
            @staticmethod
            def forward(ctx, x1):
                y1 = paddle.tanh(x1)
                ctx.save_for_backward(y1)
                tensor_1 = paddle.to_tensor([1, 2], dtype='float32')
                return y1, 5, None, "helloworld", tensor_1

            @staticmethod
            def backward(ctx, dy1, dy2):
601
                (y1,) = ctx.saved_tensor()
602 603 604 605 606 607 608 609 610 611
                re1 = dy1 * (1 - paddle.square(y1))
                return dy1

        input1 = paddle.randn([2, 3]).astype("float32")
        input2 = input1.detach().clone()
        input1.stop_gradient = False
        input2.stop_gradient = False
        z, number, none_item, string_item, tensor1 = Tanh.apply(x1=input1)
        z.mean().backward()

W
wanghuancoder 已提交
612
    def test_materialize_grads(self):
613
        class Tanh(PyLayer):
614 615 616 617
            @staticmethod
            def forward(ctx, x):
                ctx.mark_not_inplace(x)
                return x, x + x
W
wanghuancoder 已提交
618

619 620 621 622
            @staticmethod
            def backward(ctx, grad, grad2):
                self.assertEqual(grad2, paddle.zeros([1]))
                return grad
W
wanghuancoder 已提交
623

624 625 626
        x = paddle.ones([1], dtype="float64")
        x.stop_gradient = False
        Tanh.apply(x)[0].backward()
W
wanghuancoder 已提交
627 628

    def test_dont_materialize_grads(self):
629
        class Tanh(PyLayer):
630 631 632 633 634
            @staticmethod
            def forward(ctx, x):
                ctx.mark_not_inplace(x)
                ctx.set_materialize_grads(False)
                return x, x + x
W
wanghuancoder 已提交
635

636 637 638 639
            @staticmethod
            def backward(ctx, grad, grad2):
                self.assertIsNone(grad2)
                return grad
W
wanghuancoder 已提交
640

641 642 643
        x = paddle.ones([1], dtype="float64")
        x.stop_gradient = False
        Tanh.apply(x)[0].backward()
W
wanghuancoder 已提交
644 645

    def test_mark_non_differentiable(self):
646
        class Tanh(PyLayer):
647 648 649 650 651
            @staticmethod
            def forward(ctx, x):
                a = x + x
                ctx.mark_non_differentiable(a)
                return a
W
wanghuancoder 已提交
652

653 654 655 656
            @staticmethod
            def backward(ctx, grad):
                self.assertTrue(False)  # should not be call
                return paddle.ones([1], dtype="float64")
W
wanghuancoder 已提交
657

658 659 660 661
        x = paddle.ones([1], dtype="float64")
        x.stop_gradient = False
        y = Tanh.apply(x)
        y.sum().backward()
W
wanghuancoder 已提交
662 663

    def test_mark_non_differentiable2(self):
664
        class Tanh(PyLayer):
665 666 667 668 669 670
            @staticmethod
            def forward(ctx, x):
                a = x + x
                b = x + x + x
                ctx.mark_non_differentiable(a)
                return a, b
W
wanghuancoder 已提交
671

672 673 674 675 676
            @staticmethod
            def backward(ctx, grad_a, grad_b):
                self.assertEqual(grad_a, paddle.zeros([1]))
                self.assertEqual(grad_b, paddle.ones([1], dtype="float64"))
                return grad_b
W
wanghuancoder 已提交
677

678 679 680 681 682
        x = paddle.ones([1], dtype="float64")
        x.stop_gradient = False
        a, b = Tanh.apply(x)
        b.sum().backward()
        self.assertEqual(x.grad, paddle.ones([1], dtype="float64"))
W
wanghuancoder 已提交
683

684

685 686
if __name__ == '__main__':
    unittest.main()