未验证 提交 12d29f4d 编写于 作者: H Huihuang Zheng 提交者: GitHub

Change TensorCopy in recurrent_op to ShareDataWith (#19319)

上级 da127d11
......@@ -220,27 +220,39 @@ void RecurrentOp::RunImpl(const framework::Scope &scope,
}
}
// Every inputs are linked now, execute!
// Link inside::output -> outside::output
// outside::output[seq_offset: seq_offset + 1] = inside::output
executor.CreateVariables(ctx->prog_, &cur_scope, ctx->block_id_);
if (i > 0) {
LinkTensorWithCallback(scope, Outputs(kOutputs), cur_scope,
Outputs(kOutputs),
[&](const framework::LoDTensor &src_tensor,
framework::LoDTensor *dst_tensor) {
framework::Tensor src_slice =
src_tensor.Slice(seq_offset, seq_offset + 1);
dst_tensor->ShareDataWith(src_slice);
});
}
// Linked now, execute!
executor.RunPreparedContext(ctx.get(), &cur_scope,
false /*create_local_scope*/,
true /*create_vars*/, true /* keep_kids */);
// Copy inside::output -> outside::output
// outside::output[seq_offset: seq_offset + 1] = inside::output
this->LinkTensorWithCallback(
false /*create_vars*/, true /* keep_kids */);
if (i == 0) {
LinkTensorWithCallback(
cur_scope, Outputs(kOutputs), scope, Outputs(kOutputs),
[&](const framework::LoDTensor &src_tensor,
framework::LoDTensor *dst_tensor) {
if (i == 0) { // create output tensor at begin
// create output tensor at begin
dst_tensor->Resize(PrependDims(seq_len, src_tensor.dims()));
dst_tensor->mutable_data(place, src_tensor.type());
}
auto dst_out = dst_tensor->Slice(seq_offset, seq_offset + 1);
// Explicit copy output since the local RNN scope can be destroyed
// early.
framework::TensorCopy(src_tensor, place, dev_ctx, &dst_out);
});
}
scopes.Next();
}
......@@ -322,23 +334,42 @@ void RecurrentGradOp::RunImpl(const framework::Scope &scope,
for (size_t i = 0; i < ex_state_grads.size(); ++i) {
auto &cur_grad = cur_state_grads[i];
auto &ex_grad = ex_state_grads[i];
auto &ex_tensor =
auto &ex_grad_tensor =
ex_scope.FindVar(ex_grad)->Get<framework::LoDTensor>();
VLOG(10) << " RNN link " << cur_grad << " from " << ex_grad;
auto *cur_grad_var = cur_scope.Var(cur_grad);
auto cur_grad_tensor =
framework::LoDTensor *cur_grad_tensor =
cur_grad_var->GetMutable<framework::LoDTensor>();
framework::TensorCopy(ex_tensor, place, dev_ctx, cur_grad_tensor);
cur_grad_tensor->ShareDataWith(ex_grad_tensor);
}
}
}
// Link inside::output -> outside::output
// outside::output[seq_offset: seq_offset + 1] = inside::output
executor.CreateVariables(ctx->prog_, &cur_scope, ctx->block_id_);
if (step_id > 0) {
LinkTensorWithCallback(scope, Outputs(kInputGrads), cur_scope,
GradVarLists(Inputs(kInputs)),
[&](const framework::LoDTensor &src_tensor,
framework::LoDTensor *dst_tensor) {
if (src_tensor.memory_size() ==
0) { // Inside Gradient is not created.
return;
}
framework::Tensor src_slice =
src_tensor.Slice(seq_offset, seq_offset + 1);
dst_tensor->ShareDataWith(src_slice);
},
true /*is_backward*/);
}
VLOG(5) << "Recurrent memory linking finished ";
// Run step block with cur_scope
executor.RunPreparedContext(ctx.get(), &cur_scope,
false /*create_local_scope*/,
true /*create_vars*/, true /* keep_kids */);
false /*create_vars*/, true /* keep_kids */);
VLOG(5) << "executor.Run finished ";
......@@ -393,21 +424,23 @@ void RecurrentGradOp::RunImpl(const framework::Scope &scope,
// Copy input gradient from inside to outside
// outside::input_grad[seq_offset: seq_offset + 1] = inside::input_grad
if (step_id == 0) {
LinkTensorWithCallback(
cur_scope, GradVarLists(Inputs(kInputs)), scope, Outputs(kInputGrads),
[&](const framework::LoDTensor &inside, framework::LoDTensor *outside) {
[&](const framework::LoDTensor &inside,
framework::LoDTensor *outside) {
if (inside.memory_size() == 0) { // IG is not created.
return;
}
if (step_id == 0) { // alloc memory
// Alloc outside memory
outside->Resize(PrependDims(seq_len, inside.dims()));
outside->mutable_data(place, inside.type());
}
auto dst = outside->Slice(seq_offset, seq_offset + 1);
framework::TensorCopy(inside, place, dev_ctx, &dst);
},
true /*is_backward*/);
}
VLOG(5) << "Link outside gradient finished ";
if (has_state) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册