未验证 提交 9637d963 编写于 作者: C Chengmo 提交者: GitHub

update index sample (#27839)

* update index sample
上级 6d63cd2b
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import paddle
import paddle.fluid as fluid
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
...@@ -98,9 +100,7 @@ class TestCase4(TestIndexSampleOp): ...@@ -98,9 +100,7 @@ class TestCase4(TestIndexSampleOp):
class TestIndexSampleShape(unittest.TestCase): class TestIndexSampleShape(unittest.TestCase):
def test_shape(self): def test_shape(self):
import paddle.fluid as fluid paddle.enable_static()
import paddle
# create x value # create x value
x_shape = (2, 5) x_shape = (2, 5)
x_type = "float64" x_type = "float64"
...@@ -124,5 +124,22 @@ class TestIndexSampleShape(unittest.TestCase): ...@@ -124,5 +124,22 @@ class TestIndexSampleShape(unittest.TestCase):
res = exe.run(feed=feed, fetch_list=[output]) res = exe.run(feed=feed, fetch_list=[output])
class TestIndexSampleDynamic(unittest.TestCase):
def test_result(self):
with fluid.dygraph.guard():
x = paddle.to_tensor(
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0]],
dtype='float32')
index = paddle.to_tensor(
[[0, 1, 2], [1, 2, 3], [0, 0, 0]], dtype='int32')
out_z1 = paddle.index_sample(x, index)
except_output = np.array(
[[1.0, 2.0, 3.0], [6.0, 7.0, 8.0], [9.0, 9.0, 9.0]])
assert out_z1.numpy().all() == except_output.all()
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static()
unittest.main() unittest.main()
...@@ -570,9 +570,6 @@ def where(condition, x, y, name=None): ...@@ -570,9 +570,6 @@ def where(condition, x, y, name=None):
def index_sample(x, index): def index_sample(x, index):
""" """
:alias_main: paddle.index_sample
:alias: paddle.index_sample,paddle.tensor.index_sample,paddle.tensor.search.index_sample
**IndexSample Layer** **IndexSample Layer**
IndexSample OP returns the element of the specified location of X, IndexSample OP returns the element of the specified location of X,
...@@ -595,13 +592,13 @@ def index_sample(x, index): ...@@ -595,13 +592,13 @@ def index_sample(x, index):
[6, 8, 10]] [6, 8, 10]]
Args: Args:
x (Variable): The source input tensor with 2-D shape. Supported data type is x (Tensor): The source input tensor with 2-D shape. Supported data type is
int32, int64, float32, float64. int32, int64, float32, float64.
index (Variable): The index input tensor with 2-D shape, first dimension should be same with X. index (Tensor): The index input tensor with 2-D shape, first dimension should be same with X.
Data type is int32 or int64. Data type is int32 or int64.
Returns: Returns:
output (Variable): The output is a tensor with the same shape as index. output (Tensor): The output is a tensor with the same shape as index.
Examples: Examples:
...@@ -609,7 +606,6 @@ def index_sample(x, index): ...@@ -609,7 +606,6 @@ def index_sample(x, index):
import paddle import paddle
paddle.disable_static()
x = paddle.to_tensor([[1.0, 2.0, 3.0, 4.0], x = paddle.to_tensor([[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0], [5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0]], dtype='float32') [9.0, 10.0, 11.0, 12.0]], dtype='float32')
...@@ -644,8 +640,10 @@ def index_sample(x, index): ...@@ -644,8 +640,10 @@ def index_sample(x, index):
# [ 800 700] # [ 800 700]
# [1200 1100]] # [1200 1100]]
""" """
if in_dygraph_mode():
return core.ops.index_sample(x, index)
helper = LayerHelper("index_sample", **locals()) helper = LayerHelper("index_sample", **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'], check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'],
'paddle.tensor.search.index_sample') 'paddle.tensor.search.index_sample')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册