From f6ee202fe4cd7499a628c4a5f7dbcdc60c9de2c8 Mon Sep 17 00:00:00 2001 From: WangZhen <23097963+0x45f@users.noreply.github.com> Date: Wed, 18 May 2022 16:03:04 +0800 Subject: [PATCH] Add support for forward and reverse high-order automatic differentiation mechanism (#41919) * Updated triple_grad_check func * add todo for gradient checker and refine some comments * remove additional code * add test for warnging in backward.py * format python code * support multi input in triple gradient checker * Add matmul triple grad kernel * Updated comments of TODO * Supported some special tests * Change code-format to follow CI std * Updated gradient_checker.py * Fix conflicts * Removed unnecessary printing log * Change code style to follow CI std * merge upstream * add priops.py * add_p * rm useless files * add sub_p mul_p div_p * add sqrt_p and tanh_p * add reshape_p * add broadcast_p * Add python primitive wrappers. * Jvp rules updated. * JVP rules done for all the 17 primops. * quick check and fixes. * add jvp(op, *args) * add broadcast_p fill_constant_p matmul_p reduce_p reshape_p transpose_p * add split_p and concat_p * add gather_p and scatter_add_p * add slice_select_p and slice_assign_p * Add transpose rules. * add multi input check for add_p, sub_p, mul_p, div_p * update concat_p * Linearize and transpose in progress.. * refine gather_p and scatter_add_p * updated. * update transpose. * refine slice_assign_p and slice_select_p * init commit for lower * Merged with primitive ops. * small update * add rules for orig2prim and prim2orig * add 9 test for prim ops * add more test and fix some bug * add more test * register proto * Adding primops test. * add shape valid check for broadcast_p op, and add keepdim attr into reduce_p op proto * support multi input and multi output for split_p and concat_p * Test updated. * update * fix slice bug for slice_select_p and slice_assign_p * updated. * Ops updated. * Refactor and bug fixes. * updated. * finish orig2prim and prim2orig rules * dtype for axis attr should be long int * update dtype for axis attr int64_t * update for iscan CI * Update primx. * Refactor vars in primx. * update for lower transform * add more shape and dtype check * update primx.py * change IndexTensor into int32 dtype * update * Fix linearize and transpose. * Update is_dot * Update is_dot * Update is_dot * add gradient aggregation, fix add_transpose. * pass first linearize+transpose test. * update test * refactor op registration and primx. * update rule for slice_assign * try test lower * update orig2prim and prim2orig * pass simple lower pass * update * Update input types in the unit test. * orig2prim segfault. * 50% for adam.minimize * test updated. * temp fix erros in removing vars. * primx updated. * update for matmul_v2 and reshape2 orig2prim * update for minimize * Refine primrules * Remove some code * supporting unused and unreachable vars. * update for use prim2orig in minimize * fix gather and scatter_add transpose. * Add rules UT * update scatter_add * Refine UT code * fix nonetype check in topo * Update gather_p pywrapper. * remove useless print * Merge tongxin PR and refine code * readd some test * rm useless print * polish code. * fix bug in minimize * add get_input_var_list and get_output_var_list and use it in lower * Fix scatter_add_p prim2orig * Update code and fix orig2prim/prim2orig UT * delete vars after block.desc._remove * Improve ops and vars clean up logics. * fix some bug in linearize and lower * update tanh transpose. * use set instead of list for var2remove * test updated. * polish code. * fix dot2bar delete. * merge tx/ad * add indextensor_dot for gather and scatter_add * add sorted for set * Fix scale_orig2prim params * fix some syntax bug * add golbal_lower_update list * Better handling of unused vars. * update tests. * Fix elementwise_sub orig2prim * support none for transpose rule * Merge and add transform UT * fix a bug in transpose * Fix transpose and UT * a hacky fix for cancat op * Fix exector place * Refine variable name * Add elementwise_mul orig2prim and support p_norm when p=1 * Add sqrt orig2prim rule and UT * merge wz test * rename files, add enable_prim, disable_prim, prim_enabled, delete global_lower_update * fix a bug in test_ad_transform_trans * revert modify in framework.py * add paddle.fluid.incubate.ad_transform to python/setup.py.in * Fix remove vars error * Fix p_norm_orig2prim * merge wz * Modify the code directory * Add utils.py and remove get_input/output_vars functions * Update maolin code * Rename UT and refine test_ad_transform_primops * Fix div_p jvp rule * Add higher derivatives UT * Remove UT to autograd dir * Fix comments * import paddle in primops.py * Add some error message for assert * Refine UT class name and refine some comments in primreg.py * update minimize of paddle/optimizer for supporting new autograd * resolve cicular importing between backward.py and optimizer.py * fill gradients and minimize unittest * Replace `assert isinstance` with `raise TypeError` * Add some assert message for primx.py * Polish variable name * Add some assert message * add some docstring * refine some name * update the format of english documents * Split test_transform.py to two files to avoid ci error * fix the document format of enable_prim/disable_prim/prim2orig/prim_enabled * polish test_gradients_and_minimize * add default value for prim_enabled api doc * Remove some UT to avoid windows ci error * Enlarge test_gradients_and_minimize limit time * Fix ut limit time Co-authored-by: veyron95 Co-authored-by: Jiabin Yang <360788950@qq.com> Co-authored-by: levi131 Co-authored-by: Tongxin Bai Co-authored-by: Xiaoxu Chen Co-authored-by: levi131 <83750468+levi131@users.noreply.github.com> --- python/paddle/autograd/primreg.py | 54 -- python/paddle/fluid/backward.py | 6 + .../tests/unittests/autograd/CMakeLists.txt | 1 + .../autograd/test_gradients_and_minimize.py | 143 ++++ .../autograd/test_jvp_and_transpose.py | 696 +++++++++++++++++ .../unittests/autograd/test_orig2prim.py | 360 +++++++++ .../unittests/autograd/test_prim2orig.py | 381 +++++++++ .../unittests/{ => autograd}/test_primops.py | 5 +- .../unittests/autograd/test_transform.py | 313 ++++++++ python/paddle/incubate/autograd/__init__.py | 11 +- .../paddle/{ => incubate}/autograd/primops.py | 60 +- python/paddle/incubate/autograd/primreg.py | 289 +++++++ python/paddle/incubate/autograd/primrules.py | 724 ++++++++++++++++++ python/paddle/incubate/autograd/primx.py | 611 +++++++++++++++ python/paddle/incubate/autograd/utils.py | 178 +++++ python/paddle/optimizer/optimizer.py | 48 +- python/setup.py.in | 1 + 17 files changed, 3803 insertions(+), 78 deletions(-) delete mode 100644 python/paddle/autograd/primreg.py create mode 100644 python/paddle/fluid/tests/unittests/autograd/test_gradients_and_minimize.py create mode 100644 python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py create mode 100644 python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py create mode 100644 python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py rename python/paddle/fluid/tests/unittests/{ => autograd}/test_primops.py (95%) create mode 100644 python/paddle/fluid/tests/unittests/autograd/test_transform.py rename python/paddle/{ => incubate}/autograd/primops.py (77%) create mode 100644 python/paddle/incubate/autograd/primreg.py create mode 100644 python/paddle/incubate/autograd/primrules.py create mode 100644 python/paddle/incubate/autograd/primx.py create mode 100644 python/paddle/incubate/autograd/utils.py diff --git a/python/paddle/autograd/primreg.py b/python/paddle/autograd/primreg.py deleted file mode 100644 index cffb4bc050..0000000000 --- a/python/paddle/autograd/primreg.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import functools - - -class Registry(object): - """ A general registry object. """ - __slots__ = ['name', 'tab'] - - def __init__(self, name): - self.name = name - self.tab = {} - - def register(self, name, value): - assert name not in self.tab - self.tab[name] = value - - def lookup(self, name): - assert name in self.tab, f'No registry entry is found with name: {name}' - return self.tab[name] - - -_primop_fn = Registry('primop_fn') -_orig2prim = Registry('orig2prim') -_prim2orig = Registry('prim2orig') -_primop_jvp = Registry('primop_jvp') -_primop_transpose = Registry('primop_transpose') -_primop_position_argnames = Registry('primop_position_argnames') - - -def REGISTER_FN(op_type, *position_argnames): - """Decorator for registering the Python function for a primitive op.""" - - assert isinstance(op_type, str) - - _primop_position_argnames.register(op_type, position_argnames) - - def wrapper(f): - _primop_fn.register(op_type, f) - return f - - return wrapper diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index bc53c13028..145ecc83cf 100755 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -32,6 +32,7 @@ try: from collections.abc import Sequence except: from collections import Sequence + __all__ = [ 'append_backward', 'gradients', @@ -2113,6 +2114,11 @@ def gradients(targets, inputs, target_gradients=None, no_grad_set=None): check_type(target_gradients, 'target_gradients', ( framework.Variable, list, tuple, type(None)), 'paddle.static.gradients') + from ..incubate.autograd.primx import _gradients + from ..incubate.autograd.utils import prim_enabled + if prim_enabled(): + return _gradients(targets, inputs, target_gradients) + outs = calc_gradient(targets, inputs, target_gradients, no_grad_set) return _as_list(outs) diff --git a/python/paddle/fluid/tests/unittests/autograd/CMakeLists.txt b/python/paddle/fluid/tests/unittests/autograd/CMakeLists.txt index 46af5509d2..37216241b8 100644 --- a/python/paddle/fluid/tests/unittests/autograd/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/autograd/CMakeLists.txt @@ -8,3 +8,4 @@ endforeach(TEST_OP) set_tests_properties(test_autograd_functional_dynamic PROPERTIES TIMEOUT 160) set_tests_properties(test_autograd_functional_static PROPERTIES TIMEOUT 160) +set_tests_properties(test_gradients_and_minimize PROPERTIES TIMEOUT 60) diff --git a/python/paddle/fluid/tests/unittests/autograd/test_gradients_and_minimize.py b/python/paddle/fluid/tests/unittests/autograd/test_gradients_and_minimize.py new file mode 100644 index 0000000000..092ddb4094 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/autograd/test_gradients_and_minimize.py @@ -0,0 +1,143 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np + +import paddle +from paddle.incubate.autograd.primx import prim2orig +from paddle.incubate.autograd.utils import enable_prim, disable_prim, prim_enabled + +paddle.enable_static() + + +class TestGradients(unittest.TestCase): + def test_third_order(self): + enable_prim() + main = paddle.static.Program() + startup = paddle.static.Program() + with paddle.static.program_guard(main, startup): + x = paddle.static.data(name='x', shape=[1], dtype='float32') + x2 = paddle.multiply(x, x) + x3 = paddle.multiply(x2, x) + x4 = paddle.multiply(x3, x) + + grad1, = paddle.static.gradients([x4], [x]) + grad2, = paddle.static.gradients([grad1], [x]) + grad3, = paddle.static.gradients([grad2], [x]) + + prim2orig(main.block(0)) + + feed = {x.name: np.array([2.]).astype('float32')} + fetch_list = [grad3.name] + result = [np.array([48.])] + + place = paddle.CPUPlace() + if paddle.device.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + exe.run(startup) + outs = exe.run(main, feed=feed, fetch_list=fetch_list) + np.allclose(outs, result) + disable_prim() + + def test_fourth_order(self): + enable_prim() + main = paddle.static.Program() + startup = paddle.static.Program() + with paddle.static.program_guard(main, startup): + x = paddle.static.data(name='x', shape=[1], dtype='float32') + x2 = paddle.multiply(x, x) + x3 = paddle.multiply(x2, x) + x4 = paddle.multiply(x3, x) + x5 = paddle.multiply(x4, x) + out = paddle.sqrt(x5 + x4) + + grad1, = paddle.static.gradients([out], [x]) + grad2, = paddle.static.gradients([grad1], [x]) + grad3, = paddle.static.gradients([grad2], [x]) + grad4, = paddle.static.gradients([grad3], [x]) + + prim2orig(main.block(0)) + + feed = {x.name: np.array([2.]).astype('float32'), } + fetch_list = [grad4.name] + # (3*(-5*x^2-16*x-16))/(16*(x+1)^3.5) + result = [np.array([-0.27263762711])] + + place = paddle.CPUPlace() + if paddle.device.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + exe.run(startup) + outs = exe.run(main, feed=feed, fetch_list=fetch_list) + np.allclose(outs, result) + disable_prim() + + +class TestMinimize(unittest.TestCase): + def model(self, x, w, bias, opt): + paddle.seed(0) + place = paddle.CPUPlace() + if paddle.device.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + main = paddle.static.Program() + startup = paddle.static.Program() + with paddle.static.program_guard(main, startup): + input_x = paddle.static.data('x', x.shape, dtype=x.dtype) + input_x.stop_gradient = False + params_w = paddle.static.create_parameter( + shape=w.shape, dtype=w.dtype, is_bias=False) + params_bias = paddle.static.create_parameter( + shape=bias.shape, dtype=bias.dtype, is_bias=True) + y = paddle.tanh(paddle.matmul(input_x, params_w) + params_bias) + loss = paddle.norm(y, p=2) + opt = opt + _, grads = opt.minimize(loss) + if prim_enabled(): + prim2orig(main.block(0)) + exe.run(startup) + grads = exe.run(main, + feed={'x': x, + 'w': w, + 'bias': bias}, + fetch_list=grads) + return grads + + def test_adam(self): + x = np.random.rand(2, 20) + w = np.random.rand(20, 2) + bias = np.random.rand(2) + enable_prim() + prim_grads = self.model(x, w, bias, paddle.optimizer.Adam(0.01)) + disable_prim() + orig_grads = self.model(x, w, bias, paddle.optimizer.Adam(0.01)) + for orig, prim in zip(orig_grads, prim_grads): + np.testing.assert_allclose(orig, prim) + + def test_sgd(self): + x = np.random.rand(2, 20) + w = np.random.rand(20, 2) + bias = np.random.rand(2) + enable_prim() + prim_grads = self.model(x, w, bias, paddle.optimizer.SGD(0.01)) + disable_prim() + orig_grads = self.model(x, w, bias, paddle.optimizer.SGD(0.01)) + for orig, prim in zip(orig_grads, prim_grads): + np.testing.assert_allclose(orig, prim) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py b/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py new file mode 100644 index 0000000000..d6ff931a93 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py @@ -0,0 +1,696 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.layers.utils import flatten +from paddle.incubate.autograd.primrules import _orig2prim, _prim2orig, _jvp, _transpose + +paddle.enable_static() + + +############################ Test linearize rules ############################ +class TestAddPJVPAndTranspose(unittest.TestCase): + def setUp(self): + self.main_program = paddle.static.Program() + self.startup_program = paddle.static.Program() + self.layer_help = LayerHelper('TestPrim2Orig') + + with paddle.static.program_guard(self.main_program, + self.startup_program): + self.init_data() + + def init_data(self): + # Set prim op + self.op_type = 'add_p' + X = paddle.static.data(name='X', shape=[2, 2], dtype='float') + Y = paddle.static.data(name='Y', shape=[2, 2], dtype='float') + self.prim_input = {'X': X, 'Y': Y} + self.prim_output = { + 'Z': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.prim_attrs = {} + + # Set JVP + X_DOT = paddle.static.data(name='X_DOT', shape=[2, 2], dtype='float') + Y_DOT = paddle.static.data(name='Y_DOT', shape=[2, 2], dtype='float') + self.jvp_args = (X_DOT, Y_DOT) + self.jvp_out_shape_map = {0: self.prim_output['Z']} + + # Set transpose + check_dot = lambda v: True + Z_BAR = paddle.static.data(name='Z_BAR', shape=[2, 2], dtype='float') + self.transpose_args = (check_dot, Z_BAR) + self.transpose_out_shape_map = {0: X, 1: Y} + + self.all_ops = [ + # prim op: + 'add_p', + # jvp op: + 'add_p', + # transpose op: + ] + + def test_op(self): + with paddle.static.program_guard(self.main_program, + self.startup_program): + op = self.layer_help.append_op( + type=self.op_type, + inputs=self.prim_input, + outputs=self.prim_output, + attrs=self.prim_attrs) + + jvp_out = _jvp(op, *self.jvp_args) + jvp_out = flatten(jvp_out) + for k, v in self.jvp_out_shape_map.items(): + self.assertEqual(jvp_out[k].shape, v.shape) + + # Some prim ops dont have transpose rule + if hasattr(self, 'transpose_args'): + transpose_out = _transpose(op, *self.transpose_args) + transpose_out = flatten(transpose_out) + for k, v in self.transpose_out_shape_map.items(): + self.assertEqual(transpose_out[k].shape, v.shape) + + all_ops = [op.type for op in self.main_program.block(0).ops] + self.assertEqual(sorted(all_ops), sorted(self.all_ops)) + + +class TestSubPJVPAndTranspose(TestAddPJVPAndTranspose): + def init_data(self): + # Set prim op + self.op_type = 'sub_p' + X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') + Y = paddle.static.data(name='Y', shape=[5, 6], dtype='int64') + self.prim_input = {'X': X, 'Y': Y} + self.prim_output = { + 'Z': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.prim_attrs = {} + + # Set JVP + X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64') + Y_DOT = paddle.static.data(name='Y_DOT', shape=[5, 6], dtype='int64') + self.jvp_args = (X_DOT, Y_DOT) + self.jvp_out_shape_map = {0: self.prim_output['Z']} + + # Set transpose + check_dot = lambda v: True + Z_BAR = paddle.static.data(name='Z_BAR', shape=[5, 6], dtype='int64') + self.transpose_args = (check_dot, Z_BAR) + self.transpose_out_shape_map = {0: X, 1: Y} + + self.all_ops = [ + # prim op: + 'sub_p', + # jvp op: + 'sub_p', + # transpose op: + 'fill_constant_p', + 'sub_p' + ] + + +class TestMulPJVPAndTranspose(TestAddPJVPAndTranspose): + def init_data(self): + # Set prim op + self.op_type = 'mul_p' + X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') + Y = paddle.static.data(name='Y', shape=[5, 6], dtype='int64') + self.prim_input = {'X': X, 'Y': Y} + self.prim_output = { + 'Z': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.prim_attrs = {} + + # Set JVP + X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64') + Y_DOT = paddle.static.data(name='Y_DOT', shape=[5, 6], dtype='int64') + self.jvp_args = (X_DOT, Y_DOT) + self.jvp_out_shape_map = {0: self.prim_output['Z']} + + # Set transpose + check_dot = lambda v: v is X + Z_BAR = paddle.static.data(name='Z_BAR', shape=[5, 6], dtype='int64') + self.transpose_args = (check_dot, Z_BAR) + self.transpose_out_shape_map = {0: X, } + + self.all_ops = [ + # prim op: + 'mul_p', + # jvp op: + 'mul_p', + 'mul_p', + 'add_p', + # transpose op: + 'mul_p' + ] + + +class TestDivPJVPAndTranspose(TestAddPJVPAndTranspose): + def init_data(self): + # Set prim op + self.op_type = 'div_p' + X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') + Y = paddle.static.data(name='Y', shape=[5, 6], dtype='int64') + self.prim_input = {'X': X, 'Y': Y} + self.prim_output = { + 'Z': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.prim_attrs = {} + + # Set JVP + X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64') + Y_DOT = paddle.static.data(name='Y_DOT', shape=[5, 6], dtype='int64') + self.jvp_args = (X_DOT, Y_DOT) + self.jvp_out_shape_map = {0: self.prim_output['Z']} + + # Set transpose + check_dot = lambda v: v is X + Z_BAR = paddle.static.data(name='Z_BAR', shape=[5, 6], dtype='int64') + self.transpose_args = (check_dot, Z_BAR) + self.transpose_out_shape_map = {0: X, } + + self.all_ops = [ + # prim op: + 'div_p', + # jvp op: + 'div_p', + 'div_p', + 'mul_p', + 'mul_p', + 'sub_p', + # transpose op: + 'div_p' + ] + + +class TestSqrtPJVPAndTranspose(TestAddPJVPAndTranspose): + def init_data(self): + # Set prim op + self.op_type = 'sqrt_p' + X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') + self.prim_input = {'X': X, } + self.prim_output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.prim_attrs = {} + + # Set JVP + X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64') + self.jvp_args = (X_DOT, ) + self.jvp_out_shape_map = {0: self.prim_output['Y']} + + self.all_ops = [ + # prim op: + 'sqrt_p', + # jvp op: + 'div_p', + 'mul_p', + 'fill_constant_p', + # 'sqrt_p', + # transpose op: + ] + + +class TestTanhPJVPAndTranspose(TestAddPJVPAndTranspose): + def init_data(self): + # Set prim op + self.op_type = 'tanh_p' + X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') + self.prim_input = {'X': X, } + self.prim_output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.prim_attrs = {} + + # Set JVP + X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64') + self.jvp_args = (X_DOT, ) + self.jvp_out_shape_map = {0: self.prim_output['Y']} + + self.all_ops = [ + # prim op: + 'tanh_p', + # jvp op: + 'mul_p', + 'sub_p', + 'fill_constant_p', + 'mul_p', + # transpose op: + ] + + +class TestReshapePJVPAndTranspose(TestAddPJVPAndTranspose): + def init_data(self): + # Set prim op + self.op_type = 'reshape_p' + X = paddle.static.data(name='X', shape=[8, 8], dtype='int64') + self.prim_input = {'X': X, } + self.prim_output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.prim_attrs = {'shape': [2, 32]} + + # Set JVP + X_DOT = paddle.static.data(name='X_DOT', shape=[8, 8], dtype='int64') + self.jvp_args = (X_DOT, ) + self.jvp_out_shape_map = {0: self.prim_output['Y']} + + # Set transpose + check_dot = lambda v: v is X + Y_BAR = paddle.static.data(name='Y_BAR', shape=[2, 32], dtype='int64') + self.transpose_args = (check_dot, Y_BAR) + self.transpose_out_shape_map = {0: X, } + + self.all_ops = [ + # prim op: + 'reshape_p', + # jvp op: + 'reshape_p', + # transpose op: + 'reshape_p', + ] + + +class TestBroadcastPJVPAndTranspose(TestAddPJVPAndTranspose): + def init_data(self): + # Set prim op + self.op_type = 'broadcast_p' + X = paddle.static.data(name='X', shape=[10, 1], dtype='int64') + self.prim_input = {'X': X, } + self.prim_output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.prim_attrs = {'shape': [2, 10, 7]} + + # Set JVP + X_DOT = paddle.static.data(name='X_DOT', shape=[10, 7], dtype='int64') + self.jvp_args = (X_DOT, ) + self.jvp_out_shape_map = {0: self.prim_output['Y']} + + # Set transpose + check_dot = lambda v: v is X + Y_BAR = paddle.static.data( + name='Y_BAR', shape=[2, 10, 7], dtype='int64') + self.transpose_args = (check_dot, Y_BAR) + self.transpose_out_shape_map = {0: X, } + + self.all_ops = [ + # prim op: + 'broadcast_p', + # jvp op: + 'broadcast_p', + # transpose op: + 'reduce_p', + 'reshape_p' + ] + + +class TestTransposePJVPAndTranspose(TestAddPJVPAndTranspose): + def init_data(self): + # Set prim op + self.op_type = 'transpose_p' + X = paddle.static.data(name='X', shape=[2, 3, 4, 5], dtype='int64') + self.prim_input = {'X': X, } + self.prim_output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.prim_attrs = {'axis': [0, 2, 3, 1]} + + # Set JVP + X_DOT = paddle.static.data( + name='X_DOT', shape=[2, 3, 4, 5], dtype='int64') + self.jvp_args = (X_DOT, ) + self.jvp_out_shape_map = {0: self.prim_output['Y']} + + # Set transpose + check_dot = lambda v: v is X + Y_BAR = paddle.static.data( + name='Y_BAR', shape=[2, 4, 5, 3], dtype='int64') + self.transpose_args = (check_dot, Y_BAR) + self.transpose_out_shape_map = {0: X, } + + self.all_ops = [ + # prim op: + 'transpose_p', + # jvp op: + 'transpose_p', + # transpose op: + 'transpose_p', + ] + + +class TestSplitPJVPAndTranspose(TestAddPJVPAndTranspose): + def init_data(self): + # Set prim op + self.op_type = 'split_p' + X = paddle.static.data(name='X', shape=[2, 7, 10], dtype='int64') + self.prim_input = {'X': X, } + self.prim_output = { + 'YS': [ + self.layer_help.create_variable_for_type_inference( + dtype=X.dtype) for i in range(4) + ] + } + self.prim_attrs = {'num_or_sections': [2, 3, 4, 1], 'axis': 2} + + # Set JVP + X_DOT = paddle.static.data( + name='X_DOT', shape=[2, 7, 10], dtype='int64') + self.jvp_args = (X_DOT, ) + self.jvp_out_shape_map = { + 0: self.prim_output['YS'][0], + 1: self.prim_output['YS'][1], + 2: self.prim_output['YS'][2], + 3: self.prim_output['YS'][3], + } + + # Set transpose + check_dot = lambda v: v is X + YS_BAR = [ + paddle.static.data( + name='Y_BAR1', shape=[2, 7, 2], dtype='int64'), + paddle.static.data( + name='Y_BAR2', shape=[2, 7, 3], dtype='int64'), + paddle.static.data( + name='Y_BAR3', shape=[2, 7, 4], dtype='int64'), + paddle.static.data( + name='Y_BAR4', shape=[2, 7, 1], dtype='int64'), + ] + self.transpose_args = (check_dot, YS_BAR) + self.transpose_out_shape_map = {0: X, } + + self.all_ops = [ + # prim op: + 'split_p', + # jvp op: + 'split_p', + # transpose op: + 'concat_p', + ] + + +class TestConcatPJVPAndTranspose(TestAddPJVPAndTranspose): + def init_data(self): + # Set prim op + self.op_type = 'concat_p' + X = paddle.static.data(name='X', shape=[3, 9, 5], dtype='float64') + Y = paddle.static.data(name='Y', shape=[3, 2, 5], dtype='float64') + Z = paddle.static.data(name='Z', shape=[3, 3, 5], dtype='float64') + self.prim_input = {'XS': [X, Y, Z], } + self.prim_output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.prim_attrs = {'axis': 1} + + # Set JVP + XS_DOT = [ + paddle.static.data( + name='X_DOT1', shape=[3, 9, 5], dtype='float64'), + paddle.static.data( + name='X_DOT2', shape=[3, 2, 5], dtype='float64'), + paddle.static.data( + name='X_DOT3', shape=[3, 3, 5], dtype='float64'), + ] + self.jvp_args = (XS_DOT, ) + self.jvp_out_shape_map = {0: self.prim_output['Y']} + + # Set transpose + check_dot = lambda v: v is X or v is Y or v is Z + Y_BAR = paddle.static.data( + name='Y_BAR', shape=[3, 14, 5], dtype='float64') + self.transpose_args = (check_dot, Y_BAR) + self.transpose_out_shape_map = { + 0: X, + 1: Y, + 2: Z, + } + + self.all_ops = [ + # prim op: + 'concat_p', + # jvp op: + 'concat_p', + # transpose op: + 'split_p', + ] + + +class TestReducePJVPAndTranspose(TestAddPJVPAndTranspose): + def init_data(self): + # Set prim op + self.op_type = 'reduce_p' + X = paddle.static.data(name='X', shape=[2, 3, 4, 5], dtype='float64') + self.prim_input = {'X': X} + self.prim_output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.prim_attrs = {'axis': [2], 'keepdim': False} + + # Set JVP + X_DOT = paddle.static.data( + name='X_DOT1', shape=[2, 3, 4, 5], dtype='float64') + self.jvp_args = (X_DOT, ) + self.jvp_out_shape_map = {0: self.prim_output['Y']} + + # Set transpose + check_dot = lambda v: v is X + Y_BAR = paddle.static.data( + name='Y_BAR', shape=[2, 3, 5], dtype='float64') + self.transpose_args = (check_dot, Y_BAR) + self.transpose_out_shape_map = {0: X, } + + self.all_ops = [ + # prim op: + 'reduce_p', + # jvp op: + 'reduce_p', + # transpose op: + 'reshape_p', + 'broadcast_p', + ] + + +class TestMatmulPJVPAndTranspose(TestAddPJVPAndTranspose): + def init_data(self): + # Set prim op + self.op_type = 'matmul_p' + X = paddle.static.data(name='X', shape=[2, 3], dtype='float64') + Y = paddle.static.data(name='Y', shape=[3, 4], dtype='float64') + self.prim_input = {'X': X, 'Y': Y} + self.prim_output = { + 'Z': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.prim_attrs = {} + + # Set JVP + X_DOT = paddle.static.data(name='X_DOT', shape=[2, 3], dtype='float64') + Y_DOT = paddle.static.data(name='Y_DOT', shape=[3, 4], dtype='float64') + self.jvp_args = (X_DOT, Y_DOT) + self.jvp_out_shape_map = {0: self.prim_output['Z']} + + # Set transpose + check_dot = lambda v: v is X + Z_BAR = paddle.static.data(name='Z_BAR', shape=[2, 4], dtype='float64') + self.transpose_args = (check_dot, Z_BAR) + self.transpose_out_shape_map = {0: X, } + + self.all_ops = [ + # prim op: + 'matmul_p', + # jvp op: + 'matmul_p', + 'matmul_p', + 'add_p', + # transpose op: + 'matmul_p', + 'transpose_p', + ] + + +class TestSliceSelectPJVPAndTranspose(TestAddPJVPAndTranspose): + def init_data(self): + # Set prim op + self.op_type = 'slice_select_p' + X = paddle.static.data(name='X', shape=[3, 20], dtype='float64') + self.prim_input = {'X': X, } + self.prim_output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.prim_attrs = { + 'axis': [1], + 'starts': [0], + 'ends': [20], + 'strides': [2] + } + + # Set JVP + X_DOT = paddle.static.data(name='X_DOT', shape=[3, 20], dtype='float64') + self.jvp_args = (X_DOT, ) + self.jvp_out_shape_map = {0: self.prim_output['Y']} + + # Set transpose + check_dot = lambda v: v is X + Y_BAR = paddle.static.data(name='Y_BAR', shape=[3, 10], dtype='float64') + self.transpose_args = (check_dot, Y_BAR) + self.transpose_out_shape_map = {0: X, } + + self.all_ops = [ + # prim op: + 'slice_select_p', + # jvp op: + 'slice_select_p', + # transpose op: + 'slice_assign_p', + 'fill_constant_p', + ] + + +class TestSliceAssignPJVPAndTranspose(TestAddPJVPAndTranspose): + def init_data(self): + # Set prim op + self.op_type = 'slice_assign_p' + X = paddle.static.data(name='X', shape=[3, 20], dtype='float64') + Y = paddle.static.data(name='Y', shape=[3, 5], dtype='float64') + self.prim_input = {'X': X, 'Y': Y} + self.prim_output = { + 'Z': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.prim_attrs = { + 'axis': [1], + 'starts': [0], + 'ends': [10], + 'strides': [2] + } + + # Set JVP + X_DOT = paddle.static.data(name='X_DOT', shape=[3, 20], dtype='float64') + Y_DOT = paddle.static.data(name='Y_DOT', shape=[3, 5], dtype='float64') + self.jvp_args = (X_DOT, Y_DOT) + self.jvp_out_shape_map = {0: self.prim_output['Z']} + + # Set transpose + check_dot = lambda v: v is X or v is Y + Z_BAR = paddle.static.data(name='Z_BAR', shape=[3, 20], dtype='float64') + self.transpose_args = (check_dot, Z_BAR) + self.transpose_out_shape_map = {0: X, 1: Y} + + self.all_ops = [ + # prim op: + 'slice_assign_p', + # jvp op: + 'slice_assign_p', + # transpose op: + 'slice_assign_p', + 'slice_select_p', + 'fill_constant_p' + ] + + +class TestGatherPJVPAndTranspose(TestAddPJVPAndTranspose): + def init_data(self): + # Set prim op + self.op_type = 'gather_p' + X = paddle.static.data(name='X', shape=[9, 5], dtype='float64') + IndexTensor = paddle.static.data( + name='IndexTensor', shape=[3], dtype='int32') + self.prim_input = {'X': X, 'IndexTensor': IndexTensor} + self.prim_output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.prim_attrs = {'axis': 1} + + # Set JVP + X_DOT = paddle.static.data(name='X_DOT', shape=[9, 5], dtype='float64') + self.jvp_args = ( + X_DOT, + IndexTensor, ) + self.jvp_out_shape_map = {0: self.prim_output['Y']} + + # Set transpose + check_dot = lambda v: v is X + Y_BAR = paddle.static.data(name='Y_BAR', shape=[9, 3], dtype='float64') + self.transpose_args = (check_dot, Y_BAR) + self.transpose_out_shape_map = {0: X, } + + self.all_ops = [ + # prim op: + 'gather_p', + # jvp op: + 'gather_p', + # transpose op: + 'scatter_add_p', + 'fill_constant_p', + ] + + +class TestScatterAddPJVPAndTranspose(TestAddPJVPAndTranspose): + def init_data(self): + # Set prim op + self.op_type = 'scatter_add_p' + X = paddle.static.data(name='X', shape=[9, 5], dtype='float64') + Y = paddle.static.data(name='Y', shape=[9, 3], dtype='float64') + IndexTensor = paddle.static.data( + name='IndexTensor', shape=[3], dtype='int32') + self.prim_input = {'X': X, 'Y': Y, 'IndexTensor': IndexTensor} + self.prim_output = { + 'Z': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.prim_attrs = {'axis': 1} + + # Set JVP + X_DOT = paddle.static.data(name='X_DOT', shape=[9, 5], dtype='float64') + Y_DOT = paddle.static.data(name='Y_DOT', shape=[9, 3], dtype='float64') + self.jvp_args = (X_DOT, Y_DOT) + self.jvp_out_shape_map = {0: self.prim_output['Z']} + + # Set transpose + check_dot = lambda v: v is X or v is Y + Z_BAR = paddle.static.data(name='Z_BAR', shape=[9, 5], dtype='float64') + self.transpose_args = (check_dot, Z_BAR) + self.transpose_out_shape_map = {0: X, 1: Y} + + self.all_ops = [ + # prim op: + 'scatter_add_p', + # jvp op: + 'scatter_add_p', + # transpose op: + 'scatter_add_p', + 'fill_constant_p', + 'gather_p' + ] + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py b/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py new file mode 100644 index 0000000000..24c8febccf --- /dev/null +++ b/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py @@ -0,0 +1,360 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.layers.utils import flatten +from paddle.incubate.autograd.primrules import _orig2prim, _prim2orig, _jvp, _transpose + +paddle.enable_static() + + +############################ Test orig2prim rules ############################ +class TestElementWiseAddOrig2Prim(unittest.TestCase): + def setUp(self): + self.main_program = paddle.static.Program() + self.startup_program = paddle.static.Program() + self.layer_help = LayerHelper('TestOrig2Prim') + + with paddle.static.program_guard(self.main_program, + self.startup_program): + self.init_data() + + def init_data(self): + self.op_type = 'elementwise_add' + X = paddle.static.data(name='X', shape=[2, 2], dtype='float') + Y = paddle.static.data(name='Y', shape=[2, 2], dtype='float') + + self.input = {'X': X, 'Y': Y} + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.orig2prim_args = (X, Y) + self.all_ops = ['elementwise_add', 'add_p'] + # { prim_op_output_index: orig_op_output_var } + self.out_map = {0: self.output['Out']} + + def test_op(self): + with paddle.static.program_guard(self.main_program, + self.startup_program): + op = self.layer_help.append_op( + type=self.op_type, + inputs=self.input, + outputs=self.output, + attrs=self.attrs) + + prim_out = _orig2prim(op, *self.orig2prim_args) + all_ops = [op.type for op in self.main_program.block(0).ops] + + self.assertEqual(sorted(all_ops), sorted(self.all_ops)) + prim_out = flatten(prim_out) + for k, v in self.out_map.items(): + self.assertEqual(prim_out[k].shape, v.shape) + + +class TestSqrtOrig2Prim(TestElementWiseAddOrig2Prim): + def init_data(self): + self.op_type = 'sqrt' + X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') + + self.input = {'X': X, } + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.orig2prim_args = (X, ) + self.all_ops = ['sqrt', 'sqrt_p'] + # { prim_op_output_index: orig_op_output_var } + self.out_map = {0: self.output['Out']} + + +class TestElementWiseMulOrig2Prim(TestElementWiseAddOrig2Prim): + def init_data(self): + self.op_type = 'elementwise_mul' + X = paddle.static.data(name='X', shape=[8, 8], dtype='float') + Y = paddle.static.data(name='Y', shape=[8, 8], dtype='float') + + self.input = {'X': X, 'Y': Y} + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.orig2prim_args = (X, Y) + self.all_ops = ['elementwise_mul', 'mul_p'] + # { prim_op_output_index: orig_op_output_var } + self.out_map = {0: self.output['Out']} + + +class TestMatmulV2Orig2Prim(TestElementWiseAddOrig2Prim): + def init_data(self): + self.op_type = 'matmul_v2' + X = paddle.static.data(name='X', shape=[3, 4], dtype='float') + Y = paddle.static.data(name='Y', shape=[4, 3], dtype='float') + + self.input = {'X': X, 'Y': Y} + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {'trans_x': True, 'trans_y': True} + + self.orig2prim_args = (X, Y) + self.all_ops = ['matmul_v2', 'transpose_p', 'transpose_p', 'matmul_p'] + self.out_map = {0: self.output['Out']} + + +class TestTanhOrig2Prim(TestElementWiseAddOrig2Prim): + def init_data(self): + self.op_type = 'tanh' + X = paddle.static.data(name='X', shape=[3, 4], dtype='float') + + self.input = {'X': X, } + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.orig2prim_args = (X, ) + self.all_ops = ['tanh', 'tanh_p'] + self.out_map = {0: self.output['Out']} + + +class TestReshape2Orig2Prim(TestElementWiseAddOrig2Prim): + def init_data(self): + self.op_type = 'reshape2' + X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') + + self.input = {'X': X, } + self.output = { + 'Out': X, + 'XShape': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {'shape': [6, 5]} + + self.orig2prim_args = ( + None, + None, + X, ) + self.all_ops = ['reshape2', 'reshape_p', 'fill_constant_p'] + # Do not checke XShape + self.out_map = {0: self.output['Out']} + + +class TestConcatOrig2Prim(TestElementWiseAddOrig2Prim): + def init_data(self): + self.op_type = 'concat' + X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') + Y = paddle.static.data(name='Y', shape=[3, 6], dtype='int64') + + self.input = {'X': [X, Y], } + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {'axis': 0} + + self.orig2prim_args = ( + None, + (X, Y), ) + self.all_ops = ['concat', 'concat_p'] + self.out_map = {0: self.output['Out']} + + +class TestSliceOrig2Prim(TestElementWiseAddOrig2Prim): + def init_data(self): + self.op_type = 'slice' + X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') + + self.input = {'Input': X, } + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = { + 'axes': [0], + 'starts': [1], + 'ends': [4], + } + + self.orig2prim_args = (None, None, X, None, None) + self.all_ops = ['slice', 'slice_select_p'] + self.out_map = {0: self.output['Out']} + + +class TestFillZerosLikeOrig2Prim(TestElementWiseAddOrig2Prim): + def init_data(self): + self.op_type = 'fill_zeros_like' + X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') + + self.input = {'X': X, } + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.orig2prim_args = (X, ) + self.all_ops = ['fill_zeros_like', 'fill_constant_p'] + self.out_map = {0: self.output['Out']} + + +class TestSumOrig2Prim(TestElementWiseAddOrig2Prim): + def init_data(self): + self.op_type = 'sum' + X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') + Y = paddle.static.data(name='Y', shape=[5, 6], dtype='int64') + + self.input = {'X': X, 'Y': Y} + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.orig2prim_args = ((X, Y), ) + self.all_ops = ['sum', 'add_p'] + self.out_map = {0: self.output['Out']} + + +class TestPNormOrig2Prim1(TestElementWiseAddOrig2Prim): + def init_data(self): + self.op_type = 'p_norm' + X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') + + self.input = {'X': X, } + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = { + 'porder': 1, + 'asvector': True, + } + + self.orig2prim_args = (X, ) + self.all_ops = ['p_norm', 'reshape_p', 'sqrt_p', 'reduce_p', 'mul_p'] + self.out_map = {0: self.output['Out']} + + +class TestPNormOrig2Prim2(TestElementWiseAddOrig2Prim): + def init_data(self): + self.op_type = 'p_norm' + X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') + + self.input = {'X': X, } + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = { + 'porder': 2, + 'asvector': True, + } + + self.orig2prim_args = (X, ) + self.all_ops = ['p_norm', 'reshape_p', 'sqrt_p', 'reduce_p', 'mul_p'] + self.out_map = {0: self.output['Out']} + + +class TestIndexSelectOrig2Prim(TestElementWiseAddOrig2Prim): + def init_data(self): + self.op_type = 'index_select' + X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') + Index = paddle.static.data(name='Index', shape=[2], dtype='int32') + + self.input = {'X': X, 'Index': Index} + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {'dim': 0, } + + self.orig2prim_args = ( + Index, + X, ) + self.all_ops = ['index_select', 'gather_p'] + self.out_map = {0: self.output['Out']} + + +class TestElementwiseSubOrig2Prim(TestElementWiseAddOrig2Prim): + def init_data(self): + self.op_type = 'elementwise_sub' + X = paddle.static.data(name='X', shape=[5, 6], dtype='int32') + Y = paddle.static.data(name='Y', shape=[6], dtype='int32') + + self.input = {'X': X, 'Y': Y} + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {'dim': 0, } + + self.orig2prim_args = ( + X, + Y, ) + self.all_ops = ['elementwise_sub', 'broadcast_p', 'sub_p'] + self.out_map = {0: self.output['Out']} + + +class TestScaleOrig2Prim(TestElementWiseAddOrig2Prim): + def init_data(self): + self.op_type = 'scale' + X = paddle.static.data(name='X', shape=[10, 7], dtype='int32') + + self.input = {'X': X, } + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {'scale': 2.0, 'bias': 1.0, 'bias_after_scale': True} + + self.orig2prim_args = ( + None, + X, ) + self.all_ops = [ + 'scale', 'fill_constant_p', 'fill_constant_p', 'mul_p', 'add_p' + ] + self.out_map = {0: self.output['Out']} + + +class TestAssignOrig2Prim(TestElementWiseAddOrig2Prim): + def init_data(self): + self.op_type = 'assign' + X = paddle.static.data(name='X', shape=[10, 7], dtype='int32') + + self.input = {'X': X, } + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.orig2prim_args = (X, ) + self.all_ops = ['assign', 'fill_constant_p', 'add_p'] + self.out_map = {0: self.output['Out']} + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py b/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py new file mode 100644 index 0000000000..15ab016fc5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py @@ -0,0 +1,381 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.layers.utils import flatten +from paddle.incubate.autograd.primrules import _orig2prim, _prim2orig, _jvp, _transpose + +paddle.enable_static() + + +############################ Test prim2orig rules ############################ +class TestAddPPrim2Orig(unittest.TestCase): + def setUp(self): + self.main_program = paddle.static.Program() + self.startup_program = paddle.static.Program() + self.layer_help = LayerHelper('TestPrim2Orig') + + with paddle.static.program_guard(self.main_program, + self.startup_program): + self.init_data() + + def init_data(self): + self.op_type = 'add_p' + X = paddle.static.data(name='X', shape=[2, 2], dtype='float') + Y = paddle.static.data(name='Y', shape=[2, 2], dtype='float') + + self.input = {'X': X, 'Y': Y} + self.output = { + 'Z': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.prim2orig_args = (X, Y) + self.all_ops = ['add_p', 'elementwise_add'] + # { prim_op_output_var: orign_op_out_index } + self.out_map = {self.output['Z']: 0} + + def test_op(self): + with paddle.static.program_guard(self.main_program, + self.startup_program): + op = self.layer_help.append_op( + type=self.op_type, + inputs=self.input, + outputs=self.output, + attrs=self.attrs) + + orig_out = _prim2orig(op, *self.prim2orig_args) + all_ops = [op.type for op in self.main_program.block(0).ops] + self.assertEqual(sorted(all_ops), sorted(self.all_ops)) + orig_out = flatten(orig_out) + for k, v in self.out_map.items(): + self.assertEqual(k.shape, orig_out[v].shape) + + +class TestSubPPrim2Orig(TestAddPPrim2Orig): + def init_data(self): + self.op_type = 'sub_p' + X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') + Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64') + + self.input = {'X': X, 'Y': Y} + self.output = { + 'Z': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.prim2orig_args = (X, Y) + self.all_ops = ['sub_p', 'elementwise_sub'] + self.out_map = {self.output['Z']: 0} + + +class TestMulPPrim2Orig(TestAddPPrim2Orig): + def init_data(self): + self.op_type = 'mul_p' + X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') + Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64') + + self.input = {'X': X, 'Y': Y} + self.output = { + 'Z': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.prim2orig_args = (X, Y) + self.all_ops = ['mul_p', 'elementwise_mul'] + self.out_map = {self.output['Z']: 0} + + +class TestDivPPrim2Orig(TestAddPPrim2Orig): + def init_data(self): + self.op_type = 'div_p' + X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') + Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64') + + self.input = {'X': X, 'Y': Y} + self.output = { + 'Z': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.prim2orig_args = (X, Y) + self.all_ops = ['div_p', 'elementwise_div'] + self.out_map = {self.output['Z']: 0} + + +class TestSqrtPPrim2Orig(TestAddPPrim2Orig): + def init_data(self): + self.op_type = 'sqrt_p' + X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') + + self.input = {'X': X, } + self.output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.prim2orig_args = (X, ) + self.all_ops = ['sqrt_p', 'sqrt'] + self.out_map = {self.output['Y']: 0} + + +class TestTanhPPrim2Orig(TestAddPPrim2Orig): + def init_data(self): + self.op_type = 'tanh_p' + X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') + + self.input = {'X': X, } + self.output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.prim2orig_args = (X, ) + self.all_ops = ['tanh_p', 'tanh'] + self.out_map = {self.output['Y']: 0} + + +class TestReshapePPrim2Orig(TestAddPPrim2Orig): + def init_data(self): + self.op_type = 'reshape_p' + X = paddle.static.data(name='X', shape=[2, 8], dtype='float64') + + self.input = {'X': X, } + self.output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {'shape': [4, 4]} + + self.prim2orig_args = (X, ) + self.all_ops = ['reshape_p', 'reshape2'] + self.out_map = {self.output['Y']: 0} + + +class TestBroadcastPPrim2Orig(TestAddPPrim2Orig): + def init_data(self): + self.op_type = 'broadcast_p' + X = paddle.static.data(name='X', shape=[2, 8], dtype='float64') + + self.input = {'X': X, } + self.output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {'shape': [10, 2, 8]} + + self.prim2orig_args = (X, ) + self.all_ops = ['broadcast_p', 'expand_v2'] + self.out_map = {self.output['Y']: 0} + + +class TestTransposePPrim2Orig(TestAddPPrim2Orig): + def init_data(self): + self.op_type = 'transpose_p' + X = paddle.static.data(name='X', shape=[7, 8, 9, 10], dtype='float64') + + self.input = {'X': X, } + self.output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {'axis': [1, 2, 0, 3]} + + self.prim2orig_args = (X, ) + self.all_ops = ['transpose_p', 'transpose2'] + self.out_map = {self.output['Y']: 0} + + +class TestSplitPPrim2Orig(TestAddPPrim2Orig): + def init_data(self): + self.op_type = 'split_p' + X = paddle.static.data(name='X', shape=[3, 9, 5], dtype='float64') + + self.input = {'X': X, } + self.output = { + 'YS': [ + self.layer_help.create_variable_for_type_inference( + dtype=X.dtype) for i in range(3) + ] + } + self.attrs = {'num_or_sections': [2, 3, 4], 'axis': 1} + + self.prim2orig_args = (X, ) + self.all_ops = ['split_p', 'split'] + self.out_map = { + self.output['YS'][0]: 0, + self.output['YS'][1]: 1, + self.output['YS'][2]: 2, + } + + +class TestConcatPPrim2Orig(TestAddPPrim2Orig): + def init_data(self): + self.op_type = 'concat_p' + X = paddle.static.data(name='X', shape=[3, 9, 5], dtype='float64') + Y = paddle.static.data(name='Y', shape=[2, 9, 5], dtype='float64') + Z = paddle.static.data(name='Z', shape=[1, 9, 5], dtype='float64') + + self.input = {'XS': [X, Y, Z], } + self.output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {'axis': 0} + + self.prim2orig_args = ((X, Y, Z), ) + self.all_ops = ['concat_p', 'concat'] + self.out_map = {self.output['Y']: 0} + + +class TestReducePPrim2Orig(TestAddPPrim2Orig): + def init_data(self): + self.op_type = 'reduce_p' + X = paddle.static.data(name='X', shape=[3, 9, 5], dtype='float64') + + self.input = {'X': X} + self.output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {'axis': [1], 'keepdim': True} + + self.prim2orig_args = (X, ) + self.all_ops = ['reduce_p', 'reduce_sum'] + self.out_map = {self.output['Y']: 0} + + +class TestMatmulPPrim2Orig(TestAddPPrim2Orig): + def init_data(self): + self.op_type = 'matmul_p' + X = paddle.static.data(name='X', shape=[9, 5], dtype='float64') + Y = paddle.static.data(name='Y', shape=[5, 9], dtype='float64') + + self.input = {'X': X, 'Y': Y} + self.output = { + 'Z': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.prim2orig_args = (X, Y) + self.all_ops = ['matmul_p', 'matmul_v2'] + self.out_map = {self.output['Z']: 0} + + +class TestSliceSelectPPrim2Orig(TestAddPPrim2Orig): + def init_data(self): + self.op_type = 'slice_select_p' + X = paddle.static.data(name='X', shape=[9, 5], dtype='float64') + + self.input = {'X': X, } + self.output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {'axis': [0], 'starts': [1], 'ends': [8], 'strides': [2]} + + self.prim2orig_args = (X, ) + self.all_ops = ['slice_select_p', 'strided_slice'] + self.out_map = {self.output['Y']: 0} + + +class TestSliceAssignPPrim2Orig(TestAddPPrim2Orig): + def init_data(self): + self.op_type = 'slice_assign_p' + X = paddle.static.data(name='X', shape=[9, 5], dtype='float64') + Y = paddle.static.data(name='Y', shape=[9, 3], dtype='float64') + + self.input = {'X': X, 'Y': Y} + self.output = { + 'Z': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {'axis': [1], 'starts': [0], 'ends': [3], 'strides': [1]} + + self.prim2orig_args = (X, Y) + self.all_ops = ['slice_assign_p', 'assign', 'set_value'] + self.out_map = {self.output['Z']: 0} + + +class TestGatherPPrim2Orig(TestAddPPrim2Orig): + def init_data(self): + self.op_type = 'gather_p' + X = paddle.static.data(name='X', shape=[9, 5], dtype='float64') + IndexTensor = paddle.static.data( + name='IndexTensor', shape=[3], dtype='int32') + + self.input = {'X': X, 'IndexTensor': IndexTensor} + self.output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {'axis': 0, } + + self.prim2orig_args = ( + IndexTensor, + X, ) + self.all_ops = ['gather_p', 'gather'] + self.out_map = {self.output['Y']: 0} + + +class TestScatterAddPPrim2Orig(TestAddPPrim2Orig): + def init_data(self): + self.op_type = 'scatter_add_p' + X = paddle.static.data(name='X', shape=[9, 5], dtype='float64') + Y = paddle.static.data(name='Y', shape=[3, 5], dtype='float64') + IndexTensor = paddle.static.data( + name='IndexTensor', shape=[3], dtype='int32') + + self.input = {'X': X, 'Y': Y, 'IndexTensor': IndexTensor} + self.output = { + 'Z': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {'axis': 0, } + + self.prim2orig_args = (IndexTensor, X, Y) + self.all_ops = [ + 'scatter_add_p', 'fill_any_like', 'scatter', 'elementwise_add' + ] + self.out_map = {self.output['Z']: 0} + + +class TestFillConstantPPrim2Orig(TestAddPPrim2Orig): + def init_data(self): + self.op_type = 'fill_constant_p' + + self.input = {} + self.output = { + 'Y': + self.layer_help.create_variable_for_type_inference(paddle.int32) + } + self.attrs = {'value': 10, 'shape': [5, 5], 'dtype': paddle.int32} + + self.prim2orig_args = () + self.all_ops = ['fill_constant_p', 'fill_constant'] + self.out_map = {self.output['Y']: 0} + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_primops.py b/python/paddle/fluid/tests/unittests/autograd/test_primops.py similarity index 95% rename from python/paddle/fluid/tests/unittests/test_primops.py rename to python/paddle/fluid/tests/unittests/autograd/test_primops.py index cbf77c2666..e6a8c4ec3f 100644 --- a/python/paddle/fluid/tests/unittests/test_primops.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_primops.py @@ -14,12 +14,13 @@ import unittest import numpy as np - import paddle -from paddle.autograd.primops import ( +from paddle.incubate.autograd.primops import ( neg, set_value, add, sub, mul, div, sqrt, tanh, reshape, broadcast, transpose, split, concat, reduce, matmul, slice_select, slice_assign, gather, scatter_add, fill_const) +from paddle.incubate.autograd.primx import Transform, topo_path, orig2prim, prim2orig, _gradients +from paddle.incubate.autograd.utils import enable_prim, disable_prim, prim_enabled class TestPyPrimOps(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/autograd/test_transform.py b/python/paddle/fluid/tests/unittests/autograd/test_transform.py new file mode 100644 index 0000000000..a2b75f5d7b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/autograd/test_transform.py @@ -0,0 +1,313 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np + +import paddle +from paddle.incubate.autograd.primx import Transform, orig2prim, prim2orig +from paddle.fluid.layers.utils import flatten + +paddle.enable_static() + + +class TestAutoGradTransformForAdd(unittest.TestCase): + def setUp(self): + self.main_program = paddle.static.Program() + self.startup_program = paddle.static.Program() + + with paddle.static.program_guard(self.main_program, + self.startup_program): + self.init_data() + + def init_data(self): + # { input_index: input_shape } + self.xs_shape_map = {0: (20, 40), 1: (20, 40)} + # { output_index: output_shape } + self.ys_shape_map = {0: (20, 40)} + X0 = paddle.static.data( + name='X0', shape=self.xs_shape_map[0], dtype='float32') + X0.stop_gradient = False + X1 = paddle.static.data( + name='X1', shape=self.xs_shape_map[1], dtype='float32') + X1.stop_gradient = False + + A = paddle.tanh(X0) + B = paddle.tanh(X1) + Y = paddle.add(A, B) + + self.orig_xs = [X0, X1] + self.orig_ys = [Y, ] + + self.orig_ops = ['tanh', 'tanh', 'elementwise_add'] + self.orig2prim_ops = ['tanh_p', 'tanh_p', 'add_p'] + self.linearize_ops = self.orig2prim_ops + [ + # call fill_const() in linearize() function + 'fill_constant_p', + 'fill_constant_p', + # linearized op + 'mul_p', + 'sub_p', + 'fill_constant_p', + 'mul_p', + 'mul_p', + 'sub_p', + 'fill_constant_p', + 'mul_p', + 'add_p', + ] + self.transpose_ops = self.orig2prim_ops + [ + # call fill_const() in transpose() function + 'fill_constant_p', + # linearized op after remove path + 'fill_constant_p', + 'fill_constant_p', + 'mul_p', + 'sub_p', + 'fill_constant_p', + 'mul_p', + 'sub_p', + 'fill_constant_p', + # transposed op + 'mul_p', + 'mul_p' + ] + self.prim2orig_ops = [ + 'tanh', 'tanh', 'elementwise_add', 'fill_constant', 'fill_constant', + 'fill_constant', 'elementwise_mul', 'elementwise_sub', + 'fill_constant', 'elementwise_mul', 'elementwise_sub', + 'fill_constant', 'elementwise_mul', 'elementwise_mul' + ] + + def test_run(self): + # Must using with program_guard(), otherwise prim ops will append other block + with paddle.static.program_guard(self.main_program, + self.startup_program): + ad = Transform(self.main_program.block(0)) + orig_ops = [op.type for op in self.main_program.block(0).ops] + self.assertEqual(sorted(orig_ops), sorted(self.orig_ops)) + + # Test orig2prim + orig2prim(block=self.main_program.block(0)) + orig2prim_ops = [op.type for op in self.main_program.block(0).ops] + self.assertEqual(sorted(orig2prim_ops), sorted(self.orig2prim_ops)) + + # Test linearize + xs_dot, ys_dot = ad.linearize(self.orig_xs, self.orig_ys) + linearize_ops = [op.type for op in self.main_program.block(0).ops] + self.assertEqual(sorted(linearize_ops), sorted(self.linearize_ops)) + flatten_xs_dot = flatten(xs_dot) + for k, v in self.xs_shape_map.items(): + self.assertEqual(flatten_xs_dot[k].shape, v) + flatten_ys_dot = flatten(ys_dot) + for k, v in self.ys_shape_map.items(): + self.assertEqual(flatten_ys_dot[k].shape, v) + + # Test transpose + ys_bar, xs_bar = ad.transpose(ys_dot, xs_dot, retain_fwd=False) + transpose_ops = [op.type for op in self.main_program.block(0).ops] + self.assertEqual(sorted(transpose_ops), sorted(self.transpose_ops)) + flatten_xs_bar = flatten(xs_bar) + for k, v in self.xs_shape_map.items(): + # There may be None in the result of transpose like gather op + if flatten_xs_bar[k] is not None: + self.assertEqual(flatten_xs_bar[k].shape, v) + flatten_ys_bar = flatten(ys_bar) + for k, v in self.ys_shape_map.items(): + self.assertEqual(flatten_ys_bar[k].shape, v) + + # Test prim2orig + prim2orig(block=self.main_program.block(0)) + prim2orig_ops = [op.type for op in self.main_program.block(0).ops] + self.assertEqual(sorted(prim2orig_ops), sorted(self.prim2orig_ops)) + + +class TestAutoGradTransformForMatmul(TestAutoGradTransformForAdd): + def init_data(self): + # { input_index: input_shape } + self.xs_shape_map = {0: (100, 2), 1: (5, 2)} + # { output_index: output_shape } + self.ys_shape_map = {0: (100, 5)} + X0 = paddle.static.data( + 'X0', shape=self.xs_shape_map[0], dtype='float32') + X0.stop_gradient = False + X1 = paddle.static.data( + 'X1', shape=self.xs_shape_map[1], dtype='float32') + X1.stop_gradient = False + + A = paddle.reshape(X1, [2, 5]) + B = paddle.scale(A, scale=2.0, bias=2.0) + Y = paddle.matmul(X0, B) + + self.orig_xs = [X0, X1] + self.orig_ys = [Y, ] + + self.orig_ops = ['reshape2', 'scale', 'matmul_v2'] + self.orig2prim_ops = [ + 'reshape_p', 'fill_constant_p', 'fill_constant_p', + 'fill_constant_p', 'mul_p', 'add_p', 'matmul_p' + ] + self.linearize_ops = self.orig2prim_ops + [ + # call fill_const() in linearize() function + 'fill_constant_p', + 'fill_constant_p', + # linearized op + 'reshape_p', + 'mul_p', + # 'mul_p', # JVP rules handle `None` input, some op will not be appended + # 'add_p', + # 'add_p', + 'matmul_p', + 'matmul_p', + 'add_p' + ] + self.transpose_ops = self.orig2prim_ops + [ + # call fill_const() in transpose() function + 'fill_constant_p', + # linearized op after remove path + 'fill_constant_p', + 'fill_constant_p', + 'mul_p', + # transposed op + 'transpose_p', + 'matmul_p', + 'transpose_p', + 'matmul_p', + # 'mul_p', + 'reshape_p', + ] + + self.prim2orig_ops = [ + 'reshape2', + 'fill_constant', + 'fill_constant', + 'fill_constant', + 'elementwise_mul', + 'elementwise_add', + 'matmul_v2', + 'fill_constant', + 'fill_constant', + 'fill_constant', + 'elementwise_mul', + 'transpose2', + 'matmul_v2', + 'transpose2', + 'matmul_v2', + # 'elementwise_mul', + 'reshape2', + ] + + +class TestAutoGradTransformForIndexSelect(TestAutoGradTransformForAdd): + def init_data(self): + # { input_index: input_shape } + self.xs_shape_map = {0: (7, 8, 9), 1: (8, 1), 2: (7, 8, 9), 3: (3, )} + # { output_index: output_shape } + self.ys_shape_map = {0: (3, 16, 9)} + + X0 = paddle.static.data( + 'X0', shape=self.xs_shape_map[0], dtype='float32') + X0.stop_gradient = False + X1 = paddle.static.data( + 'X1', shape=self.xs_shape_map[1], dtype='float32') + X1.stop_gradient = False + X2 = paddle.static.data( + 'X2', shape=self.xs_shape_map[2], dtype='float32') + X2.stop_gradient = False + X3 = paddle.static.data('X3', shape=self.xs_shape_map[3], dtype='int32') + X3.stop_gradient = False + + A = paddle.add(X0, X1) # (7, 8, 9) + B = paddle.norm(x=A, p=2) # (1, ) + C = paddle.subtract(X2, B) # (7, 8, 9) + D = paddle.concat(x=(A, C), axis=1) # (7, 16, 9) + Y = paddle.index_select(D, X3, axis=0) # (3, 16, 9) + + self.orig_xs = [X0, X1, X2, X3] + self.orig_ys = [Y, ] + self.orig_ops = [ + 'elementwise_add', 'p_norm', 'elementwise_sub', 'concat', + 'index_select' + ] + self.orig2prim_ops = [ + 'broadcast_p', 'add_p', 'reshape_p', 'mul_p', 'reduce_p', 'sqrt_p', + 'broadcast_p', 'sub_p', 'concat_p', 'gather_p' + ] + self.linearize_ops = self.orig2prim_ops + [ + # call fill_const() in linearize() function + 'fill_constant_p', + 'fill_constant_p', + 'fill_constant_p', + 'fill_constant_p', + # linearized op + 'broadcast_p', + 'add_p', + 'reshape_p', + 'mul_p', + 'mul_p', + 'add_p', + 'reduce_p', + 'fill_constant_p', # 'sqrt_p', Will not append sqrt_p op when apply JVP for sqrt_p + 'mul_p', + 'div_p', + 'broadcast_p', + 'sub_p', + 'concat_p', + 'gather_p' + ] + self.transpose_ops = self.orig2prim_ops + [ + # call fill_const() in transpose() function + 'fill_constant_p', + # linearized op after remove path + 'fill_constant_p', + 'fill_constant_p', + 'fill_constant_p', + 'fill_constant_p', + 'fill_constant_p', + 'mul_p', + # transposed op + 'reduce_p', + 'reshape_p', + 'reshape_p', + 'mul_p', + 'mul_p', + 'reshape_p', + 'broadcast_p', + 'div_p', + 'reduce_p', + 'reshape_p', + 'fill_constant_p', + 'sub_p', + 'split_p', + 'fill_constant_p', + 'scatter_add_p', + 'add_p', # The output of the op is used by multiple subsequent ops + 'add_p', + ] + + self.prim2orig_ops = [ + 'expand_v2', 'elementwise_add', 'reshape2', 'elementwise_mul', + 'reduce_sum', 'sqrt', 'expand_v2', 'elementwise_sub', 'concat', + 'gather', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', + 'elementwise_mul', 'reduce_sum', 'reshape2', 'reshape2', + 'elementwise_mul', 'elementwise_mul', 'reshape2', 'expand_v2', + 'elementwise_div', 'reduce_sum', 'reshape2', 'fill_constant', + 'elementwise_sub', 'split', 'fill_constant', 'fill_any_like', + 'elementwise_add', 'scatter', 'elementwise_add', 'elementwise_add' + ] + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/incubate/autograd/__init__.py b/python/paddle/incubate/autograd/__init__.py index 5528bb4d06..a57dac02be 100644 --- a/python/paddle/incubate/autograd/__init__.py +++ b/python/paddle/incubate/autograd/__init__.py @@ -12,7 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. from paddle.autograd.functional import Hessian, Jacobian, jvp, vjp +from .primx import prim2orig +from .utils import enable_prim, disable_prim, prim_enabled __all__ = [ # noqa - 'vjp', 'jvp', 'Jacobian', 'Hessian' + 'vjp', + 'jvp', + 'Jacobian', + 'Hessian', + 'prim2orig', + 'enable_prim', + 'disable_prim', + 'prim_enabled' ] diff --git a/python/paddle/autograd/primops.py b/python/paddle/incubate/autograd/primops.py similarity index 77% rename from python/paddle/autograd/primops.py rename to python/paddle/incubate/autograd/primops.py index 66f641e544..11e0e51cb7 100644 --- a/python/paddle/autograd/primops.py +++ b/python/paddle/incubate/autograd/primops.py @@ -13,8 +13,6 @@ # limitations under the License. import paddle -from paddle.fluid import unique_name, core -from paddle.fluid.framework import default_main_program, default_startup_program from paddle.fluid.layer_helper import LayerHelper from .primreg import REGISTER_FN @@ -136,7 +134,9 @@ def split(x, num_or_sections, axis=0, outs=None): if isinstance(num_or_sections, (list, tuple)): n = len(num_or_sections) else: - assert isinstance(num_or_sections, int) + if not isinstance(num_or_sections, int): + raise TypeError( + f'num_or_sections must be int, but got {type(num_or_sections)}.') n = num_or_sections attrs = {'num_or_sections': num_or_sections, 'axis': axis} @@ -157,7 +157,8 @@ def split(x, num_or_sections, axis=0, outs=None): @REGISTER_FN('concat_p', 'XS', 'Y') def concat(xs, axis=0, out=None): - assert isinstance(xs, (list, tuple)) and len(xs) > 0 + if isinstance(xs, paddle.fluid.framework.Variable): + xs = [xs] attrs = {'axis': axis} helper = LayerHelper('concat_p', **locals()) if out is None: @@ -172,9 +173,10 @@ def concat(xs, axis=0, out=None): @REGISTER_FN('reduce_p', 'X', 'Y') def reduce(x, axis, keepdim=False, out=None): - assert isinstance(axis, (tuple, list)) - assert isinstance(keepdim, bool) - + if not isinstance(axis, (tuple, list)): + raise TypeError(f'axis must be tuple or list, but got {type(axis)}') + if not isinstance(keepdim, bool): + raise TypeError(f'keepdim must be bool, but got {type(keepdim)}') attrs = {'axis': axis, 'keepdim': keepdim} helper = LayerHelper('reduce_p', **locals()) @@ -196,12 +198,20 @@ def matmul(x, y, out=None): @REGISTER_FN('slice_select_p', 'X', 'Y') def slice_select(x, axis, starts, ends, strides, out=None): - assert isinstance(axis, (list, tuple)), ( - f'Argument type error. `axis` is supposed to be int, list or' - f' tuple but found {type(axis)}.') - assert isinstance(starts, (list, tuple)) - assert isinstance(ends, (list, tuple)) - assert len(axis) == len(starts) == len(ends) == len(strides) + if not isinstance(axis, (list, tuple)): + raise TypeError(f'Argument type error. `axis` is supposed to be list or' + f' tuple but found {type(axis)}.') + if not isinstance(starts, (list, tuple)): + raise TypeError( + f'Argument type error. `starts` is supposed to be list or' + f' tuple but found {type(starts)}.') + if not isinstance(ends, (list, tuple)): + raise TypeError(f'Argument type error. `ends` is supposed to be list or' + f' tuple but found {type(ends)}.') + assert len(axis) == len(starts) == len(ends) == len(strides), ( + f'len(axis), len(starts), len(ends) and len(strides) should be equal, ' + f'but len(axis)={len(axis)}, len(starts)={len(starts)}, ' + f'len(ends)={len(ends)} and len(strides)={len(strides)}') attrs = {'axis': axis, 'starts': starts, 'ends': ends, 'strides': strides} helper = LayerHelper('slice_select_p', **locals()) @@ -217,8 +227,13 @@ def slice_select(x, axis, starts, ends, strides, out=None): @REGISTER_FN('slice_assign_p', 'X', 'Y', 'Z') def slice_assign(x, y, axis, starts, ends, strides, out=None): - assert len(starts) == len(ends) == len(strides) == len(axis) - assert len(y.shape) == len(x.shape) + assert len(starts) == len(ends) == len(strides) == len(axis), ( + f'len(starts), len(ends), len(strides) and len(axis) should be equal, ' + f'but len(starts)={len(starts)}, len(ends)={len(ends)}, ' + f'len(strides)={len(strides)} and len(axis)={len(axis)}') + assert len(y.shape) == len(x.shape), ( + f'len(y.shape) should be equal to len(x.shape), ' + f'but len(y.shape)={len(y.shape)} and len(x.shape)={len(x.shape)}.') attrs = {'axis': axis, 'starts': starts, 'ends': ends, 'strides': strides} helper = LayerHelper('slice_assign_p', **locals()) @@ -233,7 +248,7 @@ def slice_assign(x, y, axis, starts, ends, strides, out=None): return out -@REGISTER_FN('gather_p', 'X', 'Y') +@REGISTER_FN('gather_p', 'X', 'IndexTensor', 'Y') def gather(x, indextensor, axis, out=None): attrs = {'axis': axis} helper = LayerHelper('gather_p', **locals()) @@ -250,9 +265,16 @@ def gather(x, indextensor, axis, out=None): @REGISTER_FN('scatter_add_p', 'X', 'Y', 'IndexTensor', 'Z') def scatter_add(x, y, indextensor, axis, out=None): - assert len(x.shape) == len(y.shape) - assert len(indextensor.shape) == 1 - assert y.shape[axis] == indextensor.shape[0] + assert len(x.shape) == len(y.shape), ( + f'len(x.shape) should be equal to len(y.shape), ' + f'but len(x.shape)={len(x.shape)} and len(y.shape)={len(y.shape)}.') + assert len( + indextensor.shape + ) == 1, f'len(indextensor.shape) must be equal to 1, but got {len(indextensor.shape)}.' + assert y.shape[axis] == indextensor.shape[0], ( + f'y.shape[axis] should be equal to indextensor.shape[0], ' + f'but y.shape[axis]={y.shape[axis]} and ' + f'indextensor.shape[0]={indextensor.shape[0]}.') attrs = {'axis': axis} helper = LayerHelper('scatter_add_p', **locals()) if out is None: diff --git a/python/paddle/incubate/autograd/primreg.py b/python/paddle/incubate/autograd/primreg.py new file mode 100644 index 0000000000..35a0dbcfc2 --- /dev/null +++ b/python/paddle/incubate/autograd/primreg.py @@ -0,0 +1,289 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools + + +class Registry(object): + """ A general registry object. """ + __slots__ = ['name', 'tab'] + + def __init__(self, name): + self.name = name + self.tab = {} + + def register(self, name, value): + assert name not in self.tab, f'name "{name}" should not be registered before.' + self.tab[name] = value + + def lookup(self, name): + return self.tab.get(name) + + +_primop_fn = Registry('primop_fn') +_orig2prim = Registry('orig2prim') +_prim2orig = Registry('prim2orig') +_primop_jvp = Registry('primop_jvp') +_primop_transpose = Registry('primop_transpose') +_primop_position_argnames = Registry('primop_position_argnames') + + +def lookup_fn(optype): + return _primop_fn.lookup(optype) + + +def lookup_orig2prim(optype): + return _orig2prim.lookup(optype) + + +def lookup_prim2orig(optype): + return _prim2orig.lookup(optype) + + +def lookup_jvp(optype): + return _primop_jvp.lookup(optype) + + +def lookup_transpose(optype): + return _primop_transpose.lookup(optype) + + +def op_position_inputs(op): + """ + Returns the position inputs of `op` as registered with REGISTER_FN. + + Args: + op(Operator): The op that needs to get the inputs + + Returns: + Tensor(s): Inputs of the op + + Examples: + .. code-block:: python + @REGISTER_FN('div_p', 'X', 'Y', 'Z') + def div(x, y, out=None): + return _simple_binop(LayerHelper('div_p', **locals())) + + The registered inputs are ['X', 'Y'] for div_p and accordingly this + function will return inputs in the order of X then Y. + + """ + args = _primop_position_argnames.lookup(op.type) + assert args is not None, 'args should not be None in op_position_inputs().' + *input_names, _ = args + + inputs = [] + for name in input_names: + vars = list(map(op.block.var, op.input(name))) + assert len( + vars + ) >= 0, f'len(vars) should be greater than or equal to 0, but len(vars)={len(vars)}.' + if len(vars) > 1: + inputs.append(vars) + else: + inputs.append(vars[0]) + + return inputs + + +def op_position_output(op): + """ + Returns the output of `op` as registered with REGISTER_FN. + + Args: + op(Operator): The op that needs to get the output + + Returns: + Tensor(s): Output of the op + + Examples: + .. code-block:: python + @REGISTER_FN('div_p', 'X', 'Y', 'Z') + def div(x, y, out=None): + return _simple_binop(LayerHelper('div_p', **locals())) + + The registered output is ['Z'] for div_p and accordingly this + function will return output Z. + + """ + args = _primop_position_argnames.lookup(op.type) + assert args is not None, 'args should not be None in op_position_output().' + *_, output_name = args + + outvars = list(map(op.block.var, op.output(output_name))) + assert len( + outvars + ) >= 0, f'len(outvars) should be greater than or equal to 0, but len(outvars)={len(outvars)}.' + if len(outvars) > 1: + output = outvars + else: + output = outvars[0] + + return output + + +def REGISTER_FN(op_type, *position_argnames): + """ + Decorator for registering the Python function for a primitive op. + + Args: + op_type(str): The op name + position_argnames(list[str]): Input and ouput names of the op + + Returns: + wrapper: Inner wrapper function + + Examples: + .. code-block:: python + @REGISTER_FN('tanh_p', 'X', 'Y') + def tanh(x, out=None): + return _simple_unop(LayerHelper('tanh_p', **locals())) + + """ + + if not isinstance(op_type, str): + raise TypeError(f'op_type must be str, but got {type(op_type)}.') + + _primop_position_argnames.register(op_type, position_argnames) + + def wrapper(f): + _primop_fn.register(op_type, f) + return f + + return wrapper + + +def REGISTER_ORIG2PRIM(op_type): + """ + Decorator for registering the lower function for an original op into sequence of primitive ops. + + Args: + op_type(str): The op name + + Returns: + wrapper: Inner wrapper function + + Examples: + .. code-block:: python + @REGISTER_ORIG2PRIM('tanh') + def tanh_orig2prim(op): + x, = get_input_var_list(op) + return primops.tanh(x) + + """ + if not isinstance(op_type, str): + raise TypeError(f'op_type must be str, but got {type(op_type)}.') + + def wrapper(f): + def _lower(op, *args, **kwargs): + assert op.type == op_type, f'op.type should be equal to op_type, but op.type is {op.type} and op_type is {op_type}' + return f(op, *args, **kwargs) + + _orig2prim.register(op_type, _lower) + + return wrapper + + +def REGISTER_PRIM2ORIG(op_type): + """ + Decorator for registering the lower function for an primitive op into sequence of original ops. + + Args: + op_type(str): The op name + + Returns: + wrapper: Inner wrapper function + + Examples: + .. code-block:: python + @REGISTER_PRIM2ORIG('tanh_p') + def tanh_prim2orig(op): + x, = get_input_var_list(op) + return paddle.tanh(x) + + """ + if not isinstance(op_type, str): + raise TypeError(f'op_type must be str, but got {type(op_type)}.') + + def wrapper(f): + def _lower(op, *args, **kwargs): + assert op.type == op_type, f'op.type should be equal to op_type, but op.type is {op.type} and op_type is {op_type}' + return f(op, *args, **kwargs) + + _prim2orig.register(op_type, _lower) + + return wrapper + + +def REGISTER_JVP(op_type): + """ + Decorator for registering the JVP function for a primitive op. + + Args: + op_type(str): The op name + + Returns: + wrapper: Inner wrapper function + + Examples: + .. code-block:: python + @REGISTER_JVP('add_p') + def add_jvp(op, x_dot, y_dot): + return primops.add(x_dot, y_dot) + + """ + if not isinstance(op_type, str): + raise TypeError(f'op_type must be str, but got {type(op_type)}.') + + def wrapper(f): + def _jvp(op, *args, **kwargs): + assert op.type == op_type, f'op.type should be equal to op_type, but op.type is {op.type} and op_type is {op_type}' + return f(op, *args, **kwargs) + + _primop_jvp.register(op_type, _jvp) + return f + + return wrapper + + +def REGISTER_TRANSPOSE(op_type): + """ + Decorator for registering the transpose function for a primitive op + that denotes a linear operation in the forward AD graph. + + Args: + op_type(str): The op name + + Returns: + wrapper: Inner wrapper function + + Examples: + .. code-block:: python + @REGISTER_TRANSPOSE('add_p') + def add_transpose(op, z_bar): + return z_bar, z_bar + + """ + if not isinstance(op_type, str): + raise TypeError(f'op_type must be str, but got {type(op_type)}.') + + def wrapper(f): + def _transpose(op, dot_checker, *args, **kwargs): + assert op.type == op_type, f'op.type should be equal to op_type, but op.type is {op.type} and op_type is {op_type}' + return f(op, dot_checker, *args, **kwargs) + + _primop_transpose.register(op_type, _transpose) + return f + + return wrapper diff --git a/python/paddle/incubate/autograd/primrules.py b/python/paddle/incubate/autograd/primrules.py new file mode 100644 index 0000000000..075fe83e25 --- /dev/null +++ b/python/paddle/incubate/autograd/primrules.py @@ -0,0 +1,724 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle + +from .primreg import REGISTER_ORIG2PRIM, REGISTER_PRIM2ORIG, REGISTER_JVP, REGISTER_TRANSPOSE +from .primreg import (lookup_fn, lookup_orig2prim, lookup_prim2orig, lookup_jvp, + lookup_transpose, op_position_inputs, op_position_output) +from .primops import (neg, add, sub, mul, div, sqrt, tanh, reshape, broadcast, + transpose, split, concat, reduce, matmul, slice_select, + slice_assign, gather, scatter_add, fill_const, set_value) +from .utils import get_input_var_list, get_output_var_list, INT_DTYPE_2_STRING + + +def _orig2prim(op, *args): + _lowerrule = lookup_orig2prim(op.type) + return _lowerrule(op, *args) + + +def _prim2orig(op, *args): + _lowerrule = lookup_prim2orig(op.type) + return _lowerrule(op, *args) + + +def _jvp(op, *args): + _jvprule = lookup_jvp(op.type) + return _jvprule(op, *args) + + +def _transpose(op, dot_checker, *args): + _transposerule = lookup_transpose(op.type) + return _transposerule(op, dot_checker, *args) + + +def linear_jvp(op, *args, **kwargs): + fn = lookup_fn(op.type) + out_dot = fn(*args, **kwargs) + return out_dot + + +## Register orig2prim lower rules +""" +These original ops are fully supported: + +elementwise_add +elementwise_sub +elementwise_mul +tanh +fill_zeros_like +sum +index_select +scale +assign +sqrt + +These original ops are partially supported: + +matmul_v2 +reshape2 +concat +slice +p_norm +""" + + +@REGISTER_ORIG2PRIM('elementwise_add') +def elementwise_add_orig2prim(op, x, y): + if x.shape != y.shape: + y = broadcast(y, shape=x.shape) + if op.attr('Scale_x') - 1.0 > 1e-5: + scale_x = fill_const( + shape=x.shape, dtype=x.dtype, value=op.attr('Scale_x')) + x = mul(x, scale_x) + if op.attr('Scale_y') - 1.0 > 1e-5: + scale_y = fill_const( + shape=y.shape, dtype=y.dtype, value=op.attr('Scale_y')) + y = mul(y, scale_y) + z = add(x, y) + if op.attr('Scale_out') - 1.0 > 1e-5: + scale_out = fill_const( + shape=z.shape, dtype=z.dtype, value=op.attr('Scale_out')) + z = mul(z, scale_out) + return z + + +@REGISTER_ORIG2PRIM('elementwise_sub') +def elementwise_sub_orig2prim(op, x, y): + if x.shape != y.shape: + y = broadcast(y, shape=x.shape) + if op.attr('Scale_x') - 1.0 > 1e-5: + scale_x = fill_const( + shape=x.shape, dtype=x.dtype, value=op.attr('Scale_x')) + x = mul(x, scale_x) + if op.attr('Scale_y') - 1.0 > 1e-5: + scale_y = fill_const( + shape=y.shape, dtype=y.dtype, value=op.attr('Scale_y')) + y = mul(y, scale_y) + z = sub(x, y) + if op.attr('Scale_out') - 1.0 > 1e-5: + scale_out = fill_const( + shape=z.shape, dtype=z.dtype, value=op.attr('Scale_out')) + z = mul(z, scale_out) + return z + + +@REGISTER_ORIG2PRIM('elementwise_mul') +def elementwise_mul_orig2prim(op, x, y): + if x.shape != y.shape: + y = broadcast(y, shape=x.shape) + if op.attr('Scale_x') - 1.0 > 1e-5: + scale_x = fill_const( + shape=x.shape, dtype=x.dtype, value=op.attr('Scale_x')) + x = mul(x, scale_x) + if op.attr('Scale_y') - 1.0 > 1e-5: + scale_y = fill_const( + shape=y.shape, dtype=y.dtype, value=op.attr('Scale_y')) + y = mul(y, scale_y) + z = mul(x, y) + if op.attr('Scale_out') - 1.0 > 1e-5: + scale_out = fill_const( + shape=z.shape, dtype=z.dtype, value=op.attr('Scale_out')) + z = mul(z, scale_out) + return z + + +@REGISTER_ORIG2PRIM('tanh') +def tanh_orig2prim(op, x): + return tanh(x) + + +@REGISTER_ORIG2PRIM('fill_zeros_like') +def fill_zeros_like_orig2prim(op, x): + return fill_const(value=0.0, shape=x.shape, dtype=x.dtype) + + +@REGISTER_ORIG2PRIM('sum') +def sum_orig2prim(op, xs): + x0 = xs[0] + for x in xs[1:]: + x0 = add(x0, x) + return x0 + + +@REGISTER_ORIG2PRIM('index_select') +def index_select_orig2prim(op, index_t, x): + return gather(x, indextensor=index_t, axis=op.attr('dim')) + + +@REGISTER_ORIG2PRIM('scale') +def scale_orig2prim(op, scale_t, x): + if scale_t is None: + scale_t = fill_const( + shape=x.shape, dtype=x.dtype, value=op.attr('scale')) + bias_t = fill_const(shape=x.shape, dtype=x.dtype, value=op.attr('bias')) + if op.attr('bias_after_scale'): + return add(mul(x, scale_t), bias_t) + else: + return mul(add(x, bias_t), scale_t) + + +@REGISTER_ORIG2PRIM('assign') +def assign_orig2prim(op, x): + zero_t = fill_const(shape=x.shape, dtype=x.dtype, value=0.0) + return add(x, zero_t) + + +@REGISTER_ORIG2PRIM('sqrt') +def sqrt_orig2prim(op, x): + return sqrt(x) + + +@REGISTER_ORIG2PRIM('matmul_v2') +def matmul_v2_orig2prim(op, x, y): + def trans(shape): + ret = [i for i in range(len(shape))] + ret[-1], ret[-2] = ret[-2], ret[-1] + return ret + + assert len(x.shape) < 4 and len( + y.shape) < 4, 'Do not support multi batchsize dimensions currently.' + + if len(x.shape) == 1: + x = broadcast(x, shape=[1, x.shape[0]]) + if len(y.shape) == 1: + y = broadcast(y, shape=[y.shape[0], 1]) + if op.attr('trans_x'): + x = transpose(x, axis=trans(x.shape)) + if op.attr('trans_y'): + y = transpose(y, axis=trans(y.shape)) + return matmul(x, y) + + +## NOTE(lml): The second output of reshape2 Xshape, which is only used in reshape2_grad, is meanlingless in new autograd mechanism, thus we use a zero tensor instead. +@REGISTER_ORIG2PRIM('reshape2') +def reshape2_orig2prim(op, shape_t, shape_tl, x): + assert shape_t is None, 'Can not lower reshape2 into prim ops with shapetensor.' + assert shape_tl is None, 'Can not lower reshape2 into prim ops with shapetensorlist.' + y, xshape = get_output_var_list(op) + return reshape( + x, shape=y.shape), fill_const( + shape=xshape.shape, dtype=xshape.dtype, value=0.0) + + +@REGISTER_ORIG2PRIM('concat') +def concat_orig2prim(op, axis_t, xs): + assert axis_t is None, 'Can not lower concat into prim ops with axistensor.' + return concat(xs, axis=op.attr('axis')) + + +@REGISTER_ORIG2PRIM('slice') +def slice_orig2prim(op, ends_t, ends_tl, x, starts_t, starts_tl): + assert starts_t is None, 'Can not lower concat into prim ops with startstensor.' + assert ends_t is None, 'Can not lower concat into prim ops with endstensor.' + assert starts_tl is None, 'Can not lower concat into prim ops with startstensorlist.' + assert ends_tl is None, 'Can not lower concat into prim ops with endstensorlist.' + starts = op.attr('starts') + ends = op.attr('ends') + strides = [1 for _ in starts] + axis = op.attr('axes') + y = slice_select(x, starts=starts, ends=ends, strides=strides, axis=axis) + if op.attr('decrease_axis'): + y = reshape(y, shape=get_output_var_list(op)[0].shape) + return y + + +@REGISTER_ORIG2PRIM('p_norm') +def p_norm_orig2prim(op, x): + def num_el(shape): + n = 1 + for s in shape: + n = n * s + return n + + assert op.attr( + 'asvector'), 'Only support lower pnorm when asvector=True currently' + if len(x.shape) > 1: + x = reshape(x, shape=[num_el(x.shape)]) + + if abs(op.attr('porder') - 2.0) < 1e-5: + return sqrt(reduce(mul(x, x), axis=[0])) + elif abs(op.attr('porder') - 1.0) < 1e-5: + return reduce(sqrt(mul(x, x)), axis=[0]) + else: + raise RuntimeError('Only support lower l2/l1 norm currently') + + +## Register prim2orig lower rules + + +@REGISTER_PRIM2ORIG('add_p') +def add_prim2orig(op, x, y): + return paddle.add(x, y) + + +@REGISTER_PRIM2ORIG('sub_p') +def sub_prim2orig(op, x, y): + return paddle.subtract(x, y) + + +@REGISTER_PRIM2ORIG('mul_p') +def mul_prim2orig(op, x, y): + return paddle.multiply(x, y) + + +@REGISTER_PRIM2ORIG('div_p') +def div_prim2orig(op, x, y): + return paddle.divide(x, y) + + +@REGISTER_PRIM2ORIG('sqrt_p') +def sqrt_prim2orig(op, x): + return paddle.sqrt(x) + + +@REGISTER_PRIM2ORIG('tanh_p') +def tanh_prim2orig(op, x): + return paddle.tanh(x) + + +@REGISTER_PRIM2ORIG('reshape_p') +def reshape_prim2orig(op, x): + return paddle.reshape(x, shape=op.attr('shape')) + + +@REGISTER_PRIM2ORIG('broadcast_p') +def broadcast_prim2orig(op, x): + return paddle.broadcast_to(x, shape=op.attr('shape')) + + +@REGISTER_PRIM2ORIG('transpose_p') +def transpose_prim2orig(op, x): + return paddle.transpose(x, perm=op.attr('axis')) + + +@REGISTER_PRIM2ORIG('split_p') +def split_prim2orig(op, x): + num_or_sections = op.attr('num_or_sections') + if len(num_or_sections) == 1: + num_or_sections = num_or_sections[0] + return paddle.split( + x, num_or_sections=num_or_sections, axis=op.attr('axis')) + + +@REGISTER_PRIM2ORIG('concat_p') +def concat_prim2orig(op, xs): + return paddle.concat(xs, axis=op.attr('axis')) + + +@REGISTER_PRIM2ORIG('reduce_p') +def reduce_prim2orig(op, x): + return paddle.sum(x, axis=op.attr('axis'), keepdim=op.attr('keepdim')) + + +@REGISTER_PRIM2ORIG('matmul_p') +def matmul_prim2orig(op, x, y): + return paddle.matmul(x, y) + + +@REGISTER_PRIM2ORIG('slice_select_p') +def slice_select_prim2orig(op, x): + return paddle.strided_slice( + x, + axes=op.attr('axis'), + starts=op.attr('starts'), + ends=op.attr('ends'), + strides=op.attr('strides')) + + +@REGISTER_PRIM2ORIG('slice_assign_p') +def slice_assign_prim2orig(op, x, y): + x_copy = paddle.assign(x) + return set_value( + x_copy, + y, + axis=op.attr('axis'), + starts=op.attr('starts'), + ends=op.attr('ends'), + strides=op.attr('strides'), + out=x_copy) + + +@REGISTER_PRIM2ORIG('gather_p') +def gather_prim2orig(op, index_t, x): + return paddle.gather(x, index_t, axis=op.attr('axis')) + + +@REGISTER_PRIM2ORIG('scatter_add_p') +def scatter_add_prim2orig(op, index_t, x, y): + assert op.attr('axis') == 0, 'Only support axis==0 currently' + zeros = paddle.zeros_like(x=x, dtype=x.dtype) + tmp = paddle.scatter(x=zeros, index=index_t, updates=y, overwrite=False) + return paddle.add(x, tmp) + + +@REGISTER_PRIM2ORIG('fill_constant_p') +def fill_constant_prim2orig(op): + return paddle.full( + shape=op.attr('shape'), + fill_value=op.attr('value'), + dtype=INT_DTYPE_2_STRING[op.attr('dtype')]) + + +## Register linearize rules +@REGISTER_JVP('add_p') +def add_jvp(op, x_dot, y_dot): + if x_dot is None: + return y_dot + elif y_dot is None: + return x_dot + else: + return linear_jvp(op, x_dot, y_dot) + + +@REGISTER_JVP('sub_p') +def sub_jvp(op, x_dot, y_dot): + if x_dot is None: + return neg(y_dot) + elif y_dot is None: + return x_dot + else: + return linear_jvp(op, x_dot, y_dot) + + +@REGISTER_JVP('mul_p') +def mul_jvp(op, x_dot, y_dot): + if x_dot is None and y_dot is None: + return None + x, y = op_position_inputs(op) + if x_dot is None: + return mul(x, y_dot) + elif y_dot is None: + return mul(x_dot, y) + else: + t1, t2 = mul(x_dot, y), mul(x, y_dot) + z_dot = add(t1, t2) + return z_dot + + +@REGISTER_JVP('div_p') +def div_jvp(op, x_dot, y_dot): + if x_dot is None and y_dot is None: + return None + x, y = op_position_inputs(op) + if y_dot is None: + return div(x_dot, y) + elif x_dot is None: + return neg(div(mul(x, y_dot), mul(y, y))) + else: + t1 = div(x_dot, y) + t2 = div(mul(x, y_dot), mul(y, y)) + return sub(t1, t2) + + +@REGISTER_JVP('sqrt_p') +def sqrt_jvp(op, x_dot): + if x_dot is None: + return None + y = op_position_output(op) + c2 = fill_const(value=2.0, shape=y.shape, dtype=y.dtype) + y_dot = div(x_dot, mul(c2, y)) + return y_dot + + +@REGISTER_JVP('tanh_p') +def tanh_jvp(op, x_dot): + if x_dot is None: + return None + y = op_position_output(op) + c1 = fill_const(value=1.0, shape=y.shape, dtype=y.dtype) + y_dot = mul(x_dot, sub(c1, mul(y, y))) + return y_dot + + +@REGISTER_JVP('reshape_p') +def reshape_jvp(op, x_dot): + if x_dot is None: + return None + shape = op.attr('shape') + return linear_jvp(op, x_dot, shape=shape) + + +@REGISTER_JVP('broadcast_p') +def broadcast_jvp(op, x_dot): + if x_dot is None: + return None + shape = op.attr('shape') + return linear_jvp(op, x_dot, shape=shape) + + +@REGISTER_JVP('transpose_p') +def transpose_jvp(op, x_dot): + if x_dot is None: + return None + axis = op.attr('axis') + return linear_jvp(op, x_dot, axis=axis) + + +@REGISTER_JVP('split_p') +def split_jvp(op, x_dot): + if x_dot is None: + return None + num_or_sections = op.attr('num_or_sections') + axis = op.attr('axis') + return linear_jvp(op, x_dot, num_or_sections=num_or_sections, axis=axis) + + +@REGISTER_JVP('concat_p') +def concat_jvp(op, xs_dot): + if xs_dot is None: + return None + axis = op.attr('axis') + return linear_jvp(op, xs_dot, axis=axis) + + +@REGISTER_JVP('reduce_p') +def reduce_jvp(op, x_dot): + if x_dot is None: + return None + axis = op.attr('axis') + keepdim = op.attr('keepdim') + return linear_jvp(op, x_dot, axis=axis, keepdim=keepdim) + + +@REGISTER_JVP('matmul_p') +def matmul_jvp(op, x_dot, y_dot): + if x_dot is None and y_dot is None: + return None + x, y = op_position_inputs(op) + if x_dot is None: + return matmul(x, y_dot) + elif y_dot is None: + return matmul(x_dot, y) + else: + t1 = matmul(x, y_dot) + t2 = matmul(x_dot, y) + return add(t1, t2) + + +@REGISTER_JVP('slice_select_p') +def slice_select_jvp(op, x_dot): + if x_dot is None: + return x_dot + axis = op.attr('axis') + starts = op.attr('starts') + ends = op.attr('ends') + strides = op.attr('strides') + return linear_jvp( + op, x_dot, axis=axis, starts=starts, ends=ends, strides=strides) + + +@REGISTER_JVP('slice_assign_p') +def slice_assign_jvp(op, x_dot, y_dot): + if x_dot is None: + assert y_dot is None, 'y_dot must be None.' + return None + else: + assert y_dot is not None, 'y_dot should not be None.' + axis = op.attr('axis') + starts = op.attr('starts') + ends = op.attr('ends') + strides = op.attr('strides') + return linear_jvp( + op, x_dot, y_dot, axis=axis, starts=starts, ends=ends, strides=strides) + + +@REGISTER_JVP('gather_p') +def gather_jvp(op, x_dot, indextensor): + if x_dot is None: + return None + _, indextensor = op_position_inputs(op) + axis = op.attr('axis') + return linear_jvp(op, x_dot, indextensor, axis=axis) + + +@REGISTER_JVP('scatter_add_p') +def scatter_add_jvp(op, x_dot, y_dot): + if x_dot is None: + return None + _, _, indextensor = op_position_inputs(op) + axis = op.attr('axis') + return linear_jvp(op, x_dot, y_dot, indextensor, axis=axis) + + +## Register transpose rules + + +@REGISTER_TRANSPOSE('add_p') +def add_transpose(op, check_dot, z_bar): + x, y = op_position_inputs(op) + assert check_dot(x) or check_dot(y), ( + f'(check_dot(x) or check_dot(y)) must be True, ' + f'but check_dot(x)={check_dot(x)} and check_dot(y)={check_dot(y)}.') + x_bar = z_bar if check_dot(x) else None + y_bar = z_bar if check_dot(y) else None + return x_bar, y_bar + + +@REGISTER_TRANSPOSE('sub_p') +def sub_transpose(op, check_dot, z_bar): + x, y = op_position_inputs(op) + assert check_dot(x) or check_dot(y), ( + f'(check_dot(x) or check_dot(y)) must be True, ' + f'but check_dot(x)={check_dot(x)} and check_dot(y)={check_dot(y)}.') + x_bar = z_bar if check_dot(x) else None + y_bar = neg(z_bar) if check_dot(y) else None + return x_bar, y_bar + + +@REGISTER_TRANSPOSE('mul_p') +def mul_transpose(op, check_dot, z_bar): + x, y = op_position_inputs(op) + assert check_dot(x) ^ check_dot(y), ( + f'(check_dot(x) ^ check_dot(y)) must be True, ' + f'but check_dot(x)={check_dot(x)} and check_dot(y)={check_dot(y)}.') + if check_dot(x): + return mul(z_bar, y), None + else: + return None, mul(x, z_bar) + + +@REGISTER_TRANSPOSE('div_p') +def div_transpose(op, check_dot, z_bar): + x, y = op_position_inputs(op) + assert not check_dot(y), 'check_dot(y) must be False' + x_bar = div(z_bar, y) if check_dot(x) else None + return x_bar, None + + +@REGISTER_TRANSPOSE('reshape_p') +def reshape_transpose(op, check_dot, y_bar): + x, = op_position_inputs(op) + assert check_dot(x), 'check_dot(x) must be True' + return reshape(y_bar, shape=x.shape) + + +@REGISTER_TRANSPOSE('broadcast_p') +def broadcast_transpose(op, check_dot, y_bar): + x, = op_position_inputs(op) + assert check_dot(x), 'check_dot(x) must be True' + bat = len(y_bar.shape) - len(x.shape) + axis = list(range(bat)) + keepdim = [(bat + i) for i, s in enumerate(x.shape) if s == 1] + axis += keepdim + # TODO: Change it. keepdim boolean + out = reduce(y_bar, axis=axis, keepdim=False) + return reshape(out, x.shape) + + +@REGISTER_TRANSPOSE('transpose_p') +def transpose_transpose(op, check_dot, y_bar): + x, = op_position_inputs(op) + assert check_dot(x), 'check_dot(x) must be True' + axis = op.attr('axis') + reordered = sorted((k, i) for i, k in enumerate(axis)) + axis = [i for k, i in reordered] + return transpose(y_bar, axis=axis) + + +@REGISTER_TRANSPOSE('split_p') +def split_transpose(op, check_dot, ys_bar): + x, = op_position_inputs(op) + assert check_dot(x), 'check_dot(x) must be True' + return concat(ys_bar, axis=op.attr('axis')) + + +@REGISTER_TRANSPOSE('concat_p') +def concat_transpose(op, check_dot, y_bar): + xs, = op_position_inputs(op) + for x in xs: + assert check_dot(x), 'check_dot(x) must be True' + axis = op.attr('axis') + sections = [x.shape[axis] for x in xs] + return split(y_bar, num_or_sections=sections, axis=axis) + + +@REGISTER_TRANSPOSE('reduce_p') +def reduce_transpose(op, check_dot, y_bar): + x, = op_position_inputs(op) + assert check_dot(x), 'check_dot(x) must be True' + axes = op.attr('axis') + shape = tuple(1 if i in axes else size for i, size in enumerate(x.shape)) + t = reshape(y_bar, shape=shape) + return broadcast(t, shape=x.shape) + + +@REGISTER_TRANSPOSE('matmul_p') +def matmul_transpose(op, check_dot, z_bar): + x, y = op_position_inputs(op) + assert check_dot(x) ^ check_dot(y), ( + f'(check_dot(x) ^ check_dot(y)) must be True, ' + f'but check_dot(x)={check_dot(x)} and check_dot(y)={check_dot(y)}.') + # TODO: replace it. this is hacky + axis = [1, 0] if len(x.shape) == 2 else [0, 2, 1] + if check_dot(x): + return matmul(z_bar, transpose(y, axis=axis)), None + else: + return None, matmul(transpose(x, axis=axis), z_bar) + + +@REGISTER_TRANSPOSE('slice_select_p') +def slice_select_transpose(op, check_dot, y_bar): + x, = op_position_inputs(op) + assert check_dot(x), 'check_dot(x) must be True' + zeros = fill_const(value=0.0, shape=x.shape, dtype=x.dtype) + axis = op.attr('axis') + starts = op.attr('starts') + ends = op.attr('ends') + strides = op.attr('strides') + return slice_assign( + zeros, y_bar, axis=axis, starts=starts, ends=ends, strides=strides) + + +@REGISTER_TRANSPOSE('slice_assign_p') +def slice_assign_transpose(op, check_dot, z_bar): + x, y = op_position_inputs(op) + assert check_dot(x) and check_dot(y), ( + f'(check_dot(x) and check_dot(y)) must be True, ' + f'but check_dot(x)={check_dot(x)} and check_dot(y)={check_dot(y)}.') + zeros = fill_const(value=0.0, shape=y.shape, dtype=y.dtype) + axis = op.attr('axis') + starts = op.attr('starts') + ends = op.attr('ends') + strides = op.attr('strides') + x_bar = slice_assign( + z_bar, zeros, axis=axis, starts=starts, ends=ends, strides=strides) + y_bar = slice_select( + z_bar, axis=axis, starts=starts, ends=ends, strides=strides) + return x_bar, y_bar + + +@REGISTER_TRANSPOSE('gather_p') +def gather_transpose(op, check_dot, y_bar): + x, indextensor = op_position_inputs(op) + assert check_dot(x), 'check_dot(x) must be True' + axis = op.attr('axis') + zeros = fill_const(0.0, x.shape, x.dtype) + x_bar = scatter_add(zeros, y_bar, indextensor, axis=axis) + indextensor_bar = None + return x_bar, indextensor_bar + + +@REGISTER_TRANSPOSE('scatter_add_p') +def scatter_add_transpose(op, check_dot, z_bar): + x, y, indextensor = op_position_inputs(op) + assert check_dot(x) and check_dot(y), ( + f'(check_dot(x) and check_dot(y)) must be True, ' + f'but check_dot(x)={check_dot(x)} and check_dot(y)={check_dot(y)}.') + axis = op.attr('axis') + zeros = fill_const(value=0.0, shape=y.shape, dtype=y.dtype) + x_bar = scatter_add(z_bar, zeros, indextensor, axis=axis) + y_bar = gather(z_bar, indextensor, axis=axis) + indextensor_bar = None + return x_bar, y_bar, indextensor_bar diff --git a/python/paddle/incubate/autograd/primx.py b/python/paddle/incubate/autograd/primx.py new file mode 100644 index 0000000000..7a96974820 --- /dev/null +++ b/python/paddle/incubate/autograd/primx.py @@ -0,0 +1,611 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.fluid import framework as framework +from paddle.fluid.framework import default_main_program +from paddle.fluid.framework import Operator +from paddle import compat as cpt +from .primops import fill_const, add +from .primreg import op_position_inputs, op_position_output, lookup_orig2prim, lookup_prim2orig +from .primrules import _orig2prim, _prim2orig, _jvp, _transpose +from .utils import get_input_var_list, get_output_var_list, to_tensors, flatten, flatten_and_remove_none +from collections import OrderedDict + + +def topo_path(xs, ys, block=None): + """ Returns the list of ops on the path from `xs` to `ys` in topological + order. + + TODO(Tongxin): supporting control flow and nested blocks. + Args: + xs: a list|tuple of vars as source + ys: a list|tuple of vars as sink + block: the program block containing the path, optional + Returns: + (path, unused_xs, unreached_ys): a tuple comprised of the resulting op + path, the unused variables in `xs`, and the unreached variables in `ys` + """ + + if block is None: + block = default_main_program().current_block() + + path = [] + backpath = [] + reached_vars = OrderedDict() + used_vars = OrderedDict() + + # Initialize reached vars + for x in xs: + assert x is None or x.block == block, f'x is not None and x.block != block' + reached_vars[id(x)] = x + + # Reaching test, returning whether an op is reached from the given input + reaching = lambda op: any(id(v) in reached_vars for v in flatten_and_remove_none(get_input_var_list(op))) + + # block.ops are supposedly in the order that preserves correct data + # dependence. + # Forward pass to identify all reached variables and ops + for op in block.ops: + if reaching(op): + path.append(op) + for var in flatten_and_remove_none(get_output_var_list(op)): + reached_vars[id(var)] = var + + used_vars = OrderedDict((id(y), y) for y in ys if id(y) in reached_vars) + back_reaching = lambda op: any(id(out) in used_vars for out in flatten_and_remove_none(get_output_var_list(op))) + + # Backward pass to find all used variables + for op in reversed(path): + if back_reaching(op): + backpath.append(op) + for var in flatten_and_remove_none(get_input_var_list(op)): + used_vars[id(var)] = var + + unused_xs = [x for x in xs if id(x) not in used_vars] + unreached_ys = [y for y in ys if id(y) not in reached_vars] + + return list(reversed(backpath)), unused_xs, unreached_ys + + +def output_vars_on_path(path): + """ Returns the output variables of all the ops on the path from `xs` + to `ys`. + + Args: + path: a list of ops on which to find the output variables + + Returns: + vars: the output vars + """ + vars = OrderedDict() + for op in path: + for out in flatten_and_remove_none(get_output_var_list(op)): + vars[id(out)] = out + + return vars + + +class VarMap(object): + """ A general map data structure for linking variables to variables. + + An example is linking variables to their gradients. + """ + + __slots__ = ['name', 'varset', 'tab'] + + def __init__(self, name, varset): + self.name = name + self.varset = varset + self.tab = OrderedDict() + + def add(self, key_var, value_var): + self.tab[id(key_var)] = id(value_var) + + def add_rec(self, key_vars, value_vars): + if value_vars is None: + return + if isinstance(key_vars, paddle.fluid.framework.Variable): + if not isinstance(value_vars, paddle.fluid.framework.Variable): + raise TypeError( + f'value_vars must be Variable, but got {type(value_vars)}') + self.tab[id(key_vars)] = id(value_vars) + else: + assert len(key_vars) == len(value_vars), ( + f'len(key_vars) shoule be equal to len(value_vars), ' + f'but len(key_vars)={len(key_vars)} and len(value_vars)={len(value_vars)}.' + ) + for key_var, value_var in zip(key_vars, value_vars): + self.add_rec(key_var, value_var) + + def lookup(self, key_var): + value_id = self.tab.get(id(key_var)) + if value_id is not None: + return self.varset.get(value_id) + else: + return None + + def delete(self, key_var): + varid = id(key_var) + if varid in self.tab: + del self.tab[id(key_var)] + + def delete_keyvars(self, key_vars): + for var in key_vars: + varid = id(var) + if varid in self.tab: + del self.tab[varid] + + def delete_valuevars(self, value_vars): + ids = [id(v) for v in value_vars] + keys = [k for k, v in self.tab.items() if v in ids] + for k in keys: + del self.tab[k] + + def contain_var(self, key_var): + return self.tab.__contains__(id(key_var)) + + def contain_value(self, value_var): + return id(value_var) in self.tab.values() + + +class Transform(object): + """ An object that maintains the state of transformations applied to a + primitve program. """ + + def __init__(self, block): + self.block = block + self.vars = self.init_vars(block) + self.var2dot = VarMap('var2dot', self.vars) + self.dot2bar = VarMap('dot2var', self.vars) + + def init_vars(self, block): + vars = OrderedDict() + for _, var in block.vars.items(): + vars[id(var)] = var + return vars + + def add_vars(self, new_vars): + self.vars.update({id(v): v for v in new_vars if v is not None}) + + def add_vars_rec(self, new_vars): + if new_vars is None: + return + if isinstance(new_vars, paddle.fluid.framework.Variable): + self.vars.update({id(new_vars): new_vars}) + return + if not isinstance(new_vars, list): + raise TypeError(f'new_vars must be list, but got {type(new_vars)}') + for var in new_vars: + self.add_vars_rec(var) + + def erase_ops(self, ordered_indexes): + block = self.block + for op_index in reversed(ordered_indexes): + block.desc._remove_op(op_index, op_index + 1) + + # remove from block.ops + for op_index in reversed(ordered_indexes): + del block.ops[op_index] + + block._sync_with_cpp() + + def erase_dots(self, vars_to_erase): + for var in vars_to_erase: + if id(var) in self.vars: + del self.vars[id(var)] + self.dot2bar.delete_keyvars(vars_to_erase) + self.var2dot.delete_valuevars(vars_to_erase) + block = self.block + for var in vars_to_erase: + name = var.name + block.desc._remove_var(cpt.to_bytes(name)) + del block.vars[name] + block._sync_with_cpp() + + def var2dot_rec(self, vars): + """ Lookup var2dot recursively.""" + if isinstance(vars, paddle.fluid.framework.Variable): + dot = self.var2dot.lookup(vars) + return dot + + dots = [self.var2dot_rec(var) for var in vars] + return dots + + def dot2bar_rec(self, dots): + + if isinstance(dots, paddle.fluid.framework.Variable): + bar = self.dot2bar.lookup(dots) + assert bar is not None, 'bar must be not None' + return bar + + bars = [self.dot2bar_rec(dot) for dot in dots] + return bars + + def linearize(self, xs, ys, xs_dot=None): + """ Performs the linearization transform, a.k.a, forward mode AD + transform, on a primitive lowered program. + + Args: + xs: a list of input variables + ys: a list of output variables + xs_dot: optional, a list of gradient input variables. The list size + must be equal to `len(xs)`. The shape and dtype of each element + must be the same as in `xs` + + Returns: + (xs_dot, ys_dot): a tuple of two lists. `xs_dot` is the list of + gradient inputs of the resulting linearized program. `ys_dot` is + the list gradient outputs of the resulting linearized program + + """ + if xs_dot is None: + xs_dot = [fill_const(1.0, shape=x.shape, dtype=x.dtype) for x in xs] + self.add_vars(xs_dot) + else: + assert len(xs) == len(xs_dot), ( + f'len(xs) should be equal to len(xs_dot), ' + f'but len(xs)={len(xs)} and len(xs_dot)={len(xs_dot)}') + + for x, dot in zip(xs, xs_dot): + assert x.dtype == dot.dtype, ( + f'x.dtype should be equal to dot.dtype, ' + f'but x.dtype={x.dtype} and dot.dtype={dot.dtype}') + assert x.shape == dot.shape, ( + f'x.shape should be equal to dot.shape, ' + f'but x.shape={x.shape} and dot.shape={dot.shape}') + self.var2dot.add(x, dot) + + path, unused_xs, _ = topo_path(xs, ys, self.block) + + # No need to track unused inputs + for x in unused_xs: + self.var2dot.delete(x) + + for op in path: + # An input var may not be on the input-output path, which implies + # there may be None's in `ins_dot`. In this case we place + # the original input in the position of the otherwise forward + # gradient. + ins = op_position_inputs(op) + jvp_ins = self.var2dot_rec(ins) + # apply op's forward ad rule + outs_dot = _jvp(op, *jvp_ins) + self.add_vars_rec(outs_dot) + outs = op_position_output(op) + self.var2dot.add_rec(outs, outs_dot) + + ys_dot = [self.var2dot.lookup(y) for y in ys] + return xs_dot, ys_dot + + def transpose(self, ys_dot, xs_dot, ys_bar=None, retain_fwd=False): + """ Performs the transpose transform, a.k.a, reverse mode AD + transform, on a linearized primitive program. + + Note, `transpose` is supposed to be used in couple with `linearize`. + + Args: + ys_dot: a list of outputs of the linearized program. + xs_dot: a list of inputs of the linearized program. + ys_bar: optional, a list of inputs of the resulting transposed + program. The list size must be equal to `len(ys_dot)`. The shape + and dtype of each element must be the same as in `ys_dot` + + Returns: + (ys_bar, xs_bar): a tuple of two lists. `ys_bar` is the list of + inputs of the resulting transposed program. `xs_bar` is + the list outputs of the resulting transposed program + + """ + assert all(v is not None for v in xs_dot), f'`xs_dot` includes None.' + assert all(v is not None for v in ys_dot), f'`ys_dot` includes None.' + + if ys_bar is None: + ys_bar = [] + for y in ys_dot: + ys_bar.append(fill_const(1.0, shape=y.shape, dtype=y.dtype)) + self.add_vars(ys_bar) + else: + assert len(ys_dot) == len(ys_bar), ( + f'len(ys_dot) should be equal to len(ys_bar), ' + f'but len(ys_dot)={len(ys_dot)} and len(ys_bar)={len(ys_bar)}') + for y_dot, y_bar in zip(ys_dot, ys_bar): + assert y_dot.shape == y_bar.shape, ( + f'y_dot.shape should be equal to y_bar.shape, ' + f'but y_dot.shape={y_dot.shape} and y_bar.shape={y_bar.shape}' + ) + assert y_dot.dtype == y_bar.dtype, ( + f'y_dot.dtype should be equal to y_bar.dtype, ' + f'but y_dot.dtype={y_dot.dtype} and y_bar.dtype={y_bar.dtype}' + ) + + for dot, bar in zip(ys_dot, ys_bar): + self.dot2bar.add(dot, bar) + + # find all the relevant forward gradients + path, unused_xs_dot, _ = topo_path(xs_dot, ys_dot, self.block) + + # No need to track unused inputs + for dot in unused_xs_dot: + self.dot2bar.delete(dot) + + dotvars = output_vars_on_path(path) + dotvars.update((id(var), var) for var in xs_dot) + + is_dot = lambda v: id(v) in dotvars + + for op in reversed(path): + out = op_position_output(op) + out_bar_rec = self.dot2bar_rec(out) + ins_bar_rec = _transpose(op, is_dot, out_bar_rec) + + # TODO(Tongxin): this is hacky. Tuple implies the Transpose rule + # returns multiple entities. There should be better ways to handle + # outputs. + if isinstance(ins_bar_rec, tuple): + ins_bar_rec = list(ins_bar_rec) + else: + ins_bar_rec = [ins_bar_rec] + self.add_vars_rec(ins_bar_rec) + + ins_bar = flatten(ins_bar_rec) + ins = flatten(op_position_inputs(op)) + assert len(ins) == len(ins_bar), ( + f'len(ins) should be equal to len(ins_bar), ' + f'but len(ins)={len(ins)} and len(ins_bar)={len(ins_bar)}') + + for dot, bar in zip(ins, ins_bar): + if bar is not None: + # aggregate gradient + grad = self.dot2bar.lookup(dot) + if grad is None: + self.dot2bar.add(dot, bar) + else: + grad = add(grad, bar) + self.add_vars([grad]) + self.dot2bar.add(dot, grad) + + xs_bar = [self.dot2bar.lookup(x) for x in xs_dot] + + if not retain_fwd and len(path) > 0: + vars_to_remove = set() + for op in path: + vars_to_remove.update( + flatten_and_remove_none(get_output_var_list(op))) + + op_indexes = [] + + block = self.block + for i, op in enumerate(block.ops): + if op in path: + op_indexes.append(i) + path.pop(0) + if len(path) == 0: + break + + self.erase_ops(op_indexes) + self.erase_dots(vars_to_remove) + + return ys_bar, xs_bar + + +def _lower(block, reverse): + # Some functions which are only used in _lower. + def bind(args, to_bind, value_table): + for i in range(len(args)): + if isinstance(args[i], list): + bind(args[i], to_bind, value_table) + elif args[i] is not None and args[i].name in to_bind: + args[i] = value_table[to_bind[args[i].name]] + + def bind_name(names, to_bind): + return_list = [] + for name in names: + if isinstance(name, list): + return_list.append(bind_name(name, to_bind)) + else: + return_list.append(to_bind[name] if name in to_bind else name) + return return_list + + def expand_nested_list(xs): + return_list = [] + for x in xs: + if isinstance(x, list): + return_list = return_list + expand_nested_list(x) + else: + return_list.append(x) + return return_list + + # Step1: Do some preparatory work for lower + lower_fn = _prim2orig if reverse else _orig2prim + lookup_fn = lookup_prim2orig if reverse else lookup_orig2prim + if block is None: + program = default_main_program() + assert program.num_blocks == 1, "The lower transform is designed to process only one block." + block = program.current_block() + + value_table = {} + to_bind = {} + to_bind_rev = {} + for var in block.desc.all_vars(): + value_table[var.name()] = block.var(var.name()) + + ops_to_remove = [] + vars_to_remove = set() + + # Step2: Process all ops in the target block + for op_idx in range(len(block.ops)): + op = block.ops[op_idx] + ops_to_remove.append(op_idx) + if lookup_fn(op.type) is not None: + input_args = get_input_var_list(op) + bind(input_args, to_bind, value_table) + + for orig_out, new_out in zip( + expand_nested_list(get_output_var_list(op)), + expand_nested_list(to_tensors(lower_fn(op, *input_args)))): + assert not (orig_out is None) ^ ( + new_out is None), "orig_out and new_out should match." + vars_to_remove.add(new_out.name) + value_table[new_out.name] = new_out + to_bind[orig_out.name] = new_out.name + to_bind_rev[new_out.name] = orig_out.name + else: + inputs = {} + for i in range(len(op.input_names)): + inputs[op.input_names[i]] = bind_name( + op.input(op.input_names[i]), to_bind) + + outputs = {} + for i in range(len(op.output_names)): + outputs[op.output_names[i]] = op.output(op.output_names[i]) + + attrs = {} + for name in sorted(op.attr_names): + attrs[name] = op.attr(name) + from paddle.fluid.dygraph.base import param_guard + new_op_desc = block.desc.append_op() + with param_guard(inputs), param_guard(outputs): + op = Operator( + block=block, + desc=new_op_desc, + type=op.type, + inputs=inputs, + outputs=outputs, + attrs=attrs) + block.ops.append(op) + + # Step3: Do some post-processing work + for op_idx in reversed(ops_to_remove): + block.desc._remove_op(op_idx, op_idx + 1) + del block.ops[op_idx] + block._sync_with_cpp() + + for op_idx in range(len(block.ops)): + op = block.ops[op_idx] + for in_name in op.input_arg_names: + if in_name in to_bind_rev: + op._rename_input(in_name, to_bind_rev[in_name]) + + for out_name in op.output_arg_names: + if out_name in to_bind_rev: + op._rename_output(out_name, to_bind_rev[out_name]) + + for var_name in sorted(vars_to_remove): + assert var_name in to_bind_rev, 'var_name "{}" is not in to_bind_rev.'.format( + var_name) + if var_name != to_bind_rev[var_name]: + block.desc._remove_var(cpt.to_bytes(var_name)) + del block.vars[var_name] + block._sync_with_cpp() + + +@framework.static_only +def orig2prim(block=None): + """ + .. note:: + **This API is ONLY available in the static mode.** + + All operators in the target block are processed as follows. + If it is an original operator, it will be transformed into + one or a series of automatic differential basic operators with + equivalent function. + + Args: + block(paddle.fluid.framework.Variable|None, optional): The + target block to process on. Default None, and will + process on the current block of main program. + + Returns: + None + """ + _lower(block, reverse=False) + + +@framework.static_only +def prim2orig(block=None): + """ + .. note:: + **ONLY available in the static mode.** + + All operators in the target block are processed as follows. + If it is an automatic differential basic operator, it will be + transformed into one or a series of original operators with + equivalent function to support execution. + + Args: + block(paddle.static.Variable|None, optional): The + target block to process on. Default None, and will + process on the current block of main program. + + Examples: + + .. code-block:: python + + import paddle + from paddle.incubate.autograd import enable_prim, prim_enabled, prim2orig + + paddle.enable_static() + enable_prim() + + x = paddle.ones(shape=[2, 2], dtype='float32') + x.stop_gradients = False + y = x * x + dy_dx = paddle.static.gradients(y, x) + if prim_enabled(): + prim2orig() + """ + _lower(block, reverse=True) + + +def _gradients(ys, xs, ys_bar=None): + """ A drop-in replacement of paddle.gradients but instead computing + on primitive ops. + + Args: + ys: the target tensor or tensors + xs: the input tensor or tensors + ys_bar: the optional gradient tensors of `ys` + + Returns: + xs_bar: a list gradients of input `xs` + """ + + ys, xs = to_tensors(ys), to_tensors(xs) + block = ys[0].block + # TODO(Tongxin) without any prior knowledge about whether the program + # is completely lowered to primitive ops, it's mandatory to run the lowering + # pass once and again. This is obviously inefficient and needs to be + # optimized. + orig2prim(block) + + ad = Transform(block) + + xs_dot, ys_dot = ad.linearize(xs, ys) + if any(var is None for var in ys_dot): + assert False, f'Gradients cannot be computed. The given output `ys` does not depend on input `xs`.' + ys_bar, xs_bar = ad.transpose(ys_dot, xs_dot, ys_bar) + # remove xs_dot and their constructor ops + + op_indexes = [] + for var in xs_dot: + if var is not None: + op_index = block.ops.index(var.op) + assert op_index >= 0, f'op_index should be greater than or equal to 0, but op_index={op_index}.' + op_indexes.append(op_index) + + ad.erase_ops(sorted(op_indexes)) + ad.erase_dots(xs_dot) + + return xs_bar diff --git a/python/paddle/incubate/autograd/utils.py b/python/paddle/incubate/autograd/utils.py new file mode 100644 index 0000000000..ec4f0915ba --- /dev/null +++ b/python/paddle/incubate/autograd/utils.py @@ -0,0 +1,178 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.fluid import framework as framework + + +class PrimOption(object): + def __init__(self): + self.enable_prim = False + + def get_status(self): + return self.enable_prim + + def set_status(self, flag): + self.enable_prim = flag + + +prim_option = PrimOption() + + +@framework.static_only +def prim_enabled(): + """ + .. note:: + **ONLY available in the static mode.** + + Shows whether the automatic differentiation mechanism based on + automatic differential basic operators is ON. Defaults to OFF. + + Returns: + flag(bool): Whether the automatic differentiation mechanism based on automatic differential basic operators is ON. + + Examples: + + .. code-block:: python + + import paddle + from paddle.incubate.autograd import enable_prim, disable_prim, prim_enabled + + paddle.enable_static() + enable_prim() + + print(prim_enabled()) # True + + disable_prim() + + print(prim_enabled()) # False + """ + return prim_option.get_status() + + +@framework.static_only +def enable_prim(): + """ + .. note:: + **ONLY available in the static mode.** + + Turns ON automatic differentiation mechanism based on automatic + differential basic operators. + + Examples: + + .. code-block:: python + + import paddle + from paddle.incubate.autograd import enable_prim, prim_enabled + + paddle.enable_static() + enable_prim() + + print(prim_enabled()) # True + """ + prim_option.set_status(True) + + +@framework.static_only +def disable_prim(): + """ + .. note:: + **ONLY available in the static mode.** + + Turns OFF automatic differentiation mechanism based on automatic + differential basic operators. + + Examples: + + .. code-block:: python + + import paddle + from paddle.incubate.autograd import enable_prim, disable_prim, prim_enabled + + paddle.enable_static() + enable_prim() + + print(prim_enabled()) # True + + disable_prim() + + print(prim_enabled()) # False + """ + prim_option.set_status(False) + + +INT_DTYPE_2_STRING = { + int(0): 'bool', + int(1): 'int16', + int(2): 'int32', + int(3): 'int64', + int(4): 'float16', + int(5): 'float32', + int(6): 'float64', + int(20): 'uint8', + int(21): 'int8', + int(23): 'complex64', + int(24): 'complex128', +} + + +def get_var_block(block, names): + assert isinstance(names, list) + if len(names) == 0: + return None + elif len(names) == 1: + return block.var(names[0]) + else: + return [block.var(name) for name in names] + + +def get_input_var_list(op): + if op.input_names is None: + return [] + else: + return [ + get_var_block(op.block, op.input(n)) for n in sorted(op.input_names) + ] + + +def get_output_var_list(op): + if op.output_names is None: + return [] + else: + return [ + get_var_block(op.block, op.output(n)) + for n in sorted(op.output_names) + ] + + +def to_tensors(xs): + if isinstance(xs, paddle.fluid.framework.Variable): + return [xs] + else: + return xs + + +def flatten(inp): + if inp is None or isinstance(inp, paddle.fluid.framework.Variable): + return [inp] + flattened = [] + for part in inp: + flattened += flatten(part) + return flattened + + +def flatten_and_remove_none(inp): + flattened = flatten(inp) + return [var for var in flattened if var is not None] diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index 0dfe294c00..9dfec3947e 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -47,6 +47,45 @@ from paddle.fluid.framework import _in_legacy_dygraph, _in_eager_without_dygraph __all__ = [] +@framework.static_only +def append_backward_new(loss_list, + parameter_list=None, + no_grad_set=None, + callbacks=None, + checkpoints=None, + distop_context=None): + from paddle.incubate.autograd.primx import orig2prim, Transform + program = default_main_program() + assert program.num_blocks == 1, "The append_backward_new interface is designed to process only one block." + block = program.current_block() + + orig2prim(block) + ad = Transform(block) + if parameter_list is None: + parameter_list = program.global_block().all_parameters() + param_dot, loss_dot = ad.linearize(parameter_list, loss_list) + loss_bar, param_bar = ad.transpose(loss_dot, param_dot) + + # remove param_dot and their constructor ops + op_indexes = [] + for var in param_dot: + if var is not None: + op_index = block.ops.index(var.op) + assert op_index >= 0 + op_indexes.append(op_index) + + ad.erase_ops(sorted(op_indexes)) + ad.erase_dots(param_dot) + + if len(parameter_list) == 1: + params_and_grads = [(parameter_list, param_bar)] + else: + params_and_grads = [] + for i, param in enumerate(parameter_list): + params_and_grads.append((param, param_bar[i])) + return params_and_grads + + class Optimizer(object): r"""Optimizer Base class. @@ -880,8 +919,13 @@ class Optimizer(object): parameter_list = parameters if parameters \ else self._parameter_list with program_guard(program, startup_program): - params_grads = append_backward(loss, parameter_list, - act_no_grad_set, callbacks) + from paddle.incubate.autograd.utils import prim_enabled + if prim_enabled(): + params_grads = append_backward_new( + [loss], parameter_list, act_no_grad_set, callbacks) + else: + params_grads = append_backward(loss, parameter_list, + act_no_grad_set, callbacks) # Note: since we can't use all_reduce_op now, # dgc_op should be the last op of one grad. self._append_dgc_ops(params_grads) diff --git a/python/setup.py.in b/python/setup.py.in index 4cf8bc3fc6..c1a6e3d394 100755 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -368,6 +368,7 @@ packages=['paddle', 'paddle.incubate.nn.functional', 'paddle.incubate.nn.layer', 'paddle.incubate.optimizer.functional', + 'paddle.incubate.autograd', 'paddle.incubate.distributed', 'paddle.incubate.distributed.models', 'paddle.incubate.distributed.models.moe', -- GitLab