未验证 提交 4c0ad772 编写于 作者: L levi131 提交者: GitHub

Lml/vhp (#36146)

* init functional jacobian api

* finish test with dtype float32

* add float64 test case

* polish code

* use atol=1e-5 with dtype float64

* fix for ci

* set timeout for test_jacobian

* init hessian API

* save status

* polish API docstring

* modify docstring

* add utils.py

* save status

* fix dygraph double grad dtype error when calling for high differential senario

* reinvoke ci

* test_hessian.py is ok

* polish hessian API

* init vhp

* Revert "init vhp"

This reverts commit cbd4d3b66abe82b0ac10721b9eddeb7d82e0a1c8.

* init vhp

* finish vhp API logically

* add test for partial_engine.cc

* modify numerical_delta with dtype float32

* merge fix for dtype float64

* spell fix

* save status

* polish code

* rm _stop_gradient_pre_process

* save status

* add example for vhp interface

* add _compute_numerical_vjp and _compute_numerical_vhp

* test is ok

* vhp is ok

* add testVHPFloat64

* modify for comments

* modify format

* modify format

* save status

* test_vhp is ok

* finish code polish

* small modify for v is None
Co-authored-by: NJiabinYang <360788950@qq.com>
上级 b7f76647
......@@ -18,6 +18,6 @@ from .backward_mode import backward # noqa: F401
from .py_layer import PyLayer, PyLayerContext # noqa: F401
from ..framework import set_grad_enabled # noqa: F401
from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401
from .functional import vjp, jvp, jacobian, hessian # noqa: F401
from .functional import vjp, jvp, jacobian, hessian, vhp # noqa: F401
__all__ = ['backward', 'PyLayer', 'PyLayerContext']
......@@ -247,9 +247,9 @@ def jvp(func, inputs, v=None, create_graph=False, allow_unused=False):
def jacobian(func, inputs, create_graph=False, allow_unused=False):
'''
.. note::
**This API is ONLY available in imperative mode.**
**This API is ONLY available in the imperative mode.**
This API computes the Jacobian matrix of `func` with respect to `inputs`.
This function computes the Jacobian matrix of `func` with respect to `inputs`.
Parameters:
func (function): a Python function that takes a Tensor or a Tensor
......@@ -389,9 +389,9 @@ def jacobian(func, inputs, create_graph=False, allow_unused=False):
def hessian(func, inputs, create_graph=False, allow_unused=False):
'''
.. note::
**This API is ONLY available in imperative mode.**
**This API is ONLY available in the imperative mode.**
This API computes the Hessian matrix of `func` with respect to `inputs`.
This function computes the Hessian matrix of `func` with respect to `inputs`.
Parameters:
func (function): a Python function that takes a Tensor or a Tensor
......@@ -509,3 +509,107 @@ def hessian(func, inputs, create_graph=False, allow_unused=False):
return jacobian(
jac_func, inputs, create_graph=create_graph, allow_unused=allow_unused)
@framework.dygraph_only
def vhp(func, inputs, v=None, create_graph=False, allow_unused=False):
'''
.. note::
**This API is ONLY available in the imperative mode.**
This function computes the product between a vector ``v`` and the
Hessian matrix of `func` with respect to `inputs`.
Parameters:
func (function): a Python function that takes a Tensor or a Tensor
list/tuple as inputs and returns a Tensor with a single element.
inputs (Tensor|list(Tensor)|tuple(Tensor)): the input Tensor or
Tensor list/tuple of the function ``func``.
v (Tensor|list(Tensor)|tuple(Tensor)|None, optional): the vector used
to compute vector hessian product. ``v`` should have same shape
and dtype with ``inputs``. If ``v`` is None, it will be set as
Tensor|list(Tensor) with all elements 1. Defaults to "None".
create_graph (bool, optional): whether to create the gradient graphs
of the computing process. When it is True, higher order derivatives
are supported to compute; when it is False, the gradient graphs of
the computing process would be discarded. Defaults to ``False``.
allow_unused (bool, optional): whether to raise error or return None if
some Tensors of `inputs` are unreachable in the graph. Error would
be raised if allow_unused=False, and None would be returned as
their gradients if allow_unused=True. Default False.
Returns:
output (tuple): tuple with:
func_output (Tensor): output of ``func(inputs)``
vhp (list(Tensor)): result of the vector hessian product
with the same shape and dtype as the inputs.
Examples 1:
.. code-block:: python
import paddle
def func(x):
return paddle.sum(paddle.matmul(x, x))
x = paddle.ones(shape=[2, 2], dtype='float32')
x.stop_gradient = False
vx = paddle.ones(shape=[2, 2], dtype='float32') * 2
vhp_rslt = paddle.autograd.vhp(func, x, v=vx)
print(vhp_rslt)
# (Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,
# [8.]),
# Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[8., 8.],
# [8., 8.]]))
Examples 2:
.. code-block:: python
import paddle
def func(x):
return paddle.sum(paddle.matmul(x, x))
x = paddle.ones(shape=[2, 2], dtype='float32')
x.stop_gradient = False
vhp_rslt = paddle.autograd.vhp(func, x)
print(vhp_rslt)
# (Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,
# [8.]),
# Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[4., 4.],
# [4., 4.]]))
Examples 3:
.. code-block:: python
import paddle
def func(x, y):
return paddle.sum(paddle.matmul(x, x))
x = paddle.ones(shape=[2, 2], dtype='float32')
x.stop_gradient = False
y = paddle.ones(shape=[2, 2], dtype='float32')
y.stop_gradient = False
vx = paddle.ones(shape=[2, 2], dtype='float32') * 2
vy = paddle.ones(shape=[2, 2], dtype='float32') * 3
vhp_rslt = paddle.autograd.vhp(func, [x, y], v=[vx, vy], allow_unused=True)
print(vhp_rslt)
# (Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,
# [8.]),
# [Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[8., 8.],
# [8., 8.]]), None])
'''
xs = _tensors(inputs, "inputs")
if v is not None:
v = _tensors(v, "v")
with gradient_scope(
xs, v, create_graph=create_graph,
allow_unused=allow_unused) as [xs, v, grad_fn, return_fn]:
outputs = func(*xs)
ys = _tensors(outputs, "outputs")
assert len(ys) == 1 and isinstance(
ys[0], paddle.Tensor
) and ys[0].shape == [
1
], "The function to compute vhp should return a Tensor with a single element"
jac = grad_fn(ys, xs, create_graph=True)
vhp = grad_fn(jac, xs, v)
outputs, vhp = return_fn(outputs), return_fn(vhp)
return outputs, vhp
......@@ -25,9 +25,7 @@ def _tensors(ts, name):
name)
return list(ts)
else:
assert isinstance(
ts, paddle.Tensor
) or ts is None, "{} must be Tensor or list of Tensor".format(name)
assert isinstance(ts, paddle.Tensor), "{} must be Tensor".format(name)
return [ts]
......
......@@ -8,3 +8,4 @@ endforeach(TEST_OP)
set_tests_properties(test_jacobian PROPERTIES TIMEOUT 20)
set_tests_properties(test_hessian PROPERTIES TIMEOUT 50)
set_tests_properties(test_vhp PROPERTIES TIMEOUT 50)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.compat as cpt
import paddle.nn.functional as F
from utils import _compute_numerical_vhp
class TestVHP(unittest.TestCase):
@classmethod
def setUpClass(self):
self.shape = (2, 2)
self.dtype = 'float32'
self.np_dtype = np.float32
self.numerical_delta = 1e-2
self.rtol = 1e-2
self.atol = 1e-2
self.x = paddle.rand(shape=self.shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.shape, dtype=self.dtype)
self.vx = paddle.rand(shape=self.shape, dtype=self.dtype)
self.vy = paddle.rand(shape=self.shape, dtype=self.dtype)
def test_single_input(self):
def func(x):
return paddle.sum(paddle.matmul(x, x))
numerical_func_output = func(self.x).numpy()
numerical_vhp = _compute_numerical_vhp(
func, self.x, self.vx, self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
func_output, vhp = paddle.autograd.vhp(func, self.x, self.vx)
assert np.allclose(func_output.numpy(), numerical_func_output,
self.rtol, self.atol)
assert np.allclose(vhp[0].numpy(), numerical_vhp[0], self.rtol,
self.atol)
def test_multi_input(self):
def func(x, y):
return paddle.sum(paddle.matmul(x, y))
numerical_func_output = func(self.x, self.y).numpy()
numerical_vhp = _compute_numerical_vhp(
func, [self.x, self.y], [self.vx, self.vy], self.numerical_delta,
self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
func_output, vhp = paddle.autograd.vhp(func, [self.x, self.y],
[self.vx, self.vy])
assert np.allclose(func_output.numpy(), numerical_func_output,
self.rtol, self.atol)
for i in range(len(vhp)):
assert np.allclose(vhp[i].numpy(), numerical_vhp[i], self.rtol,
self.atol)
def test_v_default(self):
def func(x, y):
return paddle.sum(paddle.matmul(x, y))
numerical_func_output = func(self.x, self.y).numpy()
vx = paddle.ones(self.vx.shape, dtype=self.vx.dtype)
vy = paddle.ones(self.vy.shape, dtype=self.vy.dtype)
numerical_vhp = _compute_numerical_vhp(func, [self.x, self.y],
[vx, vy], self.numerical_delta,
self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
func_output, vhp = paddle.autograd.vhp(func, [self.x, self.y])
assert np.allclose(func_output.numpy(), numerical_func_output,
self.rtol, self.atol)
for i in range(len(vhp)):
assert np.allclose(vhp[i].numpy(), numerical_vhp[i], self.rtol,
self.atol)
def test_allow_unused_false(self):
def func(x, y):
return paddle.sum(paddle.matmul(x, x))
try:
self.x.stop_gradient = False
self.y.stop_gradient = False
_ = paddle.autograd.vhp(func, [self.x, self.y])
except ValueError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("allow_unused") > 0
def test_allow_unused_true(self):
def func(x, y):
return paddle.sum(paddle.matmul(x, x))
numerical_func_output = func(self.x, self.y).numpy()
numerical_vhp = _compute_numerical_vhp(
func, [self.x, self.y], [self.vx, self.vy], self.numerical_delta,
self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
func_output, vhp = paddle.autograd.vhp(func, [self.x, self.y],
[self.vx, self.vy],
allow_unused=True)
assert np.allclose(func_output.numpy(), numerical_func_output,
self.rtol, self.atol)
assert np.allclose(vhp[0].numpy(), numerical_vhp[0], self.rtol,
self.atol)
assert vhp[1] is None
def test_create_graph_false(self):
def func(x):
return paddle.sum(F.sigmoid(x))
numerical_func_output = func(self.x).numpy()
numerical_vhp = _compute_numerical_vhp(
func, self.x, self.vx, self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
func_output, vhp = paddle.autograd.vhp(func, self.x, self.vx)
assert np.allclose(func_output.numpy(), numerical_func_output,
self.rtol, self.atol)
assert vhp[0].stop_gradient == True
assert np.allclose(vhp[0].numpy(), numerical_vhp[0], self.rtol,
self.atol)
try:
paddle.grad(vhp, self.x)
except RuntimeError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("has no gradient") > 0
def test_create_graph_true(self):
def func(x):
return paddle.sum(F.sigmoid(x))
numerical_func_output = func(self.x).numpy()
numerical_vhp = _compute_numerical_vhp(
func, self.x, self.vx, self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
func_output, vhp = paddle.autograd.vhp(func,
self.x,
self.vx,
create_graph=True)
assert np.allclose(func_output.numpy(), numerical_func_output,
self.rtol, self.atol)
assert vhp[0].stop_gradient == False
assert np.allclose(vhp[0].numpy(), numerical_vhp[0], self.rtol,
self.atol)
triple_grad = paddle.grad(vhp, self.x)
assert triple_grad is not None
class TestVHPFloat64(TestVHP):
@classmethod
def setUpClass(self):
self.shape = (2, 2)
self.dtype = 'float64'
self.np_dtype = np.float64
self.numerical_delta = 1e-5
self.rtol = 1e-5
self.atol = 1e-5
self.x = paddle.rand(shape=self.shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.shape, dtype=self.dtype)
self.vx = paddle.rand(shape=self.shape, dtype=self.dtype)
self.vy = paddle.rand(shape=self.shape, dtype=self.dtype)
if __name__ == "__main__":
unittest.main()
......@@ -105,3 +105,29 @@ def _compute_numerical_hessian(func, xs, delta, np_dtype):
jacobian_pos[0][i][0][p] - jacobian_neg[0][i][0][p]
) / delta / 2.
return hessian
def _compute_numerical_vjp(func, xs, v, delta, np_dtype):
xs = _tensors(xs, "xs")
jacobian = np.array(_compute_numerical_jacobian(func, xs, delta, np_dtype))
flat_v = np.array([v_el.numpy().reshape(-1) for v_el in v])
vjp = [np.zeros((_product(x.shape)), dtype=np_dtype) for x in xs]
for j in range(len(xs)):
for q in range(_product(xs[j].shape)):
vjp[j][q] = np.sum(jacobian[:, j, :, q].reshape(flat_v.shape) *
flat_v)
vjp = [vjp[j].reshape(xs[j].shape) for j in range(len(xs))]
return vjp
def _compute_numerical_vhp(func, xs, v, delta, np_dtype):
xs = _tensors(xs, "xs")
hessian = np.array(_compute_numerical_hessian(func, xs, delta, np_dtype))
flat_v = np.array([v_el.numpy().reshape(-1) for v_el in v])
vhp = [np.zeros((_product(x.shape)), dtype=np_dtype) for x in xs]
for j in range(len(xs)):
for q in range(_product(xs[j].shape)):
vhp[j][q] = np.sum(hessian[:, j, :, q].reshape(flat_v.shape) *
flat_v)
vhp = [vhp[j].reshape(xs[j].shape) for j in range(len(xs))]
return vhp
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册