test_prune.py 34.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import contextlib
16 17
import unittest

18 19
import numpy as np

20
import paddle
21 22 23 24 25 26
import paddle.fluid as fluid
import paddle.fluid.framework as framework


class TestPrune(unittest.TestCase):
    def net(self):
G
GGBond8488 已提交
27 28 29 30
        x = paddle.static.data(name='x', shape=[-1, 2], dtype='float32')
        x.desc.set_need_check_feed(False)
        label = paddle.static.data(name="label", shape=[-1, 1], dtype="int64")
        label.desc.set_need_check_feed(False)
C
Charles-hit 已提交
31
        y = paddle.static.nn.fc(x=[x], size=2, activation="softmax")
32 33 34
        loss = paddle.nn.functional.cross_entropy(
            input=y, label=label, reduction='none', use_softmax=False
        )
35
        loss = paddle.mean(x=loss)
36 37 38 39 40 41 42 43 44
        return x, y, label, loss

    def test_prune_with_input(self):
        program = framework.Program()
        startup_program = framework.Program()
        block = program.global_block()
        with fluid.program_guard(program, startup_program):
            (x, y, label, loss) = self.net()
        self.assertEqual(len(block.ops), 5)
45 46 47 48 49 50
        self.assertEqual(
            [op.type for op in block.ops],
            [
                "mul",
                "elementwise_add",
                "softmax",
51
                "softmax_with_cross_entropy",
52 53 54
                "reduce_mean",
            ],
        )
55
        pruned_program = program._prune_with_input(
56 57
            feeded_var_names=[y.name, label.name], targets=[loss]
        )
58
        self.assertEqual(len(pruned_program.global_block().ops), 2)
59 60
        self.assertEqual(
            [op.type for op in pruned_program.global_block().ops],
61
            ["softmax_with_cross_entropy", "reduce_mean"],
62
        )
63 64 65 66 67 68 69 70

    def test_prune(self):
        program = framework.Program()
        startup_program = framework.Program()
        block = program.global_block()
        with fluid.program_guard(program, startup_program):
            (x, y, label, loss) = self.net()
        self.assertEqual(len(block.ops), 5)
71 72 73 74 75 76
        self.assertEqual(
            [op.type for op in block.ops],
            [
                "mul",
                "elementwise_add",
                "softmax",
77
                "softmax_with_cross_entropy",
78 79 80
                "reduce_mean",
            ],
        )
81 82
        pruned_program = program._prune(targets=[loss])
        self.assertEqual(len(pruned_program.global_block().ops), 5)
83 84 85 86 87 88
        self.assertEqual(
            [op.type for op in pruned_program.global_block().ops],
            [
                "mul",
                "elementwise_add",
                "softmax",
89
                "softmax_with_cross_entropy",
90 91 92
                "reduce_mean",
            ],
        )
93 94 95 96 97 98 99 100

    def test_prune_target_not_list(self):
        program = framework.Program()
        startup_program = framework.Program()
        block = program.global_block()
        with fluid.program_guard(program, startup_program):
            (x, y, label, loss) = self.net()
        self.assertEqual(len(block.ops), 5)
101 102 103 104 105 106
        self.assertEqual(
            [op.type for op in block.ops],
            [
                "mul",
                "elementwise_add",
                "softmax",
107
                "softmax_with_cross_entropy",
108 109 110
                "reduce_mean",
            ],
        )
111 112
        pruned_program = program._prune(targets=loss)
        self.assertEqual(len(pruned_program.global_block().ops), 5)
113 114 115 116 117 118
        self.assertEqual(
            [op.type for op in pruned_program.global_block().ops],
            [
                "mul",
                "elementwise_add",
                "softmax",
119
                "softmax_with_cross_entropy",
120 121 122
                "reduce_mean",
            ],
        )
123 124 125 126 127 128 129 130

    def test_prune_target_none(self):
        program = framework.Program()
        startup_program = framework.Program()
        block = program.global_block()
        with fluid.program_guard(program, startup_program):
            (x, y, label, loss) = self.net()
        self.assertEqual(len(block.ops), 5)
