From 2ca0e118613d6125966e6b13d2e7b4f55c575104 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Fri, 24 Apr 2020 16:57:56 +0800 Subject: [PATCH] support fetch the feed var when use_prune=True, test=develop (#24110) --- python/paddle/fluid/framework.py | 21 ++++++-- .../fluid/tests/unittests/test_prune.py | 53 +++++++++++++++++++ 2 files changed, 69 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 0ebf7888451..3994fa99f0c 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -4237,6 +4237,14 @@ class Program(object): raise ValueError( "All targets of Program._prune_with_input() can only be " "Variable or Operator, but received %s." % type(t)) + + # NOTEZ(zhiqiu): For variable to be fed in fetch_list, there two cases: + # (1) the variable is leaf, it has no op that generates it; + # (2) the variable is not leaf, and we need to prune the op that generates it. + # In both cases, wo can just skip target_op of that it. + if name in feeded_var_names: + continue + # After transpiler processing, the op that output this # variable maybe has been changed, so t.op is not reliable # and we need to find the current op that generate this @@ -4253,11 +4261,14 @@ class Program(object): else: target_op = op break - t = target_op - if t is None: - raise ValueError("The target variable must have an " - "associated operator that generates it.") - targets_idx.append([t.block.idx, t.idx]) + if target_op is None: + raise ValueError( + "The target variable used for pruning should have an " + "associated operator that generates it.") + else: + targets_idx.append([target_op.block.idx, target_op.idx]) + else: + targets_idx.append([t.block.idx, t.idx]) res = Program() res.desc, pruned_origin_block_id_map = core.prune(self.desc, diff --git a/python/paddle/fluid/tests/unittests/test_prune.py b/python/paddle/fluid/tests/unittests/test_prune.py index a778b80acb0..3755d92858a 100644 --- a/python/paddle/fluid/tests/unittests/test_prune.py +++ b/python/paddle/fluid/tests/unittests/test_prune.py @@ -725,6 +725,59 @@ class TestExecutorRunAutoPrune(unittest.TestCase): self.assertTrue(np.array_equal(weight_with_prune, weight_expected)) self.assertFalse(np.array_equal(weight_without_prune, weight_expected)) + def test_prune_feed_var_in_fetchlist_1(self): + # the variable to be fed is not leaf + program = framework.Program() + startup_program = framework.Program() + scope = fluid.Scope() + with fluid.scope_guard(scope): + with fluid.program_guard(program, startup_program): + (x, y, label, loss1, loss2, w_param_attrs) = self.net1() + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(startup_program) + weight_init = np.array( + scope.find_var(w_param_attrs.name).get_tensor()) + x_np = np.random.random(size=(10, 2)).astype('float32') + label_np = np.random.randint(1, size=(10, 1)).astype('int64') + res = exe.run(program, + feed={y.name: x_np, + 'label': label_np}, + fetch_list=[y.name, loss1.name], + use_prune=True) + self.assertIsNotNone(scope.find_var(loss1.name)) + self.assertIsNone(scope.find_var(loss2.name)) + self.assertIsNone(scope.find_var(x.name)) + weight = np.array( + scope.find_var(w_param_attrs.name).get_tensor()) + self.assertTrue(np.array_equal(weight_init, + weight)) # weight unchanged + + def test_prune_feed_var_in_fetchlist_2(self): + # the variable to be fed is leaf + program = framework.Program() + startup_program = framework.Program() + scope = fluid.Scope() + with fluid.scope_guard(scope): + with fluid.program_guard(program, startup_program): + (x, y, label, loss1, loss2, w_param_attrs) = self.net1() + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(startup_program) + weight_init = np.array( + scope.find_var(w_param_attrs.name).get_tensor()) + x_np = np.random.random(size=(10, 2)).astype('float32') + label_np = np.random.randint(1, size=(10, 1)).astype('int64') + res = exe.run(program, + feed={x.name: x_np, + 'label': label_np}, + fetch_list=[x.name, loss1.name], + use_prune=True) + self.assertIsNotNone(scope.find_var(loss1.name)) + self.assertIsNone(scope.find_var(loss2.name)) + weight = np.array( + scope.find_var(w_param_attrs.name).get_tensor()) + self.assertTrue(np.array_equal(weight_init, + weight)) # weight unchanged + if __name__ == '__main__': unittest.main() -- GitLab