未验证 提交 ebf4fe6e 编写于 作者: L levi131 提交者: GitHub

Lml/prim op pywrapper (#41813)

* native commit for triple grad of sigmod

* Updated unittests files

* init functional jacobian api

* Updated trible_test func

* Updated gradient_checker & test_script

* finish test with dtype float32

* add float64 test case

* polish code

* use atol=1e-5 with dtype float64

* fix for ci

* set timeout for test_jacobian

* fix dygraph grad to support high differential

* polish API docstring

* Updated gradient checker and some related files

* fix double grad strip error for high differential

* fix double grad strip error for high differential

* Add Sigmoid triple grad tests

* fix dygraph double grad dtype error when calling for high differential senario

* Updated triple grad teses func

* Use np.random to initialize ddx

* Updated triple_grad_check func

* add todo for gradient checker and refine some comments

* remove additional code

* add test for warnging in backward.py

* format python code

* support multi input in triple gradient checker

* Add matmul triple grad kernel

* Updated comments of TODO

* Supported some special tests

* Change code-format to follow CI std

* Updated gradient_checker.py

* Fix conflicts

* Removed unnecessary printing log

* Change code style to follow CI std

* merge upstream

* add priops.py

* add_p

* rm useless files

* add sub_p mul_p div_p

* add sqrt_p and tanh_p

* add reshape_p

* add broadcast_p

* Add python primitive wrappers.

* Jvp rules updated.

* JVP rules done for all the 17 primops.

* quick check and fixes.

* add jvp(op, *args)

* add broadcast_p fill_constant_p matmul_p reduce_p reshape_p transpose_p

* add split_p and concat_p

* add gather_p and scatter_add_p

* add slice_select_p and slice_assign_p

* Add transpose rules.

* add multi input check for add_p, sub_p, mul_p, div_p

* update concat_p

* Linearize and transpose in progress..

* refine gather_p and scatter_add_p

* updated.

* update transpose.

* refine slice_assign_p and slice_select_p

* init commit for lower

* Merged with primitive ops.

* small update

* add rules for orig2prim and prim2orig

* add 9 test for prim ops

* add more test and fix some bug

* add more test

* register proto

* Adding primops test.

* add shape valid check for broadcast_p op, and add keepdim attr into reduce_p op proto

* support multi input and multi output for split_p and concat_p

* Test updated.

* update

* fix slice bug for slice_select_p and slice_assign_p

* updated.

* Ops updated.

* Refactor and bug fixes.

* updated.

* finish orig2prim and prim2orig rules

* dtype for axis attr should be long int

* update dtype for axis attr int64_t

* update for iscan CI

* Update primx.

* Refactor vars in primx.

* update for lower transform

* update primx.py

* update

* Fix linearize and transpose.

* Update is_dot

* Update is_dot

* Update is_dot

* add gradient aggregation, fix add_transpose.

* pass first linearize+transpose test.

* update test

* add_prim_op_pywrapper

* Add primops UT

* Fix set_value and update

* Fix code format and PR-CI-Coverage
Co-authored-by: Nveyron95 <veyron_wu@163.com>
Co-authored-by: NJiabin Yang <360788950@qq.com>
Co-authored-by: NTongxin Bai <waffle.bai@gmail.com>
Co-authored-by: N0x45f <wangzhen45@baidu.com>
上级 56dafc4f
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.fluid import unique_name, core
from paddle.fluid.framework import default_main_program, default_startup_program
from paddle.fluid.layer_helper import LayerHelper
from .primreg import REGISTER_FN
def _simple_unop(helper):
optype = helper.layer_type
x, out = tuple(map(helper.kwargs.get, ('x', 'out')))
if out is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type=optype, inputs={'X': x}, outputs={'Y': out}, attrs={})
return out
def _simple_binop(helper):
optype = helper.layer_type
x, y, out = tuple(map(helper.kwargs.get, ('x', 'y', 'out')))
if out is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type=optype, inputs={'X': x,
'Y': y}, outputs={'Z': out}, attrs={})
return out
def _manipulation_unop(helper):
optype = helper.layer_type
x, out = tuple(map(helper.kwargs.get, ('x', 'out')))
attrs = {
k: helper.kwargs[k]
for k in ('shape', 'axis', 'index') if k in helper.kwargs
}
if out is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type=optype, inputs={'X': x}, outputs={'Y': out}, attrs=attrs)
return out
# Each primitive op is given a Python constructor for sake of convenience.
def fill_const(value, shape, dtype, out=None):
attrs = {'value': value, 'shape': shape, 'dtype': dtype}
helper = LayerHelper('fill_constant_p', **locals())
if out is None:
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(type=helper.layer_type, outputs={'Y': out}, attrs=attrs)
return out
def neg(x, out=None):
zero = fill_const(0.0, x.shape, x.dtype)
return sub(zero, x)
def set_value(x, y, axis, starts, ends, strides, out):
assert x is out, "x and out should be the same Tensor in set_value"
attrs = {'axes': axis, 'starts': starts, 'ends': ends, 'steps': strides}
helper = LayerHelper('set_value', **locals())
helper.append_op(
type=helper.layer_type,
inputs={'Input': x,
'ValueTensor': y},
outputs={'Out': out},
attrs=attrs)
return out
@REGISTER_FN('add_p', 'X', 'Y', 'Z')
def add(x, y, out=None):
return _simple_binop(LayerHelper('add_p', **locals()))
@REGISTER_FN('sub_p', 'X', 'Y', 'Z')
def sub(x, y, out=None):
return _simple_binop(LayerHelper('sub_p', **locals()))
@REGISTER_FN('mul_p', 'X', 'Y', 'Z')
def mul(x, y, out=None):
return _simple_binop(LayerHelper('mul_p', **locals()))
@REGISTER_FN('div_p', 'X', 'Y', 'Z')
def div(x, y, out=None):
return _simple_binop(LayerHelper('div_p', **locals()))
@REGISTER_FN('sqrt_p', 'X', 'Y')
def sqrt(x, out=None):
return _simple_unop(LayerHelper('sqrt_p', **locals()))
@REGISTER_FN('tanh_p', 'X', 'Y')
def tanh(x, out=None):
return _simple_unop(LayerHelper('tanh_p', **locals()))
@REGISTER_FN('reshape_p', 'X', 'Y')
def reshape(x, shape, out=None):
return _manipulation_unop(LayerHelper('reshape_p', **locals()))
@REGISTER_FN('broadcast_p', 'X', 'Y')
def broadcast(x, shape, out=None):
return _manipulation_unop(LayerHelper('broadcast_p', **locals()))
@REGISTER_FN('transpose_p', 'X', 'Y')
def transpose(x, axis=None, out=None):
return _manipulation_unop(LayerHelper('transpose_p', **locals()))
@REGISTER_FN('split_p', 'X', 'YS')
def split(x, num_or_sections, axis=0, outs=None):
if isinstance(num_or_sections, (list, tuple)):
n = len(num_or_sections)
else:
assert isinstance(num_or_sections, int)
n = num_or_sections
attrs = {'num_or_sections': num_or_sections, 'axis': axis}
helper = LayerHelper('split_p', **locals())
if outs is None:
outs = [
helper.create_variable_for_type_inference(dtype=x.dtype)
for i in range(n)
]
helper.append_op(
type=helper.layer_type,
inputs={'X': x},
outputs={'YS': outs},
attrs=attrs)
return outs
@REGISTER_FN('concat_p', 'XS', 'Y')
def concat(xs, axis=0, out=None):
assert isinstance(xs, (list, tuple)) and len(xs) > 0
attrs = {'axis': axis}
helper = LayerHelper('concat_p', **locals())
if out is None:
out = helper.create_variable_for_type_inference(dtype=xs[0].dtype)
helper.append_op(
type=helper.layer_type,
inputs={'XS': xs},
outputs={'Y': out},
attrs=attrs)
return out
@REGISTER_FN('reduce_p', 'X', 'Y')
def reduce(x, axis, keepdim=False, out=None):
assert isinstance(axis, (tuple, list))
assert isinstance(keepdim, bool)
attrs = {'axis': axis, 'keepdim': keepdim}
helper = LayerHelper('reduce_p', **locals())
if out is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type=helper.layer_type,
inputs={'X': x},
outputs={'Y': out},
attrs=attrs)
return out
@REGISTER_FN('matmul_p', 'X', 'Y', 'Z')
def matmul(x, y, out=None):
return _simple_binop(LayerHelper('matmul_p', **locals()))
@REGISTER_FN('slice_select_p', 'X', 'Y')
def slice_select(x, axis, starts, ends, strides, out=None):
assert isinstance(axis, (list, tuple)), (
f'Argument type error. `axis` is supposed to be int, list or'
f' tuple but found {type(axis)}.')
assert isinstance(starts, (list, tuple))
assert isinstance(ends, (list, tuple))
assert len(axis) == len(starts) == len(ends) == len(strides)
attrs = {'axis': axis, 'starts': starts, 'ends': ends, 'strides': strides}
helper = LayerHelper('slice_select_p', **locals())
if out is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type=helper.layer_type,
inputs={'X': x},
outputs={'Y': out},
attrs=attrs)
return out
@REGISTER_FN('slice_assign_p', 'X', 'Y', 'Z')
def slice_assign(x, y, axis, starts, ends, strides, out=None):
assert len(starts) == len(ends) == len(strides) == len(axis)
assert len(y.shape) == len(x.shape)
attrs = {'axis': axis, 'starts': starts, 'ends': ends, 'strides': strides}
helper = LayerHelper('slice_assign_p', **locals())
if out is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type=helper.layer_type,
inputs={'X': x,
'Y': y},
outputs={'Z': out},
attrs=attrs)
return out
@REGISTER_FN('gather_p', 'X', 'Y')
def gather(x, indextensor, axis, out=None):
attrs = {'axis': axis}
helper = LayerHelper('gather_p', **locals())
if out is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type=helper.layer_type,
inputs={'X': x,
'IndexTensor': indextensor},
outputs={'Y': out},
attrs=attrs)
return out
@REGISTER_FN('scatter_add_p', 'X', 'Y', 'IndexTensor', 'Z')
def scatter_add(x, y, indextensor, axis, out=None):
assert len(x.shape) == len(y.shape)
assert len(indextensor.shape) == 1
assert y.shape[axis] == indextensor.shape[0]
attrs = {'axis': axis}
helper = LayerHelper('scatter_add_p', **locals())
if out is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type=helper.layer_type,
inputs={'X': x,
'Y': y,
'IndexTensor': indextensor},
outputs={'Z': out},
attrs=attrs)
return out
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
class Registry(object):
""" A general registry object. """
__slots__ = ['name', 'tab']
def __init__(self, name):
self.name = name
self.tab = {}
def register(self, name, value):
assert name not in self.tab
self.tab[name] = value
def lookup(self, name):
assert name in self.tab, f'No registry entry is found with name: {name}'
return self.tab[name]
_primop_fn = Registry('primop_fn')
_orig2prim = Registry('orig2prim')
_prim2orig = Registry('prim2orig')
_primop_jvp = Registry('primop_jvp')
_primop_transpose = Registry('primop_transpose')
_primop_position_argnames = Registry('primop_position_argnames')
def REGISTER_FN(op_type, *position_argnames):
"""Decorator for registering the Python function for a primitive op."""
assert isinstance(op_type, str)
_primop_position_argnames.register(op_type, position_argnames)
def wrapper(f):
_primop_fn.register(op_type, f)
return f
return wrapper
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.autograd.primops import (
neg, set_value, add, sub, mul, div, sqrt, tanh, reshape, broadcast,
transpose, split, concat, reduce, matmul, slice_select, slice_assign,
gather, scatter_add, fill_const)
class TestPyPrimOps(unittest.TestCase):
""" Test Python wrappers of primitive ops. """
def setUp(self):
paddle.enable_static()
def test_ops(self):
A = np.random.rand(1)
B = np.random.rand(2)
C = np.random.rand(2, 3)
D = np.random.rand(2, 3)
E = np.random.rand(3, 2)
a = paddle.static.data(name='A', shape=A.shape, dtype='float32')
b = paddle.static.data(name='B', shape=B.shape, dtype='float32')
c = paddle.static.data(name='C', shape=C.shape, dtype='float32')
d = paddle.static.data(name='D', shape=D.shape, dtype='float32')
e = paddle.static.data(name='E', shape=E.shape, dtype='float32')
add_1 = add(a, a)
self.assertEqual(add_1.dtype, a.dtype)
self.assertEqual(add_1.shape, a.shape)
add_2 = add(c, d)
self.assertEqual(add_2.dtype, c.dtype)
self.assertEqual(add_2.shape, c.shape)
sub_1 = sub(c, d)
self.assertEqual(sub_1.dtype, c.dtype)
self.assertEqual(sub_1.shape, c.shape)
mul_1 = mul(c, d)
self.assertEqual(mul_1.dtype, c.dtype)
self.assertEqual(mul_1.shape, c.shape)
div_1 = div(c, d)
self.assertEqual(div_1.dtype, c.dtype)
self.assertEqual(div_1.shape, c.shape)
sqrt_1 = sqrt(b)
self.assertEqual(sqrt_1.dtype, b.dtype)
self.assertEqual(sqrt_1.shape, b.shape)
tanh_1 = tanh(d)
self.assertEqual(tanh_1.dtype, d.dtype)
self.assertEqual(tanh_1.shape, d.shape)
reshape_1 = reshape(c, d.shape)
self.assertEqual(reshape_1.dtype, c.dtype)
self.assertEqual(reshape_1.shape, d.shape)
broadcast_1 = broadcast(b, e.shape)
self.assertEqual(broadcast_1.dtype, b.dtype)
self.assertEqual(broadcast_1.shape, e.shape)
transpose_1 = transpose(c, axis=[1, 0])
self.assertEqual(transpose_1.dtype, c.dtype)
self.assertEqual(transpose_1.shape, e.shape)
split_1_0, split_1_1 = split(c, num_or_sections=[1, 2], axis=1)
self.assertEqual(split_1_0.dtype, c.dtype)
self.assertEqual(split_1_0.shape, (2, 1))
self.assertEqual(split_1_1.shape, (2, 2))
concat_1 = concat([c, d], axis=0)
self.assertEqual(concat_1.dtype, c.dtype)
self.assertEqual(concat_1.shape, (4, 3))
reduce_1 = reduce(d, axis=[1])
self.assertEqual(reduce_1.dtype, d.dtype)
self.assertEqual(reduce_1.shape, (2, ))
reduce_2 = reduce(c, axis=[0, 1])
self.assertEqual(reduce_2.dtype, c.dtype)
self.assertEqual(reduce_2.shape, (1, ))
# TODO: reduce + keepdim
matmul_1 = matmul(d, e)
self.assertEqual(matmul_1.dtype, d.dtype)
self.assertEqual(matmul_1.shape, (2, 2))
slice_select_1 = slice_select(
e, axis=[0], starts=[0], ends=[2], strides=[1])
self.assertEqual(slice_select_1.dtype, e.dtype)
self.assertEqual(slice_select_1.shape, (2, 2))
slice_select_2 = slice_select(
d, axis=[0, 1], starts=[0, 1], ends=[2, 3], strides=[1, 2])
self.assertEqual(slice_select_2.dtype, d.dtype)
self.assertEqual(slice_select_2.shape, (2, 1))
y = broadcast(b, [2, 2])
slice_assign_1 = slice_assign(
d, y, axis=[1], starts=[1], ends=[3], strides=[1])
self.assertEqual(slice_assign_1.dtype, d.dtype)
self.assertEqual(slice_assign_1.shape, d.shape)
index = paddle.static.data('index', shape=[5], dtype='int32')
gather_1 = gather(e, index, axis=0)
self.assertEqual(gather_1.dtype, e.dtype)
self.assertEqual(gather_1.shape, (5, 2))
y = paddle.rand([5, 2], dtype='float32')
scatter_add_1 = scatter_add(e, y, index, axis=0)
self.assertEqual(scatter_add_1.dtype, e.dtype)
self.assertEqual(scatter_add_1.shape, e.shape)
fill_const_1 = fill_const(value=10, shape=a.shape, dtype=a.dtype)
self.assertEqual(fill_const_1.shape, a.shape)
self.assertEqual(fill_const_1.dtype, a.dtype)
neg_1 = neg(x=b)
self.assertEqual(neg_1.shape, b.shape)
self.assertEqual(neg_1.dtype, b.dtype)
set_value_1 = set_value(
d, a, axis=[1], starts=[1], ends=[3], strides=[1], out=d)
self.assertEqual(set_value_1.shape, d.shape)
self.assertEqual(set_value_1.dtype, d.dtype)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册