131 132 133 134 135 136
        self.assertEqual(
            [op.type for op in block.ops],
            [
                "mul",
                "elementwise_add",
                "softmax",
137
                "softmax_with_cross_entropy",
138 139 140
                "reduce_mean",
            ],
        )
141 142 143
        try:
            pruned_program = program._prune(targets=None)
        except ValueError as e:
144 145
            self.assertIn(
                "All targets of Program._prune_with_input() can only be Variable or Operator",
146 147
                str(e),
            )
148 149


150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
def mock(self, program, feed, fetch, optimize_ops):
    self.prune_called_times += 1
    return program


@contextlib.contextmanager
def _mock_guard(mock):
    original = fluid.Executor._prune_program
    fluid.Executor._prune_program = mock
    yield
    fluid.Executor._prune_program = original


class TestExecutorRunAutoPrune(unittest.TestCase):
    def net1(self):
G
GGBond8488 已提交
165 166 167 168
        x = paddle.static.data(name='x', shape=[-1, 2], dtype='float32')
        x.desc.set_need_check_feed(False)
        label = paddle.static.data(name="label", shape=[-1, 1], dtype="int64")
        label.desc.set_need_check_feed(False)
169 170 171
        w_param_attrs = fluid.ParamAttr(
            name="fc_weight",
            learning_rate=0.5,
172
            initializer=paddle.nn.initializer.Constant(1.0),
173 174
            trainable=True,
        )
C
Charles-hit 已提交
175 176
        y = paddle.static.nn.fc(
            x=[x], size=2, activation="softmax", weight_attr=w_param_attrs
177
        )
178 179 180
        loss1 = paddle.nn.functional.cross_entropy(
            input=y, label=label, reduction='none', use_softmax=False
        )
181
        loss1 = paddle.mean(x=loss1)
182 183 184
        loss2 = paddle.nn.functional.cross_entropy(
            input=y, label=label, reduction='none', use_softmax=False
        )
185
        loss2 = paddle.mean(x=loss2)
186 187 188 189 190
        loss1.persistable = True
        loss2.persistable = True
        return x, y, label, loss1, loss2, w_param_attrs

    def net2(self):
G
GGBond8488 已提交
191 192 193 194 195 196
        x1 = paddle.static.data(name='x1', shape=[-1, 2], dtype='float32')
        x1.desc.set_need_check_feed(False)
        x2 = paddle.static.data(name='x2', shape=[-1, 2], dtype='float32')
        x2.desc.set_need_check_feed(False)
        label = paddle.static.data(name="label", shape=[-1, 1], dtype="int64")
        label.desc.set_need_check_feed(False)
197 198 199
        w1_param_attrs = fluid.ParamAttr(
            name="fc_weight1",
            learning_rate=0.5,
200
            initializer=paddle.nn.initializer.Constant(1.0),
201 202
            trainable=True,
        )
203 204 205
        w2_param_attrs = fluid.ParamAttr(
            name="fc_weight2",
            learning_rate=0.5,
206
            initializer=paddle.nn.initializer.Constant(1.0),
207 208
            trainable=True,
        )
C
Charles-hit 已提交
209 210
        y1 = paddle.static.nn.fc(
            x=[x1], size=2, activation="softmax", weight_attr=w1_param_attrs
211
        )
C
Charles-hit 已提交
212 213
        y2 = paddle.static.nn.fc(
            x=[x2], size=2, activation="softmax", weight_attr=w2_param_attrs
214
        )
215 216 217
        loss1 = paddle.nn.functional.cross_entropy(
            input=y1, label=label, reduction='none', use_softmax=False
        )
218
        loss1 = paddle.mean(x=loss1)
219 220 221
        loss2 = paddle.nn.functional.cross_entropy(
            input=y2, label=label, reduction='none', use_softmax=False
        )
222
        loss2 = paddle.mean(x=loss2)
223 224 225 226 227 228 229 230 231 232 233
        return (
            x1,
            x2,
            y1,
            y2,
            label,
            loss1,
            loss2,
            w1_param_attrs,
            w2_param_attrs,
        )
