test_prune.py 38.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import contextlib
import os
17 18
import unittest

19 20
import numpy as np

21
import paddle
22 23 24 25 26 27
import paddle.fluid as fluid
import paddle.fluid.framework as framework


class TestPrune(unittest.TestCase):
    def net(self):
G
GGBond8488 已提交
28 29 30 31
        x = paddle.static.data(name='x', shape=[-1, 2], dtype='float32')
        x.desc.set_need_check_feed(False)
        label = paddle.static.data(name="label", shape=[-1, 1], dtype="int64")
        label.desc.set_need_check_feed(False)
C
Charles-hit 已提交
32
        y = paddle.static.nn.fc(x=[x], size=2, activation="softmax")
33 34 35
        loss = paddle.nn.functional.cross_entropy(
            input=y, label=label, reduction='none', use_softmax=False
        )
36
        loss = paddle.mean(x=loss)
37 38 39 40 41 42 43 44 45
        return x, y, label, loss

    def test_prune_with_input(self):
        program = framework.Program()
        startup_program = framework.Program()
        block = program.global_block()
        with fluid.program_guard(program, startup_program):
            (x, y, label, loss) = self.net()
        self.assertEqual(len(block.ops), 5)
46 47 48 49 50 51
        self.assertEqual(
            [op.type for op in block.ops],
            [
                "mul",
                "elementwise_add",
                "softmax",
52
                "softmax_with_cross_entropy",
53 54 55
                "reduce_mean",
            ],
        )
56
        pruned_program = program._prune_with_input(
57 58
            feeded_var_names=[y.name, label.name], targets=[loss]
        )
59
        self.assertEqual(len(pruned_program.global_block().ops), 2)
60 61
        self.assertEqual(
            [op.type for op in pruned_program.global_block().ops],
62
            ["softmax_with_cross_entropy", "reduce_mean"],
63
        )
64 65 66 67 68 69 70 71

    def test_prune(self):
        program = framework.Program()
        startup_program = framework.Program()
        block = program.global_block()
        with fluid.program_guard(program, startup_program):
            (x, y, label, loss) = self.net()
        self.assertEqual(len(block.ops), 5)
72 73 74 75 76 77
        self.assertEqual(
            [op.type for op in block.ops],
            [
                "mul",
                "elementwise_add",
                "softmax",
78
                "softmax_with_cross_entropy",
79 80 81
                "reduce_mean",
            ],
        )
82 83
        pruned_program = program._prune(targets=[loss])
        self.assertEqual(len(pruned_program.global_block().ops), 5)
84 85 86 87 88 89
        self.assertEqual(
            [op.type for op in pruned_program.global_block().ops],
            [
                "mul",
                "elementwise_add",
                "softmax",
90
                "softmax_with_cross_entropy",
91 92 93
                "reduce_mean",
            ],
        )
94 95 96 97 98 99 100 101

    def test_prune_target_not_list(self):
        program = framework.Program()
        startup_program = framework.Program()
        block = program.global_block()
        with fluid.program_guard(program, startup_program):
            (x, y, label, loss) = self.net()
        self.assertEqual(len(block.ops), 5)
102 103 104 105 106 107
        self.assertEqual(
            [op.type for op in block.ops],
            [
                "mul",
                "elementwise_add",
                "softmax",
108
                "softmax_with_cross_entropy",
109 110 111
                "reduce_mean",
            ],
        )
112 113
        pruned_program = program._prune(targets=loss)
        self.assertEqual(len(pruned_program.global_block().ops), 5)
114 115 116 117 118 119
        self.assertEqual(
            [op.type for op in pruned_program.global_block().ops],
            [
                "mul",
                "elementwise_add",
                "softmax",
120
                "softmax_with_cross_entropy",
121 122 123
                "reduce_mean",
            ],
        )
124 125 126 127 128 129 130 131

    def test_prune_target_none(self):
        program = framework.Program()
        startup_program = framework.Program()
        block = program.global_block()
        with fluid.program_guard(program, startup_program):
            (x, y, label, loss) = self.net()
        self.assertEqual(len(block.ops), 5)
132 133 134 135 136 137
        self.assertEqual(
            [op.type for op in block.ops],
            [
                "mul",
                "elementwise_add",
                "softmax",
138
                "softmax_with_cross_entropy",
139 140 141
                "reduce_mean",
            ],
        )
142 143 144
        try:
            pruned_program = program._prune(targets=None)
        except ValueError as e:
145 146
            self.assertIn(
                "All targets of Program._prune_with_input() can only be Variable or Operator",
147 148
                str(e),
            )
149 150


151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
def mock(self, program, feed, fetch, optimize_ops):
    self.prune_called_times += 1
    return program


