未验证 提交 f6ee202f 编写于 作者: W WangZhen 提交者: GitHub

Add support for forward and reverse high-order automatic differentiation mechanism (#41919)

* Updated triple_grad_check func

* add todo for gradient checker and refine some comments

* remove additional code

* add test for warnging in backward.py

* format python code

* support multi input in triple gradient checker

* Add matmul triple grad kernel

* Updated comments of TODO

* Supported some special tests

* Change code-format to follow CI std

* Updated gradient_checker.py

* Fix conflicts

* Removed unnecessary printing log

* Change code style to follow CI std

* merge upstream

* add priops.py

* add_p

* rm useless files

* add sub_p mul_p div_p

* add sqrt_p and tanh_p

* add reshape_p

* add broadcast_p

* Add python primitive wrappers.

* Jvp rules updated.

* JVP rules done for all the 17 primops.

* quick check and fixes.

* add jvp(op, *args)

* add broadcast_p fill_constant_p matmul_p reduce_p reshape_p transpose_p

* add split_p and concat_p

* add gather_p and scatter_add_p

* add slice_select_p and slice_assign_p

* Add transpose rules.

* add multi input check for add_p, sub_p, mul_p, div_p

* update concat_p

* Linearize and transpose in progress..

* refine gather_p and scatter_add_p

* updated.

* update transpose.

* refine slice_assign_p and slice_select_p

* init commit for lower

* Merged with primitive ops.

* small update

* add rules for orig2prim and prim2orig

* add 9 test for prim ops

* add more test and fix some bug

* add more test

* register proto

* Adding primops test.

* add shape valid check for broadcast_p op, and add keepdim attr into reduce_p op proto

* support multi input and multi output for split_p and concat_p

* Test updated.

* update

* fix slice bug for slice_select_p and slice_assign_p

* updated.

* Ops updated.

* Refactor and bug fixes.

* updated.

* finish orig2prim and prim2orig rules

* dtype for axis attr should be long int

* update dtype for axis attr int64_t

* update for iscan CI

* Update primx.

* Refactor vars in primx.

* update for lower transform

* add more shape and dtype check

* update primx.py

* change IndexTensor into int32 dtype

* update

* Fix linearize and transpose.

* Update is_dot

* Update is_dot

* Update is_dot

* add gradient aggregation, fix add_transpose.

* pass first linearize+transpose test.

* update test

* refactor op registration and primx.

* update rule for slice_assign

* try test lower

* update orig2prim and prim2orig

* pass simple lower pass

* update

* Update input types in the unit test.

* orig2prim segfault.

* 50% for adam.minimize

* test updated.

* temp fix erros in removing vars.

* primx updated.

* update for matmul_v2 and reshape2 orig2prim

* update for minimize

* Refine primrules

* Remove some code

* supporting unused and unreachable vars.

* update for use prim2orig in minimize

* fix gather and scatter_add transpose.

* Add rules UT

* update scatter_add

* Refine UT code

* fix nonetype check in topo

* Update gather_p pywrapper.

* remove useless print

* Merge tongxin PR and refine code

* readd some test

* rm useless print

* polish code.

* fix bug in minimize

* add get_input_var_list and get_output_var_list and use it in lower

* Fix scatter_add_p prim2orig

* Update code and fix orig2prim/prim2orig UT

* delete vars after block.desc._remove

* Improve ops and vars clean up logics.

* fix some bug in linearize and lower

* update tanh transpose.

* use set instead of list for var2remove

* test updated.

* polish code.

* fix dot2bar delete.

* merge tx/ad

* add indextensor_dot for gather and scatter_add

* add sorted for set

* Fix scale_orig2prim params

* fix some syntax bug

* add golbal_lower_update list

* Better handling of unused vars.

* update tests.

* Fix elementwise_sub orig2prim

* support none for transpose rule

* Merge and add transform UT

* fix a bug in transpose

* Fix transpose and UT

* a hacky fix for cancat op

* Fix exector place

* Refine variable name

* Add elementwise_mul orig2prim and support p_norm when p=1

* Add sqrt orig2prim rule and UT

* merge wz test

* rename files, add enable_prim, disable_prim, prim_enabled, delete global_lower_update

* fix a bug in test_ad_transform_trans

* revert modify in framework.py

* add paddle.fluid.incubate.ad_transform to  python/setup.py.in

* Fix remove vars error

* Fix p_norm_orig2prim

* merge wz

* Modify the code directory

* Add utils.py and remove get_input/output_vars functions

* Update maolin code

* Rename UT and refine test_ad_transform_primops

* Fix div_p jvp rule

* Add higher derivatives UT

* Remove UT to autograd dir

* Fix comments

* import paddle in primops.py

* Add some error message for assert

* Refine UT class name and refine some comments in primreg.py

* update minimize of paddle/optimizer for supporting new autograd

* resolve cicular importing between backward.py and optimizer.py

* fill gradients and minimize unittest

* Replace `assert isinstance` with `raise TypeError`

* Add some assert message for primx.py

* Polish variable name

* Add some assert message

* add some docstring

* refine some name

* update the format of english documents

* Split test_transform.py to two files to avoid ci error

* fix the document format of enable_prim/disable_prim/prim2orig/prim_enabled

* polish test_gradients_and_minimize

* add default value for prim_enabled api doc

* Remove some UT to avoid windows ci error

* Enlarge test_gradients_and_minimize limit time

* Fix ut limit time
Co-authored-by: Nveyron95 <veyron_wu@163.com>
Co-authored-by: NJiabin Yang <360788950@qq.com>
Co-authored-by: Nlevi131 <limaolin01@baidu.com>
Co-authored-by: NTongxin Bai <waffle.bai@gmail.com>
Co-authored-by: NXiaoxu Chen <chenxx_id@163.com>
Co-authored-by: Nlevi131 <83750468+levi131@users.noreply.github.com>
上级 b9342a80
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
class Registry(object):
""" A general registry object. """
__slots__ = ['name', 'tab']
def __init__(self, name):
self.name = name
self.tab = {}
def register(self, name, value):
assert name not in self.tab
self.tab[name] = value
def lookup(self, name):
assert name in self.tab, f'No registry entry is found with name: {name}'
return self.tab[name]
_primop_fn = Registry('primop_fn')
_orig2prim = Registry('orig2prim')
_prim2orig = Registry('prim2orig')
_primop_jvp = Registry('primop_jvp')
_primop_transpose = Registry('primop_transpose')
_primop_position_argnames = Registry('primop_position_argnames')
def REGISTER_FN(op_type, *position_argnames):
"""Decorator for registering the Python function for a primitive op."""
assert isinstance(op_type, str)
_primop_position_argnames.register(op_type, position_argnames)
def wrapper(f):
_primop_fn.register(op_type, f)
return f
return wrapper
......@@ -32,6 +32,7 @@ try:
from collections.abc import Sequence
except:
from collections import Sequence
__all__ = [
'append_backward',
'gradients',
......@@ -2113,6 +2114,11 @@ def gradients(targets, inputs, target_gradients=None, no_grad_set=None):
check_type(target_gradients, 'target_gradients', (
framework.Variable, list, tuple, type(None)), 'paddle.static.gradients')
from ..incubate.autograd.primx import _gradients
from ..incubate.autograd.utils import prim_enabled
if prim_enabled():
return _gradients(targets, inputs, target_gradients)
outs = calc_gradient(targets, inputs, target_gradients, no_grad_set)
return _as_list(outs)
......
......@@ -8,3 +8,4 @@ endforeach(TEST_OP)
set_tests_properties(test_autograd_functional_dynamic PROPERTIES TIMEOUT 160)
set_tests_properties(test_autograd_functional_static PROPERTIES TIMEOUT 160)
set_tests_properties(test_gradients_and_minimize PROPERTIES TIMEOUT 60)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.incubate.autograd.primx import prim2orig
from paddle.incubate.autograd.utils import enable_prim, disable_prim, prim_enabled
paddle.enable_static()
class TestGradients(unittest.TestCase):
def test_third_order(self):
enable_prim()
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
x = paddle.static.data(name='x', shape=[1], dtype='float32')
x2 = paddle.multiply(x, x)
x3 = paddle.multiply(x2, x)
x4 = paddle.multiply(x3, x)
grad1, = paddle.static.gradients([x4], [x])
grad2, = paddle.static.gradients([grad1], [x])
grad3, = paddle.static.gradients([grad2], [x])
prim2orig(main.block(0))
feed = {x.name: np.array([2.]).astype('float32')}
fetch_list = [grad3.name]
result = [np.array([48.])]
place = paddle.CPUPlace()
if paddle.device.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(startup)
outs = exe.run(main, feed=feed, fetch_list=fetch_list)
np.allclose(outs, result)
disable_prim()
def test_fourth_order(self):
enable_prim()
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
x = paddle.static.data(name='x', shape=[1], dtype='float32')
x2 = paddle.multiply(x, x)
x3 = paddle.multiply(x2, x)
x4 = paddle.multiply(x3, x)
x5 = paddle.multiply(x4, x)
out = paddle.sqrt(x5 + x4)
grad1, = paddle.static.gradients([out], [x])
grad2, = paddle.static.gradients([grad1], [x])
grad3, = paddle.static.gradients([grad2], [x])
grad4, = paddle.static.gradients([grad3], [x])
prim2orig(main.block(0))
feed = {x.name: np.array([2.]).astype('float32'), }
fetch_list = [grad4.name]
# (3*(-5*x^2-16*x-16))/(16*(x+1)^3.5)
result = [np.array([-0.27263762711])]
place = paddle.CPUPlace()
if paddle.device.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(startup)
outs = exe.run(main, feed=feed, fetch_list=fetch_list)
np.allclose(outs, result)
disable_prim()
class TestMinimize(unittest.TestCase):
def model(self, x, w, bias, opt):
paddle.seed(0)
place = paddle.CPUPlace()
if paddle.device.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
input_x = paddle.static.data('x', x.shape, dtype=x.dtype)
input_x.stop_gradient = False
params_w = paddle.static.create_parameter(
shape=w.shape, dtype=w.dtype, is_bias=False)
params_bias = paddle.static.create_parameter(
shape=bias.shape, dtype=bias.dtype, is_bias=True)
y = paddle.tanh(paddle.matmul(input_x, params_w) + params_bias)
loss = paddle.norm(y, p=2)
opt = opt
_, grads = opt.minimize(loss)
if prim_enabled():
prim2orig(main.block(0))
exe.run(startup)
grads = exe.run(main,
feed={'x': x,
'w': w,
'bias': bias},
fetch_list=grads)
return grads
def test_adam(self):
x = np.random.rand(2, 20)
w = np.random.rand(20, 2)
bias = np.random.rand(2)
enable_prim()
prim_grads = self.model(x, w, bias, paddle.optimizer.Adam(0.01))
disable_prim()
orig_grads = self.model(x, w, bias, paddle.optimizer.Adam(0.01))
for orig, prim in zip(orig_grads, prim_grads):
np.testing.assert_allclose(orig, prim)
def test_sgd(self):
x = np.random.rand(2, 20)
w = np.random.rand(20, 2)
bias = np.random.rand(2)
enable_prim()
prim_grads = self.model(x, w, bias, paddle.optimizer.SGD(0.01))
disable_prim()
orig_grads = self.model(x, w, bias, paddle.optimizer.SGD(0.01))
for orig, prim in zip(orig_grads, prim_grads):
np.testing.assert_allclose(orig, prim)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.layers.utils import flatten
from paddle.incubate.autograd.primrules import _orig2prim, _prim2orig, _jvp, _transpose
paddle.enable_static()
############################ Test orig2prim rules ############################
class TestElementWiseAddOrig2Prim(unittest.TestCase):
def setUp(self):
self.main_program = paddle.static.Program()
self.startup_program = paddle.static.Program()
self.layer_help = LayerHelper('TestOrig2Prim')
with paddle.static.program_guard(self.main_program,
self.startup_program):
self.init_data()
def init_data(self):
self.op_type = 'elementwise_add'
X = paddle.static.data(name='X', shape=[2, 2], dtype='float')
Y = paddle.static.data(name='Y', shape=[2, 2], dtype='float')
self.input = {'X': X, 'Y': Y}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.orig2prim_args = (X, Y)
self.all_ops = ['elementwise_add', 'add_p']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}
def test_op(self):
with paddle.static.program_guard(self.main_program,
self.startup_program):
op = self.layer_help.append_op(
type=self.op_type,
inputs=self.input,
outputs=self.output,
attrs=self.attrs)
prim_out = _orig2prim(op, *self.orig2prim_args)
all_ops = [op.type for op in self.main_program.block(0).ops]
self.assertEqual(sorted(all_ops), sorted(self.all_ops))
prim_out = flatten(prim_out)
for k, v in self.out_map.items():
self.assertEqual(prim_out[k].shape, v.shape)
class TestSqrtOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'sqrt'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
self.input = {'X': X, }
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.orig2prim_args = (X, )
self.all_ops = ['sqrt', 'sqrt_p']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}
class TestElementWiseMulOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'elementwise_mul'
X = paddle.static.data(name='X', shape=[8, 8], dtype='float')
Y = paddle.static.data(name='Y', shape=[8, 8], dtype='float')
self.input = {'X': X, 'Y': Y}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.orig2prim_args = (X, Y)
self.all_ops = ['elementwise_mul', 'mul_p']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}
class TestMatmulV2Orig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'matmul_v2'
X = paddle.static.data(name='X', shape=[3, 4], dtype='float')
Y = paddle.static.data(name='Y', shape=[4, 3], dtype='float')
self.input = {'X': X, 'Y': Y}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'trans_x': True, 'trans_y': True}
self.orig2prim_args = (X, Y)
self.all_ops = ['matmul_v2', 'transpose_p', 'transpose_p', 'matmul_p']
self.out_map = {0: self.output['Out']}
class TestTanhOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'tanh'
X = paddle.static.data(name='X', shape=[3, 4], dtype='float')
self.input = {'X': X, }
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.orig2prim_args = (X, )
self.all_ops = ['tanh', 'tanh_p']
self.out_map = {0: self.output['Out']}
class TestReshape2Orig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'reshape2'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
self.input = {'X': X, }
self.output = {
'Out': X,
'XShape':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'shape': [6, 5]}
self.orig2prim_args = (
None,
None,
X, )
self.all_ops = ['reshape2', 'reshape_p', 'fill_constant_p']
# Do not checke XShape
self.out_map = {0: self.output['Out']}
class TestConcatOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'concat'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
Y = paddle.static.data(name='Y', shape=[3, 6], dtype='int64')
self.input = {'X': [X, Y], }
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'axis': 0}
self.orig2prim_args = (
None,
(X, Y), )
self.all_ops = ['concat', 'concat_p']
self.out_map = {0: self.output['Out']}
class TestSliceOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'slice'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
self.input = {'Input': X, }
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {
'axes': [0],
'starts': [1],
'ends': [4],
}
self.orig2prim_args = (None, None, X, None, None)
self.all_ops = ['slice', 'slice_select_p']
self.out_map = {0: self.output['Out']}
class TestFillZerosLikeOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'fill_zeros_like'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
self.input = {'X': X, }
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.orig2prim_args = (X, )
self.all_ops = ['fill_zeros_like', 'fill_constant_p']
self.out_map = {0: self.output['Out']}
class TestSumOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'sum'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
Y = paddle.static.data(name='Y', shape=[5, 6], dtype='int64')
self.input = {'X': X, 'Y': Y}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.orig2prim_args = ((X, Y), )
self.all_ops = ['sum', 'add_p']
self.out_map = {0: self.output['Out']}
class TestPNormOrig2Prim1(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'p_norm'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
self.input = {'X': X, }
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {
'porder': 1,
'asvector': True,
}
self.orig2prim_args = (X, )
self.all_ops = ['p_norm', 'reshape_p', 'sqrt_p', 'reduce_p', 'mul_p']
self.out_map = {0: self.output['Out']}
class TestPNormOrig2Prim2(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'p_norm'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
self.input = {'X': X, }
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {
'porder': 2,
'asvector': True,
}
self.orig2prim_args = (X, )
self.all_ops = ['p_norm', 'reshape_p', 'sqrt_p', 'reduce_p', 'mul_p']
self.out_map = {0: self.output['Out']}
class TestIndexSelectOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'index_select'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
Index = paddle.static.data(name='Index', shape=[2], dtype='int32')
self.input = {'X': X, 'Index': Index}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'dim': 0, }
self.orig2prim_args = (
Index,
X, )
self.all_ops = ['index_select', 'gather_p']
self.out_map = {0: self.output['Out']}
class TestElementwiseSubOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'elementwise_sub'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int32')
Y = paddle.static.data(name='Y', shape=[6], dtype='int32')
self.input = {'X': X, 'Y': Y}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'dim': 0, }
self.orig2prim_args = (
X,
Y, )
self.all_ops = ['elementwise_sub', 'broadcast_p', 'sub_p']
self.out_map = {0: self.output['Out']}
class TestScaleOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'scale'
X = paddle.static.data(name='X', shape=[10, 7], dtype='int32')
self.input = {'X': X, }
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'scale': 2.0, 'bias': 1.0, 'bias_after_scale': True}
self.orig2prim_args = (
None,
X, )
self.all_ops = [
'scale', 'fill_constant_p', 'fill_constant_p', 'mul_p', 'add_p'
]
self.out_map = {0: self.output['Out']}
class TestAssignOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'assign'
X = paddle.static.data(name='X', shape=[10, 7], dtype='int32')
self.input = {'X': X, }
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.orig2prim_args = (X, )
self.all_ops = ['assign', 'fill_constant_p', 'add_p']
self.out_map = {0: self.output['Out']}
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.layers.utils import flatten
from paddle.incubate.autograd.primrules import _orig2prim, _prim2orig, _jvp, _transpose
paddle.enable_static()
############################ Test prim2orig rules ############################
class TestAddPPrim2Orig(unittest.TestCase):
def setUp(self):
self.main_program = paddle.static.Program()
self.startup_program = paddle.static.Program()
self.layer_help = LayerHelper('TestPrim2Orig')
with paddle.static.program_guard(self.main_program,
self.startup_program):
self.init_data()
def init_data(self):
self.op_type = 'add_p'
X = paddle.static.data(name='X', shape=[2, 2], dtype='float')
Y = paddle.static.data(name='Y', shape=[2, 2], dtype='float')
self.input = {'X': X, 'Y': Y}
self.output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.prim2orig_args = (X, Y)
self.all_ops = ['add_p', 'elementwise_add']
# { prim_op_output_var: orign_op_out_index }
self.out_map = {self.output['Z']: 0}
def test_op(self):
with paddle.static.program_guard(self.main_program,
self.startup_program):
op = self.layer_help.append_op(
type=self.op_type,
inputs=self.input,
outputs=self.output,
attrs=self.attrs)
orig_out = _prim2orig(op, *self.prim2orig_args)
all_ops = [op.type for op in self.main_program.block(0).ops]
self.assertEqual(sorted(all_ops), sorted(self.all_ops))
orig_out = flatten(orig_out)
for k, v in self.out_map.items():
self.assertEqual(k.shape, orig_out[v].shape)
class TestSubPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'sub_p'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64')
self.input = {'X': X, 'Y': Y}
self.output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.prim2orig_args = (X, Y)
self.all_ops = ['sub_p', 'elementwise_sub']
self.out_map = {self.output['Z']: 0}
class TestMulPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'mul_p'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64')
self.input = {'X': X, 'Y': Y}
self.output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.prim2orig_args = (X, Y)
self.all_ops = ['mul_p', 'elementwise_mul']
self.out_map = {self.output['Z']: 0}
class TestDivPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'div_p'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64')
self.input = {'X': X, 'Y': Y}
self.output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.prim2orig_args = (X, Y)
self.all_ops = ['div_p', 'elementwise_div']
self.out_map = {self.output['Z']: 0}
class TestSqrtPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'sqrt_p'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
self.input = {'X': X, }
self.output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.prim2orig_args = (X, )
self.all_ops = ['sqrt_p', 'sqrt']
self.out_map = {self.output['Y']: 0}
class TestTanhPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'tanh_p'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
self.input = {'X': X, }
self.output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.prim2orig_args = (X, )
self.all_ops = ['tanh_p', 'tanh']
self.out_map = {self.output['Y']: 0}
class TestReshapePPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'reshape_p'
X = paddle.static.data(name='X', shape=[2, 8], dtype='float64')
self.input = {'X': X, }
self.output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'shape': [4, 4]}
self.prim2orig_args = (X, )
self.all_ops = ['reshape_p', 'reshape2']
self.out_map = {self.output['Y']: 0}
class TestBroadcastPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'broadcast_p'
X = paddle.static.data(name='X', shape=[2, 8], dtype='float64')
self.input = {'X': X, }
self.output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'shape': [10, 2, 8]}
self.prim2orig_args = (X, )
self.all_ops = ['broadcast_p', 'expand_v2']
self.out_map = {self.output['Y']: 0}
class TestTransposePPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'transpose_p'
X = paddle.static.data(name='X', shape=[7, 8, 9, 10], dtype='float64')
self.input = {'X': X, }
self.output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'axis': [1, 2, 0, 3]}
self.prim2orig_args = (X, )
self.all_ops = ['transpose_p', 'transpose2']
self.out_map = {self.output['Y']: 0}
class TestSplitPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'split_p'
X = paddle.static.data(name='X', shape=[3, 9, 5], dtype='float64')
self.input = {'X': X, }
self.output = {
'YS': [
self.layer_help.create_variable_for_type_inference(
dtype=X.dtype) for i in range(3)
]
}
self.attrs = {'num_or_sections': [2, 3, 4], 'axis': 1}
self.prim2orig_args = (X, )
self.all_ops = ['split_p', 'split']
self.out_map = {
self.output['YS'][0]: 0,
self.output['YS'][1]: 1,
self.output['YS'][2]: 2,
}
class TestConcatPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'concat_p'
X = paddle.static.data(name='X', shape=[3, 9, 5], dtype='float64')
Y = paddle.static.data(name='Y', shape=[2, 9, 5], dtype='float64')
Z = paddle.static.data(name='Z', shape=[1, 9, 5], dtype='float64')
self.input = {'XS': [X, Y, Z], }
self.output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'axis': 0}
self.prim2orig_args = ((X, Y, Z), )
self.all_ops = ['concat_p', 'concat']
self.out_map = {self.output['Y']: 0}
class TestReducePPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'reduce_p'
X = paddle.static.data(name='X', shape=[3, 9, 5], dtype='float64')
self.input = {'X': X}
self.output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'axis': [1], 'keepdim': True}
self.prim2orig_args = (X, )
self.all_ops = ['reduce_p', 'reduce_sum']
self.out_map = {self.output['Y']: 0}
class TestMatmulPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'matmul_p'
X = paddle.static.data(name='X', shape=[9, 5], dtype='float64')
Y = paddle.static.data(name='Y', shape=[5, 9], dtype='float64')
self.input = {'X': X, 'Y': Y}
self.output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.prim2orig_args = (X, Y)
self.all_ops = ['matmul_p', 'matmul_v2']
self.out_map = {self.output['Z']: 0}
class TestSliceSelectPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'slice_select_p'
X = paddle.static.data(name='X', shape=[9, 5], dtype='float64')
self.input = {'X': X, }
self.output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'axis': [0], 'starts': [1], 'ends': [8], 'strides': [2]}
self.prim2orig_args = (X, )
self.all_ops = ['slice_select_p', 'strided_slice']
self.out_map = {self.output['Y']: 0}
class TestSliceAssignPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'slice_assign_p'
X = paddle.static.data(name='X', shape=[9, 5], dtype='float64')
Y = paddle.static.data(name='Y', shape=[9, 3], dtype='float64')
self.input = {'X': X, 'Y': Y}
self.output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'axis': [1], 'starts': [0], 'ends': [3], 'strides': [1]}
self.prim2orig_args = (X, Y)
self.all_ops = ['slice_assign_p', 'assign', 'set_value']
self.out_map = {self.output['Z']: 0}
class TestGatherPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'gather_p'
X = paddle.static.data(name='X', shape=[9, 5], dtype='float64')
IndexTensor = paddle.static.data(
name='IndexTensor', shape=[3], dtype='int32')
self.input = {'X': X, 'IndexTensor': IndexTensor}
self.output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'axis': 0, }
self.prim2orig_args = (
IndexTensor,
X, )
self.all_ops = ['gather_p', 'gather']
self.out_map = {self.output['Y']: 0}
class TestScatterAddPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'scatter_add_p'
X = paddle.static.data(name='X', shape=[9, 5], dtype='float64')
Y = paddle.static.data(name='Y', shape=[3, 5], dtype='float64')
IndexTensor = paddle.static.data(
name='IndexTensor', shape=[3], dtype='int32')
self.input = {'X': X, 'Y': Y, 'IndexTensor': IndexTensor}
self.output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'axis': 0, }
self.prim2orig_args = (IndexTensor, X, Y)
self.all_ops = [
'scatter_add_p', 'fill_any_like', 'scatter', 'elementwise_add'
]
self.out_map = {self.output['Z']: 0}
class TestFillConstantPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'fill_constant_p'
self.input = {}
self.output = {
'Y':
self.layer_help.create_variable_for_type_inference(paddle.int32)
}
self.attrs = {'value': 10, 'shape': [5, 5], 'dtype': paddle.int32}
self.prim2orig_args = ()
self.all_ops = ['fill_constant_p', 'fill_constant']
self.out_map = {self.output['Y']: 0}
if __name__ == '__main__':
unittest.main()
......@@ -14,12 +14,13 @@
import unittest
import numpy as np
import paddle
from paddle.autograd.primops import (
from paddle.incubate.autograd.primops import (
neg, set_value, add, sub, mul, div, sqrt, tanh, reshape, broadcast,
transpose, split, concat, reduce, matmul, slice_select, slice_assign,
gather, scatter_add, fill_const)
from paddle.incubate.autograd.primx import Transform, topo_path, orig2prim, prim2orig, _gradients
from paddle.incubate.autograd.utils import enable_prim, disable_prim, prim_enabled
class TestPyPrimOps(unittest.TestCase):
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.incubate.autograd.primx import Transform, orig2prim, prim2orig
from paddle.fluid.layers.utils import flatten
paddle.enable_static()
class TestAutoGradTransformForAdd(unittest.TestCase):
def setUp(self):
self.main_program = paddle.static.Program()
self.startup_program = paddle.static.Program()
with paddle.static.program_guard(self.main_program,
self.startup_program):
self.init_data()
def init_data(self):
# { input_index: input_shape }
self.xs_shape_map = {0: (20, 40), 1: (20, 40)}
# { output_index: output_shape }
self.ys_shape_map = {0: (20, 40)}
X0 = paddle.static.data(
name='X0', shape=self.xs_shape_map[0], dtype='float32')
X0.stop_gradient = False
X1 = paddle.static.data(
name='X1', shape=self.xs_shape_map[1], dtype='float32')
X1.stop_gradient = False
A = paddle.tanh(X0)
B = paddle.tanh(X1)
Y = paddle.add(A, B)
self.orig_xs = [X0, X1]
self.orig_ys = [Y, ]
self.orig_ops = ['tanh', 'tanh', 'elementwise_add']
self.orig2prim_ops = ['tanh_p', 'tanh_p', 'add_p']
self.linearize_ops = self.orig2prim_ops + [
# call fill_const() in linearize() function
'fill_constant_p',
'fill_constant_p',
# linearized op
'mul_p',
'sub_p',
'fill_constant_p',
'mul_p',
'mul_p',
'sub_p',
'fill_constant_p',
'mul_p',
'add_p',
]
self.transpose_ops = self.orig2prim_ops + [
# call fill_const() in transpose() function
'fill_constant_p',
# linearized op after remove path
'fill_constant_p',
'fill_constant_p',
'mul_p',
'sub_p',
'fill_constant_p',
'mul_p',
'sub_p',
'fill_constant_p',
# transposed op
'mul_p',
'mul_p'
]
self.prim2orig_ops = [
'tanh', 'tanh', 'elementwise_add', 'fill_constant', 'fill_constant',
'fill_constant', 'elementwise_mul', 'elementwise_sub',
'fill_constant', 'elementwise_mul', 'elementwise_sub',
'fill_constant', 'elementwise_mul', 'elementwise_mul'
]
def test_run(self):
# Must using with program_guard(), otherwise prim ops will append other block
with paddle.static.program_guard(self.main_program,
self.startup_program):
ad = Transform(self.main_program.block(0))
orig_ops = [op.type for op in self.main_program.block(0).ops]
self.assertEqual(sorted(orig_ops), sorted(self.orig_ops))
# Test orig2prim
orig2prim(block=self.main_program.block(0))
orig2prim_ops = [op.type for op in self.main_program.block(0).ops]
self.assertEqual(sorted(orig2prim_ops), sorted(self.orig2prim_ops))
# Test linearize
xs_dot, ys_dot = ad.linearize(self.orig_xs, self.orig_ys)
linearize_ops = [op.type for op in self.main_program.block(0).ops]
self.assertEqual(sorted(linearize_ops), sorted(self.linearize_ops))
flatten_xs_dot = flatten(xs_dot)
for k, v in self.xs_shape_map.items():
self.assertEqual(flatten_xs_dot[k].shape, v)
flatten_ys_dot = flatten(ys_dot)
for k, v in self.ys_shape_map.items():
self.assertEqual(flatten_ys_dot[k].shape, v)
# Test transpose
ys_bar, xs_bar = ad.transpose(ys_dot, xs_dot, retain_fwd=False)
transpose_ops = [op.type for op in self.main_program.block(0).ops]
self.assertEqual(sorted(transpose_ops), sorted(self.transpose_ops))
flatten_xs_bar = flatten(xs_bar)
for k, v in self.xs_shape_map.items():
# There may be None in the result of transpose like gather op
if flatten_xs_bar[k] is not None:
self.assertEqual(flatten_xs_bar[k].shape, v)
flatten_ys_bar = flatten(ys_bar)
for k, v in self.ys_shape_map.items():
self.assertEqual(flatten_ys_bar[k].shape, v)
# Test prim2orig
prim2orig(block=self.main_program.block(0))
prim2orig_ops = [op.type for op in self.main_program.block(0).ops]
self.assertEqual(sorted(prim2orig_ops), sorted(self.prim2orig_ops))
class TestAutoGradTransformForMatmul(TestAutoGradTransformForAdd):
def init_data(self):
# { input_index: input_shape }
self.xs_shape_map = {0: (100, 2), 1: (5, 2)}
# { output_index: output_shape }
self.ys_shape_map = {0: (100, 5)}
X0 = paddle.static.data(
'X0', shape=self.xs_shape_map[0], dtype='float32')
X0.stop_gradient = False
X1 = paddle.static.data(
'X1', shape=self.xs_shape_map[1], dtype='float32')
X1.stop_gradient = False
A = paddle.reshape(X1, [2, 5])
B = paddle.scale(A, scale=2.0, bias=2.0)
Y = paddle.matmul(X0, B)
self.orig_xs = [X0, X1]
self.orig_ys = [Y, ]
self.orig_ops = ['reshape2', 'scale', 'matmul_v2']
self.orig2prim_ops = [
'reshape_p', 'fill_constant_p', 'fill_constant_p',
'fill_constant_p', 'mul_p', 'add_p', 'matmul_p'
]
self.linearize_ops = self.orig2prim_ops + [
# call fill_const() in linearize() function
'fill_constant_p',
'fill_constant_p',
# linearized op
'reshape_p',
'mul_p',
# 'mul_p', # JVP rules handle `None` input, some op will not be appended
# 'add_p',
# 'add_p',
'matmul_p',
'matmul_p',
'add_p'
]
self.transpose_ops = self.orig2prim_ops + [
# call fill_const() in transpose() function
'fill_constant_p',
# linearized op after remove path
'fill_constant_p',
'fill_constant_p',
'mul_p',
# transposed op
'transpose_p',
'matmul_p',
'transpose_p',
'matmul_p',
# 'mul_p',
'reshape_p',
]
self.prim2orig_ops = [
'reshape2',
'fill_constant',
'fill_constant',
'fill_constant',
'elementwise_mul',
'elementwise_add',
'matmul_v2',
'fill_constant',
'fill_constant',
'fill_constant',
'elementwise_mul',
'transpose2',
'matmul_v2',
'transpose2',
'matmul_v2',
# 'elementwise_mul',
'reshape2',
]
class TestAutoGradTransformForIndexSelect(TestAutoGradTransformForAdd):
def init_data(self):
# { input_index: input_shape }
self.xs_shape_map = {0: (7, 8, 9), 1: (8, 1), 2: (7, 8, 9), 3: (3, )}
# { output_index: output_shape }
self.ys_shape_map = {0: (3, 16, 9)}
X0 = paddle.static.data(
'X0', shape=self.xs_shape_map[0], dtype='float32')
X0.stop_gradient = False
X1 = paddle.static.data(
'X1', shape=self.xs_shape_map[1], dtype='float32')
X1.stop_gradient = False
X2 = paddle.static.data(
'X2', shape=self.xs_shape_map[2], dtype='float32')
X2.stop_gradient = False
X3 = paddle.static.data('X3', shape=self.xs_shape_map[3], dtype='int32')
X3.stop_gradient = False
A = paddle.add(X0, X1) # (7, 8, 9)
B = paddle.norm(x=A, p=2) # (1, )
C = paddle.subtract(X2, B) # (7, 8, 9)
D = paddle.concat(x=(A, C), axis=1) # (7, 16, 9)
Y = paddle.index_select(D, X3, axis=0) # (3, 16, 9)
self.orig_xs = [X0, X1, X2, X3]
self.orig_ys = [Y, ]
self.orig_ops = [
'elementwise_add', 'p_norm', 'elementwise_sub', 'concat',
'index_select'
]
self.orig2prim_ops = [
'broadcast_p', 'add_p', 'reshape_p', 'mul_p', 'reduce_p', 'sqrt_p',
'broadcast_p', 'sub_p', 'concat_p', 'gather_p'
]
self.linearize_ops = self.orig2prim_ops + [
# call fill_const() in linearize() function
'fill_constant_p',
'fill_constant_p',
'fill_constant_p',
'fill_constant_p',
# linearized op
'broadcast_p',
'add_p',
'reshape_p',
'mul_p',
'mul_p',
'add_p',
'reduce_p',
'fill_constant_p', # 'sqrt_p', Will not append sqrt_p op when apply JVP for sqrt_p
'mul_p',
'div_p',
'broadcast_p',
'sub_p',
'concat_p',
'gather_p'
]
self.transpose_ops = self.orig2prim_ops + [
# call fill_const() in transpose() function
'fill_constant_p',
# linearized op after remove path
'fill_constant_p',
'fill_constant_p',
'fill_constant_p',
'fill_constant_p',
'fill_constant_p',
'mul_p',
# transposed op
'reduce_p',
'reshape_p',
'reshape_p',
'mul_p',
'mul_p',
'reshape_p',
'broadcast_p',
'div_p',
'reduce_p',
'reshape_p',
'fill_constant_p',
'sub_p',
'split_p',
'fill_constant_p',
'scatter_add_p',
'add_p', # The output of the op is used by multiple subsequent ops
'add_p',
]
self.prim2orig_ops = [
'expand_v2', 'elementwise_add', 'reshape2', 'elementwise_mul',
'reduce_sum', 'sqrt', 'expand_v2', 'elementwise_sub', 'concat',
'gather', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant',
'elementwise_mul', 'reduce_sum', 'reshape2', 'reshape2',
'elementwise_mul', 'elementwise_mul', 'reshape2', 'expand_v2',
'elementwise_div', 'reduce_sum', 'reshape2', 'fill_constant',
'elementwise_sub', 'split', 'fill_constant', 'fill_any_like',
'elementwise_add', 'scatter', 'elementwise_add', 'elementwise_add'
]
if __name__ == '__main__':
unittest.main()
......@@ -12,7 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.autograd.functional import Hessian, Jacobian, jvp, vjp
from .primx import prim2orig
from .utils import enable_prim, disable_prim, prim_enabled
__all__ = [ # noqa
'vjp', 'jvp', 'Jacobian', 'Hessian'
'vjp',
'jvp',
'Jacobian',
'Hessian',
'prim2orig',
'enable_prim',
'disable_prim',
'prim_enabled'
]
......@@ -13,8 +13,6 @@
# limitations under the License.
import paddle
from paddle.fluid import unique_name, core
from paddle.fluid.framework import default_main_program, default_startup_program
from paddle.fluid.layer_helper import LayerHelper
from .primreg import REGISTER_FN
......@@ -136,7 +134,9 @@ def split(x, num_or_sections, axis=0, outs=None):
if isinstance(num_or_sections, (list, tuple)):
n = len(num_or_sections)
else:
assert isinstance(num_or_sections, int)
if not isinstance(num_or_sections, int):
raise TypeError(
f'num_or_sections must be int, but got {type(num_or_sections)}.')
n = num_or_sections
attrs = {'num_or_sections': num_or_sections, 'axis': axis}
......@@ -157,7 +157,8 @@ def split(x, num_or_sections, axis=0, outs=None):
@REGISTER_FN('concat_p', 'XS', 'Y')
def concat(xs, axis=0, out=None):
assert isinstance(xs, (list, tuple)) and len(xs) > 0
if isinstance(xs, paddle.fluid.framework.Variable):
xs = [xs]
attrs = {'axis': axis}
helper = LayerHelper('concat_p', **locals())
if out is None:
......@@ -172,9 +173,10 @@ def concat(xs, axis=0, out=None):
@REGISTER_FN('reduce_p', 'X', 'Y')
def reduce(x, axis, keepdim=False, out=None):
assert isinstance(axis, (tuple, list))
assert isinstance(keepdim, bool)
if not isinstance(axis, (tuple, list)):
raise TypeError(f'axis must be tuple or list, but got {type(axis)}')
if not isinstance(keepdim, bool):
raise TypeError(f'keepdim must be bool, but got {type(keepdim)}')
attrs = {'axis': axis, 'keepdim': keepdim}
helper = LayerHelper('reduce_p', **locals())
......@@ -196,12 +198,20 @@ def matmul(x, y, out=None):
@REGISTER_FN('slice_select_p', 'X', 'Y')
def slice_select(x, axis, starts, ends, strides, out=None):
assert isinstance(axis, (list, tuple)), (
f'Argument type error. `axis` is supposed to be int, list or'
if not isinstance(axis, (list, tuple)):
raise TypeError(f'Argument type error. `axis` is supposed to be list or'
f' tuple but found {type(axis)}.')
assert isinstance(starts, (list, tuple))
assert isinstance(ends, (list, tuple))
assert len(axis) == len(starts) == len(ends) == len(strides)
if not isinstance(starts, (list, tuple)):
raise TypeError(
f'Argument type error. `starts` is supposed to be list or'
f' tuple but found {type(starts)}.')
if not isinstance(ends, (list, tuple)):
raise TypeError(f'Argument type error. `ends` is supposed to be list or'
f' tuple but found {type(ends)}.')
assert len(axis) == len(starts) == len(ends) == len(strides), (
f'len(axis), len(starts), len(ends) and len(strides) should be equal, '
f'but len(axis)={len(axis)}, len(starts)={len(starts)}, '
f'len(ends)={len(ends)} and len(strides)={len(strides)}')
attrs = {'axis': axis, 'starts': starts, 'ends': ends, 'strides': strides}
helper = LayerHelper('slice_select_p', **locals())
......@@ -217,8 +227,13 @@ def slice_select(x, axis, starts, ends, strides, out=None):
@REGISTER_FN('slice_assign_p', 'X', 'Y', 'Z')
def slice_assign(x, y, axis, starts, ends, strides, out=None):
assert len(starts) == len(ends) == len(strides) == len(axis)
assert len(y.shape) == len(x.shape)
assert len(starts) == len(ends) == len(strides) == len(axis), (
f'len(starts), len(ends), len(strides) and len(axis) should be equal, '
f'but len(starts)={len(starts)}, len(ends)={len(ends)}, '
f'len(strides)={len(strides)} and len(axis)={len(axis)}')
assert len(y.shape) == len(x.shape), (
f'len(y.shape) should be equal to len(x.shape), '
f'but len(y.shape)={len(y.shape)} and len(x.shape)={len(x.shape)}.')
attrs = {'axis': axis, 'starts': starts, 'ends': ends, 'strides': strides}
helper = LayerHelper('slice_assign_p', **locals())
......@@ -233,7 +248,7 @@ def slice_assign(x, y, axis, starts, ends, strides, out=None):
return out
@REGISTER_FN('gather_p', 'X', 'Y')
@REGISTER_FN('gather_p', 'X', 'IndexTensor', 'Y')
def gather(x, indextensor, axis, out=None):
attrs = {'axis': axis}
helper = LayerHelper('gather_p', **locals())
......@@ -250,9 +265,16 @@ def gather(x, indextensor, axis, out=None):
@REGISTER_FN('scatter_add_p', 'X', 'Y', 'IndexTensor', 'Z')
def scatter_add(x, y, indextensor, axis, out=None):
assert len(x.shape) == len(y.shape)
assert len(indextensor.shape) == 1
assert y.shape[axis] == indextensor.shape[0]
assert len(x.shape) == len(y.shape), (
f'len(x.shape) should be equal to len(y.shape), '
f'but len(x.shape)={len(x.shape)} and len(y.shape)={len(y.shape)}.')
assert len(
indextensor.shape
) == 1, f'len(indextensor.shape) must be equal to 1, but got {len(indextensor.shape)}.'
assert y.shape[axis] == indextensor.shape[0], (
f'y.shape[axis] should be equal to indextensor.shape[0], '
f'but y.shape[axis]={y.shape[axis]} and '
f'indextensor.shape[0]={indextensor.shape[0]}.')
attrs = {'axis': axis}
helper = LayerHelper('scatter_add_p', **locals())
if out is None:
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
class Registry(object):
""" A general registry object. """
__slots__ = ['name', 'tab']
def __init__(self, name):
self.name = name
self.tab = {}
def register(self, name, value):
assert name not in self.tab, f'name "{name}" should not be registered before.'
self.tab[name] = value
def lookup(self, name):
return self.tab.get(name)
_primop_fn = Registry('primop_fn')
_orig2prim = Registry('orig2prim')
_prim2orig = Registry('prim2orig')
_primop_jvp = Registry('primop_jvp')
_primop_transpose = Registry('primop_transpose')
_primop_position_argnames = Registry('primop_position_argnames')
def lookup_fn(optype):
return _primop_fn.lookup(optype)
def lookup_orig2prim(optype):
return _orig2prim.lookup(optype)
def lookup_prim2orig(optype):
return _prim2orig.lookup(optype)
def lookup_jvp(optype):
return _primop_jvp.lookup(optype)
def lookup_transpose(optype):
return _primop_transpose.lookup(optype)
def op_position_inputs(op):
"""
Returns the position inputs of `op` as registered with REGISTER_FN.
Args:
op(Operator): The op that needs to get the inputs
Returns:
Tensor(s): Inputs of the op
Examples:
.. code-block:: python
@REGISTER_FN('div_p', 'X', 'Y', 'Z')
def div(x, y, out=None):
return _simple_binop(LayerHelper('div_p', **locals()))
The registered inputs are ['X', 'Y'] for div_p and accordingly this
function will return inputs in the order of X then Y.
"""
args = _primop_position_argnames.lookup(op.type)
assert args is not None, 'args should not be None in op_position_inputs().'
*input_names, _ = args
inputs = []
for name in input_names:
vars = list(map(op.block.var, op.input(name)))
assert len(
vars
) >= 0, f'len(vars) should be greater than or equal to 0, but len(vars)={len(vars)}.'
if len(vars) > 1:
inputs.append(vars)
else:
inputs.append(vars[0])
return inputs
def op_position_output(op):
"""
Returns the output of `op` as registered with REGISTER_FN.
Args:
op(Operator): The op that needs to get the output
Returns:
Tensor(s): Output of the op
Examples:
.. code-block:: python
@REGISTER_FN('div_p', 'X', 'Y', 'Z')
def div(x, y, out=None):
return _simple_binop(LayerHelper('div_p', **locals()))
The registered output is ['Z'] for div_p and accordingly this
function will return output Z.
"""
args = _primop_position_argnames.lookup(op.type)
assert args is not None, 'args should not be None in op_position_output().'
*_, output_name = args
outvars = list(map(op.block.var, op.output(output_name)))
assert len(
outvars
) >= 0, f'len(outvars) should be greater than or equal to 0, but len(outvars)={len(outvars)}.'
if len(outvars) > 1:
output = outvars
else:
output = outvars[0]
return output
def REGISTER_FN(op_type, *position_argnames):
"""
Decorator for registering the Python function for a primitive op.
Args:
op_type(str): The op name
position_argnames(list[str]): Input and ouput names of the op
Returns:
wrapper: Inner wrapper function
Examples:
.. code-block:: python
@REGISTER_FN('tanh_p', 'X', 'Y')
def tanh(x, out=None):
return _simple_unop(LayerHelper('tanh_p', **locals()))
"""
if not isinstance(op_type, str):
raise TypeError(f'op_type must be str, but got {type(op_type)}.')
_primop_position_argnames.register(op_type, position_argnames)
def wrapper(f):
_primop_fn.register(op_type, f)
return f
return wrapper
def REGISTER_ORIG2PRIM(op_type):
"""
Decorator for registering the lower function for an original op into sequence of primitive ops.
Args:
op_type(str): The op name
Returns:
wrapper: Inner wrapper function
Examples:
.. code-block:: python
@REGISTER_ORIG2PRIM('tanh')
def tanh_orig2prim(op):
x, = get_input_var_list(op)
return primops.tanh(x)
"""
if not isinstance(op_type, str):
raise TypeError(f'op_type must be str, but got {type(op_type)}.')
def wrapper(f):
def _lower(op, *args, **kwargs):
assert op.type == op_type, f'op.type should be equal to op_type, but op.type is {op.type} and op_type is {op_type}'
return f(op, *args, **kwargs)
_orig2prim.register(op_type, _lower)
return wrapper
def REGISTER_PRIM2ORIG(op_type):
"""
Decorator for registering the lower function for an primitive op into sequence of original ops.
Args:
op_type(str): The op name
Returns:
wrapper: Inner wrapper function
Examples:
.. code-block:: python
@REGISTER_PRIM2ORIG('tanh_p')
def tanh_prim2orig(op):
x, = get_input_var_list(op)
return paddle.tanh(x)
"""
if not isinstance(op_type, str):
raise TypeError(f'op_type must be str, but got {type(op_type)}.')
def wrapper(f):
def _lower(op, *args, **kwargs):
assert op.type == op_type, f'op.type should be equal to op_type, but op.type is {op.type} and op_type is {op_type}'
return f(op, *args, **kwargs)
_prim2orig.register(op_type, _lower)
return wrapper
def REGISTER_JVP(op_type):
"""
Decorator for registering the JVP function for a primitive op.
Args:
op_type(str): The op name
Returns:
wrapper: Inner wrapper function
Examples:
.. code-block:: python
@REGISTER_JVP('add_p')
def add_jvp(op, x_dot, y_dot):
return primops.add(x_dot, y_dot)
"""
if not isinstance(op_type, str):
raise TypeError(f'op_type must be str, but got {type(op_type)}.')
def wrapper(f):
def _jvp(op, *args, **kwargs):
assert op.type == op_type, f'op.type should be equal to op_type, but op.type is {op.type} and op_type is {op_type}'
return f(op, *args, **kwargs)
_primop_jvp.register(op_type, _jvp)
return f
return wrapper
def REGISTER_TRANSPOSE(op_type):
"""
Decorator for registering the transpose function for a primitive op
that denotes a linear operation in the forward AD graph.
Args:
op_type(str): The op name
Returns:
wrapper: Inner wrapper function
Examples:
.. code-block:: python
@REGISTER_TRANSPOSE('add_p')
def add_transpose(op, z_bar):
return z_bar, z_bar
"""
if not isinstance(op_type, str):
raise TypeError(f'op_type must be str, but got {type(op_type)}.')
def wrapper(f):
def _transpose(op, dot_checker, *args, **kwargs):
assert op.type == op_type, f'op.type should be equal to op_type, but op.type is {op.type} and op_type is {op_type}'
return f(op, dot_checker, *args, **kwargs)
_primop_transpose.register(op_type, _transpose)
return f
return wrapper
此差异已折叠。
此差异已折叠。
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.fluid import framework as framework
class PrimOption(object):
def __init__(self):
self.enable_prim = False
def get_status(self):
return self.enable_prim
def set_status(self, flag):
self.enable_prim = flag
prim_option = PrimOption()
@framework.static_only
def prim_enabled():
"""
.. note::
**ONLY available in the static mode.**
Shows whether the automatic differentiation mechanism based on
automatic differential basic operators is ON. Defaults to OFF.
Returns:
flag(bool): Whether the automatic differentiation mechanism based on automatic differential basic operators is ON.
Examples:
.. code-block:: python
import paddle
from paddle.incubate.autograd import enable_prim, disable_prim, prim_enabled
paddle.enable_static()
enable_prim()
print(prim_enabled()) # True
disable_prim()
print(prim_enabled()) # False
"""
return prim_option.get_status()
@framework.static_only
def enable_prim():
"""
.. note::
**ONLY available in the static mode.**
Turns ON automatic differentiation mechanism based on automatic
differential basic operators.
Examples:
.. code-block:: python
import paddle
from paddle.incubate.autograd import enable_prim, prim_enabled
paddle.enable_static()
enable_prim()
print(prim_enabled()) # True
"""
prim_option.set_status(True)
@framework.static_only
def disable_prim():
"""
.. note::
**ONLY available in the static mode.**
Turns OFF automatic differentiation mechanism based on automatic
differential basic operators.
Examples:
.. code-block:: python
import paddle
from paddle.incubate.autograd import enable_prim, disable_prim, prim_enabled
paddle.enable_static()
enable_prim()
print(prim_enabled()) # True
disable_prim()
print(prim_enabled()) # False
"""
prim_option.set_status(False)
INT_DTYPE_2_STRING = {
int(0): 'bool',
int(1): 'int16',
int(2): 'int32',
int(3): 'int64',
int(4): 'float16',
int(5): 'float32',
int(6): 'float64',
int(20): 'uint8',
int(21): 'int8',
int(23): 'complex64',
int(24): 'complex128',
}
def get_var_block(block, names):
assert isinstance(names, list)
if len(names) == 0:
return None
elif len(names) == 1:
return block.var(names[0])
else:
return [block.var(name) for name in names]
def get_input_var_list(op):
if op.input_names is None:
return []
else:
return [
get_var_block(op.block, op.input(n)) for n in sorted(op.input_names)
]
def get_output_var_list(op):
if op.output_names is None:
return []
else:
return [
get_var_block(op.block, op.output(n))
for n in sorted(op.output_names)
]
def to_tensors(xs):
if isinstance(xs, paddle.fluid.framework.Variable):
return [xs]
else:
return xs
def flatten(inp):
if inp is None or isinstance(inp, paddle.fluid.framework.Variable):
return [inp]
flattened = []
for part in inp:
flattened += flatten(part)
return flattened
def flatten_and_remove_none(inp):
flattened = flatten(inp)
return [var for var in flattened if var is not None]
......@@ -47,6 +47,45 @@ from paddle.fluid.framework import _in_legacy_dygraph, _in_eager_without_dygraph
__all__ = []
@framework.static_only
def append_backward_new(loss_list,
parameter_list=None,
no_grad_set=None,
callbacks=None,
checkpoints=None,
distop_context=None):
from paddle.incubate.autograd.primx import orig2prim, Transform
program = default_main_program()
assert program.num_blocks == 1, "The append_backward_new interface is designed to process only one block."
block = program.current_block()
orig2prim(block)
ad = Transform(block)
if parameter_list is None:
parameter_list = program.global_block().all_parameters()
param_dot, loss_dot = ad.linearize(parameter_list, loss_list)
loss_bar, param_bar = ad.transpose(loss_dot, param_dot)
# remove param_dot and their constructor ops
op_indexes = []
for var in param_dot:
if var is not None:
op_index = block.ops.index(var.op)
assert op_index >= 0
op_indexes.append(op_index)
ad.erase_ops(sorted(op_indexes))
ad.erase_dots(param_dot)
if len(parameter_list) == 1:
params_and_grads = [(parameter_list, param_bar)]
else:
params_and_grads = []
for i, param in enumerate(parameter_list):
params_and_grads.append((param, param_bar[i]))
return params_and_grads
class Optimizer(object):
r"""Optimizer Base class.
......@@ -880,6 +919,11 @@ class Optimizer(object):
parameter_list = parameters if parameters \
else self._parameter_list
with program_guard(program, startup_program):
from paddle.incubate.autograd.utils import prim_enabled
if prim_enabled():
params_grads = append_backward_new(
[loss], parameter_list, act_no_grad_set, callbacks)
else:
params_grads = append_backward(loss, parameter_list,
act_no_grad_set, callbacks)
# Note: since we can't use all_reduce_op now,
......
......@@ -368,6 +368,7 @@ packages=['paddle',
'paddle.incubate.nn.functional',
'paddle.incubate.nn.layer',
'paddle.incubate.optimizer.functional',
'paddle.incubate.autograd',
'paddle.incubate.distributed',
'paddle.incubate.distributed.models',
'paddle.incubate.distributed.models.moe',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册