234 235 236 237 238 239 240 241 242 243 244 245 246 247 248

    def test_not_prune(self):
        """
        If use_prune = False, the targets which is not fetched will be calculated.
        """
        program = framework.Program()
        startup_program = framework.Program()
        scope = fluid.Scope()
        with fluid.scope_guard(scope):
            with fluid.program_guard(program, startup_program):
                (x, y, label, loss1, loss2, w_param_attrs) = self.net1()
                exe = fluid.Executor(fluid.CPUPlace())
                exe.run(startup_program)
                x_np = np.random.random(size=(10, 2)).astype('float32')
                label_np = np.random.randint(1, size=(10, 1)).astype('int64')
249 250 251 252 253 254
                res = exe.run(
                    program,
                    feed={'x': x_np, 'label': label_np},
                    fetch_list=[loss1.name],
                    use_prune=False,
                )
255 256 257 258 259
                self.assertIsNotNone(scope.find_var(loss1.name))
                self.assertIsNotNone(scope.find_var(loss2.name))

    def test_prune_fetches_without_optimizer(self):
        """
260
        Prune operators and variables which are not needed to generate 'fetches'.
261 262 263 264 265 266 267 268 269 270
        """
        program = framework.Program()
        startup_program = framework.Program()
        scope = fluid.Scope()
        with fluid.scope_guard(scope):
            with fluid.program_guard(program, startup_program):
                (x, y, label, loss1, loss2, w_param_attrs) = self.net1()
                exe = fluid.Executor(fluid.CPUPlace())
                exe.run(startup_program)
                weight_init = np.array(
271 272
                    scope.find_var(w_param_attrs.name).get_tensor()
                )
273 274
                x_np = np.random.random(size=(10, 2)).astype('float32')
                label_np = np.random.randint(1, size=(10, 1)).astype('int64')
275 276 277 278 279 280
                res = exe.run(
                    program,
                    feed={'x': x_np, 'label': label_np},
                    fetch_list=[loss1.name],
                    use_prune=True,
                )
281
                self.assertIsNotNone(scope.find_var(loss1.name))
282
                self.assertIsNone(scope.find_var(loss2.name))  # loss2 is pruned
283
                weight = np.array(
284 285 286 287 288
                    scope.find_var(w_param_attrs.name).get_tensor()
                )
                np.testing.assert_array_equal(
                    weight_init, weight
                )  # weight not changed
289 290 291

    def test_prune_fetches_with_optimizer(self):
        """
292
        Prune operators and operators which are not needed to generate 'fetches'.
293 294 295 296 297 298 299 300 301 302 303 304 305
        In train mode, the operators and operators in backward and optimization should be kept.
        """
        program = framework.Program()
        startup_program = framework.Program()
        scope = fluid.Scope()
        with fluid.scope_guard(scope):
            with fluid.program_guard(program, startup_program):
                (x, y, label, loss1, loss2, w_param_attrs) = self.net1()
                sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.5)
                sgd_optimizer.minimize(loss1)
                exe = fluid.Executor(fluid.CPUPlace())
                exe.run(startup_program)
                weight_init = np.array(
306 307
                    scope.find_var(w_param_attrs.name).get_tensor()
                )
308 309
                x_np = np.random.random(size=(10, 2)).astype('float32')
                label_np = np.random.randint(1, size=(10, 1)).astype('int64')
310 311 312 313 314 315
                res = exe.run(
                    program,
                    feed={'x': x_np, 'label': label_np},
                    fetch_list=[loss1.name],
                    use_prune=True,
                )
316
                self.assertIsNotNone(scope.find_var(loss1.name))
317
                self.assertIsNone(scope.find_var(loss2.name))  # loss2 is pruned
318
                weight = np.array(
319 320 321 322 323
                    scope.find_var(w_param_attrs.name).get_tensor()
                )
                self.assertFalse(
                    np.array_equal(weight_init, weight)
                )  # weight changed
324 325 326 327 328 329 330 331 332 333 334 335

    def test_prune_compiled_program(self):
        program = framework.Program()
        startup_program = framework.Program()
        scope = fluid.Scope()
        with fluid.scope_guard(scope):
            with fluid.program_guard(program, startup_program):
                (x, y, label, loss1, loss2, w_param_attrs) = self.net1()
                sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.5)
                sgd_optimizer.minimize(loss1)
                exe = fluid.Executor(fluid.CPUPlace())
                exe.run(startup_program)
336
                compiled_prog = fluid.CompiledProgram(program)
