Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
FL1623863129
YOLOX
提交
33f48a92
Y
YOLOX
项目概览
FL1623863129
/
YOLOX
与 Fork 源项目一致
从无法访问的项目Fork
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Y
YOLOX
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
未验证
提交
33f48a92
编写于
12月 24, 2022
作者:
Y
Yuang Peng
提交者:
GitHub
12月 24, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(data): support cache ram of COCO dataset (#1562)
feat(data): support cache ram of COCO dataset
上级
11c2a1f8
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
152 addition
and
79 deletion
+152
-79
requirements.txt
requirements.txt
+1
-0
setup.cfg
setup.cfg
+1
-1
tools/train.py
tools/train.py
+7
-4
yolox/core/trainer.py
yolox/core/trainer.py
+5
-2
yolox/data/datasets/coco.py
yolox/data/datasets/coco.py
+84
-58
yolox/exp/yolox_base.py
yolox/exp/yolox_base.py
+43
-14
yolox/utils/metric.py
yolox/utils/metric.py
+11
-0
未找到文件。
requirements.txt
浏览文件 @
33f48a92
...
...
@@ -8,6 +8,7 @@ torchvision
thop
ninja
tabulate
psutil
# verified versions
# pycocotools corresponds to https://github.com/ppwwyyxx/cocoapi
...
...
setup.cfg
浏览文件 @
33f48a92
...
...
@@ -3,7 +3,7 @@ line_length = 100
multi_line_output = 3
balanced_wrapping = True
known_standard_library = setuptools
known_third_party = tqdm,loguru,tabulate
known_third_party = tqdm,loguru,tabulate
,psutil
known_data_processing = cv2,numpy,scipy,PIL,matplotlib
known_datasets = pycocotools
known_deeplearning = torch,torchvision,caffe2,onnx,apex,timm,thop,torch2trt,tensorrt,openvino,onnxruntime
...
...
tools/train.py
浏览文件 @
33f48a92
...
...
@@ -67,10 +67,10 @@ def make_parser():
)
parser
.
add_argument
(
"--cache"
,
dest
=
"cache"
,
default
=
False
,
action
=
"store_true
"
,
help
=
"Caching imgs to
RAM
for fast training."
,
type
=
str
,
nargs
=
"?"
,
const
=
"ram
"
,
help
=
"Caching imgs to
ram/disk
for fast training."
,
)
parser
.
add_argument
(
"-o"
,
...
...
@@ -130,6 +130,9 @@ if __name__ == "__main__":
num_gpu
=
get_num_devices
()
if
args
.
devices
is
None
else
args
.
devices
assert
num_gpu
<=
get_num_devices
()
if
args
.
cache
is
not
None
:
exp
.
create_cache_dataset
(
args
.
cache
)
dist_url
=
"auto"
if
args
.
dist_url
is
None
else
args
.
dist_url
launch
(
main
,
...
...
yolox/core/trainer.py
浏览文件 @
33f48a92
...
...
@@ -26,6 +26,7 @@ from yolox.utils import (
gpu_mem_usage
,
is_parallel
,
load_ckpt
,
mem_usage
,
occupy_mem
,
save_checkpoint
,
setup_logger
,
...
...
@@ -250,10 +251,12 @@ class Trainer:
[
"{}: {:.3f}s"
.
format
(
k
,
v
.
avg
)
for
k
,
v
in
time_meter
.
items
()]
)
mem_str
=
"gpu mem: {:.0f}Mb, mem: {:.1f}Gb"
.
format
(
gpu_mem_usage
(),
mem_usage
())
logger
.
info
(
"{},
mem: {:.0f}Mb
, {}, {}, lr: {:.3e}"
.
format
(
"{},
{}
, {}, {}, lr: {:.3e}"
.
format
(
progress_str
,
gpu_mem_usage
()
,
mem_str
,
time_str
,
loss_str
,
self
.
meter
[
"lr"
].
latest
,
...
...
yolox/data/datasets/coco.py
浏览文件 @
33f48a92
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.
import
copy
import
os
import
random
from
multiprocessing.pool
import
ThreadPool
import
psutil
from
loguru
import
logger
from
tqdm
import
tqdm
import
cv2
import
numpy
as
np
...
...
@@ -45,6 +49,7 @@ class COCODataset(Dataset):
img_size
=
(
416
,
416
),
preproc
=
None
,
cache
=
False
,
cache_type
=
"ram"
,
):
"""
COCO dataset initialization. Annotation data are read into memory by COCO API.
...
...
@@ -64,74 +69,95 @@ class COCODataset(Dataset):
self
.
coco
=
COCO
(
os
.
path
.
join
(
self
.
data_dir
,
"annotations"
,
self
.
json_file
))
remove_useless_info
(
self
.
coco
)
self
.
ids
=
self
.
coco
.
getImgIds
()
self
.
num_imgs
=
len
(
self
.
ids
)
self
.
class_ids
=
sorted
(
self
.
coco
.
getCatIds
())
self
.
cats
=
self
.
coco
.
loadCats
(
self
.
coco
.
getCatIds
())
self
.
_classes
=
tuple
([
c
[
"name"
]
for
c
in
self
.
cats
])
self
.
imgs
=
None
self
.
name
=
name
self
.
img_size
=
img_size
self
.
preproc
=
preproc
self
.
annotations
=
self
.
_load_coco_annotations
()
if
cache
:
self
.
imgs
=
None
self
.
cache
=
cache
self
.
cache_type
=
cache_type
if
self
.
cache
:
self
.
_cache_images
()
def
__len__
(
self
):
return
len
(
self
.
ids
)
def
_cache_images
(
self
):
mem
=
psutil
.
virtual_memory
()
mem_required
=
self
.
cal_cache_ram
()
gb
=
1
<<
30
def
__del__
(
self
):
del
self
.
imgs
if
self
.
cache_type
==
"ram"
and
mem_required
>
mem
.
available
:
self
.
cache
=
False
else
:
logger
.
info
(
f
"
{
mem_required
/
gb
:.
1
f
}
GB RAM required, "
f
"
{
mem
.
available
/
gb
:.
1
f
}
/
{
mem
.
total
/
gb
:.
1
f
}
GB RAM available, "
f
"Since the first thing we do is cache, "
f
"there is no guarantee that the remaining memory space is sufficient"
)
def
_load_coco_annotations
(
self
):
return
[
self
.
load_anno_from_ids
(
_ids
)
for
_ids
in
self
.
ids
]
if
self
.
cache
and
self
.
imgs
is
None
:
if
self
.
cache_type
==
'ram'
:
self
.
imgs
=
[
None
]
*
self
.
num_imgs
logger
.
info
(
"You are using cached images in RAM to accelerate training!"
)
else
:
# 'disk'
self
.
cache_dir
=
os
.
path
.
join
(
self
.
data_dir
,
f
"
{
self
.
name
}
_cache
{
self
.
img_size
[
0
]
}
x
{
self
.
img_size
[
1
]
}
"
)
if
not
os
.
path
.
exists
(
self
.
cache_dir
):
os
.
mkdir
(
self
.
cache_dir
)
logger
.
warning
(
f
"
\n
*******************************************************************
\n
"
f
"You are using cached images in DISK to accelerate training.
\n
"
f
"This requires large DISK space.
\n
"
f
"Make sure you have
{
mem_required
/
gb
:.
1
f
}
"
f
"available DISK space for training COCO.
\n
"
f
"*******************************************************************
\\
n"
)
else
:
logger
.
info
(
"Found disk cache!"
)
return
def
_cache_images
(
self
):
logger
.
warning
(
"
\n
********************************************************************************
\n
"
"You are using cached images in RAM to accelerate training.
\n
"
"This requires large system RAM.
\n
"
"Make sure you have 200G+ RAM and 136G available disk space for training COCO.
\n
"
"********************************************************************************
\n
"
)
max_h
=
self
.
img_size
[
0
]
max_w
=
self
.
img_size
[
1
]
cache_file
=
os
.
path
.
join
(
self
.
data_dir
,
f
"img_resized_cache_
{
self
.
name
}
.array"
)
if
not
os
.
path
.
exists
(
cache_file
):
logger
.
info
(
"Caching images for the first time. This might take about 20 minutes for COCO"
"Caching images for the first time. "
"This might take about 15 minutes for COCO"
)
self
.
imgs
=
np
.
memmap
(
cache_file
,
shape
=
(
len
(
self
.
ids
),
max_h
,
max_w
,
3
),
dtype
=
np
.
uint8
,
mode
=
"w+"
,
)
from
tqdm
import
tqdm
from
multiprocessing.pool
import
ThreadPool
NUM_THREADs
=
min
(
8
,
os
.
cpu_count
())
loaded_images
=
ThreadPool
(
NUM_THREADs
).
imap
(
lambda
x
:
self
.
load_resized_img
(
x
),
range
(
len
(
self
.
annotations
)),
)
pbar
=
tqdm
(
enumerate
(
loaded_images
),
total
=
len
(
self
.
annotations
))
for
k
,
out
in
pbar
:
self
.
imgs
[
k
][:
out
.
shape
[
0
],
:
out
.
shape
[
1
],
:]
=
out
.
copy
()
self
.
imgs
.
flush
()
num_threads
=
min
(
8
,
max
(
1
,
os
.
cpu_count
()
-
1
))
b
=
0
load_imgs
=
ThreadPool
(
num_threads
).
imap
(
self
.
load_resized_img
,
range
(
self
.
num_imgs
))
pbar
=
tqdm
(
enumerate
(
load_imgs
),
total
=
self
.
num_imgs
)
for
i
,
x
in
pbar
:
# x = self.load_resized_img(self, i)
if
self
.
cache_type
==
'ram'
:
self
.
imgs
[
i
]
=
x
else
:
# 'disk'
cache_filename
=
f
'
{
self
.
annotations
[
i
][
"filename"
].
split
(
"."
)[
0
]
}
.npy'
np
.
save
(
os
.
path
.
join
(
self
.
cache_dir
,
cache_filename
),
x
)
b
+=
x
.
nbytes
pbar
.
desc
=
f
'Caching images (
{
b
/
gb
:.
1
f
}
/
{
mem_required
/
gb
:.
1
f
}
GB
{
self
.
cache
}
)'
pbar
.
close
()
else
:
logger
.
warning
(
"You are using cached imgs! Make sure your dataset is not changed!!
\n
"
"Everytime the self.input_size is changed in your exp file, you need to delete
\n
"
"the cached data and re-generate them.
\n
"
)
logger
.
info
(
"Loading cached imgs..."
)
self
.
imgs
=
np
.
memmap
(
cache_file
,
shape
=
(
len
(
self
.
ids
),
max_h
,
max_w
,
3
),
dtype
=
np
.
uint8
,
mode
=
"r+"
,
)
def
cal_cache_ram
(
self
):
cache_bytes
=
0
num_samples
=
min
(
self
.
num_imgs
,
32
)
for
_
in
range
(
num_samples
):
img
=
self
.
load_resized_img
(
random
.
randint
(
0
,
self
.
num_imgs
-
1
))
cache_bytes
+=
img
.
nbytes
mem_required
=
cache_bytes
*
self
.
num_imgs
/
num_samples
return
mem_required
def
__len__
(
self
):
return
self
.
num_imgs
def
__del__
(
self
):
del
self
.
imgs
def
_load_coco_annotations
(
self
):
return
[
self
.
load_anno_from_ids
(
_ids
)
for
_ids
in
self
.
ids
]
def
load_anno_from_ids
(
self
,
id_
):
im_ann
=
self
.
coco
.
loadImgs
(
id_
)[
0
]
...
...
@@ -152,7 +178,6 @@ class COCODataset(Dataset):
num_objs
=
len
(
objs
)
res
=
np
.
zeros
((
num_objs
,
5
))
for
ix
,
obj
in
enumerate
(
objs
):
cls
=
self
.
class_ids
.
index
(
obj
[
"category_id"
])
res
[
ix
,
0
:
4
]
=
obj
[
"clean_bbox"
]
...
...
@@ -197,15 +222,16 @@ class COCODataset(Dataset):
def
pull_item
(
self
,
index
):
id_
=
self
.
ids
[
index
]
label
,
origin_image_size
,
_
,
filename
=
self
.
annotations
[
index
]
res
,
img_info
,
resized_info
,
_
=
self
.
annotations
[
index
]
if
self
.
imgs
is
not
None
:
pad_img
=
self
.
imgs
[
index
]
img
=
pad_img
[:
resized_info
[
0
],
:
resized_info
[
1
],
:].
copy
(
)
if
self
.
cache_type
==
'ram'
:
img
=
self
.
imgs
[
index
]
elif
self
.
cache_type
==
'disk'
:
img
=
np
.
load
(
os
.
path
.
join
(
self
.
cache_dir
,
f
"
{
filename
.
split
(
'.'
)[
0
]
}
.npy"
)
)
else
:
img
=
self
.
load_resized_img
(
index
)
return
img
,
res
.
copy
(),
img_info
,
np
.
array
([
id_
])
return
copy
.
deepcopy
(
img
),
copy
.
deepcopy
(
label
),
origin_image_size
,
np
.
array
([
id_
])
@
Dataset
.
mosaic_getitem
def
__getitem__
(
self
,
index
):
...
...
yolox/exp/yolox_base.py
浏览文件 @
33f48a92
...
...
@@ -106,6 +106,23 @@ class Exp(BaseExp):
self
.
test_conf
=
0.01
# nms threshold
self
.
nmsthre
=
0.65
self
.
cache_dataset
=
None
self
.
dataset
=
None
def
create_cache_dataset
(
self
,
cache_type
:
str
=
"ram"
):
from
yolox.data
import
COCODataset
,
TrainTransform
self
.
cache_dataset
=
COCODataset
(
data_dir
=
self
.
data_dir
,
json_file
=
self
.
train_ann
,
img_size
=
self
.
input_size
,
preproc
=
TrainTransform
(
max_labels
=
50
,
flip_prob
=
self
.
flip_prob
,
hsv_prob
=
self
.
hsv_prob
),
cache
=
True
,
cache_type
=
cache_type
,
)
def
get_model
(
self
):
from
yolox.models
import
YOLOX
,
YOLOPAFPN
,
YOLOXHead
...
...
@@ -127,7 +144,16 @@ class Exp(BaseExp):
self
.
model
.
train
()
return
self
.
model
def
get_data_loader
(
self
,
batch_size
,
is_distributed
,
no_aug
=
False
,
cache_img
=
False
):
def
get_data_loader
(
self
,
batch_size
,
is_distributed
,
no_aug
=
False
,
cache_img
:
str
=
None
):
"""
Get dataloader according to cache_img parameter.
Args:
no_aug (bool, optional): Whether to turn off mosaic data enhancement. Defaults to False.
cache_img (str, optional): cache_img is equivalent to cache_type. Defaults to None.
"ram" : Caching imgs to ram for fast training.
"disk": Caching imgs to disk for fast training.
None: Do not use cache, in this case cache_data is also None.
"""
from
yolox.data
import
(
COCODataset
,
TrainTransform
,
...
...
@@ -140,18 +166,23 @@ class Exp(BaseExp):
from
yolox.utils
import
wait_for_the_master
with
wait_for_the_master
():
dataset
=
COCODataset
(
data_dir
=
self
.
data_dir
,
json_file
=
self
.
train_ann
,
img_size
=
self
.
input_size
,
preproc
=
TrainTransform
(
max_labels
=
50
,
flip_prob
=
self
.
flip_prob
,
hsv_prob
=
self
.
hsv_prob
),
cache
=
cache_img
,
)
if
self
.
cache_dataset
is
None
:
assert
cache_img
is
None
,
"cache is True, but cache_dataset is None"
dataset
=
COCODataset
(
data_dir
=
self
.
data_dir
,
json_file
=
self
.
train_ann
,
img_size
=
self
.
input_size
,
preproc
=
TrainTransform
(
max_labels
=
50
,
flip_prob
=
self
.
flip_prob
,
hsv_prob
=
self
.
hsv_prob
),
cache
=
False
,
cache_type
=
cache_img
,
)
else
:
dataset
=
self
.
cache_dataset
dataset
=
MosaicDetection
(
self
.
dataset
=
MosaicDetection
(
dataset
,
mosaic
=
not
no_aug
,
img_size
=
self
.
input_size
,
...
...
@@ -169,8 +200,6 @@ class Exp(BaseExp):
mixup_prob
=
self
.
mixup_prob
,
)
self
.
dataset
=
dataset
if
is_distributed
:
batch_size
=
batch_size
//
dist
.
get_world_size
()
...
...
yolox/utils/metric.py
浏览文件 @
33f48a92
...
...
@@ -5,6 +5,7 @@ import functools
import
os
import
time
from
collections
import
defaultdict
,
deque
import
psutil
import
numpy
as
np
...
...
@@ -16,6 +17,7 @@ __all__ = [
"get_total_and_free_memory_in_Mb"
,
"occupy_mem"
,
"gpu_mem_usage"
,
"mem_usage"
]
...
...
@@ -51,6 +53,15 @@ def gpu_mem_usage():
return
mem_usage_bytes
/
(
1024
*
1024
)
def
mem_usage
():
"""
Compute the memory usage for the current machine (GB).
"""
gb
=
1
<<
30
mem
=
psutil
.
virtual_memory
()
return
mem
.
used
/
gb
class
AverageMeter
:
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录