diff --git a/model_zoo/official/nlp/bert/run_pretrain.py b/model_zoo/official/nlp/bert/run_pretrain.py index b8fab8ad2ec2365fe914d4f8000512ecc1e36595..bab9fa2f5830f2bbb33efa46d85032bb5ac0ed65 100644 --- a/model_zoo/official/nlp/bert/run_pretrain.py +++ b/model_zoo/official/nlp/bert/run_pretrain.py @@ -64,7 +64,6 @@ def run_pretrain(): args_opt = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) context.set_context(reserve_class_name_in_scope=False) - context.set_context(variable_memory_max_size="30GB") ckpt_save_dir = args_opt.save_checkpoint_path if args_opt.distribute == "true": if args_opt.device_target == 'Ascend': @@ -99,47 +98,49 @@ def run_pretrain(): logger.warning('Gpu only support fp32 temporarily, run with fp32.') bert_net_cfg.compute_type = mstype.float32 + ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir) + net_with_loss = BertNetworkWithLoss(bert_net_cfg, True) - ds = create_bert_dataset(1, device_num, rank, args_opt.do_shuffle, - args_opt.enable_data_sink, args_opt.data_sink_steps, - args_opt.data_dir, args_opt.schema_dir) - new_repeat_count = args_opt.epoch_size + new_repeat_count = args_opt.epoch_size * ds.get_dataset_size() // args_opt.data_sink_steps if args_opt.train_steps > 0: - new_repeat_count = min(args_opt.epoch_size, args_opt.train_steps // args_opt.data_sink_steps) - netwithloss = BertNetworkWithLoss(bert_net_cfg, True) + new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps) + else: + args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size() if cfg.optimizer == 'Lamb': lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate, end_learning_rate=cfg.Lamb.end_learning_rate, warmup_steps=cfg.Lamb.warmup_steps, - decay_steps=ds.get_dataset_size() * new_repeat_count, + decay_steps=args_opt.train_steps, power=cfg.Lamb.power) params = net_with_loss.trainable_params() decay_params = list(filter(cfg.Lamb.decay_filter, params)) other_params = list(filter(lambda x: x not in decay_params, params)) group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay}, - {'params': other_params}] + {'params': other_params}, + {'order_params': params}] optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps) elif cfg.optimizer == 'Momentum': - optimizer = Momentum(netwithloss.trainable_params(), learning_rate=cfg.Momentum.learning_rate, + optimizer = Momentum(net_with_loss.trainable_params(), learning_rate=cfg.Momentum.learning_rate, momentum=cfg.Momentum.momentum) elif cfg.optimizer == 'AdamWeightDecay': lr_schedule = BertLearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate, end_learning_rate=cfg.AdamWeightDecay.end_learning_rate, warmup_steps=cfg.AdamWeightDecay.warmup_steps, - decay_steps=ds.get_dataset_size() * new_repeat_count, + decay_steps=args_opt.train_steps, power=cfg.AdamWeightDecay.power) params = net_with_loss.trainable_params() decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params)) other_params = list(filter(lambda x: x not in decay_params, params)) group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay}, - {'params': other_params, 'weight_decay': 0.0}] + {'params': other_params, 'weight_decay': 0.0}, + {'order_params': params}] optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) else: raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]". format(cfg.optimizer)) - callback = [TimeMonitor(ds.get_dataset_size()), LossCallBack()] + callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack()] if args_opt.enable_save_ckpt == "true": config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps, keep_checkpoint_max=args_opt.save_checkpoint_num) @@ -148,19 +149,22 @@ def run_pretrain(): if args_opt.load_checkpoint_path: param_dict = load_checkpoint(args_opt.load_checkpoint_path) - load_param_into_net(netwithloss, param_dict) + load_param_into_net(net_with_loss, param_dict) if args_opt.enable_lossscale == "true": update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value, scale_factor=cfg.scale_factor, scale_window=cfg.scale_window) - netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer, - scale_update_cell=update_cell) + net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer, + scale_update_cell=update_cell) else: - netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer) + net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer) + + model = Model(net_with_grads) + model.train(new_repeat_count, ds, callbacks=callback, + dataset_sink_mode=(args_opt.enable_data_sink == "true"), sink_size=args_opt.data_sink_steps) + - model = Model(netwithgrads) - model.train(new_repeat_count, ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true")) if __name__ == '__main__': numpy.random.seed(0) run_pretrain() diff --git a/model_zoo/official/nlp/bert/src/dataset.py b/model_zoo/official/nlp/bert/src/dataset.py index 097b2c1e8954c9fbeac6ff97fb6f1d651b43bb0d..5b922b9f0b90aa6ff87589a894d470736ca9c598 100644 --- a/model_zoo/official/nlp/bert/src/dataset.py +++ b/model_zoo/official/nlp/bert/src/dataset.py @@ -23,11 +23,9 @@ from mindspore import log as logger from .config import bert_net_cfg -def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", enable_data_sink="true", - data_sink_steps=1, data_dir=None, schema_dir=None): +def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None, schema_dir=None): """create train dataset""" # apply repeat operations - repeat_count = epoch_size files = os.listdir(data_dir) data_files = [] for file_name in files: @@ -40,11 +38,6 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e num_shards=device_num, shard_id=rank, shard_equal_rows=True) ori_dataset_size = ds.get_dataset_size() print('origin dataset size: ', ori_dataset_size) - new_size = ori_dataset_size - if enable_data_sink == "true": - new_size = data_sink_steps * bert_net_cfg.batch_size - ds.set_dataset_size(new_size) - new_repeat_count = int(repeat_count * ori_dataset_size // ds.get_dataset_size()) type_cast_op = C.TypeCast(mstype.int32) ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op) ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op) @@ -55,8 +48,8 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e # apply batch operations ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True) logger.info("data size: {}".format(ds.get_dataset_size())) - logger.info("repeatcount: {}".format(ds.get_repeat_count())) - return ds, new_repeat_count + logger.info("repeat count: {}".format(ds.get_repeat_count())) + return ds def create_ner_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy", diff --git a/model_zoo/official/nlp/bert/src/utils.py b/model_zoo/official/nlp/bert/src/utils.py index 775931b23af169fb6dc1b5a927536d4a66a8e642..6e8ea6ed643251aac402aa19d4e11ba57b147fa8 100644 --- a/model_zoo/official/nlp/bert/src/utils.py +++ b/model_zoo/official/nlp/bert/src/utils.py @@ -18,6 +18,7 @@ Functional Cells used in Bert finetune and evaluation. """ import os +import numpy as np import mindspore.nn as nn from mindspore.ops import operations as P from mindspore.common.tensor import Tensor