提交 bbb07bd3 编写于 作者: C chenguowei01

update dataset and compute iou in origin images

上级 dc837a2e
......@@ -14,10 +14,12 @@
import os
from paddle.fluid.io import Dataset
import paddle.fluid as fluid
import numpy as np
from PIL import Image
class Dataset(Dataset):
class Dataset(fluid.io.Dataset):
def __init__(self,
data_dir,
num_classes,
......@@ -85,12 +87,18 @@ class Dataset(Dataset):
def __getitem__(self, idx):
image_path, grt_path = self.file_list[idx]
im, im_info, label = self.transforms(im=image_path, label=grt_path)
if self.mode == 'train':
im, im_info, label = self.transforms(im=image_path, label=grt_path)
return im, label
elif self.mode == 'eval':
return im, label
im, im_info, _ = self.transforms(im=image_path)
im = im[np.newaxis, ...]
label = np.asarray(Image.open(grt_path))
label = label[np.newaxis, np.newaxis, :, :]
return im, im_info, label
if self.mode == 'test':
im, im_info, _ = self.transforms(im=image_path)
im = im[np.newaxis, ...]
return im, im_info, image_path
def __len__(self):
......
......@@ -98,7 +98,6 @@ def infer(model, test_dataset=None, model_dir=None, save_dir='output'):
logging.info("Start to predict...")
for im, im_info, im_path in tqdm.tqdm(test_dataset):
im = im[np.newaxis, ...]
im = to_variable(im)
pred, _ = model(im, mode='test')
pred = pred.numpy()
......
......@@ -230,10 +230,8 @@ def train(model,
mean_iou, mean_acc = evaluate(
model,
eval_dataset,
places=places,
model_dir=current_save_dir,
num_classes=num_classes,
batch_size=batch_size,
ignore_index=ignore_index,
epoch_id=epoch + 1)
if mean_iou > best_mean_iou:
......
......@@ -16,8 +16,10 @@ import argparse
import os
import math
from paddle.fluid.dygraph.base import to_variable
import numpy as np
import tqdm
import cv2
from paddle.fluid.dygraph.base import to_variable
import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import DataLoader
......@@ -61,12 +63,6 @@ def parse_args():
nargs=2,
default=[512, 512],
type=int)
parser.add_argument(
'--batch_size',
dest='batch_size',
help='Mini batch size',
type=int,
default=2)
parser.add_argument(
'--model_dir',
dest='model_dir',
......@@ -79,10 +75,8 @@ def parse_args():
def evaluate(model,
eval_dataset=None,
places=None,
model_dir=None,
num_classes=None,
batch_size=2,
ignore_index=255,
epoch_id=None):
ckpt_path = os.path.join(model_dir, 'model')
......@@ -90,15 +84,7 @@ def evaluate(model,
model.set_dict(para_state_dict)
model.eval()
batch_sampler = BatchSampler(
eval_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
loader = DataLoader(
eval_dataset,
batch_sampler=batch_sampler,
places=places,
return_list=True,
)
total_steps = len(batch_sampler)
total_steps = len(eval_dataset)
conf_mat = ConfusionMatrix(num_classes, streaming=True)
logging.info(
......@@ -106,15 +92,25 @@ def evaluate(model,
len(eval_dataset), total_steps))
timer = Timer()
timer.start()
for step, data in enumerate(loader):
images = data[0]
labels = data[1].astype('int64')
pred, _ = model(images, mode='eval')
for step, (im, im_info, label) in enumerate(eval_dataset):
im = to_variable(im)
pred, _ = model(im, mode='eval')
pred = pred.numpy()
labels = labels.numpy()
mask = labels != ignore_index
conf_mat.calculate(pred=pred, label=labels, ignore=mask)
pred = np.squeeze(pred).astype('uint8')
for info in im_info[::-1]:
if info[0] == 'resize':
h, w = info[1][0], info[1][1]
pred = cv2.resize(pred, (w, h), cv2.INTER_NEAREST)
elif info[0] == 'padding':
h, w = info[1][0], info[1][1]
pred = pred[0:h, 0:w]
else:
raise Exception("Unexpected info '{}' in im_info".format(
info[0]))
pred = pred[np.newaxis, :, :, np.newaxis]
mask = label != ignore_index
conf_mat.calculate(pred=pred, label=label, ignore=mask)
_, iou = conf_mat.mean_iou()
time_step = timer.elapsed_time()
......@@ -163,10 +159,8 @@ def main(args):
evaluate(
model,
eval_dataset,
places=places,
model_dir=args.model_dir,
num_classes=eval_dataset.num_classes,
batch_size=args.batch_size)
num_classes=eval_dataset.num_classes)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册