Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
79058d35
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
79058d35
编写于
6月 12, 2020
作者:
H
huangdongrun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add support for parameter
support for tensor setitem add support for tensor assgin
上级
c55b81e9
变更
7
展开全部
隐藏空白更改
内联
并排
Showing
7 changed file
with
701 addition
and
362 deletion
+701
-362
mindspore/ccsrc/ir/tensor.cc
mindspore/ccsrc/ir/tensor.cc
+17
-0
mindspore/ccsrc/ir/tensor.h
mindspore/ccsrc/ir/tensor.h
+3
-0
mindspore/common/parameter.py
mindspore/common/parameter.py
+2
-0
mindspore/common/tensor.py
mindspore/common/tensor.py
+2
-0
mindspore/ops/composite/multitype_ops/_compile_utils.py
mindspore/ops/composite/multitype_ops/_compile_utils.py
+345
-42
mindspore/ops/composite/multitype_ops/setitem_impl.py
mindspore/ops/composite/multitype_ops/setitem_impl.py
+12
-216
tests/st/pynative/test_tensor_index.py
tests/st/pynative/test_tensor_index.py
+320
-104
未找到文件。
mindspore/ccsrc/ir/tensor.cc
浏览文件 @
79058d35
...
...
@@ -92,6 +92,10 @@ Tensor &Tensor::operator=(const Tensor &tensor) {
}
return
*
this
;
}
Tensor
&
Tensor
::
AssignValue
(
const
Tensor
&
tensor
)
{
*
this
=
tensor
;
return
*
this
;
}
bool
Tensor
::
operator
==
(
const
Tensor
&
tensor
)
const
{
return
(
MetaTensor
::
operator
==
(
tensor
)
&&
data_
==
tensor
.
data_
);
...
...
@@ -470,6 +474,19 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
>>> data.set_dtype(mindspore.int32)
mindspore.int32
)mydelimiter"
)
.
def
(
"assign_value"
,
&
Tensor
::
AssignValue
,
R"mydelimiter(
Assign another tensor value to this.
Arg:
value (:class:`mindspore.tensor`): The value tensor.
Examples:
>>> data = mindspore.Tensor(np.ones((1, 2), np.float32))
>>> data2 = mindspore.Tensor(np.ones((2, 2), np.float32))
>>> data.assign_value(data2)
>>> data.shape
(2, 2)
)mydelimiter"
)
.
def
(
"__str__"
,
&
Tensor
::
ToString
)
.
def
(
"__repr__"
,
&
Tensor
::
ToStringRepr
)
.
def
(
py
::
pickle
(
...
...
mindspore/ccsrc/ir/tensor.h
浏览文件 @
79058d35
...
...
@@ -173,6 +173,9 @@ class Tensor : public MetaTensor {
// It is different from 'operator==' which just compare shape/type/address, it do real value comparison.
bool
ValueEqual
(
const
Tensor
&
other
)
const
;
// assgin value to this tensor
Tensor
&
AssignValue
(
const
Tensor
&
tensor
);
bool
operator
==
(
const
Value
&
other
)
const
override
{
if
(
other
.
isa
<
Tensor
>
())
{
auto
other_
=
static_cast
<
const
Tensor
&>
(
other
);
...
...
mindspore/common/parameter.py
浏览文件 @
79058d35
...
...
@@ -203,6 +203,8 @@ class Parameter:
return
self
.
default_input
/
other
def
__setitem__
(
self
,
index
,
value
):
default_input
=
self
.
default_input
default_input
[
index
]
=
value
return
self
def
set_parameter_data
(
self
,
data
):
...
...
mindspore/common/tensor.py
浏览文件 @
79058d35
...
...
@@ -150,6 +150,8 @@ class Tensor(Tensor_):
return
out
def
__setitem__
(
self
,
index
,
value
):
out
=
tensor_operator_registry
.
get
(
'__setitem__'
)(
self
,
index
,
value
)
self
.
assign_value
(
out
)
return
self
def
__gt__
(
self
,
other
):
...
...
mindspore/ops/composite/multitype_ops/_compile_utils.py
浏览文件 @
79058d35
此差异已折叠。
点击以展开。
mindspore/ops/composite/multitype_ops/setitem_impl.py
浏览文件 @
79058d35
...
...
@@ -16,10 +16,8 @@
"""Implementation for setitem."""
from
.
import
_compile_utils
as
compile_utils
from
.
import
_constexpr_utils
as
const_utils
from
...
import
functional
as
F
from
...composite
import
base
from
....common
import
dtype
as
mstype
setitem
=
base
.
MultitypeFuncGraph
(
'setitem'
)
...
...
@@ -139,11 +137,7 @@ def _tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
Outputs:
Tensor, element type and shape is same as data.
"""
index_dtype
=
F
.
dtype
(
index
)
tensor_dtype
=
const_utils
.
get_index_tensor_dtype
(
index_dtype
)
if
tensor_dtype
==
const_utils
.
INT_
:
return
_tensor_setitem_by_int_tensor_with_tensor
(
data
,
index
,
value_tensor
)
return
_tensor_setitem_by_bool_tensor_with_tensor
(
data
,
index
,
value_tensor
)
return
compile_utils
.
tensor_setitem_by_tensor_with_tensor
(
data
,
index
,
value_tensor
)
@
setitem
.
register
(
"Tensor"
,
"Tensor"
,
"Number"
)
...
...
@@ -166,11 +160,7 @@ def _tensor_setitem_by_tensor_with_number(data, index, value):
Outputs:
Tensor, element type and shape is same as data.
"""
index_dtype
=
F
.
dtype
(
index
)
tensor_dtype
=
const_utils
.
get_index_tensor_dtype
(
index_dtype
)
if
tensor_dtype
==
const_utils
.
BOOL_
:
return
_tensor_setitem_by_bool_tensor_with_scalar
(
data
,
index
,
value
)
return
_tensor_setitem_by_int_tensor_with_scalar
(
data
,
index
,
value
)
return
compile_utils
.
tensor_setitem_by_tensor_with_number
(
data
,
index
,
value
)
@
setitem
.
register
(
"Tensor"
,
"Tuple"
,
"Number"
)
...
...
@@ -191,24 +181,7 @@ def _tensor_setitem_by_tuple_with_number(data, tuple_index, value):
Outputs:
Tensor, element type and shape is same as data.
"""
indexes_types
=
compile_utils
.
hyper_map
(
F
.
typeof
,
tuple_index
)
index_elements_type
=
const_utils
.
tuple_index_elements_type
(
indexes_types
,
const_utils
.
TENSOR_SETITEM
)
if
index_elements_type
==
const_utils
.
NO_TENSOR
:
return
_tensor_assgin_number
(
data
,
tuple_index
,
value
)
if
index_elements_type
==
const_utils
.
ALL_TENSOR
:
indices
=
compile_utils
.
generate_indices_from_tuple_of_tensor
(
data
,
tuple_index
,
const_utils
.
TENSOR_SETITEM
)
else
:
indices
=
compile_utils
.
generate_indices_from_tuple_of_mixed_tensors
(
data
,
tuple_index
,
const_utils
.
TENSOR_SETITEM
)
updates
=
compile_utils
.
generate_updates_from_scalar
(
data
,
indices
,
value
,
const_utils
.
SET_ITEM_BY_TUPLE_OF_TENSOR
)
return
F
.
scatter_nd_update
(
data
,
indices
,
updates
)
return
compile_utils
.
tensor_setitem_by_tuple_with_number
(
data
,
tuple_index
,
value
)
@
setitem
.
register
(
"Tensor"
,
"Tuple"
,
"Tensor"
)
...
...
@@ -229,24 +202,7 @@ def _tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
Outputs:
Tensor, element type and shape is same as data.
"""
indexes_types
=
compile_utils
.
hyper_map
(
F
.
typeof
,
tuple_index
)
index_elements_type
=
const_utils
.
tuple_index_elements_type
(
indexes_types
,
const_utils
.
TENSOR_SETITEM
)
if
index_elements_type
==
const_utils
.
NO_TENSOR
:
return
_tensor_assgin_tensor
(
data
,
tuple_index
,
value
)
if
index_elements_type
==
const_utils
.
ALL_TENSOR
:
indices
=
compile_utils
.
generate_indices_from_tuple_of_tensor
(
data
,
tuple_index
,
const_utils
.
TENSOR_SETITEM
)
else
:
indices
=
compile_utils
.
generate_indices_from_tuple_of_mixed_tensors
(
data
,
tuple_index
,
const_utils
.
TENSOR_SETITEM
)
updates
=
compile_utils
.
generate_updates_from_tensor
(
data
,
indices
,
value
,
const_utils
.
SET_ITEM_BY_TUPLE_OF_TENSOR
)
return
F
.
scatter_nd_update
(
data
,
indices
,
updates
)
return
compile_utils
.
tensor_setitem_by_tuple_with_tensor
(
data
,
tuple_index
,
value
)
@
setitem
.
register
(
"Tensor"
,
"Tuple"
,
"Tuple"
)
...
...
@@ -268,22 +224,7 @@ def _tensor_setitem_by_tuple_with_tuple(data, tuple_index, value):
Outputs:
Tensor, element type and shape is same as data.
"""
indexes_types
=
compile_utils
.
hyper_map
(
F
.
typeof
,
tuple_index
)
index_elements_type
=
const_utils
.
tuple_index_elements_type
(
indexes_types
,
const_utils
.
TENSOR_SETITEM
)
if
index_elements_type
==
const_utils
.
ALL_TENSOR
:
indices
=
compile_utils
.
generate_indices_from_tuple_of_tensor
(
data
,
tuple_index
,
const_utils
.
TENSOR_SETITEM
)
else
:
indices
=
compile_utils
.
generate_indices_from_tuple_of_mixed_tensors
(
data
,
tuple_index
,
const_utils
.
TENSOR_SETITEM
)
updates
=
compile_utils
.
generate_updates_from_tuple
(
data
,
indices
,
value
,
const_utils
.
SET_ITEM_BY_TUPLE_OF_TENSOR
)
return
F
.
scatter_nd_update
(
data
,
indices
,
updates
)
return
compile_utils
.
tensor_setitem_by_tuple_with_tuple
(
data
,
tuple_index
,
value
)
@
setitem
.
register
(
"Tensor"
,
"Tensor"
,
"Tuple"
)
...
...
@@ -299,12 +240,7 @@ def _tensor_setitem_by_tensor_v2(data, index, value):
Outputs:
Tensor, element type and shape is same as data.
"""
index_dtype
=
F
.
dtype
(
index
)
check_dtype
=
const_utils
.
check_index_tensor_dtype
(
index_dtype
,
const_utils
.
TENSOR_SETITEM
)
result
=
None
if
check_dtype
:
result
=
_tensor_setitem_by_tensor_with_tuple
(
data
,
index
,
value
)
return
result
return
compile_utils
.
tensor_setitem_by_tensor_with_tuple
(
data
,
index
,
value
)
@
setitem
.
register
(
"Tensor"
,
"Slice"
,
"Tensor"
)
...
...
@@ -326,7 +262,7 @@ def _tensor_setitem_with_slice_v3(data, input_slice, value):
Outputs:
Tensor, element type and shape is same as data.
"""
return
_tensor_assgin
_tensor
(
data
,
input_slice
,
value
)
return
compile_utils
.
tensor_setitem_by_slice_with
_tensor
(
data
,
input_slice
,
value
)
@
setitem
.
register
(
"Tensor"
,
"Slice"
,
"Number"
)
...
...
@@ -348,168 +284,28 @@ def _tensor_setitem_with_slice_v1(data, input_slice, value):
Outputs:
Tensor, element type and shape is same as data.
"""
return
_tensor_assgin_number
(
data
,
input_slice
,
value
)
def
_tensor_assgin_number
(
data
,
input_slice
,
value
):
"""Givens a scalar assign to tensor by slice"""
check_result
=
const_utils
.
check_tensor_setitem_index
(
input_slice
)
result
=
None
if
check_result
:
data_shape
=
F
.
shape
(
data
)
indices
=
const_utils
.
slice2indices
(
input_slice
,
data_shape
)
is_tuple_int
=
const_utils
.
tuple_element_is_int
(
input_slice
)
if
is_tuple_int
:
indices
=
const_utils
.
integer_to_indices
(
input_slice
,
data_shape
)
result
=
_tensor_indices_number
(
data
,
data_shape
,
input_slice
,
indices
,
value
)
return
result
return
compile_utils
.
tensor_setitem_by_slice_with_number
(
data
,
input_slice
,
value
)
@
setitem
.
register
(
"Tensor"
,
"Number"
,
"Number"
)
def
_tensor_setitem_with_int_v1
(
data
,
index
,
value
):
"""Syntax: A[1] = 3"""
data_shape
=
F
.
shape
(
data
)
indices
=
const_utils
.
integer_to_indices
(
index
,
data_shape
)
return
_tensor_indices_number
(
data
,
data_shape
,
index
,
indices
,
value
)
return
compile_utils
.
tensor_setitem_by_number_with_number
(
data
,
index
,
value
)
@
setitem
.
register
(
"Tensor"
,
"Number"
,
"Tensor"
)
def
_tensor_setitem_with_int_v2
(
data
,
index
,
value
):
"""Syntax: A[1] = Tensor"""
data_shape
=
F
.
shape
(
data
)
indices
=
const_utils
.
integer_to_indices
(
index
,
data_shape
)
return
_tensor_indices_tensor
(
data
,
data_shape
,
index
,
indices
,
value
)
return
compile_utils
.
tensor_setitem_by_number_with_tensor
(
data
,
index
,
value
)
@
setitem
.
register
(
"Tensor"
,
"Ellipsis"
,
"Number"
)
def
_tensor_setitem_with_ellipsis_v1
(
data
,
index
,
value
):
"""Syntax: A[...] = number."""
data_shape
=
F
.
shape
(
data
)
data_dtype
=
F
.
dtype
(
data
)
return
F
.
fill
(
data_dtype
,
data_shape
,
value
)
return
compile_utils
.
tensor_setitem_by_ellipsis_with_number
(
data
,
index
,
value
)
@
setitem
.
register
(
"Tensor"
,
"Ellipsis"
,
"Tensor"
)
def
_tensor_setitem_with_ellipsis_v2
(
data
,
index
,
value
):
"""Syntax: A[...] = Tensor."""
result
=
None
data_shape
=
F
.
shape
(
data
)
data_dtype
=
F
.
dtype
(
data
)
data_size
=
F
.
size
(
data
)
value_shape
=
F
.
shape
(
value
)
value_size
=
F
.
size
(
value
)
check_result
=
const_utils
.
check_ellipsis_shape_size
(
data_shape
,
value_shape
,
data_size
,
value_size
)
if
check_result
:
if
data_size
==
value_size
:
result
=
F
.
reshape
(
value
,
data_shape
)
result
=
F
.
cast
(
result
,
data_dtype
)
elif
value_size
==
1
:
param1
=
F
.
fill
(
data_dtype
,
data_shape
,
1
)
param2
=
F
.
cast
(
value
,
data_dtype
)
result
=
F
.
tensor_mul
(
param1
,
param2
)
return
result
def
_tensor_assgin_tensor
(
data
,
input_slice
,
value
):
"""Assigns a tensor value to the tensor by slice."""
result
=
None
check_result
=
const_utils
.
check_tensor_setitem_index
(
input_slice
)
if
check_result
:
data_shape
=
F
.
shape
(
data
)
indices
=
const_utils
.
slice2indices
(
input_slice
,
data_shape
)
is_tuple_int
=
const_utils
.
tuple_element_is_int
(
input_slice
)
if
is_tuple_int
:
indices
=
const_utils
.
integer_to_indices
(
input_slice
,
data_shape
)
result
=
_tensor_indices_tensor
(
data
,
data_shape
,
input_slice
,
indices
,
value
)
return
result
def
_tensor_indices_tensor
(
data
,
data_shape
,
index
,
indices
,
value
):
"""Assigns a tensor value to the tensor."""
data_size
=
F
.
size
(
data
)
data_dtype
=
F
.
dtype
(
data
)
indices_size
=
F
.
size
(
indices
)
indices_size
=
const_utils
.
check_indices
(
indices_size
,
index
)
update
=
F
.
fill
(
mstype
.
int32
,
(
indices_size
,),
1
)
condition_1d
=
F
.
scatter_nd
(
indices
,
update
,
(
data_size
,))
condition
=
F
.
reshape
(
condition_1d
,
data_shape
)
condition
=
F
.
cast
(
condition
,
mstype
.
bool_
)
value_fill
=
None
value_size
=
F
.
size
(
value
)
value_size
=
const_utils
.
check_indices_value_size
(
indices_size
,
value_size
)
if
value_size
==
1
:
value_fill
=
F
.
fill
(
data_dtype
,
(
indices_size
,),
1
)
value
=
F
.
cast
(
value
,
data_dtype
)
value_fill
=
F
.
tensor_mul
(
value_fill
,
value
)
elif
value_size
>
1
:
value_fill
=
F
.
reshape
(
value
,
(
indices_size
,))
value_1d
=
F
.
scatter_nd
(
indices
,
value_fill
,
(
data_size
,))
u
=
F
.
reshape
(
value_1d
,
data_shape
)
return
F
.
select
(
condition
,
u
,
data
)
def
_tensor_indices_number
(
data
,
data_shape
,
index
,
indices
,
value
):
"""Assigns a scalar value to the tensor."""
data_size
=
F
.
size
(
data
)
data_dtype
=
F
.
dtype
(
data
)
indices_size
=
F
.
size
(
indices
)
indices_size
=
const_utils
.
check_indices
(
indices_size
,
index
)
update
=
F
.
fill
(
mstype
.
int32
,
(
indices_size
,),
1
)
condition_1d
=
F
.
scatter_nd
(
indices
,
update
,
(
data_size
,))
condition
=
F
.
reshape
(
condition_1d
,
data_shape
)
condition
=
F
.
cast
(
condition
,
mstype
.
bool_
)
value_fill
=
F
.
fill
(
data_dtype
,
(
indices_size
,),
value
)
value_1d
=
F
.
scatter_nd
(
indices
,
value_fill
,
(
data_size
,))
u
=
F
.
reshape
(
value_1d
,
data_shape
)
return
F
.
select
(
condition
,
u
,
data
)
def
_tensor_setitem_by_tensor_with_tuple
(
data
,
index
,
value
):
"""Set a tensor item by a tensor with a tuple."""
updates
=
compile_utils
.
generate_updates_from_tuple
(
data
,
index
,
value
,
const_utils
.
SET_ITEM_BY_ONE_TENSOR
)
result
=
F
.
scatter_update
(
data
,
index
,
updates
)
return
result
def
_tensor_setitem_by_int_tensor_with_scalar
(
data
,
index
,
value
):
"""Set a tensor item by a int tensor with a scalar."""
updates
=
compile_utils
.
generate_updates_from_scalar
(
data
,
index
,
value
,
const_utils
.
SET_ITEM_BY_ONE_TENSOR
)
return
F
.
scatter_update
(
data
,
index
,
updates
)
def
_tensor_setitem_by_bool_tensor_with_scalar
(
data
,
index
,
value
):
"""Set a tensor item by a bool tensor with a scalar."""
index_shape
=
F
.
shape
(
index
)
shape
=
F
.
shape
(
data
)
shape
=
const_utils
.
check_equal
(
shape
,
index_shape
,
"The tensor(shape={}) and tensor index(shape={}) should be the same shape."
)
dtype
=
F
.
dtype
(
data
)
u
=
F
.
fill
(
dtype
,
shape
,
value
)
return
F
.
select
(
index
,
u
,
data
)
def
_tensor_setitem_by_int_tensor_with_tensor
(
data
,
index
,
value
):
"""Set a tensor item by a int tensor with a tensor."""
updates
=
compile_utils
.
generate_updates_from_tensor
(
data
,
index
,
value
,
const_utils
.
SET_ITEM_BY_ONE_TENSOR
)
return
F
.
scatter_update
(
data
,
index
,
updates
)
def
_tensor_setitem_by_bool_tensor_with_tensor
(
data
,
index
,
value
):
"""Set a tensor item by a bool tensor with a tensor."""
index_shape
=
F
.
shape
(
index
)
data_shape
=
F
.
shape
(
data
)
data_shape
=
const_utils
.
check_equal
(
data_shape
,
index_shape
,
"The tensor(shape={}) and tensor index(shape={}) should be the same shape."
)
size
=
F
.
size
(
value
)
size
=
const_utils
.
check_equal
(
1
,
size
,
"When assign value is a tensor, its size should be {}, but current size is {}."
)
dtype
=
F
.
dtype
(
data
)
u_cast
=
F
.
cast
(
value
,
dtype
)
one_data
=
F
.
ones_like
(
data
)
u
=
F
.
tensor_mul
(
one_data
,
u_cast
)
result
=
F
.
select
(
index
,
u
,
data
)
return
result
return
compile_utils
.
tensor_setitem_by_ellipsis_with_tensor
(
data
,
index
,
value
)
tests/st/pynative/test_tensor_index.py
浏览文件 @
79058d35
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录