提交 881fc5df 编写于 作者: C chenguowei01

update

上级 45e5dc07
...@@ -228,8 +228,7 @@ def generate_minibatch(batch_data, label_padding_value=255): ...@@ -228,8 +228,7 @@ def generate_minibatch(batch_data, label_padding_value=255):
elif len(data) > 1: elif len(data) > 1:
if isinstance(data[1], np.ndarray) and len(data[1].shape) > 1: if isinstance(data[1], np.ndarray) and len(data[1].shape) > 1:
# padding the image and label of segmentation # padding the image and label of segmentation during the training
# during the training and evaluating phase
# the data[1] of segmentation is a image array, # the data[1] of segmentation is a image array,
# so len(data[1].shape) > 1 # so len(data[1].shape) > 1
padding_label = np.zeros( padding_label = np.zeros(
......
...@@ -360,7 +360,7 @@ class DeepLabv3p(BaseAPI): ...@@ -360,7 +360,7 @@ class DeepLabv3p(BaseAPI):
pred = pred[0:num_samples] pred = pred[0:num_samples]
for i in range(num_samples): for i in range(num_samples):
one_pred = pred[i].astype('int64') one_pred = pred[i].astype('uint8')
one_label = labels[i] one_label = labels[i]
for info in im_info[i][::-1]: for info in im_info[i][::-1]:
if info[0] == 'resize': if info[0] == 'resize':
...@@ -372,6 +372,7 @@ class DeepLabv3p(BaseAPI): ...@@ -372,6 +372,7 @@ class DeepLabv3p(BaseAPI):
else: else:
raise Exception("Unexpected info '{}' in im_info".format( raise Exception("Unexpected info '{}' in im_info".format(
info[0])) info[0]))
one_pred = one_pred.astype('int64')
one_pred = one_pred[np.newaxis, :, :, np.newaxis] one_pred = one_pred[np.newaxis, :, :, np.newaxis]
one_label = one_label[np.newaxis, np.newaxis, :, :] one_label = one_label[np.newaxis, np.newaxis, :, :]
mask = one_label != self.ignore_index mask = one_label != self.ignore_index
......
...@@ -90,7 +90,7 @@ class Compose(SegTransform): ...@@ -90,7 +90,7 @@ class Compose(SegTransform):
if label is not None: if label is not None:
if not isinstance(label, np.ndarray): if not isinstance(label, np.ndarray):
label = np.asarray(Image.open(label)) label = np.asarray(Image.open(label))
origin_label = label.copy() origin_label = label.copy()
for op in self.transforms: for op in self.transforms:
if isinstance(op, SegTransform): if isinstance(op, SegTransform):
outputs = op(im, im_info, label) outputs = op(im, im_info, label)
...@@ -105,7 +105,7 @@ class Compose(SegTransform): ...@@ -105,7 +105,7 @@ class Compose(SegTransform):
outputs = (im, im_info, label) outputs = (im, im_info, label)
else: else:
outputs = (im, im_info) outputs = (im, im_info)
if type(self.transforms[-1]).__name__ == 'ArrangeSegmenter': if self.transforms[-1].__class__.__name__ == 'ArrangeSegmenter':
if self.transforms[-1].mode == 'eval': if self.transforms[-1].mode == 'eval':
if label is not None: if label is not None:
outputs = (im, im_info, origin_label) outputs = (im, im_info, origin_label)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册