提交 fd5335c8 编写于 作者: C chengduozh

refine recurrent_op

test=develop
上级 e61d7245
......@@ -157,11 +157,13 @@ class RecurrentBase : public framework::OperatorBase {
const std::vector<std::string> &src_vars,
framework::Scope *dst_scope,
const std::vector<std::string> &dst_vars,
Callback callback) {
Callback callback,
bool is_backward = false) {
PADDLE_ENFORCE_EQ(src_vars.size(), dst_vars.size());
for (size_t i = 0; i < dst_vars.size(); ++i) {
VLOG(10) << "Link " << src_vars[i] << " to " << dst_vars[i];
AccessTensor(src_scope, src_vars[i], dst_scope, dst_vars[i], callback);
AccessTensor(src_scope, src_vars[i], dst_scope, dst_vars[i], callback,
is_backward);
}
}
......@@ -173,11 +175,13 @@ class RecurrentBase : public framework::OperatorBase {
const std::vector<std::string> &src_vars,
const framework::Scope &dst_scope,
const std::vector<std::string> &dst_vars,
Callback callback) {
Callback callback,
bool is_backward = false) {
PADDLE_ENFORCE_EQ(src_vars.size(), dst_vars.size());
for (size_t i = 0; i < dst_vars.size(); ++i) {
VLOG(10) << "Link " << src_vars[i] << " to " << dst_vars[i];
AccessTensor(src_scope, src_vars[i], dst_scope, dst_vars[i], callback);
AccessTensor(src_scope, src_vars[i], dst_scope, dst_vars[i], callback,
is_backward);
}
}
......@@ -194,9 +198,13 @@ class RecurrentBase : public framework::OperatorBase {
static void AccessTensor(const framework::Scope &src_scope,
const std::string &src_var_name,
framework::Scope *dst_scope,
const std::string &dst_var_name, Callback callback) {
const std::string &dst_var_name, Callback callback,
bool is_backward = false) {
auto *src_var = src_scope.FindVar(src_var_name);
PADDLE_ENFORCE(src_var != nullptr);
if (is_backward && src_var == nullptr) {
return;
}
PADDLE_ENFORCE(src_var != nullptr, "%s is not found.", src_var_name);
auto &src_tensor = src_var->Get<framework::LoDTensor>();
auto *dst_var = dst_scope->Var(dst_var_name);
......@@ -208,12 +216,16 @@ class RecurrentBase : public framework::OperatorBase {
static void AccessTensor(const framework::Scope &src_scope,
const std::string &src_var_name,
const framework::Scope &dst_scope,
const std::string &dst_var_name, Callback callback) {
const std::string &dst_var_name, Callback callback,
bool is_backward = false) {
auto *dst_var = dst_scope.FindVar(dst_var_name);
if (is_backward && dst_var == nullptr) {
return;
}
auto *src_var = src_scope.FindVar(src_var_name);
PADDLE_ENFORCE(src_var != nullptr);
PADDLE_ENFORCE(src_var != nullptr, "%s is not found.", src_var_name);
auto &src_tensor = src_var->Get<framework::LoDTensor>();
auto *dst_var = dst_scope.FindVar(dst_var_name);
PADDLE_ENFORCE(dst_var != nullptr);
PADDLE_ENFORCE(dst_var != nullptr, "%s is not found.", dst_var_name);
auto *dst_tensor = dst_var->GetMutable<framework::LoDTensor>();
callback(src_tensor, dst_tensor);
}
......@@ -345,7 +357,8 @@ class RecurrentGradOp : public RecurrentBase {
auto dims = framework::vectorize(inside->dims());
dims.erase(dims.begin());
inside->Resize(framework::make_ddim(dims));
});
},
true /*is_backward*/);
auto og_set = List2Set(Inputs(kOutputGrads));
if (VLOG_IS_ON(10)) {
......@@ -454,7 +467,8 @@ class RecurrentGradOp : public RecurrentBase {
auto dst = outside->Slice(seq_offset, seq_offset + 1);
framework::TensorCopy(inside, place, dev_ctx, &dst);
});
},
true /*is_backward*/);
VLOG(5) << "Link outside gradient finished ";
if (step_id + 1 == seq_len) { // at_end
......@@ -467,7 +481,8 @@ class RecurrentGradOp : public RecurrentBase {
outside->Resize(inside.dims());
outside->mutable_data(place, inside.type());
framework::TensorCopy(inside, place, dev_ctx, outside);
});
},
true /*is_backward*/);
VLOG(5) << "Link initialize state gradient finished ";
}
scopes.Next();
......@@ -609,9 +624,10 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase {
std::vector<std::string> output{kOutputs};
for (auto &s : input) {
PADDLE_ENFORCE(ctx->HasInputs(s));
PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName(s)),
"Cannot find the gradient variable %s",
framework::GradVarName(s));
// NOTE(zcd): In some case, some of kInputs doesn't have gradient.
// PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName(s)),
// "Cannot find the gradient variable %s",
// framework::GradVarName(s));
}
for (auto &s : output) {
PADDLE_ENFORCE(ctx->HasInputs(s));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册