未验证 提交 f6ee202f 编写于 作者: W WangZhen 提交者: GitHub

Add support for forward and reverse high-order automatic differentiation mechanism (#41919)

* Updated triple_grad_check func

* add todo for gradient checker and refine some comments

* remove additional code

* add test for warnging in backward.py

* format python code

* support multi input in triple gradient checker

* Add matmul triple grad kernel

* Updated comments of TODO

* Supported some special tests

* Change code-format to follow CI std

* Updated gradient_checker.py

* Fix conflicts

* Removed unnecessary printing log

* Change code style to follow CI std

* merge upstream

* add priops.py

* add_p

* rm useless files

* add sub_p mul_p div_p

* add sqrt_p and tanh_p

* add reshape_p

* add broadcast_p

* Add python primitive wrappers.

* Jvp rules updated.

* JVP rules done for all the 17 primops.

* quick check and fixes.

* add jvp(op, *args)

* add broadcast_p fill_constant_p matmul_p reduce_p reshape_p transpose_p

* add split_p and concat_p

* add gather_p and scatter_add_p

* add slice_select_p and slice_assign_p

* Add transpose rules.

* add multi input check for add_p, sub_p, mul_p, div_p

* update concat_p

* Linearize and transpose in progress..

* refine gather_p and scatter_add_p

* updated.

* update transpose.

* refine slice_assign_p and slice_select_p

* init commit for lower

* Merged with primitive ops.

* small update

* add rules for orig2prim and prim2orig

* add 9 test for prim ops

* add more test and fix some bug

* add more test

* register proto

* Adding primops test.

* add shape valid check for broadcast_p op, and add keepdim attr into reduce_p op proto

* support multi input and multi output for split_p and concat_p

* Test updated.

* update

* fix slice bug for slice_select_p and slice_assign_p

* updated.

* Ops updated.

* Refactor and bug fixes.

* updated.

* finish orig2prim and prim2orig rules

* dtype for axis attr should be long int

* update dtype for axis attr int64_t

* update for iscan CI

* Update primx.

* Refactor vars in primx.

* update for lower transform

* add more shape and dtype check

* update primx.py

* change IndexTensor into int32 dtype

* update

* Fix linearize and transpose.

* Update is_dot

* Update is_dot

* Update is_dot

* add gradient aggregation, fix add_transpose.

* pass first linearize+transpose test.

* update test

* refactor op registration and primx.

* update rule for slice_assign

* try test lower

* update orig2prim and prim2orig

* pass simple lower pass

* update

* Update input types in the unit test.

* orig2prim segfault.

* 50% for adam.minimize

* test updated.

* temp fix erros in removing vars.

* primx updated.

* update for matmul_v2 and reshape2 orig2prim

* update for minimize

* Refine primrules

* Remove some code

* supporting unused and unreachable vars.

* update for use prim2orig in minimize

* fix gather and scatter_add transpose.

* Add rules UT

* update scatter_add

* Refine UT code

* fix nonetype check in topo

* Update gather_p pywrapper.

* remove useless print

* Merge tongxin PR and refine code

* readd some test

* rm useless print

* polish code.

* fix bug in minimize

* add get_input_var_list and get_output_var_list and use it in lower

* Fix scatter_add_p prim2orig

* Update code and fix orig2prim/prim2orig UT

* delete vars after block.desc._remove

* Improve ops and vars clean up logics.

* fix some bug in linearize and lower

* update tanh transpose.

* use set instead of list for var2remove

* test updated.

* polish code.

* fix dot2bar delete.

* merge tx/ad

* add indextensor_dot for gather and scatter_add

* add sorted for set

* Fix scale_orig2prim params

* fix some syntax bug

* add golbal_lower_update list

* Better handling of unused vars.

* update tests.

* Fix elementwise_sub orig2prim

* support none for transpose rule

* Merge and add transform UT

* fix a bug in transpose

* Fix transpose and UT

* a hacky fix for cancat op

* Fix exector place

* Refine variable name

* Add elementwise_mul orig2prim and support p_norm when p=1

* Add sqrt orig2prim rule and UT

* merge wz test

* rename files, add enable_prim, disable_prim, prim_enabled, delete global_lower_update

* fix a bug in test_ad_transform_trans

* revert modify in framework.py

* add paddle.fluid.incubate.ad_transform to  python/setup.py.in

* Fix remove vars error

* Fix p_norm_orig2prim

* merge wz

* Modify the code directory

* Add utils.py and remove get_input/output_vars functions

* Update maolin code

* Rename UT and refine test_ad_transform_primops

* Fix div_p jvp rule

* Add higher derivatives UT

* Remove UT to autograd dir

* Fix comments

* import paddle in primops.py

* Add some error message for assert

* Refine UT class name and refine some comments in primreg.py

* update minimize of paddle/optimizer for supporting new autograd

* resolve cicular importing between backward.py and optimizer.py

* fill gradients and minimize unittest

* Replace `assert isinstance` with `raise TypeError`

* Add some assert message for primx.py

* Polish variable name

* Add some assert message

* add some docstring

* refine some name

* update the format of english documents

* Split test_transform.py to two files to avoid ci error

* fix the document format of enable_prim/disable_prim/prim2orig/prim_enabled

* polish test_gradients_and_minimize

* add default value for prim_enabled api doc

* Remove some UT to avoid windows ci error

* Enlarge test_gradients_and_minimize limit time

* Fix ut limit time
Co-authored-by: Nveyron95 <veyron_wu@163.com>
Co-authored-by: NJiabin Yang <360788950@qq.com>
Co-authored-by: Nlevi131 <limaolin01@baidu.com>
Co-authored-by: NTongxin Bai <waffle.bai@gmail.com>
Co-authored-by: NXiaoxu Chen <chenxx_id@163.com>
Co-authored-by: Nlevi131 <83750468+levi131@users.noreply.github.com>
上级 b9342a80
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
class Registry(object):
""" A general registry object. """
__slots__ = ['name', 'tab']
def __init__(self, name):
self.name = name
self.tab = {}
def register(self, name, value):
assert name not in self.tab
self.tab[name] = value
def lookup(self, name):
assert name in self.tab, f'No registry entry is found with name: {name}'
return self.tab[name]
_primop_fn = Registry('primop_fn')
_orig2prim = Registry('orig2prim')
_prim2orig = Registry('prim2orig')
_primop_jvp = Registry('primop_jvp')
_primop_transpose = Registry('primop_transpose')
_primop_position_argnames = Registry('primop_position_argnames')
def REGISTER_FN(op_type, *position_argnames):
"""Decorator for registering the Python function for a primitive op."""
assert isinstance(op_type, str)
_primop_position_argnames.register(op_type, position_argnames)
def wrapper(f):
_primop_fn.register(op_type, f)
return f
return wrapper
......@@ -32,6 +32,7 @@ try:
from collections.abc import Sequence
except:
from collections import Sequence
__all__ = [
'append_backward',
'gradients',
......@@ -2113,6 +2114,11 @@ def gradients(targets, inputs, target_gradients=None, no_grad_set=None):
check_type(target_gradients, 'target_gradients', (
framework.Variable, list, tuple, type(None)), 'paddle.static.gradients')
from ..incubate.autograd.primx import _gradients
from ..incubate.autograd.utils import prim_enabled
if prim_enabled():
return _gradients(targets, inputs, target_gradients)
outs = calc_gradient(targets, inputs, target_gradients, no_grad_set)
return _as_list(outs)
......
......@@ -8,3 +8,4 @@ endforeach(TEST_OP)
set_tests_properties(test_autograd_functional_dynamic PROPERTIES TIMEOUT 160)
set_tests_properties(test_autograd_functional_static PROPERTIES TIMEOUT 160)
set_tests_properties(test_gradients_and_minimize PROPERTIES TIMEOUT 60)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.incubate.autograd.primx import prim2orig
from paddle.incubate.autograd.utils import enable_prim, disable_prim, prim_enabled
paddle.enable_static()
class TestGradients(unittest.TestCase):
def test_third_order(self):
enable_prim()
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
x = paddle.static.data(name='x', shape=[1], dtype='float32')
x2 = paddle.multiply(x, x)
x3 = paddle.multiply(x2, x)
x4 = paddle.multiply(x3, x)
grad1, = paddle.static.gradients([x4], [x])
grad2, = paddle.static.gradients([grad1], [x])
grad3, = paddle.static.gradients([grad2], [x])
prim2orig(main.block(0))
feed = {x.name: np.array([2.]).astype('float32')}
fetch_list = [grad3.name]
result = [np.array([48.])]
place = paddle.CPUPlace()
if paddle.device.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(startup)
outs = exe.run(main, feed=feed, fetch_list=fetch_list)
np.allclose(outs, result)
disable_prim()
def test_fourth_order(self):
enable_prim()
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
x = paddle.static.data(name='x', shape=[1], dtype='float32')
x2 = paddle.multiply(x, x)
x3 = paddle.multiply(x2, x)
x4 = paddle.multiply(x3, x)
x5 = paddle.multiply(x4, x)
out = paddle.sqrt(x5 + x4)
grad1, = paddle.static.gradients([out], [x])
grad2, = paddle.static.gradients([grad1], [x])
grad3, = paddle.static.gradients([grad2], [x])
grad4, = paddle.static.gradients([grad3], [x])
prim2orig(main.block(0))
feed = {x.name: np.array([2.]).astype('float32'), }
fetch_list = [grad4.name]
# (3*(-5*x^2-16*x-16))/(16*(x+1)^3.5)
result = [np.array([-0.27263762711])]
place = paddle.CPUPlace()
if paddle.device.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(startup)
outs = exe.run(main, feed=feed, fetch_list=fetch_list)
np.allclose(outs, result)
disable_prim()
class TestMinimize(unittest.TestCase):
def model(self, x, w, bias, opt):
paddle.seed(0)
place = paddle.CPUPlace()
if paddle.device.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
input_x = paddle.static.data('x', x.shape, dtype=x.dtype)
input_x.stop_gradient = False
params_w = paddle.static.create_parameter(
shape=w.shape, dtype=w.dtype, is_bias=False)
params_bias = paddle.static.create_parameter(
shape=bias.shape, dtype=bias.dtype, is_bias=True)
y = paddle.tanh(paddle.matmul(input_x, params_w) + params_bias)
loss = paddle.norm(y, p=2)
opt = opt
_, grads = opt.minimize(loss)
if prim_enabled():
prim2orig(main.block(0))
exe.run(startup)
grads = exe.run(main,
feed={'x': x,
'w': w,
'bias': bias},
fetch_list=grads)
return grads
def test_adam(self):
x = np.random.rand(2, 20)
w = np.random.rand(20, 2)
bias = np.random.rand(2)
enable_prim()
prim_grads = self.model(x, w, bias, paddle.optimizer.Adam(0.01))
disable_prim()
orig_grads = self.model(x, w, bias, paddle.optimizer.Adam(0.01))
for orig, prim in zip(orig_grads, prim_grads):
np.testing.assert_allclose(orig, prim)
def test_sgd(self):
x = np.random.rand(2, 20)
w = np.random.rand(20, 2)
bias = np.random.rand(2)
enable_prim()
prim_grads = self.model(x, w, bias, paddle.optimizer.SGD(0.01))
disable_prim()
orig_grads = self.model(x, w, bias, paddle.optimizer.SGD(0.01))
for orig, prim in zip(orig_grads, prim_grads):
np.testing.assert_allclose(orig, prim)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.layers.utils import flatten
from paddle.incubate.autograd.primrules import _orig2prim, _prim2orig, _jvp, _transpose
paddle.enable_static()
############################ Test linearize rules ############################
class TestAddPJVPAndTranspose(unittest.TestCase):
def setUp(self):
self.main_program = paddle.static.Program()
self.startup_program = paddle.static.Program()
self.layer_help = LayerHelper('TestPrim2Orig')
with paddle.static.program_guard(self.main_program,
self.startup_program):
self.init_data()
def init_data(self):
# Set prim op
self.op_type = 'add_p'
X = paddle.static.data(name='X', shape=[2, 2], dtype='float')
Y = paddle.static.data(name='Y', shape=[2, 2], dtype='float')
self.prim_input = {'X': X, 'Y': Y}
self.prim_output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {}
# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[2, 2], dtype='float')
Y_DOT = paddle.static.data(name='Y_DOT', shape=[2, 2], dtype='float')
self.jvp_args = (X_DOT, Y_DOT)
self.jvp_out_shape_map = {0: self.prim_output['Z']}
# Set transpose
check_dot = lambda v: True
Z_BAR = paddle.static.data(name='Z_BAR', shape=[2, 2], dtype='float')
self.transpose_args = (check_dot, Z_BAR)
self.transpose_out_shape_map = {0: X, 1: Y}
self.all_ops = [
# prim op:
'add_p',
# jvp op:
'add_p',
# transpose op:
]
def test_op(self):
with paddle.static.program_guard(self.main_program,
self.startup_program):
op = self.layer_help.append_op(
type=self.op_type,
inputs=self.prim_input,
outputs=self.prim_output,
attrs=self.prim_attrs)
jvp_out = _jvp(op, *self.jvp_args)
jvp_out = flatten(jvp_out)
for k, v in self.jvp_out_shape_map.items():
self.assertEqual(jvp_out[k].shape, v.shape)
# Some prim ops dont have transpose rule
if hasattr(self, 'transpose_args'):
transpose_out = _transpose(op, *self.transpose_args)
transpose_out = flatten(transpose_out)
for k, v in self.transpose_out_shape_map.items():
self.assertEqual(transpose_out[k].shape, v.shape)
all_ops = [op.type for op in self.main_program.block(0).ops]
self.assertEqual(sorted(all_ops), sorted(self.all_ops))
class TestSubPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'sub_p'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
Y = paddle.static.data(name='Y', shape=[5, 6], dtype='int64')
self.prim_input = {'X': X, 'Y': Y}
self.prim_output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {}
# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64')
Y_DOT = paddle.static.data(name='Y_DOT', shape=[5, 6], dtype='int64')
self.jvp_args = (X_DOT, Y_DOT)
self.jvp_out_shape_map = {0: self.prim_output['Z']}
# Set transpose
check_dot = lambda v: True
Z_BAR = paddle.static.data(name='Z_BAR', shape=[5, 6], dtype='int64')
self.transpose_args = (check_dot, Z_BAR)
self.transpose_out_shape_map = {0: X, 1: Y}
self.all_ops = [
# prim op:
'sub_p',
# jvp op:
'sub_p',
# transpose op:
'fill_constant_p',
'sub_p'
]
class TestMulPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'mul_p'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
Y = paddle.static.data(name='Y', shape=[5, 6], dtype='int64')
self.prim_input = {'X': X, 'Y': Y}
self.prim_output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {}
# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64')
Y_DOT = paddle.static.data(name='Y_DOT', shape=[5, 6], dtype='int64')
self.jvp_args = (X_DOT, Y_DOT)
self.jvp_out_shape_map = {0: self.prim_output['Z']}
# Set transpose
check_dot = lambda v: v is X
Z_BAR = paddle.static.data(name='Z_BAR', shape=[5, 6], dtype='int64')
self.transpose_args = (check_dot, Z_BAR)
self.transpose_out_shape_map = {0: X, }
self.all_ops = [
# prim op:
'mul_p',
# jvp op:
'mul_p',
'mul_p',
'add_p',
# transpose op:
'mul_p'
]
class TestDivPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'div_p'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
Y = paddle.static.data(name='Y', shape=[5, 6], dtype='int64')
self.prim_input = {'X': X, 'Y': Y}
self.prim_output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {}
# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64')
Y_DOT = paddle.static.data(name='Y_DOT', shape=[5, 6], dtype='int64')
self.jvp_args = (X_DOT, Y_DOT)
self.jvp_out_shape_map = {0: self.prim_output['Z']}
# Set transpose
check_dot = lambda v: v is X
Z_BAR = paddle.static.data(name='Z_BAR', shape=[5, 6], dtype='int64')
self.transpose_args = (check_dot, Z_BAR)
self.transpose_out_shape_map = {0: X, }
self.all_ops = [
# prim op:
'div_p',
# jvp op:
'div_p',
'div_p',
'mul_p',
'mul_p',
'sub_p',
# transpose op:
'div_p'
]
class TestSqrtPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'sqrt_p'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
self.prim_input = {'X': X, }
self.prim_output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {}
# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64')
self.jvp_args = (X_DOT, )
self.jvp_out_shape_map = {0: self.prim_output['Y']}
self.all_ops = [
# prim op:
'sqrt_p',
# jvp op:
'div_p',
'mul_p',
'fill_constant_p',
# 'sqrt_p',
# transpose op:
]
class TestTanhPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'tanh_p'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
self.prim_input = {'X': X, }
self.prim_output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {}
# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64')
self.jvp_args = (X_DOT, )
self.jvp_out_shape_map = {0: self.prim_output['Y']}
self.all_ops = [
# prim op:
'tanh_p',
# jvp op:
'mul_p',
'sub_p',
'fill_constant_p',
'mul_p',
# transpose op:
]
class TestReshapePJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'reshape_p'
X = paddle.static.data(name='X', shape=[8, 8], dtype='int64')
self.prim_input = {'X': X, }
self.prim_output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {'shape': [2, 32]}
# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[8, 8], dtype='int64')
self.jvp_args = (X_DOT, )
self.jvp_out_shape_map = {0: self.prim_output['Y']}
# Set transpose
check_dot = lambda v: v is X
Y_BAR = paddle.static.data(name='Y_BAR', shape=[2, 32], dtype='int64')
self.transpose_args = (check_dot, Y_BAR)
self.transpose_out_shape_map = {0: X, }
self.all_ops = [
# prim op:
'reshape_p',
# jvp op:
'reshape_p',
# transpose op:
'reshape_p',
]
class TestBroadcastPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'broadcast_p'
X = paddle.static.data(name='X', shape=[10, 1], dtype='int64')
self.prim_input = {'X': X, }
self.prim_output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {'shape': [2, 10, 7]}
# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[10, 7], dtype='int64')
self.jvp_args = (X_DOT, )
self.jvp_out_shape_map = {0: self.prim_output['Y']}
# Set transpose
check_dot = lambda v: v is X
Y_BAR = paddle.static.data(
name='Y_BAR', shape=[2, 10, 7], dtype='int64')
self.transpose_args = (check_dot, Y_BAR)
self.transpose_out_shape_map = {0: X, }
self.all_ops = [
# prim op:
'broadcast_p',
# jvp op:
'broadcast_p',
# transpose op:
'reduce_p',
'reshape_p'
]
class TestTransposePJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'transpose_p'
X = paddle.static.data(name='X', shape=[2, 3, 4, 5], dtype='int64')
self.prim_input = {'X': X, }
self.prim_output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {'axis': [0, 2, 3, 1]}
# Set JVP
X_DOT = paddle.static.data(
name='X_DOT', shape=[2, 3, 4, 5], dtype='int64')
self.jvp_args = (X_DOT, )
self.jvp_out_shape_map = {0: self.prim_output['Y']}
# Set transpose
check_dot = lambda v: v is X
Y_BAR = paddle.static.data(
name='Y_BAR', shape=[2, 4, 5, 3], dtype='int64')
self.transpose_args = (check_dot, Y_BAR)
self.transpose_out_shape_map = {0: X, }
self.all_ops = [
# prim op:
'transpose_p',
# jvp op:
'transpose_p',
# transpose op:
'transpose_p',
]
class TestSplitPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'split_p'
X = paddle.static.data(name='X', shape=[2, 7, 10], dtype='int64')
self.prim_input = {'X': X, }
self.prim_output = {
'YS': [
self.layer_help.create_variable_for_type_inference(
dtype=X.dtype) for i in range(4)
]
}
self.prim_attrs = {'num_or_sections': [2, 3, 4, 1], 'axis': 2}
# Set JVP
X_DOT = paddle.static.data(
name='X_DOT', shape=[2, 7, 10], dtype='int64')
self.jvp_args = (X_DOT, )
self.jvp_out_shape_map = {
0: self.prim_output['YS'][0],
1: self.prim_output['YS'][1],
2: self.prim_output['YS'][2],
3: self.prim_output['YS'][3],
}
# Set transpose
check_dot = lambda v: v is X
YS_BAR = [
paddle.static.data(
name='Y_BAR1', shape=[2, 7, 2], dtype='int64'),
paddle.static.data(
name='Y_BAR2', shape=[2, 7, 3], dtype='int64'),
paddle.static.data(
name='Y_BAR3', shape=[2, 7, 4], dtype='int64'),
paddle.static.data(
name='Y_BAR4', shape=[2, 7, 1], dtype='int64'),
]
self.transpose_args = (check_dot, YS_BAR)
self.transpose_out_shape_map = {0: X, }
self.all_ops = [
# prim op:
'split_p',
# jvp op:
'split_p',
# transpose op:
'concat_p',
]
class TestConcatPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'concat_p'
X = paddle.static.data(name='X', shape=[3, 9, 5], dtype='float64')
Y = paddle.static.data(name='Y', shape=[3, 2, 5], dtype='float64')
Z = paddle.static.data(name='Z', shape=[3, 3, 5], dtype='float64')
self.prim_input = {'XS': [X, Y, Z], }
self.prim_output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {'axis': 1}
# Set JVP
XS_DOT = [
paddle.static.data(
name='X_DOT1', shape=[3, 9, 5], dtype='float64'),
paddle.static.data(
name='X_DOT2', shape=[3, 2, 5], dtype='float64'),
paddle.static.data(
name='X_DOT3', shape=[3, 3, 5], dtype='float64'),
]
self.jvp_args = (XS_DOT, )
self.jvp_out_shape_map = {0: self.prim_output['Y']}
# Set transpose
check_dot = lambda v: v is X or v is Y or v is Z
Y_BAR = paddle.static.data(
name='Y_BAR', shape=[3, 14, 5], dtype='float64')
self.transpose_args = (check_dot, Y_BAR)
self.transpose_out_shape_map = {
0: X,
1: Y,
2: Z,
}
self.all_ops = [
# prim op:
'concat_p',
# jvp op:
'concat_p',
# transpose op:
'split_p',
]
class TestReducePJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'reduce_p'
X = paddle.static.data(name='X', shape=[2, 3, 4, 5], dtype='float64')
self.prim_input = {'X': X}
self.prim_output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {'axis': [2], 'keepdim': False}
# Set JVP
X_DOT = paddle.static.data(
name='X_DOT1', shape=[2, 3, 4, 5], dtype='float64')
self.jvp_args = (X_DOT, )
self.jvp_out_shape_map = {0: self.prim_output['Y']}
# Set transpose
check_dot = lambda v: v is X
Y_BAR = paddle.static.data(
name='Y_BAR', shape=[2, 3, 5], dtype='float64')
self.transpose_args = (check_dot, Y_BAR)
self.transpose_out_shape_map = {0: X, }
self.all_ops = [
# prim op:
'reduce_p',
# jvp op:
'reduce_p',
# transpose op:
'reshape_p',
'broadcast_p',
]
class TestMatmulPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'matmul_p'
X = paddle.static.data(name='X', shape=[2, 3], dtype='float64')
Y = paddle.static.data(name='Y', shape=[3, 4], dtype='float64')
self.prim_input = {'X': X, 'Y': Y}
self.prim_output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {}
# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[2, 3], dtype='float64')
Y_DOT = paddle.static.data(name='Y_DOT', shape=[3, 4], dtype='float64')
self.jvp_args = (X_DOT, Y_DOT)
self.jvp_out_shape_map = {0: self.prim_output['Z']}
# Set transpose
check_dot = lambda v: v is X
Z_BAR = paddle.static.data(name='Z_BAR', shape=[2, 4], dtype='float64')
self.transpose_args = (check_dot, Z_BAR)
self.transpose_out_shape_map = {0: X, }
self.all_ops = [
# prim op:
'matmul_p',
# jvp op:
'matmul_p',
'matmul_p',
'add_p',
# transpose op:
'matmul_p',
'transpose_p',
]
class TestSliceSelectPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'slice_select_p'
X = paddle.static.data(name='X', shape=[3, 20], dtype='float64')
self.prim_input = {'X': X, }
self.prim_output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {
'axis': [1],
'starts': [0],
'ends': [20],
'strides': [2]
}
# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[3, 20], dtype='float64')
self.jvp_args = (X_DOT, )
self.jvp_out_shape_map = {0: self.prim_output['Y']}
# Set transpose
check_dot = lambda v: v is X
Y_BAR = paddle.static.data(name='Y_BAR', shape=[3, 10], dtype='float64')
self.transpose_args = (check_dot, Y_BAR)
self.transpose_out_shape_map = {0: X, }
self.all_ops = [
# prim op:
'slice_select_p',
# jvp op:
'slice_select_p',
# transpose op:
'slice_assign_p',
'fill_constant_p',
]
class TestSliceAssignPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'slice_assign_p'
X = paddle.static.data(name='X', shape=[3, 20], dtype='float64')
Y = paddle.static.data(name='Y', shape=[3, 5], dtype='float64')
self.prim_input = {'X': X, 'Y': Y}
self.prim_output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {
'axis': [1],
'starts': [0],
'ends': [10],
'strides': [2]
}
# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[3, 20], dtype='float64')
Y_DOT = paddle.static.data(name='Y_DOT', shape=[3, 5], dtype='float64')
self.jvp_args = (X_DOT, Y_DOT)
self.jvp_out_shape_map = {0: self.prim_output['Z']}
# Set transpose
check_dot = lambda v: v is X or v is Y
Z_BAR = paddle.static.data(name='Z_BAR', shape=[3, 20], dtype='float64')
self.transpose_args = (check_dot, Z_BAR)
self.transpose_out_shape_map = {0: X, 1: Y}
self.all_ops = [
# prim op:
'slice_assign_p',
# jvp op:
'slice_assign_p',
# transpose op:
'slice_assign_p',
'slice_select_p',
'fill_constant_p'
]
class TestGatherPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'gather_p'
X = paddle.static.data(name='X', shape=[9, 5], dtype='float64')
IndexTensor = paddle.static.data(
name='IndexTensor', shape=[3], dtype='int32')
self.prim_input = {'X': X, 'IndexTensor': IndexTensor}
self.prim_output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {'axis': 1}
# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[9, 5], dtype='float64')
self.jvp_args = (
X_DOT,
IndexTensor, )
self.jvp_out_shape_map = {0: self.prim_output['Y']}
# Set transpose
check_dot = lambda v: v is X
Y_BAR = paddle.static.data(name='Y_BAR', shape=[9, 3], dtype='float64')
self.transpose_args = (check_dot, Y_BAR)
self.transpose_out_shape_map = {0: X, }
self.all_ops = [
# prim op:
'gather_p',
# jvp op:
'gather_p',
# transpose op:
'scatter_add_p',
'fill_constant_p',
]
class TestScatterAddPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'scatter_add_p'
X = paddle.static.data(name='X', shape=[9, 5], dtype='float64')
Y = paddle.static.data(name='Y', shape=[9, 3], dtype='float64')
IndexTensor = paddle.static.data(
name='IndexTensor', shape=[3], dtype='int32')
self.prim_input = {'X': X, 'Y': Y, 'IndexTensor': IndexTensor}
self.prim_output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {'axis': 1}
# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[9, 5], dtype='float64')
Y_DOT = paddle.static.data(name='Y_DOT', shape=[9, 3], dtype='float64')
self.jvp_args = (X_DOT, Y_DOT)
self.jvp_out_shape_map = {0: self.prim_output['Z']}
# Set transpose
check_dot = lambda v: v is X or v is Y
Z_BAR = paddle.static.data(name='Z_BAR', shape=[9, 5], dtype='float64')
self.transpose_args = (check_dot, Z_BAR)
self.transpose_out_shape_map = {0: X, 1: Y}
self.all_ops = [
# prim op:
'scatter_add_p',
# jvp op:
'scatter_add_p',
# transpose op:
'scatter_add_p',
'fill_constant_p',
'gather_p'
]
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.layers.utils import flatten
from paddle.incubate.autograd.primrules import _orig2prim, _prim2orig, _jvp, _transpose
paddle.enable_static()
############################ Test orig2prim rules ############################
class TestElementWiseAddOrig2Prim(unittest.TestCase):
def setUp(self):
self.main_program = paddle.static.Program()
self.startup_program = paddle.static.Program()
self.layer_help = LayerHelper('TestOrig2Prim')
with paddle.static.program_guard(self.main_program,
self.startup_program):
self.init_data()
def init_data(self):
self.op_type = 'elementwise_add'
X = paddle.static.data(name='X', shape=[2, 2], dtype='float')
Y = paddle.static.data(name='Y', shape=[2, 2], dtype='float')
self.input = {'X': X, 'Y': Y}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.orig2prim_args = (X, Y)
self.all_ops = ['elementwise_add', 'add_p']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}
def test_op(self):
with paddle.static.program_guard(self.main_program,
self.startup_program):
op = self.layer_help.append_op(
type=self.op_type,
inputs=self.input,
outputs=self.output,
attrs=self.attrs)
prim_out = _orig2prim(op, *self.orig2prim_args)
all_ops = [op.type for op in self.main_program.block(0).ops]
self.assertEqual(sorted(all_ops), sorted(self.all_ops))
prim_out = flatten(prim_out)
for k, v in self.out_map.items():
self.assertEqual(prim_out[k].shape, v.shape)
class TestSqrtOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'sqrt'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
self.input = {'X': X, }
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.orig2prim_args = (X, )
self.all_ops = ['sqrt', 'sqrt_p']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}
class TestElementWiseMulOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'elementwise_mul'
X = paddle.static.data(name='X', shape=[8, 8], dtype='float')
Y = paddle.static.data(name='Y', shape=[8, 8], dtype='float')
self.input = {'X': X, 'Y': Y}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.orig2prim_args = (X, Y)
self.all_ops = ['elementwise_mul', 'mul_p']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}
class TestMatmulV2Orig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'matmul_v2'
X = paddle.static.data(name='X', shape=[3, 4], dtype='float')
Y = paddle.static.data(name='Y', shape=[4, 3], dtype='float')
self.input = {'X': X, 'Y': Y}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'trans_x': True, 'trans_y': True}
self.orig2prim_args = (X, Y)
self.all_ops = ['matmul_v2', 'transpose_p', 'transpose_p', 'matmul_p']
self.out_map = {0: self.output['Out']}
class TestTanhOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'tanh'
X = paddle.static.data(name='X', shape=[3, 4], dtype='float')
self.input = {'X': X, }
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.orig2prim_args = (X, )
self.all_ops = ['tanh', 'tanh_p']
self.out_map = {0: self.output['Out']}
class TestReshape2Orig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'reshape2'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
self.input = {'X': X, }
self.output = {
'Out': X,
'XShape':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'shape': [6, 5]}
self.orig2prim_args = (
None,
None,
X, )
self.all_ops = ['reshape2', 'reshape_p', 'fill_constant_p']
# Do not checke XShape
self.out_map = {0: self.output['Out']}
class TestConcatOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'concat'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
Y = paddle.static.data(name='Y', shape=[3, 6], dtype='int64')
self.input = {'X': [X, Y], }
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'axis': 0}
self.orig2prim_args = (
None,
(X, Y), )
self.all_ops = ['concat', 'concat_p']
self.out_map = {0: self.output['Out']}
class TestSliceOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'slice'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
self.input = {'Input': X, }
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {
'axes': [0],
'starts': [1],
'ends': [4],
}
self.orig2prim_args = (None, None, X, None, None)
self.all_ops = ['slice', 'slice_select_p']
self.out_map = {0: self.output['Out']}
class TestFillZerosLikeOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'fill_zeros_like'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
self.input = {'X': X, }
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.orig2prim_args = (X, )
self.all_ops = ['fill_zeros_like', 'fill_constant_p']
self.out_map = {0: self.output['Out']}
class TestSumOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'sum'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
Y = paddle.static.data(name='Y', shape=[5, 6], dtype='int64')
self.input = {'X': X, 'Y': Y}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.orig2prim_args = ((X, Y), )
self.all_ops = ['sum', 'add_p']
self.out_map = {0: self.output['Out']}
class TestPNormOrig2Prim1(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'p_norm'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
self.input = {'X': X, }
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {
'porder': 1,
'asvector': True,
}
self.orig2prim_args = (X, )
self.all_ops = ['p_norm', 'reshape_p', 'sqrt_p', 'reduce_p', 'mul_p']
self.out_map = {0: self.output['Out']}
class TestPNormOrig2Prim2(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'p_norm'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
self.input = {'X': X, }
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {
'porder': 2,
'asvector': True,
}
self.orig2prim_args = (X, )
self.all_ops = ['p_norm', 'reshape_p', 'sqrt_p', 'reduce_p', 'mul_p']
self.out_map = {0: self.output['Out']}
class TestIndexSelectOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'index_select'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
Index = paddle.static.data(name='Index', shape=[2], dtype='int32')
self.input = {'X': X, 'Index': Index}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'dim': 0, }
self.orig2prim_args = (
Index,
X, )
self.all_ops = ['index_select', 'gather_p']
self.out_map = {0: self.output['Out']}
class TestElementwiseSubOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'elementwise_sub'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int32')
Y = paddle.static.data(name='Y', shape=[6], dtype='int32')
self.input = {'X': X, 'Y': Y}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'dim': 0, }
self.orig2prim_args = (
X,
Y, )
self.all_ops = ['elementwise_sub', 'broadcast_p', 'sub_p']
self.out_map = {0: self.output['Out']}
class TestScaleOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'scale'
X = paddle.static.data(name='X', shape=[10, 7], dtype='int32')
self.input = {'X': X, }
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'scale': 2.0, 'bias': 1.0, 'bias_after_scale': True}
self.orig2prim_args = (
None,
X, )
self.all_ops = [
'scale', 'fill_constant_p', 'fill_constant_p', 'mul_p', 'add_p'
]
self.out_map = {0: self.output['Out']}
class TestAssignOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'assign'
X = paddle.static.data(name='X', shape=[10, 7], dtype='int32')
self.input = {'X': X, }
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.orig2prim_args = (X, )
self.all_ops = ['assign', 'fill_constant_p', 'add_p']
self.out_map = {0: self.output['Out']}
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.layers.utils import flatten
from paddle.incubate.autograd.primrules import _orig2prim, _prim2orig, _jvp, _transpose
paddle.enable_static()
############################ Test prim2orig rules ############################
class TestAddPPrim2Orig(unittest.TestCase):
def setUp(self):
self.main_program = paddle.static.Program()
self.startup_program = paddle.static.Program()
self.layer_help = LayerHelper('TestPrim2Orig')
with paddle.static.program_guard(self.main_program,
self.startup_program):
self.init_data()
def init_data(self):
self.op_type = 'add_p'
X = paddle.static.data(name='X', shape=[2, 2], dtype='float')
Y = paddle.static.data(name='Y', shape=[2, 2], dtype='float')
self.input = {'X': X, 'Y': Y}
self.output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.prim2orig_args = (X, Y)
self.all_ops = ['add_p', 'elementwise_add']
# { prim_op_output_var: orign_op_out_index }
self.out_map = {self.output['Z']: 0}
def test_op(self):
with paddle.static.program_guard(self.main_program,
self.startup_program):
op = self.layer_help.append_op(
type=self.op_type,
inputs=self.input,
outputs=self.output,
attrs=self.attrs)
orig_out = _prim2orig(op, *self.prim2orig_args)
all_ops = [op.type for op in self.main_program.block(0).ops]
self.assertEqual(sorted(all_ops), sorted(self.all_ops))
orig_out = flatten(orig_out)
for k, v in self.out_map.items():
self.assertEqual(k.shape, orig_out[v].shape)
class TestSubPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'sub_p'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64')
self.input = {'X': X, 'Y': Y}
self.output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.prim2orig_args = (X, Y)
self.all_ops = ['sub_p', 'elementwise_sub']
self.out_map = {self.output['Z']: 0}
class TestMulPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'mul_p'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64')
self.input = {'X': X, 'Y': Y}
self.output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.prim2orig_args = (X, Y)
self.all_ops = ['mul_p', 'elementwise_mul']
self.out_map = {self.output['Z']: 0}
class TestDivPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'div_p'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64')
self.input = {'X': X, 'Y': Y}
self.output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.prim2orig_args = (X, Y)
self.all_ops = ['div_p', 'elementwise_div']
self.out_map = {self.output['Z']: 0}
class TestSqrtPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'sqrt_p'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
self.input = {'X': X, }
self.output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.prim2orig_args = (X, )
self.all_ops = ['sqrt_p', 'sqrt']
self.out_map = {self.output['Y']: 0}
class TestTanhPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'tanh_p'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
self.input = {'X': X, }
self.output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.prim2orig_args = (X, )
self.all_ops = ['tanh_p', 'tanh']
self.out_map = {self.output['Y']: 0}
class TestReshapePPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'reshape_p'
X = paddle.static.data(name='X', shape=[2, 8], dtype='float64')
self.input = {'X': X, }
self.output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'shape': [4, 4]}
self.prim2orig_args = (X, )
self.all_ops = ['reshape_p', 'reshape2']
self.out_map = {self.output['Y']: 0}
class TestBroadcastPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'broadcast_p'
X = paddle.static.data(name='X', shape=[2, 8], dtype='float64')
self.input = {'X': X, }
self.output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'shape': [10, 2, 8]}
self.prim2orig_args = (X, )
self.all_ops = ['broadcast_p', 'expand_v2']
self.out_map = {self.output['Y']: 0}
class TestTransposePPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'transpose_p'
X = paddle.static.data(name='X', shape=[7, 8, 9, 10], dtype='float64')
self.input = {'X': X, }
self.output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'axis': [1, 2, 0, 3]}
self.prim2orig_args = (X, )
self.all_ops = ['transpose_p', 'transpose2']
self.out_map = {self.output['Y']: 0}
class TestSplitPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'split_p'
X = paddle.static.data(name='X', shape=[3, 9, 5], dtype='float64')
self.input = {'X': X, }
self.output = {
'YS': [
self.layer_help.create_variable_for_type_inference(
dtype=X.dtype) for i in range(3)
]
}
self.attrs = {'num_or_sections': [2, 3, 4], 'axis': 1}
self.prim2orig_args = (X, )
self.all_ops = ['split_p', 'split']
self.out_map = {
self.output['YS'][0]: 0,
self.output['YS'][1]: 1,
self.output['YS'][2]: 2,
}
class TestConcatPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'concat_p'
X = paddle.static.data(name='X', shape=[3, 9, 5], dtype='float64')
Y = paddle.static.data(name='Y', shape=[2, 9, 5], dtype='float64')
Z = paddle.static.data(name='Z', shape=[1, 9, 5], dtype='float64')
self.input = {'XS': [X, Y, Z], }
self.output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'axis': 0}
self.prim2orig_args = ((X, Y, Z), )
self.all_ops = ['concat_p', 'concat']
self.out_map = {self.output['Y']: 0}
class TestReducePPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'reduce_p'
X = paddle.static.data(name='X', shape=[3, 9, 5], dtype='float64')
self.input = {'X': X}
self.output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'axis': [1], 'keepdim': True}
self.prim2orig_args = (X, )
self.all_ops = ['reduce_p', 'reduce_sum']
self.out_map = {self.output['Y']: 0}
class TestMatmulPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'matmul_p'
X = paddle.static.data(name='X', shape=[9, 5], dtype='float64')
Y = paddle.static.data(name='Y', shape=[5, 9], dtype='float64')
self.input = {'X': X, 'Y': Y}
self.output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.prim2orig_args = (X, Y)
self.all_ops = ['matmul_p', 'matmul_v2']
self.out_map = {self.output['Z']: 0}
class TestSliceSelectPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'slice_select_p'
X = paddle.static.data(name='X', shape=[9, 5], dtype='float64')
self.input = {'X': X, }
self.output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'axis': [0], 'starts': [1], 'ends': [8], 'strides': [2]}
self.prim2orig_args = (X, )
self.all_ops = ['slice_select_p', 'strided_slice']
self.out_map = {self.output['Y']: 0}
class TestSliceAssignPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'slice_assign_p'
X = paddle.static.data(name='X', shape=[9, 5], dtype='float64')
Y = paddle.static.data(name='Y', shape=[9, 3], dtype='float64')
self.input = {'X': X, 'Y': Y}
self.output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'axis': [1], 'starts': [0], 'ends': [3], 'strides': [1]}
self.prim2orig_args = (X, Y)
self.all_ops = ['slice_assign_p', 'assign', 'set_value']
self.out_map = {self.output['Z']: 0}
class TestGatherPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'gather_p'
X = paddle.static.data(name='X', shape=[9, 5], dtype='float64')
IndexTensor = paddle.static.data(
name='IndexTensor', shape=[3], dtype='int32')
self.input = {'X': X, 'IndexTensor': IndexTensor}
self.output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'axis': 0, }
self.prim2orig_args = (
IndexTensor,
X, )
self.all_ops = ['gather_p', 'gather']
self.out_map = {self.output['Y']: 0}
class TestScatterAddPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'scatter_add_p'
X = paddle.static.data(name='X', shape=[9, 5], dtype='float64')
Y = paddle.static.data(name='Y', shape=[3, 5], dtype='float64')
IndexTensor = paddle.static.data(
name='IndexTensor', shape=[3], dtype='int32')
self.input = {'X': X, 'Y': Y, 'IndexTensor': IndexTensor}
self.output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'axis': 0, }
self.prim2orig_args = (IndexTensor, X, Y)
self.all_ops = [
'scatter_add_p', 'fill_any_like', 'scatter', 'elementwise_add'
]
self.out_map = {self.output['Z']: 0}
class TestFillConstantPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'fill_constant_p'
self.input = {}
self.output = {
'Y':
self.layer_help.create_variable_for_type_inference(paddle.int32)
}
self.attrs = {'value': 10, 'shape': [5, 5], 'dtype': paddle.int32}
self.prim2orig_args = ()
self.all_ops = ['fill_constant_p', 'fill_constant']
self.out_map = {self.output['Y']: 0}
if __name__ == '__main__':
unittest.main()
......@@ -14,12 +14,13 @@
import unittest
import numpy as np
import paddle
from paddle.autograd.primops import (
from paddle.incubate.autograd.primops import (
neg, set_value, add, sub, mul, div, sqrt, tanh, reshape, broadcast,
transpose, split, concat, reduce, matmul, slice_select, slice_assign,
gather, scatter_add, fill_const)
from paddle.incubate.autograd.primx import Transform, topo_path, orig2prim, prim2orig, _gradients
from paddle.incubate.autograd.utils import enable_prim, disable_prim, prim_enabled
class TestPyPrimOps(unittest.TestCase):
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.incubate.autograd.primx import Transform, orig2prim, prim2orig
from paddle.fluid.layers.utils import flatten
paddle.enable_static()
class TestAutoGradTransformForAdd(unittest.TestCase):
def setUp(self):
self.main_program = paddle.static.Program()
self.startup_program = paddle.static.Program()
with paddle.static.program_guard(self.main_program,
self.startup_program):
self.init_data()
def init_data(self):
# { input_index: input_shape }
self.xs_shape_map = {0: (20, 40), 1: (20, 40)}
# { output_index: output_shape }
self.ys_shape_map = {0: (20, 40)}
X0 = paddle.static.data(
name='X0', shape=self.xs_shape_map[0], dtype='float32')
X0.stop_gradient = False
X1 = paddle.static.data(
name='X1', shape=self.xs_shape_map[1], dtype='float32')
X1.stop_gradient = False
A = paddle.tanh(X0)
B = paddle.tanh(X1)
Y = paddle.add(A, B)
self.orig_xs = [X0, X1]
self.orig_ys = [Y, ]
self.orig_ops = ['tanh', 'tanh', 'elementwise_add']
self.orig2prim_ops = ['tanh_p', 'tanh_p', 'add_p']
self.linearize_ops = self.orig2prim_ops + [
# call fill_const() in linearize() function
'fill_constant_p',
'fill_constant_p',
# linearized op
'mul_p',
'sub_p',
'fill_constant_p',
'mul_p',
'mul_p',
'sub_p',
'fill_constant_p',
'mul_p',
'add_p',
]
self.transpose_ops = self.orig2prim_ops + [
# call fill_const() in transpose() function
'fill_constant_p',
# linearized op after remove path
'fill_constant_p',
'fill_constant_p',
'mul_p',
'sub_p',
'fill_constant_p',
'mul_p',
'sub_p',
'fill_constant_p',
# transposed op
'mul_p',
'mul_p'
]
self.prim2orig_ops = [
'tanh', 'tanh', 'elementwise_add', 'fill_constant', 'fill_constant',
'fill_constant', 'elementwise_mul', 'elementwise_sub',
'fill_constant', 'elementwise_mul', 'elementwise_sub',
'fill_constant', 'elementwise_mul', 'elementwise_mul'
]
def test_run(self):
# Must using with program_guard(), otherwise prim ops will append other block
with paddle.static.program_guard(self.main_program,
self.startup_program):
ad = Transform(self.main_program.block(0))
orig_ops = [op.type for op in self.main_program.block(0).ops]
self.assertEqual(sorted(orig_ops), sorted(self.orig_ops))
# Test orig2prim
orig2prim(block=self.main_program.block(0))
orig2prim_ops = [op.type for op in self.main_program.block(0).ops]
self.assertEqual(sorted(orig2prim_ops), sorted(self.orig2prim_ops))
# Test linearize
xs_dot, ys_dot = ad.linearize(self.orig_xs, self.orig_ys)
linearize_ops = [op.type for op in self.main_program.block(0).ops]
self.assertEqual(sorted(linearize_ops), sorted(self.linearize_ops))
flatten_xs_dot = flatten(xs_dot)
for k, v in self.xs_shape_map.items():
self.assertEqual(flatten_xs_dot[k].shape, v)
flatten_ys_dot = flatten(ys_dot)
for k, v in self.ys_shape_map.items():
self.assertEqual(flatten_ys_dot[k].shape, v)
# Test transpose
ys_bar, xs_bar = ad.transpose(ys_dot, xs_dot, retain_fwd=False)
transpose_ops = [op.type for op in self.main_program.block(0).ops]
self.assertEqual(sorted(transpose_ops), sorted(self.transpose_ops))
flatten_xs_bar = flatten(xs_bar)
for k, v in self.xs_shape_map.items():
# There may be None in the result of transpose like gather op
if flatten_xs_bar[k] is not None:
self.assertEqual(flatten_xs_bar[k].shape, v)
flatten_ys_bar = flatten(ys_bar)
for k, v in self.ys_shape_map.items():
self.assertEqual(flatten_ys_bar[k].shape, v)
# Test prim2orig
prim2orig(block=self.main_program.block(0))
prim2orig_ops = [op.type for op in self.main_program.block(0).ops]
self.assertEqual(sorted(prim2orig_ops), sorted(self.prim2orig_ops))
class TestAutoGradTransformForMatmul(TestAutoGradTransformForAdd):
def init_data(self):
# { input_index: input_shape }
self.xs_shape_map = {0: (100, 2), 1: (5, 2)}
# { output_index: output_shape }
self.ys_shape_map = {0: (100, 5)}
X0 = paddle.static.data(
'X0', shape=self.xs_shape_map[0], dtype='float32')
X0.stop_gradient = False
X1 = paddle.static.data(
'X1', shape=self.xs_shape_map[1], dtype='float32')
X1.stop_gradient = False
A = paddle.reshape(X1, [2, 5])
B = paddle.scale(A, scale=2.0, bias=2.0)
Y = paddle.matmul(X0, B)
self.orig_xs = [X0, X1]
self.orig_ys = [Y, ]
self.orig_ops = ['reshape2', 'scale', 'matmul_v2']
self.orig2prim_ops = [
'reshape_p', 'fill_constant_p', 'fill_constant_p',
'fill_constant_p', 'mul_p', 'add_p', 'matmul_p'
]
self.linearize_ops = self.orig2prim_ops + [
# call fill_const() in linearize() function
'fill_constant_p',
'fill_constant_p',
# linearized op
'reshape_p',
'mul_p',
# 'mul_p', # JVP rules handle `None` input, some op will not be appended
# 'add_p',
# 'add_p',
'matmul_p',
'matmul_p',
'add_p'
]
self.transpose_ops = self.orig2prim_ops + [
# call fill_const() in transpose() function
'fill_constant_p',
# linearized op after remove path
'fill_constant_p',
'fill_constant_p',
'mul_p',
# transposed op
'transpose_p',
'matmul_p',
'transpose_p',
'matmul_p',
# 'mul_p',
'reshape_p',
]
self.prim2orig_ops = [
'reshape2',
'fill_constant',
'fill_constant',
'fill_constant',
'elementwise_mul',
'elementwise_add',
'matmul_v2',
'fill_constant',
'fill_constant',
'fill_constant',
'elementwise_mul',
'transpose2',
'matmul_v2',
'transpose2',
'matmul_v2',
# 'elementwise_mul',
'reshape2',
]
class TestAutoGradTransformForIndexSelect(TestAutoGradTransformForAdd):
def init_data(self):
# { input_index: input_shape }
self.xs_shape_map = {0: (7, 8, 9), 1: (8, 1), 2: (7, 8, 9), 3: (3, )}
# { output_index: output_shape }
self.ys_shape_map = {0: (3, 16, 9)}
X0 = paddle.static.data(
'X0', shape=self.xs_shape_map[0], dtype='float32')
X0.stop_gradient = False
X1 = paddle.static.data(
'X1', shape=self.xs_shape_map[1], dtype='float32')
X1.stop_gradient = False
X2 = paddle.static.data(
'X2', shape=self.xs_shape_map[2], dtype='float32')
X2.stop_gradient = False
X3 = paddle.static.data('X3', shape=self.xs_shape_map[3], dtype='int32')
X3.stop_gradient = False
A = paddle.add(X0, X1) # (7, 8, 9)
B = paddle.norm(x=A, p=2) # (1, )
C = paddle.subtract(X2, B) # (7, 8, 9)
D = paddle.concat(x=(A, C), axis=1) # (7, 16, 9)
Y = paddle.index_select(D, X3, axis=0) # (3, 16, 9)
self.orig_xs = [X0, X1, X2, X3]
self.orig_ys = [Y, ]
self.orig_ops = [
'elementwise_add', 'p_norm', 'elementwise_sub', 'concat',
'index_select'
]
self.orig2prim_ops = [
'broadcast_p', 'add_p', 'reshape_p', 'mul_p', 'reduce_p', 'sqrt_p',
'broadcast_p', 'sub_p', 'concat_p', 'gather_p'
]
self.linearize_ops = self.orig2prim_ops + [
# call fill_const() in linearize() function
'fill_constant_p',
'fill_constant_p',
'fill_constant_p',
'fill_constant_p',
# linearized op
'broadcast_p',
'add_p',
'reshape_p',
'mul_p',
'mul_p',
'add_p',
'reduce_p',
'fill_constant_p', # 'sqrt_p', Will not append sqrt_p op when apply JVP for sqrt_p
'mul_p',
'div_p',
'broadcast_p',
'sub_p',
'concat_p',
'gather_p'
]
self.transpose_ops = self.orig2prim_ops + [
# call fill_const() in transpose() function
'fill_constant_p',
# linearized op after remove path
'fill_constant_p',
'fill_constant_p',
'fill_constant_p',
'fill_constant_p',
'fill_constant_p',
'mul_p',
# transposed op
'reduce_p',
'reshape_p',
'reshape_p',
'mul_p',
'mul_p',
'reshape_p',
'broadcast_p',
'div_p',
'reduce_p',
'reshape_p',
'fill_constant_p',
'sub_p',
'split_p',
'fill_constant_p',
'scatter_add_p',
'add_p', # The output of the op is used by multiple subsequent ops
'add_p',
]
self.prim2orig_ops = [
'expand_v2', 'elementwise_add', 'reshape2', 'elementwise_mul',
'reduce_sum', 'sqrt', 'expand_v2', 'elementwise_sub', 'concat',
'gather', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant',
'elementwise_mul', 'reduce_sum', 'reshape2', 'reshape2',
'elementwise_mul', 'elementwise_mul', 'reshape2', 'expand_v2',
'elementwise_div', 'reduce_sum', 'reshape2', 'fill_constant',
'elementwise_sub', 'split', 'fill_constant', 'fill_any_like',
'elementwise_add', 'scatter', 'elementwise_add', 'elementwise_add'
]
if __name__ == '__main__':
unittest.main()
......@@ -12,7 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.autograd.functional import Hessian, Jacobian, jvp, vjp
from .primx import prim2orig
from .utils import enable_prim, disable_prim, prim_enabled
__all__ = [ # noqa
'vjp', 'jvp', 'Jacobian', 'Hessian'
'vjp',
'jvp',
'Jacobian',
'Hessian',
'prim2orig',
'enable_prim',
'disable_prim',
'prim_enabled'
]
......@@ -13,8 +13,6 @@
# limitations under the License.
import paddle
from paddle.fluid import unique_name, core
from paddle.fluid.framework import default_main_program, default_startup_program
from paddle.fluid.layer_helper import LayerHelper
from .primreg import REGISTER_FN
......@@ -136,7 +134,9 @@ def split(x, num_or_sections, axis=0, outs=None):
if isinstance(num_or_sections, (list, tuple)):
n = len(num_or_sections)
else:
assert isinstance(num_or_sections, int)
if not isinstance(num_or_sections, int):
raise TypeError(
f'num_or_sections must be int, but got {type(num_or_sections)}.')
n = num_or_sections
attrs = {'num_or_sections': num_or_sections, 'axis': axis}
......@@ -157,7 +157,8 @@ def split(x, num_or_sections, axis=0, outs=None):
@REGISTER_FN('concat_p', 'XS', 'Y')
def concat(xs, axis=0, out=None):
assert isinstance(xs, (list, tuple)) and len(xs) > 0
if isinstance(xs, paddle.fluid.framework.Variable):
xs = [xs]
attrs = {'axis': axis}
helper = LayerHelper('concat_p', **locals())
if out is None:
......@@ -172,9 +173,10 @@ def concat(xs, axis=0, out=None):
@REGISTER_FN('reduce_p', 'X', 'Y')
def reduce(x, axis, keepdim=False, out=None):
assert isinstance(axis, (tuple, list))
assert isinstance(keepdim, bool)
if not isinstance(axis, (tuple, list)):
raise TypeError(f'axis must be tuple or list, but got {type(axis)}')
if not isinstance(keepdim, bool):
raise TypeError(f'keepdim must be bool, but got {type(keepdim)}')
attrs = {'axis': axis, 'keepdim': keepdim}
helper = LayerHelper('reduce_p', **locals())
......@@ -196,12 +198,20 @@ def matmul(x, y, out=None):
@REGISTER_FN('slice_select_p', 'X', 'Y')
def slice_select(x, axis, starts, ends, strides, out=None):
assert isinstance(axis, (list, tuple)), (
f'Argument type error. `axis` is supposed to be int, list or'
f' tuple but found {type(axis)}.')
assert isinstance(starts, (list, tuple))
assert isinstance(ends, (list, tuple))
assert len(axis) == len(starts) == len(ends) == len(strides)
if not isinstance(axis, (list, tuple)):
raise TypeError(f'Argument type error. `axis` is supposed to be list or'
f' tuple but found {type(axis)}.')
if not isinstance(starts, (list, tuple)):
raise TypeError(
f'Argument type error. `starts` is supposed to be list or'
f' tuple but found {type(starts)}.')
if not isinstance(ends, (list, tuple)):
raise TypeError(f'Argument type error. `ends` is supposed to be list or'
f' tuple but found {type(ends)}.')
assert len(axis) == len(starts) == len(ends) == len(strides), (
f'len(axis), len(starts), len(ends) and len(strides) should be equal, '
f'but len(axis)={len(axis)}, len(starts)={len(starts)}, '
f'len(ends)={len(ends)} and len(strides)={len(strides)}')
attrs = {'axis': axis, 'starts': starts, 'ends': ends, 'strides': strides}
helper = LayerHelper('slice_select_p', **locals())
......@@ -217,8 +227,13 @@ def slice_select(x, axis, starts, ends, strides, out=None):
@REGISTER_FN('slice_assign_p', 'X', 'Y', 'Z')
def slice_assign(x, y, axis, starts, ends, strides, out=None):
assert len(starts) == len(ends) == len(strides) == len(axis)
assert len(y.shape) == len(x.shape)
assert len(starts) == len(ends) == len(strides) == len(axis), (
f'len(starts), len(ends), len(strides) and len(axis) should be equal, '
f'but len(starts)={len(starts)}, len(ends)={len(ends)}, '
f'len(strides)={len(strides)} and len(axis)={len(axis)}')
assert len(y.shape) == len(x.shape), (
f'len(y.shape) should be equal to len(x.shape), '
f'but len(y.shape)={len(y.shape)} and len(x.shape)={len(x.shape)}.')
attrs = {'axis': axis, 'starts': starts, 'ends': ends, 'strides': strides}
helper = LayerHelper('slice_assign_p', **locals())
......@@ -233,7 +248,7 @@ def slice_assign(x, y, axis, starts, ends, strides, out=None):
return out
@REGISTER_FN('gather_p', 'X', 'Y')
@REGISTER_FN('gather_p', 'X', 'IndexTensor', 'Y')
def gather(x, indextensor, axis, out=None):
attrs = {'axis': axis}
helper = LayerHelper('gather_p', **locals())
......@@ -250,9 +265,16 @@ def gather(x, indextensor, axis, out=None):
@REGISTER_FN('scatter_add_p', 'X', 'Y', 'IndexTensor', 'Z')
def scatter_add(x, y, indextensor, axis, out=None):
assert len(x.shape) == len(y.shape)
assert len(indextensor.shape) == 1
assert y.shape[axis] == indextensor.shape[0]
assert len(x.shape) == len(y.shape), (
f'len(x.shape) should be equal to len(y.shape), '
f'but len(x.shape)={len(x.shape)} and len(y.shape)={len(y.shape)}.')
assert len(
indextensor.shape
) == 1, f'len(indextensor.shape) must be equal to 1, but got {len(indextensor.shape)}.'
assert y.shape[axis] == indextensor.shape[0], (
f'y.shape[axis] should be equal to indextensor.shape[0], '
f'but y.shape[axis]={y.shape[axis]} and '
f'indextensor.shape[0]={indextensor.shape[0]}.')
attrs = {'axis': axis}
helper = LayerHelper('scatter_add_p', **locals())
if out is None:
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
class Registry(object):
""" A general registry object. """
__slots__ = ['name', 'tab']
def __init__(self, name):
self.name = name
self.tab = {}
def register(self, name, value):
assert name not in self.tab, f'name "{name}" should not be registered before.'
self.tab[name] = value
def lookup(self, name):
return self.tab.get(name)
_primop_fn = Registry('primop_fn')
_orig2prim = Registry('orig2prim')
_prim2orig = Registry('prim2orig')
_primop_jvp = Registry('primop_jvp')
_primop_transpose = Registry('primop_transpose')
_primop_position_argnames = Registry('primop_position_argnames')
def lookup_fn(optype):
return _primop_fn.lookup(optype)
def lookup_orig2prim(optype):
return _orig2prim.lookup(optype)
def lookup_prim2orig(optype):
return _prim2orig.lookup(optype)
def lookup_jvp(optype):
return _primop_jvp.lookup(optype)
def lookup_transpose(optype):
return _primop_transpose.lookup(optype)
def op_position_inputs(op):
"""
Returns the position inputs of `op` as registered with REGISTER_FN.
Args:
op(Operator): The op that needs to get the inputs
Returns:
Tensor(s): Inputs of the op
Examples:
.. code-block:: python
@REGISTER_FN('div_p', 'X', 'Y', 'Z')
def div(x, y, out=None):
return _simple_binop(LayerHelper('div_p', **locals()))
The registered inputs are ['X', 'Y'] for div_p and accordingly this
function will return inputs in the order of X then Y.
"""
args = _primop_position_argnames.lookup(op.type)
assert args is not None, 'args should not be None in op_position_inputs().'
*input_names, _ = args
inputs = []
for name in input_names:
vars = list(map(op.block.var, op.input(name)))
assert len(
vars
) >= 0, f'len(vars) should be greater than or equal to 0, but len(vars)={len(vars)}.'
if len(vars) > 1:
inputs.append(vars)
else:
inputs.append(vars[0])
return inputs
def op_position_output(op):
"""
Returns the output of `op` as registered with REGISTER_FN.
Args:
op(Operator): The op that needs to get the output
Returns:
Tensor(s): Output of the op
Examples:
.. code-block:: python
@REGISTER_FN('div_p', 'X', 'Y', 'Z')
def div(x, y, out=None):
return _simple_binop(LayerHelper('div_p', **locals()))
The registered output is ['Z'] for div_p and accordingly this
function will return output Z.
"""
args = _primop_position_argnames.lookup(op.type)
assert args is not None, 'args should not be None in op_position_output().'
*_, output_name = args
outvars = list(map(op.block.var, op.output(output_name)))
assert len(
outvars
) >= 0, f'len(outvars) should be greater than or equal to 0, but len(outvars)={len(outvars)}.'
if len(outvars) > 1:
output = outvars
else:
output = outvars[0]
return output
def REGISTER_FN(op_type, *position_argnames):
"""
Decorator for registering the Python function for a primitive op.
Args:
op_type(str): The op name
position_argnames(list[str]): Input and ouput names of the op
Returns:
wrapper: Inner wrapper function
Examples:
.. code-block:: python
@REGISTER_FN('tanh_p', 'X', 'Y')
def tanh(x, out=None):
return _simple_unop(LayerHelper('tanh_p', **locals()))
"""
if not isinstance(op_type, str):
raise TypeError(f'op_type must be str, but got {type(op_type)}.')
_primop_position_argnames.register(op_type, position_argnames)
def wrapper(f):
_primop_fn.register(op_type, f)
return f
return wrapper
def REGISTER_ORIG2PRIM(op_type):
"""
Decorator for registering the lower function for an original op into sequence of primitive ops.
Args:
op_type(str): The op name
Returns:
wrapper: Inner wrapper function
Examples:
.. code-block:: python
@REGISTER_ORIG2PRIM('tanh')
def tanh_orig2prim(op):
x, = get_input_var_list(op)
return primops.tanh(x)
"""
if not isinstance(op_type, str):
raise TypeError(f'op_type must be str, but got {type(op_type)}.')
def wrapper(f):
def _lower(op, *args, **kwargs):
assert op.type == op_type, f'op.type should be equal to op_type, but op.type is {op.type} and op_type is {op_type}'
return f(op, *args, **kwargs)
_orig2prim.register(op_type, _lower)
return wrapper
def REGISTER_PRIM2ORIG(op_type):
"""
Decorator for registering the lower function for an primitive op into sequence of original ops.
Args:
op_type(str): The op name
Returns:
wrapper: Inner wrapper function
Examples:
.. code-block:: python
@REGISTER_PRIM2ORIG('tanh_p')
def tanh_prim2orig(op):
x, = get_input_var_list(op)
return paddle.tanh(x)
"""
if not isinstance(op_type, str):
raise TypeError(f'op_type must be str, but got {type(op_type)}.')
def wrapper(f):
def _lower(op, *args, **kwargs):
assert op.type == op_type, f'op.type should be equal to op_type, but op.type is {op.type} and op_type is {op_type}'
return f(op, *args, **kwargs)
_prim2orig.register(op_type, _lower)
return wrapper
def REGISTER_JVP(op_type):
"""
Decorator for registering the JVP function for a primitive op.
Args:
op_type(str): The op name
Returns:
wrapper: Inner wrapper function
Examples:
.. code-block:: python
@REGISTER_JVP('add_p')
def add_jvp(op, x_dot, y_dot):
return primops.add(x_dot, y_dot)
"""
if not isinstance(op_type, str):
raise TypeError(f'op_type must be str, but got {type(op_type)}.')
def wrapper(f):
def _jvp(op, *args, **kwargs):
assert op.type == op_type, f'op.type should be equal to op_type, but op.type is {op.type} and op_type is {op_type}'
return f(op, *args, **kwargs)
_primop_jvp.register(op_type, _jvp)
return f
return wrapper
def REGISTER_TRANSPOSE(op_type):
"""
Decorator for registering the transpose function for a primitive op
that denotes a linear operation in the forward AD graph.
Args:
op_type(str): The op name
Returns:
wrapper: Inner wrapper function
Examples:
.. code-block:: python
@REGISTER_TRANSPOSE('add_p')
def add_transpose(op, z_bar):
return z_bar, z_bar
"""
if not isinstance(op_type, str):
raise TypeError(f'op_type must be str, but got {type(op_type)}.')
def wrapper(f):
def _transpose(op, dot_checker, *args, **kwargs):
assert op.type == op_type, f'op.type should be equal to op_type, but op.type is {op.type} and op_type is {op_type}'
return f(op, dot_checker, *args, **kwargs)
_primop_transpose.register(op_type, _transpose)
return f
return wrapper
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from .primreg import REGISTER_ORIG2PRIM, REGISTER_PRIM2ORIG, REGISTER_JVP, REGISTER_TRANSPOSE
from .primreg import (lookup_fn, lookup_orig2prim, lookup_prim2orig, lookup_jvp,
lookup_transpose, op_position_inputs, op_position_output)
from .primops import (neg, add, sub, mul, div, sqrt, tanh, reshape, broadcast,
transpose, split, concat, reduce, matmul, slice_select,
slice_assign, gather, scatter_add, fill_const, set_value)
from .utils import get_input_var_list, get_output_var_list, INT_DTYPE_2_STRING
def _orig2prim(op, *args):
_lowerrule = lookup_orig2prim(op.type)
return _lowerrule(op, *args)
def _prim2orig(op, *args):
_lowerrule = lookup_prim2orig(op.type)
return _lowerrule(op, *args)
def _jvp(op, *args):
_jvprule = lookup_jvp(op.type)
return _jvprule(op, *args)
def _transpose(op, dot_checker, *args):
_transposerule = lookup_transpose(op.type)
return _transposerule(op, dot_checker, *args)
def linear_jvp(op, *args, **kwargs):
fn = lookup_fn(op.type)
out_dot = fn(*args, **kwargs)
return out_dot
## Register orig2prim lower rules
"""
These original ops are fully supported:
elementwise_add
elementwise_sub
elementwise_mul
tanh
fill_zeros_like
sum
index_select
scale
assign
sqrt
These original ops are partially supported:
matmul_v2
reshape2
concat
slice
p_norm
"""
@REGISTER_ORIG2PRIM('elementwise_add')
def elementwise_add_orig2prim(op, x, y):
if x.shape != y.shape:
y = broadcast(y, shape=x.shape)
if op.attr('Scale_x') - 1.0 > 1e-5:
scale_x = fill_const(
shape=x.shape, dtype=x.dtype, value=op.attr('Scale_x'))
x = mul(x, scale_x)
if op.attr('Scale_y') - 1.0 > 1e-5:
scale_y = fill_const(
shape=y.shape, dtype=y.dtype, value=op.attr('Scale_y'))
y = mul(y, scale_y)
z = add(x, y)
if op.attr('Scale_out') - 1.0 > 1e-5:
scale_out = fill_const(
shape=z.shape, dtype=z.dtype, value=op.attr('Scale_out'))
z = mul(z, scale_out)
return z
@REGISTER_ORIG2PRIM('elementwise_sub')
def elementwise_sub_orig2prim(op, x, y):
if x.shape != y.shape:
y = broadcast(y, shape=x.shape)
if op.attr('Scale_x') - 1.0 > 1e-5:
scale_x = fill_const(
shape=x.shape, dtype=x.dtype, value=op.attr('Scale_x'))
x = mul(x, scale_x)
if op.attr('Scale_y') - 1.0 > 1e-5:
scale_y = fill_const(
shape=y.shape, dtype=y.dtype, value=op.attr('Scale_y'))
y = mul(y, scale_y)
z = sub(x, y)
if op.attr('Scale_out') - 1.0 > 1e-5:
scale_out = fill_const(
shape=z.shape, dtype=z.dtype, value=op.attr('Scale_out'))
z = mul(z, scale_out)
return z
@REGISTER_ORIG2PRIM('elementwise_mul')
def elementwise_mul_orig2prim(op, x, y):
if x.shape != y.shape:
y = broadcast(y, shape=x.shape)
if op.attr('Scale_x') - 1.0 > 1e-5:
scale_x = fill_const(
shape=x.shape, dtype=x.dtype, value=op.attr('Scale_x'))
x = mul(x, scale_x)
if op.attr('Scale_y') - 1.0 > 1e-5:
scale_y = fill_const(
shape=y.shape, dtype=y.dtype, value=op.attr('Scale_y'))
y = mul(y, scale_y)
z = mul(x, y)
if op.attr('Scale_out') - 1.0 > 1e-5:
scale_out = fill_const(
shape=z.shape, dtype=z.dtype, value=op.attr('Scale_out'))
z = mul(z, scale_out)
return z
@REGISTER_ORIG2PRIM('tanh')
def tanh_orig2prim(op, x):
return tanh(x)
@REGISTER_ORIG2PRIM('fill_zeros_like')
def fill_zeros_like_orig2prim(op, x):
return fill_const(value=0.0, shape=x.shape, dtype=x.dtype)
@REGISTER_ORIG2PRIM('sum')
def sum_orig2prim(op, xs):
x0 = xs[0]
for x in xs[1:]:
x0 = add(x0, x)
return x0
@REGISTER_ORIG2PRIM('index_select')
def index_select_orig2prim(op, index_t, x):
return gather(x, indextensor=index_t, axis=op.attr('dim'))
@REGISTER_ORIG2PRIM('scale')
def scale_orig2prim(op, scale_t, x):
if scale_t is None:
scale_t = fill_const(
shape=x.shape, dtype=x.dtype, value=op.attr('scale'))
bias_t = fill_const(shape=x.shape, dtype=x.dtype, value=op.attr('bias'))
if op.attr('bias_after_scale'):
return add(mul(x, scale_t), bias_t)
else:
return mul(add(x, bias_t), scale_t)
@REGISTER_ORIG2PRIM('assign')
def assign_orig2prim(op, x):
zero_t = fill_const(shape=x.shape, dtype=x.dtype, value=0.0)
return add(x, zero_t)
@REGISTER_ORIG2PRIM('sqrt')
def sqrt_orig2prim(op, x):
return sqrt(x)
@REGISTER_ORIG2PRIM('matmul_v2')
def matmul_v2_orig2prim(op, x, y):
def trans(shape):
ret = [i for i in range(len(shape))]
ret[-1], ret[-2] = ret[-2], ret[-1]
return ret
assert len(x.shape) < 4 and len(
y.shape) < 4, 'Do not support multi batchsize dimensions currently.'
if len(x.shape) == 1:
x = broadcast(x, shape=[1, x.shape[0]])
if len(y.shape) == 1:
y = broadcast(y, shape=[y.shape[0], 1])
if op.attr('trans_x'):
x = transpose(x, axis=trans(x.shape))
if op.attr('trans_y'):
y = transpose(y, axis=trans(y.shape))
return matmul(x, y)
## NOTE(lml): The second output of reshape2 Xshape, which is only used in reshape2_grad, is meanlingless in new autograd mechanism, thus we use a zero tensor instead.
@REGISTER_ORIG2PRIM('reshape2')
def reshape2_orig2prim(op, shape_t, shape_tl, x):
assert shape_t is None, 'Can not lower reshape2 into prim ops with shapetensor.'
assert shape_tl is None, 'Can not lower reshape2 into prim ops with shapetensorlist.'
y, xshape = get_output_var_list(op)
return reshape(
x, shape=y.shape), fill_const(
shape=xshape.shape, dtype=xshape.dtype, value=0.0)
@REGISTER_ORIG2PRIM('concat')
def concat_orig2prim(op, axis_t, xs):
assert axis_t is None, 'Can not lower concat into prim ops with axistensor.'
return concat(xs, axis=op.attr('axis'))
@REGISTER_ORIG2PRIM('slice')
def slice_orig2prim(op, ends_t, ends_tl, x, starts_t, starts_tl):
assert starts_t is None, 'Can not lower concat into prim ops with startstensor.'
assert ends_t is None, 'Can not lower concat into prim ops with endstensor.'
assert starts_tl is None, 'Can not lower concat into prim ops with startstensorlist.'
assert ends_tl is None, 'Can not lower concat into prim ops with endstensorlist.'
starts = op.attr('starts')
ends = op.attr('ends')
strides = [1 for _ in starts]
axis = op.attr('axes')
y = slice_select(x, starts=starts, ends=ends, strides=strides, axis=axis)
if op.attr('decrease_axis'):
y = reshape(y, shape=get_output_var_list(op)[0].shape)
return y
@REGISTER_ORIG2PRIM('p_norm')
def p_norm_orig2prim(op, x):
def num_el(shape):
n = 1
for s in shape:
n = n * s
return n
assert op.attr(
'asvector'), 'Only support lower pnorm when asvector=True currently'
if len(x.shape) > 1:
x = reshape(x, shape=[num_el(x.shape)])
if abs(op.attr('porder') - 2.0) < 1e-5:
return sqrt(reduce(mul(x, x), axis=[0]))
elif abs(op.attr('porder') - 1.0) < 1e-5:
return reduce(sqrt(mul(x, x)), axis=[0])
else:
raise RuntimeError('Only support lower l2/l1 norm currently')
## Register prim2orig lower rules
@REGISTER_PRIM2ORIG('add_p')
def add_prim2orig(op, x, y):
return paddle.add(x, y)
@REGISTER_PRIM2ORIG('sub_p')
def sub_prim2orig(op, x, y):
return paddle.subtract(x, y)
@REGISTER_PRIM2ORIG('mul_p')
def mul_prim2orig(op, x, y):
return paddle.multiply(x, y)
@REGISTER_PRIM2ORIG('div_p')
def div_prim2orig(op, x, y):
return paddle.divide(x, y)
@REGISTER_PRIM2ORIG('sqrt_p')
def sqrt_prim2orig(op, x):
return paddle.sqrt(x)
@REGISTER_PRIM2ORIG('tanh_p')
def tanh_prim2orig(op, x):
return paddle.tanh(x)
@REGISTER_PRIM2ORIG('reshape_p')
def reshape_prim2orig(op, x):
return paddle.reshape(x, shape=op.attr('shape'))
@REGISTER_PRIM2ORIG('broadcast_p')
def broadcast_prim2orig(op, x):
return paddle.broadcast_to(x, shape=op.attr('shape'))
@REGISTER_PRIM2ORIG('transpose_p')
def transpose_prim2orig(op, x):
return paddle.transpose(x, perm=op.attr('axis'))
@REGISTER_PRIM2ORIG('split_p')
def split_prim2orig(op, x):
num_or_sections = op.attr('num_or_sections')
if len(num_or_sections) == 1:
num_or_sections = num_or_sections[0]
return paddle.split(
x, num_or_sections=num_or_sections, axis=op.attr('axis'))
@REGISTER_PRIM2ORIG('concat_p')
def concat_prim2orig(op, xs):
return paddle.concat(xs, axis=op.attr('axis'))
@REGISTER_PRIM2ORIG('reduce_p')
def reduce_prim2orig(op, x):
return paddle.sum(x, axis=op.attr('axis'), keepdim=op.attr('keepdim'))
@REGISTER_PRIM2ORIG('matmul_p')
def matmul_prim2orig(op, x, y):
return paddle.matmul(x, y)
@REGISTER_PRIM2ORIG('slice_select_p')
def slice_select_prim2orig(op, x):
return paddle.strided_slice(
x,
axes=op.attr('axis'),
starts=op.attr('starts'),
ends=op.attr('ends'),
strides=op.attr('strides'))
@REGISTER_PRIM2ORIG('slice_assign_p')
def slice_assign_prim2orig(op, x, y):
x_copy = paddle.assign(x)
return set_value(
x_copy,
y,
axis=op.attr('axis'),
starts=op.attr('starts'),
ends=op.attr('ends'),
strides=op.attr('strides'),
out=x_copy)
@REGISTER_PRIM2ORIG('gather_p')
def gather_prim2orig(op, index_t, x):
return paddle.gather(x, index_t, axis=op.attr('axis'))
@REGISTER_PRIM2ORIG('scatter_add_p')
def scatter_add_prim2orig(op, index_t, x, y):
assert op.attr('axis') == 0, 'Only support axis==0 currently'
zeros = paddle.zeros_like(x=x, dtype=x.dtype)
tmp = paddle.scatter(x=zeros, index=index_t, updates=y, overwrite=False)
return paddle.add(x, tmp)
@REGISTER_PRIM2ORIG('fill_constant_p')
def fill_constant_prim2orig(op):
return paddle.full(
shape=op.attr('shape'),
fill_value=op.attr('value'),
dtype=INT_DTYPE_2_STRING[op.attr('dtype')])
## Register linearize rules
@REGISTER_JVP('add_p')
def add_jvp(op, x_dot, y_dot):
if x_dot is None:
return y_dot
elif y_dot is None:
return x_dot
else:
return linear_jvp(op, x_dot, y_dot)
@REGISTER_JVP('sub_p')
def sub_jvp(op, x_dot, y_dot):
if x_dot is None:
return neg(y_dot)
elif y_dot is None:
return x_dot
else:
return linear_jvp(op, x_dot, y_dot)
@REGISTER_JVP('mul_p')
def mul_jvp(op, x_dot, y_dot):
if x_dot is None and y_dot is None:
return None
x, y = op_position_inputs(op)
if x_dot is None:
return mul(x, y_dot)
elif y_dot is None:
return mul(x_dot, y)
else:
t1, t2 = mul(x_dot, y), mul(x, y_dot)
z_dot = add(t1, t2)
return z_dot
@REGISTER_JVP('div_p')
def div_jvp(op, x_dot, y_dot):
if x_dot is None and y_dot is None:
return None
x, y = op_position_inputs(op)
if y_dot is None:
return div(x_dot, y)
elif x_dot is None:
return neg(div(mul(x, y_dot), mul(y, y)))
else:
t1 = div(x_dot, y)
t2 = div(mul(x, y_dot), mul(y, y))
return sub(t1, t2)
@REGISTER_JVP('sqrt_p')
def sqrt_jvp(op, x_dot):
if x_dot is None:
return None
y = op_position_output(op)
c2 = fill_const(value=2.0, shape=y.shape, dtype=y.dtype)
y_dot = div(x_dot, mul(c2, y))
return y_dot
@REGISTER_JVP('tanh_p')
def tanh_jvp(op, x_dot):
if x_dot is None:
return None
y = op_position_output(op)
c1 = fill_const(value=1.0, shape=y.shape, dtype=y.dtype)
y_dot = mul(x_dot, sub(c1, mul(y, y)))
return y_dot
@REGISTER_JVP('reshape_p')
def reshape_jvp(op, x_dot):
if x_dot is None:
return None
shape = op.attr('shape')
return linear_jvp(op, x_dot, shape=shape)
@REGISTER_JVP('broadcast_p')
def broadcast_jvp(op, x_dot):
if x_dot is None:
return None
shape = op.attr('shape')
return linear_jvp(op, x_dot, shape=shape)
@REGISTER_JVP('transpose_p')
def transpose_jvp(op, x_dot):
if x_dot is None:
return None
axis = op.attr('axis')
return linear_jvp(op, x_dot, axis=axis)
@REGISTER_JVP('split_p')
def split_jvp(op, x_dot):
if x_dot is None:
return None
num_or_sections = op.attr('num_or_sections')
axis = op.attr('axis')
return linear_jvp(op, x_dot, num_or_sections=num_or_sections, axis=axis)
@REGISTER_JVP('concat_p')
def concat_jvp(op, xs_dot):
if xs_dot is None:
return None
axis = op.attr('axis')
return linear_jvp(op, xs_dot, axis=axis)
@REGISTER_JVP('reduce_p')
def reduce_jvp(op, x_dot):
if x_dot is None:
return None
axis = op.attr('axis')
keepdim = op.attr('keepdim')
return linear_jvp(op, x_dot, axis=axis, keepdim=keepdim)
@REGISTER_JVP('matmul_p')
def matmul_jvp(op, x_dot, y_dot):
if x_dot is None and y_dot is None:
return None
x, y = op_position_inputs(op)
if x_dot is None:
return matmul(x, y_dot)
elif y_dot is None:
return matmul(x_dot, y)
else:
t1 = matmul(x, y_dot)
t2 = matmul(x_dot, y)
return add(t1, t2)
@REGISTER_JVP('slice_select_p')
def slice_select_jvp(op, x_dot):
if x_dot is None:
return x_dot
axis = op.attr('axis')
starts = op.attr('starts')
ends = op.attr('ends')
strides = op.attr('strides')
return linear_jvp(
op, x_dot, axis=axis, starts=starts, ends=ends, strides=strides)
@REGISTER_JVP('slice_assign_p')
def slice_assign_jvp(op, x_dot, y_dot):
if x_dot is None:
assert y_dot is None, 'y_dot must be None.'
return None
else:
assert y_dot is not None, 'y_dot should not be None.'
axis = op.attr('axis')
starts = op.attr('starts')
ends = op.attr('ends')
strides = op.attr('strides')
return linear_jvp(
op, x_dot, y_dot, axis=axis, starts=starts, ends=ends, strides=strides)
@REGISTER_JVP('gather_p')
def gather_jvp(op, x_dot, indextensor):
if x_dot is None:
return None
_, indextensor = op_position_inputs(op)
axis = op.attr('axis')
return linear_jvp(op, x_dot, indextensor, axis=axis)
@REGISTER_JVP('scatter_add_p')
def scatter_add_jvp(op, x_dot, y_dot):
if x_dot is None:
return None
_, _, indextensor = op_position_inputs(op)
axis = op.attr('axis')
return linear_jvp(op, x_dot, y_dot, indextensor, axis=axis)
## Register transpose rules
@REGISTER_TRANSPOSE('add_p')
def add_transpose(op, check_dot, z_bar):
x, y = op_position_inputs(op)
assert check_dot(x) or check_dot(y), (
f'(check_dot(x) or check_dot(y)) must be True, '
f'but check_dot(x)={check_dot(x)} and check_dot(y)={check_dot(y)}.')
x_bar = z_bar if check_dot(x) else None
y_bar = z_bar if check_dot(y) else None
return x_bar, y_bar
@REGISTER_TRANSPOSE('sub_p')
def sub_transpose(op, check_dot, z_bar):
x, y = op_position_inputs(op)
assert check_dot(x) or check_dot(y), (
f'(check_dot(x) or check_dot(y)) must be True, '
f'but check_dot(x)={check_dot(x)} and check_dot(y)={check_dot(y)}.')
x_bar = z_bar if check_dot(x) else None
y_bar = neg(z_bar) if check_dot(y) else None
return x_bar, y_bar
@REGISTER_TRANSPOSE('mul_p')
def mul_transpose(op, check_dot, z_bar):
x, y = op_position_inputs(op)
assert check_dot(x) ^ check_dot(y), (
f'(check_dot(x) ^ check_dot(y)) must be True, '
f'but check_dot(x)={check_dot(x)} and check_dot(y)={check_dot(y)}.')
if check_dot(x):
return mul(z_bar, y), None
else:
return None, mul(x, z_bar)
@REGISTER_TRANSPOSE('div_p')
def div_transpose(op, check_dot, z_bar):
x, y = op_position_inputs(op)
assert not check_dot(y), 'check_dot(y) must be False'
x_bar = div(z_bar, y) if check_dot(x) else None
return x_bar, None
@REGISTER_TRANSPOSE('reshape_p')
def reshape_transpose(op, check_dot, y_bar):
x, = op_position_inputs(op)
assert check_dot(x), 'check_dot(x) must be True'
return reshape(y_bar, shape=x.shape)
@REGISTER_TRANSPOSE('broadcast_p')
def broadcast_transpose(op, check_dot, y_bar):
x, = op_position_inputs(op)
assert check_dot(x), 'check_dot(x) must be True'
bat = len(y_bar.shape) - len(x.shape)
axis = list(range(bat))
keepdim = [(bat + i) for i, s in enumerate(x.shape) if s == 1]
axis += keepdim
# TODO: Change it. keepdim boolean
out = reduce(y_bar, axis=axis, keepdim=False)
return reshape(out, x.shape)
@REGISTER_TRANSPOSE('transpose_p')
def transpose_transpose(op, check_dot, y_bar):
x, = op_position_inputs(op)
assert check_dot(x), 'check_dot(x) must be True'
axis = op.attr('axis')
reordered = sorted((k, i) for i, k in enumerate(axis))
axis = [i for k, i in reordered]
return transpose(y_bar, axis=axis)
@REGISTER_TRANSPOSE('split_p')
def split_transpose(op, check_dot, ys_bar):
x, = op_position_inputs(op)
assert check_dot(x), 'check_dot(x) must be True'
return concat(ys_bar, axis=op.attr('axis'))
@REGISTER_TRANSPOSE('concat_p')
def concat_transpose(op, check_dot, y_bar):
xs, = op_position_inputs(op)
for x in xs:
assert check_dot(x), 'check_dot(x) must be True'
axis = op.attr('axis')
sections = [x.shape[axis] for x in xs]
return split(y_bar, num_or_sections=sections, axis=axis)
@REGISTER_TRANSPOSE('reduce_p')
def reduce_transpose(op, check_dot, y_bar):
x, = op_position_inputs(op)
assert check_dot(x), 'check_dot(x) must be True'
axes = op.attr('axis')
shape = tuple(1 if i in axes else size for i, size in enumerate(x.shape))
t = reshape(y_bar, shape=shape)
return broadcast(t, shape=x.shape)
@REGISTER_TRANSPOSE('matmul_p')
def matmul_transpose(op, check_dot, z_bar):
x, y = op_position_inputs(op)
assert check_dot(x) ^ check_dot(y), (
f'(check_dot(x) ^ check_dot(y)) must be True, '
f'but check_dot(x)={check_dot(x)} and check_dot(y)={check_dot(y)}.')
# TODO: replace it. this is hacky
axis = [1, 0] if len(x.shape) == 2 else [0, 2, 1]
if check_dot(x):
return matmul(z_bar, transpose(y, axis=axis)), None
else:
return None, matmul(transpose(x, axis=axis), z_bar)
@REGISTER_TRANSPOSE('slice_select_p')
def slice_select_transpose(op, check_dot, y_bar):
x, = op_position_inputs(op)
assert check_dot(x), 'check_dot(x) must be True'
zeros = fill_const(value=0.0, shape=x.shape, dtype=x.dtype)
axis = op.attr('axis')
starts = op.attr('starts')
ends = op.attr('ends')
strides = op.attr('strides')
return slice_assign(
zeros, y_bar, axis=axis, starts=starts, ends=ends, strides=strides)
@REGISTER_TRANSPOSE('slice_assign_p')
def slice_assign_transpose(op, check_dot, z_bar):
x, y = op_position_inputs(op)
assert check_dot(x) and check_dot(y), (
f'(check_dot(x) and check_dot(y)) must be True, '
f'but check_dot(x)={check_dot(x)} and check_dot(y)={check_dot(y)}.')
zeros = fill_const(value=0.0, shape=y.shape, dtype=y.dtype)
axis = op.attr('axis')
starts = op.attr('starts')
ends = op.attr('ends')
strides = op.attr('strides')
x_bar = slice_assign(
z_bar, zeros, axis=axis, starts=starts, ends=ends, strides=strides)
y_bar = slice_select(
z_bar, axis=axis, starts=starts, ends=ends, strides=strides)
return x_bar, y_bar
@REGISTER_TRANSPOSE('gather_p')
def gather_transpose(op, check_dot, y_bar):
x, indextensor = op_position_inputs(op)
assert check_dot(x), 'check_dot(x) must be True'
axis = op.attr('axis')
zeros = fill_const(0.0, x.shape, x.dtype)
x_bar = scatter_add(zeros, y_bar, indextensor, axis=axis)
indextensor_bar = None
return x_bar, indextensor_bar
@REGISTER_TRANSPOSE('scatter_add_p')
def scatter_add_transpose(op, check_dot, z_bar):
x, y, indextensor = op_position_inputs(op)
assert check_dot(x) and check_dot(y), (
f'(check_dot(x) and check_dot(y)) must be True, '
f'but check_dot(x)={check_dot(x)} and check_dot(y)={check_dot(y)}.')
axis = op.attr('axis')
zeros = fill_const(value=0.0, shape=y.shape, dtype=y.dtype)
x_bar = scatter_add(z_bar, zeros, indextensor, axis=axis)
y_bar = gather(z_bar, indextensor, axis=axis)
indextensor_bar = None
return x_bar, y_bar, indextensor_bar
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.fluid import framework as framework
from paddle.fluid.framework import default_main_program
from paddle.fluid.framework import Operator
from paddle import compat as cpt
from .primops import fill_const, add
from .primreg import op_position_inputs, op_position_output, lookup_orig2prim, lookup_prim2orig
from .primrules import _orig2prim, _prim2orig, _jvp, _transpose
from .utils import get_input_var_list, get_output_var_list, to_tensors, flatten, flatten_and_remove_none
from collections import OrderedDict
def topo_path(xs, ys, block=None):
""" Returns the list of ops on the path from `xs` to `ys` in topological
order.
TODO(Tongxin): supporting control flow and nested blocks.
Args:
xs: a list|tuple of vars as source
ys: a list|tuple of vars as sink
block: the program block containing the path, optional
Returns:
(path, unused_xs, unreached_ys): a tuple comprised of the resulting op
path, the unused variables in `xs`, and the unreached variables in `ys`
"""
if block is None:
block = default_main_program().current_block()
path = []
backpath = []
reached_vars = OrderedDict()
used_vars = OrderedDict()
# Initialize reached vars
for x in xs:
assert x is None or x.block == block, f'x is not None and x.block != block'
reached_vars[id(x)] = x
# Reaching test, returning whether an op is reached from the given input
reaching = lambda op: any(id(v) in reached_vars for v in flatten_and_remove_none(get_input_var_list(op)))
# block.ops are supposedly in the order that preserves correct data
# dependence.
# Forward pass to identify all reached variables and ops
for op in block.ops:
if reaching(op):
path.append(op)
for var in flatten_and_remove_none(get_output_var_list(op)):
reached_vars[id(var)] = var
used_vars = OrderedDict((id(y), y) for y in ys if id(y) in reached_vars)
back_reaching = lambda op: any(id(out) in used_vars for out in flatten_and_remove_none(get_output_var_list(op)))
# Backward pass to find all used variables
for op in reversed(path):
if back_reaching(op):
backpath.append(op)
for var in flatten_and_remove_none(get_input_var_list(op)):
used_vars[id(var)] = var
unused_xs = [x for x in xs if id(x) not in used_vars]
unreached_ys = [y for y in ys if id(y) not in reached_vars]
return list(reversed(backpath)), unused_xs, unreached_ys
def output_vars_on_path(path):
""" Returns the output variables of all the ops on the path from `xs`
to `ys`.
Args:
path: a list of ops on which to find the output variables
Returns:
vars: the output vars
"""
vars = OrderedDict()
for op in path:
for out in flatten_and_remove_none(get_output_var_list(op)):
vars[id(out)] = out
return vars
class VarMap(object):
""" A general map data structure for linking variables to variables.
An example is linking variables to their gradients.
"""
__slots__ = ['name', 'varset', 'tab']
def __init__(self, name, varset):
self.name = name
self.varset = varset
self.tab = OrderedDict()
def add(self, key_var, value_var):
self.tab[id(key_var)] = id(value_var)
def add_rec(self, key_vars, value_vars):
if value_vars is None:
return
if isinstance(key_vars, paddle.fluid.framework.Variable):
if not isinstance(value_vars, paddle.fluid.framework.Variable):
raise TypeError(
f'value_vars must be Variable, but got {type(value_vars)}')
self.tab[id(key_vars)] = id(value_vars)
else:
assert len(key_vars) == len(value_vars), (
f'len(key_vars) shoule be equal to len(value_vars), '
f'but len(key_vars)={len(key_vars)} and len(value_vars)={len(value_vars)}.'
)
for key_var, value_var in zip(key_vars, value_vars):
self.add_rec(key_var, value_var)
def lookup(self, key_var):
value_id = self.tab.get(id(key_var))
if value_id is not None:
return self.varset.get(value_id)
else:
return None
def delete(self, key_var):
varid = id(key_var)
if varid in self.tab:
del self.tab[id(key_var)]
def delete_keyvars(self, key_vars):
for var in key_vars:
varid = id(var)
if varid in self.tab:
del self.tab[varid]
def delete_valuevars(self, value_vars):
ids = [id(v) for v in value_vars]
keys = [k for k, v in self.tab.items() if v in ids]
for k in keys:
del self.tab[k]
def contain_var(self, key_var):
return self.tab.__contains__(id(key_var))
def contain_value(self, value_var):
return id(value_var) in self.tab.values()
class Transform(object):
""" An object that maintains the state of transformations applied to a
primitve program. """
def __init__(self, block):
self.block = block
self.vars = self.init_vars(block)
self.var2dot = VarMap('var2dot', self.vars)
self.dot2bar = VarMap('dot2var', self.vars)
def init_vars(self, block):
vars = OrderedDict()
for _, var in block.vars.items():
vars[id(var)] = var
return vars
def add_vars(self, new_vars):
self.vars.update({id(v): v for v in new_vars if v is not None})
def add_vars_rec(self, new_vars):
if new_vars is None:
return
if isinstance(new_vars, paddle.fluid.framework.Variable):
self.vars.update({id(new_vars): new_vars})
return
if not isinstance(new_vars, list):
raise TypeError(f'new_vars must be list, but got {type(new_vars)}')
for var in new_vars:
self.add_vars_rec(var)
def erase_ops(self, ordered_indexes):
block = self.block
for op_index in reversed(ordered_indexes):
block.desc._remove_op(op_index, op_index + 1)
# remove from block.ops
for op_index in reversed(ordered_indexes):
del block.ops[op_index]
block._sync_with_cpp()
def erase_dots(self, vars_to_erase):
for var in vars_to_erase:
if id(var) in self.vars:
del self.vars[id(var)]
self.dot2bar.delete_keyvars(vars_to_erase)
self.var2dot.delete_valuevars(vars_to_erase)
block = self.block
for var in vars_to_erase:
name = var.name
block.desc._remove_var(cpt.to_bytes(name))
del block.vars[name]
block._sync_with_cpp()
def var2dot_rec(self, vars):
""" Lookup var2dot recursively."""
if isinstance(vars, paddle.fluid.framework.Variable):
dot = self.var2dot.lookup(vars)
return dot
dots = [self.var2dot_rec(var) for var in vars]
return dots
def dot2bar_rec(self, dots):
if isinstance(dots, paddle.fluid.framework.Variable):
bar = self.dot2bar.lookup(dots)
assert bar is not None, 'bar must be not None'
return bar
bars = [self.dot2bar_rec(dot) for dot in dots]
return bars
def linearize(self, xs, ys, xs_dot=None):
""" Performs the linearization transform, a.k.a, forward mode AD
transform, on a primitive lowered program.
Args:
xs: a list of input variables
ys: a list of output variables
xs_dot: optional, a list of gradient input variables. The list size
must be equal to `len(xs)`. The shape and dtype of each element
must be the same as in `xs`
Returns:
(xs_dot, ys_dot): a tuple of two lists. `xs_dot` is the list of
gradient inputs of the resulting linearized program. `ys_dot` is
the list gradient outputs of the resulting linearized program
"""
if xs_dot is None:
xs_dot = [fill_const(1.0, shape=x.shape, dtype=x.dtype) for x in xs]
self.add_vars(xs_dot)
else:
assert len(xs) == len(xs_dot), (
f'len(xs) should be equal to len(xs_dot), '
f'but len(xs)={len(xs)} and len(xs_dot)={len(xs_dot)}')
for x, dot in zip(xs, xs_dot):
assert x.dtype == dot.dtype, (
f'x.dtype should be equal to dot.dtype, '
f'but x.dtype={x.dtype} and dot.dtype={dot.dtype}')
assert x.shape == dot.shape, (
f'x.shape should be equal to dot.shape, '
f'but x.shape={x.shape} and dot.shape={dot.shape}')
self.var2dot.add(x, dot)
path, unused_xs, _ = topo_path(xs, ys, self.block)
# No need to track unused inputs
for x in unused_xs:
self.var2dot.delete(x)
for op in path:
# An input var may not be on the input-output path, which implies
# there may be None's in `ins_dot`. In this case we place
# the original input in the position of the otherwise forward
# gradient.
ins = op_position_inputs(op)
jvp_ins = self.var2dot_rec(ins)
# apply op's forward ad rule
outs_dot = _jvp(op, *jvp_ins)
self.add_vars_rec(outs_dot)
outs = op_position_output(op)
self.var2dot.add_rec(outs, outs_dot)
ys_dot = [self.var2dot.lookup(y) for y in ys]
return xs_dot, ys_dot
def transpose(self, ys_dot, xs_dot, ys_bar=None, retain_fwd=False):
""" Performs the transpose transform, a.k.a, reverse mode AD
transform, on a linearized primitive program.
Note, `transpose` is supposed to be used in couple with `linearize`.
Args:
ys_dot: a list of outputs of the linearized program.
xs_dot: a list of inputs of the linearized program.
ys_bar: optional, a list of inputs of the resulting transposed
program. The list size must be equal to `len(ys_dot)`. The shape
and dtype of each element must be the same as in `ys_dot`
Returns:
(ys_bar, xs_bar): a tuple of two lists. `ys_bar` is the list of
inputs of the resulting transposed program. `xs_bar` is
the list outputs of the resulting transposed program
"""
assert all(v is not None for v in xs_dot), f'`xs_dot` includes None.'
assert all(v is not None for v in ys_dot), f'`ys_dot` includes None.'
if ys_bar is None:
ys_bar = []
for y in ys_dot:
ys_bar.append(fill_const(1.0, shape=y.shape, dtype=y.dtype))
self.add_vars(ys_bar)
else:
assert len(ys_dot) == len(ys_bar), (
f'len(ys_dot) should be equal to len(ys_bar), '
f'but len(ys_dot)={len(ys_dot)} and len(ys_bar)={len(ys_bar)}')
for y_dot, y_bar in zip(ys_dot, ys_bar):
assert y_dot.shape == y_bar.shape, (
f'y_dot.shape should be equal to y_bar.shape, '
f'but y_dot.shape={y_dot.shape} and y_bar.shape={y_bar.shape}'
)
assert y_dot.dtype == y_bar.dtype, (
f'y_dot.dtype should be equal to y_bar.dtype, '
f'but y_dot.dtype={y_dot.dtype} and y_bar.dtype={y_bar.dtype}'
)
for dot, bar in zip(ys_dot, ys_bar):
self.dot2bar.add(dot, bar)
# find all the relevant forward gradients
path, unused_xs_dot, _ = topo_path(xs_dot, ys_dot, self.block)
# No need to track unused inputs
for dot in unused_xs_dot:
self.dot2bar.delete(dot)
dotvars = output_vars_on_path(path)
dotvars.update((id(var), var) for var in xs_dot)
is_dot = lambda v: id(v) in dotvars
for op in reversed(path):
out = op_position_output(op)
out_bar_rec = self.dot2bar_rec(out)
ins_bar_rec = _transpose(op, is_dot, out_bar_rec)
# TODO(Tongxin): this is hacky. Tuple implies the Transpose rule
# returns multiple entities. There should be better ways to handle
# outputs.
if isinstance(ins_bar_rec, tuple):
ins_bar_rec = list(ins_bar_rec)
else:
ins_bar_rec = [ins_bar_rec]
self.add_vars_rec(ins_bar_rec)
ins_bar = flatten(ins_bar_rec)
ins = flatten(op_position_inputs(op))
assert len(ins) == len(ins_bar), (
f'len(ins) should be equal to len(ins_bar), '
f'but len(ins)={len(ins)} and len(ins_bar)={len(ins_bar)}')
for dot, bar in zip(ins, ins_bar):
if bar is not None:
# aggregate gradient
grad = self.dot2bar.lookup(dot)
if grad is None:
self.dot2bar.add(dot, bar)
else:
grad = add(grad, bar)
self.add_vars([grad])
self.dot2bar.add(dot, grad)
xs_bar = [self.dot2bar.lookup(x) for x in xs_dot]
if not retain_fwd and len(path) > 0:
vars_to_remove = set()
for op in path:
vars_to_remove.update(
flatten_and_remove_none(get_output_var_list(op)))
op_indexes = []
block = self.block
for i, op in enumerate(block.ops):
if op in path:
op_indexes.append(i)
path.pop(0)
if len(path) == 0:
break
self.erase_ops(op_indexes)
self.erase_dots(vars_to_remove)
return ys_bar, xs_bar
def _lower(block, reverse):
# Some functions which are only used in _lower.
def bind(args, to_bind, value_table):
for i in range(len(args)):
if isinstance(args[i], list):
bind(args[i], to_bind, value_table)
elif args[i] is not None and args[i].name in to_bind:
args[i] = value_table[to_bind[args[i].name]]
def bind_name(names, to_bind):
return_list = []
for name in names:
if isinstance(name, list):
return_list.append(bind_name(name, to_bind))
else:
return_list.append(to_bind[name] if name in to_bind else name)
return return_list
def expand_nested_list(xs):
return_list = []
for x in xs:
if isinstance(x, list):
return_list = return_list + expand_nested_list(x)
else:
return_list.append(x)
return return_list
# Step1: Do some preparatory work for lower
lower_fn = _prim2orig if reverse else _orig2prim
lookup_fn = lookup_prim2orig if reverse else lookup_orig2prim
if block is None:
program = default_main_program()
assert program.num_blocks == 1, "The lower transform is designed to process only one block."
block = program.current_block()
value_table = {}
to_bind = {}
to_bind_rev = {}
for var in block.desc.all_vars():
value_table[var.name()] = block.var(var.name())
ops_to_remove = []
vars_to_remove = set()
# Step2: Process all ops in the target block
for op_idx in range(len(block.ops)):
op = block.ops[op_idx]
ops_to_remove.append(op_idx)
if lookup_fn(op.type) is not None:
input_args = get_input_var_list(op)
bind(input_args, to_bind, value_table)
for orig_out, new_out in zip(
expand_nested_list(get_output_var_list(op)),
expand_nested_list(to_tensors(lower_fn(op, *input_args)))):
assert not (orig_out is None) ^ (
new_out is None), "orig_out and new_out should match."
vars_to_remove.add(new_out.name)
value_table[new_out.name] = new_out
to_bind[orig_out.name] = new_out.name
to_bind_rev[new_out.name] = orig_out.name
else:
inputs = {}
for i in range(len(op.input_names)):
inputs[op.input_names[i]] = bind_name(
op.input(op.input_names[i]), to_bind)
outputs = {}
for i in range(len(op.output_names)):
outputs[op.output_names[i]] = op.output(op.output_names[i])
attrs = {}
for name in sorted(op.attr_names):
attrs[name] = op.attr(name)
from paddle.fluid.dygraph.base import param_guard
new_op_desc = block.desc.append_op()
with param_guard(inputs), param_guard(outputs):
op = Operator(
block=block,
desc=new_op_desc,
type=op.type,
inputs=inputs,
outputs=outputs,
attrs=attrs)
block.ops.append(op)
# Step3: Do some post-processing work
for op_idx in reversed(ops_to_remove):
block.desc._remove_op(op_idx, op_idx + 1)
del block.ops[op_idx]
block._sync_with_cpp()
for op_idx in range(len(block.ops)):
op = block.ops[op_idx]
for in_name in op.input_arg_names:
if in_name in to_bind_rev:
op._rename_input(in_name, to_bind_rev[in_name])
for out_name in op.output_arg_names:
if out_name in to_bind_rev:
op._rename_output(out_name, to_bind_rev[out_name])
for var_name in sorted(vars_to_remove):
assert var_name in to_bind_rev, 'var_name "{}" is not in to_bind_rev.'.format(
var_name)
if var_name != to_bind_rev[var_name]:
block.desc._remove_var(cpt.to_bytes(var_name))
del block.vars[var_name]
block._sync_with_cpp()
@framework.static_only
def orig2prim(block=None):
"""
.. note::
**This API is ONLY available in the static mode.**
All operators in the target block are processed as follows.
If it is an original operator, it will be transformed into
one or a series of automatic differential basic operators with
equivalent function.
Args:
block(paddle.fluid.framework.Variable|None, optional): The
target block to process on. Default None, and will
process on the current block of main program.
Returns:
None
"""
_lower(block, reverse=False)
@framework.static_only
def prim2orig(block=None):
"""
.. note::
**ONLY available in the static mode.**
All operators in the target block are processed as follows.
If it is an automatic differential basic operator, it will be
transformed into one or a series of original operators with
equivalent function to support execution.
Args:
block(paddle.static.Variable|None, optional): The
target block to process on. Default None, and will
process on the current block of main program.
Examples:
.. code-block:: python
import paddle
from paddle.incubate.autograd import enable_prim, prim_enabled, prim2orig
paddle.enable_static()
enable_prim()
x = paddle.ones(shape=[2, 2], dtype='float32')
x.stop_gradients = False
y = x * x
dy_dx = paddle.static.gradients(y, x)
if prim_enabled():
prim2orig()
"""
_lower(block, reverse=True)
def _gradients(ys, xs, ys_bar=None):
""" A drop-in replacement of paddle.gradients but instead computing
on primitive ops.
Args:
ys: the target tensor or tensors
xs: the input tensor or tensors
ys_bar: the optional gradient tensors of `ys`
Returns:
xs_bar: a list gradients of input `xs`
"""
ys, xs = to_tensors(ys), to_tensors(xs)
block = ys[0].block
# TODO(Tongxin) without any prior knowledge about whether the program
# is completely lowered to primitive ops, it's mandatory to run the lowering
# pass once and again. This is obviously inefficient and needs to be
# optimized.
orig2prim(block)
ad = Transform(block)
xs_dot, ys_dot = ad.linearize(xs, ys)
if any(var is None for var in ys_dot):
assert False, f'Gradients cannot be computed. The given output `ys` does not depend on input `xs`.'
ys_bar, xs_bar = ad.transpose(ys_dot, xs_dot, ys_bar)
# remove xs_dot and their constructor ops
op_indexes = []
for var in xs_dot:
if var is not None:
op_index = block.ops.index(var.op)
assert op_index >= 0, f'op_index should be greater than or equal to 0, but op_index={op_index}.'
op_indexes.append(op_index)
ad.erase_ops(sorted(op_indexes))
ad.erase_dots(xs_dot)
return xs_bar
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.fluid import framework as framework
class PrimOption(object):
def __init__(self):
self.enable_prim = False
def get_status(self):
return self.enable_prim
def set_status(self, flag):
self.enable_prim = flag
prim_option = PrimOption()
@framework.static_only
def prim_enabled():
"""
.. note::
**ONLY available in the static mode.**
Shows whether the automatic differentiation mechanism based on
automatic differential basic operators is ON. Defaults to OFF.
Returns:
flag(bool): Whether the automatic differentiation mechanism based on automatic differential basic operators is ON.
Examples:
.. code-block:: python
import paddle
from paddle.incubate.autograd import enable_prim, disable_prim, prim_enabled
paddle.enable_static()
enable_prim()
print(prim_enabled()) # True
disable_prim()
print(prim_enabled()) # False
"""
return prim_option.get_status()
@framework.static_only
def enable_prim():
"""
.. note::
**ONLY available in the static mode.**
Turns ON automatic differentiation mechanism based on automatic
differential basic operators.
Examples:
.. code-block:: python
import paddle
from paddle.incubate.autograd import enable_prim, prim_enabled
paddle.enable_static()
enable_prim()
print(prim_enabled()) # True
"""
prim_option.set_status(True)
@framework.static_only
def disable_prim():
"""
.. note::
**ONLY available in the static mode.**
Turns OFF automatic differentiation mechanism based on automatic
differential basic operators.
Examples:
.. code-block:: python
import paddle
from paddle.incubate.autograd import enable_prim, disable_prim, prim_enabled
paddle.enable_static()
enable_prim()
print(prim_enabled()) # True
disable_prim()
print(prim_enabled()) # False
"""
prim_option.set_status(False)
INT_DTYPE_2_STRING = {
int(0): 'bool',
int(1): 'int16',
int(2): 'int32',
int(3): 'int64',
int(4): 'float16',
int(5): 'float32',
int(6): 'float64',
int(20): 'uint8',
int(21): 'int8',
int(23): 'complex64',
int(24): 'complex128',
}
def get_var_block(block, names):
assert isinstance(names, list)
if len(names) == 0:
return None
elif len(names) == 1:
return block.var(names[0])
else:
return [block.var(name) for name in names]
def get_input_var_list(op):
if op.input_names is None:
return []
else:
return [
get_var_block(op.block, op.input(n)) for n in sorted(op.input_names)
]
def get_output_var_list(op):
if op.output_names is None:
return []
else:
return [
get_var_block(op.block, op.output(n))
for n in sorted(op.output_names)
]
def to_tensors(xs):
if isinstance(xs, paddle.fluid.framework.Variable):
return [xs]
else:
return xs
def flatten(inp):
if inp is None or isinstance(inp, paddle.fluid.framework.Variable):
return [inp]
flattened = []
for part in inp:
flattened += flatten(part)
return flattened
def flatten_and_remove_none(inp):
flattened = flatten(inp)
return [var for var in flattened if var is not None]
......@@ -47,6 +47,45 @@ from paddle.fluid.framework import _in_legacy_dygraph, _in_eager_without_dygraph
__all__ = []
@framework.static_only
def append_backward_new(loss_list,
parameter_list=None,
no_grad_set=None,
callbacks=None,
checkpoints=None,
distop_context=None):
from paddle.incubate.autograd.primx import orig2prim, Transform
program = default_main_program()
assert program.num_blocks == 1, "The append_backward_new interface is designed to process only one block."
block = program.current_block()
orig2prim(block)
ad = Transform(block)
if parameter_list is None:
parameter_list = program.global_block().all_parameters()
param_dot, loss_dot = ad.linearize(parameter_list, loss_list)
loss_bar, param_bar = ad.transpose(loss_dot, param_dot)
# remove param_dot and their constructor ops
op_indexes = []
for var in param_dot:
if var is not None:
op_index = block.ops.index(var.op)
assert op_index >= 0
op_indexes.append(op_index)
ad.erase_ops(sorted(op_indexes))
ad.erase_dots(param_dot)
if len(parameter_list) == 1:
params_and_grads = [(parameter_list, param_bar)]
else:
params_and_grads = []
for i, param in enumerate(parameter_list):
params_and_grads.append((param, param_bar[i]))
return params_and_grads
class Optimizer(object):
r"""Optimizer Base class.
......@@ -880,8 +919,13 @@ class Optimizer(object):
parameter_list = parameters if parameters \
else self._parameter_list
with program_guard(program, startup_program):
params_grads = append_backward(loss, parameter_list,
act_no_grad_set, callbacks)
from paddle.incubate.autograd.utils import prim_enabled
if prim_enabled():
params_grads = append_backward_new(
[loss], parameter_list, act_no_grad_set, callbacks)
else:
params_grads = append_backward(loss, parameter_list,
act_no_grad_set, callbacks)
# Note: since we can't use all_reduce_op now,
# dgc_op should be the last op of one grad.
self._append_dgc_ops(params_grads)
......
......@@ -368,6 +368,7 @@ packages=['paddle',
'paddle.incubate.nn.functional',
'paddle.incubate.nn.layer',
'paddle.incubate.optimizer.functional',
'paddle.incubate.autograd',
'paddle.incubate.distributed',
'paddle.incubate.distributed.models',
'paddle.incubate.distributed.models.moe',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册