Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
aceec7fb
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
aceec7fb
编写于
4月 25, 2021
作者:
L
liym27
提交者:
GitHub
4月 25, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[slice] Support index is Tensor for slice in dynamic mode (#32435)
上级
25e723e7
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
94 addition
and
3 deletion
+94
-3
paddle/fluid/pybind/imperative.cc
paddle/fluid/pybind/imperative.cc
+1
-1
python/paddle/fluid/dygraph/varbase_patch_methods.py
python/paddle/fluid/dygraph/varbase_patch_methods.py
+28
-2
python/paddle/fluid/tests/unittests/test_var_base.py
python/paddle/fluid/tests/unittests/test_var_base.py
+65
-0
未找到文件。
paddle/fluid/pybind/imperative.cc
浏览文件 @
aceec7fb
...
...
@@ -746,7 +746,7 @@ void BindImperative(py::module *m_ptr) {
// inplace operator for the VarBase self.
self
->
BumpInplaceVersion
();
})
.
def
(
"_
_getitem__
"
,
.
def
(
"_
getitem_index_not_tensor
"
,
[](
std
::
shared_ptr
<
imperative
::
VarBase
>
&
self
,
py
::
handle
_index
)
{
std
::
vector
<
int
>
slice_axes
,
slice_starts
,
slice_ends
,
slice_strides
,
decrease_axis
,
infer_flags
;
...
...
python/paddle/fluid/dygraph/varbase_patch_methods.py
浏览文件 @
aceec7fb
...
...
@@ -21,7 +21,7 @@ import paddle
from
..
import
framework
from
..
import
core
from
..
import
unique_name
from
..framework
import
Variable
,
Parameter
,
ParamBase
from
..framework
import
Variable
,
Parameter
,
ParamBase
,
_getitem_impl_
from
.base
import
switch_to_static_graph
from
.math_op_patch
import
monkey_patch_math_varbase
from
.parallel
import
scale_loss
...
...
@@ -437,6 +437,31 @@ def monkey_patch_varbase():
def
__array__
(
self
,
dtype
=
None
):
return
self
.
numpy
().
astype
(
dtype
)
def
__getitem__
(
self
,
item
):
def
contain_tensor
(
item
):
if
not
isinstance
(
item
,
tuple
):
item
=
[
item
]
for
slice_item
in
item
:
if
isinstance
(
slice_item
,
slice
):
if
isinstance
(
slice_item
.
start
,
Variable
)
\
or
isinstance
(
slice_item
.
stop
,
Variable
)
\
or
isinstance
(
slice_item
.
step
,
Variable
):
return
True
else
:
if
isinstance
(
slice_item
,
Variable
):
return
True
return
False
if
contain_tensor
(
item
):
# 1. Call _getitem_impl_ when item contains tensor.
# Why not call a c++ function ? Because item can't be parsed when it contains tensor.
return
_getitem_impl_
(
self
,
item
)
else
:
# 2. Call c++ func getitem_index_not_tensor to speedup.
return
self
.
_getitem_index_not_tensor
(
item
)
for
method_name
,
method
in
(
(
"__bool__"
,
__bool__
),
(
"__nonzero__"
,
__nonzero__
),
(
"_to_static_var"
,
_to_static_var
),
(
"set_value"
,
set_value
),
...
...
@@ -445,7 +470,8 @@ def monkey_patch_varbase():
(
"gradient"
,
gradient
),
(
"register_hook"
,
register_hook
),
(
"__str__"
,
__str__
),
(
"__repr__"
,
__str__
),
(
"__deepcopy__"
,
__deepcopy__
),
(
"__module__"
,
"paddle"
),
(
"__name__"
,
"Tensor"
),
(
"__array__"
,
__array__
)):
(
"__name__"
,
"Tensor"
),
(
"__array__"
,
__array__
),
(
"__getitem__"
,
__getitem__
)):
setattr
(
core
.
VarBase
,
method_name
,
method
)
# NOTE(zhiqiu): pybind11 will set a default __str__ method of enum class.
...
...
python/paddle/fluid/tests/unittests/test_var_base.py
浏览文件 @
aceec7fb
...
...
@@ -473,6 +473,70 @@ class TestVarBase(unittest.TestCase):
np
.
array_equal
(
local_out
[
15
],
tensor_array
[::
-
1
,
::
-
1
,
::
-
1
]))
self
.
assertTrue
(
np
.
array_equal
(
local_out
[
16
],
tensor_array
[
-
4
:
4
]))
def
_test_slice_for_tensor_attr
(
self
):
tensor_array
=
np
.
array
(
[[[
1
,
2
,
3
],
[
4
,
5
,
6
],
[
7
,
8
,
9
]],
[[
10
,
11
,
12
],
[
13
,
14
,
15
],
[
16
,
17
,
18
]],
[[
19
,
20
,
21
],
[
22
,
23
,
24
],
[
25
,
26
,
27
]]]).
astype
(
'float32'
)
var
=
paddle
.
to_tensor
(
tensor_array
)
one
=
paddle
.
ones
(
shape
=
[
1
],
dtype
=
"int32"
)
two
=
paddle
.
full
(
shape
=
[
1
],
fill_value
=
2
,
dtype
=
"int32"
)
negative_one
=
paddle
.
full
(
shape
=
[
1
],
fill_value
=-
1
,
dtype
=
"int32"
)
four
=
paddle
.
full
(
shape
=
[
1
],
fill_value
=
4
,
dtype
=
"int32"
)
var
=
fluid
.
dygraph
.
to_variable
(
tensor_array
)
var1
=
var
[
0
,
one
,
one
]
var2
=
var
[
one
:]
var3
=
var
[
0
:
one
]
var4
=
var
[::
negative_one
]
var5
=
var
[
one
,
one
:,
one
:]
var_reshape
=
fluid
.
layers
.
reshape
(
var
,
[
3
,
negative_one
,
3
])
var6
=
var_reshape
[:,
:,
negative_one
]
var7
=
var
[:,
:,
:
negative_one
]
var8
=
var
[:
one
,
:
one
,
:
1
]
var9
=
var
[:
-
1
,
:
negative_one
,
:
negative_one
]
var10
=
var
[::
negative_one
,
:
one
,
:
negative_one
]
var11
=
var
[:
negative_one
,
::
-
1
,
negative_one
:]
var12
=
var
[
one
:
2
,
2
:,
::
negative_one
]
var13
=
var
[
two
:
10
,
2
:,
-
2
:
negative_one
]
var14
=
var
[
1
:
negative_one
,
0
:
2
,
::
negative_one
]
var15
=
var
[::
negative_one
,
::
-
1
,
::
negative_one
]
var16
=
var
[
-
4
:
4
]
vars
=
[
var
,
var1
,
var2
,
var3
,
var4
,
var5
,
var6
,
var7
,
var8
,
var9
,
var10
,
var11
,
var12
,
var13
,
var14
,
var15
,
var16
]
local_out
=
[
var
.
numpy
()
for
var
in
vars
]
self
.
assertTrue
(
np
.
array_equal
(
local_out
[
1
],
tensor_array
[
0
,
1
,
1
:
2
]))
self
.
assertTrue
(
np
.
array_equal
(
local_out
[
2
],
tensor_array
[
1
:]))
self
.
assertTrue
(
np
.
array_equal
(
local_out
[
3
],
tensor_array
[
0
:
1
]))
self
.
assertTrue
(
np
.
array_equal
(
local_out
[
4
],
tensor_array
[::
-
1
]))
self
.
assertTrue
(
np
.
array_equal
(
local_out
[
5
],
tensor_array
[
1
,
1
:,
1
:]))
self
.
assertTrue
(
np
.
array_equal
(
local_out
[
6
],
tensor_array
.
reshape
((
3
,
-
1
,
3
))[:,
:,
-
1
]))
self
.
assertTrue
(
np
.
array_equal
(
local_out
[
7
],
tensor_array
[:,
:,
:
-
1
]))
self
.
assertTrue
(
np
.
array_equal
(
local_out
[
8
],
tensor_array
[:
1
,
:
1
,
:
1
]))
self
.
assertTrue
(
np
.
array_equal
(
local_out
[
9
],
tensor_array
[:
-
1
,
:
-
1
,
:
-
1
]))
self
.
assertTrue
(
np
.
array_equal
(
local_out
[
10
],
tensor_array
[::
-
1
,
:
1
,
:
-
1
]))
self
.
assertTrue
(
np
.
array_equal
(
local_out
[
11
],
tensor_array
[:
-
1
,
::
-
1
,
-
1
:]))
self
.
assertTrue
(
np
.
array_equal
(
local_out
[
12
],
tensor_array
[
1
:
2
,
2
:,
::
-
1
]))
self
.
assertTrue
(
np
.
array_equal
(
local_out
[
13
],
tensor_array
[
2
:
10
,
2
:,
-
2
:
-
1
]))
self
.
assertTrue
(
np
.
array_equal
(
local_out
[
14
],
tensor_array
[
1
:
-
1
,
0
:
2
,
::
-
1
]))
self
.
assertTrue
(
np
.
array_equal
(
local_out
[
15
],
tensor_array
[::
-
1
,
::
-
1
,
::
-
1
]))
self
.
assertTrue
(
np
.
array_equal
(
local_out
[
16
],
tensor_array
[
-
4
:
4
]))
def
_test_for_var
(
self
):
np_value
=
np
.
random
.
random
((
30
,
100
,
100
)).
astype
(
'float32'
)
w
=
fluid
.
dygraph
.
to_variable
(
np_value
)
...
...
@@ -483,6 +547,7 @@ class TestVarBase(unittest.TestCase):
def
test_slice
(
self
):
with
fluid
.
dygraph
.
guard
():
self
.
_test_slice
()
self
.
_test_slice_for_tensor_attr
()
self
.
_test_for_var
()
var
=
fluid
.
dygraph
.
to_variable
(
self
.
array
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录