test_pylayer_op.py 19.3 KB
Newer Older
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

17 18 19
import numpy as np

import paddle
20
from paddle.autograd.py_layer import PyLayer
21 22


23 24 25 26 27
class FakeTensor(paddle.fluid.core.VarBase):
    def __init__(self):
        pass


28
class TestPyLayer(unittest.TestCase):
29
    def test_simple_pylayer_multiple_output(self):
30
        class tanh(PyLayer):
31 32 33 34 35 36
            @staticmethod
            def forward(ctx, x1, x2, func1, func2=paddle.square):
                ctx.func = func2
                y1 = func1(x1)
                y2 = func1(x2)
                ctx.save_for_backward(y1, y2)
W
WeiXin 已提交
37
                return y1, 1, y2, None
38 39 40 41 42 43 44 45 46 47 48 49 50

            @staticmethod
            def backward(ctx, dy1, dy2):
                y1, y2 = ctx.saved_tensor()
                re1 = dy1 * (1 - ctx.func(y1))
                re2 = dy2 * (1 - paddle.square(y2))
                return re1, re2

        input1 = paddle.randn([2, 3]).astype("float64")
        input2 = input1.detach().clone()
        input1.stop_gradient = False
        input2.stop_gradient = False
        z = tanh.apply(input1, input1, paddle.tanh, paddle.square)
W
WeiXin 已提交
51
        z = z[0] + z[2]
52 53 54 55 56
        z.mean().backward()

        z2 = paddle.tanh(input2) + paddle.tanh(input2)
        z2.mean().backward()

57
        self.assertTrue(
58 59
            np.max(np.abs((input1.grad.numpy() - input2.grad.numpy()))) < 1e-10
        )
60

61
    def test_simple_pylayer_return_none_with_no_grad(self):
62
        class tanh(PyLayer):
63 64 65 66 67 68
            @staticmethod
            def forward(ctx, x1, x2, func1, func2=paddle.square):
                ctx.func = func2
                y1 = func1(x1)
                y2 = func1(x2)
                ctx.save_for_backward(y1, y2)
W
WeiXin 已提交
69
                return 1, None, y1, y2, ''
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86

            @staticmethod
            def backward(ctx, dy1, dy2):
                y1, y2 = ctx.saved_tensor()
                re1 = dy1 * (1 - ctx.func(y1))
                re2 = dy2 * (1 - paddle.square(y2))
                return re1, None

        input1 = paddle.randn([2, 3]).astype("float64")
        input2 = input1.detach().clone()
        input3 = input1.detach().clone()
        input4 = input1.detach().clone()
        input1.stop_gradient = False
        input2.stop_gradient = False
        input3.stop_gradient = True
        input4.stop_gradient = True
        z = tanh.apply(input1, input3, paddle.tanh, paddle.square)
W
WeiXin 已提交
87
        z = z[2] + z[3]
88 89 90 91 92
        z.mean().backward()

        z2 = paddle.tanh(input2) + paddle.tanh(input4)
        z2.mean().backward()

W
WeiXin 已提交
93
        self.assertTrue(
94 95
            np.max(np.abs((input1.grad.numpy() - input2.grad.numpy()))) < 1e-10
        )
96

97
    def test_simple_pylayer_single_output(self):
98
        class tanh(PyLayer):
99 100 101 102 103 104 105 106 107
            @staticmethod
            def forward(ctx, x1, func1, func2=paddle.square):
                ctx.func = func2
                y1 = func1(x1)
                ctx.save_for_backward(y1)
                return y1

            @staticmethod
            def backward(ctx, dy1):
108
                (y1,) = ctx.saved_tensor()
109 110 111 112 113 114 115 116 117 118 119 120
                re1 = dy1 * (1 - ctx.func(y1))
                return re1

        input1 = paddle.randn([2, 3]).astype("float64")
        input2 = input1.detach().clone()
        input1.stop_gradient = False
        input2.stop_gradient = False
        z = tanh.apply(x1=input1, func1=paddle.tanh)
        z.mean().backward()
        z2 = paddle.tanh(input2)
        z2.mean().backward()

121
        self.assertTrue(
122 123
            np.max(np.abs((input1.grad.numpy() - input2.grad.numpy()))) < 1e-10
        )
124

