未验证 提交 9637d963 编写于 作者: C Chengmo 提交者: GitHub

update index sample (#27839)

* update index sample
上级 6d63cd2b
......@@ -15,6 +15,8 @@
from __future__ import print_function
import unittest
import paddle
import paddle.fluid as fluid
import numpy as np
from op_test import OpTest
......@@ -98,9 +100,7 @@ class TestCase4(TestIndexSampleOp):
class TestIndexSampleShape(unittest.TestCase):
def test_shape(self):
import paddle.fluid as fluid
import paddle
paddle.enable_static()
# create x value
x_shape = (2, 5)
x_type = "float64"
......@@ -124,5 +124,22 @@ class TestIndexSampleShape(unittest.TestCase):
res = exe.run(feed=feed, fetch_list=[output])
class TestIndexSampleDynamic(unittest.TestCase):
def test_result(self):
with fluid.dygraph.guard():
x = paddle.to_tensor(
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0]],
dtype='float32')
index = paddle.to_tensor(
[[0, 1, 2], [1, 2, 3], [0, 0, 0]], dtype='int32')
out_z1 = paddle.index_sample(x, index)
except_output = np.array(
[[1.0, 2.0, 3.0], [6.0, 7.0, 8.0], [9.0, 9.0, 9.0]])
assert out_z1.numpy().all() == except_output.all()
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
......@@ -570,9 +570,6 @@ def where(condition, x, y, name=None):
def index_sample(x, index):
"""
:alias_main: paddle.index_sample
:alias: paddle.index_sample,paddle.tensor.index_sample,paddle.tensor.search.index_sample
**IndexSample Layer**
IndexSample OP returns the element of the specified location of X,
......@@ -595,13 +592,13 @@ def index_sample(x, index):
[6, 8, 10]]
Args:
x (Variable): The source input tensor with 2-D shape. Supported data type is
x (Tensor): The source input tensor with 2-D shape. Supported data type is
int32, int64, float32, float64.
index (Variable): The index input tensor with 2-D shape, first dimension should be same with X.
index (Tensor): The index input tensor with 2-D shape, first dimension should be same with X.
Data type is int32 or int64.
Returns:
output (Variable): The output is a tensor with the same shape as index.
output (Tensor): The output is a tensor with the same shape as index.
Examples:
......@@ -609,7 +606,6 @@ def index_sample(x, index):
import paddle
paddle.disable_static()
x = paddle.to_tensor([[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0]], dtype='float32')
......@@ -644,8 +640,10 @@ def index_sample(x, index):
# [ 800 700]
# [1200 1100]]
"""
if in_dygraph_mode():
return core.ops.index_sample(x, index)
helper = LayerHelper("index_sample", **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'],
'paddle.tensor.search.index_sample')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册