diff --git a/dygraph/datasets/dataset.py b/dygraph/datasets/dataset.py index b604f0e62aa78fce3f4591e6599afb377a291d8d..908e90b4f4159e997446d0da40374ccde79abf9b 100644 --- a/dygraph/datasets/dataset.py +++ b/dygraph/datasets/dataset.py @@ -15,6 +15,8 @@ import os import paddle.fluid as fluid +import numpy as np +from PIL import Image class Dataset(fluid.io.Dataset): @@ -85,12 +87,18 @@ class Dataset(fluid.io.Dataset): def __getitem__(self, idx): image_path, grt_path = self.file_list[idx] - im, im_info, label = self.transforms(im=image_path, label=grt_path) if self.mode == 'train': + im, im_info, label = self.transforms(im=image_path, label=grt_path) return im, label elif self.mode == 'eval': - return im, label + im, im_info, _ = self.transforms(im=image_path) + im = im[np.newaxis, ...] + label = np.asarray(Image.open(grt_path)) + label = label[np.newaxis, np.newaxis, :, :] + return im, im_info, label if self.mode == 'test': + im, im_info, _ = self.transforms(im=image_path) + im = im[np.newaxis, ...] return im, im_info, image_path def __len__(self): diff --git a/dygraph/infer.py b/dygraph/infer.py index f5caf7a435d3083f7d84106024096684a9d4f3b8..0b25a48ff9c2c3ffbe9532d48c95564173364b2c 100644 --- a/dygraph/infer.py +++ b/dygraph/infer.py @@ -98,19 +98,20 @@ def infer(model, test_dataset=None, model_dir=None, save_dir='output'): logging.info("Start to predict...") for im, im_info, im_path in tqdm.tqdm(test_dataset): - im = im[np.newaxis, ...] im = to_variable(im) pred, _ = model(im, mode='test') pred = pred.numpy() pred = np.squeeze(pred).astype('uint8') - keys = list(im_info.keys()) - for k in keys[::-1]: - if k == 'shape_before_resize': - h, w = im_info[k][0], im_info[k][1] + for info in im_info[::-1]: + if info[0] == 'resize': + h, w = info[1][0], info[1][1] pred = cv2.resize(pred, (w, h), cv2.INTER_NEAREST) - elif k == 'shape_before_padding': - h, w = im_info[k][0], im_info[k][1] + elif info[0] == 'padding': + h, w = info[1][0], info[1][1] pred = pred[0:h, 0:w] + else: + raise Exception("Unexpected info '{}' in im_info".format( + info[0])) im_file = im_path.replace(test_dataset.data_dir, '') if im_file[0] == '/': diff --git a/dygraph/train.py b/dygraph/train.py index 709a66bb8c7f55dac0a83e5435a42893eb4d2e9a..70b61aaf839af9e4a6d44046037e7db703a8abcc 100644 --- a/dygraph/train.py +++ b/dygraph/train.py @@ -230,10 +230,8 @@ def train(model, mean_iou, mean_acc = evaluate( model, eval_dataset, - places=places, model_dir=current_save_dir, num_classes=num_classes, - batch_size=batch_size, ignore_index=ignore_index, epoch_id=epoch + 1) if mean_iou > best_mean_iou: diff --git a/dygraph/transforms/transforms.py b/dygraph/transforms/transforms.py index f2b24fbad48b53930d4ba1b16a9a08ee6ae3c10b..935a2c0f8670eaa24b148844aa727efe6942e666 100644 --- a/dygraph/transforms/transforms.py +++ b/dygraph/transforms/transforms.py @@ -13,28 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .functional import * import random +from collections import OrderedDict + import numpy as np from PIL import Image import cv2 -from collections import OrderedDict - - -class Compose: - """根据数据预处理/增强算子对输入数据进行操作。 - 所有操作的输入图像流形状均是[H, W, C],其中H为图像高,W为图像宽,C为图像通道数。 - - Args: - transforms (list): 数据预处理/增强算子。 - to_rgb (bool): 是否转化为rgb通道格式 - Raises: - TypeError: transforms不是list对象 - ValueError: transforms元素个数小于1。 +from .functional import * - """ +class Compose: def __init__(self, transforms, to_rgb=True): if not isinstance(transforms, list): raise TypeError('The transforms must be a list!') @@ -45,20 +34,8 @@ class Compose: self.to_rgb = to_rgb def __call__(self, im, im_info=None, label=None): - """ - Args: - im (str/np.ndarray): 图像路径/图像np.ndarray数据。 - im_info (dict): 存储与图像相关的信息,dict中的字段如下: - - shape_before_resize (tuple): 图像resize之前的大小(h, w)。 - - shape_before_padding (tuple): 图像padding之前的大小(h, w)。 - label (str/np.ndarray): 标注图像路径/标注图像np.ndarray数据。 - - Returns: - tuple: 根据网络所需字段所组成的tuple;字段由transforms中的最后一个数据预处理操作决定。 - """ - if im_info is None: - im_info = dict() + im_info = list() if isinstance(im, str): im = cv2.imread(im).astype('float32') if isinstance(label, str): @@ -82,28 +59,10 @@ class Compose: class RandomHorizontalFlip: - """以一定的概率对图像进行水平翻转。当存在标注图像时,则同步进行翻转。 - - Args: - prob (float): 随机水平翻转的概率。默认值为0.5。 - - """ - def __init__(self, prob=0.5): self.prob = prob def __call__(self, im, im_info=None, label=None): - """ - Args: - im (np.ndarray): 图像np.ndarray数据。 - im_info (dict): 存储与图像相关的信息。 - label (np.ndarray): 标注图像np.ndarray数据。 - - Returns: - tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; - 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 - 存储与图像相关信息的字典和标注图像np.ndarray数据。 - """ if random.random() < self.prob: im = horizontal_flip(im) if label is not None: @@ -115,27 +74,10 @@ class RandomHorizontalFlip: class RandomVerticalFlip: - """以一定的概率对图像进行垂直翻转。当存在标注图像时,则同步进行翻转。 - - Args: - prob (float): 随机垂直翻转的概率。默认值为0.1。 - """ - def __init__(self, prob=0.1): self.prob = prob def __call__(self, im, im_info=None, label=None): - """ - Args: - im (np.ndarray): 图像np.ndarray数据。 - im_info (dict): 存储与图像相关的信息。 - label (np.ndarray): 标注图像np.ndarray数据。 - - Returns: - tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; - 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 - 存储与图像相关信息的字典和标注图像np.ndarray数据。 - """ if random.random() < self.prob: im = vertical_flip(im) if label is not None: @@ -147,25 +89,6 @@ class RandomVerticalFlip: class Resize: - """调整图像大小(resize)。 - - - 当目标大小(target_size)类型为int时,根据插值方式, - 将图像resize为[target_size, target_size]。 - - 当目标大小(target_size)类型为list或tuple时,根据插值方式, - 将图像resize为target_size。 - 注意:当插值方式为“RANDOM”时,则随机选取一种插值方式进行resize。 - - Args: - target_size (int/list/tuple): 短边目标长度。默认为608。 - interp (str): resize的插值方式,与opencv的插值方式对应,取值范围为 - ['NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM']。默认为"LINEAR"。 - - Raises: - TypeError: 形参数据类型不满足需求。 - ValueError: 插值方式不在['NEAREST', 'LINEAR', 'CUBIC', - 'AREA', 'LANCZOS4', 'RANDOM']中。 - """ - # The interpolation mode interp_dict = { 'NEAREST': cv2.INTER_NEAREST, @@ -193,26 +116,9 @@ class Resize: self.target_size = target_size def __call__(self, im, im_info=None, label=None): - """ - Args: - im (np.ndarray): 图像np.ndarray数据。 - im_info (dict, 可选): 存储与图像相关的信息。 - label (np.ndarray): 标注图像np.ndarray数据。 - - Returns: - tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; - 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 - 存储与图像相关信息的字典和标注图像np.ndarray数据。 - 其中,im_info跟新字段为: - -shape_before_resize (tuple): 保存resize之前图像的形状(h, w)。 - - Raises: - TypeError: 形参数据类型不满足需求。 - ValueError: 数据长度不匹配。 - """ if im_info is None: - im_info = OrderedDict() - im_info['shape_before_resize'] = im.shape[:2] + im_info = list() + im_info.append(('resize', im.shape[:2])) if not isinstance(im, np.ndarray): raise TypeError("Resize: image type is not numpy.") if len(im.shape) != 3: @@ -232,33 +138,14 @@ class Resize: class ResizeByLong: - """对图像长边resize到固定值,短边按比例进行缩放。当存在标注图像时,则同步进行处理。 - - Args: - long_size (int): resize后图像的长边大小。 - """ - def __init__(self, long_size): self.long_size = long_size def __call__(self, im, im_info=None, label=None): - """ - Args: - im (np.ndarray): 图像np.ndarray数据。 - im_info (dict): 存储与图像相关的信息。 - label (np.ndarray): 标注图像np.ndarray数据。 - - Returns: - tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; - 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 - 存储与图像相关信息的字典和标注图像np.ndarray数据。 - 其中,im_info新增字段为: - -shape_before_resize (tuple): 保存resize之前图像的形状(h, w)。 - """ if im_info is None: - im_info = OrderedDict() + im_info = list() - im_info['shape_before_resize'] = im.shape[:2] + im_info.append(('resize', im.shape[:2])) im = resize_long(im, self.long_size) if label is not None: label = resize_long(label, self.long_size, cv2.INTER_NEAREST) @@ -270,16 +157,6 @@ class ResizeByLong: class ResizeRangeScaling: - """对图像长边随机resize到指定范围内,短边按比例进行缩放。当存在标注图像时,则同步进行处理。 - - Args: - min_value (int): 图像长边resize后的最小值。默认值400。 - max_value (int): 图像长边resize后的最大值。默认值600。 - - Raises: - ValueError: min_value大于max_value - """ - def __init__(self, min_value=400, max_value=600): if min_value > max_value: raise ValueError('min_value must be less than max_value, ' @@ -289,17 +166,6 @@ class ResizeRangeScaling: self.max_value = max_value def __call__(self, im, im_info=None, label=None): - """ - Args: - im (np.ndarray): 图像np.ndarray数据。 - im_info (dict): 存储与图像相关的信息。 - label (np.ndarray): 标注图像np.ndarray数据。 - - Returns: - tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; - 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 - 存储与图像相关信息的字典和标注图像np.ndarray数据。 - """ if self.min_value == self.max_value: random_size = self.max_value else: @@ -316,18 +182,6 @@ class ResizeRangeScaling: class ResizeStepScaling: - """对图像按照某一个比例resize,这个比例以scale_step_size为步长 - 在[min_scale_factor, max_scale_factor]随机变动。当存在标注图像时,则同步进行处理。 - - Args: - min_scale_factor(float), resize最小尺度。默认值0.75。 - max_scale_factor (float), resize最大尺度。默认值1.25。 - scale_step_size (float), resize尺度范围间隔。默认值0.25。 - - Raises: - ValueError: min_scale_factor大于max_scale_factor - """ - def __init__(self, min_scale_factor=0.75, max_scale_factor=1.25, @@ -342,17 +196,6 @@ class ResizeStepScaling: self.scale_step_size = scale_step_size def __call__(self, im, im_info=None, label=None): - """ - Args: - im (np.ndarray): 图像np.ndarray数据。 - im_info (dict): 存储与图像相关的信息。 - label (np.ndarray): 标注图像np.ndarray数据。 - - Returns: - tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; - 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 - 存储与图像相关信息的字典和标注图像np.ndarray数据。 - """ if self.min_scale_factor == self.max_scale_factor: scale_factor = self.min_scale_factor @@ -382,18 +225,6 @@ class ResizeStepScaling: class Normalize: - """对图像进行标准化。 - 1.尺度缩放到 [0,1]。 - 2.对图像进行减均值除以标准差操作。 - - Args: - mean (list): 图像数据集的均值。默认值[0.5, 0.5, 0.5]。 - std (list): 图像数据集的标准差。默认值[0.5, 0.5, 0.5]。 - - Raises: - ValueError: mean或std不是list对象。std包含0。 - """ - def __init__(self, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): self.mean = mean self.std = std @@ -404,18 +235,6 @@ class Normalize: raise ValueError('{}: std is invalid!'.format(self)) def __call__(self, im, im_info=None, label=None): - """ - Args: - im (np.ndarray): 图像np.ndarray数据。 - im_info (dict): 存储与图像相关的信息。 - label (np.ndarray): 标注图像np.ndarray数据。 - - Returns: - tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; - 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 - 存储与图像相关信息的字典和标注图像np.ndarray数据。 - """ - mean = np.array(self.mean)[np.newaxis, np.newaxis, :] std = np.array(self.std)[np.newaxis, np.newaxis, :] im = normalize(im, mean, std) @@ -427,19 +246,6 @@ class Normalize: class Padding: - """对图像或标注图像进行padding,padding方向为右和下。 - 根据提供的值对图像或标注图像进行padding操作。 - - Args: - target_size (int|list|tuple): padding后图像的大小。 - im_padding_value (list): 图像padding的值。默认为[127.5, 127.5, 127.5]。 - label_padding_value (int): 标注图像padding的值。默认值为255。 - - Raises: - TypeError: target_size不是int|list|tuple。 - ValueError: target_size为list|tuple时元素个数不等于2。 - """ - def __init__(self, target_size, im_padding_value=[127.5, 127.5, 127.5], @@ -458,25 +264,9 @@ class Padding: self.label_padding_value = label_padding_value def __call__(self, im, im_info=None, label=None): - """ - Args: - im (np.ndarray): 图像np.ndarray数据。 - im_info (dict): 存储与图像相关的信息。 - label (np.ndarray): 标注图像np.ndarray数据。 - - Returns: - tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; - 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 - 存储与图像相关信息的字典和标注图像np.ndarray数据。 - 其中,im_info新增字段为: - -shape_before_padding (tuple): 保存padding之前图像的形状(h, w)。 - - Raises: - ValueError: 输入图像im或label的形状大于目标值 - """ if im_info is None: - im_info = OrderedDict() - im_info['shape_before_padding'] = im.shape[:2] + im_info = list() + im_info.append(('padding', im.shape[:2])) im_height, im_width = im.shape[0], im.shape[1] if isinstance(self.target_size, int): @@ -516,18 +306,6 @@ class Padding: class RandomPaddingCrop: - """对图像和标注图进行随机裁剪,当所需要的裁剪尺寸大于原图时,则进行padding操作。 - - Args: - crop_size (int|list|tuple): 裁剪图像大小。默认为512。 - im_padding_value (list): 图像padding的值。默认为[127.5, 127.5, 127.5]。 - label_padding_value (int): 标注图像padding的值。默认值为255。 - - Raises: - TypeError: crop_size不是int/list/tuple。 - ValueError: target_size为list/tuple时元素个数不等于2。 - """ - def __init__(self, crop_size=512, im_padding_value=[127.5, 127.5, 127.5], @@ -546,17 +324,6 @@ class RandomPaddingCrop: self.label_padding_value = label_padding_value def __call__(self, im, im_info=None, label=None): - """ - Args: - im (np.ndarray): 图像np.ndarray数据。 - im_info (dict): 存储与图像相关的信息。 - label (np.ndarray): 标注图像np.ndarray数据。 - - Returns: - tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; - 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 - 存储与图像相关信息的字典和标注图像np.ndarray数据。 - """ if isinstance(self.crop_size, int): crop_width = self.crop_size crop_height = self.crop_size @@ -612,27 +379,10 @@ class RandomPaddingCrop: class RandomBlur: - """以一定的概率对图像进行高斯模糊。 - - Args: - prob (float): 图像模糊概率。默认为0.1。 - """ - def __init__(self, prob=0.1): self.prob = prob def __call__(self, im, im_info=None, label=None): - """ - Args: - im (np.ndarray): 图像np.ndarray数据。 - im_info (dict): 存储与图像相关的信息。 - label (np.ndarray): 标注图像np.ndarray数据。 - - Returns: - tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; - 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 - 存储与图像相关信息的字典和标注图像np.ndarray数据。 - """ if self.prob <= 0: n = 0 elif self.prob >= 1: @@ -655,17 +405,6 @@ class RandomBlur: class RandomRotation: - """对图像进行随机旋转。 - 在不超过最大旋转角度的情况下,图像进行随机旋转,当存在标注图像时,同步进行, - 并对旋转后的图像和标注图像进行相应的padding。 - - Args: - max_rotation (float): 最大旋转角度。默认为15度。 - im_padding_value (list): 图像padding的值。默认为[127.5, 127.5, 127.5]。 - label_padding_value (int): 标注图像padding的值。默认为255。 - - """ - def __init__(self, max_rotation=15, im_padding_value=[127.5, 127.5, 127.5], @@ -675,17 +414,6 @@ class RandomRotation: self.label_padding_value = label_padding_value def __call__(self, im, im_info=None, label=None): - """ - Args: - im (np.ndarray): 图像np.ndarray数据。 - im_info (dict): 存储与图像相关的信息。 - label (np.ndarray): 标注图像np.ndarray数据。 - - Returns: - tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; - 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 - 存储与图像相关信息的字典和标注图像np.ndarray数据。 - """ if self.max_rotation > 0: (h, w) = im.shape[:2] do_rotation = np.random.uniform(-self.max_rotation, @@ -724,30 +452,11 @@ class RandomRotation: class RandomScaleAspect: - """裁剪并resize回原始尺寸的图像和标注图像。 - 按照一定的面积比和宽高比对图像进行裁剪,并reszie回原始图像的图像,当存在标注图时,同步进行。 - - Args: - min_scale (float):裁取图像占原始图像的面积比,取值[0,1],为0时则返回原图。默认为0.5。 - aspect_ratio (float): 裁取图像的宽高比范围,非负值,为0时返回原图。默认为0.33。 - """ - def __init__(self, min_scale=0.5, aspect_ratio=0.33): self.min_scale = min_scale self.aspect_ratio = aspect_ratio def __call__(self, im, im_info=None, label=None): - """ - Args: - im (np.ndarray): 图像np.ndarray数据。 - im_info (dict): 存储与图像相关的信息。 - label (np.ndarray): 标注图像np.ndarray数据。 - - Returns: - tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; - 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 - 存储与图像相关信息的字典和标注图像np.ndarray数据。 - """ if self.min_scale != 0 and self.aspect_ratio != 0: img_height = im.shape[0] img_width = im.shape[1] @@ -784,22 +493,6 @@ class RandomScaleAspect: class RandomDistort: - """对图像进行随机失真。 - - 1. 对变换的操作顺序进行随机化操作。 - 2. 按照1中的顺序以一定的概率对图像进行随机像素内容变换。 - - Args: - brightness_range (float): 明亮度因子的范围。默认为0.5。 - brightness_prob (float): 随机调整明亮度的概率。默认为0.5。 - contrast_range (float): 对比度因子的范围。默认为0.5。 - contrast_prob (float): 随机调整对比度的概率。默认为0.5。 - saturation_range (float): 饱和度因子的范围。默认为0.5。 - saturation_prob (float): 随机调整饱和度的概率。默认为0.5。 - hue_range (int): 色调因子的范围。默认为18。 - hue_prob (float): 随机调整色调的概率。默认为0.5。 - """ - def __init__(self, brightness_range=0.5, brightness_prob=0.5, @@ -819,17 +512,6 @@ class RandomDistort: self.hue_prob = hue_prob def __call__(self, im, im_info=None, label=None): - """ - Args: - im (np.ndarray): 图像np.ndarray数据。 - im_info (dict): 存储与图像相关的信息。 - label (np.ndarray): 标注图像np.ndarray数据。 - - Returns: - tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; - 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 - 存储与图像相关信息的字典和标注图像np.ndarray数据。 - """ brightness_lower = 1 - self.brightness_range brightness_upper = 1 + self.brightness_range contrast_lower = 1 - self.contrast_range diff --git a/dygraph/val.py b/dygraph/val.py index 41d0d33485d1052bef3b1c4d70b546cdf89d3922..ca36a6fe1ca169d30f2dbd06ff58da62b507ff4f 100644 --- a/dygraph/val.py +++ b/dygraph/val.py @@ -16,8 +16,10 @@ import argparse import os import math -from paddle.fluid.dygraph.base import to_variable import numpy as np +import tqdm +import cv2 +from paddle.fluid.dygraph.base import to_variable import paddle.fluid as fluid from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.io import DataLoader @@ -61,12 +63,6 @@ def parse_args(): nargs=2, default=[512, 512], type=int) - parser.add_argument( - '--batch_size', - dest='batch_size', - help='Mini batch size', - type=int, - default=2) parser.add_argument( '--model_dir', dest='model_dir', @@ -79,10 +75,8 @@ def parse_args(): def evaluate(model, eval_dataset=None, - places=None, model_dir=None, num_classes=None, - batch_size=2, ignore_index=255, epoch_id=None): ckpt_path = os.path.join(model_dir, 'model') @@ -90,15 +84,7 @@ def evaluate(model, model.set_dict(para_state_dict) model.eval() - batch_sampler = BatchSampler( - eval_dataset, batch_size=batch_size, shuffle=False, drop_last=False) - loader = DataLoader( - eval_dataset, - batch_sampler=batch_sampler, - places=places, - return_list=True, - ) - total_steps = len(batch_sampler) + total_steps = len(eval_dataset) conf_mat = ConfusionMatrix(num_classes, streaming=True) logging.info( @@ -106,15 +92,26 @@ def evaluate(model, len(eval_dataset), total_steps)) timer = Timer() timer.start() - for step, data in enumerate(loader): - images = data[0] - labels = data[1].astype('int64') - pred, _ = model(images, mode='eval') - - pred = pred.numpy() - labels = labels.numpy() - mask = labels != ignore_index - conf_mat.calculate(pred=pred, label=labels, ignore=mask) + for step, (im, im_info, label) in enumerate(eval_dataset): + im = to_variable(im) + pred, _ = model(im, mode='eval') + pred = pred.numpy().astype('float32') + pred = np.squeeze(pred) + for info in im_info[::-1]: + if info[0] == 'resize': + h, w = info[1][0], info[1][1] + pred = cv2.resize(pred, (w, h), cv2.INTER_NEAREST) + elif info[0] == 'padding': + h, w = info[1][0], info[1][1] + pred = pred[0:h, 0:w] + else: + raise Exception("Unexpected info '{}' in im_info".format( + info[0])) + pred = pred[np.newaxis, :, :, np.newaxis] + pred = pred.astype('int64') + mask = label != ignore_index + + conf_mat.calculate(pred=pred, label=label, ignore=mask) _, iou = conf_mat.mean_iou() time_step = timer.elapsed_time() @@ -163,10 +160,8 @@ def main(args): evaluate( model, eval_dataset, - places=places, model_dir=args.model_dir, - num_classes=eval_dataset.num_classes, - batch_size=args.batch_size) + num_classes=eval_dataset.num_classes) if __name__ == '__main__':