@contextlib.contextmanager
def _mock_guard(mock):
    original = fluid.Executor._prune_program
    fluid.Executor._prune_program = mock
    yield
    fluid.Executor._prune_program = original


class TestExecutorRunAutoPrune(unittest.TestCase):
    def net1(self):
G
GGBond8488 已提交
166 167 168 169
        x = paddle.static.data(name='x', shape=[-1, 2], dtype='float32')
        x.desc.set_need_check_feed(False)
        label = paddle.static.data(name="label", shape=[-1, 1], dtype="int64")
        label.desc.set_need_check_feed(False)
170 171 172
        w_param_attrs = fluid.ParamAttr(
            name="fc_weight",
            learning_rate=0.5,
173
            initializer=paddle.nn.initializer.Constant(1.0),
174 175
            trainable=True,
        )
C
Charles-hit 已提交
176 177
        y = paddle.static.nn.fc(
            x=[x], size=2, activation="softmax", weight_attr=w_param_attrs
178
        )
179 180 181
        loss1 = paddle.nn.functional.cross_entropy(
            input=y, label=label, reduction='none', use_softmax=False
        )
182
        loss1 = paddle.mean(x=loss1)
183 184 185
        loss2 = paddle.nn.functional.cross_entropy(
            input=y, label=label, reduction='none', use_softmax=False
        )
186
        loss2 = paddle.mean(x=loss2)
187 188 189 190 191
        loss1.persistable = True
        loss2.persistable = True
        return x, y, label, loss1, loss2, w_param_attrs

    def net2(self):
G
GGBond8488 已提交
192 193 194 195 196 197
        x1 = paddle.static.data(name='x1', shape=[-1, 2], dtype='float32')
        x1.desc.set_need_check_feed(False)
        x2 = paddle.static.data(name='x2', shape=[-1, 2], dtype='float32')
        x2.desc.set_need_check_feed(False)
        label = paddle.static.data(name="label", shape=[-1, 1], dtype="int64")
        label.desc.set_need_check_feed(False)
198 199 200
        w1_param_attrs = fluid.ParamAttr(
            name="fc_weight1",
            learning_rate=0.5,
201
            initializer=paddle.nn.initializer.Constant(1.0),
202 203
            trainable=True,
        )
204 205 206
        w2_param_attrs = fluid.ParamAttr(
            name="fc_weight2",
            learning_rate=0.5,
207
            initializer=paddle.nn.initializer.Constant(1.0),
208 209
            trainable=True,
        )
C
Charles-hit 已提交
210 211
        y1 = paddle.static.nn.fc(
            x=[x1], size=2, activation="softmax", weight_attr=w1_param_attrs
212
        )
C
Charles-hit 已提交
213 214
        y2 = paddle.static.nn.fc(
            x=[x2], size=2, activation="softmax", weight_attr=w2_param_attrs
215
        )
216 217 218
        loss1 = paddle.nn.functional.cross_entropy(
            input=y1, label=label, reduction='none', use_softmax=False
        )
219
        loss1 = paddle.mean(x=loss1)
220 221 222
        loss2 = paddle.nn.functional.cross_entropy(
            input=y2, label=label, reduction='none', use_softmax=False
        )
223
        loss2 = paddle.mean(x=loss2)
224 225 226 227 228 229 230 231 232 233 234
        return (
            x1,
            x2,
            y1,
            y2,
            label,
            loss1,
            loss2,
            w1_param_attrs,
            w2_param_attrs,
        )
235 236 237 238 239 240 241 242 243 244 245 246 247 248 249

    def test_not_prune(self):
        """
        If use_prune = False, the targets which is not fetched will be calculated.
        """
        program = framework.Program()
        startup_program = framework.Program()
        scope = fluid.Scope()
        with fluid.scope_guard(scope):
            with fluid.program_guard(program, startup_program):
                (x, y, label, loss1, loss2, w_param_attrs) = self.net1()
                exe = fluid.Executor(fluid.CPUPlace())
                exe.run(startup_program)
                x_np = np.random.random(size=(10, 2)).astype('float32')
                label_np = np.random.randint(1, size=(10, 1)).astype('int64')
250 251 252 253 254 255
                res = exe.run(
                    program,
                    feed={'x': x_np, 'label': label_np},
                    fetch_list=[loss1.name],
                    use_prune=False,
                )
