未验证 提交 712bfb17 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix recurrent_op,test=develop (#17433)

上级 5babcd02
......@@ -140,9 +140,9 @@ class RecordSkipMemoryOptVarsPass : public ir::Pass {
// fail since "states" and "ex_states" cannot be found in main block.
// When memory optimization is enabled, "states", "ex_states" and their
// gradient should be skipped.
auto& ex_states =
auto ex_states =
boost::get<std::vector<std::string>>(op_desc->GetAttr("ex_states"));
auto& states =
auto states =
boost::get<std::vector<std::string>>(op_desc->GetAttr("states"));
if (op_type == "recurrent") {
UpdateSkipVarSet(skip_vars, {ex_states, states});
......@@ -154,7 +154,7 @@ class RecordSkipMemoryOptVarsPass : public ir::Pass {
UpdateSkipVarSet(
skip_vars,
{ToGradVarName(op_desc->Input("parameters")),
ToGradVarName(op_desc->Input("input")), ex_states, states,
ToGradVarName(op_desc->Input("inputs")), ex_states, states,
ToGradVarName(ex_states), ToGradVarName(states)});
}
}
......
......@@ -508,6 +508,7 @@ class RecurrentGradOp : public RecurrentBase {
for (auto *sub_scope : *step_scopes) {
const_cast<framework::Scope &>(scope).DeleteScope(sub_scope);
}
step_scopes->clear();
}
private:
......
......@@ -232,15 +232,8 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
for arg in op_desc.input_arg_names():
if core.grad_var_suffix() in arg and arg in no_grad_set:
x_in = _strip_grad_suffix_(arg)
x_in_var_desc = op_desc.block().find_var_recursive(
cpt.to_bytes(x_in))
assert x_in_var_desc is not None, "Variable {} not found".format(
x_in)
dtype = x_in_var_desc.dtype()
to_insert.append(
(_create_op_desc_("fill_zeros_like2", {"X": [x_in]},
{"Out": [arg]}, {"dtype": dtype}), idx))
to_insert.append((_create_op_desc_(
"fill_zeros_like", {"X": [x_in]}, {"Out": [arg]}, {}), idx))
list([op_descs.insert(p[1], p[0]) for p in reversed(to_insert)])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册