# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import collections import typing import unittest import config import numpy as np import utils from utils import matmul, mul, nested, o2, reduce, reduce_dim import paddle import paddle.nn.functional as F from paddle.incubate.autograd.utils import as_tensors def make_v(f, inputs): outputs = as_tensors(f(*inputs)) return [paddle.ones_like(x) for x in outputs] class TestAutogradFunctional(unittest.TestCase): @classmethod def setUpClass(cls): cls.RAW_INPUTS = { 'a': [1.0], 'b': [1.0, 2.0], 'c': [3.0, 4.0], 'd': [[2.0], [3.0]], 'A': [[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]], 'B': [[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], } def setUp(self): pass def gen_input(self, inp, stop_gradient=False): if isinstance(inp, paddle.Tensor): return inp return paddle.to_tensor( self.RAW_INPUTS[inp], stop_gradient=stop_gradient ) def gen_inputs(self, inputs): if isinstance(inputs, list): inputs = [self.gen_input(x) for x in inputs] else: inputs = [self.gen_input(inputs)] return inputs def gen_test_pairs( self, func, inputs, v=None, create_graph=False, allow_unused=False ): def vjp_test(): nonlocal v xs = self.gen_inputs(inputs) if v is not None: v = self.gen_inputs(v) outputs, inputs_grad = paddle.incubate.autograd.vjp(func, xs, v) else: outputs, inputs_grad = paddle.incubate.autograd.vjp(func, xs) return outputs, inputs_grad def grad_test(): nonlocal v xs = self.gen_inputs(inputs) if v is not None: v = self.gen_inputs(v) outputs = func(*xs) if v is not None: inputs_grad = paddle.grad( outputs, xs, v, create_graph=create_graph, allow_unused=allow_unused, ) else: inputs_grad = paddle.grad( outputs, xs, create_graph=create_graph, allow_unused=allow_unused, ) return outputs, inputs_grad return vjp_test, grad_test def gen_jvp_tests( self, func, inputs, v=None, create_graph=False, allow_unused=False ): def jvp_test(): nonlocal v xs = self.gen_inputs(inputs) if v is not None: v = self.gen_inputs(v) outputs, outputs_grad = paddle.incubate.autograd.jvp( func, xs, v, create_graph=create_graph, allow_unused=allow_unused, ) else: outputs, outputs_grad = paddle.incubate.autograd.jvp( func, xs, create_graph=create_graph, allow_unused=allow_unused, ) return outputs, outputs_grad return jvp_test def check_results(self, ref, res): type_error = 'Result is different than expected in shape or type' value_error = 'Result is different than expected values' if ref is None: self.assertTrue(res is None, type_error) elif isinstance(ref, paddle.Tensor): self.assertTrue(isinstance(res, paddle.Tensor), type_error) np.testing.assert_allclose(res, ref) else: self.assertTrue(len(res) == len(ref), type_error) for i in range(len(ref)): self.check_results(ref[i], res[i]) return True class TestVJP(TestAutogradFunctional): def func_vjp_i1o1(self): test_cases = [ [reduce, 'A'], # noqa [reduce_dim, 'A'], # noqa ] # noqa for f, inputs in test_cases: vjp, grad = self.gen_test_pairs(f, inputs) vjp_result, grad_result = vjp(), grad() self.check_results(grad_result, vjp_result) def func_vjp_i2o1(self): test_cases = [ [matmul, ['A', 'B']], # noqa [mul, ['b', 'c']], # noqa ] # noqa for f, inputs in test_cases: vjp, grad = self.gen_test_pairs(f, inputs) vjp_result, grad_result = vjp(), grad() self.check_results(grad_result, vjp_result) def func_vjp_i2o2(self): test_cases = [ [o2, ['A', 'A']], # noqa ] # noqa for f, inputs in test_cases: inputs = self.gen_inputs(inputs) v = make_v(f, inputs) vjp, grad = self.gen_test_pairs(f, inputs, v=v) vjp_result, grad_result = vjp(), grad() self.check_results(grad_result, vjp_result) def func_vjp_i2o2_omitting_v(self): test_cases = [ [o2, ['A', 'A']], # noqa ] # noqa for f, inputs in test_cases: inputs = self.gen_inputs(inputs) vjp, grad = self.gen_test_pairs(f, inputs) vjp_result, grad_result = vjp(), grad() self.check_results(grad_result, vjp_result) def func_vjp_nested(self): x = self.gen_input('a') test_cases = [ [nested(x), 'a'], # noqa ] for f, inputs in test_cases: vjp, grad = self.gen_test_pairs(f, inputs) vjp_result, grad_result = vjp(), grad() self.check_results(grad_result, vjp_result) def func_vjp_aliased_input(self): x = self.gen_input('a') ref = self.gen_test_pairs(nested(x), 'a')[0] aliased = self.gen_test_pairs(nested(x), x)[0] ref_result, aliased_result = ref(), aliased() self.check_results(ref_result, aliased_result) def test_all_cases(self): self.func_vjp_i1o1() self.func_vjp_i2o1() self.func_vjp_i2o2() self.func_vjp_i2o2_omitting_v() self.func_vjp_nested() self.func_vjp_aliased_input() def test_input_single_tensor(self): self.assertIsInstance( paddle.incubate.autograd.vjp(paddle.tanh, paddle.rand((3, 4)))[1], paddle.fluid.framework.Variable, ) @utils.place(config.DEVICES) @utils.parameterize( (utils.TEST_CASE_NAME, 'fun', 'xs', 'v', 'expected_exception'), ( ( 'v_shape_not_equal_ys', utils.square, np.random.rand(3), np.random.rand(1), RuntimeError, ), ), ) class TestVJPException(unittest.TestCase): def test_vjp(self): with self.assertRaises(self.expected_exception): paddle.incubate.autograd.vjp( self.fun, paddle.to_tensor(self.xs), paddle.to_tensor(self.v) ) def jac(grad_fn, f, inputs): assert grad_fn in [ paddle.incubate.autograd.vjp, paddle.incubate.autograd.jvp, ] if grad_fn is paddle.incubate.autograd.jvp: vs = [paddle.zeros_like(x) for x in inputs] else: outputs = f(*inputs) if isinstance(outputs, paddle.Tensor): outputs = [outputs] vs = [paddle.zeros_like(y) for y in outputs] JJ_cols = [] for i, v in enumerate(vs): v = v.flatten() for j in range(len(v)): _v = paddle.zeros_like(v).detach() _v[j] = 1.0 _v = _v.reshape(vs[i].shape) _vs = vs.copy() _vs[i] = _v _, grads = grad_fn(f, inputs, _vs) d_outs = paddle.concat([d_out.flatten() for d_out in grads]) JJ_cols.append(d_outs) # JJ is the fully unrolled jacobian JJ = paddle.stack(JJ_cols) if grad_fn is paddle.incubate.autograd.vjp: JJ = JJ.t() return JJ class TestJVP(TestAutogradFunctional): def func_jvp_i1o1(self): test_cases = [ [reduce, 'A'], # noqa [reduce_dim, 'A'], # noqa ] # noqa for f, inputs in test_cases: inputs = self.gen_inputs(inputs) forward_jac = jac(paddle.incubate.autograd.jvp, f, inputs) reverse_jac = jac(paddle.incubate.autograd.vjp, f, inputs) self.check_results(forward_jac, reverse_jac) def func_jvp_i2o1(self): test_cases = [ # noqa [matmul, ['A', 'B']], # noqa ] # noqa for f, inputs in test_cases: inputs = self.gen_inputs(inputs) forward_jac = jac(paddle.incubate.autograd.jvp, f, inputs) reverse_jac = jac(paddle.incubate.autograd.vjp, f, inputs) self.check_results(forward_jac, reverse_jac) def func_jvp_i2o2(self): test_cases = [ # noqa [o2, ['A', 'A']], # noqa ] # noqa for f, inputs in test_cases: inputs = self.gen_inputs(inputs) forward_jac = jac(paddle.incubate.autograd.jvp, f, inputs) reverse_jac = jac(paddle.incubate.autograd.vjp, f, inputs) self.check_results(forward_jac, reverse_jac) def func_jvp_i2o2_omitting_v(self): test_cases = [ # noqa [o2, ['A', 'A']], # noqa ] # noqa for f, inputs in test_cases: inputs = self.gen_inputs(inputs) results_omitting_v = paddle.incubate.autograd.jvp(f, inputs) v = [paddle.ones_like(x) for x in inputs] results_with_v = paddle.incubate.autograd.jvp(f, inputs, v) self.check_results(results_omitting_v, results_with_v) def test_all_cases(self): self.func_jvp_i1o1() self.func_jvp_i2o1() self.func_jvp_i2o2() self.func_jvp_i2o2_omitting_v() @utils.place(config.DEVICES) @utils.parameterize( (utils.TEST_CASE_NAME, 'func', 'xs'), ( ('1d_in_1d_out', utils.square, np.array([2.0, 3.0])), ('3d_in_3d_out', utils.square, np.random.rand(2, 3, 4)), ('single_in_single_out', utils.square, np.random.rand(2, 3)), ( 'multi_in_single_out', paddle.matmul, (np.random.rand(2, 2), np.random.rand(2, 2)), ), ), ) class TestJacobianNoBatch(unittest.TestCase): def setUp(self): self._dtype = ( self.xs[0].dtype if isinstance(self.xs, typing.Sequence) else self.xs.dtype ) self._eps = ( config.TOLERANCE.get(str(self._dtype)) .get("first_order_grad") .get("eps") ) self._rtol = ( config.TOLERANCE.get(str(self._dtype)) .get("first_order_grad") .get("rtol") ) self._atol = ( config.TOLERANCE.get(str(self._dtype)) .get("first_order_grad") .get("atol") ) def test_jacobian(self): xs = ( [paddle.to_tensor(x) for x in self.xs] if isinstance(self.xs, typing.Sequence) else paddle.to_tensor(self.xs) ) self._actual = paddle.incubate.autograd.Jacobian(self.func, xs, False) self._expected = self._get_expected() Index = collections.namedtuple('Index', ('type', 'value')) indexes = ( Index('all', (slice(0, None, None), slice(0, None, None))), Index('row', (0, slice(0, None, None))), Index('col', (slice(0, None, None), 0)), Index('multi-row', (slice(0, 2, 1), slice(0, None, None))), ) self.assertEqual(self._actual[:].numpy().dtype, self._expected.dtype) for index in indexes: np.testing.assert_allclose( self._actual.__getitem__(index.value), self._expected.__getitem__(index.value), rtol=self._rtol, atol=self._atol, err_msg=f'Testcase {index.type} index not passed, value is {index.value}', ) def _get_expected(self): xs = ( [paddle.to_tensor(x) for x in self.xs] if isinstance(self.xs, typing.Sequence) else paddle.to_tensor(self.xs) ) jac = utils._compute_numerical_jacobian( self.func, xs, self._eps, self._dtype ) return utils._np_concat_matrix_sequence(jac, utils.MatrixFormat.NM) @utils.place(config.DEVICES) @utils.parameterize( (utils.TEST_CASE_NAME, 'func', 'xs'), ( ( '1d_in_1d_out', utils.square, np.array([[1.0, 2.0, 3.0], [3.0, 4.0, 3.0]]), ), ('3d_in_3d_out', utils.square, np.random.rand(2, 3, 4)), ('multi_in_single_out', utils.square, np.random.rand(2, 3)), ), ) class TestJacobianBatchFirst(unittest.TestCase): def setUp(self): self._dtype = ( self.xs[0].dtype if isinstance(self.xs, typing.Sequence) else self.xs.dtype ) self._eps = ( config.TOLERANCE.get(str(self._dtype)) .get("first_order_grad") .get("eps") ) self._rtol = ( config.TOLERANCE.get(str(self._dtype)) .get("first_order_grad") .get("rtol") ) self._atol = ( config.TOLERANCE.get(str(self._dtype)) .get("first_order_grad") .get("atol") ) def test_jacobian(self): xs = ( [paddle.to_tensor(x) for x in self.xs] if isinstance(self.xs, typing.Sequence) else paddle.to_tensor(self.xs) ) self._actual = paddle.incubate.autograd.Jacobian(self.func, xs, True) self._expected = self._get_expected() Index = collections.namedtuple('Index', ('type', 'value')) indexes = ( Index( 'all', ( slice(0, None, None), slice(0, None, None), slice(0, None, None), ), ), Index('row', (slice(0, None, None), 0, slice(0, None, None))), Index('col', (slice(0, None, None), slice(0, None, None), 0)), Index( 'batch', (slice(0, 2, None), slice(0, None, None), slice(0, None, None)), ), Index( 'multi_row', (slice(0, 1, None), slice(0, 2, 1), slice(0, None, None)), ), ) self.assertEqual(self._actual[:].numpy().dtype, self._expected.dtype) for index in indexes: np.testing.assert_allclose( self._actual.__getitem__(index.value), self._expected.__getitem__(index.value), rtol=self._rtol, atol=self._atol, err_msg=f'Testcase {index.type} index not passed, value is {index.value}', ) def _get_expected(self): xs = ( [paddle.to_tensor(x) for x in self.xs] if isinstance(self.xs, typing.Sequence) else paddle.to_tensor(self.xs) ) jac = utils._compute_numerical_batch_jacobian( self.func, xs, self._eps, self._dtype, False ) jac = utils._np_concat_matrix_sequence(jac, utils.MatrixFormat.NBM) return utils._np_transpose_matrix_format( jac, utils.MatrixFormat.NBM, utils.MatrixFormat.BNM ) class TestHessianNoBatch(unittest.TestCase): @classmethod def setUpClass(self): self.shape = (2, 2) self.dtype = 'float32' self.np_dtype = np.float32 self.numerical_delta = ( config.TOLERANCE.get(self.dtype).get("second_order_grad").get("eps") ) self.rtol = ( config.TOLERANCE.get(self.dtype) .get("second_order_grad") .get("rtol") ) self.atol = ( config.TOLERANCE.get(self.dtype) .get("second_order_grad") .get("atol") ) self.x = paddle.rand(shape=self.shape, dtype=self.dtype) self.y = paddle.rand(shape=self.shape, dtype=self.dtype) def func_single_input(self): def func(x): return paddle.sum(paddle.matmul(x, x)) numerical_hessian = utils._compute_numerical_hessian( func, self.x, self.numerical_delta, self.np_dtype ) numerical_hessian = utils._np_concat_matrix_sequence(numerical_hessian) self.x.stop_gradient = False hessian = paddle.incubate.autograd.Hessian(func, self.x) np.testing.assert_allclose( hessian[:].numpy(), numerical_hessian, self.rtol, self.atol ) def func_multi_input(self): def func(x, y): return paddle.sum(paddle.matmul(x, y)) numerical_hessian = utils._compute_numerical_hessian( func, [self.x, self.y], self.numerical_delta, self.np_dtype ) numerical_hessian = utils._np_concat_matrix_sequence(numerical_hessian) self.x.stop_gradient = False self.y.stop_gradient = False hessian = paddle.incubate.autograd.Hessian(func, [self.x, self.y]) np.testing.assert_allclose( hessian[:].numpy(), numerical_hessian, rtol=self.rtol, atol=self.atol, ) def func_allow_unused_true(self): def func(x, y): return paddle.sum(paddle.matmul(x, x)) numerical_hessian = utils._compute_numerical_hessian( func, [self.x, self.y], self.numerical_delta, self.np_dtype ) numerical_hessian = utils._np_concat_matrix_sequence(numerical_hessian) self.x.stop_gradient = False self.y.stop_gradient = False hessian = paddle.incubate.autograd.Hessian(func, [self.x, self.y]) np.testing.assert_allclose( hessian[:].numpy(), numerical_hessian, self.rtol, self.atol ) def func_create_graph_true(self): def func(x): return paddle.sum(F.sigmoid(x)) numerical_hessian = utils._compute_numerical_hessian( func, self.x, self.numerical_delta, self.np_dtype ) numerical_hessian = utils._np_concat_matrix_sequence(numerical_hessian) self.x.stop_gradient = False hessian = paddle.incubate.autograd.Hessian(func, self.x) assert not hessian[:].stop_gradient np.testing.assert_allclose( hessian[:].numpy(), numerical_hessian, self.rtol, self.atol ) def func_out_not_single(self): def func(x): return x * x with self.assertRaises(RuntimeError): paddle.incubate.autograd.Hessian(func, paddle.ones([3])) def test_all_cases(self): self.setUpClass() self.func_single_input() self.func_multi_input() self.func_allow_unused_true() self.func_create_graph_true() self.func_out_not_single() class TestHessianBatchFirst(unittest.TestCase): @classmethod def setUpClass(self): self.x_shape = (5, 2) self.weight_shape = (2, 4) self.y_shape = (5, 2) self.nbatch, self.nrow = 5, 2 self.dtype = 'float32' self.np_dtype = np.float32 self.numerical_delta = ( config.TOLERANCE.get(self.dtype).get('second_order_grad').get('eps') ) self.rtol = ( config.TOLERANCE.get(self.dtype) .get('second_order_grad') .get('rtol') ) self.atol = ( config.TOLERANCE.get(self.dtype) .get('second_order_grad') .get('atol') ) self.x = paddle.rand(shape=self.x_shape, dtype=self.dtype) self.weight = paddle.rand(shape=self.weight_shape, dtype=self.dtype) self.y = paddle.rand(shape=self.y_shape, dtype=self.dtype) def func_single_input(self): def func(x): return paddle.matmul(x * x, self.weight)[:, 0:1] expected = utils._compute_numerical_batch_hessian( func, self.x, self.numerical_delta, self.np_dtype ) H = paddle.incubate.autograd.Hessian(func, self.x, is_batched=True) actual = utils._np_transpose_matrix_format( H[:].numpy(), utils.MatrixFormat.BNM, utils.MatrixFormat.NBM ) actual = actual.reshape((H.shape[1], -1)) np.testing.assert_allclose(actual, expected, self.rtol, self.atol) def func_multi_input(self): def func(x, y): return paddle.matmul(x * x * y * y, self.weight)[:, 0:1] xs_len = 2 expected = utils._compute_numerical_batch_hessian( func, [self.x, self.y], self.numerical_delta, self.np_dtype ) expected = np.reshape( np.array(expected), (xs_len, xs_len, self.nrow, self.nbatch, self.nrow), ) expected = [[n for n in row] for row in expected] expected = utils._np_concat_matrix_sequence(expected) self.x.stop_gradient = False self.y.stop_gradient = False H = paddle.incubate.autograd.Hessian( func, [self.x, self.y], is_batched=True ) actual = utils._np_transpose_matrix_format( H[:].numpy(), utils.MatrixFormat.BNM, utils.MatrixFormat.NBM ) np.testing.assert_allclose(actual, expected, self.rtol, self.atol) def func_allow_unused(self): def func(x, y): return paddle.matmul(x * x, self.weight)[:, 0:1] xs_len = 2 expected = utils._compute_numerical_batch_hessian( func, [self.x, self.y], self.numerical_delta, self.np_dtype ) expected = np.reshape( np.array(expected), (xs_len, xs_len, self.nrow, self.nbatch, self.nrow), ) expected = [[n for n in row] for row in expected] expected = utils._np_concat_matrix_sequence(expected) expected = utils._np_transpose_matrix_format( expected, utils.MatrixFormat.NBM, utils.MatrixFormat.BNM ) actual = paddle.incubate.autograd.Hessian( func, [self.x, self.y], is_batched=True )[:] np.testing.assert_allclose( actual, expected, rtol=self.rtol, atol=self.atol ) def func_stop_gradient(self): def func(x): return paddle.matmul(x * x, self.weight)[:, 0:1] expected = utils._compute_numerical_batch_hessian( func, self.x, self.numerical_delta, self.np_dtype ) x = self.x.clone() x.stop_gradient = True H = paddle.incubate.autograd.Hessian(func, self.x, is_batched=True)[:] actual = utils._np_transpose_matrix_format( H[:].numpy(), utils.MatrixFormat.BNM, utils.MatrixFormat.NBM ) actual = actual.reshape((H.shape[1], -1)) np.testing.assert_allclose(actual, expected, self.rtol, self.atol) def func_out_not_single(self): def func(x): return x * x with self.assertRaises(RuntimeError): paddle.incubate.autograd.Hessian( func, paddle.ones((3, 3)), is_batched=True ) def test_all_cases(self): self.setUpClass() self.func_single_input() self.func_multi_input() self.func_allow_unused() self.func_stop_gradient() self.func_out_not_single() if __name__ == "__main__": np.random.seed(2022) unittest.main()