125
    def test_pylayer_num_output_match(self):
126
        class tanh(PyLayer):
W
WeiXin 已提交
127 128
            @staticmethod
            def forward(
129 130 131 132
                ctx,
                x1,
                x2,
            ):
W
WeiXin 已提交
133 134 135 136 137 138 139 140 141 142 143 144 145 146
                return x1 + x2

            @staticmethod
            def backward(ctx, dy1):
                return dy1 + 1

        input1 = paddle.randn([2, 3]).astype("float64")
        input2 = input1.detach().clone()
        input1.stop_gradient = False
        input2.stop_gradient = False
        z = tanh.apply(input1, input2)
        with self.assertRaises(ValueError):
            z.mean().backward()

147
    def test_pylayer_dtype(self):
148
        class tanh(PyLayer):
149 150 151 152 153 154 155 156 157 158
            @staticmethod
            def forward(ctx, x, dtype):
                y = paddle.cast(x, dtype)
                return y

            @staticmethod
            def backward(ctx, dy1):
                return dy1

        dtypes = [
159 160 161 162 163 164 165
            'bool',
            'float16',
            'float32',
            'float64',
            'uint8',
            'int32',
            'int64',
166 167
        ]
        for dtype in dtypes:
168
            input1 = paddle.randn([2, 3])
169
            input1.stop_gradient = False
170
            self.assertIsNone(input1.grad)
171 172 173 174

            z = tanh.apply(input1, dtype)
            z = paddle.cast(z, "float32")
            z.sum().backward()
175
            self.assertIsNotNone(input1.grad)
176

177
    def test_pylayer_Exception_forward(self):
178
        class Layer_None1(PyLayer):
179 180 181 182 183 184 185 186 187
            @staticmethod
            def forward(ctx, *args):
                return None

            @staticmethod
            def backward(ctx, *args):
                return args

        input1 = paddle.randn([2, 3]).astype("float64")
W
WeiXin 已提交
188
        with self.assertRaises(ValueError):
189 190
            z = Layer_None1.apply(input1)

191
        class Layer_None2(PyLayer):
192 193
            @staticmethod
            def forward(ctx, *args):
W
WeiXin 已提交
194
                return [None, args[0]]
195 196 197 198 199 200

            @staticmethod
            def backward(ctx, *args):
                return args

        input1 = paddle.randn([2, 3]).astype("float64")
W
WeiXin 已提交
201 202
        # return None
        z = Layer_None2.apply(input1)
203

204
        class Layer_one1(PyLayer):
205 206 207 208 209 210 211 212 213
            @staticmethod
            def forward(ctx, *args):
                return 1

            @staticmethod
            def backward(ctx, *args):
                return args

        input1 = paddle.randn([2, 3]).astype("float64")
W
WeiXin 已提交
214 215
        # At least one output of `PyLayer.backward` is a `Tensor`
        with self.assertRaises(ValueError):
216 217
            z = Layer_one1.apply(input1)

218
        class Layer_one2(PyLayer):
219 220
            @staticmethod
            def forward(ctx, *args):
W
WeiXin 已提交
221
                return [1, 2, args[0]]
222 223 224 225 226 227

            @staticmethod
            def backward(ctx, *args):
                return args

        input1 = paddle.randn([2, 3]).astype("float64")
228
        # return int
W
WeiXin 已提交
229
        z = Layer_one2.apply(input1)
230

231
        class Layer_no_fw(PyLayer):
232 233 234 235 236 237 238 239
            @staticmethod
            def backward(ctx, *args):
                return args

        input1 = paddle.randn([2, 3]).astype("float64")
        with self.assertRaises(NotImplementedError):
            z = Layer_no_fw.apply(input1)

240
    def test_pylayer_nograd(self):
241
        class tanh(PyLayer):
242 243 244 245 246 247 248 249 250 251 252 253 254 255
            @staticmethod
            def forward(ctx, x1, func1, func2=paddle.square, xx=None):
                ctx.func = func2
                y1 = func1(x1)
                return y1

            @staticmethod
            def backward(ctx, x1, y1, dy1):
                re1 = dy1 * (1 - ctx.func(y1))
                return re1

        input1 = paddle.randn([2, 3]).astype("float64")
        z = tanh.apply(input1, paddle.tanh, paddle.square)
        z.mean().backward()
