test_pylayer_op.py 16.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import numpy as np

import paddle
from paddle.autograd import PyLayer


24 25 26 27 28
class FakeTensor(paddle.fluid.core.VarBase):
    def __init__(self):
        pass


29 30 31 32 33 34 35 36 37
class TestPyLayer(unittest.TestCase):
    def test_simple_pylayer_multiple_output(self):
        class tanh(PyLayer):
            @staticmethod
            def forward(ctx, x1, x2, func1, func2=paddle.square):
                ctx.func = func2
                y1 = func1(x1)
                y2 = func1(x2)
                ctx.save_for_backward(y1, y2)
W
WeiXin 已提交
38
                return y1, 1, y2, None
39 40 41 42 43 44 45 46 47 48 49 50 51

            @staticmethod
            def backward(ctx, dy1, dy2):
                y1, y2 = ctx.saved_tensor()
                re1 = dy1 * (1 - ctx.func(y1))
                re2 = dy2 * (1 - paddle.square(y2))
                return re1, re2

        input1 = paddle.randn([2, 3]).astype("float64")
        input2 = input1.detach().clone()
        input1.stop_gradient = False
        input2.stop_gradient = False
        z = tanh.apply(input1, input1, paddle.tanh, paddle.square)
W
WeiXin 已提交
52
        z = z[0] + z[2]
53 54 55 56 57
        z.mean().backward()

        z2 = paddle.tanh(input2) + paddle.tanh(input2)
        z2.mean().backward()

58 59
        self.assertTrue(
            np.max(np.abs((input1.grad.numpy() - input2.grad.numpy()))) < 1e-10)
60

61 62 63 64 65 66 67 68
    def test_simple_pylayer_return_none_with_no_grad(self):
        class tanh(PyLayer):
            @staticmethod
            def forward(ctx, x1, x2, func1, func2=paddle.square):
                ctx.func = func2
                y1 = func1(x1)
                y2 = func1(x2)
                ctx.save_for_backward(y1, y2)
W
WeiXin 已提交
69
                return 1, None, y1, y2, ''
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86

            @staticmethod
            def backward(ctx, dy1, dy2):
                y1, y2 = ctx.saved_tensor()
                re1 = dy1 * (1 - ctx.func(y1))
                re2 = dy2 * (1 - paddle.square(y2))
                return re1, None

        input1 = paddle.randn([2, 3]).astype("float64")
        input2 = input1.detach().clone()
        input3 = input1.detach().clone()
        input4 = input1.detach().clone()
        input1.stop_gradient = False
        input2.stop_gradient = False
        input3.stop_gradient = True
        input4.stop_gradient = True
        z = tanh.apply(input1, input3, paddle.tanh, paddle.square)
W
WeiXin 已提交
87
        z = z[2] + z[3]
88 89 90 91 92
        z.mean().backward()

        z2 = paddle.tanh(input2) + paddle.tanh(input4)
        z2.mean().backward()

W
WeiXin 已提交
93 94
        self.assertTrue(
            np.max(np.abs((input1.grad.numpy() - input2.grad.numpy()))) < 1e-10)
95

96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
    def test_simple_pylayer_single_output(self):
        class tanh(PyLayer):
            @staticmethod
            def forward(ctx, x1, func1, func2=paddle.square):
                ctx.func = func2
                y1 = func1(x1)
                ctx.save_for_backward(y1)
                return y1

            @staticmethod
            def backward(ctx, dy1):
                y1, = ctx.saved_tensor()
                re1 = dy1 * (1 - ctx.func(y1))
                return re1

        input1 = paddle.randn([2, 3]).astype("float64")
        input2 = input1.detach().clone()
        input1.stop_gradient = False
        input2.stop_gradient = False
        z = tanh.apply(x1=input1, func1=paddle.tanh)
        z.mean().backward()
        z2 = paddle.tanh(input2)
        z2.mean().backward()

120 121
        self.assertTrue(
            np.max(np.abs((input1.grad.numpy() - input2.grad.numpy()))) < 1e-10)
122

