提交 c8221cce 编写于 作者: W wanghua

modify pre-traning script

上级 29dc8048
......@@ -4,20 +4,26 @@ This example implements pre-training, fine-tuning and evaluation of [BERT-base](
## Requirements
- Install [MindSpore](https://www.mindspore.cn/install/en).
- Download the zhwiki dataset from <https://dumps.wikimedia.org/zhwiki> for pre-training. Extract and clean text in the dataset with [WikiExtractor](https://github.com/attardi/wiliextractor). Convert the dataset to TFRecord format and move the files to a specified path.
- Download the zhwiki dataset from <https://dumps.wikimedia.org/zhwiki> for pre-training. Extract and clean text in the dataset with [WikiExtractor](https://github.com/attardi/wil
kiextractor). Convert the dataset to TFRecord format and move the files to a specified path.
- Download the CLUE dataset from <https://www.cluebenchmarks.com> for fine-tuning and evaluation.
> Notes:
If you are running a fine-tuning or evaluation task, prepare the corresponding checkpoint file.
## Running the Example
### Pre-Training
- Set options in `config.py`. Make sure the 'DATA_DIR'(path to the dataset) and 'SCHEMA_DIR'(path to the json schema file) are set to your own path. Click [here](https://www.mindspore.cn/tutorial/zh-CN/master/use/data_preparation/loading_the_datasets.html#tfrecord) for more information about dataset and the json schema file.
- Set options in `config.py`, including lossscale, optimizer and network. Click [here](https://www.mindspore.cn/tutorial/zh-CN/master/use/data_preparation/loading_the_datasets.html#tfrecord) for more information about dataset and the json schema file.
- Run `run_pretrain.py` for pre-training of BERT-base and BERT-NEZHA model.
- Run `run_standalone_pretrain.sh` for non-distributed pre-training of BERT-base and BERT-NEZHA model.
``` bash
python run_pretrain.py --backend=ms
``` bash
sh run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_PATH
```
- Run `run_distribute_pretrain.sh` for distributed pre-training of BERT-base and BERT-NEZHA model.
``` bash
sh run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_HCCL_CONFIG_PATH MINDSPORE_PATH
```
### Fine-Tuning
- Set options in `finetune_config.py`. Make sure the 'data_file', 'schema_file' and 'ckpt_file' are set to your own path, set the 'pre_training_ckpt' to save the checkpoint files generated.
......@@ -40,30 +46,42 @@ This example implements pre-training, fine-tuning and evaluation of [BERT-base](
## Usage
### Pre-Training
```
usage: run_pretrain.py [--backend BACKEND]
optional parameters:
--backend, BACKEND MindSpore backend: ms
usage: run_pretrain.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_num N] [--device_id N]
[--enable_task_sink ENABLE_TASK_SINK] [--enable_loop_sink ENABLE_LOOP_SINK]
[--enable_mem_reuse ENABLE_MEM_REUSE] [--enable_save_ckpt ENABLE_SAVE_CKPT]
[--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE]
[--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N] [--checkpoint_path CHECKPOINT_PATH]
[--save_checkpoint_steps N] [--save_checkpoint_num N]
[--data_dir DATA_DIR] [--schema_dir SCHEMA_DIR]
options:
--distribute pre_training by serveral devices: "true"(training by more than 1 device) | "false", default is "false"
--epoch_size epoch size: N, default is 1
--device_num number of used devices: N, default is 1
--device_id device id: N, default is 0
--enable_task_sink enable task sink: "true" | "false", default is "true"
--enable_loop_sink enable loop sink: "true" | "false", default is "true"
--enable_mem_reuse enable memory reuse: "true" | "false", default is "true"
--enable_save_ckpt enable save checkpoint: "true" | "false", default is "true"
--enable_lossscale enable lossscale: "true" | "false", default is "true"
--do_shuffle enable shuffle: "true" | "false", default is "true"
--enable_data_sink enable data sink: "true" | "false", default is "true"
--data_sink_steps set data sink steps: N, default is 1
--checkpoint_path path to save checkpoint files: PATH, default is ""
--save_checkpoint_steps steps for saving checkpoint files: N, default is 1000
--save_checkpoint_num number for saving checkpoint files: N, default is 1
--data_dir path to dataset directory: PATH, default is ""
--schema_dir path to schema.json file, PATH, default is ""
```
## Options and Parameters
It contains of parameters of BERT model and options for training, which is set in file `config.py`, `finetune_config.py` and `evaluation_config.py` respectively.
### Options:
```
Pre-Training:
bert_network version of BERT model: base | large, default is base
epoch_size repeat counts of training: N, default is 40
dataset_sink_mode use dataset sink mode or not: True | False, default is True
do_shuffle shuffle the dataset or not: True | False, default is True
do_train_with_lossscale use lossscale or not: True | False, default is True
loss_scale_value initial value of loss scale: N, default is 2^32
scale_factor factor used to update loss scale: N, default is 2
scale_window steps for once updatation of loss scale: N, default is 1000
save_checkpoint_steps steps to save a checkpoint: N, default is 2000
keep_checkpoint_max numbers to save checkpoint: N, default is 1
init_ckpt checkpoint file to load: PATH, default is ""
data_dir dataset file to load: PATH, default is "/your/path/cn-wiki-128"
schema_dir dataset schema file to load: PATH, default is "your/path/datasetSchema.json"
scale_window steps for once updatation of loss scale: N, default is 1000
optimizer optimizer used in the network: AdamWerigtDecayDynamicLR | Lamb | Momentum, default is "Lamb"
Fine-Tuning:
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in dataset.py, run_pretrain.py
"""
from easydict import EasyDict as edict
import mindspore.common.dtype as mstype
from mindspore.model_zoo.Bert_NEZHA import BertConfig
cfg = edict({
'bert_network': 'base',
'loss_scale_value': 2**32,
'scale_factor': 2,
'scale_window': 1000,
'optimizer': 'Lamb',
'AdamWeightDecayDynamicLR': edict({
'learning_rate': 3e-5,
'end_learning_rate': 0.0,
'power': 5.0,
'weight_decay': 1e-5,
'eps': 1e-6,
}),
'Lamb': edict({
'start_learning_rate': 3e-5,
'end_learning_rate': 0.0,
'power': 10.0,
'warmup_steps': 10000,
'weight_decay': 0.01,
'eps': 1e-6,
'decay_filter': lambda x: False,
}),
'Momentum': edict({
'learning_rate': 2e-5,
'momentum': 0.9,
}),
})
if cfg.bert_network == 'base':
bert_net_cfg = BertConfig(
batch_size=16,
seq_length=128,
vocab_size=21136,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=False,
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16,
)
else:
bert_net_cfg = BertConfig(
batch_size=16,
seq_length=128,
vocab_size=21136,
hidden_size=1024,
num_hidden_layers=12,
num_attention_heads=16,
intermediate_size=4096,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=True,
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16,
)
......@@ -12,47 +12,37 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
NEZHA (NEural contextualiZed representation for CHinese lAnguage understanding) is the Chinese pretrained language
model currently based on BERT developed by Huawei.
1. Prepare data
Following the data preparation as in BERT, run command as below to get dataset for training:
python ./create_pretraining_data.py \
--input_file=./sample_text.txt \
--output_file=./examples.tfrecord \
--vocab_file=./your/path/vocab.txt \
--do_lower_case=True \
--max_seq_length=128 \
--max_predictions_per_seq=20 \
--masked_lm_prob=0.15 \
--random_seed=12345 \
--dupe_factor=5
2. Pretrain
First, prepare the distributed training environment, then adjust configurations in config.py, finally run train.py.
Data operations, will be used in run_pretrain.py
"""
import os
import numpy as np
from config import bert_train_cfg, bert_net_cfg
import mindspore.common.dtype as mstype
import mindspore.dataset.engine.datasets as de
import mindspore.dataset.transforms.c_transforms as C
from mindspore import context
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore.train.model import Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell
from mindspore.nn.optim import Lamb
_current_dir = os.path.dirname(os.path.realpath(__file__))
from mindspore import log as logger
from config import bert_net_cfg
def create_train_dataset(batch_size):
def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", enable_data_sink="true",
data_sink_steps=1, data_dir=None, schema_dir=None):
"""create train dataset"""
# apply repeat operations
repeat_count = bert_train_cfg.epoch_size
ds = de.TFRecordDataset([bert_train_cfg.DATA_DIR], bert_train_cfg.SCHEMA_DIR,
repeat_count = epoch_size
files = os.listdir(data_dir)
data_files = []
for file_name in files:
data_files.append(data_dir+file_name)
ds = de.TFRecordDataset(data_files, schema_dir,
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"])
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"],
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank,
shard_equal_rows=True)
ori_dataset_size = ds.get_dataset_size()
new_size = ori_dataset_size
if enable_data_sink == "true":
new_size = data_sink_steps * bert_net_cfg.batch_size
ds.set_dataset_size(new_size)
repeat_count = int(repeat_count * ori_dataset_size // ds.get_dataset_size())
type_cast_op = C.TypeCast(mstype.int32)
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)
......@@ -61,36 +51,8 @@ def create_train_dataset(batch_size):
ds = ds.map(input_columns="input_mask", operations=type_cast_op)
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True)
ds = ds.repeat(repeat_count)
logger.info("data size: {}".format(ds.get_dataset_size()))
logger.info("repeatcount: {}".format(ds.get_repeat_count()))
return ds
def weight_variable(shape):
"""weight variable"""
np.random.seed(1)
ones = np.random.uniform(-0.1, 0.1, size=shape).astype(np.float32)
return Tensor(ones)
def train_bert():
"""train bert"""
context.set_context(mode=context.GRAPH_MODE)
context.set_context(device_target="Ascend")
context.set_context(enable_task_sink=True)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
ds = create_train_dataset(bert_net_cfg.batch_size)
netwithloss = BertNetworkWithLoss(bert_net_cfg, True)
optimizer = Lamb(netwithloss.trainable_params(), decay_steps=bert_train_cfg.decay_steps,
start_learning_rate=bert_train_cfg.start_learning_rate,
end_learning_rate=bert_train_cfg.end_learning_rate, power=bert_train_cfg.power,
warmup_steps=bert_train_cfg.num_warmup_steps, decay_filter=lambda x: False)
netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)
netwithgrads.set_train(True)
model = Model(netwithgrads)
config_ck = CheckpointConfig(save_checkpoint_steps=bert_train_cfg.save_checkpoint_steps,
keep_checkpoint_max=bert_train_cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix=bert_train_cfg.checkpoint_prefix, config=config_ck)
model.train(ds.get_repeat_count(), ds, callbacks=[LossMonitor(), ckpoint_cb], dataset_sink_mode=False)
if __name__ == '__main__':
train_bert()
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "sh run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_HCCL_CONFIG_PATH MINDSPORE_PATH"
echo "for example: sh run_distribute_pretrain.sh 8 40 /path/zh-wiki/ /path/Schema.json /path/hccl.json /path/mindspore"
echo "It is better to use absolute path."
echo "=============================================================================================================="
EPOCH_SIZE=$2
DATA_DIR=$3
SCHEMA_DIR=$4
MINDSPORE_PATH=$6
export PYTHONPATH=$MINDSPORE_PATH/build/package:$PYTHONPATH
export MINDSPORE_HCCL_CONFIG_PATH=$5
export RANK_SIZE=$1
for((i=0;i<RANK_SIZE;i++))
do
export DEVICE_ID=$i
start=`expr $i \* 12`
end=`expr $start \+ 11`
cmdopt=$start"-"$end
rm -rf LOG$i
mkdir ./LOG$i
cp *.py ./LOG$i
cd ./LOG$i || exit
export RANK_ID=$i
echo "start training for rank $i, device $DEVICE_ID"
env > env.log
taskset -c $cmdopt python ../run_pretrain.py \
--distribute="true" \
--epoch_size=$EPOCH_SIZE \
--device_id=$DEVICE_ID \
--device_num=$RANK_SIZE \
--enable_task_sink="true" \
--enable_loop_sink="true" \
--enable_mem_reuse="true" \
--enable_save_ckpt="true" \
--enable_lossscale="true" \
--do_shuffle="true" \
--enable_data_sink="true" \
--data_sink_steps=1 \
--checkpoint_path="" \
--save_checkpoint_steps=1000 \
--save_checkpoint_num=1 \
--data_dir=$DATA_DIR \
--schema_dir=$SCHEMA_DIR > log.txt 2>&1 &
cd ../
done
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
#################pre_train bert example on zh-wiki########################
python run_pretrain.py
"""
import os
import argparse
import mindspore.communication.management as D
from mindspore import context
from mindspore.train.model import Model
from mindspore.train.parallel_utils import ParallelMode
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.train.callback import Callback, ModelCheckpoint, CheckpointConfig
from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecayDynamicLR
from dataset import create_bert_dataset
from config import cfg, bert_net_cfg
_current_dir = os.path.dirname(os.path.realpath(__file__))
class LossCallBack(Callback):
"""
Monitor the loss in training.
If the loss in NAN or INF terminating training.
Note:
if per_print_times is 0 do not print loss.
Args:
per_print_times (int): Print loss every times. Default: 1.
"""
def __init__(self, per_print_times=1):
super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0")
self._per_print_times = per_print_times
def step_end(self, run_context):
cb_params = run_context.original_args()
with open("./loss.log", "a+") as f:
f.write("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num,
str(cb_params.net_outputs)))
f.write('\n')
def run_pretrain():
"""pre-train bert_clue"""
parser = argparse.ArgumentParser(description='bert pre_training')
parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.")
parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
parser.add_argument("--enable_task_sink", type=str, default="true", help="Enable task sink, default is true.")
parser.add_argument("--enable_loop_sink", type=str, default="true", help="Enable loop sink, default is true.")
parser.add_argument("--enable_mem_reuse", type=str, default="true", help="Enable mem reuse, default is true.")
parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, default is true.")
parser.add_argument("--enable_lossscale", type=str, default="true", help="Use lossscale or not, default is not.")
parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.")
parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.")
parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.")
parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path")
parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, "
"default is 1000.")
parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.")
parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path")
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
context.set_context(enable_task_sink=(args_opt.enable_task_sink == "true"),
enable_loop_sink=(args_opt.enable_loop_sink == "true"),
enable_mem_reuse=(args_opt.enable_mem_reuse == "true"))
context.set_context(reserve_class_name_in_scope=False)
if args_opt.distribute == "true":
device_num = args_opt.device_num
context.reset_auto_parallel_context()
context.set_context(enable_hccl=True)
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
device_num=device_num)
D.init()
rank = args_opt.device_id % device_num
else:
context.set_context(enable_hccl=False)
rank = 0
device_num = 1
ds = create_bert_dataset(args_opt.epoch_size, device_num, rank, args_opt.do_shuffle, args_opt.enable_data_sink,
args_opt.data_sink_steps, args_opt.data_dir, args_opt.schema_dir)
netwithloss = BertNetworkWithLoss(bert_net_cfg, True)
if cfg.optimizer == 'Lamb':
optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size() * ds.get_repeat_count(),
start_learning_rate=cfg.Lamb.start_learning_rate, end_learning_rate=cfg.Lamb.end_learning_rate,
power=cfg.Lamb.power, warmup_steps=cfg.Lamb.warmup_steps, weight_decay=cfg.Lamb.weight_decay,
eps=cfg.Lamb.eps, decay_filter=cfg.Lamb.decay_filter)
elif cfg.optimizer == 'Momentum':
optimizer = Momentum(netwithloss.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
momentum=cfg.Momentum.momentum)
elif cfg.optimizer == 'AdamWeightDecayDynamicLR':
optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(),
decay_steps=ds.get_dataset_size() * ds.get_repeat_count(),
learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate,
end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate,
power=cfg.AdamWeightDecayDynamicLR.power,
weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay,
eps=cfg.AdamWeightDecayDynamicLR.eps)
else:
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecayDynamicLR]".
format(cfg.optimizer))
callback = [LossCallBack()]
if args_opt.enable_save_ckpt == "true":
config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
keep_checkpoint_max=args_opt.save_checkpoint_num)
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert', config=config_ck)
callback.append(ckpoint_cb)
if args_opt.checkpoint_path:
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(netwithloss, param_dict)
if args_opt.enable_lossscale == "true":
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
scale_factor=cfg.scale_factor,
scale_window=cfg.scale_window)
netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer,
scale_update_cell=update_cell)
else:
netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)
model = Model(netwithgrads)
model.train(ds.get_repeat_count(), ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"))
if __name__ == '__main__':
run_pretrain()
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -13,45 +14,33 @@
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py
"""
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "sh run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_PATH"
echo "for example: sh run_standalone_pretrain.sh 0 40 /path/zh-wiki/ /path/Schema.json /path/mindspore"
echo "=============================================================================================================="
from easydict import EasyDict as edict
import mindspore.common.dtype as mstype
from mindspore.model_zoo.Bert_NEZHA import BertConfig
bert_train_cfg = edict({
'epoch_size': 10,
'num_warmup_steps': 0,
'start_learning_rate': 1e-4,
'end_learning_rate': 0.0,
'decay_steps': 1000,
'power': 10.0,
'save_checkpoint_steps': 2000,
'keep_checkpoint_max': 10,
'checkpoint_prefix': "checkpoint_bert",
# please add your own dataset path
'DATA_DIR': "/your/path/examples.tfrecord",
# please add your own dataset schema path
'SCHEMA_DIR': "/your/path/datasetSchema.json"
})
bert_net_cfg = BertConfig(
batch_size=16,
seq_length=128,
vocab_size=21136,
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
intermediate_size=4096,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=True,
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16,
)
DEVICE_ID=$1
EPOCH_SIZE=$2
DATA_DIR=$3
SCHEMA_DIR=$4
MINDSPORE_PATH=$5
export PYTHONPATH=$MINDSPORE_PATH/build/package:$PYTHONPATH
python run_pretrain.py \
--distribute="false" \
--epoch_size=$EPOCH_SIZE \
--device_id=$DEVICE_ID \
--enable_task_sink="true" \
--enable_loop_sink="true" \
--enable_mem_reuse="true" \
--enable_save_ckpt="true" \
--enable_lossscale="true" \
--do_shuffle="true" \
--enable_data_sink="true" \
--data_sink_steps=1 \
--checkpoint_path="" \
--save_checkpoint_steps=1000 \
--save_checkpoint_num=1 \
--data_dir=$DATA_DIR \
--schema_dir=$SCHEMA_DIR > log.txt 2>&1 &
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册