test_inplace.py 17.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import numpy as np

import paddle
import paddle.fluid.core as core
22
from paddle.fluid.framework import _test_eager_guard, in_dygraph_mode
23 24 25


class TestInplace(unittest.TestCase):
26

27
    def func_test_forward_version(self):
28 29 30 31 32 33 34
        with paddle.fluid.dygraph.guard():
            var = paddle.to_tensor(np.ones((4, 2, 3)).astype(np.float32))
            self.assertEqual(var.inplace_version, 0)

            var[0] = 1.1
            self.assertEqual(var.inplace_version, 1)

35
            paddle.assign(paddle.ones(shape=[3]), var)
36 37 38 39 40 41 42 43

            # NOTE(liym27): assign(input, output) is an inplace operation for output.
            # There is inplace-related processing for api assign, var.inplace_version should be 2 not 1.
            self.assertEqual(var.inplace_version, 2)

            var[2] = 3
            self.assertEqual(var.inplace_version, 3)

44 45 46 47 48 49
    def test_forward_version(self):
        with _test_eager_guard():
            self.func_test_forward_version()
        self.func_test_forward_version()

    def func_test_backward_error(self):
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
        # It raises an error because the inplace operator will result
        # in incorrect gradient computation.
        with paddle.fluid.dygraph.guard():
            var_a = paddle.ones(shape=[4, 2, 3], dtype="float32")
            var_a.stop_gradient = False

            var_b = var_a**2

            # Here, the gradient computation will use the value of var_b
            var_c = var_b**2
            var_b[1:2] = 3.3  # var_b is modified inplace after using it

            var_d = var_b**2

            loss = paddle.nn.functional.relu(var_c + var_d)
65 66 67 68 69
            with self.assertRaisesRegexp(
                    RuntimeError,
                    "received tensor_version:{} != wrapper_version_snapshot:{}".
                    format(1, 0)):
                loss.backward()
70

71 72 73 74 75 76
    def test_backward_error(self):
        with _test_eager_guard():
            self.func_test_backward_error()
        self.func_test_backward_error()

    def func_test_backward_success_1(self):
77 78 79 80 81 82 83 84 85 86 87 88 89 90
        # var_b is modified inplace before using it, the inplace operator doesn't result
        # in incorrect gradient computation.
        with paddle.fluid.dygraph.guard():
            var_a = paddle.ones(shape=[4, 2, 3], dtype="float32")
            var_a.stop_gradient = False

            var_b = var_a**2
            var_b[1:2] = 3  # var_b is modified inplace before using it

            # Here, the gradient computation will use the value of var_b
            var_c = var_b**2
            loss = var_c.sum()
            loss.backward()

91 92 93 94 95 96
    def test_backward_success_1(self):
        with _test_eager_guard():
            self.func_test_backward_success_1()
        self.func_test_backward_success_1()

    def func_test_backward_success_2(self):
97 98 99 100 101 102 103 104 105 106
        # Although var_b is modified inplace after using it, it does not used in gradient computation.
        # The inplace operator doesn't result in incorrect gradient computation.
        with paddle.fluid.dygraph.guard():
            var_a = paddle.ones(shape=[4, 2, 3], dtype="float32")
            var_a.stop_gradient = False

            var_b = var_a**2

            var_b[1:2] = 3  # var_b is modified inplace before using it

107
            var_c = var_b + var_b  # Here, the grad op of sum doesn't use the value of var_b
108 109 110 111 112 113
            loss = var_c.sum()

            var_b[1:2] = 3  # var_b is modified inplace after using it

            loss.backward()

114
    def test_backward_success_2(self):
115 116
        with _test_eager_guard():
            self.func_test_backward_success_2()
117 118
        self.func_test_backward_success_2()

119

120
class TestDygraphInplace(unittest.TestCase):
121

122 123
    def setUp(self):
        self.init_data()
124
        self.set_np_compare_func()
125 126

    def init_data(self):
127
        self.input_var_numpy = np.random.uniform(-5, 5, [10, 20, 1])
128 129
        self.dtype = "float32"

130 131 132
    def set_np_compare_func(self):
        self.np_compare = np.array_equal

133 134 135 136 137 138
    def non_inplace_api_processing(self, var):
        return paddle.squeeze(var)

    def inplace_api_processing(self, var):
        return paddle.squeeze_(var)

139
    def func_test_inplace_api(self):
140 141 142 143 144
        var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
        inplace_var = self.inplace_api_processing(var)
        self.assertTrue(id(var) == id(inplace_var))

        inplace_var[0] = 2.
145
        np.testing.assert_array_equal(var.numpy(), inplace_var.numpy())
146

147 148 149 150 151 152
    def test_inplace_api(self):
        with _test_eager_guard():
            self.func_test_inplace_api()
        self.func_test_inplace_api()

    def func_test_forward_version(self):
