未验证 提交 81cfbddc 编写于 作者: X xiongkun 提交者: GitHub

fix recurrent_grad tmp variable@GRAD don't exsit in VariableScope (#37061)

上级 f5caf9c5
......@@ -172,7 +172,8 @@ std::vector<OperatorBase*> create_all_ops(const framework::BlockDesc& block) {
}
std::tuple<VariableValueMap, VariableIdMap> build_variable_map(
const VariableNameMap& var_name_map, VariableScope* var_scope) {
const VariableNameMap& var_name_map, VariableScope* var_scope,
bool enforce_exist = true) {
VariableValueMap name2var;
VariableIdMap name2id;
for (auto& item : var_name_map) {
......@@ -181,6 +182,11 @@ std::tuple<VariableValueMap, VariableIdMap> build_variable_map(
vars.reserve(item.second.size());
for (auto& var_name : item.second) {
if (!enforce_exist && !var_scope->HasVar(var_name)) {
// skip the non-exist variable: such as recurrent_grad
VLOG(4) << var_name << " don't exist in variable scope, skip it!";
continue;
}
auto var_id = var_scope->VarId(var_name);
auto* in_var = var_scope->Var(var_id);
vars.push_back(in_var);
......@@ -436,13 +442,15 @@ void build_op_func_list(const platform::Place& place,
VariableValueMap ins_map;
VariableIdMap ins_name2id;
bool enforce_exist = true;
if (op->Type() == "recurrent_grad") enforce_exist = false;
std::tie(ins_map, ins_name2id) =
build_variable_map(inputs_names, var_scope);
build_variable_map(inputs_names, var_scope, enforce_exist);
VariableValueMap outs_map;
VariableIdMap outs_name2id;
std::tie(outs_map, outs_name2id) =
build_variable_map(outputs_names, var_scope);
build_variable_map(outputs_names, var_scope, enforce_exist);
// step 2: build OpFuncNode
OpFuncNode op_func_node;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册