未验证 提交 9e764d82 编写于 作者: X Xiaoxu Chen 提交者: GitHub

Enhance vjp/jvp/Jacobian/Hessian API for supporting dynamic, static graph and...

Enhance vjp/jvp/Jacobian/Hessian API for supporting dynamic, static graph and batched, unbatched mode (#40692)

* modify vjp/jvp for both dynamic and static graph

* enforce jacobian class for supporting first/last batch

* add unittest for jvp, jacobian withlast batch, jacobian with first batch

* fix the incorrect shape when multi-index Jacobian

* enforce Hessian class for supporting dynamic graph

* add Hessian class unittest

* bugfix, jvp double_backward_trick zeros_like return stop_gradient=True in static graph

* add API beta warnnings

* add white_list for cuda11.x ci windows.

* optimize some code snippets and documments

* set unittest timeout to 100 seconds

* move vjp,jvp,Jacobian,Hessian to incubate

* fix vjp,vjp import path of sample code

* fix code style error of augtograd/__init__ file
上级 ab8c33b1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -13,12 +13,18 @@ ...@@ -13,12 +13,18 @@
# limitations under the License. # limitations under the License.
from ..fluid.dygraph.base import grad # noqa: F401 from ..fluid.dygraph.base import grad # noqa: F401
from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401
from ..framework import is_grad_enabled, set_grad_enabled # noqa: F401
from . import backward_mode # noqa: F401 from . import backward_mode # noqa: F401
from .backward_mode import backward # noqa: F401 from .backward_mode import backward # noqa: F401
from .py_layer import PyLayer, PyLayerContext, EagerPyLayer, EagerPyLayerContext # noqa: F401 from .py_layer import PyLayer, PyLayerContext, EagerPyLayer, EagerPyLayerContext # noqa: F401
from ..framework import set_grad_enabled, is_grad_enabled # noqa: F401 from ..framework import set_grad_enabled, is_grad_enabled # noqa: F401
from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401 from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401
from .functional import jacobian, hessian, batch_jacobian, batch_hessian # noqa: F401 from .functional import vjp, jvp, Jacobian, Hessian # noqa: F401
from .functional import vjp, jvp, vhp # noqa: F401 from .functional import jacobian, hessian, batch_jacobian, batch_hessian, vhp # noqa: F401
__all__ = ['backward', 'PyLayer', 'PyLayerContext'] __all__ = [ # noqa
'backward',
'PyLayer',
'PyLayerContext',
]
...@@ -6,6 +6,5 @@ foreach(TEST_OP ${TEST_OPS}) ...@@ -6,6 +6,5 @@ foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS})
endforeach(TEST_OP) endforeach(TEST_OP)
set_tests_properties(test_jacobian PROPERTIES TIMEOUT 50) set_tests_properties(test_autograd_functional_dynamic PROPERTIES TIMEOUT 100)
set_tests_properties(test_hessian PROPERTIES TIMEOUT 50) set_tests_properties(test_autograd_functional_static PROPERTIES TIMEOUT 100)
set_tests_properties(test_vhp PROPERTIES TIMEOUT 50)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle import paddle
DEVICES = [paddle.CPUPlace()]
if paddle.is_compiled_with_cuda():
DEVICES.append(paddle.CUDAPlace(0))
def _tensors(ts, name): DEFAULT_DTYPE = 'float64'
if isinstance(ts, (list, tuple)):
assert len(ts) > 0, "{} connot be empty".format(name)
for each_t in ts:
assert isinstance(
each_t, paddle.Tensor
) or each_t is None, "Elements of {} must be paddle.Tensor or None".format(
name)
return list(ts)
else:
assert isinstance(ts, paddle.Tensor), "{} must be Tensor".format(name)
return [ts]
def _stack_tensor_or_return_none(origin_list):
assert len(origin_list) > 0, "Can't not stack an empty list"
return paddle.stack(
origin_list, axis=0) if isinstance(origin_list[0],
paddle.Tensor) else None
def _replace_none_with_zero_tensor(t, spec_t): # The numerical tolerance of different dtype of different order different
if t is None: # derivative. It's a empirical value provided by Paddle Science team.
zero_t = paddle.zeros(shape=spec_t.shape, dtype=spec_t.dtype) TOLERANCE = {
zero_t.stop_gradient = spec_t.stop_gradient "float32": {
return zero_t "first_order_grad": {
else: "rtol": 1e-3,
return t "atol": 1e-3,
"eps": 1e-4
},
"second_order_grad": {
"rtol": 1e-2,
"atol": 1e-2,
"eps": 1e-2
}
},
"float64": {
"first_order_grad": {
"rtol": 1e-7,
"atol": 1e-7,
"eps": 1e-7
},
"second_order_grad": {
"rtol": 1e-5,
"atol": 1e-5,
"eps": 1e-5
}
}
}
...@@ -12,17 +12,137 @@ ...@@ -12,17 +12,137 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import typing
import unittest import unittest
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from utils import _compute_numerical_jacobian, _compute_numerical_batch_jacobian
import config
import utils
from utils import (_compute_numerical_batch_jacobian,
_compute_numerical_jacobian)
from paddle.autograd.functional import _as_tensors
paddle.enable_static()
@utils.place(config.DEVICES)
@utils.parameterize((utils.TEST_CASE_NAME, 'fun', 'xs', 'v', 'stop_gradient'), (
('tensor_input', utils.reduce, np.random.rand(2, 3), None, False),
('tensor_sequence_input', utils.reduce, np.random.rand(2, 3), None, False),
('v_not_none', utils.reduce, np.random.rand(2, 3), np.random.rand(1),
False),
('xs_stop_gradient', utils.reduce, np.random.rand(2, 3), np.random.rand(1),
True),
('func_mutmul', utils.matmul, (np.random.rand(3, 2), np.random.rand(2, 3)),
None, False),
('func_mul', utils.mul, (np.random.rand(3, 3), np.random.rand(3, 3)), None,
False),
('func_out_two', utils.o2, (np.random.rand(10), np.random.rand(10)), None,
False), ))
class TestVJP(unittest.TestCase):
def setUp(self):
self.dtype = str(self.xs[0].dtype) if isinstance(
self.xs, typing.Sequence) else str(self.xs.dtype)
self._rtol = config.TOLERANCE.get(str(self.dtype)).get(
"first_order_grad").get("rtol")
self._atol = config.TOLERANCE.get(str(self.dtype)).get(
"first_order_grad").get("atol")
def _vjp(self):
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
feed, static_xs, static_v = gen_static_data_and_feed(
self.xs, self.v, stop_gradient=self.stop_gradient)
ys, xs_grads = paddle.autograd.vjp(self.fun, static_xs, static_v)
exe.run(sp)
return exe.run(mp, feed=feed, fetch_list=[ys, xs_grads])
def _expected_vjp(self):
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
feed, static_xs, static_v = gen_static_data_and_feed(self.xs,
self.v, False)
ys = self.fun(*static_xs) if isinstance(
static_xs, typing.Sequence) else self.fun(static_xs)
xs_grads = paddle.static.gradients(ys, static_xs, static_v)
exe.run(sp)
return exe.run(mp, feed=feed, fetch_list=[ys, xs_grads])
def test_vjp(self):
actual = self._vjp()
expected = self._expected_vjp()
self.assertEqual(len(actual), len(expected))
for i in range(len(actual)):
np.testing.assert_allclose(
actual[i], expected[i], rtol=self._rtol, atol=self._atol)
@utils.place(config.DEVICES)
@utils.parameterize(
(utils.TEST_CASE_NAME, 'fun', 'xs', 'v', 'expected_exception'), (
('v_shape_not_equal_ys', utils.square, np.random.rand(3),
np.random.rand(1), RuntimeError), ))
class TestVJPException(unittest.TestCase):
def setUp(self):
self.exe = paddle.static.Executor()
def _vjp(self):
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
feed, static_xs, static_v = gen_static_data_and_feed(self.xs,
self.v)
ys, xs_grads = paddle.autograd.vjp(self.fun, static_xs, static_v)
self.exe.run(sp)
return self.exe.run(mp, feed, fetch_list=[ys, xs_grads])
def test_vjp(self):
with self.assertRaises(self.expected_exception):
self._vjp()
def gen_static_data_and_feed(xs, v, stop_gradient=True):
feed = {}
if isinstance(xs, typing.Sequence):
static_xs = []
for i, x in enumerate(xs):
x = paddle.static.data(f"x{i}", x.shape, x.dtype)
x.stop_gradient = stop_gradient
static_xs.append(x)
feed.update({f'x{idx}': value for idx, value in enumerate(xs)})
else:
static_xs = paddle.static.data('x', xs.shape, xs.dtype)
static_xs.stop_gradient = stop_gradient
feed.update({'x': xs})
if isinstance(v, typing.Sequence):
static_v = []
for i, e in enumerate(v):
e = paddle.static.data(f'v{idx}', v.shape, v.dtype)
e.stop_gradient = stop_gradient
static_v.append(e)
feed.update({f'v{idx}': value for idx, value in v})
elif v is not None:
static_v = paddle.static.data('v', v.shape, v.dtype)
static_v.stop_gradient = stop_gradient
feed.update({'v': v})
else:
static_v = v
return feed, static_xs, static_v
def approx_jacobian(f, xs, dtype, eps=1e-5, batch=False): def approx_jacobian(f, xs, dtype, eps=1e-5, batch=False):
r"""Computes an approximate Jacobian matrix of a multi-valued function r"""Computes an approximate Jacobian matrix of a multi-valued function
using finite differences. using finite differences.
The function input is required to be an np array or a list of list of np The function input is required to be an np array or a list of list of np
arrays. arrays.
""" """
...@@ -106,8 +226,13 @@ class TestJacobianFloat32(unittest.TestCase): ...@@ -106,8 +226,13 @@ class TestJacobianFloat32(unittest.TestCase):
else: else:
self.place = fluid.CPUPlace() self.place = fluid.CPUPlace()
self.dtype = 'float32' self.dtype = 'float32'
self.np_dtype = np.float32
prepare_data(self, all_data_shapes, self.dtype) prepare_data(self, all_data_shapes, self.dtype)
self.eps = 1e-4 self.eps = config.TOLERANCE.get(self.dtype).get('first_order_grad').get(
'eps')
# self.rtol = config.TOLERANCE.get(self.dtype).get('first_order_grad').get('rtol')
# self.atol = config.TOLERANCE.get(self.dtype).get('first_order_grad').get('atol')
# Do't use tolerance in config, which will cause this test case failed.
self.rtol = 1e-2 self.rtol = 1e-2
self.atol = 1e-2 self.atol = 1e-2
...@@ -116,8 +241,11 @@ class TestJacobianFloat32(unittest.TestCase): ...@@ -116,8 +241,11 @@ class TestJacobianFloat32(unittest.TestCase):
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
xs = make_tensors(inps) xs = make_tensors(inps)
JJ = paddle.autograd.functional.Jacobian(pd_f, xs, batch=batch) JJ = paddle.autograd.functional.Jacobian(pd_f, xs, is_batched=batch)
nrow, ncol = JJ.shape() if batch:
_, nrow, ncol = JJ.shape
else:
nrow, ncol = JJ.shape
full_jacobian = JJ[:] full_jacobian = JJ[:]
exe = fluid.Executor(self.place) exe = fluid.Executor(self.place)
exe.run(startup) exe.run(startup)
...@@ -128,17 +256,26 @@ class TestJacobianFloat32(unittest.TestCase): ...@@ -128,17 +256,26 @@ class TestJacobianFloat32(unittest.TestCase):
pd_jacobians = exe.run(main, feed=feeds, fetch_list=[full_jacobian])[0] pd_jacobians = exe.run(main, feed=feeds, fetch_list=[full_jacobian])[0]
np_jacobians = approx_jacobian( np_jacobians = approx_jacobian(
np_f, inps, self.dtype, self.eps, batch=batch) np_f, inps, self.dtype, self.eps, batch=batch)
self.assertTrue( if batch:
np.allclose(pd_jacobians, np_jacobians, self.rtol, self.atol)) np_jacobians = utils._np_transpose_matrix_format(
np_jacobians, utils.MatrixFormat.NBM, utils.MatrixFormat.BNM)
np.testing.assert_allclose(pd_jacobians, np_jacobians, self.rtol,
self.atol)
def run_test_by_rows(self, pd_f, np_f, inps, batch=False): def run_test_by_rows(self, pd_f, np_f, inps, batch=False):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
xs = make_tensors(inps) xs = make_tensors(inps)
JJ = paddle.autograd.functional.Jacobian(pd_f, xs, batch=batch) JJ = paddle.autograd.functional.Jacobian(pd_f, xs, is_batched=batch)
nrow, ncol = JJ.shape() if batch:
rows = [JJ[i] for i in range(nrow)] nbatch, nrow, ncol = JJ.shape
rows = [JJ[:, i, :] for i in range(nrow)]
else:
nrow, ncol = JJ.shape
rows = [JJ[i, :] for i in range(nrow)]
exe = fluid.Executor(self.place) exe = fluid.Executor(self.place)
exe.run(startup) exe.run(startup)
if isinstance(inps, list): if isinstance(inps, list):
...@@ -148,17 +285,23 @@ class TestJacobianFloat32(unittest.TestCase): ...@@ -148,17 +285,23 @@ class TestJacobianFloat32(unittest.TestCase):
pd_jac = exe.run(main, feed=feeds, fetch_list=[rows]) pd_jac = exe.run(main, feed=feeds, fetch_list=[rows])
np_jac = approx_jacobian(np_f, inps, self.dtype, self.eps, batch=batch) np_jac = approx_jacobian(np_f, inps, self.dtype, self.eps, batch=batch)
for i in range(nrow): for i in range(nrow):
self.assertTrue( np.testing.assert_allclose(pd_jac[i], np_jac[i], self.rtol,
np.allclose(pd_jac[i], np_jac[i], self.rtol, self.atol)) self.atol)
def run_test_by_entries(self, pd_f, np_f, inps, batch=False): def run_test_by_entries(self, pd_f, np_f, inps, batch=False):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
xs = make_tensors(inps) xs = make_tensors(inps)
JJ = paddle.autograd.functional.Jacobian(pd_f, xs, batch=batch) JJ = paddle.autograd.functional.Jacobian(pd_f, xs, is_batched=batch)
nrow, ncol = JJ.shape() if batch:
entries = [JJ[i, j] for i in range(nrow) for j in range(ncol)] nbatch, nrow, ncol = JJ.shape
entries = [
JJ[:, i, j] for i in range(nrow) for j in range(ncol)
]
else:
nrow, ncol = JJ.shape
entries = [JJ[i, j] for i in range(nrow) for j in range(ncol)]
exe = fluid.Executor(self.place) exe = fluid.Executor(self.place)
exe.run(startup) exe.run(startup)
if isinstance(inps, list): if isinstance(inps, list):
...@@ -171,8 +314,7 @@ class TestJacobianFloat32(unittest.TestCase): ...@@ -171,8 +314,7 @@ class TestJacobianFloat32(unittest.TestCase):
np_jac[i, ..., j] for i in range(nrow) for j in range(ncol) np_jac[i, ..., j] for i in range(nrow) for j in range(ncol)
] ]
for pd_entry, np_entry in zip(pd_entries, np_entries): for pd_entry, np_entry in zip(pd_entries, np_entries):
self.assertTrue( np.testing.assert_allclose(pd_entry, np_entry, self.rtol, self.atol)
np.allclose(pd_entry, np_entry, self.rtol, self.atol))
def test_square(self): def test_square(self):
def pd_f(x): def pd_f(x):
...@@ -186,8 +328,7 @@ class TestJacobianFloat32(unittest.TestCase): ...@@ -186,8 +328,7 @@ class TestJacobianFloat32(unittest.TestCase):
self.run_test_by_entries(pd_f, np_f, self.A) self.run_test_by_entries(pd_f, np_f, self.A)
def test_mul(self): def test_mul(self):
def pd_f(xs): def pd_f(x, y):
x, y = xs
return paddle.multiply(x, y) return paddle.multiply(x, y)
def np_f(xs): def np_f(xs):
...@@ -202,8 +343,7 @@ class TestJacobianFloat32(unittest.TestCase): ...@@ -202,8 +343,7 @@ class TestJacobianFloat32(unittest.TestCase):
self.run_test_by_entries(pd_f, np_f, [self.B, self.C]) self.run_test_by_entries(pd_f, np_f, [self.B, self.C])
def test_matmul(self): def test_matmul(self):
def pd_f(xs): def pd_f(x, y):
x, y = xs
return paddle.matmul(x, y) return paddle.matmul(x, y)
def np_f(xs): def np_f(xs):
...@@ -215,8 +355,7 @@ class TestJacobianFloat32(unittest.TestCase): ...@@ -215,8 +355,7 @@ class TestJacobianFloat32(unittest.TestCase):
self.run_test_by_entries(pd_f, np_f, [self.B, self.C]) self.run_test_by_entries(pd_f, np_f, [self.B, self.C])
def test_batch_matmul(self): def test_batch_matmul(self):
def pd_f(xs): def pd_f(x, y):
x, y = xs
return paddle.matmul(x, y) return paddle.matmul(x, y)
def np_f(xs): def np_f(xs):
...@@ -238,12 +377,15 @@ class TestJacobianFloat64(TestJacobianFloat32): ...@@ -238,12 +377,15 @@ class TestJacobianFloat64(TestJacobianFloat32):
self.place = fluid.CPUPlace() self.place = fluid.CPUPlace()
self.dtype = 'float64' self.dtype = 'float64'
prepare_data(self, all_data_shapes, self.dtype) prepare_data(self, all_data_shapes, self.dtype)
self.eps = 1e-7 self.eps = config.TOLERANCE.get(self.dtype).get('first_order_grad').get(
self.rtol = 1e-6 'eps')
self.atol = 1e-6 self.rtol = config.TOLERANCE.get(self.dtype).get(
'first_order_grad').get('rtol')
self.atol = config.TOLERANCE.get(self.dtype).get(
'first_order_grad').get('atol')
class TestHessianFloat64(unittest.TestCase): class TestHessianFloat32(unittest.TestCase):
@classmethod @classmethod
def setUpClass(self): def setUpClass(self):
paddle.enable_static() paddle.enable_static()
...@@ -251,19 +393,22 @@ class TestHessianFloat64(unittest.TestCase): ...@@ -251,19 +393,22 @@ class TestHessianFloat64(unittest.TestCase):
self.place = fluid.CUDAPlace(0) self.place = fluid.CUDAPlace(0)
else: else:
self.place = fluid.CPUPlace() self.place = fluid.CPUPlace()
self.dtype = 'float64' self.dtype = 'float32'
prepare_data(self, all_data_shapes, self.dtype) prepare_data(self, all_data_shapes, self.dtype)
self.eps = 1e-7 self.eps = config.TOLERANCE.get(self.dtype).get(
self.rtol = 1e-6 'second_order_grad').get('eps')
self.atol = 1e-6 self.rtol = config.TOLERANCE.get(self.dtype).get(
'second_order_grad').get('rtol')
self.atol = config.TOLERANCE.get(self.dtype).get(
'second_order_grad').get('atol')
def run_test_by_fullmatrix(self, pd_f, inps, np_hess, batch=False): def run_test_by_fullmatrix(self, pd_f, inps, np_hess, batch=False):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
xs = make_tensors(inps) xs = make_tensors(inps)
HH = paddle.autograd.functional.Hessian(pd_f, xs, batch=batch) HH = paddle.autograd.functional.Hessian(pd_f, xs, is_batched=batch)
nrow, ncol = HH.shape() nrow, ncol = HH.shape
full_hessian = HH[:] full_hessian = HH[:]
exe = fluid.Executor(self.place) exe = fluid.Executor(self.place)
exe.run(startup) exe.run(startup)
...@@ -272,36 +417,38 @@ class TestHessianFloat64(unittest.TestCase): ...@@ -272,36 +417,38 @@ class TestHessianFloat64(unittest.TestCase):
else: else:
feeds = {'x': inps} feeds = {'x': inps}
pd_hess = exe.run(main, feed=feeds, fetch_list=[full_hessian])[0] pd_hess = exe.run(main, feed=feeds, fetch_list=[full_hessian])[0]
self.assertTrue(np.allclose(pd_hess, np_hess, self.rtol, self.atol)) np.testing.assert_allclose(pd_hess, np_hess, self.rtol, self.atol)
def test_square(self): def test_square(self):
def pd_f(x): def pd_f(x):
"""Input is a square matrix.""" """Input is a square matrix."""
return paddle.matmul(x, x.T) return paddle.matmul(x, x.T).flatten().sum()
def np_hess(x): def np_hess(x):
dim = x.shape[0] dim = x.shape[0]
f_xx_upperleft = 2 * np.eye(dim, dtype=self.dtype) upperleft = 2 * np.eye(dim, dtype=self.dtype)
f_xx = np.zeros([dim * dim, dim * dim], dtype=self.dtype) upper = np.concatenate((upperleft, upperleft))
f_xx[:dim, :dim] = f_xx_upperleft return np.concatenate((upper, upper), axis=1)
return f_xx
self.run_test_by_fullmatrix(pd_f, self.B, np_hess(self.B)) self.run_test_by_fullmatrix(pd_f, self.B, np_hess(self.B))
def test_batch_square(self):
def pd_f(x):
"""Input is a square matrix."""
return paddle.matmul(x, paddle.transpose(x, [0, 2, 1]))
def np_hess(x):
bat, dim, _ = x.shape
f_xx_upperleft = 2 * np.eye(dim, dtype=self.dtype)
f_xx = np.zeros([bat, dim * dim, dim * dim], dtype=self.dtype)
f_xx[..., :dim, :dim] = f_xx_upperleft
return f_xx
self.run_test_by_fullmatrix( class TestHessianFloat64(TestHessianFloat32):
pd_f, self.E, np_hess(self.E), batch=True) @classmethod
def setUpClass(self):
paddle.enable_static()
if fluid.core.is_compiled_with_cuda():
self.place = fluid.CUDAPlace(0)
else:
self.place = fluid.CPUPlace()
self.dtype = 'float64'
prepare_data(self, all_data_shapes, self.dtype)
self.eps = config.TOLERANCE.get(self.dtype).get(
'second_order_grad').get('eps')
self.rtol = config.TOLERANCE.get(self.dtype).get(
'second_order_grad').get('rtol')
self.atol = config.TOLERANCE.get(self.dtype).get(
'second_order_grad').get('atol')
if __name__ == "__main__": if __name__ == "__main__":
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.compat as cpt
import paddle.nn.functional as F
from utils import _compute_numerical_hessian, _compute_numerical_batch_hessian
class TestHessian(unittest.TestCase):
@classmethod
def setUpClass(self):
self.shape = (2, 2)
self.dtype = 'float32'
self.np_dtype = np.float32
self.numerical_delta = 1e-2
self.rtol = 1e-2
self.atol = 1e-2
self.x = paddle.rand(shape=self.shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.shape, dtype=self.dtype)
def test_single_input(self):
def func(x):
return paddle.sum(paddle.matmul(x, x))
numerical_hessian = _compute_numerical_hessian(
func, self.x, self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
hessian = paddle.autograd.hessian(func, self.x)
assert np.allclose(hessian.numpy(), numerical_hessian[0][0], self.rtol,
self.atol)
def test_multi_input(self):
def func(x, y):
return paddle.sum(paddle.matmul(x, y))
numerical_hessian = _compute_numerical_hessian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
hessian = paddle.autograd.hessian(func, [self.x, self.y])
for i in range(len(hessian)):
for j in range(len(hessian[0])):
assert np.allclose(hessian[i][j].numpy(),
numerical_hessian[i][j], self.rtol,
self.atol)
def test_allow_unused_false(self):
def func(x, y):
return paddle.sum(paddle.matmul(x, x))
try:
self.x.stop_gradient = False
self.y.stop_gradient = False
hessian = paddle.autograd.hessian(func, [self.x, self.y])
except ValueError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("allow_unused") > 0
def test_allow_unused_true(self):
def func(x, y):
return paddle.sum(paddle.matmul(x, x))
numerical_hessian = _compute_numerical_hessian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
hessian = paddle.autograd.hessian(
func, [self.x, self.y], allow_unused=True)
for i in range(len(hessian)):
for j in range(len(hessian[0])):
if i == j == 0:
assert np.allclose(hessian[i][j].numpy(),
numerical_hessian[i][j], self.rtol,
self.atol)
else:
assert hessian[i][j] is None
def test_create_graph_false(self):
def func(x):
return paddle.sum(paddle.matmul(x, x))
numerical_hessian = _compute_numerical_hessian(
func, self.x, self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
hessian = paddle.autograd.hessian(func, self.x)
assert hessian.stop_gradient == True
assert np.allclose(hessian.numpy(), numerical_hessian[0][0], self.rtol,
self.atol)
try:
paddle.grad(hessian, self.x)
except RuntimeError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("has no gradient") > 0
def test_create_graph_true(self):
def func(x):
return paddle.sum(F.sigmoid(x))
numerical_hessian = _compute_numerical_hessian(
func, self.x, self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
hessian = paddle.autograd.hessian(func, self.x, create_graph=True)
assert hessian.stop_gradient == False
assert np.allclose(hessian.numpy(), numerical_hessian[0][0], self.rtol,
self.atol)
triple_grad = paddle.grad(hessian, self.x)
assert triple_grad is not None
class TestHessianFloat64(TestHessian):
@classmethod
def setUpClass(self):
self.shape = (2, 2)
self.dtype = 'float64'
self.np_dtype = np.float64
self.numerical_delta = 1e-5
self.rtol = 1e-5
self.atol = 1e-5
self.x = paddle.rand(shape=self.shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.shape, dtype=self.dtype)
class TestBatchHessian(unittest.TestCase):
@classmethod
def setUpClass(self):
self.x_shape = (5, 2)
self.weight_shape = (2, 4)
self.y_shape = (5, 2)
self.dtype = 'float32'
self.np_dtype = np.float32
self.numerical_delta = 1e-2
self.rtol = 1e-3
self.atol = 1e-3
self.x = paddle.rand(shape=self.x_shape, dtype=self.dtype)
self.weight = paddle.rand(shape=self.weight_shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.y_shape, dtype=self.dtype)
def test_single_input(self):
def func(x):
return paddle.matmul(x * x, self.weight)[:, 0:1]
numerical_hessian = _compute_numerical_batch_hessian(
func, self.x, self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
hessian = paddle.autograd.batch_hessian(func, self.x, create_graph=True)
assert np.allclose(hessian, numerical_hessian, self.rtol, self.atol)
def test_multi_input(self):
def func(x, y):
return paddle.matmul(x * x * y * y, self.weight)[:, 0:1]
numerical_hessian = _compute_numerical_batch_hessian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
hessian = paddle.autograd.batch_hessian(func, [self.x, self.y])
shape_tensor = paddle.to_tensor(numerical_hessian).astype("float64")
hessian_reshape = np.reshape(hessian, (shape_tensor.shape))
assert np.allclose(hessian_reshape, numerical_hessian, self.rtol,
self.atol)
def test_allow_unused_false(self):
def func(x, y):
return paddle.matmul(x * x, self.weight)[:, 0:1]
try:
self.x.stop_gradient = False
self.y.stop_gradient = False
hessian = paddle.autograd.batch_hessian(func, [self.x, self.y])
except ValueError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("allow_unused") > 0
def test_allow_unused_true(self):
def func(x, y):
return paddle.matmul(x * x, self.weight)[:, 0:1]
numerical_hessian = _compute_numerical_batch_hessian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
hessian = paddle.autograd.batch_hessian(
func, [self.x, self.y], allow_unused=True)
for i in range(len(hessian)):
for j in range(len(hessian[0])):
if i == j == 0:
numerical_hessian = np.stack(
(numerical_hessian[i][j], numerical_hessian[i][j + 1]),
axis=0)
assert np.allclose(hessian[i][j], numerical_hessian,
self.rtol, self.atol)
else:
assert hessian[i][j] is None
def test_create_graph_false(self):
def func(x):
return paddle.matmul(x * x, self.weight)[:, 0:1]
numerical_hessian = _compute_numerical_batch_hessian(
func, self.x, self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
hessian = paddle.autograd.batch_hessian(func, self.x)
assert hessian.stop_gradient == True
assert np.allclose(hessian.numpy(), numerical_hessian, self.rtol,
self.atol)
try:
paddle.grad(hessian, self.x)
except RuntimeError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("has no gradient") > 0
def test_create_graph_true(self):
def func(x):
return paddle.matmul(x * x, self.weight)[:, 0:1]
numerical_hessian = _compute_numerical_batch_hessian(
func, self.x, self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
hessian = paddle.autograd.batch_hessian(func, self.x, create_graph=True)
assert hessian.stop_gradient == False
assert np.allclose(hessian.numpy(), numerical_hessian, self.rtol,
self.atol)
triple_grad = paddle.grad(hessian, self.x)
assert triple_grad is not None
class TestBatchHessianFloat64(TestBatchHessian):
@classmethod
def setUpClass(self):
self.x_shape = (5, 2)
self.weight_shape = (2, 4)
self.y_shape = (5, 2)
self.dtype = 'float64'
self.np_dtype = np.float64
self.numerical_delta = 1e-4
self.rtol = 1e-5
self.atol = 1e-5
self.x = paddle.rand(shape=self.x_shape, dtype=self.dtype)
self.weight = paddle.rand(shape=self.weight_shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.y_shape, dtype=self.dtype)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.compat as cpt
from utils import _compute_numerical_jacobian, _compute_numerical_batch_jacobian
class TestJacobian(unittest.TestCase):
@classmethod
def setUpClass(self):
self.shape = (4, 4)
self.dtype = 'float32'
self.np_dtype = np.float32
self.numerical_delta = 1e-4
self.rtol = 1e-3
self.atol = 1e-3
self.x = paddle.rand(shape=self.shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.shape, dtype=self.dtype)
def test_single_input_and_single_output(self):
def func(x):
return paddle.matmul(x, x)
numerical_jacobian = _compute_numerical_jacobian(
func, self.x, self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
jacobian = paddle.autograd.jacobian(func, self.x)
assert np.allclose(jacobian.numpy(), numerical_jacobian[0][0],
self.rtol, self.atol)
def test_single_input_and_multi_output(self):
def func(x):
return paddle.matmul(x, x), x * x
numerical_jacobian = _compute_numerical_jacobian(
func, self.x, self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
jacobian = paddle.autograd.jacobian(func, self.x)
for i in range(len(jacobian)):
assert np.allclose(jacobian[i].numpy(), numerical_jacobian[i][0],
self.rtol, self.atol)
def test_multi_input_and_single_output(self):
def func(x, y):
return paddle.matmul(x, y)
numerical_jacobian = _compute_numerical_jacobian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
jacobian = paddle.autograd.jacobian(func, [self.x, self.y])
for j in range(len(jacobian)):
assert np.allclose(jacobian[j].numpy(), numerical_jacobian[0][j],
self.rtol, self.atol)
def test_multi_input_and_multi_output(self):
def func(x, y):
return paddle.matmul(x, y), x * y
numerical_jacobian = _compute_numerical_jacobian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
jacobian = paddle.autograd.jacobian(func, [self.x, self.y])
for i in range(len(jacobian)):
for j in range(len(jacobian[0])):
assert np.allclose(jacobian[i][j].numpy(),
numerical_jacobian[i][j], self.rtol,
self.atol)
def test_allow_unused_false(self):
def func(x, y):
return paddle.matmul(x, x)
try:
self.x.stop_gradient = False
self.y.stop_gradient = False
jacobian = paddle.autograd.jacobian(func, [self.x, self.y])
except ValueError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("allow_unused") > 0
def test_allow_unused_true(self):
def func(x, y):
return paddle.matmul(x, x)
numerical_jacobian = _compute_numerical_jacobian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
jacobian = paddle.autograd.jacobian(
func, [self.x, self.y], allow_unused=True)
assert np.allclose(jacobian[0].numpy(), numerical_jacobian[0][0],
self.rtol, self.atol)
assert jacobian[1] is None
def test_create_graph_false(self):
def func(x, y):
return paddle.matmul(x, y)
numerical_jacobian = _compute_numerical_jacobian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
jacobian = paddle.autograd.jacobian(func, [self.x, self.y])
for j in range(len(jacobian)):
assert jacobian[j].stop_gradient == True
assert np.allclose(jacobian[j].numpy(), numerical_jacobian[0][j],
self.rtol, self.atol)
try:
paddle.grad(jacobian[0], [self.x, self.y])
except RuntimeError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("has no gradient") > 0
def test_create_graph_true(self):
def func(x, y):
return paddle.matmul(x, y)
numerical_jacobian = _compute_numerical_jacobian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
jacobian = paddle.autograd.jacobian(
func, [self.x, self.y], create_graph=True)
for j in range(len(jacobian)):
assert jacobian[j].stop_gradient == False
assert np.allclose(jacobian[j].numpy(), numerical_jacobian[0][j],
self.rtol, self.atol)
double_grad = paddle.grad(jacobian[0], [self.x, self.y])
assert double_grad is not None
class TestJacobianFloat64(TestJacobian):
@classmethod
def setUpClass(self):
self.shape = (4, 4)
self.dtype = 'float64'
self.np_dtype = np.float64
self.numerical_delta = 1e-7
self.rtol = 1e-7
self.atol = 1e-7
self.x = paddle.rand(shape=self.shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.shape, dtype=self.dtype)
class TestJacobianBatch(unittest.TestCase):
@classmethod
def setUpClass(self):
self.x_shape = (4, 2)
self.weight_shape = (2, 4)
self.y_shape = (4, 2)
self.dtype = 'float32'
self.np_dtype = np.float32
self.numerical_delta = 1e-4
self.rtol = 1e-3
self.atol = 1e-3
self.x = paddle.rand(shape=self.x_shape, dtype=self.dtype)
self.weight = paddle.rand(shape=self.weight_shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.y_shape, dtype=self.dtype)
def test_batch_single_input_and_batch_single_output(self):
def func(x):
return paddle.matmul(paddle.matmul(x, self.weight), self.y)
numerical_jacobian = _compute_numerical_batch_jacobian(
func, [self.x], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
batch_jacobian = paddle.autograd.batch_jacobian(
func,
self.x, )
self.assertTrue(
np.allclose(batch_jacobian.numpy().all(), numerical_jacobian[0][0]
.all()))
def test_batch_single_input_and_batch_multi_output(self):
def func(x):
return paddle.matmul(paddle.matmul(x, self.weight), self.y), x * x
numerical_jacobian = _compute_numerical_batch_jacobian(
func, [self.x], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
batch_jacobian = paddle.autograd.batch_jacobian(
func,
self.x, )
for i in range(len(batch_jacobian)):
assert np.allclose(batch_jacobian[i].numpy(),
numerical_jacobian[i][0], self.rtol, self.atol)
def test_batch_multi_input_and_batch_single_output(self):
def func(x, y):
return x * y
numerical_jacobian = _compute_numerical_batch_jacobian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
batch_jacobian = paddle.autograd.batch_jacobian(func, [self.x, self.y])
for j in range(len(batch_jacobian)):
assert np.allclose(batch_jacobian[j].numpy(),
numerical_jacobian[0][j], self.rtol, self.atol)
def test_batch_multi_input_and_batch_multi_output(self):
def func(x, y):
return x * y, x * y
numerical_jacobian = _compute_numerical_batch_jacobian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
batch_jacobian = paddle.autograd.batch_jacobian(func, [self.x, self.y])
for i in range(len(batch_jacobian)):
assert np.allclose(batch_jacobian[i], numerical_jacobian[i],
self.rtol, self.atol)
def test_allow_unused_false(self):
def func(x, y):
return x * x
try:
self.x.stop_gradient = False
self.y.stop_gradient = False
jacobian = paddle.autograd.batch_jacobian(func, [self.x, self.y])
except ValueError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("allow_unused") > 0
def test_allow_unused_true(self):
def func(x, y):
return x * x
numerical_jacobian = _compute_numerical_batch_jacobian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
jacobian = paddle.autograd.batch_jacobian(
func, [self.x, self.y], allow_unused=True)
assert np.allclose(jacobian[0].numpy(), numerical_jacobian[0][0],
self.rtol, self.atol)
assert jacobian[1] is None
def test_create_graph_false(self):
def func(x, y):
return x * y
numerical_jacobian = _compute_numerical_batch_jacobian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
jacobian = paddle.autograd.batch_jacobian(func, [self.x, self.y])
for j in range(len(jacobian)):
assert jacobian[j].stop_gradient == True
assert np.allclose(jacobian[j].numpy(), numerical_jacobian[0][j],
self.rtol, self.atol)
try:
paddle.grad(jacobian[0], [self.x, self.y])
except RuntimeError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("has no gradient") > 0
def test_create_graph_true(self):
def func(x, y):
return x * y
numerical_jacobian = _compute_numerical_batch_jacobian(
func, [self.x, self.y], self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
jacobian = paddle.autograd.batch_jacobian(
func, [self.x, self.y], create_graph=True)
for j in range(len(jacobian)):
assert jacobian[j].stop_gradient == False
assert np.allclose(jacobian[j].numpy(), numerical_jacobian[0][j],
self.rtol, self.atol)
double_grad = paddle.grad(jacobian[0], [self.x, self.y])
assert double_grad is not None
class TestJacobianBatchFloat64(TestJacobianBatch):
@classmethod
def setUpClass(self):
self.x_shape = (12, 2)
self.weight_shape = (2, 12)
self.y_shape = (12, 2)
self.dtype = 'float64'
self.np_dtype = np.float64
self.numerical_delta = 1e-7
self.rtol = 1e-7
self.atol = 1e-7
self.x = paddle.rand(shape=self.x_shape, dtype=self.dtype)
self.weight = paddle.rand(shape=self.weight_shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.y_shape, dtype=self.dtype)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.compat as cpt
import paddle.nn.functional as F
from utils import _compute_numerical_vhp
class TestVHP(unittest.TestCase):
@classmethod
def setUpClass(self):
self.shape = (2, 2)
self.dtype = 'float32'
self.np_dtype = np.float32
self.numerical_delta = 1e-2
self.rtol = 1e-2
self.atol = 1e-2
self.x = paddle.rand(shape=self.shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.shape, dtype=self.dtype)
self.vx = paddle.rand(shape=self.shape, dtype=self.dtype)
self.vy = paddle.rand(shape=self.shape, dtype=self.dtype)
def test_single_input(self):
def func(x):
return paddle.sum(paddle.matmul(x, x))
numerical_func_output = func(self.x).numpy()
numerical_vhp = _compute_numerical_vhp(
func, self.x, self.vx, self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
func_output, vhp = paddle.autograd.vhp(func, self.x, self.vx)
assert np.allclose(func_output.numpy(), numerical_func_output,
self.rtol, self.atol)
assert np.allclose(vhp[0].numpy(), numerical_vhp[0], self.rtol,
self.atol)
def test_multi_input(self):
def func(x, y):
return paddle.sum(paddle.matmul(x, y))
numerical_func_output = func(self.x, self.y).numpy()
numerical_vhp = _compute_numerical_vhp(
func, [self.x, self.y], [self.vx, self.vy], self.numerical_delta,
self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
func_output, vhp = paddle.autograd.vhp(func, [self.x, self.y],
[self.vx, self.vy])
assert np.allclose(func_output.numpy(), numerical_func_output,
self.rtol, self.atol)
for i in range(len(vhp)):
assert np.allclose(vhp[i].numpy(), numerical_vhp[i], self.rtol,
self.atol)
def test_v_default(self):
def func(x, y):
return paddle.sum(paddle.matmul(x, y))
numerical_func_output = func(self.x, self.y).numpy()
vx = paddle.ones(self.vx.shape, dtype=self.vx.dtype)
vy = paddle.ones(self.vy.shape, dtype=self.vy.dtype)
numerical_vhp = _compute_numerical_vhp(func, [self.x, self.y],
[vx, vy], self.numerical_delta,
self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
func_output, vhp = paddle.autograd.vhp(func, [self.x, self.y])
assert np.allclose(func_output.numpy(), numerical_func_output,
self.rtol, self.atol)
for i in range(len(vhp)):
assert np.allclose(vhp[i].numpy(), numerical_vhp[i], self.rtol,
self.atol)
def test_allow_unused_false(self):
def func(x, y):
return paddle.sum(paddle.matmul(x, x))
try:
self.x.stop_gradient = False
self.y.stop_gradient = False
_ = paddle.autograd.vhp(func, [self.x, self.y])
except ValueError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("allow_unused") > 0
def test_allow_unused_true(self):
def func(x, y):
return paddle.sum(paddle.matmul(x, x))
numerical_func_output = func(self.x, self.y).numpy()
numerical_vhp = _compute_numerical_vhp(
func, [self.x, self.y], [self.vx, self.vy], self.numerical_delta,
self.np_dtype)
self.x.stop_gradient = False
self.y.stop_gradient = False
func_output, vhp = paddle.autograd.vhp(func, [self.x, self.y],
[self.vx, self.vy],
allow_unused=True)
assert np.allclose(func_output.numpy(), numerical_func_output,
self.rtol, self.atol)
assert np.allclose(vhp[0].numpy(), numerical_vhp[0], self.rtol,
self.atol)
assert vhp[1] is None
def test_create_graph_false(self):
def func(x):
return paddle.sum(F.sigmoid(x))
numerical_func_output = func(self.x).numpy()
numerical_vhp = _compute_numerical_vhp(
func, self.x, self.vx, self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
func_output, vhp = paddle.autograd.vhp(func, self.x, self.vx)
assert np.allclose(func_output.numpy(), numerical_func_output,
self.rtol, self.atol)
assert vhp[0].stop_gradient == True
assert np.allclose(vhp[0].numpy(), numerical_vhp[0], self.rtol,
self.atol)
try:
paddle.grad(vhp, self.x)
except RuntimeError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("has no gradient") > 0
def test_create_graph_true(self):
def func(x):
return paddle.sum(F.sigmoid(x))
numerical_func_output = func(self.x).numpy()
numerical_vhp = _compute_numerical_vhp(
func, self.x, self.vx, self.numerical_delta, self.np_dtype)
self.x.stop_gradient = False
func_output, vhp = paddle.autograd.vhp(func,
self.x,
self.vx,
create_graph=True)
assert np.allclose(func_output.numpy(), numerical_func_output,
self.rtol, self.atol)
assert vhp[0].stop_gradient == False
assert np.allclose(vhp[0].numpy(), numerical_vhp[0], self.rtol,
self.atol)
triple_grad = paddle.grad(vhp, self.x)
assert triple_grad is not None
class TestVHPFloat64(TestVHP):
@classmethod
def setUpClass(self):
self.shape = (2, 2)
self.dtype = 'float64'
self.np_dtype = np.float64
self.numerical_delta = 1e-5
self.rtol = 1e-5
self.atol = 1e-5
self.x = paddle.rand(shape=self.shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.shape, dtype=self.dtype)
self.vx = paddle.rand(shape=self.shape, dtype=self.dtype)
self.vy = paddle.rand(shape=self.shape, dtype=self.dtype)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
from paddle.autograd.functional import vjp, jvp, _tensors
from paddle import grad, ones_like, zeros_like
def reduce(x):
return paddle.sum(x)
def reduce_dim(x):
return paddle.sum(x, axis=0)
def matmul(x, y):
return paddle.matmul(x, y)
def mul(x, y):
return x * y
def pow(x, y):
return paddle.pow(x, y)
def o2(x, y):
return paddle.multiply(x, y), paddle.matmul(x, y.t())
def unuse(x, y):
return paddle.sum(x)
def nested(x):
def inner(y):
return x * y
return inner
def make_v(f, inputs):
outputs = _tensors(f(*inputs), "outputs")
return [ones_like(x) for x in outputs]
class TestAutogradFunctional(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.RAW_INPUTS = {
'a': [1.0],
'b': [1.0, 2.0],
'c': [3.0, 4.0],
'd': [[2.0], [3.0]],
'A': [[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]],
'B': [[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]],
}
def setUp(self):
pass
def gen_input(self, inp, stop_gradient=False):
if isinstance(inp, paddle.Tensor):
return inp
return paddle.to_tensor(
self.RAW_INPUTS[inp], stop_gradient=stop_gradient)
def gen_inputs(self, inputs):
if isinstance(inputs, list):
inputs = [self.gen_input(x) for x in inputs]
else:
inputs = [self.gen_input(inputs)]
return inputs
def gen_test_pairs(self,
func,
inputs,
v=None,
create_graph=False,
allow_unused=False):
def vjp_test():
nonlocal v
xs = self.gen_inputs(inputs)
if v is not None:
v = self.gen_inputs(v)
outputs, inputs_grad = vjp(func,
xs,
v,
create_graph=create_graph,
allow_unused=allow_unused)
else:
outputs, inputs_grad = vjp(func,
xs,
create_graph=create_graph,
allow_unused=allow_unused)
return outputs, inputs_grad
def grad_test():
nonlocal v
xs = self.gen_inputs(inputs)
if v is not None:
v = self.gen_inputs(v)
outputs = func(*xs)
if v is not None:
inputs_grad = grad(
outputs,
xs,
v,
create_graph=create_graph,
allow_unused=allow_unused)
else:
inputs_grad = grad(
outputs,
xs,
create_graph=create_graph,
allow_unused=allow_unused)
return outputs, inputs_grad
return vjp_test, grad_test
def gen_jvp_tests(self,
func,
inputs,
v=None,
create_graph=False,
allow_unused=False):
def jvp_test():
nonlocal v
xs = self.gen_inputs(inputs)
if v is not None:
v = self.gen_inputs(v)
outputs, outputs_grad = jvp(func,
xs,
v,
create_graph=create_graph,
allow_unused=allow_unused)
else:
outputs, outputs_grad = jvp(func,
xs,
create_graph=create_graph,
allow_unused=allow_unused)
return outputs, outputs_grad
return jvp_test
def check_results(self, ref, res):
type_error = 'Result is different than expected in shape or type'
value_error = 'Result is different than expected values'
if ref is None:
self.assertTrue(res is None, type_error)
elif isinstance(ref, paddle.Tensor):
self.assertTrue(isinstance(res, paddle.Tensor), type_error)
self.assertTrue(paddle.allclose(res, ref), value_error)
else:
self.assertTrue(len(res) == len(ref), type_error)
for i in range(len(ref)):
self.check_results(ref[i], res[i])
return True
class TestVJP(TestAutogradFunctional):
def test_vjp_i1o1_no_create_graph(self):
test_cases = [
[reduce, 'A'], #noqa
[reduce_dim, 'A'], #noqa
] #noqa
for f, inputs in test_cases:
vjp, grad = self.gen_test_pairs(f, inputs)
vjp_result, grad_result = vjp(), grad()
self.check_results(grad_result, vjp_result)
def test_vjp_i2o1_no_create_graph(self):
test_cases = [
[matmul, ['A', 'B']], #noqa
[mul, ['b', 'c']], #noqa
] #noqa
for f, inputs in test_cases:
vjp, grad = self.gen_test_pairs(f, inputs)
vjp_result, grad_result = vjp(), grad()
self.check_results(grad_result, vjp_result)
def test_vjp_i2o2_no_create_graph(self):
test_cases = [
[o2, ['A', 'A']], #noqa
] #noqa
for f, inputs in test_cases:
inputs = self.gen_inputs(inputs)
v = make_v(f, inputs)
vjp, grad = self.gen_test_pairs(f, inputs, v=v)
vjp_result, grad_result = vjp(), grad()
self.check_results(grad_result, vjp_result)
def test_vjp_i2o2_omitting_v_no_create_graph(self):
test_cases = [
[o2, ['A', 'A']], #noqa
] #noqa
for f, inputs in test_cases:
inputs = self.gen_inputs(inputs)
vjp, grad = self.gen_test_pairs(f, inputs)
vjp_result, grad_result = vjp(), grad()
self.check_results(grad_result, vjp_result)
def test_vjp_nested_no_create_graph(self):
x = self.gen_input('a')
test_cases = [
[nested(x), 'a'], #noqa
]
for f, inputs in test_cases:
vjp, grad = self.gen_test_pairs(f, inputs)
vjp_result, grad_result = vjp(), grad()
self.check_results(grad_result, vjp_result)
def test_vjp_aliased_input_no_create_graph(self):
x = self.gen_input('a')
ref = self.gen_test_pairs(nested(x), 'a')[0]
aliased = self.gen_test_pairs(nested(x), x)[0]
ref_result, aliased_result = ref(), aliased()
self.check_results(ref_result, aliased_result)
def test_vjp_allowunused_no_create_graph(self):
x, y = self.gen_input('A'), self.gen_input('a')
vjp, grad = self.gen_test_pairs(unuse, [x, y], allow_unused=True)
vjp_result, grad_result = vjp(), grad()
self.check_results(grad_result, vjp_result)
def jac(grad_fn, f, inputs):
assert grad_fn in [vjp, jvp]
if grad_fn is jvp:
vs = [zeros_like(x) for x in inputs]
else:
outputs = f(*inputs)
if isinstance(outputs, paddle.Tensor):
outputs = [outputs]
vs = [zeros_like(y) for y in outputs]
JJ_cols = []
for i, v in enumerate(vs):
v = v.flatten()
for j in range(len(v)):
_v = zeros_like(v).detach()
_v[j] = 1.0
_v = _v.reshape(vs[i].shape)
_vs = vs.copy()
_vs[i] = _v
_, grads = grad_fn(f, inputs, vs)
d_outs = paddle.concat([d_out.flatten() for d_out in grads])
JJ_cols.append(d_outs)
# JJ is the fully unrolled jacobian
JJ = paddle.stack(JJ_cols)
if grad_fn is vjp:
JJ = JJ.t()
return JJ
class TestJVP(TestAutogradFunctional):
def test_jvp_i1o1_no_create_graph(self):
test_cases = [
[reduce, 'A'], #noqa
[reduce_dim, 'A'], #noqa
] #noqa
for f, inputs in test_cases:
inputs = self.gen_inputs(inputs)
forward_jac = jac(jvp, f, inputs)
reverse_jac = jac(vjp, f, inputs)
self.check_results(forward_jac, reverse_jac)
def test_jvp_i2o1_no_create_graph(self):
test_cases = [ #noqa
[matmul, ['A', 'B']], #noqa
] #noqa
for f, inputs in test_cases:
inputs = self.gen_inputs(inputs)
forward_jac = jac(jvp, f, inputs)
reverse_jac = jac(vjp, f, inputs)
self.check_results(forward_jac, reverse_jac)
def test_jvp_i2o2_no_create_graph(self):
test_cases = [ #noqa
[o2, ['A', 'A']], #noqa
] #noqa
for f, inputs in test_cases:
inputs = self.gen_inputs(inputs)
forward_jac = jac(jvp, f, inputs)
reverse_jac = jac(vjp, f, inputs)
self.check_results(forward_jac, reverse_jac)
def test_jvp_i2o2_omitting_v_no_create_graph(self):
test_cases = [ #noqa
[o2, ['A', 'A']], #noqa
] #noqa
for f, inputs in test_cases:
inputs = self.gen_inputs(inputs)
results_omitting_v = jvp(f, inputs)
v = [ones_like(x) for x in inputs]
results_with_v = jvp(f, inputs, v)
self.check_results(results_omitting_v, results_with_v)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import typing
import enum
import sys
import re
import inspect
import functools
import contextlib
import collections
import numpy as np import numpy as np
import paddle import paddle
from paddle.autograd.functional import _tensors from paddle.autograd.functional import _as_tensors
##########################################################
# Finite Difference Utils
##########################################################
def _product(t): def _product(t):
if isinstance(t, int): if isinstance(t, int):
return t return t
...@@ -25,7 +36,9 @@ def _product(t): ...@@ -25,7 +36,9 @@ def _product(t):
def _get_item(t, idx): def _get_item(t, idx):
assert isinstance(t, paddle.Tensor), "The first argument t must be Tensor." assert isinstance(
t,
paddle.fluid.framework.Variable), "The first argument t must be Tensor."
assert isinstance(idx, assert isinstance(idx,
int), "The second argument idx must be an int number." int), "The second argument idx must be an int number."
flat_t = paddle.reshape(t, [-1]) flat_t = paddle.reshape(t, [-1])
...@@ -33,7 +46,9 @@ def _get_item(t, idx): ...@@ -33,7 +46,9 @@ def _get_item(t, idx):
def _set_item(t, idx, value): def _set_item(t, idx, value):
assert isinstance(t, paddle.Tensor), "The first argument t must be Tensor." assert isinstance(
t,
paddle.fluid.framework.Variable), "The first argument t must be Tensor."
assert isinstance(idx, assert isinstance(idx,
int), "The second argument idx must be an int number." int), "The second argument idx must be an int number."
flat_t = paddle.reshape(t, [-1]) flat_t = paddle.reshape(t, [-1])
...@@ -42,8 +57,8 @@ def _set_item(t, idx, value): ...@@ -42,8 +57,8 @@ def _set_item(t, idx, value):
def _compute_numerical_jacobian(func, xs, delta, np_dtype): def _compute_numerical_jacobian(func, xs, delta, np_dtype):
xs = _tensors(xs, "xs") xs = list(_as_tensors(xs))
ys = _tensors(func(*xs), "ys") ys = list(_as_tensors(func(*xs)))
fin_size = len(xs) fin_size = len(xs)
fout_size = len(ys) fout_size = len(ys)
jacobian = list([] for _ in range(fout_size)) jacobian = list([] for _ in range(fout_size))
...@@ -59,11 +74,11 @@ def _compute_numerical_jacobian(func, xs, delta, np_dtype): ...@@ -59,11 +74,11 @@ def _compute_numerical_jacobian(func, xs, delta, np_dtype):
orig = _get_item(xs[j], q) orig = _get_item(xs[j], q)
x_pos = orig + delta x_pos = orig + delta
xs[j] = _set_item(xs[j], q, x_pos) xs[j] = _set_item(xs[j], q, x_pos)
ys_pos = _tensors(func(*xs), "ys_pos") ys_pos = _as_tensors(func(*xs))
x_neg = orig - delta x_neg = orig - delta
xs[j] = _set_item(xs[j], q, x_neg) xs[j] = _set_item(xs[j], q, x_neg)
ys_neg = _tensors(func(*xs), "ys_neg") ys_neg = _as_tensors(func(*xs))
xs[j] = _set_item(xs[j], q, orig) xs[j] = _set_item(xs[j], q, orig)
...@@ -76,8 +91,8 @@ def _compute_numerical_jacobian(func, xs, delta, np_dtype): ...@@ -76,8 +91,8 @@ def _compute_numerical_jacobian(func, xs, delta, np_dtype):
def _compute_numerical_hessian(func, xs, delta, np_dtype): def _compute_numerical_hessian(func, xs, delta, np_dtype):
xs = _tensors(xs, "xs") xs = list(_as_tensors(xs))
ys = _tensors(func(*xs), "ys") ys = list(_as_tensors(func(*xs)))
fin_size = len(xs) fin_size = len(xs)
hessian = list([] for _ in range(fin_size)) hessian = list([] for _ in range(fin_size))
for i in range(fin_size): for i in range(fin_size):
...@@ -107,10 +122,22 @@ def _compute_numerical_hessian(func, xs, delta, np_dtype): ...@@ -107,10 +122,22 @@ def _compute_numerical_hessian(func, xs, delta, np_dtype):
return hessian return hessian
def _compute_numerical_batch_jacobian(func, xs, delta, np_dtype): def concat_to_matrix(xs, is_batched=False):
"""Concats a tuple of tuple of Jacobian/Hessian matrix into one matrix"""
rows = []
for i in range(len(xs)):
rows.append(np.concatenate([x for x in xs[i]], -1))
return np.concatenate(rows, 1) if is_batched else np.concatenate(rows, 0)
def _compute_numerical_batch_jacobian(func,
xs,
delta,
np_dtype,
merge_batch=True):
no_batch_jacobian = _compute_numerical_jacobian(func, xs, delta, np_dtype) no_batch_jacobian = _compute_numerical_jacobian(func, xs, delta, np_dtype)
xs = _tensors(xs, "xs") xs = list(_as_tensors(xs))
ys = _tensors(func(*xs), "ys") ys = list(_as_tensors(func(*xs)))
fin_size = len(xs) fin_size = len(xs)
fout_size = len(ys) fout_size = len(ys)
bs = xs[0].shape[0] bs = xs[0].shape[0]
...@@ -128,7 +155,8 @@ def _compute_numerical_batch_jacobian(func, xs, delta, np_dtype): ...@@ -128,7 +155,8 @@ def _compute_numerical_batch_jacobian(func, xs, delta, np_dtype):
for b in range(bs): for b in range(bs):
for q in range(in_size): for q in range(in_size):
batch_jac_i_j[p][b][q] = jac[b][p][b][q] batch_jac_i_j[p][b][q] = jac[b][p][b][q]
batch_jac_i_j = np.reshape(batch_jac_i_j, (out_size, -1)) if merge_batch:
batch_jac_i_j = np.reshape(batch_jac_i_j, (out_size, -1))
batch_jac_i.append(batch_jac_i_j) batch_jac_i.append(batch_jac_i_j)
bat_jac.append(batch_jac_i) bat_jac.append(batch_jac_i)
...@@ -136,7 +164,7 @@ def _compute_numerical_batch_jacobian(func, xs, delta, np_dtype): ...@@ -136,7 +164,7 @@ def _compute_numerical_batch_jacobian(func, xs, delta, np_dtype):
def _compute_numerical_batch_hessian(func, xs, delta, np_dtype): def _compute_numerical_batch_hessian(func, xs, delta, np_dtype):
xs = _tensors(xs, "xs") xs = list(_as_tensors(xs))
batch_size = xs[0].shape[0] batch_size = xs[0].shape[0]
fin_size = len(xs) fin_size = len(xs)
hessian = [] hessian = []
...@@ -175,8 +203,10 @@ def _compute_numerical_batch_hessian(func, xs, delta, np_dtype): ...@@ -175,8 +203,10 @@ def _compute_numerical_batch_hessian(func, xs, delta, np_dtype):
def _compute_numerical_vjp(func, xs, v, delta, np_dtype): def _compute_numerical_vjp(func, xs, v, delta, np_dtype):
xs = _tensors(xs, "xs") xs = _as_tensors(xs)
jacobian = np.array(_compute_numerical_jacobian(func, xs, delta, np_dtype)) jacobian = np.array(_compute_numerical_jacobian(func, xs, delta, np_dtype))
if v is None:
v = [paddle.ones_like(x) for x in xs]
flat_v = np.array([v_el.numpy().reshape(-1) for v_el in v]) flat_v = np.array([v_el.numpy().reshape(-1) for v_el in v])
vjp = [np.zeros((_product(x.shape)), dtype=np_dtype) for x in xs] vjp = [np.zeros((_product(x.shape)), dtype=np_dtype) for x in xs]
for j in range(len(xs)): for j in range(len(xs)):
...@@ -188,7 +218,7 @@ def _compute_numerical_vjp(func, xs, v, delta, np_dtype): ...@@ -188,7 +218,7 @@ def _compute_numerical_vjp(func, xs, v, delta, np_dtype):
def _compute_numerical_vhp(func, xs, v, delta, np_dtype): def _compute_numerical_vhp(func, xs, v, delta, np_dtype):
xs = _tensors(xs, "xs") xs = list(_as_tensors(xs))
hessian = np.array(_compute_numerical_hessian(func, xs, delta, np_dtype)) hessian = np.array(_compute_numerical_hessian(func, xs, delta, np_dtype))
flat_v = np.array([v_el.numpy().reshape(-1) for v_el in v]) flat_v = np.array([v_el.numpy().reshape(-1) for v_el in v])
vhp = [np.zeros((_product(x.shape)), dtype=np_dtype) for x in xs] vhp = [np.zeros((_product(x.shape)), dtype=np_dtype) for x in xs]
...@@ -198,3 +228,166 @@ def _compute_numerical_vhp(func, xs, v, delta, np_dtype): ...@@ -198,3 +228,166 @@ def _compute_numerical_vhp(func, xs, v, delta, np_dtype):
flat_v) flat_v)
vhp = [vhp[j].reshape(xs[j].shape) for j in range(len(xs))] vhp = [vhp[j].reshape(xs[j].shape) for j in range(len(xs))]
return vhp return vhp
##########################################################
# TestCases of different function.
##########################################################
def reduce(x):
return paddle.sum(x)
def reduce_dim(x):
return paddle.sum(x, axis=0)
def matmul(x, y):
return paddle.matmul(x, y)
def mul(x, y):
return x * y
def pow(x, y):
return paddle.pow(x, y)
def o2(x, y):
return paddle.multiply(x, y), paddle.matmul(x, y.t())
def unuse(x, y):
return paddle.sum(x)
def nested(x):
def inner(y):
return x * y
return inner
def square(x):
return x * x
##########################################################
# Parameterized Test Utils.
##########################################################
TEST_CASE_NAME = 'suffix'
def place(devices, key='place'):
"""A Decorator for a class which will make the class running on different
devices .
Args:
devices (Sequence[Paddle.CUDAPlace|Paddle.CPUPlace]): Device list.
key (str, optional): Defaults to 'place'.
"""
def decorate(cls):
module = sys.modules[cls.__module__].__dict__
raw_classes = {
k: v
for k, v in module.items() if k.startswith(cls.__name__)
}
for raw_name, raw_cls in raw_classes.items():
for d in devices:
test_cls = dict(raw_cls.__dict__)
test_cls.update({key: d})
new_name = raw_name + '.' + d.__class__.__name__
module[new_name] = type(new_name, (raw_cls, ), test_cls)
del module[raw_name]
return cls
return decorate
def parameterize(fields, values=None):
"""Decorator for a unittest class which make the class running on different
test cases.
Args:
fields (Sequence): The feild name sequence of test cases.
values (Sequence, optional): The test cases sequence. Defaults to None.
"""
fields = [fields] if isinstance(fields, str) else fields
params = [dict(zip(fields, vals)) for vals in values]
def decorate(cls):
test_cls_module = sys.modules[cls.__module__].__dict__
for i, values in enumerate(params):
test_cls = dict(cls.__dict__)
values = {
k: staticmethod(v) if callable(v) else v
for k, v in values.items()
}
test_cls.update(values)
name = cls.__name__ + str(i)
name = name + '.' + \
values.get('suffix') if values.get('suffix') else name
test_cls_module[name] = type(name, (cls, ), test_cls)
for m in list(cls.__dict__):
if m.startswith("test"):
delattr(cls, m)
return cls
return decorate
##########################################################
# Utils for transpose different Jacobian/Hessian matrix format.
##########################################################
# B is batch size, N is row size, M is column size.
MatrixFormat = enum.Enum('MatrixFormat', ('NBM', 'BNM', 'NMB', 'NM'))
def _np_transpose_matrix_format(src, src_format, des_format):
"""Transpose Jacobian/Hessian matrix format."""
supported_format = (MatrixFormat.NBM, MatrixFormat.BNM, MatrixFormat.NMB)
if src_format not in supported_format or des_format not in supported_format:
raise ValueError(
f"Supported Jacobian format is {supported_format}, but got src: {src_format}, des: {des_format}"
)
src_axis = {c: i for i, c in enumerate(src_format.name)}
dst_axis = tuple(src_axis[c] for c in des_format.name)
return np.transpose(src, dst_axis)
def _np_concat_matrix_sequence(src, src_format=MatrixFormat.NM):
"""Convert a sequence of sequence of Jacobian/Hessian matrix into one huge
matrix."""
def concat_col(xs):
if src_format in (MatrixFormat.NBM, MatrixFormat.BNM, MatrixFormat.NM):
return np.concatenate(xs, axis=-1)
else:
return np.concatenate(xs, axis=1)
def concat_row(xs):
if src_format in (MatrixFormat.NBM, MatrixFormat.NM, MatrixFormat.NMB):
return np.concatenate(xs, axis=0)
else:
return np.concatenate(xs, axis=1)
supported_format = (MatrixFormat.NBM, MatrixFormat.BNM, MatrixFormat.NMB,
MatrixFormat.NM)
if src_format not in supported_format:
raise ValueError(
f"Supported Jacobian format is {supported_format}, but got {src_format}"
)
if not isinstance(src, typing.Sequence):
return src
if not isinstance(src[0], typing.Sequence):
src = [src]
return concat_row(tuple(concat_col(xs) for xs in src))
...@@ -26,6 +26,7 @@ from .tensor import segment_mean ...@@ -26,6 +26,7 @@ from .tensor import segment_mean
from .tensor import segment_max from .tensor import segment_max
from .tensor import segment_min from .tensor import segment_min
from .passes import fuse_resnet_unit_pass from .passes import fuse_resnet_unit_pass
import paddle.incubate.autograd
from . import nn #noqa: F401 from . import nn #noqa: F401
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.autograd.functional import Hessian, Jacobian, jvp, vjp
__all__ = [ # noqa
'vjp', 'jvp', 'Jacobian', 'Hessian'
]
...@@ -273,6 +273,7 @@ packages=['paddle', ...@@ -273,6 +273,7 @@ packages=['paddle',
'paddle.distributed.ps', 'paddle.distributed.ps',
'paddle.distributed.ps.utils', 'paddle.distributed.ps.utils',
'paddle.incubate', 'paddle.incubate',
'paddle.incubate.autograd',
'paddle.incubate.optimizer', 'paddle.incubate.optimizer',
'paddle.incubate.checkpoint', 'paddle.incubate.checkpoint',
'paddle.incubate.operators', 'paddle.incubate.operators',
......
...@@ -12,55 +12,8 @@ ...@@ -12,55 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
set -e
set +x
NIGHTLY_MODE=$1
PRECISION_TEST=$2
WITH_GPU=$3
export PADDLE_ROOT="$(cd "$PWD/../" && pwd )"
if [ ${NIGHTLY_MODE:-OFF} == "ON" ]; then
nightly_label=""
else
nightly_label="(RUN_TYPE=NIGHTLY|RUN_TYPE=DIST:NIGHTLY|RUN_TYPE=EXCLUSIVE:NIGHTLY)"
echo "========================================="
echo "Unittests with nightly labels are only run at night"
echo "========================================="
fi
if disable_ut_quickly=$(python ${PADDLE_ROOT}/tools/get_quick_disable_lt.py); then
echo "========================================="
echo "The following unittests have been disabled:"
echo ${disable_ut_quickly}
echo "========================================="
else
disable_ut_quickly=''
fi
# check added ut
set +e # /*================Fixed Disabled Windows CUDA10.x MKL(PR-CI-Windows) unittests===========================*/
cp $PADDLE_ROOT/tools/check_added_ut.sh $PADDLE_ROOT/tools/check_added_ut_win.sh
bash $PADDLE_ROOT/tools/check_added_ut_win.sh
rm -rf $PADDLE_ROOT/tools/check_added_ut_win.sh
if [ -f "$PADDLE_ROOT/added_ut" ];then
added_uts=^$(awk BEGIN{RS=EOF}'{gsub(/\n/,"$|^");print}' $PADDLE_ROOT/added_ut)$
ctest -R "(${added_uts})" --output-on-failure -C Release --repeat-until-fail 3;added_ut_error=$?
rm -f $PADDLE_ROOT/added_ut
if [ "$added_ut_error" != 0 ];then
echo "========================================"
echo "Added UT should pass three additional executions"
echo "========================================"
exit 8;
fi
if nvcc --version | grep 11.2; then
echo "Only test added_ut temporarily when running in CI-Windows-inference of CUDA 11.2."
exit 0;
fi
fi
set -e
# /*==================Fixed Disabled Windows GPU MKL unittests==============================*/
# TODO: fix these unittest that is bound to fail # TODO: fix these unittest that is bound to fail
disable_wingpu_test="^test_model$|\ disable_wingpu_test="^test_model$|\
^test_dataloader_early_reset$|\ ^test_dataloader_early_reset$|\
...@@ -97,7 +50,7 @@ disable_wingpu_test="^test_model$|\ ...@@ -97,7 +50,7 @@ disable_wingpu_test="^test_model$|\
^test_bilinear_interp_op$|\ ^test_bilinear_interp_op$|\
^disable_wingpu_test$" ^disable_wingpu_test$"
# /*==================Fixed Disabled Windows GPU MKL unittests==============================*/ # /*=================Fixed Disabled Windows TRT MKL unittests=======================*/
# TODO: fix these unittest that is bound to fail # TODO: fix these unittest that is bound to fail
disable_win_trt_test="^test_trt_convert_conv2d$|\ disable_win_trt_test="^test_trt_convert_conv2d$|\
^test_trt_convert_conv2d_fusion$|\ ^test_trt_convert_conv2d_fusion$|\
...@@ -119,7 +72,13 @@ disable_win_trt_test="^test_trt_convert_conv2d$|\ ...@@ -119,7 +72,13 @@ disable_win_trt_test="^test_trt_convert_conv2d$|\
^test_trt_convert_matmul$|\ ^test_trt_convert_matmul$|\
^test_trt_convert_scale$" ^test_trt_convert_scale$"
# /*==================Fixed Disabled Windows GPU inference_api_test unittests==============================*/ # /*=============Fixed Disabled Windows CUDA11.x MKL(PR-CI-Windows-Inference) unittests=================*/
# TODO: fix these unittest that is bound to fail
disable_wingpu11_test="^test_autograd_functional_dynamic$|\
^disable_wingpu_test$"
# /*==========Fixed Disabled Windows CUDA11.x inference_api_test(PR-CI-Windows-Inference) unittests=============*/
disable_win_inference_api_test="^trt_quant_int8_yolov3_r50_test$|\ disable_win_inference_api_test="^trt_quant_int8_yolov3_r50_test$|\
^test_trt_dynamic_shape_ernie$|\ ^test_trt_dynamic_shape_ernie$|\
^test_trt_dynamic_shape_ernie_fp16_ser_deser$|\ ^test_trt_dynamic_shape_ernie_fp16_ser_deser$|\
...@@ -128,9 +87,8 @@ disable_win_inference_api_test="^trt_quant_int8_yolov3_r50_test$|\ ...@@ -128,9 +87,8 @@ disable_win_inference_api_test="^trt_quant_int8_yolov3_r50_test$|\
^lite_mul_model_test$|\ ^lite_mul_model_test$|\
^paddle_infer_api_copy_tensor_tester$" ^paddle_infer_api_copy_tensor_tester$"
# /*============================================================================*/
# /*==================Fixed Disabled Windows CPU OPENBLAS unittests==============================*/ # /*==========Fixed Disabled Windows CPU OPENBLAS((PR-CI-Windows-OPENBLAS)) unittests==============================*/
# TODO: fix these unittest that is bound to fail # TODO: fix these unittest that is bound to fail
disable_wincpu_test="^jit_kernel_test$|\ disable_wincpu_test="^jit_kernel_test$|\
^test_analyzer_transformer$|\ ^test_analyzer_transformer$|\
...@@ -189,6 +147,58 @@ long_time_test="^test_gru_op$|\ ...@@ -189,6 +147,58 @@ long_time_test="^test_gru_op$|\
^test_trt_matmul_quant_dequant$|\ ^test_trt_matmul_quant_dequant$|\
^test_strided_slice_op$" ^test_strided_slice_op$"
# /*============================================================================*/
set -e
set +x
NIGHTLY_MODE=$1
PRECISION_TEST=$2
WITH_GPU=$3
export PADDLE_ROOT="$(cd "$PWD/../" && pwd )"
if [ ${NIGHTLY_MODE:-OFF} == "ON" ]; then
nightly_label=""
else
nightly_label="(RUN_TYPE=NIGHTLY|RUN_TYPE=DIST:NIGHTLY|RUN_TYPE=EXCLUSIVE:NIGHTLY)"
echo "========================================="
echo "Unittests with nightly labels are only run at night"
echo "========================================="
fi
if disable_ut_quickly=$(python ${PADDLE_ROOT}/tools/get_quick_disable_lt.py); then
echo "========================================="
echo "The following unittests have been disabled:"
echo ${disable_ut_quickly}
echo "========================================="
else
disable_ut_quickly=''
fi
# check added ut
set +e
cp $PADDLE_ROOT/tools/check_added_ut.sh $PADDLE_ROOT/tools/check_added_ut_win.sh
bash $PADDLE_ROOT/tools/check_added_ut_win.sh
rm -rf $PADDLE_ROOT/tools/check_added_ut_win.sh
if [ -f "$PADDLE_ROOT/added_ut" ];then
added_uts=^$(awk BEGIN{RS=EOF}'{gsub(/\n/,"$|^");print}' $PADDLE_ROOT/added_ut)$
ctest -R "(${added_uts})" -E "$disable_wingpu11_test" --output-on-failure -C Release --repeat-until-fail 3;added_ut_error=$?
rm -f $PADDLE_ROOT/added_ut
if [ "$added_ut_error" != 0 ];then
echo "========================================"
echo "Added UT should pass three additional executions"
echo "========================================"
exit 8;
fi
if nvcc --version | grep 11.2; then
echo "Only test added_ut temporarily when running in CI-Windows-inference of CUDA 11.2."
exit 0;
fi
fi
set -e
if [ ${WITH_GPU:-OFF} == "ON" ];then if [ ${WITH_GPU:-OFF} == "ON" ];then
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册