diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index d4abfeca44be483cca4d6a94bc6edac09b80a6ee..b1f2b6df091ab12ab65aa673d0645f44012c64ae 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -3597,7 +3597,7 @@ class Program(object): p._copy_dist_param_info_from(self) return p - def _prune(self, feeded_var_names, targets): + def _prune(self, targets): """ Prune operators and variables which are not needed to generate :code:`targets`. @@ -3611,8 +3611,63 @@ class Program(object): Returns: Program: A new, pruned program. + """ + + if not isinstance(targets, list): + targets = [targets] + targets_idx = [] + for t in targets: + if not isinstance(t, Operator): + if isinstance(t, Variable): + # After transpiler processing, the op that output this + # variable maybe has been changed, so t.op is not reliable + # and we need to find the current op that generate this + # variable here. + t.op = None + global_block = self.global_block() + for idx, op in enumerate(global_block.ops): + if t.name in op.output_arg_names: + t.op = op + break + + t = t.op + if t is None: + raise ValueError( + "The target variable must have an " + "associated operator that generates it.") + else: + raise ValueError("All targets of prune() can only be " + "Variable or Operator.") + + targets_idx.append([t.block.idx, t.idx]) + res = Program() + res.desc = core.prune(self.desc, set(), targets_idx) + res.blocks = [ + Block(res, i) for i in six.moves.range(res.desc.num_blocks()) + ] + res._sync_with_cpp() + return res + + def _prune_with_input(self, feeded_var_names, targets): + """ + Prune operators and variables which are not needed to generate + :code:`targets`. Prune operators and variables which are needed + to generate feeded_var + + Notes: This is a very low level API. Users should not use this API + directly. This API is in flux and not stable. + + Args: + feeded_var_names(list|str): A list of variable names from where + pruning start. If it is set as [], this API works just like _prune() + targets(list|Variable|Operator): A list of variables or operators + need to be pruned + + Returns: + Program: A new, pruned program. """ + if not isinstance(feeded_var_names, list): feeded_var_names = [feeded_var_names] if not isinstance(targets, list): diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 3f8318688f35f0cfe7ce557437cea8ad21973db2..9787d32741af47ca497638843ecf1b94177949f3 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -1121,7 +1121,8 @@ def save_inference_model(dirname, main_program.desc.flush() - main_program = main_program._prune(feeded_var_names, target_vars) + main_program = main_program._prune_with_input( + feeded_var_names=feeded_var_names, targets=target_vars) main_program = main_program._inference_optimize(prune_read_op=True) fetch_var_names = [v.name for v in target_vars] diff --git a/python/paddle/fluid/tests/unittests/test_prune.py b/python/paddle/fluid/tests/unittests/test_prune.py new file mode 100644 index 0000000000000000000000000000000000000000..dd7e1153b2c006a7313ae28bbf96b7c2baa6117e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_prune.py @@ -0,0 +1,100 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +import paddle.fluid as fluid +import paddle.fluid.framework as framework +import paddle.compat as cpt + + +class TestPrune(unittest.TestCase): + def net(self): + x = fluid.layers.data(name='x', shape=[2], dtype='float32') + label = fluid.layers.data(name="label", shape=[1], dtype="int64") + y = fluid.layers.fc(input=[x], size=2, act="softmax") + loss = fluid.layers.cross_entropy(input=y, label=label) + loss = fluid.layers.mean(x=loss) + return x, y, label, loss + + def test_prune_with_input(self): + program = framework.Program() + startup_program = framework.Program() + block = program.global_block() + with fluid.program_guard(program, startup_program): + (x, y, label, loss) = self.net() + self.assertEqual(len(block.ops), 5) + self.assertEqual([op.type for op in block.ops], [ + "mul", "elementwise_add", "softmax", "cross_entropy2", "mean" + ]) + pruned_program = program._prune_with_input( + feeded_var_names=[y.name, label.name], targets=[loss]) + self.assertEqual(len(pruned_program.global_block().ops), 2) + self.assertEqual([op.type for op in pruned_program.global_block().ops], + ["cross_entropy2", "mean"]) + + def test_prune(self): + program = framework.Program() + startup_program = framework.Program() + block = program.global_block() + with fluid.program_guard(program, startup_program): + (x, y, label, loss) = self.net() + self.assertEqual(len(block.ops), 5) + self.assertEqual([op.type for op in block.ops], [ + "mul", "elementwise_add", "softmax", "cross_entropy2", "mean" + ]) + pruned_program = program._prune(targets=[loss]) + self.assertEqual(len(pruned_program.global_block().ops), 5) + self.assertEqual( + [op.type for op in pruned_program.global_block().ops], + ["mul", "elementwise_add", "softmax", "cross_entropy2", "mean"]) + + def test_prune_target_not_list(self): + program = framework.Program() + startup_program = framework.Program() + block = program.global_block() + with fluid.program_guard(program, startup_program): + (x, y, label, loss) = self.net() + self.assertEqual(len(block.ops), 5) + self.assertEqual([op.type for op in block.ops], [ + "mul", "elementwise_add", "softmax", "cross_entropy2", "mean" + ]) + pruned_program = program._prune(targets=loss) + self.assertEqual(len(pruned_program.global_block().ops), 5) + self.assertEqual( + [op.type for op in pruned_program.global_block().ops], + ["mul", "elementwise_add", "softmax", "cross_entropy2", "mean"]) + + def test_prune_target_none(self): + program = framework.Program() + startup_program = framework.Program() + block = program.global_block() + with fluid.program_guard(program, startup_program): + (x, y, label, loss) = self.net() + self.assertEqual(len(block.ops), 5) + self.assertEqual([op.type for op in block.ops], [ + "mul", "elementwise_add", "softmax", "cross_entropy2", "mean" + ]) + try: + pruned_program = program._prune(targets=None) + except ValueError as e: + self.assertEqual( + "All targets of prune() can only be Variable or Operator.", + cpt.get_exception_message(e)) + + +if __name__ == '__main__': + unittest.main()