未验证 提交 d43655ba 编写于 作者: T Tongxin Bai 提交者: GitHub

[autograd] static Jacobian pass tests. (#39007)

* [autograd] static Jacobian pass tests.

* [autograd] apply CR suggested changes.

* [autograd] more tests.

* [autograd] add CPUPlace in tests.

* [autograd] bug fixes.

* [autograd] reformatted.
上级 c00303ec
......@@ -14,6 +14,7 @@
import contextlib
import paddle
from paddle.static import gradients
from ..fluid import framework
from ..fluid.dygraph import grad
from ..tensor.creation import assign
......@@ -904,3 +905,122 @@ def vhp(func, inputs, v=None, create_graph=False, allow_unused=False):
vhp = grad_fn(jac, xs, v)
outputs, vhp = return_fn(outputs), return_fn(vhp)
return outputs, vhp
class Jacobian(object):
r"""
Object that represents the Jacobian matrix of a muli-input multi-output
function.
The Jacobian values are lazily evaluated if accessed through indices.
In contrast, slicing access would trigger evaluating the full matrix
if it's not already computed.
Examples:
.. code-block:: python
import paddle
import numpy as np
def func(xs):
x, y = xs
return paddle.matmul(x, y)
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
x = paddle.static.data(name='x', shape=[2, 2], dtype='float32')
JJ = paddle.autograd.functional.Jacobian(func, [x, x])
nrow, ncol = JJ.shape()
full_jacobian = JJ[:]
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(startup)
feeds = {'x': np.array([[2., 2.], [2., 1.]]).astype('float32')}
jacobian = exe.run(main, feed=feeds, fetch_list=[full_jacobian])[0]
print(jacobian)
# [[4. 2. 2. 0. 4. 2. 2. 0.]
# [2. 3. 0. 2. 2. 3. 0. 2.]
# [2. 0. 3. 2. 2. 0. 3. 2.]
# [0. 2. 2. 2. 0. 2. 2. 2.]]
"""
def __init__(self, func, inputs, batch=False):
r"""Constructing a Jacobian matrix.
Parameters:
func (Callable): a Python function that takes as input a Tensor
or a Tensor list and outputs a Tensor or a Tensor list.
inputs (Tensor|list[Tensor]): a Tensor or a list of Tensors as
`func`'s input.
batch (bool): if True the 0'th axis is considered the batch
dimension, both on input and output.
"""
def enable_grads(inputs):
if isinstance(inputs, (list, tuple)):
for x in inputs:
x.stop_gradient = False
else:
assert isinstance(inputs, paddle.fluid.framework.Variable), (
f"Expecting {inputs} to be paddle.fluid.framework.Variable,"
f" however it's found to be a(n) {type(inputs)}.")
inputs.stop_gradient = False
return inputs
self.batch = batch
self.xs = enable_grads(inputs)
ys = func(inputs)
if not isinstance(ys, list):
ys = [ys]
self.y = self.flatten_all(ys)
self.ydim = self.y.shape[-1]
self.xdim = self.flatten_all(inputs).shape[-1]
self.bdim = self.y.shape[0]
self.jacobian = {}
def flatten(self, x):
to = [x.shape[0], -1] if self.batch else [-1]
return x.reshape(to)
def flatten_all(self, xs):
return paddle.concat([self.flatten(x) for x in xs], axis=-1)
def shape(self):
return (self.ydim, self.xdim)
def __getitem__(self, tup):
if hasattr(tup, '__iter__'):
i, j = tup
else:
i, j = tup, None
if isinstance(i, slice):
slicing = True
else:
slicing = False
if slicing:
if 'full' not in self.jacobian:
rows = [
self.flatten_all(gradients(self.y[..., i], self.xs))
for i in range(self.ydim)
]
self.jacobian['full'] = paddle.stack(rows)
return self.jacobian['full'][i]
assert 0 <= i < self.ydim, f"Jacobian index i={i} is not valid."
assert (j is None) or (
0 <= j < self.xdim), f"Jacobian index j={j} is not valid."
if 'full' in self.jacobian:
JJ = self.jacobian['full']
else:
JJ = self.jacobian
if i not in self.jacobian:
self.jacobian[i] = self.flatten_all(
gradients(self.y[..., i], self.xs))
if j is None:
return JJ[i]
else:
return JJ[i][..., j]
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from utils import _compute_numerical_jacobian, _compute_numerical_batch_jacobian
def approx_jacobian(f, xs, dtype, eps=1e-5, batch=False):
r"""Computes an approximate Jacobian matrix of a multi-valued function
using finite differences.
The function input is required to be an np array or a list of list of np
arrays.
"""
def flatten(x):
if len(x.shape) > 0:
to = [x.shape[0], -1] if batch else [-1]
return x.reshape(to)
else:
return x
def flatten_all(xs):
if isinstance(xs, list):
flattened = np.concatenate([flatten(x) for x in xs], axis=-1)
else:
flattened = flatten(xs)
return flattened
def x_like(x, orig_x):
return x.reshape(orig_x.shape)
def _f(x):
if multi_inps:
_xs = np.split(x, splits, axis=-1)
_xs = [x_like(_x, _o) for _x, _o in zip(_xs, xs)]
outs = f(_xs)
else:
outs = f(x)
return flatten_all(outs)
multi_inps = False if isinstance(xs, np.ndarray) else True
x = flatten_all(xs)
xdim = x.shape[-1]
splits = []
if multi_inps:
split = 0
for inp in xs:
split += flatten(inp).shape[-1]
splits.append(split)
ds = eps * np.eye(xdim, dtype=dtype)
fprimes_by_x = [(0.5 / eps) * (_f(x + d) - _f(x - d)) for d in ds]
fprimes_by_y = np.stack(fprimes_by_x, axis=-1)
return np.transpose(fprimes_by_y, [1, 0, 2]) if batch else fprimes_by_y
class TestJacobianFloat32(unittest.TestCase):
@classmethod
def setUpClass(self):
paddle.enable_static()
if fluid.core.is_compiled_with_cuda():
self.place = fluid.CUDAPlace(0)
else:
self.place = fluid.CPUPlace()
self.np_dtype = np.float32
self.A = np.array([[1., 2.]]).astype('float32')
self.B = np.array([[1., 2.], [2., 1.]]).astype('float32')
self.C = np.array([[2., 2.], [2., 1.]]).astype('float32')
self.D = np.array(
[[[2., 2.], [2., 1.]], [[1., 2.], [2., 1.]]]).astype('float32')
self.E = np.array(
[[[3., 4.], [2., 3.]], [[2., 1.], [1., 3.]]]).astype('float32')
self.eps = 1e-4
self.rtol = 1e-2
self.atol = 1e-2
def run_test(self, pd_f, np_f, inps, dtype, batch=False):
def make_tensors(inps):
if isinstance(inps, list):
xs = [
paddle.static.data(
f'x{i}', inp.shape, dtype=inp.dtype)
for i, inp in enumerate(inps)
]
else:
xs = paddle.static.data(
name='x', shape=inps.shape, dtype=inps.dtype)
return xs
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
xs = make_tensors(inps)
JJ = paddle.autograd.functional.Jacobian(pd_f, xs, batch=batch)
nrow, ncol = JJ.shape()
full_jacobian = JJ[:]
exe = fluid.Executor(self.place)
exe.run(startup)
if isinstance(inps, list):
feeds = {f'x{i}': x for i, x in enumerate(inps)}
else:
feeds = {'x': inps}
pd_jacobians = exe.run(main, feed=feeds, fetch_list=[full_jacobian])[0]
np_jacobians = approx_jacobian(np_f, inps, dtype, self.eps, batch=batch)
self.assertTrue(
np.allclose(pd_jacobians, np_jacobians, self.rtol, self.atol))
def test_square(self):
def pd_f(x):
return paddle.multiply(x, x)
def np_f(x):
return np.multiply(x, x)
self.run_test(pd_f, np_f, self.A, np.dtype('float32'))
def test_mul(self):
def pd_f(xs):
x, y = xs
return paddle.multiply(x, y)
def np_f(xs):
x, y = xs
return np.multiply(x, y)
self.run_test(pd_f, np_f, [self.B, self.C], np.dtype('float32'))
def test_matmul(self):
def pd_f(xs):
x, y = xs
return paddle.matmul(x, y)
def np_f(xs):
x, y = xs
return np.matmul(x, y)
self.run_test(pd_f, np_f, [self.B, self.C], np.dtype('float32'))
def test_batch_matmul(self):
def pd_f(xs):
x, y = xs
return paddle.matmul(x, y)
def np_f(xs):
x, y = xs
return np.matmul(x, y)
self.run_test(
pd_f, np_f, [self.D, self.E], np.dtype('float32'), batch=True)
class TestJacobianFloat64(unittest.TestCase):
@classmethod
def setUpClass(self):
paddle.enable_static()
if fluid.core.is_compiled_with_cuda():
self.place = fluid.CUDAPlace(0)
else:
self.place = fluid.CPUPlace()
self.np_dtype = np.float32
self.A = np.array([[1., 2.]]).astype('float64')
self.B = np.array([[1., 2.], [2., 1.]]).astype('float64')
self.C = np.array([[2., 2.], [2., 1.]]).astype('float64')
self.D = np.array(
[[[2., 2.], [2., 1.]], [[1., 2.], [2., 1.]]]).astype('float64')
self.E = np.array(
[[[3., 4.], [2., 3.]], [[2., 1.], [1., 3.]]]).astype('float64')
self.eps = 1e-7
self.rtol = 1e-6
self.atol = 1e-6
def run_test_by_fullmatrix(self, pd_f, np_f, inps, dtype, batch=False):
def make_tensors(inps):
if isinstance(inps, list):
xs = [
paddle.static.data(
f'x{i}', inp.shape, dtype=inp.dtype)
for i, inp in enumerate(inps)
]
else:
xs = paddle.static.data(
name='x', shape=inps.shape, dtype=inps.dtype)
return xs
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
xs = make_tensors(inps)
JJ = paddle.autograd.functional.Jacobian(pd_f, xs, batch=batch)
nrow, ncol = JJ.shape()
full_jacobian = JJ[:]
exe = fluid.Executor(self.place)
exe.run(startup)
if isinstance(inps, list):
feeds = {f'x{i}': x for i, x in enumerate(inps)}
else:
feeds = {'x': inps}
pd_jacobians = exe.run(main, feed=feeds, fetch_list=[full_jacobian])[0]
np_jacobians = approx_jacobian(np_f, inps, dtype, self.eps, batch=batch)
self.assertTrue(
np.allclose(pd_jacobians, np_jacobians, self.rtol, self.atol))
def run_test_by_rows(self, pd_f, np_f, inps, dtype, batch=False):
def make_tensors(inps):
if isinstance(inps, list):
xs = [
paddle.static.data(
f'x{i}', inp.shape, dtype=inp.dtype)
for i, inp in enumerate(inps)
]
else:
xs = paddle.static.data(
name='x', shape=inps.shape, dtype=inps.dtype)
return xs
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
xs = make_tensors(inps)
JJ = paddle.autograd.functional.Jacobian(pd_f, xs, batch=batch)
nrow, ncol = JJ.shape()
rows = [JJ[i] for i in range(nrow)]
exe = fluid.Executor(self.place)
exe.run(startup)
if isinstance(inps, list):
feeds = {f'x{i}': x for i, x in enumerate(inps)}
else:
feeds = {'x': inps}
pd_jac = exe.run(main, feed=feeds, fetch_list=[rows])
np_jac = approx_jacobian(np_f, inps, dtype, self.eps, batch=batch)
for i in range(nrow):
self.assertTrue(
np.allclose(pd_jac[i], np_jac[i], self.rtol, self.atol))
def run_test_by_entries(self, pd_f, np_f, inps, dtype, batch=False):
def make_tensors(inps):
if isinstance(inps, list):
xs = [
paddle.static.data(
f'x{i}', inp.shape, dtype=inp.dtype)
for i, inp in enumerate(inps)
]
else:
xs = paddle.static.data(
name='x', shape=inps.shape, dtype=inps.dtype)
return xs
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
xs = make_tensors(inps)
JJ = paddle.autograd.functional.Jacobian(pd_f, xs, batch=batch)
nrow, ncol = JJ.shape()
entries = [JJ[i, j] for i in range(nrow) for j in range(ncol)]
exe = fluid.Executor(self.place)
exe.run(startup)
if isinstance(inps, list):
feeds = {f'x{i}': x for i, x in enumerate(inps)}
else:
feeds = {'x': inps}
pd_entries = exe.run(main, feed=feeds, fetch_list=[entries])
np_jac = approx_jacobian(np_f, inps, dtype, self.eps, batch=batch)
np_entries = [
np_jac[i, ..., j] for i in range(nrow) for j in range(ncol)
]
for pd_entry, np_entry in zip(pd_entries, np_entries):
self.assertTrue(
np.allclose(pd_entry, np_entry, self.rtol, self.atol))
def test_square(self):
def pd_f(x):
return paddle.multiply(x, x)
def np_f(x):
return np.multiply(x, x)
self.run_test_by_fullmatrix(pd_f, np_f, self.A, np.dtype('float64'))
self.run_test_by_rows(pd_f, np_f, self.A, np.dtype('float64'))
self.run_test_by_entries(pd_f, np_f, self.A, np.dtype('float64'))
def test_mul(self):
def pd_f(xs):
x, y = xs
return paddle.multiply(x, y)
def np_f(xs):
x, y = xs
return np.multiply(x, y)
self.run_test_by_fullmatrix(pd_f, np_f, [self.B, self.C],
np.dtype('float64'))
self.run_test_by_rows(pd_f, np_f, [self.B, self.C], np.dtype('float64'))
self.run_test_by_entries(pd_f, np_f, [self.B, self.C],
np.dtype('float64'))
def test_matmul(self):
def pd_f(xs):
x, y = xs
return paddle.matmul(x, y)
def np_f(xs):
x, y = xs
return np.matmul(x, y)
self.run_test_by_fullmatrix(pd_f, np_f, [self.B, self.C],
np.dtype('float64'))
self.run_test_by_rows(pd_f, np_f, [self.B, self.C], np.dtype('float64'))
self.run_test_by_entries(pd_f, np_f, [self.B, self.C],
np.dtype('float64'))
def test_batch_matmul(self):
def pd_f(xs):
x, y = xs
return paddle.matmul(x, y)
def np_f(xs):
x, y = xs
return np.matmul(x, y)
self.run_test_by_fullmatrix(
pd_f, np_f, [self.D, self.E], np.dtype('float64'), batch=True)
self.run_test_by_rows(
pd_f, np_f, [self.D, self.E], np.dtype('float64'), batch=True)
self.run_test_by_entries(
pd_f, np_f, [self.D, self.E], np.dtype('float64'), batch=True)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册