diff --git a/dygraph/datasets/ade.py b/dygraph/datasets/ade.py index c65c3994b708c7b7ebc84cf44d2067ae2478aafd..6a220adb91430a96d443556cc97bd1e770b90bc6 100644 --- a/dygraph/datasets/ade.py +++ b/dygraph/datasets/ade.py @@ -14,6 +14,9 @@ import os +import numpy as np +from PIL import Image + from .dataset import Dataset from utils.download import download_file_and_uncompress @@ -74,3 +77,21 @@ class ADE20K(Dataset): img_path = os.path.join(img_dir, img_files[i]) grt_path = os.path.join(grt_dir, grt_files[i]) self.file_list.append([img_path, grt_path]) + + def __getitem__(self, idx): + image_path, grt_path = self.file_list[idx] + if self.mode == 'test': + im, im_info, _ = self.transforms(im=image_path) + im = im[np.newaxis, ...] + return im, im_info, image_path + elif self.mode == 'val': + im, im_info, _ = self.transforms(im=image_path) + im = im[np.newaxis, ...] + label = np.asarray(Image.open(grt_path)) + label = label - 1 + label = label[np.newaxis, np.newaxis, :, :] + return im, im_info, label + else: + im, im_info, label = self.transforms(im=image_path, label=grt_path) + label = label - 1 + return im, label