W
WeiXin 已提交
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
    def test_pylayer_num_output_match(self):
        class tanh(PyLayer):
            @staticmethod
            def forward(
                    ctx,
                    x1,
                    x2, ):
                return x1 + x2

            @staticmethod
            def backward(ctx, dy1):
                return dy1 + 1

        input1 = paddle.randn([2, 3]).astype("float64")
        input2 = input1.detach().clone()
        input1.stop_gradient = False
        input2.stop_gradient = False
        z = tanh.apply(input1, input2)
        with self.assertRaises(ValueError):
            z.mean().backward()

144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
    def test_pylayer_dtype(self):
        class tanh(PyLayer):
            @staticmethod
            def forward(ctx, x, dtype):
                y = paddle.cast(x, dtype)
                return y

            @staticmethod
            def backward(ctx, dy1):
                return dy1

        dtypes = [
            'bool', 'float16', 'float32', 'float64', 'uint8', 'int32', 'int64'
        ]
        for dtype in dtypes:
            input1 = (paddle.randn([2, 3]))
            input1.stop_gradient = False
            self.assertTrue(input1.grad is None)

            z = tanh.apply(input1, dtype)
            z = paddle.cast(z, "float32")
            z.sum().backward()
            self.assertTrue(input1.grad is not None)

    def test_pylayer_Exception_forward(self):
        class Layer_None1(PyLayer):
            @staticmethod
            def forward(ctx, *args):
                return None

            @staticmethod
            def backward(ctx, *args):
                return args

        input1 = paddle.randn([2, 3]).astype("float64")
W
WeiXin 已提交
179
        with self.assertRaises(ValueError):
180 181 182 183 184
            z = Layer_None1.apply(input1)

        class Layer_None2(PyLayer):
            @staticmethod
            def forward(ctx, *args):
W
WeiXin 已提交
185
                return [None, args[0]]
186 187 188 189 190 191

            @staticmethod
            def backward(ctx, *args):
                return args

        input1 = paddle.randn([2, 3]).astype("float64")
W
WeiXin 已提交
192 193
        # return None
        z = Layer_None2.apply(input1)
194 195 196 197 198 199 200 201 202 203 204

        class Layer_one1(PyLayer):
            @staticmethod
            def forward(ctx, *args):
                return 1

            @staticmethod
            def backward(ctx, *args):
                return args

        input1 = paddle.randn([2, 3]).astype("float64")
W
WeiXin 已提交
205 206
        # At least one output of `PyLayer.backward` is a `Tensor`
        with self.assertRaises(ValueError):
207 208 209 210 211
            z = Layer_one1.apply(input1)

        class Layer_one2(PyLayer):
            @staticmethod
            def forward(ctx, *args):
W
WeiXin 已提交
212
                return [1, 2, args[0]]
213 214 215 216 217 218

            @staticmethod
            def backward(ctx, *args):
                return args

        input1 = paddle.randn([2, 3]).astype("float64")
W
WeiXin 已提交
219 220
        # return int 
        z = Layer_one2.apply(input1)
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262

        class Layer_no_fw(PyLayer):
            @staticmethod
            def backward(ctx, *args):
                return args

        input1 = paddle.randn([2, 3]).astype("float64")
        with self.assertRaises(NotImplementedError):
            z = Layer_no_fw.apply(input1)

    def test_pylayer_nograd(self):
        class tanh(PyLayer):
            @staticmethod
            def forward(ctx, x1, func1, func2=paddle.square, xx=None):
                ctx.func = func2
                y1 = func1(x1)
                return y1

            @staticmethod
            def backward(ctx, x1, y1, dy1):
                re1 = dy1 * (1 - ctx.func(y1))
                return re1

        input1 = paddle.randn([2, 3]).astype("float64")
        z = tanh.apply(input1, paddle.tanh, paddle.square)
        z.mean().backward()
        self.assertTrue(z.grad is None)

    def test_pylayer_Exception_bk(self):
        class Layer_bk_none1(PyLayer):
            @staticmethod
            def forward(ctx, x):
                return x * 2

            @staticmethod
            def backward(ctx, dy1):
                return None

        input2 = paddle.randn([2, 3]).astype("float64")
        input2.stop_gradient = False
        z = Layer_bk_none1.apply(input2)

263
        with self.assertRaises(ValueError):