337
                weight_init = np.array(
338 339
                    scope.find_var(w_param_attrs.name).get_tensor()
                )
340 341
                x_np = np.random.random(size=(10, 2)).astype('float32')
                label_np = np.random.randint(1, size=(10, 1)).astype('int64')
342 343 344 345 346 347
                res = exe.run(
                    compiled_prog,
                    feed={'x': x_np, 'label': label_np},
                    fetch_list=[loss1.name],
                    use_prune=True,
                )
348 349 350
                self.assertIsNotNone(scope.find_var(loss1.name))
                self.assertIsNone(scope.find_var(loss2.name))
                weight = np.array(
351 352 353 354 355
                    scope.find_var(w_param_attrs.name).get_tensor()
                )
                self.assertFalse(
                    np.array_equal(weight_init, weight)
                )  # weight changed
356 357 358 359 360 361 362 363 364 365 366

    def test_prune_feed_without_optimizer(self):
        program = framework.Program()
        startup_program = framework.Program()
        scope = fluid.Scope()
        with fluid.scope_guard(scope):
            with fluid.program_guard(program, startup_program):
                (x, y, label, loss1, loss2, w_param_attrs) = self.net1()
                exe = fluid.Executor(fluid.CPUPlace())
                exe.run(startup_program)
                weight_init = np.array(
367 368
                    scope.find_var(w_param_attrs.name).get_tensor()
                )
369 370
                x_np = np.random.random(size=(10, 2)).astype('float32')
                label_np = np.random.randint(1, size=(10, 1)).astype('int64')
371 372 373 374 375 376
                res = exe.run(
                    program,
                    feed={y.name: x_np, 'label': label_np},
                    fetch_list=[loss1.name],
                    use_prune=True,
                )
377 378 379
                self.assertIsNotNone(scope.find_var(loss1.name))
                self.assertIsNone(scope.find_var(loss2.name))
                weight = np.array(
380 381 382 383 384
                    scope.find_var(w_param_attrs.name).get_tensor()
                )
                np.testing.assert_array_equal(
                    weight_init, weight
                )  # weight unchanged
385 386 387 388 389 390 391 392 393 394 395 396 397 398

    def test_prune_feed_with_optimizer(self):
        program = framework.Program()
        startup_program = framework.Program()
        scope = fluid.Scope()
        with fluid.scope_guard(scope):
            with fluid.program_guard(program, startup_program):
                (x, y, label, loss1, loss2, w_param_attrs) = self.net1()
                sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.5)
                sgd_optimizer.minimize(loss1)
                exe = fluid.Executor(fluid.CPUPlace())
                exe.run(startup_program)
                x_np = np.random.random(size=(10, 2)).astype('float32')
                label_np = np.random.randint(1, size=(10, 1)).astype('int64')
399 400 401 402 403 404 405 406
                self.assertRaises(
                    Exception,
                    exe.run,
                    program,
                    feed={y.name: x_np, 'label': label_np},
                    fetch_list=[loss1.name],
                    use_prune=True,
                )
407 408 409 410 411
                self.assertIsNotNone(scope.find_var(loss1.name))
                self.assertIsNone(scope.find_var(loss2.name))

    def test_prune_with_cache_program(self):
        '''
412
        When use_prune=True, Executor should cache the pruned program.
413 414 415
        If in next run, the program, feed, fetch are not changed, Executor use the cached pruned program,
        and needn't to call  _prune_program() to prune the program.
        In this test, we hack the Executor._prune_program with a mock function which do nothing but increase
416
        Executor.prune_called_times, and we check prune_called_times equals 1 even if we called exe.run()
417 418 419 420 421 422 423 424 425 426 427 428 429 430 431
        10 times with the same input arguments.
        '''
        with _mock_guard(mock):
            exe = fluid.Executor(fluid.CPUPlace())
            exe.prune_called_times = 0
            program = framework.Program()
            startup_program = framework.Program()
            scope = fluid.Scope()
            with fluid.scope_guard(scope):
                with fluid.program_guard(program, startup_program):
                    (x, y, label, loss1, loss2, w_param_attrs) = self.net1()
                    sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.5)
                    sgd_optimizer.minimize(loss1)
                    exe.run(startup_program)
                    x_np = np.random.random(size=(10, 2)).astype('float32')
