diff --git a/paddle/fluid/imperative/basic_engine.cc b/paddle/fluid/imperative/basic_engine.cc index 60bc88ca7237c44dc63aa98e0064ab59addd707c..cef3af06401a3190c86ec26724e2dfd9f0702d22 100644 --- a/paddle/fluid/imperative/basic_engine.cc +++ b/paddle/fluid/imperative/basic_engine.cc @@ -33,8 +33,10 @@ namespace paddle { namespace imperative { -void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy) { +void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy, + bool retain_graph) { backward_strategy_ = strategy; + retain_graph_ = retain_graph; init_node_ = var->GradVarBase()->GradNode(); var->GradVarBase()->ClearGradNode(); @@ -224,7 +226,9 @@ void BasicEngine::Execute() { need_accu_var_list_.clear(); VLOG(3) << "Remove op after op " << cur_op.Type() << " runs"; - cur_op.ClearBackwardTrace(); + if (!retain_graph_) { + cur_op.ClearBackwardTrace(); + } } // Step 3: Collect ready ops diff --git a/paddle/fluid/imperative/basic_engine.h b/paddle/fluid/imperative/basic_engine.h index 2d517bb43d39f0321fe0a42718f20b9c457d01bb..4d25d81235098cca37491b1d8e43b481adc2fd0a 100644 --- a/paddle/fluid/imperative/basic_engine.h +++ b/paddle/fluid/imperative/basic_engine.h @@ -30,7 +30,8 @@ class OpBase; class BasicEngine : public Engine { public: - void Init(VarBase* var, const detail::BackwardStrategy& strategy); + void Init(VarBase* var, const detail::BackwardStrategy& strategy, + bool retain_graph = false); void Execute() override; @@ -51,6 +52,7 @@ class BasicEngine : public Engine { accumulators_; std::vector>> need_accu_var_list_; + bool retain_graph_; }; } // namespace imperative diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index cc25e0fda160767a8f1d3deb2627f77f49ac2ea8..f32fd5192be90a24b2c18049e57c9d649bec8257 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -694,11 +694,11 @@ void BindImperative(py::module *m_ptr) { .def("_run_backward", [](imperative::VarBase &self, const imperative::detail::BackwardStrategy &bckst, - const imperative::Tracer &tracer) { + const imperative::Tracer &tracer, bool retain_graph) { // TODO(jiabin): when we impl more backward execution we can // select them auto *engine = tracer.GetEngine(); - engine->Init(&self, bckst); + engine->Init(&self, bckst, retain_graph); VLOG(3) << "Start backward"; engine->Execute(); VLOG(3) << "Finish backward"; diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index 5da400391e3e3778104c59e8a8c3f93c7e405e5d..16ca96a6cc87e6baae713b4a4d32cd9fbc2a148c 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -73,7 +73,7 @@ def monkey_patch_varbase(): framework._current_expected_place()) @framework.dygraph_only - def backward(self, backward_strategy=None): + def backward(self, backward_strategy=None, retain_graph=False): """ **Notes**: **This API is ONLY available in Dygraph mode** @@ -82,6 +82,10 @@ def monkey_patch_varbase(): Args: backward_strategy( :ref:`api_fluid_dygraph_BackwardStrategy` ): The Backward Strategy to run backward + retain_graph(bool, optional): If False, the graph used to compute grads will be freed. If you would + like to add more ops to the built graph after calling this method(`backward`), set the parameter + `retain_graph` to True, then the grads will be retained. Thus, seting it to False is much more memory-efficient. + Defaults to False. Returns: NoneType: None @@ -113,7 +117,8 @@ def monkey_patch_varbase(): backward_strategy = BackwardStrategy() backward_strategy.sort_sum_gradient = False - self._run_backward(backward_strategy, framework._dygraph_tracer()) + self._run_backward(backward_strategy, + framework._dygraph_tracer(), retain_graph) else: raise ValueError( "Variable.backward() is only available in DyGraph mode")