未验证 提交 f5a37518 编写于 作者: C chengduo 提交者: GitHub

Refine recurrent_op (#16027)

* refine recurrent_op
test=develop

* remove unnecessary code
test=develop
上级 6057f362
...@@ -157,11 +157,13 @@ class RecurrentBase : public framework::OperatorBase { ...@@ -157,11 +157,13 @@ class RecurrentBase : public framework::OperatorBase {
const std::vector<std::string> &src_vars, const std::vector<std::string> &src_vars,
framework::Scope *dst_scope, framework::Scope *dst_scope,
const std::vector<std::string> &dst_vars, const std::vector<std::string> &dst_vars,
Callback callback) { Callback callback,
bool is_backward = false) {
PADDLE_ENFORCE_EQ(src_vars.size(), dst_vars.size()); PADDLE_ENFORCE_EQ(src_vars.size(), dst_vars.size());
for (size_t i = 0; i < dst_vars.size(); ++i) { for (size_t i = 0; i < dst_vars.size(); ++i) {
VLOG(10) << "Link " << src_vars[i] << " to " << dst_vars[i]; VLOG(10) << "Link " << src_vars[i] << " to " << dst_vars[i];
AccessTensor(src_scope, src_vars[i], dst_scope, dst_vars[i], callback); AccessTensor(src_scope, src_vars[i], dst_scope, dst_vars[i], callback,
is_backward);
} }
} }
...@@ -173,11 +175,13 @@ class RecurrentBase : public framework::OperatorBase { ...@@ -173,11 +175,13 @@ class RecurrentBase : public framework::OperatorBase {
const std::vector<std::string> &src_vars, const std::vector<std::string> &src_vars,
const framework::Scope &dst_scope, const framework::Scope &dst_scope,
const std::vector<std::string> &dst_vars, const std::vector<std::string> &dst_vars,
Callback callback) { Callback callback,
bool is_backward = false) {
PADDLE_ENFORCE_EQ(src_vars.size(), dst_vars.size()); PADDLE_ENFORCE_EQ(src_vars.size(), dst_vars.size());
for (size_t i = 0; i < dst_vars.size(); ++i) { for (size_t i = 0; i < dst_vars.size(); ++i) {
VLOG(10) << "Link " << src_vars[i] << " to " << dst_vars[i]; VLOG(10) << "Link " << src_vars[i] << " to " << dst_vars[i];
AccessTensor(src_scope, src_vars[i], dst_scope, dst_vars[i], callback); AccessTensor(src_scope, src_vars[i], dst_scope, dst_vars[i], callback,
is_backward);
} }
} }
...@@ -194,9 +198,13 @@ class RecurrentBase : public framework::OperatorBase { ...@@ -194,9 +198,13 @@ class RecurrentBase : public framework::OperatorBase {
static void AccessTensor(const framework::Scope &src_scope, static void AccessTensor(const framework::Scope &src_scope,
const std::string &src_var_name, const std::string &src_var_name,
framework::Scope *dst_scope, framework::Scope *dst_scope,
const std::string &dst_var_name, Callback callback) { const std::string &dst_var_name, Callback callback,
bool is_backward = false) {
auto *src_var = src_scope.FindVar(src_var_name); auto *src_var = src_scope.FindVar(src_var_name);
PADDLE_ENFORCE(src_var != nullptr); if (is_backward && src_var == nullptr) {
return;
}
PADDLE_ENFORCE(src_var != nullptr, "%s is not found.", src_var_name);
auto &src_tensor = src_var->Get<framework::LoDTensor>(); auto &src_tensor = src_var->Get<framework::LoDTensor>();
auto *dst_var = dst_scope->Var(dst_var_name); auto *dst_var = dst_scope->Var(dst_var_name);
...@@ -208,12 +216,16 @@ class RecurrentBase : public framework::OperatorBase { ...@@ -208,12 +216,16 @@ class RecurrentBase : public framework::OperatorBase {
static void AccessTensor(const framework::Scope &src_scope, static void AccessTensor(const framework::Scope &src_scope,
const std::string &src_var_name, const std::string &src_var_name,
const framework::Scope &dst_scope, const framework::Scope &dst_scope,
const std::string &dst_var_name, Callback callback) { const std::string &dst_var_name, Callback callback,
bool is_backward = false) {
auto *dst_var = dst_scope.FindVar(dst_var_name);
if (is_backward && dst_var == nullptr) {
return;
}
auto *src_var = src_scope.FindVar(src_var_name); auto *src_var = src_scope.FindVar(src_var_name);
PADDLE_ENFORCE(src_var != nullptr); PADDLE_ENFORCE(src_var != nullptr, "%s is not found.", src_var_name);
auto &src_tensor = src_var->Get<framework::LoDTensor>(); auto &src_tensor = src_var->Get<framework::LoDTensor>();
auto *dst_var = dst_scope.FindVar(dst_var_name); PADDLE_ENFORCE(dst_var != nullptr, "%s is not found.", dst_var_name);
PADDLE_ENFORCE(dst_var != nullptr);
auto *dst_tensor = dst_var->GetMutable<framework::LoDTensor>(); auto *dst_tensor = dst_var->GetMutable<framework::LoDTensor>();
callback(src_tensor, dst_tensor); callback(src_tensor, dst_tensor);
} }
...@@ -345,7 +357,8 @@ class RecurrentGradOp : public RecurrentBase { ...@@ -345,7 +357,8 @@ class RecurrentGradOp : public RecurrentBase {
auto dims = framework::vectorize(inside->dims()); auto dims = framework::vectorize(inside->dims());
dims.erase(dims.begin()); dims.erase(dims.begin());
inside->Resize(framework::make_ddim(dims)); inside->Resize(framework::make_ddim(dims));
}); },
true /*is_backward*/);
auto og_set = List2Set(Inputs(kOutputGrads)); auto og_set = List2Set(Inputs(kOutputGrads));
if (VLOG_IS_ON(10)) { if (VLOG_IS_ON(10)) {
...@@ -454,7 +467,8 @@ class RecurrentGradOp : public RecurrentBase { ...@@ -454,7 +467,8 @@ class RecurrentGradOp : public RecurrentBase {
auto dst = outside->Slice(seq_offset, seq_offset + 1); auto dst = outside->Slice(seq_offset, seq_offset + 1);
framework::TensorCopy(inside, place, dev_ctx, &dst); framework::TensorCopy(inside, place, dev_ctx, &dst);
}); },
true /*is_backward*/);
VLOG(5) << "Link outside gradient finished "; VLOG(5) << "Link outside gradient finished ";
if (step_id + 1 == seq_len) { // at_end if (step_id + 1 == seq_len) { // at_end
...@@ -467,7 +481,8 @@ class RecurrentGradOp : public RecurrentBase { ...@@ -467,7 +481,8 @@ class RecurrentGradOp : public RecurrentBase {
outside->Resize(inside.dims()); outside->Resize(inside.dims());
outside->mutable_data(place, inside.type()); outside->mutable_data(place, inside.type());
framework::TensorCopy(inside, place, dev_ctx, outside); framework::TensorCopy(inside, place, dev_ctx, outside);
}); },
true /*is_backward*/);
VLOG(5) << "Link initialize state gradient finished "; VLOG(5) << "Link initialize state gradient finished ";
} }
scopes.Next(); scopes.Next();
...@@ -608,10 +623,8 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase { ...@@ -608,10 +623,8 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase {
std::vector<std::string> input{kInputs, kInitialStates}; std::vector<std::string> input{kInputs, kInitialStates};
std::vector<std::string> output{kOutputs}; std::vector<std::string> output{kOutputs};
for (auto &s : input) { for (auto &s : input) {
// NOTE(zcd): In some case, some of kInputs doesn't have gradient.
PADDLE_ENFORCE(ctx->HasInputs(s)); PADDLE_ENFORCE(ctx->HasInputs(s));
PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName(s)),
"Cannot find the gradient variable %s",
framework::GradVarName(s));
} }
for (auto &s : output) { for (auto &s : output) {
PADDLE_ENFORCE(ctx->HasInputs(s)); PADDLE_ENFORCE(ctx->HasInputs(s));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册