Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
OneFlow-Benchmark
提交
6832aa45
O
OneFlow-Benchmark
项目概览
Oneflow-Inc
/
OneFlow-Benchmark
上一次同步 接近 3 年
通知
1
Star
92
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
OneFlow-Benchmark
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
6832aa45
编写于
3月 21, 2020
作者:
S
ShawnXuan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
oneflow validation
上级
93d3d896
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
105 addition
and
0 deletion
+105
-0
cnn_e2e/of_cnn_val.py
cnn_e2e/of_cnn_val.py
+76
-0
of_val.sh
of_val.sh
+29
-0
未找到文件。
cnn_e2e/of_cnn_val.py
0 → 100755
浏览文件 @
6832aa45
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
time
import
math
import
numpy
as
np
import
config
as
configs
parser
=
configs
.
get_parser
()
args
=
parser
.
parse_args
()
configs
.
print_args
(
args
)
from
util
import
Snapshot
,
Summary
,
InitNodes
,
Metric
import
ofrecord_util
from
job_function_util
import
get_train_config
,
get_val_config
import
oneflow
as
flow
#import vgg_model
import
resnet_model
#import alexnet_model
total_device_num
=
args
.
num_nodes
*
args
.
gpu_num_per_node
train_batch_size
=
total_device_num
*
args
.
batch_size_per_device
val_batch_size
=
total_device_num
*
args
.
val_batch_size_per_device
(
C
,
H
,
W
)
=
args
.
image_shape
epoch_size
=
math
.
ceil
(
args
.
num_examples
/
train_batch_size
)
num_val_steps
=
int
(
args
.
num_val_examples
/
val_batch_size
)
model_dict
=
{
"resnet50"
:
resnet_model
.
resnet50
,
#"vgg16": vgg_model.vgg16,
#"alexnet": alexnet_model.alexnet,
}
flow
.
config
.
gpu_device_num
(
args
.
gpu_num_per_node
)
flow
.
config
.
enable_debug_mode
(
True
)
@
flow
.
function
(
get_val_config
(
args
))
def
InferenceNet
():
if
args
.
val_data_dir
:
assert
os
.
path
.
exists
(
args
.
val_data_dir
)
print
(
"Loading data from {}"
.
format
(
args
.
val_data_dir
))
(
labels
,
images
)
=
ofrecord_util
.
load_imagenet_for_validation
(
args
)
else
:
print
(
"Loading synthetic data."
)
(
labels
,
images
)
=
ofrecord_util
.
load_synthetic
(
args
)
logits
=
model_dict
[
args
.
model
](
images
)
predictions
=
flow
.
nn
.
softmax
(
logits
)
outputs
=
{
"predictions"
:
predictions
,
"labels"
:
labels
}
return
outputs
def
main
():
InitNodes
(
args
)
assert
args
.
model_load_dir
,
'must have model load dir'
flow
.
env
.
grpc_use_no_signal
()
flow
.
env
.
log_dir
(
args
.
log_dir
)
summary
=
Summary
(
args
.
log_dir
,
args
)
for
epoch
in
range
(
args
.
num_epochs
):
model_load_dir
=
os
.
path
.
join
(
args
.
model_load_dir
,
'snapshot_epoch_{}'
.
format
(
epoch
+
1
))
snapshot
=
Snapshot
(
args
.
model_save_dir
,
model_load_dir
)
metric
=
Metric
(
desc
=
'validataion'
,
calculate_batches
=
num_val_steps
,
summary
=
summary
,
save_summary_steps
=
num_val_steps
,
batch_size
=
val_batch_size
)
for
i
in
range
(
num_val_steps
):
InferenceNet
().
async_get
(
metric
.
metric_cb
(
epoch
,
i
))
if
__name__
==
"__main__"
:
main
()
of_val.sh
0 → 100755
浏览文件 @
6832aa45
rm
-rf
core.
*
DATA_ROOT
=
/mnt/13_nfs/xuan/ImageNet/ofrecord
#DATA_ROOT=/dataset/ImageNet/ofrecord
#DATA_ROOT=/dataset/imagenet-mxnet
#python3 cnn_benchmark/of_cnn_train_val.py \
#gdb --args \
#nvprof -f -o resnet.nvvp \
python3 cnn_e2e/of_cnn_val.py
\
--model_load_dir
=
output/models
\
--train_data_dir
=
$DATA_ROOT
/train
\
--train_data_part_num
=
256
\
--val_data_dir
=
$DATA_ROOT
/validation
\
--val_data_part_num
=
256
\
--num_nodes
=
1
\
--node_ips
=
'11.11.1.13,11.11.1.14'
\
--gpu_num_per_node
=
4
\
--optimizer
=
"momentum-cosine-decay"
\
--learning_rate
=
0.256
\
--loss_print_every_n_iter
=
20
\
--batch_size_per_device
=
32
\
--val_batch_size_per_device
=
125
\
--model
=
"resnet50"
#--use_fp16 true \
#--weight_l2=3.0517578125e-05 \
#--num_examples=1024 \
#--optimizer="momentum-decay" \
#--data_dir="/mnt/13_nfs/xuan/ImageNet/ofrecord/train"
#--data_dir="/mnt/dataset/xuan/ImageNet/ofrecord/train"
#--warmup_iter_num=10000 \
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录