Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
e7df47ec
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e7df47ec
编写于
8月 26, 2021
作者:
W
WeiXin
提交者:
GitHub
8月 26, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support tensor index. (#34824)
* polish code * polish code. * polish code. * polish code. * polish code.
上级
678a259a
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
741 addition
and
50 deletion
+741
-50
paddle/fluid/pybind/imperative.cc
paddle/fluid/pybind/imperative.cc
+1
-1
python/paddle/fluid/dygraph/varbase_patch_methods.py
python/paddle/fluid/dygraph/varbase_patch_methods.py
+47
-17
python/paddle/fluid/tests/unittests/test_variable.py
python/paddle/fluid/tests/unittests/test_variable.py
+461
-5
python/paddle/fluid/variable_index.py
python/paddle/fluid/variable_index.py
+232
-27
未找到文件。
paddle/fluid/pybind/imperative.cc
浏览文件 @
e7df47ec
...
@@ -815,7 +815,7 @@ void BindImperative(py::module *m_ptr) {
...
@@ -815,7 +815,7 @@ void BindImperative(py::module *m_ptr) {
.
def
(
"__init__"
,
&
InitVarBaseFromNumpyWithArgDefault
,
py
::
arg
(
"value"
))
.
def
(
"__init__"
,
&
InitVarBaseFromNumpyWithArgDefault
,
py
::
arg
(
"value"
))
.
def
(
"__init__"
,
&
InitVarBaseFromTensorWithArgDefault
,
py
::
arg
(
"tensor"
))
.
def
(
"__init__"
,
&
InitVarBaseFromTensorWithArgDefault
,
py
::
arg
(
"tensor"
))
.
def
(
"__init__"
,
&
InitVarBaseFromNumpyWithKwargs
)
.
def
(
"__init__"
,
&
InitVarBaseFromNumpyWithKwargs
)
.
def
(
"__setitem__"
,
.
def
(
"__setitem_
varbase_
_"
,
[](
std
::
shared_ptr
<
imperative
::
VarBase
>
&
self
,
py
::
handle
_index
,
[](
std
::
shared_ptr
<
imperative
::
VarBase
>
&
self
,
py
::
handle
_index
,
py
::
object
&
value_obj
)
{
py
::
object
&
value_obj
)
{
VLOG
(
4
)
<<
"Call __setitem__"
;
VLOG
(
4
)
<<
"Call __setitem__"
;
...
...
python/paddle/fluid/dygraph/varbase_patch_methods.py
浏览文件 @
e7df47ec
...
@@ -22,7 +22,7 @@ import paddle
...
@@ -22,7 +22,7 @@ import paddle
from
..
import
framework
from
..
import
framework
from
..
import
core
from
..
import
core
from
..
import
unique_name
from
..
import
unique_name
from
..framework
import
Variable
,
Parameter
,
ParamBase
,
_getitem_impl_
from
..framework
import
Variable
,
Parameter
,
ParamBase
,
_getitem_impl_
,
_setitem_impl_
from
.base
import
switch_to_static_graph
from
.base
import
switch_to_static_graph
from
.math_op_patch
import
monkey_patch_math_varbase
from
.math_op_patch
import
monkey_patch_math_varbase
from
.parallel
import
scale_loss
from
.parallel
import
scale_loss
...
@@ -543,23 +543,41 @@ def monkey_patch_varbase():
...
@@ -543,23 +543,41 @@ def monkey_patch_varbase():
array
=
array
.
astype
(
dtype
)
array
=
array
.
astype
(
dtype
)
return
array
return
array
def
contain_tensor
(
item
):
if
not
isinstance
(
item
,
tuple
):
item
=
[
item
]
for
slice_item
in
item
:
if
isinstance
(
slice_item
,
slice
):
if
isinstance
(
slice_item
.
start
,
Variable
)
\
or
isinstance
(
slice_item
.
stop
,
Variable
)
\
or
isinstance
(
slice_item
.
step
,
Variable
):
return
True
else
:
if
isinstance
(
slice_item
,
Variable
):
return
True
return
False
def
__getitem__
(
self
,
item
):
def
__getitem__
(
self
,
item
):
def
contain_tensor
(
item
):
def
is_list_tuple
(
index
,
contain_type
):
if
not
isinstance
(
item
,
tuple
):
def
_is_list_tuple
(
item
):
item
=
[
item
]
if
not
(
isinstance
(
item
,
(
list
,
tuple
))
or
type
(
item
)
==
contain_type
):
for
slice_item
in
item
:
return
False
if
isinstance
(
slice_item
,
slice
):
if
isinstance
(
item
,
(
tuple
,
list
)):
if
isinstance
(
slice_item
.
start
,
Variable
)
\
for
s
in
item
:
or
isinstance
(
slice_item
.
stop
,
Variable
)
\
if
not
_is_list_tuple
(
s
):
or
isinstance
(
slice_item
.
step
,
Variable
):
return
False
return
True
return
True
else
:
if
isinstance
(
slice_item
,
Variable
):
return
True
return
False
if
contain_tensor
(
item
):
if
not
isinstance
(
index
,
(
tuple
,
list
)):
return
False
for
s
in
index
:
if
not
_is_list_tuple
(
s
):
return
False
return
True
if
contain_tensor
(
item
)
or
is_list_tuple
(
item
,
int
):
# 1. Call _getitem_impl_ when item contains tensor.
# 1. Call _getitem_impl_ when item contains tensor.
# Why not call a c++ function ? Because item can't be parsed when it contains tensor.
# Why not call a c++ function ? Because item can't be parsed when it contains tensor.
return
_getitem_impl_
(
self
,
item
)
return
_getitem_impl_
(
self
,
item
)
...
@@ -568,6 +586,17 @@ def monkey_patch_varbase():
...
@@ -568,6 +586,17 @@ def monkey_patch_varbase():
# 2. Call c++ func getitem_index_not_tensor to speedup.
# 2. Call c++ func getitem_index_not_tensor to speedup.
return
self
.
_getitem_index_not_tensor
(
item
)
return
self
.
_getitem_index_not_tensor
(
item
)
def
__setitem__
(
self
,
item
,
value
):
if
contain_tensor
(
item
):
# 1. Call _setitem_impl_ when item contains tensor.
# Why not call a c++ function ? Because item can't be parsed when it contains tensor.
return
_setitem_impl_
(
self
,
item
,
value
)
else
:
# 2. Call c++ func __setitem_varbase__ to speedup.
return
self
.
__setitem_varbase__
(
item
,
value
)
for
method_name
,
method
in
(
for
method_name
,
method
in
(
(
"__bool__"
,
__bool__
),
(
"__nonzero__"
,
__nonzero__
),
(
"__bool__"
,
__bool__
),
(
"__nonzero__"
,
__nonzero__
),
(
"_to_static_var"
,
_to_static_var
),
(
"set_value"
,
set_value
),
(
"_to_static_var"
,
_to_static_var
),
(
"set_value"
,
set_value
),
...
@@ -577,7 +606,8 @@ def monkey_patch_varbase():
...
@@ -577,7 +606,8 @@ def monkey_patch_varbase():
(
"__str__"
,
__str__
),
(
"__repr__"
,
__str__
),
(
"__str__"
,
__str__
),
(
"__repr__"
,
__str__
),
(
"__deepcopy__"
,
__deepcopy__
),
(
"__module__"
,
"paddle"
),
(
"__deepcopy__"
,
__deepcopy__
),
(
"__module__"
,
"paddle"
),
(
"__name__"
,
"Tensor"
),
(
"__array__"
,
__array__
),
(
"__name__"
,
"Tensor"
),
(
"__array__"
,
__array__
),
(
"__getitem__"
,
__getitem__
),
(
"item"
,
item
)):
(
"__getitem__"
,
__getitem__
),
(
"item"
,
item
),
(
"__setitem__"
,
__setitem__
)):
setattr
(
core
.
VarBase
,
method_name
,
method
)
setattr
(
core
.
VarBase
,
method_name
,
method
)
# NOTE(zhiqiu): pybind11 will set a default __str__ method of enum class.
# NOTE(zhiqiu): pybind11 will set a default __str__ method of enum class.
...
...
python/paddle/fluid/tests/unittests/test_variable.py
浏览文件 @
e7df47ec
...
@@ -15,6 +15,8 @@
...
@@ -15,6 +15,8 @@
from
__future__
import
print_function
from
__future__
import
print_function
import
unittest
import
unittest
from
functools
import
reduce
import
paddle
import
paddle
from
paddle.fluid.framework
import
default_main_program
,
Program
,
convert_np_dtype_to_dtype_
,
in_dygraph_mode
from
paddle.fluid.framework
import
default_main_program
,
Program
,
convert_np_dtype_to_dtype_
,
in_dygraph_mode
import
paddle
import
paddle
...
@@ -228,21 +230,25 @@ class TestVariable(unittest.TestCase):
...
@@ -228,21 +230,25 @@ class TestVariable(unittest.TestCase):
out2
=
x
[
0
:,
...]
out2
=
x
[
0
:,
...]
out3
=
x
[...,
1
:]
out3
=
x
[...,
1
:]
out4
=
x
[...]
out4
=
x
[...]
out5
=
x
[[
1
,
0
],
[
0
,
0
]]
out6
=
x
[([
1
,
0
],
[
0
,
0
])]
exe
=
paddle
.
static
.
Executor
(
place
)
exe
=
paddle
.
static
.
Executor
(
place
)
result
=
exe
.
run
(
prog
,
fetch_list
=
[
out1
,
out2
,
out3
,
out4
])
result
=
exe
.
run
(
prog
,
fetch_list
=
[
out1
,
out2
,
out3
,
out4
,
out5
,
out6
])
expected
=
[
data
[
0
:,
...,
1
:],
data
[
0
:,
...],
data
[...,
1
:],
data
[...]]
expected
=
[
data
[
0
:,
...,
1
:],
data
[
0
:,
...],
data
[...,
1
:],
data
[...],
data
[[
1
,
0
],
[
0
,
0
]],
data
[([
1
,
0
],
[
0
,
0
])]
]
self
.
assertTrue
((
result
[
0
]
==
expected
[
0
]).
all
())
self
.
assertTrue
((
result
[
0
]
==
expected
[
0
]).
all
())
self
.
assertTrue
((
result
[
1
]
==
expected
[
1
]).
all
())
self
.
assertTrue
((
result
[
1
]
==
expected
[
1
]).
all
())
self
.
assertTrue
((
result
[
2
]
==
expected
[
2
]).
all
())
self
.
assertTrue
((
result
[
2
]
==
expected
[
2
]).
all
())
self
.
assertTrue
((
result
[
3
]
==
expected
[
3
]).
all
())
self
.
assertTrue
((
result
[
3
]
==
expected
[
3
]).
all
())
self
.
assertTrue
((
result
[
4
]
==
expected
[
4
]).
all
())
self
.
assertTrue
((
result
[
5
]
==
expected
[
5
]).
all
())
with
self
.
assertRaises
(
IndexError
):
with
self
.
assertRaises
(
IndexError
):
res
=
x
[[
1
,
0
],
[
0
,
0
]]
with
self
.
assertRaises
(
TypeError
):
res
=
x
[[
1.2
,
0
]]
res
=
x
[[
1.2
,
0
]]
def
_test_slice_index_list_bool
(
self
,
place
):
def
_test_slice_index_list_bool
(
self
,
place
):
...
@@ -472,5 +478,455 @@ class TestVariableSlice(unittest.TestCase):
...
@@ -472,5 +478,455 @@ class TestVariableSlice(unittest.TestCase):
self
.
_test_item_none_and_decrease
(
place
)
self
.
_test_item_none_and_decrease
(
place
)
class
TestListIndex
(
unittest
.
TestCase
):
def
numel
(
self
,
shape
):
return
reduce
(
lambda
x
,
y
:
x
*
y
,
shape
)
def
test_static_graph_list_index
(
self
):
paddle
.
enable_static
()
inps_shape
=
[
3
,
4
,
5
,
2
]
array
=
np
.
arange
(
self
.
numel
(
inps_shape
),
dtype
=
'float32'
).
reshape
(
inps_shape
)
index_shape
=
[
3
,
3
,
2
,
1
]
index
=
np
.
arange
(
self
.
numel
(
index_shape
)).
reshape
(
index_shape
)
for
_
in
range
(
3
):
program
=
paddle
.
static
.
Program
()
index_mod
=
(
index
%
(
array
.
shape
[
0
])).
tolist
()
with
paddle
.
static
.
program_guard
(
program
):
x
=
paddle
.
static
.
data
(
name
=
'x'
,
shape
=
array
.
shape
,
dtype
=
'float32'
)
y
=
x
[
index_mod
]
place
=
paddle
.
fluid
.
CPUPlace
(
)
if
not
paddle
.
fluid
.
core
.
is_compiled_with_cuda
(
)
else
paddle
.
fluid
.
CUDAPlace
(
0
)
prog
=
paddle
.
static
.
default_main_program
()
exe
=
paddle
.
static
.
Executor
(
place
)
exe
.
run
(
paddle
.
static
.
default_startup_program
())
fetch_list
=
[
y
.
name
]
getitem_np
=
array
[
index_mod
]
getitem_pp
=
exe
.
run
(
prog
,
feed
=
{
x
.
name
:
array
},
fetch_list
=
fetch_list
)
self
.
assertTrue
(
np
.
array_equal
(
getitem_np
,
getitem_pp
[
0
]))
array
=
array
[
0
]
index
=
index
[
0
]
def
test_dygraph_list_index
(
self
):
paddle
.
disable_static
()
inps_shape
=
[
3
,
4
,
5
,
3
]
array
=
np
.
arange
(
self
.
numel
(
inps_shape
)).
reshape
(
inps_shape
)
index_shape
=
[
2
,
3
,
4
,
5
,
6
]
index
=
np
.
arange
(
self
.
numel
(
index_shape
)).
reshape
(
index_shape
)
for
_
in
range
(
len
(
inps_shape
)
-
1
):
pt
=
paddle
.
to_tensor
(
array
)
index_mod
=
(
index
%
(
array
.
shape
[
-
1
])).
tolist
()
try
:
getitem_np
=
array
[
index_mod
]
except
:
with
self
.
assertRaises
(
ValueError
):
getitem_pp
=
pt
[
index_mod
]
array
=
array
[
0
]
index
=
index
[
0
]
continue
getitem_pp
=
pt
[
index_mod
]
self
.
assertTrue
(
np
.
array_equal
(
getitem_np
,
getitem_pp
.
numpy
()))
array
=
array
[
0
]
index
=
index
[
0
]
def
test_static_graph_list_index_muti_dim
(
self
):
paddle
.
enable_static
()
inps_shape
=
[
3
,
4
,
5
]
array
=
np
.
arange
(
self
.
numel
(
inps_shape
),
dtype
=
'float32'
).
reshape
(
inps_shape
)
index_shape
=
[
2
,
2
]
index1
=
np
.
arange
(
self
.
numel
(
index_shape
)).
reshape
(
index_shape
)
index2
=
np
.
arange
(
self
.
numel
(
index_shape
)).
reshape
(
index_shape
)
+
2
value_shape
=
[
3
,
2
,
2
,
3
]
value_np
=
np
.
arange
(
self
.
numel
(
value_shape
),
dtype
=
'float32'
).
reshape
(
value_shape
)
+
100
index_mod1
=
(
index1
%
(
min
(
array
.
shape
))).
tolist
()
index_mod2
=
(
index2
%
(
min
(
array
.
shape
))).
tolist
()
program
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
program
):
x
=
paddle
.
static
.
data
(
name
=
'x'
,
shape
=
array
.
shape
,
dtype
=
'float32'
)
value
=
paddle
.
static
.
data
(
name
=
'value'
,
shape
=
value_np
.
shape
,
dtype
=
'float32'
)
index1
=
paddle
.
static
.
data
(
name
=
'index1'
,
shape
=
index1
.
shape
,
dtype
=
'int32'
)
index2
=
paddle
.
static
.
data
(
name
=
'index2'
,
shape
=
index2
.
shape
,
dtype
=
'int32'
)
y
=
x
[
index1
,
index2
]
place
=
paddle
.
fluid
.
CPUPlace
(
)
if
not
paddle
.
fluid
.
core
.
is_compiled_with_cuda
(
)
else
paddle
.
fluid
.
CUDAPlace
(
0
)
prog
=
paddle
.
static
.
default_main_program
()
exe
=
paddle
.
static
.
Executor
(
place
)
exe
.
run
(
paddle
.
static
.
default_startup_program
())
fetch_list
=
[
y
.
name
]
array2
=
array
.
copy
()
y2
=
array2
[
index_mod1
,
index_mod2
]
getitem_pp
=
exe
.
run
(
prog
,
feed
=
{
x
.
name
:
array
,
index1
.
name
:
index_mod1
,
index2
.
name
:
index_mod2
},
fetch_list
=
fetch_list
)
self
.
assertTrue
(
np
.
array_equal
(
y2
,
getitem_pp
[
0
]),
msg
=
'
\n
numpy:{},
\n
paddle:{}'
.
format
(
y2
,
getitem_pp
[
0
]))
def
test_dygraph_list_index_muti_dim
(
self
):
paddle
.
disable_static
()
inps_shape
=
[
3
,
4
,
5
]
array
=
np
.
arange
(
self
.
numel
(
inps_shape
),
dtype
=
'float32'
).
reshape
(
inps_shape
)
index_shape
=
[
2
,
2
]
index1
=
np
.
arange
(
self
.
numel
(
index_shape
)).
reshape
(
index_shape
)
index2
=
np
.
arange
(
self
.
numel
(
index_shape
)).
reshape
(
index_shape
)
+
2
value_shape
=
[
3
,
2
,
2
,
3
]
value_np
=
np
.
arange
(
self
.
numel
(
value_shape
),
dtype
=
'float32'
).
reshape
(
value_shape
)
+
100
index_mod1
=
(
index1
%
(
min
(
array
.
shape
))).
tolist
()
index_mod2
=
(
index2
%
(
min
(
array
.
shape
))).
tolist
()
x
=
paddle
.
to_tensor
(
array
)
index_t1
=
paddle
.
to_tensor
(
index_mod1
)
index_t2
=
paddle
.
to_tensor
(
index_mod2
)
y_np
=
array
[
index_t1
,
index_t2
]
y
=
x
[
index_t1
,
index_t2
]
self
.
assertTrue
(
np
.
array_equal
(
y
.
numpy
(),
y_np
))
def
run_setitem_list_index
(
self
,
array
,
index
,
value_np
):
x
=
paddle
.
static
.
data
(
name
=
'x'
,
shape
=
array
.
shape
,
dtype
=
'float32'
)
value
=
paddle
.
static
.
data
(
name
=
'value'
,
shape
=
value_np
.
shape
,
dtype
=
'float32'
)
x
[
index
]
=
value
y
=
x
place
=
paddle
.
fluid
.
CPUPlace
()
prog
=
paddle
.
static
.
default_main_program
()
exe
=
paddle
.
static
.
Executor
(
place
)
exe
.
run
(
paddle
.
static
.
default_startup_program
())
fetch_list
=
[
y
.
name
]
array2
=
array
.
copy
()
try
:
array2
[
index
]
=
value_np
except
:
with
self
.
assertRaises
(
ValueError
):
setitem_pp
=
exe
.
run
(
prog
,
feed
=
{
x
.
name
:
array
,
value
.
name
:
value_np
},
fetch_list
=
fetch_list
)
return
setitem_pp
=
exe
.
run
(
prog
,
feed
=
{
x
.
name
:
array
,
value
.
name
:
value_np
},
fetch_list
=
fetch_list
)
self
.
assertTrue
(
np
.
array_equal
(
array2
,
setitem_pp
[
0
]),
msg
=
'
\n
numpy:{},
\n
paddle:{}'
.
format
(
array2
,
setitem_pp
[
0
]))
def
test_static_graph_setitem_list_index
(
self
):
paddle
.
enable_static
()
# case 1:
inps_shape
=
[
3
,
4
,
5
,
2
,
3
]
array
=
np
.
arange
(
self
.
numel
(
inps_shape
),
dtype
=
'float32'
).
reshape
(
inps_shape
)
index_shape
=
[
3
,
3
,
1
,
2
]
index
=
np
.
arange
(
self
.
numel
(
index_shape
)).
reshape
(
index_shape
)
value_shape
=
inps_shape
[
3
:]
value_np
=
np
.
arange
(
self
.
numel
(
value_shape
),
dtype
=
'float32'
).
reshape
(
value_shape
)
+
100
for
_
in
range
(
3
):
program
=
paddle
.
static
.
Program
()
index_mod
=
(
index
%
(
min
(
array
.
shape
))).
tolist
()
with
paddle
.
static
.
program_guard
(
program
):
self
.
run_setitem_list_index
(
array
,
index_mod
,
value_np
)
array
=
array
[
0
]
index
=
index
[
0
]
# case 2:
inps_shape
=
[
3
,
4
,
5
,
4
,
3
]
array
=
np
.
arange
(
self
.
numel
(
inps_shape
),
dtype
=
'float32'
).
reshape
(
inps_shape
)
index_shape
=
[
4
,
3
,
2
,
2
]
index
=
np
.
arange
(
self
.
numel
(
index_shape
)).
reshape
(
index_shape
)
value_shape
=
[
3
]
value_np
=
np
.
arange
(
self
.
numel
(
value_shape
),
dtype
=
'float32'
).
reshape
(
value_shape
)
+
100
for
_
in
range
(
4
):
program
=
paddle
.
static
.
Program
()
index_mod
=
(
index
%
(
min
(
array
.
shape
))).
tolist
()
with
paddle
.
static
.
program_guard
(
program
):
self
.
run_setitem_list_index
(
array
,
index_mod
,
value_np
)
array
=
array
[
0
]
index
=
index
[
0
]
# case 3:
inps_shape
=
[
3
,
4
,
5
,
3
,
3
]
array
=
np
.
arange
(
self
.
numel
(
inps_shape
),
dtype
=
'float32'
).
reshape
(
inps_shape
)
index_shape
=
[
4
,
3
,
2
,
2
]
index
=
np
.
arange
(
self
.
numel
(
index_shape
)).
reshape
(
index_shape
)
value_shape
=
[
3
,
2
,
2
,
3
]
value_np
=
np
.
arange
(
self
.
numel
(
value_shape
),
dtype
=
'float32'
).
reshape
(
value_shape
)
+
100
index_mod
=
(
index
%
(
min
(
array
.
shape
))).
tolist
()
self
.
run_setitem_list_index
(
array
,
index_mod
,
value_np
)
def
test_static_graph_tensor_index_setitem_muti_dim
(
self
):
paddle
.
enable_static
()
inps_shape
=
[
3
,
4
,
5
,
4
]
array
=
np
.
arange
(
self
.
numel
(
inps_shape
),
dtype
=
'float32'
).
reshape
(
inps_shape
)
index_shape
=
[
2
,
3
,
4
]
index1
=
np
.
arange
(
self
.
numel
(
index_shape
),
dtype
=
'int32'
).
reshape
(
index_shape
)
index2
=
np
.
arange
(
self
.
numel
(
index_shape
),
dtype
=
'int32'
).
reshape
(
index_shape
)
+
2
value_shape
=
[
4
]
value_np
=
np
.
arange
(
self
.
numel
(
value_shape
),
dtype
=
'float32'
).
reshape
(
value_shape
)
+
100
for
_
in
range
(
3
):
index_mod1
=
index1
%
(
min
(
array
.
shape
))
index_mod2
=
index2
%
(
min
(
array
.
shape
))
array2
=
array
.
copy
()
array2
[
index_mod1
,
index_mod2
]
=
value_np
array3
=
array
.
copy
()
array3
[
index_mod1
]
=
value_np
program
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
program
):
x1
=
paddle
.
static
.
data
(
name
=
'x1'
,
shape
=
array
.
shape
,
dtype
=
'float32'
)
x2
=
paddle
.
static
.
data
(
name
=
'x2'
,
shape
=
array
.
shape
,
dtype
=
'float32'
)
value
=
paddle
.
static
.
data
(
name
=
'value'
,
shape
=
value_np
.
shape
,
dtype
=
'float32'
)
index_1
=
paddle
.
static
.
data
(
name
=
'index_1'
,
shape
=
index1
.
shape
,
dtype
=
'int32'
)
index_2
=
paddle
.
static
.
data
(
name
=
'index_2'
,
shape
=
index2
.
shape
,
dtype
=
'int32'
)
x1
[
index_1
,
index_2
]
=
value
x2
[
index_1
]
=
value
place
=
paddle
.
fluid
.
CPUPlace
(
)
if
not
paddle
.
fluid
.
core
.
is_compiled_with_cuda
(
)
else
paddle
.
fluid
.
CUDAPlace
(
0
)
prog
=
paddle
.
static
.
default_main_program
()
exe
=
paddle
.
static
.
Executor
(
place
)
exe
.
run
(
paddle
.
static
.
default_startup_program
())
fetch_list
=
[
x1
.
name
,
x2
.
name
]
setitem_pp
=
exe
.
run
(
prog
,
feed
=
{
x1
.
name
:
array
,
x2
.
name
:
array
,
value
.
name
:
value_np
,
index_1
.
name
:
index_mod1
,
index_2
.
name
:
index_mod2
},
fetch_list
=
fetch_list
)
self
.
assertTrue
(
np
.
array_equal
(
array2
,
setitem_pp
[
0
]),
msg
=
'
\n
numpy:{},
\n
paddle:{}'
.
format
(
array2
,
setitem_pp
[
0
]))
self
.
assertTrue
(
np
.
array_equal
(
array3
,
setitem_pp
[
1
]),
msg
=
'
\n
numpy:{},
\n
paddle:{}'
.
format
(
array3
,
setitem_pp
[
1
]))
array
=
array
[
0
]
index1
=
index1
[
0
]
index2
=
index2
[
0
]
def
test_static_graph_array_index_muti_dim
(
self
):
paddle
.
enable_static
()
inps_shape
=
[
3
,
4
,
5
,
4
]
array
=
np
.
arange
(
self
.
numel
(
inps_shape
),
dtype
=
'float32'
).
reshape
(
inps_shape
)
index_shape
=
[
2
,
3
,
4
]
index1
=
np
.
arange
(
self
.
numel
(
index_shape
),
dtype
=
'int32'
).
reshape
(
index_shape
)
index2
=
np
.
arange
(
self
.
numel
(
index_shape
),
dtype
=
'int32'
).
reshape
(
index_shape
)
+
2
for
_
in
range
(
3
):
index_mod1
=
index1
%
(
min
(
array
.
shape
))
index_mod2
=
index2
%
(
min
(
array
.
shape
))
array2
=
array
.
copy
()
array2
[
index_mod1
,
index_mod2
]
=
1
y_np1
=
array2
[
index_mod2
,
index_mod1
]
array3
=
array
.
copy
()
array3
[
index_mod1
]
=
2.5
y_np2
=
array3
[
index_mod2
]
program
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
program
):
x1
=
paddle
.
static
.
data
(
name
=
'x1'
,
shape
=
array
.
shape
,
dtype
=
'float32'
)
x2
=
paddle
.
static
.
data
(
name
=
'x2'
,
shape
=
array
.
shape
,
dtype
=
'float32'
)
x1
[
index_mod1
,
index_mod2
]
=
1
x2
[
index_mod1
]
=
2.5
y1
=
x1
[
index_mod2
,
index_mod1
]
y2
=
x2
[
index_mod2
]
place
=
paddle
.
fluid
.
CPUPlace
(
)
if
not
paddle
.
fluid
.
core
.
is_compiled_with_cuda
(
)
else
paddle
.
fluid
.
CUDAPlace
(
0
)
prog
=
paddle
.
static
.
default_main_program
()
exe
=
paddle
.
static
.
Executor
(
place
)
exe
.
run
(
paddle
.
static
.
default_startup_program
())
fetch_list
=
[
x1
.
name
,
x2
.
name
,
y1
.
name
,
y2
.
name
]
setitem_pp
=
exe
.
run
(
prog
,
feed
=
{
x1
.
name
:
array
,
x2
.
name
:
array
},
fetch_list
=
fetch_list
)
self
.
assertTrue
(
np
.
array_equal
(
array2
,
setitem_pp
[
0
]),
msg
=
'
\n
numpy:{},
\n
paddle:{}'
.
format
(
array2
,
setitem_pp
[
0
]))
self
.
assertTrue
(
np
.
array_equal
(
array3
,
setitem_pp
[
1
]),
msg
=
'
\n
numpy:{},
\n
paddle:{}'
.
format
(
array3
,
setitem_pp
[
1
]))
self
.
assertTrue
(
np
.
array_equal
(
y_np1
,
setitem_pp
[
2
]),
msg
=
'
\n
numpy:{},
\n
paddle:{}'
.
format
(
y_np1
,
setitem_pp
[
2
]))
self
.
assertTrue
(
np
.
array_equal
(
y_np2
,
setitem_pp
[
3
]),
msg
=
'
\n
numpy:{},
\n
paddle:{}'
.
format
(
y_np2
,
setitem_pp
[
3
]))
array
=
array
[
0
]
index1
=
index1
[
0
]
index2
=
index2
[
0
]
def
test_dygraph_array_index_muti_dim
(
self
):
paddle
.
disable_static
()
inps_shape
=
[
3
,
4
,
5
,
4
]
array
=
np
.
arange
(
self
.
numel
(
inps_shape
),
dtype
=
'float32'
).
reshape
(
inps_shape
)
index_shape
=
[
2
,
3
,
4
]
index1
=
np
.
arange
(
self
.
numel
(
index_shape
),
dtype
=
'int32'
).
reshape
(
index_shape
)
index2
=
np
.
arange
(
self
.
numel
(
index_shape
),
dtype
=
'int32'
).
reshape
(
index_shape
)
+
2
for
_
in
range
(
3
):
index_mod1
=
index1
%
(
min
(
array
.
shape
))
index_mod2
=
index2
%
(
min
(
array
.
shape
))
index_mod_t1
=
paddle
.
to_tensor
(
index_mod1
)
index_mod_t2
=
paddle
.
to_tensor
(
index_mod2
)
# 2 dim getitem
array1
=
array
.
copy
()
y_np1
=
array1
[
index_mod2
,
index_mod1
]
tensor1
=
paddle
.
to_tensor
(
array
)
y_t1
=
tensor1
[
index_mod_t2
,
index_mod_t1
]
self
.
assertTrue
(
np
.
array_equal
(
y_t1
.
numpy
(),
y_np1
),
msg
=
'
\n
numpy:{},
\n
paddle:{}'
.
format
(
y_np1
,
y_t1
.
numpy
()))
# 1 dim getitem
array2
=
array
.
copy
()
y_np2
=
array2
[
index_mod2
]
tensor2
=
paddle
.
to_tensor
(
array
)
y_t2
=
tensor2
[
index_mod_t2
]
self
.
assertTrue
(
np
.
array_equal
(
y_t2
.
numpy
(),
y_np2
),
msg
=
'
\n
numpy:{},
\n
paddle:{}'
.
format
(
y_np2
,
y_t2
.
numpy
()))
# 2 dim setitem
array1
=
array
.
copy
()
array1
[
index_mod1
,
index_mod2
]
=
1
tensor1
[
index_mod_t1
,
index_mod_t2
]
=
1
self
.
assertTrue
(
np
.
array_equal
(
tensor1
.
numpy
(),
array1
),
msg
=
'
\n
numpy:{},
\n
paddle:{}'
.
format
(
array1
,
tensor1
.
numpy
()))
# 1 dim setitem
array2
=
array
.
copy
()
array2
[
index_mod1
]
=
2.5
tensor2
[
index_mod_t1
]
=
2.5
self
.
assertTrue
(
np
.
array_equal
(
tensor2
.
numpy
(),
array2
),
msg
=
'
\n
numpy:{},
\n
paddle:{}'
.
format
(
array2
,
tensor2
.
numpy
()))
array
=
array
[
0
]
index1
=
index1
[
0
]
index2
=
index2
[
0
]
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/variable_index.py
浏览文件 @
e7df47ec
...
@@ -16,10 +16,172 @@ import sys
...
@@ -16,10 +16,172 @@ import sys
import
numpy
as
np
import
numpy
as
np
from
.
import
unique_name
from
.
import
unique_name
from
.
import
core
from
.
import
core
import
paddle
MAX_INTEGER
=
2
**
31
-
1
MAX_INTEGER
=
2
**
31
-
1
def
is_list_tuple
(
index
,
contain_type
):
def
_is_list_tuple
(
item
):
if
not
(
isinstance
(
item
,
(
list
,
tuple
))
or
type
(
item
)
==
contain_type
):
return
False
if
isinstance
(
item
,
(
tuple
,
list
)):
for
s
in
item
:
if
not
_is_list_tuple
(
s
):
return
False
return
True
if
not
isinstance
(
index
,
(
tuple
,
list
)):
return
False
for
s
in
index
:
if
not
_is_list_tuple
(
s
):
return
False
return
True
def
is_one_dim_list
(
index
,
contain_type
):
if
isinstance
(
index
,
list
):
for
i
in
index
:
if
not
isinstance
(
i
,
contain_type
):
return
False
else
:
return
False
return
True
def
get_list_index_shape
(
var_dims
,
index_dims
):
var_dims_size
=
len
(
var_dims
)
index_dims_size
=
len
(
index_dims
)
out_dims_size
=
var_dims_size
-
index_dims
[
0
]
+
index_dims_size
-
1
out_dims_shape
=
[
1
]
*
out_dims_size
out_dims_shape
[:
index_dims_size
-
1
]
=
index_dims
[
1
:]
out_dims_shape
[
index_dims_size
-
1
:]
=
var_dims
[
index_dims
[
0
]:]
return
out_dims_shape
class
SliceInfo
:
def
__init__
(
self
):
self
.
pre_shape
=
None
self
.
indexes
=
[]
def
update
(
self
,
index
):
if
is_list_tuple
(
index
,
int
)
or
isinstance
(
index
,
(
paddle
.
fluid
.
Variable
,
np
.
ndarray
)):
# convert index to Tensor
if
not
isinstance
(
index
,
paddle
.
fluid
.
Variable
):
index
=
paddle
.
assign
(
index
)
self
.
indexes
.
append
(
index
)
if
self
.
pre_shape
is
None
:
self
.
pre_shape
=
index
.
shape
else
:
if
self
.
pre_shape
!=
index
.
shape
:
# broadcast
cur_shape
=
paddle
.
broadcast_shape
(
self
.
pre_shape
,
index
.
shape
)
for
i
in
range
(
len
(
self
.
indexes
)):
self
.
indexes
[
i
]
=
paddle
.
broadcast_to
(
self
.
indexes
[
i
],
cur_shape
)
self
.
pre_shape
=
self
.
indexes
[
-
1
].
shape
else
:
raise
ValueError
(
"Index should be list/tuple of int or Tensor, but received {}."
.
format
(
index
))
def
shape_stride
(
self
,
shape
):
s
=
[
1
]
*
len
(
shape
)
for
i
in
range
(
len
(
shape
)
-
2
,
-
1
,
-
1
):
s
[
i
]
=
shape
[
i
+
1
]
*
s
[
i
+
1
]
return
s
def
numel
(
self
,
shape
):
return
reduce
(
lambda
x
,
y
:
x
*
y
,
shape
)
def
get_offset_stride
(
self
,
tensor_shape
):
for
index
in
self
.
indexes
:
if
not
isinstance
(
index
,
paddle
.
fluid
.
Variable
):
raise
ValueError
(
"only support list/tensor index, but received {}."
.
format
(
type
(
index
)))
if
len
(
self
.
indexes
)
<=
len
(
tensor_shape
)
or
len
(
self
.
indexes
)
==
1
:
shape
=
paddle
.
stack
(
self
.
indexes
)
axes
=
list
(
range
(
1
,
len
(
self
.
pre_shape
)
+
1
))
+
[
0
,
]
else
:
raise
ValueError
(
"too many indices for tensor: tensor is {}-dimensional, but {} were indexed"
.
format
(
len
(
tensor_shape
),
self
.
pre_shape
[
0
]))
shape_transpose
=
paddle
.
transpose
(
shape
,
axes
)
return
shape_transpose
def
get_item
(
self
,
tensor
):
shape_transpose
=
self
.
get_offset_stride
(
tensor
.
shape
)
index
=
paddle
.
assign
(
shape_transpose
)
return
paddle
.
gather_nd
(
tensor
,
index
)
def
set_item
(
self
,
tensor_origin
,
value
):
if
not
isinstance
(
value
,
paddle
.
fluid
.
Variable
):
value
=
paddle
.
assign
(
value
)
tensor_type
=
None
if
tensor_origin
.
dtype
in
[
core
.
VarDesc
.
VarType
.
FP32
,
core
.
VarDesc
.
VarType
.
FP64
]:
tensor
=
tensor_origin
else
:
tensor_type
=
tensor_origin
.
dtype
tensor
=
tensor_origin
.
astype
(
core
.
VarDesc
.
VarType
.
FP32
)
if
value
.
dtype
!=
tensor
.
dtype
:
value
=
value
.
astype
(
tensor
.
dtype
)
shape_transpose
=
self
.
get_offset_stride
(
tensor_origin
.
shape
)
index
=
paddle
.
assign
(
shape_transpose
)
gather_tensor_shape
=
get_list_index_shape
(
tensor
.
shape
,
[
len
(
self
.
indexes
),
]
+
list
(
self
.
indexes
[
-
1
].
shape
))
value_dims_bd
=
[
1
,
]
*
len
(
gather_tensor_shape
)
value_dims_bd
[
-
len
(
value
.
shape
):]
=
list
(
value
.
shape
)
for
i
in
range
(
len
(
gather_tensor_shape
)):
if
not
(
value_dims_bd
[
i
]
==
gather_tensor_shape
[
i
]
or
value_dims_bd
[
i
]
==
1
):
raise
ValueError
(
"{} can not broadcast into {}"
.
format
(
value
.
shape
,
gather_tensor_shape
))
value_broadcast
=
paddle
.
broadcast_to
(
value
,
gather_tensor_shape
)
value_1d
=
value_broadcast
.
reshape
([
-
1
]
+
gather_tensor_shape
[
len
(
index
.
shape
)
-
1
:])
index_1d
=
index
.
reshape
([
-
1
,
index
.
shape
[
-
1
]])
tensor_stride
=
paddle
.
assign
(
self
.
shape_stride
(
tensor
.
shape
[:
index
.
shape
[
-
1
]]))
inds
=
[]
for
i
in
range
(
index_1d
.
shape
[
0
]):
temp
=
(
index_1d
[
i
]
*
tensor_stride
).
sum
()
inds
.
append
(
temp
)
index_1d
=
paddle
.
stack
(
inds
).
reshape
([
-
1
])
t_reshape
=
tensor
.
reshape
([
-
1
]
+
list
(
tensor
.
shape
[
index
.
shape
[
-
1
]:]))
out
=
paddle
.
scatter
(
t_reshape
,
index_1d
,
value_1d
)
if
tensor_type
is
not
None
:
out
=
out
.
astype
(
tensor_type
)
tensor_origin
[:]
=
out
.
reshape
(
tensor_origin
.
shape
)
return
tensor_origin
def
replace_ellipsis
(
var
,
item
):
def
replace_ellipsis
(
var
,
item
):
from
.framework
import
Variable
from
.framework
import
Variable
# Use slice(None) to replace Ellipsis.
# Use slice(None) to replace Ellipsis.
...
@@ -32,7 +194,9 @@ def replace_ellipsis(var, item):
...
@@ -32,7 +194,9 @@ def replace_ellipsis(var, item):
item
=
list
(
item
)
item
=
list
(
item
)
# Remove Variable to skip bug when counting Ellipsis
# Remove Variable to skip bug when counting Ellipsis
item_remove_var
=
[
ele
for
ele
in
item
if
not
isinstance
(
ele
,
Variable
)]
item_remove_var
=
[
ele
for
ele
in
item
if
not
isinstance
(
ele
,
(
Variable
,
np
.
ndarray
))
]
ell_count
=
item_remove_var
.
count
(
Ellipsis
)
ell_count
=
item_remove_var
.
count
(
Ellipsis
)
if
ell_count
==
0
:
if
ell_count
==
0
:
return
item
return
item
...
@@ -99,6 +263,9 @@ def _getitem_impl_(var, item):
...
@@ -99,6 +263,9 @@ def _getitem_impl_(var, item):
Sliced variable
Sliced variable
"""
"""
from
.framework
import
default_main_program
,
Variable
from
.framework
import
default_main_program
,
Variable
if
isinstance
(
item
,
list
):
if
not
is_one_dim_list
(
item
,
int
):
item
=
tuple
(
item
)
if
not
isinstance
(
item
,
tuple
):
if
not
isinstance
(
item
,
tuple
):
item
=
(
item
,
)
item
=
(
item
,
)
...
@@ -113,6 +280,7 @@ def _getitem_impl_(var, item):
...
@@ -113,6 +280,7 @@ def _getitem_impl_(var, item):
use_strided_slice
=
False
use_strided_slice
=
False
item
,
none_axes
=
replace_none
(
item
)
item
,
none_axes
=
replace_none
(
item
)
item
=
replace_ellipsis
(
var
,
item
)
item
=
replace_ellipsis
(
var
,
item
)
slice_info
=
SliceInfo
()
for
dim
,
slice_item
in
enumerate
(
item
):
for
dim
,
slice_item
in
enumerate
(
item
):
if
is_integer_or_scalar_tensor
(
slice_item
):
if
is_integer_or_scalar_tensor
(
slice_item
):
...
@@ -151,6 +319,11 @@ def _getitem_impl_(var, item):
...
@@ -151,6 +319,11 @@ def _getitem_impl_(var, item):
elif
isinstance
(
slice_item
,
list
):
elif
isinstance
(
slice_item
,
list
):
all_bool
=
True
all_bool
=
True
if
is_list_tuple
(
slice_item
,
int
):
slice_info
.
update
(
slice_item
)
continue
for
i
in
slice_item
:
for
i
in
slice_item
:
if
type
(
i
)
is
int
:
if
type
(
i
)
is
int
:
all_bool
=
False
all_bool
=
False
...
@@ -188,35 +361,43 @@ def _getitem_impl_(var, item):
...
@@ -188,35 +361,43 @@ def _getitem_impl_(var, item):
idx
=
assign
(
np
.
array
(
slice_item
).
astype
(
"int32"
))
idx
=
assign
(
np
.
array
(
slice_item
).
astype
(
"int32"
))
return
index_select
(
var
,
index
=
idx
,
axis
=
0
)
return
index_select
(
var
,
index
=
idx
,
axis
=
0
)
elif
isinstance
(
slice_item
,
Variable
):
elif
isinstance
(
slice_item
,
np
.
ndarray
):
if
len
(
item
)
!=
1
:
slice_info
.
update
(
slice_item
)
raise
IndexError
(
continue
"When index contains a Tensor, its length must be 1, but received {}."
.
elif
isinstance
(
slice_item
,
(
Variable
)):
format
(
len
(
item
)))
if
len
(
item
)
==
1
:
from
..tensor
import
index_select
,
gather_nd
from
..tensor
import
index_select
,
gather_nd
from
.layers.nn
import
where
from
.layers.nn
import
where
if
slice_item
.
dtype
==
core
.
VarDesc
.
VarType
.
BOOL
:
if
slice_item
.
dtype
==
paddle
.
bool
:
if
len
(
slice_item
.
shape
)
>
len
(
var
.
shape
):
if
len
(
slice_item
.
shape
)
>
len
(
var
.
shape
):
raise
IndexError
(
"The dims of bool index doesn't match indexed array, "
"the dims of bool index except to be equal or less "
"than {}, but received {}."
.
format
(
len
(
var
.
shape
),
len
(
slice_item
.
shape
)))
for
i
,
dim_len
in
enumerate
(
slice_item
.
shape
):
if
dim_len
!=
var
.
shape
[
i
]:
raise
IndexError
(
raise
IndexError
(
"The dimension of bool index doesn't match indexed array along "
\
"The dims of bool index doesn't match indexed array, "
"dimension {}, the target dimension is {}, but received {}."
.
"the dims of bool index except to be equal or less "
format
(
i
,
var
.
shape
[
i
],
dim_len
))
"than {}, but received {}."
.
format
(
bool_2_idx
=
where
(
slice_item
==
True
)
len
(
var
.
shape
),
len
(
slice_item
.
shape
)))
return
gather_nd
(
var
,
bool_2_idx
)
for
i
,
dim_len
in
enumerate
(
slice_item
.
shape
):
return
index_select
(
var
,
index
=
slice_item
,
axis
=
0
)
if
dim_len
!=
var
.
shape
[
i
]:
raise
IndexError
(
"The dimension of bool index doesn't match indexed array along "
\
"dimension {}, the target dimension is {}, but received {}."
.
format
(
i
,
var
.
shape
[
i
],
dim_len
))
bool_2_idx
=
where
(
slice_item
==
True
)
return
gather_nd
(
var
,
bool_2_idx
)
else
:
if
len
(
slice_item
.
shape
)
==
1
:
return
index_select
(
var
,
index
=
slice_item
,
axis
=
0
)
else
:
slice_info
.
update
(
slice_item
)
continue
else
:
slice_info
.
update
(
slice_item
)
continue
else
:
else
:
raise
IndexError
(
raise
IndexError
(
"Valid index accept int or slice or ellipsis, but received {}."
.
"Valid index accept int or slice or ellipsis
or list
, but received {}."
.
format
(
slice_item
))
format
(
slice_item
))
axes
.
append
(
dim
)
axes
.
append
(
dim
)
...
@@ -225,6 +406,13 @@ def _getitem_impl_(var, item):
...
@@ -225,6 +406,13 @@ def _getitem_impl_(var, item):
steps
.
append
(
step
)
steps
.
append
(
step
)
use_strided_slice
=
True
if
step
!=
1
else
use_strided_slice
use_strided_slice
=
True
if
step
!=
1
else
use_strided_slice
if
slice_info
.
indexes
:
if
len
(
slice_info
.
indexes
)
!=
len
(
item
):
raise
IndexError
(
"Valid index accept int or slice or ellipsis or list, but received {}."
.
format
(
item
))
return
slice_info
.
get_item
(
var
)
inputs
=
{
'Input'
:
[
var
]}
inputs
=
{
'Input'
:
[
var
]}
attrs
=
{
attrs
=
{
'axes'
:
axes
,
'axes'
:
axes
,
...
@@ -298,7 +486,9 @@ def _setitem_impl_(var, item, value):
...
@@ -298,7 +486,9 @@ def _setitem_impl_(var, item, value):
from
.framework
import
default_main_program
,
Variable
from
.framework
import
default_main_program
,
Variable
inputs
=
{
'Input'
:
var
}
inputs
=
{
'Input'
:
var
}
if
isinstance
(
item
,
list
):
if
not
is_one_dim_list
(
item
,
int
):
item
=
tuple
(
item
)
# 1. Parse item
# 1. Parse item
if
not
isinstance
(
item
,
tuple
):
if
not
isinstance
(
item
,
tuple
):
item
=
(
item
,
)
item
=
(
item
,
)
...
@@ -311,7 +501,7 @@ def _setitem_impl_(var, item, value):
...
@@ -311,7 +501,7 @@ def _setitem_impl_(var, item, value):
item
,
none_axes
=
replace_none
(
item
)
item
,
none_axes
=
replace_none
(
item
)
item
=
replace_ellipsis
(
var
,
item
)
item
=
replace_ellipsis
(
var
,
item
)
slice_info
=
SliceInfo
()
dim
=
0
dim
=
0
for
_
,
slice_item
in
enumerate
(
item
):
for
_
,
slice_item
in
enumerate
(
item
):
if
is_integer_or_scalar_tensor
(
slice_item
):
if
is_integer_or_scalar_tensor
(
slice_item
):
...
@@ -319,6 +509,16 @@ def _setitem_impl_(var, item, value):
...
@@ -319,6 +509,16 @@ def _setitem_impl_(var, item, value):
start
=
slice_item
start
=
slice_item
end
=
slice_item
+
1
if
slice_item
!=
-
1
else
MAX_INTEGER
end
=
slice_item
+
1
if
slice_item
!=
-
1
else
MAX_INTEGER
step
=
1
step
=
1
elif
isinstance
(
slice_item
,
list
):
if
not
is_list_tuple
(
slice_item
,
int
):
raise
TypeError
(
"Only support int or list in index list. But revceived {}."
.
format
(
slice_item
))
slice_info
.
update
(
slice_item
)
continue
elif
isinstance
(
slice_item
,
(
Variable
,
np
.
ndarray
)):
slice_info
.
update
(
slice_item
)
continue
elif
isinstance
(
slice_item
,
slice
):
elif
isinstance
(
slice_item
,
slice
):
start
=
slice_item
.
start
start
=
slice_item
.
start
...
@@ -358,7 +558,12 @@ def _setitem_impl_(var, item, value):
...
@@ -358,7 +558,12 @@ def _setitem_impl_(var, item, value):
steps
.
append
(
step
)
steps
.
append
(
step
)
dim
+=
1
dim
+=
1
if
slice_info
.
indexes
:
if
len
(
slice_info
.
indexes
)
!=
len
(
item
):
raise
IndexError
(
"Valid index accept int or slice or ellipsis or list, but received {}."
.
format
(
item
))
return
slice_info
.
set_item
(
var
,
value
)
attrs
=
{
attrs
=
{
'axes'
:
axes
,
'axes'
:
axes
,
'starts'
:
starts
,
'starts'
:
starts
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录