未验证 提交 cea50868 编写于 作者: Z Zhen Wang 提交者: GitHub

Fix the double grad bug for the star gan. (#25655)

* fix the double grad bug for the star gan. test=develop

* update the retain_graph parameter doc. test=develop

* add the unit test for the retain_graph parameter. test=develop
上级 e8caffbb
......@@ -33,8 +33,10 @@
namespace paddle {
namespace imperative {
void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy) {
void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy,
bool retain_graph) {
backward_strategy_ = strategy;
retain_graph_ = retain_graph;
init_node_ = var->GradVarBase()->GradNode();
var->GradVarBase()->ClearGradNode();
......@@ -226,7 +228,9 @@ void BasicEngine::Execute() {
need_accu_var_list_.clear();
VLOG(3) << "Remove op after op " << cur_op.Type() << " runs";
cur_op.ClearBackwardTrace();
if (!retain_graph_) {
cur_op.ClearBackwardTrace();
}
}
// Step 3: Collect ready ops
......
......@@ -30,7 +30,8 @@ class OpBase;
class BasicEngine : public Engine {
public:
void Init(VarBase* var, const detail::BackwardStrategy& strategy);
void Init(VarBase* var, const detail::BackwardStrategy& strategy,
bool retain_graph = false);
void Execute() override;
......@@ -51,6 +52,7 @@ class BasicEngine : public Engine {
accumulators_;
std::vector<std::pair<GradientAccumulator*, std::shared_ptr<VariableWrapper>>>
need_accu_var_list_;
bool retain_graph_;
};
} // namespace imperative
......
......@@ -721,11 +721,11 @@ void BindImperative(py::module *m_ptr) {
.def("_run_backward",
[](imperative::VarBase &self,
const imperative::detail::BackwardStrategy &bckst,
const imperative::Tracer &tracer) {
const imperative::Tracer &tracer, bool retain_graph) {
// TODO(jiabin): when we impl more backward execution we can
// select them
auto *engine = tracer.GetEngine();
engine->Init(&self, bckst);
engine->Init(&self, bckst, retain_graph);
VLOG(3) << "Start backward";
engine->Execute();
VLOG(3) << "Finish backward";
......
......@@ -124,7 +124,7 @@ def monkey_patch_varbase():
framework._current_expected_place())
@framework.dygraph_only
def backward(self, backward_strategy=None):
def backward(self, backward_strategy=None, retain_graph=False):
"""
**Notes**:
**This API is ONLY available in Dygraph mode**
......@@ -133,6 +133,10 @@ def monkey_patch_varbase():
Args:
backward_strategy( :ref:`api_fluid_dygraph_BackwardStrategy` ): The Backward Strategy to run backward
retain_graph(bool, optional): If False, the graph used to compute grads will be freed. If you would
like to add more ops to the built graph after calling this method(`backward`), set the parameter
`retain_graph` to True, then the grads will be retained. Thus, seting it to False is much more memory-efficient.
Defaults to False.
Returns:
NoneType: None
......@@ -164,7 +168,8 @@ def monkey_patch_varbase():
backward_strategy = BackwardStrategy()
backward_strategy.sort_sum_gradient = False
self._run_backward(backward_strategy, framework._dygraph_tracer())
self._run_backward(backward_strategy,
framework._dygraph_tracer(), retain_graph)
else:
raise ValueError(
"Variable.backward() is only available in DyGraph mode")
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.fluid as fluid
import unittest
paddle.enable_imperative()
SEED = 2020
np.random.seed(SEED)
fluid.default_main_program().random_seed = SEED
class Generator(fluid.dygraph.Layer):
def __init__(self):
super(Generator, self).__init__()
self.conv1 = paddle.nn.Conv2D(3, 3, 3, 1)
def forward(self, x):
x = self.conv1(x)
x = fluid.layers.tanh(x)
return x
class Discriminator(fluid.dygraph.Layer):
def __init__(self):
super(Discriminator, self).__init__()
self.convd = paddle.nn.Conv2D(6, 3, 1)
def forward(self, x):
x = self.convd(x)
return x
class TestRetainGraph(unittest.TestCase):
def cal_gradient_penalty(self,
netD,
real_data,
fake_data,
edge_data=None,
type='mixed',
constant=1.0,
lambda_gp=10.0):
if lambda_gp > 0.0:
if type == 'real':
interpolatesv = real_data
elif type == 'fake':
interpolatesv = fake_data
elif type == 'mixed':
alpha = paddle.rand((real_data.shape[0], 1))
alpha = paddle.expand(
alpha, [1, np.prod(real_data.shape) // real_data.shape[0]])
alpha = paddle.reshape(alpha, real_data.shape)
interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
else:
raise NotImplementedError('{} not implemented'.format(type))
interpolatesv.stop_gradient = False
real_data.stop_gradient = True
fake_AB = paddle.concat((real_data.detach(), interpolatesv), 1)
disc_interpolates = netD(fake_AB)
outs = paddle.fill_constant(disc_interpolates.shape,
disc_interpolates.dtype, 1.0)
gradients = paddle.imperative.grad(
outputs=disc_interpolates,
inputs=fake_AB,
grad_outputs=outs,
create_graph=True,
retain_graph=True,
only_inputs=True)
gradients = paddle.reshape(gradients[0], [real_data.shape[0], -1])
gradient_penalty = paddle.reduce_mean((paddle.norm(
gradients + 1e-16, 2, 1) - constant)**
2) * lambda_gp # added eps
return gradient_penalty, gradients
else:
return 0.0, None
def test_retain(self):
g = Generator()
d = Discriminator()
optim_g = paddle.optimizer.Adam(parameter_list=g.parameters())
optim_d = paddle.optimizer.Adam(parameter_list=d.parameters())
gan_criterion = paddle.nn.MSELoss()
l1_criterion = paddle.nn.L1Loss()
A = np.random.rand(2, 3, 32, 32).astype('float32')
B = np.random.rand(2, 3, 32, 32).astype('float32')
realA = paddle.imperative.to_variable(A)
realB = paddle.imperative.to_variable(B)
fakeB = g(realA)
optim_d.clear_gradients()
fake_AB = paddle.concat((realA, fakeB), 1)
G_pred_fake = d(fake_AB.detach())
false_target = paddle.fill_constant(G_pred_fake.shape, 'float32', 0.0)
G_gradient_penalty, _ = self.cal_gradient_penalty(
d, realA, fakeB, lambda_gp=10.0)
loss_d = gan_criterion(G_pred_fake, false_target) + G_gradient_penalty
loss_d.backward(retain_graph=True)
optim_d.minimize(loss_d)
optim_g.clear_gradients()
fake_AB = paddle.concat((realA, fakeB), 1)
G_pred_fake = d(fake_AB)
true_target = paddle.fill_constant(G_pred_fake.shape, 'float32', 1.0)
loss_g = l1_criterion(fakeB, realB) + gan_criterion(G_pred_fake,
true_target)
loss_g.backward()
optim_g.minimize(loss_g)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册