未验证 提交 5f1a8e46 编写于 作者: Skr.B's avatar Skr.B 提交者: GitHub

【PaddlePaddle Hackathon 3 No.16】为 Paddle 新增 API paddle.take (#44741)

上级 871e3329
......@@ -280,6 +280,7 @@ from .tensor.math import outer # noqa: F401
from .tensor.math import heaviside # noqa: F401
from .tensor.math import frac # noqa: F401
from .tensor.math import sgn # noqa: F401
from .tensor.math import take # noqa: F401
from .tensor.random import bernoulli # noqa: F401
from .tensor.random import poisson # noqa: F401
......@@ -656,4 +657,5 @@ __all__ = [ # noqa
'tril_indices',
'sgn',
'triu_indices',
'take',
]
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard
class TestTakeAPI(unittest.TestCase):
def set_mode(self):
self.mode = 'raise'
def set_dtype(self):
self.input_dtype = 'float64'
self.index_dtype = 'int64'
def set_input(self):
self.input_shape = [3, 4]
self.index_shape = [2, 3]
self.input_np = np.arange(0, 12).reshape(self.input_shape).astype(
self.input_dtype)
self.index_np = np.arange(-4, 2).reshape(self.index_shape).astype(
self.index_dtype)
def setUp(self):
self.set_mode()
self.set_dtype()
self.set_input()
self.place = fluid.CUDAPlace(
0) if core.is_compiled_with_cuda() else fluid.CPUPlace()
def test_static_graph(self):
paddle.enable_static()
startup_program = Program()
train_program = Program()
with program_guard(startup_program, train_program):
x = fluid.data(name='input',
dtype=self.input_dtype,
shape=self.input_shape)
index = fluid.data(name='index',
dtype=self.index_dtype,
shape=self.index_shape)
out = paddle.take(x, index, mode=self.mode)
exe = fluid.Executor(self.place)
st_result = exe.run(fluid.default_main_program(),
feed={
'input': self.input_np,
'index': self.index_np
},
fetch_list=out)
np.testing.assert_allclose(
st_result[0],
np.take(self.input_np, self.index_np, mode=self.mode))
def test_dygraph(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.input_np)
index = paddle.to_tensor(self.index_np)
dy_result = paddle.take(x, index, mode=self.mode)
np.testing.assert_allclose(
np.take(self.input_np, self.index_np, mode=self.mode),
dy_result.numpy())
class TestTakeInt32(TestTakeAPI):
"""Test take API with data type int32"""
def set_dtype(self):
self.input_dtype = 'int32'
self.index_dtype = 'int64'
class TestTakeInt64(TestTakeAPI):
"""Test take API with data type int64"""
def set_dtype(self):
self.input_dtype = 'int64'
self.index_dtype = 'int64'
class TestTakeFloat32(TestTakeAPI):
"""Test take API with data type float32"""
def set_dtype(self):
self.input_dtype = 'float32'
self.index_dtype = 'int64'
class TestTakeTypeError(TestTakeAPI):
"""Test take Type Error"""
def test_static_type_error(self):
"""Argument 'index' must be Tensor"""
paddle.enable_static()
with program_guard(Program()):
x = fluid.data(name='input',
dtype=self.input_dtype,
shape=self.input_shape)
self.assertRaises(TypeError, paddle.take, x, self.index_np,
self.mode)
def test_dygraph_type_error(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.input_np)
self.assertRaises(TypeError, paddle.take, x, self.index_np, self.mode)
def test_static_dtype_error(self):
"""Data type of argument 'index' must be in [paddle.int32, paddle.int64]"""
paddle.enable_static()
with program_guard(Program()):
x = fluid.data(name='input',
dtype='float64',
shape=self.input_shape)
index = fluid.data(name='index',
dtype='float32',
shape=self.index_shape)
self.assertRaises(TypeError, paddle.take, x, index, self.mode)
def test_dygraph_dtype_error(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.input_np)
index = paddle.to_tensor(self.index_np, dtype='float32')
self.assertRaises(TypeError, paddle.take, x, index, self.mode)
class TestTakeModeRaisePos(unittest.TestCase):
"""Test positive index out of range error"""
def set_mode(self):
self.mode = 'raise'
def set_dtype(self):
self.input_dtype = 'float64'
self.index_dtype = 'int64'
def set_input(self):
self.input_shape = [3, 4]
self.index_shape = [5, 6]
self.input_np = np.arange(0, 12).reshape(self.input_shape).astype(
self.input_dtype)
self.index_np = np.arange(-10, 20).reshape(self.index_shape).astype(
self.index_dtype) # positive indices are out of range
def setUp(self):
self.set_mode()
self.set_dtype()
self.set_input()
self.place = fluid.CUDAPlace(
0) if core.is_compiled_with_cuda() else fluid.CPUPlace()
def test_static_index_error(self):
"""When the index is out of range,
an error is reported directly through `paddle.index_select`"""
paddle.enable_static()
with program_guard(Program()):
x = fluid.data(name='input',
dtype=self.input_dtype,
shape=self.input_shape)
index = fluid.data(name='index',
dtype=self.index_dtype,
shape=self.index_shape)
self.assertRaises(ValueError, paddle.index_select, x, index)
def test_dygraph_index_error(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.input_np)
index = paddle.to_tensor(self.index_np, dtype=self.index_dtype)
self.assertRaises(ValueError, paddle.index_select, x, index)
class TestTakeModeRaiseNeg(TestTakeModeRaisePos):
"""Test negative index out of range error"""
def set_mode(self):
self.mode = 'raise'
def set_dtype(self):
self.input_dtype = 'float64'
self.index_dtype = 'int64'
def set_input(self):
self.input_shape = [3, 4]
self.index_shape = [5, 6]
self.input_np = np.arange(0, 12).reshape(self.input_shape).astype(
self.input_dtype)
self.index_np = np.arange(-20, 10).reshape(self.index_shape).astype(
self.index_dtype) # negative indices are out of range
def setUp(self):
self.set_mode()
self.set_dtype()
self.set_input()
self.place = fluid.CUDAPlace(
0) if core.is_compiled_with_cuda() else fluid.CPUPlace()
class TestTakeModeWrap(TestTakeAPI):
"""Test take index out of range mode"""
def set_mode(self):
self.mode = 'wrap'
def set_input(self):
self.input_shape = [3, 4]
self.index_shape = [5, 8]
self.input_np = np.arange(0, 12).reshape(self.input_shape).astype(
self.input_dtype)
self.index_np = np.arange(-20, 20).reshape(self.index_shape).astype(
self.index_dtype) # Both ends of the index are out of bounds
class TestTakeModeClip(TestTakeAPI):
"""Test take index out of range mode"""
def set_mode(self):
self.mode = 'clip'
def set_input(self):
self.input_shape = [3, 4]
self.index_shape = [5, 8]
self.input_np = np.arange(0, 12).reshape(self.input_shape).astype(
self.input_dtype)
self.index_np = np.arange(-20, 20).reshape(self.index_shape).astype(
self.index_dtype) # Both ends of the index are out of bounds
if __name__ == "__main__":
unittest.main()
......@@ -234,6 +234,7 @@ from .math import outer # noqa: F401
from .math import heaviside # noqa: F401
from .math import frac # noqa: F401
from .math import sgn # noqa: F401
from .math import take # noqa: F401
from .random import multinomial # noqa: F401
from .random import standard_normal # noqa: F401
......@@ -280,8 +281,8 @@ from .array import create_array # noqa: F401
from .einsum import einsum # noqa: F401
#this list used in math_op_patch.py for _binary_creator_
tensor_method_func = [ #noqa
# this list used in math_op_patch.py for _binary_creator_
tensor_method_func = [ # noqa
'matmul',
'dot',
'cov',
......@@ -505,11 +506,12 @@ tensor_method_func = [ #noqa
'put_along_axis_',
'exponential_',
'heaviside',
'take',
'bucketize',
'sgn',
]
#this list used in math_op_patch.py for magic_method bind
# this list used in math_op_patch.py for magic_method bind
magic_method_func = [
('__and__', 'bitwise_and'),
('__or__', 'bitwise_or'),
......
......@@ -4748,7 +4748,6 @@ def frac(x, name=None):
type="trunc", inputs=inputs, attrs=attrs, outputs={"Out": y})
return _elementwise_op(LayerHelper(op_type, **locals()))
def sgn(x, name=None):
"""
For complex tensor, this API returns a new tensor whose elements have the same angles as the corresponding
......@@ -4789,3 +4788,105 @@ def sgn(x, name=None):
return paddle.as_complex(output)
else:
return paddle.sign(x)
def take(x, index, mode='raise', name=None):
"""
Returns a new tensor with the elements of input tensor x at the given index.
The input tensor is treated as if it were viewed as a 1-D tensor.
The result takes the same shape as the index.
Args:
x (Tensor): An N-D Tensor, its data type should be int32, int64, float32, float64.
index (Tensor): An N-D Tensor, its data type should be int32, int64.
mode (str, optional): Specifies how out-of-bounds index will behave. the candicates are ``'raise'``, ``'wrap'`` and ``'clip'``.
- ``'raise'``: raise an error (default);
- ``'wrap'``: wrap around;
- ``'clip'``: clip to the range. ``'clip'`` mode means that all indices that are too large are replaced by the index that addresses the last element. Note that this disables indexing with negative numbers.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor, Tensor with the same shape as index, the data type is the same with input.
Examples:
.. code-block:: python
import paddle
x_int = paddle.arange(0, 12).reshape([3, 4])
x_float = x_int.astype(paddle.float64)
idx_pos = paddle.arange(4, 10).reshape([2, 3]) # positive index
idx_neg = paddle.arange(-2, 4).reshape([2, 3]) # negative index
idx_err = paddle.arange(-2, 13).reshape([3, 5]) # index out of range
paddle.take(x_int, idx_pos)
# Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True,
# [[4, 5, 6],
# [7, 8, 9]])
paddle.take(x_int, idx_neg)
# Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True,
# [[10, 11, 0 ],
# [1 , 2 , 3 ]])
paddle.take(x_float, idx_pos)
# Tensor(shape=[2, 3], dtype=float64, place=Place(cpu), stop_gradient=True,
# [[4., 5., 6.],
# [7., 8., 9.]])
x_int.take(idx_pos)
# Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True,
# [[4, 5, 6],
# [7, 8, 9]])
paddle.take(x_int, idx_err, mode='wrap')
# Tensor(shape=[3, 5], dtype=int32, place=Place(cpu), stop_gradient=True,
# [[10, 11, 0 , 1 , 2 ],
# [3 , 4 , 5 , 6 , 7 ],
# [8 , 9 , 10, 11, 0 ]])
paddle.take(x_int, idx_err, mode='clip')
# Tensor(shape=[3, 5], dtype=int32, place=Place(cpu), stop_gradient=True,
# [[0 , 0 , 0 , 1 , 2 ],
# [3 , 4 , 5 , 6 , 7 ],
# [8 , 9 , 10, 11, 11]])
"""
if mode not in ['raise', 'wrap', 'clip']:
raise ValueError(
"'mode' in 'take' should be 'raise', 'wrap', 'clip', but received {}.".format(mode))
if paddle.in_dynamic_mode():
if not isinstance(index, (paddle.Tensor, Variable)):
raise TypeError(
"The type of 'index' must be Tensor, but got {}".format(type(index)))
if index.dtype not in [paddle.int32, paddle.int64]:
raise TypeError(
"The data type of 'index' must be one of ['int32', 'int64'], but got {}".format(
index.dtype))
else:
check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'take')
input_1d = x.flatten()
index_1d = index.flatten()
max_index = input_1d.shape[-1]
if mode == 'raise':
# This processing enables 'take' to handle negative indexes within the correct range.
index_1d = paddle.where(index_1d < 0, index_1d + max_index, index_1d)
elif mode == 'wrap':
# The out of range indices are constrained by taking the remainder.
index_1d = paddle.where(index_1d < 0,
index_1d % max_index, index_1d)
index_1d = paddle.where(index_1d >= max_index,
index_1d % max_index, index_1d)
elif mode == 'clip':
# 'clip' mode disables indexing with negative numbers.
index_1d = clip(index_1d, 0, max_index - 1)
out = input_1d.index_select(index_1d).reshape(index.shape)
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册