256
        self.assertIsNone(z.grad)
257

258
    def test_pylayer_Exception_bk(self):
259
        class Layer_bk_none1(PyLayer):
260 261 262 263 264 265 266 267 268 269 270 271
            @staticmethod
            def forward(ctx, x):
                return x * 2

            @staticmethod
            def backward(ctx, dy1):
                return None

        input2 = paddle.randn([2, 3]).astype("float64")
        input2.stop_gradient = False
        z = Layer_bk_none1.apply(input2)

272
        with self.assertRaises(ValueError):
273
            z.sum().backward()
274

275
        class Layer_bk_none2(PyLayer):
276 277 278 279 280 281 282 283 284 285 286
            @staticmethod
            def forward(ctx, x1, x2):
                return x1 + x2

            @staticmethod
            def backward(ctx, dy1):
                return None, dy1

        input1 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = False
        z = Layer_bk_none2.apply(input1, input1)
287

288
        with self.assertRaises(ValueError):
289
            z.mean().backward()
290

291
        class Layer_bk_one1(PyLayer):
292 293 294 295 296 297 298 299 300 301 302
            @staticmethod
            def forward(ctx, x):
                return x + x

            @staticmethod
            def backward(ctx, dy):
                return 1

        input1 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = False
        z = Layer_bk_one1.apply(input1)
303

304
        with self.assertRaises(ValueError):
305
            z.mean().backward()
306

307
        class Layer_bk_one2(PyLayer):
308
            @staticmethod
309 310
            def forward(ctx, x1, x2):
                return x1 * 2, x2 * 5
311 312 313 314 315 316 317

            @staticmethod
            def backward(ctx, *args):
                return 1, 1

        input1 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = False
318

319 320 321
        y = Layer_bk_one2.apply(input1, input1)
        z = y[0] + y[1]
        with self.assertRaises(ValueError):
322
            z.mean().backward()
323

324
        class Layer_no_bk(PyLayer):
325 326 327 328 329 330 331 332
            @staticmethod
            def forward(ctx, x):
                return x * 2, x * 5

        input1 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = False
        z = Layer_no_bk.apply(input1)

333 334 335
        with self.assertRaises(OSError):
            z = z[0] + z[1]
            z.mean().backward()
336

337
        class Layer_bk_match(PyLayer):
338 339 340 341 342 343 344 345 346 347 348 349
            @staticmethod
            def forward(ctx, x):
                return x * 2, x * 5

            @staticmethod
            def backward(ctx, dy1, dy2):
                return dy2 * 2, dy1 * 2

        input1 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = False
        z = Layer_bk_match.apply(input1)
        with self.assertRaises(ValueError):
350 351
            z = z[0] + z[1]
            z.mean().backward()
352

353
    def test_pylayer_bk_return_none(self):
354
        class Layer_bk_none1(PyLayer):
355 356 357 358 359 360 361 362 363 364 365 366 367 368 369
            @staticmethod
            def forward(ctx, x1, x2):
                return x1 + x2

            @staticmethod
            def backward(ctx, dy):
                return 1

        input1 = paddle.randn([2, 3]).astype("float64")
        input2 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = True
        input2.stop_gradient = False
        z = Layer_bk_none1.apply(input1, input2)

        with self.assertRaises(ValueError):
370
            z.mean().backward()
371

372
        class Layer_bk_none2(PyLayer):
373 374 375 376 377 378 379 380 381 382 383 384 385 386 387
            @staticmethod
            def forward(ctx, x1, x2):
                return x1 * 2, x2 * 5

            @staticmethod
            def backward(ctx, *args):
                return 1, 1

        input1 = paddle.randn([2, 3]).astype("float64")
        input2 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = True
        input2.stop_gradient = False
        z = Layer_bk_none2.apply(input1, input2)
        z = z[0] + z[1]
        with self.assertRaises(ValueError):
388
            z.mean().backward()
389

390
    def test_pylayer_inplace(self):
391
        class cus_tanh(PyLayer):
392 393
            @staticmethod
            def forward(ctx, x):
394
                return x
395 396 397 398 399

            @staticmethod
            def backward(ctx, dy):
                return dy

400 401
        class Layer(paddle.nn.Layer):
            def __init__(self):
402
                super().__init__()
