Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
f2a56c6a
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f2a56c6a
编写于
11月 15, 2021
作者:
Z
zyfncg
提交者:
GitHub
11月 15, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bug of indexing with ellipsis (#37182)
上级
10cc040d
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
19 addition
and
3 deletion
+19
-3
paddle/fluid/pybind/imperative.cc
paddle/fluid/pybind/imperative.cc
+8
-1
python/paddle/fluid/tests/unittests/test_var_base.py
python/paddle/fluid/tests/unittests/test_var_base.py
+5
-0
python/paddle/fluid/tests/unittests/test_variable.py
python/paddle/fluid/tests/unittests/test_variable.py
+6
-2
未找到文件。
paddle/fluid/pybind/imperative.cc
浏览文件 @
f2a56c6a
...
@@ -549,13 +549,20 @@ static void ParseIndexingSlice(
...
@@ -549,13 +549,20 @@ static void ParseIndexingSlice(
// specified_dims is the number of dimensions which indexed by Interger,
// specified_dims is the number of dimensions which indexed by Interger,
// Slices.
// Slices.
int
specified_dims
=
0
;
int
specified_dims
=
0
;
int
ell_count
=
0
;
for
(
int
dim
=
0
;
dim
<
size
;
++
dim
)
{
for
(
int
dim
=
0
;
dim
<
size
;
++
dim
)
{
PyObject
*
slice_item
=
PyTuple_GetItem
(
index
,
dim
);
PyObject
*
slice_item
=
PyTuple_GetItem
(
index
,
dim
);
if
(
PyCheckInteger
(
slice_item
)
||
PySlice_Check
(
slice_item
))
{
if
(
PyCheckInteger
(
slice_item
)
||
PySlice_Check
(
slice_item
))
{
specified_dims
++
;
specified_dims
++
;
}
else
if
(
slice_item
==
Py_Ellipsis
)
{
ell_count
++
;
}
}
}
}
PADDLE_ENFORCE_LE
(
ell_count
,
1
,
platform
::
errors
::
InvalidArgument
(
"An index can only have a single ellipsis ('...')"
));
for
(
int
i
=
0
,
dim
=
0
;
i
<
size
;
++
i
)
{
for
(
int
i
=
0
,
dim
=
0
;
i
<
size
;
++
i
)
{
PyObject
*
slice_item
=
PyTuple_GetItem
(
index
,
i
);
PyObject
*
slice_item
=
PyTuple_GetItem
(
index
,
i
);
...
@@ -660,7 +667,7 @@ static void ParseIndexingSlice(
...
@@ -660,7 +667,7 @@ static void ParseIndexingSlice(
}
}
// valid_index is the number of dimensions exclude None index
// valid_index is the number of dimensions exclude None index
const
int
valid_indexs
=
size
-
none_axes
->
size
();
const
int
valid_indexs
=
size
-
none_axes
->
size
()
-
ell_count
;
PADDLE_ENFORCE_EQ
(
valid_indexs
<=
rank
,
true
,
PADDLE_ENFORCE_EQ
(
valid_indexs
<=
rank
,
true
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Too many indices (%d) for tensor of dimension %d."
,
"Too many indices (%d) for tensor of dimension %d."
,
...
...
python/paddle/fluid/tests/unittests/test_var_base.py
浏览文件 @
f2a56c6a
...
@@ -702,6 +702,11 @@ class TestVarBase(unittest.TestCase):
...
@@ -702,6 +702,11 @@ class TestVarBase(unittest.TestCase):
assert_getitem_ellipsis_index
(
var_fp32
,
np_fp32_value
)
assert_getitem_ellipsis_index
(
var_fp32
,
np_fp32_value
)
assert_getitem_ellipsis_index
(
var_int
,
np_int_value
)
assert_getitem_ellipsis_index
(
var_int
,
np_int_value
)
# test 1 dim tensor
var_one_dim
=
paddle
.
to_tensor
([
1
,
2
,
3
,
4
])
self
.
assertTrue
(
np
.
array_equal
(
var_one_dim
[...,
0
].
numpy
(),
np
.
array
([
1
])))
def
_test_none_index
(
self
):
def
_test_none_index
(
self
):
shape
=
(
8
,
64
,
5
,
256
)
shape
=
(
8
,
64
,
5
,
256
)
np_value
=
np
.
random
.
random
(
shape
).
astype
(
'float32'
)
np_value
=
np
.
random
.
random
(
shape
).
astype
(
'float32'
)
...
...
python/paddle/fluid/tests/unittests/test_variable.py
浏览文件 @
f2a56c6a
...
@@ -226,19 +226,22 @@ class TestVariable(unittest.TestCase):
...
@@ -226,19 +226,22 @@ class TestVariable(unittest.TestCase):
prog
=
paddle
.
static
.
Program
()
prog
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
prog
):
with
paddle
.
static
.
program_guard
(
prog
):
x
=
paddle
.
assign
(
data
)
x
=
paddle
.
assign
(
data
)
y
=
paddle
.
assign
([
1
,
2
,
3
,
4
])
out1
=
x
[
0
:,
...,
1
:]
out1
=
x
[
0
:,
...,
1
:]
out2
=
x
[
0
:,
...]
out2
=
x
[
0
:,
...]
out3
=
x
[...,
1
:]
out3
=
x
[...,
1
:]
out4
=
x
[...]
out4
=
x
[...]
out5
=
x
[[
1
,
0
],
[
0
,
0
]]
out5
=
x
[[
1
,
0
],
[
0
,
0
]]
out6
=
x
[([
1
,
0
],
[
0
,
0
])]
out6
=
x
[([
1
,
0
],
[
0
,
0
])]
out7
=
y
[...,
0
]
exe
=
paddle
.
static
.
Executor
(
place
)
exe
=
paddle
.
static
.
Executor
(
place
)
result
=
exe
.
run
(
prog
,
fetch_list
=
[
out1
,
out2
,
out3
,
out4
,
out5
,
out6
])
result
=
exe
.
run
(
prog
,
fetch_list
=
[
out1
,
out2
,
out3
,
out4
,
out5
,
out6
,
out7
])
expected
=
[
expected
=
[
data
[
0
:,
...,
1
:],
data
[
0
:,
...],
data
[...,
1
:],
data
[...],
data
[
0
:,
...,
1
:],
data
[
0
:,
...],
data
[...,
1
:],
data
[...],
data
[[
1
,
0
],
[
0
,
0
]],
data
[([
1
,
0
],
[
0
,
0
])]
data
[[
1
,
0
],
[
0
,
0
]],
data
[([
1
,
0
],
[
0
,
0
])]
,
np
.
array
([
1
])
]
]
self
.
assertTrue
((
result
[
0
]
==
expected
[
0
]).
all
())
self
.
assertTrue
((
result
[
0
]
==
expected
[
0
]).
all
())
...
@@ -247,6 +250,7 @@ class TestVariable(unittest.TestCase):
...
@@ -247,6 +250,7 @@ class TestVariable(unittest.TestCase):
self
.
assertTrue
((
result
[
3
]
==
expected
[
3
]).
all
())
self
.
assertTrue
((
result
[
3
]
==
expected
[
3
]).
all
())
self
.
assertTrue
((
result
[
4
]
==
expected
[
4
]).
all
())
self
.
assertTrue
((
result
[
4
]
==
expected
[
4
]).
all
())
self
.
assertTrue
((
result
[
5
]
==
expected
[
5
]).
all
())
self
.
assertTrue
((
result
[
5
]
==
expected
[
5
]).
all
())
self
.
assertTrue
((
result
[
6
]
==
expected
[
6
]).
all
())
with
self
.
assertRaises
(
IndexError
):
with
self
.
assertRaises
(
IndexError
):
res
=
x
[[
1.2
,
0
]]
res
=
x
[[
1.2
,
0
]]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录