未验证 提交 4d24d352 编写于 作者: W Weilong Wu 提交者: GitHub

[New features] Support batch_jacobian and batch_hessian (#37547)

* native commit for triple grad of sigmod

* Updated unittests files

* init functional jacobian api

* Updated trible_test func

* Updated gradient_checker & test_script

* finish test with dtype float32

* add float64 test case

* polish code

* use atol=1e-5 with dtype float64

* fix for ci

* set timeout for test_jacobian

* fix dygraph grad to support high differential

* polish API docstring

* Updated gradient checker and some related files

* fix double grad strip error for high differential

* fix double grad strip error for high differential

* Add Sigmoid triple grad tests

* fix dygraph double grad dtype error when calling for high differential senario

* Updated triple grad teses func

* Use np.random to initialize ddx

* Updated triple_grad_check func

* add todo for gradient checker and refine some comments

* remove additional code

* add test for warnging in backward.py

* format python code

* support multi input in triple gradient checker

* Add matmul triple grad kernel

* Updated comments of TODO

* Supported some special tests

* Change code-format to follow CI std

* Updated gradient_checker.py

* Fix conflicts

* Removed unnecessary printing log

* Change code style to follow CI std

* support batch in jacobian and hessian

* add batch jacobian and batch hessian

* Add batch_jacobian test, draft version

* [New features] Add elementwise_mul triple grad kernel (#37152)

* Add elementwise_mul triple grad kernel

* Removed InplaceInferer and polished code

* Add numerical_batch_jacobian,numerical_batch_hessian and tests

* Support batch_jacobian and batch_numerical

* Use pre-commit to check code format

* Update doc, polish code, add unit test

* Reset the TIMEOUT properties of test_jacobian to pass CI
Co-authored-by: Nlevi131 <limaolin01@baidu.com>
Co-authored-by: NJiabin Yang <360788950@qq.com>
上级 d0a89744
......@@ -18,6 +18,7 @@ from .backward_mode import backward # noqa: F401
from .py_layer import PyLayer, PyLayerContext # noqa: F401
from ..framework import set_grad_enabled # noqa: F401
from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401
from .functional import vjp, jvp, jacobian, hessian, vhp # noqa: F401
from .functional import jacobian, hessian, batch_jacobian, batch_hessian # noqa: F401
from .functional import vjp, jvp, vhp # noqa: F401
__all__ = ['backward', 'PyLayer', 'PyLayerContext']
......@@ -385,6 +385,297 @@ def jacobian(func, inputs, create_graph=False, allow_unused=False):
return jacobian
@framework.dygraph_only
def batch_jacobian(func, inputs, create_graph=False, allow_unused=False):
'''
.. note::
**This API is ONLY available in the imperative mode.**
This function computes the batch Jacobian matrix of `func` with respect to `inputs`.
Noted that the first dimension of inputs is batch size.
Parameters:
func (function): a Python function that takes a Tensor or a Tensor
list/tuple as inputs(the first dimension is batch size) and
returns a Tensor or a Tensor tuple.
inputs (Tensor|list(Tensor)|tuple(Tensor)): the input Tensor or
Tensor list/tuple of the function ``func``, Noted that
the first dimension of inputs is batch size.
create_graph (bool, optional): whether to create the gradient graphs
of the computing process. When it is True, higher order derivatives
are supported to compute; when it is False, the gradient graphs of
the computing process would be discarded. Defaults to ``False``.
allow_unused (bool, optional): whether to raise error or return None if
some Tensors of `inputs` are unreachable in the graph. Error would
be raised if allow_unused=False, and None would be returned as
their gradients if allow_unused=True. Default False.
Returns:
Jacobian (Tensor or nested tuple of Tensors): if function ``func``
takes a Tensor as inputs and returns a Tensor as outputs, Jacobian
will be a single Tensor containing the Jacobian matrix for the
linearized inputs and outputs. If one of the inputs and outputs is
a Tensor, and another is a Tensor list/tuple, then the Jacobian will
be a tuple of Tensors. If both of inputs and outputs are Tensor
list/tuple, then the Jacobian will be a tuple of tuple of Tensors.
Noted that the first dimension of inputs is batch size.
For example,
the inputs shape and outputs shape of function ``func` is [batch_size, num]
and [batch_size, num] respectively, then the Jacobian will be a Tensor with
a shape of [num, batch_size * num], where ``Jacobian[i][j]`` will contain
the Jacobian matrix of the ``i``th column output and the ``j``th input and
will have same dtype and device as the corresponding input.
Other situations can be deduced by analogy.
Examples 1:
.. code-block:: python
import paddle
x = paddle.ones(shape=(4, 2), dtype='float64')
weight = paddle.ones(shape=(2, 4), dtype='float64')
y = paddle.ones(shape=(4, 2), dtype='float64')
def func(x):
return paddle.matmul(paddle.matmul(x, weight), y)
x.stop_gradient = False
batch_jacobian = paddle.autograd.batch_jacobian(func, x)
print(batch_jacobian)
# Tensor(shape=[2, 8], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
# [[4., 4., 4., 4., 4., 4., 4., 4.],
# [4., 4., 4., 4., 4., 4., 4., 4.]])
Examples 2:
.. code-block:: python
import paddle
x = paddle.ones(shape=(4, 2), dtype='float64')
weight = paddle.ones(shape=(2, 4), dtype='float64')
y = paddle.ones(shape=(4, 2), dtype='float64')
def func(x):
return paddle.matmul(paddle.matmul(x, weight), y), x * x
x.stop_gradient = False
batch_jacobian = paddle.autograd.batch_jacobian(func, x)
print(batch_jacobian)
# (Tensor(shape=[2, 8], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
# [[4., 4., 4., 4., 4., 4., 4., 4.],
# [4., 4., 4., 4., 4., 4., 4., 4.]]), Tensor(shape=[2, 8], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
# [[2., 0., 2., 0., 2., 0., 2., 0.],
# [0., 2., 0., 2., 0., 2., 0., 2.]]))
Examples 3:
.. code-block:: python
import paddle
x = paddle.ones(shape=(4, 2), dtype='float64')
weight = paddle.ones(shape=(2, 4), dtype='float64')
y = paddle.ones(shape=(4, 2), dtype='float64')
def func(x, y):
return x * y
x.stop_gradient = False
y.stop_gradient = False
batch_jacobian = paddle.autograd.batch_jacobian(func, [x, y])
print(batch_jacobian)
# (Tensor(shape=[2, 8], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
# [[1., 0., 1., 0., 1., 0., 1., 0.],
# [0., 1., 0., 1., 0., 1., 0., 1.]]), Tensor(shape=[2, 8], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
# [[1., 0., 1., 0., 1., 0., 1., 0.],
# [0., 1., 0., 1., 0., 1., 0., 1.]]))
'''
inputs = _tensors(inputs, "inputs")
outputs = _tensors(func(*inputs), "outputs")
batch_size = inputs[0].shape[0]
for input in inputs:
assert input.shape[
0] == batch_size, "The first dimension of input should equals to the same batch size!"
for output in outputs:
assert output.shape[
0] == batch_size, "The first dimension of output should equals to the same batch size!"
fin_size = len(inputs)
fout_size = len(outputs)
flat_outputs = tuple(
reshape(
output, shape=[batch_size, -1]) for output in outputs)
jacobian = tuple()
for i, flat_output in enumerate(flat_outputs):
jac_i = list([] for _ in range(fin_size))
for k in range(flat_output.shape[1]):
row_k = grad(
flat_output[:, k],
inputs,
create_graph=create_graph,
retain_graph=True,
allow_unused=allow_unused)
for j in range(fin_size):
jac_i[j].append(
reshape(
row_k[j], shape=[-1])
if isinstance(row_k[j], paddle.Tensor) else None)
jacobian += (tuple(
_stack_tensor_or_return_none(jac_i_j) for jac_i_j in jac_i), )
if fin_size == 1 and fout_size == 1:
return jacobian[0][0]
elif fin_size == 1 and fout_size != 1:
return tuple(jacobian[i][0] for i in range(fout_size))
elif fin_size != 1 and fout_size == 1:
return jacobian[0]
else:
return jacobian
@framework.dygraph_only
def batch_hessian(func, inputs, create_graph=False, allow_unused=False):
'''
.. note::
**This API is ONLY available in the imperative mode.**
This function computes the batch Hessian matrix of `func` with respect to `inputs`.
Noted that the first dimension of inputs is batch size.
Parameters:
func (function): a Python function that takes a Tensor or a Tensor
list/tuple as inputs(the first dimension is batch size) and
returns a Tensor with shape [batch_size, 1].
inputs (Tensor|list(Tensor)|tuple(Tensor)): the input Tensor or
Tensor list/tuple of the function ``func``.
Noted that the first dimension of inputs is batch size.
create_graph (bool, optional): whether to create the gradient graphs
of the computing process. When it is True, higher order derivatives
are supported to compute; when it is False, the gradient graphs of
the computing process would be discarded. Defaults to ``False``.
allow_unused (bool, optional): whether to raise error or return None if
some Tensors of `inputs` are unreachable in the graph. Error would
be raised if allow_unused=False, and None would be returned as
their gradients if allow_unused=True. Default False.
Returns:
Hessian (Tensor or a tuple of tuple of Tensors): if function ``func``
takes a Tensor as ``inputs``, Hessian will be a single Tensor containing
the Hessian matrix for the linearized ``inputs`` Tensor. If function
``func`` takes a Tensor list/tuple as ``inputs``, then the Hessian will
be a tuple of tuple of Tensors. Noted that the first dimension of inputs
is batch size and the execution step is to obtain the result of the
first order differentiation, and then differentiate the batch input.
For example,
the inputs shape and outputs shape of function ``func` is [batch_size, num]
and [batch_size, 1] respectively, then the batched Hessian will be a Tensor with
a shape of [num, batch_size * num].
Why the final shape in this case is that?
because batch_hessian will create a inner func(the wrapper of paddle.grad() func)
to computes the sum of gradients of `outputs` with respect to each `inputs`,
this inner func will get the first order differentiation and shape is [batch_size, num],
then call batch_jacobian to compute jacobian between the first order differentiation
and the origin inputs. The final result ``Hessian[i][j]`` will contain the Jacobian
matrix of the ``i``th column output(Noted that this output means the first order
differentiation) and the ``j``th input and will have same dtype and device as the
corresponding input. Other situations can be deduced by analogy.
Examples 1:
.. code-block:: python
import paddle
x = paddle.ones(shape=(4, 2), dtype='float64')
weight = paddle.ones(shape=(2, 4), dtype='float64')
y = paddle.ones(shape=(4, 2), dtype='float64')
def func(x):
return paddle.matmul(x * x, weight)[:, 0:1]
x.stop_gradient = False
batch_hessian = paddle.autograd.batch_hessian(func, x)
print(batch_hessian)
# Tensor(shape=[2, 8], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
# [[2., 0., 2., 0., 2., 0., 2., 0.],
# [0., 2., 0., 2., 0., 2., 0., 2.]])
Examples 2:
.. code-block:: python
import paddle
x = paddle.ones(shape=(4, 2), dtype='float64')
weight = paddle.ones(shape=(2, 4), dtype='float64')
y = paddle.ones(shape=(4, 2), dtype='float64')
def func(x, y):
return paddle.matmul(x * x * y * y, weight)[:, 0:1]
x.stop_gradient = False
y.stop_gradient = False
batch_hessian = paddle.autograd.batch_hessian(func, [x, y])
print(batch_hessian)
# ((Tensor(shape=[2, 8], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
# [[2., 0., 2., 0., 2., 0., 2., 0.],
# [0., 2., 0., 2., 0., 2., 0., 2.]]),
# Tensor(shape=[2, 8], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
# [[4., 0., 4., 0., 4., 0., 4., 0.],
# [0., 4., 0., 4., 0., 4., 0., 4.]])),
# (Tensor(shape=[2, 8], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
# [[4., 0., 4., 0., 4., 0., 4., 0.],
# [0., 4., 0., 4., 0., 4., 0., 4.]]),
# Tensor(shape=[2, 8], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
# [[2., 0., 2., 0., 2., 0., 2., 0.],
# [0., 2., 0., 2., 0., 2., 0., 2.]])))
Examples 3:
.. code-block:: python
import paddle
x = paddle.ones(shape=(4, 2), dtype='float64')
weight = paddle.ones(shape=(2, 4), dtype='float64')
y = paddle.ones(shape=(4, 2), dtype='float64')
def func(x, y):
return paddle.matmul(x * x, weight)[:, 0:1]
x.stop_gradient = False
y.stop_gradient = False
batch_hessian = paddle.autograd.batch_hessian(func, [x, y], allow_unused=True)
print(batch_hessian)
# ((Tensor(shape=[2, 8], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
# [[2., 0., 2., 0., 2., 0., 2., 0.],
# [0., 2., 0., 2., 0., 2., 0., 2.]]), None), (None, None))
'''
inputs = _tensors(inputs, "inputs")
outputs = func(*inputs)
batch_size = inputs[0].shape[0]
for input in inputs:
assert input.shape[
0] == batch_size, "The first dimension of input should equals to the same batch size!"
assert isinstance(outputs, paddle.Tensor) and outputs.shape == [
batch_size, 1
], "The function to compute batched Hessian matrix should return a Tensor of shape [batch_size, 1]"
def jac_func(*ins):
grad_inputs = grad(
outputs,
ins,
create_graph=True,
retain_graph=True,
allow_unused=allow_unused)
return tuple(
_replace_none_with_zero_tensor(grad_inputs[i], inputs[i])
for i in range(len(inputs)))
return batch_jacobian(
jac_func, inputs, create_graph=create_graph, allow_unused=allow_unused)
@framework.dygraph_only
def hessian(func, inputs, create_graph=False, allow_unused=False):
'''
......
......@@ -6,6 +6,6 @@ foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS})
endforeach(TEST_OP)
set_tests_properties(test_jacobian PROPERTIES TIMEOUT 20)
set_tests_properties(test_jacobian PROPERTIES TIMEOUT 50)
set_tests_properties(test_hessian PROPERTIES TIMEOUT 50)
set_tests_properties(test_vhp PROPERTIES TIMEOUT 50)
......@@ -17,7 +17,7 @@ import numpy as np
import paddle
import paddle.compat as cpt
import paddle.nn.functional as F
from utils import _compute_numerical_hessian
from utils import _compute_numerical_hessian, _compute_numerical_batch_hessian
class TestHessian(unittest.TestCase):
......@@ -136,5 +136,128 @@ class TestHessianFloat64(TestHessian):
self.y = paddle.rand(shape=self.shape, dtype=self.dtype)
class TestBatchHessian(unittest.TestCase):
@classmethod
def setUpClass(self):
self.x_shape = (5, 2)
self.weight_shape = (2, 4)
self.y_shape = (5, 2)
self.dtype = 'float32'
self.np_dtype = np.float32
self.numerical_delta = 1e-2
self.rtol = 1e-3
self.atol = 1e-3
self.x = paddle.rand(shape=self.x_shape, dtype=self.dtype)
self.weight = paddle.rand(shape=self.weight_shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.y_shape, dtype=self.dtype)
def test_single_input(self):
def func(x):
return paddle.matmul(x * x, self.weight)[:, 0:1]
numerical_hessian = _compute_numerical_batch_hessian(
func, self.x, self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
hessian = paddle.autograd.batch_hessian(func, self.x, create_graph=True)
assert np.allclose(hessian, numerical_hessian, self.rtol, self.atol)
def test_multi_input(self):
def func(x, y):
return paddle.matmul(x * x * y * y, self.weight)[:, 0:1]
numerical_hessian = _compute_numerical_batch_hessian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
hessian = paddle.autograd.batch_hessian(func, [self.x, self.y])
shape_tensor = paddle.to_tensor(numerical_hessian).astype("float64")
hessian_reshape = np.reshape(hessian, (shape_tensor.shape))
assert np.allclose(hessian_reshape, numerical_hessian, self.rtol,
self.atol)
def test_allow_unused_false(self):
def func(x, y):
return paddle.matmul(x * x, self.weight)[:, 0:1]
try:
self.x.stop_gradient = False
self.y.stop_gradient = False
hessian = paddle.autograd.batch_hessian(func, [self.x, self.y])
except ValueError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("allow_unused") > 0
def test_allow_unused_true(self):
def func(x, y):
return paddle.matmul(x * x, self.weight)[:, 0:1]
numerical_hessian = _compute_numerical_batch_hessian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
hessian = paddle.autograd.batch_hessian(
func, [self.x, self.y], allow_unused=True)
for i in range(len(hessian)):
for j in range(len(hessian[0])):
if i == j == 0:
numerical_hessian = np.stack(
(numerical_hessian[i][j], numerical_hessian[i][j + 1]),
axis=0)
assert np.allclose(hessian[i][j], numerical_hessian,
self.rtol, self.atol)
else:
assert hessian[i][j] is None
def test_create_graph_false(self):
def func(x):
return paddle.matmul(x * x, self.weight)[:, 0:1]
numerical_hessian = _compute_numerical_batch_hessian(
func, self.x, self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
hessian = paddle.autograd.batch_hessian(func, self.x)
assert hessian.stop_gradient == True
assert np.allclose(hessian.numpy(), numerical_hessian, self.rtol,
self.atol)
try:
paddle.grad(hessian, self.x)
except RuntimeError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("has no gradient") > 0
def test_create_graph_true(self):
def func(x):
return paddle.matmul(x * x, self.weight)[:, 0:1]
numerical_hessian = _compute_numerical_batch_hessian(
func, self.x, self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
hessian = paddle.autograd.batch_hessian(func, self.x, create_graph=True)
assert hessian.stop_gradient == False
assert np.allclose(hessian.numpy(), numerical_hessian, self.rtol,
self.atol)
triple_grad = paddle.grad(hessian, self.x)
assert triple_grad is not None
class TestBatchHessianFloat64(TestBatchHessian):
@classmethod
def setUpClass(self):
self.x_shape = (5, 2)
self.weight_shape = (2, 4)
self.y_shape = (5, 2)
self.dtype = 'float64'
self.np_dtype = np.float64
self.numerical_delta = 1e-4
self.rtol = 1e-5
self.atol = 1e-5
self.x = paddle.rand(shape=self.x_shape, dtype=self.dtype)
self.weight = paddle.rand(shape=self.weight_shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.y_shape, dtype=self.dtype)
if __name__ == "__main__":
unittest.main()
......@@ -16,7 +16,7 @@ import unittest
import numpy as np
import paddle
import paddle.compat as cpt
from utils import _compute_numerical_jacobian
from utils import _compute_numerical_jacobian, _compute_numerical_batch_jacobian
class TestJacobian(unittest.TestCase):
......@@ -158,5 +158,162 @@ class TestJacobianFloat64(TestJacobian):
self.y = paddle.rand(shape=self.shape, dtype=self.dtype)
class TestJacobianBatch(unittest.TestCase):
@classmethod
def setUpClass(self):
self.x_shape = (4, 2)
self.weight_shape = (2, 4)
self.y_shape = (4, 2)
self.dtype = 'float32'
self.np_dtype = np.float32
self.numerical_delta = 1e-4
self.rtol = 1e-3
self.atol = 1e-3
self.x = paddle.rand(shape=self.x_shape, dtype=self.dtype)
self.weight = paddle.rand(shape=self.weight_shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.y_shape, dtype=self.dtype)
def test_batch_single_input_and_batch_single_output(self):
def func(x):
return paddle.matmul(paddle.matmul(x, self.weight), self.y)
numerical_jacobian = _compute_numerical_batch_jacobian(
func, [self.x], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
batch_jacobian = paddle.autograd.batch_jacobian(
func,
self.x, )
self.assertTrue(
np.allclose(batch_jacobian.numpy().all(), numerical_jacobian[0][0]
.all()))
def test_batch_single_input_and_batch_multi_output(self):
def func(x):
return paddle.matmul(paddle.matmul(x, self.weight), self.y), x * x
numerical_jacobian = _compute_numerical_batch_jacobian(
func, [self.x], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
batch_jacobian = paddle.autograd.batch_jacobian(
func,
self.x, )
for i in range(len(batch_jacobian)):
assert np.allclose(batch_jacobian[i].numpy(),
numerical_jacobian[i][0], self.rtol, self.atol)
def test_batch_multi_input_and_batch_single_output(self):
def func(x, y):
return x * y
numerical_jacobian = _compute_numerical_batch_jacobian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
batch_jacobian = paddle.autograd.batch_jacobian(func, [self.x, self.y])
for j in range(len(batch_jacobian)):
assert np.allclose(batch_jacobian[j].numpy(),
numerical_jacobian[0][j], self.rtol, self.atol)
def test_batch_multi_input_and_batch_multi_output(self):
def func(x, y):
return x * y, x * y
numerical_jacobian = _compute_numerical_batch_jacobian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
batch_jacobian = paddle.autograd.batch_jacobian(func, [self.x, self.y])
for i in range(len(batch_jacobian)):
assert np.allclose(batch_jacobian[i], numerical_jacobian[i],
self.rtol, self.atol)
def test_allow_unused_false(self):
def func(x, y):
return x * x
try:
self.x.stop_gradient = False
self.y.stop_gradient = False
jacobian = paddle.autograd.batch_jacobian(func, [self.x, self.y])
except ValueError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("allow_unused") > 0
def test_allow_unused_true(self):
def func(x, y):
return x * x
numerical_jacobian = _compute_numerical_batch_jacobian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
jacobian = paddle.autograd.batch_jacobian(
func, [self.x, self.y], allow_unused=True)
assert np.allclose(jacobian[0].numpy(), numerical_jacobian[0][0],
self.rtol, self.atol)
assert jacobian[1] is None
def test_create_graph_false(self):
def func(x, y):
return x * y
numerical_jacobian = _compute_numerical_batch_jacobian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
jacobian = paddle.autograd.batch_jacobian(func, [self.x, self.y])
for j in range(len(jacobian)):
assert jacobian[j].stop_gradient == True
assert np.allclose(jacobian[j].numpy(), numerical_jacobian[0][j],
self.rtol, self.atol)
try:
paddle.grad(jacobian[0], [self.x, self.y])
except RuntimeError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("has no gradient") > 0
def test_create_graph_true(self):
def func(x, y):
return x * y
numerical_jacobian = _compute_numerical_batch_jacobian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
jacobian = paddle.autograd.batch_jacobian(
func, [self.x, self.y], create_graph=True)
for j in range(len(jacobian)):
assert jacobian[j].stop_gradient == False
assert np.allclose(jacobian[j].numpy(), numerical_jacobian[0][j],
self.rtol, self.atol)
double_grad = paddle.grad(jacobian[0], [self.x, self.y])
assert double_grad is not None
class TestJacobianBatchFloat64(TestJacobianBatch):
@classmethod
def setUpClass(self):
self.x_shape = (12, 2)
self.weight_shape = (2, 12)
self.y_shape = (12, 2)
self.dtype = 'float64'
self.np_dtype = np.float64
self.numerical_delta = 1e-7
self.rtol = 1e-7
self.atol = 1e-7
self.x = paddle.rand(shape=self.x_shape, dtype=self.dtype)
self.weight = paddle.rand(shape=self.weight_shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.y_shape, dtype=self.dtype)
if __name__ == "__main__":
unittest.main()
......@@ -107,6 +107,73 @@ def _compute_numerical_hessian(func, xs, delta, np_dtype):
return hessian
def _compute_numerical_batch_jacobian(func, xs, delta, np_dtype):
no_batch_jacobian = _compute_numerical_jacobian(func, xs, delta, np_dtype)
xs = _tensors(xs, "xs")
ys = _tensors(func(*xs), "ys")
fin_size = len(xs)
fout_size = len(ys)
bs = xs[0].shape[0]
bat_jac = []
for i in range(fout_size):
batch_jac_i = []
for j in range(fin_size):
jac = no_batch_jacobian[i][j]
jac_shape = jac.shape
out_size = jac_shape[0] // bs
in_size = jac_shape[1] // bs
jac = np.reshape(jac, (bs, out_size, bs, in_size))
batch_jac_i_j = np.zeros(shape=(out_size, bs, in_size))
for p in range(out_size):
for b in range(bs):
for q in range(in_size):
batch_jac_i_j[p][b][q] = jac[b][p][b][q]
batch_jac_i_j = np.reshape(batch_jac_i_j, (out_size, -1))
batch_jac_i.append(batch_jac_i_j)
bat_jac.append(batch_jac_i)
return bat_jac
def _compute_numerical_batch_hessian(func, xs, delta, np_dtype):
xs = _tensors(xs, "xs")
batch_size = xs[0].shape[0]
fin_size = len(xs)
hessian = []
for b in range(batch_size):
x_l = []
for j in range(fin_size):
x_l.append(paddle.reshape(xs[j][b], shape=[1, -1]))
hes_b = _compute_numerical_hessian(func, x_l, delta, np_dtype)
if fin_size == 1:
hessian.append(hes_b[0][0])
else:
hessian.append(hes_b)
hessian_res = []
for index in range(fin_size):
x_reshape = paddle.reshape(xs[index], shape=[batch_size, -1])
for index_ in range(fin_size):
for i in range(x_reshape.shape[1]):
tmp = []
for j in range(batch_size):
if fin_size == 1:
tmp.extend(hessian[j][i])
else:
tmp.extend(hessian[j][i][index_][index])
hessian_res.append(tmp)
if fin_size == 1:
return hessian_res
hessian_result = []
mid = len(hessian_res) // 2
for i in range(mid):
hessian_result.append(
np.stack(
(hessian_res[i], hessian_res[mid + i]), axis=0))
return hessian_result
def _compute_numerical_vjp(func, xs, v, delta, np_dtype):
xs = _tensors(xs, "xs")
jacobian = np.array(_compute_numerical_jacobian(func, xs, delta, np_dtype))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册