From ebf4fe6e5b849f74decbe0327c0c989adbab2e0d Mon Sep 17 00:00:00 2001 From: levi131 <83750468+levi131@users.noreply.github.com> Date: Sat, 16 Apr 2022 10:36:53 +0800 Subject: [PATCH] Lml/prim op pywrapper (#41813) * native commit for triple grad of sigmod * Updated unittests files * init functional jacobian api * Updated trible_test func * Updated gradient_checker & test_script * finish test with dtype float32 * add float64 test case * polish code * use atol=1e-5 with dtype float64 * fix for ci * set timeout for test_jacobian * fix dygraph grad to support high differential * polish API docstring * Updated gradient checker and some related files * fix double grad strip error for high differential * fix double grad strip error for high differential * Add Sigmoid triple grad tests * fix dygraph double grad dtype error when calling for high differential senario * Updated triple grad teses func * Use np.random to initialize ddx * Updated triple_grad_check func * add todo for gradient checker and refine some comments * remove additional code * add test for warnging in backward.py * format python code * support multi input in triple gradient checker * Add matmul triple grad kernel * Updated comments of TODO * Supported some special tests * Change code-format to follow CI std * Updated gradient_checker.py * Fix conflicts * Removed unnecessary printing log * Change code style to follow CI std * merge upstream * add priops.py * add_p * rm useless files * add sub_p mul_p div_p * add sqrt_p and tanh_p * add reshape_p * add broadcast_p * Add python primitive wrappers. * Jvp rules updated. * JVP rules done for all the 17 primops. * quick check and fixes. * add jvp(op, *args) * add broadcast_p fill_constant_p matmul_p reduce_p reshape_p transpose_p * add split_p and concat_p * add gather_p and scatter_add_p * add slice_select_p and slice_assign_p * Add transpose rules. * add multi input check for add_p, sub_p, mul_p, div_p * update concat_p * Linearize and transpose in progress.. * refine gather_p and scatter_add_p * updated. * update transpose. * refine slice_assign_p and slice_select_p * init commit for lower * Merged with primitive ops. * small update * add rules for orig2prim and prim2orig * add 9 test for prim ops * add more test and fix some bug * add more test * register proto * Adding primops test. * add shape valid check for broadcast_p op, and add keepdim attr into reduce_p op proto * support multi input and multi output for split_p and concat_p * Test updated. * update * fix slice bug for slice_select_p and slice_assign_p * updated. * Ops updated. * Refactor and bug fixes. * updated. * finish orig2prim and prim2orig rules * dtype for axis attr should be long int * update dtype for axis attr int64_t * update for iscan CI * Update primx. * Refactor vars in primx. * update for lower transform * update primx.py * update * Fix linearize and transpose. * Update is_dot * Update is_dot * Update is_dot * add gradient aggregation, fix add_transpose. * pass first linearize+transpose test. * update test * add_prim_op_pywrapper * Add primops UT * Fix set_value and update * Fix code format and PR-CI-Coverage Co-authored-by: veyron95 Co-authored-by: Jiabin Yang <360788950@qq.com> Co-authored-by: Tongxin Bai Co-authored-by: 0x45f --- python/paddle/autograd/primops.py | 267 ++++++++++++++++++ python/paddle/autograd/primreg.py | 54 ++++ .../fluid/tests/unittests/test_primops.py | 147 ++++++++++ 3 files changed, 468 insertions(+) create mode 100644 python/paddle/autograd/primops.py create mode 100644 python/paddle/autograd/primreg.py create mode 100644 python/paddle/fluid/tests/unittests/test_primops.py diff --git a/python/paddle/autograd/primops.py b/python/paddle/autograd/primops.py new file mode 100644 index 0000000000..66f641e544 --- /dev/null +++ b/python/paddle/autograd/primops.py @@ -0,0 +1,267 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.fluid import unique_name, core +from paddle.fluid.framework import default_main_program, default_startup_program +from paddle.fluid.layer_helper import LayerHelper +from .primreg import REGISTER_FN + + +def _simple_unop(helper): + optype = helper.layer_type + x, out = tuple(map(helper.kwargs.get, ('x', 'out'))) + if out is None: + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op(type=optype, inputs={'X': x}, outputs={'Y': out}, attrs={}) + return out + + +def _simple_binop(helper): + optype = helper.layer_type + x, y, out = tuple(map(helper.kwargs.get, ('x', 'y', 'out'))) + if out is None: + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type=optype, inputs={'X': x, + 'Y': y}, outputs={'Z': out}, attrs={}) + return out + + +def _manipulation_unop(helper): + optype = helper.layer_type + x, out = tuple(map(helper.kwargs.get, ('x', 'out'))) + + attrs = { + k: helper.kwargs[k] + for k in ('shape', 'axis', 'index') if k in helper.kwargs + } + + if out is None: + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type=optype, inputs={'X': x}, outputs={'Y': out}, attrs=attrs) + return out + + +# Each primitive op is given a Python constructor for sake of convenience. +def fill_const(value, shape, dtype, out=None): + attrs = {'value': value, 'shape': shape, 'dtype': dtype} + helper = LayerHelper('fill_constant_p', **locals()) + if out is None: + out = helper.create_variable_for_type_inference(dtype) + helper.append_op(type=helper.layer_type, outputs={'Y': out}, attrs=attrs) + return out + + +def neg(x, out=None): + zero = fill_const(0.0, x.shape, x.dtype) + return sub(zero, x) + + +def set_value(x, y, axis, starts, ends, strides, out): + assert x is out, "x and out should be the same Tensor in set_value" + attrs = {'axes': axis, 'starts': starts, 'ends': ends, 'steps': strides} + helper = LayerHelper('set_value', **locals()) + helper.append_op( + type=helper.layer_type, + inputs={'Input': x, + 'ValueTensor': y}, + outputs={'Out': out}, + attrs=attrs) + return out + + +@REGISTER_FN('add_p', 'X', 'Y', 'Z') +def add(x, y, out=None): + return _simple_binop(LayerHelper('add_p', **locals())) + + +@REGISTER_FN('sub_p', 'X', 'Y', 'Z') +def sub(x, y, out=None): + return _simple_binop(LayerHelper('sub_p', **locals())) + + +@REGISTER_FN('mul_p', 'X', 'Y', 'Z') +def mul(x, y, out=None): + return _simple_binop(LayerHelper('mul_p', **locals())) + + +@REGISTER_FN('div_p', 'X', 'Y', 'Z') +def div(x, y, out=None): + return _simple_binop(LayerHelper('div_p', **locals())) + + +@REGISTER_FN('sqrt_p', 'X', 'Y') +def sqrt(x, out=None): + return _simple_unop(LayerHelper('sqrt_p', **locals())) + + +@REGISTER_FN('tanh_p', 'X', 'Y') +def tanh(x, out=None): + return _simple_unop(LayerHelper('tanh_p', **locals())) + + +@REGISTER_FN('reshape_p', 'X', 'Y') +def reshape(x, shape, out=None): + return _manipulation_unop(LayerHelper('reshape_p', **locals())) + + +@REGISTER_FN('broadcast_p', 'X', 'Y') +def broadcast(x, shape, out=None): + return _manipulation_unop(LayerHelper('broadcast_p', **locals())) + + +@REGISTER_FN('transpose_p', 'X', 'Y') +def transpose(x, axis=None, out=None): + return _manipulation_unop(LayerHelper('transpose_p', **locals())) + + +@REGISTER_FN('split_p', 'X', 'YS') +def split(x, num_or_sections, axis=0, outs=None): + if isinstance(num_or_sections, (list, tuple)): + n = len(num_or_sections) + else: + assert isinstance(num_or_sections, int) + n = num_or_sections + + attrs = {'num_or_sections': num_or_sections, 'axis': axis} + + helper = LayerHelper('split_p', **locals()) + if outs is None: + outs = [ + helper.create_variable_for_type_inference(dtype=x.dtype) + for i in range(n) + ] + helper.append_op( + type=helper.layer_type, + inputs={'X': x}, + outputs={'YS': outs}, + attrs=attrs) + return outs + + +@REGISTER_FN('concat_p', 'XS', 'Y') +def concat(xs, axis=0, out=None): + assert isinstance(xs, (list, tuple)) and len(xs) > 0 + attrs = {'axis': axis} + helper = LayerHelper('concat_p', **locals()) + if out is None: + out = helper.create_variable_for_type_inference(dtype=xs[0].dtype) + helper.append_op( + type=helper.layer_type, + inputs={'XS': xs}, + outputs={'Y': out}, + attrs=attrs) + return out + + +@REGISTER_FN('reduce_p', 'X', 'Y') +def reduce(x, axis, keepdim=False, out=None): + assert isinstance(axis, (tuple, list)) + assert isinstance(keepdim, bool) + + attrs = {'axis': axis, 'keepdim': keepdim} + + helper = LayerHelper('reduce_p', **locals()) + if out is None: + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type=helper.layer_type, + inputs={'X': x}, + outputs={'Y': out}, + attrs=attrs) + return out + + +@REGISTER_FN('matmul_p', 'X', 'Y', 'Z') +def matmul(x, y, out=None): + return _simple_binop(LayerHelper('matmul_p', **locals())) + + +@REGISTER_FN('slice_select_p', 'X', 'Y') +def slice_select(x, axis, starts, ends, strides, out=None): + assert isinstance(axis, (list, tuple)), ( + f'Argument type error. `axis` is supposed to be int, list or' + f' tuple but found {type(axis)}.') + assert isinstance(starts, (list, tuple)) + assert isinstance(ends, (list, tuple)) + assert len(axis) == len(starts) == len(ends) == len(strides) + + attrs = {'axis': axis, 'starts': starts, 'ends': ends, 'strides': strides} + helper = LayerHelper('slice_select_p', **locals()) + if out is None: + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type=helper.layer_type, + inputs={'X': x}, + outputs={'Y': out}, + attrs=attrs) + return out + + +@REGISTER_FN('slice_assign_p', 'X', 'Y', 'Z') +def slice_assign(x, y, axis, starts, ends, strides, out=None): + assert len(starts) == len(ends) == len(strides) == len(axis) + assert len(y.shape) == len(x.shape) + + attrs = {'axis': axis, 'starts': starts, 'ends': ends, 'strides': strides} + helper = LayerHelper('slice_assign_p', **locals()) + if out is None: + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type=helper.layer_type, + inputs={'X': x, + 'Y': y}, + outputs={'Z': out}, + attrs=attrs) + return out + + +@REGISTER_FN('gather_p', 'X', 'Y') +def gather(x, indextensor, axis, out=None): + attrs = {'axis': axis} + helper = LayerHelper('gather_p', **locals()) + if out is None: + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type=helper.layer_type, + inputs={'X': x, + 'IndexTensor': indextensor}, + outputs={'Y': out}, + attrs=attrs) + return out + + +@REGISTER_FN('scatter_add_p', 'X', 'Y', 'IndexTensor', 'Z') +def scatter_add(x, y, indextensor, axis, out=None): + assert len(x.shape) == len(y.shape) + assert len(indextensor.shape) == 1 + assert y.shape[axis] == indextensor.shape[0] + attrs = {'axis': axis} + helper = LayerHelper('scatter_add_p', **locals()) + if out is None: + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type=helper.layer_type, + inputs={'X': x, + 'Y': y, + 'IndexTensor': indextensor}, + outputs={'Z': out}, + attrs=attrs) + return out diff --git a/python/paddle/autograd/primreg.py b/python/paddle/autograd/primreg.py new file mode 100644 index 0000000000..cffb4bc050 --- /dev/null +++ b/python/paddle/autograd/primreg.py @@ -0,0 +1,54 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools + + +class Registry(object): + """ A general registry object. """ + __slots__ = ['name', 'tab'] + + def __init__(self, name): + self.name = name + self.tab = {} + + def register(self, name, value): + assert name not in self.tab + self.tab[name] = value + + def lookup(self, name): + assert name in self.tab, f'No registry entry is found with name: {name}' + return self.tab[name] + + +_primop_fn = Registry('primop_fn') +_orig2prim = Registry('orig2prim') +_prim2orig = Registry('prim2orig') +_primop_jvp = Registry('primop_jvp') +_primop_transpose = Registry('primop_transpose') +_primop_position_argnames = Registry('primop_position_argnames') + + +def REGISTER_FN(op_type, *position_argnames): + """Decorator for registering the Python function for a primitive op.""" + + assert isinstance(op_type, str) + + _primop_position_argnames.register(op_type, position_argnames) + + def wrapper(f): + _primop_fn.register(op_type, f) + return f + + return wrapper diff --git a/python/paddle/fluid/tests/unittests/test_primops.py b/python/paddle/fluid/tests/unittests/test_primops.py new file mode 100644 index 0000000000..cbf77c2666 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_primops.py @@ -0,0 +1,147 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np + +import paddle +from paddle.autograd.primops import ( + neg, set_value, add, sub, mul, div, sqrt, tanh, reshape, broadcast, + transpose, split, concat, reduce, matmul, slice_select, slice_assign, + gather, scatter_add, fill_const) + + +class TestPyPrimOps(unittest.TestCase): + """ Test Python wrappers of primitive ops. """ + + def setUp(self): + paddle.enable_static() + + def test_ops(self): + A = np.random.rand(1) + B = np.random.rand(2) + C = np.random.rand(2, 3) + D = np.random.rand(2, 3) + E = np.random.rand(3, 2) + + a = paddle.static.data(name='A', shape=A.shape, dtype='float32') + b = paddle.static.data(name='B', shape=B.shape, dtype='float32') + c = paddle.static.data(name='C', shape=C.shape, dtype='float32') + d = paddle.static.data(name='D', shape=D.shape, dtype='float32') + e = paddle.static.data(name='E', shape=E.shape, dtype='float32') + + add_1 = add(a, a) + self.assertEqual(add_1.dtype, a.dtype) + self.assertEqual(add_1.shape, a.shape) + + add_2 = add(c, d) + self.assertEqual(add_2.dtype, c.dtype) + self.assertEqual(add_2.shape, c.shape) + + sub_1 = sub(c, d) + self.assertEqual(sub_1.dtype, c.dtype) + self.assertEqual(sub_1.shape, c.shape) + + mul_1 = mul(c, d) + self.assertEqual(mul_1.dtype, c.dtype) + self.assertEqual(mul_1.shape, c.shape) + + div_1 = div(c, d) + self.assertEqual(div_1.dtype, c.dtype) + self.assertEqual(div_1.shape, c.shape) + + sqrt_1 = sqrt(b) + self.assertEqual(sqrt_1.dtype, b.dtype) + self.assertEqual(sqrt_1.shape, b.shape) + + tanh_1 = tanh(d) + self.assertEqual(tanh_1.dtype, d.dtype) + self.assertEqual(tanh_1.shape, d.shape) + + reshape_1 = reshape(c, d.shape) + self.assertEqual(reshape_1.dtype, c.dtype) + self.assertEqual(reshape_1.shape, d.shape) + + broadcast_1 = broadcast(b, e.shape) + self.assertEqual(broadcast_1.dtype, b.dtype) + self.assertEqual(broadcast_1.shape, e.shape) + + transpose_1 = transpose(c, axis=[1, 0]) + self.assertEqual(transpose_1.dtype, c.dtype) + self.assertEqual(transpose_1.shape, e.shape) + + split_1_0, split_1_1 = split(c, num_or_sections=[1, 2], axis=1) + self.assertEqual(split_1_0.dtype, c.dtype) + self.assertEqual(split_1_0.shape, (2, 1)) + self.assertEqual(split_1_1.shape, (2, 2)) + + concat_1 = concat([c, d], axis=0) + self.assertEqual(concat_1.dtype, c.dtype) + self.assertEqual(concat_1.shape, (4, 3)) + + reduce_1 = reduce(d, axis=[1]) + self.assertEqual(reduce_1.dtype, d.dtype) + self.assertEqual(reduce_1.shape, (2, )) + + reduce_2 = reduce(c, axis=[0, 1]) + self.assertEqual(reduce_2.dtype, c.dtype) + self.assertEqual(reduce_2.shape, (1, )) + # TODO: reduce + keepdim + + matmul_1 = matmul(d, e) + self.assertEqual(matmul_1.dtype, d.dtype) + self.assertEqual(matmul_1.shape, (2, 2)) + + slice_select_1 = slice_select( + e, axis=[0], starts=[0], ends=[2], strides=[1]) + self.assertEqual(slice_select_1.dtype, e.dtype) + self.assertEqual(slice_select_1.shape, (2, 2)) + + slice_select_2 = slice_select( + d, axis=[0, 1], starts=[0, 1], ends=[2, 3], strides=[1, 2]) + self.assertEqual(slice_select_2.dtype, d.dtype) + self.assertEqual(slice_select_2.shape, (2, 1)) + + y = broadcast(b, [2, 2]) + slice_assign_1 = slice_assign( + d, y, axis=[1], starts=[1], ends=[3], strides=[1]) + self.assertEqual(slice_assign_1.dtype, d.dtype) + self.assertEqual(slice_assign_1.shape, d.shape) + + index = paddle.static.data('index', shape=[5], dtype='int32') + gather_1 = gather(e, index, axis=0) + self.assertEqual(gather_1.dtype, e.dtype) + self.assertEqual(gather_1.shape, (5, 2)) + + y = paddle.rand([5, 2], dtype='float32') + scatter_add_1 = scatter_add(e, y, index, axis=0) + self.assertEqual(scatter_add_1.dtype, e.dtype) + self.assertEqual(scatter_add_1.shape, e.shape) + + fill_const_1 = fill_const(value=10, shape=a.shape, dtype=a.dtype) + self.assertEqual(fill_const_1.shape, a.shape) + self.assertEqual(fill_const_1.dtype, a.dtype) + + neg_1 = neg(x=b) + self.assertEqual(neg_1.shape, b.shape) + self.assertEqual(neg_1.dtype, b.dtype) + + set_value_1 = set_value( + d, a, axis=[1], starts=[1], ends=[3], strides=[1], out=d) + self.assertEqual(set_value_1.shape, d.shape) + self.assertEqual(set_value_1.dtype, d.dtype) + + +if __name__ == '__main__': + unittest.main() -- GitLab