未验证 提交 64e7c715 编写于 作者: T Tongxin Bai 提交者: GitHub

Symbolic Hessian (#39221)

* [autograd] static Jacobian pass tests.

* [autograd] apply CR suggested changes.

* [autograd] more tests.

* [autograd] add CPUPlace in tests.

* [autograd] bug fixes.

* [autograd] reformatted.

* [autograd] adding Hessian, in progress.

* [autograd] Hessian passes. A double grad bug fixed.

* [autograd] fix renaming conflict in double backward pass.

* [autograd] polish test.s

* fix a bug when using brackets

* debug for ci

* [autograd] fixing Hessian test.

* polish format.
Co-authored-by: Nlevi131 <83750468+levi131@users.noreply.github.com>
Co-authored-by: Nlevi131 <limaolin01@baidu.com>
上级 984b16fc
...@@ -909,12 +909,32 @@ def vhp(func, inputs, v=None, create_graph=False, allow_unused=False): ...@@ -909,12 +909,32 @@ def vhp(func, inputs, v=None, create_graph=False, allow_unused=False):
class Jacobian(object): class Jacobian(object):
r""" r"""
Object that represents the Jacobian matrix of a muli-input multi-output Computes the Jacobian matrix of function `func`, which may take as input
function. single or multiple tensor typed arguments and output a single tensor or
multiple tensors.
The Jacobian values are lazily evaluated if accessed through indices. In case `func` is multi-input and multi-output, i.e.,
In contrast, slicing access would trigger evaluating the full matrix
if it's not already computed. func: Callable[[Tensor, ...], [Tensor, ...]]
`func` is treated as a vector valued function with all its inputs flattened
into a single one dimensional tensor, or a two dimensional tensor with the
first dimension retained as the batching dimension. The same rule applies to
the function outputs.
Once the Jacobian J is constructed, there are four ways to retrieve the
partial derivatives.
- J[:], retrieving the full matrix.
- J[:, j], retrieving the partial derivatives w.r.t. the j'th input
variable.
- J[i, :], retrieving the partial derivatives w.r.t. the i'th output
variable.
- J[i, j], retrieving the partial derivatives w.r.t. the i'th output
variable and the j'th input variable.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -984,7 +1004,10 @@ class Jacobian(object): ...@@ -984,7 +1004,10 @@ class Jacobian(object):
return x.reshape(to) return x.reshape(to)
def flatten_all(self, xs): def flatten_all(self, xs):
if isinstance(xs, (list, tuple)):
return paddle.concat([self.flatten(x) for x in xs], axis=-1) return paddle.concat([self.flatten(x) for x in xs], axis=-1)
else:
return self.flatten(xs)
def shape(self): def shape(self):
return (self.ydim, self.xdim) return (self.ydim, self.xdim)
...@@ -995,23 +1018,23 @@ class Jacobian(object): ...@@ -995,23 +1018,23 @@ class Jacobian(object):
else: else:
i, j = tup, None i, j = tup, None
if isinstance(i, slice): full = isinstance(i, slice)
slicing = True
else:
slicing = False
if slicing: if full:
if 'full' not in self.jacobian: if 'full' not in self.jacobian:
rows = [ rows = [
self.flatten_all(gradients(self.y[..., i], self.xs)) self.flatten_all(gradients(self.y[..., i], self.xs))
for i in range(self.ydim) for i in range(self.ydim)
] ]
self.jacobian['full'] = paddle.stack(rows) self.jacobian['full'] = full_jacobian = paddle.stack(rows)
return self.jacobian['full'][i] else:
full_jacobian = self.jacobian['full']
return full_jacobian[i] if j is None else full_jacobian[i][..., j]
assert 0 <= i < self.ydim, f"Jacobian index i={i} is not valid." assert 0 <= i < self.ydim, f"Jacobian index i={i} is not valid."
assert (j is None) or ( assert j is None or isinstance(j, slice) or (0 <= j < self.xdim), (
0 <= j < self.xdim), f"Jacobian index j={j} is not valid." f"Jacobian index j={j} is not valid.")
if 'full' in self.jacobian: if 'full' in self.jacobian:
JJ = self.jacobian['full'] JJ = self.jacobian['full']
else: else:
...@@ -1024,3 +1047,17 @@ class Jacobian(object): ...@@ -1024,3 +1047,17 @@ class Jacobian(object):
return JJ[i] return JJ[i]
else: else:
return JJ[i][..., j] return JJ[i][..., j]
class Hessian(object):
def __init__(self, func, inputs, batch=False):
f_x = lambda xs: Jacobian(func, xs, batch=batch)[0]
self.symbolic = Jacobian(f_x, inputs, batch=batch)
self.xs = inputs
self.batch = batch
def __getitem__(self, tup):
return self.symbolic[tup]
def shape(self):
return self.symbolic.shape()
...@@ -1132,9 +1132,9 @@ def _append_backward_ops_(block, ...@@ -1132,9 +1132,9 @@ def _append_backward_ops_(block,
# So rename here before _addup_repetitive_outputs_. # So rename here before _addup_repetitive_outputs_.
if program._appending_grad_times > 1: if program._appending_grad_times > 1:
for op_desc in grad_op_desc: for op_desc in grad_op_desc:
if not _is_grad_op_(op): forward_op_inputs = op.desc.input_arg_names()
for name in op_desc.input_arg_names(): for name in op_desc.input_arg_names():
if name in rename_var_map: if name in rename_var_map and name not in forward_op_inputs:
op_desc._rename_input(name, rename_var_map[name]) op_desc._rename_input(name, rename_var_map[name])
for name in op_desc.output_arg_names(): for name in op_desc.output_arg_names():
if "@GRAD" not in name: if "@GRAD" not in name:
......
...@@ -66,33 +66,12 @@ def approx_jacobian(f, xs, dtype, eps=1e-5, batch=False): ...@@ -66,33 +66,12 @@ def approx_jacobian(f, xs, dtype, eps=1e-5, batch=False):
ds = eps * np.eye(xdim, dtype=dtype) ds = eps * np.eye(xdim, dtype=dtype)
fprimes_by_x = [(0.5 / eps) * (_f(x + d) - _f(x - d)) for d in ds] fprimes_by_x = [(0.5 * (_f(x + d) - _f(x - d)) / eps) for d in ds]
fprimes_by_y = np.stack(fprimes_by_x, axis=-1) fprimes_by_y = np.stack(fprimes_by_x, axis=-1)
return np.transpose(fprimes_by_y, [1, 0, 2]) if batch else fprimes_by_y return np.transpose(fprimes_by_y, [1, 0, 2]) if batch else fprimes_by_y
class TestJacobianFloat32(unittest.TestCase): def make_tensors(inps):
@classmethod
def setUpClass(self):
paddle.enable_static()
if fluid.core.is_compiled_with_cuda():
self.place = fluid.CUDAPlace(0)
else:
self.place = fluid.CPUPlace()
self.np_dtype = np.float32
self.A = np.array([[1., 2.]]).astype('float32')
self.B = np.array([[1., 2.], [2., 1.]]).astype('float32')
self.C = np.array([[2., 2.], [2., 1.]]).astype('float32')
self.D = np.array(
[[[2., 2.], [2., 1.]], [[1., 2.], [2., 1.]]]).astype('float32')
self.E = np.array(
[[[3., 4.], [2., 3.]], [[2., 1.], [1., 3.]]]).astype('float32')
self.eps = 1e-4
self.rtol = 1e-2
self.atol = 1e-2
def run_test(self, pd_f, np_f, inps, dtype, batch=False):
def make_tensors(inps):
if isinstance(inps, list): if isinstance(inps, list):
xs = [ xs = [
paddle.static.data( paddle.static.data(
...@@ -100,73 +79,25 @@ class TestJacobianFloat32(unittest.TestCase): ...@@ -100,73 +79,25 @@ class TestJacobianFloat32(unittest.TestCase):
for i, inp in enumerate(inps) for i, inp in enumerate(inps)
] ]
else: else:
xs = paddle.static.data( xs = paddle.static.data(name='x', shape=inps.shape, dtype=inps.dtype)
name='x', shape=inps.shape, dtype=inps.dtype)
return xs return xs
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
xs = make_tensors(inps)
JJ = paddle.autograd.functional.Jacobian(pd_f, xs, batch=batch)
nrow, ncol = JJ.shape()
full_jacobian = JJ[:]
exe = fluid.Executor(self.place)
exe.run(startup)
if isinstance(inps, list):
feeds = {f'x{i}': x for i, x in enumerate(inps)}
else:
feeds = {'x': inps}
pd_jacobians = exe.run(main, feed=feeds, fetch_list=[full_jacobian])[0]
np_jacobians = approx_jacobian(np_f, inps, dtype, self.eps, batch=batch)
self.assertTrue(
np.allclose(pd_jacobians, np_jacobians, self.rtol, self.atol))
def test_square(self): all_data_shapes = {
def pd_f(x): 'A': [[1., 2.]],
return paddle.multiply(x, x) 'B': [[1., 2.], [2., 1.]],
'C': [[2., 2.], [2., 1.]],
'D': [[[2., 2.], [2., 1.]], [[1., 2.], [2., 1.]]],
'E': [[[3., 4.], [2., 3.]], [[2., 1.], [1., 3.]]],
}
def np_f(x):
return np.multiply(x, x)
self.run_test(pd_f, np_f, self.A, np.dtype('float32')) def prepare_data(test, input_shapes, dtype):
for name, shape in input_shapes.items():
setattr(test, name, np.array(shape, dtype=dtype))
def test_mul(self):
def pd_f(xs):
x, y = xs
return paddle.multiply(x, y)
def np_f(xs):
x, y = xs
return np.multiply(x, y)
self.run_test(pd_f, np_f, [self.B, self.C], np.dtype('float32'))
def test_matmul(self):
def pd_f(xs):
x, y = xs
return paddle.matmul(x, y)
def np_f(xs): class TestJacobianFloat32(unittest.TestCase):
x, y = xs
return np.matmul(x, y)
self.run_test(pd_f, np_f, [self.B, self.C], np.dtype('float32'))
def test_batch_matmul(self):
def pd_f(xs):
x, y = xs
return paddle.matmul(x, y)
def np_f(xs):
x, y = xs
return np.matmul(x, y)
self.run_test(
pd_f, np_f, [self.D, self.E], np.dtype('float32'), batch=True)
class TestJacobianFloat64(unittest.TestCase):
@classmethod @classmethod
def setUpClass(self): def setUpClass(self):
paddle.enable_static() paddle.enable_static()
...@@ -174,31 +105,13 @@ class TestJacobianFloat64(unittest.TestCase): ...@@ -174,31 +105,13 @@ class TestJacobianFloat64(unittest.TestCase):
self.place = fluid.CUDAPlace(0) self.place = fluid.CUDAPlace(0)
else: else:
self.place = fluid.CPUPlace() self.place = fluid.CPUPlace()
self.np_dtype = np.float32 self.dtype = 'float32'
self.A = np.array([[1., 2.]]).astype('float64') prepare_data(self, all_data_shapes, self.dtype)
self.B = np.array([[1., 2.], [2., 1.]]).astype('float64') self.eps = 1e-4
self.C = np.array([[2., 2.], [2., 1.]]).astype('float64') self.rtol = 1e-2
self.D = np.array( self.atol = 1e-2
[[[2., 2.], [2., 1.]], [[1., 2.], [2., 1.]]]).astype('float64')
self.E = np.array(
[[[3., 4.], [2., 3.]], [[2., 1.], [1., 3.]]]).astype('float64')
self.eps = 1e-7
self.rtol = 1e-6
self.atol = 1e-6
def run_test_by_fullmatrix(self, pd_f, np_f, inps, dtype, batch=False):
def make_tensors(inps):
if isinstance(inps, list):
xs = [
paddle.static.data(
f'x{i}', inp.shape, dtype=inp.dtype)
for i, inp in enumerate(inps)
]
else:
xs = paddle.static.data(
name='x', shape=inps.shape, dtype=inps.dtype)
return xs
def run_test_by_fullmatrix(self, pd_f, np_f, inps, batch=False):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
...@@ -213,23 +126,12 @@ class TestJacobianFloat64(unittest.TestCase): ...@@ -213,23 +126,12 @@ class TestJacobianFloat64(unittest.TestCase):
else: else:
feeds = {'x': inps} feeds = {'x': inps}
pd_jacobians = exe.run(main, feed=feeds, fetch_list=[full_jacobian])[0] pd_jacobians = exe.run(main, feed=feeds, fetch_list=[full_jacobian])[0]
np_jacobians = approx_jacobian(np_f, inps, dtype, self.eps, batch=batch) np_jacobians = approx_jacobian(
np_f, inps, self.dtype, self.eps, batch=batch)
self.assertTrue( self.assertTrue(
np.allclose(pd_jacobians, np_jacobians, self.rtol, self.atol)) np.allclose(pd_jacobians, np_jacobians, self.rtol, self.atol))
def run_test_by_rows(self, pd_f, np_f, inps, dtype, batch=False): def run_test_by_rows(self, pd_f, np_f, inps, batch=False):
def make_tensors(inps):
if isinstance(inps, list):
xs = [
paddle.static.data(
f'x{i}', inp.shape, dtype=inp.dtype)
for i, inp in enumerate(inps)
]
else:
xs = paddle.static.data(
name='x', shape=inps.shape, dtype=inps.dtype)
return xs
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
...@@ -244,24 +146,12 @@ class TestJacobianFloat64(unittest.TestCase): ...@@ -244,24 +146,12 @@ class TestJacobianFloat64(unittest.TestCase):
else: else:
feeds = {'x': inps} feeds = {'x': inps}
pd_jac = exe.run(main, feed=feeds, fetch_list=[rows]) pd_jac = exe.run(main, feed=feeds, fetch_list=[rows])
np_jac = approx_jacobian(np_f, inps, dtype, self.eps, batch=batch) np_jac = approx_jacobian(np_f, inps, self.dtype, self.eps, batch=batch)
for i in range(nrow): for i in range(nrow):
self.assertTrue( self.assertTrue(
np.allclose(pd_jac[i], np_jac[i], self.rtol, self.atol)) np.allclose(pd_jac[i], np_jac[i], self.rtol, self.atol))
def run_test_by_entries(self, pd_f, np_f, inps, dtype, batch=False): def run_test_by_entries(self, pd_f, np_f, inps, batch=False):
def make_tensors(inps):
if isinstance(inps, list):
xs = [
paddle.static.data(
f'x{i}', inp.shape, dtype=inp.dtype)
for i, inp in enumerate(inps)
]
else:
xs = paddle.static.data(
name='x', shape=inps.shape, dtype=inps.dtype)
return xs
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
...@@ -276,7 +166,7 @@ class TestJacobianFloat64(unittest.TestCase): ...@@ -276,7 +166,7 @@ class TestJacobianFloat64(unittest.TestCase):
else: else:
feeds = {'x': inps} feeds = {'x': inps}
pd_entries = exe.run(main, feed=feeds, fetch_list=[entries]) pd_entries = exe.run(main, feed=feeds, fetch_list=[entries])
np_jac = approx_jacobian(np_f, inps, dtype, self.eps, batch=batch) np_jac = approx_jacobian(np_f, inps, self.dtype, self.eps, batch=batch)
np_entries = [ np_entries = [
np_jac[i, ..., j] for i in range(nrow) for j in range(ncol) np_jac[i, ..., j] for i in range(nrow) for j in range(ncol)
] ]
...@@ -291,9 +181,9 @@ class TestJacobianFloat64(unittest.TestCase): ...@@ -291,9 +181,9 @@ class TestJacobianFloat64(unittest.TestCase):
def np_f(x): def np_f(x):
return np.multiply(x, x) return np.multiply(x, x)
self.run_test_by_fullmatrix(pd_f, np_f, self.A, np.dtype('float64')) self.run_test_by_fullmatrix(pd_f, np_f, self.A)
self.run_test_by_rows(pd_f, np_f, self.A, np.dtype('float64')) self.run_test_by_rows(pd_f, np_f, self.A)
self.run_test_by_entries(pd_f, np_f, self.A, np.dtype('float64')) self.run_test_by_entries(pd_f, np_f, self.A)
def test_mul(self): def test_mul(self):
def pd_f(xs): def pd_f(xs):
...@@ -304,11 +194,12 @@ class TestJacobianFloat64(unittest.TestCase): ...@@ -304,11 +194,12 @@ class TestJacobianFloat64(unittest.TestCase):
x, y = xs x, y = xs
return np.multiply(x, y) return np.multiply(x, y)
self.run_test_by_fullmatrix(pd_f, np_f, [self.B, self.C], self.run_test_by_fullmatrix(
np.dtype('float64')) pd_f,
self.run_test_by_rows(pd_f, np_f, [self.B, self.C], np.dtype('float64')) np_f,
self.run_test_by_entries(pd_f, np_f, [self.B, self.C], [self.B, self.C], )
np.dtype('float64')) self.run_test_by_rows(pd_f, np_f, [self.B, self.C])
self.run_test_by_entries(pd_f, np_f, [self.B, self.C])
def test_matmul(self): def test_matmul(self):
def pd_f(xs): def pd_f(xs):
...@@ -319,11 +210,9 @@ class TestJacobianFloat64(unittest.TestCase): ...@@ -319,11 +210,9 @@ class TestJacobianFloat64(unittest.TestCase):
x, y = xs x, y = xs
return np.matmul(x, y) return np.matmul(x, y)
self.run_test_by_fullmatrix(pd_f, np_f, [self.B, self.C], self.run_test_by_fullmatrix(pd_f, np_f, [self.B, self.C])
np.dtype('float64')) self.run_test_by_rows(pd_f, np_f, [self.B, self.C])
self.run_test_by_rows(pd_f, np_f, [self.B, self.C], np.dtype('float64')) self.run_test_by_entries(pd_f, np_f, [self.B, self.C])
self.run_test_by_entries(pd_f, np_f, [self.B, self.C],
np.dtype('float64'))
def test_batch_matmul(self): def test_batch_matmul(self):
def pd_f(xs): def pd_f(xs):
...@@ -334,12 +223,85 @@ class TestJacobianFloat64(unittest.TestCase): ...@@ -334,12 +223,85 @@ class TestJacobianFloat64(unittest.TestCase):
x, y = xs x, y = xs
return np.matmul(x, y) return np.matmul(x, y)
self.run_test_by_fullmatrix(pd_f, np_f, [self.D, self.E], batch=True)
self.run_test_by_rows(pd_f, np_f, [self.D, self.E], batch=True)
self.run_test_by_entries(pd_f, np_f, [self.D, self.E], batch=True)
class TestJacobianFloat64(TestJacobianFloat32):
@classmethod
def setUpClass(self):
paddle.enable_static()
if fluid.core.is_compiled_with_cuda():
self.place = fluid.CUDAPlace(0)
else:
self.place = fluid.CPUPlace()
self.dtype = 'float64'
prepare_data(self, all_data_shapes, self.dtype)
self.eps = 1e-7
self.rtol = 1e-6
self.atol = 1e-6
class TestHessianFloat64(unittest.TestCase):
@classmethod
def setUpClass(self):
paddle.enable_static()
if fluid.core.is_compiled_with_cuda():
self.place = fluid.CUDAPlace(0)
else:
self.place = fluid.CPUPlace()
self.dtype = 'float64'
prepare_data(self, all_data_shapes, self.dtype)
self.eps = 1e-7
self.rtol = 1e-6
self.atol = 1e-6
def run_test_by_fullmatrix(self, pd_f, inps, np_hess, batch=False):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
xs = make_tensors(inps)
HH = paddle.autograd.functional.Hessian(pd_f, xs, batch=batch)
nrow, ncol = HH.shape()
full_hessian = HH[:]
exe = fluid.Executor(self.place)
exe.run(startup)
if isinstance(inps, list):
feeds = {f'x{i}': x for i, x in enumerate(inps)}
else:
feeds = {'x': inps}
pd_hess = exe.run(main, feed=feeds, fetch_list=[full_hessian])[0]
self.assertTrue(np.allclose(pd_hess, np_hess, self.rtol, self.atol))
def test_square(self):
def pd_f(x):
"""Input is a square matrix."""
return paddle.matmul(x, x.T)
def np_hess(x):
dim = x.shape[0]
f_xx_upperleft = 2 * np.eye(dim, dtype=self.dtype)
f_xx = np.zeros([dim * dim, dim * dim], dtype=self.dtype)
f_xx[:dim, :dim] = f_xx_upperleft
return f_xx
self.run_test_by_fullmatrix(pd_f, self.B, np_hess(self.B))
def test_batch_square(self):
def pd_f(x):
"""Input is a square matrix."""
return paddle.matmul(x, paddle.transpose(x, [0, 2, 1]))
def np_hess(x):
bat, dim, _ = x.shape
f_xx_upperleft = 2 * np.eye(dim, dtype=self.dtype)
f_xx = np.zeros([bat, dim * dim, dim * dim], dtype=self.dtype)
f_xx[..., :dim, :dim] = f_xx_upperleft
return f_xx
self.run_test_by_fullmatrix( self.run_test_by_fullmatrix(
pd_f, np_f, [self.D, self.E], np.dtype('float64'), batch=True) pd_f, self.E, np_hess(self.E), batch=True)
self.run_test_by_rows(
pd_f, np_f, [self.D, self.E], np.dtype('float64'), batch=True)
self.run_test_by_entries(
pd_f, np_f, [self.D, self.E], np.dtype('float64'), batch=True)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册