Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
97524b9d
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
97524b9d
编写于
6月 04, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 04, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1823 support vm for ConfusionMatrix
Merge pull request !1823 from jiangjinsheng/vm_ConfusionMatrix
上级
0fc2da9b
fc4cf5a4
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
128 addition
and
2 deletion
+128
-2
mindspore/ops/_op_impl/tbe/__init__.py
mindspore/ops/_op_impl/tbe/__init__.py
+1
-0
mindspore/ops/_op_impl/tbe/confusion_matrix.py
mindspore/ops/_op_impl/tbe/confusion_matrix.py
+63
-0
mindspore/ops/operations/__init__.py
mindspore/ops/operations/__init__.py
+4
-2
mindspore/ops/operations/other_ops.py
mindspore/ops/operations/other_ops.py
+47
-0
tests/ut/python/ops/test_array_ops.py
tests/ut/python/ops/test_array_ops.py
+13
-0
未找到文件。
mindspore/ops/_op_impl/tbe/__init__.py
浏览文件 @
97524b9d
...
...
@@ -237,3 +237,4 @@ from .basic_lstm_cell import _basic_lstm_cell_tbe
from
.basic_lstm_cell_c_state_grad
import
_basic_lstm_cell_c_state_grad_tbe
from
.basic_lstm_cell_weight_grad
import
_basic_lstm_cell_weight_grad_tbe
from
.basic_lstm_cell_input_grad
import
_basic_lstm_cell_input_grad_tbe
from
.confusion_matrix
import
_confusion_matrix_tbe
mindspore/ops/_op_impl/tbe/confusion_matrix.py
0 → 100644
浏览文件 @
97524b9d
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ConfusionMatrix op"""
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
confusion_matrix_op_info
=
TBERegOp
(
"ConfusionMatrix"
)
\
.
fusion_type
(
"OPAQUE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"confusion_matrix.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"confusion_matrix"
)
\
.
partial_flag
(
True
)
\
.
attr
(
"num_classes"
,
"required"
,
"int"
,
"all"
)
\
.
attr
(
"dtype"
,
"required"
,
"str"
,
"all"
)
\
.
input
(
0
,
"labels"
,
False
,
"required"
,
"all"
)
\
.
input
(
1
,
"predictions"
,
False
,
"required"
,
"all"
)
\
.
input
(
2
,
"weights"
,
False
,
"optional"
,
"all"
)
\
.
output
(
0
,
"y"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
I32_Default
,
DataType
.
I32_Default
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
I8_Default
,
DataType
.
I8_Default
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
U8_Default
,
DataType
.
U8_Default
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
I32_Default
,
DataType
.
I32_Default
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
I8_Default
,
DataType
.
I8_Default
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
U8_Default
,
DataType
.
U8_Default
)
\
.
dtype_format
(
DataType
.
I32_Default
,
DataType
.
I32_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
I32_Default
,
DataType
.
I32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
I32_Default
,
DataType
.
I32_Default
,
DataType
.
I32_Default
,
DataType
.
I32_Default
)
\
.
dtype_format
(
DataType
.
I32_Default
,
DataType
.
I32_Default
,
DataType
.
I8_Default
,
DataType
.
I8_Default
)
\
.
dtype_format
(
DataType
.
I32_Default
,
DataType
.
I32_Default
,
DataType
.
U8_Default
,
DataType
.
U8_Default
)
\
.
dtype_format
(
DataType
.
I8_Default
,
DataType
.
I8_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
I8_Default
,
DataType
.
I8_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
I8_Default
,
DataType
.
I8_Default
,
DataType
.
I32_Default
,
DataType
.
I32_Default
)
\
.
dtype_format
(
DataType
.
I8_Default
,
DataType
.
I8_Default
,
DataType
.
I8_Default
,
DataType
.
I8_Default
)
\
.
dtype_format
(
DataType
.
I8_Default
,
DataType
.
I8_Default
,
DataType
.
U8_Default
,
DataType
.
U8_Default
)
\
.
dtype_format
(
DataType
.
U8_Default
,
DataType
.
U8_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
U8_Default
,
DataType
.
U8_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
U8_Default
,
DataType
.
U8_Default
,
DataType
.
I32_Default
,
DataType
.
I32_Default
)
\
.
dtype_format
(
DataType
.
U8_Default
,
DataType
.
U8_Default
,
DataType
.
I8_Default
,
DataType
.
I8_Default
)
\
.
dtype_format
(
DataType
.
U8_Default
,
DataType
.
U8_Default
,
DataType
.
U8_Default
,
DataType
.
U8_Default
)
\
.
get_op_info
()
@
op_info_register
(
confusion_matrix_op_info
)
def
_confusion_matrix_tbe
():
"""ConfusionMatrix TBE register"""
return
mindspore/ops/operations/__init__.py
浏览文件 @
97524b9d
...
...
@@ -73,7 +73,8 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
TopK
,
BinaryCrossEntropy
,
SparseApplyAdagrad
,
LARSUpdate
,
ApplyFtrl
,
SparseApplyFtrl
,
ApplyProximalAdagrad
,
SparseApplyProximalAdagrad
,
ApplyRMSProp
,
ApplyCenteredRMSProp
,
BasicLSTMCell
)
from
.other_ops
import
Assign
,
IOU
,
BoundingBoxDecode
,
BoundingBoxEncode
,
CheckValid
,
MakeRefKey
,
CheckBprop
from
.other_ops
import
(
Assign
,
IOU
,
BoundingBoxDecode
,
BoundingBoxEncode
,
CheckValid
,
MakeRefKey
,
CheckBprop
,
ConfusionMatrix
)
from
.
import
_quant_ops
from
._quant_ops
import
*
from
.thor_ops
import
*
...
...
@@ -287,7 +288,8 @@ __all__ = [
"BesselI1e"
,
"Atan"
,
"Atanh"
,
"BasicLSTMCell"
"BasicLSTMCell"
,
"ConfusionMatrix"
]
__all__
.
extend
(
_quant_ops
.
__all__
)
...
...
mindspore/ops/operations/other_ops.py
浏览文件 @
97524b9d
...
...
@@ -366,3 +366,50 @@ class CheckBprop(PrimitiveWithInfer):
raise
TypeError
(
f
"
{
tips
}
, the dtype of
{
i
}
th output should be
{
ydtype
}
,"
f
" but got
{
xdtype
}
."
)
return
xdtypes
class
ConfusionMatrix
(
PrimitiveWithInfer
):
r
"""
Calculate the confusion matrix from labels and predictions.
Args:
num_classes (int): The num of classes.
dtype (str): Data type of confusion matrix. Default: 'int32'.
Inputs:
- **labels** (Tensor) - real labels, tensor of 1-D. the dtype must be non-negative Integer.
- **predictions** (Tensor) - the labels from prediction, tensor of 1-D.
the shape same as `labels` and the dtype must be non-negative Integer.
- **weights** (Tensor) - tensor of 1-D. the shape same as `predictions`.
Outputs:
Tensor, the confusion matrix, with shape (`num_classes`, `num_classes`).
Examples:
>>> confusion_matrix = P.ConfusionMatrix(4)
>>> labels = Tensor([0, 1, 1, 3], mindspore.int32)
>>> predictions = Tensor([1, 2, 1, 3], mindspore.int32)
>>> confusion_matrix(labels, predictions)
"""
@
prim_attr_register
def
__init__
(
self
,
num_classes
,
dtype
=
"int32"
):
validator
.
check_value_type
(
"num_classes"
,
num_classes
,
[
int
],
self
.
name
)
validator
.
check_value_type
(
"dtype"
,
dtype
,
[
str
],
self
.
name
)
def
infer_shape
(
self
,
labels
,
predictions
,
weights
=
None
):
validator
.
check
(
'labels dimension'
,
len
(
labels
),
''
,
1
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
'labels shape'
,
labels
,
'predictions shape'
,
predictions
,
Rel
.
EQ
,
self
.
name
)
if
weights
is
not
None
:
validator
.
check
(
'labels shape'
,
labels
,
'weights shape'
,
weights
,
Rel
.
EQ
,
self
.
name
)
ret
=
(
self
.
num_classes
,
self
.
num_classes
)
return
ret
def
infer_dtype
(
self
,
labels
,
predictions
,
weights
=
None
):
validator
.
check_subclass
(
'labels'
,
labels
,
mstype
.
tensor
,
self
.
name
)
validator
.
check_subclass
(
'predictions'
,
predictions
,
mstype
.
tensor
,
self
.
name
)
if
weights
is
not
None
:
validator
.
check_subclass
(
'weights'
,
weights
,
mstype
.
tensor
,
self
.
name
)
args
=
{
"labels"
:
labels
,
"predictions"
:
predictions
}
validator
.
check_tensor_type_same
(
args
,
(
mstype
.
number_type
),
self
.
name
)
return
labels
tests/ut/python/ops/test_array_ops.py
浏览文件 @
97524b9d
...
...
@@ -285,6 +285,16 @@ class SpaceToBatchNDNet(Cell):
def
construct
(
self
,
x
):
return
self
.
space_to_batch_nd
(
x
)
class
ConfusionMatrixNet
(
Cell
):
def
__init__
(
self
):
super
(
ConfusionMatrixNet
,
self
).
__init__
()
self
.
confusion_matrix
=
P
.
ConfusionMatrix
(
4
,
"int32"
)
def
construct
(
self
,
x
,
y
):
return
self
.
confusion_matrix
(
x
,
y
)
test_case_array_ops
=
[
(
'CustNet1'
,
{
'block'
:
CustNet1
(),
...
...
@@ -325,6 +335,9 @@ test_case_array_ops = [
(
'BatchToSpaceNDNet'
,
{
'block'
:
BatchToSpaceNDNet
(),
'desc_inputs'
:
[
Tensor
(
np
.
random
.
rand
(
4
,
1
,
1
,
1
).
astype
(
np
.
float16
))]}),
(
'ConfusionMatrixNet'
,
{
'block'
:
ConfusionMatrixNet
(),
'desc_inputs'
:
[
Tensor
([
0
,
1
,
1
,
3
],
ms
.
int32
),
Tensor
([
0
,
1
,
1
,
3
],
ms
.
int32
)]}),
]
test_case_lists
=
[
test_case_array_ops
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录