Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
de0cb386
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
de0cb386
编写于
11月 22, 2021
作者:
Z
zyfncg
提交者:
GitHub
11月 22, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bug of indexing tensor with None (#37400)
上级
31344ab7
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
23 addition
and
32 deletion
+23
-32
paddle/fluid/pybind/imperative.cc
paddle/fluid/pybind/imperative.cc
+3
-25
python/paddle/fluid/tests/unittests/test_set_value_op.py
python/paddle/fluid/tests/unittests/test_set_value_op.py
+8
-0
python/paddle/fluid/tests/unittests/test_var_base.py
python/paddle/fluid/tests/unittests/test_var_base.py
+3
-1
python/paddle/fluid/tests/unittests/test_variable.py
python/paddle/fluid/tests/unittests/test_variable.py
+4
-2
python/paddle/fluid/variable_index.py
python/paddle/fluid/variable_index.py
+5
-4
未找到文件。
paddle/fluid/pybind/imperative.cc
浏览文件 @
de0cb386
...
@@ -562,7 +562,7 @@ static void ParseIndexingSlice(
...
@@ -562,7 +562,7 @@ static void ParseIndexingSlice(
PADDLE_ENFORCE_LE
(
ell_count
,
1
,
PADDLE_ENFORCE_LE
(
ell_count
,
1
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"An index can only have a single ellipsis ('...')"
));
"An index can only have a single ellipsis ('...')"
));
int
none_count
=
0
;
for
(
int
i
=
0
,
dim
=
0
;
i
<
size
;
++
i
)
{
for
(
int
i
=
0
,
dim
=
0
;
i
<
size
;
++
i
)
{
PyObject
*
slice_item
=
PyTuple_GetItem
(
index
,
i
);
PyObject
*
slice_item
=
PyTuple_GetItem
(
index
,
i
);
...
@@ -608,7 +608,8 @@ static void ParseIndexingSlice(
...
@@ -608,7 +608,8 @@ static void ParseIndexingSlice(
}
else
if
(
slice_item
==
Py_Ellipsis
)
{
}
else
if
(
slice_item
==
Py_Ellipsis
)
{
dim
+=
rank
-
specified_dims
;
dim
+=
rank
-
specified_dims
;
}
else
if
(
slice_item
==
Py_None
)
{
}
else
if
(
slice_item
==
Py_None
)
{
none_axes
->
push_back
(
dim
);
none_axes
->
push_back
(
dim
+
none_count
);
none_count
++
;
}
else
if
(
PyList_Check
(
slice_item
))
{
}
else
if
(
PyList_Check
(
slice_item
))
{
*
list_select_flag
=
true
;
*
list_select_flag
=
true
;
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
...
@@ -1214,29 +1215,6 @@ void BindImperative(py::module *m_ptr) {
...
@@ -1214,29 +1215,6 @@ void BindImperative(py::module *m_ptr) {
axis
-=
len
;
axis
-=
len
;
}
}
// Deal with cases that there are more than one
// prefix none index, For example:
// [None, None, :, :, None]
// the none_axes int the return of ParseIndexingSlice is:
// [0, 0, 2 ]
// according to the interface of "unsqueeze2",
// we should convert it to:
// [0, 0, 4 ]
int
prefix_zero_cnt
=
0
;
for
(
const
auto
&
axis
:
none_axes
)
{
if
(
axis
==
0
)
{
prefix_zero_cnt
++
;
}
else
{
break
;
}
}
if
(
prefix_zero_cnt
>
0
)
{
int
none_axes_num
=
static_cast
<
int
>
(
none_axes
.
size
());
for
(
int
i
=
prefix_zero_cnt
;
i
<
none_axes_num
;
++
i
)
{
none_axes
[
i
]
+=
prefix_zero_cnt
;
}
}
imperative
::
NameVarBaseMap
ins
=
{{
"X"
,
{
out
}}};
imperative
::
NameVarBaseMap
ins
=
{{
"X"
,
{
out
}}};
framework
::
AttributeMap
attrs
=
{{
"axes"
,
none_axes
}};
framework
::
AttributeMap
attrs
=
{{
"axes"
,
none_axes
}};
auto
new_out
=
std
::
shared_ptr
<
imperative
::
VarBase
>
(
auto
new_out
=
std
::
shared_ptr
<
imperative
::
VarBase
>
(
...
...
python/paddle/fluid/tests/unittests/test_set_value_op.py
浏览文件 @
de0cb386
...
@@ -408,6 +408,14 @@ class TestSetValueItemNone9(TestSetValueApi):
...
@@ -408,6 +408,14 @@ class TestSetValueItemNone9(TestSetValueApi):
self
.
data
[
None
,
:,
1
,
...,
None
]
=
np
.
zeros
(
self
.
shape
)[
0
,
0
,
:,
None
]
self
.
data
[
None
,
:,
1
,
...,
None
]
=
np
.
zeros
(
self
.
shape
)[
0
,
0
,
:,
None
]
class
TestSetValueItemNone10
(
TestSetValueApi
):
def
_call_setitem
(
self
,
x
):
x
[...,
None
,
:,
None
]
=
np
.
zeros
(
self
.
shape
)[...,
None
,
:,
None
]
def
_get_answer
(
self
):
self
.
data
[...,
None
,
:,
None
]
=
np
.
zeros
(
self
.
shape
)[...,
None
,
:,
None
]
# 1.5 item is list or Tensor of bol
# 1.5 item is list or Tensor of bol
class
TestSetValueItemBool1
(
TestSetValueApi
):
class
TestSetValueItemBool1
(
TestSetValueApi
):
def
_call_setitem
(
self
,
x
):
def
_call_setitem
(
self
,
x
):
...
...
python/paddle/fluid/tests/unittests/test_var_base.py
浏览文件 @
de0cb386
...
@@ -723,6 +723,7 @@ class TestVarBase(unittest.TestCase):
...
@@ -723,6 +723,7 @@ class TestVarBase(unittest.TestCase):
var_tensor
[
None
].
numpy
(),
var_tensor
[
None
].
numpy
(),
var_tensor
[
0
,
0
,
None
,
0
,
0
,
None
].
numpy
(),
var_tensor
[
0
,
0
,
None
,
0
,
0
,
None
].
numpy
(),
var_tensor
[
None
,
None
,
0
,
...,
None
].
numpy
(),
var_tensor
[
None
,
None
,
0
,
...,
None
].
numpy
(),
var_tensor
[...,
None
,
:,
None
].
numpy
(),
var_tensor
[
0
,
1
:
10
:
2
,
None
,
None
,
...].
numpy
(),
var_tensor
[
0
,
1
:
10
:
2
,
None
,
None
,
...].
numpy
(),
]
]
...
@@ -738,11 +739,12 @@ class TestVarBase(unittest.TestCase):
...
@@ -738,11 +739,12 @@ class TestVarBase(unittest.TestCase):
np
.
array_equal
(
var
[
8
],
np_value
[
0
,
0
,
None
,
0
,
0
,
None
]))
np
.
array_equal
(
var
[
8
],
np_value
[
0
,
0
,
None
,
0
,
0
,
None
]))
self
.
assertTrue
(
self
.
assertTrue
(
np
.
array_equal
(
var
[
9
],
np_value
[
None
,
None
,
0
,
...,
None
]))
np
.
array_equal
(
var
[
9
],
np_value
[
None
,
None
,
0
,
...,
None
]))
self
.
assertTrue
(
np
.
array_equal
(
var
[
10
],
np_value
[...,
None
,
:,
None
]))
# TODO(zyfncg) there is a bug of dimensions when slice step > 1 and
# TODO(zyfncg) there is a bug of dimensions when slice step > 1 and
# indexs has int type
# indexs has int type
# self.assertTrue(
# self.assertTrue(
# np.array_equal(var[1
0
], np_value[0, 1:10:2, None, None, ...]))
# np.array_equal(var[1
1
], np_value[0, 1:10:2, None, None, ...]))
def
_test_bool_index
(
self
):
def
_test_bool_index
(
self
):
shape
=
(
4
,
2
,
5
,
64
)
shape
=
(
4
,
2
,
5
,
64
)
...
...
python/paddle/fluid/tests/unittests/test_variable.py
浏览文件 @
de0cb386
...
@@ -436,13 +436,15 @@ class TestVariableSlice(unittest.TestCase):
...
@@ -436,13 +436,15 @@ class TestVariableSlice(unittest.TestCase):
out1
=
x
[
0
:,
None
]
out1
=
x
[
0
:,
None
]
out2
=
x
[
None
,
1
:]
out2
=
x
[
None
,
1
:]
out3
=
x
[
None
]
out3
=
x
[
None
]
out4
=
x
[...,
None
,
:,
None
]
outs
=
[
out0
,
out1
,
out2
,
out3
]
outs
=
[
out0
,
out1
,
out2
,
out3
,
out4
]
exe
=
paddle
.
static
.
Executor
(
place
)
exe
=
paddle
.
static
.
Executor
(
place
)
result
=
exe
.
run
(
prog
,
fetch_list
=
outs
)
result
=
exe
.
run
(
prog
,
fetch_list
=
outs
)
expected
=
[
expected
=
[
data
[
0
:,
None
,
1
:],
data
[
0
:,
None
],
data
[
None
,
1
:],
data
[
None
]
data
[
0
:,
None
,
1
:],
data
[
0
:,
None
],
data
[
None
,
1
:],
data
[
None
],
data
[...,
None
,
:,
None
]
]
]
for
i
in
range
(
len
(
outs
)):
for
i
in
range
(
len
(
outs
)):
self
.
assertEqual
(
outs
[
i
].
shape
,
expected
[
i
].
shape
)
self
.
assertEqual
(
outs
[
i
].
shape
,
expected
[
i
].
shape
)
...
...
python/paddle/fluid/variable_index.py
浏览文件 @
de0cb386
...
@@ -204,7 +204,8 @@ def replace_ellipsis(var, item):
...
@@ -204,7 +204,8 @@ def replace_ellipsis(var, item):
# Remove Variable to skip bug when counting Ellipsis
# Remove Variable to skip bug when counting Ellipsis
item_remove_var
=
[
item_remove_var
=
[
ele
for
ele
in
item
if
not
isinstance
(
ele
,
(
Variable
,
np
.
ndarray
))
ele
for
ele
in
item
if
not
isinstance
(
ele
,
(
Variable
,
np
.
ndarray
))
and
ele
is
not
None
]
]
ell_count
=
item_remove_var
.
count
(
Ellipsis
)
ell_count
=
item_remove_var
.
count
(
Ellipsis
)
if
ell_count
==
0
:
if
ell_count
==
0
:
...
@@ -218,7 +219,7 @@ def replace_ellipsis(var, item):
...
@@ -218,7 +219,7 @@ def replace_ellipsis(var, item):
return
item
[:
-
1
]
return
item
[:
-
1
]
else
:
else
:
item
[
ell_idx
:
ell_idx
+
1
]
=
[
slice
(
None
)]
*
(
item
[
ell_idx
:
ell_idx
+
1
]
=
[
slice
(
None
)]
*
(
len
(
var
.
shape
)
-
len
(
item
)
+
1
)
len
(
var
.
shape
)
-
len
(
item
)
+
item
.
count
(
None
)
+
1
)
return
item
return
item
...
@@ -298,8 +299,8 @@ def _getitem_impl_(var, item):
...
@@ -298,8 +299,8 @@ def _getitem_impl_(var, item):
use_strided_slice
=
False
use_strided_slice
=
False
item
=
replace_ndarray
(
item
)
item
=
replace_ndarray
(
item
)
item
,
none_axes
=
replace_none
(
item
)
item
=
replace_ellipsis
(
var
,
item
)
item
=
replace_ellipsis
(
var
,
item
)
item
,
none_axes
=
replace_none
(
item
)
slice_info
=
SliceInfo
()
slice_info
=
SliceInfo
()
for
dim
,
slice_item
in
enumerate
(
item
):
for
dim
,
slice_item
in
enumerate
(
item
):
...
@@ -517,8 +518,8 @@ def _setitem_impl_(var, item, value):
...
@@ -517,8 +518,8 @@ def _setitem_impl_(var, item, value):
steps
=
[]
steps
=
[]
item
=
replace_ndarray
(
item
)
item
=
replace_ndarray
(
item
)
item
,
none_axes
=
replace_none
(
item
)
item
=
replace_ellipsis
(
var
,
item
)
item
=
replace_ellipsis
(
var
,
item
)
item
,
none_axes
=
replace_none
(
item
)
slice_info
=
SliceInfo
()
slice_info
=
SliceInfo
()
dim
=
0
dim
=
0
for
_
,
slice_item
in
enumerate
(
item
):
for
_
,
slice_item
in
enumerate
(
item
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录