未验证 提交 0cb413d3 编写于 作者: L Leo Chen 提交者: GitHub

add backward inplace for dygraph (#35412)

* add backward inplace for dygraph

* fix bug

* support gradient accumulation
上级 abe70d3e
......@@ -314,6 +314,68 @@ static std::shared_ptr<NameVarMap<VariableWrapper>> CallGradientHooks(
return tmp_ins_ptr;
}
static bool IsInputCanInplace(const std::shared_ptr<VariableWrapper>& var) {
auto* inner_var = var->MutableVar();
if (inner_var->IsInitialized() && inner_var->IsType<framework::LoDTensor>()) {
auto tensor = inner_var->GetMutable<framework::LoDTensor>();
if (tensor->IsInitialized()) {
return true;
}
}
return false;
}
static void PerformBackwardInplace(const std::string& op_type,
const NameVarMap<VariableWrapper>& ins,
NameVarMap<VariableWrapper>* outs) {
auto& infer_inplace =
paddle::framework::OpInfoMap::Instance().Get(op_type).infer_inplace_;
if (infer_inplace) {
auto in_to_outs = infer_inplace(true);
for (auto& pair : in_to_outs) {
framework::LoDTensor *in_tensor = nullptr, *out_tensor = nullptr;
for (auto& p : ins) {
if (p.first == pair.first) {
// has at least one var
if (p.second.size() > 0 && p.second[0]) {
auto& in_var = p.second[0];
VLOG(10) << p.first << " use_count: " << in_var.use_count();
// the refcount of var to be inplaced should be 1
if (in_var.use_count() == 1) {
if (IsInputCanInplace(in_var)) {
in_tensor =
in_var->MutableVar()->GetMutable<framework::LoDTensor>();
}
}
}
}
}
if (!in_tensor) {
continue;
}
for (auto& p : *outs) {
if (p.first == pair.second) {
if (p.second.size() > 0 && p.second[0]) {
auto& out_var = p.second[0];
if (out_var->Type() == framework::proto::VarType::LOD_TENSOR) {
out_tensor =
out_var->MutableVar()->GetMutable<framework::LoDTensor>();
}
}
}
}
if (!out_tensor) {
continue;
}
out_tensor->ShareBufferWith(*in_tensor);
out_tensor->Resize(in_tensor->dims());
VLOG(4) << "Inplace performed in op " << op_type << ": " << pair.second
<< " -> " << pair.first;
}
}
}
void BasicEngine::Execute() {
if (init_nodes_.empty()) {
return;
......@@ -483,6 +545,10 @@ void BasicEngine::Execute() {
*/
auto tmp_ins_ptr = CallGradientHooks(bwd_ins, cur_op.Type());
if (!tmp_ins_ptr) {
PerformBackwardInplace(cur_op.Type(), bwd_ins, &tmp_outs);
}
{
VLOG(3) << "Start to execute grad op " << cur_op.Type();
try {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册