提交 301b01e4 编写于 作者: C chenhaozhe

sync some bugfix of bert scripts to branch r0.5

上级 d44cf6a0
...@@ -18,6 +18,7 @@ python run_pretrain.py ...@@ -18,6 +18,7 @@ python run_pretrain.py
""" """
import os import os
import math
import argparse import argparse
import numpy import numpy
import mindspore.communication.management as D import mindspore.communication.management as D
...@@ -44,15 +45,16 @@ class LossCallBack(Callback): ...@@ -44,15 +45,16 @@ class LossCallBack(Callback):
Args: Args:
per_print_times (int): Print loss every times. Default: 1. per_print_times (int): Print loss every times. Default: 1.
""" """
def __init__(self, per_print_times=1): def __init__(self, data_epoch_size=1):
super(LossCallBack, self).__init__() super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0: if not isinstance(data_epoch_size, int) or data_epoch_size < 0:
raise ValueError("print_step must be int and >= 0") raise ValueError("data_epoch_size must be int and >= 0")
self._per_print_times = per_print_times self._data_epoch_size = data_epoch_size
def step_end(self, run_context): def step_end(self, run_context):
cb_params = run_context.original_args() cb_params = run_context.original_args()
print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num, percent, epoch = math.modf(cb_params.cur_epoch_num / self._data_epoch_size)
str(cb_params.net_outputs))) print("epoch: {}, current epoch percent: {}, step: {}, outputs are {}"
.format(epoch, "%.3f" % percent, cb_params.cur_step_num, str(cb_params.net_outputs)))
def run_pretrain(): def run_pretrain():
"""pre-train bert_clue""" """pre-train bert_clue"""
...@@ -120,6 +122,7 @@ def run_pretrain(): ...@@ -120,6 +122,7 @@ def run_pretrain():
ds, new_repeat_count = create_bert_dataset(args_opt.epoch_size, device_num, rank, args_opt.do_shuffle, ds, new_repeat_count = create_bert_dataset(args_opt.epoch_size, device_num, rank, args_opt.do_shuffle,
args_opt.enable_data_sink, args_opt.data_sink_steps, args_opt.enable_data_sink, args_opt.data_sink_steps,
args_opt.data_dir, args_opt.schema_dir) args_opt.data_dir, args_opt.schema_dir)
data_epoch_size = new_repeat_count // args_opt.epoch_size # Epoch nums in one dataset.
if args_opt.train_steps > 0: if args_opt.train_steps > 0:
new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps) new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps)
netwithloss = BertNetworkWithLoss(bert_net_cfg, True) netwithloss = BertNetworkWithLoss(bert_net_cfg, True)
...@@ -144,7 +147,7 @@ def run_pretrain(): ...@@ -144,7 +147,7 @@ def run_pretrain():
else: else:
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecayDynamicLR]". raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecayDynamicLR]".
format(cfg.optimizer)) format(cfg.optimizer))
callback = [TimeMonitor(ds.get_dataset_size()), LossCallBack()] callback = [TimeMonitor(ds.get_dataset_size()), LossCallBack(data_epoch_size)]
if args_opt.enable_save_ckpt == "true": if args_opt.enable_save_ckpt == "true":
config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps, config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
keep_checkpoint_max=args_opt.save_checkpoint_num) keep_checkpoint_max=args_opt.save_checkpoint_num)
......
...@@ -54,7 +54,7 @@ do ...@@ -54,7 +54,7 @@ do
export GLOG_log_dir=${CUR_DIR}/ms_log export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0 export GLOG_logtostderr=0
env > env.log env > env.log
taskset -c $cmdopt python ../run_pretrain.py \ taskset -c $cmdopt nohup python ../run_pretrain.py \
--distribute="true" \ --distribute="true" \
--epoch_size=$EPOCH_SIZE \ --epoch_size=$EPOCH_SIZE \
--device_id=$DEVICE_ID \ --device_id=$DEVICE_ID \
......
...@@ -29,7 +29,7 @@ mkdir -p ms_log ...@@ -29,7 +29,7 @@ mkdir -p ms_log
CUR_DIR=`pwd` CUR_DIR=`pwd`
export GLOG_log_dir=${CUR_DIR}/ms_log export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0 export GLOG_logtostderr=0
python run_pretrain.py \ nohup python run_pretrain.py \
--distribute="false" \ --distribute="false" \
--epoch_size=$EPOCH_SIZE \ --epoch_size=$EPOCH_SIZE \
--device_id=$DEVICE_ID \ --device_id=$DEVICE_ID \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册