提交 bdca4b37 编写于 作者: Y Yang Yang

change api based on design doc

上级 e0cee58c
...@@ -46,7 +46,7 @@ bool IsTarget(const OpDesc& op_desc) { ...@@ -46,7 +46,7 @@ bool IsTarget(const OpDesc& op_desc) {
return false; return false;
} }
void Prune(const ProgramDesc& input, ProgramDesc& output, int block_id) { void prune_impl(const ProgramDesc& input, ProgramDesc& output, int block_id) {
// TODO(tonyyang-svail): // TODO(tonyyang-svail):
// - will change to use multiple blocks for RNN op and Cond Op // - will change to use multiple blocks for RNN op and Cond Op
...@@ -99,8 +99,10 @@ void Prune(const ProgramDesc& input, ProgramDesc& output, int block_id) { ...@@ -99,8 +99,10 @@ void Prune(const ProgramDesc& input, ProgramDesc& output, int block_id) {
*op_field->Add() = input.blocks(block_id).ops(i); *op_field->Add() = input.blocks(block_id).ops(i);
} }
} }
}
// return should_run; void Prune(const ProgramDesc& input, ProgramDesc& output) {
prune_impl(input, output, 0);
} }
} // namespace framework } // namespace framework
......
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void Prune(const ProgramDesc& input, ProgramDesc& output, int block_id); void Prune(const ProgramDesc& input, ProgramDesc& output);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -68,11 +68,11 @@ TEST(Prune, one_operator) { ...@@ -68,11 +68,11 @@ TEST(Prune, one_operator) {
f::ProgramDesc *pdesc = program.Proto(); f::ProgramDesc *pdesc = program.Proto();
f::ProgramDesc pruned; f::ProgramDesc pruned;
Prune(*pdesc, pruned, 0); Prune(*pdesc, pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 0); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 0);
pdesc->mutable_blocks(0)->mutable_ops(0)->set_is_target(true); pdesc->mutable_blocks(0)->mutable_ops(0)->set_is_target(true);
Prune(*pdesc, pruned, 0); Prune(*pdesc, pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 1); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 1);
} }
...@@ -91,7 +91,7 @@ TEST(Prune, forward) { ...@@ -91,7 +91,7 @@ TEST(Prune, forward) {
for (int i = 0; i < pdesc->blocks(0).ops_size(); ++i) { for (int i = 0; i < pdesc->blocks(0).ops_size(); ++i) {
f::ProgramDesc pruned; f::ProgramDesc pruned;
pdesc->mutable_blocks(0)->mutable_ops(i)->set_is_target(true); pdesc->mutable_blocks(0)->mutable_ops(i)->set_is_target(true);
Prune(*pdesc, pruned, 0); Prune(*pdesc, pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), i + 1); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), i + 1);
} }
} }
...@@ -111,7 +111,7 @@ TEST(Prune, multi_input_op) { ...@@ -111,7 +111,7 @@ TEST(Prune, multi_input_op) {
pdesc->mutable_blocks(0)->mutable_ops(3)->set_is_target(true); pdesc->mutable_blocks(0)->mutable_ops(3)->set_is_target(true);
f::ProgramDesc pruned; f::ProgramDesc pruned;
Prune(*pdesc, pruned, 0); Prune(*pdesc, pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 4); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 4);
} }
...@@ -128,7 +128,7 @@ TEST(Prune, multi_output_op) { ...@@ -128,7 +128,7 @@ TEST(Prune, multi_output_op) {
pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true); pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true);
f::ProgramDesc pruned; f::ProgramDesc pruned;
Prune(*pdesc, pruned, 0); Prune(*pdesc, pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 2); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 2);
} }
...@@ -146,6 +146,6 @@ TEST(Prune, multi_target) { ...@@ -146,6 +146,6 @@ TEST(Prune, multi_target) {
pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true); pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true);
f::ProgramDesc pruned; f::ProgramDesc pruned;
Prune(*pdesc, pruned, 0); Prune(*pdesc, pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 3); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 3);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册