diff --git a/README.md b/README.md index 8bd105e6d62390468987212ecd1fef2065b96177..a8a9ba0122c546a2e38890219c5f60780783faf9 100644 --- a/README.md +++ b/README.md @@ -122,9 +122,9 @@ python -m yolox.tools.train -n yolox-s -d 8 -b 64 --fp16 -o [--cache] * -d: number of gpu devices * -b: total batch size, the recommended number for -b is num-gpu * 8 * --fp16: mixed precision training -* --cache: caching imgs into RAM to accelarate training, which need large system RAM. +* --cache: caching imgs into RAM to accelarate training, which need large system RAM. + - When using -f, the above commands are equivalent to: ```shell @@ -140,7 +140,8 @@ We also support multi-nodes training. Just add the following args: * --num\_machines: num of your total training nodes * --machine\_rank: specify the rank of each node -Suppose you want to train YOLOX on 2 machines, and your master machines's IP is 123.123.123.123, use port 12312 and TCP. +Suppose you want to train YOLOX on 2 machines, and your master machines's IP is 123.123.123.123, use port 12312 and TCP. + On master machine, run ```shell python tools/train.py -n yolox-s -b 128 --dist-url tcp://123.123.123.123:12312 --num_machines 2 --machine_rank 0 @@ -163,7 +164,8 @@ python tools/train.py -n yolox-s -d 8 -b 64 --fp16 -o [--cache] --logger wandb w An example wandb dashboard is available [here](https://wandb.ai/manan-goel/yolox-nano/runs/3pzfeom0) -**Others** +**Others** + See more information with the following command: ```shell python -m yolox.tools.train --help @@ -202,6 +204,7 @@ python -m yolox.tools.eval -n yolox-s -c yolox_s.pth -b 1 -d 1 --conf 0.001 --f Tutorials * [Training on custom data](docs/train_custom_data.md) +* [Caching for custom data](docs/cache.md) * [Manipulating training image size](docs/manipulate_training_image_size.md) * [Freezing model](docs/freeze_module.md) diff --git a/docs/cache.md b/docs/cache.md new file mode 100755 index 0000000000000000000000000000000000000000..66aded7cb7e83ef4f568912cd7d5c751a74006b9 --- /dev/null +++ b/docs/cache.md @@ -0,0 +1,97 @@ +# Cache Custom Data + +The caching feature is specifically tailored for users with ample memory resources. However, we still offer the option to cache data to disk, but disk performance can vary and may not guarantee optimal user experience. Implementing custom dataset RAM caching is also more straightforward and user-friendly compared to disk caching. With a few simple modifications, users can expect to see a significant increase in training speed, with speeds nearly double that of non-cached datasets. + +This page explains how to cache your own custom data with YOLOX. + +## 0. Before you start + +**Step1** Clone this repo and follow the [README](../README.md) to install YOLOX. + +**Stpe2** Read the [Training on custom data](./train_custom_data.md) tutorial to understand how to prepare your custom data. + +## 1. Inheirit from `CacheDataset` + + +**Step1** Create a custom dataset that inherits from the `CacheDataset` class. Note that whether inheriting from `Dataset` or `CacheDataset `, the `__init__()` method of your custom dataset should take the following keyword arguments: `input_dimension`, `cache`, and `cache_type`. Also, call `super().__init__()` and pass in `input_dimension`, `num_imgs`, `cache`, and `cache_type` as input, where `num_imgs` is the size of the dataset. + +**Step2** Implement the abstract function `read_img(self, index, use_cache=True)` of parent class and decorate it with `@cache_read_img`. This function takes an `index` as input and returns an `image`, and the returned image will be used for caching. It is recommended to put all repetitive and fixed post-processing operations on the image in this function to reduce the post-processing time of the image during training. + +```python +# CustomDataset.py +from yolox.data.datasets import CacheDataset, cache_read_img + +class CustomDataset(CacheDataset): + def __init__(self, input_dimension, cache, cache_type, *args, **kwargs): + # Get the required keyword arguments of super().__init__() + super().__init__( + input_dimension=input_dimension, + num_imgs=num_imgs, + cache=cache, + cache_type=cache_type + ) + # ... + + @cache_read_img + def read_img(self, index, use_cache=True): + # get image ... + # (optional) repetitive and fixed post-processing operations for image + return image +``` + +## 2. Create your Exp file and return your custom dataset + +**Step1** Create a new class that inherits from the `Exp` class provided by the `yolox_base.py`. Override the `get_dataset()` and `get_eval_dataset()` method to return an instance of your custom dataset. + +**Step2** Implement your own `get_evaluator` method to return an instance of your custom evaluator. + +```python +# CustomeExp.py +from yolox.exp import Exp as MyExp + +class Exp(MyExp): + def get_dataset(self, cache, cache_type: str = "ram"): + return CustomDataset( + input_dimension=self.input_size, + cache=cache, + cache_type=cache_type + ) + + def get_eval_dataset(self): + return CustomDataset( + input_dimension=self.input_size, + ) + + def get_evaluator(self, batch_size, is_distributed, testdev=False, legacy=False): + return CustomEvaluator( + dataloader=self.get_eval_loader(batch_size, is_distributed, testdev=testdev, legacy=legacy), + img_size=self.test_size, + confthre=self.test_conf, + nmsthre=self.nmsthre, + num_classes=self.num_classes, + testdev=testdev, + ) +``` + +**(Optional)** `get_data_loader` and `get_eval_loader` are now a default behavior in `yolox_base.py` and generally do not need to be changed. If you have to change `get_data_loader`, you need to add the following code at the beginning. + +```python +# CustomeExp.py +from yolox.exp import Exp as MyExp + +class Exp(MyExp): + def get_data_loader(self, batch_size, is_distributed, no_aug=False, cache_img: str = None): + if self.dataset is None: + with wait_for_the_master(): + assert cache_img is None + self.dataset = self.get_dataset(cache=False, cache_type=cache_img) + # ... + +``` + +## 3. Cache to Disk +It's important to note that the `cache_type` can be `"ram"` or `"disk"`, depending on where you want to cache your dataset. If you choose `"disk"`, you need to pass in additional parameters to `super().__init__()` of `CustomDataset`: `data_dir`, `cache_dir_name`, `path_filename`. + +- `data_dir`: the root directory of the dataset, e.g. `/path/to/COCO`. +- `cache_dir_name`: the name of the directory to cache to disk, for example `"custom_cache"`, then the files cached to disk will be saved under `/path/to/COCO/custom_cache`. +- `path_filename`: a list of paths to the data relative to the `data_dir`, e.g. if you have data `/path/to/COCO/train/1.jpg`, `/path/to/COCO/train/2.jpg`, then `path_filename = ['train/1.jpg', ' train/2.jpg']`. diff --git a/exps/example/yolox_voc/yolox_voc_s.py b/exps/example/yolox_voc/yolox_voc_s.py index e5cdb61036d5e580eced65f93bed7398613d7172..379ba9ac79adb4c0fa8677088f2c6eaafda38046 100644 --- a/exps/example/yolox_voc/yolox_voc_s.py +++ b/exps/example/yolox_voc/yolox_voc_s.py @@ -1,9 +1,6 @@ # encoding: utf-8 import os -import torch -import torch.distributed as dist - from yolox.data import get_yolox_datadir from yolox.exp import Exp as MyExp @@ -24,115 +21,40 @@ class Exp(MyExp): self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0] - def get_data_loader(self, batch_size, is_distributed, no_aug=False, cache_img=False): - from yolox.data import ( - VOCDetection, - TrainTransform, - YoloBatchSampler, - DataLoader, - InfiniteSampler, - MosaicDetection, - worker_init_reset_seed, - ) - from yolox.utils import ( - wait_for_the_master, - get_local_rank, - ) - local_rank = get_local_rank() + def get_dataset(self, cache: bool, cache_type: str = "ram"): + from yolox.data import VOCDetection, TrainTransform - with wait_for_the_master(local_rank): - dataset = VOCDetection( - data_dir=os.path.join(get_yolox_datadir(), "VOCdevkit"), - image_sets=[('2007', 'trainval'), ('2012', 'trainval')], - img_size=self.input_size, - preproc=TrainTransform( - max_labels=50, - flip_prob=self.flip_prob, - hsv_prob=self.hsv_prob), - cache=cache_img, - ) - - dataset = MosaicDetection( - dataset, - mosaic=not no_aug, + return VOCDetection( + data_dir=os.path.join(get_yolox_datadir(), "VOCdevkit"), + image_sets=[('2007', 'trainval'), ('2012', 'trainval')], img_size=self.input_size, preproc=TrainTransform( - max_labels=120, + max_labels=50, flip_prob=self.flip_prob, hsv_prob=self.hsv_prob), - degrees=self.degrees, - translate=self.translate, - mosaic_scale=self.mosaic_scale, - mixup_scale=self.mixup_scale, - shear=self.shear, - enable_mixup=self.enable_mixup, - mosaic_prob=self.mosaic_prob, - mixup_prob=self.mixup_prob, - ) - - self.dataset = dataset - - if is_distributed: - batch_size = batch_size // dist.get_world_size() - - sampler = InfiniteSampler( - len(self.dataset), seed=self.seed if self.seed else 0 + cache=cache, + cache_type=cache_type, ) - batch_sampler = YoloBatchSampler( - sampler=sampler, - batch_size=batch_size, - drop_last=False, - mosaic=not no_aug, - ) - - dataloader_kwargs = {"num_workers": self.data_num_workers, "pin_memory": True} - dataloader_kwargs["batch_sampler"] = batch_sampler - - # Make sure each process has different random seed, especially for 'fork' method - dataloader_kwargs["worker_init_fn"] = worker_init_reset_seed - - train_loader = DataLoader(self.dataset, **dataloader_kwargs) - - return train_loader - - def get_eval_loader(self, batch_size, is_distributed, testdev=False, legacy=False): + def get_eval_dataset(self, **kwargs): from yolox.data import VOCDetection, ValTransform + legacy = kwargs.get("legacy", False) - valdataset = VOCDetection( + return VOCDetection( data_dir=os.path.join(get_yolox_datadir(), "VOCdevkit"), image_sets=[('2007', 'test')], img_size=self.test_size, preproc=ValTransform(legacy=legacy), ) - if is_distributed: - batch_size = batch_size // dist.get_world_size() - sampler = torch.utils.data.distributed.DistributedSampler( - valdataset, shuffle=False - ) - else: - sampler = torch.utils.data.SequentialSampler(valdataset) - - dataloader_kwargs = { - "num_workers": self.data_num_workers, - "pin_memory": True, - "sampler": sampler, - } - dataloader_kwargs["batch_size"] = batch_size - val_loader = torch.utils.data.DataLoader(valdataset, **dataloader_kwargs) - - return val_loader - def get_evaluator(self, batch_size, is_distributed, testdev=False, legacy=False): from yolox.evaluators import VOCEvaluator - val_loader = self.get_eval_loader(batch_size, is_distributed, testdev, legacy) - evaluator = VOCEvaluator( - dataloader=val_loader, + return VOCEvaluator( + dataloader=self.get_eval_loader(batch_size, is_distributed, + testdev=testdev, legacy=legacy), img_size=self.test_size, confthre=self.test_conf, nmsthre=self.nmsthre, num_classes=self.num_classes, ) - return evaluator diff --git a/tools/train.py b/tools/train.py index 3166db5a4548b0d907731e4ffc20a14168f9b2cd..999ce8f776934868f8008e57b1caf22e7f7649d7 100644 --- a/tools/train.py +++ b/tools/train.py @@ -131,7 +131,7 @@ if __name__ == "__main__": assert num_gpu <= get_num_devices() if args.cache is not None: - exp.create_cache_dataset(args.cache) + exp.dataset = exp.get_dataset(cache=True, cache_type=args.cache) dist_url = "auto" if args.dist_url is None else args.dist_url launch( diff --git a/yolox/data/datasets/__init__.py b/yolox/data/datasets/__init__.py index dee2c9f482d7c3bf5d3b7609c71f5d93455bd6c9..0b6fd8ec4cecffe94d80084b57f3b966e4f01def 100644 --- a/yolox/data/datasets/__init__.py +++ b/yolox/data/datasets/__init__.py @@ -4,6 +4,6 @@ from .coco import COCODataset from .coco_classes import COCO_CLASSES -from .datasets_wrapper import ConcatDataset, Dataset, MixConcatDataset +from .datasets_wrapper import CacheDataset, ConcatDataset, Dataset, MixConcatDataset from .mosaicdetection import MosaicDetection from .voc import VOCDetection diff --git a/yolox/data/datasets/coco.py b/yolox/data/datasets/coco.py index 6da23c29581f8ce7beb21774bfb6dfaa108ecca4..8d19047a2bdef1c2a1af544d484cb2eee3af8aaa 100644 --- a/yolox/data/datasets/coco.py +++ b/yolox/data/datasets/coco.py @@ -3,18 +3,13 @@ # Copyright (c) Megvii, Inc. and its affiliates. import copy import os -import random -from multiprocessing.pool import ThreadPool -import psutil -from loguru import logger -from tqdm import tqdm import cv2 import numpy as np from pycocotools.coco import COCO from ..dataloading import get_yolox_datadir -from .datasets_wrapper import Dataset +from .datasets_wrapper import CacheDataset, cache_read_img def remove_useless_info(coco): @@ -36,7 +31,7 @@ def remove_useless_info(coco): anno.pop("segmentation", None) -class COCODataset(Dataset): +class COCODataset(CacheDataset): """ COCO dataset class. """ @@ -60,7 +55,6 @@ class COCODataset(Dataset): img_size (int): target image size after pre-processing preproc: data augmentation strategy """ - super().__init__(img_size) if data_dir is None: data_dir = os.path.join(get_yolox_datadir(), "COCO") self.data_dir = data_dir @@ -77,85 +71,21 @@ class COCODataset(Dataset): self.img_size = img_size self.preproc = preproc self.annotations = self._load_coco_annotations() - self.imgs = None - self.cache = cache - self.cache_type = cache_type - if self.cache: - self._cache_images() - - def _cache_images(self): - mem = psutil.virtual_memory() - mem_required = self.cal_cache_ram() - gb = 1 << 30 - - if self.cache_type == "ram" and mem_required > mem.available: - self.cache = False - else: - logger.info( - f"{mem_required / gb:.1f}GB RAM required, " - f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB RAM available, " - f"Since the first thing we do is cache, " - f"there is no guarantee that the remaining memory space is sufficient" - ) - - if self.imgs is None: - if self.cache_type == 'ram': - self.imgs = [None] * self.num_imgs - logger.info("You are using cached images in RAM to accelerate training!") - else: # 'disk' - self.cache_dir = os.path.join( - self.data_dir, - f"{self.name}_cache{self.img_size[0]}x{self.img_size[1]}" - ) - if not os.path.exists(self.cache_dir): - os.mkdir(self.cache_dir) - logger.warning( - f"\n*******************************************************************\n" - f"You are using cached images in DISK to accelerate training.\n" - f"This requires large DISK space.\n" - f"Make sure you have {mem_required / gb:.1f} " - f"available DISK space for training COCO.\n" - f"*******************************************************************\\n" - ) - else: - logger.info("Found disk cache!") - return - - logger.info( - "Caching images for the first time. " - "This might take about 15 minutes for COCO" - ) - - num_threads = min(8, max(1, os.cpu_count() - 1)) - b = 0 - load_imgs = ThreadPool(num_threads).imap(self.load_resized_img, range(self.num_imgs)) - pbar = tqdm(enumerate(load_imgs), total=self.num_imgs) - for i, x in pbar: # x = self.load_resized_img(self, i) - if self.cache_type == 'ram': - self.imgs[i] = x - else: # 'disk' - cache_filename = f'{self.annotations[i]["filename"].split(".")[0]}.npy' - np.save(os.path.join(self.cache_dir, cache_filename), x) - b += x.nbytes - pbar.desc = f'Caching images ({b / gb:.1f}/{mem_required / gb:.1f}GB {self.cache})' - pbar.close() - - def cal_cache_ram(self): - cache_bytes = 0 - num_samples = min(self.num_imgs, 32) - for _ in range(num_samples): - img = self.load_resized_img(random.randint(0, self.num_imgs - 1)) - cache_bytes += img.nbytes - mem_required = cache_bytes * self.num_imgs / num_samples - return mem_required + path_filename = [os.path.join(name, anno[3]) for anno in self.annotations] + super().__init__( + input_dimension=img_size, + num_imgs=self.num_imgs, + data_dir=data_dir, + cache_dir_name=f"cache_{name}", + path_filename=path_filename, + cache=cache, + cache_type=cache_type + ) def __len__(self): return self.num_imgs - def __del__(self): - del self.imgs - def _load_coco_annotations(self): return [self.load_anno_from_ids(_ids) for _ids in self.ids] @@ -220,20 +150,18 @@ class COCODataset(Dataset): return img + @cache_read_img(use_cache=True) + def read_img(self, index): + return self.load_resized_img(index) + def pull_item(self, index): id_ = self.ids[index] - label, origin_image_size, _, filename = self.annotations[index] - - if self.cache and self.cache_type == 'ram': - img = self.imgs[index] - elif self.cache and self.cache_type == 'disk': - img = np.load(os.path.join(self.cache_dir, f"{filename.split('.')[0]}.npy")) - else: - img = self.load_resized_img(index) + label, origin_image_size, _, _ = self.annotations[index] + img = self.read_img(index) - return copy.deepcopy(img), copy.deepcopy(label), origin_image_size, np.array([id_]) + return img, copy.deepcopy(label), origin_image_size, np.array([id_]) - @Dataset.mosaic_getitem + @CacheDataset.mosaic_getitem def __getitem__(self, index): """ One image / label pair for the given index is picked up and pre-processed. diff --git a/yolox/data/datasets/datasets_wrapper.py b/yolox/data/datasets/datasets_wrapper.py index f85a121d5a655560568f839019ffd4ac7309e287..c45fe380f5b7ac1c40452ff3903da651fe324225 100644 --- a/yolox/data/datasets/datasets_wrapper.py +++ b/yolox/data/datasets/datasets_wrapper.py @@ -3,7 +3,17 @@ # Copyright (c) Megvii, Inc. and its affiliates. import bisect -from functools import wraps +import copy +import os +import random +from abc import ABCMeta, abstractmethod +from functools import partial, wraps +from multiprocessing.pool import ThreadPool +import psutil +from loguru import logger +from tqdm import tqdm + +import numpy as np from torch.utils.data.dataset import ConcatDataset as torchConcatDataset from torch.utils.data.dataset import Dataset as torchDataset @@ -112,3 +122,179 @@ class Dataset(torchDataset): return ret_val return wrapper + + +class CacheDataset(Dataset, metaclass=ABCMeta): + """ This class is a subclass of the base :class:`yolox.data.datasets.Dataset`, + that enables cache images to ram or disk. + + Args: + input_dimension (tuple): (width,height) tuple with default dimensions of the network + num_imgs (int): datset size + data_dir (str): the root directory of the dataset, e.g. `/path/to/COCO`. + cache_dir_name (str): the name of the directory to cache to disk, + e.g. `"custom_cache"`. The files cached to disk will be saved + under `/path/to/COCO/custom_cache`. + path_filename (str): a list of paths to the data relative to the `data_dir`, + e.g. if you have data `/path/to/COCO/train/1.jpg`, `/path/to/COCO/train/2.jpg`, + then `path_filename = ['train/1.jpg', ' train/2.jpg']`. + cache (bool): whether to cache the images to ram or disk. + cache_type (str): the type of cache, + "ram" : Caching imgs to ram for fast training. + "disk": Caching imgs to disk for fast training. + """ + + def __init__( + self, + input_dimension, + num_imgs=None, + data_dir=None, + cache_dir_name=None, + path_filename=None, + cache=False, + cache_type="ram", + ): + super().__init__(input_dimension) + self.cache = cache + self.cache_type = cache_type + + if self.cache and self.cache_type == "disk": + self.cache_dir = os.path.join(data_dir, cache_dir_name) + self.path_filename = path_filename + + if self.cache and self.cache_type == "ram": + self.imgs = None + + if self.cache: + self.cache_images( + num_imgs=num_imgs, + data_dir=data_dir, + cache_dir_name=cache_dir_name, + path_filename=path_filename, + ) + + def __del__(self): + if self.cache and self.cache_type == "ram": + del self.imgs + + @abstractmethod + def read_img(self, index): + """ + Given index, return the corresponding image + + Args: + index (int): image index + """ + raise NotImplementedError + + def cache_images( + self, + num_imgs=None, + data_dir=None, + cache_dir_name=None, + path_filename=None, + ): + assert num_imgs is not None, "num_imgs must be specified as the size of the dataset" + if self.cache_type == "disk": + assert (data_dir and cache_dir_name and path_filename) is not None, \ + "data_dir, cache_name and path_filename must be specified if cache_type is disk" + self.path_filename = path_filename + + mem = psutil.virtual_memory() + mem_required = self.cal_cache_occupy(num_imgs) + gb = 1 << 30 + + if self.cache_type == "ram": + if mem_required > mem.available: + self.cache = False + else: + logger.info( + f"{mem_required / gb:.1f}GB RAM required, " + f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB RAM available, " + f"Since the first thing we do is cache, " + f"there is no guarantee that the remaining memory space is sufficient" + ) + + if self.cache and self.imgs is None: + if self.cache_type == 'ram': + self.imgs = [None] * num_imgs + logger.info("You are using cached images in RAM to accelerate training!") + else: # 'disk' + if not os.path.exists(self.cache_dir): + os.mkdir(self.cache_dir) + logger.warning( + f"\n*******************************************************************\n" + f"You are using cached images in DISK to accelerate training.\n" + f"This requires large DISK space.\n" + f"Make sure you have {mem_required / gb:.1f} " + f"available DISK space for training your dataset.\n" + f"*******************************************************************\\n" + ) + else: + logger.info(f"Found disk cache at {self.cache_dir}") + return + + logger.info( + "Caching images...\n" + "This might take some time for your dataset" + ) + + num_threads = min(8, max(1, os.cpu_count() - 1)) + b = 0 + load_imgs = ThreadPool(num_threads).imap( + partial(self.read_img, use_cache=False), + range(num_imgs) + ) + pbar = tqdm(enumerate(load_imgs), total=num_imgs) + for i, x in pbar: # x = self.read_img(self, i, use_cache=False) + if self.cache_type == 'ram': + self.imgs[i] = x + else: # 'disk' + cache_filename = f'{self.path_filename[i].split(".")[0]}.npy' + cache_path_filename = os.path.join(self.cache_dir, cache_filename) + os.makedirs(os.path.dirname(cache_path_filename), exist_ok=True) + np.save(cache_path_filename, x) + b += x.nbytes + pbar.desc = \ + f'Caching images ({b / gb:.1f}/{mem_required / gb:.1f}GB {self.cache_type})' + pbar.close() + + def cal_cache_occupy(self, num_imgs): + cache_bytes = 0 + num_samples = min(num_imgs, 32) + for _ in range(num_samples): + img = self.read_img(index=random.randint(0, num_imgs - 1), use_cache=False) + cache_bytes += img.nbytes + mem_required = cache_bytes * num_imgs / num_samples + return mem_required + + +def cache_read_img(use_cache=True): + def decorator(read_img_fn): + """ + Decorate the read_img function to cache the image + + Args: + read_img_fn: read_img function + use_cache (bool, optional): For the decorated read_img function, + whether to read the image from cache. + Defaults to True. + """ + @wraps(read_img_fn) + def wrapper(self, index, use_cache=use_cache): + cache = self.cache and use_cache + if cache: + if self.cache_type == "ram": + img = self.imgs[index] + img = copy.deepcopy(img) + elif self.cache_type == "disk": + img = np.load( + os.path.join( + self.cache_dir, f"{self.path_filename[index].split('.')[0]}.npy")) + else: + raise ValueError(f"Unknown cache type: {self.cache_type}") + else: + img = read_img_fn(self, index) + return img + return wrapper + return decorator diff --git a/yolox/data/datasets/voc.py b/yolox/data/datasets/voc.py index 666d72da36ccc9851290c7d8fcc6a36267c635f9..62b539d871bce582e5e41ef682b3760361d1406f 100644 --- a/yolox/data/datasets/voc.py +++ b/yolox/data/datasets/voc.py @@ -10,14 +10,13 @@ import os import os.path import pickle import xml.etree.ElementTree as ET -from loguru import logger import cv2 import numpy as np from yolox.evaluators.voc_eval import voc_eval -from .datasets_wrapper import Dataset +from .datasets_wrapper import CacheDataset, cache_read_img from .voc_classes import VOC_CLASSES @@ -80,7 +79,7 @@ class AnnotationTransform(object): return res, img_info -class VOCDetection(Dataset): +class VOCDetection(CacheDataset): """ VOC Detection Dataset Object @@ -108,8 +107,8 @@ class VOCDetection(Dataset): target_transform=AnnotationTransform(), dataset_name="VOC0712", cache=False, + cache_type="ram", ): - super().__init__(img_size) self.root = data_dir self.image_set = image_sets self.img_size = img_size @@ -131,66 +130,29 @@ class VOCDetection(Dataset): os.path.join(rootpath, "ImageSets", "Main", name + ".txt") ): self.ids.append((rootpath, line.strip())) + self.num_imgs = len(self.ids) self.annotations = self._load_coco_annotations() - self.imgs = None - if cache: - self._cache_images() - def __len__(self): - return len(self.ids) - - def _load_coco_annotations(self): - return [self.load_anno_from_ids(_ids) for _ids in range(len(self.ids))] - - def _cache_images(self): - logger.warning( - "\n********************************************************************************\n" - "You are using cached images in RAM to accelerate training.\n" - "This requires large system RAM.\n" - "Make sure you have 60G+ RAM and 19G available disk space for training VOC.\n" - "********************************************************************************\n" + path_filename = [ + (self._imgpath % self.ids[i]).split(self.root + "/")[1] + for i in range(self.num_imgs) + ] + super().__init__( + input_dimension=img_size, + num_imgs=self.num_imgs, + data_dir=self.root, + cache_dir_name=f"cache_{self.name}", + path_filename=path_filename, + cache=cache, + cache_type=cache_type ) - max_h = self.img_size[0] - max_w = self.img_size[1] - cache_file = os.path.join(self.root, f"img_resized_cache_{self.name}.array") - if not os.path.exists(cache_file): - logger.info( - "Caching images for the first time. This might take about 3 minutes for VOC" - ) - self.imgs = np.memmap( - cache_file, - shape=(len(self.ids), max_h, max_w, 3), - dtype=np.uint8, - mode="w+", - ) - from tqdm import tqdm - from multiprocessing.pool import ThreadPool - NUM_THREADs = min(8, os.cpu_count()) - loaded_images = ThreadPool(NUM_THREADs).imap( - lambda x: self.load_resized_img(x), - range(len(self.annotations)), - ) - pbar = tqdm(enumerate(loaded_images), total=len(self.annotations)) - for k, out in pbar: - self.imgs[k][: out.shape[0], : out.shape[1], :] = out.copy() - self.imgs.flush() - pbar.close() - else: - logger.warning( - "You are using cached imgs! Make sure your dataset is not changed!!\n" - "Everytime the self.input_size is changed in your exp file, you need to delete\n" - "the cached data and re-generate them.\n" - ) + def __len__(self): + return self.num_imgs - logger.info("Loading cached imgs...") - self.imgs = np.memmap( - cache_file, - shape=(len(self.ids), max_h, max_w, 3), - dtype=np.uint8, - mode="r+", - ) + def _load_coco_annotations(self): + return [self.load_anno_from_ids(_ids) for _ids in range(self.num_imgs)] def load_anno_from_ids(self, index): img_id = self.ids[index] @@ -227,6 +189,10 @@ class VOCDetection(Dataset): return img + @cache_read_img + def read_img(self, index, use_cache=True): + return self.load_resized_img(index) + def pull_item(self, index): """Returns the original image and target at an index for mixup @@ -238,17 +204,12 @@ class VOCDetection(Dataset): Return: img, target """ - if self.imgs is not None: - target, img_info, resized_info = self.annotations[index] - pad_img = self.imgs[index] - img = pad_img[: resized_info[0], : resized_info[1], :].copy() - else: - img = self.load_resized_img(index) - target, img_info, _ = self.annotations[index] + target, img_info, _ = self.annotations[index] + img = self.read_img(index) return img, target, img_info, index - @Dataset.mosaic_getitem + @CacheDataset.mosaic_getitem def __getitem__(self, index): img, target, img_info, img_id = self.pull_item(index) diff --git a/yolox/evaluators/coco_evaluator.py b/yolox/evaluators/coco_evaluator.py index 080fe9f135e6f18dc2064f3a8d3be0eb1bad1b94..e218c745624e5330dbae37dcac60f83052bf2f31 100644 --- a/yolox/evaluators/coco_evaluator.py +++ b/yolox/evaluators/coco_evaluator.py @@ -90,8 +90,8 @@ class COCOEvaluator: nmsthre: float, num_classes: int, testdev: bool = False, - per_class_AP: bool = False, - per_class_AR: bool = False, + per_class_AP: bool = True, + per_class_AR: bool = True, ): """ Args: @@ -101,8 +101,8 @@ class COCOEvaluator: confthre: confidence threshold ranging from 0 to 1, which is defined in the config file. nmsthre: IoU threshold of non-max supression ranging from 0 to 1. - per_class_AP: Show per class AP during evalution or not. Default to False. - per_class_AR: Show per class AR during evalution or not. Default to False. + per_class_AP: Show per class AP during evalution or not. Default to True. + per_class_AR: Show per class AR during evalution or not. Default to True. """ self.dataloader = dataloader self.img_size = img_size @@ -188,6 +188,9 @@ class COCOEvaluator: statistics = torch.cuda.FloatTensor([inference_time, nms_time, n_samples]) if distributed: + # different process/device might have different speed, + # to make sure the process will not be stucked, sync func is used here. + synchronize() data_list = gather(data_list, dst=0) output_data = gather(output_data, dst=0) data_list = list(itertools.chain(*data_list)) diff --git a/yolox/exp/base_exp.py b/yolox/exp/base_exp.py index e26ae079ce9324e1d7b8759b9f0c71a385accca1..5429391b1876912c3bc4b4eacc75e5ec9e6ac982 100644 --- a/yolox/exp/base_exp.py +++ b/yolox/exp/base_exp.py @@ -22,11 +22,16 @@ class BaseExp(metaclass=ABCMeta): self.output_dir = "./YOLOX_outputs" self.print_interval = 100 self.eval_interval = 10 + self.dataset = None @abstractmethod def get_model(self) -> Module: pass + @abstractmethod + def get_dataset(self, cache: bool = False, cache_type: str = "ram"): + pass + @abstractmethod def get_data_loader( self, batch_size: int, is_distributed: bool diff --git a/yolox/exp/yolox_base.py b/yolox/exp/yolox_base.py index bfd9d7c2e6bfb6385a6bf4b314be6fb4fa5163f0..c467c6e59fbc9f10f3e335259f41544896f3092b 100644 --- a/yolox/exp/yolox_base.py +++ b/yolox/exp/yolox_base.py @@ -106,23 +106,6 @@ class Exp(BaseExp): self.test_conf = 0.01 # nms threshold self.nmsthre = 0.65 - self.cache_dataset = None - self.dataset = None - - def create_cache_dataset(self, cache_type: str = "ram"): - from yolox.data import COCODataset, TrainTransform - self.cache_dataset = COCODataset( - data_dir=self.data_dir, - json_file=self.train_ann, - img_size=self.input_size, - preproc=TrainTransform( - max_labels=50, - flip_prob=self.flip_prob, - hsv_prob=self.hsv_prob - ), - cache=True, - cache_type=cache_type, - ) def get_model(self): from yolox.models import YOLOX, YOLOPAFPN, YOLOXHead @@ -144,6 +127,30 @@ class Exp(BaseExp): self.model.train() return self.model + def get_dataset(self, cache: bool = False, cache_type: str = "ram"): + """ + Get dataset according to cache and cache_type parameters. + Args: + cache (bool): Whether to cache imgs to ram or disk. + cache_type (str, optional): Defaults to "ram". + "ram" : Caching imgs to ram for fast training. + "disk": Caching imgs to disk for fast training. + """ + from yolox.data import COCODataset, TrainTransform + + return COCODataset( + data_dir=self.data_dir, + json_file=self.train_ann, + img_size=self.input_size, + preproc=TrainTransform( + max_labels=50, + flip_prob=self.flip_prob, + hsv_prob=self.hsv_prob + ), + cache=cache, + cache_type=cache_type, + ) + def get_data_loader(self, batch_size, is_distributed, no_aug=False, cache_img: str = None): """ Get dataloader according to cache_img parameter. @@ -155,7 +162,6 @@ class Exp(BaseExp): None: Do not use cache, in this case cache_data is also None. """ from yolox.data import ( - COCODataset, TrainTransform, YoloBatchSampler, DataLoader, @@ -165,25 +171,16 @@ class Exp(BaseExp): ) from yolox.utils import wait_for_the_master - with wait_for_the_master(): - if self.cache_dataset is None: - assert cache_img is None, "cache is True, but cache_dataset is None" - dataset = COCODataset( - data_dir=self.data_dir, - json_file=self.train_ann, - img_size=self.input_size, - preproc=TrainTransform( - max_labels=50, - flip_prob=self.flip_prob, - hsv_prob=self.hsv_prob), - cache=False, - cache_type=cache_img, - ) - else: - dataset = self.cache_dataset + # if cache is True, we will create self.dataset before launch + # else we will create self.dataset after launch + if self.dataset is None: + with wait_for_the_master(): + assert cache_img is None, \ + "cache_img must be None if you didn't create self.dataset before launch" + self.dataset = self.get_dataset(cache=False, cache_type=cache_img) self.dataset = MosaicDetection( - dataset, + dataset=self.dataset, mosaic=not no_aug, img_size=self.input_size, preproc=TrainTransform( @@ -298,10 +295,12 @@ class Exp(BaseExp): ) return scheduler - def get_eval_loader(self, batch_size, is_distributed, testdev=False, legacy=False): + def get_eval_dataset(self, **kwargs): from yolox.data import COCODataset, ValTransform + testdev = kwargs.get("testdev", False) + legacy = kwargs.get("legacy", False) - valdataset = COCODataset( + return COCODataset( data_dir=self.data_dir, json_file=self.val_ann if not testdev else self.test_ann, name="val2017" if not testdev else "test2017", @@ -309,6 +308,9 @@ class Exp(BaseExp): preproc=ValTransform(legacy=legacy), ) + def get_eval_loader(self, batch_size, is_distributed, **kwargs): + valdataset = self.get_eval_dataset(**kwargs) + if is_distributed: batch_size = batch_size // dist.get_world_size() sampler = torch.utils.data.distributed.DistributedSampler( @@ -330,16 +332,15 @@ class Exp(BaseExp): def get_evaluator(self, batch_size, is_distributed, testdev=False, legacy=False): from yolox.evaluators import COCOEvaluator - val_loader = self.get_eval_loader(batch_size, is_distributed, testdev, legacy) - evaluator = COCOEvaluator( - dataloader=val_loader, + return COCOEvaluator( + dataloader=self.get_eval_loader(batch_size, is_distributed, + testdev=testdev, legacy=legacy), img_size=self.test_size, confthre=self.test_conf, nmsthre=self.nmsthre, num_classes=self.num_classes, testdev=testdev, ) - return evaluator def get_trainer(self, args): from yolox.core import Trainer