432 433 434
                    label_np = np.random.randint(1, size=(10, 1)).astype(
                        'int64'
                    )
435
                    for i in range(10):
436 437 438 439 440 441
                        res = exe.run(
                            program,
                            feed={'x': x_np, 'label': label_np},
                            fetch_list=[loss1.name],
                            use_prune=True,
                        )
442 443 444 445 446
                        if i == 0:
                            self.assertEqual(exe.prune_called_times, 1)
                        else:
                            self.assertEqual(exe.prune_called_times, 1)

447 448 449
    def test_prune_with_cache_program2(self):
        '''
        When use_prune=True, Executor should cache the pruned program.
450
        If the only difference in fetch_list is  optimize_ops during multiple runs,
451 452 453 454 455 456 457 458 459 460
        the cache_keys should be different and get different pruned program.
        '''
        with _mock_guard(mock):
            exe = fluid.Executor(fluid.CPUPlace())
            exe.prune_called_times = 0
            program = framework.Program()
            startup_program = framework.Program()
            scope = fluid.Scope()
            with fluid.scope_guard(scope):
                with fluid.program_guard(program, startup_program):
461 462 463 464 465 466 467 468 469 470 471
                    (
                        x1,
                        x2,
                        y1,
                        y2,
                        label,
                        loss1,
                        loss2,
                        w1_param_attrs,
                        w2_param_attrs,
                    ) = self.net2()
472
                    adam_optimizer1 = fluid.optimizer.AdamOptimizer(
473 474
                        learning_rate=0.5
                    )
475 476
                    train1 = adam_optimizer1.minimize(loss1)
                    adam_optimizer2 = fluid.optimizer.AdamOptimizer(
477 478
                        learning_rate=0.5
                    )
479 480 481
                    train2 = adam_optimizer2.minimize(loss2)
                    exe.run(startup_program)
                    x_np = np.random.random(size=(10, 2)).astype('float32')
482 483 484
                    label_np = np.random.randint(1, size=(10, 1)).astype(
                        'int64'
                    )
485 486 487

                    for i in range(10):
                        if i % 2:
488 489 490 491 492 493 494 495 496 497
                            res = exe.run(
                                program,
                                feed={
                                    'x1': x_np,
                                    'x2': x_np,
                                    'label': label_np,
                                },
                                fetch_list=[loss1, loss2, train1],
                                use_prune=True,
                            )
498
                        else:
499 500 501 502 503 504 505 506 507 508
                            res = exe.run(
                                program,
                                feed={
                                    'x1': x_np,
                                    'x2': x_np,
                                    'label': label_np,
                                },
                                fetch_list=[loss1, loss2, train2],
                                use_prune=True,
                            )
509 510 511 512 513 514 515
                        if i == 0:
                            self.assertEqual(exe.prune_called_times, 1)
                        elif i == 1:
                            self.assertEqual(exe.prune_called_times, 2)
                        else:
                            self.assertEqual(exe.prune_called_times, 2)

516 517
    def test_prune_with_cache_compiled_program(self):
        '''
518
        When use_prune=True, Executor should cache the pruned program.
519 520 521
        If in next run, the program, feed, fetch are not changed, Executor use the cached pruned program,
        and needn't to call  _prune_program() to prune the program.
        In this test, we hack the Executor._prune_program with a mock function which do nothing but increase
522
        Executor.prune_called_times, and we check prune_called_times equals 1 even if we called exe.run()
523 524 525 526 527 528 529 530 531 532 533 534 535 536 537
        10 times with the same input arguments.
        '''
        with _mock_guard(mock):
            exe = fluid.Executor(fluid.CPUPlace())
            exe.prune_called_times = 0
            program = framework.Program()
            startup_program = framework.Program()
            scope = fluid.Scope()
            with fluid.scope_guard(scope):
                with fluid.program_guard(program, startup_program):
                    (x, y, label, loss1, loss2, w_param_attrs) = self.net1()
                    sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.5)
                    sgd_optimizer.minimize(loss1)
                    exe.run(startup_program)
                    x_np = np.random.random(size=(10, 2)).astype('float32')
