Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
OneFlow-Benchmark
提交
6f7ce3f6
O
OneFlow-Benchmark
项目概览
Oneflow-Inc
/
OneFlow-Benchmark
上一次同步 接近 3 年
通知
1
Star
92
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
OneFlow-Benchmark
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
6f7ce3f6
编写于
2月 08, 2020
作者:
S
ShawnXuan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix dali cpu mode
上级
288e19c4
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
15 addition
and
20 deletion
+15
-20
cnn_benchmark/config.py
cnn_benchmark/config.py
+2
-1
cnn_benchmark/dali.py
cnn_benchmark/dali.py
+8
-15
cnn_benchmark/of_cnn_train_val.py
cnn_benchmark/of_cnn_train_val.py
+5
-4
未找到文件。
cnn_benchmark/config.py
浏览文件 @
6f7ce3f6
...
...
@@ -57,7 +57,8 @@ def get_parser(parser=None):
parser
.
add_argument
(
'--data_train_idx'
,
type
=
str
,
default
=
''
,
help
=
'the index of training data'
)
parser
.
add_argument
(
'--data_val'
,
type
=
str
,
help
=
'the validation data'
)
parser
.
add_argument
(
'--data_val_idx'
,
type
=
str
,
default
=
''
,
help
=
'the index of validation data'
)
parser
.
add_argument
(
"--num_examples"
,
type
=
int
,
default
=
1281167
,
help
=
"imagenet pic number"
)
parser
.
add_argument
(
"--num_examples"
,
type
=
int
,
default
=
1281167
,
help
=
"train pic number"
)
parser
.
add_argument
(
"--num_val_examples"
,
type
=
int
,
default
=
50000
,
help
=
"validation pic number"
)
## snapshot
parser
.
add_argument
(
"--model_save_dir"
,
type
=
str
,
...
...
cnn_benchmark/dali.py
浏览文件 @
6f7ce3f6
...
...
@@ -54,7 +54,6 @@ class HybridTrainPipe(Pipeline):
dali_device
=
"cpu"
if
dali_cpu
else
"mixed"
dali_resize_device
=
"cpu"
if
dali_cpu
else
"gpu"
print
(
dali_device
,
dali_resize_device
)
if
args
.
dali_fuse_decoder
:
self
.
decode
=
ops
.
ImageDecoderRandomCrop
(
device
=
dali_device
,
output_type
=
types
.
RGB
,
device_memory_padding
=
nvjpeg_padding
,
...
...
@@ -68,8 +67,8 @@ class HybridTrainPipe(Pipeline):
self
.
resize
=
ops
.
RandomResizedCrop
(
device
=
dali_resize_device
,
size
=
crop_shape
)
#self.cmnp = ops.CropMirrorNormalize(device=
dali_resize_device, #
"gpu",
self
.
cmnp
=
ops
.
CropMirrorNormalize
(
device
=
"gpu"
,
#self.cmnp = ops.CropMirrorNormalize(device="gpu",
self
.
cmnp
=
ops
.
CropMirrorNormalize
(
device
=
dali_resize_device
,
#
"gpu",
output_dtype
=
types
.
FLOAT16
if
dtype
==
'float16'
else
types
.
FLOAT
,
output_layout
=
output_layout
,
crop
=
crop_shape
,
pad_output
=
pad_output
,
image_type
=
types
.
RGB
,
mean
=
args
.
rgb_mean
,
std
=
args
.
rgb_std
)
...
...
@@ -81,7 +80,7 @@ class HybridTrainPipe(Pipeline):
images
=
self
.
decode
(
self
.
jpegs
)
images
=
self
.
resize
(
images
)
output
=
self
.
cmnp
(
images
.
gpu
()
,
mirror
=
rng
)
output
=
self
.
cmnp
(
images
,
mirror
=
rng
)
return
[
output
,
self
.
labels
]
...
...
@@ -102,10 +101,9 @@ class HybridValPipe(Pipeline):
self
.
decode
=
ops
.
ImageDecoder
(
device
=
"mixed"
,
output_type
=
types
.
RGB
,
device_memory_padding
=
nvjpeg_padding
,
host_memory_padding
=
nvjpeg_padding
)
print
(
dali_device
)
self
.
resize
=
ops
.
Resize
(
device
=
dali_device
,
resize_shorter
=
resize_shp
)
if
resize_shp
else
None
#self.cmnp = ops.CropMirrorNormalize(device=
dali_device,#
"gpu",
self
.
cmnp
=
ops
.
CropMirrorNormalize
(
device
=
"gpu"
,
#self.cmnp = ops.CropMirrorNormalize(device="gpu",
self
.
cmnp
=
ops
.
CropMirrorNormalize
(
device
=
dali_device
,
#
"gpu",
output_dtype
=
types
.
FLOAT16
if
dtype
==
'float16'
else
types
.
FLOAT
,
output_layout
=
output_layout
,
crop
=
crop_shape
,
pad_output
=
pad_output
,
image_type
=
types
.
RGB
,
mean
=
args
.
rgb_mean
,
std
=
args
.
rgb_std
)
...
...
@@ -115,7 +113,7 @@ class HybridValPipe(Pipeline):
images
=
self
.
decode
(
self
.
jpegs
)
if
self
.
resize
:
images
=
self
.
resize
(
images
)
output
=
self
.
cmnp
(
images
.
gpu
()
)
output
=
self
.
cmnp
(
images
)
return
[
output
,
self
.
labels
]
...
...
@@ -279,7 +277,7 @@ class DALIGenericIterator(object):
print
(
"DALI iterator does not support resetting while epoch is not finished. Ignoring..."
)
def
get_rec_iter
(
args
,
dali_cpu
=
False
,
todo
=
True
):
def
get_rec_iter
(
args
,
train_batch_size
,
val_batch_size
,
dali_cpu
=
False
,
todo
=
True
):
# TBD dali_cpu only not work
if
todo
:
gpus
=
[
0
]
...
...
@@ -295,11 +293,6 @@ def get_rec_iter(args, dali_cpu=False, todo=True):
# the input_layout w.r.t. the model is the output_layout of the image pipeline
output_layout
=
types
.
NHWC
if
args
.
input_layout
==
'NHWC'
else
types
.
NCHW
total_device_num
=
args
.
num_nodes
*
args
.
gpu_num_per_node
train_batch_size
=
total_device_num
*
args
.
batch_size_per_device
val_batch_size
=
total_device_num
*
args
.
val_batch_size_per_device
print
(
train_batch_size
,
val_batch_size
)
trainpipes
=
[
HybridTrainPipe
(
args
=
args
,
batch_size
=
train_batch_size
,
num_threads
=
num_threads
,
...
...
@@ -355,7 +348,7 @@ if __name__ == '__main__':
parser
=
configs
.
get_parser
()
args
=
parser
.
parse_args
()
print_args
(
args
)
train_data_iter
,
val_data_iter
=
get_rec_iter
(
args
,
True
)
train_data_iter
,
val_data_iter
=
get_rec_iter
(
args
,
256
,
500
,
True
)
for
epoch
in
range
(
args
.
num_epochs
):
tic
=
time
.
time
()
print
(
'Starting epoch {}'
.
format
(
epoch
))
...
...
cnn_benchmark/of_cnn_train_val.py
浏览文件 @
6f7ce3f6
...
...
@@ -30,6 +30,7 @@ epoch_size = math.ceil(args.num_examples / train_batch_size)
num_train_batches
=
epoch_size
*
args
.
num_epochs
num_warmup_batches
=
epoch_size
*
args
.
warmup_epochs
decay_batches
=
num_train_batches
-
num_warmup_batches
num_val_steps
=
args
.
num_val_examples
/
val_batch_size
summary
=
Summary
(
args
.
log_dir
,
args
)
timer
=
StopWatch
()
...
...
@@ -143,7 +144,7 @@ def train_callback(epoch, step):
def
do_predictions
(
epoch
,
predict_step
,
predictions
):
acc_acc
(
predict_step
,
predictions
)
if
predict_step
+
1
==
args
.
val_step_num
:
if
predict_step
+
1
==
num_val_steps
:
assert
main
.
total
>
0
summary
.
scalar
(
'top1_accuracy'
,
main
.
correct
/
main
.
total
,
epoch
)
#summary.scalar('top1_correct', main.correct, epoch)
...
...
@@ -166,7 +167,7 @@ def main():
snapshot
=
Snapshot
(
args
.
model_save_dir
,
args
.
model_load_dir
)
train_data_iter
,
val_data_iter
=
get_rec_iter
(
args
,
True
)
train_data_iter
,
val_data_iter
=
get_rec_iter
(
args
,
train_batch_size
,
val_batch_size
,
True
)
timer
.
start
()
for
epoch
in
range
(
args
.
num_epochs
):
tic
=
time
.
time
()
...
...
@@ -186,8 +187,8 @@ def main():
for
i
,
batches
in
enumerate
(
val_data_iter
):
assert
len
(
batches
)
==
1
images
,
labels
=
batches
[
0
]
#
InferenceNet(images, labels.astype(np.int32)).async_get(predict_callback(epoch, i))
acc_acc
(
i
,
InferenceNet
(
images
,
labels
.
astype
(
np
.
int32
)).
get
())
InferenceNet
(
images
,
labels
.
astype
(
np
.
int32
)).
async_get
(
predict_callback
(
epoch
,
i
))
#
acc_acc(i, InferenceNet(images, labels.astype(np.int32)).get())
assert
main
.
total
>
0
top1_accuracy
=
main
.
correct
/
main
.
total
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录