未验证 提交 56dd7653 编写于 作者: H Huihuang Zheng 提交者: GitHub

Delete useless ex-scope in recurrent op (#19426)

上级 b8aa37d5
......@@ -54,20 +54,6 @@ static void ClearStepScopes(const platform::DeviceContext &dev_ctx,
step_scopes->clear();
}
// StepScopes manages scopes inside RNN.
// StepScopes::CurScope() get the current scope
// StepScopes::ExScope() get the ex-scope, or scope in previous time step.
// StepScopes::Next() move to next time step.
//
// if is_train = False, then
// there are two scopes for the RNN and just support forward.
// else
// the len(scopes) == seq_len
//
// if is_backward = True, then
// reversely access scopes
// else
// access scopes from begin to end.
StepScopes::StepScopes(const platform::DeviceContext &dev_ctx,
const framework::Scope &parent, StepScopeVar *scopes,
bool is_train, size_t seq_len, bool is_backward)
......@@ -76,8 +62,8 @@ StepScopes::StepScopes(const platform::DeviceContext &dev_ctx,
is_train_(is_train),
is_backward_(is_backward) {
size_t num_step_scopes = is_train ? seq_len : 2;
PADDLE_ENFORCE(is_train || !is_backward,
"Cannot backward when is not training");
PADDLE_ENFORCE_EQ(is_train || !is_backward, true,
"Cannot backward when is not training");
if (!is_backward_) {
ClearStepScopes(dev_ctx, const_cast<framework::Scope *>(&parent), scopes);
scopes->reserve(static_cast<size_t>(num_step_scopes));
......@@ -94,12 +80,22 @@ framework::Scope &StepScopes::ExScope() {
return scope;
}
void StepScopes::Next() {
if (is_backward_) {
--counter_;
} else {
++counter_;
void StepScopes::BackwardNext(const platform::DeviceContext &dev_ctx,
framework::Scope *parent_scope) {
PADDLE_ENFORCE_EQ(is_backward_, true,
"Cannot get backward next scope when is forward");
if (counter_ + 2 == scopes_->size()) {
parent_scope->DeleteScope((*scopes_)[counter_ + 1]);
scopes_->pop_back();
VLOG(3) << "Deleted scope at " << counter_ + 1;
}
--counter_;
}
void StepScopes::ForwardNext() {
PADDLE_ENFORCE_EQ(is_backward_, false,
"Cannot get forward next scope when is backward");
++counter_;
}
framework::Scope &StepScopes::GetScope(size_t scope_id) const {
......@@ -125,11 +121,11 @@ int64_t RecurrentBase::GetSequenceLength(const framework::Scope &scope) const {
// Dim format SEQ_LEN, BATCH_SIZE, ...
int64_t seq_len = -1;
auto &all_inputs = Inputs(kInputs);
PADDLE_ENFORCE(!all_inputs.empty());
PADDLE_ENFORCE_EQ(!all_inputs.empty(), true);
for (auto &iname : all_inputs) {
auto *var = scope.FindVar(iname);
PADDLE_ENFORCE(var != nullptr);
PADDLE_ENFORCE(var->IsType<framework::LoDTensor>());
PADDLE_ENFORCE_NOT_NULL(var);
PADDLE_ENFORCE_EQ(var->IsType<framework::LoDTensor>(), true);
auto &dim = var->Get<framework::LoDTensor>().dims();
if (seq_len == -1) {
seq_len = dim[0];
......@@ -254,7 +250,7 @@ void RecurrentOp::RunImpl(const framework::Scope &scope,
});
}
scopes.Next();
scopes.ForwardNext();
}
}
......@@ -262,7 +258,7 @@ StepScopes RecurrentOp::CreateStepScopes(const platform::DeviceContext &dev_ctx,
const framework::Scope &scope,
size_t seq_len) const {
auto *var = scope.FindVar(Output(kStepScopes));
PADDLE_ENFORCE(var != nullptr);
PADDLE_ENFORCE_NOT_NULL(var);
return StepScopes(dev_ctx, scope, var->GetMutable<StepScopeVar>(),
Attr<bool>(kIsTrain), seq_len);
}
......@@ -459,11 +455,11 @@ void RecurrentGradOp::RunImpl(const framework::Scope &scope,
VLOG(5) << "Link initialize state gradient finished ";
}
}
scopes.Next();
scopes.BackwardNext(dev_ctx, const_cast<framework::Scope *>(&scope));
}
// Delete the scope of StepScopes
auto *var = scope.FindVar(Input(kStepScopes));
PADDLE_ENFORCE(var != nullptr);
PADDLE_ENFORCE_NOT_NULL(var);
auto *step_scopes = var->GetMutable<StepScopeVar>();
ClearStepScopes(dev_ctx, const_cast<framework::Scope *>(&scope), step_scopes);
}
......@@ -472,7 +468,7 @@ StepScopes RecurrentGradOp::CreateStepScopes(
const platform::DeviceContext &dev_ctx, const framework::Scope &scope,
size_t seq_len) const {
auto *var = scope.FindVar(Input(kStepScopes));
PADDLE_ENFORCE(var != nullptr);
PADDLE_ENFORCE_NOT_NULL(var);
return StepScopes(dev_ctx, scope, var->GetMutable<StepScopeVar>(),
Attr<bool>(kIsTrain), seq_len, true /*is_backward*/);
}
......@@ -491,6 +487,7 @@ std::unordered_set<std::string> RecurrentGradOp::LocalVarNames(
const framework::Scope &scope) const {
return this->List2Set(scope.LocalVarNames());
}
std::vector<std::string> RecurrentGradOp::GradVarLists(
const std::vector<std::string> &var_names) {
std::vector<std::string> retv;
......@@ -627,25 +624,25 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase {
0, "The Attr(%s) should be empty.", RecurrentBase::kStates);
}
PADDLE_ENFORCE(ctx->HasInputs(RecurrentBase::kInputs),
"The input(%s) should not be empty.",
RecurrentBase::kInputs);
PADDLE_ENFORCE(ctx->HasInputs(RecurrentBase::kOutputs),
"The input(%s) should not be empty.",
RecurrentBase::kOutputs);
PADDLE_ENFORCE_EQ(ctx->HasInputs(RecurrentBase::kInputs), true,
"The input(%s) should not be empty.",
RecurrentBase::kInputs);
PADDLE_ENFORCE_EQ(ctx->HasInputs(RecurrentBase::kOutputs), true,
"The input(%s) should not be empty.",
RecurrentBase::kOutputs);
// In some case the kInitialStates is empty.
if (ctx->HasInputs(RecurrentBase::kInitialStates)) {
PADDLE_ENFORCE(ctx->HasOutputs(
framework::GradVarName(RecurrentBase::kInitialStates)),
"The output of(%s) should not be empty.",
framework::GradVarName(RecurrentBase::kInitialStates));
PADDLE_ENFORCE_EQ(ctx->HasOutputs(framework::GradVarName(
RecurrentBase::kInitialStates)),
true, "The output of(%s) should not be empty.",
framework::GradVarName(RecurrentBase::kInitialStates));
ctx->SetOutputsDim(framework::GradVarName(RecurrentBase::kInitialStates),
ctx->GetInputsDim(RecurrentBase::kInitialStates));
}
PADDLE_ENFORCE(
ctx->HasOutputs(framework::GradVarName(RecurrentBase::kInputs)),
PADDLE_ENFORCE_EQ(
ctx->HasOutputs(framework::GradVarName(RecurrentBase::kInputs)), true,
"The output of(%s) should not be empty.",
framework::GradVarName(RecurrentBase::kInputs));
ctx->SetOutputsDim(framework::GradVarName(RecurrentBase::kInputs),
......@@ -653,9 +650,9 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase {
// In some case the kParameters is empty.
if (ctx->HasInputs(RecurrentBase::kParameters)) {
PADDLE_ENFORCE(
PADDLE_ENFORCE_EQ(
ctx->HasOutputs(framework::GradVarName(RecurrentBase::kParameters)),
"The output of(%s) should not be empty.",
true, "The output of(%s) should not be empty.",
framework::GradVarName(RecurrentBase::kParameters));
ctx->SetOutputsDim(framework::GradVarName(RecurrentBase::kParameters),
ctx->GetInputsDim(RecurrentBase::kParameters));
......
......@@ -25,20 +25,17 @@ limitations under the License. */
namespace paddle {
namespace operators {
// StepScopes manages scopes inside RNN.
// StepScopes::CurScope() get the current scope
// StepScopes::ExScope() get the ex-scope, or scope in previous time step.
// StepScopes::Next() move to next time step.
// StepScopes manages the scopes inside Recurrent Op.
//
// if is_train = False, then
// there are two scopes for the RNN and just support forward.
// there are two scopes for the RNN and just support forward
// else
// the len(scopes) == seq_len
//
// if is_backward = True, then
// reversely access scopes
// reversely access scopes, delete useless ex-scope
// else
// access scopes from begin to end.
// access scopes from beginning to end
class StepScopes {
public:
StepScopes(const platform::DeviceContext &dev_ctx,
......@@ -46,11 +43,19 @@ class StepScopes {
std::vector<framework::Scope *> *scopes, bool is_train,
size_t seq_len, bool is_backward = false);
// Get the current scope
framework::Scope &CurScope();
// Get the ex-scope, which is the scope in previous time step
framework::Scope &ExScope();
void Next();
// Move to next time step when forwarding
void ForwardNext();
// Delete ex-scope after using it, then move to next time step when
// backwarding
void BackwardNext(const platform::DeviceContext &dev_ctx,
framework::Scope *parent_scope);
private:
framework::Scope &GetScope(size_t scope_id) const;
......@@ -154,7 +159,7 @@ class RecurrentBase : public framework::OperatorBase {
if (is_backward && src_var == nullptr) {
return;
}
PADDLE_ENFORCE(src_var != nullptr, "%s is not found.", src_var_name);
PADDLE_ENFORCE_NOT_NULL(src_var, "%s is not found.", src_var_name);
auto &src_tensor = src_var->Get<framework::LoDTensor>();
auto *dst_var = dst_scope->Var(dst_var_name);
......@@ -173,9 +178,9 @@ class RecurrentBase : public framework::OperatorBase {
return;
}
auto *src_var = src_scope.FindVar(src_var_name);
PADDLE_ENFORCE(src_var != nullptr, "%s is not found.", src_var_name);
PADDLE_ENFORCE_NOT_NULL(src_var, "%s is not found.", src_var_name);
auto &src_tensor = src_var->Get<framework::LoDTensor>();
PADDLE_ENFORCE(dst_var != nullptr, "%s is not found.", dst_var_name);
PADDLE_ENFORCE_NOT_NULL(dst_var, "%s is not found.", dst_var_name);
auto *dst_tensor = dst_var->GetMutable<framework::LoDTensor>();
callback(src_tensor, dst_tensor);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册