未验证 提交 712bfb17 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix recurrent_op,test=develop (#17433)

上级 5babcd02
...@@ -140,9 +140,9 @@ class RecordSkipMemoryOptVarsPass : public ir::Pass { ...@@ -140,9 +140,9 @@ class RecordSkipMemoryOptVarsPass : public ir::Pass {
// fail since "states" and "ex_states" cannot be found in main block. // fail since "states" and "ex_states" cannot be found in main block.
// When memory optimization is enabled, "states", "ex_states" and their // When memory optimization is enabled, "states", "ex_states" and their
// gradient should be skipped. // gradient should be skipped.
auto& ex_states = auto ex_states =
boost::get<std::vector<std::string>>(op_desc->GetAttr("ex_states")); boost::get<std::vector<std::string>>(op_desc->GetAttr("ex_states"));
auto& states = auto states =
boost::get<std::vector<std::string>>(op_desc->GetAttr("states")); boost::get<std::vector<std::string>>(op_desc->GetAttr("states"));
if (op_type == "recurrent") { if (op_type == "recurrent") {
UpdateSkipVarSet(skip_vars, {ex_states, states}); UpdateSkipVarSet(skip_vars, {ex_states, states});
...@@ -154,7 +154,7 @@ class RecordSkipMemoryOptVarsPass : public ir::Pass { ...@@ -154,7 +154,7 @@ class RecordSkipMemoryOptVarsPass : public ir::Pass {
UpdateSkipVarSet( UpdateSkipVarSet(
skip_vars, skip_vars,
{ToGradVarName(op_desc->Input("parameters")), {ToGradVarName(op_desc->Input("parameters")),
ToGradVarName(op_desc->Input("input")), ex_states, states, ToGradVarName(op_desc->Input("inputs")), ex_states, states,
ToGradVarName(ex_states), ToGradVarName(states)}); ToGradVarName(ex_states), ToGradVarName(states)});
} }
} }
......
...@@ -508,6 +508,7 @@ class RecurrentGradOp : public RecurrentBase { ...@@ -508,6 +508,7 @@ class RecurrentGradOp : public RecurrentBase {
for (auto *sub_scope : *step_scopes) { for (auto *sub_scope : *step_scopes) {
const_cast<framework::Scope &>(scope).DeleteScope(sub_scope); const_cast<framework::Scope &>(scope).DeleteScope(sub_scope);
} }
step_scopes->clear();
} }
private: private:
......
...@@ -232,15 +232,8 @@ def _remove_no_grad_branch_(op_descs, no_grad_set): ...@@ -232,15 +232,8 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
for arg in op_desc.input_arg_names(): for arg in op_desc.input_arg_names():
if core.grad_var_suffix() in arg and arg in no_grad_set: if core.grad_var_suffix() in arg and arg in no_grad_set:
x_in = _strip_grad_suffix_(arg) x_in = _strip_grad_suffix_(arg)
x_in_var_desc = op_desc.block().find_var_recursive( to_insert.append((_create_op_desc_(
cpt.to_bytes(x_in)) "fill_zeros_like", {"X": [x_in]}, {"Out": [arg]}, {}), idx))
assert x_in_var_desc is not None, "Variable {} not found".format(
x_in)
dtype = x_in_var_desc.dtype()
to_insert.append(
(_create_op_desc_("fill_zeros_like2", {"X": [x_in]},
{"Out": [arg]}, {"dtype": dtype}), idx))
list([op_descs.insert(p[1], p[0]) for p in reversed(to_insert)]) list([op_descs.insert(p[1], p[0]) for p in reversed(to_insert)])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册