提交 c6d261b2 编写于 作者: Y yoonlee666

add bert script to master

上级 930a1fb0
......@@ -14,42 +14,44 @@
# ============================================================================
"""
network config setting, will be used in main.py
network config setting, will be used in train.py
"""
from easydict import EasyDict as edict
import mindspore.common.dtype as mstype
from mindspore.model_zoo.Bert_NEZHA import BertConfig
bert_cfg = edict({
bert_train_cfg = edict({
'epoch_size': 10,
'num_warmup_steps': 0,
'start_learning_rate': 1e-4,
'end_learning_rate': 1,
'end_learning_rate': 0.0,
'decay_steps': 1000,
'power': 10.0,
'save_checkpoint_steps': 2000,
'keep_checkpoint_max': 10,
'checkpoint_prefix': "checkpoint_bert",
'DATA_DIR' = "/your/path/examples.tfrecord"
'SCHEMA_DIR' = "/your/path/datasetSchema.json"
'bert_config': BertConfig(
batch_size=16,
seq_length=128,
vocab_size=21136,
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
intermediate_size=4096,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=True,
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16,
)
# please add your own dataset path
'DATA_DIR': "/your/path/examples.tfrecord",
# please add your own dataset schema path
'SCHEMA_DIR': "/your/path/datasetSchema.json"
})
bert_net_cfg = BertConfig(
batch_size=16,
seq_length=128,
vocab_size=21136,
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
intermediate_size=4096,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=True,
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16,
)
......@@ -14,7 +14,8 @@
# ============================================================================
"""
NEZHA (NEural contextualiZed representation for CHinese lAnguage understanding) is the Chinese pretrained language model currently based on BERT developed by Huawei.
NEZHA (NEural contextualiZed representation for CHinese lAnguage understanding) is the Chinese pretrained language
model currently based on BERT developed by Huawei.
1. Prepare data
Following the data preparation as in BERT, run command as below to get dataset for training:
python ./create_pretraining_data.py \
......@@ -28,35 +29,29 @@ Following the data preparation as in BERT, run command as below to get dataset f
--random_seed=12345 \
--dupe_factor=5
2. Pretrain
First, prepare the distributed training environment, then adjust configurations in config.py, finally run main.py.
First, prepare the distributed training environment, then adjust configurations in config.py, finally run train.py.
"""
import os
import pytest
import numpy as np
from numpy import allclose
from config import bert_cfg as cfg
import mindspore.common.dtype as mstype
from config import bert_train_cfg, bert_net_cfg
import mindspore.dataset.engine.datasets as de
import mindspore._c_dataengine as deMap
from mindspore import context
from mindspore.common.tensor import Tensor
from mindspore.train.model import Model
from mindspore.train.callback import Callback
from mindspore.model_zoo.Bert_NEZHA import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell
from mindspore.nn.optim import Lamb
from mindspore import log as logger
_current_dir = os.path.dirname(os.path.realpath(__file__))
DATA_DIR = [cfg.DATA_DIR]
SCHEMA_DIR = cfg.SCHEMA_DIR
def me_de_train_dataset(batch_size):
"""test me de train dataset"""
def create_train_dataset(batch_size):
"""create train dataset"""
# apply repeat operations
repeat_count = cfg.epoch_size
ds = de.StorageDataset(DATA_DIR, SCHEMA_DIR, columns_list=["input_ids", "input_mask", "segment_ids",
"next_sentence_labels", "masked_lm_positions",
"masked_lm_ids", "masked_lm_weights"])
repeat_count = bert_train_cfg.epoch_size
ds = de.StorageDataset([bert_train_cfg.DATA_DIR], bert_train_cfg.SCHEMA_DIR,
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"])
type_cast_op = deMap.TypeCastOp("int32")
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)
......@@ -69,43 +64,32 @@ def me_de_train_dataset(batch_size):
ds = ds.repeat(repeat_count)
return ds
def weight_variable(shape):
"""weight variable"""
np.random.seed(1)
ones = np.random.uniform(-0.1, 0.1, size=shape).astype(np.float32)
return Tensor(ones)
class ModelCallback(Callback):
def __init__(self):
super(ModelCallback, self).__init__()
self.loss_list = []
def step_end(self, run_context):
cb_params = run_context.original_args()
self.loss_list.append(cb_params.net_outputs.asnumpy()[0])
logger.info("epoch: {}, outputs are {}".format(cb_params.cur_epoch_num, str(cb_params.net_outputs)))
def test_bert_tdt():
"""test bert tdt"""
def train_bert():
"""train bert"""
context.set_context(mode=context.GRAPH_MODE)
context.set_context(device_target="Ascend")
context.set_context(enable_task_sink=True)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
parallel_callback = ModelCallback()
ds = me_de_train_dataset(cfg.bert_config.batch_size)
config = cfg.bert_config
netwithloss = BertNetworkWithLoss(config, True)
optimizer = Lamb(netwithloss.trainable_params(), decay_steps=cfg.decay_steps, start_learning_rate=cfg.start_learning_rate,
end_learning_rate=cfg.end_learning_rate, power=cfg.power, warmup_steps=cfg.num_warmup_steps, decay_filter=lambda x: False)
ds = create_train_dataset(bert_net_cfg.batch_size)
netwithloss = BertNetworkWithLoss(bert_net_cfg, True)
optimizer = Lamb(netwithloss.trainable_params(), decay_steps=bert_train_cfg.decay_steps,
start_learning_rate=bert_train_cfg.start_learning_rate,
end_learning_rate=bert_train_cfg.end_learning_rate, power=bert_train_cfg.power,
warmup_steps=bert_train_cfg.num_warmup_steps, decay_filter=lambda x: False)
netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)
netwithgrads.set_train(True)
model = Model(netwithgrads)
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix=cfg.checkpoint_prefix, config=config_ck)
model.train(ds.get_repeat_count(), ds, callbacks=[parallel_callback, ckpoint_cb], dataset_sink_mode=False)
config_ck = CheckpointConfig(save_checkpoint_steps=bert_train_cfg.save_checkpoint_steps,
keep_checkpoint_max=bert_train_cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix=bert_train_cfg.checkpoint_prefix, config=config_ck)
model.train(ds.get_repeat_count(), ds, callbacks=[LossMonitor(), ckpoint_cb], dataset_sink_mode=False)
if __name__ == '__main__':
test_bert_tdt()
train_bert()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册