Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
OneFlow-Benchmark
提交
6efb07fe
O
OneFlow-Benchmark
项目概览
Oneflow-Inc
/
OneFlow-Benchmark
上一次同步 2 年多
通知
0
Star
92
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
OneFlow-Benchmark
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
6efb07fe
编写于
2月 15, 2020
作者:
M
mir-of
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix flow.function in of_cnn_infer_benchmarks.py
上级
bfc921b0
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
17 addition
and
16 deletion
+17
-16
cnn_benchmark/of_cnn_infer_benchmarks.py
cnn_benchmark/of_cnn_infer_benchmarks.py
+17
-16
未找到文件。
cnn_benchmark/of_cnn_infer_benchmarks.py
浏览文件 @
6efb07fe
...
...
@@ -114,24 +114,28 @@ model_dict = {
"alexnet"
:
alexnet_model
.
alexnet
,
}
func_config
=
flow
.
FunctionConfig
()
func_config
.
default_data_type
(
flow
.
float
)
flow
.
config
.
gpu_device_num
(
args
.
gpu_num_per_node
)
if
args
.
use_tensorrt
:
func_config
.
use_tensorrt
()
if
args
.
use_xla_jit
:
func_config
.
use_xla_jit
()
if
args
.
precision
==
"float16"
:
if
not
args
.
use_tensorrt
:
func_config
.
enable_auto_mixed_precision
()
else
:
func_config
.
tensorrt
.
use_fp16
()
@
flow
.
function
@
flow
.
function
(
func_config
)
def
InferenceNet
():
total_device_num
=
args
.
node_num
*
args
.
gpu_num_per_node
batch_size
=
total_device_num
*
args
.
batch_size_per_device
if
args
.
use_tensorrt
:
flow
.
config
.
use_tensorrt
()
if
args
.
use_xla_jit
:
flow
.
config
.
use_xla_jit
()
if
args
.
precision
==
"float16"
:
if
not
args
.
use_tensorrt
:
flow
.
config
.
enable_auto_mixed_precision
()
else
:
flow
.
config
.
tensorrt
.
use_fp16
()
if
args
.
data_dir
:
assert
os
.
path
.
exists
(
args
.
data_dir
)
print
(
"Loading data from {}"
.
format
(
args
.
data_dir
))
...
...
@@ -159,12 +163,9 @@ def main():
print
(
"{} = {}"
.
format
(
arg
,
getattr
(
args
,
arg
)))
print
(
"-"
.
ljust
(
66
,
"-"
))
print
(
"Time stamp: {}"
.
format
(
str
(
datetime
.
now
().
strftime
(
"%Y-%m-%d-%H:%M:%S"
))))
flow
.
config
.
default_data_type
(
flow
.
float
)
flow
.
config
.
gpu_device_num
(
args
.
gpu_num_per_node
)
flow
.
env
.
grpc_use_no_signal
()
flow
.
env
.
log_dir
(
args
.
log_dir
)
# flow.config.enable_inplace(False)
# flow.config.ctrl_port(12140)
if
args
.
node_num
>
1
:
nodes
=
[]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录