From 6b051b651ae72305d9877fd3cd094028c21bdddb Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 2 Oct 2017 14:24:03 -0700 Subject: [PATCH] optimize code --- paddle/operators/recurrent_op.cc | 38 ++++++++++++---------- paddle/operators/recurrent_op.h | 4 +-- paddle/operators/rnn/recurrent_op_utils.cc | 8 ++--- 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/paddle/operators/recurrent_op.cc b/paddle/operators/recurrent_op.cc index 016e2043fd..bcd6a3410a 100644 --- a/paddle/operators/recurrent_op.cc +++ b/paddle/operators/recurrent_op.cc @@ -32,24 +32,25 @@ void RecurrentAlgorithm::Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const { auto* input0 = scope.FindVar(arg_->inlinks[0]); PADDLE_ENFORCE_NOT_NULL(input0); - seq_len_ = input0->GetMutable()->dims()[0]; - PADDLE_ENFORCE_GT(seq_len_, 0); + size_t seq_len = input0->GetMutable()->dims()[0]; + PADDLE_ENFORCE_GT(seq_len, 0); - CreateScopes(scope); + CreateScopes(scope, seq_len); auto& step_scopes = GetStepScopes(scope); - rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_); + rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len); InitMemories(step_scopes[0]); - for (size_t i = 0; i < seq_len_; i++) { - if (i > 0) { - rnn::LinkMemories(step_scopes, arg_->memories, i, -1); + for (size_t step_id = 0; step_id < seq_len; step_id++) { + if (step_id > 0) { + rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1); } - (*stepnet_)->Run(*step_scopes[i], dev_ctx); + (*stepnet_)->Run(*step_scopes[step_id], dev_ctx); } - rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_); + rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len); } -void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { +void RecurrentAlgorithm::CreateScopes(const Scope& scope, + size_t seq_len) const { // TODO(superjom) Only two scopes are needed for inference, this case will be // supported later. auto step_scopes_var = scope.FindVar(arg_->step_scopes); @@ -60,8 +61,8 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { PADDLE_ENFORCE_NOT_NULL(stepnet_); PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "stepnet_ op has no outputs"); - if (seq_len_ > step_scopes->size()) { - for (size_t i = step_scopes->size(); i < seq_len_; ++i) { + if (seq_len > step_scopes->size()) { + for (size_t i = step_scopes->size(); i < seq_len; ++i) { auto& step_scope = scope.NewScope(); // create step net's temp inputs @@ -144,17 +145,18 @@ class RecurrentAlgorithmProtoAndCheckerMaker void RecurrentGradientAlgorithm::Run( const Scope& scope, const platform::DeviceContext& dev_ctx) const { - seq_len_ = - scope.FindVar(arg_->inlinks[0])->GetMutable()->dims()[0]; + auto* input0 = scope.FindVar(arg_->inlinks[0]); + PADDLE_ENFORCE_NOT_NULL(input0); + size_t seq_len = input0->GetMutable()->dims()[0]; auto step_scopes = GetStepScopes(scope); - rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_); - for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) { - if (static_cast(step_id) != seq_len_ - 1) { + rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len); + for (int step_id = seq_len - 1; step_id >= 0; --step_id) { + if (step_id != seq_len - 1) { rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1); } (*stepnet_)->Run(*step_scopes[step_id], dev_ctx); } - rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_); + rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len); LinkBootMemoryGradients(step_scopes[0]); } diff --git a/paddle/operators/recurrent_op.h b/paddle/operators/recurrent_op.h index 752025e42c..253d7e3284 100644 --- a/paddle/operators/recurrent_op.h +++ b/paddle/operators/recurrent_op.h @@ -48,7 +48,7 @@ class RecurrentAlgorithm { * NOTE the scopes are reused in both the forward and backward, so just * create once and expand its size if more steps need. */ - void CreateScopes(const framework::Scope& scope) const; + void CreateScopes(const framework::Scope& scope, size_t seq_len) const; const std::vector& GetStepScopes( const framework::Scope& scope) const { @@ -61,7 +61,6 @@ class RecurrentAlgorithm { private: std::unique_ptr* stepnet_; rnn::Argument* arg_; - mutable size_t seq_len_; }; class RecurrentGradientAlgorithm { @@ -97,7 +96,6 @@ class RecurrentGradientAlgorithm { private: rnn::Argument* arg_; - mutable size_t seq_len_; std::unique_ptr* stepnet_; }; diff --git a/paddle/operators/rnn/recurrent_op_utils.cc b/paddle/operators/rnn/recurrent_op_utils.cc index a02994f99d..a37d21d480 100644 --- a/paddle/operators/rnn/recurrent_op_utils.cc +++ b/paddle/operators/rnn/recurrent_op_utils.cc @@ -53,12 +53,12 @@ void ConcatOutputs(const std::vector& step_scopes, const std::vector& outlinks, const size_t seq_len) { for (size_t i = 0; i < outlinks.size(); i++) { - auto output_var = step_scopes[0]->parent().FindVar(outlinks[i]); + auto* output_var = step_scopes[0]->parent().FindVar(outlinks[i]); PADDLE_ENFORCE_NOT_NULL(output_var, "output link [%s] is not in scope.", outlinks[i]); LoDTensor* output = output_var->GetMutable(); - auto step_scope_var = step_scopes[0]->FindVar(outlinks[i]); + auto* step_scope_var = step_scopes[0]->FindVar(outlinks[i]); PADDLE_ENFORCE_NOT_NULL(step_scope_var, "%s not in scope", outlinks[i]); f::DDim step_dims = step_scope_var->template GetMutable()->dims(); @@ -89,8 +89,8 @@ void LinkMemories(const std::vector& scopes, step_id + offset, scopes.size(), "offset [%d] is out of range, it must be less than (%d - %d)", offset, scopes.size(), step_id); - auto scope = scopes[step_id]; - auto linked_scope = scopes[step_id + offset]; + auto* scope = scopes[step_id]; + auto* linked_scope = scopes[step_id + offset]; for (auto& attr : memories) { auto mem = scope->FindVar(attr.pre_var)->GetMutable(); auto linked_mem = linked_scope->FindVar(attr.var)->GetMutable(); -- GitLab