153 154 155 156 157 158 159 160 161 162 163 164 165
        with paddle.fluid.dygraph.guard():
            var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
            self.assertEqual(var.inplace_version, 0)

            inplace_var = self.inplace_api_processing(var)
            self.assertEqual(var.inplace_version, 1)

            inplace_var[0] = 2.
            self.assertEqual(var.inplace_version, 2)

            inplace_var = self.inplace_api_processing(inplace_var)
            self.assertEqual(var.inplace_version, 3)

166 167 168 169 170 171
    def test_forward_version(self):
        with _test_eager_guard():
            self.func_test_forward_version()
        self.func_test_forward_version()

    def func_test_leaf_inplace_var_error(self):
172 173 174 175 176 177 178 179 180
        with paddle.fluid.dygraph.guard():
            var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
            var.stop_gradient = False

            def leaf_inplace_error():
                self.inplace_api_processing(var)

            self.assertRaises(ValueError, leaf_inplace_error)

181 182 183 184 185 186
    def test_leaf_inplace_var_error(self):
        with _test_eager_guard():
            self.func_test_leaf_inplace_var_error()
        self.func_test_leaf_inplace_var_error()

    def func_test_backward_error(self):
187 188 189 190 191 192 193 194 195 196 197 198 199
        # It raises an error because the inplace operator will result
        # in incorrect gradient computation.
        with paddle.fluid.dygraph.guard():
            var_a = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
            var_a.stop_gradient = False

            var_b = var_a**2

            # Here, the gradient computation will use the value of var_b
            var_c = var_b**2
            self.inplace_api_processing(var_b)

            loss = paddle.nn.functional.relu(var_c)
200 201 202 203 204
            with self.assertRaisesRegexp(
                    RuntimeError,
                    "received tensor_version:{} != wrapper_version_snapshot:{}".
                    format(1, 0)):
                loss.backward()
205

206 207 208 209 210 211
    def test_backward_error(self):
        with _test_eager_guard():
            self.func_test_backward_error()
        self.func_test_backward_error()

    def func_test_backward_success_1(self):
212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
        # var_b is modified inplace before using it, the inplace operator doesn't result
        # in incorrect gradient computation.
        grad_var_a, grad_var_a_inplace = 0, 1
        with paddle.fluid.dygraph.guard():
            var_a = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
            var_a.stop_gradient = False

            var_b = var_a**2
            var_c = self.inplace_api_processing(
                var_b)  # var_b is modified inplace before using it

            # Here, the gradient computation will use the value of var_b
            var_d = var_c**2
            loss = var_d.sum()
            loss.backward()
227
            grad_var_a_inplace = var_a.grad.numpy()
228 229 230 231 232 233 234 235 236 237

        with paddle.fluid.dygraph.guard():
            var_a = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
            var_a.stop_gradient = False

            var_b = var_a**2
            var_c = self.non_inplace_api_processing(var_b)
            var_d = var_c**2
            loss = var_d.sum()
            loss.backward()
238
            grad_var_a = var_a.grad.numpy()
239

240
        self.assertTrue(self.np_compare(grad_var_a_inplace, grad_var_a))
241

242 243 244 245 246 247
    def test_backward_success_1(self):
        with _test_eager_guard():
            self.func_test_backward_success_1()
        self.func_test_backward_success_1()

    def func_test_backward_success_2(self):
248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263
        # Although var_b is modified inplace after using it, it does not used in gradient computation.
        # The inplace operator doesn't result in incorrect gradient computation.
        grad_var_a, grad_var_a_inplace = 0, 1
        with paddle.fluid.dygraph.guard():
            var_a = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
            var_a.stop_gradient = False

            var_b = var_a**2

            var_c = self.inplace_api_processing(
                var_b)  # var_b is modified inplace before using it

            var_d = var_c + var_c  # Here, the grad op of sum doesn't use the value of var_b
            loss = var_d.sum()

            loss.backward()
264
            grad_var_a_inplace = var_a.grad.numpy()
265 266 267 268 269 270 271

        with paddle.fluid.dygraph.guard():
            var_a = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
            var_a.stop_gradient = False

            var_b = var_a**2

272
            var_c = self.non_inplace_api_processing(var_b)
273 274 275 276 277

            var_d = var_c + var_c  # Here, the grad op of sum doesn't use the value of var_b
            loss = var_d.sum()

            loss.backward()
278
            grad_var_a = var_a.grad.numpy()
279
        np.testing.assert_array_equal(grad_var_a_inplace, grad_var_a)
280

281 282 283 284 285
    def test_backward_success_2(self):
        with _test_eager_guard():
            self.func_test_backward_success_2()
        self.func_test_backward_success_2()

286 287

class TestDygraphInplaceUnsqueeze(TestDygraphInplace):
288

