提交 6b051b65 编写于 作者: Q qiaolongfei

optimize code

上级 32f5c9dd
...@@ -32,24 +32,25 @@ void RecurrentAlgorithm::Run(const Scope& scope, ...@@ -32,24 +32,25 @@ void RecurrentAlgorithm::Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const { const platform::DeviceContext& dev_ctx) const {
auto* input0 = scope.FindVar(arg_->inlinks[0]); auto* input0 = scope.FindVar(arg_->inlinks[0]);
PADDLE_ENFORCE_NOT_NULL(input0); PADDLE_ENFORCE_NOT_NULL(input0);
seq_len_ = input0->GetMutable<LoDTensor>()->dims()[0]; size_t seq_len = input0->GetMutable<LoDTensor>()->dims()[0];
PADDLE_ENFORCE_GT(seq_len_, 0); PADDLE_ENFORCE_GT(seq_len, 0);
CreateScopes(scope); CreateScopes(scope, seq_len);
auto& step_scopes = GetStepScopes(scope); auto& step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_); rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len);
InitMemories(step_scopes[0]); InitMemories(step_scopes[0]);
for (size_t i = 0; i < seq_len_; i++) { for (size_t step_id = 0; step_id < seq_len; step_id++) {
if (i > 0) { if (step_id > 0) {
rnn::LinkMemories(step_scopes, arg_->memories, i, -1); rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1);
} }
(*stepnet_)->Run(*step_scopes[i], dev_ctx); (*stepnet_)->Run(*step_scopes[step_id], dev_ctx);
} }
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_); rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len);
} }
void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { void RecurrentAlgorithm::CreateScopes(const Scope& scope,
size_t seq_len) const {
// TODO(superjom) Only two scopes are needed for inference, this case will be // TODO(superjom) Only two scopes are needed for inference, this case will be
// supported later. // supported later.
auto step_scopes_var = scope.FindVar(arg_->step_scopes); auto step_scopes_var = scope.FindVar(arg_->step_scopes);
...@@ -60,8 +61,8 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { ...@@ -60,8 +61,8 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
PADDLE_ENFORCE_NOT_NULL(stepnet_); PADDLE_ENFORCE_NOT_NULL(stepnet_);
PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "stepnet_ op has no outputs"); PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "stepnet_ op has no outputs");
if (seq_len_ > step_scopes->size()) { if (seq_len > step_scopes->size()) {
for (size_t i = step_scopes->size(); i < seq_len_; ++i) { for (size_t i = step_scopes->size(); i < seq_len; ++i) {
auto& step_scope = scope.NewScope(); auto& step_scope = scope.NewScope();
// create step net's temp inputs // create step net's temp inputs
...@@ -144,17 +145,18 @@ class RecurrentAlgorithmProtoAndCheckerMaker ...@@ -144,17 +145,18 @@ class RecurrentAlgorithmProtoAndCheckerMaker
void RecurrentGradientAlgorithm::Run( void RecurrentGradientAlgorithm::Run(
const Scope& scope, const platform::DeviceContext& dev_ctx) const { const Scope& scope, const platform::DeviceContext& dev_ctx) const {
seq_len_ = auto* input0 = scope.FindVar(arg_->inlinks[0]);
scope.FindVar(arg_->inlinks[0])->GetMutable<LoDTensor>()->dims()[0]; PADDLE_ENFORCE_NOT_NULL(input0);
size_t seq_len = input0->GetMutable<LoDTensor>()->dims()[0];
auto step_scopes = GetStepScopes(scope); auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_); rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len);
for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) { for (int step_id = seq_len - 1; step_id >= 0; --step_id) {
if (static_cast<size_t>(step_id) != seq_len_ - 1) { if (step_id != seq_len - 1) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1); rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1);
} }
(*stepnet_)->Run(*step_scopes[step_id], dev_ctx); (*stepnet_)->Run(*step_scopes[step_id], dev_ctx);
} }
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_); rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len);
LinkBootMemoryGradients(step_scopes[0]); LinkBootMemoryGradients(step_scopes[0]);
} }
......
...@@ -48,7 +48,7 @@ class RecurrentAlgorithm { ...@@ -48,7 +48,7 @@ class RecurrentAlgorithm {
* NOTE the scopes are reused in both the forward and backward, so just * NOTE the scopes are reused in both the forward and backward, so just
* create once and expand its size if more steps need. * create once and expand its size if more steps need.
*/ */
void CreateScopes(const framework::Scope& scope) const; void CreateScopes(const framework::Scope& scope, size_t seq_len) const;
const std::vector<framework::Scope*>& GetStepScopes( const std::vector<framework::Scope*>& GetStepScopes(
const framework::Scope& scope) const { const framework::Scope& scope) const {
...@@ -61,7 +61,6 @@ class RecurrentAlgorithm { ...@@ -61,7 +61,6 @@ class RecurrentAlgorithm {
private: private:
std::unique_ptr<framework::OperatorBase>* stepnet_; std::unique_ptr<framework::OperatorBase>* stepnet_;
rnn::Argument* arg_; rnn::Argument* arg_;
mutable size_t seq_len_;
}; };
class RecurrentGradientAlgorithm { class RecurrentGradientAlgorithm {
...@@ -97,7 +96,6 @@ class RecurrentGradientAlgorithm { ...@@ -97,7 +96,6 @@ class RecurrentGradientAlgorithm {
private: private:
rnn::Argument* arg_; rnn::Argument* arg_;
mutable size_t seq_len_;
std::unique_ptr<framework::OperatorBase>* stepnet_; std::unique_ptr<framework::OperatorBase>* stepnet_;
}; };
......
...@@ -53,12 +53,12 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes, ...@@ -53,12 +53,12 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes,
const std::vector<std::string>& outlinks, const std::vector<std::string>& outlinks,
const size_t seq_len) { const size_t seq_len) {
for (size_t i = 0; i < outlinks.size(); i++) { for (size_t i = 0; i < outlinks.size(); i++) {
auto output_var = step_scopes[0]->parent().FindVar(outlinks[i]); auto* output_var = step_scopes[0]->parent().FindVar(outlinks[i]);
PADDLE_ENFORCE_NOT_NULL(output_var, "output link [%s] is not in scope.", PADDLE_ENFORCE_NOT_NULL(output_var, "output link [%s] is not in scope.",
outlinks[i]); outlinks[i]);
LoDTensor* output = output_var->GetMutable<LoDTensor>(); LoDTensor* output = output_var->GetMutable<LoDTensor>();
auto step_scope_var = step_scopes[0]->FindVar(outlinks[i]); auto* step_scope_var = step_scopes[0]->FindVar(outlinks[i]);
PADDLE_ENFORCE_NOT_NULL(step_scope_var, "%s not in scope", outlinks[i]); PADDLE_ENFORCE_NOT_NULL(step_scope_var, "%s not in scope", outlinks[i]);
f::DDim step_dims = f::DDim step_dims =
step_scope_var->template GetMutable<LoDTensor>()->dims(); step_scope_var->template GetMutable<LoDTensor>()->dims();
...@@ -89,8 +89,8 @@ void LinkMemories(const std::vector<Scope*>& scopes, ...@@ -89,8 +89,8 @@ void LinkMemories(const std::vector<Scope*>& scopes,
step_id + offset, scopes.size(), step_id + offset, scopes.size(),
"offset [%d] is out of range, it must be less than (%d - %d)", offset, "offset [%d] is out of range, it must be less than (%d - %d)", offset,
scopes.size(), step_id); scopes.size(), step_id);
auto scope = scopes[step_id]; auto* scope = scopes[step_id];
auto linked_scope = scopes[step_id + offset]; auto* linked_scope = scopes[step_id + offset];
for (auto& attr : memories) { for (auto& attr : memories) {
auto mem = scope->FindVar(attr.pre_var)->GetMutable<LoDTensor>(); auto mem = scope->FindVar(attr.pre_var)->GetMutable<LoDTensor>();
auto linked_mem = linked_scope->FindVar(attr.var)->GetMutable<LoDTensor>(); auto linked_mem = linked_scope->FindVar(attr.var)->GetMutable<LoDTensor>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册