提交 dca9b6c5 编写于 作者: M mapingshuo 提交者: Dong Daxiang

add feed_var_names to Prune interface (#19589)

* Fix bug: add feed_vars to the prune function
上级 f45cb1c2
...@@ -68,7 +68,8 @@ bool HasSubBlock(const proto::OpDesc& op_desc) { ...@@ -68,7 +68,8 @@ bool HasSubBlock(const proto::OpDesc& op_desc) {
// the child block to help pruning // the child block to help pruning
void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
int block_id, int parent_block_id, int block_id, int parent_block_id,
std::set<std::string>* dependent_vars) { std::set<std::string>* dependent_vars,
const std::set<std::string> feed_var_names) {
auto& block = input.blocks(block_id); auto& block = input.blocks(block_id);
auto& ops = block.ops(); auto& ops = block.ops();
...@@ -94,9 +95,11 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, ...@@ -94,9 +95,11 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
// insert its input to the dependency graph // insert its input to the dependency graph
for (auto& var : op_desc.inputs()) { for (auto& var : op_desc.inputs()) {
for (auto& argu : var.arguments()) { for (auto& argu : var.arguments()) {
if (feed_var_names.count(argu) == 0) {
dependent_vars->insert(argu); dependent_vars->insert(argu);
} }
} }
}
should_run.push_back(true); should_run.push_back(true);
} else { } else {
should_run.push_back(false); should_run.push_back(false);
...@@ -127,18 +130,22 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, ...@@ -127,18 +130,22 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
std::set<std::string> sub_block_dependent_vars; std::set<std::string> sub_block_dependent_vars;
for (auto& var : op->inputs()) { for (auto& var : op->inputs()) {
for (auto& argu : var.arguments()) { for (auto& argu : var.arguments()) {
if (feed_var_names.count(argu) == 0) {
sub_block_dependent_vars.insert(argu); sub_block_dependent_vars.insert(argu);
} }
} }
}
for (auto& var : op->outputs()) { for (auto& var : op->outputs()) {
for (auto& argu : var.arguments()) { for (auto& argu : var.arguments()) {
if (feed_var_names.count(argu) == 0) {
sub_block_dependent_vars.insert(argu); sub_block_dependent_vars.insert(argu);
} }
} }
}
// GetSubBlockIndex(*op) is the idx of the sub_block in the input desc // GetSubBlockIndex(*op) is the idx of the sub_block in the input desc
// output_block_id is the idx of the current block in the output desc // output_block_id is the idx of the current block in the output desc
prune_impl(input, output, GetSubBlockIndex(*op), output_block_id, prune_impl(input, output, GetSubBlockIndex(*op), output_block_id,
&sub_block_dependent_vars); &sub_block_dependent_vars, feed_var_names);
} }
} }
} }
...@@ -178,10 +185,12 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, ...@@ -178,10 +185,12 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
} }
// TODO(fengjiayi): Prune() could be inplaced to avoid unnecessary copies // TODO(fengjiayi): Prune() could be inplaced to avoid unnecessary copies
void Prune(const proto::ProgramDesc& input, proto::ProgramDesc* output) { void Prune(const proto::ProgramDesc& input,
const std::set<std::string>& feed_var_names,
proto::ProgramDesc* output) {
std::set<std::string> dependent_vars; std::set<std::string> dependent_vars;
output->clear_blocks(); output->clear_blocks();
prune_impl(input, output, 0, -1, &dependent_vars); prune_impl(input, output, 0, -1, &dependent_vars, feed_var_names);
} }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -14,13 +14,17 @@ limitations under the License. */ ...@@ -14,13 +14,17 @@ limitations under the License. */
#pragma once #pragma once
#include <set>
#include <string>
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void Prune(const proto::ProgramDesc& input, proto::ProgramDesc* output); void Prune(const proto::ProgramDesc& input,
const std::set<std::string>& feed_var_names,
proto::ProgramDesc* output);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/prune.h" #include "paddle/fluid/framework/prune.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <set>
#include <string> #include <string>
#include "paddle/fluid/framework/attribute.h" #include "paddle/fluid/framework/attribute.h"
...@@ -58,12 +59,13 @@ TEST(Prune, one_operator) { ...@@ -58,12 +59,13 @@ TEST(Prune, one_operator) {
f::proto::ProgramDesc *pdesc = program.Proto(); f::proto::ProgramDesc *pdesc = program.Proto();
f::proto::ProgramDesc pruned; f::proto::ProgramDesc pruned;
std::set<std::string> feed_var_names = {};
f::Prune(*pdesc, &pruned); f::Prune(*pdesc, feed_var_names, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 0); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 0);
feed_var_names.insert("a");
pdesc->mutable_blocks(0)->mutable_ops(0)->set_is_target(true); pdesc->mutable_blocks(0)->mutable_ops(0)->set_is_target(true);
f::Prune(*pdesc, &pruned); f::Prune(*pdesc, feed_var_names, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 1); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 1);
} }
...@@ -81,11 +83,11 @@ TEST(Prune, forward) { ...@@ -81,11 +83,11 @@ TEST(Prune, forward) {
block); block);
f::proto::ProgramDesc *pdesc = program.Proto(); f::proto::ProgramDesc *pdesc = program.Proto();
std::set<std::string> feed_var_names = {"a"};
for (int i = 0; i < pdesc->blocks(0).ops_size(); ++i) { for (int i = 0; i < pdesc->blocks(0).ops_size(); ++i) {
f::proto::ProgramDesc pruned; f::proto::ProgramDesc pruned;
pdesc->mutable_blocks(0)->mutable_ops(i)->set_is_target(true); pdesc->mutable_blocks(0)->mutable_ops(i)->set_is_target(true);
f::Prune(*pdesc, &pruned); f::Prune(*pdesc, feed_var_names, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), i + 1); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), i + 1);
} }
} }
...@@ -107,7 +109,8 @@ TEST(Prune, multi_input_op) { ...@@ -107,7 +109,8 @@ TEST(Prune, multi_input_op) {
pdesc->mutable_blocks(0)->mutable_ops(3)->set_is_target(true); pdesc->mutable_blocks(0)->mutable_ops(3)->set_is_target(true);
f::proto::ProgramDesc pruned; f::proto::ProgramDesc pruned;
f::Prune(*pdesc, &pruned); std::set<std::string> feed_var_names = {"a0", "a1", "a2"};
f::Prune(*pdesc, feed_var_names, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 4); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 4);
} }
...@@ -126,7 +129,8 @@ TEST(Prune, multi_output_op) { ...@@ -126,7 +129,8 @@ TEST(Prune, multi_output_op) {
pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true); pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true);
f::proto::ProgramDesc pruned; f::proto::ProgramDesc pruned;
f::Prune(*pdesc, &pruned); std::set<std::string> feed_var_names = {"a"};
f::Prune(*pdesc, feed_var_names, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 2); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 2);
} }
...@@ -146,6 +150,7 @@ TEST(Prune, multi_target) { ...@@ -146,6 +150,7 @@ TEST(Prune, multi_target) {
pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true); pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true);
f::proto::ProgramDesc pruned; f::proto::ProgramDesc pruned;
f::Prune(*pdesc, &pruned); std::set<std::string> feed_var_names = {"a"};
f::Prune(*pdesc, feed_var_names, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 3); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 3);
} }
...@@ -749,13 +749,15 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -749,13 +749,15 @@ All parameter, weight, gradient are variables in Paddle.
#endif #endif
m.def("prune", [](const ProgramDesc &origin, m.def("prune", [](const ProgramDesc &origin,
const std::set<std::string> &feeded_var_names,
const std::vector<std::array<size_t, 2>> &targets) { const std::vector<std::array<size_t, 2>> &targets) {
ProgramDesc prog_with_targets(origin); ProgramDesc prog_with_targets(origin);
for (const auto &t : targets) { for (const auto &t : targets) {
prog_with_targets.MutableBlock(t[0])->Op(t[1])->SetIsTarget(true); prog_with_targets.MutableBlock(t[0])->Op(t[1])->SetIsTarget(true);
} }
proto::ProgramDesc pruned_desc; proto::ProgramDesc pruned_desc;
Prune(*prog_with_targets.Proto(), &pruned_desc); Prune(*prog_with_targets.Proto(), feeded_var_names, &pruned_desc);
return new ProgramDesc(pruned_desc); return new ProgramDesc(pruned_desc);
}); });
m.def("empty_var_name", m.def("empty_var_name",
......
...@@ -3247,7 +3247,7 @@ class Program(object): ...@@ -3247,7 +3247,7 @@ class Program(object):
p._copy_dist_param_info_from(self) p._copy_dist_param_info_from(self)
return p return p
def _prune(self, targets): def _prune(self, feeded_var_names, targets):
""" """
Prune operators and variables which are not needed to generate Prune operators and variables which are not needed to generate
:code:`targets`. :code:`targets`.
...@@ -3263,8 +3263,16 @@ class Program(object): ...@@ -3263,8 +3263,16 @@ class Program(object):
Program: A new, pruned program. Program: A new, pruned program.
""" """
if not isinstance(feeded_var_names, list):
feeded_var_names = [feeded_var_names]
if not isinstance(targets, list): if not isinstance(targets, list):
targets = [targets] targets = [targets]
for var in feeded_var_names:
if not isinstance(var, six.string_types):
raise ValueError("All feeded_var_names of prune() can only be "
"str.")
targets_idx = [] targets_idx = []
for t in targets: for t in targets:
if not isinstance(t, Operator): if not isinstance(t, Operator):
...@@ -3291,7 +3299,7 @@ class Program(object): ...@@ -3291,7 +3299,7 @@ class Program(object):
targets_idx.append([t.block.idx, t.idx]) targets_idx.append([t.block.idx, t.idx])
res = Program() res = Program()
res.desc = core.prune(self.desc, targets_idx) res.desc = core.prune(self.desc, set(feeded_var_names), targets_idx)
res.blocks = [ res.blocks = [
Block(res, i) for i in six.moves.range(res.desc.num_blocks()) Block(res, i) for i in six.moves.range(res.desc.num_blocks())
] ]
......
...@@ -1080,7 +1080,7 @@ def save_inference_model(dirname, ...@@ -1080,7 +1080,7 @@ def save_inference_model(dirname,
main_program.desc.flush() main_program.desc.flush()
main_program = main_program._prune(targets=target_vars) main_program = main_program._prune(feeded_var_names, target_vars)
main_program = main_program._inference_optimize(prune_read_op=True) main_program = main_program._inference_optimize(prune_read_op=True)
fetch_var_names = [v.name for v in target_vars] fetch_var_names = [v.name for v in target_vars]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册