提交 97524b9d 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1823 support vm for ConfusionMatrix

Merge pull request !1823 from jiangjinsheng/vm_ConfusionMatrix
...@@ -237,3 +237,4 @@ from .basic_lstm_cell import _basic_lstm_cell_tbe ...@@ -237,3 +237,4 @@ from .basic_lstm_cell import _basic_lstm_cell_tbe
from .basic_lstm_cell_c_state_grad import _basic_lstm_cell_c_state_grad_tbe from .basic_lstm_cell_c_state_grad import _basic_lstm_cell_c_state_grad_tbe
from .basic_lstm_cell_weight_grad import _basic_lstm_cell_weight_grad_tbe from .basic_lstm_cell_weight_grad import _basic_lstm_cell_weight_grad_tbe
from .basic_lstm_cell_input_grad import _basic_lstm_cell_input_grad_tbe from .basic_lstm_cell_input_grad import _basic_lstm_cell_input_grad_tbe
from .confusion_matrix import _confusion_matrix_tbe
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ConfusionMatrix op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
confusion_matrix_op_info = TBERegOp("ConfusionMatrix") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("confusion_matrix.so") \
.compute_cost(10) \
.kernel_name("confusion_matrix") \
.partial_flag(True) \
.attr("num_classes", "required", "int", "all") \
.attr("dtype", "required", "str", "all") \
.input(0, "labels", False, "required", "all") \
.input(1, "predictions", False, "required", "all") \
.input(2, "weights", False, "optional", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
.get_op_info()
@op_info_register(confusion_matrix_op_info)
def _confusion_matrix_tbe():
"""ConfusionMatrix TBE register"""
return
...@@ -73,7 +73,8 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, ...@@ -73,7 +73,8 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl, TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl,
ApplyProximalAdagrad, SparseApplyProximalAdagrad, ApplyProximalAdagrad, SparseApplyProximalAdagrad,
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell) ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell)
from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey, CheckBprop from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode,
CheckValid, MakeRefKey, CheckBprop, ConfusionMatrix)
from . import _quant_ops from . import _quant_ops
from ._quant_ops import * from ._quant_ops import *
from .thor_ops import * from .thor_ops import *
...@@ -287,7 +288,8 @@ __all__ = [ ...@@ -287,7 +288,8 @@ __all__ = [
"BesselI1e", "BesselI1e",
"Atan", "Atan",
"Atanh", "Atanh",
"BasicLSTMCell" "BasicLSTMCell",
"ConfusionMatrix"
] ]
__all__.extend(_quant_ops.__all__) __all__.extend(_quant_ops.__all__)
......
...@@ -366,3 +366,50 @@ class CheckBprop(PrimitiveWithInfer): ...@@ -366,3 +366,50 @@ class CheckBprop(PrimitiveWithInfer):
raise TypeError(f"{tips}, the dtype of {i}th output should be {ydtype}," raise TypeError(f"{tips}, the dtype of {i}th output should be {ydtype},"
f" but got {xdtype}.") f" but got {xdtype}.")
return xdtypes return xdtypes
class ConfusionMatrix(PrimitiveWithInfer):
r"""
Calculate the confusion matrix from labels and predictions.
Args:
num_classes (int): The num of classes.
dtype (str): Data type of confusion matrix. Default: 'int32'.
Inputs:
- **labels** (Tensor) - real labels, tensor of 1-D. the dtype must be non-negative Integer.
- **predictions** (Tensor) - the labels from prediction, tensor of 1-D.
the shape same as `labels` and the dtype must be non-negative Integer.
- **weights** (Tensor) - tensor of 1-D. the shape same as `predictions`.
Outputs:
Tensor, the confusion matrix, with shape (`num_classes`, `num_classes`).
Examples:
>>> confusion_matrix = P.ConfusionMatrix(4)
>>> labels = Tensor([0, 1, 1, 3], mindspore.int32)
>>> predictions = Tensor([1, 2, 1, 3], mindspore.int32)
>>> confusion_matrix(labels, predictions)
"""
@prim_attr_register
def __init__(self, num_classes, dtype="int32"):
validator.check_value_type("num_classes", num_classes, [int], self.name)
validator.check_value_type("dtype", dtype, [str], self.name)
def infer_shape(self, labels, predictions, weights=None):
validator.check('labels dimension', len(labels), '', 1, Rel.EQ, self.name)
validator.check('labels shape', labels, 'predictions shape', predictions, Rel.EQ, self.name)
if weights is not None:
validator.check('labels shape', labels, 'weights shape', weights, Rel.EQ, self.name)
ret = (self.num_classes, self.num_classes)
return ret
def infer_dtype(self, labels, predictions, weights=None):
validator.check_subclass('labels', labels, mstype.tensor, self.name)
validator.check_subclass('predictions', predictions, mstype.tensor, self.name)
if weights is not None:
validator.check_subclass('weights', weights, mstype.tensor, self.name)
args = {"labels": labels, "predictions": predictions}
validator.check_tensor_type_same(args, (mstype.number_type), self.name)
return labels
...@@ -285,6 +285,16 @@ class SpaceToBatchNDNet(Cell): ...@@ -285,6 +285,16 @@ class SpaceToBatchNDNet(Cell):
def construct(self, x): def construct(self, x):
return self.space_to_batch_nd(x) return self.space_to_batch_nd(x)
class ConfusionMatrixNet(Cell):
def __init__(self):
super(ConfusionMatrixNet, self).__init__()
self.confusion_matrix = P.ConfusionMatrix(4, "int32")
def construct(self, x, y):
return self.confusion_matrix(x, y)
test_case_array_ops = [ test_case_array_ops = [
('CustNet1', { ('CustNet1', {
'block': CustNet1(), 'block': CustNet1(),
...@@ -325,6 +335,9 @@ test_case_array_ops = [ ...@@ -325,6 +335,9 @@ test_case_array_ops = [
('BatchToSpaceNDNet', { ('BatchToSpaceNDNet', {
'block': BatchToSpaceNDNet(), 'block': BatchToSpaceNDNet(),
'desc_inputs': [Tensor(np.random.rand(4, 1, 1, 1).astype(np.float16))]}), 'desc_inputs': [Tensor(np.random.rand(4, 1, 1, 1).astype(np.float16))]}),
('ConfusionMatrixNet', {
'block': ConfusionMatrixNet(),
'desc_inputs': [Tensor([0, 1, 1, 3], ms.int32), Tensor([0, 1, 1, 3], ms.int32)]}),
] ]
test_case_lists = [test_case_array_ops] test_case_lists = [test_case_array_ops]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册