提交 3fa39034 编写于 作者: C chenguowei01

update confusion matrix calculation

上级 e9a2881d
...@@ -49,7 +49,6 @@ def evaluate(model, ...@@ -49,7 +49,6 @@ def evaluate(model,
for iter, (im, im_info, label) in tqdm.tqdm( for iter, (im, im_info, label) in tqdm.tqdm(
enumerate(eval_dataset), total=total_iters): enumerate(eval_dataset), total=total_iters):
im = to_variable(im) im = to_variable(im)
# pred, _ = model(im)
logits = model(im) logits = model(im)
pred = paddle.argmax(logits[0], axis=1) pred = paddle.argmax(logits[0], axis=1)
pred = pred.numpy().astype('float32') pred = pred.numpy().astype('float32')
...@@ -68,7 +67,7 @@ def evaluate(model, ...@@ -68,7 +67,7 @@ def evaluate(model,
pred = pred.astype('int64') pred = pred.astype('int64')
mask = label != ignore_index mask = label != ignore_index
# To-DO Test Execution Time # To-DO Test Execution Time
conf_mat.calculate(pred=pred, label=label, ignore=mask) conf_mat.calculate(pred=pred, label=label, mask=mask)
_, iou = conf_mat.mean_iou() _, iou = conf_mat.mean_iou()
time_iter = timer.elapsed_time() time_iter = timer.elapsed_time()
......
...@@ -29,18 +29,32 @@ class ConfusionMatrix(object): ...@@ -29,18 +29,32 @@ class ConfusionMatrix(object):
self.num_classes = num_classes self.num_classes = num_classes
self.streaming = streaming self.streaming = streaming
def calculate(self, pred, label, ignore=None): def calculate(self, pred, label, mask):
"""
Calculate confusion matrix
Args:
pred (np.ndarray): The prediction of input image by model.
label (np.ndarray): The ground truth of input image.
mask (np.ndarray): The mask which pixel is valid. The dtype should be bool.
"""
# If not in streaming mode, clear matrix everytime when call `calculate` # If not in streaming mode, clear matrix everytime when call `calculate`
if not self.streaming: if not self.streaming:
self.zero_matrix() self.zero_matrix()
label = np.transpose(label, (0, 2, 3, 1)) pred = np.squeeze(pred)
ignore = np.transpose(ignore, (0, 2, 3, 1)) label = np.squeeze(label)
mask = np.array(ignore) == 1 mask = np.squeeze(mask)
if not pred.shape == label.shape == mask.shape:
raise ValueError(
'Shape of `pred`, `label` and `mask` should be equal, '
'but there are {}, {} and {}.'.format(pred.shape, label.shape,
mask.shape))
label = np.asarray(label)[mask] label = label[mask]
pred = np.asarray(pred)[mask] pred = pred[mask]
one = np.ones_like(pred) one = np.ones_like(pred).astype('int64')
# Accumuate ([row=label, col=pred], 1) into sparse # Accumuate ([row=label, col=pred], 1) into sparse
spm = csr_matrix((one, (label, pred)), spm = csr_matrix((one, (label, pred)),
shape=(self.num_classes, self.num_classes)) shape=(self.num_classes, self.num_classes))
......
...@@ -18,8 +18,8 @@ import paddle ...@@ -18,8 +18,8 @@ import paddle
from paddle.distributed import ParallelEnv from paddle.distributed import ParallelEnv
import paddleseg import paddleseg
from paddleseg.cvlibs import manager from paddleseg.cvlibs import manager, Config
from paddleseg.utils import get_environ_info, Config, logger from paddleseg.utils import get_environ_info, logger
from paddleseg.core import evaluate from paddleseg.core import evaluate
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册