diff --git a/python/paddle/fluid/tests/unittests/test_index_sample_op.py b/python/paddle/fluid/tests/unittests/test_index_sample_op.py index bd71ca0c1c9e795a529fb12cab5c12a7478c9ba4..f640c0531192d65a686e0f21be5bedb9eb0497fb 100644 --- a/python/paddle/fluid/tests/unittests/test_index_sample_op.py +++ b/python/paddle/fluid/tests/unittests/test_index_sample_op.py @@ -15,6 +15,8 @@ from __future__ import print_function import unittest +import paddle +import paddle.fluid as fluid import numpy as np from op_test import OpTest @@ -98,9 +100,7 @@ class TestCase4(TestIndexSampleOp): class TestIndexSampleShape(unittest.TestCase): def test_shape(self): - import paddle.fluid as fluid - import paddle - + paddle.enable_static() # create x value x_shape = (2, 5) x_type = "float64" @@ -124,5 +124,22 @@ class TestIndexSampleShape(unittest.TestCase): res = exe.run(feed=feed, fetch_list=[output]) +class TestIndexSampleDynamic(unittest.TestCase): + def test_result(self): + with fluid.dygraph.guard(): + x = paddle.to_tensor( + [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], + [9.0, 10.0, 11.0, 12.0]], + dtype='float32') + index = paddle.to_tensor( + [[0, 1, 2], [1, 2, 3], [0, 0, 0]], dtype='int32') + out_z1 = paddle.index_sample(x, index) + + except_output = np.array( + [[1.0, 2.0, 3.0], [6.0, 7.0, 8.0], [9.0, 9.0, 9.0]]) + assert out_z1.numpy().all() == except_output.all() + + if __name__ == "__main__": + paddle.enable_static() unittest.main() diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 19d8fc58b0e7e7162c777ac1a56c3b9c5ac08283..7f722d1957b96bc4fae8414ce882c7e05489b165 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -570,9 +570,6 @@ def where(condition, x, y, name=None): def index_sample(x, index): """ - :alias_main: paddle.index_sample - :alias: paddle.index_sample,paddle.tensor.index_sample,paddle.tensor.search.index_sample - **IndexSample Layer** IndexSample OP returns the element of the specified location of X, @@ -595,13 +592,13 @@ def index_sample(x, index): [6, 8, 10]] Args: - x (Variable): The source input tensor with 2-D shape. Supported data type is + x (Tensor): The source input tensor with 2-D shape. Supported data type is int32, int64, float32, float64. - index (Variable): The index input tensor with 2-D shape, first dimension should be same with X. + index (Tensor): The index input tensor with 2-D shape, first dimension should be same with X. Data type is int32 or int64. Returns: - output (Variable): The output is a tensor with the same shape as index. + output (Tensor): The output is a tensor with the same shape as index. Examples: @@ -609,7 +606,6 @@ def index_sample(x, index): import paddle - paddle.disable_static() x = paddle.to_tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]], dtype='float32') @@ -644,8 +640,10 @@ def index_sample(x, index): # [ 800 700] # [1200 1100]] - """ + if in_dygraph_mode(): + return core.ops.index_sample(x, index) + helper = LayerHelper("index_sample", **locals()) check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'], 'paddle.tensor.search.index_sample')