提交 3d4e8268 编写于 作者: Z Zeng Jinle 提交者: Tao Luo

fix recurrent fwd bug when no backward and scope clear (#17460)

上级 14f22362
......@@ -37,6 +37,20 @@ constexpr char kInitStateGrads[] = "initial_states" GRAD_SUFFIX;
using StepScopeVar = std::vector<framework::Scope *>;
static void ClearStepScopes(const platform::DeviceContext &dev_ctx,
framework::Scope *parent_scope,
StepScopeVar *step_scopes) {
if (step_scopes->empty()) return;
dev_ctx.Wait();
for (auto *sub_scope : *step_scopes) {
parent_scope->DeleteScope(sub_scope);
}
step_scopes->clear();
}
// StepScopes manages scopes inside RNN.
// StepScopes::CurScope() get the current scope
// StepScopes::ExScope() get the ex-scope, or scope in previous time step.
......@@ -53,7 +67,8 @@ using StepScopeVar = std::vector<framework::Scope *>;
// access scopes from begin to end.
class StepScopes {
public:
StepScopes(const framework::Scope &parent, StepScopeVar *scopes,
StepScopes(const platform::DeviceContext &dev_ctx,
const framework::Scope &parent, StepScopeVar *scopes,
bool is_train, size_t seq_len, bool is_backward = false)
: counter_(is_backward ? seq_len - 1 : 0UL),
scopes_(scopes),
......@@ -63,7 +78,7 @@ class StepScopes {
PADDLE_ENFORCE(is_train || !is_backward,
"Cannot backward when is not training");
if (!is_backward_) {
PADDLE_ENFORCE(scopes->empty());
ClearStepScopes(dev_ctx, const_cast<framework::Scope *>(&parent), scopes);
scopes->reserve(static_cast<size_t>(num_step_scopes));
for (size_t i = 0; i < num_step_scopes; ++i) {
scopes->emplace_back(&parent.NewScope());
......@@ -244,14 +259,15 @@ class RecurrentOp : public RecurrentBase {
const platform::Place &place) const override {
bool has_state = Attr<bool>(kHasStates);
auto seq_len = static_cast<size_t>(this->GetSequenceLength(scope));
VLOG(3) << "Static RNN input sequence length = " << seq_len;
StepScopes scopes = CreateStepScopes(scope, seq_len);
auto reverse = Attr<bool>(kReverse);
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
VLOG(3) << "Static RNN input sequence length = " << seq_len;
StepScopes scopes = CreateStepScopes(dev_ctx, scope, seq_len);
auto reverse = Attr<bool>(kReverse);
framework::Executor executor(place);
auto *block = Attr<framework::BlockDesc *>(kStepBlock);
......@@ -316,11 +332,12 @@ class RecurrentOp : public RecurrentBase {
}
private:
StepScopes CreateStepScopes(const framework::Scope &scope,
StepScopes CreateStepScopes(const platform::DeviceContext &dev_ctx,
const framework::Scope &scope,
size_t seq_len) const {
auto *var = scope.FindVar(Output(kStepScopes));
PADDLE_ENFORCE(var != nullptr);
return StepScopes(scope, var->GetMutable<StepScopeVar>(),
return StepScopes(dev_ctx, scope, var->GetMutable<StepScopeVar>(),
Attr<bool>(kIsTrain), seq_len);
}
};
......@@ -338,17 +355,18 @@ class RecurrentGradOp : public RecurrentBase {
const platform::Place &place) const override {
bool has_state = Attr<bool>(kHasStates);
const size_t seq_len = static_cast<size_t>(GetSequenceLength(scope));
StepScopes scopes = CreateStepScopes(scope, seq_len);
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
StepScopes scopes = CreateStepScopes(dev_ctx, scope, seq_len);
auto reverse = Attr<bool>(kReverse);
framework::Executor executor(place);
auto *block = Attr<framework::BlockDesc *>(kStepBlock);
auto *program = block->Program();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
for (size_t step_id = 0; step_id < seq_len; ++step_id) {
size_t seq_offset = reverse ? step_id : seq_len - step_id - 1;
VLOG(3) << "Recurrent backward operate at the time step " << seq_offset;
......@@ -501,22 +519,20 @@ class RecurrentGradOp : public RecurrentBase {
scopes.Next();
}
// Delete the scope of StepScopes
dev_ctx.Wait();
auto *var = scope.FindVar(Input(kStepScopes));
PADDLE_ENFORCE(var != nullptr);
auto step_scopes = var->GetMutable<StepScopeVar>();
for (auto *sub_scope : *step_scopes) {
const_cast<framework::Scope &>(scope).DeleteScope(sub_scope);
}
step_scopes->clear();
auto *step_scopes = var->GetMutable<StepScopeVar>();
ClearStepScopes(dev_ctx, const_cast<framework::Scope *>(&scope),
step_scopes);
}
private:
StepScopes CreateStepScopes(const framework::Scope &scope,
StepScopes CreateStepScopes(const platform::DeviceContext &dev_ctx,
const framework::Scope &scope,
size_t seq_len) const {
auto *var = scope.FindVar(Input(kStepScopes));
PADDLE_ENFORCE(var != nullptr);
return StepScopes(scope, var->GetMutable<StepScopeVar>(),
return StepScopes(dev_ctx, scope, var->GetMutable<StepScopeVar>(),
Attr<bool>(kIsTrain), seq_len, true /*is_backward*/);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册