538 539 540
                    label_np = np.random.randint(1, size=(10, 1)).astype(
                        'int64'
                    )
541
                    compiled_prog = fluid.CompiledProgram(program)
542
                    for i in range(10):
543 544 545 546 547 548
                        res = exe.run(
                            compiled_prog,
                            feed={'x': x_np, 'label': label_np},
                            fetch_list=[loss1.name],
                            use_prune=True,
                        )
549 550 551 552 553 554 555
                        if i == 0:
                            self.assertEqual(exe.prune_called_times, 1)
                        else:
                            self.assertEqual(exe.prune_called_times, 1)

    def test_prune_with_multi_optimizers(self):
        '''
556
        If there are multiple optimizers in the program, we can run specific one by
557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573
        pass the return of optimize.minimize() to fetch_list.
        '''
        exe = fluid.Executor(fluid.CPUPlace())
        program = framework.Program()
        startup_program = framework.Program()
        scope = fluid.Scope()
        # do not use_prune
        with fluid.scope_guard(scope):
            with fluid.program_guard(program, startup_program):
                (x, y, label, loss1, loss2, w_param_attrs) = self.net1()
                sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.5)
                train1, _ = sgd_optimizer.minimize(loss1)
                cloned_program = program.clone()
                train2, _ = sgd_optimizer.minimize(loss2)
                exe.run(startup_program)
                x_np = np.random.random(size=(10, 2)).astype('float32')
                label_np = np.random.randint(1, size=(10, 1)).astype('int64')
574 575 576 577 578 579
                res = exe.run(
                    program,
                    feed={'x': x_np, 'label': label_np},
                    fetch_list=[loss1.name],
                    use_prune=False,
                )
580
                weight_without_prune = np.array(
581 582
                    scope.find_var(w_param_attrs.name).get_tensor()
                )
583 584 585 586 587

        scope = fluid.Scope()
        # use_prune
        with fluid.scope_guard(scope):
            exe.run(startup_program)
588 589 590 591 592 593
            res = exe.run(
                program,
                feed={'x': x_np, 'label': label_np},
                fetch_list=[loss1.name, train1],
                use_prune=True,
            )
594
            weight_with_prune = np.array(
595 596
                scope.find_var(w_param_attrs.name).get_tensor()
            )
597 598 599 600 601

        # expected
        scope = fluid.Scope()
        with fluid.scope_guard(scope):
            exe.run(startup_program)
602 603 604 605 606 607
            exe.run(
                cloned_program,
                feed={'x': x_np, 'label': label_np},
                fetch_list=[loss1.name],
                use_prune=False,
            )
608
            weight_expected = np.array(
609 610
                scope.find_var(w_param_attrs.name).get_tensor()
            )
611

612
        np.testing.assert_array_equal(weight_with_prune, weight_expected)
613 614 615 616
        self.assertFalse(np.array_equal(weight_without_prune, weight_expected))

    def test_prune_program_with_tupe_in_fetch_list(self):
        '''
617
        If there are multiple optimizers in the program, we can run specific one by
618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636
        pass the return of optimize.minimize() to fetch_list.
        '''
        exe = fluid.Executor(fluid.CPUPlace())
        program = framework.Program()
        startup_program = framework.Program()
        scope = fluid.Scope()
        # do not use_prune
        with fluid.scope_guard(scope):
            with fluid.program_guard(program, startup_program):
                (x, y, label, loss1, loss2, w_param_attrs) = self.net1()
                sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.5)
                train1 = sgd_optimizer.minimize(loss1)
                cloned_program = program.clone()

                train2 = sgd_optimizer.minimize(loss2)
                exe.run(startup_program)
                x_np = np.random.random(size=(10, 2)).astype('float32')
                label_np = np.random.randint(1, size=(10, 1)).astype('int64')

637 638 639 640 641 642
                res = exe.run(
                    program,
                    feed={'x': x_np, 'label': label_np},
                    fetch_list=[loss1.name],
                    use_prune=False,
                )
643 644

                weight_without_prune = np.array(
645 646
                    scope.find_var(w_param_attrs.name).get_tensor()
                )
647 648 649 650 651

        scope = fluid.Scope()
        # use_prune
        with fluid.scope_guard(scope):
            exe.run(startup_program)
