# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """mIou.""" import numpy as np from mindspore.nn.metrics.metric import Metric def confuse_matrix(target, pred, n): k = (target >= 0) & (target < n) return np.bincount(n * target[k].astype(int) + pred[k], minlength=n ** 2).reshape(n, n) def iou(hist): denominator = hist.sum(1) + hist.sum(0) - np.diag(hist) res = np.diag(hist) / np.where(denominator > 0, denominator, 1) res = np.sum(res) / np.count_nonzero(denominator) return res class MiouPrecision(Metric): """Calculate miou precision.""" def __init__(self, num_class=21): super(MiouPrecision, self).__init__() if not isinstance(num_class, int): raise TypeError('num_class should be integer type, but got {}'.format(type(num_class))) if num_class < 1: raise ValueError('num_class must be at least 1, but got {}'.format(num_class)) self._num_class = num_class self._mIoU = [] self.clear() def clear(self): self._hist = np.zeros((self._num_class, self._num_class)) self._mIoU = [] def update(self, *inputs): if len(inputs) != 2: raise ValueError('Need 2 inputs (y_pred, y), but got {}'.format(len(inputs))) predict_in = self._convert_data(inputs[0]) label_in = self._convert_data(inputs[1]) pred = predict_in label = label_in if len(label.flatten()) != len(pred.flatten()): print('Skipping: len(gt) = {:d}, len(pred) = {:d}'.format(len(label.flatten()), len(pred.flatten()))) raise ValueError('Class number not match, last input data contain {} classes, but current data contain {} ' 'classes'.format(self._num_class, predict_in.shape[1])) self._hist = confuse_matrix(label.flatten(), pred.flatten(), self._num_class) mIoUs = iou(self._hist) self._mIoU.append(mIoUs) def eval(self): """ Computes the mIoU categorical accuracy. """ mIoU = np.nanmean(self._mIoU) print('mIoU = {}'.format(mIoU)) return mIoU