未验证 提交 12d29f4d 编写于 作者: H Huihuang Zheng 提交者: GitHub

Change TensorCopy in recurrent_op to ShareDataWith (#19319)

上级 da127d11
...@@ -220,27 +220,39 @@ void RecurrentOp::RunImpl(const framework::Scope &scope, ...@@ -220,27 +220,39 @@ void RecurrentOp::RunImpl(const framework::Scope &scope,
} }
} }
// Every inputs are linked now, execute! // Link inside::output -> outside::output
// outside::output[seq_offset: seq_offset + 1] = inside::output
executor.CreateVariables(ctx->prog_, &cur_scope, ctx->block_id_);
if (i > 0) {
LinkTensorWithCallback(scope, Outputs(kOutputs), cur_scope,
Outputs(kOutputs),
[&](const framework::LoDTensor &src_tensor,
framework::LoDTensor *dst_tensor) {
framework::Tensor src_slice =
src_tensor.Slice(seq_offset, seq_offset + 1);
dst_tensor->ShareDataWith(src_slice);
});
}
// Linked now, execute!
executor.RunPreparedContext(ctx.get(), &cur_scope, executor.RunPreparedContext(ctx.get(), &cur_scope,
false /*create_local_scope*/, false /*create_local_scope*/,
true /*create_vars*/, true /* keep_kids */); false /*create_vars*/, true /* keep_kids */);
if (i == 0) {
// Copy inside::output -> outside::output LinkTensorWithCallback(
// outside::output[seq_offset: seq_offset + 1] = inside::output
this->LinkTensorWithCallback(
cur_scope, Outputs(kOutputs), scope, Outputs(kOutputs), cur_scope, Outputs(kOutputs), scope, Outputs(kOutputs),
[&](const framework::LoDTensor &src_tensor, [&](const framework::LoDTensor &src_tensor,
framework::LoDTensor *dst_tensor) { framework::LoDTensor *dst_tensor) {
if (i == 0) { // create output tensor at begin // create output tensor at begin
dst_tensor->Resize(PrependDims(seq_len, src_tensor.dims())); dst_tensor->Resize(PrependDims(seq_len, src_tensor.dims()));
dst_tensor->mutable_data(place, src_tensor.type()); dst_tensor->mutable_data(place, src_tensor.type());
}
auto dst_out = dst_tensor->Slice(seq_offset, seq_offset + 1); auto dst_out = dst_tensor->Slice(seq_offset, seq_offset + 1);
// Explicit copy output since the local RNN scope can be destroyed // Explicit copy output since the local RNN scope can be destroyed
// early. // early.
framework::TensorCopy(src_tensor, place, dev_ctx, &dst_out); framework::TensorCopy(src_tensor, place, dev_ctx, &dst_out);
}); });
}
scopes.Next(); scopes.Next();
} }
...@@ -322,23 +334,42 @@ void RecurrentGradOp::RunImpl(const framework::Scope &scope, ...@@ -322,23 +334,42 @@ void RecurrentGradOp::RunImpl(const framework::Scope &scope,
for (size_t i = 0; i < ex_state_grads.size(); ++i) { for (size_t i = 0; i < ex_state_grads.size(); ++i) {
auto &cur_grad = cur_state_grads[i]; auto &cur_grad = cur_state_grads[i];
auto &ex_grad = ex_state_grads[i]; auto &ex_grad = ex_state_grads[i];
auto &ex_tensor = auto &ex_grad_tensor =
ex_scope.FindVar(ex_grad)->Get<framework::LoDTensor>(); ex_scope.FindVar(ex_grad)->Get<framework::LoDTensor>();
VLOG(10) << " RNN link " << cur_grad << " from " << ex_grad; VLOG(10) << " RNN link " << cur_grad << " from " << ex_grad;
auto *cur_grad_var = cur_scope.Var(cur_grad); auto *cur_grad_var = cur_scope.Var(cur_grad);
auto cur_grad_tensor = framework::LoDTensor *cur_grad_tensor =
cur_grad_var->GetMutable<framework::LoDTensor>(); cur_grad_var->GetMutable<framework::LoDTensor>();
framework::TensorCopy(ex_tensor, place, dev_ctx, cur_grad_tensor); cur_grad_tensor->ShareDataWith(ex_grad_tensor);
}
} }
} }
// Link inside::output -> outside::output
// outside::output[seq_offset: seq_offset + 1] = inside::output
executor.CreateVariables(ctx->prog_, &cur_scope, ctx->block_id_);
if (step_id > 0) {
LinkTensorWithCallback(scope, Outputs(kInputGrads), cur_scope,
GradVarLists(Inputs(kInputs)),
[&](const framework::LoDTensor &src_tensor,
framework::LoDTensor *dst_tensor) {
if (src_tensor.memory_size() ==
0) { // Inside Gradient is not created.
return;
}
framework::Tensor src_slice =
src_tensor.Slice(seq_offset, seq_offset + 1);
dst_tensor->ShareDataWith(src_slice);
},
true /*is_backward*/);
} }
VLOG(5) << "Recurrent memory linking finished "; VLOG(5) << "Recurrent memory linking finished ";
// Run step block with cur_scope // Run step block with cur_scope
executor.RunPreparedContext(ctx.get(), &cur_scope, executor.RunPreparedContext(ctx.get(), &cur_scope,
false /*create_local_scope*/, false /*create_local_scope*/,
true /*create_vars*/, true /* keep_kids */); false /*create_vars*/, true /* keep_kids */);
VLOG(5) << "executor.Run finished "; VLOG(5) << "executor.Run finished ";
...@@ -393,21 +424,23 @@ void RecurrentGradOp::RunImpl(const framework::Scope &scope, ...@@ -393,21 +424,23 @@ void RecurrentGradOp::RunImpl(const framework::Scope &scope,
// Copy input gradient from inside to outside // Copy input gradient from inside to outside
// outside::input_grad[seq_offset: seq_offset + 1] = inside::input_grad // outside::input_grad[seq_offset: seq_offset + 1] = inside::input_grad
if (step_id == 0) {
LinkTensorWithCallback( LinkTensorWithCallback(
cur_scope, GradVarLists(Inputs(kInputs)), scope, Outputs(kInputGrads), cur_scope, GradVarLists(Inputs(kInputs)), scope, Outputs(kInputGrads),
[&](const framework::LoDTensor &inside, framework::LoDTensor *outside) { [&](const framework::LoDTensor &inside,
framework::LoDTensor *outside) {
if (inside.memory_size() == 0) { // IG is not created. if (inside.memory_size() == 0) { // IG is not created.
return; return;
} }
if (step_id == 0) { // alloc memory // Alloc outside memory
outside->Resize(PrependDims(seq_len, inside.dims())); outside->Resize(PrependDims(seq_len, inside.dims()));
outside->mutable_data(place, inside.type()); outside->mutable_data(place, inside.type());
}
auto dst = outside->Slice(seq_offset, seq_offset + 1); auto dst = outside->Slice(seq_offset, seq_offset + 1);
framework::TensorCopy(inside, place, dev_ctx, &dst); framework::TensorCopy(inside, place, dev_ctx, &dst);
}, },
true /*is_backward*/); true /*is_backward*/);
}
VLOG(5) << "Link outside gradient finished "; VLOG(5) << "Link outside gradient finished ";
if (has_state) { if (has_state) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册