diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/record_skip_memory_opt_vars_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/record_skip_memory_opt_vars_pass.cc index 075a1955eb641832dd8cc3c11befd58e798b545b..040b769f89dd6de6cf3585d1e5f83da8fdb700d3 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/record_skip_memory_opt_vars_pass.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/record_skip_memory_opt_vars_pass.cc @@ -140,9 +140,9 @@ class RecordSkipMemoryOptVarsPass : public ir::Pass { // fail since "states" and "ex_states" cannot be found in main block. // When memory optimization is enabled, "states", "ex_states" and their // gradient should be skipped. - auto& ex_states = + auto ex_states = boost::get>(op_desc->GetAttr("ex_states")); - auto& states = + auto states = boost::get>(op_desc->GetAttr("states")); if (op_type == "recurrent") { UpdateSkipVarSet(skip_vars, {ex_states, states}); @@ -154,7 +154,7 @@ class RecordSkipMemoryOptVarsPass : public ir::Pass { UpdateSkipVarSet( skip_vars, {ToGradVarName(op_desc->Input("parameters")), - ToGradVarName(op_desc->Input("input")), ex_states, states, + ToGradVarName(op_desc->Input("inputs")), ex_states, states, ToGradVarName(ex_states), ToGradVarName(states)}); } } diff --git a/paddle/fluid/operators/recurrent_op.cc b/paddle/fluid/operators/recurrent_op.cc index 1a2feee11c951cd4a55958df58f3756472f64769..6ead10c9987a7c20266cb0c4af5a9d02721c15f4 100644 --- a/paddle/fluid/operators/recurrent_op.cc +++ b/paddle/fluid/operators/recurrent_op.cc @@ -508,6 +508,7 @@ class RecurrentGradOp : public RecurrentBase { for (auto *sub_scope : *step_scopes) { const_cast(scope).DeleteScope(sub_scope); } + step_scopes->clear(); } private: diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 41f9016edcb0964b4a95c10e257d10d548306ee8..9030a33f3ef4531694481039330c48c0e7d22b4b 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -232,15 +232,8 @@ def _remove_no_grad_branch_(op_descs, no_grad_set): for arg in op_desc.input_arg_names(): if core.grad_var_suffix() in arg and arg in no_grad_set: x_in = _strip_grad_suffix_(arg) - x_in_var_desc = op_desc.block().find_var_recursive( - cpt.to_bytes(x_in)) - assert x_in_var_desc is not None, "Variable {} not found".format( - x_in) - dtype = x_in_var_desc.dtype() - - to_insert.append( - (_create_op_desc_("fill_zeros_like2", {"X": [x_in]}, - {"Out": [arg]}, {"dtype": dtype}), idx)) + to_insert.append((_create_op_desc_( + "fill_zeros_like", {"X": [x_in]}, {"Out": [arg]}, {}), idx)) list([op_descs.insert(p[1], p[0]) for p in reversed(to_insert)])