403 404

            def forward(self, data):
405
                data = data**2
406 407 408 409
                z = paddle.tanh(data)
                z = cus_tanh.apply(data)
                return z.mean()

410 411 412
        for i in range(2):
            data = paddle.ones([2, 3], dtype="float64") / (i + 1)
            data.stop_gradient = False
413 414 415
            layer = Layer()
            z = layer(data)
            z.backward()
416
            self.assertIsNotNone(data.grad)
417

418
    def test_pylayer_inplace_backward_error(self):
419
        class cus_tanh(PyLayer):
420 421 422
            @staticmethod
            def forward(ctx, x):
                return x
423

424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451
            @staticmethod
            def backward(ctx, dy):
                return dy

        class Layer(paddle.nn.Layer):
            def __init__(self):
                super().__init__()

            def forward(self, data):
                var_b = data**2
                var_c = var_b**2
                z = cus_tanh.apply(var_b)
                loss = paddle.nn.functional.relu(var_c)
                return loss

        data = paddle.ones([2, 3], dtype="float64")
        data.stop_gradient = False
        layer = Layer()
        z = layer(data)
        with self.assertRaisesRegexp(
            RuntimeError,
            "received tensor_version:{} != wrapper_version_snapshot:{}".format(
                1, 0
            ),
        ):
            z.backward()

    def test_pylayer_inplace_backward_success_1(self):
452
        class cus_tanh(PyLayer):
453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473
            @staticmethod
            def forward(ctx, x):
                return x

            @staticmethod
            def backward(ctx, dy):
                return dy

        class Layer(paddle.nn.Layer):
            def __init__(self):
                super().__init__()

            def forward(self, data):
                var_b = data**2
                var_c = cus_tanh.apply(var_b)
                var_d = var_c**2
                loss = var_d.sum()
                return loss

        for i in range(2):
            data = paddle.ones([2, 3], dtype="float64") / (i + 1)
474 475 476
            data.stop_gradient = False
            layer = Layer()
            z = layer(data)
477 478
            z.backward()
            self.assertIsNotNone(data.grad)
479

480
    def test_pylayer_inplace_backward_success_2(self):
481
        class cus_tanh(PyLayer):
482 483 484
            @staticmethod
            def forward(ctx, x):
                return x
485

486 487 488
            @staticmethod
            def backward(ctx, dy):
                return dy
489

490 491 492
        class Layer(paddle.nn.Layer):
            def __init__(self):
                super().__init__()
493

494 495 496 497 498 499
            def forward(self, data):
                var_b = data**2
                var_c = cus_tanh.apply(var_b)
                var_d = var_c + var_c
                loss = var_d.sum()
                return loss
500

501 502 503 504 505 506 507 508 509
        for i in range(2):
            data = paddle.ones([2, 3], dtype="float64") / (i + 1)
            data.stop_gradient = False
            layer = Layer()
            z = layer(data)
            z.backward()
            self.assertIsNotNone(data.grad)

    def test_pylayer_inplace_and_leaf_exception(self):
510
        class cus_pylayer_op(PyLayer):
511 512 513 514 515 516 517 518 519 520
            @staticmethod
            def forward(ctx, x):
                return x

            @staticmethod
            def backward(ctx, dy):
                return dy

        class Layer(paddle.nn.Layer):
            def __init__(self):
521
                super().__init__()
522 523 524 525 526 527 528 529 530 531 532 533 534

            def forward(self, data):
                z = cus_pylayer_op.apply(data)
                return z.mean()

        for i in range(2):
            data = paddle.ones([2, 3], dtype="float64") / (i + 1)
            data.stop_gradient = False
            layer = Layer()

            with self.assertRaises(ValueError):
                z = layer(data)

535
    def test_backward_in_backward(self):
536
        class cus_tanh(PyLayer):
537 538 539 540 541 542 543 544 545 546 547 548 549
            @staticmethod
            def forward(ctx, x):
                temp = x.detach()
                ctx.inputs = temp
                return x.mean()

            @staticmethod
            def backward(ctx, dy):
                with paddle.set_grad_enabled(True):
                    temp = ctx.inputs
                    temp.stop_gradient = False
                    z = paddle.tanh(temp)
                    z.backward()
550
                    self.assertIsNotNone(temp.grad)
