未验证 提交 1f93582c 编写于 作者: L levi131 提交者: GitHub

Add functional autograd API:hessian (#36108)

* init functional jacobian api

* finish test with dtype float32

* add float64 test case

* polish code

* use atol=1e-5 with dtype float64

* fix for ci

* set timeout for test_jacobian

* init hessian API

* save status

* polish API docstring

* modify docstring

* add utils.py

* save status

* fix dygraph double grad dtype error when calling for high differential senario

* reinvoke ci

* test_hessian.py is ok

* polish hessian API

* init vhp

* Revert "init vhp"

This reverts commit cbd4d3b66abe82b0ac10721b9eddeb7d82e0a1c8.

* add test for partial_engine.cc

* modify numerical_delta with dtype float32

* merge fix for dtype float64

* spell fix

* polish code

* rm _stop_gradient_pre_process
Co-authored-by: NJiabinYang <360788950@qq.com>
上级 a9ea41c5
......@@ -18,6 +18,6 @@ from .backward_mode import backward # noqa: F401
from .py_layer import PyLayer, PyLayerContext # noqa: F401
from ..framework import set_grad_enabled # noqa: F401
from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401
from .functional import jacobian # noqa: F401
from .functional import jacobian, hessian # noqa: F401
__all__ = ['backward', 'PyLayer', 'PyLayerContext']
......@@ -13,34 +13,10 @@
# limitations under the License.
from paddle.fluid import framework
from .utils import _check_tensors, _stack_tensor_or_return_none, _replace_none_with_zero_tensor
import paddle
def _check_tensors(in_out_list, name):
assert in_out_list is not None, "{} should not be None".format(name)
if isinstance(in_out_list, (list, tuple)):
assert len(in_out_list) > 0, "{} connot be empyt".format(name)
for each_var in in_out_list:
assert isinstance(
each_var,
paddle.Tensor), "Elements of {} must be paddle.Tensor".format(
name)
return in_out_list
else:
assert isinstance(
in_out_list,
paddle.Tensor), "{} must be Tensor or list of Tensor".format(name)
return [in_out_list]
def _stack_tensor_or_return_none(origin_list):
assert len(origin_list) > 0, "Can't not stack an empty list"
return paddle.stack(
origin_list, axis=0) if isinstance(origin_list[0],
paddle.Tensor) else None
@framework.dygraph_only
def jacobian(func, inputs, create_graph=False, allow_unused=False):
'''
......@@ -183,3 +159,129 @@ def jacobian(func, inputs, create_graph=False, allow_unused=False):
return jacobian[0]
else:
return jacobian
@framework.dygraph_only
def hessian(func, inputs, create_graph=False, allow_unused=False):
'''
.. note::
**This API is ONLY available in imperative mode.**
This API computes the Hessian matrix of `func` with respect to `inputs`.
Parameters:
func (function): a Python function that takes a Tensor or a Tensor
list/tuple as inputs and returns a Tensor with a single element.
inputs (Tensor|list(Tensor)|tuple(Tensor)): the input Tensor or
Tensor list/tuple of the function ``func``.
create_graph (bool, optional): whether to create the gradient graphs
of the computing process. When it is True, higher order derivatives
are supported to compute; when it is False, the gradient graphs of
the computing process would be discarded. Defaults to ``False``.
allow_unused (bool, optional): whether to raise error or return None if
some Tensors of `inputs` are unreachable in the graph. Error would
be raised if allow_unused=False, and None would be returned as
their gradients if allow_unused=True. Default False.
Returns:
Hessian (Tensor or a tuple of tuple of Tensors): if function ``func``
takes a Tensor as ``inputs``, Hessian will be a single Tensor containing
the Hessian matrix for the linearized ``inputs`` Tensor. If function
``func`` takes a Tensor list/tuple as ``inputs``, then the Hessian will
be a tuple of tuple of Tensors where ``Hessian[i][j]`` will contain the
Hessian matrix of the ``i``th input and ``j``th input with size ``m * n``.
Here ``m`` and ``n`` denote the number of elements of the ``i`` th input
and the ``j`` th input respectively.
Examples 1:
.. code-block:: python
import paddle
def func(x):
return paddle.sum(paddle.matmul(x, x))
x = paddle.ones(shape=[2, 2], dtype='float32')
x.stop_gradient = False
hessian = paddle.autograd.hessian(func, x)
print(hessian)
# Tensor(shape=[4, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[2., 1., 1., 0.],
# [1., 0., 2., 1.],
# [1., 2., 0., 1.],
# [0., 1., 1., 2.]])
Examples 2:
.. code-block:: python
import paddle
def func(x, y):
return paddle.sum(paddle.matmul(x, y))
x = paddle.ones(shape=[2, 2], dtype='float32')
y = paddle.ones(shape=[2, 2], dtype='float32')
x.stop_gradient = False
y.stop_gradient = False
hessian = paddle.autograd.hessian(func, [x, y])
print(hessian)
# ((Tensor(shape=[4, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[0., 0., 0., 0.],
# [0., 0., 0., 0.],
# [0., 0., 0., 0.],
# [0., 0., 0., 0.]]),
# Tensor(shape=[4, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[1., 1., 0., 0.],
# [0., 0., 1., 1.],
# [1., 1., 0., 0.],
# [0., 0., 1., 1.]])),
# (Tensor(shape=[4, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[1., 0., 1., 0.],
# [1., 0., 1., 0.],
# [0., 1., 0., 1.],
# [0., 1., 0., 1.]]),
# Tensor(shape=[4, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[0., 0., 0., 0.],
# [0., 0., 0., 0.],
# [0., 0., 0., 0.],
# [0., 0., 0., 0.]])))
Examples 3:
.. code-block:: python
import paddle
def func(x, y):
return paddle.sum(paddle.matmul(x, x))
x = paddle.ones(shape=[2, 2], dtype='float32')
y = paddle.ones(shape=[2, 2], dtype='float32')
x.stop_gradient = False
y.stop_gradient = False
hessian = paddle.autograd.hessian(func, [x, y], allow_unused=True)
print(hessian)
# ((Tensor(shape=[4, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[2., 1., 1., 0.],
# [1., 0., 2., 1.],
# [1., 2., 0., 1.],
# [0., 1., 1., 2.]]), None), (None, None))
'''
inputs = _check_tensors(inputs, "inputs")
outputs = func(*inputs)
assert isinstance(outputs, paddle.Tensor) and outputs.shape == [
1
], "The function to compute Hessian matrix should return a Tensor with a single element"
def jac_func(*ins):
grad_inputs = paddle.grad(
outputs,
ins,
create_graph=True,
retain_graph=True,
allow_unused=allow_unused)
return tuple(
_replace_none_with_zero_tensor(grad_inputs[i], inputs[i])
for i in range(len(inputs)))
return jacobian(
jac_func, inputs, create_graph=create_graph, allow_unused=allow_unused)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
def _check_tensors(in_out_list, name):
assert in_out_list is not None, "{} should not be None".format(name)
if isinstance(in_out_list, (list, tuple)):
assert len(in_out_list) > 0, "{} connot be empyt".format(name)
for each_var in in_out_list:
assert isinstance(
each_var,
paddle.Tensor), "Elements of {} must be paddle.Tensor".format(
name)
return list(in_out_list)
else:
assert isinstance(
in_out_list,
paddle.Tensor), "{} must be Tensor or list of Tensor".format(name)
return [in_out_list]
def _stack_tensor_or_return_none(origin_list):
assert len(origin_list) > 0, "Can't not stack an empty list"
return paddle.stack(
origin_list, axis=0) if isinstance(origin_list[0],
paddle.Tensor) else None
def _replace_none_with_zero_tensor(t, spec_t):
if t is None:
zero_t = paddle.zeros(shape=spec_t.shape, dtype=spec_t.dtype)
zero_t.stop_gradient = spec_t.stop_gradient
return zero_t
else:
return t
......@@ -7,3 +7,4 @@ foreach(TEST_OP ${TEST_OPS})
endforeach(TEST_OP)
set_tests_properties(test_jacobian PROPERTIES TIMEOUT 20)
set_tests_properties(test_hessian PROPERTIES TIMEOUT 20)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.compat as cpt
from utils import _compute_numerical_hessian
class TestHessian(unittest.TestCase):
@classmethod
def setUpClass(self):
self.shape = (2, 2)
self.dtype = 'float32'
self.np_dtype = np.float32
self.numerical_delta = 1e-2
self.rtol = 1e-2
self.atol = 1e-2
self.x = paddle.rand(shape=self.shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.shape, dtype=self.dtype)
def test_single_input(self):
def func(x):
return paddle.sum(paddle.matmul(x, x))
numerical_hessian = _compute_numerical_hessian(
func, self.x, self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
hessian = paddle.autograd.hessian(func, self.x)
assert np.allclose(hessian.numpy(), numerical_hessian[0][0], self.rtol,
self.atol)
def test_multi_input(self):
def func(x, y):
return paddle.sum(paddle.matmul(x, y))
numerical_hessian = _compute_numerical_hessian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
hessian = paddle.autograd.hessian(func, [self.x, self.y])
for i in range(len(hessian)):
for j in range(len(hessian[0])):
assert np.allclose(hessian[i][j].numpy(),
numerical_hessian[i][j], self.rtol,
self.atol)
def test_allow_unused_false(self):
def func(x, y):
return paddle.sum(paddle.matmul(x, x))
try:
self.x.stop_gradient = False
self.y.stop_gradient = False
hessian = paddle.autograd.hessian(func, [self.x, self.y])
except ValueError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("allow_unused") > 0
def test_allow_unused_true(self):
def func(x, y):
return paddle.sum(paddle.matmul(x, x))
numerical_hessian = _compute_numerical_hessian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
hessian = paddle.autograd.hessian(
func, [self.x, self.y], allow_unused=True)
for i in range(len(hessian)):
for j in range(len(hessian[0])):
if i == j == 0:
assert np.allclose(hessian[i][j].numpy(),
numerical_hessian[i][j], self.rtol,
self.atol)
else:
assert hessian[i][j] is None
def test_create_graph_false(self):
def func(x):
return paddle.sum(paddle.matmul(x, x))
numerical_hessian = _compute_numerical_hessian(
func, self.x, self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
hessian = paddle.autograd.hessian(func, self.x)
assert hessian.stop_gradient == True
assert np.allclose(hessian.numpy(), numerical_hessian[0][0], self.rtol,
self.atol)
try:
paddle.grad(hessian, self.x)
except RuntimeError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("has no gradient") > 0
# TODO(levi): enable this test case when matmul_grad_grad_grad is ok
def _test_create_graph_true(self):
def func(x):
return paddle.sum(paddle.matmul(x, x))
numerical_hessian = _compute_numerical_hessian(
func, self.x, self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
hessian = paddle.autograd.hessian(func, self.x, create_graph=True)
assert hessian.stop_gradient == False
assert np.allclose(hessian.numpy(), numerical_hessian[0][0], self.rtol,
self.atol)
triple_grad = paddle.grad(hessian, self.x)
assert triple_grad is not None
class TestHessianFloat64(TestHessian):
@classmethod
def setUpClass(self):
self.shape = (2, 2)
self.dtype = 'float64'
self.np_dtype = np.float64
self.numerical_delta = 1e-5
self.rtol = 1e-5
self.atol = 1e-5
self.x = paddle.rand(shape=self.shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.shape, dtype=self.dtype)
if __name__ == "__main__":
unittest.main()
......@@ -16,65 +16,7 @@ import unittest
import numpy as np
import paddle
import paddle.compat as cpt
from paddle.autograd.functional import _check_tensors
def _product(t):
if isinstance(t, int):
return t
else:
return np.product(t)
def _get_item(t, idx):
assert isinstance(t, paddle.Tensor), "The first argument t must be Tensor."
assert isinstance(idx,
int), "The second argument idx must be an int number."
flat_t = paddle.reshape(t, [-1])
return flat_t.__getitem__(idx)
def _set_item(t, idx, value):
assert isinstance(t, paddle.Tensor), "The first argument t must be Tensor."
assert isinstance(idx,
int), "The second argument idx must be an int number."
flat_t = paddle.reshape(t, [-1])
flat_t.__setitem__(idx, value)
return paddle.reshape(flat_t, t.shape)
def _compute_numerical_jacobian(func, xs, delta, np_dtype):
xs = _check_tensors(xs, "xs")
ys = _check_tensors(func(*xs), "ys")
fin_size = len(xs)
fout_size = len(ys)
jacobian = list([] for _ in range(fout_size))
for i in range(fout_size):
jac_i = list([] for _ in range(fin_size))
for j in range(fin_size):
jac_i[j] = np.zeros(
(_product(ys[i].shape), _product(xs[j].shape)), dtype=np_dtype)
jacobian[i] = jac_i
for j in range(fin_size):
for q in range(_product(xs[j].shape)):
orig = _get_item(xs[j], q)
x_pos = orig + delta
xs[j] = _set_item(xs[j], q, x_pos)
ys_pos = _check_tensors(func(*xs), "ys_pos")
x_neg = orig - delta
xs[j] = _set_item(xs[j], q, x_neg)
ys_neg = _check_tensors(func(*xs), "ys_neg")
xs[j] = _set_item(xs[j], q, orig)
for i in range(fout_size):
for p in range(_product(ys[i].shape)):
y_pos = _get_item(ys_pos[i], p)
y_neg = _get_item(ys_neg[i], p)
jacobian[i][j][p][q] = (y_pos - y_neg) / delta / 2.
return jacobian
from utils import _compute_numerical_jacobian
class TestJacobian(unittest.TestCase):
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from paddle.autograd.functional import _check_tensors
def _product(t):
if isinstance(t, int):
return t
else:
return np.product(t)
def _get_item(t, idx):
assert isinstance(t, paddle.Tensor), "The first argument t must be Tensor."
assert isinstance(idx,
int), "The second argument idx must be an int number."
flat_t = paddle.reshape(t, [-1])
return flat_t.__getitem__(idx)
def _set_item(t, idx, value):
assert isinstance(t, paddle.Tensor), "The first argument t must be Tensor."
assert isinstance(idx,
int), "The second argument idx must be an int number."
flat_t = paddle.reshape(t, [-1])
flat_t.__setitem__(idx, value)
return paddle.reshape(flat_t, t.shape)
def _compute_numerical_jacobian(func, xs, delta, np_dtype):
xs = _check_tensors(xs, "xs")
ys = _check_tensors(func(*xs), "ys")
fin_size = len(xs)
fout_size = len(ys)
jacobian = list([] for _ in range(fout_size))
for i in range(fout_size):
jac_i = list([] for _ in range(fin_size))
for j in range(fin_size):
jac_i[j] = np.zeros(
(_product(ys[i].shape), _product(xs[j].shape)), dtype=np_dtype)
jacobian[i] = jac_i
for j in range(fin_size):
for q in range(_product(xs[j].shape)):
orig = _get_item(xs[j], q)
x_pos = orig + delta
xs[j] = _set_item(xs[j], q, x_pos)
ys_pos = _check_tensors(func(*xs), "ys_pos")
x_neg = orig - delta
xs[j] = _set_item(xs[j], q, x_neg)
ys_neg = _check_tensors(func(*xs), "ys_neg")
xs[j] = _set_item(xs[j], q, orig)
for i in range(fout_size):
for p in range(_product(ys[i].shape)):
y_pos = _get_item(ys_pos[i], p)
y_neg = _get_item(ys_neg[i], p)
jacobian[i][j][p][q] = (y_pos - y_neg) / delta / 2.
return jacobian
def _compute_numerical_hessian(func, xs, delta, np_dtype):
xs = _check_tensors(xs, "xs")
ys = _check_tensors(func(*xs), "ys")
fin_size = len(xs)
hessian = list([] for _ in range(fin_size))
for i in range(fin_size):
hessian_i = list([] for _ in range(fin_size))
for j in range(fin_size):
hessian_i[j] = np.zeros(
(_product(xs[i].shape), _product(xs[j].shape)), dtype=np_dtype)
hessian[i] = hessian_i
for i in range(fin_size):
for p in range(_product(xs[i].shape)):
for j in range(fin_size):
for q in range(_product(xs[j].shape)):
orig = _get_item(xs[j], q)
x_pos = orig + delta
xs[j] = _set_item(xs[j], q, x_pos)
jacobian_pos = _compute_numerical_jacobian(func, xs, delta,
np_dtype)
x_neg = orig - delta
xs[j] = _set_item(xs[j], q, x_neg)
jacobian_neg = _compute_numerical_jacobian(func, xs, delta,
np_dtype)
xs[j] = _set_item(xs[j], q, orig)
hessian[i][j][p][q] = (
jacobian_pos[0][i][0][p] - jacobian_neg[0][i][0][p]
) / delta / 2.
return hessian
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册