From 4c0ad7727efd5cf9d1d1bac3364f0ae487359e5c Mon Sep 17 00:00:00 2001 From: levi131 <83750468+levi131@users.noreply.github.com> Date: Mon, 18 Oct 2021 16:10:52 +0800 Subject: [PATCH] Lml/vhp (#36146) * init functional jacobian api * finish test with dtype float32 * add float64 test case * polish code * use atol=1e-5 with dtype float64 * fix for ci * set timeout for test_jacobian * init hessian API * save status * polish API docstring * modify docstring * add utils.py * save status * fix dygraph double grad dtype error when calling for high differential senario * reinvoke ci * test_hessian.py is ok * polish hessian API * init vhp * Revert "init vhp" This reverts commit cbd4d3b66abe82b0ac10721b9eddeb7d82e0a1c8. * init vhp * finish vhp API logically * add test for partial_engine.cc * modify numerical_delta with dtype float32 * merge fix for dtype float64 * spell fix * save status * polish code * rm _stop_gradient_pre_process * save status * add example for vhp interface * add _compute_numerical_vjp and _compute_numerical_vhp * test is ok * vhp is ok * add testVHPFloat64 * modify for comments * modify format * modify format * save status * test_vhp is ok * finish code polish * small modify for v is None Co-authored-by: JiabinYang <360788950@qq.com> --- python/paddle/autograd/__init__.py | 2 +- python/paddle/autograd/functional.py | 112 ++++++++++- python/paddle/autograd/utils.py | 4 +- .../tests/unittests/autograd/CMakeLists.txt | 1 + .../tests/unittests/autograd/test_vhp.py | 182 ++++++++++++++++++ .../fluid/tests/unittests/autograd/utils.py | 26 +++ 6 files changed, 319 insertions(+), 8 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/autograd/test_vhp.py diff --git a/python/paddle/autograd/__init__.py b/python/paddle/autograd/__init__.py index cffc18e95e5..bbfb9f22fc1 100644 --- a/python/paddle/autograd/__init__.py +++ b/python/paddle/autograd/__init__.py @@ -18,6 +18,6 @@ from .backward_mode import backward # noqa: F401 from .py_layer import PyLayer, PyLayerContext # noqa: F401 from ..framework import set_grad_enabled # noqa: F401 from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401 -from .functional import vjp, jvp, jacobian, hessian # noqa: F401 +from .functional import vjp, jvp, jacobian, hessian, vhp # noqa: F401 __all__ = ['backward', 'PyLayer', 'PyLayerContext'] diff --git a/python/paddle/autograd/functional.py b/python/paddle/autograd/functional.py index 66ae1562edb..c6235877f5b 100644 --- a/python/paddle/autograd/functional.py +++ b/python/paddle/autograd/functional.py @@ -247,9 +247,9 @@ def jvp(func, inputs, v=None, create_graph=False, allow_unused=False): def jacobian(func, inputs, create_graph=False, allow_unused=False): ''' .. note:: - **This API is ONLY available in imperative mode.** + **This API is ONLY available in the imperative mode.** - This API computes the Jacobian matrix of `func` with respect to `inputs`. + This function computes the Jacobian matrix of `func` with respect to `inputs`. Parameters: func (function): a Python function that takes a Tensor or a Tensor @@ -389,9 +389,9 @@ def jacobian(func, inputs, create_graph=False, allow_unused=False): def hessian(func, inputs, create_graph=False, allow_unused=False): ''' .. note:: - **This API is ONLY available in imperative mode.** + **This API is ONLY available in the imperative mode.** - This API computes the Hessian matrix of `func` with respect to `inputs`. + This function computes the Hessian matrix of `func` with respect to `inputs`. Parameters: func (function): a Python function that takes a Tensor or a Tensor @@ -509,3 +509,107 @@ def hessian(func, inputs, create_graph=False, allow_unused=False): return jacobian( jac_func, inputs, create_graph=create_graph, allow_unused=allow_unused) + + +@framework.dygraph_only +def vhp(func, inputs, v=None, create_graph=False, allow_unused=False): + ''' + .. note:: + **This API is ONLY available in the imperative mode.** + + This function computes the product between a vector ``v`` and the + Hessian matrix of `func` with respect to `inputs`. + + Parameters: + func (function): a Python function that takes a Tensor or a Tensor + list/tuple as inputs and returns a Tensor with a single element. + inputs (Tensor|list(Tensor)|tuple(Tensor)): the input Tensor or + Tensor list/tuple of the function ``func``. + v (Tensor|list(Tensor)|tuple(Tensor)|None, optional): the vector used + to compute vector hessian product. ``v`` should have same shape + and dtype with ``inputs``. If ``v`` is None, it will be set as + Tensor|list(Tensor) with all elements 1. Defaults to "None". + create_graph (bool, optional): whether to create the gradient graphs + of the computing process. When it is True, higher order derivatives + are supported to compute; when it is False, the gradient graphs of + the computing process would be discarded. Defaults to ``False``. + allow_unused (bool, optional): whether to raise error or return None if + some Tensors of `inputs` are unreachable in the graph. Error would + be raised if allow_unused=False, and None would be returned as + their gradients if allow_unused=True. Default False. + Returns: + output (tuple): tuple with: + func_output (Tensor): output of ``func(inputs)`` + vhp (list(Tensor)): result of the vector hessian product + with the same shape and dtype as the inputs. + Examples 1: + .. code-block:: python + import paddle + def func(x): + return paddle.sum(paddle.matmul(x, x)) + + x = paddle.ones(shape=[2, 2], dtype='float32') + x.stop_gradient = False + vx = paddle.ones(shape=[2, 2], dtype='float32') * 2 + vhp_rslt = paddle.autograd.vhp(func, x, v=vx) + print(vhp_rslt) + # (Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False, + # [8.]), + # Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [[8., 8.], + # [8., 8.]])) + + Examples 2: + .. code-block:: python + import paddle + def func(x): + return paddle.sum(paddle.matmul(x, x)) + + x = paddle.ones(shape=[2, 2], dtype='float32') + x.stop_gradient = False + vhp_rslt = paddle.autograd.vhp(func, x) + print(vhp_rslt) + # (Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False, + # [8.]), + # Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [[4., 4.], + # [4., 4.]])) + + Examples 3: + .. code-block:: python + import paddle + def func(x, y): + return paddle.sum(paddle.matmul(x, x)) + + x = paddle.ones(shape=[2, 2], dtype='float32') + x.stop_gradient = False + y = paddle.ones(shape=[2, 2], dtype='float32') + y.stop_gradient = False + vx = paddle.ones(shape=[2, 2], dtype='float32') * 2 + vy = paddle.ones(shape=[2, 2], dtype='float32') * 3 + vhp_rslt = paddle.autograd.vhp(func, [x, y], v=[vx, vy], allow_unused=True) + print(vhp_rslt) + # (Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False, + # [8.]), + # [Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [[8., 8.], + # [8., 8.]]), None]) + ''' + xs = _tensors(inputs, "inputs") + if v is not None: + v = _tensors(v, "v") + + with gradient_scope( + xs, v, create_graph=create_graph, + allow_unused=allow_unused) as [xs, v, grad_fn, return_fn]: + outputs = func(*xs) + ys = _tensors(outputs, "outputs") + assert len(ys) == 1 and isinstance( + ys[0], paddle.Tensor + ) and ys[0].shape == [ + 1 + ], "The function to compute vhp should return a Tensor with a single element" + jac = grad_fn(ys, xs, create_graph=True) + vhp = grad_fn(jac, xs, v) + outputs, vhp = return_fn(outputs), return_fn(vhp) + return outputs, vhp diff --git a/python/paddle/autograd/utils.py b/python/paddle/autograd/utils.py index 81fe19c1688..710c9ee18df 100644 --- a/python/paddle/autograd/utils.py +++ b/python/paddle/autograd/utils.py @@ -25,9 +25,7 @@ def _tensors(ts, name): name) return list(ts) else: - assert isinstance( - ts, paddle.Tensor - ) or ts is None, "{} must be Tensor or list of Tensor".format(name) + assert isinstance(ts, paddle.Tensor), "{} must be Tensor".format(name) return [ts] diff --git a/python/paddle/fluid/tests/unittests/autograd/CMakeLists.txt b/python/paddle/fluid/tests/unittests/autograd/CMakeLists.txt index 369134c8989..30d87e2c9b2 100644 --- a/python/paddle/fluid/tests/unittests/autograd/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/autograd/CMakeLists.txt @@ -8,3 +8,4 @@ endforeach(TEST_OP) set_tests_properties(test_jacobian PROPERTIES TIMEOUT 20) set_tests_properties(test_hessian PROPERTIES TIMEOUT 50) +set_tests_properties(test_vhp PROPERTIES TIMEOUT 50) diff --git a/python/paddle/fluid/tests/unittests/autograd/test_vhp.py b/python/paddle/fluid/tests/unittests/autograd/test_vhp.py new file mode 100644 index 00000000000..09b25203e04 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/autograd/test_vhp.py @@ -0,0 +1,182 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle +import paddle.compat as cpt +import paddle.nn.functional as F +from utils import _compute_numerical_vhp + + +class TestVHP(unittest.TestCase): + @classmethod + def setUpClass(self): + self.shape = (2, 2) + self.dtype = 'float32' + self.np_dtype = np.float32 + self.numerical_delta = 1e-2 + self.rtol = 1e-2 + self.atol = 1e-2 + self.x = paddle.rand(shape=self.shape, dtype=self.dtype) + self.y = paddle.rand(shape=self.shape, dtype=self.dtype) + self.vx = paddle.rand(shape=self.shape, dtype=self.dtype) + self.vy = paddle.rand(shape=self.shape, dtype=self.dtype) + + def test_single_input(self): + def func(x): + return paddle.sum(paddle.matmul(x, x)) + + numerical_func_output = func(self.x).numpy() + numerical_vhp = _compute_numerical_vhp( + func, self.x, self.vx, self.numerical_delta, self.np_dtype) + + self.x.stop_gradient = False + func_output, vhp = paddle.autograd.vhp(func, self.x, self.vx) + assert np.allclose(func_output.numpy(), numerical_func_output, + self.rtol, self.atol) + assert np.allclose(vhp[0].numpy(), numerical_vhp[0], self.rtol, + self.atol) + + def test_multi_input(self): + def func(x, y): + return paddle.sum(paddle.matmul(x, y)) + + numerical_func_output = func(self.x, self.y).numpy() + numerical_vhp = _compute_numerical_vhp( + func, [self.x, self.y], [self.vx, self.vy], self.numerical_delta, + self.np_dtype) + + self.x.stop_gradient = False + self.y.stop_gradient = False + func_output, vhp = paddle.autograd.vhp(func, [self.x, self.y], + [self.vx, self.vy]) + assert np.allclose(func_output.numpy(), numerical_func_output, + self.rtol, self.atol) + for i in range(len(vhp)): + assert np.allclose(vhp[i].numpy(), numerical_vhp[i], self.rtol, + self.atol) + + def test_v_default(self): + def func(x, y): + return paddle.sum(paddle.matmul(x, y)) + + numerical_func_output = func(self.x, self.y).numpy() + vx = paddle.ones(self.vx.shape, dtype=self.vx.dtype) + vy = paddle.ones(self.vy.shape, dtype=self.vy.dtype) + numerical_vhp = _compute_numerical_vhp(func, [self.x, self.y], + [vx, vy], self.numerical_delta, + self.np_dtype) + + self.x.stop_gradient = False + self.y.stop_gradient = False + func_output, vhp = paddle.autograd.vhp(func, [self.x, self.y]) + assert np.allclose(func_output.numpy(), numerical_func_output, + self.rtol, self.atol) + for i in range(len(vhp)): + assert np.allclose(vhp[i].numpy(), numerical_vhp[i], self.rtol, + self.atol) + + def test_allow_unused_false(self): + def func(x, y): + return paddle.sum(paddle.matmul(x, x)) + + try: + self.x.stop_gradient = False + self.y.stop_gradient = False + _ = paddle.autograd.vhp(func, [self.x, self.y]) + except ValueError as e: + error_msg = cpt.get_exception_message(e) + assert error_msg.find("allow_unused") > 0 + + def test_allow_unused_true(self): + def func(x, y): + return paddle.sum(paddle.matmul(x, x)) + + numerical_func_output = func(self.x, self.y).numpy() + numerical_vhp = _compute_numerical_vhp( + func, [self.x, self.y], [self.vx, self.vy], self.numerical_delta, + self.np_dtype) + + self.x.stop_gradient = False + self.y.stop_gradient = False + func_output, vhp = paddle.autograd.vhp(func, [self.x, self.y], + [self.vx, self.vy], + allow_unused=True) + assert np.allclose(func_output.numpy(), numerical_func_output, + self.rtol, self.atol) + assert np.allclose(vhp[0].numpy(), numerical_vhp[0], self.rtol, + self.atol) + assert vhp[1] is None + + def test_create_graph_false(self): + def func(x): + return paddle.sum(F.sigmoid(x)) + + numerical_func_output = func(self.x).numpy() + numerical_vhp = _compute_numerical_vhp( + func, self.x, self.vx, self.numerical_delta, self.np_dtype) + + self.x.stop_gradient = False + func_output, vhp = paddle.autograd.vhp(func, self.x, self.vx) + assert np.allclose(func_output.numpy(), numerical_func_output, + self.rtol, self.atol) + assert vhp[0].stop_gradient == True + assert np.allclose(vhp[0].numpy(), numerical_vhp[0], self.rtol, + self.atol) + try: + paddle.grad(vhp, self.x) + except RuntimeError as e: + error_msg = cpt.get_exception_message(e) + assert error_msg.find("has no gradient") > 0 + + def test_create_graph_true(self): + def func(x): + return paddle.sum(F.sigmoid(x)) + + numerical_func_output = func(self.x).numpy() + numerical_vhp = _compute_numerical_vhp( + func, self.x, self.vx, self.numerical_delta, self.np_dtype) + + self.x.stop_gradient = False + func_output, vhp = paddle.autograd.vhp(func, + self.x, + self.vx, + create_graph=True) + assert np.allclose(func_output.numpy(), numerical_func_output, + self.rtol, self.atol) + assert vhp[0].stop_gradient == False + assert np.allclose(vhp[0].numpy(), numerical_vhp[0], self.rtol, + self.atol) + triple_grad = paddle.grad(vhp, self.x) + assert triple_grad is not None + + +class TestVHPFloat64(TestVHP): + @classmethod + def setUpClass(self): + self.shape = (2, 2) + self.dtype = 'float64' + self.np_dtype = np.float64 + self.numerical_delta = 1e-5 + self.rtol = 1e-5 + self.atol = 1e-5 + self.x = paddle.rand(shape=self.shape, dtype=self.dtype) + self.y = paddle.rand(shape=self.shape, dtype=self.dtype) + self.vx = paddle.rand(shape=self.shape, dtype=self.dtype) + self.vy = paddle.rand(shape=self.shape, dtype=self.dtype) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/autograd/utils.py b/python/paddle/fluid/tests/unittests/autograd/utils.py index 3087e932051..402e89ae476 100644 --- a/python/paddle/fluid/tests/unittests/autograd/utils.py +++ b/python/paddle/fluid/tests/unittests/autograd/utils.py @@ -105,3 +105,29 @@ def _compute_numerical_hessian(func, xs, delta, np_dtype): jacobian_pos[0][i][0][p] - jacobian_neg[0][i][0][p] ) / delta / 2. return hessian + + +def _compute_numerical_vjp(func, xs, v, delta, np_dtype): + xs = _tensors(xs, "xs") + jacobian = np.array(_compute_numerical_jacobian(func, xs, delta, np_dtype)) + flat_v = np.array([v_el.numpy().reshape(-1) for v_el in v]) + vjp = [np.zeros((_product(x.shape)), dtype=np_dtype) for x in xs] + for j in range(len(xs)): + for q in range(_product(xs[j].shape)): + vjp[j][q] = np.sum(jacobian[:, j, :, q].reshape(flat_v.shape) * + flat_v) + vjp = [vjp[j].reshape(xs[j].shape) for j in range(len(xs))] + return vjp + + +def _compute_numerical_vhp(func, xs, v, delta, np_dtype): + xs = _tensors(xs, "xs") + hessian = np.array(_compute_numerical_hessian(func, xs, delta, np_dtype)) + flat_v = np.array([v_el.numpy().reshape(-1) for v_el in v]) + vhp = [np.zeros((_product(x.shape)), dtype=np_dtype) for x in xs] + for j in range(len(xs)): + for q in range(_product(xs[j].shape)): + vhp[j][q] = np.sum(hessian[:, j, :, q].reshape(flat_v.shape) * + flat_v) + vhp = [vhp[j].reshape(xs[j].shape) for j in range(len(xs))] + return vhp -- GitLab