提交 b3b71e1d 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3724 modify readme and timemoniter steps

Merge pull request !3724 from wanghua/r0.6
# TinyBERT Example
## Description
This example implements general distill and task distill of [BERT-base](https://github.com/google-research/bert)(the base version of BERT model).
[TinyBERT](https://github.com/huawei-noah/Pretrained-Model/tree/master/TinyBERT) is 7.5x smalller and 9.4x faster on inference than [BERT-base](https://github.com/google-research/bert)(the base version of BERT model) and achieves competitive performances in the tasks of natural language understanding. It performs a novel transformer distillation at both the pre-training and task-specific learning stages.
## Requirements
- Install [MindSpore](https://www.mindspore.cn/install/en).
......
......@@ -87,8 +87,10 @@ def run_general_distill():
if args_opt.enable_data_sink == "true":
repeat_count = args_opt.epoch_size * dataset.get_dataset_size() // args_opt.data_sink_steps
time_monitor_steps = args_opt.data_sink_steps
else:
repeat_count = args_opt.epoch_size
time_monitor_steps = dataset_size
lr_schedule = BertLearningRate(learning_rate=common_cfg.AdamWeightDecay.learning_rate,
end_learning_rate=common_cfg.AdamWeightDecay.end_learning_rate,
......@@ -104,10 +106,10 @@ def run_general_distill():
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=common_cfg.AdamWeightDecay.eps)
callback = [TimeMonitor(dataset_size), LossCallBack(), ModelSaveCkpt(netwithloss.bert,
args_opt.save_ckpt_step,
args_opt.max_ckpt_num,
save_ckpt_dir)]
callback = [TimeMonitor(time_monitor_steps), LossCallBack(), ModelSaveCkpt(netwithloss.bert,
args_opt.save_ckpt_step,
args_opt.max_ckpt_num,
save_ckpt_dir)]
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=common_cfg.loss_scale_value,
scale_factor=common_cfg.scale_factor,
......
......@@ -92,8 +92,10 @@ def run_predistill():
dataset_size = dataset.get_dataset_size()
if args_opt.enable_data_sink == 'true':
repeat_count = args_opt.td_phase1_epoch_size * dataset.get_dataset_size() // args_opt.data_sink_steps
time_monitor_steps = args_opt.data_sink_steps
else:
repeat_count = args_opt.td_phase1_epoch_size
time_monitor_steps = dataset_size
optimizer_cfg = cfg.optimizer_cfg
......@@ -110,10 +112,10 @@ def run_predistill():
{'order_params': params}]
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
callback = [TimeMonitor(dataset_size), LossCallBack(), ModelSaveCkpt(netwithloss.bert,
args_opt.save_ckpt_step,
args_opt.max_ckpt_num,
td_phase1_save_ckpt_dir)]
callback = [TimeMonitor(time_monitor_steps), LossCallBack(), ModelSaveCkpt(netwithloss.bert,
args_opt.save_ckpt_step,
args_opt.max_ckpt_num,
td_phase1_save_ckpt_dir)]
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
scale_factor=cfg.scale_factor,
scale_window=cfg.scale_window)
......@@ -147,8 +149,10 @@ def run_task_distill(ckpt_file):
dataset_size = train_dataset.get_dataset_size()
if args_opt.enable_data_sink == 'true':
repeat_count = args_opt.td_phase2_epoch_size * train_dataset.get_dataset_size() // args_opt.data_sink_steps
time_monitor_steps = args_opt.data_sink_steps
else:
repeat_count = args_opt.td_phase2_epoch_size
time_monitor_steps = dataset_size
optimizer_cfg = cfg.optimizer_cfg
......@@ -170,14 +174,14 @@ def run_task_distill(ckpt_file):
device_num, rank, args_opt.do_shuffle,
args_opt.eval_data_dir, args_opt.schema_dir)
if args_opt.do_eval.lower() == "true":
callback = [TimeMonitor(dataset_size), LossCallBack(),
callback = [TimeMonitor(time_monitor_steps), LossCallBack(),
ModelSaveCkpt(netwithloss.bert,
args_opt.save_ckpt_step,
args_opt.max_ckpt_num,
td_phase2_save_ckpt_dir),
EvalCallBack(netwithloss.bert, eval_dataset)]
else:
callback = [TimeMonitor(dataset_size), LossCallBack(),
callback = [TimeMonitor(time_monitor_steps), LossCallBack(),
ModelSaveCkpt(netwithloss.bert,
args_opt.save_ckpt_step,
args_opt.max_ckpt_num,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册