Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
3ba31ec1
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3ba31ec1
编写于
4月 20, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 20, 2020
浏览文件
操作
浏览文件
下载
差异文件
!419 Tensor assign with bool Tensor
Merge pull request !419 from candanzg/tensor_assign_bool_index
上级
b554a868
3f087dba
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
343 addition
and
1 deletion
+343
-1
mindspore/_extends/parse/resources.py
mindspore/_extends/parse/resources.py
+1
-1
mindspore/ops/composite/multitype_ops/__init__.py
mindspore/ops/composite/multitype_ops/__init__.py
+2
-0
mindspore/ops/composite/multitype_ops/_multitype_ops_util.py
mindspore/ops/composite/multitype_ops/_multitype_ops_util.py
+45
-0
mindspore/ops/composite/multitype_ops/setitem_impl.py
mindspore/ops/composite/multitype_ops/setitem_impl.py
+194
-0
mindspore/ops/functional.py
mindspore/ops/functional.py
+5
-0
tests/ut/python/ops/test_tensor_slice.py
tests/ut/python/ops/test_tensor_slice.py
+96
-0
未找到文件。
mindspore/_extends/parse/resources.py
浏览文件 @
3ba31ec1
...
...
@@ -83,6 +83,7 @@ convert_object_map = {
T
.
mul
:
multitype_ops
.
mul
,
T
.
truediv
:
multitype_ops
.
div
,
T
.
getitem
:
multitype_ops
.
getitem
,
T
.
setitem
:
multitype_ops
.
setitem
,
T
.
floordiv
:
multitype_ops
.
floordiv
,
T
.
mod
:
multitype_ops
.
mod
,
T
.
pow
:
multitype_ops
.
pow_
,
...
...
@@ -118,7 +119,6 @@ convert_object_map = {
T
.
iter
:
M
.
ms_iter
,
T
.
next
:
M
.
ms_next
,
T
.
hasnext
:
M
.
hasnext
,
T
.
setitem
:
M
.
setitem
,
T
.
make_tuple
:
F
.
make_tuple
,
T
.
make_dict
:
F
.
make_dict
,
...
...
mindspore/ops/composite/multitype_ops/__init__.py
浏览文件 @
3ba31ec1
...
...
@@ -23,6 +23,7 @@ from .pow_impl import pow_
from
.floordiv_impl
import
floordiv
from
.mod_impl
import
mod
from
.getitem_impl
import
getitem
from
.setitem_impl
import
setitem
from
.zeros_like_impl
import
zeros_like
from
.ones_like_impl
import
ones_like
from
.equal_impl
import
equal
...
...
@@ -55,6 +56,7 @@ __all__ = [
'greater_equal'
,
'negative'
,
'getitem'
,
'setitem'
,
'logical_and'
,
'logical_or'
,
'logical_not'
...
...
mindspore/ops/composite/multitype_ops/_multitype_ops_util.py
0 → 100644
浏览文件 @
3ba31ec1
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""constexpr util"""
from
...primitive
import
constexpr
@
constexpr
def
is_same_type
(
inst
,
type_
):
"""
Check whether an object is an instance of a target type.
Inputs:
inst (mindspore.dtype): Inspected type.
type_ (mindspore.dtype): Target type.
Outputs:
bool, the check result.
"""
return
inst
==
type_
@
constexpr
def
error_msg
(
msg
=
""
,
format_values
=
""
):
"""
Used to throw exception information.
Inputs:
msg (str): information content.
"""
raise
ValueError
(
msg
.
format
(
*
format_values
))
mindspore/ops/composite/multitype_ops/setitem_impl.py
0 → 100644
浏览文件 @
3ba31ec1
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Implementation for setitem."""
from
...composite
import
base
from
....common
import
dtype
as
mstype
from
...
import
functional
as
F
from
.
import
_multitype_ops_util
as
mult_util
setitem
=
base
.
MultitypeFuncGraph
(
'setitem'
)
@
setitem
.
register
(
"List"
,
"Number"
,
"String"
)
def
_list_setitem_with_string
(
data
,
number_index
,
value
):
"""
Assign value to list.
Inputs:
data (list): Data of type lis.
number_index (Number): Index of data.
value (String): Value given.
Outputs:
List, type is same as the element type of data.
"""
return
F
.
list_setitem
(
data
,
number_index
,
value
)
@
setitem
.
register
(
"List"
,
"Number"
,
"Number"
)
def
_list_setitem_with_number
(
data
,
number_index
,
value
):
"""
Assign value to list.
Inputs:
data (list): Data of type lis.
number_index (Number): Index of data.
value (Number): Value given.
Outputs:
List, type is same as the element type of data.
"""
return
F
.
list_setitem
(
data
,
number_index
,
value
)
@
setitem
.
register
(
"List"
,
"Number"
,
"Tensor"
)
def
_list_setitem_with_Tensor
(
data
,
number_index
,
value
):
"""
Assign value to list.
Inputs:
data (list): Data of type lis.
number_index (Number): Index of data.
value (Tensor): Value given.
Outputs:
List, type is same as the element type of data.
"""
return
F
.
list_setitem
(
data
,
number_index
,
value
)
@
setitem
.
register
(
"List"
,
"Number"
,
"List"
)
def
_list_setitem_with_List
(
data
,
number_index
,
value
):
"""
Assign value to list.
Inputs:
data (list): Data of type lis.
number_index (Number): Index of data.
value (List): Value given.
Outputs:
List, type is same as the element type of data.
"""
return
F
.
list_setitem
(
data
,
number_index
,
value
)
@
setitem
.
register
(
"Dictionary"
,
"String"
,
"Tensor"
)
def
_dict_setitem_with_tensor
(
data
,
key
,
value
):
"""
Assign value to dictionary.
Inputs:
data (Dictionary): Data of type dict.
key (str): Key of the data.
value (Tensor): Value given.
Outputs:
Dict, type is as same as the element type of data.
"""
return
F
.
dict_setitem
(
data
,
key
,
value
)
@
setitem
.
register
(
"Dictionary"
,
"String"
,
"Number"
)
def
_dict_setitem_with_number
(
data
,
key
,
value
):
"""
Assign value to dictionary.
Inputs:
data (Dictionary): Data of type dict.
key (str): Key of the data.
value (Number): Value given.
Outputs:
Dict, type is as same as the element type of data.
"""
return
F
.
dict_setitem
(
data
,
key
,
value
)
@
setitem
.
register
(
"Tensor"
,
"Tensor"
,
"Tensor"
)
def
_tensor_setitem_by_tensor_v1
(
data
,
index
,
value_tensor
):
"""
Tensor assignment.
Note:
Syntax support: A[B] = U and A[A>n] = U.
Restraint condition: 1) A, U is a Tensor, and B is a bool Tensor.
2) A.shape == B.shape
3) U.size == 1
4) n is a number
Inputs:
data (Tensor): Assigned tensor.
index (Tensor): Tensor of bool type.
value_tensor (Tensor): Tensor with size 1.
Outputs:
Tensor, element type and shape is same as data.
"""
index_dtype
=
F
.
dtype
(
index
)
index_shape
=
F
.
shape
(
index
)
is_bool
=
mult_util
.
is_same_type
(
index_dtype
,
mstype
.
bool_
)
if
not
is_bool
:
return
mult_util
.
error_msg
(
"The tensor index should be a bool type tensor. {} type tensor is not supported yet."
,
(
index_dtype
,))
data_shape
=
F
.
shape
(
data
)
if
index_shape
!=
data_shape
:
return
mult_util
.
error_msg
(
"The tensor(shape={}) and tensor index(shape={}) should be the same shape."
,
(
data_shape
,
index_shape
))
size
=
F
.
size
(
value_tensor
)
if
size
!=
1
:
return
mult_util
.
error_msg
(
"When assign value is a tensor, its size should be 1, but current size is {}."
,
(
size
,))
dtype
=
F
.
dtype
(
data
)
u_cast
=
F
.
cast
(
value_tensor
,
dtype
)
one_data
=
F
.
ones_like
(
data
)
u
=
F
.
tensor_mul
(
one_data
,
u_cast
)
return
F
.
select
(
index
,
u
,
data
)
@
setitem
.
register
(
"Tensor"
,
"Tensor"
,
"Number"
)
def
_tensor_setitem_by_tensor_v2
(
data
,
index
,
value
):
"""
Tensor assignment.
Note:
Syntax support: A[B] = u and A[A>n] = u.
Restraint condition: 1) A is a Tensor, and B is a bool Tensor.
2) A.shape == B.shape
3) u is a scalar
4) n is a number
Inputs:
data (Tensor): Assigned tensor.
index (Tensor): Tensor of bool type.
value_tensor (Number): Assignment value.
Outputs:
Tensor, element type and shape is same as data.
"""
index_dtype
=
F
.
dtype
(
index
)
index_shape
=
F
.
shape
(
index
)
is_bool
=
mult_util
.
is_same_type
(
index_dtype
,
mstype
.
bool_
)
if
not
is_bool
:
return
mult_util
.
error_msg
(
"The tensor index should be a bool type tensor. {} type tensor is not supported yet."
,
(
index_dtype
,))
shape
=
F
.
shape
(
data
)
if
index_shape
!=
shape
:
return
mult_util
.
error_msg
(
"The tensor(shape={}) and tensor index(shape={}) should be the same shape."
,
(
shape
,
index_shape
))
dtype
=
F
.
dtype
(
data
)
u
=
F
.
fill
(
dtype
,
shape
,
value
)
return
F
.
select
(
index
,
u
,
data
)
mindspore/ops/functional.py
浏览文件 @
3ba31ec1
...
...
@@ -31,6 +31,9 @@ dtype = P.DType()
issubclass_
=
P
.
IsSubClass
()
isinstance_
=
P
.
IsInstance
()
fill
=
P
.
Fill
()
select
=
P
.
Select
()
size
=
P
.
Size
()
ones_like
=
P
.
OnesLike
()
shape
=
P
.
Shape
()
rank
=
P
.
Rank
()
reshape
=
P
.
Reshape
()
...
...
@@ -68,7 +71,9 @@ scalar_cast = P.ScalarCast()
tuple_setitem
=
Primitive
(
'tuple_setitem'
)
tuple_getitem
=
Primitive
(
'tuple_getitem'
)
list_getitem
=
Primitive
(
'list_getitem'
)
list_setitem
=
Primitive
(
'list_setitem'
)
dict_getitem
=
Primitive
(
'dict_getitem'
)
dict_setitem
=
Primitive
(
'dict_setitem'
)
tuple_div
=
Primitive
(
"tuple_div"
)
tuple_len
=
Primitive
(
"tuple_len"
)
tuple_reversed
=
Primitive
(
"tuple_reversed"
)
...
...
tests/ut/python/ops/test_tensor_slice.py
浏览文件 @
3ba31ec1
...
...
@@ -18,6 +18,7 @@ import pytest
from
mindspore
import
Tensor
from
mindspore
import
context
from
mindspore
import
dtype
as
mstype
from
mindspore.nn
import
Cell
from
....mindspore_test_framework.mindspore_test
import
mindspore_test
...
...
@@ -79,7 +80,102 @@ class NetWorkReduceToScalar(Cell):
return
ret
class
TensorAssignWithBoolTensorIndex
(
Cell
):
def
__init__
(
self
):
super
(
TensorAssignWithBoolTensorIndex
,
self
).
__init__
()
self
.
t
=
Tensor
(
np
.
arange
(
6
).
reshape
([
2
,
3
]),
dtype
=
mstype
.
float64
)
def
construct
(
self
,
a
,
b
,
c
,
u_tensor
,
_scalar
):
a
[
c
]
=
u_scalar
a
[
b
]
=
u_tensor
z
=
a
+
self
.
t
return
z
class
TensorAssignWithBoolTensorIndexError
(
Cell
):
def
__init__
(
self
):
super
(
TensorAssignWithBoolTensorIndexError
,
self
).
__init__
()
def
construct
(
self
,
a
,
b
,
c
,
u_tensor
):
a
[
b
][
c
]
=
u_tensor
return
a
class
TensorAssignWithBoolTensorIndex2
(
Cell
):
def
__init__
(
self
):
super
(
TensorAssignWithBoolTensorIndex2
,
self
).
__init__
()
self
.
t
=
Tensor
(
np
.
arange
(
6
).
reshape
([
2
,
3
]),
dtype
=
mstype
.
float64
)
def
construct
(
self
,
a
,
u_tensor
,
_scalar
):
a
[
a
>
8
]
=
u_tensor
a
[
a
>=
6
]
=
u_scalar
a
[
a
<
3
]
=
u_scalar
a
[
a
<=
5
]
=
u_tensor
a
[
a
==
5
]
=
u_scalar
z
=
a
+
self
.
t
return
z
class
TensorAssignWithBoolTensorIndex2Error
(
Cell
):
def
__init__
(
self
):
super
(
TensorAssignWithBoolTensorIndex2Error
,
self
).
__init__
()
def
construct
(
self
,
a
,
u_tensor
):
a
[
a
>
8
][
a
>
5
]
=
u_tensor
return
a
a
=
np
.
random
.
uniform
(
1
,
10
,[
2
,
3
])
b
=
a
>
5
c
=
a
<
3
Ta
=
Tensor
(
a
)
Tb
=
Tensor
(
b
)
Tc
=
Tensor
(
c
)
Td
=
Tensor
([
True
,
True
])
u_tensor
=
Tensor
([
1
])
u_tensor_error
=
Tensor
([
1
,
2
])
u_scalar
=
5
def
test_tensor_assign_bool_index
():
net1
=
TensorAssignWithBoolTensorIndex
()
net2
=
TensorAssignWithBoolTensorIndex2
()
net1
(
Ta
,
Tb
,
Tc
,
u_tensor
,
u_scalar
)
with
pytest
.
raises
(
ValueError
):
net1
(
Ta
,
Td
,
Tc
,
u_tensor
,
u_scalar
)
with
pytest
.
raises
(
ValueError
):
net1
(
Ta
,
u_tensor
,
Tc
,
u_tensor
,
u_scalar
)
with
pytest
.
raises
(
ValueError
):
net1
(
Ta
,
Tb
,
Td
,
u_tensor
,
u_scalar
)
with
pytest
.
raises
(
ValueError
):
net1
(
Ta
,
Tb
,
Ta
,
u_tensor
,
u_scalar
)
with
pytest
.
raises
(
ValueError
):
net1
(
Ta
,
Tb
,
Tc
,
u_tensor_error
,
u_scalar
)
#net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar)
with
pytest
.
raises
(
ValueError
):
net2
(
Ta
,
u_tensor_error
,
u_scalar
)
net3
=
TensorAssignWithBoolTensorIndexError
()
with
pytest
.
raises
(
AttributeError
):
net3
(
Ta
,
Tb
,
Tc
,
u_tensor
)
with
pytest
.
raises
(
AttributeError
):
net3
(
Ta
,
Tb
,
Tc
,
u_scalar
)
net4
=
TensorAssignWithBoolTensorIndex2Error
()
with
pytest
.
raises
(
AttributeError
):
net4
(
Ta
,
u_tensor
)
with
pytest
.
raises
(
AttributeError
):
net4
(
Ta
,
u_scalar
)
test_cases
=
[
(
'TensorAssignWithBoolTensorIndex'
,
{
'block'
:
TensorAssignWithBoolTensorIndex
(),
'desc_inputs'
:
[
Ta
,
Tb
,
Tc
,
u_tensor
,
u_scalar
],
}),
(
'TensorAssignWithBoolTensorIndex2'
,
{
'block'
:
TensorAssignWithBoolTensorIndex2
(),
'desc_inputs'
:
[
Ta
,
u_tensor
,
u_scalar
],
}),
(
'SlicePositive'
,
{
'block'
:
NetWorkSlicePositive
(),
'desc_inputs'
:
[
Tensor
(
np
.
ones
([
6
,
8
,
10
],
np
.
int32
))],
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录