diff --git a/paddle/fluid/imperative/basic_engine.cc b/paddle/fluid/imperative/basic_engine.cc index 6c2f74e2712b0e7ccdce60e2b2c53ee529b52c5c..de1246883f1019bc3e6adabadbc9e071926eb772 100644 --- a/paddle/fluid/imperative/basic_engine.cc +++ b/paddle/fluid/imperative/basic_engine.cc @@ -33,8 +33,10 @@ namespace paddle { namespace imperative { -void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy) { +void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy, + bool retain_graph) { backward_strategy_ = strategy; + retain_graph_ = retain_graph; init_node_ = var->GradVarBase()->GradNode(); var->GradVarBase()->ClearGradNode(); @@ -226,7 +228,9 @@ void BasicEngine::Execute() { need_accu_var_list_.clear(); VLOG(3) << "Remove op after op " << cur_op.Type() << " runs"; - cur_op.ClearBackwardTrace(); + if (!retain_graph_) { + cur_op.ClearBackwardTrace(); + } } // Step 3: Collect ready ops diff --git a/paddle/fluid/imperative/basic_engine.h b/paddle/fluid/imperative/basic_engine.h index 2d517bb43d39f0321fe0a42718f20b9c457d01bb..4d25d81235098cca37491b1d8e43b481adc2fd0a 100644 --- a/paddle/fluid/imperative/basic_engine.h +++ b/paddle/fluid/imperative/basic_engine.h @@ -30,7 +30,8 @@ class OpBase; class BasicEngine : public Engine { public: - void Init(VarBase* var, const detail::BackwardStrategy& strategy); + void Init(VarBase* var, const detail::BackwardStrategy& strategy, + bool retain_graph = false); void Execute() override; @@ -51,6 +52,7 @@ class BasicEngine : public Engine { accumulators_; std::vector>> need_accu_var_list_; + bool retain_graph_; }; } // namespace imperative diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 626f6b1ecc217039b2e587413f26bc1ba688d27d..82941c58280560b1c09b149da01ef3d6e8a3f8e0 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -721,11 +721,11 @@ void BindImperative(py::module *m_ptr) { .def("_run_backward", [](imperative::VarBase &self, const imperative::detail::BackwardStrategy &bckst, - const imperative::Tracer &tracer) { + const imperative::Tracer &tracer, bool retain_graph) { // TODO(jiabin): when we impl more backward execution we can // select them auto *engine = tracer.GetEngine(); - engine->Init(&self, bckst); + engine->Init(&self, bckst, retain_graph); VLOG(3) << "Start backward"; engine->Execute(); VLOG(3) << "Finish backward"; diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index 2e41a8ff417b3083d96d0a9bd1fa453c8fddc014..d509fcc38e771bf5a5bacb63602966a871c7c885 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -124,7 +124,7 @@ def monkey_patch_varbase(): framework._current_expected_place()) @framework.dygraph_only - def backward(self, backward_strategy=None): + def backward(self, backward_strategy=None, retain_graph=False): """ **Notes**: **This API is ONLY available in Dygraph mode** @@ -133,6 +133,10 @@ def monkey_patch_varbase(): Args: backward_strategy( :ref:`api_fluid_dygraph_BackwardStrategy` ): The Backward Strategy to run backward + retain_graph(bool, optional): If False, the graph used to compute grads will be freed. If you would + like to add more ops to the built graph after calling this method(`backward`), set the parameter + `retain_graph` to True, then the grads will be retained. Thus, seting it to False is much more memory-efficient. + Defaults to False. Returns: NoneType: None @@ -164,7 +168,8 @@ def monkey_patch_varbase(): backward_strategy = BackwardStrategy() backward_strategy.sort_sum_gradient = False - self._run_backward(backward_strategy, framework._dygraph_tracer()) + self._run_backward(backward_strategy, + framework._dygraph_tracer(), retain_graph) else: raise ValueError( "Variable.backward() is only available in DyGraph mode") diff --git a/python/paddle/fluid/tests/unittests/test_retain_graph.py b/python/paddle/fluid/tests/unittests/test_retain_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..bc50cf197f63e6082ea1d3fdbff1891f500e5b9a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_retain_graph.py @@ -0,0 +1,135 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +import paddle.fluid as fluid +import unittest + +paddle.enable_imperative() +SEED = 2020 +np.random.seed(SEED) +fluid.default_main_program().random_seed = SEED + + +class Generator(fluid.dygraph.Layer): + def __init__(self): + super(Generator, self).__init__() + self.conv1 = paddle.nn.Conv2D(3, 3, 3, 1) + + def forward(self, x): + x = self.conv1(x) + x = fluid.layers.tanh(x) + return x + + +class Discriminator(fluid.dygraph.Layer): + def __init__(self): + super(Discriminator, self).__init__() + self.convd = paddle.nn.Conv2D(6, 3, 1) + + def forward(self, x): + x = self.convd(x) + return x + + +class TestRetainGraph(unittest.TestCase): + def cal_gradient_penalty(self, + netD, + real_data, + fake_data, + edge_data=None, + type='mixed', + constant=1.0, + lambda_gp=10.0): + if lambda_gp > 0.0: + if type == 'real': + interpolatesv = real_data + elif type == 'fake': + interpolatesv = fake_data + elif type == 'mixed': + alpha = paddle.rand((real_data.shape[0], 1)) + alpha = paddle.expand( + alpha, [1, np.prod(real_data.shape) // real_data.shape[0]]) + alpha = paddle.reshape(alpha, real_data.shape) + interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) + else: + raise NotImplementedError('{} not implemented'.format(type)) + interpolatesv.stop_gradient = False + real_data.stop_gradient = True + fake_AB = paddle.concat((real_data.detach(), interpolatesv), 1) + disc_interpolates = netD(fake_AB) + + outs = paddle.fill_constant(disc_interpolates.shape, + disc_interpolates.dtype, 1.0) + gradients = paddle.imperative.grad( + outputs=disc_interpolates, + inputs=fake_AB, + grad_outputs=outs, + create_graph=True, + retain_graph=True, + only_inputs=True) + + gradients = paddle.reshape(gradients[0], [real_data.shape[0], -1]) + + gradient_penalty = paddle.reduce_mean((paddle.norm( + gradients + 1e-16, 2, 1) - constant)** + 2) * lambda_gp # added eps + return gradient_penalty, gradients + else: + return 0.0, None + + def test_retain(self): + g = Generator() + d = Discriminator() + + optim_g = paddle.optimizer.Adam(parameter_list=g.parameters()) + optim_d = paddle.optimizer.Adam(parameter_list=d.parameters()) + + gan_criterion = paddle.nn.MSELoss() + l1_criterion = paddle.nn.L1Loss() + + A = np.random.rand(2, 3, 32, 32).astype('float32') + B = np.random.rand(2, 3, 32, 32).astype('float32') + + realA = paddle.imperative.to_variable(A) + realB = paddle.imperative.to_variable(B) + fakeB = g(realA) + + optim_d.clear_gradients() + fake_AB = paddle.concat((realA, fakeB), 1) + G_pred_fake = d(fake_AB.detach()) + + false_target = paddle.fill_constant(G_pred_fake.shape, 'float32', 0.0) + + G_gradient_penalty, _ = self.cal_gradient_penalty( + d, realA, fakeB, lambda_gp=10.0) + loss_d = gan_criterion(G_pred_fake, false_target) + G_gradient_penalty + + loss_d.backward(retain_graph=True) + optim_d.minimize(loss_d) + + optim_g.clear_gradients() + fake_AB = paddle.concat((realA, fakeB), 1) + G_pred_fake = d(fake_AB) + true_target = paddle.fill_constant(G_pred_fake.shape, 'float32', 1.0) + loss_g = l1_criterion(fakeB, realB) + gan_criterion(G_pred_fake, + true_target) + + loss_g.backward() + optim_g.minimize(loss_g) + + +if __name__ == '__main__': + unittest.main()