Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
9637d963
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9637d963
编写于
10月 13, 2020
作者:
C
Chengmo
提交者:
GitHub
10月 13, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update index sample (#27839)
* update index sample
上级
6d63cd2b
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
26 addition
and
11 deletion
+26
-11
python/paddle/fluid/tests/unittests/test_index_sample_op.py
python/paddle/fluid/tests/unittests/test_index_sample_op.py
+20
-3
python/paddle/tensor/search.py
python/paddle/tensor/search.py
+6
-8
未找到文件。
python/paddle/fluid/tests/unittests/test_index_sample_op.py
浏览文件 @
9637d963
...
...
@@ -15,6 +15,8 @@
from
__future__
import
print_function
import
unittest
import
paddle
import
paddle.fluid
as
fluid
import
numpy
as
np
from
op_test
import
OpTest
...
...
@@ -98,9 +100,7 @@ class TestCase4(TestIndexSampleOp):
class
TestIndexSampleShape
(
unittest
.
TestCase
):
def
test_shape
(
self
):
import
paddle.fluid
as
fluid
import
paddle
paddle
.
enable_static
()
# create x value
x_shape
=
(
2
,
5
)
x_type
=
"float64"
...
...
@@ -124,5 +124,22 @@ class TestIndexSampleShape(unittest.TestCase):
res
=
exe
.
run
(
feed
=
feed
,
fetch_list
=
[
output
])
class
TestIndexSampleDynamic
(
unittest
.
TestCase
):
def
test_result
(
self
):
with
fluid
.
dygraph
.
guard
():
x
=
paddle
.
to_tensor
(
[[
1.0
,
2.0
,
3.0
,
4.0
],
[
5.0
,
6.0
,
7.0
,
8.0
],
[
9.0
,
10.0
,
11.0
,
12.0
]],
dtype
=
'float32'
)
index
=
paddle
.
to_tensor
(
[[
0
,
1
,
2
],
[
1
,
2
,
3
],
[
0
,
0
,
0
]],
dtype
=
'int32'
)
out_z1
=
paddle
.
index_sample
(
x
,
index
)
except_output
=
np
.
array
(
[[
1.0
,
2.0
,
3.0
],
[
6.0
,
7.0
,
8.0
],
[
9.0
,
9.0
,
9.0
]])
assert
out_z1
.
numpy
().
all
()
==
except_output
.
all
()
if
__name__
==
"__main__"
:
paddle
.
enable_static
()
unittest
.
main
()
python/paddle/tensor/search.py
浏览文件 @
9637d963
...
...
@@ -570,9 +570,6 @@ def where(condition, x, y, name=None):
def
index_sample
(
x
,
index
):
"""
:alias_main: paddle.index_sample
:alias: paddle.index_sample,paddle.tensor.index_sample,paddle.tensor.search.index_sample
**IndexSample Layer**
IndexSample OP returns the element of the specified location of X,
...
...
@@ -595,13 +592,13 @@ def index_sample(x, index):
[6, 8, 10]]
Args:
x (
Variable
): The source input tensor with 2-D shape. Supported data type is
x (
Tensor
): The source input tensor with 2-D shape. Supported data type is
int32, int64, float32, float64.
index (
Variable
): The index input tensor with 2-D shape, first dimension should be same with X.
index (
Tensor
): The index input tensor with 2-D shape, first dimension should be same with X.
Data type is int32 or int64.
Returns:
output (
Variable
): The output is a tensor with the same shape as index.
output (
Tensor
): The output is a tensor with the same shape as index.
Examples:
...
...
@@ -609,7 +606,6 @@ def index_sample(x, index):
import paddle
paddle.disable_static()
x = paddle.to_tensor([[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0]], dtype='float32')
...
...
@@ -644,8 +640,10 @@ def index_sample(x, index):
# [ 800 700]
# [1200 1100]]
"""
if
in_dygraph_mode
():
return
core
.
ops
.
index_sample
(
x
,
index
)
helper
=
LayerHelper
(
"index_sample"
,
**
locals
())
check_variable_and_dtype
(
x
,
'x'
,
[
'float32'
,
'float64'
,
'int32'
,
'int64'
],
'paddle.tensor.search.index_sample'
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录