提交 bdca4b37 编写于 作者: Y Yang Yang

change api based on design doc

上级 e0cee58c
......@@ -46,7 +46,7 @@ bool IsTarget(const OpDesc& op_desc) {
return false;
}
void Prune(const ProgramDesc& input, ProgramDesc& output, int block_id) {
void prune_impl(const ProgramDesc& input, ProgramDesc& output, int block_id) {
// TODO(tonyyang-svail):
// - will change to use multiple blocks for RNN op and Cond Op
......@@ -99,8 +99,10 @@ void Prune(const ProgramDesc& input, ProgramDesc& output, int block_id) {
*op_field->Add() = input.blocks(block_id).ops(i);
}
}
}
// return should_run;
void Prune(const ProgramDesc& input, ProgramDesc& output) {
prune_impl(input, output, 0);
}
} // namespace framework
......
......@@ -20,7 +20,7 @@ limitations under the License. */
namespace paddle {
namespace framework {
void Prune(const ProgramDesc& input, ProgramDesc& output, int block_id);
void Prune(const ProgramDesc& input, ProgramDesc& output);
} // namespace framework
} // namespace paddle
......@@ -68,11 +68,11 @@ TEST(Prune, one_operator) {
f::ProgramDesc *pdesc = program.Proto();
f::ProgramDesc pruned;
Prune(*pdesc, pruned, 0);
Prune(*pdesc, pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 0);
pdesc->mutable_blocks(0)->mutable_ops(0)->set_is_target(true);
Prune(*pdesc, pruned, 0);
Prune(*pdesc, pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 1);
}
......@@ -91,7 +91,7 @@ TEST(Prune, forward) {
for (int i = 0; i < pdesc->blocks(0).ops_size(); ++i) {
f::ProgramDesc pruned;
pdesc->mutable_blocks(0)->mutable_ops(i)->set_is_target(true);
Prune(*pdesc, pruned, 0);
Prune(*pdesc, pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), i + 1);
}
}
......@@ -111,7 +111,7 @@ TEST(Prune, multi_input_op) {
pdesc->mutable_blocks(0)->mutable_ops(3)->set_is_target(true);
f::ProgramDesc pruned;
Prune(*pdesc, pruned, 0);
Prune(*pdesc, pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 4);
}
......@@ -128,7 +128,7 @@ TEST(Prune, multi_output_op) {
pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true);
f::ProgramDesc pruned;
Prune(*pdesc, pruned, 0);
Prune(*pdesc, pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 2);
}
......@@ -146,6 +146,6 @@ TEST(Prune, multi_target) {
pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true);
f::ProgramDesc pruned;
Prune(*pdesc, pruned, 0);
Prune(*pdesc, pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 3);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册