652 653 654 655 656 657
            res = exe.run(
                program,
                feed={'x': x_np, 'label': label_np},
                fetch_list=[loss1.name, train1],
                use_prune=True,
            )
658
            weight_with_prune = np.array(
659 660
                scope.find_var(w_param_attrs.name).get_tensor()
            )
661 662 663 664 665

        # expected
        scope = fluid.Scope()
        with fluid.scope_guard(scope):
            exe.run(startup_program)
666 667 668 669 670 671
            exe.run(
                cloned_program,
                feed={'x': x_np, 'label': label_np},
                fetch_list=[loss1.name],
                use_prune=False,
            )
672
            weight_expected = np.array(
673 674
                scope.find_var(w_param_attrs.name).get_tensor()
            )
675

676
        np.testing.assert_array_equal(weight_with_prune, weight_expected)
677 678 679 680 681 682 683 684 685 686 687 688
        self.assertFalse(np.array_equal(weight_without_prune, weight_expected))

    def test_prune_program_partial_parameter_updated(self):
        """
        When running startup program, all parameters declared will be initialized.
        When running main program with prune=True, the pruned parameters will exist in scope and stay unchanged.
        """
        program = framework.Program()
        startup_program = framework.Program()
        scope = fluid.Scope()
        with fluid.scope_guard(scope):
            with fluid.program_guard(program, startup_program):
689 690 691 692 693 694 695 696 697 698 699
                (
                    x1,
                    x2,
                    y1,
                    y2,
                    label,
                    loss1,
                    loss2,
                    w1_param_attrs,
                    w2_param_attrs,
                ) = self.net2()
700 701 702 703 704 705 706 707 708
                loss1.persistable = True
                loss2.persistable = True
                sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.5)
                train1 = sgd_optimizer.minimize(loss1)
                sgd_optimizer1 = fluid.optimizer.SGD(learning_rate=0.5)
                train2 = sgd_optimizer1.minimize(loss2)
                exe = fluid.Executor(fluid.CPUPlace())
                exe.run(startup_program)
                weight1_init = np.array(
709 710
                    scope.find_var(w1_param_attrs.name).get_tensor()
                )
711
                weight2_init = np.array(
712 713
                    scope.find_var(w2_param_attrs.name).get_tensor()
                )
714 715 716
                x_np = np.random.random(size=(10, 2)).astype('float32')
                label_np = np.random.randint(1, size=(10, 1)).astype('int64')

717 718 719 720 721 722
                res = exe.run(
                    program,
                    feed={'x1': x_np, 'label': label_np},
                    fetch_list=[loss1.name, train1],
                    use_prune=True,
                )
723 724 725 726 727
                self.assertIsNotNone(scope.find_var(w1_param_attrs.name))
                self.assertIsNotNone(scope.find_var(w2_param_attrs.name))
                self.assertIsNotNone(scope.find_var(loss1.name))
                self.assertIsNone(scope.find_var(loss2.name))
                weight1 = np.array(
728 729
                    scope.find_var(w1_param_attrs.name).get_tensor()
                )
730
                weight2 = np.array(
731 732 733 734 735 736 737 738
                    scope.find_var(w2_param_attrs.name).get_tensor()
                )
                self.assertFalse(
                    np.array_equal(weight1_init, weight1)
                )  # weight changed
                np.testing.assert_array_equal(
                    weight2_init, weight2
                )  # weight2 unchanged
739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758

    def test_prune_override_use_prune(self):
        '''
        If optimize_ops in provided in the fetch_list, the argument use_prune is always override to True.
        '''
        exe = fluid.Executor(fluid.CPUPlace())
        program = framework.Program()
        startup_program = framework.Program()
        scope = fluid.Scope()
        # do not use_prune
        with fluid.scope_guard(scope):
            with fluid.program_guard(program, startup_program):
                (x, y, label, loss1, loss2, w_param_attrs) = self.net1()
                sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.5)
                train1, _ = sgd_optimizer.minimize(loss1)
                cloned_program = program.clone()
                train2, _ = sgd_optimizer.minimize(loss2)
                exe.run(startup_program)
                x_np = np.random.random(size=(10, 2)).astype('float32')
                label_np = np.random.randint(1, size=(10, 1)).astype('int64')
