Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
94a76969
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
94a76969
编写于
4月 01, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor ImageNet
GitOrigin-RevId: f7774e0ffc5de7ffb3ea5eba5ddb9809b9d049dd
上级
e677ffc8
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
60 addition
and
53 deletion
+60
-53
python_module/megengine/data/dataset/vision/imagenet.py
python_module/megengine/data/dataset/vision/imagenet.py
+60
-53
未找到文件。
python_module/megengine/data/dataset/vision/imagenet.py
浏览文件 @
94a76969
...
...
@@ -24,7 +24,7 @@ from ....core.serialization import load, save
from
....distributed.util
import
is_distributed
from
....logger
import
get_logger
from
.folder
import
ImageFolder
from
.utils
import
_default_dataset_root
,
untar
,
untargz
from
.utils
import
_default_dataset_root
,
calculate_md5
,
untar
,
untargz
logger
=
get_logger
(
__name__
)
...
...
@@ -33,40 +33,28 @@ class ImageNet(ImageFolder):
r
"""
Load ImageNet from raw files or folder, expected folder looks like
raw files situation (optional):
root/ILSVRC2012_img_train.tar
root/ILSVRC2012_img_val
.tar
root/ILSVRC2012_devkit_t12.tar.gz
image folder situation (required):
root/
train/cls/xxx.${img_ext}
root/
val/cls/xxx.${img_ext}
root/
ILSVRC2012_devkit_t12/data/meta.mat
root/
ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt
If the
required folders don't exist, raw
files are required to get extracted and processed.
${root}/
| [REQUIRED TAR FILES]
|- ILSVRC2012_img_train
.tar
|- ILSVRC2012_img_val.tar
|- ILSVRC2012_devkit_t12.tar.gz
| [OPTIONAL IMAGE FOLDERS]
|-
train/cls/xxx.${img_ext}
|-
val/cls/xxx.${img_ext}
|-
ILSVRC2012_devkit_t12/data/meta.mat
|-
ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt
If the
image folders don't exist, raw tar
files are required to get extracted and processed.
"""
raw_file_meta
=
{
"train"
:
(
"ILSVRC2012_img_train.tar"
,
"1d675b47d978889d74fa0da5fadfb00e"
),
"val"
:
(
"ILSVRC2012_img_val.tar"
,
"29b22e2961454d5413ddabcf34fc5622"
),
"devkit"
:
(
"ILSVRC2012_devkit_t12.tar.gz"
,
"fa75699e90414af021442c21a62c3abf"
),
}
"""
raw files of ImageNet (train, val, devkit)
"""
}
# ImageNet raw files
default_train_dir
=
"train"
"""
directory of train data
"""
default_val_dir
=
"val"
"""
directory of val data
"""
default_devkit_dir
=
"ILSVRC2012_devkit_t12"
"""
directory of devkit
"""
def
__init__
(
self
,
root
:
str
=
None
,
train
:
bool
=
True
,
**
kwargs
):
r
"""
...
...
@@ -97,13 +85,16 @@ class ImageNet(ImageFolder):
else
:
self
.
root
=
root
self
.
devkit_dir
=
os
.
path
.
join
(
self
.
root
,
self
.
default_devkit_dir
)
if
not
os
.
path
.
exists
(
self
.
root
):
raise
FileNotFoundError
(
"dir %s does not exist"
%
self
.
root
)
self
.
devkit_dir
=
os
.
path
.
join
(
self
.
root
,
self
.
default_devkit_dir
)
if
not
os
.
path
.
exists
(
self
.
devkit_dir
):
logger
.
warning
(
"devkit directory %s does not exists"
%
self
.
devkit_dir
)
self
.
_prepare_devkit
()
self
.
train
=
train
if
train
:
self
.
target_folder
=
os
.
path
.
join
(
self
.
root
,
self
.
default_train_dir
)
...
...
@@ -125,7 +116,7 @@ class ImageNet(ImageFolder):
"extracting raw file shouldn't be done in distributed mode, use single process instead"
)
else
:
self
.
parse
(
train
)
self
.
_prepare_train
()
if
train
else
self
.
_prepare_val
(
)
super
().
__init__
(
self
.
target_folder
,
**
kwargs
)
...
...
@@ -180,14 +171,13 @@ class ImageNet(ImageFolder):
]
)
def
organize_val_data
(
self
):
def
_
organize_val_data
(
self
):
id2wnid
=
self
.
meta
[
0
]
val_idcs
=
self
.
valid_ground_truth
val_wnids
=
[
id2wnid
[
idx
]
for
idx
in
val_idcs
]
raw_val_dir
=
os
.
path
.
join
(
self
.
root
,
"ILSVRC2012_img_val"
)
val_images
=
sorted
(
[
os
.
path
.
join
(
raw_val_dir
,
image
)
for
image
in
os
.
listdir
(
raw_val_di
r
)]
[
os
.
path
.
join
(
self
.
target_folder
,
image
)
for
image
in
os
.
listdir
(
self
.
target_folde
r
)]
)
logger
.
debug
(
"mkdir for val set wnids"
)
...
...
@@ -203,24 +193,41 @@ class ImageNet(ImageFolder):
),
)
def
parse
(
self
,
train
):
if
train
:
logger
.
info
(
"process train raw file.. this may take several hours"
)
untar
(
os
.
path
.
join
(
self
.
root
,
self
.
raw_file_meta
[
"train"
][
0
]),
self
.
target_folder
,
)
paths
=
[
os
.
path
.
join
(
self
.
target_folder
,
child_dir
)
for
child_dir
in
os
.
listdir
(
self
.
target_folder
)
]
for
path
in
tqdm
(
paths
):
untar
(
path
,
os
.
path
.
splitext
(
path
)[
0
],
remove
=
True
)
else
:
logger
.
info
(
"process devkit file.."
)
untargz
(
os
.
path
.
join
(
self
.
root
,
self
.
raw_file_meta
[
"devkit"
][
0
]))
logger
.
info
(
"process valid raw file.. this may take 10-20 minutes"
)
raw_val_dir
=
os
.
path
.
join
(
self
.
root
,
"ILSVRC2012_img_val"
)
os
.
makedirs
(
raw_val_dir
,
exist_ok
=
True
)
untar
(
os
.
path
.
join
(
self
.
root
,
self
.
raw_file_meta
[
"val"
][
0
]),
raw_val_dir
)
self
.
organize_val_data
()
def
_prepare_val
(
self
):
assert
not
self
.
train
raw_filename
,
checksum
=
self
.
raw_file_meta
[
"val"
]
raw_file
=
os
.
path
.
join
(
self
.
root
,
raw_filename
)
logger
.
info
(
"checksum valid tar file {} .."
.
format
(
raw_file
))
assert
calculate_md5
(
raw_file
)
==
checksum
,
\
"checksum mismatch, {} may be damaged"
.
format
(
raw_file
)
logger
.
info
(
"extract valid tar file.. this may take 10-20 minutes"
)
untar
(
os
.
path
.
join
(
self
.
root
,
raw_file
),
self
.
target_folder
)
self
.
_organize_val_data
()
def
_prepare_train
(
self
):
assert
self
.
train
raw_filename
,
checksum
=
self
.
raw_file_meta
[
"train"
]
raw_file
=
os
.
path
.
join
(
self
.
root
,
raw_filename
)
logger
.
info
(
"checksum train tar file {} .."
.
format
(
raw_file
))
assert
calculate_md5
(
raw_file
)
==
checksum
,
\
"checksum mismatch, {} may be damaged"
.
format
(
raw_file
)
logger
.
info
(
"extract train tar file.. this may take several hours"
)
untar
(
os
.
path
.
join
(
self
.
root
,
raw_file
),
self
.
target_folder
,
)
paths
=
[
os
.
path
.
join
(
self
.
target_folder
,
child_dir
)
for
child_dir
in
os
.
listdir
(
self
.
target_folder
)
]
for
path
in
tqdm
(
paths
):
untar
(
path
,
os
.
path
.
splitext
(
path
)[
0
],
remove
=
True
)
def
_prepare_devkit
(
self
):
raw_filename
,
checksum
=
self
.
raw_file_meta
[
"val"
]
raw_file
=
os
.
path
.
join
(
self
.
root
,
raw_filename
)
logger
.
info
(
"checksum devkit tar file {} .."
.
format
(
raw_file
))
assert
calculate_md5
(
raw_file
)
==
checksum
,
\
"checksum mismatch, {} may be damaged"
.
format
(
raw_file
)
logger
.
info
(
"extract devkit file.."
)
untargz
(
os
.
path
.
join
(
self
.
root
,
self
.
raw_file_meta
[
"devkit"
][
0
]))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录