Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
3ba31ec1
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3ba31ec1
编写于
4月 20, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 20, 2020
浏览文件
操作
浏览文件
下载
差异文件
!419 Tensor assign with bool Tensor
Merge pull request !419 from candanzg/tensor_assign_bool_index
上级
b554a868
3f087dba
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
343 addition
and
1 deletion
+343
-1
mindspore/_extends/parse/resources.py
mindspore/_extends/parse/resources.py
+1
-1
mindspore/ops/composite/multitype_ops/__init__.py
mindspore/ops/composite/multitype_ops/__init__.py
+2
-0
mindspore/ops/composite/multitype_ops/_multitype_ops_util.py
mindspore/ops/composite/multitype_ops/_multitype_ops_util.py
+45
-0
mindspore/ops/composite/multitype_ops/setitem_impl.py
mindspore/ops/composite/multitype_ops/setitem_impl.py
+194
-0
mindspore/ops/functional.py
mindspore/ops/functional.py
+5
-0
tests/ut/python/ops/test_tensor_slice.py
tests/ut/python/ops/test_tensor_slice.py
+96
-0
未找到文件。
mindspore/_extends/parse/resources.py
浏览文件 @
3ba31ec1
...
@@ -83,6 +83,7 @@ convert_object_map = {
...
@@ -83,6 +83,7 @@ convert_object_map = {
T
.
mul
:
multitype_ops
.
mul
,
T
.
mul
:
multitype_ops
.
mul
,
T
.
truediv
:
multitype_ops
.
div
,
T
.
truediv
:
multitype_ops
.
div
,
T
.
getitem
:
multitype_ops
.
getitem
,
T
.
getitem
:
multitype_ops
.
getitem
,
T
.
setitem
:
multitype_ops
.
setitem
,
T
.
floordiv
:
multitype_ops
.
floordiv
,
T
.
floordiv
:
multitype_ops
.
floordiv
,
T
.
mod
:
multitype_ops
.
mod
,
T
.
mod
:
multitype_ops
.
mod
,
T
.
pow
:
multitype_ops
.
pow_
,
T
.
pow
:
multitype_ops
.
pow_
,
...
@@ -118,7 +119,6 @@ convert_object_map = {
...
@@ -118,7 +119,6 @@ convert_object_map = {
T
.
iter
:
M
.
ms_iter
,
T
.
iter
:
M
.
ms_iter
,
T
.
next
:
M
.
ms_next
,
T
.
next
:
M
.
ms_next
,
T
.
hasnext
:
M
.
hasnext
,
T
.
hasnext
:
M
.
hasnext
,
T
.
setitem
:
M
.
setitem
,
T
.
make_tuple
:
F
.
make_tuple
,
T
.
make_tuple
:
F
.
make_tuple
,
T
.
make_dict
:
F
.
make_dict
,
T
.
make_dict
:
F
.
make_dict
,
...
...
mindspore/ops/composite/multitype_ops/__init__.py
浏览文件 @
3ba31ec1
...
@@ -23,6 +23,7 @@ from .pow_impl import pow_
...
@@ -23,6 +23,7 @@ from .pow_impl import pow_
from
.floordiv_impl
import
floordiv
from
.floordiv_impl
import
floordiv
from
.mod_impl
import
mod
from
.mod_impl
import
mod
from
.getitem_impl
import
getitem
from
.getitem_impl
import
getitem
from
.setitem_impl
import
setitem
from
.zeros_like_impl
import
zeros_like
from
.zeros_like_impl
import
zeros_like
from
.ones_like_impl
import
ones_like
from
.ones_like_impl
import
ones_like
from
.equal_impl
import
equal
from
.equal_impl
import
equal
...
@@ -55,6 +56,7 @@ __all__ = [
...
@@ -55,6 +56,7 @@ __all__ = [
'greater_equal'
,
'greater_equal'
,
'negative'
,
'negative'
,
'getitem'
,
'getitem'
,
'setitem'
,
'logical_and'
,
'logical_and'
,
'logical_or'
,
'logical_or'
,
'logical_not'
'logical_not'
...
...
mindspore/ops/composite/multitype_ops/_multitype_ops_util.py
0 → 100644
浏览文件 @
3ba31ec1
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""constexpr util"""
from
...primitive
import
constexpr
@
constexpr
def
is_same_type
(
inst
,
type_
):
"""
Check whether an object is an instance of a target type.
Inputs:
inst (mindspore.dtype): Inspected type.
type_ (mindspore.dtype): Target type.
Outputs:
bool, the check result.
"""
return
inst
==
type_
@
constexpr
def
error_msg
(
msg
=
""
,
format_values
=
""
):
"""
Used to throw exception information.
Inputs:
msg (str): information content.
"""
raise
ValueError
(
msg
.
format
(
*
format_values
))
mindspore/ops/composite/multitype_ops/setitem_impl.py
0 → 100644
浏览文件 @
3ba31ec1
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Implementation for setitem."""
from
...composite
import
base
from
....common
import
dtype
as
mstype
from
...
import
functional
as
F
from
.
import
_multitype_ops_util
as
mult_util
setitem
=
base
.
MultitypeFuncGraph
(
'setitem'
)
@
setitem
.
register
(
"List"
,
"Number"
,
"String"
)
def
_list_setitem_with_string
(
data
,
number_index
,
value
):
"""
Assign value to list.
Inputs:
data (list): Data of type lis.
number_index (Number): Index of data.
value (String): Value given.
Outputs:
List, type is same as the element type of data.
"""
return
F
.
list_setitem
(
data
,
number_index
,
value
)
@
setitem
.
register
(
"List"
,
"Number"
,
"Number"
)
def
_list_setitem_with_number
(
data
,
number_index
,
value
):
"""
Assign value to list.
Inputs:
data (list): Data of type lis.
number_index (Number): Index of data.
value (Number): Value given.
Outputs:
List, type is same as the element type of data.
"""
return
F
.
list_setitem
(
data
,
number_index
,
value
)
@
setitem
.
register
(
"List"
,
"Number"
,
"Tensor"
)
def
_list_setitem_with_Tensor
(
data
,
number_index
,
value
):
"""
Assign value to list.
Inputs:
data (list): Data of type lis.
number_index (Number): Index of data.
value (Tensor): Value given.
Outputs:
List, type is same as the element type of data.
"""
return
F
.
list_setitem
(
data
,
number_index
,
value
)
@
setitem
.
register
(
"List"
,
"Number"
,
"List"
)
def
_list_setitem_with_List
(
data
,
number_index
,
value
):
"""
Assign value to list.
Inputs:
data (list): Data of type lis.
number_index (Number): Index of data.
value (List): Value given.
Outputs:
List, type is same as the element type of data.
"""
return
F
.
list_setitem
(
data
,
number_index
,
value
)
@
setitem
.
register
(
"Dictionary"
,
"String"
,
"Tensor"
)
def
_dict_setitem_with_tensor
(
data
,
key
,
value
):
"""
Assign value to dictionary.
Inputs:
data (Dictionary): Data of type dict.
key (str): Key of the data.
value (Tensor): Value given.
Outputs:
Dict, type is as same as the element type of data.
"""
return
F
.
dict_setitem
(
data
,
key
,
value
)
@
setitem
.
register
(
"Dictionary"
,
"String"
,
"Number"
)
def
_dict_setitem_with_number
(
data
,
key
,
value
):
"""
Assign value to dictionary.
Inputs:
data (Dictionary): Data of type dict.
key (str): Key of the data.
value (Number): Value given.
Outputs:
Dict, type is as same as the element type of data.
"""
return
F
.
dict_setitem
(
data
,
key
,
value
)
@
setitem
.
register
(
"Tensor"
,
"Tensor"
,
"Tensor"
)
def
_tensor_setitem_by_tensor_v1
(
data
,
index
,
value_tensor
):
"""
Tensor assignment.
Note:
Syntax support: A[B] = U and A[A>n] = U.
Restraint condition: 1) A, U is a Tensor, and B is a bool Tensor.
2) A.shape == B.shape
3) U.size == 1
4) n is a number
Inputs:
data (Tensor): Assigned tensor.
index (Tensor): Tensor of bool type.
value_tensor (Tensor): Tensor with size 1.
Outputs:
Tensor, element type and shape is same as data.
"""
index_dtype
=
F
.
dtype
(
index
)
index_shape
=
F
.
shape
(
index
)
is_bool
=
mult_util
.
is_same_type
(
index_dtype
,
mstype
.
bool_
)
if
not
is_bool
:
return
mult_util
.
error_msg
(
"The tensor index should be a bool type tensor. {} type tensor is not supported yet."
,
(
index_dtype
,))
data_shape
=
F
.
shape
(
data
)
if
index_shape
!=
data_shape
:
return
mult_util
.
error_msg
(
"The tensor(shape={}) and tensor index(shape={}) should be the same shape."
,
(
data_shape
,
index_shape
))
size
=
F
.
size
(
value_tensor
)
if
size
!=
1
:
return
mult_util
.
error_msg
(
"When assign value is a tensor, its size should be 1, but current size is {}."
,
(
size
,))
dtype
=
F
.
dtype
(
data
)
u_cast
=
F
.
cast
(
value_tensor
,
dtype
)
one_data
=
F
.
ones_like
(
data
)
u
=
F
.
tensor_mul
(
one_data
,
u_cast
)
return
F
.
select
(
index
,
u
,
data
)
@
setitem
.
register
(
"Tensor"
,
"Tensor"
,
"Number"
)
def
_tensor_setitem_by_tensor_v2
(
data
,
index
,
value
):
"""
Tensor assignment.
Note:
Syntax support: A[B] = u and A[A>n] = u.
Restraint condition: 1) A is a Tensor, and B is a bool Tensor.
2) A.shape == B.shape
3) u is a scalar
4) n is a number
Inputs:
data (Tensor): Assigned tensor.
index (Tensor): Tensor of bool type.
value_tensor (Number): Assignment value.
Outputs:
Tensor, element type and shape is same as data.
"""
index_dtype
=
F
.
dtype
(
index
)
index_shape
=
F
.
shape
(
index
)
is_bool
=
mult_util
.
is_same_type
(
index_dtype
,
mstype
.
bool_
)
if
not
is_bool
:
return
mult_util
.
error_msg
(
"The tensor index should be a bool type tensor. {} type tensor is not supported yet."
,
(
index_dtype
,))
shape
=
F
.
shape
(
data
)
if
index_shape
!=
shape
:
return
mult_util
.
error_msg
(
"The tensor(shape={}) and tensor index(shape={}) should be the same shape."
,
(
shape
,
index_shape
))
dtype
=
F
.
dtype
(
data
)
u
=
F
.
fill
(
dtype
,
shape
,
value
)
return
F
.
select
(
index
,
u
,
data
)
mindspore/ops/functional.py
浏览文件 @
3ba31ec1
...
@@ -31,6 +31,9 @@ dtype = P.DType()
...
@@ -31,6 +31,9 @@ dtype = P.DType()
issubclass_
=
P
.
IsSubClass
()
issubclass_
=
P
.
IsSubClass
()
isinstance_
=
P
.
IsInstance
()
isinstance_
=
P
.
IsInstance
()
fill
=
P
.
Fill
()
fill
=
P
.
Fill
()
select
=
P
.
Select
()
size
=
P
.
Size
()
ones_like
=
P
.
OnesLike
()
shape
=
P
.
Shape
()
shape
=
P
.
Shape
()
rank
=
P
.
Rank
()
rank
=
P
.
Rank
()
reshape
=
P
.
Reshape
()
reshape
=
P
.
Reshape
()
...
@@ -68,7 +71,9 @@ scalar_cast = P.ScalarCast()
...
@@ -68,7 +71,9 @@ scalar_cast = P.ScalarCast()
tuple_setitem
=
Primitive
(
'tuple_setitem'
)
tuple_setitem
=
Primitive
(
'tuple_setitem'
)
tuple_getitem
=
Primitive
(
'tuple_getitem'
)
tuple_getitem
=
Primitive
(
'tuple_getitem'
)
list_getitem
=
Primitive
(
'list_getitem'
)
list_getitem
=
Primitive
(
'list_getitem'
)
list_setitem
=
Primitive
(
'list_setitem'
)
dict_getitem
=
Primitive
(
'dict_getitem'
)
dict_getitem
=
Primitive
(
'dict_getitem'
)
dict_setitem
=
Primitive
(
'dict_setitem'
)
tuple_div
=
Primitive
(
"tuple_div"
)
tuple_div
=
Primitive
(
"tuple_div"
)
tuple_len
=
Primitive
(
"tuple_len"
)
tuple_len
=
Primitive
(
"tuple_len"
)
tuple_reversed
=
Primitive
(
"tuple_reversed"
)
tuple_reversed
=
Primitive
(
"tuple_reversed"
)
...
...
tests/ut/python/ops/test_tensor_slice.py
浏览文件 @
3ba31ec1
...
@@ -18,6 +18,7 @@ import pytest
...
@@ -18,6 +18,7 @@ import pytest
from
mindspore
import
Tensor
from
mindspore
import
Tensor
from
mindspore
import
context
from
mindspore
import
context
from
mindspore
import
dtype
as
mstype
from
mindspore.nn
import
Cell
from
mindspore.nn
import
Cell
from
....mindspore_test_framework.mindspore_test
import
mindspore_test
from
....mindspore_test_framework.mindspore_test
import
mindspore_test
...
@@ -79,7 +80,102 @@ class NetWorkReduceToScalar(Cell):
...
@@ -79,7 +80,102 @@ class NetWorkReduceToScalar(Cell):
return
ret
return
ret
class
TensorAssignWithBoolTensorIndex
(
Cell
):
def
__init__
(
self
):
super
(
TensorAssignWithBoolTensorIndex
,
self
).
__init__
()
self
.
t
=
Tensor
(
np
.
arange
(
6
).
reshape
([
2
,
3
]),
dtype
=
mstype
.
float64
)
def
construct
(
self
,
a
,
b
,
c
,
u_tensor
,
_scalar
):
a
[
c
]
=
u_scalar
a
[
b
]
=
u_tensor
z
=
a
+
self
.
t
return
z
class
TensorAssignWithBoolTensorIndexError
(
Cell
):
def
__init__
(
self
):
super
(
TensorAssignWithBoolTensorIndexError
,
self
).
__init__
()
def
construct
(
self
,
a
,
b
,
c
,
u_tensor
):
a
[
b
][
c
]
=
u_tensor
return
a
class
TensorAssignWithBoolTensorIndex2
(
Cell
):
def
__init__
(
self
):
super
(
TensorAssignWithBoolTensorIndex2
,
self
).
__init__
()
self
.
t
=
Tensor
(
np
.
arange
(
6
).
reshape
([
2
,
3
]),
dtype
=
mstype
.
float64
)
def
construct
(
self
,
a
,
u_tensor
,
_scalar
):
a
[
a
>
8
]
=
u_tensor
a
[
a
>=
6
]
=
u_scalar
a
[
a
<
3
]
=
u_scalar
a
[
a
<=
5
]
=
u_tensor
a
[
a
==
5
]
=
u_scalar
z
=
a
+
self
.
t
return
z
class
TensorAssignWithBoolTensorIndex2Error
(
Cell
):
def
__init__
(
self
):
super
(
TensorAssignWithBoolTensorIndex2Error
,
self
).
__init__
()
def
construct
(
self
,
a
,
u_tensor
):
a
[
a
>
8
][
a
>
5
]
=
u_tensor
return
a
a
=
np
.
random
.
uniform
(
1
,
10
,[
2
,
3
])
b
=
a
>
5
c
=
a
<
3
Ta
=
Tensor
(
a
)
Tb
=
Tensor
(
b
)
Tc
=
Tensor
(
c
)
Td
=
Tensor
([
True
,
True
])
u_tensor
=
Tensor
([
1
])
u_tensor_error
=
Tensor
([
1
,
2
])
u_scalar
=
5
def
test_tensor_assign_bool_index
():
net1
=
TensorAssignWithBoolTensorIndex
()
net2
=
TensorAssignWithBoolTensorIndex2
()
net1
(
Ta
,
Tb
,
Tc
,
u_tensor
,
u_scalar
)
with
pytest
.
raises
(
ValueError
):
net1
(
Ta
,
Td
,
Tc
,
u_tensor
,
u_scalar
)
with
pytest
.
raises
(
ValueError
):
net1
(
Ta
,
u_tensor
,
Tc
,
u_tensor
,
u_scalar
)
with
pytest
.
raises
(
ValueError
):
net1
(
Ta
,
Tb
,
Td
,
u_tensor
,
u_scalar
)
with
pytest
.
raises
(
ValueError
):
net1
(
Ta
,
Tb
,
Ta
,
u_tensor
,
u_scalar
)
with
pytest
.
raises
(
ValueError
):
net1
(
Ta
,
Tb
,
Tc
,
u_tensor_error
,
u_scalar
)
#net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar)
with
pytest
.
raises
(
ValueError
):
net2
(
Ta
,
u_tensor_error
,
u_scalar
)
net3
=
TensorAssignWithBoolTensorIndexError
()
with
pytest
.
raises
(
AttributeError
):
net3
(
Ta
,
Tb
,
Tc
,
u_tensor
)
with
pytest
.
raises
(
AttributeError
):
net3
(
Ta
,
Tb
,
Tc
,
u_scalar
)
net4
=
TensorAssignWithBoolTensorIndex2Error
()
with
pytest
.
raises
(
AttributeError
):
net4
(
Ta
,
u_tensor
)
with
pytest
.
raises
(
AttributeError
):
net4
(
Ta
,
u_scalar
)
test_cases
=
[
test_cases
=
[
(
'TensorAssignWithBoolTensorIndex'
,
{
'block'
:
TensorAssignWithBoolTensorIndex
(),
'desc_inputs'
:
[
Ta
,
Tb
,
Tc
,
u_tensor
,
u_scalar
],
}),
(
'TensorAssignWithBoolTensorIndex2'
,
{
'block'
:
TensorAssignWithBoolTensorIndex2
(),
'desc_inputs'
:
[
Ta
,
u_tensor
,
u_scalar
],
}),
(
'SlicePositive'
,
{
(
'SlicePositive'
,
{
'block'
:
NetWorkSlicePositive
(),
'block'
:
NetWorkSlicePositive
(),
'desc_inputs'
:
[
Tensor
(
np
.
ones
([
6
,
8
,
10
],
np
.
int32
))],
'desc_inputs'
:
[
Tensor
(
np
.
ones
([
6
,
8
,
10
],
np
.
int32
))],
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录