diff --git a/python/paddle/autograd/primops.py b/python/paddle/autograd/primops.py new file mode 100644 index 0000000000000000000000000000000000000000..66f641e54467c6791ced0e11cda14afb29e5c316 --- /dev/null +++ b/python/paddle/autograd/primops.py @@ -0,0 +1,267 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.fluid import unique_name, core +from paddle.fluid.framework import default_main_program, default_startup_program +from paddle.fluid.layer_helper import LayerHelper +from .primreg import REGISTER_FN + + +def _simple_unop(helper): + optype = helper.layer_type + x, out = tuple(map(helper.kwargs.get, ('x', 'out'))) + if out is None: + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op(type=optype, inputs={'X': x}, outputs={'Y': out}, attrs={}) + return out + + +def _simple_binop(helper): + optype = helper.layer_type + x, y, out = tuple(map(helper.kwargs.get, ('x', 'y', 'out'))) + if out is None: + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type=optype, inputs={'X': x, + 'Y': y}, outputs={'Z': out}, attrs={}) + return out + + +def _manipulation_unop(helper): + optype = helper.layer_type + x, out = tuple(map(helper.kwargs.get, ('x', 'out'))) + + attrs = { + k: helper.kwargs[k] + for k in ('shape', 'axis', 'index') if k in helper.kwargs + } + + if out is None: + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type=optype, inputs={'X': x}, outputs={'Y': out}, attrs=attrs) + return out + + +# Each primitive op is given a Python constructor for sake of convenience. +def fill_const(value, shape, dtype, out=None): + attrs = {'value': value, 'shape': shape, 'dtype': dtype} + helper = LayerHelper('fill_constant_p', **locals()) + if out is None: + out = helper.create_variable_for_type_inference(dtype) + helper.append_op(type=helper.layer_type, outputs={'Y': out}, attrs=attrs) + return out + + +def neg(x, out=None): + zero = fill_const(0.0, x.shape, x.dtype) + return sub(zero, x) + + +def set_value(x, y, axis, starts, ends, strides, out): + assert x is out, "x and out should be the same Tensor in set_value" + attrs = {'axes': axis, 'starts': starts, 'ends': ends, 'steps': strides} + helper = LayerHelper('set_value', **locals()) + helper.append_op( + type=helper.layer_type, + inputs={'Input': x, + 'ValueTensor': y}, + outputs={'Out': out}, + attrs=attrs) + return out + + +@REGISTER_FN('add_p', 'X', 'Y', 'Z') +def add(x, y, out=None): + return _simple_binop(LayerHelper('add_p', **locals())) + + +@REGISTER_FN('sub_p', 'X', 'Y', 'Z') +def sub(x, y, out=None): + return _simple_binop(LayerHelper('sub_p', **locals())) + + +@REGISTER_FN('mul_p', 'X', 'Y', 'Z') +def mul(x, y, out=None): + return _simple_binop(LayerHelper('mul_p', **locals())) + + +@REGISTER_FN('div_p', 'X', 'Y', 'Z') +def div(x, y, out=None): + return _simple_binop(LayerHelper('div_p', **locals())) + + +@REGISTER_FN('sqrt_p', 'X', 'Y') +def sqrt(x, out=None): + return _simple_unop(LayerHelper('sqrt_p', **locals())) + + +@REGISTER_FN('tanh_p', 'X', 'Y') +def tanh(x, out=None): + return _simple_unop(LayerHelper('tanh_p', **locals())) + + +@REGISTER_FN('reshape_p', 'X', 'Y') +def reshape(x, shape, out=None): + return _manipulation_unop(LayerHelper('reshape_p', **locals())) + + +@REGISTER_FN('broadcast_p', 'X', 'Y') +def broadcast(x, shape, out=None): + return _manipulation_unop(LayerHelper('broadcast_p', **locals())) + + +@REGISTER_FN('transpose_p', 'X', 'Y') +def transpose(x, axis=None, out=None): + return _manipulation_unop(LayerHelper('transpose_p', **locals())) + + +@REGISTER_FN('split_p', 'X', 'YS') +def split(x, num_or_sections, axis=0, outs=None): + if isinstance(num_or_sections, (list, tuple)): + n = len(num_or_sections) + else: + assert isinstance(num_or_sections, int) + n = num_or_sections + + attrs = {'num_or_sections': num_or_sections, 'axis': axis} + + helper = LayerHelper('split_p', **locals()) + if outs is None: + outs = [ + helper.create_variable_for_type_inference(dtype=x.dtype) + for i in range(n) + ] + helper.append_op( + type=helper.layer_type, + inputs={'X': x}, + outputs={'YS': outs}, + attrs=attrs) + return outs + + +@REGISTER_FN('concat_p', 'XS', 'Y') +def concat(xs, axis=0, out=None): + assert isinstance(xs, (list, tuple)) and len(xs) > 0 + attrs = {'axis': axis} + helper = LayerHelper('concat_p', **locals()) + if out is None: + out = helper.create_variable_for_type_inference(dtype=xs[0].dtype) + helper.append_op( + type=helper.layer_type, + inputs={'XS': xs}, + outputs={'Y': out}, + attrs=attrs) + return out + + +@REGISTER_FN('reduce_p', 'X', 'Y') +def reduce(x, axis, keepdim=False, out=None): + assert isinstance(axis, (tuple, list)) + assert isinstance(keepdim, bool) + + attrs = {'axis': axis, 'keepdim': keepdim} + + helper = LayerHelper('reduce_p', **locals()) + if out is None: + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type=helper.layer_type, + inputs={'X': x}, + outputs={'Y': out}, + attrs=attrs) + return out + + +@REGISTER_FN('matmul_p', 'X', 'Y', 'Z') +def matmul(x, y, out=None): + return _simple_binop(LayerHelper('matmul_p', **locals())) + + +@REGISTER_FN('slice_select_p', 'X', 'Y') +def slice_select(x, axis, starts, ends, strides, out=None): + assert isinstance(axis, (list, tuple)), ( + f'Argument type error. `axis` is supposed to be int, list or' + f' tuple but found {type(axis)}.') + assert isinstance(starts, (list, tuple)) + assert isinstance(ends, (list, tuple)) + assert len(axis) == len(starts) == len(ends) == len(strides) + + attrs = {'axis': axis, 'starts': starts, 'ends': ends, 'strides': strides} + helper = LayerHelper('slice_select_p', **locals()) + if out is None: + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type=helper.layer_type, + inputs={'X': x}, + outputs={'Y': out}, + attrs=attrs) + return out + + +@REGISTER_FN('slice_assign_p', 'X', 'Y', 'Z') +def slice_assign(x, y, axis, starts, ends, strides, out=None): + assert len(starts) == len(ends) == len(strides) == len(axis) + assert len(y.shape) == len(x.shape) + + attrs = {'axis': axis, 'starts': starts, 'ends': ends, 'strides': strides} + helper = LayerHelper('slice_assign_p', **locals()) + if out is None: + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type=helper.layer_type, + inputs={'X': x, + 'Y': y}, + outputs={'Z': out}, + attrs=attrs) + return out + + +@REGISTER_FN('gather_p', 'X', 'Y') +def gather(x, indextensor, axis, out=None): + attrs = {'axis': axis} + helper = LayerHelper('gather_p', **locals()) + if out is None: + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type=helper.layer_type, + inputs={'X': x, + 'IndexTensor': indextensor}, + outputs={'Y': out}, + attrs=attrs) + return out + + +@REGISTER_FN('scatter_add_p', 'X', 'Y', 'IndexTensor', 'Z') +def scatter_add(x, y, indextensor, axis, out=None): + assert len(x.shape) == len(y.shape) + assert len(indextensor.shape) == 1 + assert y.shape[axis] == indextensor.shape[0] + attrs = {'axis': axis} + helper = LayerHelper('scatter_add_p', **locals()) + if out is None: + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type=helper.layer_type, + inputs={'X': x, + 'Y': y, + 'IndexTensor': indextensor}, + outputs={'Z': out}, + attrs=attrs) + return out diff --git a/python/paddle/autograd/primreg.py b/python/paddle/autograd/primreg.py new file mode 100644 index 0000000000000000000000000000000000000000..cffb4bc050b4be7f88f926e7d718234cf531635d --- /dev/null +++ b/python/paddle/autograd/primreg.py @@ -0,0 +1,54 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools + + +class Registry(object): + """ A general registry object. """ + __slots__ = ['name', 'tab'] + + def __init__(self, name): + self.name = name + self.tab = {} + + def register(self, name, value): + assert name not in self.tab + self.tab[name] = value + + def lookup(self, name): + assert name in self.tab, f'No registry entry is found with name: {name}' + return self.tab[name] + + +_primop_fn = Registry('primop_fn') +_orig2prim = Registry('orig2prim') +_prim2orig = Registry('prim2orig') +_primop_jvp = Registry('primop_jvp') +_primop_transpose = Registry('primop_transpose') +_primop_position_argnames = Registry('primop_position_argnames') + + +def REGISTER_FN(op_type, *position_argnames): + """Decorator for registering the Python function for a primitive op.""" + + assert isinstance(op_type, str) + + _primop_position_argnames.register(op_type, position_argnames) + + def wrapper(f): + _primop_fn.register(op_type, f) + return f + + return wrapper diff --git a/python/paddle/fluid/tests/unittests/test_primops.py b/python/paddle/fluid/tests/unittests/test_primops.py new file mode 100644 index 0000000000000000000000000000000000000000..cbf77c26666118e8f1cca238dedd347f756636ff --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_primops.py @@ -0,0 +1,147 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np + +import paddle +from paddle.autograd.primops import ( + neg, set_value, add, sub, mul, div, sqrt, tanh, reshape, broadcast, + transpose, split, concat, reduce, matmul, slice_select, slice_assign, + gather, scatter_add, fill_const) + + +class TestPyPrimOps(unittest.TestCase): + """ Test Python wrappers of primitive ops. """ + + def setUp(self): + paddle.enable_static() + + def test_ops(self): + A = np.random.rand(1) + B = np.random.rand(2) + C = np.random.rand(2, 3) + D = np.random.rand(2, 3) + E = np.random.rand(3, 2) + + a = paddle.static.data(name='A', shape=A.shape, dtype='float32') + b = paddle.static.data(name='B', shape=B.shape, dtype='float32') + c = paddle.static.data(name='C', shape=C.shape, dtype='float32') + d = paddle.static.data(name='D', shape=D.shape, dtype='float32') + e = paddle.static.data(name='E', shape=E.shape, dtype='float32') + + add_1 = add(a, a) + self.assertEqual(add_1.dtype, a.dtype) + self.assertEqual(add_1.shape, a.shape) + + add_2 = add(c, d) + self.assertEqual(add_2.dtype, c.dtype) + self.assertEqual(add_2.shape, c.shape) + + sub_1 = sub(c, d) + self.assertEqual(sub_1.dtype, c.dtype) + self.assertEqual(sub_1.shape, c.shape) + + mul_1 = mul(c, d) + self.assertEqual(mul_1.dtype, c.dtype) + self.assertEqual(mul_1.shape, c.shape) + + div_1 = div(c, d) + self.assertEqual(div_1.dtype, c.dtype) + self.assertEqual(div_1.shape, c.shape) + + sqrt_1 = sqrt(b) + self.assertEqual(sqrt_1.dtype, b.dtype) + self.assertEqual(sqrt_1.shape, b.shape) + + tanh_1 = tanh(d) + self.assertEqual(tanh_1.dtype, d.dtype) + self.assertEqual(tanh_1.shape, d.shape) + + reshape_1 = reshape(c, d.shape) + self.assertEqual(reshape_1.dtype, c.dtype) + self.assertEqual(reshape_1.shape, d.shape) + + broadcast_1 = broadcast(b, e.shape) + self.assertEqual(broadcast_1.dtype, b.dtype) + self.assertEqual(broadcast_1.shape, e.shape) + + transpose_1 = transpose(c, axis=[1, 0]) + self.assertEqual(transpose_1.dtype, c.dtype) + self.assertEqual(transpose_1.shape, e.shape) + + split_1_0, split_1_1 = split(c, num_or_sections=[1, 2], axis=1) + self.assertEqual(split_1_0.dtype, c.dtype) + self.assertEqual(split_1_0.shape, (2, 1)) + self.assertEqual(split_1_1.shape, (2, 2)) + + concat_1 = concat([c, d], axis=0) + self.assertEqual(concat_1.dtype, c.dtype) + self.assertEqual(concat_1.shape, (4, 3)) + + reduce_1 = reduce(d, axis=[1]) + self.assertEqual(reduce_1.dtype, d.dtype) + self.assertEqual(reduce_1.shape, (2, )) + + reduce_2 = reduce(c, axis=[0, 1]) + self.assertEqual(reduce_2.dtype, c.dtype) + self.assertEqual(reduce_2.shape, (1, )) + # TODO: reduce + keepdim + + matmul_1 = matmul(d, e) + self.assertEqual(matmul_1.dtype, d.dtype) + self.assertEqual(matmul_1.shape, (2, 2)) + + slice_select_1 = slice_select( + e, axis=[0], starts=[0], ends=[2], strides=[1]) + self.assertEqual(slice_select_1.dtype, e.dtype) + self.assertEqual(slice_select_1.shape, (2, 2)) + + slice_select_2 = slice_select( + d, axis=[0, 1], starts=[0, 1], ends=[2, 3], strides=[1, 2]) + self.assertEqual(slice_select_2.dtype, d.dtype) + self.assertEqual(slice_select_2.shape, (2, 1)) + + y = broadcast(b, [2, 2]) + slice_assign_1 = slice_assign( + d, y, axis=[1], starts=[1], ends=[3], strides=[1]) + self.assertEqual(slice_assign_1.dtype, d.dtype) + self.assertEqual(slice_assign_1.shape, d.shape) + + index = paddle.static.data('index', shape=[5], dtype='int32') + gather_1 = gather(e, index, axis=0) + self.assertEqual(gather_1.dtype, e.dtype) + self.assertEqual(gather_1.shape, (5, 2)) + + y = paddle.rand([5, 2], dtype='float32') + scatter_add_1 = scatter_add(e, y, index, axis=0) + self.assertEqual(scatter_add_1.dtype, e.dtype) + self.assertEqual(scatter_add_1.shape, e.shape) + + fill_const_1 = fill_const(value=10, shape=a.shape, dtype=a.dtype) + self.assertEqual(fill_const_1.shape, a.shape) + self.assertEqual(fill_const_1.dtype, a.dtype) + + neg_1 = neg(x=b) + self.assertEqual(neg_1.shape, b.shape) + self.assertEqual(neg_1.dtype, b.dtype) + + set_value_1 = set_value( + d, a, axis=[1], starts=[1], ends=[3], strides=[1], out=d) + self.assertEqual(set_value_1.shape, d.shape) + self.assertEqual(set_value_1.dtype, d.dtype) + + +if __name__ == '__main__': + unittest.main()