提交 12d9c71c 编写于 作者: Y yoonlee666

delete dropout when prob equals 0 and adjust bert ci script

上级 2224fa09
......@@ -42,7 +42,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
if enable_data_sink == "true":
new_size = data_sink_steps * bert_net_cfg.batch_size
ds.set_dataset_size(new_size)
repeat_count = int(repeat_count * ori_dataset_size // ds.get_dataset_size())
new_repeat_count = int(repeat_count * ori_dataset_size // ds.get_dataset_size())
type_cast_op = C.TypeCast(mstype.int32)
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)
......@@ -55,4 +55,4 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
ds = ds.repeat(repeat_count)
logger.info("data size: {}".format(ds.get_dataset_size()))
logger.info("repeatcount: {}".format(ds.get_repeat_count()))
return ds
return ds, new_repeat_count
......@@ -24,7 +24,7 @@ from mindspore import context
from mindspore.train.model import Model
from mindspore.train.parallel_utils import ParallelMode
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.train.callback import Callback, ModelCheckpoint, CheckpointConfig
from mindspore.train.callback import Callback, ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecayDynamicLR
......@@ -87,8 +87,9 @@ def run_pretrain():
rank = 0
device_num = 1
ds = create_bert_dataset(args_opt.epoch_size, device_num, rank, args_opt.do_shuffle, args_opt.enable_data_sink,
args_opt.data_sink_steps, args_opt.data_dir, args_opt.schema_dir)
ds, new_repeat_count = create_bert_dataset(args_opt.epoch_size, device_num, rank, args_opt.do_shuffle,
args_opt.enable_data_sink, args_opt.data_sink_steps,
args_opt.data_dir, args_opt.schema_dir)
netwithloss = BertNetworkWithLoss(bert_net_cfg, True)
......@@ -112,7 +113,7 @@ def run_pretrain():
else:
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecayDynamicLR]".
format(cfg.optimizer))
callback = [LossCallBack()]
callback = [TimeMonitor(ds.get_dataset_size()), LossCallBack()]
if args_opt.enable_save_ckpt == "true":
config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
keep_checkpoint_max=args_opt.save_checkpoint_num)
......@@ -133,6 +134,6 @@ def run_pretrain():
netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)
model = Model(netwithgrads)
model.train(ds.get_repeat_count(), ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"))
model.train(new_repeat_count, ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"))
if __name__ == '__main__':
run_pretrain()
......@@ -99,6 +99,9 @@ class Dropout(Cell):
out, _ = self.dropout(x)
return out
if self.keep_prob == 1:
return x
shape = self.get_shape(x)
dtype = P.DType()(x)
keep_prob = self.cast(self.keep_prob, dtype)
......
......@@ -26,7 +26,7 @@ from mindspore import context
from mindspore import log as logger
from mindspore.common.tensor import Tensor
from mindspore.model_zoo.Bert_NEZHA import BertConfig, BertNetworkWithLoss, BertTrainOneStepWithLossScaleCell
from mindspore.nn.optim import Momentum
from mindspore.nn.optim import Lamb
from mindspore.train.callback import Callback
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
from mindspore.train.model import Model
......@@ -73,7 +73,7 @@ def get_config(version='base', batch_size=1):
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=True,
use_relative_positions=False,
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
......@@ -138,7 +138,9 @@ def test_bert_tdt():
batch_size = int(os.getenv('BATCH_SIZE', '16'))
config = get_config(version=version, batch_size=batch_size)
netwithloss = BertNetworkWithLoss(config, True)
optimizer = Momentum(netwithloss.trainable_params(), learning_rate=2e-5, momentum=0.9)
optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size()*ds.get_repeat_count(),
start_learning_rate=5e-5, end_learning_rate=1e-9,
power=10.0, warmup_steps=0, weight_decay=0.01)
scale_window = 3
scale_manager = DynamicLossScaleManager(2 ** 16, 2, scale_window)
netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer,
......@@ -169,10 +171,10 @@ def test_bert_tdt():
# assertion occurs while the loss value, overflow state or loss_scale value is wrong
loss_value = np.array(callback.loss_list)
expect_loss_value = [12.191826, 11.966009, 11.972208, 11.98216, 11.973932, 12.611078, 12.17554, 12.840299,
12.403329, 12.621632]
expect_loss_value = [12.207201, 11.980862, 11.984737, 11.879344, 11.832838, 12.411388,
12.009449, 12.621273, 12.223175, 12.427313]
print("loss value: {}".format(loss_value))
assert np.allclose(loss_value, expect_loss_value, 0.00001, 0.00001)
assert np.allclose(loss_value, expect_loss_value, 0, 0.0005)
overflow = np.array(callback.overflow_list)
expect_overflow = [True, True, False, False, False, True, False, False, False, True]
......@@ -182,7 +184,7 @@ def test_bert_tdt():
loss_scale = np.array(callback.lossscale_list)
expect_loss_scale = [32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0]
print("loss scale: {}".format(loss_scale))
assert np.allclose(loss_scale, expect_loss_scale, 0.00001, 0.00001)
assert np.allclose(loss_scale, expect_loss_scale, 0, 0)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册