289 290 291 292 293 294 295 296
    def non_inplace_api_processing(self, var):
        return paddle.unsqueeze(var, -1)

    def inplace_api_processing(self, var):
        return paddle.unsqueeze_(var, -1)


class TestDygraphInplaceReshape(TestDygraphInplace):
297

298 299 300 301 302 303 304
    def non_inplace_api_processing(self, var):
        return paddle.reshape(var, [-1])

    def inplace_api_processing(self, var):
        return paddle.reshape_(var, [-1])


305 306 307 308 309 310 311 312 313 314 315
class TestDygraphInplaceReshapeTensor(TestDygraphInplace):

    def non_inplace_api_processing(self, var):
        shape = paddle.to_tensor(-1)
        return paddle.reshape(var, shape)

    def inplace_api_processing(self, var):
        shape = paddle.to_tensor(-1)
        return paddle.reshape_(var, shape)


316
class TestDygraphInplaceFlatten(TestDygraphInplace):
317

318 319 320 321 322 323 324
    def non_inplace_api_processing(self, var):
        return var.flatten()

    def inplace_api_processing(self, var):
        return var.flatten_()


325
class TestDygraphInplaceScatter(TestDygraphInplace):
326

327 328 329 330 331 332
    def init_data(self):
        self.input_var_numpy = np.array([[1, 1], [2, 2], [3, 3]])
        self.dtype = "float32"

    def non_inplace_api_processing(self, var):
        index = paddle.to_tensor([2, 1, 0, 1], dtype='int64')
333 334
        updates = paddle.to_tensor([[1, 1], [2, 2], [3, 3], [4, 4]],
                                   dtype='float32')
335 336 337 338 339

        return paddle.scatter(var, index, updates, overwrite=False)

    def inplace_api_processing(self, var):
        index = paddle.to_tensor([2, 1, 0, 1], dtype='int64')
340 341
        updates = paddle.to_tensor([[1, 1], [2, 2], [3, 3], [4, 4]],
                                   dtype='float32')
342 343 344 345 346

        return paddle.scatter_(var, index, updates, overwrite=False)


class TestDygraphInplaceElu(TestDygraphInplace):
347

348 349 350 351 352 353 354 355
    def non_inplace_api_processing(self, var):
        return paddle.nn.functional.elu(var)

    def inplace_api_processing(self, var):
        return paddle.nn.functional.elu_(var)


class TestDygraphInplaceRelu(TestDygraphInplace):
356

357 358 359 360 361 362 363 364
    def non_inplace_api_processing(self, var):
        return paddle.nn.functional.relu(var)

    def inplace_api_processing(self, var):
        return paddle.nn.functional.relu_(var)


class TestDygraphInplaceSoftmax(TestDygraphInplace):
365

366 367 368 369 370 371 372 373
    def non_inplace_api_processing(self, var):
        return paddle.nn.functional.softmax(var)

    def inplace_api_processing(self, var):
        return paddle.nn.functional.softmax_(var)


class TestDygraphInplaceTanh(TestDygraphInplace):
374

375 376 377 378 379 380 381
    def non_inplace_api_processing(self, var):
        return paddle.tanh(var)

    def inplace_api_processing(self, var):
        return paddle.tanh_(var)


382
class TestDygraphInplaceCeil(TestDygraphInplace):
383

384 385 386 387 388 389 390 391
    def non_inplace_api_processing(self, var):
        return var.ceil()

    def inplace_api_processing(self, var):
        return var.ceil_()


class TestDygraphInplaceFloor(TestDygraphInplace):
392

393 394 395 396 397 398 399 400
    def non_inplace_api_processing(self, var):
        return var.floor()

    def inplace_api_processing(self, var):
        return var.floor_()


class TestDygraphInplaceExp(TestDygraphInplace):
401

402 403 404 405 406 407 408 409 410 411 412
    def set_np_compare_func(self):
        self.np_compare = np.allclose

    def non_inplace_api_processing(self, var):
        return var.exp()

    def inplace_api_processing(self, var):
        return var.exp_()


class TestDygraphInplaceReciprocal(TestDygraphInplace):
413

414 415 416 417 418 419 420 421
    def non_inplace_api_processing(self, var):
        return var.reciprocal()

    def inplace_api_processing(self, var):
        return var.reciprocal_()


class TestDygraphInplaceRound(TestDygraphInplace):
422

423 424 425 426 427 428 429 430
    def non_inplace_api_processing(self, var):
        return var.round()

    def inplace_api_processing(self, var):
        return var.round_()


class TestDygraphInplaceSqrt(TestDygraphInplace):
431

432 433 434 435 436 437 438 439 440 441 442 443
    def init_data(self):
        self.input_var_numpy = np.random.uniform(0, 5, [10, 20, 1])
        self.dtype = "float32"

    def non_inplace_api_processing(self, var):
        return var.sqrt()

    def inplace_api_processing(self, var):
        return var.sqrt_()


