未验证 提交 cd856fb1 编写于 作者: J Jason 提交者: GitHub

Merge pull request #227 from wuyefeilin/develop

eval in origin image
......@@ -217,10 +217,18 @@ def generate_minibatch(batch_data, label_padding_value=255):
padding_im = np.zeros(
(im_c, max_shape[1], max_shape[2]), dtype=np.float32)
padding_im[:, :im_h, :im_w] = data[0]
if len(data) > 1:
if len(data) > 2:
# padding the image, label and insert 'padding' into `im_info` of segmentation during evaluating phase.
if len(data[1]) == 0 or 'padding' not in [
data[1][i][0] for i in range(len(data[1]))
]:
data[1].append(('padding', [im_h, im_w]))
padding_batch.append((padding_im, data[1], data[2]))
elif len(data) > 1:
if isinstance(data[1], np.ndarray) and len(data[1].shape) > 1:
# padding the image and label of segmentation
# during the training and evaluating phase
# padding the image and label of segmentation during the training
# the data[1] of segmentation is a image array,
# so len(data[1].shape) > 1
padding_label = np.zeros(
......
......@@ -340,7 +340,8 @@ class DeepLabv3p(BaseAPI):
for step, data in tqdm.tqdm(
enumerate(data_generator()), total=total_steps):
images = np.array([d[0] for d in data])
labels = np.array([d[1] for d in data])
im_info = [d[1] for d in data]
labels = [d[2] for d in data]
num_samples = images.shape[0]
if num_samples < batch_size:
......@@ -358,10 +359,25 @@ class DeepLabv3p(BaseAPI):
if num_samples < batch_size:
pred = pred[0:num_samples]
mask = labels != self.ignore_index
conf_mat.calculate(pred=pred, label=labels, ignore=mask)
for i in range(num_samples):
one_pred = pred[i].astype('uint8')
one_label = labels[i]
for info in im_info[i][::-1]:
if info[0] == 'resize':
w, h = info[1][1], info[1][0]
one_pred = cv2.resize(one_pred, (w, h), cv2.INTER_NEAREST)
elif info[0] == 'padding':
w, h = info[1][1], info[1][0]
one_pred = one_pred[0:h, 0:w]
else:
raise Exception("Unexpected info '{}' in im_info".format(
info[0]))
one_pred = one_pred.astype('int64')
one_pred = one_pred[np.newaxis, :, :, np.newaxis]
one_label = one_label[np.newaxis, np.newaxis, :, :]
mask = one_label != self.ignore_index
conf_mat.calculate(pred=one_pred, label=one_label, ignore=mask)
_, iou = conf_mat.mean_iou()
logging.debug("[EVAL] Epoch={}, Step={}/{}, iou={}".format(
epoch_id, step + 1, total_steps, iou))
......
......@@ -90,6 +90,7 @@ class Compose(SegTransform):
if label is not None:
if not isinstance(label, np.ndarray):
label = np.asarray(Image.open(label))
origin_label = label.copy()
for op in self.transforms:
if isinstance(op, SegTransform):
outputs = op(im, im_info, label)
......@@ -104,6 +105,10 @@ class Compose(SegTransform):
outputs = (im, im_info, label)
else:
outputs = (im, im_info)
if self.transforms[-1].__class__.__name__ == 'ArrangeSegmenter':
if self.transforms[-1].mode == 'eval':
if label is not None:
outputs = (im, im_info, origin_label)
return outputs
def add_augmenters(self, augmenters):
......@@ -1092,9 +1097,12 @@ class ArrangeSegmenter(SegTransform):
'quant'时,返回的tuple为(im,),为图像np.ndarray数据。
"""
im = permute(im, False)
if self.mode == 'train' or self.mode == 'eval':
if self.mode == 'train':
label = label[np.newaxis, :, :]
return (im, label)
if self.mode == 'eval':
label = label[np.newaxis, :, :]
return (im, im_info, label)
elif self.mode == 'test':
return (im, im_info)
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册