未验证 提交 6d30efe8 编写于 作者: Y Yuang Peng 提交者: GitHub

feat(data): support custom dataset cache (#1584)

feat(data): support custom dataset cache
上级 f15f1935
......@@ -122,9 +122,9 @@ python -m yolox.tools.train -n yolox-s -d 8 -b 64 --fp16 -o [--cache]
* -d: number of gpu devices
* -b: total batch size, the recommended number for -b is num-gpu * 8
* --fp16: mixed precision training
* --cache: caching imgs into RAM to accelarate training, which need large system RAM.
* --cache: caching imgs into RAM to accelarate training, which need large system RAM.
When using -f, the above commands are equivalent to:
```shell
......@@ -140,7 +140,8 @@ We also support multi-nodes training. Just add the following args:
* --num\_machines: num of your total training nodes
* --machine\_rank: specify the rank of each node
Suppose you want to train YOLOX on 2 machines, and your master machines's IP is 123.123.123.123, use port 12312 and TCP.
Suppose you want to train YOLOX on 2 machines, and your master machines's IP is 123.123.123.123, use port 12312 and TCP.
On master machine, run
```shell
python tools/train.py -n yolox-s -b 128 --dist-url tcp://123.123.123.123:12312 --num_machines 2 --machine_rank 0
......@@ -163,7 +164,8 @@ python tools/train.py -n yolox-s -d 8 -b 64 --fp16 -o [--cache] --logger wandb w
An example wandb dashboard is available [here](https://wandb.ai/manan-goel/yolox-nano/runs/3pzfeom0)
**Others**
**Others**
See more information with the following command:
```shell
python -m yolox.tools.train --help
......@@ -202,6 +204,7 @@ python -m yolox.tools.eval -n yolox-s -c yolox_s.pth -b 1 -d 1 --conf 0.001 --f
<summary>Tutorials</summary>
* [Training on custom data](docs/train_custom_data.md)
* [Caching for custom data](docs/cache.md)
* [Manipulating training image size](docs/manipulate_training_image_size.md)
* [Freezing model](docs/freeze_module.md)
......
# Cache Custom Data
The caching feature is specifically tailored for users with ample memory resources. However, we still offer the option to cache data to disk, but disk performance can vary and may not guarantee optimal user experience. Implementing custom dataset RAM caching is also more straightforward and user-friendly compared to disk caching. With a few simple modifications, users can expect to see a significant increase in training speed, with speeds nearly double that of non-cached datasets.
This page explains how to cache your own custom data with YOLOX.
## 0. Before you start
**Step1** Clone this repo and follow the [README](../README.md) to install YOLOX.
**Stpe2** Read the [Training on custom data](./train_custom_data.md) tutorial to understand how to prepare your custom data.
## 1. Inheirit from `CacheDataset`
**Step1** Create a custom dataset that inherits from the `CacheDataset` class. Note that whether inheriting from `Dataset` or `CacheDataset `, the `__init__()` method of your custom dataset should take the following keyword arguments: `input_dimension`, `cache`, and `cache_type`. Also, call `super().__init__()` and pass in `input_dimension`, `num_imgs`, `cache`, and `cache_type` as input, where `num_imgs` is the size of the dataset.
**Step2** Implement the abstract function `read_img(self, index, use_cache=True)` of parent class and decorate it with `@cache_read_img`. This function takes an `index` as input and returns an `image`, and the returned image will be used for caching. It is recommended to put all repetitive and fixed post-processing operations on the image in this function to reduce the post-processing time of the image during training.
```python
# CustomDataset.py
from yolox.data.datasets import CacheDataset, cache_read_img
class CustomDataset(CacheDataset):
def __init__(self, input_dimension, cache, cache_type, *args, **kwargs):
# Get the required keyword arguments of super().__init__()
super().__init__(
input_dimension=input_dimension,
num_imgs=num_imgs,
cache=cache,
cache_type=cache_type
)
# ...
@cache_read_img
def read_img(self, index, use_cache=True):
# get image ...
# (optional) repetitive and fixed post-processing operations for image
return image
```
## 2. Create your Exp file and return your custom dataset
**Step1** Create a new class that inherits from the `Exp` class provided by the `yolox_base.py`. Override the `get_dataset()` and `get_eval_dataset()` method to return an instance of your custom dataset.
**Step2** Implement your own `get_evaluator` method to return an instance of your custom evaluator.
```python
# CustomeExp.py
from yolox.exp import Exp as MyExp
class Exp(MyExp):
def get_dataset(self, cache, cache_type: str = "ram"):
return CustomDataset(
input_dimension=self.input_size,
cache=cache,
cache_type=cache_type
)
def get_eval_dataset(self):
return CustomDataset(
input_dimension=self.input_size,
)
def get_evaluator(self, batch_size, is_distributed, testdev=False, legacy=False):
return CustomEvaluator(
dataloader=self.get_eval_loader(batch_size, is_distributed, testdev=testdev, legacy=legacy),
img_size=self.test_size,
confthre=self.test_conf,
nmsthre=self.nmsthre,
num_classes=self.num_classes,
testdev=testdev,
)
```
**(Optional)** `get_data_loader` and `get_eval_loader` are now a default behavior in `yolox_base.py` and generally do not need to be changed. If you have to change `get_data_loader`, you need to add the following code at the beginning.
```python
# CustomeExp.py
from yolox.exp import Exp as MyExp
class Exp(MyExp):
def get_data_loader(self, batch_size, is_distributed, no_aug=False, cache_img: str = None):
if self.dataset is None:
with wait_for_the_master():
assert cache_img is None
self.dataset = self.get_dataset(cache=False, cache_type=cache_img)
# ...
```
## 3. Cache to Disk
It's important to note that the `cache_type` can be `"ram"` or `"disk"`, depending on where you want to cache your dataset. If you choose `"disk"`, you need to pass in additional parameters to `super().__init__()` of `CustomDataset`: `data_dir`, `cache_dir_name`, `path_filename`.
- `data_dir`: the root directory of the dataset, e.g. `/path/to/COCO`.
- `cache_dir_name`: the name of the directory to cache to disk, for example `"custom_cache"`, then the files cached to disk will be saved under `/path/to/COCO/custom_cache`.
- `path_filename`: a list of paths to the data relative to the `data_dir`, e.g. if you have data `/path/to/COCO/train/1.jpg`, `/path/to/COCO/train/2.jpg`, then `path_filename = ['train/1.jpg', ' train/2.jpg']`.
# encoding: utf-8
import os
import torch
import torch.distributed as dist
from yolox.data import get_yolox_datadir
from yolox.exp import Exp as MyExp
......@@ -24,115 +21,40 @@ class Exp(MyExp):
self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
def get_data_loader(self, batch_size, is_distributed, no_aug=False, cache_img=False):
from yolox.data import (
VOCDetection,
TrainTransform,
YoloBatchSampler,
DataLoader,
InfiniteSampler,
MosaicDetection,
worker_init_reset_seed,
)
from yolox.utils import (
wait_for_the_master,
get_local_rank,
)
local_rank = get_local_rank()
def get_dataset(self, cache: bool, cache_type: str = "ram"):
from yolox.data import VOCDetection, TrainTransform
with wait_for_the_master(local_rank):
dataset = VOCDetection(
data_dir=os.path.join(get_yolox_datadir(), "VOCdevkit"),
image_sets=[('2007', 'trainval'), ('2012', 'trainval')],
img_size=self.input_size,
preproc=TrainTransform(
max_labels=50,
flip_prob=self.flip_prob,
hsv_prob=self.hsv_prob),
cache=cache_img,
)
dataset = MosaicDetection(
dataset,
mosaic=not no_aug,
return VOCDetection(
data_dir=os.path.join(get_yolox_datadir(), "VOCdevkit"),
image_sets=[('2007', 'trainval'), ('2012', 'trainval')],
img_size=self.input_size,
preproc=TrainTransform(
max_labels=120,
max_labels=50,
flip_prob=self.flip_prob,
hsv_prob=self.hsv_prob),
degrees=self.degrees,
translate=self.translate,
mosaic_scale=self.mosaic_scale,
mixup_scale=self.mixup_scale,
shear=self.shear,
enable_mixup=self.enable_mixup,
mosaic_prob=self.mosaic_prob,
mixup_prob=self.mixup_prob,
)
self.dataset = dataset
if is_distributed:
batch_size = batch_size // dist.get_world_size()
sampler = InfiniteSampler(
len(self.dataset), seed=self.seed if self.seed else 0
cache=cache,
cache_type=cache_type,
)
batch_sampler = YoloBatchSampler(
sampler=sampler,
batch_size=batch_size,
drop_last=False,
mosaic=not no_aug,
)
dataloader_kwargs = {"num_workers": self.data_num_workers, "pin_memory": True}
dataloader_kwargs["batch_sampler"] = batch_sampler
# Make sure each process has different random seed, especially for 'fork' method
dataloader_kwargs["worker_init_fn"] = worker_init_reset_seed
train_loader = DataLoader(self.dataset, **dataloader_kwargs)
return train_loader
def get_eval_loader(self, batch_size, is_distributed, testdev=False, legacy=False):
def get_eval_dataset(self, **kwargs):
from yolox.data import VOCDetection, ValTransform
legacy = kwargs.get("legacy", False)
valdataset = VOCDetection(
return VOCDetection(
data_dir=os.path.join(get_yolox_datadir(), "VOCdevkit"),
image_sets=[('2007', 'test')],
img_size=self.test_size,
preproc=ValTransform(legacy=legacy),
)
if is_distributed:
batch_size = batch_size // dist.get_world_size()
sampler = torch.utils.data.distributed.DistributedSampler(
valdataset, shuffle=False
)
else:
sampler = torch.utils.data.SequentialSampler(valdataset)
dataloader_kwargs = {
"num_workers": self.data_num_workers,
"pin_memory": True,
"sampler": sampler,
}
dataloader_kwargs["batch_size"] = batch_size
val_loader = torch.utils.data.DataLoader(valdataset, **dataloader_kwargs)
return val_loader
def get_evaluator(self, batch_size, is_distributed, testdev=False, legacy=False):
from yolox.evaluators import VOCEvaluator
val_loader = self.get_eval_loader(batch_size, is_distributed, testdev, legacy)
evaluator = VOCEvaluator(
dataloader=val_loader,
return VOCEvaluator(
dataloader=self.get_eval_loader(batch_size, is_distributed,
testdev=testdev, legacy=legacy),
img_size=self.test_size,
confthre=self.test_conf,
nmsthre=self.nmsthre,
num_classes=self.num_classes,
)
return evaluator
......@@ -131,7 +131,7 @@ if __name__ == "__main__":
assert num_gpu <= get_num_devices()
if args.cache is not None:
exp.create_cache_dataset(args.cache)
exp.dataset = exp.get_dataset(cache=True, cache_type=args.cache)
dist_url = "auto" if args.dist_url is None else args.dist_url
launch(
......
......@@ -4,6 +4,6 @@
from .coco import COCODataset
from .coco_classes import COCO_CLASSES
from .datasets_wrapper import ConcatDataset, Dataset, MixConcatDataset
from .datasets_wrapper import CacheDataset, ConcatDataset, Dataset, MixConcatDataset
from .mosaicdetection import MosaicDetection
from .voc import VOCDetection
......@@ -3,18 +3,13 @@
# Copyright (c) Megvii, Inc. and its affiliates.
import copy
import os
import random
from multiprocessing.pool import ThreadPool
import psutil
from loguru import logger
from tqdm import tqdm
import cv2
import numpy as np
from pycocotools.coco import COCO
from ..dataloading import get_yolox_datadir
from .datasets_wrapper import Dataset
from .datasets_wrapper import CacheDataset, cache_read_img
def remove_useless_info(coco):
......@@ -36,7 +31,7 @@ def remove_useless_info(coco):
anno.pop("segmentation", None)
class COCODataset(Dataset):
class COCODataset(CacheDataset):
"""
COCO dataset class.
"""
......@@ -60,7 +55,6 @@ class COCODataset(Dataset):
img_size (int): target image size after pre-processing
preproc: data augmentation strategy
"""
super().__init__(img_size)
if data_dir is None:
data_dir = os.path.join(get_yolox_datadir(), "COCO")
self.data_dir = data_dir
......@@ -77,85 +71,21 @@ class COCODataset(Dataset):
self.img_size = img_size
self.preproc = preproc
self.annotations = self._load_coco_annotations()
self.imgs = None
self.cache = cache
self.cache_type = cache_type
if self.cache:
self._cache_images()
def _cache_images(self):
mem = psutil.virtual_memory()
mem_required = self.cal_cache_ram()
gb = 1 << 30
if self.cache_type == "ram" and mem_required > mem.available:
self.cache = False
else:
logger.info(
f"{mem_required / gb:.1f}GB RAM required, "
f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB RAM available, "
f"Since the first thing we do is cache, "
f"there is no guarantee that the remaining memory space is sufficient"
)
if self.imgs is None:
if self.cache_type == 'ram':
self.imgs = [None] * self.num_imgs
logger.info("You are using cached images in RAM to accelerate training!")
else: # 'disk'
self.cache_dir = os.path.join(
self.data_dir,
f"{self.name}_cache{self.img_size[0]}x{self.img_size[1]}"
)
if not os.path.exists(self.cache_dir):
os.mkdir(self.cache_dir)
logger.warning(
f"\n*******************************************************************\n"
f"You are using cached images in DISK to accelerate training.\n"
f"This requires large DISK space.\n"
f"Make sure you have {mem_required / gb:.1f} "
f"available DISK space for training COCO.\n"
f"*******************************************************************\\n"
)
else:
logger.info("Found disk cache!")
return
logger.info(
"Caching images for the first time. "
"This might take about 15 minutes for COCO"
)
num_threads = min(8, max(1, os.cpu_count() - 1))
b = 0
load_imgs = ThreadPool(num_threads).imap(self.load_resized_img, range(self.num_imgs))
pbar = tqdm(enumerate(load_imgs), total=self.num_imgs)
for i, x in pbar: # x = self.load_resized_img(self, i)
if self.cache_type == 'ram':
self.imgs[i] = x
else: # 'disk'
cache_filename = f'{self.annotations[i]["filename"].split(".")[0]}.npy'
np.save(os.path.join(self.cache_dir, cache_filename), x)
b += x.nbytes
pbar.desc = f'Caching images ({b / gb:.1f}/{mem_required / gb:.1f}GB {self.cache})'
pbar.close()
def cal_cache_ram(self):
cache_bytes = 0
num_samples = min(self.num_imgs, 32)
for _ in range(num_samples):
img = self.load_resized_img(random.randint(0, self.num_imgs - 1))
cache_bytes += img.nbytes
mem_required = cache_bytes * self.num_imgs / num_samples
return mem_required
path_filename = [os.path.join(name, anno[3]) for anno in self.annotations]
super().__init__(
input_dimension=img_size,
num_imgs=self.num_imgs,
data_dir=data_dir,
cache_dir_name=f"cache_{name}",
path_filename=path_filename,
cache=cache,
cache_type=cache_type
)
def __len__(self):
return self.num_imgs
def __del__(self):
del self.imgs
def _load_coco_annotations(self):
return [self.load_anno_from_ids(_ids) for _ids in self.ids]
......@@ -220,20 +150,18 @@ class COCODataset(Dataset):
return img
@cache_read_img(use_cache=True)
def read_img(self, index):
return self.load_resized_img(index)
def pull_item(self, index):
id_ = self.ids[index]
label, origin_image_size, _, filename = self.annotations[index]
if self.cache and self.cache_type == 'ram':
img = self.imgs[index]
elif self.cache and self.cache_type == 'disk':
img = np.load(os.path.join(self.cache_dir, f"{filename.split('.')[0]}.npy"))
else:
img = self.load_resized_img(index)
label, origin_image_size, _, _ = self.annotations[index]
img = self.read_img(index)
return copy.deepcopy(img), copy.deepcopy(label), origin_image_size, np.array([id_])
return img, copy.deepcopy(label), origin_image_size, np.array([id_])
@Dataset.mosaic_getitem
@CacheDataset.mosaic_getitem
def __getitem__(self, index):
"""
One image / label pair for the given index is picked up and pre-processed.
......
......@@ -3,7 +3,17 @@
# Copyright (c) Megvii, Inc. and its affiliates.
import bisect
from functools import wraps
import copy
import os
import random
from abc import ABCMeta, abstractmethod
from functools import partial, wraps
from multiprocessing.pool import ThreadPool
import psutil
from loguru import logger
from tqdm import tqdm
import numpy as np
from torch.utils.data.dataset import ConcatDataset as torchConcatDataset
from torch.utils.data.dataset import Dataset as torchDataset
......@@ -112,3 +122,179 @@ class Dataset(torchDataset):
return ret_val
return wrapper
class CacheDataset(Dataset, metaclass=ABCMeta):
""" This class is a subclass of the base :class:`yolox.data.datasets.Dataset`,
that enables cache images to ram or disk.
Args:
input_dimension (tuple): (width,height) tuple with default dimensions of the network
num_imgs (int): datset size
data_dir (str): the root directory of the dataset, e.g. `/path/to/COCO`.
cache_dir_name (str): the name of the directory to cache to disk,
e.g. `"custom_cache"`. The files cached to disk will be saved
under `/path/to/COCO/custom_cache`.
path_filename (str): a list of paths to the data relative to the `data_dir`,
e.g. if you have data `/path/to/COCO/train/1.jpg`, `/path/to/COCO/train/2.jpg`,
then `path_filename = ['train/1.jpg', ' train/2.jpg']`.
cache (bool): whether to cache the images to ram or disk.
cache_type (str): the type of cache,
"ram" : Caching imgs to ram for fast training.
"disk": Caching imgs to disk for fast training.
"""
def __init__(
self,
input_dimension,
num_imgs=None,
data_dir=None,
cache_dir_name=None,
path_filename=None,
cache=False,
cache_type="ram",
):
super().__init__(input_dimension)
self.cache = cache
self.cache_type = cache_type
if self.cache and self.cache_type == "disk":
self.cache_dir = os.path.join(data_dir, cache_dir_name)
self.path_filename = path_filename
if self.cache and self.cache_type == "ram":
self.imgs = None
if self.cache:
self.cache_images(
num_imgs=num_imgs,
data_dir=data_dir,
cache_dir_name=cache_dir_name,
path_filename=path_filename,
)
def __del__(self):
if self.cache and self.cache_type == "ram":
del self.imgs
@abstractmethod
def read_img(self, index):
"""
Given index, return the corresponding image
Args:
index (int): image index
"""
raise NotImplementedError
def cache_images(
self,
num_imgs=None,
data_dir=None,
cache_dir_name=None,
path_filename=None,
):
assert num_imgs is not None, "num_imgs must be specified as the size of the dataset"
if self.cache_type == "disk":
assert (data_dir and cache_dir_name and path_filename) is not None, \
"data_dir, cache_name and path_filename must be specified if cache_type is disk"
self.path_filename = path_filename
mem = psutil.virtual_memory()
mem_required = self.cal_cache_occupy(num_imgs)
gb = 1 << 30
if self.cache_type == "ram":
if mem_required > mem.available:
self.cache = False
else:
logger.info(
f"{mem_required / gb:.1f}GB RAM required, "
f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB RAM available, "
f"Since the first thing we do is cache, "
f"there is no guarantee that the remaining memory space is sufficient"
)
if self.cache and self.imgs is None:
if self.cache_type == 'ram':
self.imgs = [None] * num_imgs
logger.info("You are using cached images in RAM to accelerate training!")
else: # 'disk'
if not os.path.exists(self.cache_dir):
os.mkdir(self.cache_dir)
logger.warning(
f"\n*******************************************************************\n"
f"You are using cached images in DISK to accelerate training.\n"
f"This requires large DISK space.\n"
f"Make sure you have {mem_required / gb:.1f} "
f"available DISK space for training your dataset.\n"
f"*******************************************************************\\n"
)
else:
logger.info(f"Found disk cache at {self.cache_dir}")
return
logger.info(
"Caching images...\n"
"This might take some time for your dataset"
)
num_threads = min(8, max(1, os.cpu_count() - 1))
b = 0
load_imgs = ThreadPool(num_threads).imap(
partial(self.read_img, use_cache=False),
range(num_imgs)
)
pbar = tqdm(enumerate(load_imgs), total=num_imgs)
for i, x in pbar: # x = self.read_img(self, i, use_cache=False)
if self.cache_type == 'ram':
self.imgs[i] = x
else: # 'disk'
cache_filename = f'{self.path_filename[i].split(".")[0]}.npy'
cache_path_filename = os.path.join(self.cache_dir, cache_filename)
os.makedirs(os.path.dirname(cache_path_filename), exist_ok=True)
np.save(cache_path_filename, x)
b += x.nbytes
pbar.desc = \
f'Caching images ({b / gb:.1f}/{mem_required / gb:.1f}GB {self.cache_type})'
pbar.close()
def cal_cache_occupy(self, num_imgs):
cache_bytes = 0
num_samples = min(num_imgs, 32)
for _ in range(num_samples):
img = self.read_img(index=random.randint(0, num_imgs - 1), use_cache=False)
cache_bytes += img.nbytes
mem_required = cache_bytes * num_imgs / num_samples
return mem_required
def cache_read_img(use_cache=True):
def decorator(read_img_fn):
"""
Decorate the read_img function to cache the image
Args:
read_img_fn: read_img function
use_cache (bool, optional): For the decorated read_img function,
whether to read the image from cache.
Defaults to True.
"""
@wraps(read_img_fn)
def wrapper(self, index, use_cache=use_cache):
cache = self.cache and use_cache
if cache:
if self.cache_type == "ram":
img = self.imgs[index]
img = copy.deepcopy(img)
elif self.cache_type == "disk":
img = np.load(
os.path.join(
self.cache_dir, f"{self.path_filename[index].split('.')[0]}.npy"))
else:
raise ValueError(f"Unknown cache type: {self.cache_type}")
else:
img = read_img_fn(self, index)
return img
return wrapper
return decorator
......@@ -10,14 +10,13 @@ import os
import os.path
import pickle
import xml.etree.ElementTree as ET
from loguru import logger
import cv2
import numpy as np
from yolox.evaluators.voc_eval import voc_eval
from .datasets_wrapper import Dataset
from .datasets_wrapper import CacheDataset, cache_read_img
from .voc_classes import VOC_CLASSES
......@@ -80,7 +79,7 @@ class AnnotationTransform(object):
return res, img_info
class VOCDetection(Dataset):
class VOCDetection(CacheDataset):
"""
VOC Detection Dataset Object
......@@ -108,8 +107,8 @@ class VOCDetection(Dataset):
target_transform=AnnotationTransform(),
dataset_name="VOC0712",
cache=False,
cache_type="ram",
):
super().__init__(img_size)
self.root = data_dir
self.image_set = image_sets
self.img_size = img_size
......@@ -131,66 +130,29 @@ class VOCDetection(Dataset):
os.path.join(rootpath, "ImageSets", "Main", name + ".txt")
):
self.ids.append((rootpath, line.strip()))
self.num_imgs = len(self.ids)
self.annotations = self._load_coco_annotations()
self.imgs = None
if cache:
self._cache_images()
def __len__(self):
return len(self.ids)
def _load_coco_annotations(self):
return [self.load_anno_from_ids(_ids) for _ids in range(len(self.ids))]
def _cache_images(self):
logger.warning(
"\n********************************************************************************\n"
"You are using cached images in RAM to accelerate training.\n"
"This requires large system RAM.\n"
"Make sure you have 60G+ RAM and 19G available disk space for training VOC.\n"
"********************************************************************************\n"
path_filename = [
(self._imgpath % self.ids[i]).split(self.root + "/")[1]
for i in range(self.num_imgs)
]
super().__init__(
input_dimension=img_size,
num_imgs=self.num_imgs,
data_dir=self.root,
cache_dir_name=f"cache_{self.name}",
path_filename=path_filename,
cache=cache,
cache_type=cache_type
)
max_h = self.img_size[0]
max_w = self.img_size[1]
cache_file = os.path.join(self.root, f"img_resized_cache_{self.name}.array")
if not os.path.exists(cache_file):
logger.info(
"Caching images for the first time. This might take about 3 minutes for VOC"
)
self.imgs = np.memmap(
cache_file,
shape=(len(self.ids), max_h, max_w, 3),
dtype=np.uint8,
mode="w+",
)
from tqdm import tqdm
from multiprocessing.pool import ThreadPool
NUM_THREADs = min(8, os.cpu_count())
loaded_images = ThreadPool(NUM_THREADs).imap(
lambda x: self.load_resized_img(x),
range(len(self.annotations)),
)
pbar = tqdm(enumerate(loaded_images), total=len(self.annotations))
for k, out in pbar:
self.imgs[k][: out.shape[0], : out.shape[1], :] = out.copy()
self.imgs.flush()
pbar.close()
else:
logger.warning(
"You are using cached imgs! Make sure your dataset is not changed!!\n"
"Everytime the self.input_size is changed in your exp file, you need to delete\n"
"the cached data and re-generate them.\n"
)
def __len__(self):
return self.num_imgs
logger.info("Loading cached imgs...")
self.imgs = np.memmap(
cache_file,
shape=(len(self.ids), max_h, max_w, 3),
dtype=np.uint8,
mode="r+",
)
def _load_coco_annotations(self):
return [self.load_anno_from_ids(_ids) for _ids in range(self.num_imgs)]
def load_anno_from_ids(self, index):
img_id = self.ids[index]
......@@ -227,6 +189,10 @@ class VOCDetection(Dataset):
return img
@cache_read_img
def read_img(self, index, use_cache=True):
return self.load_resized_img(index)
def pull_item(self, index):
"""Returns the original image and target at an index for mixup
......@@ -238,17 +204,12 @@ class VOCDetection(Dataset):
Return:
img, target
"""
if self.imgs is not None:
target, img_info, resized_info = self.annotations[index]
pad_img = self.imgs[index]
img = pad_img[: resized_info[0], : resized_info[1], :].copy()
else:
img = self.load_resized_img(index)
target, img_info, _ = self.annotations[index]
target, img_info, _ = self.annotations[index]
img = self.read_img(index)
return img, target, img_info, index
@Dataset.mosaic_getitem
@CacheDataset.mosaic_getitem
def __getitem__(self, index):
img, target, img_info, img_id = self.pull_item(index)
......
......@@ -90,8 +90,8 @@ class COCOEvaluator:
nmsthre: float,
num_classes: int,
testdev: bool = False,
per_class_AP: bool = False,
per_class_AR: bool = False,
per_class_AP: bool = True,
per_class_AR: bool = True,
):
"""
Args:
......@@ -101,8 +101,8 @@ class COCOEvaluator:
confthre: confidence threshold ranging from 0 to 1, which
is defined in the config file.
nmsthre: IoU threshold of non-max supression ranging from 0 to 1.
per_class_AP: Show per class AP during evalution or not. Default to False.
per_class_AR: Show per class AR during evalution or not. Default to False.
per_class_AP: Show per class AP during evalution or not. Default to True.
per_class_AR: Show per class AR during evalution or not. Default to True.
"""
self.dataloader = dataloader
self.img_size = img_size
......@@ -188,6 +188,9 @@ class COCOEvaluator:
statistics = torch.cuda.FloatTensor([inference_time, nms_time, n_samples])
if distributed:
# different process/device might have different speed,
# to make sure the process will not be stucked, sync func is used here.
synchronize()
data_list = gather(data_list, dst=0)
output_data = gather(output_data, dst=0)
data_list = list(itertools.chain(*data_list))
......
......@@ -22,11 +22,16 @@ class BaseExp(metaclass=ABCMeta):
self.output_dir = "./YOLOX_outputs"
self.print_interval = 100
self.eval_interval = 10
self.dataset = None
@abstractmethod
def get_model(self) -> Module:
pass
@abstractmethod
def get_dataset(self, cache: bool = False, cache_type: str = "ram"):
pass
@abstractmethod
def get_data_loader(
self, batch_size: int, is_distributed: bool
......
......@@ -106,23 +106,6 @@ class Exp(BaseExp):
self.test_conf = 0.01
# nms threshold
self.nmsthre = 0.65
self.cache_dataset = None
self.dataset = None
def create_cache_dataset(self, cache_type: str = "ram"):
from yolox.data import COCODataset, TrainTransform
self.cache_dataset = COCODataset(
data_dir=self.data_dir,
json_file=self.train_ann,
img_size=self.input_size,
preproc=TrainTransform(
max_labels=50,
flip_prob=self.flip_prob,
hsv_prob=self.hsv_prob
),
cache=True,
cache_type=cache_type,
)
def get_model(self):
from yolox.models import YOLOX, YOLOPAFPN, YOLOXHead
......@@ -144,6 +127,30 @@ class Exp(BaseExp):
self.model.train()
return self.model
def get_dataset(self, cache: bool = False, cache_type: str = "ram"):
"""
Get dataset according to cache and cache_type parameters.
Args:
cache (bool): Whether to cache imgs to ram or disk.
cache_type (str, optional): Defaults to "ram".
"ram" : Caching imgs to ram for fast training.
"disk": Caching imgs to disk for fast training.
"""
from yolox.data import COCODataset, TrainTransform
return COCODataset(
data_dir=self.data_dir,
json_file=self.train_ann,
img_size=self.input_size,
preproc=TrainTransform(
max_labels=50,
flip_prob=self.flip_prob,
hsv_prob=self.hsv_prob
),
cache=cache,
cache_type=cache_type,
)
def get_data_loader(self, batch_size, is_distributed, no_aug=False, cache_img: str = None):
"""
Get dataloader according to cache_img parameter.
......@@ -155,7 +162,6 @@ class Exp(BaseExp):
None: Do not use cache, in this case cache_data is also None.
"""
from yolox.data import (
COCODataset,
TrainTransform,
YoloBatchSampler,
DataLoader,
......@@ -165,25 +171,16 @@ class Exp(BaseExp):
)
from yolox.utils import wait_for_the_master
with wait_for_the_master():
if self.cache_dataset is None:
assert cache_img is None, "cache is True, but cache_dataset is None"
dataset = COCODataset(
data_dir=self.data_dir,
json_file=self.train_ann,
img_size=self.input_size,
preproc=TrainTransform(
max_labels=50,
flip_prob=self.flip_prob,
hsv_prob=self.hsv_prob),
cache=False,
cache_type=cache_img,
)
else:
dataset = self.cache_dataset
# if cache is True, we will create self.dataset before launch
# else we will create self.dataset after launch
if self.dataset is None:
with wait_for_the_master():
assert cache_img is None, \
"cache_img must be None if you didn't create self.dataset before launch"
self.dataset = self.get_dataset(cache=False, cache_type=cache_img)
self.dataset = MosaicDetection(
dataset,
dataset=self.dataset,
mosaic=not no_aug,
img_size=self.input_size,
preproc=TrainTransform(
......@@ -298,10 +295,12 @@ class Exp(BaseExp):
)
return scheduler
def get_eval_loader(self, batch_size, is_distributed, testdev=False, legacy=False):
def get_eval_dataset(self, **kwargs):
from yolox.data import COCODataset, ValTransform
testdev = kwargs.get("testdev", False)
legacy = kwargs.get("legacy", False)
valdataset = COCODataset(
return COCODataset(
data_dir=self.data_dir,
json_file=self.val_ann if not testdev else self.test_ann,
name="val2017" if not testdev else "test2017",
......@@ -309,6 +308,9 @@ class Exp(BaseExp):
preproc=ValTransform(legacy=legacy),
)
def get_eval_loader(self, batch_size, is_distributed, **kwargs):
valdataset = self.get_eval_dataset(**kwargs)
if is_distributed:
batch_size = batch_size // dist.get_world_size()
sampler = torch.utils.data.distributed.DistributedSampler(
......@@ -330,16 +332,15 @@ class Exp(BaseExp):
def get_evaluator(self, batch_size, is_distributed, testdev=False, legacy=False):
from yolox.evaluators import COCOEvaluator
val_loader = self.get_eval_loader(batch_size, is_distributed, testdev, legacy)
evaluator = COCOEvaluator(
dataloader=val_loader,
return COCOEvaluator(
dataloader=self.get_eval_loader(batch_size, is_distributed,
testdev=testdev, legacy=legacy),
img_size=self.test_size,
confthre=self.test_conf,
nmsthre=self.nmsthre,
num_classes=self.num_classes,
testdev=testdev,
)
return evaluator
def get_trainer(self, args):
from yolox.core import Trainer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册