264
            z.sum().backward()
265 266 267 268 269 270 271 272 273 274 275 276 277

        class Layer_bk_none2(PyLayer):
            @staticmethod
            def forward(ctx, x1, x2):
                return x1 + x2

            @staticmethod
            def backward(ctx, dy1):
                return None, dy1

        input1 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = False
        z = Layer_bk_none2.apply(input1, input1)
278

279
        with self.assertRaises(ValueError):
280
            z.mean().backward()
281 282 283 284 285 286 287 288 289 290 291 292 293

        class Layer_bk_one1(PyLayer):
            @staticmethod
            def forward(ctx, x):
                return x + x

            @staticmethod
            def backward(ctx, dy):
                return 1

        input1 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = False
        z = Layer_bk_one1.apply(input1)
294

295
        with self.assertRaises(ValueError):
296
            z.mean().backward()
297 298 299

        class Layer_bk_one2(PyLayer):
            @staticmethod
300 301
            def forward(ctx, x1, x2):
                return x1 * 2, x2 * 5
302 303 304 305 306 307 308

            @staticmethod
            def backward(ctx, *args):
                return 1, 1

        input1 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = False
309

310 311 312
        y = Layer_bk_one2.apply(input1, input1)
        z = y[0] + y[1]
        with self.assertRaises(ValueError):
313
            z.mean().backward()
314 315 316 317 318 319 320 321 322 323

        class Layer_no_bk(PyLayer):
            @staticmethod
            def forward(ctx, x):
                return x * 2, x * 5

        input1 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = False
        z = Layer_no_bk.apply(input1)

324 325 326
        with self.assertRaises(OSError):
            z = z[0] + z[1]
            z.mean().backward()
327 328 329 330 331 332 333 334 335 336 337 338 339 340

        class Layer_bk_match(PyLayer):
            @staticmethod
            def forward(ctx, x):
                return x * 2, x * 5

            @staticmethod
            def backward(ctx, dy1, dy2):
                return dy2 * 2, dy1 * 2

        input1 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = False
        z = Layer_bk_match.apply(input1)
        with self.assertRaises(ValueError):
341 342
            z = z[0] + z[1]
            z.mean().backward()
343

344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360
    def test_pylayer_bk_return_none(self):
        class Layer_bk_none1(PyLayer):
            @staticmethod
            def forward(ctx, x1, x2):
                return x1 + x2

            @staticmethod
            def backward(ctx, dy):
                return 1

        input1 = paddle.randn([2, 3]).astype("float64")
        input2 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = True
        input2.stop_gradient = False
        z = Layer_bk_none1.apply(input1, input2)

        with self.assertRaises(ValueError):
361
            z.mean().backward()
362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378

        class Layer_bk_none2(PyLayer):
            @staticmethod
            def forward(ctx, x1, x2):
                return x1 * 2, x2 * 5

            @staticmethod
            def backward(ctx, *args):
                return 1, 1

        input1 = paddle.randn([2, 3]).astype("float64")
        input2 = paddle.randn([2, 3]).astype("float64")
        input1.stop_gradient = True
        input2.stop_gradient = False
        z = Layer_bk_none2.apply(input1, input2)
        z = z[0] + z[1]
        with self.assertRaises(ValueError):
379
            z.mean().backward()
380

381 382 383 384
    def test_pylayer_inplace(self):
        class cus_tanh(PyLayer):
            @staticmethod
            def forward(ctx, x):
385
                return x
386 387 388 389 390

            @staticmethod
            def backward(ctx, dy):
                return dy

391 392 393 394 395 396 397 398 399 400
        class Layer(paddle.nn.Layer):
            def __init__(self):
                super(Layer, self).__init__()

            def forward(self, data):
                data = paddle.nn.functional.relu(data)
                z = paddle.tanh(data)
                z = cus_tanh.apply(data)
                return z.mean()

401 402 403
        for i in range(2):
            data = paddle.ones([2, 3], dtype="float64") / (i + 1)
            data.stop_gradient = False
