diff --git a/python/paddle/autograd/__init__.py b/python/paddle/autograd/__init__.py index dfbb3cfb45f2be3cb468932d4b3b9f22ed3ad81b..f4a0122759dc5ddd3b98f5a7c6404d040637f837 100644 --- a/python/paddle/autograd/__init__.py +++ b/python/paddle/autograd/__init__.py @@ -18,6 +18,6 @@ from .backward_mode import backward # noqa: F401 from .py_layer import PyLayer, PyLayerContext # noqa: F401 from ..framework import set_grad_enabled # noqa: F401 from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401 -from .functional import jacobian # noqa: F401 +from .functional import jacobian, hessian # noqa: F401 __all__ = ['backward', 'PyLayer', 'PyLayerContext'] diff --git a/python/paddle/autograd/functional.py b/python/paddle/autograd/functional.py index c1b4dd9e3a2db86290fc3b2ecda5360b209ebec1..a5665631c937f80a3d4469c796d9c64a5aa754d5 100644 --- a/python/paddle/autograd/functional.py +++ b/python/paddle/autograd/functional.py @@ -13,34 +13,10 @@ # limitations under the License. from paddle.fluid import framework +from .utils import _check_tensors, _stack_tensor_or_return_none, _replace_none_with_zero_tensor import paddle -def _check_tensors(in_out_list, name): - assert in_out_list is not None, "{} should not be None".format(name) - - if isinstance(in_out_list, (list, tuple)): - assert len(in_out_list) > 0, "{} connot be empyt".format(name) - for each_var in in_out_list: - assert isinstance( - each_var, - paddle.Tensor), "Elements of {} must be paddle.Tensor".format( - name) - return in_out_list - else: - assert isinstance( - in_out_list, - paddle.Tensor), "{} must be Tensor or list of Tensor".format(name) - return [in_out_list] - - -def _stack_tensor_or_return_none(origin_list): - assert len(origin_list) > 0, "Can't not stack an empty list" - return paddle.stack( - origin_list, axis=0) if isinstance(origin_list[0], - paddle.Tensor) else None - - @framework.dygraph_only def jacobian(func, inputs, create_graph=False, allow_unused=False): ''' @@ -183,3 +159,129 @@ def jacobian(func, inputs, create_graph=False, allow_unused=False): return jacobian[0] else: return jacobian + + +@framework.dygraph_only +def hessian(func, inputs, create_graph=False, allow_unused=False): + ''' + .. note:: + **This API is ONLY available in imperative mode.** + + This API computes the Hessian matrix of `func` with respect to `inputs`. + + Parameters: + func (function): a Python function that takes a Tensor or a Tensor + list/tuple as inputs and returns a Tensor with a single element. + inputs (Tensor|list(Tensor)|tuple(Tensor)): the input Tensor or + Tensor list/tuple of the function ``func``. + create_graph (bool, optional): whether to create the gradient graphs + of the computing process. When it is True, higher order derivatives + are supported to compute; when it is False, the gradient graphs of + the computing process would be discarded. Defaults to ``False``. + allow_unused (bool, optional): whether to raise error or return None if + some Tensors of `inputs` are unreachable in the graph. Error would + be raised if allow_unused=False, and None would be returned as + their gradients if allow_unused=True. Default False. + Returns: + Hessian (Tensor or a tuple of tuple of Tensors): if function ``func`` + takes a Tensor as ``inputs``, Hessian will be a single Tensor containing + the Hessian matrix for the linearized ``inputs`` Tensor. If function + ``func`` takes a Tensor list/tuple as ``inputs``, then the Hessian will + be a tuple of tuple of Tensors where ``Hessian[i][j]`` will contain the + Hessian matrix of the ``i``th input and ``j``th input with size ``m * n``. + Here ``m`` and ``n`` denote the number of elements of the ``i`` th input + and the ``j`` th input respectively. + + Examples 1: + .. code-block:: python + + import paddle + + def func(x): + return paddle.sum(paddle.matmul(x, x)) + + x = paddle.ones(shape=[2, 2], dtype='float32') + x.stop_gradient = False + hessian = paddle.autograd.hessian(func, x) + print(hessian) + # Tensor(shape=[4, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [[2., 1., 1., 0.], + # [1., 0., 2., 1.], + # [1., 2., 0., 1.], + # [0., 1., 1., 2.]]) + + Examples 2: + .. code-block:: python + + import paddle + + def func(x, y): + return paddle.sum(paddle.matmul(x, y)) + + x = paddle.ones(shape=[2, 2], dtype='float32') + y = paddle.ones(shape=[2, 2], dtype='float32') + x.stop_gradient = False + y.stop_gradient = False + hessian = paddle.autograd.hessian(func, [x, y]) + print(hessian) + # ((Tensor(shape=[4, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [[0., 0., 0., 0.], + # [0., 0., 0., 0.], + # [0., 0., 0., 0.], + # [0., 0., 0., 0.]]), + # Tensor(shape=[4, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [[1., 1., 0., 0.], + # [0., 0., 1., 1.], + # [1., 1., 0., 0.], + # [0., 0., 1., 1.]])), + # (Tensor(shape=[4, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [[1., 0., 1., 0.], + # [1., 0., 1., 0.], + # [0., 1., 0., 1.], + # [0., 1., 0., 1.]]), + # Tensor(shape=[4, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [[0., 0., 0., 0.], + # [0., 0., 0., 0.], + # [0., 0., 0., 0.], + # [0., 0., 0., 0.]]))) + + Examples 3: + .. code-block:: python + + import paddle + + def func(x, y): + return paddle.sum(paddle.matmul(x, x)) + + x = paddle.ones(shape=[2, 2], dtype='float32') + y = paddle.ones(shape=[2, 2], dtype='float32') + x.stop_gradient = False + y.stop_gradient = False + hessian = paddle.autograd.hessian(func, [x, y], allow_unused=True) + print(hessian) + # ((Tensor(shape=[4, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [[2., 1., 1., 0.], + # [1., 0., 2., 1.], + # [1., 2., 0., 1.], + # [0., 1., 1., 2.]]), None), (None, None)) + + ''' + inputs = _check_tensors(inputs, "inputs") + outputs = func(*inputs) + assert isinstance(outputs, paddle.Tensor) and outputs.shape == [ + 1 + ], "The function to compute Hessian matrix should return a Tensor with a single element" + + def jac_func(*ins): + grad_inputs = paddle.grad( + outputs, + ins, + create_graph=True, + retain_graph=True, + allow_unused=allow_unused) + return tuple( + _replace_none_with_zero_tensor(grad_inputs[i], inputs[i]) + for i in range(len(inputs))) + + return jacobian( + jac_func, inputs, create_graph=create_graph, allow_unused=allow_unused) diff --git a/python/paddle/autograd/utils.py b/python/paddle/autograd/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d437f7d82d36112b7ffb34f063217021001770aa --- /dev/null +++ b/python/paddle/autograd/utils.py @@ -0,0 +1,49 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle + + +def _check_tensors(in_out_list, name): + assert in_out_list is not None, "{} should not be None".format(name) + + if isinstance(in_out_list, (list, tuple)): + assert len(in_out_list) > 0, "{} connot be empyt".format(name) + for each_var in in_out_list: + assert isinstance( + each_var, + paddle.Tensor), "Elements of {} must be paddle.Tensor".format( + name) + return list(in_out_list) + else: + assert isinstance( + in_out_list, + paddle.Tensor), "{} must be Tensor or list of Tensor".format(name) + return [in_out_list] + + +def _stack_tensor_or_return_none(origin_list): + assert len(origin_list) > 0, "Can't not stack an empty list" + return paddle.stack( + origin_list, axis=0) if isinstance(origin_list[0], + paddle.Tensor) else None + + +def _replace_none_with_zero_tensor(t, spec_t): + if t is None: + zero_t = paddle.zeros(shape=spec_t.shape, dtype=spec_t.dtype) + zero_t.stop_gradient = spec_t.stop_gradient + return zero_t + else: + return t diff --git a/python/paddle/fluid/tests/unittests/autograd/CMakeLists.txt b/python/paddle/fluid/tests/unittests/autograd/CMakeLists.txt index 7f7a232fcefa649596849c7f9d31d343bf100172..1e9d433ebce8e160d77ccfeb67a9a2f039aae5d6 100644 --- a/python/paddle/fluid/tests/unittests/autograd/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/autograd/CMakeLists.txt @@ -7,3 +7,4 @@ foreach(TEST_OP ${TEST_OPS}) endforeach(TEST_OP) set_tests_properties(test_jacobian PROPERTIES TIMEOUT 20) +set_tests_properties(test_hessian PROPERTIES TIMEOUT 20) diff --git a/python/paddle/fluid/tests/unittests/autograd/test_hessian.py b/python/paddle/fluid/tests/unittests/autograd/test_hessian.py new file mode 100644 index 0000000000000000000000000000000000000000..120a6c853e8d897a17bde479556a19f560860a99 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/autograd/test_hessian.py @@ -0,0 +1,140 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle +import paddle.compat as cpt +from utils import _compute_numerical_hessian + + +class TestHessian(unittest.TestCase): + @classmethod + def setUpClass(self): + self.shape = (2, 2) + self.dtype = 'float32' + self.np_dtype = np.float32 + self.numerical_delta = 1e-2 + self.rtol = 1e-2 + self.atol = 1e-2 + self.x = paddle.rand(shape=self.shape, dtype=self.dtype) + self.y = paddle.rand(shape=self.shape, dtype=self.dtype) + + def test_single_input(self): + def func(x): + return paddle.sum(paddle.matmul(x, x)) + + numerical_hessian = _compute_numerical_hessian( + func, self.x, self.numerical_delta, self.np_dtype) + + self.x.stop_gradient = False + hessian = paddle.autograd.hessian(func, self.x) + assert np.allclose(hessian.numpy(), numerical_hessian[0][0], self.rtol, + self.atol) + + def test_multi_input(self): + def func(x, y): + return paddle.sum(paddle.matmul(x, y)) + + numerical_hessian = _compute_numerical_hessian( + func, [self.x, self.y], self.numerical_delta, self.np_dtype) + + self.x.stop_gradient = False + self.y.stop_gradient = False + hessian = paddle.autograd.hessian(func, [self.x, self.y]) + for i in range(len(hessian)): + for j in range(len(hessian[0])): + assert np.allclose(hessian[i][j].numpy(), + numerical_hessian[i][j], self.rtol, + self.atol) + + def test_allow_unused_false(self): + def func(x, y): + return paddle.sum(paddle.matmul(x, x)) + + try: + self.x.stop_gradient = False + self.y.stop_gradient = False + hessian = paddle.autograd.hessian(func, [self.x, self.y]) + except ValueError as e: + error_msg = cpt.get_exception_message(e) + assert error_msg.find("allow_unused") > 0 + + def test_allow_unused_true(self): + def func(x, y): + return paddle.sum(paddle.matmul(x, x)) + + numerical_hessian = _compute_numerical_hessian( + func, [self.x, self.y], self.numerical_delta, self.np_dtype) + self.x.stop_gradient = False + self.y.stop_gradient = False + hessian = paddle.autograd.hessian( + func, [self.x, self.y], allow_unused=True) + for i in range(len(hessian)): + for j in range(len(hessian[0])): + if i == j == 0: + assert np.allclose(hessian[i][j].numpy(), + numerical_hessian[i][j], self.rtol, + self.atol) + else: + assert hessian[i][j] is None + + def test_create_graph_false(self): + def func(x): + return paddle.sum(paddle.matmul(x, x)) + + numerical_hessian = _compute_numerical_hessian( + func, self.x, self.numerical_delta, self.np_dtype) + self.x.stop_gradient = False + hessian = paddle.autograd.hessian(func, self.x) + assert hessian.stop_gradient == True + assert np.allclose(hessian.numpy(), numerical_hessian[0][0], self.rtol, + self.atol) + try: + paddle.grad(hessian, self.x) + except RuntimeError as e: + error_msg = cpt.get_exception_message(e) + assert error_msg.find("has no gradient") > 0 + + # TODO(levi): enable this test case when matmul_grad_grad_grad is ok + def _test_create_graph_true(self): + def func(x): + return paddle.sum(paddle.matmul(x, x)) + + numerical_hessian = _compute_numerical_hessian( + func, self.x, self.numerical_delta, self.np_dtype) + self.x.stop_gradient = False + hessian = paddle.autograd.hessian(func, self.x, create_graph=True) + assert hessian.stop_gradient == False + assert np.allclose(hessian.numpy(), numerical_hessian[0][0], self.rtol, + self.atol) + triple_grad = paddle.grad(hessian, self.x) + assert triple_grad is not None + + +class TestHessianFloat64(TestHessian): + @classmethod + def setUpClass(self): + self.shape = (2, 2) + self.dtype = 'float64' + self.np_dtype = np.float64 + self.numerical_delta = 1e-5 + self.rtol = 1e-5 + self.atol = 1e-5 + self.x = paddle.rand(shape=self.shape, dtype=self.dtype) + self.y = paddle.rand(shape=self.shape, dtype=self.dtype) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/autograd/test_jacobian.py b/python/paddle/fluid/tests/unittests/autograd/test_jacobian.py index 2722d2c83b130e62ca15e2c09d1d072c4916b81a..2f0b8c7cad3e5e872543a9380ee6efb4c9719b92 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_jacobian.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_jacobian.py @@ -16,65 +16,7 @@ import unittest import numpy as np import paddle import paddle.compat as cpt -from paddle.autograd.functional import _check_tensors - - -def _product(t): - if isinstance(t, int): - return t - else: - return np.product(t) - - -def _get_item(t, idx): - assert isinstance(t, paddle.Tensor), "The first argument t must be Tensor." - assert isinstance(idx, - int), "The second argument idx must be an int number." - flat_t = paddle.reshape(t, [-1]) - return flat_t.__getitem__(idx) - - -def _set_item(t, idx, value): - assert isinstance(t, paddle.Tensor), "The first argument t must be Tensor." - assert isinstance(idx, - int), "The second argument idx must be an int number." - flat_t = paddle.reshape(t, [-1]) - flat_t.__setitem__(idx, value) - return paddle.reshape(flat_t, t.shape) - - -def _compute_numerical_jacobian(func, xs, delta, np_dtype): - xs = _check_tensors(xs, "xs") - ys = _check_tensors(func(*xs), "ys") - fin_size = len(xs) - fout_size = len(ys) - jacobian = list([] for _ in range(fout_size)) - for i in range(fout_size): - jac_i = list([] for _ in range(fin_size)) - for j in range(fin_size): - jac_i[j] = np.zeros( - (_product(ys[i].shape), _product(xs[j].shape)), dtype=np_dtype) - jacobian[i] = jac_i - - for j in range(fin_size): - for q in range(_product(xs[j].shape)): - orig = _get_item(xs[j], q) - x_pos = orig + delta - xs[j] = _set_item(xs[j], q, x_pos) - ys_pos = _check_tensors(func(*xs), "ys_pos") - - x_neg = orig - delta - xs[j] = _set_item(xs[j], q, x_neg) - ys_neg = _check_tensors(func(*xs), "ys_neg") - - xs[j] = _set_item(xs[j], q, orig) - - for i in range(fout_size): - for p in range(_product(ys[i].shape)): - y_pos = _get_item(ys_pos[i], p) - y_neg = _get_item(ys_neg[i], p) - jacobian[i][j][p][q] = (y_pos - y_neg) / delta / 2. - return jacobian +from utils import _compute_numerical_jacobian class TestJacobian(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/autograd/utils.py b/python/paddle/fluid/tests/unittests/autograd/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0aadef4a809f3f224f87d270b443d8ef8b057cff --- /dev/null +++ b/python/paddle/fluid/tests/unittests/autograd/utils.py @@ -0,0 +1,107 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +from paddle.autograd.functional import _check_tensors + + +def _product(t): + if isinstance(t, int): + return t + else: + return np.product(t) + + +def _get_item(t, idx): + assert isinstance(t, paddle.Tensor), "The first argument t must be Tensor." + assert isinstance(idx, + int), "The second argument idx must be an int number." + flat_t = paddle.reshape(t, [-1]) + return flat_t.__getitem__(idx) + + +def _set_item(t, idx, value): + assert isinstance(t, paddle.Tensor), "The first argument t must be Tensor." + assert isinstance(idx, + int), "The second argument idx must be an int number." + flat_t = paddle.reshape(t, [-1]) + flat_t.__setitem__(idx, value) + return paddle.reshape(flat_t, t.shape) + + +def _compute_numerical_jacobian(func, xs, delta, np_dtype): + xs = _check_tensors(xs, "xs") + ys = _check_tensors(func(*xs), "ys") + fin_size = len(xs) + fout_size = len(ys) + jacobian = list([] for _ in range(fout_size)) + for i in range(fout_size): + jac_i = list([] for _ in range(fin_size)) + for j in range(fin_size): + jac_i[j] = np.zeros( + (_product(ys[i].shape), _product(xs[j].shape)), dtype=np_dtype) + jacobian[i] = jac_i + + for j in range(fin_size): + for q in range(_product(xs[j].shape)): + orig = _get_item(xs[j], q) + x_pos = orig + delta + xs[j] = _set_item(xs[j], q, x_pos) + ys_pos = _check_tensors(func(*xs), "ys_pos") + + x_neg = orig - delta + xs[j] = _set_item(xs[j], q, x_neg) + ys_neg = _check_tensors(func(*xs), "ys_neg") + + xs[j] = _set_item(xs[j], q, orig) + + for i in range(fout_size): + for p in range(_product(ys[i].shape)): + y_pos = _get_item(ys_pos[i], p) + y_neg = _get_item(ys_neg[i], p) + jacobian[i][j][p][q] = (y_pos - y_neg) / delta / 2. + return jacobian + + +def _compute_numerical_hessian(func, xs, delta, np_dtype): + xs = _check_tensors(xs, "xs") + ys = _check_tensors(func(*xs), "ys") + fin_size = len(xs) + hessian = list([] for _ in range(fin_size)) + for i in range(fin_size): + hessian_i = list([] for _ in range(fin_size)) + for j in range(fin_size): + hessian_i[j] = np.zeros( + (_product(xs[i].shape), _product(xs[j].shape)), dtype=np_dtype) + hessian[i] = hessian_i + + for i in range(fin_size): + for p in range(_product(xs[i].shape)): + for j in range(fin_size): + for q in range(_product(xs[j].shape)): + orig = _get_item(xs[j], q) + x_pos = orig + delta + xs[j] = _set_item(xs[j], q, x_pos) + jacobian_pos = _compute_numerical_jacobian(func, xs, delta, + np_dtype) + x_neg = orig - delta + xs[j] = _set_item(xs[j], q, x_neg) + jacobian_neg = _compute_numerical_jacobian(func, xs, delta, + np_dtype) + xs[j] = _set_item(xs[j], q, orig) + hessian[i][j][p][q] = ( + jacobian_pos[0][i][0][p] - jacobian_neg[0][i][0][p] + ) / delta / 2. + return hessian