未验证 提交 bd79ae09 编写于 作者: W wanghuancoder 提交者: GitHub

add inplace logic into new_executor (#35618)

* add inplace logic into new_executor, test=develop

* check shape and add inplace FLAGS, test=develop

* refine, test=develop

* refine, test=develop
上级 4d236354
......@@ -35,27 +35,6 @@ namespace paddle {
namespace framework {
namespace details {
// TODO(zjl): support SelectedRows
static inline const Tensor &GetTensorFromVar(const Variable *var) {
if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Variable must be type of LoDTensor, but received %s.",
framework::ToTypeName(var->Type())));
}
}
static inline Tensor *GetMutableTensorFromVar(Variable *var) {
if (var->IsType<LoDTensor>()) {
return var->GetMutable<LoDTensor>();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Variable must be type of LoDTensor, but received %s.",
framework::ToTypeName(var->Type())));
}
}
ShareTensorBufferFunctor::ShareTensorBufferFunctor(
Scope *scope, size_t scope_idx, const std::string &op_type,
const std::vector<const ir::MemOptVarInfo *> &in_var_infos,
......
......@@ -39,6 +39,27 @@ namespace paddle {
namespace framework {
namespace details {
// TODO(zjl): support SelectedRows
static inline const Tensor &GetTensorFromVar(const Variable *var) {
if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Variable must be type of LoDTensor, but received %s.",
framework::ToTypeName(var->Type())));
}
}
static inline Tensor *GetMutableTensorFromVar(Variable *var) {
if (var->IsType<LoDTensor>()) {
return var->GetMutable<LoDTensor>();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Variable must be type of LoDTensor, but received %s.",
framework::ToTypeName(var->Type())));
}
}
// NOTE(paddle-dev): ShareTensorBufferFunctor is responsible for
// performing memory reuse in run-time. ShareTensorBufferOpHandle
// is only a wrapper of ShareTensorBufferFunctor.
......
......@@ -17,6 +17,10 @@
#include <unordered_set>
#include "paddle/fluid/framework/details/share_tensor_buffer_functor.h"
DEFINE_bool(new_executor_use_inplace, true, "Use inplace in new executor");
namespace paddle {
namespace framework {
......@@ -194,6 +198,41 @@ void InterpreterCore::Convert() {
gc_event_.emplace_back(vec_instruction_[i].execution_ctx_.get()->GetPlace(),
platform::GenerateDeviceEventFlag());
}
if (FLAGS_new_executor_use_inplace) {
BuildInplace();
}
}
void InterpreterCore::BuildInplace() {
for (size_t i = 0; i < vec_instruction_.size(); ++i) {
if (!vec_instruction_[i]
.kernel_func_.operator_base_->Info()
.infer_inplace_) {
continue;
}
auto in_to_outs =
vec_instruction_[i].kernel_func_.operator_base_->Info().infer_inplace_(
platform::is_gpu_place(vec_instruction_[i].dev_ctx_->GetPlace()));
for (auto& pair : in_to_outs) {
auto iter = vec_instruction_[i].input_index_.find(pair.first);
if (iter != vec_instruction_[i].input_index_.end()) {
if (input_var2op_info_[iter->second[0]].size() == 1) {
auto iterout = vec_instruction_[i].output_index_.find(pair.second);
if (iterout != vec_instruction_[i].output_index_.end()) {
auto invar = global_scope_->var_list[iter->second[0]];
auto outvar = global_scope_->var_list[iterout->second[0]];
if (invar && outvar) {
vec_instruction_[i].vec_inplace_in_to_out_.emplace_back(invar,
outvar);
}
}
}
}
}
}
}
void InterpreterCore::BuildAndCacheInstructionCtx(
......@@ -265,6 +304,17 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
instr_node.kernel_func_.operator_base_)
->InferShape(instr_node.infershape_ctx_.get());
if (FLAGS_new_executor_use_inplace) {
for (auto& pair : instr_node.vec_inplace_in_to_out_) {
const auto& in = paddle::framework::details::GetTensorFromVar(pair.first);
auto* out =
paddle::framework::details::GetMutableTensorFromVar(pair.second);
if (in.dims() == out->dims()) {
out->ShareBufferWith(in);
}
}
}
instr_node.kernel_func_.compute_func_(*instr_node.execution_ctx_.get());
}
......
......@@ -53,6 +53,8 @@ class InterpreterCore {
const VariableScope& var_scope,
const platform::Place& place);
void BuildInplace();
void RunInstruction(const Instruction& instr_node);
void ExecuteInstructionList(const std::vector<Instruction>& vec_instr,
......
......@@ -522,6 +522,8 @@ struct Instruction {
std::vector<EventInter> output_events_;
platform::DeviceContext* dev_ctx_; // not owned
std::vector<std::pair<Variable*, Variable*>> vec_inplace_in_to_out_;
};
enum class OpFuncType {
......
......@@ -209,6 +209,7 @@ def __bootstrap__():
'sort_sum_gradient',
'max_inplace_grad_add',
'apply_pass_to_program',
'new_executor_use_inplace',
]
if 'Darwin' not in sysstr:
read_env_flags.append('use_pinned_memory')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册