404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429
            layer = Layer()
            z = layer(data)
            z.backward()
            self.assertTrue(data.grad is not None)

    def test_backward_in_backward(self):
        class cus_tanh(PyLayer):
            @staticmethod
            def forward(ctx, x):
                temp = x.detach()
                ctx.inputs = temp
                return x.mean()

            @staticmethod
            def backward(ctx, dy):
                with paddle.set_grad_enabled(True):
                    temp = ctx.inputs
                    temp.stop_gradient = False
                    z = paddle.tanh(temp)
                    z.backward()
                    self.assertTrue(temp.grad is not None)
                    return paddle.to_tensor(temp.grad)

        for i in range(2):
            data = paddle.ones([2, 3], dtype="float32") / (i + 1)
            data.stop_gradient = False
430 431 432 433
            data = paddle.nn.functional.relu(data)
            z = paddle.tanh(data)
            z = cus_tanh.apply(data)

434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556
    def test_return_to_tensor(self):
        class Tanh(PyLayer):
            @staticmethod
            def forward(ctx, x1):
                y1 = paddle.tanh(x1)
                ctx.save_for_backward(y1)
                tensor_1 = paddle.to_tensor([1, 2], dtype='float32')
                return y1, 5, None, "helloworld", tensor_1

            @staticmethod
            def backward(ctx, dy1, dy2):
                y1, = ctx.saved_tensor()
                re1 = dy1 * (1 - paddle.square(y1))
                return dy1

        input1 = paddle.randn([2, 3]).astype("float32")
        input2 = input1.detach().clone()
        input1.stop_gradient = False
        input2.stop_gradient = False
        z, number, none_item, string_item, tensor1 = Tanh.apply(x1=input1)
        z.mean().backward()


class TestPyLayerReturnType(unittest.TestCase):
    def test_forward_args_fake_tensor(self):
        class Tanh(PyLayer):
            @staticmethod
            def forward(ctx, x1):
                y1 = FakeTensor()
                return y1, x1

            @staticmethod
            def backward(ctx, dy1, dy2):
                return dy1

        input1 = FakeTensor()

        with self.assertRaises(ValueError):
            y1, y2 = Tanh.apply(input1)

    def test_forward_kwargs_fake_tensor(self):
        class Tanh(PyLayer):
            @staticmethod
            def forward(ctx, x1):

                return x1

            @staticmethod
            def backward(ctx, dy1, dy2):
                return dy1

        input1 = FakeTensor()

        with self.assertRaises(ValueError):
            y = Tanh.apply(x1=input1)

    def test_forward_return_fake_tensor(self):
        class Tanh(PyLayer):
            @staticmethod
            def forward(ctx, x1):

                return FakeTensor()

            @staticmethod
            def backward(ctx, dy1, dy2):
                return dy1

        input1 = paddle.randn([3, 2])

        with self.assertRaises(ValueError):
            y = Tanh.apply(x1=input1)

    def test_forward_return_fake_tensor_tuple(self):
        class Tanh(PyLayer):
            @staticmethod
            def forward(ctx, x1):

                return FakeTensor(), FakeTensor()

            @staticmethod
            def backward(ctx, dy1, dy2):
                return dy1

        input1 = paddle.randn([3, 2])

        with self.assertRaises(ValueError):
            y = Tanh.apply(x1=input1)

    def test_backward_return_fake_tensor_tuple(self):
        class Tanh(PyLayer):
            @staticmethod
            def forward(ctx, x1, x2):
                return x1 + 1, x1 + 2

            @staticmethod
            def backward(ctx, dy1, dy2):

                return FakeTensor(), 2

        input1 = paddle.randn([3, 2])
        input1.stop_gradient = False
        y, _ = Tanh.apply(input1, 1 + input1)

        with self.assertRaises(ValueError):
            y.mean().backward()

    def test_backward_return_fake_tensor(self):
        class Tanh(PyLayer):
            @staticmethod
            def forward(ctx, x1):
                return x1 + 1, x1 + 2

            @staticmethod
            def backward(ctx, dy1, dy2):
                return FakeTensor()

        input1 = paddle.randn([3, 2])
        input1.stop_gradient = False
        y, _ = Tanh.apply(input1)

        with self.assertRaises(ValueError):
            y.mean().backward()

557 558 559

if __name__ == '__main__':
    unittest.main()