759 760 761 762 763 764
                res = exe.run(
                    program,
                    feed={'x': x_np, 'label': label_np},
                    fetch_list=[loss1.name],
                    use_prune=False,
                )
765 766

                weight_without_prune = np.array(
767 768
                    scope.find_var(w_param_attrs.name).get_tensor()
                )
769 770 771 772 773

        scope = fluid.Scope()
        # use_prune
        with fluid.scope_guard(scope):
            exe.run(startup_program)
774 775 776 777 778
            res = exe.run(
                program,
                feed={'x': x_np, 'label': label_np},
                fetch_list=[loss1.name, train1],
            )
779
            weight_with_prune = np.array(
780 781
                scope.find_var(w_param_attrs.name).get_tensor()
            )
782 783 784 785 786

        # expected
        scope = fluid.Scope()
        with fluid.scope_guard(scope):
            exe.run(startup_program)
787 788 789 790 791 792
            exe.run(
                cloned_program,
                feed={'x': x_np, 'label': label_np},
                fetch_list=[loss1.name],
                use_prune=False,
            )
793
            weight_expected = np.array(
794 795
                scope.find_var(w_param_attrs.name).get_tensor()
            )
796

797
        np.testing.assert_array_equal(weight_with_prune, weight_expected)
798 799
        self.assertFalse(np.array_equal(weight_without_prune, weight_expected))

800 801 802 803 804 805 806 807 808 809 810
    def test_prune_feed_var_in_fetchlist_1(self):
        # the variable to be fed is not leaf
        program = framework.Program()
        startup_program = framework.Program()
        scope = fluid.Scope()
        with fluid.scope_guard(scope):
            with fluid.program_guard(program, startup_program):
                (x, y, label, loss1, loss2, w_param_attrs) = self.net1()
                exe = fluid.Executor(fluid.CPUPlace())
                exe.run(startup_program)
                weight_init = np.array(
811 812
                    scope.find_var(w_param_attrs.name).get_tensor()
                )
813 814
                x_np = np.random.random(size=(10, 2)).astype('float32')
                label_np = np.random.randint(1, size=(10, 1)).astype('int64')
815 816 817 818 819 820
                res = exe.run(
                    program,
                    feed={y.name: x_np, 'label': label_np},
                    fetch_list=[y.name, loss1.name],
                    use_prune=True,
                )
821 822 823 824
                self.assertIsNotNone(scope.find_var(loss1.name))
                self.assertIsNone(scope.find_var(loss2.name))
                self.assertIsNone(scope.find_var(x.name))
                weight = np.array(
825 826 827 828 829
                    scope.find_var(w_param_attrs.name).get_tensor()
                )
                np.testing.assert_array_equal(
                    weight_init, weight
                )  # weight unchanged
830 831 832 833 834 835 836 837 838 839 840 841

    def test_prune_feed_var_in_fetchlist_2(self):
        # the variable to be fed is leaf
        program = framework.Program()
        startup_program = framework.Program()
        scope = fluid.Scope()
        with fluid.scope_guard(scope):
            with fluid.program_guard(program, startup_program):
                (x, y, label, loss1, loss2, w_param_attrs) = self.net1()
                exe = fluid.Executor(fluid.CPUPlace())
                exe.run(startup_program)
                weight_init = np.array(
842 843
                    scope.find_var(w_param_attrs.name).get_tensor()
                )
844 845
                x_np = np.random.random(size=(10, 2)).astype('float32')
                label_np = np.random.randint(1, size=(10, 1)).astype('int64')
846 847 848 849 850 851
                res = exe.run(
                    program,
                    feed={x.name: x_np, 'label': label_np},
                    fetch_list=[x.name, loss1.name],
                    use_prune=True,
                )
852 853 854
                self.assertIsNotNone(scope.find_var(loss1.name))
                self.assertIsNone(scope.find_var(loss2.name))
                weight = np.array(
855 856 857 858 859
                    scope.find_var(w_param_attrs.name).get_tensor()
                )
                np.testing.assert_array_equal(
                    weight_init, weight
                )  # weight unchanged
860

861

862 863
if __name__ == '__main__':
    unittest.main()