From 64e7c7158f59cdf79673b8fd23488c8088c7ef36 Mon Sep 17 00:00:00 2001 From: Tongxin Bai Date: Sat, 29 Jan 2022 14:34:16 +0800 Subject: [PATCH] Symbolic Hessian (#39221) * [autograd] static Jacobian pass tests. * [autograd] apply CR suggested changes. * [autograd] more tests. * [autograd] add CPUPlace in tests. * [autograd] bug fixes. * [autograd] reformatted. * [autograd] adding Hessian, in progress. * [autograd] Hessian passes. A double grad bug fixed. * [autograd] fix renaming conflict in double backward pass. * [autograd] polish test.s * fix a bug when using brackets * debug for ci * [autograd] fixing Hessian test. * polish format. Co-authored-by: levi131 <83750468+levi131@users.noreply.github.com> Co-authored-by: levi131 --- python/paddle/autograd/functional.py | 67 +++- python/paddle/fluid/backward.py | 8 +- ...bian_static.py => test_autograd_static.py} | 288 ++++++++---------- 3 files changed, 181 insertions(+), 182 deletions(-) rename python/paddle/fluid/tests/unittests/autograd/{test_jacobian_static.py => test_autograd_static.py} (53%) diff --git a/python/paddle/autograd/functional.py b/python/paddle/autograd/functional.py index b9cceafebaa..c663d37e7f2 100644 --- a/python/paddle/autograd/functional.py +++ b/python/paddle/autograd/functional.py @@ -909,12 +909,32 @@ def vhp(func, inputs, v=None, create_graph=False, allow_unused=False): class Jacobian(object): r""" - Object that represents the Jacobian matrix of a muli-input multi-output - function. + Computes the Jacobian matrix of function `func`, which may take as input + single or multiple tensor typed arguments and output a single tensor or + multiple tensors. + + In case `func` is multi-input and multi-output, i.e., + + func: Callable[[Tensor, ...], [Tensor, ...]] + + `func` is treated as a vector valued function with all its inputs flattened + into a single one dimensional tensor, or a two dimensional tensor with the + first dimension retained as the batching dimension. The same rule applies to + the function outputs. + + Once the Jacobian J is constructed, there are four ways to retrieve the + partial derivatives. + + - J[:], retrieving the full matrix. + + - J[:, j], retrieving the partial derivatives w.r.t. the j'th input + variable. + + - J[i, :], retrieving the partial derivatives w.r.t. the i'th output + variable. - The Jacobian values are lazily evaluated if accessed through indices. - In contrast, slicing access would trigger evaluating the full matrix - if it's not already computed. + - J[i, j], retrieving the partial derivatives w.r.t. the i'th output + variable and the j'th input variable. Examples: .. code-block:: python @@ -984,7 +1004,10 @@ class Jacobian(object): return x.reshape(to) def flatten_all(self, xs): - return paddle.concat([self.flatten(x) for x in xs], axis=-1) + if isinstance(xs, (list, tuple)): + return paddle.concat([self.flatten(x) for x in xs], axis=-1) + else: + return self.flatten(xs) def shape(self): return (self.ydim, self.xdim) @@ -995,23 +1018,23 @@ class Jacobian(object): else: i, j = tup, None - if isinstance(i, slice): - slicing = True - else: - slicing = False + full = isinstance(i, slice) - if slicing: + if full: if 'full' not in self.jacobian: rows = [ self.flatten_all(gradients(self.y[..., i], self.xs)) for i in range(self.ydim) ] - self.jacobian['full'] = paddle.stack(rows) - return self.jacobian['full'][i] + self.jacobian['full'] = full_jacobian = paddle.stack(rows) + else: + full_jacobian = self.jacobian['full'] + + return full_jacobian[i] if j is None else full_jacobian[i][..., j] assert 0 <= i < self.ydim, f"Jacobian index i={i} is not valid." - assert (j is None) or ( - 0 <= j < self.xdim), f"Jacobian index j={j} is not valid." + assert j is None or isinstance(j, slice) or (0 <= j < self.xdim), ( + f"Jacobian index j={j} is not valid.") if 'full' in self.jacobian: JJ = self.jacobian['full'] else: @@ -1024,3 +1047,17 @@ class Jacobian(object): return JJ[i] else: return JJ[i][..., j] + + +class Hessian(object): + def __init__(self, func, inputs, batch=False): + f_x = lambda xs: Jacobian(func, xs, batch=batch)[0] + self.symbolic = Jacobian(f_x, inputs, batch=batch) + self.xs = inputs + self.batch = batch + + def __getitem__(self, tup): + return self.symbolic[tup] + + def shape(self): + return self.symbolic.shape() diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 4805994b7aa..1637b33723b 100755 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -1132,10 +1132,10 @@ def _append_backward_ops_(block, # So rename here before _addup_repetitive_outputs_. if program._appending_grad_times > 1: for op_desc in grad_op_desc: - if not _is_grad_op_(op): - for name in op_desc.input_arg_names(): - if name in rename_var_map: - op_desc._rename_input(name, rename_var_map[name]) + forward_op_inputs = op.desc.input_arg_names() + for name in op_desc.input_arg_names(): + if name in rename_var_map and name not in forward_op_inputs: + op_desc._rename_input(name, rename_var_map[name]) for name in op_desc.output_arg_names(): if "@GRAD" not in name: continue diff --git a/python/paddle/fluid/tests/unittests/autograd/test_jacobian_static.py b/python/paddle/fluid/tests/unittests/autograd/test_autograd_static.py similarity index 53% rename from python/paddle/fluid/tests/unittests/autograd/test_jacobian_static.py rename to python/paddle/fluid/tests/unittests/autograd/test_autograd_static.py index 28fc6932b07..60dc9d06b8a 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_jacobian_static.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_autograd_static.py @@ -66,107 +66,38 @@ def approx_jacobian(f, xs, dtype, eps=1e-5, batch=False): ds = eps * np.eye(xdim, dtype=dtype) - fprimes_by_x = [(0.5 / eps) * (_f(x + d) - _f(x - d)) for d in ds] + fprimes_by_x = [(0.5 * (_f(x + d) - _f(x - d)) / eps) for d in ds] fprimes_by_y = np.stack(fprimes_by_x, axis=-1) return np.transpose(fprimes_by_y, [1, 0, 2]) if batch else fprimes_by_y -class TestJacobianFloat32(unittest.TestCase): - @classmethod - def setUpClass(self): - paddle.enable_static() - if fluid.core.is_compiled_with_cuda(): - self.place = fluid.CUDAPlace(0) - else: - self.place = fluid.CPUPlace() - self.np_dtype = np.float32 - self.A = np.array([[1., 2.]]).astype('float32') - self.B = np.array([[1., 2.], [2., 1.]]).astype('float32') - self.C = np.array([[2., 2.], [2., 1.]]).astype('float32') - self.D = np.array( - [[[2., 2.], [2., 1.]], [[1., 2.], [2., 1.]]]).astype('float32') - self.E = np.array( - [[[3., 4.], [2., 3.]], [[2., 1.], [1., 3.]]]).astype('float32') - self.eps = 1e-4 - self.rtol = 1e-2 - self.atol = 1e-2 - - def run_test(self, pd_f, np_f, inps, dtype, batch=False): - def make_tensors(inps): - if isinstance(inps, list): - xs = [ - paddle.static.data( - f'x{i}', inp.shape, dtype=inp.dtype) - for i, inp in enumerate(inps) - ] - else: - xs = paddle.static.data( - name='x', shape=inps.shape, dtype=inps.dtype) - return xs - - main = fluid.Program() - startup = fluid.Program() - with fluid.program_guard(main, startup): - xs = make_tensors(inps) - JJ = paddle.autograd.functional.Jacobian(pd_f, xs, batch=batch) - nrow, ncol = JJ.shape() - full_jacobian = JJ[:] - exe = fluid.Executor(self.place) - exe.run(startup) - if isinstance(inps, list): - feeds = {f'x{i}': x for i, x in enumerate(inps)} - else: - feeds = {'x': inps} - pd_jacobians = exe.run(main, feed=feeds, fetch_list=[full_jacobian])[0] - np_jacobians = approx_jacobian(np_f, inps, dtype, self.eps, batch=batch) - self.assertTrue( - np.allclose(pd_jacobians, np_jacobians, self.rtol, self.atol)) - - def test_square(self): - def pd_f(x): - return paddle.multiply(x, x) - - def np_f(x): - return np.multiply(x, x) - - self.run_test(pd_f, np_f, self.A, np.dtype('float32')) - - def test_mul(self): - def pd_f(xs): - x, y = xs - return paddle.multiply(x, y) - - def np_f(xs): - x, y = xs - return np.multiply(x, y) - - self.run_test(pd_f, np_f, [self.B, self.C], np.dtype('float32')) - - def test_matmul(self): - def pd_f(xs): - x, y = xs - return paddle.matmul(x, y) - - def np_f(xs): - x, y = xs - return np.matmul(x, y) +def make_tensors(inps): + if isinstance(inps, list): + xs = [ + paddle.static.data( + f'x{i}', inp.shape, dtype=inp.dtype) + for i, inp in enumerate(inps) + ] + else: + xs = paddle.static.data(name='x', shape=inps.shape, dtype=inps.dtype) + return xs - self.run_test(pd_f, np_f, [self.B, self.C], np.dtype('float32')) - def test_batch_matmul(self): - def pd_f(xs): - x, y = xs - return paddle.matmul(x, y) +all_data_shapes = { + 'A': [[1., 2.]], + 'B': [[1., 2.], [2., 1.]], + 'C': [[2., 2.], [2., 1.]], + 'D': [[[2., 2.], [2., 1.]], [[1., 2.], [2., 1.]]], + 'E': [[[3., 4.], [2., 3.]], [[2., 1.], [1., 3.]]], +} - def np_f(xs): - x, y = xs - return np.matmul(x, y) - self.run_test( - pd_f, np_f, [self.D, self.E], np.dtype('float32'), batch=True) +def prepare_data(test, input_shapes, dtype): + for name, shape in input_shapes.items(): + setattr(test, name, np.array(shape, dtype=dtype)) -class TestJacobianFloat64(unittest.TestCase): +class TestJacobianFloat32(unittest.TestCase): @classmethod def setUpClass(self): paddle.enable_static() @@ -174,31 +105,13 @@ class TestJacobianFloat64(unittest.TestCase): self.place = fluid.CUDAPlace(0) else: self.place = fluid.CPUPlace() - self.np_dtype = np.float32 - self.A = np.array([[1., 2.]]).astype('float64') - self.B = np.array([[1., 2.], [2., 1.]]).astype('float64') - self.C = np.array([[2., 2.], [2., 1.]]).astype('float64') - self.D = np.array( - [[[2., 2.], [2., 1.]], [[1., 2.], [2., 1.]]]).astype('float64') - self.E = np.array( - [[[3., 4.], [2., 3.]], [[2., 1.], [1., 3.]]]).astype('float64') - self.eps = 1e-7 - self.rtol = 1e-6 - self.atol = 1e-6 - - def run_test_by_fullmatrix(self, pd_f, np_f, inps, dtype, batch=False): - def make_tensors(inps): - if isinstance(inps, list): - xs = [ - paddle.static.data( - f'x{i}', inp.shape, dtype=inp.dtype) - for i, inp in enumerate(inps) - ] - else: - xs = paddle.static.data( - name='x', shape=inps.shape, dtype=inps.dtype) - return xs + self.dtype = 'float32' + prepare_data(self, all_data_shapes, self.dtype) + self.eps = 1e-4 + self.rtol = 1e-2 + self.atol = 1e-2 + def run_test_by_fullmatrix(self, pd_f, np_f, inps, batch=False): main = fluid.Program() startup = fluid.Program() with fluid.program_guard(main, startup): @@ -213,23 +126,12 @@ class TestJacobianFloat64(unittest.TestCase): else: feeds = {'x': inps} pd_jacobians = exe.run(main, feed=feeds, fetch_list=[full_jacobian])[0] - np_jacobians = approx_jacobian(np_f, inps, dtype, self.eps, batch=batch) + np_jacobians = approx_jacobian( + np_f, inps, self.dtype, self.eps, batch=batch) self.assertTrue( np.allclose(pd_jacobians, np_jacobians, self.rtol, self.atol)) - def run_test_by_rows(self, pd_f, np_f, inps, dtype, batch=False): - def make_tensors(inps): - if isinstance(inps, list): - xs = [ - paddle.static.data( - f'x{i}', inp.shape, dtype=inp.dtype) - for i, inp in enumerate(inps) - ] - else: - xs = paddle.static.data( - name='x', shape=inps.shape, dtype=inps.dtype) - return xs - + def run_test_by_rows(self, pd_f, np_f, inps, batch=False): main = fluid.Program() startup = fluid.Program() with fluid.program_guard(main, startup): @@ -244,24 +146,12 @@ class TestJacobianFloat64(unittest.TestCase): else: feeds = {'x': inps} pd_jac = exe.run(main, feed=feeds, fetch_list=[rows]) - np_jac = approx_jacobian(np_f, inps, dtype, self.eps, batch=batch) + np_jac = approx_jacobian(np_f, inps, self.dtype, self.eps, batch=batch) for i in range(nrow): self.assertTrue( np.allclose(pd_jac[i], np_jac[i], self.rtol, self.atol)) - def run_test_by_entries(self, pd_f, np_f, inps, dtype, batch=False): - def make_tensors(inps): - if isinstance(inps, list): - xs = [ - paddle.static.data( - f'x{i}', inp.shape, dtype=inp.dtype) - for i, inp in enumerate(inps) - ] - else: - xs = paddle.static.data( - name='x', shape=inps.shape, dtype=inps.dtype) - return xs - + def run_test_by_entries(self, pd_f, np_f, inps, batch=False): main = fluid.Program() startup = fluid.Program() with fluid.program_guard(main, startup): @@ -276,7 +166,7 @@ class TestJacobianFloat64(unittest.TestCase): else: feeds = {'x': inps} pd_entries = exe.run(main, feed=feeds, fetch_list=[entries]) - np_jac = approx_jacobian(np_f, inps, dtype, self.eps, batch=batch) + np_jac = approx_jacobian(np_f, inps, self.dtype, self.eps, batch=batch) np_entries = [ np_jac[i, ..., j] for i in range(nrow) for j in range(ncol) ] @@ -291,9 +181,9 @@ class TestJacobianFloat64(unittest.TestCase): def np_f(x): return np.multiply(x, x) - self.run_test_by_fullmatrix(pd_f, np_f, self.A, np.dtype('float64')) - self.run_test_by_rows(pd_f, np_f, self.A, np.dtype('float64')) - self.run_test_by_entries(pd_f, np_f, self.A, np.dtype('float64')) + self.run_test_by_fullmatrix(pd_f, np_f, self.A) + self.run_test_by_rows(pd_f, np_f, self.A) + self.run_test_by_entries(pd_f, np_f, self.A) def test_mul(self): def pd_f(xs): @@ -304,11 +194,12 @@ class TestJacobianFloat64(unittest.TestCase): x, y = xs return np.multiply(x, y) - self.run_test_by_fullmatrix(pd_f, np_f, [self.B, self.C], - np.dtype('float64')) - self.run_test_by_rows(pd_f, np_f, [self.B, self.C], np.dtype('float64')) - self.run_test_by_entries(pd_f, np_f, [self.B, self.C], - np.dtype('float64')) + self.run_test_by_fullmatrix( + pd_f, + np_f, + [self.B, self.C], ) + self.run_test_by_rows(pd_f, np_f, [self.B, self.C]) + self.run_test_by_entries(pd_f, np_f, [self.B, self.C]) def test_matmul(self): def pd_f(xs): @@ -319,11 +210,9 @@ class TestJacobianFloat64(unittest.TestCase): x, y = xs return np.matmul(x, y) - self.run_test_by_fullmatrix(pd_f, np_f, [self.B, self.C], - np.dtype('float64')) - self.run_test_by_rows(pd_f, np_f, [self.B, self.C], np.dtype('float64')) - self.run_test_by_entries(pd_f, np_f, [self.B, self.C], - np.dtype('float64')) + self.run_test_by_fullmatrix(pd_f, np_f, [self.B, self.C]) + self.run_test_by_rows(pd_f, np_f, [self.B, self.C]) + self.run_test_by_entries(pd_f, np_f, [self.B, self.C]) def test_batch_matmul(self): def pd_f(xs): @@ -334,12 +223,85 @@ class TestJacobianFloat64(unittest.TestCase): x, y = xs return np.matmul(x, y) - self.run_test_by_fullmatrix( - pd_f, np_f, [self.D, self.E], np.dtype('float64'), batch=True) - self.run_test_by_rows( - pd_f, np_f, [self.D, self.E], np.dtype('float64'), batch=True) - self.run_test_by_entries( - pd_f, np_f, [self.D, self.E], np.dtype('float64'), batch=True) + self.run_test_by_fullmatrix(pd_f, np_f, [self.D, self.E], batch=True) + self.run_test_by_rows(pd_f, np_f, [self.D, self.E], batch=True) + self.run_test_by_entries(pd_f, np_f, [self.D, self.E], batch=True) + + +class TestJacobianFloat64(TestJacobianFloat32): + @classmethod + def setUpClass(self): + paddle.enable_static() + if fluid.core.is_compiled_with_cuda(): + self.place = fluid.CUDAPlace(0) + else: + self.place = fluid.CPUPlace() + self.dtype = 'float64' + prepare_data(self, all_data_shapes, self.dtype) + self.eps = 1e-7 + self.rtol = 1e-6 + self.atol = 1e-6 + + +class TestHessianFloat64(unittest.TestCase): + @classmethod + def setUpClass(self): + paddle.enable_static() + if fluid.core.is_compiled_with_cuda(): + self.place = fluid.CUDAPlace(0) + else: + self.place = fluid.CPUPlace() + self.dtype = 'float64' + prepare_data(self, all_data_shapes, self.dtype) + self.eps = 1e-7 + self.rtol = 1e-6 + self.atol = 1e-6 + + def run_test_by_fullmatrix(self, pd_f, inps, np_hess, batch=False): + main = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(main, startup): + xs = make_tensors(inps) + HH = paddle.autograd.functional.Hessian(pd_f, xs, batch=batch) + nrow, ncol = HH.shape() + full_hessian = HH[:] + exe = fluid.Executor(self.place) + exe.run(startup) + if isinstance(inps, list): + feeds = {f'x{i}': x for i, x in enumerate(inps)} + else: + feeds = {'x': inps} + pd_hess = exe.run(main, feed=feeds, fetch_list=[full_hessian])[0] + self.assertTrue(np.allclose(pd_hess, np_hess, self.rtol, self.atol)) + + def test_square(self): + def pd_f(x): + """Input is a square matrix.""" + return paddle.matmul(x, x.T) + + def np_hess(x): + dim = x.shape[0] + f_xx_upperleft = 2 * np.eye(dim, dtype=self.dtype) + f_xx = np.zeros([dim * dim, dim * dim], dtype=self.dtype) + f_xx[:dim, :dim] = f_xx_upperleft + return f_xx + + self.run_test_by_fullmatrix(pd_f, self.B, np_hess(self.B)) + + def test_batch_square(self): + def pd_f(x): + """Input is a square matrix.""" + return paddle.matmul(x, paddle.transpose(x, [0, 2, 1])) + + def np_hess(x): + bat, dim, _ = x.shape + f_xx_upperleft = 2 * np.eye(dim, dtype=self.dtype) + f_xx = np.zeros([bat, dim * dim, dim * dim], dtype=self.dtype) + f_xx[..., :dim, :dim] = f_xx_upperleft + return f_xx + + self.run_test_by_fullmatrix( + pd_f, self.E, np_hess(self.E), batch=True) if __name__ == "__main__": -- GitLab