From 81cfbddc1a6e6eb9d1f6db829d1a5e2509cf1d7e Mon Sep 17 00:00:00 2001 From: xiongkun Date: Wed, 10 Nov 2021 16:34:32 +0800 Subject: [PATCH] fix recurrent_grad tmp variable@GRAD don't exsit in VariableScope (#37061) --- .../framework/new_executor/interpretercore_util.cc | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.cc b/paddle/fluid/framework/new_executor/interpretercore_util.cc index 7b52cc991e1..acf0b4b30c7 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.cc +++ b/paddle/fluid/framework/new_executor/interpretercore_util.cc @@ -172,7 +172,8 @@ std::vector create_all_ops(const framework::BlockDesc& block) { } std::tuple build_variable_map( - const VariableNameMap& var_name_map, VariableScope* var_scope) { + const VariableNameMap& var_name_map, VariableScope* var_scope, + bool enforce_exist = true) { VariableValueMap name2var; VariableIdMap name2id; for (auto& item : var_name_map) { @@ -181,6 +182,11 @@ std::tuple build_variable_map( vars.reserve(item.second.size()); for (auto& var_name : item.second) { + if (!enforce_exist && !var_scope->HasVar(var_name)) { + // skip the non-exist variable: such as recurrent_grad + VLOG(4) << var_name << " don't exist in variable scope, skip it!"; + continue; + } auto var_id = var_scope->VarId(var_name); auto* in_var = var_scope->Var(var_id); vars.push_back(in_var); @@ -436,13 +442,15 @@ void build_op_func_list(const platform::Place& place, VariableValueMap ins_map; VariableIdMap ins_name2id; + bool enforce_exist = true; + if (op->Type() == "recurrent_grad") enforce_exist = false; std::tie(ins_map, ins_name2id) = - build_variable_map(inputs_names, var_scope); + build_variable_map(inputs_names, var_scope, enforce_exist); VariableValueMap outs_map; VariableIdMap outs_name2id; std::tie(outs_map, outs_name2id) = - build_variable_map(outputs_names, var_scope); + build_variable_map(outputs_names, var_scope, enforce_exist); // step 2: build OpFuncNode OpFuncNode op_func_node; -- GitLab