Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
OneFlow-Benchmark
提交
7d38d810
O
OneFlow-Benchmark
项目概览
Oneflow-Inc
/
OneFlow-Benchmark
上一次同步 2 年多
通知
1
Star
92
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
OneFlow-Benchmark
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
7d38d810
编写于
9月 09, 2020
作者:
S
ShawnXuan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
bert add xla option
上级
6673c2dc
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
13 addition
and
9 deletion
+13
-9
LanguageModeling/BERT/config.py
LanguageModeling/BERT/config.py
+7
-5
LanguageModeling/BERT/util.py
LanguageModeling/BERT/util.py
+6
-4
未找到文件。
LanguageModeling/BERT/config.py
浏览文件 @
7d38d810
...
...
@@ -48,14 +48,16 @@ def get_parser(parser=None):
help
=
'node/machine number for training'
)
parser
.
add_argument
(
'--node_ips'
,
type
=
str_list
,
default
=
[
'192.168.1.13'
,
'192.168.1.14'
],
help
=
'nodes ip list for training, devided by ",", length >= num_nodes'
)
# train
parser
.
add_argument
(
"--learning_rate"
,
type
=
float
,
default
=
1e-4
,
help
=
"Learning rate"
)
parser
.
add_argument
(
"--weight_decay_rate"
,
type
=
float
,
default
=
0.01
,
help
=
"weight decay rate"
)
parser
.
add_argument
(
"--warmup_proportion"
,
type
=
float
,
default
=
0.1
)
parser
.
add_argument
(
'--use_fp16'
,
type
=
str2bool
,
nargs
=
'?'
,
default
=
'False'
,
const
=
True
,
parser
.
add_argument
(
'--use_fp16'
,
type
=
str2bool
,
nargs
=
'?'
,
default
=
'False'
,
const
=
True
,
help
=
'use use fp16 or not'
)
parser
.
add_argument
(
'--use_xla'
,
type
=
str2bool
,
nargs
=
'?'
,
const
=
True
,
help
=
'Whether to use use xla'
)
# log and resore/save
parser
.
add_argument
(
"--loss_print_every_n_iter"
,
type
=
int
,
default
=
10
,
required
=
False
,
help
=
"print loss every n iteration"
)
...
...
@@ -68,7 +70,7 @@ def get_parser(parser=None):
help
=
"save model snapshot for last iteration"
)
parser
.
add_argument
(
"--model_load_dir"
,
type
=
str
,
default
=
None
,
help
=
"model load directory"
)
parser
.
add_argument
(
"--log_dir"
,
type
=
str
,
default
=
"./output"
,
help
=
"log info save directory"
)
# bert backbone
parser
.
add_argument
(
'--do_lower_case'
,
type
=
str2bool
,
nargs
=
'?'
,
const
=
True
,
default
=
'True'
)
parser
.
add_argument
(
"--seq_length"
,
type
=
int
,
default
=
512
)
...
...
@@ -81,7 +83,7 @@ def get_parser(parser=None):
parser
.
add_argument
(
"--attention_probs_dropout_prob"
,
type
=
float
,
default
=
0.1
)
parser
.
add_argument
(
"--hidden_dropout_prob"
,
type
=
float
,
default
=
0.1
)
parser
.
add_argument
(
"--hidden_size_per_head"
,
type
=
int
,
default
=
64
)
return
parser
...
...
LanguageModeling/BERT/util.py
浏览文件 @
7d38d810
...
...
@@ -131,7 +131,7 @@ class Metric(object):
self
.
metric_dict
[
key
]
=
0.0
self
.
metric_dict
[
'throughput'
]
=
0.0
self
.
num_samples
=
0.0
def
update_and_save
(
self
,
key
,
value
,
step
,
**
kwargs
):
self
.
metric_dict
[
key
]
=
value
if
self
.
save_summary
:
...
...
@@ -164,14 +164,16 @@ class Metric(object):
def
CreateOptimizer
(
args
):
warmup_batches
=
int
(
args
.
iter_num
*
args
.
warmup_proportion
)
lr_warmup
=
flow
.
optimizer
.
warmup
.
linear
(
warmup_batches
,
0
)
lr_scheduler
=
flow
.
optimizer
.
PolynomialSchduler
(
args
.
learning_rate
,
args
.
iter_num
,
0.0
,
lr_scheduler
=
flow
.
optimizer
.
PolynomialSchduler
(
args
.
learning_rate
,
args
.
iter_num
,
0.0
,
warmup
=
lr_warmup
)
return
flow
.
optimizer
.
AdamW
(
lr_scheduler
,
epsilon
=
1e-6
,
weight_decay
=
args
.
weight_decay_rate
,
return
flow
.
optimizer
.
AdamW
(
lr_scheduler
,
epsilon
=
1e-6
,
weight_decay
=
args
.
weight_decay_rate
,
weight_decay_excludes
=
[
"bias"
,
"LayerNorm"
,
"layer_norm"
],
grad_clipping
=
flow
.
optimizer
.
grad_clipping
.
by_global_norm
(
1.0
))
def
GetFunctionConfig
(
args
):
config
=
flow
.
function_config
()
config
.
enable_auto_mixed_precision
(
args
.
use_fp16
)
if
args
.
use_xla
:
config
.
use_xla_jit
(
True
)
return
config
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录