From 94a7696924ec897390c7eb753514657cef8ec49a Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 1 Apr 2020 11:04:48 +0800 Subject: [PATCH] refactor ImageNet GitOrigin-RevId: f7774e0ffc5de7ffb3ea5eba5ddb9809b9d049dd --- .../megengine/data/dataset/vision/imagenet.py | 113 ++++++++++-------- 1 file changed, 60 insertions(+), 53 deletions(-) diff --git a/python_module/megengine/data/dataset/vision/imagenet.py b/python_module/megengine/data/dataset/vision/imagenet.py index 9a5abff76..82fcaca10 100644 --- a/python_module/megengine/data/dataset/vision/imagenet.py +++ b/python_module/megengine/data/dataset/vision/imagenet.py @@ -24,7 +24,7 @@ from ....core.serialization import load, save from ....distributed.util import is_distributed from ....logger import get_logger from .folder import ImageFolder -from .utils import _default_dataset_root, untar, untargz +from .utils import _default_dataset_root, calculate_md5, untar, untargz logger = get_logger(__name__) @@ -33,40 +33,28 @@ class ImageNet(ImageFolder): r""" Load ImageNet from raw files or folder, expected folder looks like - raw files situation (optional): - root/ILSVRC2012_img_train.tar - root/ILSVRC2012_img_val.tar - root/ILSVRC2012_devkit_t12.tar.gz - - image folder situation (required): - root/train/cls/xxx.${img_ext} - root/val/cls/xxx.${img_ext} - root/ILSVRC2012_devkit_t12/data/meta.mat - root/ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt - - If the required folders don't exist, raw files are required to get extracted and processed. + ${root}/ + | [REQUIRED TAR FILES] + |- ILSVRC2012_img_train.tar + |- ILSVRC2012_img_val.tar + |- ILSVRC2012_devkit_t12.tar.gz + | [OPTIONAL IMAGE FOLDERS] + |- train/cls/xxx.${img_ext} + |- val/cls/xxx.${img_ext} + |- ILSVRC2012_devkit_t12/data/meta.mat + |- ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt + + If the image folders don't exist, raw tar files are required to get extracted and processed. """ raw_file_meta = { "train": ("ILSVRC2012_img_train.tar", "1d675b47d978889d74fa0da5fadfb00e"), "val": ("ILSVRC2012_img_val.tar", "29b22e2961454d5413ddabcf34fc5622"), "devkit": ("ILSVRC2012_devkit_t12.tar.gz", "fa75699e90414af021442c21a62c3abf"), - } - """ - raw files of ImageNet (train, val, devkit) - """ + } # ImageNet raw files default_train_dir = "train" - """ - directory of train data - """ default_val_dir = "val" - """ - directory of val data - """ default_devkit_dir = "ILSVRC2012_devkit_t12" - """ - directory of devkit - """ def __init__(self, root: str = None, train: bool = True, **kwargs): r""" @@ -97,13 +85,16 @@ class ImageNet(ImageFolder): else: self.root = root - self.devkit_dir = os.path.join(self.root, self.default_devkit_dir) - if not os.path.exists(self.root): raise FileNotFoundError("dir %s does not exist" % self.root) + + self.devkit_dir = os.path.join(self.root, self.default_devkit_dir) if not os.path.exists(self.devkit_dir): logger.warning("devkit directory %s does not exists" % self.devkit_dir) + self._prepare_devkit() + + self.train = train if train: self.target_folder = os.path.join(self.root, self.default_train_dir) @@ -125,7 +116,7 @@ class ImageNet(ImageFolder): "extracting raw file shouldn't be done in distributed mode, use single process instead" ) else: - self.parse(train) + self._prepare_train() if train else self._prepare_val() super().__init__(self.target_folder, **kwargs) @@ -180,14 +171,13 @@ class ImageNet(ImageFolder): ] ) - def organize_val_data(self): + def _organize_val_data(self): id2wnid = self.meta[0] val_idcs = self.valid_ground_truth val_wnids = [id2wnid[idx] for idx in val_idcs] - raw_val_dir = os.path.join(self.root, "ILSVRC2012_img_val") val_images = sorted( - [os.path.join(raw_val_dir, image) for image in os.listdir(raw_val_dir)] + [os.path.join(self.target_folder, image) for image in os.listdir(self.target_folder)] ) logger.debug("mkdir for val set wnids") @@ -203,24 +193,41 @@ class ImageNet(ImageFolder): ), ) - def parse(self, train): - if train: - logger.info("process train raw file.. this may take several hours") - untar( - os.path.join(self.root, self.raw_file_meta["train"][0]), - self.target_folder, - ) - paths = [ - os.path.join(self.target_folder, child_dir) - for child_dir in os.listdir(self.target_folder) - ] - for path in tqdm(paths): - untar(path, os.path.splitext(path)[0], remove=True) - else: - logger.info("process devkit file..") - untargz(os.path.join(self.root, self.raw_file_meta["devkit"][0])) - logger.info("process valid raw file.. this may take 10-20 minutes") - raw_val_dir = os.path.join(self.root, "ILSVRC2012_img_val") - os.makedirs(raw_val_dir, exist_ok=True) - untar(os.path.join(self.root, self.raw_file_meta["val"][0]), raw_val_dir) - self.organize_val_data() + def _prepare_val(self): + assert not self.train + raw_filename, checksum = self.raw_file_meta["val"] + raw_file = os.path.join(self.root, raw_filename) + logger.info("checksum valid tar file {} ..".format(raw_file)) + assert calculate_md5(raw_file) == checksum, \ + "checksum mismatch, {} may be damaged".format(raw_file) + logger.info("extract valid tar file.. this may take 10-20 minutes") + untar(os.path.join(self.root, raw_file), self.target_folder) + self._organize_val_data() + + def _prepare_train(self): + assert self.train + raw_filename, checksum = self.raw_file_meta["train"] + raw_file = os.path.join(self.root, raw_filename) + logger.info("checksum train tar file {} ..".format(raw_file)) + assert calculate_md5(raw_file) == checksum, \ + "checksum mismatch, {} may be damaged".format(raw_file) + logger.info("extract train tar file.. this may take several hours") + untar( + os.path.join(self.root, raw_file), + self.target_folder, + ) + paths = [ + os.path.join(self.target_folder, child_dir) + for child_dir in os.listdir(self.target_folder) + ] + for path in tqdm(paths): + untar(path, os.path.splitext(path)[0], remove=True) + + def _prepare_devkit(self): + raw_filename, checksum = self.raw_file_meta["val"] + raw_file = os.path.join(self.root, raw_filename) + logger.info("checksum devkit tar file {} ..".format(raw_file)) + assert calculate_md5(raw_file) == checksum, \ + "checksum mismatch, {} may be damaged".format(raw_file) + logger.info("extract devkit file..") + untargz(os.path.join(self.root, self.raw_file_meta["devkit"][0])) -- GitLab