From 12d9c71c8378ca87ac247ece5691f15454b9a109 Mon Sep 17 00:00:00 2001 From: yoonlee666 Date: Thu, 21 May 2020 22:08:36 +0800 Subject: [PATCH] delete dropout when prob equals 0 and adjust bert ci script --- example/bert_clue/dataset.py | 4 ++-- example/bert_clue/run_pretrain.py | 11 ++++++----- mindspore/nn/layer/basic.py | 3 +++ .../networks/models/bert/bert_tdt_lossscale.py | 16 +++++++++------- 4 files changed, 20 insertions(+), 14 deletions(-) diff --git a/example/bert_clue/dataset.py b/example/bert_clue/dataset.py index d54f2a666..9dbe7b8ce 100644 --- a/example/bert_clue/dataset.py +++ b/example/bert_clue/dataset.py @@ -42,7 +42,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e if enable_data_sink == "true": new_size = data_sink_steps * bert_net_cfg.batch_size ds.set_dataset_size(new_size) - repeat_count = int(repeat_count * ori_dataset_size // ds.get_dataset_size()) + new_repeat_count = int(repeat_count * ori_dataset_size // ds.get_dataset_size()) type_cast_op = C.TypeCast(mstype.int32) ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op) ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op) @@ -55,4 +55,4 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e ds = ds.repeat(repeat_count) logger.info("data size: {}".format(ds.get_dataset_size())) logger.info("repeatcount: {}".format(ds.get_repeat_count())) - return ds + return ds, new_repeat_count diff --git a/example/bert_clue/run_pretrain.py b/example/bert_clue/run_pretrain.py index 6b8127dda..c587d41bc 100644 --- a/example/bert_clue/run_pretrain.py +++ b/example/bert_clue/run_pretrain.py @@ -24,7 +24,7 @@ from mindspore import context from mindspore.train.model import Model from mindspore.train.parallel_utils import ParallelMode from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell -from mindspore.train.callback import Callback, ModelCheckpoint, CheckpointConfig +from mindspore.train.callback import Callback, ModelCheckpoint, CheckpointConfig, TimeMonitor from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecayDynamicLR @@ -87,8 +87,9 @@ def run_pretrain(): rank = 0 device_num = 1 - ds = create_bert_dataset(args_opt.epoch_size, device_num, rank, args_opt.do_shuffle, args_opt.enable_data_sink, - args_opt.data_sink_steps, args_opt.data_dir, args_opt.schema_dir) + ds, new_repeat_count = create_bert_dataset(args_opt.epoch_size, device_num, rank, args_opt.do_shuffle, + args_opt.enable_data_sink, args_opt.data_sink_steps, + args_opt.data_dir, args_opt.schema_dir) netwithloss = BertNetworkWithLoss(bert_net_cfg, True) @@ -112,7 +113,7 @@ def run_pretrain(): else: raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecayDynamicLR]". format(cfg.optimizer)) - callback = [LossCallBack()] + callback = [TimeMonitor(ds.get_dataset_size()), LossCallBack()] if args_opt.enable_save_ckpt == "true": config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps, keep_checkpoint_max=args_opt.save_checkpoint_num) @@ -133,6 +134,6 @@ def run_pretrain(): netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer) model = Model(netwithgrads) - model.train(ds.get_repeat_count(), ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true")) + model.train(new_repeat_count, ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true")) if __name__ == '__main__': run_pretrain() diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index 9b12d17c0..24d547c8b 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -99,6 +99,9 @@ class Dropout(Cell): out, _ = self.dropout(x) return out + if self.keep_prob == 1: + return x + shape = self.get_shape(x) dtype = P.DType()(x) keep_prob = self.cast(self.keep_prob, dtype) diff --git a/tests/st/networks/models/bert/bert_tdt_lossscale.py b/tests/st/networks/models/bert/bert_tdt_lossscale.py index f8768cce7..65679b9d5 100644 --- a/tests/st/networks/models/bert/bert_tdt_lossscale.py +++ b/tests/st/networks/models/bert/bert_tdt_lossscale.py @@ -26,7 +26,7 @@ from mindspore import context from mindspore import log as logger from mindspore.common.tensor import Tensor from mindspore.model_zoo.Bert_NEZHA import BertConfig, BertNetworkWithLoss, BertTrainOneStepWithLossScaleCell -from mindspore.nn.optim import Momentum +from mindspore.nn.optim import Lamb from mindspore.train.callback import Callback from mindspore.train.loss_scale_manager import DynamicLossScaleManager from mindspore.train.model import Model @@ -73,7 +73,7 @@ def get_config(version='base', batch_size=1): max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, - use_relative_positions=True, + use_relative_positions=False, input_mask_from_dataset=True, token_type_ids_from_dataset=True, dtype=mstype.float32, @@ -138,7 +138,9 @@ def test_bert_tdt(): batch_size = int(os.getenv('BATCH_SIZE', '16')) config = get_config(version=version, batch_size=batch_size) netwithloss = BertNetworkWithLoss(config, True) - optimizer = Momentum(netwithloss.trainable_params(), learning_rate=2e-5, momentum=0.9) + optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size()*ds.get_repeat_count(), + start_learning_rate=5e-5, end_learning_rate=1e-9, + power=10.0, warmup_steps=0, weight_decay=0.01) scale_window = 3 scale_manager = DynamicLossScaleManager(2 ** 16, 2, scale_window) netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer, @@ -169,10 +171,10 @@ def test_bert_tdt(): # assertion occurs while the loss value, overflow state or loss_scale value is wrong loss_value = np.array(callback.loss_list) - expect_loss_value = [12.191826, 11.966009, 11.972208, 11.98216, 11.973932, 12.611078, 12.17554, 12.840299, - 12.403329, 12.621632] + expect_loss_value = [12.207201, 11.980862, 11.984737, 11.879344, 11.832838, 12.411388, + 12.009449, 12.621273, 12.223175, 12.427313] print("loss value: {}".format(loss_value)) - assert np.allclose(loss_value, expect_loss_value, 0.00001, 0.00001) + assert np.allclose(loss_value, expect_loss_value, 0, 0.0005) overflow = np.array(callback.overflow_list) expect_overflow = [True, True, False, False, False, True, False, False, False, True] @@ -182,7 +184,7 @@ def test_bert_tdt(): loss_scale = np.array(callback.lossscale_list) expect_loss_scale = [32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0] print("loss scale: {}".format(loss_scale)) - assert np.allclose(loss_scale, expect_loss_scale, 0.00001, 0.00001) + assert np.allclose(loss_scale, expect_loss_scale, 0, 0) if __name__ == '__main__': -- GitLab