未验证 提交 731caea3 编写于 作者: Z Zhen Wang 提交者: GitHub

[Cherry-pick]Fix the double grad bug for the star gan. (#25655) (#25964)

* Fix the double grad bug for the star gan. (#25655)

* update the retain_graph parameter doc. test=develop
上级 2a7efefe
...@@ -33,8 +33,10 @@ ...@@ -33,8 +33,10 @@
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy) { void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy,
bool retain_graph) {
backward_strategy_ = strategy; backward_strategy_ = strategy;
retain_graph_ = retain_graph;
init_node_ = var->GradVarBase()->GradNode(); init_node_ = var->GradVarBase()->GradNode();
var->GradVarBase()->ClearGradNode(); var->GradVarBase()->ClearGradNode();
...@@ -224,8 +226,10 @@ void BasicEngine::Execute() { ...@@ -224,8 +226,10 @@ void BasicEngine::Execute() {
need_accu_var_list_.clear(); need_accu_var_list_.clear();
VLOG(3) << "Remove op after op " << cur_op.Type() << " runs"; VLOG(3) << "Remove op after op " << cur_op.Type() << " runs";
if (!retain_graph_) {
cur_op.ClearBackwardTrace(); cur_op.ClearBackwardTrace();
} }
}
// Step 3: Collect ready ops // Step 3: Collect ready ops
for (auto& grad_pending_node : shared_cur_node->GradPendingNodes()) { for (auto& grad_pending_node : shared_cur_node->GradPendingNodes()) {
......
...@@ -30,7 +30,8 @@ class OpBase; ...@@ -30,7 +30,8 @@ class OpBase;
class BasicEngine : public Engine { class BasicEngine : public Engine {
public: public:
void Init(VarBase* var, const detail::BackwardStrategy& strategy); void Init(VarBase* var, const detail::BackwardStrategy& strategy,
bool retain_graph = false);
void Execute() override; void Execute() override;
...@@ -51,6 +52,7 @@ class BasicEngine : public Engine { ...@@ -51,6 +52,7 @@ class BasicEngine : public Engine {
accumulators_; accumulators_;
std::vector<std::pair<GradientAccumulator*, std::shared_ptr<VariableWrapper>>> std::vector<std::pair<GradientAccumulator*, std::shared_ptr<VariableWrapper>>>
need_accu_var_list_; need_accu_var_list_;
bool retain_graph_;
}; };
} // namespace imperative } // namespace imperative
......
...@@ -694,11 +694,11 @@ void BindImperative(py::module *m_ptr) { ...@@ -694,11 +694,11 @@ void BindImperative(py::module *m_ptr) {
.def("_run_backward", .def("_run_backward",
[](imperative::VarBase &self, [](imperative::VarBase &self,
const imperative::detail::BackwardStrategy &bckst, const imperative::detail::BackwardStrategy &bckst,
const imperative::Tracer &tracer) { const imperative::Tracer &tracer, bool retain_graph) {
// TODO(jiabin): when we impl more backward execution we can // TODO(jiabin): when we impl more backward execution we can
// select them // select them
auto *engine = tracer.GetEngine(); auto *engine = tracer.GetEngine();
engine->Init(&self, bckst); engine->Init(&self, bckst, retain_graph);
VLOG(3) << "Start backward"; VLOG(3) << "Start backward";
engine->Execute(); engine->Execute();
VLOG(3) << "Finish backward"; VLOG(3) << "Finish backward";
......
...@@ -73,7 +73,7 @@ def monkey_patch_varbase(): ...@@ -73,7 +73,7 @@ def monkey_patch_varbase():
framework._current_expected_place()) framework._current_expected_place())
@framework.dygraph_only @framework.dygraph_only
def backward(self, backward_strategy=None): def backward(self, backward_strategy=None, retain_graph=False):
""" """
**Notes**: **Notes**:
**This API is ONLY available in Dygraph mode** **This API is ONLY available in Dygraph mode**
...@@ -82,6 +82,10 @@ def monkey_patch_varbase(): ...@@ -82,6 +82,10 @@ def monkey_patch_varbase():
Args: Args:
backward_strategy( :ref:`api_fluid_dygraph_BackwardStrategy` ): The Backward Strategy to run backward backward_strategy( :ref:`api_fluid_dygraph_BackwardStrategy` ): The Backward Strategy to run backward
retain_graph(bool, optional): If False, the graph used to compute grads will be freed. If you would
like to add more ops to the built graph after calling this method(`backward`), set the parameter
`retain_graph` to True, then the grads will be retained. Thus, seting it to False is much more memory-efficient.
Defaults to False.
Returns: Returns:
NoneType: None NoneType: None
...@@ -113,7 +117,8 @@ def monkey_patch_varbase(): ...@@ -113,7 +117,8 @@ def monkey_patch_varbase():
backward_strategy = BackwardStrategy() backward_strategy = BackwardStrategy()
backward_strategy.sort_sum_gradient = False backward_strategy.sort_sum_gradient = False
self._run_backward(backward_strategy, framework._dygraph_tracer()) self._run_backward(backward_strategy,
framework._dygraph_tracer(), retain_graph)
else: else:
raise ValueError( raise ValueError(
"Variable.backward() is only available in DyGraph mode") "Variable.backward() is only available in DyGraph mode")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册