From dca9b6c5b06656c611cd5b3ca2740fd0b24e2c44 Mon Sep 17 00:00:00 2001 From: mapingshuo Date: Thu, 5 Sep 2019 19:04:40 +0800 Subject: [PATCH] add feed_var_names to Prune interface (#19589) * Fix bug: add feed_vars to the prune function --- paddle/fluid/framework/prune.cc | 23 ++++++++++++++++------- paddle/fluid/framework/prune.h | 6 +++++- paddle/fluid/framework/prune_test.cc | 21 +++++++++++++-------- paddle/fluid/pybind/pybind.cc | 4 +++- python/paddle/fluid/framework.py | 12 ++++++++++-- python/paddle/fluid/io.py | 2 +- 6 files changed, 48 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/framework/prune.cc b/paddle/fluid/framework/prune.cc index 0afcd85fe7c..a2df5cfcf93 100644 --- a/paddle/fluid/framework/prune.cc +++ b/paddle/fluid/framework/prune.cc @@ -68,7 +68,8 @@ bool HasSubBlock(const proto::OpDesc& op_desc) { // the child block to help pruning void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, int block_id, int parent_block_id, - std::set* dependent_vars) { + std::set* dependent_vars, + const std::set feed_var_names) { auto& block = input.blocks(block_id); auto& ops = block.ops(); @@ -94,7 +95,9 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, // insert its input to the dependency graph for (auto& var : op_desc.inputs()) { for (auto& argu : var.arguments()) { - dependent_vars->insert(argu); + if (feed_var_names.count(argu) == 0) { + dependent_vars->insert(argu); + } } } should_run.push_back(true); @@ -127,18 +130,22 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, std::set sub_block_dependent_vars; for (auto& var : op->inputs()) { for (auto& argu : var.arguments()) { - sub_block_dependent_vars.insert(argu); + if (feed_var_names.count(argu) == 0) { + sub_block_dependent_vars.insert(argu); + } } } for (auto& var : op->outputs()) { for (auto& argu : var.arguments()) { - sub_block_dependent_vars.insert(argu); + if (feed_var_names.count(argu) == 0) { + sub_block_dependent_vars.insert(argu); + } } } // GetSubBlockIndex(*op) is the idx of the sub_block in the input desc // output_block_id is the idx of the current block in the output desc prune_impl(input, output, GetSubBlockIndex(*op), output_block_id, - &sub_block_dependent_vars); + &sub_block_dependent_vars, feed_var_names); } } } @@ -178,10 +185,12 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, } // TODO(fengjiayi): Prune() could be inplaced to avoid unnecessary copies -void Prune(const proto::ProgramDesc& input, proto::ProgramDesc* output) { +void Prune(const proto::ProgramDesc& input, + const std::set& feed_var_names, + proto::ProgramDesc* output) { std::set dependent_vars; output->clear_blocks(); - prune_impl(input, output, 0, -1, &dependent_vars); + prune_impl(input, output, 0, -1, &dependent_vars, feed_var_names); } } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/prune.h b/paddle/fluid/framework/prune.h index 1be7cd25d09..3caa6cde4f4 100644 --- a/paddle/fluid/framework/prune.h +++ b/paddle/fluid/framework/prune.h @@ -14,13 +14,17 @@ limitations under the License. */ #pragma once +#include +#include #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { namespace framework { -void Prune(const proto::ProgramDesc& input, proto::ProgramDesc* output); +void Prune(const proto::ProgramDesc& input, + const std::set& feed_var_names, + proto::ProgramDesc* output); } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/prune_test.cc b/paddle/fluid/framework/prune_test.cc index 8af7d2d510d..210e61a4dec 100644 --- a/paddle/fluid/framework/prune_test.cc +++ b/paddle/fluid/framework/prune_test.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/fluid/framework/prune.h" #include +#include #include #include "paddle/fluid/framework/attribute.h" @@ -58,12 +59,13 @@ TEST(Prune, one_operator) { f::proto::ProgramDesc *pdesc = program.Proto(); f::proto::ProgramDesc pruned; - - f::Prune(*pdesc, &pruned); + std::set feed_var_names = {}; + f::Prune(*pdesc, feed_var_names, &pruned); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 0); + feed_var_names.insert("a"); pdesc->mutable_blocks(0)->mutable_ops(0)->set_is_target(true); - f::Prune(*pdesc, &pruned); + f::Prune(*pdesc, feed_var_names, &pruned); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 1); } @@ -81,11 +83,11 @@ TEST(Prune, forward) { block); f::proto::ProgramDesc *pdesc = program.Proto(); - + std::set feed_var_names = {"a"}; for (int i = 0; i < pdesc->blocks(0).ops_size(); ++i) { f::proto::ProgramDesc pruned; pdesc->mutable_blocks(0)->mutable_ops(i)->set_is_target(true); - f::Prune(*pdesc, &pruned); + f::Prune(*pdesc, feed_var_names, &pruned); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), i + 1); } } @@ -107,7 +109,8 @@ TEST(Prune, multi_input_op) { pdesc->mutable_blocks(0)->mutable_ops(3)->set_is_target(true); f::proto::ProgramDesc pruned; - f::Prune(*pdesc, &pruned); + std::set feed_var_names = {"a0", "a1", "a2"}; + f::Prune(*pdesc, feed_var_names, &pruned); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 4); } @@ -126,7 +129,8 @@ TEST(Prune, multi_output_op) { pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true); f::proto::ProgramDesc pruned; - f::Prune(*pdesc, &pruned); + std::set feed_var_names = {"a"}; + f::Prune(*pdesc, feed_var_names, &pruned); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 2); } @@ -146,6 +150,7 @@ TEST(Prune, multi_target) { pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true); f::proto::ProgramDesc pruned; - f::Prune(*pdesc, &pruned); + std::set feed_var_names = {"a"}; + f::Prune(*pdesc, feed_var_names, &pruned); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 3); } diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 2b6ea4575ae..ad9501af6b9 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -749,13 +749,15 @@ All parameter, weight, gradient are variables in Paddle. #endif m.def("prune", [](const ProgramDesc &origin, + const std::set &feeded_var_names, const std::vector> &targets) { ProgramDesc prog_with_targets(origin); + for (const auto &t : targets) { prog_with_targets.MutableBlock(t[0])->Op(t[1])->SetIsTarget(true); } proto::ProgramDesc pruned_desc; - Prune(*prog_with_targets.Proto(), &pruned_desc); + Prune(*prog_with_targets.Proto(), feeded_var_names, &pruned_desc); return new ProgramDesc(pruned_desc); }); m.def("empty_var_name", diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 5e2c3394520..805d5381588 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -3247,7 +3247,7 @@ class Program(object): p._copy_dist_param_info_from(self) return p - def _prune(self, targets): + def _prune(self, feeded_var_names, targets): """ Prune operators and variables which are not needed to generate :code:`targets`. @@ -3263,8 +3263,16 @@ class Program(object): Program: A new, pruned program. """ + if not isinstance(feeded_var_names, list): + feeded_var_names = [feeded_var_names] if not isinstance(targets, list): targets = [targets] + + for var in feeded_var_names: + if not isinstance(var, six.string_types): + raise ValueError("All feeded_var_names of prune() can only be " + "str.") + targets_idx = [] for t in targets: if not isinstance(t, Operator): @@ -3291,7 +3299,7 @@ class Program(object): targets_idx.append([t.block.idx, t.idx]) res = Program() - res.desc = core.prune(self.desc, targets_idx) + res.desc = core.prune(self.desc, set(feeded_var_names), targets_idx) res.blocks = [ Block(res, i) for i in six.moves.range(res.desc.num_blocks()) ] diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 4412010d7f3..6076b336c39 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -1080,7 +1080,7 @@ def save_inference_model(dirname, main_program.desc.flush() - main_program = main_program._prune(targets=target_vars) + main_program = main_program._prune(feeded_var_names, target_vars) main_program = main_program._inference_optimize(prune_read_op=True) fetch_var_names = [v.name for v in target_vars] -- GitLab