未验证 提交 d0789013 编写于 作者: L Leo Chen 提交者: GitHub

support fetch the feed var when use_prune=True, test=develop (#24112)

上级 b4b31f47
...@@ -4237,6 +4237,14 @@ class Program(object): ...@@ -4237,6 +4237,14 @@ class Program(object):
raise ValueError( raise ValueError(
"All targets of Program._prune_with_input() can only be " "All targets of Program._prune_with_input() can only be "
"Variable or Operator, but received %s." % type(t)) "Variable or Operator, but received %s." % type(t))
# NOTEZ(zhiqiu): For variable to be fed in fetch_list, there two cases:
# (1) the variable is leaf, it has no op that generates it;
# (2) the variable is not leaf, and we need to prune the op that generates it.
# In both cases, wo can just skip target_op of that it.
if name in feeded_var_names:
continue
# After transpiler processing, the op that output this # After transpiler processing, the op that output this
# variable maybe has been changed, so t.op is not reliable # variable maybe has been changed, so t.op is not reliable
# and we need to find the current op that generate this # and we need to find the current op that generate this
...@@ -4253,11 +4261,14 @@ class Program(object): ...@@ -4253,11 +4261,14 @@ class Program(object):
else: else:
target_op = op target_op = op
break break
t = target_op if target_op is None:
if t is None: raise ValueError(
raise ValueError("The target variable must have an " "The target variable used for pruning should have an "
"associated operator that generates it.") "associated operator that generates it.")
targets_idx.append([t.block.idx, t.idx]) else:
targets_idx.append([target_op.block.idx, target_op.idx])
else:
targets_idx.append([t.block.idx, t.idx])
res = Program() res = Program()
res.desc, pruned_origin_block_id_map = core.prune(self.desc, res.desc, pruned_origin_block_id_map = core.prune(self.desc,
......
...@@ -725,6 +725,59 @@ class TestExecutorRunAutoPrune(unittest.TestCase): ...@@ -725,6 +725,59 @@ class TestExecutorRunAutoPrune(unittest.TestCase):
self.assertTrue(np.array_equal(weight_with_prune, weight_expected)) self.assertTrue(np.array_equal(weight_with_prune, weight_expected))
self.assertFalse(np.array_equal(weight_without_prune, weight_expected)) self.assertFalse(np.array_equal(weight_without_prune, weight_expected))
def test_prune_feed_var_in_fetchlist_1(self):
# the variable to be fed is not leaf
program = framework.Program()
startup_program = framework.Program()
scope = fluid.Scope()
with fluid.scope_guard(scope):
with fluid.program_guard(program, startup_program):
(x, y, label, loss1, loss2, w_param_attrs) = self.net1()
exe = fluid.Executor(fluid.CPUPlace())
exe.run(startup_program)
weight_init = np.array(
scope.find_var(w_param_attrs.name).get_tensor())
x_np = np.random.random(size=(10, 2)).astype('float32')
label_np = np.random.randint(1, size=(10, 1)).astype('int64')
res = exe.run(program,
feed={y.name: x_np,
'label': label_np},
fetch_list=[y.name, loss1.name],
use_prune=True)
self.assertIsNotNone(scope.find_var(loss1.name))
self.assertIsNone(scope.find_var(loss2.name))
self.assertIsNone(scope.find_var(x.name))
weight = np.array(
scope.find_var(w_param_attrs.name).get_tensor())
self.assertTrue(np.array_equal(weight_init,
weight)) # weight unchanged
def test_prune_feed_var_in_fetchlist_2(self):
# the variable to be fed is leaf
program = framework.Program()
startup_program = framework.Program()
scope = fluid.Scope()
with fluid.scope_guard(scope):
with fluid.program_guard(program, startup_program):
(x, y, label, loss1, loss2, w_param_attrs) = self.net1()
exe = fluid.Executor(fluid.CPUPlace())
exe.run(startup_program)
weight_init = np.array(
scope.find_var(w_param_attrs.name).get_tensor())
x_np = np.random.random(size=(10, 2)).astype('float32')
label_np = np.random.randint(1, size=(10, 1)).astype('int64')
res = exe.run(program,
feed={x.name: x_np,
'label': label_np},
fetch_list=[x.name, loss1.name],
use_prune=True)
self.assertIsNotNone(scope.find_var(loss1.name))
self.assertIsNone(scope.find_var(loss2.name))
weight = np.array(
scope.find_var(w_param_attrs.name).get_tensor())
self.assertTrue(np.array_equal(weight_init,
weight)) # weight unchanged
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册