diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 04d1f94d75241c611a357a0db2aff068e459a683..7d9d3d5fef14a81064ff229b7abfd090988f3e50 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -209,6 +209,23 @@ void InterpreterCore::Convert() { } } +bool InterpreterCore::BuildInplaceCheckVarIsOnlyInput(size_t var_index) { + if (!global_scope_->vec_meta_info_[var_index].vardesc_) { + return input_var2op_info_[var_index].size() == 1; + } else { + int is_input_cnt = 0; + for (auto inst_id : input_var2op_info_[var_index]) { + OpInOutInfo info; + info.Build(vec_instruction_[inst_id].kernel_func_.operator_base_); + if (info.IsInArgBufferNeeded( + global_scope_->vec_meta_info_[var_index].vardesc_->Name())) { + is_input_cnt++; + } + } + return is_input_cnt == 1; + } +} + void InterpreterCore::BuildInplace() { for (size_t i = 0; i < vec_instruction_.size(); ++i) { if (!vec_instruction_[i] @@ -224,7 +241,7 @@ void InterpreterCore::BuildInplace() { for (auto& pair : in_to_outs) { auto iter = vec_instruction_[i].input_index_.find(pair.first); if (iter != vec_instruction_[i].input_index_.end()) { - if (input_var2op_info_[iter->second[0]].size() == 1) { + if (BuildInplaceCheckVarIsOnlyInput(iter->second[0])) { auto iterout = vec_instruction_[i].output_index_.find(pair.second); if (iterout != vec_instruction_[i].output_index_.end()) { auto invar = global_scope_->var_list[iter->second[0]]; @@ -232,6 +249,15 @@ void InterpreterCore::BuildInplace() { if (invar && outvar) { vec_instruction_[i].vec_inplace_in_to_out_.emplace_back(invar, outvar); + VLOG(3) << "inplace " + << vec_instruction_[i].kernel_func_.operator_base_->Type() + << " " + << global_scope_->vec_meta_info_[iter->second[0]] + .vardesc_->Name() + << " -> " + << global_scope_->vec_meta_info_[iterout->second[0]] + .vardesc_->Name() + << std::endl; } } } diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index 974fcb0a24ac8b960e9230130e0a896c8e90c09d..e594f9ca8b54b5fee6dae88191cc1a58db87ff3e 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -57,6 +57,8 @@ class InterpreterCore { void BuildInplace(); + bool BuildInplaceCheckVarIsOnlyInput(size_t var_index); + void RunInstruction(const Instruction& instr_node); void ExecuteInstructionList(const std::vector& vec_instr);