Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
301b01e4
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
301b01e4
编写于
8月 19, 2020
作者:
C
chenhaozhe
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
sync some bugfix of bert scripts to branch r0.5
上级
d44cf6a0
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
12 addition
and
9 deletion
+12
-9
model_zoo/bert/run_pretrain.py
model_zoo/bert/run_pretrain.py
+10
-7
model_zoo/bert/scripts/run_distribute_pretrain.sh
model_zoo/bert/scripts/run_distribute_pretrain.sh
+1
-1
model_zoo/bert/scripts/run_standalone_pretrain.sh
model_zoo/bert/scripts/run_standalone_pretrain.sh
+1
-1
未找到文件。
model_zoo/bert/run_pretrain.py
浏览文件 @
301b01e4
...
...
@@ -18,6 +18,7 @@ python run_pretrain.py
"""
import
os
import
math
import
argparse
import
numpy
import
mindspore.communication.management
as
D
...
...
@@ -44,15 +45,16 @@ class LossCallBack(Callback):
Args:
per_print_times (int): Print loss every times. Default: 1.
"""
def
__init__
(
self
,
per_print_times
=
1
):
def
__init__
(
self
,
data_epoch_size
=
1
):
super
(
LossCallBack
,
self
).
__init__
()
if
not
isinstance
(
per_print_times
,
int
)
or
per_print_times
<
0
:
raise
ValueError
(
"
print_step
must be int and >= 0"
)
self
.
_
per_print_times
=
per_print_times
if
not
isinstance
(
data_epoch_size
,
int
)
or
data_epoch_size
<
0
:
raise
ValueError
(
"
data_epoch_size
must be int and >= 0"
)
self
.
_
data_epoch_size
=
data_epoch_size
def
step_end
(
self
,
run_context
):
cb_params
=
run_context
.
original_args
()
print
(
"epoch: {}, step: {}, outputs are {}"
.
format
(
cb_params
.
cur_epoch_num
,
cb_params
.
cur_step_num
,
str
(
cb_params
.
net_outputs
)))
percent
,
epoch
=
math
.
modf
(
cb_params
.
cur_epoch_num
/
self
.
_data_epoch_size
)
print
(
"epoch: {}, current epoch percent: {}, step: {}, outputs are {}"
.
format
(
epoch
,
"%.3f"
%
percent
,
cb_params
.
cur_step_num
,
str
(
cb_params
.
net_outputs
)))
def
run_pretrain
():
"""pre-train bert_clue"""
...
...
@@ -120,6 +122,7 @@ def run_pretrain():
ds
,
new_repeat_count
=
create_bert_dataset
(
args_opt
.
epoch_size
,
device_num
,
rank
,
args_opt
.
do_shuffle
,
args_opt
.
enable_data_sink
,
args_opt
.
data_sink_steps
,
args_opt
.
data_dir
,
args_opt
.
schema_dir
)
data_epoch_size
=
new_repeat_count
//
args_opt
.
epoch_size
# Epoch nums in one dataset.
if
args_opt
.
train_steps
>
0
:
new_repeat_count
=
min
(
new_repeat_count
,
args_opt
.
train_steps
//
args_opt
.
data_sink_steps
)
netwithloss
=
BertNetworkWithLoss
(
bert_net_cfg
,
True
)
...
...
@@ -144,7 +147,7 @@ def run_pretrain():
else
:
raise
ValueError
(
"Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecayDynamicLR]"
.
format
(
cfg
.
optimizer
))
callback
=
[
TimeMonitor
(
ds
.
get_dataset_size
()),
LossCallBack
()]
callback
=
[
TimeMonitor
(
ds
.
get_dataset_size
()),
LossCallBack
(
data_epoch_size
)]
if
args_opt
.
enable_save_ckpt
==
"true"
:
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
args_opt
.
save_checkpoint_steps
,
keep_checkpoint_max
=
args_opt
.
save_checkpoint_num
)
...
...
model_zoo/bert/scripts/run_distribute_pretrain.sh
浏览文件 @
301b01e4
...
...
@@ -54,7 +54,7 @@ do
export
GLOG_log_dir
=
${
CUR_DIR
}
/ms_log
export
GLOG_logtostderr
=
0
env
>
env.log
taskset
-c
$cmdopt
python ../run_pretrain.py
\
taskset
-c
$cmdopt
nohup
python ../run_pretrain.py
\
--distribute
=
"true"
\
--epoch_size
=
$EPOCH_SIZE
\
--device_id
=
$DEVICE_ID
\
...
...
model_zoo/bert/scripts/run_standalone_pretrain.sh
浏览文件 @
301b01e4
...
...
@@ -29,7 +29,7 @@ mkdir -p ms_log
CUR_DIR
=
`
pwd
`
export
GLOG_log_dir
=
${
CUR_DIR
}
/ms_log
export
GLOG_logtostderr
=
0
python run_pretrain.py
\
nohup
python run_pretrain.py
\
--distribute
=
"false"
\
--epoch_size
=
$EPOCH_SIZE
\
--device_id
=
$DEVICE_ID
\
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录