From bdca4b37c434b26b2c6ae300899a1c562a82e133 Mon Sep 17 00:00:00 2001 From: Yang Yang Date: Tue, 17 Oct 2017 02:58:08 +0000 Subject: [PATCH] change api based on design doc --- paddle/framework/prune.cc | 6 ++++-- paddle/framework/prune.h | 2 +- paddle/framework/prune_test.cc | 12 ++++++------ 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/paddle/framework/prune.cc b/paddle/framework/prune.cc index b08e0116b..958336929 100644 --- a/paddle/framework/prune.cc +++ b/paddle/framework/prune.cc @@ -46,7 +46,7 @@ bool IsTarget(const OpDesc& op_desc) { return false; } -void Prune(const ProgramDesc& input, ProgramDesc& output, int block_id) { +void prune_impl(const ProgramDesc& input, ProgramDesc& output, int block_id) { // TODO(tonyyang-svail): // - will change to use multiple blocks for RNN op and Cond Op @@ -99,8 +99,10 @@ void Prune(const ProgramDesc& input, ProgramDesc& output, int block_id) { *op_field->Add() = input.blocks(block_id).ops(i); } } +} - // return should_run; +void Prune(const ProgramDesc& input, ProgramDesc& output) { + prune_impl(input, output, 0); } } // namespace framework diff --git a/paddle/framework/prune.h b/paddle/framework/prune.h index 1c74d3b76..9414ac64f 100644 --- a/paddle/framework/prune.h +++ b/paddle/framework/prune.h @@ -20,7 +20,7 @@ limitations under the License. */ namespace paddle { namespace framework { -void Prune(const ProgramDesc& input, ProgramDesc& output, int block_id); +void Prune(const ProgramDesc& input, ProgramDesc& output); } // namespace framework } // namespace paddle diff --git a/paddle/framework/prune_test.cc b/paddle/framework/prune_test.cc index dc066facb..a8faf1891 100644 --- a/paddle/framework/prune_test.cc +++ b/paddle/framework/prune_test.cc @@ -68,11 +68,11 @@ TEST(Prune, one_operator) { f::ProgramDesc *pdesc = program.Proto(); f::ProgramDesc pruned; - Prune(*pdesc, pruned, 0); + Prune(*pdesc, pruned); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 0); pdesc->mutable_blocks(0)->mutable_ops(0)->set_is_target(true); - Prune(*pdesc, pruned, 0); + Prune(*pdesc, pruned); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 1); } @@ -91,7 +91,7 @@ TEST(Prune, forward) { for (int i = 0; i < pdesc->blocks(0).ops_size(); ++i) { f::ProgramDesc pruned; pdesc->mutable_blocks(0)->mutable_ops(i)->set_is_target(true); - Prune(*pdesc, pruned, 0); + Prune(*pdesc, pruned); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), i + 1); } } @@ -111,7 +111,7 @@ TEST(Prune, multi_input_op) { pdesc->mutable_blocks(0)->mutable_ops(3)->set_is_target(true); f::ProgramDesc pruned; - Prune(*pdesc, pruned, 0); + Prune(*pdesc, pruned); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 4); } @@ -128,7 +128,7 @@ TEST(Prune, multi_output_op) { pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true); f::ProgramDesc pruned; - Prune(*pdesc, pruned, 0); + Prune(*pdesc, pruned); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 2); } @@ -146,6 +146,6 @@ TEST(Prune, multi_target) { pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true); f::ProgramDesc pruned; - Prune(*pdesc, pruned, 0); + Prune(*pdesc, pruned); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 3); } -- GitLab