From d43655ba9402fe7021e4fb11dff6b6eecdf559fb Mon Sep 17 00:00:00 2001 From: Tongxin Bai Date: Mon, 24 Jan 2022 22:52:39 +0800 Subject: [PATCH] [autograd] static Jacobian pass tests. (#39007) * [autograd] static Jacobian pass tests. * [autograd] apply CR suggested changes. * [autograd] more tests. * [autograd] add CPUPlace in tests. * [autograd] bug fixes. * [autograd] reformatted. --- python/paddle/autograd/functional.py | 120 ++++++ .../autograd/test_jacobian_static.py | 346 ++++++++++++++++++ 2 files changed, 466 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/autograd/test_jacobian_static.py diff --git a/python/paddle/autograd/functional.py b/python/paddle/autograd/functional.py index 2e5adfa5df..b9cceafeba 100644 --- a/python/paddle/autograd/functional.py +++ b/python/paddle/autograd/functional.py @@ -14,6 +14,7 @@ import contextlib import paddle +from paddle.static import gradients from ..fluid import framework from ..fluid.dygraph import grad from ..tensor.creation import assign @@ -904,3 +905,122 @@ def vhp(func, inputs, v=None, create_graph=False, allow_unused=False): vhp = grad_fn(jac, xs, v) outputs, vhp = return_fn(outputs), return_fn(vhp) return outputs, vhp + + +class Jacobian(object): + r""" + Object that represents the Jacobian matrix of a muli-input multi-output + function. + + The Jacobian values are lazily evaluated if accessed through indices. + In contrast, slicing access would trigger evaluating the full matrix + if it's not already computed. + + Examples: + .. code-block:: python + import paddle + import numpy as np + + def func(xs): + x, y = xs + return paddle.matmul(x, y) + + main = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(main, startup): + x = paddle.static.data(name='x', shape=[2, 2], dtype='float32') + JJ = paddle.autograd.functional.Jacobian(func, [x, x]) + nrow, ncol = JJ.shape() + full_jacobian = JJ[:] + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + exe.run(startup) + + feeds = {'x': np.array([[2., 2.], [2., 1.]]).astype('float32')} + jacobian = exe.run(main, feed=feeds, fetch_list=[full_jacobian])[0] + print(jacobian) + # [[4. 2. 2. 0. 4. 2. 2. 0.] + # [2. 3. 0. 2. 2. 3. 0. 2.] + # [2. 0. 3. 2. 2. 0. 3. 2.] + # [0. 2. 2. 2. 0. 2. 2. 2.]] + """ + + def __init__(self, func, inputs, batch=False): + r"""Constructing a Jacobian matrix. + + Parameters: + func (Callable): a Python function that takes as input a Tensor + or a Tensor list and outputs a Tensor or a Tensor list. + inputs (Tensor|list[Tensor]): a Tensor or a list of Tensors as + `func`'s input. + batch (bool): if True the 0'th axis is considered the batch + dimension, both on input and output. + """ + + def enable_grads(inputs): + if isinstance(inputs, (list, tuple)): + for x in inputs: + x.stop_gradient = False + else: + assert isinstance(inputs, paddle.fluid.framework.Variable), ( + f"Expecting {inputs} to be paddle.fluid.framework.Variable," + f" however it's found to be a(n) {type(inputs)}.") + inputs.stop_gradient = False + return inputs + + self.batch = batch + self.xs = enable_grads(inputs) + ys = func(inputs) + if not isinstance(ys, list): + ys = [ys] + self.y = self.flatten_all(ys) + self.ydim = self.y.shape[-1] + self.xdim = self.flatten_all(inputs).shape[-1] + self.bdim = self.y.shape[0] + self.jacobian = {} + + def flatten(self, x): + to = [x.shape[0], -1] if self.batch else [-1] + return x.reshape(to) + + def flatten_all(self, xs): + return paddle.concat([self.flatten(x) for x in xs], axis=-1) + + def shape(self): + return (self.ydim, self.xdim) + + def __getitem__(self, tup): + if hasattr(tup, '__iter__'): + i, j = tup + else: + i, j = tup, None + + if isinstance(i, slice): + slicing = True + else: + slicing = False + + if slicing: + if 'full' not in self.jacobian: + rows = [ + self.flatten_all(gradients(self.y[..., i], self.xs)) + for i in range(self.ydim) + ] + self.jacobian['full'] = paddle.stack(rows) + return self.jacobian['full'][i] + + assert 0 <= i < self.ydim, f"Jacobian index i={i} is not valid." + assert (j is None) or ( + 0 <= j < self.xdim), f"Jacobian index j={j} is not valid." + if 'full' in self.jacobian: + JJ = self.jacobian['full'] + else: + JJ = self.jacobian + if i not in self.jacobian: + self.jacobian[i] = self.flatten_all( + gradients(self.y[..., i], self.xs)) + + if j is None: + return JJ[i] + else: + return JJ[i][..., j] diff --git a/python/paddle/fluid/tests/unittests/autograd/test_jacobian_static.py b/python/paddle/fluid/tests/unittests/autograd/test_jacobian_static.py new file mode 100644 index 0000000000..28fc6932b0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/autograd/test_jacobian_static.py @@ -0,0 +1,346 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle +import paddle.fluid as fluid +from utils import _compute_numerical_jacobian, _compute_numerical_batch_jacobian + + +def approx_jacobian(f, xs, dtype, eps=1e-5, batch=False): + r"""Computes an approximate Jacobian matrix of a multi-valued function + using finite differences. + + The function input is required to be an np array or a list of list of np + arrays. + """ + + def flatten(x): + if len(x.shape) > 0: + to = [x.shape[0], -1] if batch else [-1] + return x.reshape(to) + else: + return x + + def flatten_all(xs): + if isinstance(xs, list): + flattened = np.concatenate([flatten(x) for x in xs], axis=-1) + else: + flattened = flatten(xs) + return flattened + + def x_like(x, orig_x): + return x.reshape(orig_x.shape) + + def _f(x): + if multi_inps: + _xs = np.split(x, splits, axis=-1) + _xs = [x_like(_x, _o) for _x, _o in zip(_xs, xs)] + outs = f(_xs) + else: + outs = f(x) + return flatten_all(outs) + + multi_inps = False if isinstance(xs, np.ndarray) else True + x = flatten_all(xs) + xdim = x.shape[-1] + splits = [] + + if multi_inps: + split = 0 + for inp in xs: + split += flatten(inp).shape[-1] + splits.append(split) + + ds = eps * np.eye(xdim, dtype=dtype) + + fprimes_by_x = [(0.5 / eps) * (_f(x + d) - _f(x - d)) for d in ds] + fprimes_by_y = np.stack(fprimes_by_x, axis=-1) + return np.transpose(fprimes_by_y, [1, 0, 2]) if batch else fprimes_by_y + + +class TestJacobianFloat32(unittest.TestCase): + @classmethod + def setUpClass(self): + paddle.enable_static() + if fluid.core.is_compiled_with_cuda(): + self.place = fluid.CUDAPlace(0) + else: + self.place = fluid.CPUPlace() + self.np_dtype = np.float32 + self.A = np.array([[1., 2.]]).astype('float32') + self.B = np.array([[1., 2.], [2., 1.]]).astype('float32') + self.C = np.array([[2., 2.], [2., 1.]]).astype('float32') + self.D = np.array( + [[[2., 2.], [2., 1.]], [[1., 2.], [2., 1.]]]).astype('float32') + self.E = np.array( + [[[3., 4.], [2., 3.]], [[2., 1.], [1., 3.]]]).astype('float32') + self.eps = 1e-4 + self.rtol = 1e-2 + self.atol = 1e-2 + + def run_test(self, pd_f, np_f, inps, dtype, batch=False): + def make_tensors(inps): + if isinstance(inps, list): + xs = [ + paddle.static.data( + f'x{i}', inp.shape, dtype=inp.dtype) + for i, inp in enumerate(inps) + ] + else: + xs = paddle.static.data( + name='x', shape=inps.shape, dtype=inps.dtype) + return xs + + main = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(main, startup): + xs = make_tensors(inps) + JJ = paddle.autograd.functional.Jacobian(pd_f, xs, batch=batch) + nrow, ncol = JJ.shape() + full_jacobian = JJ[:] + exe = fluid.Executor(self.place) + exe.run(startup) + if isinstance(inps, list): + feeds = {f'x{i}': x for i, x in enumerate(inps)} + else: + feeds = {'x': inps} + pd_jacobians = exe.run(main, feed=feeds, fetch_list=[full_jacobian])[0] + np_jacobians = approx_jacobian(np_f, inps, dtype, self.eps, batch=batch) + self.assertTrue( + np.allclose(pd_jacobians, np_jacobians, self.rtol, self.atol)) + + def test_square(self): + def pd_f(x): + return paddle.multiply(x, x) + + def np_f(x): + return np.multiply(x, x) + + self.run_test(pd_f, np_f, self.A, np.dtype('float32')) + + def test_mul(self): + def pd_f(xs): + x, y = xs + return paddle.multiply(x, y) + + def np_f(xs): + x, y = xs + return np.multiply(x, y) + + self.run_test(pd_f, np_f, [self.B, self.C], np.dtype('float32')) + + def test_matmul(self): + def pd_f(xs): + x, y = xs + return paddle.matmul(x, y) + + def np_f(xs): + x, y = xs + return np.matmul(x, y) + + self.run_test(pd_f, np_f, [self.B, self.C], np.dtype('float32')) + + def test_batch_matmul(self): + def pd_f(xs): + x, y = xs + return paddle.matmul(x, y) + + def np_f(xs): + x, y = xs + return np.matmul(x, y) + + self.run_test( + pd_f, np_f, [self.D, self.E], np.dtype('float32'), batch=True) + + +class TestJacobianFloat64(unittest.TestCase): + @classmethod + def setUpClass(self): + paddle.enable_static() + if fluid.core.is_compiled_with_cuda(): + self.place = fluid.CUDAPlace(0) + else: + self.place = fluid.CPUPlace() + self.np_dtype = np.float32 + self.A = np.array([[1., 2.]]).astype('float64') + self.B = np.array([[1., 2.], [2., 1.]]).astype('float64') + self.C = np.array([[2., 2.], [2., 1.]]).astype('float64') + self.D = np.array( + [[[2., 2.], [2., 1.]], [[1., 2.], [2., 1.]]]).astype('float64') + self.E = np.array( + [[[3., 4.], [2., 3.]], [[2., 1.], [1., 3.]]]).astype('float64') + self.eps = 1e-7 + self.rtol = 1e-6 + self.atol = 1e-6 + + def run_test_by_fullmatrix(self, pd_f, np_f, inps, dtype, batch=False): + def make_tensors(inps): + if isinstance(inps, list): + xs = [ + paddle.static.data( + f'x{i}', inp.shape, dtype=inp.dtype) + for i, inp in enumerate(inps) + ] + else: + xs = paddle.static.data( + name='x', shape=inps.shape, dtype=inps.dtype) + return xs + + main = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(main, startup): + xs = make_tensors(inps) + JJ = paddle.autograd.functional.Jacobian(pd_f, xs, batch=batch) + nrow, ncol = JJ.shape() + full_jacobian = JJ[:] + exe = fluid.Executor(self.place) + exe.run(startup) + if isinstance(inps, list): + feeds = {f'x{i}': x for i, x in enumerate(inps)} + else: + feeds = {'x': inps} + pd_jacobians = exe.run(main, feed=feeds, fetch_list=[full_jacobian])[0] + np_jacobians = approx_jacobian(np_f, inps, dtype, self.eps, batch=batch) + self.assertTrue( + np.allclose(pd_jacobians, np_jacobians, self.rtol, self.atol)) + + def run_test_by_rows(self, pd_f, np_f, inps, dtype, batch=False): + def make_tensors(inps): + if isinstance(inps, list): + xs = [ + paddle.static.data( + f'x{i}', inp.shape, dtype=inp.dtype) + for i, inp in enumerate(inps) + ] + else: + xs = paddle.static.data( + name='x', shape=inps.shape, dtype=inps.dtype) + return xs + + main = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(main, startup): + xs = make_tensors(inps) + JJ = paddle.autograd.functional.Jacobian(pd_f, xs, batch=batch) + nrow, ncol = JJ.shape() + rows = [JJ[i] for i in range(nrow)] + exe = fluid.Executor(self.place) + exe.run(startup) + if isinstance(inps, list): + feeds = {f'x{i}': x for i, x in enumerate(inps)} + else: + feeds = {'x': inps} + pd_jac = exe.run(main, feed=feeds, fetch_list=[rows]) + np_jac = approx_jacobian(np_f, inps, dtype, self.eps, batch=batch) + for i in range(nrow): + self.assertTrue( + np.allclose(pd_jac[i], np_jac[i], self.rtol, self.atol)) + + def run_test_by_entries(self, pd_f, np_f, inps, dtype, batch=False): + def make_tensors(inps): + if isinstance(inps, list): + xs = [ + paddle.static.data( + f'x{i}', inp.shape, dtype=inp.dtype) + for i, inp in enumerate(inps) + ] + else: + xs = paddle.static.data( + name='x', shape=inps.shape, dtype=inps.dtype) + return xs + + main = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(main, startup): + xs = make_tensors(inps) + JJ = paddle.autograd.functional.Jacobian(pd_f, xs, batch=batch) + nrow, ncol = JJ.shape() + entries = [JJ[i, j] for i in range(nrow) for j in range(ncol)] + exe = fluid.Executor(self.place) + exe.run(startup) + if isinstance(inps, list): + feeds = {f'x{i}': x for i, x in enumerate(inps)} + else: + feeds = {'x': inps} + pd_entries = exe.run(main, feed=feeds, fetch_list=[entries]) + np_jac = approx_jacobian(np_f, inps, dtype, self.eps, batch=batch) + np_entries = [ + np_jac[i, ..., j] for i in range(nrow) for j in range(ncol) + ] + for pd_entry, np_entry in zip(pd_entries, np_entries): + self.assertTrue( + np.allclose(pd_entry, np_entry, self.rtol, self.atol)) + + def test_square(self): + def pd_f(x): + return paddle.multiply(x, x) + + def np_f(x): + return np.multiply(x, x) + + self.run_test_by_fullmatrix(pd_f, np_f, self.A, np.dtype('float64')) + self.run_test_by_rows(pd_f, np_f, self.A, np.dtype('float64')) + self.run_test_by_entries(pd_f, np_f, self.A, np.dtype('float64')) + + def test_mul(self): + def pd_f(xs): + x, y = xs + return paddle.multiply(x, y) + + def np_f(xs): + x, y = xs + return np.multiply(x, y) + + self.run_test_by_fullmatrix(pd_f, np_f, [self.B, self.C], + np.dtype('float64')) + self.run_test_by_rows(pd_f, np_f, [self.B, self.C], np.dtype('float64')) + self.run_test_by_entries(pd_f, np_f, [self.B, self.C], + np.dtype('float64')) + + def test_matmul(self): + def pd_f(xs): + x, y = xs + return paddle.matmul(x, y) + + def np_f(xs): + x, y = xs + return np.matmul(x, y) + + self.run_test_by_fullmatrix(pd_f, np_f, [self.B, self.C], + np.dtype('float64')) + self.run_test_by_rows(pd_f, np_f, [self.B, self.C], np.dtype('float64')) + self.run_test_by_entries(pd_f, np_f, [self.B, self.C], + np.dtype('float64')) + + def test_batch_matmul(self): + def pd_f(xs): + x, y = xs + return paddle.matmul(x, y) + + def np_f(xs): + x, y = xs + return np.matmul(x, y) + + self.run_test_by_fullmatrix( + pd_f, np_f, [self.D, self.E], np.dtype('float64'), batch=True) + self.run_test_by_rows( + pd_f, np_f, [self.D, self.E], np.dtype('float64'), batch=True) + self.run_test_by_entries( + pd_f, np_f, [self.D, self.E], np.dtype('float64'), batch=True) + + +if __name__ == "__main__": + unittest.main() -- GitLab