未验证 提交 33f48a92 编写于 作者: Y Yuang Peng 提交者: GitHub

feat(data): support cache ram of COCO dataset (#1562)

feat(data): support cache ram of COCO dataset
上级 11c2a1f8
...@@ -8,6 +8,7 @@ torchvision ...@@ -8,6 +8,7 @@ torchvision
thop thop
ninja ninja
tabulate tabulate
psutil
# verified versions # verified versions
# pycocotools corresponds to https://github.com/ppwwyyxx/cocoapi # pycocotools corresponds to https://github.com/ppwwyyxx/cocoapi
......
...@@ -3,7 +3,7 @@ line_length = 100 ...@@ -3,7 +3,7 @@ line_length = 100
multi_line_output = 3 multi_line_output = 3
balanced_wrapping = True balanced_wrapping = True
known_standard_library = setuptools known_standard_library = setuptools
known_third_party = tqdm,loguru,tabulate known_third_party = tqdm,loguru,tabulate,psutil
known_data_processing = cv2,numpy,scipy,PIL,matplotlib known_data_processing = cv2,numpy,scipy,PIL,matplotlib
known_datasets = pycocotools known_datasets = pycocotools
known_deeplearning = torch,torchvision,caffe2,onnx,apex,timm,thop,torch2trt,tensorrt,openvino,onnxruntime known_deeplearning = torch,torchvision,caffe2,onnx,apex,timm,thop,torch2trt,tensorrt,openvino,onnxruntime
......
...@@ -67,10 +67,10 @@ def make_parser(): ...@@ -67,10 +67,10 @@ def make_parser():
) )
parser.add_argument( parser.add_argument(
"--cache", "--cache",
dest="cache", type=str,
default=False, nargs="?",
action="store_true", const="ram",
help="Caching imgs to RAM for fast training.", help="Caching imgs to ram/disk for fast training.",
) )
parser.add_argument( parser.add_argument(
"-o", "-o",
...@@ -130,6 +130,9 @@ if __name__ == "__main__": ...@@ -130,6 +130,9 @@ if __name__ == "__main__":
num_gpu = get_num_devices() if args.devices is None else args.devices num_gpu = get_num_devices() if args.devices is None else args.devices
assert num_gpu <= get_num_devices() assert num_gpu <= get_num_devices()
if args.cache is not None:
exp.create_cache_dataset(args.cache)
dist_url = "auto" if args.dist_url is None else args.dist_url dist_url = "auto" if args.dist_url is None else args.dist_url
launch( launch(
main, main,
......
...@@ -26,6 +26,7 @@ from yolox.utils import ( ...@@ -26,6 +26,7 @@ from yolox.utils import (
gpu_mem_usage, gpu_mem_usage,
is_parallel, is_parallel,
load_ckpt, load_ckpt,
mem_usage,
occupy_mem, occupy_mem,
save_checkpoint, save_checkpoint,
setup_logger, setup_logger,
...@@ -250,10 +251,12 @@ class Trainer: ...@@ -250,10 +251,12 @@ class Trainer:
["{}: {:.3f}s".format(k, v.avg) for k, v in time_meter.items()] ["{}: {:.3f}s".format(k, v.avg) for k, v in time_meter.items()]
) )
mem_str = "gpu mem: {:.0f}Mb, mem: {:.1f}Gb".format(gpu_mem_usage(), mem_usage())
logger.info( logger.info(
"{}, mem: {:.0f}Mb, {}, {}, lr: {:.3e}".format( "{}, {}, {}, {}, lr: {:.3e}".format(
progress_str, progress_str,
gpu_mem_usage(), mem_str,
time_str, time_str,
loss_str, loss_str,
self.meter["lr"].latest, self.meter["lr"].latest,
......
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates. # Copyright (c) Megvii, Inc. and its affiliates.
import copy
import os import os
import random
from multiprocessing.pool import ThreadPool
import psutil
from loguru import logger from loguru import logger
from tqdm import tqdm
import cv2 import cv2
import numpy as np import numpy as np
...@@ -45,6 +49,7 @@ class COCODataset(Dataset): ...@@ -45,6 +49,7 @@ class COCODataset(Dataset):
img_size=(416, 416), img_size=(416, 416),
preproc=None, preproc=None,
cache=False, cache=False,
cache_type="ram",
): ):
""" """
COCO dataset initialization. Annotation data are read into memory by COCO API. COCO dataset initialization. Annotation data are read into memory by COCO API.
...@@ -64,74 +69,95 @@ class COCODataset(Dataset): ...@@ -64,74 +69,95 @@ class COCODataset(Dataset):
self.coco = COCO(os.path.join(self.data_dir, "annotations", self.json_file)) self.coco = COCO(os.path.join(self.data_dir, "annotations", self.json_file))
remove_useless_info(self.coco) remove_useless_info(self.coco)
self.ids = self.coco.getImgIds() self.ids = self.coco.getImgIds()
self.num_imgs = len(self.ids)
self.class_ids = sorted(self.coco.getCatIds()) self.class_ids = sorted(self.coco.getCatIds())
self.cats = self.coco.loadCats(self.coco.getCatIds()) self.cats = self.coco.loadCats(self.coco.getCatIds())
self._classes = tuple([c["name"] for c in self.cats]) self._classes = tuple([c["name"] for c in self.cats])
self.imgs = None
self.name = name self.name = name
self.img_size = img_size self.img_size = img_size
self.preproc = preproc self.preproc = preproc
self.annotations = self._load_coco_annotations() self.annotations = self._load_coco_annotations()
if cache: self.imgs = None
self.cache = cache
self.cache_type = cache_type
if self.cache:
self._cache_images() self._cache_images()
def __len__(self): def _cache_images(self):
return len(self.ids) mem = psutil.virtual_memory()
mem_required = self.cal_cache_ram()
gb = 1 << 30
def __del__(self): if self.cache_type == "ram" and mem_required > mem.available:
del self.imgs self.cache = False
else:
logger.info(
f"{mem_required / gb:.1f}GB RAM required, "
f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB RAM available, "
f"Since the first thing we do is cache, "
f"there is no guarantee that the remaining memory space is sufficient"
)
def _load_coco_annotations(self): if self.cache and self.imgs is None:
return [self.load_anno_from_ids(_ids) for _ids in self.ids] if self.cache_type == 'ram':
self.imgs = [None] * self.num_imgs
logger.info("You are using cached images in RAM to accelerate training!")
else: # 'disk'
self.cache_dir = os.path.join(
self.data_dir,
f"{self.name}_cache{self.img_size[0]}x{self.img_size[1]}"
)
if not os.path.exists(self.cache_dir):
os.mkdir(self.cache_dir)
logger.warning(
f"\n*******************************************************************\n"
f"You are using cached images in DISK to accelerate training.\n"
f"This requires large DISK space.\n"
f"Make sure you have {mem_required / gb:.1f} "
f"available DISK space for training COCO.\n"
f"*******************************************************************\\n"
)
else:
logger.info("Found disk cache!")
return
def _cache_images(self):
logger.warning(
"\n********************************************************************************\n"
"You are using cached images in RAM to accelerate training.\n"
"This requires large system RAM.\n"
"Make sure you have 200G+ RAM and 136G available disk space for training COCO.\n"
"********************************************************************************\n"
)
max_h = self.img_size[0]
max_w = self.img_size[1]
cache_file = os.path.join(self.data_dir, f"img_resized_cache_{self.name}.array")
if not os.path.exists(cache_file):
logger.info( logger.info(
"Caching images for the first time. This might take about 20 minutes for COCO" "Caching images for the first time. "
"This might take about 15 minutes for COCO"
) )
self.imgs = np.memmap(
cache_file,
shape=(len(self.ids), max_h, max_w, 3),
dtype=np.uint8,
mode="w+",
)
from tqdm import tqdm
from multiprocessing.pool import ThreadPool
NUM_THREADs = min(8, os.cpu_count()) num_threads = min(8, max(1, os.cpu_count() - 1))
loaded_images = ThreadPool(NUM_THREADs).imap( b = 0
lambda x: self.load_resized_img(x), load_imgs = ThreadPool(num_threads).imap(self.load_resized_img, range(self.num_imgs))
range(len(self.annotations)), pbar = tqdm(enumerate(load_imgs), total=self.num_imgs)
) for i, x in pbar: # x = self.load_resized_img(self, i)
pbar = tqdm(enumerate(loaded_images), total=len(self.annotations)) if self.cache_type == 'ram':
for k, out in pbar: self.imgs[i] = x
self.imgs[k][: out.shape[0], : out.shape[1], :] = out.copy() else: # 'disk'
self.imgs.flush() cache_filename = f'{self.annotations[i]["filename"].split(".")[0]}.npy'
np.save(os.path.join(self.cache_dir, cache_filename), x)
b += x.nbytes
pbar.desc = f'Caching images ({b / gb:.1f}/{mem_required / gb:.1f}GB {self.cache})'
pbar.close() pbar.close()
else:
logger.warning(
"You are using cached imgs! Make sure your dataset is not changed!!\n"
"Everytime the self.input_size is changed in your exp file, you need to delete\n"
"the cached data and re-generate them.\n"
)
logger.info("Loading cached imgs...") def cal_cache_ram(self):
self.imgs = np.memmap( cache_bytes = 0
cache_file, num_samples = min(self.num_imgs, 32)
shape=(len(self.ids), max_h, max_w, 3), for _ in range(num_samples):
dtype=np.uint8, img = self.load_resized_img(random.randint(0, self.num_imgs - 1))
mode="r+", cache_bytes += img.nbytes
) mem_required = cache_bytes * self.num_imgs / num_samples
return mem_required
def __len__(self):
return self.num_imgs
def __del__(self):
del self.imgs
def _load_coco_annotations(self):
return [self.load_anno_from_ids(_ids) for _ids in self.ids]
def load_anno_from_ids(self, id_): def load_anno_from_ids(self, id_):
im_ann = self.coco.loadImgs(id_)[0] im_ann = self.coco.loadImgs(id_)[0]
...@@ -152,7 +178,6 @@ class COCODataset(Dataset): ...@@ -152,7 +178,6 @@ class COCODataset(Dataset):
num_objs = len(objs) num_objs = len(objs)
res = np.zeros((num_objs, 5)) res = np.zeros((num_objs, 5))
for ix, obj in enumerate(objs): for ix, obj in enumerate(objs):
cls = self.class_ids.index(obj["category_id"]) cls = self.class_ids.index(obj["category_id"])
res[ix, 0:4] = obj["clean_bbox"] res[ix, 0:4] = obj["clean_bbox"]
...@@ -197,15 +222,16 @@ class COCODataset(Dataset): ...@@ -197,15 +222,16 @@ class COCODataset(Dataset):
def pull_item(self, index): def pull_item(self, index):
id_ = self.ids[index] id_ = self.ids[index]
label, origin_image_size, _, filename = self.annotations[index]
res, img_info, resized_info, _ = self.annotations[index] if self.cache_type == 'ram':
if self.imgs is not None: img = self.imgs[index]
pad_img = self.imgs[index] elif self.cache_type == 'disk':
img = pad_img[: resized_info[0], : resized_info[1], :].copy() img = np.load(os.path.join(self.cache_dir, f"{filename.split('.')[0]}.npy"))
else: else:
img = self.load_resized_img(index) img = self.load_resized_img(index)
return img, res.copy(), img_info, np.array([id_]) return copy.deepcopy(img), copy.deepcopy(label), origin_image_size, np.array([id_])
@Dataset.mosaic_getitem @Dataset.mosaic_getitem
def __getitem__(self, index): def __getitem__(self, index):
......
...@@ -106,6 +106,23 @@ class Exp(BaseExp): ...@@ -106,6 +106,23 @@ class Exp(BaseExp):
self.test_conf = 0.01 self.test_conf = 0.01
# nms threshold # nms threshold
self.nmsthre = 0.65 self.nmsthre = 0.65
self.cache_dataset = None
self.dataset = None
def create_cache_dataset(self, cache_type: str = "ram"):
from yolox.data import COCODataset, TrainTransform
self.cache_dataset = COCODataset(
data_dir=self.data_dir,
json_file=self.train_ann,
img_size=self.input_size,
preproc=TrainTransform(
max_labels=50,
flip_prob=self.flip_prob,
hsv_prob=self.hsv_prob
),
cache=True,
cache_type=cache_type,
)
def get_model(self): def get_model(self):
from yolox.models import YOLOX, YOLOPAFPN, YOLOXHead from yolox.models import YOLOX, YOLOPAFPN, YOLOXHead
...@@ -127,7 +144,16 @@ class Exp(BaseExp): ...@@ -127,7 +144,16 @@ class Exp(BaseExp):
self.model.train() self.model.train()
return self.model return self.model
def get_data_loader(self, batch_size, is_distributed, no_aug=False, cache_img=False): def get_data_loader(self, batch_size, is_distributed, no_aug=False, cache_img: str = None):
"""
Get dataloader according to cache_img parameter.
Args:
no_aug (bool, optional): Whether to turn off mosaic data enhancement. Defaults to False.
cache_img (str, optional): cache_img is equivalent to cache_type. Defaults to None.
"ram" : Caching imgs to ram for fast training.
"disk": Caching imgs to disk for fast training.
None: Do not use cache, in this case cache_data is also None.
"""
from yolox.data import ( from yolox.data import (
COCODataset, COCODataset,
TrainTransform, TrainTransform,
...@@ -140,18 +166,23 @@ class Exp(BaseExp): ...@@ -140,18 +166,23 @@ class Exp(BaseExp):
from yolox.utils import wait_for_the_master from yolox.utils import wait_for_the_master
with wait_for_the_master(): with wait_for_the_master():
dataset = COCODataset( if self.cache_dataset is None:
data_dir=self.data_dir, assert cache_img is None, "cache is True, but cache_dataset is None"
json_file=self.train_ann, dataset = COCODataset(
img_size=self.input_size, data_dir=self.data_dir,
preproc=TrainTransform( json_file=self.train_ann,
max_labels=50, img_size=self.input_size,
flip_prob=self.flip_prob, preproc=TrainTransform(
hsv_prob=self.hsv_prob), max_labels=50,
cache=cache_img, flip_prob=self.flip_prob,
) hsv_prob=self.hsv_prob),
cache=False,
cache_type=cache_img,
)
else:
dataset = self.cache_dataset
dataset = MosaicDetection( self.dataset = MosaicDetection(
dataset, dataset,
mosaic=not no_aug, mosaic=not no_aug,
img_size=self.input_size, img_size=self.input_size,
...@@ -169,8 +200,6 @@ class Exp(BaseExp): ...@@ -169,8 +200,6 @@ class Exp(BaseExp):
mixup_prob=self.mixup_prob, mixup_prob=self.mixup_prob,
) )
self.dataset = dataset
if is_distributed: if is_distributed:
batch_size = batch_size // dist.get_world_size() batch_size = batch_size // dist.get_world_size()
......
...@@ -5,6 +5,7 @@ import functools ...@@ -5,6 +5,7 @@ import functools
import os import os
import time import time
from collections import defaultdict, deque from collections import defaultdict, deque
import psutil
import numpy as np import numpy as np
...@@ -16,6 +17,7 @@ __all__ = [ ...@@ -16,6 +17,7 @@ __all__ = [
"get_total_and_free_memory_in_Mb", "get_total_and_free_memory_in_Mb",
"occupy_mem", "occupy_mem",
"gpu_mem_usage", "gpu_mem_usage",
"mem_usage"
] ]
...@@ -51,6 +53,15 @@ def gpu_mem_usage(): ...@@ -51,6 +53,15 @@ def gpu_mem_usage():
return mem_usage_bytes / (1024 * 1024) return mem_usage_bytes / (1024 * 1024)
def mem_usage():
"""
Compute the memory usage for the current machine (GB).
"""
gb = 1 << 30
mem = psutil.virtual_memory()
return mem.used / gb
class AverageMeter: class AverageMeter:
"""Track a series of values and provide access to smoothed values over a """Track a series of values and provide access to smoothed values over a
window or the global series average. window or the global series average.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册