未验证 提交 731caea3 编写于 作者: Z Zhen Wang 提交者: GitHub

[Cherry-pick]Fix the double grad bug for the star gan. (#25655) (#25964)

* Fix the double grad bug for the star gan. (#25655)

* update the retain_graph parameter doc. test=develop
上级 2a7efefe
......@@ -33,8 +33,10 @@
namespace paddle {
namespace imperative {
void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy) {
void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy,
bool retain_graph) {
backward_strategy_ = strategy;
retain_graph_ = retain_graph;
init_node_ = var->GradVarBase()->GradNode();
var->GradVarBase()->ClearGradNode();
......@@ -224,8 +226,10 @@ void BasicEngine::Execute() {
need_accu_var_list_.clear();
VLOG(3) << "Remove op after op " << cur_op.Type() << " runs";
if (!retain_graph_) {
cur_op.ClearBackwardTrace();
}
}
// Step 3: Collect ready ops
for (auto& grad_pending_node : shared_cur_node->GradPendingNodes()) {
......
......@@ -30,7 +30,8 @@ class OpBase;
class BasicEngine : public Engine {
public:
void Init(VarBase* var, const detail::BackwardStrategy& strategy);
void Init(VarBase* var, const detail::BackwardStrategy& strategy,
bool retain_graph = false);
void Execute() override;
......@@ -51,6 +52,7 @@ class BasicEngine : public Engine {
accumulators_;
std::vector<std::pair<GradientAccumulator*, std::shared_ptr<VariableWrapper>>>
need_accu_var_list_;
bool retain_graph_;
};
} // namespace imperative
......
......@@ -694,11 +694,11 @@ void BindImperative(py::module *m_ptr) {
.def("_run_backward",
[](imperative::VarBase &self,
const imperative::detail::BackwardStrategy &bckst,
const imperative::Tracer &tracer) {
const imperative::Tracer &tracer, bool retain_graph) {
// TODO(jiabin): when we impl more backward execution we can
// select them
auto *engine = tracer.GetEngine();
engine->Init(&self, bckst);
engine->Init(&self, bckst, retain_graph);
VLOG(3) << "Start backward";
engine->Execute();
VLOG(3) << "Finish backward";
......
......@@ -73,7 +73,7 @@ def monkey_patch_varbase():
framework._current_expected_place())
@framework.dygraph_only
def backward(self, backward_strategy=None):
def backward(self, backward_strategy=None, retain_graph=False):
"""
**Notes**:
**This API is ONLY available in Dygraph mode**
......@@ -82,6 +82,10 @@ def monkey_patch_varbase():
Args:
backward_strategy( :ref:`api_fluid_dygraph_BackwardStrategy` ): The Backward Strategy to run backward
retain_graph(bool, optional): If False, the graph used to compute grads will be freed. If you would
like to add more ops to the built graph after calling this method(`backward`), set the parameter
`retain_graph` to True, then the grads will be retained. Thus, seting it to False is much more memory-efficient.
Defaults to False.
Returns:
NoneType: None
......@@ -113,7 +117,8 @@ def monkey_patch_varbase():
backward_strategy = BackwardStrategy()
backward_strategy.sort_sum_gradient = False
self._run_backward(backward_strategy, framework._dygraph_tracer())
self._run_backward(backward_strategy,
framework._dygraph_tracer(), retain_graph)
else:
raise ValueError(
"Variable.backward() is only available in DyGraph mode")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册