diff --git a/paddle/framework/prune.cc b/paddle/framework/prune.cc index b08e0116b7db6a37ebc45f0976968f0f5c79e950..95833692925af4477fe575d6bd908a2ce7653c1b 100644 --- a/paddle/framework/prune.cc +++ b/paddle/framework/prune.cc @@ -46,7 +46,7 @@ bool IsTarget(const OpDesc& op_desc) { return false; } -void Prune(const ProgramDesc& input, ProgramDesc& output, int block_id) { +void prune_impl(const ProgramDesc& input, ProgramDesc& output, int block_id) { // TODO(tonyyang-svail): // - will change to use multiple blocks for RNN op and Cond Op @@ -99,8 +99,10 @@ void Prune(const ProgramDesc& input, ProgramDesc& output, int block_id) { *op_field->Add() = input.blocks(block_id).ops(i); } } +} - // return should_run; +void Prune(const ProgramDesc& input, ProgramDesc& output) { + prune_impl(input, output, 0); } } // namespace framework diff --git a/paddle/framework/prune.h b/paddle/framework/prune.h index 1c74d3b763b778cfef176a67003a66cb008e42b6..9414ac64f9491c07aabb216a4c81dfe6e78e8043 100644 --- a/paddle/framework/prune.h +++ b/paddle/framework/prune.h @@ -20,7 +20,7 @@ limitations under the License. */ namespace paddle { namespace framework { -void Prune(const ProgramDesc& input, ProgramDesc& output, int block_id); +void Prune(const ProgramDesc& input, ProgramDesc& output); } // namespace framework } // namespace paddle diff --git a/paddle/framework/prune_test.cc b/paddle/framework/prune_test.cc index dc066facb2f0c66dfb821d14acd1fbe994e5ce84..a8faf1891ed4aaa77aa29ea971bb3f6ce12e1145 100644 --- a/paddle/framework/prune_test.cc +++ b/paddle/framework/prune_test.cc @@ -68,11 +68,11 @@ TEST(Prune, one_operator) { f::ProgramDesc *pdesc = program.Proto(); f::ProgramDesc pruned; - Prune(*pdesc, pruned, 0); + Prune(*pdesc, pruned); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 0); pdesc->mutable_blocks(0)->mutable_ops(0)->set_is_target(true); - Prune(*pdesc, pruned, 0); + Prune(*pdesc, pruned); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 1); } @@ -91,7 +91,7 @@ TEST(Prune, forward) { for (int i = 0; i < pdesc->blocks(0).ops_size(); ++i) { f::ProgramDesc pruned; pdesc->mutable_blocks(0)->mutable_ops(i)->set_is_target(true); - Prune(*pdesc, pruned, 0); + Prune(*pdesc, pruned); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), i + 1); } } @@ -111,7 +111,7 @@ TEST(Prune, multi_input_op) { pdesc->mutable_blocks(0)->mutable_ops(3)->set_is_target(true); f::ProgramDesc pruned; - Prune(*pdesc, pruned, 0); + Prune(*pdesc, pruned); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 4); } @@ -128,7 +128,7 @@ TEST(Prune, multi_output_op) { pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true); f::ProgramDesc pruned; - Prune(*pdesc, pruned, 0); + Prune(*pdesc, pruned); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 2); } @@ -146,6 +146,6 @@ TEST(Prune, multi_target) { pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true); f::ProgramDesc pruned; - Prune(*pdesc, pruned, 0); + Prune(*pdesc, pruned); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 3); }