256 257 258 259 260
                self.assertIsNotNone(scope.find_var(loss1.name))
                self.assertIsNotNone(scope.find_var(loss2.name))

    def test_prune_fetches_without_optimizer(self):
        """
261
        Prune operators and variables which are not needed to generate 'fetches'.
262 263 264 265 266 267 268 269 270 271
        """
        program = framework.Program()
        startup_program = framework.Program()
        scope = fluid.Scope()
        with fluid.scope_guard(scope):
            with fluid.program_guard(program, startup_program):
                (x, y, label, loss1, loss2, w_param_attrs) = self.net1()
                exe = fluid.Executor(fluid.CPUPlace())
                exe.run(startup_program)
                weight_init = np.array(
272 273
                    scope.find_var(w_param_attrs.name).get_tensor()
                )
274 275
                x_np = np.random.random(size=(10, 2)).astype('float32')
                label_np = np.random.randint(1, size=(10, 1)).astype('int64')
276 277 278 279 280 281
                res = exe.run(
                    program,
                    feed={'x': x_np, 'label': label_np},
                    fetch_list=[loss1.name],
                    use_prune=True,
                )
282
                self.assertIsNotNone(scope.find_var(loss1.name))
283
                self.assertIsNone(scope.find_var(loss2.name))  # loss2 is pruned
284
                weight = np.array(
285 286 287 288 289
                    scope.find_var(w_param_attrs.name).get_tensor()
                )
                np.testing.assert_array_equal(
                    weight_init, weight
                )  # weight not changed
290 291 292

    def test_prune_fetches_with_optimizer(self):
        """
293
        Prune operators and operators which are not needed to generate 'fetches'.
294 295 296 297 298 299 300 301 302 303 304 305 306
        In train mode, the operators and operators in backward and optimization should be kept.
        """
        program = framework.Program()
        startup_program = framework.Program()
        scope = fluid.Scope()
        with fluid.scope_guard(scope):
            with fluid.program_guard(program, startup_program):
                (x, y, label, loss1, loss2, w_param_attrs) = self.net1()
                sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.5)
                sgd_optimizer.minimize(loss1)
                exe = fluid.Executor(fluid.CPUPlace())
                exe.run(startup_program)
                weight_init = np.array(
307 308
                    scope.find_var(w_param_attrs.name).get_tensor()
                )
309 310
                x_np = np.random.random(size=(10, 2)).astype('float32')
                label_np = np.random.randint(1, size=(10, 1)).astype('int64')
311 312 313 314 315 316
                res = exe.run(
                    program,
                    feed={'x': x_np, 'label': label_np},
                    fetch_list=[loss1.name],
                    use_prune=True,
                )
317
                self.assertIsNotNone(scope.find_var(loss1.name))
318
                self.assertIsNone(scope.find_var(loss2.name))  # loss2 is pruned
319
                weight = np.array(
320 321 322 323 324
                    scope.find_var(w_param_attrs.name).get_tensor()
                )
                self.assertFalse(
                    np.array_equal(weight_init, weight)
                )  # weight changed
325 326 327 328 329 330 331 332 333 334 335 336 337

    def test_prune_compiled_program(self):
        program = framework.Program()
        startup_program = framework.Program()
        scope = fluid.Scope()
        with fluid.scope_guard(scope):
            with fluid.program_guard(program, startup_program):
                (x, y, label, loss1, loss2, w_param_attrs) = self.net1()
                sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.5)
                sgd_optimizer.minimize(loss1)
                exe = fluid.Executor(fluid.CPUPlace())
                exe.run(startup_program)
                compiled_prog = fluid.CompiledProgram(
338 339 340 341
                    program
                ).with_data_parallel(
                    loss_name=loss1.name, places=fluid.CPUPlace()
                )
342
                weight_init = np.array(
343 344
                    scope.find_var(w_param_attrs.name).get_tensor()
                )
345 346
                x_np = np.random.random(size=(10, 2)).astype('float32')
                label_np = np.random.randint(1, size=(10, 1)).astype('int64')
347 348 349 350 351 352
                res = exe.run(
                    compiled_prog,
                    feed={'x': x_np, 'label': label_np},
                    fetch_list=[loss1.name],
                    use_prune=True,
                )
353 354 355
                self.assertIsNotNone(scope.find_var(loss1.name))
                self.assertIsNone(scope.find_var(loss2.name))
                weight = np.array(
356 357 358 359 360
                    scope.find_var(w_param_attrs.name).get_tensor()
                )
                self.assertFalse(
                    np.array_equal(weight_init, weight)
                )  # weight changed
361 362 363 364 365 366 367 368 369 370 371

    def test_prune_feed_without_optimizer(self):
        program = framework.Program()
        startup_program = framework.Program()
        scope = fluid.Scope()
        with fluid.scope_guard(scope):
            with fluid.program_guard(program, startup_program):
                (x, y, label, loss1, loss2, w_param_attrs) = self.net1()
                exe = fluid.Executor(fluid.CPUPlace())
                exe.run(startup_program)
                weight_init = np.array(
372 373
                    scope.find_var(w_param_attrs.name).get_tensor()
                )
