提交 dca9b6c5 编写于 作者: M mapingshuo 提交者: Dong Daxiang

add feed_var_names to Prune interface (#19589)

* Fix bug: add feed_vars to the prune function
上级 f45cb1c2
......@@ -68,7 +68,8 @@ bool HasSubBlock(const proto::OpDesc& op_desc) {
// the child block to help pruning
void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
int block_id, int parent_block_id,
std::set<std::string>* dependent_vars) {
std::set<std::string>* dependent_vars,
const std::set<std::string> feed_var_names) {
auto& block = input.blocks(block_id);
auto& ops = block.ops();
......@@ -94,7 +95,9 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
// insert its input to the dependency graph
for (auto& var : op_desc.inputs()) {
for (auto& argu : var.arguments()) {
dependent_vars->insert(argu);
if (feed_var_names.count(argu) == 0) {
dependent_vars->insert(argu);
}
}
}
should_run.push_back(true);
......@@ -127,18 +130,22 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
std::set<std::string> sub_block_dependent_vars;
for (auto& var : op->inputs()) {
for (auto& argu : var.arguments()) {
sub_block_dependent_vars.insert(argu);
if (feed_var_names.count(argu) == 0) {
sub_block_dependent_vars.insert(argu);
}
}
}
for (auto& var : op->outputs()) {
for (auto& argu : var.arguments()) {
sub_block_dependent_vars.insert(argu);
if (feed_var_names.count(argu) == 0) {
sub_block_dependent_vars.insert(argu);
}
}
}
// GetSubBlockIndex(*op) is the idx of the sub_block in the input desc
// output_block_id is the idx of the current block in the output desc
prune_impl(input, output, GetSubBlockIndex(*op), output_block_id,
&sub_block_dependent_vars);
&sub_block_dependent_vars, feed_var_names);
}
}
}
......@@ -178,10 +185,12 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
}
// TODO(fengjiayi): Prune() could be inplaced to avoid unnecessary copies
void Prune(const proto::ProgramDesc& input, proto::ProgramDesc* output) {
void Prune(const proto::ProgramDesc& input,
const std::set<std::string>& feed_var_names,
proto::ProgramDesc* output) {
std::set<std::string> dependent_vars;
output->clear_blocks();
prune_impl(input, output, 0, -1, &dependent_vars);
prune_impl(input, output, 0, -1, &dependent_vars, feed_var_names);
}
} // namespace framework
} // namespace paddle
......@@ -14,13 +14,17 @@ limitations under the License. */
#pragma once
#include <set>
#include <string>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
void Prune(const proto::ProgramDesc& input, proto::ProgramDesc* output);
void Prune(const proto::ProgramDesc& input,
const std::set<std::string>& feed_var_names,
proto::ProgramDesc* output);
} // namespace framework
} // namespace paddle
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/prune.h"
#include <gtest/gtest.h>
#include <set>
#include <string>
#include "paddle/fluid/framework/attribute.h"
......@@ -58,12 +59,13 @@ TEST(Prune, one_operator) {
f::proto::ProgramDesc *pdesc = program.Proto();
f::proto::ProgramDesc pruned;
f::Prune(*pdesc, &pruned);
std::set<std::string> feed_var_names = {};
f::Prune(*pdesc, feed_var_names, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 0);
feed_var_names.insert("a");
pdesc->mutable_blocks(0)->mutable_ops(0)->set_is_target(true);
f::Prune(*pdesc, &pruned);
f::Prune(*pdesc, feed_var_names, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 1);
}
......@@ -81,11 +83,11 @@ TEST(Prune, forward) {
block);
f::proto::ProgramDesc *pdesc = program.Proto();
std::set<std::string> feed_var_names = {"a"};
for (int i = 0; i < pdesc->blocks(0).ops_size(); ++i) {
f::proto::ProgramDesc pruned;
pdesc->mutable_blocks(0)->mutable_ops(i)->set_is_target(true);
f::Prune(*pdesc, &pruned);
f::Prune(*pdesc, feed_var_names, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), i + 1);
}
}
......@@ -107,7 +109,8 @@ TEST(Prune, multi_input_op) {
pdesc->mutable_blocks(0)->mutable_ops(3)->set_is_target(true);
f::proto::ProgramDesc pruned;
f::Prune(*pdesc, &pruned);
std::set<std::string> feed_var_names = {"a0", "a1", "a2"};
f::Prune(*pdesc, feed_var_names, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 4);
}
......@@ -126,7 +129,8 @@ TEST(Prune, multi_output_op) {
pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true);
f::proto::ProgramDesc pruned;
f::Prune(*pdesc, &pruned);
std::set<std::string> feed_var_names = {"a"};
f::Prune(*pdesc, feed_var_names, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 2);
}
......@@ -146,6 +150,7 @@ TEST(Prune, multi_target) {
pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true);
f::proto::ProgramDesc pruned;
f::Prune(*pdesc, &pruned);
std::set<std::string> feed_var_names = {"a"};
f::Prune(*pdesc, feed_var_names, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 3);
}
......@@ -749,13 +749,15 @@ All parameter, weight, gradient are variables in Paddle.
#endif
m.def("prune", [](const ProgramDesc &origin,
const std::set<std::string> &feeded_var_names,
const std::vector<std::array<size_t, 2>> &targets) {
ProgramDesc prog_with_targets(origin);
for (const auto &t : targets) {
prog_with_targets.MutableBlock(t[0])->Op(t[1])->SetIsTarget(true);
}
proto::ProgramDesc pruned_desc;
Prune(*prog_with_targets.Proto(), &pruned_desc);
Prune(*prog_with_targets.Proto(), feeded_var_names, &pruned_desc);
return new ProgramDesc(pruned_desc);
});
m.def("empty_var_name",
......
......@@ -3247,7 +3247,7 @@ class Program(object):
p._copy_dist_param_info_from(self)
return p
def _prune(self, targets):
def _prune(self, feeded_var_names, targets):
"""
Prune operators and variables which are not needed to generate
:code:`targets`.
......@@ -3263,8 +3263,16 @@ class Program(object):
Program: A new, pruned program.
"""
if not isinstance(feeded_var_names, list):
feeded_var_names = [feeded_var_names]
if not isinstance(targets, list):
targets = [targets]
for var in feeded_var_names:
if not isinstance(var, six.string_types):
raise ValueError("All feeded_var_names of prune() can only be "
"str.")
targets_idx = []
for t in targets:
if not isinstance(t, Operator):
......@@ -3291,7 +3299,7 @@ class Program(object):
targets_idx.append([t.block.idx, t.idx])
res = Program()
res.desc = core.prune(self.desc, targets_idx)
res.desc = core.prune(self.desc, set(feeded_var_names), targets_idx)
res.blocks = [
Block(res, i) for i in six.moves.range(res.desc.num_blocks())
]
......
......@@ -1080,7 +1080,7 @@ def save_inference_model(dirname,
main_program.desc.flush()
main_program = main_program._prune(targets=target_vars)
main_program = main_program._prune(feeded_var_names, target_vars)
main_program = main_program._inference_optimize(prune_read_op=True)
fetch_var_names = [v.name for v in target_vars]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册