提交 6fdf3809 编写于 作者: C chenhaozhe

fix bert scripts to adapt the new concept of repeatcount in minddata

上级 ad651f38
......@@ -64,7 +64,6 @@ def run_pretrain():
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
context.set_context(reserve_class_name_in_scope=False)
context.set_context(variable_memory_max_size="30GB")
ckpt_save_dir = args_opt.save_checkpoint_path
if args_opt.distribute == "true":
if args_opt.device_target == 'Ascend':
......@@ -99,47 +98,49 @@ def run_pretrain():
logger.warning('Gpu only support fp32 temporarily, run with fp32.')
bert_net_cfg.compute_type = mstype.float32
ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir)
net_with_loss = BertNetworkWithLoss(bert_net_cfg, True)
ds = create_bert_dataset(1, device_num, rank, args_opt.do_shuffle,
args_opt.enable_data_sink, args_opt.data_sink_steps,
args_opt.data_dir, args_opt.schema_dir)
new_repeat_count = args_opt.epoch_size
new_repeat_count = args_opt.epoch_size * ds.get_dataset_size() // args_opt.data_sink_steps
if args_opt.train_steps > 0:
new_repeat_count = min(args_opt.epoch_size, args_opt.train_steps // args_opt.data_sink_steps)
netwithloss = BertNetworkWithLoss(bert_net_cfg, True)
new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps)
else:
args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size()
if cfg.optimizer == 'Lamb':
lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate,
end_learning_rate=cfg.Lamb.end_learning_rate,
warmup_steps=cfg.Lamb.warmup_steps,
decay_steps=ds.get_dataset_size() * new_repeat_count,
decay_steps=args_opt.train_steps,
power=cfg.Lamb.power)
params = net_with_loss.trainable_params()
decay_params = list(filter(cfg.Lamb.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params))
group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay},
{'params': other_params}]
{'params': other_params},
{'order_params': params}]
optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps)
elif cfg.optimizer == 'Momentum':
optimizer = Momentum(netwithloss.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
optimizer = Momentum(net_with_loss.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
momentum=cfg.Momentum.momentum)
elif cfg.optimizer == 'AdamWeightDecay':
lr_schedule = BertLearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
warmup_steps=cfg.AdamWeightDecay.warmup_steps,
decay_steps=ds.get_dataset_size() * new_repeat_count,
decay_steps=args_opt.train_steps,
power=cfg.AdamWeightDecay.power)
params = net_with_loss.trainable_params()
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params))
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0}]
{'params': other_params, 'weight_decay': 0.0},
{'order_params': params}]
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
else:
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]".
format(cfg.optimizer))
callback = [TimeMonitor(ds.get_dataset_size()), LossCallBack()]
callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack()]
if args_opt.enable_save_ckpt == "true":
config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
keep_checkpoint_max=args_opt.save_checkpoint_num)
......@@ -148,19 +149,22 @@ def run_pretrain():
if args_opt.load_checkpoint_path:
param_dict = load_checkpoint(args_opt.load_checkpoint_path)
load_param_into_net(netwithloss, param_dict)
load_param_into_net(net_with_loss, param_dict)
if args_opt.enable_lossscale == "true":
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
scale_factor=cfg.scale_factor,
scale_window=cfg.scale_window)
netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer,
scale_update_cell=update_cell)
net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer,
scale_update_cell=update_cell)
else:
netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)
net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer)
model = Model(net_with_grads)
model.train(new_repeat_count, ds, callbacks=callback,
dataset_sink_mode=(args_opt.enable_data_sink == "true"), sink_size=args_opt.data_sink_steps)
model = Model(netwithgrads)
model.train(new_repeat_count, ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"))
if __name__ == '__main__':
numpy.random.seed(0)
run_pretrain()
......@@ -23,11 +23,9 @@ from mindspore import log as logger
from .config import bert_net_cfg
def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", enable_data_sink="true",
data_sink_steps=1, data_dir=None, schema_dir=None):
def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None, schema_dir=None):
"""create train dataset"""
# apply repeat operations
repeat_count = epoch_size
files = os.listdir(data_dir)
data_files = []
for file_name in files:
......@@ -40,11 +38,6 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
num_shards=device_num, shard_id=rank, shard_equal_rows=True)
ori_dataset_size = ds.get_dataset_size()
print('origin dataset size: ', ori_dataset_size)
new_size = ori_dataset_size
if enable_data_sink == "true":
new_size = data_sink_steps * bert_net_cfg.batch_size
ds.set_dataset_size(new_size)
new_repeat_count = int(repeat_count * ori_dataset_size // ds.get_dataset_size())
type_cast_op = C.TypeCast(mstype.int32)
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)
......@@ -55,8 +48,8 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
# apply batch operations
ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True)
logger.info("data size: {}".format(ds.get_dataset_size()))
logger.info("repeatcount: {}".format(ds.get_repeat_count()))
return ds, new_repeat_count
logger.info("repeat count: {}".format(ds.get_repeat_count()))
return ds
def create_ner_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy",
......
......@@ -18,6 +18,7 @@ Functional Cells used in Bert finetune and evaluation.
"""
import os
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册