Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
OneFlow-Benchmark
提交
68901d68
O
OneFlow-Benchmark
项目概览
Oneflow-Inc
/
OneFlow-Benchmark
上一次同步 2 年多
通知
1
Star
92
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
OneFlow-Benchmark
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
68901d68
编写于
6月 28, 2020
作者:
M
mir-of
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
use argpartition
上级
a72650a5
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
18 addition
and
5 deletion
+18
-5
cnn_benchmark/config.py
cnn_benchmark/config.py
+2
-0
cnn_benchmark/of_cnn_train_val.py
cnn_benchmark/of_cnn_train_val.py
+14
-4
cnn_benchmark/util.py
cnn_benchmark/util.py
+2
-1
未找到文件。
cnn_benchmark/config.py
浏览文件 @
68901d68
...
...
@@ -66,6 +66,7 @@ def get_parser(parser=None):
parser
.
add_argument
(
"--val_batch_size_per_device"
,
type
=
int
,
default
=
8
)
# for data process
parser
.
add_argument
(
"--num_classes"
,
type
=
int
,
default
=
1000
,
help
=
"num of pic classes"
)
parser
.
add_argument
(
"--num_examples"
,
type
=
int
,
default
=
1281167
,
help
=
"train pic number"
)
parser
.
add_argument
(
"--num_val_examples"
,
type
=
int
,
...
...
@@ -78,6 +79,7 @@ def get_parser(parser=None):
default
=
'NHWC'
,
help
=
"NCHW or NHWC"
)
parser
.
add_argument
(
'--image-shape'
,
type
=
int_list
,
default
=
[
3
,
224
,
224
],
help
=
'the image shape feed into the network'
)
parser
.
add_argument
(
'--label-smoothing'
,
type
=
float
,
default
=
0.1
,
help
=
'label smoothing factor'
)
# snapshot
parser
.
add_argument
(
"--model_save_dir"
,
type
=
str
,
...
...
cnn_benchmark/of_cnn_train_val.py
浏览文件 @
68901d68
...
...
@@ -14,7 +14,6 @@ from job_function_util import get_train_config, get_val_config
import
resnet_model
parser
=
configs
.
get_parser
()
args
=
parser
.
parse_args
()
configs
.
print_args
(
args
)
...
...
@@ -41,6 +40,14 @@ if args.use_boxing_v2:
flow
.
config
.
collective_boxing
.
nccl_fusion_all_reduce_use_buffer
(
False
)
def
label_smoothing
(
labels
,
classes
,
eta
,
dtype
):
assert
classes
>
0
assert
eta
>=
0.0
and
eta
<
1.0
return
flow
.
one_hot
(
labels
,
depth
=
classes
,
dtype
=
dtype
,
on_value
=
1
-
eta
+
eta
/
classes
,
off_value
=
eta
/
classes
)
@
flow
.
global_function
(
get_train_config
(
args
))
def
TrainNet
():
if
args
.
train_data_dir
:
...
...
@@ -54,9 +61,12 @@ def TrainNet():
logits
=
model_dict
[
args
.
model
](
images
,
need_transpose
=
False
if
args
.
train_data_dir
else
True
)
loss
=
flow
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
,
logits
,
name
=
"softmax_loss"
)
loss
=
flow
.
math
.
reduce_mean
(
loss
)
# loss = flow.nn.sparse_softmax_cross_entropy_with_logits(
# labels, logits, name="softmax_loss")
# loss = flow.math.reduce_mean(loss)
one_hot_labels
=
label_smoothing
(
labels
,
args
.
num_classes
,
args
.
label_smoothing
,
logits
.
dtype
)
loss
=
flow
.
nn
.
softmax_cross_entropy_with_logits
(
one_hot_labels
,
logits
,
name
=
"softmax_loss"
)
flow
.
losses
.
add_loss
(
loss
)
predictions
=
flow
.
nn
.
softmax
(
logits
)
outputs
=
{
"loss"
:
loss
,
"predictions"
:
predictions
,
"labels"
:
labels
}
...
...
cnn_benchmark/util.py
浏览文件 @
68901d68
...
...
@@ -48,6 +48,7 @@ class Summary(object):
def
__init__
(
self
,
log_dir
,
config
,
filename
=
'summary.csv'
):
self
.
_filename
=
filename
self
.
_log_dir
=
log_dir
if
not
os
.
path
.
exists
(
log_dir
):
os
.
makedirs
(
log_dir
)
self
.
_metrics
=
pd
.
DataFrame
({
"epoch"
:
0
,
"iter"
:
0
,
"legend"
:
"cfg"
,
"note"
:
str
(
config
)},
index
=
[
0
])
def
scalar
(
self
,
legend
,
value
,
epoch
,
step
=-
1
):
...
...
@@ -84,7 +85,7 @@ class StopWatch(object):
def
match_top_k
(
predictions
,
labels
,
top_k
=
1
):
max_k_preds
=
predictions
.
argsort
(
axis
=
1
)[:,
-
top_k
:][:,
::
-
1
]
max_k_preds
=
np
.
argpartition
(
predictions
.
ndarray
(),
-
top_k
)[:,
-
top_k
:
]
match_array
=
np
.
logical_or
.
reduce
(
max_k_preds
==
labels
.
reshape
((
-
1
,
1
)),
axis
=
1
)
num_matched
=
match_array
.
sum
()
return
num_matched
,
match_array
.
shape
[
0
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录