class TestDygraphInplaceRsqrt(TestDygraphInplaceSqrt):
444

445 446 447 448 449 450 451 452
    def non_inplace_api_processing(self, var):
        return var.rsqrt()

    def inplace_api_processing(self, var):
        return var.rsqrt_()


class TestDygraphInplaceClip(TestDygraphInplace):
453

454 455 456 457 458 459 460 461
    def non_inplace_api_processing(self, var):
        return var.clip(0.6, 1.5)

    def inplace_api_processing(self, var):
        return var.clip_(0.6, 1.5)


class TestDygraphInplaceScale(TestDygraphInplace):
462

463 464 465 466 467 468 469 470
    def non_inplace_api_processing(self, var):
        return var.scale(scale=2.0, bias=3.0)

    def inplace_api_processing(self, var):
        return var.scale_(scale=2.0, bias=3.0)


class TestDygraphInplaceAdd(TestDygraphInplace):
471

472 473 474
    def init_data(self):
        self.input_var_numpy = np.random.rand(2, 3, 4)
        self.dtype = "float32"
475
        self.input_var_numpy_2 = np.random.rand(2, 3, 4).astype(self.dtype)
476 477

    def non_inplace_api_processing(self, var):
478 479
        input_var_2 = paddle.to_tensor(self.input_var_numpy_2)
        return var.add(input_var_2)
480 481

    def inplace_api_processing(self, var):
482 483
        input_var_2 = paddle.to_tensor(self.input_var_numpy_2)
        return var.add_(input_var_2)
484 485 486


class TestDygraphInplaceSubtract(TestDygraphInplaceAdd):
487

488
    def non_inplace_api_processing(self, var):
489 490
        input_var_2 = paddle.to_tensor(self.input_var_numpy_2)
        return var.subtract(input_var_2)
491 492

    def inplace_api_processing(self, var):
493 494
        input_var_2 = paddle.to_tensor(self.input_var_numpy_2)
        return var.subtract_(input_var_2)
495 496


497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519
class TestDygraphInplaceRemainder(TestDygraphInplaceAdd):

    def non_inplace_api_processing(self, var):
        input_var_2 = paddle.to_tensor(self.input_var_numpy_2)
        return var.remainder(input_var_2)

    def inplace_api_processing(self, var):
        input_var_2 = paddle.to_tensor(self.input_var_numpy_2)
        return var.remainder_(input_var_2)

    def test_leaf_inplace_var_error(self):
        pass

    def test_backward_error(self):
        pass

    def test_backward_success_1(self):
        pass

    def test_backward_success_2(self):
        pass


520
class TestLossIsInplaceVar(unittest.TestCase):
521

522
    def func_test_loss_is_inplace_var(self):
523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542
        with paddle.fluid.dygraph.guard():
            var_a = paddle.ones((2, 2))
            var_a.stop_gradient = False

            var_b = var_a * 2
            loss = var_b.tanh_()

            loss.backward()
            inplace_grad_var_a = var_a.grad.numpy()

        with paddle.fluid.dygraph.guard():
            var_a = paddle.ones((2, 2))
            var_a.stop_gradient = False

            var_b = var_a * 2
            loss = var_b.tanh()

            loss.backward()
            grad_var_a = var_a.grad.numpy()

543
        np.testing.assert_array_equal(inplace_grad_var_a, grad_var_a)
544

545 546 547 548 549
    def test_loss_is_inplace_var(self):
        with _test_eager_guard():
            self.func_test_loss_is_inplace_var()
        self.func_test_loss_is_inplace_var()

550

551
class TestContinuouslyInplace(unittest.TestCase):
552

553
    def func_test_continuously_inplace(self):
554 555 556 557 558 559 560 561 562 563
        a = paddle.rand([2, 3])
        a.stop_gradient = False
        b = a * 2

        b.reshape_([-1])
        b.reshape_([2, 3])
        b.reshape_([-1])

        b.backward()

564 565 566 567 568
    def test_continuously_inplace(self):
        with _test_eager_guard():
            self.func_test_continuously_inplace()
        self.func_test_continuously_inplace()

569

570
class TestGetitemBeforeInplace(unittest.TestCase):
571

572 573 574 575 576 577 578 579 580 581 582 583 584
    def test_getitem_before_inplace(self):
        with _test_eager_guard():
            a = paddle.ones(shape=[4, 2, 3], dtype="float32")
            a.stop_gradient = False
            b = a**2
            b[0] = 3
            # getitem has no_need_buffer input
            c = b[0:2]
            loss = c.sum()
            b[1] = 2
            loss.backward()


585 586
if __name__ == '__main__':
    unittest.main()