551 552 553 554 555
                    return paddle.to_tensor(temp.grad)

        for i in range(2):
            data = paddle.ones([2, 3], dtype="float32") / (i + 1)
            data.stop_gradient = False
556 557 558 559
            data = paddle.nn.functional.relu(data)
            z = paddle.tanh(data)
            z = cus_tanh.apply(data)

560
    def test_return_to_tensor(self):
561
        class Tanh(PyLayer):
562 563 564 565 566 567 568 569 570
            @staticmethod
            def forward(ctx, x1):
                y1 = paddle.tanh(x1)
                ctx.save_for_backward(y1)
                tensor_1 = paddle.to_tensor([1, 2], dtype='float32')
                return y1, 5, None, "helloworld", tensor_1

            @staticmethod
            def backward(ctx, dy1, dy2):
571
                (y1,) = ctx.saved_tensor()
572 573 574 575 576 577 578 579 580 581
                re1 = dy1 * (1 - paddle.square(y1))
                return dy1

        input1 = paddle.randn([2, 3]).astype("float32")
        input2 = input1.detach().clone()
        input1.stop_gradient = False
        input2.stop_gradient = False
        z, number, none_item, string_item, tensor1 = Tanh.apply(x1=input1)
        z.mean().backward()

W
wanghuancoder 已提交
582
    def test_materialize_grads(self):
583
        class Tanh(PyLayer):
584 585 586 587
            @staticmethod
            def forward(ctx, x):
                ctx.mark_not_inplace(x)
                return x, x + x
W
wanghuancoder 已提交
588

589 590 591 592
            @staticmethod
            def backward(ctx, grad, grad2):
                self.assertEqual(grad2, paddle.zeros([1]))
                return grad
W
wanghuancoder 已提交
593

594 595 596
        x = paddle.ones([1], dtype="float64")
        x.stop_gradient = False
        Tanh.apply(x)[0].backward()
W
wanghuancoder 已提交
597 598

    def test_dont_materialize_grads(self):
599
        class Tanh(PyLayer):
600 601 602 603 604
            @staticmethod
            def forward(ctx, x):
                ctx.mark_not_inplace(x)
                ctx.set_materialize_grads(False)
                return x, x + x
W
wanghuancoder 已提交
605

606 607 608 609
            @staticmethod
            def backward(ctx, grad, grad2):
                self.assertIsNone(grad2)
                return grad
W
wanghuancoder 已提交
610

611 612 613
        x = paddle.ones([1], dtype="float64")
        x.stop_gradient = False
        Tanh.apply(x)[0].backward()
W
wanghuancoder 已提交
614 615

    def test_mark_non_differentiable(self):
616
        class Tanh(PyLayer):
617 618 619 620 621
            @staticmethod
            def forward(ctx, x):
                a = x + x
                ctx.mark_non_differentiable(a)
                return a
W
wanghuancoder 已提交
622

623 624 625 626
            @staticmethod
            def backward(ctx, grad):
                self.assertTrue(False)  # should not be call
                return paddle.ones([1], dtype="float64")
W
wanghuancoder 已提交
627

628 629 630 631
        x = paddle.ones([1], dtype="float64")
        x.stop_gradient = False
        y = Tanh.apply(x)
        y.sum().backward()
W
wanghuancoder 已提交
632 633

    def test_mark_non_differentiable2(self):
634
        class Tanh(PyLayer):
635 636 637 638 639 640
            @staticmethod
            def forward(ctx, x):
                a = x + x
                b = x + x + x
                ctx.mark_non_differentiable(a)
                return a, b
W
wanghuancoder 已提交
641

642 643 644 645 646
            @staticmethod
            def backward(ctx, grad_a, grad_b):
                self.assertEqual(grad_a, paddle.zeros([1]))
                self.assertEqual(grad_b, paddle.ones([1], dtype="float64"))
                return grad_b
W
wanghuancoder 已提交
647

648 649 650 651 652
        x = paddle.ones([1], dtype="float64")
        x.stop_gradient = False
        a, b = Tanh.apply(x)
        b.sum().backward()
        self.assertEqual(x.grad, paddle.ones([1], dtype="float64"))
W
wanghuancoder 已提交
653

654

655 656
if __name__ == '__main__':
    unittest.main()