diff --git a/python/paddle/autograd/functional.py b/python/paddle/autograd/functional.py index 17c7ad5b18af5f3e9b3b07f1f5c98fbf6ab56c49..66ae1562edb68a81231b70c477395814970cdd63 100644 --- a/python/paddle/autograd/functional.py +++ b/python/paddle/autograd/functional.py @@ -23,10 +23,11 @@ from .utils import _tensors, _stack_tensor_or_return_none, _replace_none_with_ze @contextlib.contextmanager def gradient_scope(*var_lists, create_graph=False, allow_unused=False): - def grad_fn(ys, xs, v, create_graph=create_graph): - assert len(ys) == len(v), ( - f'`v` is expected to be of the same size as the output. ' - f'Here the output is {ys}, and `v` is {v}.') + def grad_fn(ys, xs, v=None, create_graph=create_graph): + if v is not None: + assert len(ys) == len(v), ( + f'The argument {v} is expected to be of the same size as the output. ' + f'Here the output is {ys}, and `v` is {v}.') if allow_unused: ys = [ to_tensor( @@ -49,6 +50,8 @@ def gradient_scope(*var_lists, create_graph=False, allow_unused=False): return out def process(vl): + if vl is None: + return None out = [] # If v is treated as constant in the outer scope, its gradient is guaranteed # not to be taken beyond this scope. Within this scope, however, v's gradient @@ -151,7 +154,9 @@ def vjp(func, inputs, v=None, create_graph=False, allow_unused=False): # [[2., 1.], # [1., 0.]]), None] """ - xs, v = _tensors(inputs, "inputs"), _tensors(v, "v") + xs = _tensors(inputs, "inputs") + if v is not None: + v = _tensors(v, "v") with gradient_scope( xs, v, create_graph=create_graph, @@ -221,7 +226,9 @@ def jvp(func, inputs, v=None, create_graph=False, allow_unused=False): # [0., 0.]])] """ - xs, v = _tensors(inputs, "inputs"), _tensors(v, "v") + xs = _tensors(inputs, "inputs") + if v is not None: + v = _tensors(v, "v") with gradient_scope( xs, v, create_graph=create_graph, diff --git a/python/paddle/fluid/tests/unittests/autograd/test_vjp_jvp.py b/python/paddle/fluid/tests/unittests/autograd/test_vjp_jvp.py index f3680ab2a62238f2660a43801a0a82656c90c8c0..c228ad79321d4376ed43cc430ee8f939cd5f4b7c 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_vjp_jvp.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_vjp_jvp.py @@ -205,6 +205,16 @@ class TestVJP(TestAutogradFunctional): vjp_result, grad_result = vjp(), grad() self.check_results(grad_result, vjp_result) + def test_vjp_i2o2_omitting_v_no_create_graph(self): + test_cases = [ + [o2, ['A', 'A']], #noqa + ] #noqa + for f, inputs in test_cases: + inputs = self.gen_inputs(inputs) + vjp, grad = self.gen_test_pairs(f, inputs) + vjp_result, grad_result = vjp(), grad() + self.check_results(grad_result, vjp_result) + def test_vjp_nested_no_create_graph(self): x = self.gen_input('a') test_cases = [ @@ -289,6 +299,17 @@ class TestJVP(TestAutogradFunctional): reverse_jac = jac(vjp, f, inputs) self.check_results(forward_jac, reverse_jac) + def test_jvp_i2o2_omitting_v_no_create_graph(self): + test_cases = [ #noqa + [o2, ['A', 'A']], #noqa + ] #noqa + for f, inputs in test_cases: + inputs = self.gen_inputs(inputs) + results_omitting_v = jvp(f, inputs) + v = [ones_like(x) for x in inputs] + results_with_v = jvp(f, inputs, v) + self.check_results(results_omitting_v, results_with_v) + if __name__ == "__main__": unittest.main()