未验证 提交 e4017e5c 编写于 作者: W wangchaochaohu 提交者: GitHub

refine the index_select Op for API 2.0 test=develop (#25296)

上级 c10dcff1
...@@ -83,7 +83,7 @@ class TestIndexSelectAPI(unittest.TestCase): ...@@ -83,7 +83,7 @@ class TestIndexSelectAPI(unittest.TestCase):
x = fluid.layers.data(name='x', shape=[-1, 4]) x = fluid.layers.data(name='x', shape=[-1, 4])
index = fluid.layers.data( index = fluid.layers.data(
name='index', shape=[3], dtype='int32', append_batch_size=False) name='index', shape=[3], dtype='int32', append_batch_size=False)
z = paddle.index_select(x, index, dim=1) z = paddle.index_select(x, index, axis=1)
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
res, = exe.run(feed={'x': self.data_x, res, = exe.run(feed={'x': self.data_x,
'index': self.data_index}, 'index': self.data_index},
...@@ -124,7 +124,7 @@ class TestIndexSelectAPI(unittest.TestCase): ...@@ -124,7 +124,7 @@ class TestIndexSelectAPI(unittest.TestCase):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(self.data_x) x = fluid.dygraph.to_variable(self.data_x)
index = fluid.dygraph.to_variable(self.data_index) index = fluid.dygraph.to_variable(self.data_index)
z = paddle.index_select(x, index, dim=1) z = paddle.index_select(x, index, axis=1)
np_z = z.numpy() np_z = z.numpy()
expect_out = np.array([[1.0, 2.0, 2.0], [5.0, 6.0, 6.0], expect_out = np.array([[1.0, 2.0, 2.0], [5.0, 6.0, 6.0],
[9.0, 10.0, 10.0]]) [9.0, 10.0, 10.0]])
......
...@@ -63,9 +63,9 @@ def argmax(input, axis=None, dtype=None, out=None, keepdims=False, name=None): ...@@ -63,9 +63,9 @@ def argmax(input, axis=None, dtype=None, out=None, keepdims=False, name=None):
Variable that meets the requirements to store the result of operation. Variable that meets the requirements to store the result of operation.
if out is None, a new Varibale will be create to store the result. Defalut is None. if out is None, a new Varibale will be create to store the result. Defalut is None.
keepdims(bool, optional): Keep the axis that do the select max. keepdims(bool, optional): Keep the axis that do the select max.
name(str, optional): The name of output variable, normally there is no need for user to set this this property. name(str, optional): The default value is None. Normally there is no
Default value is None, the framework set the name of output variable. need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Returns: Returns:
Variable: A Tensor with data type int64. Variable: A Tensor with data type int64.
...@@ -135,7 +135,7 @@ def argmax(input, axis=None, dtype=None, out=None, keepdims=False, name=None): ...@@ -135,7 +135,7 @@ def argmax(input, axis=None, dtype=None, out=None, keepdims=False, name=None):
return out return out
def index_select(input, index, dim=0): def index_select(x, index, axis=0, name=None):
""" """
:alias_main: paddle.index_select :alias_main: paddle.index_select
:alias: paddle.index_select,paddle.tensor.index_select,paddle.tensor.search.index_select :alias: paddle.index_select,paddle.tensor.index_select,paddle.tensor.search.index_select
...@@ -146,56 +146,60 @@ def index_select(input, index, dim=0): ...@@ -146,56 +146,60 @@ def index_select(input, index, dim=0):
size as the length of `index`; other dimensions have the same size as in the `input` tensor. size as the length of `index`; other dimensions have the same size as in the `input` tensor.
Args: Args:
input (Variable): The input tensor variable. x (Variable): The input tensor variable.The dtype of x can be one of float32, float64, int32, int64.
index (Variable): The 1-D tensor containing the indices to index. index (Variable): The 1-D tensor containing the indices to index.the dtype of index can be int32 or int64.
dim (int): The dimension in which we index. axis (int, optional): The dimension in which we index. Default: if None, the axis is 0.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Returns: Returns:
Variable: A Tensor with same data type as `input`. Variable: A Tensor with same data type as `input`.
Raises:
TypeError: x must be a Variable and the dtype of x must be one of float32, float64, int32 and int64.
TypeError: index must be a Variable adn the dtype of index must be int32 or int64.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
import paddle.fluid as fluid
import numpy as np import numpy as np
paddle.enable_imperative() # Now we are in imperative mode
data = np.array([[1.0, 2.0, 3.0, 4.0], data = np.array([[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0], [5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0]]) [9.0, 10.0, 11.0, 12.0]])
data_index = np.array([0, 1, 1]).astype('int32') data_index = np.array([0, 1, 1]).astype('int32')
with fluid.dygraph.guard(): x = paddle.imperative.to_variable(data)
x = fluid.dygraph.to_variable(data) index = paddle.imperative.to_variable(data_index)
index = fluid.dygraph.to_variable(data_index) out_z1 = paddle.index_select(x=x, index=index)
out_z1 = paddle.index_select(x, index) #[[1. 2. 3. 4.]
print(out_z1.numpy()) # [5. 6. 7. 8.]
#[[1. 2. 3. 4.] # [5. 6. 7. 8.]]
# [5. 6. 7. 8.] out_z2 = paddle.index_select(x=x, index=index, axis=1)
# [5. 6. 7. 8.]] #[[ 1. 2. 2.]
out_z2 = paddle.index_select(x, index, dim=1) # [ 5. 6. 6.]
print(out_z2.numpy()) # [ 9. 10. 10.]]
#[[ 1. 2. 2.]
# [ 5. 6. 6.]
# [ 9. 10. 10.]]
""" """
helper = LayerHelper("index_select", **locals())
if in_dygraph_mode(): if in_dygraph_mode():
return core.ops.index_select(input, index, 'dim', dim) return core.ops.index_select(x, index, 'dim', axis)
check_variable_and_dtype(input, 'x', helper = LayerHelper("index_select", **locals())
['float32', 'float64', 'int32', 'int64'], check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'],
'paddle.tensor.search.index_sample') 'paddle.tensor.search.index_select')
check_variable_and_dtype(index, 'index', ['int32', 'int64'], check_variable_and_dtype(index, 'index', ['int32', 'int64'],
'paddle.tensor.search.index_sample') 'paddle.tensor.search.index_select')
out = helper.create_variable_for_type_inference(input.dtype) out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op( helper.append_op(
type='index_select', type='index_select',
inputs={'X': input, inputs={'X': x,
'Index': index}, 'Index': index},
outputs={'Out': out}, outputs={'Out': out},
attrs={'dim': dim}) attrs={'dim': axis})
return out return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册