提交 e15cfa24 编写于 作者: P panfengfeng

fix getdatasetsize changes

上级 0f34238a
......@@ -92,7 +92,7 @@ if __name__ == "__main__":
network = LeNet5(cfg.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
repeat_size = cfg.epoch_size
repeat_size = 1
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max)
......
......@@ -84,7 +84,7 @@ if __name__ == "__main__":
network = AlexNet(cfg.num_classes)
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
repeat_size = cfg.epoch_size
repeat_size = 1
# when batch_size=32, steps is 1562
lr = Tensor(get_lr(0, cfg.learning_rate, cfg.epoch_size, 1562))
opt = nn.Momentum(network.trainable_params(), lr, cfg.momentum)
......
......@@ -125,7 +125,7 @@ if __name__ == '__main__':
model = Model(net, loss_fn=ls, optimizer=opt, metrics={'acc'})
if args_opt.do_train:
dataset = create_dataset(epoch_size)
dataset = create_dataset()
batch_num = dataset.get_dataset_size()
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=10)
ckpoint_cb = ModelCheckpoint(prefix="train_resnet_cifar10", directory="./", config=config_ck)
......
......@@ -77,7 +77,7 @@ if __name__ == '__main__':
model = Model(network, loss, opt, {'acc': Accuracy()})
print("============== Starting Training ==============")
ds_train = lstm_create_dataset(args.preprocess_path, cfg.batch_size, cfg.num_epochs)
ds_train = lstm_create_dataset(args.preprocess_path, cfg.batch_size, 1)
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="lstm", directory=args.ckpt_path, config=config_ck)
......
......@@ -28,12 +28,12 @@ from mindspore.train.parallel_utils import ParallelMode
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecayDynamicLR
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay
from mindspore import log as logger
from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
from src.dataset import create_bert_dataset
from src.config import cfg, bert_net_cfg
from src.utils import LossCallBack
from src.utils import LossCallBack, BertLearningRate
_current_dir = os.path.dirname(os.path.realpath(__file__))
......@@ -64,7 +64,6 @@ def run_pretrain():
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
context.set_context(reserve_class_name_in_scope=False)
context.set_context(variable_memory_max_size="30GB")
ckpt_save_dir = args_opt.save_checkpoint_path
if args_opt.distribute == "true":
if args_opt.device_target == 'Ascend':
......@@ -99,35 +98,49 @@ def run_pretrain():
logger.warning('Gpu only support fp32 temporarily, run with fp32.')
bert_net_cfg.compute_type = mstype.float32
ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir)
net_with_loss = BertNetworkWithLoss(bert_net_cfg, True)
ds, new_repeat_count = create_bert_dataset(args_opt.epoch_size, device_num, rank, args_opt.do_shuffle,
args_opt.enable_data_sink, args_opt.data_sink_steps,
args_opt.data_dir, args_opt.schema_dir)
new_repeat_count = args_opt.epoch_size * ds.get_dataset_size() // args_opt.data_sink_steps
if args_opt.train_steps > 0:
new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps)
netwithloss = BertNetworkWithLoss(bert_net_cfg, True)
else:
args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size()
if cfg.optimizer == 'Lamb':
optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size() * new_repeat_count,
start_learning_rate=cfg.Lamb.start_learning_rate, end_learning_rate=cfg.Lamb.end_learning_rate,
power=cfg.Lamb.power, warmup_steps=cfg.Lamb.warmup_steps, weight_decay=cfg.Lamb.weight_decay,
eps=cfg.Lamb.eps)
lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate,
end_learning_rate=cfg.Lamb.end_learning_rate,
warmup_steps=cfg.Lamb.warmup_steps,
decay_steps=args_opt.train_steps,
power=cfg.Lamb.power)
params = net_with_loss.trainable_params()
decay_params = list(filter(cfg.Lamb.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params))
group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay},
{'params': other_params},
{'order_params': params}]
optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps)
elif cfg.optimizer == 'Momentum':
optimizer = Momentum(netwithloss.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
optimizer = Momentum(net_with_loss.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
momentum=cfg.Momentum.momentum)
elif cfg.optimizer == 'AdamWeightDecayDynamicLR':
optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(),
decay_steps=ds.get_dataset_size() * new_repeat_count,
learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate,
end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate,
power=cfg.AdamWeightDecayDynamicLR.power,
weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay,
eps=cfg.AdamWeightDecayDynamicLR.eps,
warmup_steps=cfg.AdamWeightDecayDynamicLR.warmup_steps)
elif cfg.optimizer == 'AdamWeightDecay':
lr_schedule = BertLearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
warmup_steps=cfg.AdamWeightDecay.warmup_steps,
decay_steps=args_opt.train_steps,
power=cfg.AdamWeightDecay.power)
params = net_with_loss.trainable_params()
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params))
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0},
{'order_params': params}]
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
else:
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecayDynamicLR]".
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]".
format(cfg.optimizer))
callback = [TimeMonitor(ds.get_dataset_size()), LossCallBack()]
callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack()]
if args_opt.enable_save_ckpt == "true":
config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
keep_checkpoint_max=args_opt.save_checkpoint_num)
......@@ -136,19 +149,22 @@ def run_pretrain():
if args_opt.load_checkpoint_path:
param_dict = load_checkpoint(args_opt.load_checkpoint_path)
load_param_into_net(netwithloss, param_dict)
load_param_into_net(net_with_loss, param_dict)
if args_opt.enable_lossscale == "true":
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
scale_factor=cfg.scale_factor,
scale_window=cfg.scale_window)
netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer,
scale_update_cell=update_cell)
net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer,
scale_update_cell=update_cell)
else:
netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)
net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer)
model = Model(net_with_grads)
model.train(new_repeat_count, ds, callbacks=callback,
dataset_sink_mode=(args_opt.enable_data_sink == "true"), sink_size=args_opt.data_sink_steps)
model = Model(netwithgrads)
model.train(new_repeat_count, ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"))
if __name__ == '__main__':
numpy.random.seed(0)
run_pretrain()
......@@ -277,8 +277,8 @@ class RelaPosMatrixGenerator(nn.Cell):
def __init__(self, length, max_relative_position):
super(RelaPosMatrixGenerator, self).__init__()
self._length = length
self._max_relative_position = Tensor(max_relative_position, dtype=mstype.int32)
self._min_relative_position = Tensor(-max_relative_position, dtype=mstype.int32)
self._max_relative_position = max_relative_position
self._min_relative_position = -max_relative_position
self.range_length = -length + 1
self.tile = P.Tile()
......@@ -336,9 +336,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell):
self.relative_positions_matrix = RelaPosMatrixGenerator(length=length,
max_relative_position=max_relative_position)
self.reshape = P.Reshape()
self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.one_hot = nn.OneHot(depth=self.vocab_size)
self.shape = P.Shape()
self.gather = P.GatherV2() # index_select
self.matmul = P.BatchMatMul()
......@@ -350,7 +348,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell):
if self.use_one_hot_embeddings:
flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,))
one_hot_relative_positions_matrix = self.one_hot(
flat_relative_positions_matrix, self.vocab_size, self.on_value, self.off_value)
flat_relative_positions_matrix)
embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table)
my_shape = self.shape(relative_positions_matrix_out) + (self.depth,)
embeddings = self.reshape(embeddings, my_shape)
......@@ -372,11 +370,9 @@ class SaturateCast(nn.Cell):
def __init__(self, src_type=mstype.float32, dst_type=mstype.float32):
super(SaturateCast, self).__init__()
np_type = mstype.dtype_to_nptype(dst_type)
min_type = np.finfo(np_type).min
max_type = np.finfo(np_type).max
self.tensor_min_type = Tensor([min_type], dtype=src_type)
self.tensor_max_type = Tensor([max_type], dtype=src_type)
self.tensor_min_type = float(np.finfo(np_type).min)
self.tensor_max_type = float(np.finfo(np_type).max)
self.min_op = P.Minimum()
self.max_op = P.Maximum()
......@@ -442,7 +438,7 @@ class BertAttention(nn.Cell):
self.has_attention_mask = has_attention_mask
self.use_relative_positions = use_relative_positions
self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=compute_type)
self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head))
self.reshape = P.Reshape()
self.shape_from_2d = (-1, from_tensor_width)
self.shape_to_2d = (-1, to_tensor_width)
......@@ -471,7 +467,7 @@ class BertAttention(nn.Cell):
self.trans_shape = (0, 2, 1, 3)
self.trans_shape_relative = (2, 0, 1, 3)
self.trans_shape_position = (1, 2, 0, 3)
self.multiply_data = Tensor([-10000.0,], dtype=compute_type)
self.multiply_data = -10000.0
self.batch_num = batch_size * num_attention_heads
self.matmul = P.BatchMatMul()
......
......@@ -24,20 +24,22 @@ cfg = edict({
'scale_factor': 2,
'scale_window': 1000,
'optimizer': 'Lamb',
'AdamWeightDecayDynamicLR': edict({
'AdamWeightDecay': edict({
'learning_rate': 3e-5,
'end_learning_rate': 1e-10,
'power': 5.0,
'weight_decay': 1e-5,
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
'eps': 1e-6,
'warmup_steps': 10000,
}),
'Lamb': edict({
'start_learning_rate': 3e-5,
'learning_rate': 3e-5,
'end_learning_rate': 1e-10,
'power': 10.0,
'warmup_steps': 10000,
'weight_decay': 0.01,
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
'eps': 1e-6,
}),
'Momentum': edict({
......
......@@ -23,11 +23,9 @@ from mindspore import log as logger
from .config import bert_net_cfg
def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", enable_data_sink="true",
data_sink_steps=1, data_dir=None, schema_dir=None):
def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None, schema_dir=None):
"""create train dataset"""
# apply repeat operations
repeat_count = epoch_size
files = os.listdir(data_dir)
data_files = []
for file_name in files:
......@@ -36,15 +34,10 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
ds = de.TFRecordDataset(data_files, schema_dir if schema_dir != "" else None,
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"],
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank,
shard_equal_rows=True)
shuffle=de.Shuffle.FILES if do_shuffle == "true" else False,
num_shards=device_num, shard_id=rank, shard_equal_rows=True)
ori_dataset_size = ds.get_dataset_size()
print('origin dataset size: ', ori_dataset_size)
new_size = ori_dataset_size
if enable_data_sink == "true":
new_size = data_sink_steps * bert_net_cfg.batch_size
ds.set_dataset_size(new_size)
new_repeat_count = int(repeat_count * ori_dataset_size // ds.get_dataset_size())
type_cast_op = C.TypeCast(mstype.int32)
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)
......@@ -54,10 +47,9 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
# apply batch operations
ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True)
ds = ds.repeat(max(new_repeat_count, repeat_count))
logger.info("data size: {}".format(ds.get_dataset_size()))
logger.info("repeatcount: {}".format(ds.get_repeat_count()))
return ds, new_repeat_count
logger.info("repeat count: {}".format(ds.get_repeat_count()))
return ds
def create_ner_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy",
......
......@@ -18,11 +18,13 @@ Functional Cells used in Bert finetune and evaluation.
"""
import os
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.common import dtype as mstype
from mindspore.train.callback import Callback
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
class CrossEntropyCalculation(nn.Cell):
......@@ -123,3 +125,25 @@ def LoadNewestCkpt(load_finetune_checkpoint_dir, steps_per_epoch, epoch_num, pre
max_num = int(num)
load_finetune_checkpoint_path = os.path.join(load_finetune_checkpoint_dir, filename)
return load_finetune_checkpoint_path
class BertLearningRate(LearningRateSchedule):
"""
Warmup-decay learning rate for Bert network.
"""
def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power):
super(BertLearningRate, self).__init__()
self.warmup_lr = WarmUpLR(learning_rate, warmup_steps)
self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
self.greater = P.Greater()
self.one = Tensor(np.array([1.0]).astype(np.float32))
self.cast = P.Cast()
def construct(self, global_step):
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
warmup_lr = self.warmup_lr(global_step)
decay_lr = self.decay_lr(global_step)
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
return lr
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册