diff --git a/python/paddle/autograd/__init__.py b/python/paddle/autograd/__init__.py index f4a0122759dc5ddd3b98f5a7c6404d040637f837..cffc18e95e5ab3ff3bbb9fea6cca0e1579272866 100644 --- a/python/paddle/autograd/__init__.py +++ b/python/paddle/autograd/__init__.py @@ -18,6 +18,6 @@ from .backward_mode import backward # noqa: F401 from .py_layer import PyLayer, PyLayerContext # noqa: F401 from ..framework import set_grad_enabled # noqa: F401 from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401 -from .functional import jacobian, hessian # noqa: F401 +from .functional import vjp, jvp, jacobian, hessian # noqa: F401 __all__ = ['backward', 'PyLayer', 'PyLayerContext'] diff --git a/python/paddle/autograd/functional.py b/python/paddle/autograd/functional.py index a5665631c937f80a3d4469c796d9c64a5aa754d5..688e04335ebb700135c41dc6c483afcfdad63115 100644 --- a/python/paddle/autograd/functional.py +++ b/python/paddle/autograd/functional.py @@ -12,9 +12,239 @@ # See the License for the specific language governing permissions and # limitations under the License. -from paddle.fluid import framework -from .utils import _check_tensors, _stack_tensor_or_return_none, _replace_none_with_zero_tensor +import contextlib import paddle +from ..fluid import framework +from ..fluid.dygraph import grad +from ..nn.initializer import assign +from ..tensor import reshape, zeros_like, to_tensor +from .utils import _check_tensors, _stack_tensor_or_return_none, _replace_none_with_zero_tensor + + +def to_tensorlist(tl): + if not isinstance(tl, list): + if isinstance(tl, tuple): + tl = list(tl) + else: + tl = [tl] + for t in tl: + assert isinstance(t, paddle.Tensor) or t is None, ( + f'{t} is expected to be paddle.Tensor or None, but found {type(t)}.' + ) + return tl + + +@contextlib.contextmanager +def gradient_scope(*var_lists, create_graph=False, allow_unused=False): + def grad_fn(ys, xs, v, create_graph=create_graph): + assert len(ys) == len(v), ( + f'`v` is expected to be of the same size as the output. ' + f'Here the output is {ys}, and `v` is {v}.') + if allow_unused: + ys = [ + to_tensor( + [0.0], stop_gradient=False) if y is None else y for y in ys + ] + return grad( + ys, xs, v, create_graph=create_graph, allow_unused=allow_unused) + + def return_fn(out): + if isinstance(out, paddle.Tensor): + if not create_graph: + out = out.detach() + return out + if isinstance(out, list): + return list(return_fn(x) for x in out) + elif isinstance(out, tuple): + return tuple(return_fn(x) for x in out) + else: + assert out is None + return out + + def process(vl): + out = [] + # If v is treated as constant in the outer scope, its gradient is guaranteed + # not to be taken beyond this scope. Within this scope, however, v's gradient + # may be computed. We only need to detach v in this case. + # Otherwise, v's gradient is valid, and is subject to update beyond this scope. + # In this case we must not confuse the gradient in the outer scope with the + # inner one's. Moreover, we need to make sure that the result from the inner + # scope can flow back to the outer scope. This can be satisfied by extending + # the original variable with a duplication operation v1 = v so that v still + # maintains the complete lineage. + for v in vl: + if v is None: + out.append(v) + continue + if create_graph and not v.stop_gradient: + v = assign(v) + else: + v = v.detach() + v.stop_gradient = False + out.append(v) + return out + + try: + var_lists = [process(vl) for vl in var_lists] + bundle = var_lists + [grad_fn, return_fn] + yield bundle + finally: + pass + + +@framework.dygraph_only +def vjp(func, inputs, v=None, create_graph=False, allow_unused=False): + r"""Computes the Vector-Jacobian product, a functional form of + reverse mode automatic differentiation. + + Args: + func(Callable): `func` takes as input a tensor or a list + of tensors and returns a tensor or a list of tensors. + inputs(list[Tensor]|Tensor): used as positional arguments + to evaluate `func`. `inputs` is accepted as one tensor + or a list of tensors. + v(list[Tensor]|Tensor, optional): the cotangent vector + invovled in the VJP computation. `v` matches the size + and shape of `func`'s output. Default value is None + and in this case is equivalent to all ones the same size + of `func`'s output. + create_graph(bool, optional): if `True`, gradients can + be evaluated on the results. If `False`, taking gradients + on the results is invalid. Default value is False. + allow_unused(bool, optional): In case that some Tensors of + `inputs` do not contribute to the computation of the output. + If `allow_unused` is False, an error will be raised, + Otherwise, the gradients of the said inputs are returned + None. Default value is False. + + Returns: + output(tuple): + func_out: the output of `func(inputs)` + vjp(list[Tensor]|Tensor): the pullback results of `v` on `func` + + Examples: + .. code-block:: python + + def func(x): + return paddle.matmul(x, x) + + x = paddle.ones(shape=[2, 2], dtype='float32') + output, inputs_grad = vjp(func, x) + print(inputs_grad) + # [Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [[4., 4.], + # [4., 4.]])] + + v = paddle.to_tensor([[1.0, 0.0], [0.0, 0.0]]) + output, inputs_grad = vjp(func, x, v) + print(inputs_grad) + # [Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [[2., 1.], + # [1., 0.]])] + + output, inputs_grad = vjp(func, x, v, create_graph=True) + print(inputs_grad) + # [Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=False, + # [[2., 1.], + # [1., 0.]])] + + y = paddle.ones(shape=[2, 2], dtype='float32') + def func_unused(x, y): + return paddle.matmul(x, x) + + output, inputs_grad = vjp(func, [x, y], v) + # ValueError: (InvalidArgument) The 1-th input does not appear in the backward graph. + # Please check the input variable or set allow_unused=True to get None result. + # [Hint: Expected allow_unused_ == true, but received allow_unused_:0 != true:1.] + + output, inputs_grad = vjp(func, [x, y], v, allow_unused=True) + print(inputs_grad) + # [Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [[2., 1.], + # [1., 0.]]), None] + """ + xs, v = to_tensorlist(inputs), to_tensorlist(v) + + with gradient_scope( + xs, v, create_graph=create_graph, + allow_unused=allow_unused) as [xs, v, grad_fn, return_fn]: + outputs = func(*xs) + ys = to_tensorlist(outputs) + grads = grad_fn(ys, xs, v) + outputs, grads = return_fn(outputs), return_fn(grads) + + return outputs, grads + + +@framework.dygraph_only +def jvp(func, inputs, v=None, create_graph=False, allow_unused=False): + r""" + Computes the Jacobian-Vector product for a function at the given + inputs and a vector in the tangent space induced by the inputs. + + .. note:: + **This API is ONLY available in imperative mode.** + + Args: + func(Callable): `func` takes as input a tensor or a list + of tensors and returns a tensor or a list of tensors. + inputs(list[Tensor]|Tensor): used as positional arguments + to evaluate `func`. `inputs` is accepted as one tensor + or a list of tensors. + v(list[Tensor]|Tensor, optional): the tangent vector + invovled in the JVP computation. `v` matches the size + and shape of `inputs`. `v` is Optional if `func` returns + a single tensor. Default value is None and in this case + is equivalent to all ones the same size of `inputs`. + create_graph(bool, optional): if `True`, gradients can + be evaluated on the results. If `False`, taking gradients + on the results is invalid. Default value is False. + allow_unused(bool, optional): In case that some Tensors of + `inputs` do not contribute to the computation of the output. + If `allow_unused` is False, an error will be raised, + Otherwise, the gradients of the said inputs are returned + None. Default value is False. + + Returns: + output(tuple): + func_out: the output of `func(inputs)` + jvp(list[Tensor]|Tensor): the pullback results of `v` on `func` + + Examples: + .. code-block:: python + + def func(x): + return paddle.matmul(x, x) + + x = paddle.ones(shape=[2, 2], dtype='float32') + + output, inputs_grad = jvp(func, x) + print(inputs_grad) + # [Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [[2., 2.], + # [2., 2.]])] + + v = paddle.to_tensor([[1.0, 0.0], [0.0, 0.0]]) + output, inputs_grad = vjp(func, x, v) + print(inputs_grad) + # [Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [[1., 1.], + # [0., 0.]])] + + """ + xs, v = to_tensorlist(inputs), to_tensorlist(v) + + with gradient_scope( + xs, v, create_graph=create_graph, + allow_unused=allow_unused) as [xs, v, grad_fn, return_fn]: + outputs = func(*xs) + ys = to_tensorlist(outputs) + ys_grad = [zeros_like(y) for y in ys] + xs_grad = grad_fn(ys, xs, ys_grad, create_graph=True) + ys_grad = grad_fn(xs_grad, ys_grad, v) + outputs, ys_grad = return_fn(outputs), return_fn(ys_grad) + + return outputs, ys_grad @framework.dygraph_only @@ -60,7 +290,7 @@ def jacobian(func, inputs, create_graph=False, allow_unused=False): def func(x): return paddle.matmul(x, x) - + x = paddle.ones(shape=[2, 2], dtype='float32') x.stop_gradient = False jacobian = paddle.autograd.jacobian(func, x) @@ -78,7 +308,7 @@ def jacobian(func, inputs, create_graph=False, allow_unused=False): def func(x, y): return paddle.matmul(x, y) - + x = paddle.ones(shape=[2, 2], dtype='float32') y = paddle.ones(shape=[2, 2], dtype='float32') * 2 x.stop_gradient = False @@ -131,14 +361,12 @@ def jacobian(func, inputs, create_graph=False, allow_unused=False): outputs = _check_tensors(func(*inputs), "outputs") fin_size = len(inputs) fout_size = len(outputs) - flat_outputs = tuple( - paddle.reshape( - output, shape=[-1]) for output in outputs) + flat_outputs = tuple(reshape(output, shape=[-1]) for output in outputs) jacobian = tuple() for i, flat_output in enumerate(flat_outputs): jac_i = list([] for _ in range(fin_size)) for k in range(len(flat_output)): - row_k = paddle.grad( + row_k = grad( flat_output[k], inputs, create_graph=create_graph, @@ -146,7 +374,7 @@ def jacobian(func, inputs, create_graph=False, allow_unused=False): allow_unused=allow_unused) for j in range(fin_size): jac_i[j].append( - paddle.reshape( + reshape( row_k[j], shape=[-1]) if isinstance(row_k[j], paddle.Tensor) else None) jacobian += (tuple( @@ -273,7 +501,7 @@ def hessian(func, inputs, create_graph=False, allow_unused=False): ], "The function to compute Hessian matrix should return a Tensor with a single element" def jac_func(*ins): - grad_inputs = paddle.grad( + grad_inputs = grad( outputs, ins, create_graph=True, diff --git a/python/paddle/fluid/tests/unittests/autograd/test_vjp_jvp.py b/python/paddle/fluid/tests/unittests/autograd/test_vjp_jvp.py new file mode 100644 index 0000000000000000000000000000000000000000..86331d36a3ca8201b42ef02cc1b8edd29180cce0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/autograd/test_vjp_jvp.py @@ -0,0 +1,294 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle + +from paddle.autograd.functional import vjp, jvp, to_tensorlist +from paddle import grad, ones_like, zeros_like + + +def reduce(x): + return paddle.sum(x) + + +def reduce_dim(x): + return paddle.sum(x, axis=0) + + +def matmul(x, y): + return paddle.matmul(x, y) + + +def mul(x, y): + return x * y + + +def pow(x, y): + return paddle.pow(x, y) + + +def o2(x, y): + return paddle.multiply(x, y), paddle.matmul(x, y.t()) + + +def unuse(x, y): + return paddle.sum(x) + + +def nested(x): + def inner(y): + return x * y + + return inner + + +def make_v(f, inputs): + outputs = to_tensorlist(f(*inputs)) + return [ones_like(x) for x in outputs] + + +class TestAutogradFunctional(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.RAW_INPUTS = { + 'a': [1.0], + 'b': [1.0, 2.0], + 'c': [3.0, 4.0], + 'd': [[2.0], [3.0]], + 'A': [[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]], + 'B': [[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], + } + + def setUp(self): + pass + + def gen_input(self, inp, stop_gradient=False): + if isinstance(inp, paddle.Tensor): + return inp + return paddle.to_tensor( + self.RAW_INPUTS[inp], stop_gradient=stop_gradient) + + def gen_inputs(self, inputs): + if isinstance(inputs, list): + inputs = [self.gen_input(x) for x in inputs] + else: + inputs = [self.gen_input(inputs)] + return inputs + + def gen_test_pairs(self, + func, + inputs, + v=None, + create_graph=False, + allow_unused=False): + def vjp_test(): + nonlocal v + xs = self.gen_inputs(inputs) + if v is not None: + v = self.gen_inputs(v) + outputs, inputs_grad = vjp(func, + xs, + v, + create_graph=create_graph, + allow_unused=allow_unused) + else: + outputs, inputs_grad = vjp(func, + xs, + create_graph=create_graph, + allow_unused=allow_unused) + return outputs, inputs_grad + + def grad_test(): + nonlocal v + xs = self.gen_inputs(inputs) + if v is not None: + v = self.gen_inputs(v) + outputs = func(*xs) + if v is not None: + inputs_grad = grad( + outputs, + xs, + v, + create_graph=create_graph, + allow_unused=allow_unused) + else: + inputs_grad = grad( + outputs, + xs, + create_graph=create_graph, + allow_unused=allow_unused) + return outputs, inputs_grad + + return vjp_test, grad_test + + def gen_jvp_tests(self, + func, + inputs, + v=None, + create_graph=False, + allow_unused=False): + def jvp_test(): + nonlocal v + xs = self.gen_inputs(inputs) + if v is not None: + v = self.gen_inputs(v) + outputs, outputs_grad = jvp(func, + xs, + v, + create_graph=create_graph, + allow_unused=allow_unused) + else: + outputs, outputs_grad = jvp(func, + xs, + create_graph=create_graph, + allow_unused=allow_unused) + return outputs, outputs_grad + + return jvp_test + + def check_results(self, ref, res): + type_error = 'Result is different than expected in shape or type' + value_error = 'Result is different than expected values' + if ref is None: + self.assertTrue(res is None, type_error) + elif isinstance(ref, paddle.Tensor): + self.assertTrue(isinstance(res, paddle.Tensor), type_error) + self.assertTrue(paddle.allclose(res, ref), value_error) + else: + self.assertTrue(len(res) == len(ref), type_error) + for i in range(len(ref)): + self.check_results(ref[i], res[i]) + return True + + +class TestVJP(TestAutogradFunctional): + def test_vjp_i1o1_no_create_graph(self): + test_cases = [ + [reduce, 'A'], #noqa + [reduce_dim, 'A'], #noqa + ] #noqa + for f, inputs in test_cases: + vjp, grad = self.gen_test_pairs(f, inputs) + vjp_result, grad_result = vjp(), grad() + self.check_results(grad_result, vjp_result) + + def test_vjp_i2o1_no_create_graph(self): + test_cases = [ + [matmul, ['A', 'B']], #noqa + [mul, ['b', 'c']], #noqa + ] #noqa + for f, inputs in test_cases: + vjp, grad = self.gen_test_pairs(f, inputs) + vjp_result, grad_result = vjp(), grad() + self.check_results(grad_result, vjp_result) + + def test_vjp_i2o2_no_create_graph(self): + test_cases = [ + [o2, ['A', 'A']], #noqa + ] #noqa + for f, inputs in test_cases: + inputs = self.gen_inputs(inputs) + v = make_v(f, inputs) + vjp, grad = self.gen_test_pairs(f, inputs, v=v) + vjp_result, grad_result = vjp(), grad() + self.check_results(grad_result, vjp_result) + + def test_vjp_nested_no_create_graph(self): + x = self.gen_input('a') + test_cases = [ + [nested(x), 'a'], #noqa + ] + for f, inputs in test_cases: + vjp, grad = self.gen_test_pairs(f, inputs) + vjp_result, grad_result = vjp(), grad() + self.check_results(grad_result, vjp_result) + + def test_vjp_aliased_input_no_create_graph(self): + x = self.gen_input('a') + ref = self.gen_test_pairs(nested(x), 'a')[0] + aliased = self.gen_test_pairs(nested(x), x)[0] + ref_result, aliased_result = ref(), aliased() + self.check_results(ref_result, aliased_result) + + def test_vjp_allowunused_no_create_graph(self): + x, y = self.gen_input('A'), self.gen_input('a') + vjp, grad = self.gen_test_pairs(unuse, [x, y], allow_unused=True) + vjp_result, grad_result = vjp(), grad() + self.check_results(grad_result, vjp_result) + + +def jac(grad_fn, f, inputs): + assert grad_fn in [vjp, jvp] + if grad_fn is jvp: + vs = [zeros_like(x) for x in inputs] + else: + outputs = f(*inputs) + if isinstance(outputs, paddle.Tensor): + outputs = [outputs] + vs = [zeros_like(y) for y in outputs] + JJ_cols = [] + for i, v in enumerate(vs): + v = v.flatten() + for j in range(len(v)): + _v = zeros_like(v).detach() + _v[j] = 1.0 + _v = _v.reshape(vs[i].shape) + _vs = vs.copy() + _vs[i] = _v + _, grads = grad_fn(f, inputs, vs) + d_outs = paddle.concat([d_out.flatten() for d_out in grads]) + JJ_cols.append(d_outs) + # JJ is the fully unrolled jacobian + JJ = paddle.stack(JJ_cols) + if grad_fn is vjp: + JJ = JJ.t() + return JJ + + +class TestJVP(TestAutogradFunctional): + def test_jvp_i1o1_no_create_graph(self): + test_cases = [ + [reduce, 'A'], #noqa + [reduce_dim, 'A'], #noqa + ] #noqa + for f, inputs in test_cases: + inputs = self.gen_inputs(inputs) + forward_jac = jac(jvp, f, inputs) + reverse_jac = jac(vjp, f, inputs) + self.check_results(forward_jac, reverse_jac) + + def test_jvp_i2o1_no_create_graph(self): + test_cases = [ #noqa + [matmul, ['A', 'B']], #noqa + ] #noqa + for f, inputs in test_cases: + inputs = self.gen_inputs(inputs) + forward_jac = jac(jvp, f, inputs) + reverse_jac = jac(vjp, f, inputs) + self.check_results(forward_jac, reverse_jac) + + def test_jvp_i2o2_no_create_graph(self): + test_cases = [ #noqa + [o2, ['A', 'A']], #noqa + ] #noqa + for f, inputs in test_cases: + inputs = self.gen_inputs(inputs) + forward_jac = jac(jvp, f, inputs) + reverse_jac = jac(vjp, f, inputs) + self.check_results(forward_jac, reverse_jac) + + +if __name__ == "__main__": + unittest.main()