提交 5e99f31b 编写于 作者: M mapingshuo 提交者: Dong Daxiang

add a new interface _prune_with_input (#20022)

* add a default value for _prune interface

* modify document
上级 6f184775
......@@ -3597,7 +3597,7 @@ class Program(object):
p._copy_dist_param_info_from(self)
return p
def _prune(self, feeded_var_names, targets):
def _prune(self, targets):
"""
Prune operators and variables which are not needed to generate
:code:`targets`.
......@@ -3611,8 +3611,63 @@ class Program(object):
Returns:
Program: A new, pruned program.
"""
if not isinstance(targets, list):
targets = [targets]
targets_idx = []
for t in targets:
if not isinstance(t, Operator):
if isinstance(t, Variable):
# After transpiler processing, the op that output this
# variable maybe has been changed, so t.op is not reliable
# and we need to find the current op that generate this
# variable here.
t.op = None
global_block = self.global_block()
for idx, op in enumerate(global_block.ops):
if t.name in op.output_arg_names:
t.op = op
break
t = t.op
if t is None:
raise ValueError(
"The target variable must have an "
"associated operator that generates it.")
else:
raise ValueError("All targets of prune() can only be "
"Variable or Operator.")
targets_idx.append([t.block.idx, t.idx])
res = Program()
res.desc = core.prune(self.desc, set(), targets_idx)
res.blocks = [
Block(res, i) for i in six.moves.range(res.desc.num_blocks())
]
res._sync_with_cpp()
return res
def _prune_with_input(self, feeded_var_names, targets):
"""
Prune operators and variables which are not needed to generate
:code:`targets`. Prune operators and variables which are needed
to generate feeded_var
Notes: This is a very low level API. Users should not use this API
directly. This API is in flux and not stable.
Args:
feeded_var_names(list|str): A list of variable names from where
pruning start. If it is set as [], this API works just like _prune()
targets(list|Variable|Operator): A list of variables or operators
need to be pruned
Returns:
Program: A new, pruned program.
"""
if not isinstance(feeded_var_names, list):
feeded_var_names = [feeded_var_names]
if not isinstance(targets, list):
......
......@@ -1121,7 +1121,8 @@ def save_inference_model(dirname,
main_program.desc.flush()
main_program = main_program._prune(feeded_var_names, target_vars)
main_program = main_program._prune_with_input(
feeded_var_names=feeded_var_names, targets=target_vars)
main_program = main_program._inference_optimize(prune_read_op=True)
fetch_var_names = [v.name for v in target_vars]
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.fluid as fluid
import paddle.fluid.framework as framework
import paddle.compat as cpt
class TestPrune(unittest.TestCase):
def net(self):
x = fluid.layers.data(name='x', shape=[2], dtype='float32')
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
y = fluid.layers.fc(input=[x], size=2, act="softmax")
loss = fluid.layers.cross_entropy(input=y, label=label)
loss = fluid.layers.mean(x=loss)
return x, y, label, loss
def test_prune_with_input(self):
program = framework.Program()
startup_program = framework.Program()
block = program.global_block()
with fluid.program_guard(program, startup_program):
(x, y, label, loss) = self.net()
self.assertEqual(len(block.ops), 5)
self.assertEqual([op.type for op in block.ops], [
"mul", "elementwise_add", "softmax", "cross_entropy2", "mean"
])
pruned_program = program._prune_with_input(
feeded_var_names=[y.name, label.name], targets=[loss])
self.assertEqual(len(pruned_program.global_block().ops), 2)
self.assertEqual([op.type for op in pruned_program.global_block().ops],
["cross_entropy2", "mean"])
def test_prune(self):
program = framework.Program()
startup_program = framework.Program()
block = program.global_block()
with fluid.program_guard(program, startup_program):
(x, y, label, loss) = self.net()
self.assertEqual(len(block.ops), 5)
self.assertEqual([op.type for op in block.ops], [
"mul", "elementwise_add", "softmax", "cross_entropy2", "mean"
])
pruned_program = program._prune(targets=[loss])
self.assertEqual(len(pruned_program.global_block().ops), 5)
self.assertEqual(
[op.type for op in pruned_program.global_block().ops],
["mul", "elementwise_add", "softmax", "cross_entropy2", "mean"])
def test_prune_target_not_list(self):
program = framework.Program()
startup_program = framework.Program()
block = program.global_block()
with fluid.program_guard(program, startup_program):
(x, y, label, loss) = self.net()
self.assertEqual(len(block.ops), 5)
self.assertEqual([op.type for op in block.ops], [
"mul", "elementwise_add", "softmax", "cross_entropy2", "mean"
])
pruned_program = program._prune(targets=loss)
self.assertEqual(len(pruned_program.global_block().ops), 5)
self.assertEqual(
[op.type for op in pruned_program.global_block().ops],
["mul", "elementwise_add", "softmax", "cross_entropy2", "mean"])
def test_prune_target_none(self):
program = framework.Program()
startup_program = framework.Program()
block = program.global_block()
with fluid.program_guard(program, startup_program):
(x, y, label, loss) = self.net()
self.assertEqual(len(block.ops), 5)
self.assertEqual([op.type for op in block.ops], [
"mul", "elementwise_add", "softmax", "cross_entropy2", "mean"
])
try:
pruned_program = program._prune(targets=None)
except ValueError as e:
self.assertEqual(
"All targets of prune() can only be Variable or Operator.",
cpt.get_exception_message(e))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册