未验证 提交 64e7c715 编写于 作者: T Tongxin Bai 提交者: GitHub

Symbolic Hessian (#39221)

* [autograd] static Jacobian pass tests.

* [autograd] apply CR suggested changes.

* [autograd] more tests.

* [autograd] add CPUPlace in tests.

* [autograd] bug fixes.

* [autograd] reformatted.

* [autograd] adding Hessian, in progress.

* [autograd] Hessian passes. A double grad bug fixed.

* [autograd] fix renaming conflict in double backward pass.

* [autograd] polish test.s

* fix a bug when using brackets

* debug for ci

* [autograd] fixing Hessian test.

* polish format.
Co-authored-by: Nlevi131 <83750468+levi131@users.noreply.github.com>
Co-authored-by: Nlevi131 <limaolin01@baidu.com>
上级 984b16fc
......@@ -909,12 +909,32 @@ def vhp(func, inputs, v=None, create_graph=False, allow_unused=False):
class Jacobian(object):
r"""
Object that represents the Jacobian matrix of a muli-input multi-output
function.
Computes the Jacobian matrix of function `func`, which may take as input
single or multiple tensor typed arguments and output a single tensor or
multiple tensors.
In case `func` is multi-input and multi-output, i.e.,
func: Callable[[Tensor, ...], [Tensor, ...]]
`func` is treated as a vector valued function with all its inputs flattened
into a single one dimensional tensor, or a two dimensional tensor with the
first dimension retained as the batching dimension. The same rule applies to
the function outputs.
Once the Jacobian J is constructed, there are four ways to retrieve the
partial derivatives.
- J[:], retrieving the full matrix.
- J[:, j], retrieving the partial derivatives w.r.t. the j'th input
variable.
- J[i, :], retrieving the partial derivatives w.r.t. the i'th output
variable.
The Jacobian values are lazily evaluated if accessed through indices.
In contrast, slicing access would trigger evaluating the full matrix
if it's not already computed.
- J[i, j], retrieving the partial derivatives w.r.t. the i'th output
variable and the j'th input variable.
Examples:
.. code-block:: python
......@@ -984,7 +1004,10 @@ class Jacobian(object):
return x.reshape(to)
def flatten_all(self, xs):
return paddle.concat([self.flatten(x) for x in xs], axis=-1)
if isinstance(xs, (list, tuple)):
return paddle.concat([self.flatten(x) for x in xs], axis=-1)
else:
return self.flatten(xs)
def shape(self):
return (self.ydim, self.xdim)
......@@ -995,23 +1018,23 @@ class Jacobian(object):
else:
i, j = tup, None
if isinstance(i, slice):
slicing = True
else:
slicing = False
full = isinstance(i, slice)
if slicing:
if full:
if 'full' not in self.jacobian:
rows = [
self.flatten_all(gradients(self.y[..., i], self.xs))
for i in range(self.ydim)
]
self.jacobian['full'] = paddle.stack(rows)
return self.jacobian['full'][i]
self.jacobian['full'] = full_jacobian = paddle.stack(rows)
else:
full_jacobian = self.jacobian['full']
return full_jacobian[i] if j is None else full_jacobian[i][..., j]
assert 0 <= i < self.ydim, f"Jacobian index i={i} is not valid."
assert (j is None) or (
0 <= j < self.xdim), f"Jacobian index j={j} is not valid."
assert j is None or isinstance(j, slice) or (0 <= j < self.xdim), (
f"Jacobian index j={j} is not valid.")
if 'full' in self.jacobian:
JJ = self.jacobian['full']
else:
......@@ -1024,3 +1047,17 @@ class Jacobian(object):
return JJ[i]
else:
return JJ[i][..., j]
class Hessian(object):
def __init__(self, func, inputs, batch=False):
f_x = lambda xs: Jacobian(func, xs, batch=batch)[0]
self.symbolic = Jacobian(f_x, inputs, batch=batch)
self.xs = inputs
self.batch = batch
def __getitem__(self, tup):
return self.symbolic[tup]
def shape(self):
return self.symbolic.shape()
......@@ -1132,10 +1132,10 @@ def _append_backward_ops_(block,
# So rename here before _addup_repetitive_outputs_.
if program._appending_grad_times > 1:
for op_desc in grad_op_desc:
if not _is_grad_op_(op):
for name in op_desc.input_arg_names():
if name in rename_var_map:
op_desc._rename_input(name, rename_var_map[name])
forward_op_inputs = op.desc.input_arg_names()
for name in op_desc.input_arg_names():
if name in rename_var_map and name not in forward_op_inputs:
op_desc._rename_input(name, rename_var_map[name])
for name in op_desc.output_arg_names():
if "@GRAD" not in name:
continue
......
......@@ -66,107 +66,38 @@ def approx_jacobian(f, xs, dtype, eps=1e-5, batch=False):
ds = eps * np.eye(xdim, dtype=dtype)
fprimes_by_x = [(0.5 / eps) * (_f(x + d) - _f(x - d)) for d in ds]
fprimes_by_x = [(0.5 * (_f(x + d) - _f(x - d)) / eps) for d in ds]
fprimes_by_y = np.stack(fprimes_by_x, axis=-1)
return np.transpose(fprimes_by_y, [1, 0, 2]) if batch else fprimes_by_y
class TestJacobianFloat32(unittest.TestCase):
@classmethod
def setUpClass(self):
paddle.enable_static()
if fluid.core.is_compiled_with_cuda():
self.place = fluid.CUDAPlace(0)
else:
self.place = fluid.CPUPlace()
self.np_dtype = np.float32
self.A = np.array([[1., 2.]]).astype('float32')
self.B = np.array([[1., 2.], [2., 1.]]).astype('float32')
self.C = np.array([[2., 2.], [2., 1.]]).astype('float32')
self.D = np.array(
[[[2., 2.], [2., 1.]], [[1., 2.], [2., 1.]]]).astype('float32')
self.E = np.array(
[[[3., 4.], [2., 3.]], [[2., 1.], [1., 3.]]]).astype('float32')
self.eps = 1e-4
self.rtol = 1e-2
self.atol = 1e-2
def run_test(self, pd_f, np_f, inps, dtype, batch=False):
def make_tensors(inps):
if isinstance(inps, list):
xs = [
paddle.static.data(
f'x{i}', inp.shape, dtype=inp.dtype)
for i, inp in enumerate(inps)
]
else:
xs = paddle.static.data(
name='x', shape=inps.shape, dtype=inps.dtype)
return xs
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
xs = make_tensors(inps)
JJ = paddle.autograd.functional.Jacobian(pd_f, xs, batch=batch)
nrow, ncol = JJ.shape()
full_jacobian = JJ[:]
exe = fluid.Executor(self.place)
exe.run(startup)
if isinstance(inps, list):
feeds = {f'x{i}': x for i, x in enumerate(inps)}
else:
feeds = {'x': inps}
pd_jacobians = exe.run(main, feed=feeds, fetch_list=[full_jacobian])[0]
np_jacobians = approx_jacobian(np_f, inps, dtype, self.eps, batch=batch)
self.assertTrue(
np.allclose(pd_jacobians, np_jacobians, self.rtol, self.atol))
def test_square(self):
def pd_f(x):
return paddle.multiply(x, x)
def np_f(x):
return np.multiply(x, x)
self.run_test(pd_f, np_f, self.A, np.dtype('float32'))
def test_mul(self):
def pd_f(xs):
x, y = xs
return paddle.multiply(x, y)
def np_f(xs):
x, y = xs
return np.multiply(x, y)
self.run_test(pd_f, np_f, [self.B, self.C], np.dtype('float32'))
def test_matmul(self):
def pd_f(xs):
x, y = xs
return paddle.matmul(x, y)
def np_f(xs):
x, y = xs
return np.matmul(x, y)
def make_tensors(inps):
if isinstance(inps, list):
xs = [
paddle.static.data(
f'x{i}', inp.shape, dtype=inp.dtype)
for i, inp in enumerate(inps)
]
else:
xs = paddle.static.data(name='x', shape=inps.shape, dtype=inps.dtype)
return xs
self.run_test(pd_f, np_f, [self.B, self.C], np.dtype('float32'))
def test_batch_matmul(self):
def pd_f(xs):
x, y = xs
return paddle.matmul(x, y)
all_data_shapes = {
'A': [[1., 2.]],
'B': [[1., 2.], [2., 1.]],
'C': [[2., 2.], [2., 1.]],
'D': [[[2., 2.], [2., 1.]], [[1., 2.], [2., 1.]]],
'E': [[[3., 4.], [2., 3.]], [[2., 1.], [1., 3.]]],
}
def np_f(xs):
x, y = xs
return np.matmul(x, y)
self.run_test(
pd_f, np_f, [self.D, self.E], np.dtype('float32'), batch=True)
def prepare_data(test, input_shapes, dtype):
for name, shape in input_shapes.items():
setattr(test, name, np.array(shape, dtype=dtype))
class TestJacobianFloat64(unittest.TestCase):
class TestJacobianFloat32(unittest.TestCase):
@classmethod
def setUpClass(self):
paddle.enable_static()
......@@ -174,31 +105,13 @@ class TestJacobianFloat64(unittest.TestCase):
self.place = fluid.CUDAPlace(0)
else:
self.place = fluid.CPUPlace()
self.np_dtype = np.float32
self.A = np.array([[1., 2.]]).astype('float64')
self.B = np.array([[1., 2.], [2., 1.]]).astype('float64')
self.C = np.array([[2., 2.], [2., 1.]]).astype('float64')
self.D = np.array(
[[[2., 2.], [2., 1.]], [[1., 2.], [2., 1.]]]).astype('float64')
self.E = np.array(
[[[3., 4.], [2., 3.]], [[2., 1.], [1., 3.]]]).astype('float64')
self.eps = 1e-7
self.rtol = 1e-6
self.atol = 1e-6
def run_test_by_fullmatrix(self, pd_f, np_f, inps, dtype, batch=False):
def make_tensors(inps):
if isinstance(inps, list):
xs = [
paddle.static.data(
f'x{i}', inp.shape, dtype=inp.dtype)
for i, inp in enumerate(inps)
]
else:
xs = paddle.static.data(
name='x', shape=inps.shape, dtype=inps.dtype)
return xs
self.dtype = 'float32'
prepare_data(self, all_data_shapes, self.dtype)
self.eps = 1e-4
self.rtol = 1e-2
self.atol = 1e-2
def run_test_by_fullmatrix(self, pd_f, np_f, inps, batch=False):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
......@@ -213,23 +126,12 @@ class TestJacobianFloat64(unittest.TestCase):
else:
feeds = {'x': inps}
pd_jacobians = exe.run(main, feed=feeds, fetch_list=[full_jacobian])[0]
np_jacobians = approx_jacobian(np_f, inps, dtype, self.eps, batch=batch)
np_jacobians = approx_jacobian(
np_f, inps, self.dtype, self.eps, batch=batch)
self.assertTrue(
np.allclose(pd_jacobians, np_jacobians, self.rtol, self.atol))
def run_test_by_rows(self, pd_f, np_f, inps, dtype, batch=False):
def make_tensors(inps):
if isinstance(inps, list):
xs = [
paddle.static.data(
f'x{i}', inp.shape, dtype=inp.dtype)
for i, inp in enumerate(inps)
]
else:
xs = paddle.static.data(
name='x', shape=inps.shape, dtype=inps.dtype)
return xs
def run_test_by_rows(self, pd_f, np_f, inps, batch=False):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
......@@ -244,24 +146,12 @@ class TestJacobianFloat64(unittest.TestCase):
else:
feeds = {'x': inps}
pd_jac = exe.run(main, feed=feeds, fetch_list=[rows])
np_jac = approx_jacobian(np_f, inps, dtype, self.eps, batch=batch)
np_jac = approx_jacobian(np_f, inps, self.dtype, self.eps, batch=batch)
for i in range(nrow):
self.assertTrue(
np.allclose(pd_jac[i], np_jac[i], self.rtol, self.atol))
def run_test_by_entries(self, pd_f, np_f, inps, dtype, batch=False):
def make_tensors(inps):
if isinstance(inps, list):
xs = [
paddle.static.data(
f'x{i}', inp.shape, dtype=inp.dtype)
for i, inp in enumerate(inps)
]
else:
xs = paddle.static.data(
name='x', shape=inps.shape, dtype=inps.dtype)
return xs
def run_test_by_entries(self, pd_f, np_f, inps, batch=False):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
......@@ -276,7 +166,7 @@ class TestJacobianFloat64(unittest.TestCase):
else:
feeds = {'x': inps}
pd_entries = exe.run(main, feed=feeds, fetch_list=[entries])
np_jac = approx_jacobian(np_f, inps, dtype, self.eps, batch=batch)
np_jac = approx_jacobian(np_f, inps, self.dtype, self.eps, batch=batch)
np_entries = [
np_jac[i, ..., j] for i in range(nrow) for j in range(ncol)
]
......@@ -291,9 +181,9 @@ class TestJacobianFloat64(unittest.TestCase):
def np_f(x):
return np.multiply(x, x)
self.run_test_by_fullmatrix(pd_f, np_f, self.A, np.dtype('float64'))
self.run_test_by_rows(pd_f, np_f, self.A, np.dtype('float64'))
self.run_test_by_entries(pd_f, np_f, self.A, np.dtype('float64'))
self.run_test_by_fullmatrix(pd_f, np_f, self.A)
self.run_test_by_rows(pd_f, np_f, self.A)
self.run_test_by_entries(pd_f, np_f, self.A)
def test_mul(self):
def pd_f(xs):
......@@ -304,11 +194,12 @@ class TestJacobianFloat64(unittest.TestCase):
x, y = xs
return np.multiply(x, y)
self.run_test_by_fullmatrix(pd_f, np_f, [self.B, self.C],
np.dtype('float64'))
self.run_test_by_rows(pd_f, np_f, [self.B, self.C], np.dtype('float64'))
self.run_test_by_entries(pd_f, np_f, [self.B, self.C],
np.dtype('float64'))
self.run_test_by_fullmatrix(
pd_f,
np_f,
[self.B, self.C], )
self.run_test_by_rows(pd_f, np_f, [self.B, self.C])
self.run_test_by_entries(pd_f, np_f, [self.B, self.C])
def test_matmul(self):
def pd_f(xs):
......@@ -319,11 +210,9 @@ class TestJacobianFloat64(unittest.TestCase):
x, y = xs
return np.matmul(x, y)
self.run_test_by_fullmatrix(pd_f, np_f, [self.B, self.C],
np.dtype('float64'))
self.run_test_by_rows(pd_f, np_f, [self.B, self.C], np.dtype('float64'))
self.run_test_by_entries(pd_f, np_f, [self.B, self.C],
np.dtype('float64'))
self.run_test_by_fullmatrix(pd_f, np_f, [self.B, self.C])
self.run_test_by_rows(pd_f, np_f, [self.B, self.C])
self.run_test_by_entries(pd_f, np_f, [self.B, self.C])
def test_batch_matmul(self):
def pd_f(xs):
......@@ -334,12 +223,85 @@ class TestJacobianFloat64(unittest.TestCase):
x, y = xs
return np.matmul(x, y)
self.run_test_by_fullmatrix(
pd_f, np_f, [self.D, self.E], np.dtype('float64'), batch=True)
self.run_test_by_rows(
pd_f, np_f, [self.D, self.E], np.dtype('float64'), batch=True)
self.run_test_by_entries(
pd_f, np_f, [self.D, self.E], np.dtype('float64'), batch=True)
self.run_test_by_fullmatrix(pd_f, np_f, [self.D, self.E], batch=True)
self.run_test_by_rows(pd_f, np_f, [self.D, self.E], batch=True)
self.run_test_by_entries(pd_f, np_f, [self.D, self.E], batch=True)
class TestJacobianFloat64(TestJacobianFloat32):
@classmethod
def setUpClass(self):
paddle.enable_static()
if fluid.core.is_compiled_with_cuda():
self.place = fluid.CUDAPlace(0)
else:
self.place = fluid.CPUPlace()
self.dtype = 'float64'
prepare_data(self, all_data_shapes, self.dtype)
self.eps = 1e-7
self.rtol = 1e-6
self.atol = 1e-6
class TestHessianFloat64(unittest.TestCase):
@classmethod
def setUpClass(self):
paddle.enable_static()
if fluid.core.is_compiled_with_cuda():
self.place = fluid.CUDAPlace(0)
else:
self.place = fluid.CPUPlace()
self.dtype = 'float64'
prepare_data(self, all_data_shapes, self.dtype)
self.eps = 1e-7
self.rtol = 1e-6
self.atol = 1e-6
def run_test_by_fullmatrix(self, pd_f, inps, np_hess, batch=False):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
xs = make_tensors(inps)
HH = paddle.autograd.functional.Hessian(pd_f, xs, batch=batch)
nrow, ncol = HH.shape()
full_hessian = HH[:]
exe = fluid.Executor(self.place)
exe.run(startup)
if isinstance(inps, list):
feeds = {f'x{i}': x for i, x in enumerate(inps)}
else:
feeds = {'x': inps}
pd_hess = exe.run(main, feed=feeds, fetch_list=[full_hessian])[0]
self.assertTrue(np.allclose(pd_hess, np_hess, self.rtol, self.atol))
def test_square(self):
def pd_f(x):
"""Input is a square matrix."""
return paddle.matmul(x, x.T)
def np_hess(x):
dim = x.shape[0]
f_xx_upperleft = 2 * np.eye(dim, dtype=self.dtype)
f_xx = np.zeros([dim * dim, dim * dim], dtype=self.dtype)
f_xx[:dim, :dim] = f_xx_upperleft
return f_xx
self.run_test_by_fullmatrix(pd_f, self.B, np_hess(self.B))
def test_batch_square(self):
def pd_f(x):
"""Input is a square matrix."""
return paddle.matmul(x, paddle.transpose(x, [0, 2, 1]))
def np_hess(x):
bat, dim, _ = x.shape
f_xx_upperleft = 2 * np.eye(dim, dtype=self.dtype)
f_xx = np.zeros([bat, dim * dim, dim * dim], dtype=self.dtype)
f_xx[..., :dim, :dim] = f_xx_upperleft
return f_xx
self.run_test_by_fullmatrix(
pd_f, self.E, np_hess(self.E), batch=True)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册