提交 3fa39034 编写于 作者: C chenguowei01

update confusion matrix calculation

上级 e9a2881d
......@@ -49,7 +49,6 @@ def evaluate(model,
for iter, (im, im_info, label) in tqdm.tqdm(
enumerate(eval_dataset), total=total_iters):
im = to_variable(im)
# pred, _ = model(im)
logits = model(im)
pred = paddle.argmax(logits[0], axis=1)
pred = pred.numpy().astype('float32')
......@@ -68,7 +67,7 @@ def evaluate(model,
pred = pred.astype('int64')
mask = label != ignore_index
# To-DO Test Execution Time
conf_mat.calculate(pred=pred, label=label, ignore=mask)
conf_mat.calculate(pred=pred, label=label, mask=mask)
_, iou = conf_mat.mean_iou()
time_iter = timer.elapsed_time()
......
......@@ -29,18 +29,32 @@ class ConfusionMatrix(object):
self.num_classes = num_classes
self.streaming = streaming
def calculate(self, pred, label, ignore=None):
def calculate(self, pred, label, mask):
"""
Calculate confusion matrix
Args:
pred (np.ndarray): The prediction of input image by model.
label (np.ndarray): The ground truth of input image.
mask (np.ndarray): The mask which pixel is valid. The dtype should be bool.
"""
# If not in streaming mode, clear matrix everytime when call `calculate`
if not self.streaming:
self.zero_matrix()
label = np.transpose(label, (0, 2, 3, 1))
ignore = np.transpose(ignore, (0, 2, 3, 1))
mask = np.array(ignore) == 1
pred = np.squeeze(pred)
label = np.squeeze(label)
mask = np.squeeze(mask)
label = np.asarray(label)[mask]
pred = np.asarray(pred)[mask]
one = np.ones_like(pred)
if not pred.shape == label.shape == mask.shape:
raise ValueError(
'Shape of `pred`, `label` and `mask` should be equal, '
'but there are {}, {} and {}.'.format(pred.shape, label.shape,
mask.shape))
label = label[mask]
pred = pred[mask]
one = np.ones_like(pred).astype('int64')
# Accumuate ([row=label, col=pred], 1) into sparse
spm = csr_matrix((one, (label, pred)),
shape=(self.num_classes, self.num_classes))
......
......@@ -18,8 +18,8 @@ import paddle
from paddle.distributed import ParallelEnv
import paddleseg
from paddleseg.cvlibs import manager
from paddleseg.utils import get_environ_info, Config, logger
from paddleseg.cvlibs import manager, Config
from paddleseg.utils import get_environ_info, logger
from paddleseg.core import evaluate
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册