374 375
                x_np = np.random.random(size=(10, 2)).astype('float32')
                label_np = np.random.randint(1, size=(10, 1)).astype('int64')
376 377 378 379 380 381
                res = exe.run(
                    program,
                    feed={y.name: x_np, 'label': label_np},
                    fetch_list=[loss1.name],
                    use_prune=True,
                )
382 383 384
                self.assertIsNotNone(scope.find_var(loss1.name))
                self.assertIsNone(scope.find_var(loss2.name))
                weight = np.array(
385 386 387 388 389
                    scope.find_var(w_param_attrs.name).get_tensor()
                )
                np.testing.assert_array_equal(
                    weight_init, weight
                )  # weight unchanged
390 391 392 393 394 395 396 397 398 399 400 401 402 403

    def test_prune_feed_with_optimizer(self):
        program = framework.Program()
        startup_program = framework.Program()
        scope = fluid.Scope()
        with fluid.scope_guard(scope):
            with fluid.program_guard(program, startup_program):
                (x, y, label, loss1, loss2, w_param_attrs) = self.net1()
                sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.5)
                sgd_optimizer.minimize(loss1)
                exe = fluid.Executor(fluid.CPUPlace())
                exe.run(startup_program)
                x_np = np.random.random(size=(10, 2)).astype('float32')
                label_np = np.random.randint(1, size=(10, 1)).astype('int64')
404 405 406 407 408 409 410 411
                self.assertRaises(
                    Exception,
                    exe.run,
                    program,
                    feed={y.name: x_np, 'label': label_np},
                    fetch_list=[loss1.name],
                    use_prune=True,
                )
412 413 414 415 416
                self.assertIsNotNone(scope.find_var(loss1.name))
                self.assertIsNone(scope.find_var(loss2.name))

    def test_prune_with_cache_program(self):
        '''
417
        When use_prune=True, Executor should cache the pruned program.
418 419 420
        If in next run, the program, feed, fetch are not changed, Executor use the cached pruned program,
        and needn't to call  _prune_program() to prune the program.
        In this test, we hack the Executor._prune_program with a mock function which do nothing but increase
421
        Executor.prune_called_times, and we check prune_called_times equals 1 even if we called exe.run()
422 423 424 425 426 427 428 429 430 431 432 433 434 435 436
        10 times with the same input arguments.
        '''
        with _mock_guard(mock):
            exe = fluid.Executor(fluid.CPUPlace())
            exe.prune_called_times = 0
            program = framework.Program()
            startup_program = framework.Program()
            scope = fluid.Scope()
            with fluid.scope_guard(scope):
                with fluid.program_guard(program, startup_program):
                    (x, y, label, loss1, loss2, w_param_attrs) = self.net1()
                    sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.5)
                    sgd_optimizer.minimize(loss1)
                    exe.run(startup_program)
                    x_np = np.random.random(size=(10, 2)).astype('float32')
437 438 439
                    label_np = np.random.randint(1, size=(10, 1)).astype(
                        'int64'
                    )
440
                    for i in range(10):
441 442 443 444 445 446
                        res = exe.run(
                            program,
                            feed={'x': x_np, 'label': label_np},
                            fetch_list=[loss1.name],
                            use_prune=True,
                        )
447 448 449 450 451
                        if i == 0:
                            self.assertEqual(exe.prune_called_times, 1)
                        else:
                            self.assertEqual(exe.prune_called_times, 1)

452 453 454
    def test_prune_with_cache_program2(self):
        '''
        When use_prune=True, Executor should cache the pruned program.
455
        If the only difference in fetch_list is  optimize_ops during multiple runs,
456 457 458 459 460 461 462 463 464 465
        the cache_keys should be different and get different pruned program.
        '''
        with _mock_guard(mock):
            exe = fluid.Executor(fluid.CPUPlace())
            exe.prune_called_times = 0
            program = framework.Program()
            startup_program = framework.Program()
            scope = fluid.Scope()
            with fluid.scope_guard(scope):
                with fluid.program_guard(program, startup_program):
466 467 468 469 470 471 472 473 474 475 476
                    (
                        x1,
                        x2,
                        y1,
                        y2,
                        label,
                        loss1,
                        loss2,
                        w1_param_attrs,
                        w2_param_attrs,
                    ) = self.net2()
477
                    adam_optimizer1 = fluid.optimizer.AdamOptimizer(
478 479
                        learning_rate=0.5
                    )
