diff --git a/paddle/fluid/imperative/basic_engine.cc b/paddle/fluid/imperative/basic_engine.cc index 75659298ea764d21f4d289cb980f5401171cf1da..62d449ccd2ea8c873629a6dade5fce2fac167aed 100644 --- a/paddle/fluid/imperative/basic_engine.cc +++ b/paddle/fluid/imperative/basic_engine.cc @@ -314,6 +314,68 @@ static std::shared_ptr> CallGradientHooks( return tmp_ins_ptr; } +static bool IsInputCanInplace(const std::shared_ptr& var) { + auto* inner_var = var->MutableVar(); + if (inner_var->IsInitialized() && inner_var->IsType()) { + auto tensor = inner_var->GetMutable(); + if (tensor->IsInitialized()) { + return true; + } + } + return false; +} + +static void PerformBackwardInplace(const std::string& op_type, + const NameVarMap& ins, + NameVarMap* outs) { + auto& infer_inplace = + paddle::framework::OpInfoMap::Instance().Get(op_type).infer_inplace_; + + if (infer_inplace) { + auto in_to_outs = infer_inplace(true); + for (auto& pair : in_to_outs) { + framework::LoDTensor *in_tensor = nullptr, *out_tensor = nullptr; + for (auto& p : ins) { + if (p.first == pair.first) { + // has at least one var + if (p.second.size() > 0 && p.second[0]) { + auto& in_var = p.second[0]; + VLOG(10) << p.first << " use_count: " << in_var.use_count(); + // the refcount of var to be inplaced should be 1 + if (in_var.use_count() == 1) { + if (IsInputCanInplace(in_var)) { + in_tensor = + in_var->MutableVar()->GetMutable(); + } + } + } + } + } + if (!in_tensor) { + continue; + } + for (auto& p : *outs) { + if (p.first == pair.second) { + if (p.second.size() > 0 && p.second[0]) { + auto& out_var = p.second[0]; + if (out_var->Type() == framework::proto::VarType::LOD_TENSOR) { + out_tensor = + out_var->MutableVar()->GetMutable(); + } + } + } + } + if (!out_tensor) { + continue; + } + out_tensor->ShareBufferWith(*in_tensor); + out_tensor->Resize(in_tensor->dims()); + VLOG(4) << "Inplace performed in op " << op_type << ": " << pair.second + << " -> " << pair.first; + } + } +} + void BasicEngine::Execute() { if (init_nodes_.empty()) { return; @@ -483,6 +545,10 @@ void BasicEngine::Execute() { */ auto tmp_ins_ptr = CallGradientHooks(bwd_ins, cur_op.Type()); + if (!tmp_ins_ptr) { + PerformBackwardInplace(cur_op.Type(), bwd_ins, &tmp_outs); + } + { VLOG(3) << "Start to execute grad op " << cur_op.Type(); try {