未验证 提交 bb22e59c 编写于 作者: R Ruibiao Chen 提交者: GitHub

Skip inplace for coalesce_tensor_op outputs (#44795)

* Skip inplace for coalesce_tensor_op outputs

* Fix typos

* Add UTs

* Fix typos
上级 756f01db
......@@ -284,6 +284,28 @@ void InterpreterCore::BuildAndCacheInstructionCtx(Instruction* instr_node) {
}
void InterpreterCore::BuildInplace() {
// NOTE(Ruibiao): coalesce_tensor_op outputs a FusedOutput Tensor and a list
// of Output Tensors which are sliced from the FusedOutput. These outputs
// sholud not be the outvar of the in-place var-pair since memory reuse
// between FusedOutput and Output Tensors is assumed. For the following
// example:
// fused_var, var1, var2, var3 = coalesce_tensor(var1, var2, var3)
// var1 = sum(var4, var5)
// ...
//
// After running coalesce_tensor_op, var1 is assumed to share the buffer
// slices from fused_var. However, if sum_op is in-place, then var1 would
// re-share the buffer with var4 instead of fused_var.
std::set<std::string> skip_inplace_outvars;
for (Instruction& instr : vec_instruction_) {
OperatorBase* op = instr.OpBase();
if (op->Type() == "coalesce_tensor") {
const std::vector<std::string>& outputs =
op->OutputVars(/*has_intermediate=*/false);
skip_inplace_outvars.insert(outputs.begin(), outputs.end());
}
}
for (size_t i = 0; i < vec_instruction_.size(); ++i) {
auto& instr = vec_instruction_[i];
auto* op_base = instr.OpBase();
......@@ -309,17 +331,20 @@ void InterpreterCore::BuildInplace() {
if (BuildInplaceCheckVarIsOnlyInput(iter->second[0])) {
auto iterout = outputs.find(pair.second);
if (iterout != outputs.end() && !iterout->second.empty()) {
auto invar =
local_scope_->FindVar(var_scope_.GetNameById(iter->second[0]));
auto outvar = local_scope_->FindVar(
var_scope_.GetNameById(iterout->second[0]));
const std::string& invar_name =
var_scope_.GetNameById(iter->second[0]);
const std::string& outvar_name =
var_scope_.GetNameById(iterout->second[0]);
auto invar = local_scope_->FindVar(invar_name);
auto outvar = local_scope_->FindVar(outvar_name);
if (invar && outvar && invar->IsType<LoDTensor>() &&
outvar->IsType<LoDTensor>()) {
outvar->IsType<LoDTensor>() &&
skip_inplace_outvars.find(outvar_name) ==
skip_inplace_outvars.end()) {
instr.AddInplace(invar, outvar);
VLOG(3) << "inplace " << vec_instruction_[i].OpBase()->Type()
<< " " << var_scope_.GetNameById(iter->second[0])
<< " -> " << var_scope_.GetNameById(iterout->second[0])
<< std::endl;
VLOG(3) << "inplace " << op_base->Type() << " " << invar_name
<< " -> " << outvar_name;
}
}
}
......
......@@ -1751,3 +1751,7 @@ py_test_modules(
set_tests_properties(test_add_reader_dependency_for_interpretercore
PROPERTIES TIMEOUT 120)
py_test_modules(
test_eager_deletion_padding_rnn_for_interpretercore MODULES
test_eager_deletion_padding_rnn ENVS FLAGS_CONVERT_GRAPH_TO_PROGRAM=true)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册