480 481
                    train1 = adam_optimizer1.minimize(loss1)
                    adam_optimizer2 = fluid.optimizer.AdamOptimizer(
482 483
                        learning_rate=0.5
                    )
484 485 486
                    train2 = adam_optimizer2.minimize(loss2)
                    exe.run(startup_program)
                    x_np = np.random.random(size=(10, 2)).astype('float32')
487 488 489
                    label_np = np.random.randint(1, size=(10, 1)).astype(
                        'int64'
                    )
490 491 492

                    for i in range(10):
                        if i % 2:
493 494 495 496 497 498 499 500 501 502
                            res = exe.run(
                                program,
                                feed={
                                    'x1': x_np,
                                    'x2': x_np,
                                    'label': label_np,
                                },
                                fetch_list=[loss1, loss2, train1],
                                use_prune=True,
                            )
503
                        else:
504 505 506 507 508 509 510 511 512 513
                            res = exe.run(
                                program,
                                feed={
                                    'x1': x_np,
                                    'x2': x_np,
                                    'label': label_np,
                                },
                                fetch_list=[loss1, loss2, train2],
                                use_prune=True,
                            )
514 515 516 517 518 519 520
                        if i == 0:
                            self.assertEqual(exe.prune_called_times, 1)
                        elif i == 1:
                            self.assertEqual(exe.prune_called_times, 2)
                        else:
                            self.assertEqual(exe.prune_called_times, 2)

521 522
    def test_prune_with_cache_compiled_program(self):
        '''
523
        When use_prune=True, Executor should cache the pruned program.
524 525 526
        If in next run, the program, feed, fetch are not changed, Executor use the cached pruned program,
        and needn't to call  _prune_program() to prune the program.
        In this test, we hack the Executor._prune_program with a mock function which do nothing but increase
527
        Executor.prune_called_times, and we check prune_called_times equals 1 even if we called exe.run()
528 529 530 531 532 533 534 535 536 537 538 539 540 541 542
        10 times with the same input arguments.
        '''
        with _mock_guard(mock):
            exe = fluid.Executor(fluid.CPUPlace())
            exe.prune_called_times = 0
            program = framework.Program()
            startup_program = framework.Program()
            scope = fluid.Scope()
            with fluid.scope_guard(scope):
                with fluid.program_guard(program, startup_program):
                    (x, y, label, loss1, loss2, w_param_attrs) = self.net1()
                    sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.5)
                    sgd_optimizer.minimize(loss1)
                    exe.run(startup_program)
                    x_np = np.random.random(size=(10, 2)).astype('float32')
543 544 545
                    label_np = np.random.randint(1, size=(10, 1)).astype(
                        'int64'
                    )
546
                    compiled_prog = fluid.CompiledProgram(
547 548 549 550
                        program
                    ).with_data_parallel(
                        loss_name=loss1.name, places=fluid.CPUPlace()
                    )
551
                    for i in range(10):
552 553 554 555 556 557
                        res = exe.run(
                            compiled_prog,
                            feed={'x': x_np, 'label': label_np},
                            fetch_list=[loss1.name],
                            use_prune=True,
                        )
558 559 560 561 562 563 564
                        if i == 0:
                            self.assertEqual(exe.prune_called_times, 1)
                        else:
                            self.assertEqual(exe.prune_called_times, 1)

    def test_prune_with_multi_optimizers(self):
        '''
565
        If there are multiple optimizers in the program, we can run specific one by
566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582
        pass the return of optimize.minimize() to fetch_list.
        '''
        exe = fluid.Executor(fluid.CPUPlace())
        program = framework.Program()
        startup_program = framework.Program()
        scope = fluid.Scope()
        # do not use_prune
        with fluid.scope_guard(scope):
            with fluid.program_guard(program, startup_program):
                (x, y, label, loss1, loss2, w_param_attrs) = self.net1()
                sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.5)
                train1, _ = sgd_optimizer.minimize(loss1)
                cloned_program = program.clone()
                train2, _ = sgd_optimizer.minimize(loss2)
                exe.run(startup_program)
                x_np = np.random.random(size=(10, 2)).astype('float32')
                label_np = np.random.randint(1, size=(10, 1)).astype('int64')
583 584 585 586 587 588
                res = exe.run(
                    program,
                    feed={'x': x_np, 'label': label_np},
                    fetch_list=[loss1.name],
                    use_prune=False,
                )
589
                weight_without_prune = np.array(
590 591
                    scope.find_var(w_param_attrs.name).get_tensor()
                )
592 593 594 595 596

        scope = fluid.Scope()
        # use_prune
        with fluid.scope_guard(scope):
            exe.run(startup_program)
