未验证 提交 56dd7653 编写于 作者: H Huihuang Zheng 提交者: GitHub

Delete useless ex-scope in recurrent op (#19426)

上级 b8aa37d5
...@@ -54,20 +54,6 @@ static void ClearStepScopes(const platform::DeviceContext &dev_ctx, ...@@ -54,20 +54,6 @@ static void ClearStepScopes(const platform::DeviceContext &dev_ctx,
step_scopes->clear(); step_scopes->clear();
} }
// StepScopes manages scopes inside RNN.
// StepScopes::CurScope() get the current scope
// StepScopes::ExScope() get the ex-scope, or scope in previous time step.
// StepScopes::Next() move to next time step.
//
// if is_train = False, then
// there are two scopes for the RNN and just support forward.
// else
// the len(scopes) == seq_len
//
// if is_backward = True, then
// reversely access scopes
// else
// access scopes from begin to end.
StepScopes::StepScopes(const platform::DeviceContext &dev_ctx, StepScopes::StepScopes(const platform::DeviceContext &dev_ctx,
const framework::Scope &parent, StepScopeVar *scopes, const framework::Scope &parent, StepScopeVar *scopes,
bool is_train, size_t seq_len, bool is_backward) bool is_train, size_t seq_len, bool is_backward)
...@@ -76,8 +62,8 @@ StepScopes::StepScopes(const platform::DeviceContext &dev_ctx, ...@@ -76,8 +62,8 @@ StepScopes::StepScopes(const platform::DeviceContext &dev_ctx,
is_train_(is_train), is_train_(is_train),
is_backward_(is_backward) { is_backward_(is_backward) {
size_t num_step_scopes = is_train ? seq_len : 2; size_t num_step_scopes = is_train ? seq_len : 2;
PADDLE_ENFORCE(is_train || !is_backward, PADDLE_ENFORCE_EQ(is_train || !is_backward, true,
"Cannot backward when is not training"); "Cannot backward when is not training");
if (!is_backward_) { if (!is_backward_) {
ClearStepScopes(dev_ctx, const_cast<framework::Scope *>(&parent), scopes); ClearStepScopes(dev_ctx, const_cast<framework::Scope *>(&parent), scopes);
scopes->reserve(static_cast<size_t>(num_step_scopes)); scopes->reserve(static_cast<size_t>(num_step_scopes));
...@@ -94,12 +80,22 @@ framework::Scope &StepScopes::ExScope() { ...@@ -94,12 +80,22 @@ framework::Scope &StepScopes::ExScope() {
return scope; return scope;
} }
void StepScopes::Next() { void StepScopes::BackwardNext(const platform::DeviceContext &dev_ctx,
if (is_backward_) { framework::Scope *parent_scope) {
--counter_; PADDLE_ENFORCE_EQ(is_backward_, true,
} else { "Cannot get backward next scope when is forward");
++counter_; if (counter_ + 2 == scopes_->size()) {
parent_scope->DeleteScope((*scopes_)[counter_ + 1]);
scopes_->pop_back();
VLOG(3) << "Deleted scope at " << counter_ + 1;
} }
--counter_;
}
void StepScopes::ForwardNext() {
PADDLE_ENFORCE_EQ(is_backward_, false,
"Cannot get forward next scope when is backward");
++counter_;
} }
framework::Scope &StepScopes::GetScope(size_t scope_id) const { framework::Scope &StepScopes::GetScope(size_t scope_id) const {
...@@ -125,11 +121,11 @@ int64_t RecurrentBase::GetSequenceLength(const framework::Scope &scope) const { ...@@ -125,11 +121,11 @@ int64_t RecurrentBase::GetSequenceLength(const framework::Scope &scope) const {
// Dim format SEQ_LEN, BATCH_SIZE, ... // Dim format SEQ_LEN, BATCH_SIZE, ...
int64_t seq_len = -1; int64_t seq_len = -1;
auto &all_inputs = Inputs(kInputs); auto &all_inputs = Inputs(kInputs);
PADDLE_ENFORCE(!all_inputs.empty()); PADDLE_ENFORCE_EQ(!all_inputs.empty(), true);
for (auto &iname : all_inputs) { for (auto &iname : all_inputs) {
auto *var = scope.FindVar(iname); auto *var = scope.FindVar(iname);
PADDLE_ENFORCE(var != nullptr); PADDLE_ENFORCE_NOT_NULL(var);
PADDLE_ENFORCE(var->IsType<framework::LoDTensor>()); PADDLE_ENFORCE_EQ(var->IsType<framework::LoDTensor>(), true);
auto &dim = var->Get<framework::LoDTensor>().dims(); auto &dim = var->Get<framework::LoDTensor>().dims();
if (seq_len == -1) { if (seq_len == -1) {
seq_len = dim[0]; seq_len = dim[0];
...@@ -254,7 +250,7 @@ void RecurrentOp::RunImpl(const framework::Scope &scope, ...@@ -254,7 +250,7 @@ void RecurrentOp::RunImpl(const framework::Scope &scope,
}); });
} }
scopes.Next(); scopes.ForwardNext();
} }
} }
...@@ -262,7 +258,7 @@ StepScopes RecurrentOp::CreateStepScopes(const platform::DeviceContext &dev_ctx, ...@@ -262,7 +258,7 @@ StepScopes RecurrentOp::CreateStepScopes(const platform::DeviceContext &dev_ctx,
const framework::Scope &scope, const framework::Scope &scope,
size_t seq_len) const { size_t seq_len) const {
auto *var = scope.FindVar(Output(kStepScopes)); auto *var = scope.FindVar(Output(kStepScopes));
PADDLE_ENFORCE(var != nullptr); PADDLE_ENFORCE_NOT_NULL(var);
return StepScopes(dev_ctx, scope, var->GetMutable<StepScopeVar>(), return StepScopes(dev_ctx, scope, var->GetMutable<StepScopeVar>(),
Attr<bool>(kIsTrain), seq_len); Attr<bool>(kIsTrain), seq_len);
} }
...@@ -459,11 +455,11 @@ void RecurrentGradOp::RunImpl(const framework::Scope &scope, ...@@ -459,11 +455,11 @@ void RecurrentGradOp::RunImpl(const framework::Scope &scope,
VLOG(5) << "Link initialize state gradient finished "; VLOG(5) << "Link initialize state gradient finished ";
} }
} }
scopes.Next(); scopes.BackwardNext(dev_ctx, const_cast<framework::Scope *>(&scope));
} }
// Delete the scope of StepScopes // Delete the scope of StepScopes
auto *var = scope.FindVar(Input(kStepScopes)); auto *var = scope.FindVar(Input(kStepScopes));
PADDLE_ENFORCE(var != nullptr); PADDLE_ENFORCE_NOT_NULL(var);
auto *step_scopes = var->GetMutable<StepScopeVar>(); auto *step_scopes = var->GetMutable<StepScopeVar>();
ClearStepScopes(dev_ctx, const_cast<framework::Scope *>(&scope), step_scopes); ClearStepScopes(dev_ctx, const_cast<framework::Scope *>(&scope), step_scopes);
} }
...@@ -472,7 +468,7 @@ StepScopes RecurrentGradOp::CreateStepScopes( ...@@ -472,7 +468,7 @@ StepScopes RecurrentGradOp::CreateStepScopes(
const platform::DeviceContext &dev_ctx, const framework::Scope &scope, const platform::DeviceContext &dev_ctx, const framework::Scope &scope,
size_t seq_len) const { size_t seq_len) const {
auto *var = scope.FindVar(Input(kStepScopes)); auto *var = scope.FindVar(Input(kStepScopes));
PADDLE_ENFORCE(var != nullptr); PADDLE_ENFORCE_NOT_NULL(var);
return StepScopes(dev_ctx, scope, var->GetMutable<StepScopeVar>(), return StepScopes(dev_ctx, scope, var->GetMutable<StepScopeVar>(),
Attr<bool>(kIsTrain), seq_len, true /*is_backward*/); Attr<bool>(kIsTrain), seq_len, true /*is_backward*/);
} }
...@@ -491,6 +487,7 @@ std::unordered_set<std::string> RecurrentGradOp::LocalVarNames( ...@@ -491,6 +487,7 @@ std::unordered_set<std::string> RecurrentGradOp::LocalVarNames(
const framework::Scope &scope) const { const framework::Scope &scope) const {
return this->List2Set(scope.LocalVarNames()); return this->List2Set(scope.LocalVarNames());
} }
std::vector<std::string> RecurrentGradOp::GradVarLists( std::vector<std::string> RecurrentGradOp::GradVarLists(
const std::vector<std::string> &var_names) { const std::vector<std::string> &var_names) {
std::vector<std::string> retv; std::vector<std::string> retv;
...@@ -627,25 +624,25 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase { ...@@ -627,25 +624,25 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase {
0, "The Attr(%s) should be empty.", RecurrentBase::kStates); 0, "The Attr(%s) should be empty.", RecurrentBase::kStates);
} }
PADDLE_ENFORCE(ctx->HasInputs(RecurrentBase::kInputs), PADDLE_ENFORCE_EQ(ctx->HasInputs(RecurrentBase::kInputs), true,
"The input(%s) should not be empty.", "The input(%s) should not be empty.",
RecurrentBase::kInputs); RecurrentBase::kInputs);
PADDLE_ENFORCE(ctx->HasInputs(RecurrentBase::kOutputs), PADDLE_ENFORCE_EQ(ctx->HasInputs(RecurrentBase::kOutputs), true,
"The input(%s) should not be empty.", "The input(%s) should not be empty.",
RecurrentBase::kOutputs); RecurrentBase::kOutputs);
// In some case the kInitialStates is empty. // In some case the kInitialStates is empty.
if (ctx->HasInputs(RecurrentBase::kInitialStates)) { if (ctx->HasInputs(RecurrentBase::kInitialStates)) {
PADDLE_ENFORCE(ctx->HasOutputs( PADDLE_ENFORCE_EQ(ctx->HasOutputs(framework::GradVarName(
framework::GradVarName(RecurrentBase::kInitialStates)), RecurrentBase::kInitialStates)),
"The output of(%s) should not be empty.", true, "The output of(%s) should not be empty.",
framework::GradVarName(RecurrentBase::kInitialStates)); framework::GradVarName(RecurrentBase::kInitialStates));
ctx->SetOutputsDim(framework::GradVarName(RecurrentBase::kInitialStates), ctx->SetOutputsDim(framework::GradVarName(RecurrentBase::kInitialStates),
ctx->GetInputsDim(RecurrentBase::kInitialStates)); ctx->GetInputsDim(RecurrentBase::kInitialStates));
} }
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(
ctx->HasOutputs(framework::GradVarName(RecurrentBase::kInputs)), ctx->HasOutputs(framework::GradVarName(RecurrentBase::kInputs)), true,
"The output of(%s) should not be empty.", "The output of(%s) should not be empty.",
framework::GradVarName(RecurrentBase::kInputs)); framework::GradVarName(RecurrentBase::kInputs));
ctx->SetOutputsDim(framework::GradVarName(RecurrentBase::kInputs), ctx->SetOutputsDim(framework::GradVarName(RecurrentBase::kInputs),
...@@ -653,9 +650,9 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase { ...@@ -653,9 +650,9 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase {
// In some case the kParameters is empty. // In some case the kParameters is empty.
if (ctx->HasInputs(RecurrentBase::kParameters)) { if (ctx->HasInputs(RecurrentBase::kParameters)) {
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(
ctx->HasOutputs(framework::GradVarName(RecurrentBase::kParameters)), ctx->HasOutputs(framework::GradVarName(RecurrentBase::kParameters)),
"The output of(%s) should not be empty.", true, "The output of(%s) should not be empty.",
framework::GradVarName(RecurrentBase::kParameters)); framework::GradVarName(RecurrentBase::kParameters));
ctx->SetOutputsDim(framework::GradVarName(RecurrentBase::kParameters), ctx->SetOutputsDim(framework::GradVarName(RecurrentBase::kParameters),
ctx->GetInputsDim(RecurrentBase::kParameters)); ctx->GetInputsDim(RecurrentBase::kParameters));
......
...@@ -25,20 +25,17 @@ limitations under the License. */ ...@@ -25,20 +25,17 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
// StepScopes manages scopes inside RNN. // StepScopes manages the scopes inside Recurrent Op.
// StepScopes::CurScope() get the current scope
// StepScopes::ExScope() get the ex-scope, or scope in previous time step.
// StepScopes::Next() move to next time step.
// //
// if is_train = False, then // if is_train = False, then
// there are two scopes for the RNN and just support forward. // there are two scopes for the RNN and just support forward
// else // else
// the len(scopes) == seq_len // the len(scopes) == seq_len
// //
// if is_backward = True, then // if is_backward = True, then
// reversely access scopes // reversely access scopes, delete useless ex-scope
// else // else
// access scopes from begin to end. // access scopes from beginning to end
class StepScopes { class StepScopes {
public: public:
StepScopes(const platform::DeviceContext &dev_ctx, StepScopes(const platform::DeviceContext &dev_ctx,
...@@ -46,11 +43,19 @@ class StepScopes { ...@@ -46,11 +43,19 @@ class StepScopes {
std::vector<framework::Scope *> *scopes, bool is_train, std::vector<framework::Scope *> *scopes, bool is_train,
size_t seq_len, bool is_backward = false); size_t seq_len, bool is_backward = false);
// Get the current scope
framework::Scope &CurScope(); framework::Scope &CurScope();
// Get the ex-scope, which is the scope in previous time step
framework::Scope &ExScope(); framework::Scope &ExScope();
void Next(); // Move to next time step when forwarding
void ForwardNext();
// Delete ex-scope after using it, then move to next time step when
// backwarding
void BackwardNext(const platform::DeviceContext &dev_ctx,
framework::Scope *parent_scope);
private: private:
framework::Scope &GetScope(size_t scope_id) const; framework::Scope &GetScope(size_t scope_id) const;
...@@ -154,7 +159,7 @@ class RecurrentBase : public framework::OperatorBase { ...@@ -154,7 +159,7 @@ class RecurrentBase : public framework::OperatorBase {
if (is_backward && src_var == nullptr) { if (is_backward && src_var == nullptr) {
return; return;
} }
PADDLE_ENFORCE(src_var != nullptr, "%s is not found.", src_var_name); PADDLE_ENFORCE_NOT_NULL(src_var, "%s is not found.", src_var_name);
auto &src_tensor = src_var->Get<framework::LoDTensor>(); auto &src_tensor = src_var->Get<framework::LoDTensor>();
auto *dst_var = dst_scope->Var(dst_var_name); auto *dst_var = dst_scope->Var(dst_var_name);
...@@ -173,9 +178,9 @@ class RecurrentBase : public framework::OperatorBase { ...@@ -173,9 +178,9 @@ class RecurrentBase : public framework::OperatorBase {
return; return;
} }
auto *src_var = src_scope.FindVar(src_var_name); auto *src_var = src_scope.FindVar(src_var_name);
PADDLE_ENFORCE(src_var != nullptr, "%s is not found.", src_var_name); PADDLE_ENFORCE_NOT_NULL(src_var, "%s is not found.", src_var_name);
auto &src_tensor = src_var->Get<framework::LoDTensor>(); auto &src_tensor = src_var->Get<framework::LoDTensor>();
PADDLE_ENFORCE(dst_var != nullptr, "%s is not found.", dst_var_name); PADDLE_ENFORCE_NOT_NULL(dst_var, "%s is not found.", dst_var_name);
auto *dst_tensor = dst_var->GetMutable<framework::LoDTensor>(); auto *dst_tensor = dst_var->GetMutable<framework::LoDTensor>();
callback(src_tensor, dst_tensor); callback(src_tensor, dst_tensor);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册