Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
60c5adaa
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
60c5adaa
编写于
9月 06, 2021
作者:
W
WeiXin
提交者:
GitHub
9月 06, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support numpy dtype and polish code of list index. (#35404)
* support numpy dtype and polish code of list index. * polish code.
上级
5675042d
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
76 addition
and
11 deletion
+76
-11
paddle/fluid/pybind/imperative.cc
paddle/fluid/pybind/imperative.cc
+12
-5
python/paddle/fluid/dygraph/varbase_patch_methods.py
python/paddle/fluid/dygraph/varbase_patch_methods.py
+28
-6
python/paddle/fluid/tests/unittests/test_var_base.py
python/paddle/fluid/tests/unittests/test_var_base.py
+36
-0
未找到文件。
paddle/fluid/pybind/imperative.cc
浏览文件 @
60c5adaa
...
...
@@ -331,7 +331,14 @@ GetVarBaseListFromPyHandle(const py::handle &handle) {
return
result
;
}
static
bool
IsNumpyType
(
PyObject
*
obj
)
{
// It is not a good way to judge the type of obj by its type'name. Maybe using
// `PyArray_IsScalar` will be better. However, this interface cannot be used
// by including pybind11, and it needs to compile with numpy.
auto
type_name
=
std
::
string
(
Py_TYPE
(
obj
)
->
tp_name
);
return
type_name
==
"numpy.int64"
||
type_name
==
"numpy.longlong"
||
type_name
==
"numpy.int32"
||
type_name
==
"numpy.int16"
;
}
static
imperative
::
NameVarBaseMap
ConvertToNameVarBaseMap
(
const
PyNameVarBaseMap
&
map
)
{
imperative
::
NameVarBaseMap
result
;
...
...
@@ -372,7 +379,7 @@ static int _PySlice_GetIndices(PySliceObject *r, Py_ssize_t length,
if
(
r
->
step
==
Py_None
)
{
*
step
=
1
;
}
else
{
if
(
PyCheckInteger
(
r
->
step
))
{
if
(
PyCheckInteger
(
r
->
step
)
||
IsNumpyType
(
r
->
step
)
)
{
*
step
=
PyLong_AsLong
(
r
->
step
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
...
...
@@ -384,7 +391,7 @@ static int _PySlice_GetIndices(PySliceObject *r, Py_ssize_t length,
if
(
r
->
start
==
Py_None
)
{
*
start
=
*
step
<
0
?
length
-
1
:
0
;
}
else
{
if
(
PyCheckInteger
(
r
->
start
))
{
if
(
PyCheckInteger
(
r
->
start
)
||
IsNumpyType
(
r
->
start
)
)
{
*
start
=
PyLong_AsLong
(
r
->
start
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
...
...
@@ -398,7 +405,7 @@ static int _PySlice_GetIndices(PySliceObject *r, Py_ssize_t length,
if
(
r
->
stop
==
Py_None
)
{
*
stop
=
*
step
<
0
?
-
1
:
length
;
}
else
{
if
(
PyCheckInteger
(
r
->
stop
))
{
if
(
PyCheckInteger
(
r
->
stop
)
||
IsNumpyType
(
r
->
stop
)
)
{
*
stop
=
PyLong_AsLong
(
r
->
stop
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
...
...
@@ -456,7 +463,7 @@ static void ParseIndexingSlice(
infer_flags
->
push_back
(
1
);
int
dim_len
=
shape
[
dim
];
if
(
PyCheckInteger
(
slice_item
))
{
if
(
PyCheckInteger
(
slice_item
)
||
IsNumpyType
(
slice_item
)
)
{
// integer, PyLong_AsLong supports both int and long
int
start
=
static_cast
<
int
>
(
PyLong_AsLong
(
slice_item
));
auto
s_t
=
start
;
...
...
python/paddle/fluid/dygraph/varbase_patch_methods.py
浏览文件 @
60c5adaa
...
...
@@ -544,7 +544,7 @@ def monkey_patch_varbase():
return
array
def
contain_tensor
(
item
):
if
not
isinstance
(
item
,
tuple
):
if
not
isinstance
(
item
,
(
tuple
,
list
)
):
item
=
[
item
]
for
slice_item
in
item
:
...
...
@@ -554,20 +554,21 @@ def monkey_patch_varbase():
or
isinstance
(
slice_item
.
step
,
Variable
):
return
True
else
:
if
isinstance
(
slice_item
,
Variable
):
if
isinstance
(
slice_item
,
Variable
)
and
Variable
.
dtype
!=
paddle
.
bool
:
return
True
return
False
def
__getitem__
(
self
,
item
):
def
is_list_tuple
(
index
,
contain_type
):
def
_is_list_tuple
(
item
):
if
not
(
isinstance
(
item
,
(
list
,
tuple
))
or
type
(
item
)
==
contain_type
):
return
False
if
isinstance
(
item
,
(
tuple
,
list
)):
for
s
in
item
:
if
not
_is_list_tuple
(
s
):
return
False
else
:
if
type
(
item
)
!=
contain_type
:
return
False
return
True
if
not
isinstance
(
index
,
(
tuple
,
list
)):
...
...
@@ -599,7 +600,28 @@ def monkey_patch_varbase():
return
False
if
contain_tensor_or_list
(
item
):
def
is_combine_index
(
item
):
var_type
=
None
item_type
=
None
if
isinstance
(
item
,
(
tuple
,
list
)):
for
slice_item
in
item
:
if
item_type
is
None
:
item_type
=
type
(
slice_item
)
else
:
if
type
(
slice_item
)
!=
item_type
:
return
True
if
isinstance
(
slice_item
,
Variable
):
if
var_type
is
None
:
var_type
=
slice_item
.
dtype
else
:
if
var_type
!=
slice_item
.
dtype
:
return
True
return
False
return
False
if
contain_tensor_or_list
(
item
)
and
not
is_combine_index
(
item
):
# To reuse code with static graph,
# Call _setitem_impl_ when item contains tensor or list.
return
_setitem_impl_
(
self
,
item
,
value
)
...
...
python/paddle/fluid/tests/unittests/test_var_base.py
浏览文件 @
60c5adaa
...
...
@@ -779,6 +779,40 @@ class TestVarBase(unittest.TestCase):
for
i
,
e
in
enumerate
(
w
):
self
.
assertTrue
(
np
.
array_equal
(
e
.
numpy
(),
np_value
[
i
]))
def
_test_numpy_index
(
self
):
array
=
np
.
arange
(
120
).
reshape
([
4
,
5
,
6
])
t
=
paddle
.
to_tensor
(
array
)
self
.
assertTrue
(
np
.
array_equal
(
t
[
np
.
longlong
(
0
)].
numpy
(),
array
[
0
]))
self
.
assertTrue
(
np
.
array_equal
(
t
[
np
.
longlong
(
0
):
np
.
longlong
(
4
):
np
.
longlong
(
2
)]
.
numpy
(),
array
[
0
:
4
:
2
]))
self
.
assertTrue
(
np
.
array_equal
(
t
[
np
.
int64
(
0
)].
numpy
(),
array
[
0
]))
self
.
assertTrue
(
np
.
array_equal
(
t
[
np
.
int32
(
1
):
np
.
int32
(
4
):
np
.
int32
(
2
)].
numpy
(),
array
[
1
:
4
:
2
]))
self
.
assertTrue
(
np
.
array_equal
(
t
[
np
.
int16
(
0
):
np
.
int16
(
4
):
np
.
int16
(
2
)].
numpy
(),
array
[
0
:
4
:
2
]))
def
_test_list_index
(
self
):
# case1:
array
=
np
.
arange
(
120
).
reshape
([
6
,
5
,
4
])
x
=
paddle
.
to_tensor
(
array
)
py_idx
=
[[
0
,
2
,
0
,
1
,
3
],
[
0
,
0
,
1
,
2
,
0
]]
idx
=
[
paddle
.
to_tensor
(
py_idx
[
0
]),
paddle
.
to_tensor
(
py_idx
[
1
])]
self
.
assertTrue
(
np
.
array_equal
(
x
[
idx
].
numpy
(),
array
[
py_idx
]))
self
.
assertTrue
(
np
.
array_equal
(
x
[
py_idx
].
numpy
(),
array
[
py_idx
]))
# case2:
tensor_x
=
paddle
.
to_tensor
(
np
.
zeros
(
12
).
reshape
(
2
,
6
).
astype
(
np
.
float32
))
tensor_y1
=
paddle
.
zeros
([
1
])
+
2
tensor_y2
=
paddle
.
zeros
([
1
])
+
5
tensor_x
[:,
tensor_y1
:
tensor_y2
]
=
42
res
=
tensor_x
.
numpy
()
exp
=
np
.
array
([[
0.
,
0.
,
42.
,
42.
,
42.
,
0.
],
[
0.
,
0.
,
42.
,
42.
,
42.
,
0.
]])
self
.
assertTrue
(
np
.
array_equal
(
res
,
exp
))
def
test_slice
(
self
):
with
fluid
.
dygraph
.
guard
():
self
.
_test_slice
()
...
...
@@ -787,6 +821,8 @@ class TestVarBase(unittest.TestCase):
self
.
_test_for_getitem_ellipsis_index
()
self
.
_test_none_index
()
self
.
_test_bool_index
()
self
.
_test_numpy_index
()
self
.
_test_list_index
()
var
=
fluid
.
dygraph
.
to_variable
(
self
.
array
)
self
.
assertTrue
(
np
.
array_equal
(
var
[
1
,
:].
numpy
(),
self
.
array
[
1
,
:]))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录