597 598 599 600 601 602
            res = exe.run(
                program,
                feed={'x': x_np, 'label': label_np},
                fetch_list=[loss1.name, train1],
                use_prune=True,
            )
603
            weight_with_prune = np.array(
604 605
                scope.find_var(w_param_attrs.name).get_tensor()
            )
606 607 608 609 610

        # expected
        scope = fluid.Scope()
        with fluid.scope_guard(scope):
            exe.run(startup_program)
611 612 613 614 615 616
            exe.run(
                cloned_program,
                feed={'x': x_np, 'label': label_np},
                fetch_list=[loss1.name],
                use_prune=False,
            )
617
            weight_expected = np.array(
618 619
                scope.find_var(w_param_attrs.name).get_tensor()
            )
620

621
        np.testing.assert_array_equal(weight_with_prune, weight_expected)
622 623 624 625 626 627 628 629 630 631 632 633 634 635 636
        self.assertFalse(np.array_equal(weight_without_prune, weight_expected))

    def test_prune_with_multi_devices(self):
        '''
        When training model with multi_devices, the pruned CompiledProgram should share same local scopes.
        This test the correctness.
        '''
        exe = fluid.Executor(fluid.CPUPlace())
        program = framework.Program()
        startup_program = framework.Program()
        scope = fluid.Scope()
        os.environ['CPU_NUM'] = str(2)
        # do not use_prune
        with fluid.scope_guard(scope):
            with fluid.program_guard(program, startup_program):
637 638 639 640 641 642 643 644 645 646 647
                (
                    x1,
                    x2,
                    y1,
                    y2,
                    label,
                    loss1,
                    loss2,
                    w1_param_attrs,
                    w2_param_attrs,
                ) = self.net2()
648
                adam_optimizer1 = fluid.optimizer.AdamOptimizer(
649 650
                    learning_rate=0.5
                )
651 652 653
                train1 = adam_optimizer1.minimize(loss1)
                cloned_program = program.clone()
                adam_optimizer2 = fluid.optimizer.AdamOptimizer(
654 655
                    learning_rate=0.5
                )
656 657 658 659 660
                train2 = adam_optimizer2.minimize(loss2)
                exe.run(startup_program)
                x_np = np.random.random(size=(10, 2)).astype('float32')
                label_np = np.random.randint(1, size=(10, 1)).astype('int64')
                compiled_prog1 = fluid.CompiledProgram(
661 662 663 664
                    program
                ).with_data_parallel(
                    loss_name=loss1.name, places=[fluid.CPUPlace()] * 2
                )
665
                compiled_prog2 = fluid.CompiledProgram(
666 667 668 669
                    program
                ).with_data_parallel(
                    loss_name=loss2.name, places=[fluid.CPUPlace()] * 2
                )
670 671
                for i in range(10):
                    if i % 2 == 1:
672 673 674 675 676 677 678 679 680
                        res = exe.run(
                            compiled_prog1,
                            feed=[
                                {'x1': x_np[0:5, :], 'label': label_np[0:5, :]},
                                {'x1': x_np[5:, :], 'label': label_np[5:, :]},
                            ],
                            fetch_list=[loss1.name, train1],
                            use_prune=True,
                        )
681
                    else:
682 683 684 685 686 687
                        res = exe.run(
                            compiled_prog2,
                            feed={'x2': x_np, 'label': label_np},
                            fetch_list=[loss2.name, train2],
                            use_prune=True,
                        )
688
                weight1 = np.array(
689 690
                    scope.find_var(w1_param_attrs.name).get_tensor()
                )
691 692 693 694 695 696
        # expected
        scope = fluid.Scope()
        with fluid.scope_guard(scope):
            exe.run(startup_program)
            for i in range(10):
                if i % 2 == 1:
697 698 699 700 701 702
                    exe.run(
                        cloned_program,
                        feed={'x1': x_np, 'x2': x_np, 'label': label_np},
                        fetch_list=[loss1.name],
                        use_prune=False,
                    )
703
            weight2 = np.array(scope.find_var(w1_param_attrs.name).get_tensor())
704
        np.testing.assert_allclose(weight1, weight2, rtol=1e-05)
705 706 707

    def test_prune_program_with_tupe_in_fetch_list(self):
        '''
708
        If there are multiple optimizers in the program, we can run specific one by
709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727
        pass the return of optimize.minimize() to fetch_list.
        '''
        exe = fluid.Executor(fluid.CPUPlace())
        program = framework.Program()
        startup_program = framework.Program()
        scope = fluid.Scope()
        # do not use_prune
        with fluid.scope_guard(scope):
            with fluid.program_guard(program, startup_program):
                (x, y, label, loss1, loss2, w_param_attrs) = self.net1()
                sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.5)
                train1 = sgd_optimizer.minimize(loss1)
                cloned_program = program.clone()

                train2 = sgd_optimizer.minimize(loss2)
                exe.run(startup_program)
                x_np = np.random.random(size=(10, 2)).astype('float32')
                label_np = np.random.randint(1, size=(10, 1)).astype('int64')

