提交 02ebf03c 编写于 作者: Y yoonlee666 提交者: 高东海

add bert script to master

上级 56ab3a1d
...@@ -14,42 +14,44 @@ ...@@ -14,42 +14,44 @@
# ============================================================================ # ============================================================================
""" """
network config setting, will be used in main.py network config setting, will be used in train.py
""" """
from easydict import EasyDict as edict from easydict import EasyDict as edict
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.model_zoo.Bert_NEZHA import BertConfig from mindspore.model_zoo.Bert_NEZHA import BertConfig
bert_cfg = edict({ bert_train_cfg = edict({
'epoch_size': 10, 'epoch_size': 10,
'num_warmup_steps': 0, 'num_warmup_steps': 0,
'start_learning_rate': 1e-4, 'start_learning_rate': 1e-4,
'end_learning_rate': 1, 'end_learning_rate': 0.0,
'decay_steps': 1000, 'decay_steps': 1000,
'power': 10.0, 'power': 10.0,
'save_checkpoint_steps': 2000, 'save_checkpoint_steps': 2000,
'keep_checkpoint_max': 10, 'keep_checkpoint_max': 10,
'checkpoint_prefix': "checkpoint_bert", 'checkpoint_prefix': "checkpoint_bert",
'DATA_DIR' = "/your/path/examples.tfrecord" # please add your own dataset path
'SCHEMA_DIR' = "/your/path/datasetSchema.json" 'DATA_DIR': "/your/path/examples.tfrecord",
'bert_config': BertConfig( # please add your own dataset schema path
batch_size=16, 'SCHEMA_DIR': "/your/path/datasetSchema.json"
seq_length=128,
vocab_size=21136,
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
intermediate_size=4096,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=True,
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16,
)
}) })
bert_net_cfg = BertConfig(
batch_size=16,
seq_length=128,
vocab_size=21136,
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
intermediate_size=4096,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=True,
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16,
)
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
# ============================================================================ # ============================================================================
""" """
NEZHA (NEural contextualiZed representation for CHinese lAnguage understanding) is the Chinese pretrained language model currently based on BERT developed by Huawei. NEZHA (NEural contextualiZed representation for CHinese lAnguage understanding) is the Chinese pretrained language
model currently based on BERT developed by Huawei.
1. Prepare data 1. Prepare data
Following the data preparation as in BERT, run command as below to get dataset for training: Following the data preparation as in BERT, run command as below to get dataset for training:
python ./create_pretraining_data.py \ python ./create_pretraining_data.py \
...@@ -28,35 +29,29 @@ Following the data preparation as in BERT, run command as below to get dataset f ...@@ -28,35 +29,29 @@ Following the data preparation as in BERT, run command as below to get dataset f
--random_seed=12345 \ --random_seed=12345 \
--dupe_factor=5 --dupe_factor=5
2. Pretrain 2. Pretrain
First, prepare the distributed training environment, then adjust configurations in config.py, finally run main.py. First, prepare the distributed training environment, then adjust configurations in config.py, finally run train.py.
""" """
import os import os
import pytest
import numpy as np import numpy as np
from numpy import allclose from config import bert_train_cfg, bert_net_cfg
from config import bert_cfg as cfg
import mindspore.common.dtype as mstype
import mindspore.dataset.engine.datasets as de import mindspore.dataset.engine.datasets as de
import mindspore._c_dataengine as deMap import mindspore._c_dataengine as deMap
from mindspore import context from mindspore import context
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.callback import Callback from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.model_zoo.Bert_NEZHA import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell
from mindspore.nn.optim import Lamb from mindspore.nn.optim import Lamb
from mindspore import log as logger
_current_dir = os.path.dirname(os.path.realpath(__file__)) _current_dir = os.path.dirname(os.path.realpath(__file__))
DATA_DIR = [cfg.DATA_DIR]
SCHEMA_DIR = cfg.SCHEMA_DIR
def me_de_train_dataset(batch_size): def create_train_dataset(batch_size):
"""test me de train dataset""" """create train dataset"""
# apply repeat operations # apply repeat operations
repeat_count = cfg.epoch_size repeat_count = bert_train_cfg.epoch_size
ds = de.StorageDataset(DATA_DIR, SCHEMA_DIR, columns_list=["input_ids", "input_mask", "segment_ids", ds = de.StorageDataset([bert_train_cfg.DATA_DIR], bert_train_cfg.SCHEMA_DIR,
"next_sentence_labels", "masked_lm_positions", columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
"masked_lm_ids", "masked_lm_weights"]) "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"])
type_cast_op = deMap.TypeCastOp("int32") type_cast_op = deMap.TypeCastOp("int32")
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op) ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op) ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)
...@@ -69,43 +64,32 @@ def me_de_train_dataset(batch_size): ...@@ -69,43 +64,32 @@ def me_de_train_dataset(batch_size):
ds = ds.repeat(repeat_count) ds = ds.repeat(repeat_count)
return ds return ds
def weight_variable(shape): def weight_variable(shape):
"""weight variable""" """weight variable"""
np.random.seed(1) np.random.seed(1)
ones = np.random.uniform(-0.1, 0.1, size=shape).astype(np.float32) ones = np.random.uniform(-0.1, 0.1, size=shape).astype(np.float32)
return Tensor(ones) return Tensor(ones)
def train_bert():
class ModelCallback(Callback): """train bert"""
def __init__(self):
super(ModelCallback, self).__init__()
self.loss_list = []
def step_end(self, run_context):
cb_params = run_context.original_args()
self.loss_list.append(cb_params.net_outputs.asnumpy()[0])
logger.info("epoch: {}, outputs are {}".format(cb_params.cur_epoch_num, str(cb_params.net_outputs)))
def test_bert_tdt():
"""test bert tdt"""
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
context.set_context(device_target="Ascend") context.set_context(device_target="Ascend")
context.set_context(enable_task_sink=True) context.set_context(enable_task_sink=True)
context.set_context(enable_loop_sink=True) context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True) context.set_context(enable_mem_reuse=True)
parallel_callback = ModelCallback() ds = create_train_dataset(bert_net_cfg.batch_size)
ds = me_de_train_dataset(cfg.bert_config.batch_size) netwithloss = BertNetworkWithLoss(bert_net_cfg, True)
config = cfg.bert_config optimizer = Lamb(netwithloss.trainable_params(), decay_steps=bert_train_cfg.decay_steps,
netwithloss = BertNetworkWithLoss(config, True) start_learning_rate=bert_train_cfg.start_learning_rate,
optimizer = Lamb(netwithloss.trainable_params(), decay_steps=cfg.decay_steps, start_learning_rate=cfg.start_learning_rate, end_learning_rate=bert_train_cfg.end_learning_rate, power=bert_train_cfg.power,
end_learning_rate=cfg.end_learning_rate, power=cfg.power, warmup_steps=cfg.num_warmup_steps, decay_filter=lambda x: False) warmup_steps=bert_train_cfg.num_warmup_steps, decay_filter=lambda x: False)
netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer) netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)
netwithgrads.set_train(True) netwithgrads.set_train(True)
model = Model(netwithgrads) model = Model(netwithgrads)
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, keep_checkpoint_max=cfg.keep_checkpoint_max) config_ck = CheckpointConfig(save_checkpoint_steps=bert_train_cfg.save_checkpoint_steps,
ckpoint_cb = ModelCheckpoint(prefix=cfg.checkpoint_prefix, config=config_ck) keep_checkpoint_max=bert_train_cfg.keep_checkpoint_max)
model.train(ds.get_repeat_count(), ds, callbacks=[parallel_callback, ckpoint_cb], dataset_sink_mode=False) ckpoint_cb = ModelCheckpoint(prefix=bert_train_cfg.checkpoint_prefix, config=config_ck)
model.train(ds.get_repeat_count(), ds, callbacks=[LossMonitor(), ckpoint_cb], dataset_sink_mode=False)
if __name__ == '__main__': if __name__ == '__main__':
test_bert_tdt() train_bert()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册