未验证 提交 9e764d82 编写于 作者: X Xiaoxu Chen 提交者: GitHub

Enhance vjp/jvp/Jacobian/Hessian API for supporting dynamic, static graph and...

Enhance vjp/jvp/Jacobian/Hessian API for supporting dynamic, static graph and batched, unbatched mode (#40692)

* modify vjp/jvp for both dynamic and static graph

* enforce jacobian class for supporting first/last batch

* add unittest for jvp, jacobian withlast batch, jacobian with first batch

* fix the incorrect shape when multi-index Jacobian

* enforce Hessian class for supporting dynamic graph

* add Hessian class unittest

* bugfix, jvp double_backward_trick zeros_like return stop_gradient=True in static graph

* add API beta warnnings

* add white_list for cuda11.x ci windows.

* optimize some code snippets and documments

* set unittest timeout to 100 seconds

* move vjp,jvp,Jacobian,Hessian to incubate

* fix vjp,vjp import path of sample code

* fix code style error of augtograd/__init__ file
上级 ab8c33b1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -13,12 +13,18 @@
# limitations under the License.
from ..fluid.dygraph.base import grad # noqa: F401
from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401
from ..framework import is_grad_enabled, set_grad_enabled # noqa: F401
from . import backward_mode # noqa: F401
from .backward_mode import backward # noqa: F401
from .py_layer import PyLayer, PyLayerContext, EagerPyLayer, EagerPyLayerContext # noqa: F401
from ..framework import set_grad_enabled, is_grad_enabled # noqa: F401
from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401
from .functional import jacobian, hessian, batch_jacobian, batch_hessian # noqa: F401
from .functional import vjp, jvp, vhp # noqa: F401
from .functional import vjp, jvp, Jacobian, Hessian # noqa: F401
from .functional import jacobian, hessian, batch_jacobian, batch_hessian, vhp # noqa: F401
__all__ = ['backward', 'PyLayer', 'PyLayerContext']
__all__ = [ # noqa
'backward',
'PyLayer',
'PyLayerContext',
]
......@@ -6,6 +6,5 @@ foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS})
endforeach(TEST_OP)
set_tests_properties(test_jacobian PROPERTIES TIMEOUT 50)
set_tests_properties(test_hessian PROPERTIES TIMEOUT 50)
set_tests_properties(test_vhp PROPERTIES TIMEOUT 50)
set_tests_properties(test_autograd_functional_dynamic PROPERTIES TIMEOUT 100)
set_tests_properties(test_autograd_functional_static PROPERTIES TIMEOUT 100)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
DEVICES = [paddle.CPUPlace()]
if paddle.is_compiled_with_cuda():
DEVICES.append(paddle.CUDAPlace(0))
def _tensors(ts, name):
if isinstance(ts, (list, tuple)):
assert len(ts) > 0, "{} connot be empty".format(name)
for each_t in ts:
assert isinstance(
each_t, paddle.Tensor
) or each_t is None, "Elements of {} must be paddle.Tensor or None".format(
name)
return list(ts)
else:
assert isinstance(ts, paddle.Tensor), "{} must be Tensor".format(name)
return [ts]
def _stack_tensor_or_return_none(origin_list):
assert len(origin_list) > 0, "Can't not stack an empty list"
return paddle.stack(
origin_list, axis=0) if isinstance(origin_list[0],
paddle.Tensor) else None
DEFAULT_DTYPE = 'float64'
def _replace_none_with_zero_tensor(t, spec_t):
if t is None:
zero_t = paddle.zeros(shape=spec_t.shape, dtype=spec_t.dtype)
zero_t.stop_gradient = spec_t.stop_gradient
return zero_t
else:
return t
# The numerical tolerance of different dtype of different order different
# derivative. It's a empirical value provided by Paddle Science team.
TOLERANCE = {
"float32": {
"first_order_grad": {
"rtol": 1e-3,
"atol": 1e-3,
"eps": 1e-4
},
"second_order_grad": {
"rtol": 1e-2,
"atol": 1e-2,
"eps": 1e-2
}
},
"float64": {
"first_order_grad": {
"rtol": 1e-7,
"atol": 1e-7,
"eps": 1e-7
},
"second_order_grad": {
"rtol": 1e-5,
"atol": 1e-5,
"eps": 1e-5
}
}
}
......@@ -12,17 +12,137 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import typing
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from utils import _compute_numerical_jacobian, _compute_numerical_batch_jacobian
import config
import utils
from utils import (_compute_numerical_batch_jacobian,
_compute_numerical_jacobian)
from paddle.autograd.functional import _as_tensors
paddle.enable_static()
@utils.place(config.DEVICES)
@utils.parameterize((utils.TEST_CASE_NAME, 'fun', 'xs', 'v', 'stop_gradient'), (
('tensor_input', utils.reduce, np.random.rand(2, 3), None, False),
('tensor_sequence_input', utils.reduce, np.random.rand(2, 3), None, False),
('v_not_none', utils.reduce, np.random.rand(2, 3), np.random.rand(1),
False),
('xs_stop_gradient', utils.reduce, np.random.rand(2, 3), np.random.rand(1),
True),
('func_mutmul', utils.matmul, (np.random.rand(3, 2), np.random.rand(2, 3)),
None, False),
('func_mul', utils.mul, (np.random.rand(3, 3), np.random.rand(3, 3)), None,
False),
('func_out_two', utils.o2, (np.random.rand(10), np.random.rand(10)), None,
False), ))
class TestVJP(unittest.TestCase):
def setUp(self):
self.dtype = str(self.xs[0].dtype) if isinstance(
self.xs, typing.Sequence) else str(self.xs.dtype)
self._rtol = config.TOLERANCE.get(str(self.dtype)).get(
"first_order_grad").get("rtol")
self._atol = config.TOLERANCE.get(str(self.dtype)).get(
"first_order_grad").get("atol")
def _vjp(self):
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
feed, static_xs, static_v = gen_static_data_and_feed(
self.xs, self.v, stop_gradient=self.stop_gradient)
ys, xs_grads = paddle.autograd.vjp(self.fun, static_xs, static_v)
exe.run(sp)
return exe.run(mp, feed=feed, fetch_list=[ys, xs_grads])
def _expected_vjp(self):
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
feed, static_xs, static_v = gen_static_data_and_feed(self.xs,
self.v, False)
ys = self.fun(*static_xs) if isinstance(
static_xs, typing.Sequence) else self.fun(static_xs)
xs_grads = paddle.static.gradients(ys, static_xs, static_v)
exe.run(sp)
return exe.run(mp, feed=feed, fetch_list=[ys, xs_grads])
def test_vjp(self):
actual = self._vjp()
expected = self._expected_vjp()
self.assertEqual(len(actual), len(expected))
for i in range(len(actual)):
np.testing.assert_allclose(
actual[i], expected[i], rtol=self._rtol, atol=self._atol)
@utils.place(config.DEVICES)
@utils.parameterize(
(utils.TEST_CASE_NAME, 'fun', 'xs', 'v', 'expected_exception'), (
('v_shape_not_equal_ys', utils.square, np.random.rand(3),
np.random.rand(1), RuntimeError), ))
class TestVJPException(unittest.TestCase):
def setUp(self):
self.exe = paddle.static.Executor()
def _vjp(self):
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
feed, static_xs, static_v = gen_static_data_and_feed(self.xs,
self.v)
ys, xs_grads = paddle.autograd.vjp(self.fun, static_xs, static_v)
self.exe.run(sp)
return self.exe.run(mp, feed, fetch_list=[ys, xs_grads])
def test_vjp(self):
with self.assertRaises(self.expected_exception):
self._vjp()
def gen_static_data_and_feed(xs, v, stop_gradient=True):
feed = {}
if isinstance(xs, typing.Sequence):
static_xs = []
for i, x in enumerate(xs):
x = paddle.static.data(f"x{i}", x.shape, x.dtype)
x.stop_gradient = stop_gradient
static_xs.append(x)
feed.update({f'x{idx}': value for idx, value in enumerate(xs)})
else:
static_xs = paddle.static.data('x', xs.shape, xs.dtype)
static_xs.stop_gradient = stop_gradient
feed.update({'x': xs})
if isinstance(v, typing.Sequence):
static_v = []
for i, e in enumerate(v):
e = paddle.static.data(f'v{idx}', v.shape, v.dtype)
e.stop_gradient = stop_gradient
static_v.append(e)
feed.update({f'v{idx}': value for idx, value in v})
elif v is not None:
static_v = paddle.static.data('v', v.shape, v.dtype)
static_v.stop_gradient = stop_gradient
feed.update({'v': v})
else:
static_v = v
return feed, static_xs, static_v
def approx_jacobian(f, xs, dtype, eps=1e-5, batch=False):
r"""Computes an approximate Jacobian matrix of a multi-valued function
using finite differences.
The function input is required to be an np array or a list of list of np
arrays.
"""
......@@ -106,8 +226,13 @@ class TestJacobianFloat32(unittest.TestCase):
else:
self.place = fluid.CPUPlace()
self.dtype = 'float32'
self.np_dtype = np.float32
prepare_data(self, all_data_shapes, self.dtype)
self.eps = 1e-4
self.eps = config.TOLERANCE.get(self.dtype).get('first_order_grad').get(
'eps')
# self.rtol = config.TOLERANCE.get(self.dtype).get('first_order_grad').get('rtol')
# self.atol = config.TOLERANCE.get(self.dtype).get('first_order_grad').get('atol')
# Do't use tolerance in config, which will cause this test case failed.
self.rtol = 1e-2
self.atol = 1e-2
......@@ -116,8 +241,11 @@ class TestJacobianFloat32(unittest.TestCase):
startup = fluid.Program()
with fluid.program_guard(main, startup):
xs = make_tensors(inps)
JJ = paddle.autograd.functional.Jacobian(pd_f, xs, batch=batch)
nrow, ncol = JJ.shape()
JJ = paddle.autograd.functional.Jacobian(pd_f, xs, is_batched=batch)
if batch:
_, nrow, ncol = JJ.shape
else:
nrow, ncol = JJ.shape
full_jacobian = JJ[:]
exe = fluid.Executor(self.place)
exe.run(startup)
......@@ -128,17 +256,26 @@ class TestJacobianFloat32(unittest.TestCase):
pd_jacobians = exe.run(main, feed=feeds, fetch_list=[full_jacobian])[0]
np_jacobians = approx_jacobian(
np_f, inps, self.dtype, self.eps, batch=batch)
self.assertTrue(
np.allclose(pd_jacobians, np_jacobians, self.rtol, self.atol))
if batch:
np_jacobians = utils._np_transpose_matrix_format(
np_jacobians, utils.MatrixFormat.NBM, utils.MatrixFormat.BNM)
np.testing.assert_allclose(pd_jacobians, np_jacobians, self.rtol,
self.atol)
def run_test_by_rows(self, pd_f, np_f, inps, batch=False):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
xs = make_tensors(inps)
JJ = paddle.autograd.functional.Jacobian(pd_f, xs, batch=batch)
nrow, ncol = JJ.shape()
rows = [JJ[i] for i in range(nrow)]
JJ = paddle.autograd.functional.Jacobian(pd_f, xs, is_batched=batch)
if batch:
nbatch, nrow, ncol = JJ.shape
rows = [JJ[:, i, :] for i in range(nrow)]
else:
nrow, ncol = JJ.shape
rows = [JJ[i, :] for i in range(nrow)]
exe = fluid.Executor(self.place)
exe.run(startup)
if isinstance(inps, list):
......@@ -148,17 +285,23 @@ class TestJacobianFloat32(unittest.TestCase):
pd_jac = exe.run(main, feed=feeds, fetch_list=[rows])
np_jac = approx_jacobian(np_f, inps, self.dtype, self.eps, batch=batch)
for i in range(nrow):
self.assertTrue(
np.allclose(pd_jac[i], np_jac[i], self.rtol, self.atol))
np.testing.assert_allclose(pd_jac[i], np_jac[i], self.rtol,
self.atol)
def run_test_by_entries(self, pd_f, np_f, inps, batch=False):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
xs = make_tensors(inps)
JJ = paddle.autograd.functional.Jacobian(pd_f, xs, batch=batch)
nrow, ncol = JJ.shape()
entries = [JJ[i, j] for i in range(nrow) for j in range(ncol)]
JJ = paddle.autograd.functional.Jacobian(pd_f, xs, is_batched=batch)
if batch:
nbatch, nrow, ncol = JJ.shape
entries = [
JJ[:, i, j] for i in range(nrow) for j in range(ncol)
]
else:
nrow, ncol = JJ.shape
entries = [JJ[i, j] for i in range(nrow) for j in range(ncol)]
exe = fluid.Executor(self.place)
exe.run(startup)
if isinstance(inps, list):
......@@ -171,8 +314,7 @@ class TestJacobianFloat32(unittest.TestCase):
np_jac[i, ..., j] for i in range(nrow) for j in range(ncol)
]
for pd_entry, np_entry in zip(pd_entries, np_entries):
self.assertTrue(
np.allclose(pd_entry, np_entry, self.rtol, self.atol))
np.testing.assert_allclose(pd_entry, np_entry, self.rtol, self.atol)
def test_square(self):
def pd_f(x):
......@@ -186,8 +328,7 @@ class TestJacobianFloat32(unittest.TestCase):
self.run_test_by_entries(pd_f, np_f, self.A)
def test_mul(self):
def pd_f(xs):
x, y = xs
def pd_f(x, y):
return paddle.multiply(x, y)
def np_f(xs):
......@@ -202,8 +343,7 @@ class TestJacobianFloat32(unittest.TestCase):
self.run_test_by_entries(pd_f, np_f, [self.B, self.C])
def test_matmul(self):
def pd_f(xs):
x, y = xs
def pd_f(x, y):
return paddle.matmul(x, y)
def np_f(xs):
......@@ -215,8 +355,7 @@ class TestJacobianFloat32(unittest.TestCase):
self.run_test_by_entries(pd_f, np_f, [self.B, self.C])
def test_batch_matmul(self):
def pd_f(xs):
x, y = xs
def pd_f(x, y):
return paddle.matmul(x, y)
def np_f(xs):
......@@ -238,12 +377,15 @@ class TestJacobianFloat64(TestJacobianFloat32):
self.place = fluid.CPUPlace()
self.dtype = 'float64'
prepare_data(self, all_data_shapes, self.dtype)
self.eps = 1e-7
self.rtol = 1e-6
self.atol = 1e-6
self.eps = config.TOLERANCE.get(self.dtype).get('first_order_grad').get(
'eps')
self.rtol = config.TOLERANCE.get(self.dtype).get(
'first_order_grad').get('rtol')
self.atol = config.TOLERANCE.get(self.dtype).get(
'first_order_grad').get('atol')
class TestHessianFloat64(unittest.TestCase):
class TestHessianFloat32(unittest.TestCase):
@classmethod
def setUpClass(self):
paddle.enable_static()
......@@ -251,19 +393,22 @@ class TestHessianFloat64(unittest.TestCase):
self.place = fluid.CUDAPlace(0)
else:
self.place = fluid.CPUPlace()
self.dtype = 'float64'
self.dtype = 'float32'
prepare_data(self, all_data_shapes, self.dtype)
self.eps = 1e-7
self.rtol = 1e-6
self.atol = 1e-6
self.eps = config.TOLERANCE.get(self.dtype).get(
'second_order_grad').get('eps')
self.rtol = config.TOLERANCE.get(self.dtype).get(
'second_order_grad').get('rtol')
self.atol = config.TOLERANCE.get(self.dtype).get(
'second_order_grad').get('atol')
def run_test_by_fullmatrix(self, pd_f, inps, np_hess, batch=False):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
xs = make_tensors(inps)
HH = paddle.autograd.functional.Hessian(pd_f, xs, batch=batch)
nrow, ncol = HH.shape()
HH = paddle.autograd.functional.Hessian(pd_f, xs, is_batched=batch)
nrow, ncol = HH.shape
full_hessian = HH[:]
exe = fluid.Executor(self.place)
exe.run(startup)
......@@ -272,36 +417,38 @@ class TestHessianFloat64(unittest.TestCase):
else:
feeds = {'x': inps}
pd_hess = exe.run(main, feed=feeds, fetch_list=[full_hessian])[0]
self.assertTrue(np.allclose(pd_hess, np_hess, self.rtol, self.atol))
np.testing.assert_allclose(pd_hess, np_hess, self.rtol, self.atol)
def test_square(self):
def pd_f(x):
"""Input is a square matrix."""
return paddle.matmul(x, x.T)
return paddle.matmul(x, x.T).flatten().sum()
def np_hess(x):
dim = x.shape[0]
f_xx_upperleft = 2 * np.eye(dim, dtype=self.dtype)
f_xx = np.zeros([dim * dim, dim * dim], dtype=self.dtype)
f_xx[:dim, :dim] = f_xx_upperleft
return f_xx
upperleft = 2 * np.eye(dim, dtype=self.dtype)
upper = np.concatenate((upperleft, upperleft))
return np.concatenate((upper, upper), axis=1)
self.run_test_by_fullmatrix(pd_f, self.B, np_hess(self.B))
def test_batch_square(self):
def pd_f(x):
"""Input is a square matrix."""
return paddle.matmul(x, paddle.transpose(x, [0, 2, 1]))
def np_hess(x):
bat, dim, _ = x.shape
f_xx_upperleft = 2 * np.eye(dim, dtype=self.dtype)
f_xx = np.zeros([bat, dim * dim, dim * dim], dtype=self.dtype)
f_xx[..., :dim, :dim] = f_xx_upperleft
return f_xx
self.run_test_by_fullmatrix(
pd_f, self.E, np_hess(self.E), batch=True)
class TestHessianFloat64(TestHessianFloat32):
@classmethod
def setUpClass(self):
paddle.enable_static()
if fluid.core.is_compiled_with_cuda():
self.place = fluid.CUDAPlace(0)
else:
self.place = fluid.CPUPlace()
self.dtype = 'float64'
prepare_data(self, all_data_shapes, self.dtype)
self.eps = config.TOLERANCE.get(self.dtype).get(
'second_order_grad').get('eps')
self.rtol = config.TOLERANCE.get(self.dtype).get(
'second_order_grad').get('rtol')
self.atol = config.TOLERANCE.get(self.dtype).get(
'second_order_grad').get('atol')
if __name__ == "__main__":
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.compat as cpt
import paddle.nn.functional as F
from utils import _compute_numerical_hessian, _compute_numerical_batch_hessian
class TestHessian(unittest.TestCase):
@classmethod
def setUpClass(self):
self.shape = (2, 2)
self.dtype = 'float32'
self.np_dtype = np.float32
self.numerical_delta = 1e-2
self.rtol = 1e-2
self.atol = 1e-2
self.x = paddle.rand(shape=self.shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.shape, dtype=self.dtype)
def test_single_input(self):
def func(x):
return paddle.sum(paddle.matmul(x, x))
numerical_hessian = _compute_numerical_hessian(
func, self.x, self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
hessian = paddle.autograd.hessian(func, self.x)
assert np.allclose(hessian.numpy(), numerical_hessian[0][0], self.rtol,
self.atol)
def test_multi_input(self):
def func(x, y):
return paddle.sum(paddle.matmul(x, y))
numerical_hessian = _compute_numerical_hessian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
hessian = paddle.autograd.hessian(func, [self.x, self.y])
for i in range(len(hessian)):
for j in range(len(hessian[0])):
assert np.allclose(hessian[i][j].numpy(),
numerical_hessian[i][j], self.rtol,
self.atol)
def test_allow_unused_false(self):
def func(x, y):
return paddle.sum(paddle.matmul(x, x))
try:
self.x.stop_gradient = False
self.y.stop_gradient = False
hessian = paddle.autograd.hessian(func, [self.x, self.y])
except ValueError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("allow_unused") > 0
def test_allow_unused_true(self):
def func(x, y):
return paddle.sum(paddle.matmul(x, x))
numerical_hessian = _compute_numerical_hessian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
hessian = paddle.autograd.hessian(
func, [self.x, self.y], allow_unused=True)
for i in range(len(hessian)):
for j in range(len(hessian[0])):
if i == j == 0:
assert np.allclose(hessian[i][j].numpy(),
numerical_hessian[i][j], self.rtol,
self.atol)
else:
assert hessian[i][j] is None
def test_create_graph_false(self):
def func(x):
return paddle.sum(paddle.matmul(x, x))
numerical_hessian = _compute_numerical_hessian(
func, self.x, self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
hessian = paddle.autograd.hessian(func, self.x)
assert hessian.stop_gradient == True
assert np.allclose(hessian.numpy(), numerical_hessian[0][0], self.rtol,
self.atol)
try:
paddle.grad(hessian, self.x)
except RuntimeError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("has no gradient") > 0
def test_create_graph_true(self):
def func(x):
return paddle.sum(F.sigmoid(x))
numerical_hessian = _compute_numerical_hessian(
func, self.x, self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
hessian = paddle.autograd.hessian(func, self.x, create_graph=True)
assert hessian.stop_gradient == False
assert np.allclose(hessian.numpy(), numerical_hessian[0][0], self.rtol,
self.atol)
triple_grad = paddle.grad(hessian, self.x)
assert triple_grad is not None
class TestHessianFloat64(TestHessian):
@classmethod
def setUpClass(self):
self.shape = (2, 2)
self.dtype = 'float64'
self.np_dtype = np.float64
self.numerical_delta = 1e-5
self.rtol = 1e-5
self.atol = 1e-5
self.x = paddle.rand(shape=self.shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.shape, dtype=self.dtype)
class TestBatchHessian(unittest.TestCase):
@classmethod
def setUpClass(self):
self.x_shape = (5, 2)
self.weight_shape = (2, 4)
self.y_shape = (5, 2)
self.dtype = 'float32'
self.np_dtype = np.float32
self.numerical_delta = 1e-2
self.rtol = 1e-3
self.atol = 1e-3
self.x = paddle.rand(shape=self.x_shape, dtype=self.dtype)
self.weight = paddle.rand(shape=self.weight_shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.y_shape, dtype=self.dtype)
def test_single_input(self):
def func(x):
return paddle.matmul(x * x, self.weight)[:, 0:1]
numerical_hessian = _compute_numerical_batch_hessian(
func, self.x, self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
hessian = paddle.autograd.batch_hessian(func, self.x, create_graph=True)
assert np.allclose(hessian, numerical_hessian, self.rtol, self.atol)
def test_multi_input(self):
def func(x, y):
return paddle.matmul(x * x * y * y, self.weight)[:, 0:1]
numerical_hessian = _compute_numerical_batch_hessian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
hessian = paddle.autograd.batch_hessian(func, [self.x, self.y])
shape_tensor = paddle.to_tensor(numerical_hessian).astype("float64")
hessian_reshape = np.reshape(hessian, (shape_tensor.shape))
assert np.allclose(hessian_reshape, numerical_hessian, self.rtol,
self.atol)
def test_allow_unused_false(self):
def func(x, y):
return paddle.matmul(x * x, self.weight)[:, 0:1]
try:
self.x.stop_gradient = False
self.y.stop_gradient = False
hessian = paddle.autograd.batch_hessian(func, [self.x, self.y])
except ValueError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("allow_unused") > 0
def test_allow_unused_true(self):
def func(x, y):
return paddle.matmul(x * x, self.weight)[:, 0:1]
numerical_hessian = _compute_numerical_batch_hessian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
hessian = paddle.autograd.batch_hessian(
func, [self.x, self.y], allow_unused=True)
for i in range(len(hessian)):
for j in range(len(hessian[0])):
if i == j == 0:
numerical_hessian = np.stack(
(numerical_hessian[i][j], numerical_hessian[i][j + 1]),
axis=0)
assert np.allclose(hessian[i][j], numerical_hessian,
self.rtol, self.atol)
else:
assert hessian[i][j] is None
def test_create_graph_false(self):
def func(x):
return paddle.matmul(x * x, self.weight)[:, 0:1]
numerical_hessian = _compute_numerical_batch_hessian(
func, self.x, self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
hessian = paddle.autograd.batch_hessian(func, self.x)
assert hessian.stop_gradient == True
assert np.allclose(hessian.numpy(), numerical_hessian, self.rtol,
self.atol)
try:
paddle.grad(hessian, self.x)
except RuntimeError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("has no gradient") > 0
def test_create_graph_true(self):
def func(x):
return paddle.matmul(x * x, self.weight)[:, 0:1]
numerical_hessian = _compute_numerical_batch_hessian(
func, self.x, self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
hessian = paddle.autograd.batch_hessian(func, self.x, create_graph=True)
assert hessian.stop_gradient == False
assert np.allclose(hessian.numpy(), numerical_hessian, self.rtol,
self.atol)
triple_grad = paddle.grad(hessian, self.x)
assert triple_grad is not None
class TestBatchHessianFloat64(TestBatchHessian):
@classmethod
def setUpClass(self):
self.x_shape = (5, 2)
self.weight_shape = (2, 4)
self.y_shape = (5, 2)
self.dtype = 'float64'
self.np_dtype = np.float64
self.numerical_delta = 1e-4
self.rtol = 1e-5
self.atol = 1e-5
self.x = paddle.rand(shape=self.x_shape, dtype=self.dtype)
self.weight = paddle.rand(shape=self.weight_shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.y_shape, dtype=self.dtype)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.compat as cpt
from utils import _compute_numerical_jacobian, _compute_numerical_batch_jacobian
class TestJacobian(unittest.TestCase):
@classmethod
def setUpClass(self):
self.shape = (4, 4)
self.dtype = 'float32'
self.np_dtype = np.float32
self.numerical_delta = 1e-4
self.rtol = 1e-3
self.atol = 1e-3
self.x = paddle.rand(shape=self.shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.shape, dtype=self.dtype)
def test_single_input_and_single_output(self):
def func(x):
return paddle.matmul(x, x)
numerical_jacobian = _compute_numerical_jacobian(
func, self.x, self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
jacobian = paddle.autograd.jacobian(func, self.x)
assert np.allclose(jacobian.numpy(), numerical_jacobian[0][0],
self.rtol, self.atol)
def test_single_input_and_multi_output(self):
def func(x):
return paddle.matmul(x, x), x * x
numerical_jacobian = _compute_numerical_jacobian(
func, self.x, self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
jacobian = paddle.autograd.jacobian(func, self.x)
for i in range(len(jacobian)):
assert np.allclose(jacobian[i].numpy(), numerical_jacobian[i][0],
self.rtol, self.atol)
def test_multi_input_and_single_output(self):
def func(x, y):
return paddle.matmul(x, y)
numerical_jacobian = _compute_numerical_jacobian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
jacobian = paddle.autograd.jacobian(func, [self.x, self.y])
for j in range(len(jacobian)):
assert np.allclose(jacobian[j].numpy(), numerical_jacobian[0][j],
self.rtol, self.atol)
def test_multi_input_and_multi_output(self):
def func(x, y):
return paddle.matmul(x, y), x * y
numerical_jacobian = _compute_numerical_jacobian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
jacobian = paddle.autograd.jacobian(func, [self.x, self.y])
for i in range(len(jacobian)):
for j in range(len(jacobian[0])):
assert np.allclose(jacobian[i][j].numpy(),
numerical_jacobian[i][j], self.rtol,
self.atol)
def test_allow_unused_false(self):
def func(x, y):
return paddle.matmul(x, x)
try:
self.x.stop_gradient = False
self.y.stop_gradient = False
jacobian = paddle.autograd.jacobian(func, [self.x, self.y])
except ValueError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("allow_unused") > 0
def test_allow_unused_true(self):
def func(x, y):
return paddle.matmul(x, x)
numerical_jacobian = _compute_numerical_jacobian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
jacobian = paddle.autograd.jacobian(
func, [self.x, self.y], allow_unused=True)
assert np.allclose(jacobian[0].numpy(), numerical_jacobian[0][0],
self.rtol, self.atol)
assert jacobian[1] is None
def test_create_graph_false(self):
def func(x, y):
return paddle.matmul(x, y)
numerical_jacobian = _compute_numerical_jacobian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
jacobian = paddle.autograd.jacobian(func, [self.x, self.y])
for j in range(len(jacobian)):
assert jacobian[j].stop_gradient == True
assert np.allclose(jacobian[j].numpy(), numerical_jacobian[0][j],
self.rtol, self.atol)
try:
paddle.grad(jacobian[0], [self.x, self.y])
except RuntimeError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("has no gradient") > 0
def test_create_graph_true(self):
def func(x, y):
return paddle.matmul(x, y)
numerical_jacobian = _compute_numerical_jacobian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
jacobian = paddle.autograd.jacobian(
func, [self.x, self.y], create_graph=True)
for j in range(len(jacobian)):
assert jacobian[j].stop_gradient == False
assert np.allclose(jacobian[j].numpy(), numerical_jacobian[0][j],
self.rtol, self.atol)
double_grad = paddle.grad(jacobian[0], [self.x, self.y])
assert double_grad is not None
class TestJacobianFloat64(TestJacobian):
@classmethod
def setUpClass(self):
self.shape = (4, 4)
self.dtype = 'float64'
self.np_dtype = np.float64
self.numerical_delta = 1e-7
self.rtol = 1e-7
self.atol = 1e-7
self.x = paddle.rand(shape=self.shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.shape, dtype=self.dtype)
class TestJacobianBatch(unittest.TestCase):
@classmethod
def setUpClass(self):
self.x_shape = (4, 2)
self.weight_shape = (2, 4)
self.y_shape = (4, 2)
self.dtype = 'float32'
self.np_dtype = np.float32
self.numerical_delta = 1e-4
self.rtol = 1e-3
self.atol = 1e-3
self.x = paddle.rand(shape=self.x_shape, dtype=self.dtype)
self.weight = paddle.rand(shape=self.weight_shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.y_shape, dtype=self.dtype)
def test_batch_single_input_and_batch_single_output(self):
def func(x):
return paddle.matmul(paddle.matmul(x, self.weight), self.y)
numerical_jacobian = _compute_numerical_batch_jacobian(
func, [self.x], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
batch_jacobian = paddle.autograd.batch_jacobian(
func,
self.x, )
self.assertTrue(
np.allclose(batch_jacobian.numpy().all(), numerical_jacobian[0][0]
.all()))
def test_batch_single_input_and_batch_multi_output(self):
def func(x):
return paddle.matmul(paddle.matmul(x, self.weight), self.y), x * x
numerical_jacobian = _compute_numerical_batch_jacobian(
func, [self.x], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
batch_jacobian = paddle.autograd.batch_jacobian(
func,
self.x, )
for i in range(len(batch_jacobian)):
assert np.allclose(batch_jacobian[i].numpy(),
numerical_jacobian[i][0], self.rtol, self.atol)
def test_batch_multi_input_and_batch_single_output(self):
def func(x, y):
return x * y
numerical_jacobian = _compute_numerical_batch_jacobian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
batch_jacobian = paddle.autograd.batch_jacobian(func, [self.x, self.y])
for j in range(len(batch_jacobian)):
assert np.allclose(batch_jacobian[j].numpy(),
numerical_jacobian[0][j], self.rtol, self.atol)
def test_batch_multi_input_and_batch_multi_output(self):
def func(x, y):
return x * y, x * y
numerical_jacobian = _compute_numerical_batch_jacobian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
batch_jacobian = paddle.autograd.batch_jacobian(func, [self.x, self.y])
for i in range(len(batch_jacobian)):
assert np.allclose(batch_jacobian[i], numerical_jacobian[i],
self.rtol, self.atol)
def test_allow_unused_false(self):
def func(x, y):
return x * x
try:
self.x.stop_gradient = False
self.y.stop_gradient = False
jacobian = paddle.autograd.batch_jacobian(func, [self.x, self.y])
except ValueError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("allow_unused") > 0
def test_allow_unused_true(self):
def func(x, y):
return x * x
numerical_jacobian = _compute_numerical_batch_jacobian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
jacobian = paddle.autograd.batch_jacobian(
func, [self.x, self.y], allow_unused=True)
assert np.allclose(jacobian[0].numpy(), numerical_jacobian[0][0],
self.rtol, self.atol)
assert jacobian[1] is None
def test_create_graph_false(self):
def func(x, y):
return x * y
numerical_jacobian = _compute_numerical_batch_jacobian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
jacobian = paddle.autograd.batch_jacobian(func, [self.x, self.y])
for j in range(len(jacobian)):
assert jacobian[j].stop_gradient == True
assert np.allclose(jacobian[j].numpy(), numerical_jacobian[0][j],
self.rtol, self.atol)
try:
paddle.grad(jacobian[0], [self.x, self.y])
except RuntimeError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("has no gradient") > 0
def test_create_graph_true(self):
def func(x, y):
return x * y
numerical_jacobian = _compute_numerical_batch_jacobian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
jacobian = paddle.autograd.batch_jacobian(
func, [self.x, self.y], create_graph=True)
for j in range(len(jacobian)):
assert jacobian[j].stop_gradient == False
assert np.allclose(jacobian[j].numpy(), numerical_jacobian[0][j],
self.rtol, self.atol)
double_grad = paddle.grad(jacobian[0], [self.x, self.y])
assert double_grad is not None
class TestJacobianBatchFloat64(TestJacobianBatch):
@classmethod
def setUpClass(self):
self.x_shape = (12, 2)
self.weight_shape = (2, 12)
self.y_shape = (12, 2)
self.dtype = 'float64'
self.np_dtype = np.float64
self.numerical_delta = 1e-7
self.rtol = 1e-7
self.atol = 1e-7
self.x = paddle.rand(shape=self.x_shape, dtype=self.dtype)
self.weight = paddle.rand(shape=self.weight_shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.y_shape, dtype=self.dtype)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.compat as cpt
import paddle.nn.functional as F
from utils import _compute_numerical_vhp
class TestVHP(unittest.TestCase):
@classmethod
def setUpClass(self):
self.shape = (2, 2)
self.dtype = 'float32'
self.np_dtype = np.float32
self.numerical_delta = 1e-2
self.rtol = 1e-2
self.atol = 1e-2
self.x = paddle.rand(shape=self.shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.shape, dtype=self.dtype)
self.vx = paddle.rand(shape=self.shape, dtype=self.dtype)
self.vy = paddle.rand(shape=self.shape, dtype=self.dtype)
def test_single_input(self):
def func(x):
return paddle.sum(paddle.matmul(x, x))
numerical_func_output = func(self.x).numpy()
numerical_vhp = _compute_numerical_vhp(
func, self.x, self.vx, self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
func_output, vhp = paddle.autograd.vhp(func, self.x, self.vx)
assert np.allclose(func_output.numpy(), numerical_func_output,
self.rtol, self.atol)
assert np.allclose(vhp[0].numpy(), numerical_vhp[0], self.rtol,
self.atol)
def test_multi_input(self):
def func(x, y):
return paddle.sum(paddle.matmul(x, y))
numerical_func_output = func(self.x, self.y).numpy()
numerical_vhp = _compute_numerical_vhp(
func, [self.x, self.y], [self.vx, self.vy], self.numerical_delta,
self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
func_output, vhp = paddle.autograd.vhp(func, [self.x, self.y],
[self.vx, self.vy])
assert np.allclose(func_output.numpy(), numerical_func_output,
self.rtol, self.atol)
for i in range(len(vhp)):
assert np.allclose(vhp[i].numpy(), numerical_vhp[i], self.rtol,
self.atol)
def test_v_default(self):
def func(x, y):
return paddle.sum(paddle.matmul(x, y))
numerical_func_output = func(self.x, self.y).numpy()
vx = paddle.ones(self.vx.shape, dtype=self.vx.dtype)
vy = paddle.ones(self.vy.shape, dtype=self.vy.dtype)
numerical_vhp = _compute_numerical_vhp(func, [self.x, self.y],
[vx, vy], self.numerical_delta,
self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
func_output, vhp = paddle.autograd.vhp(func, [self.x, self.y])
assert np.allclose(func_output.numpy(), numerical_func_output,
self.rtol, self.atol)
for i in range(len(vhp)):
assert np.allclose(vhp[i].numpy(), numerical_vhp[i], self.rtol,
self.atol)
def test_allow_unused_false(self):
def func(x, y):
return paddle.sum(paddle.matmul(x, x))
try:
self.x.stop_gradient = False
self.y.stop_gradient = False
_ = paddle.autograd.vhp(func, [self.x, self.y])
except ValueError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("allow_unused") > 0
def test_allow_unused_true(self):
def func(x, y):
return paddle.sum(paddle.matmul(x, x))
numerical_func_output = func(self.x, self.y).numpy()
numerical_vhp = _compute_numerical_vhp(
func, [self.x, self.y], [self.vx, self.vy], self.numerical_delta,
self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
func_output, vhp = paddle.autograd.vhp(func, [self.x, self.y],
[self.vx, self.vy],
allow_unused=True)
assert np.allclose(func_output.numpy(), numerical_func_output,
self.rtol, self.atol)
assert np.allclose(vhp[0].numpy(), numerical_vhp[0], self.rtol,
self.atol)
assert vhp[1] is None
def test_create_graph_false(self):
def func(x):
return paddle.sum(F.sigmoid(x))
numerical_func_output = func(self.x).numpy()
numerical_vhp = _compute_numerical_vhp(
func, self.x, self.vx, self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
func_output, vhp = paddle.autograd.vhp(func, self.x, self.vx)
assert np.allclose(func_output.numpy(), numerical_func_output,
self.rtol, self.atol)
assert vhp[0].stop_gradient == True
assert np.allclose(vhp[0].numpy(), numerical_vhp[0], self.rtol,
self.atol)
try:
paddle.grad(vhp, self.x)
except RuntimeError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("has no gradient") > 0
def test_create_graph_true(self):
def func(x):
return paddle.sum(F.sigmoid(x))
numerical_func_output = func(self.x).numpy()
numerical_vhp = _compute_numerical_vhp(
func, self.x, self.vx, self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
func_output, vhp = paddle.autograd.vhp(func,
self.x,
self.vx,
create_graph=True)
assert np.allclose(func_output.numpy(), numerical_func_output,
self.rtol, self.atol)
assert vhp[0].stop_gradient == False
assert np.allclose(vhp[0].numpy(), numerical_vhp[0], self.rtol,
self.atol)
triple_grad = paddle.grad(vhp, self.x)
assert triple_grad is not None
class TestVHPFloat64(TestVHP):
@classmethod
def setUpClass(self):
self.shape = (2, 2)
self.dtype = 'float64'
self.np_dtype = np.float64
self.numerical_delta = 1e-5
self.rtol = 1e-5
self.atol = 1e-5
self.x = paddle.rand(shape=self.shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.shape, dtype=self.dtype)
self.vx = paddle.rand(shape=self.shape, dtype=self.dtype)
self.vy = paddle.rand(shape=self.shape, dtype=self.dtype)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
from paddle.autograd.functional import vjp, jvp, _tensors
from paddle import grad, ones_like, zeros_like
def reduce(x):
return paddle.sum(x)
def reduce_dim(x):
return paddle.sum(x, axis=0)
def matmul(x, y):
return paddle.matmul(x, y)
def mul(x, y):
return x * y
def pow(x, y):
return paddle.pow(x, y)
def o2(x, y):
return paddle.multiply(x, y), paddle.matmul(x, y.t())
def unuse(x, y):
return paddle.sum(x)
def nested(x):
def inner(y):
return x * y
return inner
def make_v(f, inputs):
outputs = _tensors(f(*inputs), "outputs")
return [ones_like(x) for x in outputs]
class TestAutogradFunctional(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.RAW_INPUTS = {
'a': [1.0],
'b': [1.0, 2.0],
'c': [3.0, 4.0],
'd': [[2.0], [3.0]],
'A': [[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]],
'B': [[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]],
}
def setUp(self):
pass
def gen_input(self, inp, stop_gradient=False):
if isinstance(inp, paddle.Tensor):
return inp
return paddle.to_tensor(
self.RAW_INPUTS[inp], stop_gradient=stop_gradient)
def gen_inputs(self, inputs):
if isinstance(inputs, list):
inputs = [self.gen_input(x) for x in inputs]
else:
inputs = [self.gen_input(inputs)]
return inputs
def gen_test_pairs(self,
func,
inputs,
v=None,
create_graph=False,
allow_unused=False):
def vjp_test():
nonlocal v
xs = self.gen_inputs(inputs)
if v is not None:
v = self.gen_inputs(v)
outputs, inputs_grad = vjp(func,
xs,
v,
create_graph=create_graph,
allow_unused=allow_unused)
else:
outputs, inputs_grad = vjp(func,
xs,
create_graph=create_graph,
allow_unused=allow_unused)
return outputs, inputs_grad
def grad_test():
nonlocal v
xs = self.gen_inputs(inputs)
if v is not None:
v = self.gen_inputs(v)
outputs = func(*xs)
if v is not None:
inputs_grad = grad(
outputs,
xs,
v,
create_graph=create_graph,
allow_unused=allow_unused)
else:
inputs_grad = grad(
outputs,
xs,
create_graph=create_graph,
allow_unused=allow_unused)
return outputs, inputs_grad
return vjp_test, grad_test
def gen_jvp_tests(self,
func,
inputs,
v=None,
create_graph=False,
allow_unused=False):
def jvp_test():
nonlocal v
xs = self.gen_inputs(inputs)
if v is not None:
v = self.gen_inputs(v)
outputs, outputs_grad = jvp(func,
xs,
v,
create_graph=create_graph,
allow_unused=allow_unused)
else:
outputs, outputs_grad = jvp(func,
xs,
create_graph=create_graph,
allow_unused=allow_unused)
return outputs, outputs_grad
return jvp_test
def check_results(self, ref, res):
type_error = 'Result is different than expected in shape or type'
value_error = 'Result is different than expected values'
if ref is None:
self.assertTrue(res is None, type_error)
elif isinstance(ref, paddle.Tensor):
self.assertTrue(isinstance(res, paddle.Tensor), type_error)
self.assertTrue(paddle.allclose(res, ref), value_error)
else:
self.assertTrue(len(res) == len(ref), type_error)
for i in range(len(ref)):
self.check_results(ref[i], res[i])
return True
class TestVJP(TestAutogradFunctional):
def test_vjp_i1o1_no_create_graph(self):
test_cases = [
[reduce, 'A'], #noqa
[reduce_dim, 'A'], #noqa
] #noqa
for f, inputs in test_cases:
vjp, grad = self.gen_test_pairs(f, inputs)
vjp_result, grad_result = vjp(), grad()
self.check_results(grad_result, vjp_result)
def test_vjp_i2o1_no_create_graph(self):
test_cases = [
[matmul, ['A', 'B']], #noqa
[mul, ['b', 'c']], #noqa
] #noqa
for f, inputs in test_cases:
vjp, grad = self.gen_test_pairs(f, inputs)
vjp_result, grad_result = vjp(), grad()
self.check_results(grad_result, vjp_result)
def test_vjp_i2o2_no_create_graph(self):
test_cases = [
[o2, ['A', 'A']], #noqa
] #noqa
for f, inputs in test_cases:
inputs = self.gen_inputs(inputs)
v = make_v(f, inputs)
vjp, grad = self.gen_test_pairs(f, inputs, v=v)
vjp_result, grad_result = vjp(), grad()
self.check_results(grad_result, vjp_result)
def test_vjp_i2o2_omitting_v_no_create_graph(self):
test_cases = [
[o2, ['A', 'A']], #noqa
] #noqa
for f, inputs in test_cases:
inputs = self.gen_inputs(inputs)
vjp, grad = self.gen_test_pairs(f, inputs)
vjp_result, grad_result = vjp(), grad()
self.check_results(grad_result, vjp_result)
def test_vjp_nested_no_create_graph(self):
x = self.gen_input('a')
test_cases = [
[nested(x), 'a'], #noqa
]
for f, inputs in test_cases:
vjp, grad = self.gen_test_pairs(f, inputs)
vjp_result, grad_result = vjp(), grad()
self.check_results(grad_result, vjp_result)
def test_vjp_aliased_input_no_create_graph(self):
x = self.gen_input('a')
ref = self.gen_test_pairs(nested(x), 'a')[0]
aliased = self.gen_test_pairs(nested(x), x)[0]
ref_result, aliased_result = ref(), aliased()
self.check_results(ref_result, aliased_result)
def test_vjp_allowunused_no_create_graph(self):
x, y = self.gen_input('A'), self.gen_input('a')
vjp, grad = self.gen_test_pairs(unuse, [x, y], allow_unused=True)
vjp_result, grad_result = vjp(), grad()
self.check_results(grad_result, vjp_result)
def jac(grad_fn, f, inputs):
assert grad_fn in [vjp, jvp]
if grad_fn is jvp:
vs = [zeros_like(x) for x in inputs]
else:
outputs = f(*inputs)
if isinstance(outputs, paddle.Tensor):
outputs = [outputs]
vs = [zeros_like(y) for y in outputs]
JJ_cols = []
for i, v in enumerate(vs):
v = v.flatten()
for j in range(len(v)):
_v = zeros_like(v).detach()
_v[j] = 1.0
_v = _v.reshape(vs[i].shape)
_vs = vs.copy()
_vs[i] = _v
_, grads = grad_fn(f, inputs, vs)
d_outs = paddle.concat([d_out.flatten() for d_out in grads])
JJ_cols.append(d_outs)
# JJ is the fully unrolled jacobian
JJ = paddle.stack(JJ_cols)
if grad_fn is vjp:
JJ = JJ.t()
return JJ
class TestJVP(TestAutogradFunctional):
def test_jvp_i1o1_no_create_graph(self):
test_cases = [
[reduce, 'A'], #noqa
[reduce_dim, 'A'], #noqa
] #noqa
for f, inputs in test_cases:
inputs = self.gen_inputs(inputs)
forward_jac = jac(jvp, f, inputs)
reverse_jac = jac(vjp, f, inputs)
self.check_results(forward_jac, reverse_jac)
def test_jvp_i2o1_no_create_graph(self):
test_cases = [ #noqa
[matmul, ['A', 'B']], #noqa
] #noqa
for f, inputs in test_cases:
inputs = self.gen_inputs(inputs)
forward_jac = jac(jvp, f, inputs)
reverse_jac = jac(vjp, f, inputs)
self.check_results(forward_jac, reverse_jac)
def test_jvp_i2o2_no_create_graph(self):
test_cases = [ #noqa
[o2, ['A', 'A']], #noqa
] #noqa
for f, inputs in test_cases:
inputs = self.gen_inputs(inputs)
forward_jac = jac(jvp, f, inputs)
reverse_jac = jac(vjp, f, inputs)
self.check_results(forward_jac, reverse_jac)
def test_jvp_i2o2_omitting_v_no_create_graph(self):
test_cases = [ #noqa
[o2, ['A', 'A']], #noqa
] #noqa
for f, inputs in test_cases:
inputs = self.gen_inputs(inputs)
results_omitting_v = jvp(f, inputs)
v = [ones_like(x) for x in inputs]
results_with_v = jvp(f, inputs, v)
self.check_results(results_omitting_v, results_with_v)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing
import enum
import sys
import re
import inspect
import functools
import contextlib
import collections
import numpy as np
import paddle
from paddle.autograd.functional import _tensors
from paddle.autograd.functional import _as_tensors
##########################################################
# Finite Difference Utils
##########################################################
def _product(t):
if isinstance(t, int):
return t
......@@ -25,7 +36,9 @@ def _product(t):
def _get_item(t, idx):
assert isinstance(t, paddle.Tensor), "The first argument t must be Tensor."
assert isinstance(
t,
paddle.fluid.framework.Variable), "The first argument t must be Tensor."
assert isinstance(idx,
int), "The second argument idx must be an int number."
flat_t = paddle.reshape(t, [-1])
......@@ -33,7 +46,9 @@ def _get_item(t, idx):
def _set_item(t, idx, value):
assert isinstance(t, paddle.Tensor), "The first argument t must be Tensor."
assert isinstance(
t,
paddle.fluid.framework.Variable), "The first argument t must be Tensor."
assert isinstance(idx,
int), "The second argument idx must be an int number."
flat_t = paddle.reshape(t, [-1])
......@@ -42,8 +57,8 @@ def _set_item(t, idx, value):
def _compute_numerical_jacobian(func, xs, delta, np_dtype):
xs = _tensors(xs, "xs")
ys = _tensors(func(*xs), "ys")
xs = list(_as_tensors(xs))
ys = list(_as_tensors(func(*xs)))
fin_size = len(xs)
fout_size = len(ys)
jacobian = list([] for _ in range(fout_size))
......@@ -59,11 +74,11 @@ def _compute_numerical_jacobian(func, xs, delta, np_dtype):
orig = _get_item(xs[j], q)
x_pos = orig + delta
xs[j] = _set_item(xs[j], q, x_pos)
ys_pos = _tensors(func(*xs), "ys_pos")
ys_pos = _as_tensors(func(*xs))
x_neg = orig - delta
xs[j] = _set_item(xs[j], q, x_neg)
ys_neg = _tensors(func(*xs), "ys_neg")
ys_neg = _as_tensors(func(*xs))
xs[j] = _set_item(xs[j], q, orig)
......@@ -76,8 +91,8 @@ def _compute_numerical_jacobian(func, xs, delta, np_dtype):
def _compute_numerical_hessian(func, xs, delta, np_dtype):
xs = _tensors(xs, "xs")
ys = _tensors(func(*xs), "ys")
xs = list(_as_tensors(xs))
ys = list(_as_tensors(func(*xs)))
fin_size = len(xs)
hessian = list([] for _ in range(fin_size))
for i in range(fin_size):
......@@ -107,10 +122,22 @@ def _compute_numerical_hessian(func, xs, delta, np_dtype):
return hessian
def _compute_numerical_batch_jacobian(func, xs, delta, np_dtype):
def concat_to_matrix(xs, is_batched=False):
"""Concats a tuple of tuple of Jacobian/Hessian matrix into one matrix"""
rows = []
for i in range(len(xs)):
rows.append(np.concatenate([x for x in xs[i]], -1))
return np.concatenate(rows, 1) if is_batched else np.concatenate(rows, 0)
def _compute_numerical_batch_jacobian(func,
xs,
delta,
np_dtype,
merge_batch=True):
no_batch_jacobian = _compute_numerical_jacobian(func, xs, delta, np_dtype)
xs = _tensors(xs, "xs")
ys = _tensors(func(*xs), "ys")
xs = list(_as_tensors(xs))
ys = list(_as_tensors(func(*xs)))
fin_size = len(xs)
fout_size = len(ys)
bs = xs[0].shape[0]
......@@ -128,7 +155,8 @@ def _compute_numerical_batch_jacobian(func, xs, delta, np_dtype):
for b in range(bs):
for q in range(in_size):
batch_jac_i_j[p][b][q] = jac[b][p][b][q]
batch_jac_i_j = np.reshape(batch_jac_i_j, (out_size, -1))
if merge_batch:
batch_jac_i_j = np.reshape(batch_jac_i_j, (out_size, -1))
batch_jac_i.append(batch_jac_i_j)
bat_jac.append(batch_jac_i)
......@@ -136,7 +164,7 @@ def _compute_numerical_batch_jacobian(func, xs, delta, np_dtype):
def _compute_numerical_batch_hessian(func, xs, delta, np_dtype):
xs = _tensors(xs, "xs")
xs = list(_as_tensors(xs))
batch_size = xs[0].shape[0]
fin_size = len(xs)
hessian = []
......@@ -175,8 +203,10 @@ def _compute_numerical_batch_hessian(func, xs, delta, np_dtype):
def _compute_numerical_vjp(func, xs, v, delta, np_dtype):
xs = _tensors(xs, "xs")
xs = _as_tensors(xs)
jacobian = np.array(_compute_numerical_jacobian(func, xs, delta, np_dtype))
if v is None:
v = [paddle.ones_like(x) for x in xs]
flat_v = np.array([v_el.numpy().reshape(-1) for v_el in v])
vjp = [np.zeros((_product(x.shape)), dtype=np_dtype) for x in xs]
for j in range(len(xs)):
......@@ -188,7 +218,7 @@ def _compute_numerical_vjp(func, xs, v, delta, np_dtype):
def _compute_numerical_vhp(func, xs, v, delta, np_dtype):
xs = _tensors(xs, "xs")
xs = list(_as_tensors(xs))
hessian = np.array(_compute_numerical_hessian(func, xs, delta, np_dtype))
flat_v = np.array([v_el.numpy().reshape(-1) for v_el in v])
vhp = [np.zeros((_product(x.shape)), dtype=np_dtype) for x in xs]
......@@ -198,3 +228,166 @@ def _compute_numerical_vhp(func, xs, v, delta, np_dtype):
flat_v)
vhp = [vhp[j].reshape(xs[j].shape) for j in range(len(xs))]
return vhp
##########################################################
# TestCases of different function.
##########################################################
def reduce(x):
return paddle.sum(x)
def reduce_dim(x):
return paddle.sum(x, axis=0)
def matmul(x, y):
return paddle.matmul(x, y)
def mul(x, y):
return x * y
def pow(x, y):
return paddle.pow(x, y)
def o2(x, y):
return paddle.multiply(x, y), paddle.matmul(x, y.t())
def unuse(x, y):
return paddle.sum(x)
def nested(x):
def inner(y):
return x * y
return inner
def square(x):
return x * x
##########################################################
# Parameterized Test Utils.
##########################################################
TEST_CASE_NAME = 'suffix'
def place(devices, key='place'):
"""A Decorator for a class which will make the class running on different
devices .
Args:
devices (Sequence[Paddle.CUDAPlace|Paddle.CPUPlace]): Device list.
key (str, optional): Defaults to 'place'.
"""
def decorate(cls):
module = sys.modules[cls.__module__].__dict__
raw_classes = {
k: v
for k, v in module.items() if k.startswith(cls.__name__)
}
for raw_name, raw_cls in raw_classes.items():
for d in devices:
test_cls = dict(raw_cls.__dict__)
test_cls.update({key: d})
new_name = raw_name + '.' + d.__class__.__name__
module[new_name] = type(new_name, (raw_cls, ), test_cls)
del module[raw_name]
return cls
return decorate
def parameterize(fields, values=None):
"""Decorator for a unittest class which make the class running on different
test cases.
Args:
fields (Sequence): The feild name sequence of test cases.
values (Sequence, optional): The test cases sequence. Defaults to None.
"""
fields = [fields] if isinstance(fields, str) else fields
params = [dict(zip(fields, vals)) for vals in values]
def decorate(cls):
test_cls_module = sys.modules[cls.__module__].__dict__
for i, values in enumerate(params):
test_cls = dict(cls.__dict__)
values = {
k: staticmethod(v) if callable(v) else v
for k, v in values.items()
}
test_cls.update(values)
name = cls.__name__ + str(i)
name = name + '.' + \
values.get('suffix') if values.get('suffix') else name
test_cls_module[name] = type(name, (cls, ), test_cls)
for m in list(cls.__dict__):
if m.startswith("test"):
delattr(cls, m)
return cls
return decorate
##########################################################
# Utils for transpose different Jacobian/Hessian matrix format.
##########################################################
# B is batch size, N is row size, M is column size.
MatrixFormat = enum.Enum('MatrixFormat', ('NBM', 'BNM', 'NMB', 'NM'))
def _np_transpose_matrix_format(src, src_format, des_format):
"""Transpose Jacobian/Hessian matrix format."""
supported_format = (MatrixFormat.NBM, MatrixFormat.BNM, MatrixFormat.NMB)
if src_format not in supported_format or des_format not in supported_format:
raise ValueError(
f"Supported Jacobian format is {supported_format}, but got src: {src_format}, des: {des_format}"
)
src_axis = {c: i for i, c in enumerate(src_format.name)}
dst_axis = tuple(src_axis[c] for c in des_format.name)
return np.transpose(src, dst_axis)
def _np_concat_matrix_sequence(src, src_format=MatrixFormat.NM):
"""Convert a sequence of sequence of Jacobian/Hessian matrix into one huge
matrix."""
def concat_col(xs):
if src_format in (MatrixFormat.NBM, MatrixFormat.BNM, MatrixFormat.NM):
return np.concatenate(xs, axis=-1)
else:
return np.concatenate(xs, axis=1)
def concat_row(xs):
if src_format in (MatrixFormat.NBM, MatrixFormat.NM, MatrixFormat.NMB):
return np.concatenate(xs, axis=0)
else:
return np.concatenate(xs, axis=1)
supported_format = (MatrixFormat.NBM, MatrixFormat.BNM, MatrixFormat.NMB,
MatrixFormat.NM)
if src_format not in supported_format:
raise ValueError(
f"Supported Jacobian format is {supported_format}, but got {src_format}"
)
if not isinstance(src, typing.Sequence):
return src
if not isinstance(src[0], typing.Sequence):
src = [src]
return concat_row(tuple(concat_col(xs) for xs in src))
......@@ -26,6 +26,7 @@ from .tensor import segment_mean
from .tensor import segment_max
from .tensor import segment_min
from .passes import fuse_resnet_unit_pass
import paddle.incubate.autograd
from . import nn #noqa: F401
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.autograd.functional import Hessian, Jacobian, jvp, vjp
__all__ = [ # noqa
'vjp', 'jvp', 'Jacobian', 'Hessian'
]
......@@ -273,6 +273,7 @@ packages=['paddle',
'paddle.distributed.ps',
'paddle.distributed.ps.utils',
'paddle.incubate',
'paddle.incubate.autograd',
'paddle.incubate.optimizer',
'paddle.incubate.checkpoint',
'paddle.incubate.operators',
......
......@@ -12,55 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
set +x
NIGHTLY_MODE=$1
PRECISION_TEST=$2
WITH_GPU=$3
export PADDLE_ROOT="$(cd "$PWD/../" && pwd )"
if [ ${NIGHTLY_MODE:-OFF} == "ON" ]; then
nightly_label=""
else
nightly_label="(RUN_TYPE=NIGHTLY|RUN_TYPE=DIST:NIGHTLY|RUN_TYPE=EXCLUSIVE:NIGHTLY)"
echo "========================================="
echo "Unittests with nightly labels are only run at night"
echo "========================================="
fi
if disable_ut_quickly=$(python ${PADDLE_ROOT}/tools/get_quick_disable_lt.py); then
echo "========================================="
echo "The following unittests have been disabled:"
echo ${disable_ut_quickly}
echo "========================================="
else
disable_ut_quickly=''
fi
# check added ut
set +e
cp $PADDLE_ROOT/tools/check_added_ut.sh $PADDLE_ROOT/tools/check_added_ut_win.sh
bash $PADDLE_ROOT/tools/check_added_ut_win.sh
rm -rf $PADDLE_ROOT/tools/check_added_ut_win.sh
if [ -f "$PADDLE_ROOT/added_ut" ];then
added_uts=^$(awk BEGIN{RS=EOF}'{gsub(/\n/,"$|^");print}' $PADDLE_ROOT/added_ut)$
ctest -R "(${added_uts})" --output-on-failure -C Release --repeat-until-fail 3;added_ut_error=$?
rm -f $PADDLE_ROOT/added_ut
if [ "$added_ut_error" != 0 ];then
echo "========================================"
echo "Added UT should pass three additional executions"
echo "========================================"
exit 8;
fi
if nvcc --version | grep 11.2; then
echo "Only test added_ut temporarily when running in CI-Windows-inference of CUDA 11.2."
exit 0;
fi
fi
set -e
# /*==================Fixed Disabled Windows GPU MKL unittests==============================*/
# /*================Fixed Disabled Windows CUDA10.x MKL(PR-CI-Windows) unittests===========================*/
# TODO: fix these unittest that is bound to fail
disable_wingpu_test="^test_model$|\
^test_dataloader_early_reset$|\
......@@ -97,7 +50,7 @@ disable_wingpu_test="^test_model$|\
^test_bilinear_interp_op$|\
^disable_wingpu_test$"
# /*==================Fixed Disabled Windows GPU MKL unittests==============================*/
# /*=================Fixed Disabled Windows TRT MKL unittests=======================*/
# TODO: fix these unittest that is bound to fail
disable_win_trt_test="^test_trt_convert_conv2d$|\
^test_trt_convert_conv2d_fusion$|\
......@@ -119,7 +72,13 @@ disable_win_trt_test="^test_trt_convert_conv2d$|\
^test_trt_convert_matmul$|\
^test_trt_convert_scale$"
# /*==================Fixed Disabled Windows GPU inference_api_test unittests==============================*/
# /*=============Fixed Disabled Windows CUDA11.x MKL(PR-CI-Windows-Inference) unittests=================*/
# TODO: fix these unittest that is bound to fail
disable_wingpu11_test="^test_autograd_functional_dynamic$|\
^disable_wingpu_test$"
# /*==========Fixed Disabled Windows CUDA11.x inference_api_test(PR-CI-Windows-Inference) unittests=============*/
disable_win_inference_api_test="^trt_quant_int8_yolov3_r50_test$|\
^test_trt_dynamic_shape_ernie$|\
^test_trt_dynamic_shape_ernie_fp16_ser_deser$|\
......@@ -128,9 +87,8 @@ disable_win_inference_api_test="^trt_quant_int8_yolov3_r50_test$|\
^lite_mul_model_test$|\
^paddle_infer_api_copy_tensor_tester$"
# /*============================================================================*/
# /*==================Fixed Disabled Windows CPU OPENBLAS unittests==============================*/
# /*==========Fixed Disabled Windows CPU OPENBLAS((PR-CI-Windows-OPENBLAS)) unittests==============================*/
# TODO: fix these unittest that is bound to fail
disable_wincpu_test="^jit_kernel_test$|\
^test_analyzer_transformer$|\
......@@ -189,6 +147,58 @@ long_time_test="^test_gru_op$|\
^test_trt_matmul_quant_dequant$|\
^test_strided_slice_op$"
# /*============================================================================*/
set -e
set +x
NIGHTLY_MODE=$1
PRECISION_TEST=$2
WITH_GPU=$3
export PADDLE_ROOT="$(cd "$PWD/../" && pwd )"
if [ ${NIGHTLY_MODE:-OFF} == "ON" ]; then
nightly_label=""
else
nightly_label="(RUN_TYPE=NIGHTLY|RUN_TYPE=DIST:NIGHTLY|RUN_TYPE=EXCLUSIVE:NIGHTLY)"
echo "========================================="
echo "Unittests with nightly labels are only run at night"
echo "========================================="
fi
if disable_ut_quickly=$(python ${PADDLE_ROOT}/tools/get_quick_disable_lt.py); then
echo "========================================="
echo "The following unittests have been disabled:"
echo ${disable_ut_quickly}
echo "========================================="
else
disable_ut_quickly=''
fi
# check added ut
set +e
cp $PADDLE_ROOT/tools/check_added_ut.sh $PADDLE_ROOT/tools/check_added_ut_win.sh
bash $PADDLE_ROOT/tools/check_added_ut_win.sh
rm -rf $PADDLE_ROOT/tools/check_added_ut_win.sh
if [ -f "$PADDLE_ROOT/added_ut" ];then
added_uts=^$(awk BEGIN{RS=EOF}'{gsub(/\n/,"$|^");print}' $PADDLE_ROOT/added_ut)$
ctest -R "(${added_uts})" -E "$disable_wingpu11_test" --output-on-failure -C Release --repeat-until-fail 3;added_ut_error=$?
rm -f $PADDLE_ROOT/added_ut
if [ "$added_ut_error" != 0 ];then
echo "========================================"
echo "Added UT should pass three additional executions"
echo "========================================"
exit 8;
fi
if nvcc --version | grep 11.2; then
echo "Only test added_ut temporarily when running in CI-Windows-inference of CUDA 11.2."
exit 0;
fi
fi
set -e
if [ ${WITH_GPU:-OFF} == "ON" ];then
export CUDA_VISIBLE_DEVICES=0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册