728 729 730 731 732 733
                res = exe.run(
                    program,
                    feed={'x': x_np, 'label': label_np},
                    fetch_list=[loss1.name],
                    use_prune=False,
                )
734 735

                weight_without_prune = np.array(
736 737
                    scope.find_var(w_param_attrs.name).get_tensor()
                )
738 739 740 741 742

        scope = fluid.Scope()
        # use_prune
        with fluid.scope_guard(scope):
            exe.run(startup_program)
743 744 745 746 747 748
            res = exe.run(
                program,
                feed={'x': x_np, 'label': label_np},
                fetch_list=[loss1.name, train1],
                use_prune=True,
            )
749
            weight_with_prune = np.array(
750 751
                scope.find_var(w_param_attrs.name).get_tensor()
            )
752 753 754 755 756

        # expected
        scope = fluid.Scope()
        with fluid.scope_guard(scope):
            exe.run(startup_program)
757 758 759 760 761 762
            exe.run(
                cloned_program,
                feed={'x': x_np, 'label': label_np},
                fetch_list=[loss1.name],
                use_prune=False,
            )
763
            weight_expected = np.array(
764 765
                scope.find_var(w_param_attrs.name).get_tensor()
            )
766

767
        np.testing.assert_array_equal(weight_with_prune, weight_expected)
768 769 770 771 772 773 774 775 776 777 778 779
        self.assertFalse(np.array_equal(weight_without_prune, weight_expected))

    def test_prune_program_partial_parameter_updated(self):
        """
        When running startup program, all parameters declared will be initialized.
        When running main program with prune=True, the pruned parameters will exist in scope and stay unchanged.
        """
        program = framework.Program()
        startup_program = framework.Program()
        scope = fluid.Scope()
        with fluid.scope_guard(scope):
            with fluid.program_guard(program, startup_program):
780 781 782 783 784 785 786 787 788 789 790
                (
                    x1,
                    x2,
                    y1,
                    y2,
                    label,
                    loss1,
                    loss2,
                    w1_param_attrs,
                    w2_param_attrs,
                ) = self.net2()
791 792 793 794 795 796 797 798 799
                loss1.persistable = True
                loss2.persistable = True
                sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.5)
                train1 = sgd_optimizer.minimize(loss1)
                sgd_optimizer1 = fluid.optimizer.SGD(learning_rate=0.5)
                train2 = sgd_optimizer1.minimize(loss2)
                exe = fluid.Executor(fluid.CPUPlace())
                exe.run(startup_program)
                weight1_init = np.array(
800 801
                    scope.find_var(w1_param_attrs.name).get_tensor()
                )
802
                weight2_init = np.array(
803 804
                    scope.find_var(w2_param_attrs.name).get_tensor()
                )
805 806 807
                x_np = np.random.random(size=(10, 2)).astype('float32')
                label_np = np.random.randint(1, size=(10, 1)).astype('int64')

808 809 810 811 812 813
                res = exe.run(
                    program,
                    feed={'x1': x_np, 'label': label_np},
                    fetch_list=[loss1.name, train1],
                    use_prune=True,
                )
814 815 816 817 818
                self.assertIsNotNone(scope.find_var(w1_param_attrs.name))
                self.assertIsNotNone(scope.find_var(w2_param_attrs.name))
                self.assertIsNotNone(scope.find_var(loss1.name))
                self.assertIsNone(scope.find_var(loss2.name))
                weight1 = np.array(
819 820
                    scope.find_var(w1_param_attrs.name).get_tensor()
                )
821
                weight2 = np.array(
822 823 824 825 826 827 828 829
                    scope.find_var(w2_param_attrs.name).get_tensor()
                )
                self.assertFalse(
                    np.array_equal(weight1_init, weight1)
                )  # weight changed
                np.testing.assert_array_equal(
                    weight2_init, weight2
                )  # weight2 unchanged
