Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
OneFlow-Benchmark
提交
6999f340
O
OneFlow-Benchmark
项目概览
Oneflow-Inc
/
OneFlow-Benchmark
上一次同步 接近 3 年
通知
1
Star
92
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
OneFlow-Benchmark
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
6999f340
编写于
5月 29, 2020
作者:
S
ShawnXuan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
new dataloader and boxing v2 params
上级
080de09d
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
18 addition
and
7 deletion
+18
-7
cnn_e2e/of_cnn_train_val.py
cnn_e2e/of_cnn_train_val.py
+14
-4
cnn_e2e/resnet_model.py
cnn_e2e/resnet_model.py
+4
-3
未找到文件。
cnn_e2e/of_cnn_train_val.py
浏览文件 @
6999f340
...
...
@@ -39,18 +39,25 @@ model_dict = {
flow
.
config
.
gpu_device_num
(
args
.
gpu_num_per_node
)
flow
.
config
.
enable_debug_mode
(
True
)
if
args
.
use_boxing_v2
:
flow
.
config
.
collective_boxing
.
nccl_fusion_threshold_mb
(
8
)
flow
.
config
.
collective_boxing
.
nccl_fusion_all_reduce_use_buffer
(
False
)
@
flow
.
function
(
get_train_config
(
args
))
def
TrainNet
():
if
args
.
train_data_dir
:
assert
os
.
path
.
exists
(
args
.
train_data_dir
)
print
(
"Loading data from {}"
.
format
(
args
.
train_data_dir
))
(
labels
,
images
)
=
ofrecord_util
.
load_imagenet_for_training2
(
args
)
if
args
.
use_new_dataloader
:
(
labels
,
images
)
=
ofrecord_util
.
load_imagenet_for_training2
(
args
)
else
:
(
labels
,
images
)
=
ofrecord_util
.
load_imagenet_for_training
(
args
)
# note: images.shape = (N C H W) in cc's new dataloader(load_imagenet_for_training2)
else
:
print
(
"Loading synthetic data."
)
(
labels
,
images
)
=
ofrecord_util
.
load_synthetic
(
args
)
logits
=
model_dict
[
args
.
model
](
images
)
logits
=
model_dict
[
args
.
model
](
images
,
need_transpose
=
not
args
.
use_new_dataloader
)
loss
=
flow
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
,
logits
,
name
=
"softmax_loss"
)
loss
=
flow
.
math
.
reduce_mean
(
loss
)
flow
.
losses
.
add_loss
(
loss
)
...
...
@@ -64,12 +71,15 @@ def InferenceNet():
if
args
.
val_data_dir
:
assert
os
.
path
.
exists
(
args
.
val_data_dir
)
print
(
"Loading data from {}"
.
format
(
args
.
val_data_dir
))
(
labels
,
images
)
=
ofrecord_util
.
load_imagenet_for_validation2
(
args
)
if
args
.
use_new_dataloader
:
(
labels
,
images
)
=
ofrecord_util
.
load_imagenet_for_validation2
(
args
)
else
:
(
labels
,
images
)
=
ofrecord_util
.
load_imagenet_for_validation
(
args
)
else
:
print
(
"Loading synthetic data."
)
(
labels
,
images
)
=
ofrecord_util
.
load_synthetic
(
args
)
logits
=
model_dict
[
args
.
model
](
images
)
logits
=
model_dict
[
args
.
model
](
images
,
need_transpose
=
not
args
.
use_new_dataloader
)
predictions
=
flow
.
nn
.
softmax
(
logits
)
outputs
=
{
"predictions"
:
predictions
,
"labels"
:
labels
}
return
outputs
...
...
cnn_e2e/resnet_model.py
浏览文件 @
6999f340
...
...
@@ -26,7 +26,7 @@ def _conv2d(
):
weight
=
flow
.
get_variable
(
name
+
"-weight"
,
shape
=
(
filters
,
input
.
s
tatic_s
hape
[
1
],
kernel_size
,
kernel_size
),
shape
=
(
filters
,
input
.
shape
[
1
],
kernel_size
,
kernel_size
),
dtype
=
input
.
dtype
,
initializer
=
weight_initializer
,
regularizer
=
weight_regularizer
,
...
...
@@ -125,10 +125,11 @@ def resnet_stem(input):
return
pool1
def
resnet50
(
images
,
trainable
=
True
):
def
resnet50
(
images
,
trainable
=
True
,
need_transpose
=
False
):
# note: images.shape = (N C H W) in cc's new dataloader, transpose is not needed anymore
# images = flow.transpose(images, name="transpose", perm=[0, 3, 1, 2])
if
need_transpose
:
images
=
flow
.
transpose
(
images
,
name
=
"transpose"
,
perm
=
[
0
,
3
,
1
,
2
])
with
flow
.
deprecated
.
variable_scope
(
"Resnet"
):
stem
=
resnet_stem
(
images
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录