diff --git a/python/paddle/autograd/functional.py b/python/paddle/autograd/functional.py index 8dda4811d1b26355ce48736fef1fefc8527687d5..aa3e99978b72a0412e8199a6d4a5b51506a9ee3d 100644 --- a/python/paddle/autograd/functional.py +++ b/python/paddle/autograd/functional.py @@ -17,6 +17,7 @@ import typing import paddle from paddle.fluid import framework +from paddle.autograd.utils import as_tensors def vjp(func, xs, v=None): @@ -346,10 +347,16 @@ class _Jacobian(object): """ def __init__(self, func, xs): - self._xs = _separate(xs) - self._ys = func(*_as_tensors(self._xs)) - self._flatten_xs = self._flatten(_as_tensors(self._xs)) - self._flatten_ys = self._flatten(_as_tensors(self._ys)) + # Skip separating in prim mode temporarily, as detach and clone are not + # primitive operators. + if not paddle.fluid._non_static_mode( + ) and paddle.incubate.autograd.prim_enabled(): + self._xs = xs + else: + self._xs = _separate(xs) + self._ys = func(*as_tensors(self._xs)) + self._flatten_xs = self._flatten(as_tensors(self._xs)) + self._flatten_ys = self._flatten(as_tensors(self._ys)) self._cache = {} @property @@ -385,9 +392,13 @@ class _Jacobian(object): return self._cached_evaluate( indexes[self._lazy_axis])[other_indexes] lazy_indexes = self._lazy_indexes(indexes) - part_jac = paddle.stack( + # Using concat and reshape to replace stack operator temporarily, as + # it is not a primitive operator. + shape = list(self.shape) + shape[self._lazy_axis] = len(lazy_indexes) + part_jac = paddle.concat( [self._cached_evaluate(i) for i in lazy_indexes], - axis=self._lazy_axis) + axis=self._lazy_axis).reshape(shape) return part_jac[self._shifted_indexes(indexes, len(lazy_indexes))] def _cached_evaluate(self, k): @@ -449,7 +460,7 @@ class _JacobianBatchLast(_Jacobian): def _flatten(self, xs): return paddle.concat( - tuple(x.reshape((-1, x.shape[-1])) for x in _as_tensors(xs)), 0) + tuple(x.reshape((-1, x.shape[-1])) for x in as_tensors(xs)), 0) def _evaluate(self, row): return self._flatten(_grad(self._flatten_ys[row, :], self._xs)) @@ -475,7 +486,7 @@ class _JacobianBatchFirst(_Jacobian): def _flatten(self, xs): return paddle.concat( - tuple(x.reshape((x.shape[0], -1)) for x in _as_tensors(xs)), 1) + tuple(x.reshape((x.shape[0], -1)) for x in as_tensors(xs)), 1) def _evaluate(self, row_index): return self._flatten(_grad(self._flatten_ys[:, row_index], self._xs)) @@ -526,10 +537,6 @@ def _multi_index(indexes, shape): return tuple(positive_indexes) -def _as_tensors(xs): - return (xs, ) if isinstance(xs, framework.Variable) else xs - - def _stack_tensor_or_return_none(origin_list): assert len(origin_list) > 0, "Can't not stack an empty list" return paddle.stack(origin_list, axis=0) if isinstance( @@ -683,7 +690,7 @@ def _check_v_shape(v, refs): if v is None: return - v, refs = _as_tensors(v), _as_tensors(refs) + v, refs = as_tensors(v), as_tensors(refs) if len(refs) != len(v): raise RuntimeError(f"The argument v is a tuple of invalid length:" f"should be {len(refs)} but got {len(v)}.") @@ -805,8 +812,8 @@ def jacobian(func, inputs, create_graph=False, allow_unused=False): # [0., 0., 0., 2.]]), None)) ''' - inputs = _as_tensors(inputs) - outputs = _as_tensors(func(*inputs)) + inputs = as_tensors(inputs) + outputs = as_tensors(func(*inputs)) fin_size = len(inputs) fout_size = len(outputs) flat_outputs = tuple( @@ -942,8 +949,8 @@ def batch_jacobian(func, inputs, create_graph=False, allow_unused=False): ''' - inputs = _as_tensors(inputs) - outputs = _as_tensors(func(*inputs)) + inputs = as_tensors(inputs) + outputs = as_tensors(func(*inputs)) batch_size = inputs[0].shape[0] for input in inputs: @@ -1103,7 +1110,7 @@ def batch_hessian(func, inputs, create_graph=False, allow_unused=False): # [0., 2., 0., 2., 0., 2., 0., 2.]]), None), (None, None)) ''' - inputs = _as_tensors(inputs) + inputs = as_tensors(inputs) outputs = func(*inputs) batch_size = inputs[0].shape[0] for input in inputs: @@ -1234,7 +1241,7 @@ def hessian(func, inputs, create_graph=False, allow_unused=False): # [0., 1., 1., 2.]]), None), (None, None)) ''' - inputs = _as_tensors(inputs) + inputs = as_tensors(inputs) outputs = func(*inputs) assert isinstance(outputs, paddle.Tensor) and outputs.shape == [ 1 @@ -1339,12 +1346,12 @@ def vhp(func, inputs, v=None, create_graph=False, allow_unused=False): # [[8., 8.], # [8., 8.]]), None]) ''' - xs = _as_tensors(inputs) + xs = as_tensors(inputs) if v is not None: - v = _as_tensors(v) + v = as_tensors(v) xs, v = _separate(xs), _separate(v) outputs = func(*xs) - ys = _as_tensors(outputs) + ys = as_tensors(outputs) assert len(ys) == 1 and isinstance( ys[0], framework.Variable ) and ys[0].shape == [ diff --git a/python/paddle/autograd/utils.py b/python/paddle/autograd/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6b8865f4d7df012e902e5e24db57a32bec708074 --- /dev/null +++ b/python/paddle/autograd/utils.py @@ -0,0 +1,26 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing + +from paddle.fluid import framework + + +def as_tensors(xs): + if isinstance(xs, framework.Variable): + return (xs, ) + elif isinstance(xs, typing.Sequence): + return tuple(xs) + else: + return xs diff --git a/python/paddle/fluid/tests/unittests/autograd/CMakeLists.txt b/python/paddle/fluid/tests/unittests/autograd/CMakeLists.txt index c2ccad7dd24f0510ce332a5ff0527b51e8bbd69d..832ecc61ee19062194b66dcced271f586f3b4bdb 100644 --- a/python/paddle/fluid/tests/unittests/autograd/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/autograd/CMakeLists.txt @@ -5,10 +5,19 @@ file( string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") set(GC_ENVS FLAGS_eager_delete_tensor_gb=0.0) +if(WIN32) + # TODO: Fix these unittests failed on Windows + list(REMOVE_ITEM TEST_OPS test_autograd_functional_prim) + list(REMOVE_ITEM TEST_OPS test_primapi) +endif() + foreach(TEST_OP ${TEST_OPS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS}) endforeach() -set_tests_properties(test_autograd_functional_dynamic PROPERTIES TIMEOUT 160) +set_tests_properties(test_autograd_functional_dynamic PROPERTIES TIMEOUT 200) set_tests_properties(test_autograd_functional_static PROPERTIES TIMEOUT 160) set_tests_properties(test_gradients_and_minimize PROPERTIES TIMEOUT 60) +if(NOT WIN32) + set_tests_properties(test_autograd_functional_prim PROPERTIES TIMEOUT 60) +endif() diff --git a/python/paddle/fluid/tests/unittests/autograd/test_autograd_functional_dynamic.py b/python/paddle/fluid/tests/unittests/autograd/test_autograd_functional_dynamic.py index 6e097b6335bcc77ebd7e6da826b68873f12bbfe3..a98b509f963c7c074427c0ec05a0a52b8b203a1c 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_autograd_functional_dynamic.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_autograd_functional_dynamic.py @@ -21,7 +21,7 @@ import paddle import paddle.fluid as fluid import paddle.compat as cpt import paddle.nn.functional as F -from paddle.autograd.functional import _as_tensors +from paddle.autograd.utils import as_tensors from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph, _in_eager_without_dygraph_check import config @@ -33,7 +33,7 @@ from utils import matmul, mul, nested, o2, pow, reduce, reduce_dim, unuse def make_v(f, inputs): - outputs = _as_tensors(f(*inputs)) + outputs = as_tensors(f(*inputs)) return [paddle.ones_like(x) for x in outputs] diff --git a/python/paddle/fluid/tests/unittests/autograd/test_autograd_functional_prim.py b/python/paddle/fluid/tests/unittests/autograd/test_autograd_functional_prim.py new file mode 100644 index 0000000000000000000000000000000000000000..f75460df6b52dcf44e2a426920ebd155d66cc76b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/autograd/test_autograd_functional_prim.py @@ -0,0 +1,149 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing +import unittest + +import numpy as np +import paddle + +import config +import utils + + +@utils.place(config.DEVICES) +@utils.parameterize((utils.TEST_CASE_NAME, 'fun', 'args', 'dtype'), ( + ('unary_float32', paddle.tanh, (np.random.rand(2, 3), ), 'float32'), + ('binary_float32', paddle.matmul, + (np.random.rand(2, 3), np.random.rand(3, 2)), 'float32'), + ('unary_float64', paddle.tanh, (np.random.rand(2, 3), ), 'float64'), + ('binary_float64', paddle.matmul, + (np.random.rand(2, 3), np.random.rand(3, 2)), 'float64'), +)) +class TestJacobianPrim(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.args = [arg.astype(cls.dtype) for arg in cls.args] + cls._rtol = config.TOLERANCE.get( + cls.dtype).get('first_order_grad').get('rtol') + cls._atol = config.TOLERANCE.get( + cls.dtype).get('first_order_grad').get('atol') + + def setUp(self): + paddle.enable_static() + paddle.incubate.autograd.enable_prim() + + def tearDown(self): + paddle.incubate.autograd.disable_prim() + paddle.disable_static() + + def test_jacobian_prim(self): + + def wrapper(fun, args): + mp = paddle.static.Program() + sp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + static_args = [ + paddle.static.data(f'arg{i}', arg.shape, self.dtype) + for i, arg in enumerate(args) + ] + for arg in static_args: + arg.stop_gradient = False + jac = paddle.incubate.autograd.Jacobian(fun, static_args)[:] + if paddle.incubate.autograd.prim_enabled(): + paddle.incubate.autograd.prim2orig() + exe = paddle.static.Executor() + exe.run(sp) + [jac] = exe.run(mp, + feed={f'arg{i}': arg + for i, arg in enumerate(args)}, + fetch_list=[jac]) + return jac + + paddle.incubate.autograd.enable_prim() + prim_jac = wrapper(self.fun, self.args) + paddle.incubate.autograd.disable_prim() + orig_jac = wrapper(self.fun, self.args) + + np.testing.assert_allclose(orig_jac, + prim_jac, + rtol=self._rtol, + atol=self._atol) + + +@utils.place(config.DEVICES) +@utils.parameterize((utils.TEST_CASE_NAME, 'fun', 'args', 'dtype'), ( + ('unary_float32', paddle.tanh, (np.random.rand(1), ), 'float32'), + ('binary_float32', paddle.multiply, + (np.random.rand(1), np.random.rand(1)), 'float32'), + ('unary_float64', paddle.tanh, (np.random.rand(1), ), 'float64'), + ('binary_float64', paddle.multiply, + (np.random.rand(1), np.random.rand(1)), 'float64'), +)) +class TestHessianPrim(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.args = [arg.astype(cls.dtype) for arg in cls.args] + cls._rtol = config.TOLERANCE.get( + cls.dtype).get('second_order_grad').get('rtol') + cls._atol = config.TOLERANCE.get( + cls.dtype).get('second_order_grad').get('atol') + + def setUp(self): + paddle.enable_static() + paddle.incubate.autograd.enable_prim() + + def tearDown(self): + paddle.incubate.autograd.disable_prim() + paddle.disable_static() + + def test_jacobian_prim(self): + + def wrapper(fun, args): + mp = paddle.static.Program() + sp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + static_args = [ + paddle.static.data(f'arg{i}', arg.shape, self.dtype) + for i, arg in enumerate(args) + ] + for arg in static_args: + arg.stop_gradient = False + hessian = paddle.incubate.autograd.Hessian(fun, static_args)[:] + if paddle.incubate.autograd.prim_enabled(): + paddle.incubate.autograd.prim2orig() + exe = paddle.static.Executor() + exe.run(sp) + [hessian + ] = exe.run(mp, + feed={f'arg{i}': arg + for i, arg in enumerate(args)}, + fetch_list=[hessian]) + return hessian + + paddle.incubate.autograd.enable_prim() + prim_jac = wrapper(self.fun, self.args) + paddle.incubate.autograd.disable_prim() + orig_jac = wrapper(self.fun, self.args) + + np.testing.assert_allclose(orig_jac, + prim_jac, + rtol=self._rtol, + atol=self._atol) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/autograd/test_autograd_functional_static.py b/python/paddle/fluid/tests/unittests/autograd/test_autograd_functional_static.py index 06d3bb5eb2495f5fab76b32da18f0d3d168fed46..4e01ad5382c91617c1829d0c417bed2af5899843 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_autograd_functional_static.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_autograd_functional_static.py @@ -23,7 +23,6 @@ import config import utils from utils import (_compute_numerical_batch_jacobian, _compute_numerical_jacobian) -from paddle.autograd.functional import _as_tensors paddle.enable_static() @@ -58,7 +57,7 @@ class TestVJP(unittest.TestCase): sp = paddle.static.Program() mp = paddle.static.Program() with paddle.static.program_guard(mp, sp): - feed, static_xs, static_v = gen_static_data_and_feed( + feed, static_xs, static_v = utils.gen_static_data_and_feed( self.xs, self.v, stop_gradient=self.stop_gradient) ys, xs_grads = paddle.autograd.vjp(self.fun, static_xs, static_v) exe.run(sp) @@ -69,7 +68,7 @@ class TestVJP(unittest.TestCase): sp = paddle.static.Program() mp = paddle.static.Program() with paddle.static.program_guard(mp, sp): - feed, static_xs, static_v = gen_static_data_and_feed( + feed, static_xs, static_v = utils.gen_static_data_and_feed( self.xs, self.v, False) ys = self.fun(*static_xs) if isinstance( static_xs, typing.Sequence) else self.fun(static_xs) @@ -102,7 +101,7 @@ class TestVJPException(unittest.TestCase): sp = paddle.static.Program() mp = paddle.static.Program() with paddle.static.program_guard(mp, sp): - feed, static_xs, static_v = gen_static_data_and_feed( + feed, static_xs, static_v = utils.gen_static_data_and_feed( self.xs, self.v) ys, xs_grads = paddle.autograd.vjp(self.fun, static_xs, static_v) self.exe.run(sp) @@ -113,37 +112,6 @@ class TestVJPException(unittest.TestCase): self._vjp() -def gen_static_data_and_feed(xs, v, stop_gradient=True): - feed = {} - if isinstance(xs, typing.Sequence): - static_xs = [] - for i, x in enumerate(xs): - x = paddle.static.data(f"x{i}", x.shape, x.dtype) - x.stop_gradient = stop_gradient - static_xs.append(x) - feed.update({f'x{idx}': value for idx, value in enumerate(xs)}) - else: - static_xs = paddle.static.data('x', xs.shape, xs.dtype) - static_xs.stop_gradient = stop_gradient - feed.update({'x': xs}) - - if isinstance(v, typing.Sequence): - static_v = [] - for i, e in enumerate(v): - e = paddle.static.data(f'v{idx}', v.shape, v.dtype) - e.stop_gradient = stop_gradient - static_v.append(e) - feed.update({f'v{idx}': value for idx, value in v}) - elif v is not None: - static_v = paddle.static.data('v', v.shape, v.dtype) - static_v.stop_gradient = stop_gradient - feed.update({'v': v}) - else: - static_v = v - - return feed, static_xs, static_v - - def approx_jacobian(f, xs, dtype, eps=1e-5, batch=False): r"""Computes an approximate Jacobian matrix of a multi-valued function using finite differences. diff --git a/python/paddle/fluid/tests/unittests/autograd/test_primapi.py b/python/paddle/fluid/tests/unittests/autograd/test_primapi.py new file mode 100644 index 0000000000000000000000000000000000000000..0137f4103fbb30c16607440486da7f5e861d2c99 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/autograd/test_primapi.py @@ -0,0 +1,129 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing +import unittest + +import numpy as np +import paddle +from paddle.incubate.autograd import primapi + +import config +import utils + + +@utils.place(config.DEVICES) +@utils.parameterize( + (utils.TEST_CASE_NAME, 'fun', 'xs', 'v', 'dtype'), + (('matmul', paddle.matmul, + (np.random.rand(2, 3), np.random.rand(3, 2)), None, 'float32'), + ('multiply', paddle.multiply, + (np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float64'), + ('add', paddle.add, + (np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float32'), + ('input_not_sequence', paddle.tanh, + (np.random.rand(5, 5), ), None, 'float64'), + ('input_gradients_not_none', paddle.matmul, + (np.random.rand(3, 3), np.random.rand(3, 3)), + (np.random.rand(3, 3), np.random.rand(3, 3)), 'float64'))) +class TestForwardGradients(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.xs = tuple(x.astype(cls.dtype) for x in cls.xs) + cls._rtol = config.TOLERANCE.get(str( + cls.dtype)).get("first_order_grad").get("rtol") + cls._atol = config.TOLERANCE.get(str( + cls.dtype)).get("first_order_grad").get("atol") + + def setUp(self): + paddle.enable_static() + paddle.incubate.autograd.enable_prim() + + def tearDown(self): + paddle.incubate.autograd.disable_prim() + paddle.disable_static() + + def test_forward_gradients(self): + + def expected(): + paddle.incubate.autograd.disable_prim() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + feed, static_xs, static_v = utils.gen_static_data_and_feed( + self.xs, self.v, stop_gradient=False) + _, ys_grad = paddle.autograd.jvp(self.fun, static_xs, static_v) + exe = paddle.static.Executor() + exe.run(sp) + out = exe.run(mp, feed=feed, fetch_list=ys_grad) + paddle.incubate.autograd.enable_prim() + return out + + def actual(): + paddle.incubate.autograd.enable_prim() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + feed, static_xs, static_v = utils.gen_static_data_and_feed( + self.xs, self.v, stop_gradient=False) + ys = self.fun(*static_xs) if isinstance( + static_xs, typing.Sequence) else self.fun(static_xs) + ys_grad = primapi.forward_gradients(ys, static_xs, static_v) + paddle.incubate.autograd.prim2orig(mp.block(0)) + exe = paddle.static.Executor() + exe.run(sp) + out = exe.run(mp, feed=feed, fetch_list=ys_grad) + paddle.incubate.autograd.disable_prim() + return out + + actual = actual() + expected = expected() + self.assertEqual(type(actual), type(expected)) + np.testing.assert_allclose(np.concatenate(actual), + np.concatenate(expected), + rtol=self._rtol, + atol=self._atol) + + def test_prim_disabled(self): + paddle.incubate.autograd.disable_prim() + sp = paddle.static.Program() + mp = paddle.static.Program() + with self.assertRaises(RuntimeError): + with paddle.static.program_guard(mp, sp): + feed, static_xs, static_v = utils.gen_static_data_and_feed( + self.xs, self.v, stop_gradient=False) + ys = self.fun(*static_xs) if isinstance( + static_xs, typing.Sequence) else self.fun(static_xs) + ys_grad = primapi.forward_gradients(ys, static_xs, static_v) + paddle.incubate.autograd.prim2orig(mp.block(0)) + exe = paddle.static.Executor() + exe.run(sp) + exe.run(mp, feed=feed, fetch_list=ys_grad) + paddle.incubate.autograd.enable_prim() + + def test_illegal_param(self): + paddle.incubate.autograd.enable_prim() + with self.assertRaises(TypeError): + primapi.forward_gradients(1, paddle.static.data('inputs', + shape=[1])) + + with self.assertRaises(TypeError): + primapi.forward_gradients(paddle.static.data('targets', shape=[1]), + 1) + paddle.incubate.autograd.disable_prim() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/autograd/utils.py b/python/paddle/fluid/tests/unittests/autograd/utils.py index 4105ea2672be019c31d83492c3cf880da828722a..8a0e51f60f47bfa48e322f336852d248626b2836 100644 --- a/python/paddle/fluid/tests/unittests/autograd/utils.py +++ b/python/paddle/fluid/tests/unittests/autograd/utils.py @@ -22,7 +22,7 @@ import contextlib import collections import numpy as np import paddle -from paddle.autograd.functional import _as_tensors +from paddle.autograd.utils import as_tensors ########################################################## @@ -57,8 +57,8 @@ def _set_item(t, idx, value): def _compute_numerical_jacobian(func, xs, delta, np_dtype): - xs = list(_as_tensors(xs)) - ys = list(_as_tensors(func(*xs))) + xs = list(as_tensors(xs)) + ys = list(as_tensors(func(*xs))) fin_size = len(xs) fout_size = len(ys) jacobian = list([] for _ in range(fout_size)) @@ -74,11 +74,11 @@ def _compute_numerical_jacobian(func, xs, delta, np_dtype): orig = _get_item(xs[j], q) x_pos = orig + delta xs[j] = _set_item(xs[j], q, x_pos) - ys_pos = _as_tensors(func(*xs)) + ys_pos = as_tensors(func(*xs)) x_neg = orig - delta xs[j] = _set_item(xs[j], q, x_neg) - ys_neg = _as_tensors(func(*xs)) + ys_neg = as_tensors(func(*xs)) xs[j] = _set_item(xs[j], q, orig) @@ -91,8 +91,8 @@ def _compute_numerical_jacobian(func, xs, delta, np_dtype): def _compute_numerical_hessian(func, xs, delta, np_dtype): - xs = list(_as_tensors(xs)) - ys = list(_as_tensors(func(*xs))) + xs = list(as_tensors(xs)) + ys = list(as_tensors(func(*xs))) fin_size = len(xs) hessian = list([] for _ in range(fin_size)) for i in range(fin_size): @@ -136,8 +136,8 @@ def _compute_numerical_batch_jacobian(func, np_dtype, merge_batch=True): no_batch_jacobian = _compute_numerical_jacobian(func, xs, delta, np_dtype) - xs = list(_as_tensors(xs)) - ys = list(_as_tensors(func(*xs))) + xs = list(as_tensors(xs)) + ys = list(as_tensors(func(*xs))) fin_size = len(xs) fout_size = len(ys) bs = xs[0].shape[0] @@ -164,7 +164,7 @@ def _compute_numerical_batch_jacobian(func, def _compute_numerical_batch_hessian(func, xs, delta, np_dtype): - xs = list(_as_tensors(xs)) + xs = list(as_tensors(xs)) batch_size = xs[0].shape[0] fin_size = len(xs) hessian = [] @@ -202,7 +202,7 @@ def _compute_numerical_batch_hessian(func, xs, delta, np_dtype): def _compute_numerical_vjp(func, xs, v, delta, np_dtype): - xs = _as_tensors(xs) + xs = as_tensors(xs) jacobian = np.array(_compute_numerical_jacobian(func, xs, delta, np_dtype)) if v is None: v = [paddle.ones_like(x) for x in xs] @@ -217,7 +217,7 @@ def _compute_numerical_vjp(func, xs, v, delta, np_dtype): def _compute_numerical_vhp(func, xs, v, delta, np_dtype): - xs = list(_as_tensors(xs)) + xs = list(as_tensors(xs)) hessian = np.array(_compute_numerical_hessian(func, xs, delta, np_dtype)) flat_v = np.array([v_el.numpy().reshape(-1) for v_el in v]) vhp = [np.zeros((_product(x.shape)), dtype=np_dtype) for x in xs] @@ -391,3 +391,37 @@ def _np_concat_matrix_sequence(src, src_format=MatrixFormat.NM): if not isinstance(src[0], typing.Sequence): src = [src] return concat_row(tuple(concat_col(xs) for xs in src)) + + +########################################################## +# Utils for generating test data. +########################################################## +def gen_static_data_and_feed(xs, v, stop_gradient=True): + feed = {} + if isinstance(xs, typing.Sequence): + static_xs = [] + for i, x in enumerate(xs): + x = paddle.static.data(f"x{i}", x.shape, x.dtype) + x.stop_gradient = stop_gradient + static_xs.append(x) + feed.update({f'x{idx}': value for idx, value in enumerate(xs)}) + else: + static_xs = paddle.static.data('x', xs.shape, xs.dtype) + static_xs.stop_gradient = stop_gradient + feed.update({'x': xs}) + + if isinstance(v, typing.Sequence): + static_v = [] + for i, e in enumerate(v): + e = paddle.static.data(f'v{i}', e.shape, e.dtype) + e.stop_gradient = stop_gradient + static_v.append(e) + feed.update({f'v{i}': value for i, value in enumerate(v)}) + elif v is not None: + static_v = paddle.static.data('v', v.shape, v.dtype) + static_v.stop_gradient = stop_gradient + feed.update({'v': v}) + else: + static_v = v + + return feed, static_xs, static_v diff --git a/python/paddle/incubate/autograd/primapi.py b/python/paddle/incubate/autograd/primapi.py new file mode 100644 index 0000000000000000000000000000000000000000..75a70b09731f2db652b214fc61eaea0117064d9d --- /dev/null +++ b/python/paddle/incubate/autograd/primapi.py @@ -0,0 +1,93 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing + +import paddle.autograd.utils as tensor_utils +import paddle.incubate.autograd.utils as prim_utils +from paddle.fluid import framework +from paddle.incubate.autograd import primx + + +@framework.static_only +def forward_gradients(targets, inputs, input_gradients=None): + """Forward mode of automatic differentiation. + + .. note:: + **ONLY available in the static mode and primitive operators.** + + Args: + targets: The target tensor or tensors + inputs: The input tensor or tensors + input_gradients: The gradient Tensor or Tensors of inputs which has + the same shape with inputs, Defaults to None, in this case is + equivalent to all ones . + + Returns: + target_gradients (Tensor|Sequence[Tensor]): The gradients for targets. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + paddle.enable_static() + paddle.incubate.autograd.enable_prim() + + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + + with paddle.static.program_guard(main_program, startup_program): + x = paddle.static.data('x', shape=[1], dtype='float32') + y = x * x + y_grad = paddle.incubate.autograd.forward_gradients(y, x) + paddle.incubate.autograd.prim2orig() + + exe = paddle.static.Executor() + exe.run(startup_program) + y_grad = exe.run(main_program, feed={'x': np.array([2.]).astype('float32')}, fetch_list=[y_grad]) + print(y_grad) + # [array([4.], dtype=float32)] + + paddle.incubate.autograd.disable_prim() + paddle.disable_static() + """ + if not prim_utils.prim_enabled(): + raise RuntimeError('forward_gradients must be running on primitive' + 'operators, use enable_prim to turn it on.') + + if not isinstance(targets, (framework.Variable, typing.Sequence)): + raise TypeError(f'Expected targets is Tensor|Sequence[Tesnor], ' + f'but got {type(targets)}.') + + if not isinstance(inputs, (framework.Variable, typing.Sequence)): + raise TypeError(f'Expected inputs is Tensor|Sequence[Tesnor], ' + f'but got {type(inputs)}.') + + ys, xs, xs_dot = tensor_utils.as_tensors(targets), tensor_utils.as_tensors( + inputs), tensor_utils.as_tensors(input_gradients) + + block = framework.default_main_program().current_block() + if any(x.block != block for x in xs + ys): + raise RuntimeError( + 'Variable in inputs and targets should exist in current block of ' + 'main program.') + + primx.orig2prim(block) + ad = primx.Transform(ys[0].block) + _, ys_dot = ad.linearize(xs, ys, xs_dot) + + return ys_dot[0] if isinstance(targets, framework.Variable) else ys_dot diff --git a/python/paddle/incubate/autograd/primrules.py b/python/paddle/incubate/autograd/primrules.py index 7983032f1a1248eca4c16027af6e3255ee7d04db..24e48e8c5425838d65c69e25de22f61ba2488f38 100644 --- a/python/paddle/incubate/autograd/primrules.py +++ b/python/paddle/incubate/autograd/primrules.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import typing import paddle @@ -656,10 +657,14 @@ def split_transpose(op, check_dot, ys_bar): @REGISTER_TRANSPOSE('concat_p') def concat_transpose(op, check_dot, y_bar): xs, = op_position_inputs(op) + if not isinstance(xs, typing.Sequence): + xs = [xs] for x in xs: assert check_dot(x), 'check_dot(x) must be True' axis = op.attr('axis') sections = [x.shape[axis] for x in xs] + if len(sections) == 1: + return y_bar return split(y_bar, num_or_sections=sections, axis=axis) diff --git a/python/paddle/incubate/autograd/primx.py b/python/paddle/incubate/autograd/primx.py index 5ee45116e66d85e23803e364623c6123c49e0cd3..d5037dcf64994bca69ec452a217e983a2e10d526 100644 --- a/python/paddle/incubate/autograd/primx.py +++ b/python/paddle/incubate/autograd/primx.py @@ -20,8 +20,9 @@ from paddle import compat as cpt from .primops import fill_const, add from .primreg import op_position_inputs, op_position_output, lookup_orig2prim, lookup_prim2orig from .primrules import _orig2prim, _prim2orig, _jvp, _transpose -from .utils import get_input_var_list, get_output_var_list, to_tensors, flatten, flatten_and_remove_none +from .utils import get_input_var_list, get_output_var_list, flatten, flatten_and_remove_none from collections import OrderedDict +from paddle.autograd.utils import as_tensors def topo_path(xs, ys, block=None): @@ -457,7 +458,7 @@ def _lower(block, reverse): for orig_out, new_out in zip( expand_nested_list(get_output_var_list(op)), - expand_nested_list(to_tensors(lower_fn(op, *input_args)))): + expand_nested_list(as_tensors(lower_fn(op, *input_args)))): assert not (orig_out is None) ^ ( new_out is None), "orig_out and new_out should match." vars_to_remove.add(new_out.name) @@ -591,7 +592,7 @@ def _gradients(ys, xs, ys_bar=None): xs_bar: a list gradients of input `xs` """ - ys, xs = to_tensors(ys), to_tensors(xs) + ys, xs, ys_bar = as_tensors(ys), as_tensors(xs), as_tensors(ys_bar) block = default_main_program().current_block() for el in xs + ys: assert el is None or el.block == block, f'variable in xs and ys should be None or in current block of main program' diff --git a/python/paddle/incubate/autograd/utils.py b/python/paddle/incubate/autograd/utils.py index 44bbd32bc9c322984513b80eaee69a5c91950bb8..9d6a8c4f6a36dc325065d6bb1a581b00810c4bb5 100644 --- a/python/paddle/incubate/autograd/utils.py +++ b/python/paddle/incubate/autograd/utils.py @@ -158,13 +158,6 @@ def get_output_var_list(op): ] -def to_tensors(xs): - if isinstance(xs, paddle.fluid.framework.Variable): - return [xs] - else: - return xs - - def flatten(inp): if inp is None or isinstance(inp, paddle.fluid.framework.Variable): return [inp]