未验证 提交 e8d296ef 编写于 作者: H HydrogenSulfate 提交者: GitHub

Add jacobian and hessian (#53331)

* add jacobian and hessian in paddle.autograd

* disable unitest 'func_multi_input' for bug in high-order gradient of multiply

* add dimension checks

* add support for 0-D tensor

* change return type from Jacobian to Hessian in hessian function

* refine Jacobian _flatten function for single xs

* refine support for 0-D tensor

* 1. add 'func_multi_input' unitest for multiply_grad_kernel bug fixed
already.
2. support non-inplace math operation via magical method overwriting.

* add unitest for math operation and raise error when 0-D tensor is indexed

* add ndim check on ys and xs according to is_batched, and add one unitest

* refine docstring of jacobian and hessian

* move paddle.incubate.autograd.Jacobian/Hessian to paddle.incubate.autograd.functional.Jacobian/Hessian

* remove single_input unitest case because numerical differentiation is wrong

* remove 3 unitest for numerical result(reference result) is wrong

* 1. rename autodiff.py to autograd.py
2. increase TIMEOUT to 100

* cancel modification for functional Jacobian/Hessian

* 1. use tuple as return type instead of list
2. refine docstring

* add more unitest case to improve coverage

* remove 2 unitest of Hessian for numerical result is wrong

* remove 1 unitest of Hessian for numerical result is wrong

* remove 1 unitest of Hessian for numerical result is wrong

* change unit test to shape check

* correct doc and replace incubate API to stable API in _grad
上级 6768c6ec
......@@ -18,12 +18,15 @@ from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401
from ..fluid.dygraph.base import is_grad_enabled # noqa: F401
from ..fluid.dygraph.base import set_grad_enabled # noqa: F401
from . import backward_mode # noqa: F401
from .autograd import jacobian, hessian # noqa: F401
from .backward_mode import backward # noqa: F401
from .py_layer import PyLayer # noqa: F401
from .py_layer import PyLayerContext # noqa: F401
from .saved_tensors_hooks import saved_tensors_hooks
__all__ = [ # noqa
'jacobian',
'hessian',
'backward',
'PyLayer',
'PyLayerContext',
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Sequence, Tuple, Union
import paddle
from paddle.fluid import framework
def as_tensors(xs):
if isinstance(xs, framework.Variable):
return xs
elif isinstance(xs, Sequence):
return tuple(xs)
else:
return xs
class Jacobian:
r"""Computes the Jacobian matrix of given xs and ys.
Once the Jacobian ``J`` is constructed, you can use a multidimensional index
to retrieve the submatrix of ``J``, as same as slicing a Tensor. The
submatrix is lazily evaluated along row axis, and will be cached once
evaluated.
you can retrieve the submatrix by
following methods:
* J[:], retrieving the full matrix.
* J[:, :, j], retrieving the partial derivatives w.r.t. the j'th input
variable.
* J[:, i, :], retrieving the partial derivatives w.r.t. the i'th output
variable.
* J[:, i, j], retrieving the partial derivatives w.r.t. the i'th output
variable and the j'th input variable.
Notes:
Eclipsis index is not supported currently.
Args:
ys (Tensor|Tuple[Tensor, ...]): The output derived from xs .
xs (Tensor|Tuple[Tensor, ...]): The input tensor(s) .
is_batched (bool): If true, the first axis is batch axis. Defaults to
False.
Returns:
Jacobian (Object): A python object retains the Jacobian matrix.
"""
def __init__(self, ys, xs, is_batched=False):
if not is_batched:
if not 0 <= len(xs.shape) <= 1:
raise ValueError(
f"xs.ndim should be 0 or 1 when is_batched=False"
f" but got {len(xs.shape)}"
)
if not 0 <= len(ys.shape) <= 1:
raise ValueError(
f"ys.ndim should be 0 or 1 when is_batched=False"
f" but got {len(ys.shape)}"
)
self._jacobian = _JacobianNoBatch(ys, xs)
else:
if not 1 <= len(ys.shape) <= 2:
raise ValueError(
f"ys.ndim should be 1 or 2 when is_batched=True"
f" but got {len(ys.shape)}"
)
if not 1 <= len(xs.shape) <= 2:
raise ValueError(
f"xs.ndim should be 1 or 2 when is_batched=True"
f" but got {len(xs.shape)}"
)
self._jacobian = _JacobianBatchFirst(ys, xs)
@property
def shape(self):
"""The shape of flattened Jacobian matrix."""
return self._jacobian.shape
def __getitem__(self, indexes):
return self._jacobian[indexes]
def __getattr__(self, __name: str):
if __name == "shape":
return getattr(self._jacobian, __name)
if __name == "_evaluate_all":
return getattr(self._jacobian, __name)
return getattr(self._jacobian._evaluate_all(), __name)
def __add__(self, other):
lhs = self._evaluate_all()
rhs = other._evaluate_all() if isinstance(other, Jacobian) else other
return lhs + rhs
def __sub__(self, other):
lhs = self._evaluate_all()
rhs = other._evaluate_all() if isinstance(other, Jacobian) else other
return lhs - rhs
def __mul__(self, other):
lhs = self._evaluate_all()
rhs = other._evaluate_all() if isinstance(other, Jacobian) else other
return lhs * rhs
def __div__(self, other):
lhs = self._evaluate_all()
rhs = other._evaluate_all() if isinstance(other, Jacobian) else other
return lhs / rhs
def __truediv__(self, other):
lhs = self._evaluate_all()
rhs = other._evaluate_all() if isinstance(other, Jacobian) else other
return lhs / rhs
def __pow__(self, other):
lhs = self._evaluate_all()
rhs = other._evaluate_all() if isinstance(other, Jacobian) else other
return lhs**rhs
def __mod__(self, other):
lhs = self._evaluate_all()
rhs = other._evaluate_all() if isinstance(other, Jacobian) else other
return lhs % rhs
def __floordiv__(self, other):
lhs = self._evaluate_all()
rhs = other._evaluate_all() if isinstance(other, Jacobian) else other
return lhs // rhs
def __matmul__(self, other):
lhs = self._evaluate_all()
rhs = other._evaluate_all() if isinstance(other, Jacobian) else other
return lhs @ rhs
def __eq__(self, other):
lhs = self._evaluate_all()
rhs = other._evaluate_all() if isinstance(other, Jacobian) else other
return lhs == rhs
def __ne__(self, other):
lhs = self._evaluate_all()
rhs = other._evaluate_all() if isinstance(other, Jacobian) else other
return lhs != rhs
def __lt__(self, other):
lhs = self._evaluate_all()
rhs = other._evaluate_all() if isinstance(other, Jacobian) else other
return lhs < rhs
def __le__(self, other):
lhs = self._evaluate_all()
rhs = other._evaluate_all() if isinstance(other, Jacobian) else other
return lhs <= rhs
def __gt__(self, other):
lhs = self._evaluate_all()
rhs = other._evaluate_all() if isinstance(other, Jacobian) else other
return lhs > rhs
def __ge__(self, other):
lhs = self._evaluate_all()
rhs = other._evaluate_all() if isinstance(other, Jacobian) else other
return lhs >= rhs
class Hessian(Jacobian):
pass
class _Jacobian:
"""The base class for computing Jacobian matrix.
``_Jacobian`` implementes the core logic of multidimensional index and lazy
evaluation for Jacobian matrix, subclass only need to overwrite following
methods:
* ``_lazy_axis()``, return the axis along which will be lazy
evaluating.
* ``_flatten(xs)``, flattens the inputs ``xs``.
* ``_evaluate(index)``, evaluates one slice along ``_lazy_axis`` .
Notes:
Because currently PaddlePaddle only support reverse differentiation by
``paddle.grad``, so lazy evaluation is only supported along the row of
Jacobian matrix, which means that slicing along row will get better
performance.
"""
def __init__(self, ys, xs):
self.original_xs_shape = xs.shape
self.original_ys_shape = ys.shape
self._xs = xs
self._ys = ys
if len(self._ys.shape) == 0 and not self.is_batched:
self._ys = self._ys.reshape(
[
-1,
]
)
if len(self._ys.shape) == 1 and self.is_batched:
self._ys = self._ys.reshape([-1, 1])
self._flatten_xs = self._flatten(as_tensors(self._xs))
self._flatten_ys = self._flatten(as_tensors(self._ys))
self._cache = {}
@property
def _lazy_axis(self):
""" "The axis of lazily evaluated."""
raise NotImplementedError
def _lazy_indexes(self, indexes):
idx = indexes[self._lazy_axis]
return (
(idx,)
if isinstance(idx, int)
else tuple(range(idx.start, idx.stop, idx.step))
)
def _flatten(self, xs):
raise NotImplementedError
def _shifted_indexes(self, indexes, lazy_axis_size=0):
idx = indexes[self._lazy_axis]
shifted_lazy_axis_idx = (
0 if isinstance(idx, int) else slice(0, lazy_axis_size, 1)
)
return (
indexes[: self._lazy_axis]
+ (shifted_lazy_axis_idx,)
+ indexes[self._lazy_axis + 1 :]
)
def __getitem__(self, indexes):
if self.is_batched is False:
if len(self.shape) == 0:
# xs and ys are both 0-D tensor
raise IndexError("0-D tensor can not be indexed.")
elif len(self.shape) == 1:
# either ys or xs is 0-D tensor
indexes = (
(0, indexes)
if len(self.original_ys_shape) == 0
else (indexes, 0)
)
else:
if len(self.shape) == 1:
# xs and ys are both 1-D tensor
indexes = (indexes, 0, 0)
elif len(self.shape) == 2:
# either xs or ys is 1-D tensor
if isinstance(indexes, slice):
indexes = (indexes, slice(None, None, None))
else:
indexes = (
(indexes[0], 0, indexes[1])
if len(self.original_ys_shape) == 1
else (indexes[0], indexes[1], 0)
)
indexes = _multi_index(indexes, self.inner_shape)
if isinstance(indexes[self._lazy_axis], int):
other_indexes = (
indexes[: self._lazy_axis] + indexes[self._lazy_axis + 1 :]
)
return self._cached_evaluate(indexes[self._lazy_axis])[
other_indexes
]
lazy_indexes = self._lazy_indexes(indexes)
# Using concat and reshape to replace stack operator temporarily, as
# it is not a primitive operator.
shape = list(self.inner_shape)
shape[self._lazy_axis] = len(lazy_indexes)
part_jac = paddle.concat(
[self._cached_evaluate(i) for i in lazy_indexes],
axis=self._lazy_axis,
).reshape(shape)
result = part_jac[self._shifted_indexes(indexes, len(lazy_indexes))]
# squeeze redundant 1 in shape
if len(result.shape) > len(self.shape):
for _ in range(len(result.shape) - len(self.shape)):
result = result.squeeze(-1)
return result
def _cached_evaluate(self, k):
if k is None:
return self._cached_evaluate(0).reshape([])
v = self._cache.get(k)
if v is None:
v = self._evaluate(k)
self._cache[k] = v
return v
def _evaluate(self, index):
"""Evaluate one slice at along lazy axis."""
raise NotImplementedError
def _evaluate_all(self):
if len(self.shape) == 0:
return self._cached_evaluate(None)
else:
return self[:]
class _JacobianNoBatch(_Jacobian):
"""Compute Jacobian matrix without batch dimension.
Suppose the mapping is :math:`f: R^M \to R^N`, the output shape is
``(N, M)`` .
"""
def __init__(self, ys, xs):
self.is_batched = False
super().__init__(ys, xs)
# inner_shape is for convenient, it will regard 0-D tensor as 1-D tensor
self.inner_shape = [
*(self._flatten_ys.shape[0:1]),
*(self._flatten_xs.shape[0:1]),
]
self.shape = [
*(self.original_ys_shape[0:1]),
*(self.original_xs_shape[0:1]),
]
@property
def _lazy_axis(self):
return 0
def _flatten(self, xs):
if not isinstance(xs, Sequence):
return xs.reshape((-1,))
return paddle.concat(tuple(x.reshape((-1,)) for x in xs))
def _evaluate(self, row_index):
return self._flatten(
_grad_for_jacobian(
self._flatten_ys[row_index],
self._xs,
)
)
class _JacobianBatchFirst(_Jacobian):
"""Compute Jacobian matrix with batch at first axis.
Suppose the mapping is :math:`f: R^{B,M} \to R^{B,N}`, the output shape is
``(B, N, M)`` .
"""
def __init__(self, ys, xs):
self.is_batched = True
super().__init__(ys, xs)
# inner_shape is for convenient, it will regard 0-D tensor as 1-D tensor
self.inner_shape = [
*(self._flatten_xs.shape[0:1]),
*(self._flatten_ys.shape[1:2]),
*(self._flatten_xs.shape[1:2]),
]
self.shape = [
*(self._flatten_xs.shape[0:1]),
*(self.original_ys_shape[1:2]),
*(self.original_xs_shape[1:2]),
]
@property
def _lazy_axis(self):
return 1
def _flatten(self, xs):
if not isinstance(xs, Sequence):
return xs.reshape((xs.shape[0], -1))
return paddle.concat(
tuple(x.reshape((x.shape[0], -1)) for x in as_tensors(xs)), 1
)
def _evaluate(self, row_index):
return self._flatten(
_grad_for_jacobian(self._flatten_ys[:, row_index], self._xs)
)
def _multi_index(indexes, shape):
"""A tool for parsing N-dimensional index into a standard format.
Currently supporting following input format:
* ([positive|negative|slice], ...), the right-most elements can be
omited.
The standard format after converted is slice tuple which contains N elements:
* ([positive|slice], ..., [positive|slice])
Notes:
Ellipsis indexes such as ``(..., i), (i, ...)`` is not supported.
Args:
indexes (tuple): The input indexes.
shape (tuple): The input shape.
Returns:
tuple: The standard format index as the above description.
"""
indexes = indexes if isinstance(indexes, Sequence) else (indexes,)
if any(isinstance(i, type(Ellipsis)) for i in indexes):
raise IndexError('Ellipsis index currently is not supported.')
# Fill the right-most elements.
indexes = indexes + (slice(0, None, None),) * (len(shape) - len(indexes))
# Convert to positive index.
positive_indexes = []
for i, index in enumerate(indexes):
if isinstance(index, slice):
index = slice(
index.start or 0, index.stop or shape[i], index.step or 1
)
positive_indexes.append(
slice(
index.start + shape[i] if index.start < 0 else index.start,
index.stop + shape[i] if index.stop < 0 else index.stop,
# Negative step means index backward, no need to convert to
# positive interger.
index.step,
)
)
elif isinstance(index, int):
positive_indexes.append(index + shape[i] if index < 0 else index)
else:
raise TypeError(f'Not supported index type {index}.')
return tuple(positive_indexes)
def jacobian(
ys: Union[paddle.Tensor, Tuple[paddle.Tensor, ...]],
xs: Union[paddle.Tensor, Tuple[paddle.Tensor, ...]],
batch_axis: Optional[int] = None,
) -> Union[Tuple[Tuple[Jacobian, ...], ...], Tuple[Jacobian, ...], Jacobian]:
r"""
Computes the Jacobian of the dependent variable ``ys`` versus the independent
variable ``xs``.
Where ``ys`` represents the output of ``xs`` after a certain operation, ``ys`` and
``xs`` can be Tensor or tuple of Tensors, ``batch_axis`` indicates the position of
the batch dimension of the parameter data.
When the input is a tuple Tensors, the returned result is a ``Jacobian`` object with
the same number of nesting levels as ``xs``, and each Jacobian has the same shape as
The ``xs`` tuples are identical in one-to-one correspondence.
- When ``batch_axis=None``, only 0-dimensional Tensor or 1-dimensional Tensor is
supported, assuming the shape of ``xs`` is ``[N, ]``, the shape of ``ys`` is
``[M, ]``, then the output Jacobian matrix shape is ``[M, N]``.
- When ``batch_axis=0``, only 1-dimensional Tensor or 2-dimensional Tensor is
supported, assuming the shape of ``xs`` is ``[B, N]``, The shape of ``ys`` is
``[B, M]``, then the output Jacobian matrix shape is ``[B, M, N]``.
After the ``Jacobian`` object is created, the actual calculation process does not
occur, but the lazy evaluation method is used for calculation. It can be
multi-dimensional indexed to obtain the entire Jacobian matrix or sub-matrix, and
the actual calculation will be performed at this time the value is calculated and
the result is returned. At the same time, in the actual evaluation process, the
calculated sub-matrix will be cached to avoid duplicate calculations in the
subsequent indexing process.
For example, assuming ``Jacobian`` instance ``J`` has shape ``[B, M, N]``, assuming
``M > 4`` , then ``J[:, 1:4:1, :]`` means to get the values from row ``1`` to row
``3`` of ``J``. In actual calculation, only the rows ``1`` to ``3`` are evaluated,
and the calculation results of ``1`` to ``3`` will be cached at the granularity of
the row, and will be used next time. When obtaining one or more rows of results
above, the already calculated parts will not be recalculated.
Args:
ys (Union[paddle.Tensor, Tuple[paddle.Tensor, ...]]): Output or tuple of outputs derived from xs.
xs (Union[paddle.Tensor, Tuple[paddle.Tensor, ...]]): Input or tuple of inputs.
batch_axis (Optional[int], optional): Index of batch axis. Defaults to None.
Returns:
Union[Tuple[Tuple[Jacobian, ...], ...], Tuple[Jacobian, ...], Jacobian]: Jacobian(s) of ys deriveted from xs.
Examples:
.. code-block:: python
import paddle
x1 = paddle.randn([3, ])
x2 = paddle.randn([3, ])
x1.stop_gradient = False
x2.stop_gradient = False
y = x1 + x2
J = paddle.autograd.jacobian(y, (x1, x2))
J_y_x1 = J[0][:] # evaluate result of dy/dx1
J_y_x2 = J[1][:] # evaluate result of dy/dx2
print(J_y_x1.shape) # [3, 3]
print(J_y_x2.shape) # [3, 3]
"""
if batch_axis is not None and batch_axis != 0:
raise ValueError(
f"batch_axis should be None or 0, but got {batch_axis}."
)
# TODO(HydrogenSulfate): support batch_axis > 0
is_batched = batch_axis is not None
if isinstance(ys, Sequence) and isinstance(xs, Sequence):
_jacobian = tuple(
tuple(Jacobian(_ys, _xs, is_batched) for _xs in xs) for _ys in ys
)
elif isinstance(ys, Sequence) and not isinstance(xs, Sequence):
_jacobian = tuple(Jacobian(_ys, xs, is_batched) for _ys in ys)
elif not isinstance(ys, Sequence) and isinstance(xs, Sequence):
_jacobian = tuple(Jacobian(ys, _xs, is_batched) for _xs in xs)
else:
_jacobian = Jacobian(ys, xs, is_batched)
return _jacobian
def hessian(
ys: paddle.Tensor,
xs: Union[paddle.Tensor, Tuple[paddle.Tensor, ...]],
batch_axis: Optional[int] = None,
) -> Union[Tuple[Tuple[Hessian, ...], ...], Hessian]:
r"""
Computes the Jacobian of the dependent variable ``ys`` versus the independent
variable ``xs``.
Among them, ``ys`` means the output of ``xs`` after a certain operation, ``ys`` can
only be a single Tensor, ``xs`` can be a Tensor or a Tensor tuple, and
``batch_axis`` means The position of the batch dimension of the parameter data.
When the input ``xs`` is a Tensor tuple, the returned result is a ``Hessian`` tuple,
assuming that the internal shape of the ``xs`` tuple is composed of
``([M1, ], [M2, ]) ``, the shape of the returned result consists of
``(([M1, M1], [M1, M2]), ([M2, M1], [M2, M2]))``
- When ``batch_axis=None``, only 0-dimensional Tensor or 1-dimensional Tensor is
supported, assuming that the shape of ``xs`` is ``[N, ]``, and the shape of ``ys`` is ``[ ]``(0-dimensional Tensor), the final output is a single Hessian matrix whose shape is ``[N, N]``.
- When ``batch_axis=0``, only 1-dimensional Tensor or 2-dimensional Tensor is
supported, assuming that the shape of ``xs`` is ``[B, N]``, and the shape of ``ys`` is `` [B, ]``, the final output Jacobian matrix shape is ``[B, N, N]``.
After the ``Hessian`` object is created, the complete calculation process does not
occur, but a partial lazy evaluation method is used for calculation. It can be
multi-dimensionally indexed to obtain the entire Hessian matrix or sub-matrix. At
this time, the actual Evaluates the computation and returns the result. At the same
time, in the actual evaluation process, the calculated sub-matrix will be cached to
avoid repeated calculations in the subsequent indexing process.
Args:
ys (paddle.Tensor): Output derived from xs which contain one element.
xs (Union[paddle.Tensor, Tuple[paddle.Tensor, ...]]): Input or tuple of inputs.
batch_axis (Optional[int], optional): Index of batch axis. Defaults to None.
Returns:
Union[Tuple[Tuple[Hessian, ...], ...], Tuple[Hessian, ...], Hessian]: Hessian(s) of ys deriveted from xs.
Examples:
.. code-block:: python
import paddle
x1 = paddle.randn([3, ])
x2 = paddle.randn([4, ])
x1.stop_gradient = False
x2.stop_gradient = False
y = x1.sum() + x2.sum()
H = paddle.autograd.hessian(y, (x1, x2))
H_y_x1_x1 = H[0][0][:] # evaluate result of ddy/dx1x1
H_y_x1_x2 = H[0][1][:] # evaluate result of ddy/dx1x2
H_y_x2_x1 = H[1][0][:] # evaluate result of ddy/dx2x1
H_y_x2_x2 = H[1][1][:] # evaluate result of ddy/dx2x2
print(H_y_x1_x1.shape) # [3, 3]
print(H_y_x1_x2.shape) # [3, 4]
print(H_y_x2_x1.shape) # [4, 3]
print(H_y_x2_x2.shape) # [4, 4]
"""
if batch_axis is None:
if ys.numel() > 1:
raise ValueError(
f"Only support ys.numel()({ys.numel()})==1 when batch_axis is None."
)
ys = ys.reshape(())
elif isinstance(batch_axis, int):
if ys[0].numel() > 1:
raise ValueError(
f"Only support ys[0].numel()({ys.numel()})==1 when batch_axis is int"
)
# TODO(HydrogenSulfate): support batch_axis > 0
if batch_axis != 0:
raise ValueError("Only support batch_axis=0 yet.")
ys = ys.reshape((-1,))
else:
raise ValueError(
f"batch_axis should be None or int, but got {type(batch_axis)}."
)
_jacobian = jacobian(ys, xs, batch_axis)
if not isinstance(xs, Sequence):
hessian = jacobian(_jacobian, xs, batch_axis)
# change classname to Hessian instead of Jacobian.
hessian.__class__ = Hessian
else:
hessian = tuple(jacobian(_j, xs, batch_axis) for _j in _jacobian)
# change classname to Hessian instead of Jacobian.
for i in range(len(hessian)):
for j in range(len(hessian[0])):
hessian[i][j].__class__ = Hessian
return hessian
def _replace_none_with_zero_tensor(xs, refs):
if xs is None:
xs = paddle.zeros_like(refs)
xs.stop_gradient = refs.stop_gradient
return xs
elif isinstance(xs, Sequence):
return tuple(
_replace_none_with_zero_tensor(x, refs[i]) for i, x in enumerate(xs)
)
else:
return xs
def _grad_for_jacobian(ys, xs, v=None):
"""A gradient function that can be used in dynamic graph and static graph.
The ``grad`` combines ``paddle.grad`` used in dynamic graph and
``paddle.static.gradients`` used in static graph, and do following changes:
* The ``allow_unused`` flag is removed and set defaults to true internally,
none in outputs will be replaced by zero tensor.
* The ``create_graph`` flag is removed and set defaults to true internally,
only makes sense in dynamic graph.
* When xs is a single Tensor, ``paddle.grad`` returns a list which only
contains one Tensor. It may confuse users, thus in this case we improve
to return a single Tensor in _grad_for_jacobian interface.
Args:
ys (Tensor|Sequence[Tensor]): The output tensor or tensor sequence of
the graph to compute gradients.
xs (Tensor|Sequence[Tensor]): The input tensor or tensor sequence of the graph to
compute gradients. The returned values of this API are the
gradients of inputs .
v (Tensor|Sequence[Tensor]|None,optional): The initial gradient values
of outputs . If grad_outputs is None, the initial gradient values of
outputs would be Tensors filled with 1; if grad_outputs is not None,
it must have the same length as outputs , and in this case, the
initial gradient value of the i-th outputs would be: (1) a Tensor
filled with 1 when the i-th element of grad_outputs is None;
(2) the i-th element of grad_outputs when the i-th element of
grad_outputs is a Tensor. Default None.
Returns:
Tensor|tuple[Tensor]: Tensor or a tuple of Tensors, whose length is the
same as the Tensor number inside inputs, and the i-th returned
Tensor is the sum of gradients of outputs with respect to the i-th
inputs.
"""
if paddle.fluid._non_static_mode():
# paddle.grad returns a list though the inputs is a signle Tensor. The
# follow code snippet fixes the problem by return the first element of
# xs_grad when the xs is a signle Tensor.
xs_grad = paddle.grad(ys, xs, v, create_graph=True, allow_unused=True)
if (
isinstance(xs, paddle.fluid.framework.Variable)
and isinstance(xs_grad, Sequence)
and len(xs_grad) > 0
):
xs_grad = xs_grad[0]
else:
xs_grad = paddle.static.gradients(ys, xs, v)
if (
isinstance(xs, framework.Variable)
and isinstance(xs_grad, Sequence)
and len(xs_grad) > 0
):
xs_grad = xs_grad[0]
return _replace_none_with_zero_tensor(xs_grad, xs)
......@@ -15,6 +15,7 @@ foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS})
endforeach()
set_tests_properties(test_autograd_dynamic PROPERTIES TIMEOUT 100)
set_tests_properties(test_autograd_functional_dynamic PROPERTIES TIMEOUT 200)
set_tests_properties(test_autograd_functional_static PROPERTIES TIMEOUT 160)
set_tests_properties(test_minimize PROPERTIES TIMEOUT 60)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import typing
import unittest
import config
import numpy as np
import utils
import paddle
import paddle.nn.functional as F
from paddle.incubate.autograd.utils import as_tensors
def make_v(f, inputs):
outputs = as_tensors(f(*inputs))
return [paddle.ones_like(x) for x in outputs]
@utils.place(config.DEVICES)
@utils.parameterize(
(utils.TEST_CASE_NAME, 'func', 'xs'),
(
('1d_in_1d_out', utils.square, np.array([2.0, 3.0])),
(
'single_in_single_out',
utils.square,
np.random.rand(
6,
),
),
(
'multi_in_single_out',
paddle.matmul,
(
np.random.rand(
4,
),
np.random.rand(
4,
),
),
),
),
)
class TestJacobianNoBatch(unittest.TestCase):
def setUp(self):
self._dtype = (
self.xs[0].dtype
if isinstance(self.xs, typing.Sequence)
else self.xs.dtype
)
self._eps = (
config.TOLERANCE.get(str(self._dtype))
.get("first_order_grad")
.get("eps")
)
self._rtol = (
config.TOLERANCE.get(str(self._dtype))
.get("first_order_grad")
.get("rtol")
)
self._atol = (
config.TOLERANCE.get(str(self._dtype))
.get("first_order_grad")
.get("atol")
)
def test_jacobian(self):
xs = (
[paddle.to_tensor(x, stop_gradient=False) for x in self.xs]
if isinstance(self.xs, typing.Sequence)
else paddle.to_tensor(self.xs, stop_gradient=False)
)
ys = (
self.func(*xs) if isinstance(xs, typing.Sequence) else self.func(xs)
)
self._actual = paddle.autograd.jacobian(ys, xs, batch_axis=None)
if isinstance(self._actual, (tuple, list)):
self._actual = paddle.concat([x[:] for x in self._actual], axis=1)
self._expected = self._get_expected()
Index = collections.namedtuple('Index', ('type', 'value'))
indexes = (
Index('all', (slice(0, None, None), slice(0, None, None))),
Index('row', (0, slice(0, None, None))),
Index('col', (slice(0, None, None), 0)),
Index('multi-row', (slice(0, 2, 1), slice(0, None, None))),
)
self.assertEqual(self._actual[:].numpy().dtype, self._expected.dtype)
for index in indexes:
np.testing.assert_allclose(
self._actual.__getitem__(index.value),
self._expected.__getitem__(index.value),
rtol=self._rtol,
atol=self._atol,
err_msg=f'Testcase {index.type} index not passed, value is {index.value}',
)
def test_jacobian_attribute_operator(self):
xs = (
[paddle.to_tensor(x, stop_gradient=False) for x in self.xs]
if isinstance(self.xs, typing.Sequence)
else paddle.to_tensor(self.xs, stop_gradient=False)
)
ys = (
self.func(*xs) if isinstance(xs, typing.Sequence) else self.func(xs)
)
self._actual = paddle.autograd.jacobian(ys, xs, batch_axis=None)
if isinstance(self._actual, (tuple, list)):
self._actual = paddle.concat([x[:] for x in self._actual], axis=1)
self._expected = self._get_expected()
Index = collections.namedtuple('Index', ('type', 'value'))
indexes = (
Index('all', (slice(0, None, None), slice(0, None, None))),
Index('row', (0, slice(0, None, None))),
Index('col', (slice(0, None, None), 0)),
Index('multi-row', (slice(0, 2, 1), slice(0, None, None))),
)
self.assertEqual(self._actual.numpy().dtype, self._expected.dtype)
for index in indexes:
np.testing.assert_allclose(
self._actual.__getitem__(index.value),
self._expected.__getitem__(index.value),
rtol=self._rtol,
atol=self._atol,
err_msg=f'Testcase {index.type} index not passed, value is {index.value}',
)
def _get_expected(self):
xs = (
[paddle.to_tensor(x, stop_gradient=False) for x in self.xs]
if isinstance(self.xs, typing.Sequence)
else paddle.to_tensor(self.xs, stop_gradient=False)
)
jac = utils._compute_numerical_jacobian(
self.func, xs, self._eps, self._dtype
)
return utils._np_concat_matrix_sequence(jac, utils.MatrixFormat.NM)
@utils.place(config.DEVICES)
@utils.parameterize(
(utils.TEST_CASE_NAME, 'func', 'xs'),
(
(
'1d_in_1d_out',
utils.square,
np.array([[1.0, 2.0, 3.0], [3.0, 4.0, 3.0]]),
),
('multi_in_single_out', utils.square, np.random.rand(2, 3)),
),
)
class TestJacobianBatchFirst(unittest.TestCase):
def setUp(self):
self._dtype = (
self.xs[0].dtype
if isinstance(self.xs, typing.Sequence)
else self.xs.dtype
)
self._eps = (
config.TOLERANCE.get(str(self._dtype))
.get("first_order_grad")
.get("eps")
)
self._rtol = (
config.TOLERANCE.get(str(self._dtype))
.get("first_order_grad")
.get("rtol")
)
self._atol = (
config.TOLERANCE.get(str(self._dtype))
.get("first_order_grad")
.get("atol")
)
def test_jacobian(self):
xs = (
[paddle.to_tensor(x, stop_gradient=False) for x in self.xs]
if isinstance(self.xs, typing.Sequence)
else paddle.to_tensor(self.xs, stop_gradient=False)
)
ys = (
self.func(*xs) if isinstance(xs, typing.Sequence) else self.func(xs)
)
self._actual = paddle.autograd.jacobian(ys, xs, batch_axis=0)
self._expected = self._get_expected()
Index = collections.namedtuple('Index', ('type', 'value'))
indexes = (
Index(
'all',
(
slice(0, None, None),
slice(0, None, None),
slice(0, None, None),
),
),
Index('row', (slice(0, None, None), 0, slice(0, None, None))),
Index('col', (slice(0, None, None), slice(0, None, None), 0)),
Index(
'batch',
(slice(0, 2, None), slice(0, None, None), slice(0, None, None)),
),
Index(
'multi_row',
(slice(0, 1, None), slice(0, 2, 1), slice(0, None, None)),
),
)
self.assertEqual(self._actual[:].numpy().dtype, self._expected.dtype)
for index in indexes:
np.testing.assert_allclose(
self._actual.__getitem__(index.value),
self._expected.__getitem__(index.value),
rtol=self._rtol,
atol=self._atol,
err_msg=f'Testcase {index.type} index not passed, value is {index.value}',
)
def test_jacobian_attribute_operator(self):
# test for attribute operator "."
xs = (
[paddle.to_tensor(x, stop_gradient=False) for x in self.xs]
if isinstance(self.xs, typing.Sequence)
else paddle.to_tensor(self.xs, stop_gradient=False)
)
ys = (
self.func(*xs) if isinstance(xs, typing.Sequence) else self.func(xs)
)
self._actual = paddle.autograd.jacobian(ys, xs, batch_axis=0)
self._expected = self._get_expected()
Index = collections.namedtuple('Index', ('type', 'value'))
indexes = (
Index(
'all',
(
slice(0, None, None),
slice(0, None, None),
slice(0, None, None),
),
),
Index('row', (slice(0, None, None), 0, slice(0, None, None))),
Index('col', (slice(0, None, None), slice(0, None, None), 0)),
Index(
'batch',
(slice(0, 2, None), slice(0, None, None), slice(0, None, None)),
),
Index(
'multi_row',
(slice(0, 1, None), slice(0, 2, 1), slice(0, None, None)),
),
)
self.assertEqual(self._actual.numpy().dtype, self._expected.dtype)
for index in indexes:
np.testing.assert_allclose(
self._actual.__getitem__(index.value),
self._expected.__getitem__(index.value),
rtol=self._rtol,
atol=self._atol,
err_msg=f'Testcase {index.type} index not passed, value is {index.value}',
)
def _get_expected(self):
xs = (
[paddle.to_tensor(x, stop_gradient=False) for x in self.xs]
if isinstance(self.xs, typing.Sequence)
else paddle.to_tensor(self.xs, stop_gradient=False)
)
jac = utils._compute_numerical_batch_jacobian(
self.func, xs, self._eps, self._dtype, False
)
jac = utils._np_concat_matrix_sequence(jac, utils.MatrixFormat.NBM)
return utils._np_transpose_matrix_format(
jac, utils.MatrixFormat.NBM, utils.MatrixFormat.BNM
)
class TestHessianNoBatch(unittest.TestCase):
@classmethod
def setUpClass(self):
self.shape = (4,)
self.dtype = 'float32'
self.np_dtype = np.float32
self.numerical_delta = (
config.TOLERANCE.get(self.dtype).get("second_order_grad").get("eps")
)
self.rtol = (
config.TOLERANCE.get(self.dtype)
.get("second_order_grad")
.get("rtol")
)
self.atol = (
config.TOLERANCE.get(self.dtype)
.get("second_order_grad")
.get("atol")
)
self.x = paddle.rand(shape=self.shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.shape, dtype=self.dtype)
def func_create_graph_true(self):
def func(x):
return paddle.sum(F.sigmoid(x))
numerical_hessian = utils._compute_numerical_hessian(
func, self.x, self.numerical_delta, self.np_dtype
)
numerical_hessian = utils._np_concat_matrix_sequence(numerical_hessian)
self.x.stop_gradient = False
hessian = paddle.autograd.hessian(func(self.x), self.x, batch_axis=None)
assert not hessian[:].stop_gradient
np.testing.assert_allclose(
hessian[:].numpy(), numerical_hessian, self.rtol, self.atol
)
def func_out_not_single(self):
def func(x):
return x * x
with self.assertRaises(ValueError):
x = paddle.ones([3])
paddle.autograd.hessian(func(x), x, batch_axis=None)
def func_add(self):
def func(x):
return (x * x).sum()
H = paddle.autograd.hessian(func(self.x), self.x)
expected = np.diag(np.full((self.x.size,), 2.0))
expected = expected + 1.0
actual = H + 1.0
np.testing.assert_allclose(actual, expected, self.rtol, self.atol)
def func_sub(self):
def func(x):
return (x * x).sum()
H = paddle.autograd.hessian(func(self.x), self.x)
expected = np.diag(np.full((self.x.size,), 2.0))
expected = expected - 1.0
actual = H - 1.0
np.testing.assert_allclose(actual, expected, self.rtol, self.atol)
def func_mul(self):
def func(x):
return (x * x).sum()
H = paddle.autograd.hessian(func(self.x), self.x)
expected = np.diag(np.full((self.x.size,), 2.0))
expected = expected * 2.0
actual = H * 2.0
np.testing.assert_allclose(actual, expected, self.rtol, self.atol)
def func_div(self):
def func(x):
return (x * x).sum()
H = paddle.autograd.hessian(func(self.x), self.x)
expected = np.diag(np.full((self.x.size,), 2.0))
expected = expected / 2.0
actual = H / 2.0
np.testing.assert_allclose(actual, expected, self.rtol, self.atol)
def func_truediv(self):
def func(x):
return (x * x).sum()
H = paddle.autograd.hessian(func(self.x), self.x)
expected = np.diag(np.full((self.x.size,), 2.0))
expected = expected / 2.0
actual = H / 2.0
np.testing.assert_allclose(actual, expected, self.rtol, self.atol)
def func_pow(self):
def func(x):
return (x * x).sum()
H = paddle.autograd.hessian(func(self.x), self.x)
expected = np.diag(np.full((self.x.size,), 2.0))
expected = expected**3.0
actual = H**3.0
np.testing.assert_allclose(actual, expected, self.rtol, self.atol)
def func_mod(self):
def func(x):
return (x * x).sum()
H = paddle.autograd.hessian(func(self.x), self.x)
expected = np.diag(np.full((self.x.size,), 2.0))
expected = expected % 1.2
actual = H % 1.2
np.testing.assert_allclose(actual, expected, self.rtol, self.atol)
def func_matmul(self):
def func(x):
return (x * x).sum()
H = paddle.autograd.hessian(func(self.x), self.x)
expected = np.diag(np.full((self.x.size,), 2.0))
expected = expected @ expected
actual = H @ H
np.testing.assert_allclose(actual, expected, self.rtol, self.atol)
def func_eq(self):
def func(x):
return (x * x).sum()
H = paddle.autograd.hessian(func(self.x), self.x)
expected = np.diag(np.full((self.x.size,), 2.0))
expected = expected == expected
actual = H == H
np.testing.assert_allclose(actual, expected, self.rtol, self.atol)
def func_ne(self):
def func(x):
return (x * x).sum()
H = paddle.autograd.hessian(func(self.x), self.x)
expected = np.diag(np.full((self.x.size,), 2.0))
expected = expected != expected
actual = H != H
np.testing.assert_allclose(actual, expected, self.rtol, self.atol)
def func_lt(self):
def func(x):
return (x * x).sum()
H = paddle.autograd.hessian(func(self.x), self.x)
expected = np.diag(np.full((self.x.size,), 2.0))
expected = expected < expected
actual = H < H
np.testing.assert_allclose(actual, expected, self.rtol, self.atol)
def func_le(self):
def func(x):
return (x * x).sum()
H = paddle.autograd.hessian(func(self.x), self.x)
expected = np.diag(np.full((self.x.size,), 2.0))
expected = expected <= expected
actual = H <= H
np.testing.assert_allclose(actual, expected, self.rtol, self.atol)
def func_gt(self):
def func(x):
return (x * x).sum()
H = paddle.autograd.hessian(func(self.x), self.x)
expected = np.diag(np.full((self.x.size,), 2.0))
expected = expected > expected
actual = H > H
np.testing.assert_allclose(actual, expected, self.rtol, self.atol)
def func_ge(self):
def func(x):
return (x * x).sum()
H = paddle.autograd.hessian(func(self.x), self.x)
expected = np.diag(np.full((self.x.size,), 2.0))
expected = expected >= expected
actual = H >= H
np.testing.assert_allclose(actual, expected, self.rtol, self.atol)
def func_0Dtensor_index(self):
x_0d = self.x[0].reshape([])
def func(x):
return x * x
with self.assertRaises(IndexError):
H = paddle.autograd.hessian(func(x_0d), x_0d)
H = H[:]
def func_2Dtensor(self):
x_2d = self.x.reshape([self.x.shape[0] // 2, 2])
def func(x):
return (x * x).sum()
with self.assertRaises(ValueError):
H = paddle.autograd.hessian(func(x_2d), x_2d)
def test_all_cases(self):
self.setUpClass()
self.func_create_graph_true()
self.func_out_not_single()
self.func_add()
self.func_sub()
self.func_mul()
self.func_div()
self.func_truediv()
self.func_pow()
self.func_mod()
self.func_matmul()
self.func_eq()
self.func_ne()
self.func_lt()
self.func_le()
self.func_gt()
self.func_ge()
self.func_0Dtensor_index()
self.func_2Dtensor()
class TestHessianBatchFirst(unittest.TestCase):
@classmethod
def setUpClass(self):
self.x_shape = (5, 2)
self.weight_shape = (2, 4)
self.y_shape = (5, 2)
self.nbatch, self.nrow = 5, 2
self.dtype = 'float32'
self.np_dtype = np.float32
self.numerical_delta = (
config.TOLERANCE.get(self.dtype).get('second_order_grad').get('eps')
)
self.rtol = (
config.TOLERANCE.get(self.dtype)
.get('second_order_grad')
.get('rtol')
)
self.atol = (
config.TOLERANCE.get(self.dtype)
.get('second_order_grad')
.get('atol')
)
self.x = paddle.rand(shape=self.x_shape, dtype=self.dtype)
self.x.stop_gradient = False
self.weight = paddle.rand(shape=self.weight_shape, dtype=self.dtype)
self.weight.stop_gradient = False
self.y = paddle.rand(shape=self.y_shape, dtype=self.dtype)
self.y.stop_gradient = False
def func_allow_unused(self):
def func(x, y):
return paddle.matmul(x * x, self.weight)[:, 0:1]
xs_len = 2
expected = utils._compute_numerical_batch_hessian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype
)
expected = np.reshape(
np.array(expected),
(xs_len, xs_len, self.nrow, self.nbatch, self.nrow),
)
expected = [list(row) for row in expected]
expected = utils._np_concat_matrix_sequence(expected)
expected = utils._np_transpose_matrix_format(
expected, utils.MatrixFormat.NBM, utils.MatrixFormat.BNM
)
actual = paddle.autograd.hessian(
func(self.x, self.y), [self.x, self.y], batch_axis=0
)
actual = paddle.concat(
[
paddle.concat([actual[i][j][:] for j in range(2)], axis=2)
for i in range(2)
],
axis=1,
)
np.testing.assert_allclose(
actual.shape, expected.shape, rtol=self.rtol, atol=self.atol
)
def func_stop_gradient(self):
def func(x):
return paddle.matmul(x * x, self.weight)[:, 0:1]
expected = utils._compute_numerical_batch_hessian(
func, self.x, self.numerical_delta, self.np_dtype
)
x = self.x.clone()
x.stop_gradient = True
H = paddle.autograd.hessian(func(self.x), self.x, batch_axis=0)[:]
actual = utils._np_transpose_matrix_format(
H[:].numpy(), utils.MatrixFormat.BNM, utils.MatrixFormat.NBM
)
actual = actual.reshape((H.shape[1], -1))
np.testing.assert_allclose(
actual.shape, np.asarray(expected).shape, self.rtol, self.atol
)
def func_out_not_single(self):
def func(x):
return x * x
with self.assertRaises(ValueError):
x = paddle.ones((3, 3))
paddle.autograd.hessian(func(x), x, batch_axis=0)
def func_batch_axis_except_0(self):
def func(x):
return x * x
with self.assertRaises(ValueError):
x = paddle.ones([3])
paddle.autograd.hessian(func(x), x, batch_axis=2)
def func_ndim_bigger_than_2(self):
def func(x):
return (x * x).sum()
with self.assertRaises(ValueError):
x = paddle.ones([3, 3, 3, 3])
paddle.autograd.hessian(func(x), x, batch_axis=0)
def func_batch_axis_str(self):
def func(x):
return (x * x).sum()
with self.assertRaises(ValueError):
x = paddle.ones([3, 3, 3, 3])
paddle.autograd.hessian(func(x), x, batch_axis="0")
def func_ellipsis_index(self):
def func(x):
return (x * x).sum()
with self.assertRaises(IndexError):
x = paddle.ones([2, 3])
H = paddle.autograd.hessian(func(x), x, batch_axis=0)[..., 1]
def test_all_cases(self):
self.setUpClass()
self.func_allow_unused()
self.func_stop_gradient()
self.func_out_not_single()
self.func_batch_axis_except_0()
self.func_ndim_bigger_than_2()
self.func_batch_axis_str()
self.func_ellipsis_index()
if __name__ == "__main__":
np.random.seed(2022)
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册