From 8bfd62ffb72e44235cc5baca67d698832b53c2ad Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Fri, 27 Mar 2020 05:59:23 -0500 Subject: [PATCH] Expose dygraph.grad api (#23124) * expose dygraph.grad api, test=develop, test=document_fix * add more parameter in dygraph.grad API, test=develop * add only_inputs=True parameter, test=develop * follow comments, test=develop, test=document_fix * fix typo, test=develop, test=document_fix --- paddle/fluid/imperative/basic_engine.cc | 12 +- paddle/fluid/imperative/op_base.h | 14 ++ .../fluid/imperative/partial_grad_engine.cc | 71 ++++--- paddle/fluid/imperative/partial_grad_engine.h | 7 +- paddle/fluid/pybind/imperative.cc | 9 +- python/paddle/fluid/dygraph/base.py | 173 ++++++++++++++++-- .../unittests/test_imperative_double_grad.py | 34 ++-- ...perative_star_gan_with_gradient_penalty.py | 5 +- 8 files changed, 256 insertions(+), 69 deletions(-) diff --git a/paddle/fluid/imperative/basic_engine.cc b/paddle/fluid/imperative/basic_engine.cc index 9a9283a65a..a991ce689a 100644 --- a/paddle/fluid/imperative/basic_engine.cc +++ b/paddle/fluid/imperative/basic_engine.cc @@ -83,7 +83,6 @@ void BasicEngine::CheckBackwardInputs(const OpBase& op) { } if (tensor && !tensor->IsInitialized()) { - // if grad var has OverridedStopGradient skip this Op VLOG(6) << "Set ungenerated Grad: " << var->Name() << " as zero"; auto* dev_ctx = platform::DeviceContextPool::Instance().Get(op.place()); tensor->mutable_data(op.place(), var->DataType()); @@ -139,16 +138,7 @@ void BasicEngine::PrepareDeps() { q.pop(); for (auto& cur_op : *cur_node) { - PADDLE_ENFORCE_NE( - cur_op.GetInsMap().empty() && cur_op.GetOutsMap().empty(), true, - platform::errors::NotFound( - "Inputs and outputs of %s do not exist. " - "This may be because you call \"backward()\" twice for the same " - "subgraph. Please try to call \"stop_gradient = True\" or " - "\"detach()\" if you use some same vars between two " - "\"backward()\" " - "calls.", - cur_op.Type())); + cur_op.EnforceHasInOut(); PrepareGradAccumulators(cur_op); } diff --git a/paddle/fluid/imperative/op_base.h b/paddle/fluid/imperative/op_base.h index b3044cef49..079177acc5 100644 --- a/paddle/fluid/imperative/op_base.h +++ b/paddle/fluid/imperative/op_base.h @@ -119,6 +119,20 @@ class OpBase { void SetPlace(const platform::Place& place) { place_ = place; } + void EnforceHasInOut() const { + PADDLE_ENFORCE_NE( + ins_.empty() && outs_.empty(), true, + platform::errors::NotFound( + "Inputs and outputs of %s do not exist. This may be because:\n" + "1. You use some output variables of the previous batch as the " + "inputs of the current batch. Please try to call \"stop_gradient " + "= True\" or \"detach()\" for these variables.\n" + "2. You calculate backward twice for the same subgraph without " + "setting retain_graph=True. Please set retain_graph=True in the " + "first backward call.\n\n", + Type())); + } + static size_t GenerateUniqueId() { static std::atomic unique_id{0}; return unique_id.fetch_add(1); diff --git a/paddle/fluid/imperative/partial_grad_engine.cc b/paddle/fluid/imperative/partial_grad_engine.cc index 597a9093cd..135d54d1b7 100644 --- a/paddle/fluid/imperative/partial_grad_engine.cc +++ b/paddle/fluid/imperative/partial_grad_engine.cc @@ -57,16 +57,15 @@ namespace imperative { static void GetGraphInfoBetweenTargets( std::unordered_set *input_target_grads, std::unordered_set *output_targets, - std::unordered_set *startup_ops_ptr, - std::unordered_map> - *pending_ops_ptr, - std::unordered_map *op_deps_ptr, + std::unordered_set *startup_ops_ptr, + std::unordered_map> *pending_ops_ptr, + std::unordered_map *op_deps_ptr, std::unordered_set *related_grad_vars_ptr, const std::unordered_set &no_grad_var_grad) { /** * Step 1. Find the candidate startup grad ops, prepared for following BFS. */ - std::queue> q; + std::queue> q; std::unordered_set visited; for (auto iter = output_targets->begin(); iter != output_targets->end();) { auto *output_target = *iter; @@ -98,9 +97,8 @@ static void GetGraphInfoBetweenTargets( * not all input_target_grads would be found. */ std::unordered_set found_input_target_grads; - std::unordered_set endpoint_ops; - std::unordered_map> - preceding_ops; + std::unordered_set endpoint_ops; + std::unordered_map> preceding_ops; while (!q.empty()) { auto op_node_pair = q.front(); q.pop(); @@ -153,8 +151,7 @@ static void GetGraphInfoBetweenTargets( auto &target_vars = *related_grad_vars_ptr; target_vars = *input_target_grads; - std::queue> - op_queue; + std::queue> op_queue; for (auto &endpoint_op : endpoint_ops) { op_queue.emplace(endpoint_op, nullptr); } @@ -238,7 +235,7 @@ static void GetGraphInfoBetweenTargets( for (auto iter = output_targets->begin(); iter != output_targets->end();) { auto &grad_node = (*iter)->GradVarBase()->GradNode(); bool is_valid = std::find_if(grad_node->begin(), grad_node->end(), - [&](const OpBase &op) { + [&](OpBase &op) { // NOLINT return startup_ops.count(&op) > 0; }) != grad_node->end(); if (is_valid) { @@ -518,12 +515,13 @@ class PartialGradTask { const std::vector> &output_grads, const std::vector> &no_grad_vars, const platform::Place &place, - const detail::BackwardStrategy &strategy, bool create_graph); + const detail::BackwardStrategy &strategy, bool create_graph, + bool retain_graph, bool allow_unused, bool only_inputs); std::vector> Run(); private: - void RunEachOp(const OpBase *op); + void RunEachOp(OpBase *op); void PrepareInitialReadyVarsMap(const OpBase *op); @@ -536,10 +534,9 @@ class PartialGradTask { } private: - std::unordered_set startup_ops_; - std::unordered_map> - pending_ops_; - std::unordered_map op_deps_; + std::unordered_set startup_ops_; + std::unordered_map> pending_ops_; + std::unordered_map op_deps_; ReadyGradVarInfoMap ready_grad_vars_; @@ -562,6 +559,9 @@ class PartialGradTask { platform::Place place_; bool create_graph_; + bool retain_graph_; + bool allow_unused_; + bool only_inputs_; detail::BackwardStrategy strategy_; }; @@ -571,12 +571,19 @@ PartialGradTask::PartialGradTask( const std::vector> &output_grads, const std::vector> &no_grad_vars, const platform::Place &place, const detail::BackwardStrategy &strategy, - bool create_graph) { + bool create_graph, bool retain_graph, bool allow_unused, bool only_inputs) { input_targets_ = input_targets; place_ = place; create_graph_ = create_graph; + retain_graph_ = retain_graph; + allow_unused_ = allow_unused; + only_inputs_ = only_inputs; strategy_ = strategy; + PADDLE_ENFORCE_EQ(only_inputs_, true, + platform::errors::Unimplemented( + "only_inputs=False is not supported yet")); + for (auto &var : no_grad_vars) { if (var && var->GradVarBase()) { no_grad_var_grad_.insert(var->GradVarBase()->SharedVar().get()); @@ -738,7 +745,7 @@ PartialGradTask::PartialGradTask( std::vector> PartialGradTask::Run() { VLOG(10) << "Startup op number " << startup_ops_.size(); - std::queue q; + std::queue q; for (auto *op : startup_ops_) { q.push(op); } @@ -746,8 +753,13 @@ std::vector> PartialGradTask::Run() { while (!q.empty()) { auto *op = q.front(); q.pop(); + VLOG(10) << "Start to run " << op->Type(); + op->EnforceHasInOut(); RunEachOp(op); + if (!retain_graph_) { + op->ClearBackwardTrace(); + } VLOG(10) << "End to run " << op->Type(); auto iter = pending_ops_.find(op); @@ -773,7 +785,7 @@ std::vector> PartialGradTask::Run() { return CreateResult(); } -void PartialGradTask::RunEachOp(const OpBase *op) { +void PartialGradTask::RunEachOp(OpBase *op) { // Prepare new inputs NameVarMap tmp_ins; for (auto &input_pair : op->GetInsMap()) { @@ -960,7 +972,8 @@ void PartialGradTask::PrepareInitialGradientAccumulators(const OpBase *op) { std::vector> PartialGradTask::CreateResult() { std::vector> result; result.reserve(input_targets_.size()); - for (auto &input_target : input_targets_) { + for (size_t i = 0; i < input_targets_.size(); ++i) { + auto &input_target = input_targets_[i]; PADDLE_ENFORCE_NOT_NULL( input_target->GradVarBase(), platform::errors::InvalidArgument("input should have gradient")); @@ -971,6 +984,12 @@ std::vector> PartialGradTask::CreateResult() { ready_var->SetOverridedStopGradient(!create_graph_); result.emplace_back(std::move(ready_var)); } else { // return None if it does not appear in the graph + PADDLE_ENFORCE_EQ(allow_unused_, true, + platform::errors::InvalidArgument( + "The %d-th input does not appear in the backward " + "graph. Please check the input variable or set " + "allow_unused=True to get None result.", + i)); result.emplace_back(); } } @@ -995,14 +1014,17 @@ PartialGradEngine::PartialGradEngine( const std::vector> &output_grads, const std::vector> &no_grad_vars, const platform::Place &place, const detail::BackwardStrategy &strategy, - bool create_graph) + bool create_graph, bool retain_graph, bool allow_unused, bool only_inputs) : input_targets_(input_targets), output_targets_(output_targets), output_grads_(output_grads), no_grad_vars_(no_grad_vars), place_(place), strategy_(strategy), - create_graph_(create_graph) {} + create_graph_(create_graph), + retain_graph_(retain_graph), + allow_unused_(allow_unused), + only_inputs_(only_inputs) {} std::vector> PartialGradEngine::GetResult() const { return results_; @@ -1017,7 +1039,8 @@ void PartialGradEngine::Clear() { void PartialGradEngine::Execute() { PartialGradTask task(input_targets_, output_targets_, output_grads_, - no_grad_vars_, place_, strategy_, create_graph_); + no_grad_vars_, place_, strategy_, create_graph_, + retain_graph_, allow_unused_, only_inputs_); VLOG(10) << "Starts to execute PartialGradEngine"; results_ = task.Run(); Clear(); diff --git a/paddle/fluid/imperative/partial_grad_engine.h b/paddle/fluid/imperative/partial_grad_engine.h index fde4703ad4..1a0fdcca4f 100644 --- a/paddle/fluid/imperative/partial_grad_engine.h +++ b/paddle/fluid/imperative/partial_grad_engine.h @@ -32,8 +32,8 @@ class PartialGradEngine : public Engine { const std::vector> &output_grads, const std::vector> &no_grad_vars, const platform::Place &place, - const detail::BackwardStrategy &strategy, - bool create_graph); + const detail::BackwardStrategy &strategy, bool create_graph, + bool retain_graph, bool allow_unused, bool only_inputs); void Execute() override; @@ -50,6 +50,9 @@ class PartialGradEngine : public Engine { platform::Place place_; detail::BackwardStrategy strategy_; bool create_graph_; + bool retain_graph_; + bool allow_unused_; + bool only_inputs_; std::vector> results_; }; diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 932d1b3992..b334e009bf 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -783,10 +783,11 @@ void BindImperative(py::module *m_ptr) { const std::vector> &no_grad_vars, const platform::Place &place, const imperative::detail::BackwardStrategy &strategy, - bool create_graph) { - imperative::PartialGradEngine engine(input_targets, output_targets, - output_grads, no_grad_vars, place, - strategy, create_graph); + bool create_graph, bool retain_graph, bool allow_unused, + bool only_inputs) { + imperative::PartialGradEngine engine( + input_targets, output_targets, output_grads, no_grad_vars, place, + strategy, create_graph, retain_graph, allow_unused, only_inputs); engine.Execute(); return engine.GetResult(); }, diff --git a/python/paddle/fluid/dygraph/base.py b/python/paddle/fluid/dygraph/base.py index 22b95dfd5f..2c55dc8951 100644 --- a/python/paddle/fluid/dygraph/base.py +++ b/python/paddle/fluid/dygraph/base.py @@ -23,6 +23,7 @@ import objgraph __all__ = [ 'no_grad', + 'grad', 'guard', 'enable_dygraph', 'disable_dygraph', @@ -254,9 +255,145 @@ def _print_debug_msg(parameter_list, limit=5, is_test=False): def grad(outputs, inputs, grad_outputs=None, - no_grad_set=None, + retain_graph=None, create_graph=False, + only_inputs=True, + allow_unused=False, + no_grad_vars=None, backward_strategy=None): + ''' + .. note:: + **This API is ONLY available in Dygraph mode.** + + This API computes the sum of gradients of `outputs` with respect to each `inputs` . + + Parameters: + outputs (Variable|list(Variable)|tuple(Variable)): the output Variable or + Variable list/tuple of the graph to compute gradients. + inputs (Variable|list(Variable)|tuple(Variable)): the input Variable or + Variable list/tuple of the graph to compute gradients. The returned + values of this API are the gradients of `inputs` . + grad_outputs (Variable|list(Variable|None)|tuple(Variable|None), optional): + initial gradient values of `outputs` . If `grad_outputs` is None, + the initial gradient values of `outputs` would be Tensors filled with 1; + if `grad_outputs` is not None, it must have the same length as `outputs` , + and in this case, the initial gradient value of the i-th `outputs` would + be: (1) a Tensor filled with 1 when the i-th element of `grad_outputs` + is None; (2) the i-th element of `grad_outputs` when the i-th element of + `grad_outputs` is a Variable. Default None. + retain_graph (bool, optional): whether to retain the forward graph which + is used to calculate the gradient. When it is True, the graph would + be retained, in which way users can calculate backward twice for the + same graph. When it is False, the graph would be freed. Default None, + which means it is equal to `create_graph` . + create_graph (bool, optional): whether to create the gradient graphs of + the computing process. When it is True, higher order derivatives are + supported to compute; when it is False, the gradient graphs of the + computing process would be discarded. Default False. + only_inputs (bool, optional): whether to only compute the gradients of + `inputs` . If it is False, the gradients of all remaining leaf + Variables in the graph would be also computed and accumulated. + If it is True, only the gradients of `inputs` would be computed. + Default True. only_inputs=False is under development, and it is + not supported yet. + allow_unused (bool, optional): whether to raise error or return None if some + Variables of `inputs` are unreachable in the graph. If some Variables of + `inputs` are unreachable in the graph (i.e., their gradients are None), + error would be raised if allow_unused=False, or None would be returned as + their gradients if allow_unused=True. Default False. + no_grad_vars (Variable|list(Variable)|tuple(Variable)|set(Variable), optional): + the Variables whose gradients are not needed to compute. Default None. + backward_strategy (BackwardStrategy, optional): The backward strategy to + compute gradients. See :ref:`api_fluid_dygraph_BackwardStrategy` for + details. Default None. + + Returns: + tuple: a tuple of Variables, whose length is the same as the Variable number + inside `inputs`, and the i-th returned Variable is the sum of gradients of + `outputs` with respect to the i-th `inputs`. + + Examples 1: + .. code-block:: python + + import paddle.fluid as fluid + + def test_dygraph_grad(create_graph): + with fluid.dygraph.guard(): + x = fluid.layers.ones(shape=[1], dtype='float32') + x.stop_gradient = False + y = x * x + + # Since y = x * x, dx = 2 * x + dx = fluid.dygraph.grad( + outputs=[y], + inputs=[x], + create_graph=create_graph, + retain_graph=True)[0] + + z = y + dx + + # If create_graph = False, the gradient of dx + # would not be backpropagated. Therefore, + # z = x * x + dx, and x.gradient() = 2 * x = 2.0 + + # If create_graph = True, the gradient of dx + # would be backpropagated. Therefore, + # z = x * x + dx = x * x + 2 * x, and + # x.gradient() = 2 * x + 2 = 4.0 + + z.backward() + return x.gradient() + + print(test_dygraph_grad(create_graph=False)) # [2.] + print(test_dygraph_grad(create_graph=True)) # [4.] + + Examples 2: + .. code-block:: python + + import paddle.fluid as fluid + + fluid.enable_dygraph() + + def test_dygraph_grad(grad_outputs=None): + x = fluid.layers.fill_constant(shape=[1], value=2.0, dtype='float32') + x.stop_gradient = False + + y1 = x * x + y2 = x * 3 + + # If grad_outputs=None, dy1 = [1], dy2 = [1]. + # If grad_outputs=[g1, g2], then: + # - dy1 = [1] if g1 is None else g1 + # - dy2 = [1] if g2 is None else g2 + + # Since y1 = x * x, dx = 2 * x * dy1. + # Since y2 = x * 3, dx = 3 * dy2. + # Therefore, the final result would be: + # dx = 2 * x * dy1 + 3 * dy2 = 4 * dy1 + 3 * dy2. + + dx = fluid.dygraph.grad( + outputs=[y1, y2], + inputs=[x], + grad_outputs=grad_outputs)[0] + + return dx.numpy() + + THREE = fluid.layers.fill_constant(shape=[1], value=3.0, dtype='float32') + FOUR = fluid.layers.fill_constant(shape=[1], value=4.0, dtype='float32') + + # dy1 = [1], dy2 = [1] + print(test_dygraph_grad(None)) # [7.] + + # dy1 = [1], dy2 = [4] + print(test_dygraph_grad([None, FOUR])) # [16.] + + # dy1 = [4], dy2 = [1] + print(test_dygraph_grad([FOUR, None])) # [19.] + + # dy1 = [3], dy2 = [4] + print(test_dygraph_grad([THREE, FOUR])) # [24.] + ''' + def check_in_out(in_out_list, name): assert in_out_list is not None, "{} should not be None".format(name) @@ -294,18 +431,18 @@ def grad(outputs, assert len(grad_outputs) == len( outputs), "The length of grad_outputs must be equal to outputs" - if no_grad_set is None: - no_grad_set = [] - elif isinstance(no_grad_set, core.VarBase): - no_grad_set = [no_grad_set] - elif isinstance(no_grad_set, (list, tuple, set)): - no_grad_set = list(no_grad_set) - for var in no_grad_set: + if no_grad_vars is None: + no_grad_vars = [] + elif isinstance(no_grad_vars, core.VarBase): + no_grad_vars = [no_grad_vars] + elif isinstance(no_grad_vars, (list, tuple, set)): + no_grad_vars = list(no_grad_vars) + for var in no_grad_vars: assert isinstance( - var, core.VarBase), "no_grad_set can only contains Variable" + var, core.VarBase), "no_grad_vars can only contains Variable" else: raise AssertionError( - "no_grad_set must be None, Variable or list/tuple/set of Variables") + "no_grad_vars must be None, Variable or list/tuple/set of Variables") if backward_strategy is None: backward_strategy = core.BackwardStrategy() @@ -315,10 +452,22 @@ def grad(outputs, assert isinstance(create_graph, bool), "create_graph must be True or False" + if retain_graph is None: + retain_graph = create_graph + + assert isinstance(retain_graph, + bool), "retain_graph must be None, True or False" + + assert isinstance(allow_unused, bool), "allow_unused must be True or False" + + assert isinstance(only_inputs, bool), "only_inputs must be True or False" + assert only_inputs, "only_inputs=False is not supported yet" + place = core.Place() place.set_place(framework._current_expected_place()) - return core.dygraph_partial_grad(inputs, outputs, grad_outputs, no_grad_set, - place, backward_strategy, create_graph) + return core.dygraph_partial_grad( + inputs, outputs, grad_outputs, no_grad_vars, place, backward_strategy, + create_graph, retain_graph, allow_unused, only_inputs) @framework.dygraph_only diff --git a/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py b/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py index a0e1de32cd..0fa1556a02 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py @@ -17,7 +17,6 @@ from paddle.fluid.wrapped_decorator import wrap_decorator import unittest from unittest import TestCase import numpy as np -from paddle.fluid.dygraph.base import grad def _dygraph_guard_(func): @@ -48,16 +47,20 @@ class TestDygraphDoubleGrad(TestCase): outputs, inputs, grad_outputs=None, - no_grad_set=None, - create_graph=False): + no_grad_vars=None, + retain_graph=None, + create_graph=False, + allow_unused=False): backward_strategy = fluid.dygraph.BackwardStrategy() backward_strategy.sort_sum_gradient = self.sort_sum_gradient - return grad( + return fluid.dygraph.grad( outputs=outputs, inputs=inputs, grad_outputs=grad_outputs, - no_grad_set=no_grad_set, + no_grad_vars=no_grad_vars, + retain_graph=retain_graph, create_graph=create_graph, + allow_unused=allow_unused, backward_strategy=backward_strategy) @dygraph_guard @@ -84,10 +87,11 @@ class TestDygraphDoubleGrad(TestCase): [random_var(shape)], [random_var(shape)]) with self.assertRaises(AssertionError): - self.grad([random_var(shape)], [random_var(shape)], no_grad_set=[1]) + self.grad( + [random_var(shape)], [random_var(shape)], no_grad_vars=[1]) with self.assertRaises(AssertionError): - self.grad([random_var(shape)], [random_var(shape)], no_grad_set=1) + self.grad([random_var(shape)], [random_var(shape)], no_grad_vars=1) @dygraph_guard def test_simple_example(self): @@ -96,17 +100,20 @@ class TestDygraphDoubleGrad(TestCase): y = x + 1 for create_graph in [False, True]: - dx, = self.grad([x], [x], create_graph=create_graph) + dx, = self.grad( + [x], [x], create_graph=create_graph, retain_graph=True) self.assertEqual(dx.shape, x.shape) self.assertTrue(np.all(dx.numpy() == 1)) self.assertNotEqual(dx.stop_gradient, create_graph) - dx_mul_2, = self.grad([y, x], [x], create_graph=create_graph) + dx_mul_2, = self.grad( + [y, x], [x], create_graph=create_graph, retain_graph=True) self.assertEqual(dx_mul_2.shape, x.shape) self.assertTrue(np.all(dx_mul_2.numpy() == 2)) self.assertNotEqual(dx_mul_2.stop_gradient, create_graph) - none_grad, = self.grad([x], [y], create_graph=create_graph) + none_grad, = self.grad( + [x], [y], create_graph=create_graph, allow_unused=True) self.assertTrue(none_grad is None) grad_with_none_and_not_none, = self.grad( @@ -160,7 +167,8 @@ class TestDygraphDoubleGrad(TestCase): outputs=[y, z], inputs=[x], grad_outputs=[grad_y, grad_z], - create_graph=create_graph) + create_graph=create_graph, + retain_graph=True) grad_y_np = ones_grad_y if grad_y is None else grad_y.numpy( ) @@ -216,7 +224,7 @@ class TestDygraphDoubleGrad(TestCase): self.assertTrue(np.allclose(x_grad_actual, x_grad_expected)) @dygraph_guard - def test_example_with_gradient_accumulation_and_no_grad_set(self): + def test_example_with_gradient_accumulation_and_no_grad_vars(self): x = random_var(self.shape) x_np = x.numpy() numel = x_np.size @@ -231,7 +239,7 @@ class TestDygraphDoubleGrad(TestCase): del y1, z, w dx_actual, = self.grad( - [w_mean], [x], create_graph=True, no_grad_set=[y2]) + [w_mean], [x], create_graph=True, no_grad_vars=[y2]) self.assertFalse(y2.stop_gradient) self.assertFalse(dx_actual.stop_gradient) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_star_gan_with_gradient_penalty.py b/python/paddle/fluid/tests/unittests/test_imperative_star_gan_with_gradient_penalty.py index 6122cc6ab8..7bd026daa9 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_star_gan_with_gradient_penalty.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_star_gan_with_gradient_penalty.py @@ -368,11 +368,10 @@ def loss_cls(cls, label, cfg): def calc_gradients(outputs, inputs, no_grad_set): if fluid.in_dygraph_mode(): - from paddle.fluid.dygraph.base import grad - return grad( + return fluid.dygraph.grad( outputs=outputs, inputs=inputs, - no_grad_set=no_grad_set, + no_grad_vars=no_grad_set, create_graph=True) else: return fluid.gradients( -- GitLab