From 4d24d35231eba598a7a9cd9f4a66111ec59cc09d Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Mon, 29 Nov 2021 11:31:38 +0800 Subject: [PATCH] [New features] Support batch_jacobian and batch_hessian (#37547) * native commit for triple grad of sigmod * Updated unittests files * init functional jacobian api * Updated trible_test func * Updated gradient_checker & test_script * finish test with dtype float32 * add float64 test case * polish code * use atol=1e-5 with dtype float64 * fix for ci * set timeout for test_jacobian * fix dygraph grad to support high differential * polish API docstring * Updated gradient checker and some related files * fix double grad strip error for high differential * fix double grad strip error for high differential * Add Sigmoid triple grad tests * fix dygraph double grad dtype error when calling for high differential senario * Updated triple grad teses func * Use np.random to initialize ddx * Updated triple_grad_check func * add todo for gradient checker and refine some comments * remove additional code * add test for warnging in backward.py * format python code * support multi input in triple gradient checker * Add matmul triple grad kernel * Updated comments of TODO * Supported some special tests * Change code-format to follow CI std * Updated gradient_checker.py * Fix conflicts * Removed unnecessary printing log * Change code style to follow CI std * support batch in jacobian and hessian * add batch jacobian and batch hessian * Add batch_jacobian test, draft version * [New features] Add elementwise_mul triple grad kernel (#37152) * Add elementwise_mul triple grad kernel * Removed InplaceInferer and polished code * Add numerical_batch_jacobian,numerical_batch_hessian and tests * Support batch_jacobian and batch_numerical * Use pre-commit to check code format * Update doc, polish code, add unit test * Reset the TIMEOUT properties of test_jacobian to pass CI Co-authored-by: levi131 Co-authored-by: Jiabin Yang <360788950@qq.com> --- python/paddle/autograd/__init__.py | 3 +- python/paddle/autograd/functional.py | 291 ++++++++++++++++++ .../tests/unittests/autograd/CMakeLists.txt | 2 +- .../tests/unittests/autograd/test_hessian.py | 125 +++++++- .../tests/unittests/autograd/test_jacobian.py | 159 +++++++++- .../fluid/tests/unittests/autograd/utils.py | 67 ++++ 6 files changed, 643 insertions(+), 4 deletions(-) diff --git a/python/paddle/autograd/__init__.py b/python/paddle/autograd/__init__.py index bbfb9f22fc1..661bc3485b3 100644 --- a/python/paddle/autograd/__init__.py +++ b/python/paddle/autograd/__init__.py @@ -18,6 +18,7 @@ from .backward_mode import backward # noqa: F401 from .py_layer import PyLayer, PyLayerContext # noqa: F401 from ..framework import set_grad_enabled # noqa: F401 from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401 -from .functional import vjp, jvp, jacobian, hessian, vhp # noqa: F401 +from .functional import jacobian, hessian, batch_jacobian, batch_hessian # noqa: F401 +from .functional import vjp, jvp, vhp # noqa: F401 __all__ = ['backward', 'PyLayer', 'PyLayerContext'] diff --git a/python/paddle/autograd/functional.py b/python/paddle/autograd/functional.py index c6235877f5b..2e5adfa5dfb 100644 --- a/python/paddle/autograd/functional.py +++ b/python/paddle/autograd/functional.py @@ -385,6 +385,297 @@ def jacobian(func, inputs, create_graph=False, allow_unused=False): return jacobian +@framework.dygraph_only +def batch_jacobian(func, inputs, create_graph=False, allow_unused=False): + ''' + .. note:: + **This API is ONLY available in the imperative mode.** + + This function computes the batch Jacobian matrix of `func` with respect to `inputs`. + Noted that the first dimension of inputs is batch size. + + Parameters: + func (function): a Python function that takes a Tensor or a Tensor + list/tuple as inputs(the first dimension is batch size) and + returns a Tensor or a Tensor tuple. + inputs (Tensor|list(Tensor)|tuple(Tensor)): the input Tensor or + Tensor list/tuple of the function ``func``, Noted that + the first dimension of inputs is batch size. + create_graph (bool, optional): whether to create the gradient graphs + of the computing process. When it is True, higher order derivatives + are supported to compute; when it is False, the gradient graphs of + the computing process would be discarded. Defaults to ``False``. + allow_unused (bool, optional): whether to raise error or return None if + some Tensors of `inputs` are unreachable in the graph. Error would + be raised if allow_unused=False, and None would be returned as + their gradients if allow_unused=True. Default False. + Returns: + Jacobian (Tensor or nested tuple of Tensors): if function ``func`` + takes a Tensor as inputs and returns a Tensor as outputs, Jacobian + will be a single Tensor containing the Jacobian matrix for the + linearized inputs and outputs. If one of the inputs and outputs is + a Tensor, and another is a Tensor list/tuple, then the Jacobian will + be a tuple of Tensors. If both of inputs and outputs are Tensor + list/tuple, then the Jacobian will be a tuple of tuple of Tensors. + Noted that the first dimension of inputs is batch size. + + For example, + the inputs shape and outputs shape of function ``func` is [batch_size, num] + and [batch_size, num] respectively, then the Jacobian will be a Tensor with + a shape of [num, batch_size * num], where ``Jacobian[i][j]`` will contain + the Jacobian matrix of the ``i``th column output and the ``j``th input and + will have same dtype and device as the corresponding input. + Other situations can be deduced by analogy. + + Examples 1: + .. code-block:: python + + import paddle + + x = paddle.ones(shape=(4, 2), dtype='float64') + weight = paddle.ones(shape=(2, 4), dtype='float64') + y = paddle.ones(shape=(4, 2), dtype='float64') + + def func(x): + return paddle.matmul(paddle.matmul(x, weight), y) + + x.stop_gradient = False + batch_jacobian = paddle.autograd.batch_jacobian(func, x) + print(batch_jacobian) + # Tensor(shape=[2, 8], dtype=float64, place=CUDAPlace(0), stop_gradient=True, + # [[4., 4., 4., 4., 4., 4., 4., 4.], + # [4., 4., 4., 4., 4., 4., 4., 4.]]) + + Examples 2: + .. code-block:: python + + import paddle + + x = paddle.ones(shape=(4, 2), dtype='float64') + weight = paddle.ones(shape=(2, 4), dtype='float64') + y = paddle.ones(shape=(4, 2), dtype='float64') + + def func(x): + return paddle.matmul(paddle.matmul(x, weight), y), x * x + + x.stop_gradient = False + batch_jacobian = paddle.autograd.batch_jacobian(func, x) + print(batch_jacobian) + # (Tensor(shape=[2, 8], dtype=float64, place=CUDAPlace(0), stop_gradient=True, + # [[4., 4., 4., 4., 4., 4., 4., 4.], + # [4., 4., 4., 4., 4., 4., 4., 4.]]), Tensor(shape=[2, 8], dtype=float64, place=CUDAPlace(0), stop_gradient=True, + # [[2., 0., 2., 0., 2., 0., 2., 0.], + # [0., 2., 0., 2., 0., 2., 0., 2.]])) + + Examples 3: + .. code-block:: python + + import paddle + + x = paddle.ones(shape=(4, 2), dtype='float64') + weight = paddle.ones(shape=(2, 4), dtype='float64') + y = paddle.ones(shape=(4, 2), dtype='float64') + + def func(x, y): + return x * y + + x.stop_gradient = False + y.stop_gradient = False + batch_jacobian = paddle.autograd.batch_jacobian(func, [x, y]) + print(batch_jacobian) + # (Tensor(shape=[2, 8], dtype=float64, place=CUDAPlace(0), stop_gradient=True, + # [[1., 0., 1., 0., 1., 0., 1., 0.], + # [0., 1., 0., 1., 0., 1., 0., 1.]]), Tensor(shape=[2, 8], dtype=float64, place=CUDAPlace(0), stop_gradient=True, + # [[1., 0., 1., 0., 1., 0., 1., 0.], + # [0., 1., 0., 1., 0., 1., 0., 1.]])) + + ''' + inputs = _tensors(inputs, "inputs") + outputs = _tensors(func(*inputs), "outputs") + batch_size = inputs[0].shape[0] + for input in inputs: + assert input.shape[ + 0] == batch_size, "The first dimension of input should equals to the same batch size!" + for output in outputs: + assert output.shape[ + 0] == batch_size, "The first dimension of output should equals to the same batch size!" + fin_size = len(inputs) + fout_size = len(outputs) + flat_outputs = tuple( + reshape( + output, shape=[batch_size, -1]) for output in outputs) + jacobian = tuple() + for i, flat_output in enumerate(flat_outputs): + jac_i = list([] for _ in range(fin_size)) + for k in range(flat_output.shape[1]): + row_k = grad( + flat_output[:, k], + inputs, + create_graph=create_graph, + retain_graph=True, + allow_unused=allow_unused) + for j in range(fin_size): + jac_i[j].append( + reshape( + row_k[j], shape=[-1]) + if isinstance(row_k[j], paddle.Tensor) else None) + jacobian += (tuple( + _stack_tensor_or_return_none(jac_i_j) for jac_i_j in jac_i), ) + if fin_size == 1 and fout_size == 1: + return jacobian[0][0] + elif fin_size == 1 and fout_size != 1: + return tuple(jacobian[i][0] for i in range(fout_size)) + elif fin_size != 1 and fout_size == 1: + return jacobian[0] + else: + return jacobian + + +@framework.dygraph_only +def batch_hessian(func, inputs, create_graph=False, allow_unused=False): + ''' + .. note:: + **This API is ONLY available in the imperative mode.** + + This function computes the batch Hessian matrix of `func` with respect to `inputs`. + Noted that the first dimension of inputs is batch size. + + Parameters: + func (function): a Python function that takes a Tensor or a Tensor + list/tuple as inputs(the first dimension is batch size) and + returns a Tensor with shape [batch_size, 1]. + inputs (Tensor|list(Tensor)|tuple(Tensor)): the input Tensor or + Tensor list/tuple of the function ``func``. + Noted that the first dimension of inputs is batch size. + create_graph (bool, optional): whether to create the gradient graphs + of the computing process. When it is True, higher order derivatives + are supported to compute; when it is False, the gradient graphs of + the computing process would be discarded. Defaults to ``False``. + allow_unused (bool, optional): whether to raise error or return None if + some Tensors of `inputs` are unreachable in the graph. Error would + be raised if allow_unused=False, and None would be returned as + their gradients if allow_unused=True. Default False. + Returns: + Hessian (Tensor or a tuple of tuple of Tensors): if function ``func`` + takes a Tensor as ``inputs``, Hessian will be a single Tensor containing + the Hessian matrix for the linearized ``inputs`` Tensor. If function + ``func`` takes a Tensor list/tuple as ``inputs``, then the Hessian will + be a tuple of tuple of Tensors. Noted that the first dimension of inputs + is batch size and the execution step is to obtain the result of the + first order differentiation, and then differentiate the batch input. + + For example, + the inputs shape and outputs shape of function ``func` is [batch_size, num] + and [batch_size, 1] respectively, then the batched Hessian will be a Tensor with + a shape of [num, batch_size * num]. + + Why the final shape in this case is that? + because batch_hessian will create a inner func(the wrapper of paddle.grad() func) + to computes the sum of gradients of `outputs` with respect to each `inputs`, + this inner func will get the first order differentiation and shape is [batch_size, num], + then call batch_jacobian to compute jacobian between the first order differentiation + and the origin inputs. The final result ``Hessian[i][j]`` will contain the Jacobian + matrix of the ``i``th column output(Noted that this output means the first order + differentiation) and the ``j``th input and will have same dtype and device as the + corresponding input. Other situations can be deduced by analogy. + + + Examples 1: + .. code-block:: python + + import paddle + + x = paddle.ones(shape=(4, 2), dtype='float64') + weight = paddle.ones(shape=(2, 4), dtype='float64') + y = paddle.ones(shape=(4, 2), dtype='float64') + + def func(x): + return paddle.matmul(x * x, weight)[:, 0:1] + + + x.stop_gradient = False + batch_hessian = paddle.autograd.batch_hessian(func, x) + print(batch_hessian) + # Tensor(shape=[2, 8], dtype=float64, place=CUDAPlace(0), stop_gradient=True, + # [[2., 0., 2., 0., 2., 0., 2., 0.], + # [0., 2., 0., 2., 0., 2., 0., 2.]]) + + Examples 2: + .. code-block:: python + + import paddle + + x = paddle.ones(shape=(4, 2), dtype='float64') + weight = paddle.ones(shape=(2, 4), dtype='float64') + y = paddle.ones(shape=(4, 2), dtype='float64') + + def func(x, y): + return paddle.matmul(x * x * y * y, weight)[:, 0:1] + + x.stop_gradient = False + y.stop_gradient = False + batch_hessian = paddle.autograd.batch_hessian(func, [x, y]) + print(batch_hessian) + # ((Tensor(shape=[2, 8], dtype=float64, place=CUDAPlace(0), stop_gradient=True, + # [[2., 0., 2., 0., 2., 0., 2., 0.], + # [0., 2., 0., 2., 0., 2., 0., 2.]]), + # Tensor(shape=[2, 8], dtype=float64, place=CUDAPlace(0), stop_gradient=True, + # [[4., 0., 4., 0., 4., 0., 4., 0.], + # [0., 4., 0., 4., 0., 4., 0., 4.]])), + # (Tensor(shape=[2, 8], dtype=float64, place=CUDAPlace(0), stop_gradient=True, + # [[4., 0., 4., 0., 4., 0., 4., 0.], + # [0., 4., 0., 4., 0., 4., 0., 4.]]), + # Tensor(shape=[2, 8], dtype=float64, place=CUDAPlace(0), stop_gradient=True, + # [[2., 0., 2., 0., 2., 0., 2., 0.], + # [0., 2., 0., 2., 0., 2., 0., 2.]]))) + + + Examples 3: + .. code-block:: python + + import paddle + + x = paddle.ones(shape=(4, 2), dtype='float64') + weight = paddle.ones(shape=(2, 4), dtype='float64') + y = paddle.ones(shape=(4, 2), dtype='float64') + + def func(x, y): + return paddle.matmul(x * x, weight)[:, 0:1] + + x.stop_gradient = False + y.stop_gradient = False + batch_hessian = paddle.autograd.batch_hessian(func, [x, y], allow_unused=True) + print(batch_hessian) + # ((Tensor(shape=[2, 8], dtype=float64, place=CUDAPlace(0), stop_gradient=True, + # [[2., 0., 2., 0., 2., 0., 2., 0.], + # [0., 2., 0., 2., 0., 2., 0., 2.]]), None), (None, None)) + + ''' + inputs = _tensors(inputs, "inputs") + outputs = func(*inputs) + batch_size = inputs[0].shape[0] + for input in inputs: + assert input.shape[ + 0] == batch_size, "The first dimension of input should equals to the same batch size!" + assert isinstance(outputs, paddle.Tensor) and outputs.shape == [ + batch_size, 1 + ], "The function to compute batched Hessian matrix should return a Tensor of shape [batch_size, 1]" + + def jac_func(*ins): + grad_inputs = grad( + outputs, + ins, + create_graph=True, + retain_graph=True, + allow_unused=allow_unused) + return tuple( + _replace_none_with_zero_tensor(grad_inputs[i], inputs[i]) + for i in range(len(inputs))) + + return batch_jacobian( + jac_func, inputs, create_graph=create_graph, allow_unused=allow_unused) + + @framework.dygraph_only def hessian(func, inputs, create_graph=False, allow_unused=False): ''' diff --git a/python/paddle/fluid/tests/unittests/autograd/CMakeLists.txt b/python/paddle/fluid/tests/unittests/autograd/CMakeLists.txt index 30d87e2c9b2..6d9625483ea 100644 --- a/python/paddle/fluid/tests/unittests/autograd/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/autograd/CMakeLists.txt @@ -6,6 +6,6 @@ foreach(TEST_OP ${TEST_OPS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS}) endforeach(TEST_OP) -set_tests_properties(test_jacobian PROPERTIES TIMEOUT 20) +set_tests_properties(test_jacobian PROPERTIES TIMEOUT 50) set_tests_properties(test_hessian PROPERTIES TIMEOUT 50) set_tests_properties(test_vhp PROPERTIES TIMEOUT 50) diff --git a/python/paddle/fluid/tests/unittests/autograd/test_hessian.py b/python/paddle/fluid/tests/unittests/autograd/test_hessian.py index 1aa0d94de16..7b3bd9fd559 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_hessian.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_hessian.py @@ -17,7 +17,7 @@ import numpy as np import paddle import paddle.compat as cpt import paddle.nn.functional as F -from utils import _compute_numerical_hessian +from utils import _compute_numerical_hessian, _compute_numerical_batch_hessian class TestHessian(unittest.TestCase): @@ -136,5 +136,128 @@ class TestHessianFloat64(TestHessian): self.y = paddle.rand(shape=self.shape, dtype=self.dtype) +class TestBatchHessian(unittest.TestCase): + @classmethod + def setUpClass(self): + self.x_shape = (5, 2) + self.weight_shape = (2, 4) + self.y_shape = (5, 2) + self.dtype = 'float32' + self.np_dtype = np.float32 + self.numerical_delta = 1e-2 + self.rtol = 1e-3 + self.atol = 1e-3 + self.x = paddle.rand(shape=self.x_shape, dtype=self.dtype) + self.weight = paddle.rand(shape=self.weight_shape, dtype=self.dtype) + self.y = paddle.rand(shape=self.y_shape, dtype=self.dtype) + + def test_single_input(self): + def func(x): + return paddle.matmul(x * x, self.weight)[:, 0:1] + + numerical_hessian = _compute_numerical_batch_hessian( + func, self.x, self.numerical_delta, self.np_dtype) + self.x.stop_gradient = False + hessian = paddle.autograd.batch_hessian(func, self.x, create_graph=True) + assert np.allclose(hessian, numerical_hessian, self.rtol, self.atol) + + def test_multi_input(self): + def func(x, y): + return paddle.matmul(x * x * y * y, self.weight)[:, 0:1] + + numerical_hessian = _compute_numerical_batch_hessian( + func, [self.x, self.y], self.numerical_delta, self.np_dtype) + + self.x.stop_gradient = False + self.y.stop_gradient = False + hessian = paddle.autograd.batch_hessian(func, [self.x, self.y]) + + shape_tensor = paddle.to_tensor(numerical_hessian).astype("float64") + hessian_reshape = np.reshape(hessian, (shape_tensor.shape)) + assert np.allclose(hessian_reshape, numerical_hessian, self.rtol, + self.atol) + + def test_allow_unused_false(self): + def func(x, y): + return paddle.matmul(x * x, self.weight)[:, 0:1] + + try: + self.x.stop_gradient = False + self.y.stop_gradient = False + hessian = paddle.autograd.batch_hessian(func, [self.x, self.y]) + except ValueError as e: + error_msg = cpt.get_exception_message(e) + assert error_msg.find("allow_unused") > 0 + + def test_allow_unused_true(self): + def func(x, y): + return paddle.matmul(x * x, self.weight)[:, 0:1] + + numerical_hessian = _compute_numerical_batch_hessian( + func, [self.x, self.y], self.numerical_delta, self.np_dtype) + self.x.stop_gradient = False + self.y.stop_gradient = False + hessian = paddle.autograd.batch_hessian( + func, [self.x, self.y], allow_unused=True) + + for i in range(len(hessian)): + for j in range(len(hessian[0])): + if i == j == 0: + numerical_hessian = np.stack( + (numerical_hessian[i][j], numerical_hessian[i][j + 1]), + axis=0) + assert np.allclose(hessian[i][j], numerical_hessian, + self.rtol, self.atol) + else: + assert hessian[i][j] is None + + def test_create_graph_false(self): + def func(x): + return paddle.matmul(x * x, self.weight)[:, 0:1] + + numerical_hessian = _compute_numerical_batch_hessian( + func, self.x, self.numerical_delta, self.np_dtype) + self.x.stop_gradient = False + hessian = paddle.autograd.batch_hessian(func, self.x) + assert hessian.stop_gradient == True + assert np.allclose(hessian.numpy(), numerical_hessian, self.rtol, + self.atol) + try: + paddle.grad(hessian, self.x) + except RuntimeError as e: + error_msg = cpt.get_exception_message(e) + assert error_msg.find("has no gradient") > 0 + + def test_create_graph_true(self): + def func(x): + return paddle.matmul(x * x, self.weight)[:, 0:1] + + numerical_hessian = _compute_numerical_batch_hessian( + func, self.x, self.numerical_delta, self.np_dtype) + self.x.stop_gradient = False + hessian = paddle.autograd.batch_hessian(func, self.x, create_graph=True) + assert hessian.stop_gradient == False + assert np.allclose(hessian.numpy(), numerical_hessian, self.rtol, + self.atol) + triple_grad = paddle.grad(hessian, self.x) + assert triple_grad is not None + + +class TestBatchHessianFloat64(TestBatchHessian): + @classmethod + def setUpClass(self): + self.x_shape = (5, 2) + self.weight_shape = (2, 4) + self.y_shape = (5, 2) + self.dtype = 'float64' + self.np_dtype = np.float64 + self.numerical_delta = 1e-4 + self.rtol = 1e-5 + self.atol = 1e-5 + self.x = paddle.rand(shape=self.x_shape, dtype=self.dtype) + self.weight = paddle.rand(shape=self.weight_shape, dtype=self.dtype) + self.y = paddle.rand(shape=self.y_shape, dtype=self.dtype) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/autograd/test_jacobian.py b/python/paddle/fluid/tests/unittests/autograd/test_jacobian.py index 2f0b8c7cad3..335ea4e519b 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_jacobian.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_jacobian.py @@ -16,7 +16,7 @@ import unittest import numpy as np import paddle import paddle.compat as cpt -from utils import _compute_numerical_jacobian +from utils import _compute_numerical_jacobian, _compute_numerical_batch_jacobian class TestJacobian(unittest.TestCase): @@ -158,5 +158,162 @@ class TestJacobianFloat64(TestJacobian): self.y = paddle.rand(shape=self.shape, dtype=self.dtype) +class TestJacobianBatch(unittest.TestCase): + @classmethod + def setUpClass(self): + self.x_shape = (4, 2) + self.weight_shape = (2, 4) + self.y_shape = (4, 2) + self.dtype = 'float32' + self.np_dtype = np.float32 + self.numerical_delta = 1e-4 + self.rtol = 1e-3 + self.atol = 1e-3 + self.x = paddle.rand(shape=self.x_shape, dtype=self.dtype) + self.weight = paddle.rand(shape=self.weight_shape, dtype=self.dtype) + self.y = paddle.rand(shape=self.y_shape, dtype=self.dtype) + + def test_batch_single_input_and_batch_single_output(self): + def func(x): + return paddle.matmul(paddle.matmul(x, self.weight), self.y) + + numerical_jacobian = _compute_numerical_batch_jacobian( + func, [self.x], self.numerical_delta, self.np_dtype) + + self.x.stop_gradient = False + batch_jacobian = paddle.autograd.batch_jacobian( + func, + self.x, ) + + self.assertTrue( + np.allclose(batch_jacobian.numpy().all(), numerical_jacobian[0][0] + .all())) + + def test_batch_single_input_and_batch_multi_output(self): + def func(x): + return paddle.matmul(paddle.matmul(x, self.weight), self.y), x * x + + numerical_jacobian = _compute_numerical_batch_jacobian( + func, [self.x], self.numerical_delta, self.np_dtype) + + self.x.stop_gradient = False + batch_jacobian = paddle.autograd.batch_jacobian( + func, + self.x, ) + + for i in range(len(batch_jacobian)): + assert np.allclose(batch_jacobian[i].numpy(), + numerical_jacobian[i][0], self.rtol, self.atol) + + def test_batch_multi_input_and_batch_single_output(self): + def func(x, y): + return x * y + + numerical_jacobian = _compute_numerical_batch_jacobian( + func, [self.x, self.y], self.numerical_delta, self.np_dtype) + + self.x.stop_gradient = False + self.y.stop_gradient = False + batch_jacobian = paddle.autograd.batch_jacobian(func, [self.x, self.y]) + + for j in range(len(batch_jacobian)): + assert np.allclose(batch_jacobian[j].numpy(), + numerical_jacobian[0][j], self.rtol, self.atol) + + def test_batch_multi_input_and_batch_multi_output(self): + def func(x, y): + return x * y, x * y + + numerical_jacobian = _compute_numerical_batch_jacobian( + func, [self.x, self.y], self.numerical_delta, self.np_dtype) + + self.x.stop_gradient = False + self.y.stop_gradient = False + batch_jacobian = paddle.autograd.batch_jacobian(func, [self.x, self.y]) + + for i in range(len(batch_jacobian)): + assert np.allclose(batch_jacobian[i], numerical_jacobian[i], + self.rtol, self.atol) + + def test_allow_unused_false(self): + def func(x, y): + return x * x + + try: + self.x.stop_gradient = False + self.y.stop_gradient = False + jacobian = paddle.autograd.batch_jacobian(func, [self.x, self.y]) + except ValueError as e: + error_msg = cpt.get_exception_message(e) + assert error_msg.find("allow_unused") > 0 + + def test_allow_unused_true(self): + def func(x, y): + return x * x + + numerical_jacobian = _compute_numerical_batch_jacobian( + func, [self.x, self.y], self.numerical_delta, self.np_dtype) + self.x.stop_gradient = False + self.y.stop_gradient = False + jacobian = paddle.autograd.batch_jacobian( + func, [self.x, self.y], allow_unused=True) + + assert np.allclose(jacobian[0].numpy(), numerical_jacobian[0][0], + self.rtol, self.atol) + assert jacobian[1] is None + + def test_create_graph_false(self): + def func(x, y): + return x * y + + numerical_jacobian = _compute_numerical_batch_jacobian( + func, [self.x, self.y], self.numerical_delta, self.np_dtype) + self.x.stop_gradient = False + self.y.stop_gradient = False + jacobian = paddle.autograd.batch_jacobian(func, [self.x, self.y]) + for j in range(len(jacobian)): + assert jacobian[j].stop_gradient == True + assert np.allclose(jacobian[j].numpy(), numerical_jacobian[0][j], + self.rtol, self.atol) + try: + paddle.grad(jacobian[0], [self.x, self.y]) + except RuntimeError as e: + error_msg = cpt.get_exception_message(e) + assert error_msg.find("has no gradient") > 0 + + def test_create_graph_true(self): + def func(x, y): + return x * y + + numerical_jacobian = _compute_numerical_batch_jacobian( + func, [self.x, self.y], self.numerical_delta, self.np_dtype) + self.x.stop_gradient = False + self.y.stop_gradient = False + jacobian = paddle.autograd.batch_jacobian( + func, [self.x, self.y], create_graph=True) + for j in range(len(jacobian)): + assert jacobian[j].stop_gradient == False + assert np.allclose(jacobian[j].numpy(), numerical_jacobian[0][j], + self.rtol, self.atol) + double_grad = paddle.grad(jacobian[0], [self.x, self.y]) + assert double_grad is not None + + +class TestJacobianBatchFloat64(TestJacobianBatch): + @classmethod + def setUpClass(self): + self.x_shape = (12, 2) + self.weight_shape = (2, 12) + self.y_shape = (12, 2) + self.dtype = 'float64' + self.np_dtype = np.float64 + self.numerical_delta = 1e-7 + self.rtol = 1e-7 + self.atol = 1e-7 + self.x = paddle.rand(shape=self.x_shape, dtype=self.dtype) + self.weight = paddle.rand(shape=self.weight_shape, dtype=self.dtype) + self.y = paddle.rand(shape=self.y_shape, dtype=self.dtype) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/autograd/utils.py b/python/paddle/fluid/tests/unittests/autograd/utils.py index 402e89ae476..b06ce6ed7cc 100644 --- a/python/paddle/fluid/tests/unittests/autograd/utils.py +++ b/python/paddle/fluid/tests/unittests/autograd/utils.py @@ -107,6 +107,73 @@ def _compute_numerical_hessian(func, xs, delta, np_dtype): return hessian +def _compute_numerical_batch_jacobian(func, xs, delta, np_dtype): + no_batch_jacobian = _compute_numerical_jacobian(func, xs, delta, np_dtype) + xs = _tensors(xs, "xs") + ys = _tensors(func(*xs), "ys") + fin_size = len(xs) + fout_size = len(ys) + bs = xs[0].shape[0] + bat_jac = [] + for i in range(fout_size): + batch_jac_i = [] + for j in range(fin_size): + jac = no_batch_jacobian[i][j] + jac_shape = jac.shape + out_size = jac_shape[0] // bs + in_size = jac_shape[1] // bs + jac = np.reshape(jac, (bs, out_size, bs, in_size)) + batch_jac_i_j = np.zeros(shape=(out_size, bs, in_size)) + for p in range(out_size): + for b in range(bs): + for q in range(in_size): + batch_jac_i_j[p][b][q] = jac[b][p][b][q] + batch_jac_i_j = np.reshape(batch_jac_i_j, (out_size, -1)) + batch_jac_i.append(batch_jac_i_j) + bat_jac.append(batch_jac_i) + + return bat_jac + + +def _compute_numerical_batch_hessian(func, xs, delta, np_dtype): + xs = _tensors(xs, "xs") + batch_size = xs[0].shape[0] + fin_size = len(xs) + hessian = [] + for b in range(batch_size): + x_l = [] + for j in range(fin_size): + x_l.append(paddle.reshape(xs[j][b], shape=[1, -1])) + hes_b = _compute_numerical_hessian(func, x_l, delta, np_dtype) + if fin_size == 1: + hessian.append(hes_b[0][0]) + else: + hessian.append(hes_b) + + hessian_res = [] + for index in range(fin_size): + x_reshape = paddle.reshape(xs[index], shape=[batch_size, -1]) + for index_ in range(fin_size): + for i in range(x_reshape.shape[1]): + tmp = [] + for j in range(batch_size): + if fin_size == 1: + tmp.extend(hessian[j][i]) + else: + tmp.extend(hessian[j][i][index_][index]) + hessian_res.append(tmp) + if fin_size == 1: + return hessian_res + + hessian_result = [] + mid = len(hessian_res) // 2 + for i in range(mid): + hessian_result.append( + np.stack( + (hessian_res[i], hessian_res[mid + i]), axis=0)) + return hessian_result + + def _compute_numerical_vjp(func, xs, v, delta, np_dtype): xs = _tensors(xs, "xs") jacobian = np.array(_compute_numerical_jacobian(func, xs, delta, np_dtype)) -- GitLab