未验证 提交 2f4c089b 编写于 作者: A andyjpaddle 提交者: GitHub

Add diff op (#37441)

* add diff op, test=develop

* rm some notes, test=develop

* update diff doc

* update sample code

* fix diff api params and example code, test=develop
上级 0c8b9994
......@@ -223,6 +223,7 @@ from .tensor.math import trunc # noqa: F401
from .tensor.math import digamma # noqa: F401
from .tensor.math import neg # noqa: F401
from .tensor.math import lgamma # noqa: F401
from .tensor.math import diff # noqa: F401
from .tensor.random import multinomial # noqa: F401
from .tensor.random import standard_normal # noqa: F401
......@@ -531,5 +532,6 @@ __all__ = [ # noqa
'broadcast_tensors',
'einsum',
'set_flags',
'get_flags'
'get_flags',
'diff'
]
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.fluid.core as core
class TestDiffOp(unittest.TestCase):
def set_args(self):
self.input = np.array([1, 4, 5, 2]).astype('float32')
self.n = 1
self.axis = -1
self.prepend = None
self.append = None
def get_output(self):
if self.prepend is not None and self.append is not None:
self.output = np.diff(
self.input,
n=self.n,
axis=self.axis,
prepend=self.prepend,
append=self.append)
elif self.prepend is not None:
self.output = np.diff(
self.input, n=self.n, axis=self.axis, prepend=self.prepend)
elif self.append is not None:
self.output = np.diff(
self.input, n=self.n, axis=self.axis, append=self.append)
else:
self.output = np.diff(self.input, n=self.n, axis=self.axis)
def setUp(self):
self.set_args()
self.get_output()
self.places = [paddle.CPUPlace()]
if core.is_compiled_with_cuda():
self.places.append(paddle.CUDAPlace(0))
def test_dygraph(self):
for place in self.places:
paddle.disable_static(place)
x = paddle.to_tensor(self.input, place=place)
if self.prepend is not None:
self.prepend = paddle.to_tensor(self.prepend, place=place)
if self.append is not None:
self.append = paddle.to_tensor(self.append, place=place)
out = paddle.diff(
x,
n=self.n,
axis=self.axis,
prepend=self.prepend,
append=self.append)
self.assertTrue((out.numpy() == self.output).all(), True)
def test_static(self):
paddle.enable_static()
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
with fluid.program_guard(fluid.Program(), fluid.Program()):
x = paddle.fluid.data(
name="input",
shape=self.input.shape,
dtype=self.input.dtype)
has_pend = False
prepend = None
append = None
if self.prepend is not None:
has_pend = True
prepend = paddle.fluid.data(
name="prepend",
shape=self.prepend.shape,
dtype=self.prepend.dtype)
if self.append is not None:
has_pend = True
append = paddle.fluid.data(
name="append",
shape=self.append.shape,
dtype=self.append.dtype)
exe = fluid.Executor(place)
out = paddle.diff(
x, n=self.n, axis=self.axis, prepend=prepend, append=append)
fetches = exe.run(fluid.default_main_program(),
feed={
"input": self.input,
"prepend": self.prepend,
"append": self.append
},
fetch_list=[out])
self.assertTrue((fetches[0] == self.output).all(), True)
def test_grad(self):
for place in self.places:
x = paddle.to_tensor(self.input, place=place, stop_gradient=False)
if self.prepend is not None:
self.prepend = paddle.to_tensor(self.prepend, place=place)
if self.append is not None:
self.append = paddle.to_tensor(self.append, place=place)
out = paddle.diff(
x,
n=self.n,
axis=self.axis,
prepend=self.prepend,
append=self.append)
try:
out.backward()
x_grad = x.grad
except:
raise RuntimeError("Check Diff Gradient Failed")
class TestDiffOpAxis(TestDiffOp):
def set_args(self):
self.input = np.array([[1, 4, 5, 2], [1, 5, 4, 2]]).astype('float32')
self.n = 1
self.axis = 0
self.prepend = None
self.append = None
class TestDiffOpNDim(TestDiffOp):
def set_args(self):
self.input = np.random.rand(10, 10).astype('float32')
self.n = 1
self.axis = -1
self.prepend = None
self.append = None
class TestDiffOpBool(TestDiffOp):
def set_args(self):
self.input = np.array([0, 1, 1, 0, 1, 0]).astype('bool')
self.n = 1
self.axis = -1
self.prepend = None
self.append = None
class TestDiffOpPrepend(TestDiffOp):
def set_args(self):
self.input = np.array([[1, 4, 5, 2], [1, 5, 4, 2]]).astype('float32')
self.n = 1
self.axis = -1
self.prepend = np.array([[2, 3, 4], [1, 3, 5]]).astype('float32')
self.append = None
class TestDiffOpPrependAxis(TestDiffOp):
def set_args(self):
self.input = np.array([[1, 4, 5, 2], [1, 5, 4, 2]]).astype('float32')
self.n = 1
self.axis = 0
self.prepend = np.array(
[[0, 2, 3, 4], [1, 3, 5, 7], [2, 5, 8, 0]]).astype('float32')
self.append = None
class TestDiffOpAppend(TestDiffOp):
def set_args(self):
self.input = np.array([[1, 4, 5, 2], [1, 5, 4, 2]]).astype('float32')
self.n = 1
self.axis = -1
self.prepend = None
self.append = np.array([[2, 3, 4], [1, 3, 5]]).astype('float32')
class TestDiffOpAppendAxis(TestDiffOp):
def set_args(self):
self.input = np.array([[1, 4, 5, 2], [1, 5, 4, 2]]).astype('float32')
self.n = 1
self.axis = 0
self.prepend = None
self.append = np.array([[2, 3, 4, 1]]).astype('float32')
class TestDiffOpPreAppend(TestDiffOp):
def set_args(self):
self.input = np.array([[1, 4, 5, 2], [1, 5, 4, 2]]).astype('float32')
self.n = 1
self.axis = -1
self.prepend = np.array([[0, 4], [5, 9]]).astype('float32')
self.append = np.array([[2, 3, 4], [1, 3, 5]]).astype('float32')
class TestDiffOpPreAppendAxis(TestDiffOp):
def set_args(self):
self.input = np.array([[1, 4, 5, 2], [1, 5, 4, 2]]).astype('float32')
self.n = 1
self.axis = 0
self.prepend = np.array([[0, 4, 5, 9], [5, 9, 2, 3]]).astype('float32')
self.append = np.array([[2, 3, 4, 7], [1, 3, 5, 6]]).astype('float32')
if __name__ == '__main__':
unittest.main()
......@@ -189,6 +189,7 @@ from .math import digamma # noqa: F401
from .math import neg # noqa: F401
from .math import lgamma # noqa: F401
from .math import diagonal # noqa: F401
from .math import diff # noqa: F401
from .random import multinomial # noqa: F401
from .random import standard_normal # noqa: F401
......@@ -400,7 +401,8 @@ tensor_method_func = [ #noqa
'uniform_',
'multi_dot',
'solve',
'triangular_solve'
'triangular_solve',
'diff'
]
#this list used in math_op_patch.py for magic_method bind
......
......@@ -2611,3 +2611,166 @@ def atan2(x, y, name=None):
helper.append_op(
type='atan2', inputs=inputs, outputs={'Out': out})
return out
def diff(x, n=1, axis=-1, prepend=None, append=None, name=None):
r"""
Computes the n-th forward difference along the given axis.
The first-order differences is computed by using the following formula:
.. math::
out[i] = x[i+1] - x[i]
Higher-order differences are computed by using paddle.diff() recursively.
Only n=1 is currently supported.
Args:
x(Tensor): The input tensor to compute the forward difference on
n(int, optional): The number of times to recursively compute the difference.
Only support n=1. Default:1
axis(int, optional): The axis to compute the difference along. Default:-1
prepend(Tensor, optional): The tensor to prepend to input along axis before computing the difference.
It's dimensions must be equivalent to that of x,
and its shapes must match x's shape except on axis.
append(Tensor, optional): The tensor to append to input along axis before computing the difference,
It's dimensions must be equivalent to that of x,
and its shapes must match x's shape except on axis.
name(str|None): A name for this layer(optional). If set None,
the layer will be named automatically.
Returns:
Tensor: The output tensor with same dtype with x.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([1, 4, 5, 2])
out = paddle.diff(x)
print(out)
# out:
# [3, 1, -3]
y = paddle.to_tensor([7, 9])
out = paddle.diff(x, append=y)
print(out)
# out:
# [3, 1, -3, 5, 2]
z = paddle.to_tensor([[1, 2, 3], [4, 5, 6]])
out = paddle.diff(z, axis=0)
print(out)
# out:
# [[3, 3, 3]]
out = paddle.diff(z, axis=1)
print(out)
# out:
# [[1, 1], [1, 1]]
"""
if axis < 0:
axis = axis + len(x.shape)
if axis > len(x.shape):
axis = len(x.shape)
if axis < 0:
axis = 0
dtype = x.dtype
axes = [axis]
infer_flags = list(1 for i in range(len(axes)))
if in_dygraph_mode():
has_pend = False
input_list = []
if prepend is not None and append is not None:
input_list = [prepend, x, append]
has_pend = True
elif prepend is not None:
input_list = [prepend, x]
has_pend = True
elif append is not None:
input_list = [x, append]
has_pend = True
if has_pend:
new_input = _C_ops.concat(input_list, 'axis', axis)
else:
new_input = x
attrs_1 = ()
attrs_2 = ()
dim_len = new_input.shape[axis]
starts_1 = [0]
attrs_1 += ('starts', starts_1)
ends_1 = [dim_len - 1]
attrs_1 += ('ends', ends_1)
input_front = _C_ops.slice(new_input, None, None, 'axes', axes, \
'infer_flags', infer_flags, *attrs_1)
starts_2 = [1]
attrs_2 += ('starts', starts_2)
ends_2 = [dim_len]
attrs_2 += ('ends', ends_2)
input_back = _C_ops.slice(new_input, None, None, 'axes', axes, \
'infer_flags', infer_flags, *attrs_2)
if x.dtype == paddle.bool:
op = getattr(_C_ops, "logical_xor")
out = op(input_back, input_front)
else:
out = layers.elementwise_sub(input_back, input_front, axis=axis)
return out
else:
check_variable_and_dtype(x, 'x', ['float32', 'float64', 'bool', 'int32', 'int64'], 'diff')
check_type(axis, 'axis', (int), 'diff')
helper = LayerHelper('diff', **locals())
has_pend = False
input_list = []
if prepend is not None and append is not None:
input_list = [prepend, x, append]
has_pend = True
elif prepend is not None:
input_list = [prepend, x]
has_pend = True
elif append is not None:
input_list = [x, append]
has_pend = True
if has_pend:
new_input = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='concat', inputs={'X': input_list}, outputs={'Out': [new_input]}, attrs={'axis': axis}
)
else:
new_input = x
dim_len = new_input.shape[axis]
attrs_1 = {'axes': axes}
starts_1 = [0]
ends_1 = [dim_len - 1]
attrs_1['starts'] = starts_1
attrs_1['ends'] = ends_1
input_front = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='slice', inputs={'Input': new_input}, attrs=attrs_1, outputs={'Out': input_front}
)
attrs_2 = {'axes': axes}
starts_2 = [1]
ends_2 = [dim_len]
attrs_2['starts'] = starts_2
attrs_2['ends'] = ends_2
input_back = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='slice', inputs={'Input': new_input}, attrs=attrs_2, outputs={'Out': input_back}
)
if dtype == paddle.bool:
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='logical_xor', inputs={"X": input_back, "Y": input_front}, outputs={"Out": out}
)
else:
out = layers.elementwise_sub(input_back, input_front, axis=axis)
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册