830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849

    def test_prune_override_use_prune(self):
        '''
        If optimize_ops in provided in the fetch_list, the argument use_prune is always override to True.
        '''
        exe = fluid.Executor(fluid.CPUPlace())
        program = framework.Program()
        startup_program = framework.Program()
        scope = fluid.Scope()
        # do not use_prune
        with fluid.scope_guard(scope):
            with fluid.program_guard(program, startup_program):
                (x, y, label, loss1, loss2, w_param_attrs) = self.net1()
                sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.5)
                train1, _ = sgd_optimizer.minimize(loss1)
                cloned_program = program.clone()
                train2, _ = sgd_optimizer.minimize(loss2)
                exe.run(startup_program)
                x_np = np.random.random(size=(10, 2)).astype('float32')
                label_np = np.random.randint(1, size=(10, 1)).astype('int64')
850 851 852 853 854 855
                res = exe.run(
                    program,
                    feed={'x': x_np, 'label': label_np},
                    fetch_list=[loss1.name],
                    use_prune=False,
                )
856 857

                weight_without_prune = np.array(
858 859
                    scope.find_var(w_param_attrs.name).get_tensor()
                )
860 861 862 863 864

        scope = fluid.Scope()
        # use_prune
        with fluid.scope_guard(scope):
            exe.run(startup_program)
865 866 867 868 869
            res = exe.run(
                program,
                feed={'x': x_np, 'label': label_np},
                fetch_list=[loss1.name, train1],
            )
870
            weight_with_prune = np.array(
871 872
                scope.find_var(w_param_attrs.name).get_tensor()
            )
873 874 875 876 877

        # expected
        scope = fluid.Scope()
        with fluid.scope_guard(scope):
            exe.run(startup_program)
878 879 880 881 882 883
            exe.run(
                cloned_program,
                feed={'x': x_np, 'label': label_np},
                fetch_list=[loss1.name],
                use_prune=False,
            )
884
            weight_expected = np.array(
885 886
                scope.find_var(w_param_attrs.name).get_tensor()
            )
887

888
        np.testing.assert_array_equal(weight_with_prune, weight_expected)
889 890
        self.assertFalse(np.array_equal(weight_without_prune, weight_expected))

891 892 893 894 895 896 897 898 899 900 901
    def test_prune_feed_var_in_fetchlist_1(self):
        # the variable to be fed is not leaf
        program = framework.Program()
        startup_program = framework.Program()
        scope = fluid.Scope()
        with fluid.scope_guard(scope):
            with fluid.program_guard(program, startup_program):
                (x, y, label, loss1, loss2, w_param_attrs) = self.net1()
                exe = fluid.Executor(fluid.CPUPlace())
                exe.run(startup_program)
                weight_init = np.array(
902 903
                    scope.find_var(w_param_attrs.name).get_tensor()
                )
904 905
                x_np = np.random.random(size=(10, 2)).astype('float32')
                label_np = np.random.randint(1, size=(10, 1)).astype('int64')
906 907 908 909 910 911
                res = exe.run(
                    program,
                    feed={y.name: x_np, 'label': label_np},
                    fetch_list=[y.name, loss1.name],
                    use_prune=True,
                )
912 913 914 915
                self.assertIsNotNone(scope.find_var(loss1.name))
                self.assertIsNone(scope.find_var(loss2.name))
                self.assertIsNone(scope.find_var(x.name))
                weight = np.array(
916 917 918 919 920
                    scope.find_var(w_param_attrs.name).get_tensor()
                )
                np.testing.assert_array_equal(
                    weight_init, weight
                )  # weight unchanged
921 922 923 924 925 926 927 928 929 930 931 932

    def test_prune_feed_var_in_fetchlist_2(self):
        # the variable to be fed is leaf
        program = framework.Program()
        startup_program = framework.Program()
        scope = fluid.Scope()
        with fluid.scope_guard(scope):
            with fluid.program_guard(program, startup_program):
                (x, y, label, loss1, loss2, w_param_attrs) = self.net1()
                exe = fluid.Executor(fluid.CPUPlace())
                exe.run(startup_program)
                weight_init = np.array(
933 934
                    scope.find_var(w_param_attrs.name).get_tensor()
                )
935 936
                x_np = np.random.random(size=(10, 2)).astype('float32')
                label_np = np.random.randint(1, size=(10, 1)).astype('int64')
937 938 939 940 941 942
                res = exe.run(
                    program,
                    feed={x.name: x_np, 'label': label_np},
                    fetch_list=[x.name, loss1.name],
                    use_prune=True,
                )
943 944 945
                self.assertIsNotNone(scope.find_var(loss1.name))
                self.assertIsNone(scope.find_var(loss2.name))
                weight = np.array(
946 947 948 949 950
                    scope.find_var(w_param_attrs.name).get_tensor()
                )
                np.testing.assert_array_equal(
                    weight_init, weight
                )  # weight unchanged
951

952

953 954
if __name__ == '__main__':
    unittest.main()