提交 83e17ceb 编写于 作者: Y yoonlee666

update bert scripts

上级 09083fe2
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
NEZHA (NEural contextualiZed representation for CHinese lAnguage understanding) is the Chinese pretrained language
model currently based on BERT developed by Huawei.
1. Prepare data
Following the data preparation as in BERT, run command as below to get dataset for training:
python ./create_pretraining_data.py \
--input_file=./sample_text.txt \
--output_file=./examples.tfrecord \
--vocab_file=./your/path/vocab.txt \
--do_lower_case=True \
--max_seq_length=128 \
--max_predictions_per_seq=20 \
--masked_lm_prob=0.15 \
--random_seed=12345 \
--dupe_factor=5
2. Pretrain
First, prepare the distributed training environment, then adjust configurations in config.py, finally run train.py.
"""
import os
import numpy as np
from config import bert_train_cfg, bert_net_cfg
import mindspore.dataset.engine.datasets as de
import mindspore.dataset.transforms.c_transforms as C
from mindspore import context
from mindspore.common.tensor import Tensor
from mindspore.train.model import Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell
from mindspore.nn.optim import Lamb
_current_dir = os.path.dirname(os.path.realpath(__file__))
def create_train_dataset(batch_size):
"""create train dataset"""
# apply repeat operations
repeat_count = bert_train_cfg.epoch_size
ds = de.StorageDataset([bert_train_cfg.DATA_DIR], bert_train_cfg.SCHEMA_DIR,
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"])
type_cast_op = C.TypeCast(mstype.int32)
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)
ds = ds.map(input_columns="next_sentence_labels", operations=type_cast_op)
ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
ds = ds.map(input_columns="input_mask", operations=type_cast_op)
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(repeat_count)
return ds
def weight_variable(shape):
"""weight variable"""
np.random.seed(1)
ones = np.random.uniform(-0.1, 0.1, size=shape).astype(np.float32)
return Tensor(ones)
def train_bert():
"""train bert"""
context.set_context(mode=context.GRAPH_MODE)
context.set_context(device_target="Ascend")
ds = create_train_dataset(bert_net_cfg.batch_size)
netwithloss = BertNetworkWithLoss(bert_net_cfg, True)
optimizer = Lamb(netwithloss.trainable_params(), decay_steps=bert_train_cfg.decay_steps,
start_learning_rate=bert_train_cfg.start_learning_rate,
end_learning_rate=bert_train_cfg.end_learning_rate, power=bert_train_cfg.power,
warmup_steps=bert_train_cfg.num_warmup_steps, decay_filter=lambda x: False)
netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)
netwithgrads.set_train(True)
model = Model(netwithgrads)
config_ck = CheckpointConfig(save_checkpoint_steps=bert_train_cfg.save_checkpoint_steps,
keep_checkpoint_max=bert_train_cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix=bert_train_cfg.checkpoint_prefix, config=config_ck)
model.train(ds.get_repeat_count(), ds, callbacks=[LossMonitor(), ckpoint_cb], dataset_sink_mode=False)
if __name__ == '__main__':
train_bert()
# Bert NEZHA
`NEZHA` (**NE**ural contextuali**Z**ed representation for C**H**inese l**A**nguage understanding) is the Chinese pretrained language model currently based on BERT developed by Huawei.
# BERT Example
## Description
This example implements pre-training[BERT-base](https://github.com/google-research/bert)(the base version of BERT model) and [BERT-NEZHA](https://github.com/huawei-noah/Pretrained-Language-Model)(a Chinese pretrained language model developed by Huawei, which introduced a improvement of Functional Relative Positional Encoding as an effective positional encoding scheme).
## Requirements
- Install [MindSpore](https://www.mindspore.cn/install/en).
- Download the zhwiki dataset for pre-training. Extract and clean text in the dataset with [WikiExtractor](https://github.com/attardi/wikiextractor). Convert the dataset to TFRecord format and move the files to a specified path.
## Running the Example
### Pre-Training
- Set options in `config.py`, including lossscale, optimizer and network. Click [here](https://www.mindspore.cn/tutorial/zh-CN/master/use/data_preparation/loading_the_datasets.html#tfrecord) for more information about dataset and the json schema file.
- Run `run_standalone_pretrain.sh` for non-distributed pre-training of BERT-base and BERT-NEZHA model.
``` bash
sh scripts/run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR
```
- Run `run_distribute_pretrain.sh` for distributed pre-training of BERT-base and BERT-NEZHA model.
``` bash
sh scripts/run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_HCCL_CONFIG_PATH
```
## Usage
### Pre-Training
```
usage: run_pretrain.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_num N] [--device_id N]
[--enable_save_ckpt ENABLE_SAVE_CKPT]
[--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE]
[--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N] [--checkpoint_path CHECKPOINT_PATH]
[--save_checkpoint_steps N] [--save_checkpoint_num N]
[--data_dir DATA_DIR] [--schema_dir SCHEMA_DIR]
options:
--distribute pre_training by serveral devices: "true"(training by more than 1 device) | "false", default is "false"
--epoch_size epoch size: N, default is 1
--device_num number of used devices: N, default is 1
--device_id device id: N, default is 0
--enable_save_ckpt enable save checkpoint: "true" | "false", default is "true"
--enable_lossscale enable lossscale: "true" | "false", default is "true"
--do_shuffle enable shuffle: "true" | "false", default is "true"
--enable_data_sink enable data sink: "true" | "false", default is "true"
--data_sink_steps set data sink steps: N, default is 1
--checkpoint_path path to save checkpoint files: PATH, default is ""
--save_checkpoint_steps steps for saving checkpoint files: N, default is 1000
--save_checkpoint_num number for saving checkpoint files: N, default is 1
--data_dir path to dataset directory: PATH, default is ""
--schema_dir path to schema.json file, PATH, default is ""
```
## Options and Parameters
It contains of parameters of BERT model and options for training, which is set in file `config.py`.
### Options:
```
config.py:
bert_network version of BERT model: base | nezha, default is base
loss_scale_value initial value of loss scale: N, default is 2^32
scale_factor factor used to update loss scale: N, default is 2
scale_window steps for once updatation of loss scale: N, default is 1000
optimizer optimizer used in the network: AdamWerigtDecayDynamicLR | Lamb | Momentum, default is "Lamb"
```
### Parameters:
```
Parameters for dataset and network (Pre-Training):
batch_size batch size of input dataset: N, default is 16
seq_length length of input sequence: N, default is 128
vocab_size size of each embedding vector: N, must be consistant with the dataset you use. Default is 21136
hidden_size size of bert encoder layers: N, default is 768
num_hidden_layers number of hidden layers: N, default is 12
num_attention_heads number of attention heads: N, default is 12
intermediate_size size of intermediate layer: N, default is 3072
hidden_act activation function used: ACTIVATION, default is "gelu"
hidden_dropout_prob dropout probability for BertOutput: Q, default is 0.1
attention_probs_dropout_prob dropout probability for BertAttention: Q, default is 0.1
max_position_embeddings maximum length of sequences: N, default is 512
type_vocab_size size of token type vocab: N, default is 16
initializer_range initialization value of TruncatedNormal: Q, default is 0.02
use_relative_positions use relative positions or not: True | False, default is False
input_mask_from_dataset use the input mask loaded form dataset or not: True | False, default is True
token_type_ids_from_dataset use the token type ids loaded from dataset or not: True | False, default is True
dtype data type of input: mstype.float16 | mstype.float32, default is mstype.float32
compute_type compute type in BertTransformer: mstype.float16 | mstype.float32, default is mstype.float16
Parameters for optimizer:
AdamWeightDecayDynamicLR:
decay_steps steps of the learning rate decay: N
learning_rate value of learning rate: Q
end_learning_rate value of end learning rate: Q, must be positive
power power: Q
warmup_steps steps of the learning rate warm up: N
weight_decay weight decay: Q
eps term added to the denominator to improve numerical stability: Q
Lamb:
decay_steps steps of the learning rate decay: N
learning_rate value of learning rate: Q
end_learning_rate value of end learning rate: Q
power power: Q
warmup_steps steps of the learning rate warm up: N
weight_decay weight decay: Q
Momentum:
learning_rate value of learning rate: Q
momentum momentum for the moving average: Q
```
- `Bert_NEZHA`: Source of NEZHA model same as the one from `mindspore.model_zoo.Bert_NEZHA`
- `Bert_NEZHA_cnwiki`: The NEZHA pretraining example using data from cnwiki.
\ No newline at end of file
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
#################pre_train bert example on zh-wiki########################
python run_pretrain.py
"""
import os
import argparse
import numpy
import mindspore.communication.management as D
import mindspore.common.dtype as mstype
from mindspore import context
from mindspore.train.model import Model
from mindspore.train.parallel_utils import ParallelMode
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecayDynamicLR
from mindspore import log as logger
from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
from src.dataset import create_bert_dataset
from src.config import cfg, bert_net_cfg
from src.utils import LossCallBack
_current_dir = os.path.dirname(os.path.realpath(__file__))
def run_pretrain():
"""pre-train bert_clue"""
parser = argparse.ArgumentParser(description='bert pre_training')
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
help='device where the code will be implemented. (Default: Ascend)')
parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.")
parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, default is true.")
parser.add_argument("--enable_lossscale", type=str, default="true", help="Use lossscale or not, default is not.")
parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.")
parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.")
parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.")
parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path")
parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path")
parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, "
"default is 1000.")
parser.add_argument("--train_steps", type=int, default=-1, help="Training Steps, default is -1, "
"meaning run all steps according to epoch number.")
parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.")
parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path")
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
context.set_context(reserve_class_name_in_scope=False)
context.set_context(variable_memory_max_size="30GB")
ckpt_save_dir = args_opt.save_checkpoint_path
if args_opt.distribute == "true":
if args_opt.device_target == 'Ascend':
D.init('hccl')
device_num = args_opt.device_num
rank = args_opt.device_id % device_num
else:
D.init('nccl')
device_num = D.get_group_size()
rank = D.get_rank()
ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(rank) + '/'
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
device_num=device_num)
from mindspore.parallel._auto_parallel_context import auto_parallel_context
if bert_net_cfg.num_hidden_layers == 12:
if bert_net_cfg.use_relative_positions:
auto_parallel_context().set_all_reduce_fusion_split_indices([29, 58, 87, 116, 145, 174, 203, 217])
else:
auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205])
elif bert_net_cfg.num_hidden_layers == 24:
if bert_net_cfg.use_relative_positions:
auto_parallel_context().set_all_reduce_fusion_split_indices([30, 90, 150, 210, 270, 330, 390, 421])
else:
auto_parallel_context().set_all_reduce_fusion_split_indices([38, 93, 148, 203, 258, 313, 368, 397])
else:
rank = 0
device_num = 1
if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32:
logger.warning('Gpu only support fp32 temporarily, run with fp32.')
bert_net_cfg.compute_type = mstype.float32
ds, new_repeat_count = create_bert_dataset(args_opt.epoch_size, device_num, rank, args_opt.do_shuffle,
args_opt.enable_data_sink, args_opt.data_sink_steps,
args_opt.data_dir, args_opt.schema_dir)
if args_opt.train_steps > 0:
new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps)
netwithloss = BertNetworkWithLoss(bert_net_cfg, True)
if cfg.optimizer == 'Lamb':
optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size() * new_repeat_count,
start_learning_rate=cfg.Lamb.start_learning_rate, end_learning_rate=cfg.Lamb.end_learning_rate,
power=cfg.Lamb.power, warmup_steps=cfg.Lamb.warmup_steps, weight_decay=cfg.Lamb.weight_decay,
eps=cfg.Lamb.eps)
elif cfg.optimizer == 'Momentum':
optimizer = Momentum(netwithloss.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
momentum=cfg.Momentum.momentum)
elif cfg.optimizer == 'AdamWeightDecayDynamicLR':
optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(),
decay_steps=ds.get_dataset_size() * new_repeat_count,
learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate,
end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate,
power=cfg.AdamWeightDecayDynamicLR.power,
weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay,
eps=cfg.AdamWeightDecayDynamicLR.eps,
warmup_steps=cfg.AdamWeightDecayDynamicLR.warmup_steps)
else:
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecayDynamicLR]".
format(cfg.optimizer))
callback = [TimeMonitor(ds.get_dataset_size()), LossCallBack()]
if args_opt.enable_save_ckpt == "true":
config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
keep_checkpoint_max=args_opt.save_checkpoint_num)
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert', directory=ckpt_save_dir, config=config_ck)
callback.append(ckpoint_cb)
if args_opt.load_checkpoint_path:
param_dict = load_checkpoint(args_opt.load_checkpoint_path)
load_param_into_net(netwithloss, param_dict)
if args_opt.enable_lossscale == "true":
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
scale_factor=cfg.scale_factor,
scale_window=cfg.scale_window)
netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer,
scale_update_cell=update_cell)
else:
netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)
model = Model(netwithgrads)
model.train(new_repeat_count, ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"))
if __name__ == '__main__':
numpy.random.seed(0)
run_pretrain()
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "bash run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_HCCL_CONFIG_PATH"
echo "for example: bash run_distribute_pretrain.sh 8 40 /path/zh-wiki/ /path/Schema.json /path/hccl.json"
echo "It is better to use absolute path."
echo "=============================================================================================================="
EPOCH_SIZE=$2
DATA_DIR=$3
SCHEMA_DIR=$4
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
export RANK_TABLE_FILE=$5
export RANK_SIZE=$1
cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
echo "the number of logical core" $cores
avg_core_per_rank=`expr $cores \/ $RANK_SIZE`
core_gap=`expr $avg_core_per_rank \- 1`
echo "avg_core_per_rank" $avg_core_per_rank
echo "core_gap" $core_gap
for((i=0;i<RANK_SIZE;i++))
do
start=`expr $i \* $avg_core_per_rank`
export DEVICE_ID=$i
export RANK_ID=$i
export DEPLOY_MODE=0
export GE_USE_STATIC_MEMORY=1
end=`expr $start \+ $core_gap`
cmdopt=$start"-"$end
rm -rf LOG$i
mkdir ./LOG$i
cp *.py ./LOG$i
cd ./LOG$i || exit
echo "start training for rank $i, device $DEVICE_ID"
mkdir -p ms_log
CUR_DIR=`pwd`
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
env > env.log
taskset -c $cmdopt python ${PROJECT_DIR}/../run_pretrain.py \
--distribute="true" \
--epoch_size=$EPOCH_SIZE \
--device_id=$DEVICE_ID \
--device_num=$RANK_SIZE \
--enable_save_ckpt="true" \
--enable_lossscale="true" \
--do_shuffle="true" \
--enable_data_sink="true" \
--data_sink_steps=100 \
--load_checkpoint_path="" \
--save_checkpoint_steps=10000 \
--save_checkpoint_num=1 \
--data_dir=$DATA_DIR \
--schema_dir=$SCHEMA_DIR > log.txt 2>&1 &
cd ../
done
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -13,45 +14,31 @@
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py
"""
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "bash run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR"
echo "for example: bash run_distribute_pretrain.sh 8 40 /path/zh-wiki/ /path/Schema.json"
echo "It is better to use absolute path."
echo "=============================================================================================================="
RANK_SIZE=$1
EPOCH_SIZE=$2
DATA_DIR=$3
SCHEMA_DIR=$4
mpirun --allow-run-as-root -n $RANK_SIZE \
python run_pretrain.py \
--device_target="GPU" \
--distribute="true" \
--epoch_size=$EPOCH_SIZE \
--enable_save_ckpt="true" \
--enable_lossscale="false" \
--do_shuffle="true" \
--enable_data_sink="true" \
--data_sink_steps=1 \
--load_checkpoint_path="" \
--save_checkpoint_steps=10000 \
--save_checkpoint_num=1 \
--data_dir=$DATA_DIR \
--schema_dir=$SCHEMA_DIR > log.txt 2>&1 &
from easydict import EasyDict as edict
import mindspore.common.dtype as mstype
from mindspore.model_zoo.Bert_NEZHA import BertConfig
bert_train_cfg = edict({
'epoch_size': 10,
'num_warmup_steps': 0,
'start_learning_rate': 1e-4,
'end_learning_rate': 0.0,
'decay_steps': 1000,
'power': 10.0,
'save_checkpoint_steps': 2000,
'keep_checkpoint_max': 10,
'checkpoint_prefix': "checkpoint_bert",
# please add your own dataset path
'DATA_DIR': "/your/path/examples.tfrecord",
# please add your own dataset schema path
'SCHEMA_DIR': "/your/path/datasetSchema.json"
})
bert_net_cfg = BertConfig(
batch_size=16,
seq_length=128,
vocab_size=21136,
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
intermediate_size=4096,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=True,
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16,
)
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "bash run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR"
echo "for example: bash run_standalone_pretrain.sh 0 40 /path/zh-wiki/ /path/Schema.json"
echo "=============================================================================================================="
DEVICE_ID=$1
EPOCH_SIZE=$2
DATA_DIR=$3
SCHEMA_DIR=$4
mkdir -p ms_log
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
CUR_DIR=`pwd`
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
python ${PROJECT_DIR}/../run_pretrain.py \
--distribute="false" \
--epoch_size=$EPOCH_SIZE \
--device_id=$DEVICE_ID \
--enable_save_ckpt="true" \
--enable_lossscale="true" \
--do_shuffle="true" \
--enable_data_sink="true" \
--data_sink_steps=1 \
--load_checkpoint_path="" \
--save_checkpoint_steps=10000 \
--save_checkpoint_num=1 \
--data_dir=$DATA_DIR \
--schema_dir=$SCHEMA_DIR > log.txt 2>&1 &
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "bash run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR"
echo "for example: bash run_standalone_pretrain.sh 0 40 /path/zh-wiki/ /path/Schema.json"
echo "=============================================================================================================="
DEVICE_ID=$1
EPOCH_SIZE=$2
DATA_DIR=$3
SCHEMA_DIR=$4
export CUDA_VISIBLE_DEVICES=$DEVICE_ID
mkdir -p ms_log
CUR_DIR=`pwd`
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
python run_pretrain.py \
--device_target="GPU" \
--distribute="false" \
--epoch_size=$EPOCH_SIZE \
--enable_save_ckpt="true" \
--enable_lossscale="false" \
--do_shuffle="true" \
--enable_data_sink="true" \
--data_sink_steps=1 \
--load_checkpoint_path="" \
--save_checkpoint_path="" \
--save_checkpoint_steps=10000 \
--save_checkpoint_num=1 \
--data_dir=$DATA_DIR \
--schema_dir=$SCHEMA_DIR > log.txt 2>&1 &
......@@ -27,48 +27,38 @@ from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode
from mindspore.communication.management import get_group_size
from mindspore import context
from mindspore.ops import _selected_ops
from .bert_model import BertModel
GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0
clip_grad = C.MultitypeFuncGraph("clip_grad")
class ClipGradients(nn.Cell):
# pylint: disable=consider-using-in
@clip_grad.register("Number", "Number", "Tensor")
def _clip_grad(clip_type, clip_value, grad):
"""
Clip gradients.
Inputs:
grads (tuple[Tensor]): Gradients.
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
clip_value (float): Specifies how much to clip.
grad (tuple[Tensor]): Gradients.
Outputs:
tuple[Tensor], clipped gradients.
"""
def __init__(self):
super(ClipGradients, self).__init__()
self.clip_by_norm = nn.ClipByNorm()
self.cast = P.Cast()
self.dtype = P.DType()
def construct(self,
grads,
clip_type,
clip_value):
if clip_type != 0 and clip_type != 1:
return grads
new_grads = ()
for grad in grads:
dt = self.dtype(grad)
if clip_type == 0:
t = C.clip_by_value(grad, self.cast(F.tuple_to_array((-clip_value,)), dt),
self.cast(F.tuple_to_array((clip_value,)), dt))
else:
t = self.clip_by_norm(grad, self.cast(F.tuple_to_array((clip_value,)), dt))
new_grads = new_grads + (t,)
return new_grads
if clip_type != 0 and clip_type != 1:
return grad
dt = F.dtype(grad)
if clip_type == 0:
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
F.cast(F.tuple_to_array((clip_value,)), dt))
else:
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
return new_grad
class GetMaskedLMOutput(nn.Cell):
......@@ -92,7 +82,7 @@ class GetMaskedLMOutput(nn.Cell):
config.hidden_size,
weight_init=weight_init,
activation=config.hidden_act).to_float(config.compute_type)
self.layernorm = nn.LayerNorm(config.hidden_size).to_float(config.compute_type)
self.layernorm = nn.LayerNorm((config.hidden_size,)).to_float(config.compute_type)
self.output_bias = Parameter(
initializer(
'zero',
......@@ -141,10 +131,10 @@ class GetNextSentenceOutput(nn.Cell):
"""
def __init__(self, config):
super(GetNextSentenceOutput, self).__init__()
self.log_softmax = P.LogSoftmax()
self.weight_init = TruncatedNormal(config.initializer_range)
self.log_softmax = _selected_ops.LogSoftmax()
weight_init = TruncatedNormal(config.initializer_range)
self.dense = nn.Dense(config.hidden_size, 2,
weight_init=self.weight_init, has_bias=True).to_float(config.compute_type)
weight_init=weight_init, has_bias=True).to_float(config.compute_type)
self.dtype = config.dtype
self.cast = P.Cast()
......@@ -294,8 +284,8 @@ class BertTrainOneStepCell(nn.Cell):
degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.clip_gradients = ClipGradients()
self.cast = P.Cast()
self.hyper_map = C.HyperMap()
def set_sens(self, value):
self.sens = value
......@@ -327,11 +317,10 @@ class BertTrainOneStepCell(nn.Cell):
masked_lm_weights,
self.cast(F.tuple_to_array((self.sens,)),
mstype.float32))
grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
succ = self.optimizer(grads)
return F.depend(loss, succ)
......@@ -370,13 +359,12 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = None
self.grad_reducer = F.identity
self.degree = 1
if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean")
degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.clip_gradients = ClipGradients()
self.cast = P.Cast()
self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus()
......@@ -391,7 +379,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
name="loss_scale")
self.add_flags(has_effect=True)
@C.add_flags(has_effect=True)
def construct(self,
input_ids,
input_mask,
......@@ -426,11 +415,10 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
masked_lm_weights,
self.cast(scaling_sens,
mstype.float32))
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
# apply grad reducer on grads
grads = self.grad_reducer(grads)
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
self.get_status(init)
flag_sum = self.reduce_sum(init, (0,))
if self.is_distributed:
......@@ -446,5 +434,5 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
succ = False
else:
succ = self.optimizer(grads)
ret = (loss, cond)
ret = (loss, cond, scaling_sens)
return F.depend(ret, succ)
......@@ -25,6 +25,7 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from .fused_layer_norm import FusedLayerNorm
class BertConfig:
......@@ -77,7 +78,8 @@ class BertConfig:
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float32):
compute_type=mstype.float32,
enable_fused_layernorm=False):
self.batch_size = batch_size
self.seq_length = seq_length
self.vocab_size = vocab_size
......@@ -96,6 +98,7 @@ class BertConfig:
self.use_relative_positions = use_relative_positions
self.dtype = dtype
self.compute_type = compute_type
self.enable_fused_layernorm = enable_fused_layernorm
class EmbeddingLookup(nn.Cell):
......@@ -190,11 +193,11 @@ class EmbeddingPostprocessor(nn.Cell):
self.array_mul = P.MatMul()
self.reshape = P.Reshape()
self.shape = tuple(embedding_shape)
self.layernorm = nn.LayerNorm(embedding_size)
self.layernorm = nn.LayerNorm((embedding_size,))
self.dropout = nn.Dropout(1 - dropout_prob)
self.gather = P.GatherV2()
self.use_relative_positions = use_relative_positions
self.slice = P.Slice()
self.slice = P.StridedSlice()
self.full_position_embeddings = Parameter(initializer
(TruncatedNormal(initializer_range),
[max_position_embeddings,
......@@ -216,7 +219,7 @@ class EmbeddingPostprocessor(nn.Cell):
output += token_type_embeddings
if not self.use_relative_positions:
_, seq, width = self.shape
position_embeddings = self.slice(self.full_position_embeddings, [0, 0], [seq, width])
position_embeddings = self.slice(self.full_position_embeddings, (0, 0), (seq, width), (1, 1))
position_embeddings = self.reshape(position_embeddings, (1, seq, width))
output += position_embeddings
output = self.layernorm(output)
......@@ -240,19 +243,25 @@ class BertOutput(nn.Cell):
out_channels,
initializer_range=0.02,
dropout_prob=0.1,
compute_type=mstype.float32):
compute_type=mstype.float32,
enable_fused_layernorm=False):
super(BertOutput, self).__init__()
self.dense = nn.Dense(in_channels, out_channels,
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
self.dropout = nn.Dropout(1 - dropout_prob)
self.dropout_prob = dropout_prob
self.add = P.TensorAdd()
self.layernorm = nn.LayerNorm(out_channels).to_float(compute_type)
if compute_type == mstype.float16:
self.layernorm = FusedLayerNorm((out_channels,),
use_batch_norm=enable_fused_layernorm).to_float(compute_type)
else:
self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)
self.cast = P.Cast()
def construct(self, hidden_status, input_tensor):
output = self.dense(hidden_status)
output = self.dropout(output)
output = self.add(output, input_tensor)
output = self.add(input_tensor, output)
output = self.layernorm(output)
return output
......@@ -481,12 +490,13 @@ class BertAttention(nn.Cell):
self.shape_return = (batch_size, from_seq_length, num_attention_heads * size_per_head)
self.cast_compute_type = SaturateCast(dst_type=compute_type)
self._generate_relative_positions_embeddings = \
RelaPosEmbeddingsGenerator(length=to_seq_length,
depth=size_per_head,
max_relative_position=16,
initializer_range=initializer_range,
use_one_hot_embeddings=use_one_hot_embeddings)
if self.use_relative_positions:
self._generate_relative_positions_embeddings = \
RelaPosEmbeddingsGenerator(length=to_seq_length,
depth=size_per_head,
max_relative_position=16,
initializer_range=initializer_range,
use_one_hot_embeddings=use_one_hot_embeddings)
def construct(self, from_tensor, to_tensor, attention_mask):
# reshape 2d/3d input tensors to 2d
......@@ -529,7 +539,7 @@ class BertAttention(nn.Cell):
self.trans_shape_position)
attention_scores = attention_scores + key_position_scores_r_t
attention_scores = self.multiply(attention_scores, self.scores_mul)
attention_scores = self.multiply(self.scores_mul, attention_scores)
if self.has_attention_mask:
attention_mask = self.expand_dims(attention_mask, 1)
......@@ -606,7 +616,8 @@ class BertSelfAttention(nn.Cell):
initializer_range=0.02,
hidden_dropout_prob=0.1,
use_relative_positions=False,
compute_type=mstype.float32):
compute_type=mstype.float32,
enable_fused_layernorm=False):
super(BertSelfAttention, self).__init__()
if hidden_size % num_attention_heads != 0:
raise ValueError("The hidden size (%d) is not a multiple of the number "
......@@ -634,7 +645,8 @@ class BertSelfAttention(nn.Cell):
out_channels=hidden_size,
initializer_range=initializer_range,
dropout_prob=hidden_dropout_prob,
compute_type=compute_type)
compute_type=compute_type,
enable_fused_layernorm=enable_fused_layernorm)
self.reshape = P.Reshape()
self.shape = (-1, hidden_size)
......@@ -676,7 +688,8 @@ class BertEncoderCell(nn.Cell):
hidden_dropout_prob=0.1,
use_relative_positions=False,
hidden_act="gelu",
compute_type=mstype.float32):
compute_type=mstype.float32,
enable_fused_layernorm=False):
super(BertEncoderCell, self).__init__()
self.attention = BertSelfAttention(
batch_size=batch_size,
......@@ -688,7 +701,8 @@ class BertEncoderCell(nn.Cell):
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
use_relative_positions=use_relative_positions,
compute_type=compute_type)
compute_type=compute_type,
enable_fused_layernorm=enable_fused_layernorm)
self.intermediate = nn.Dense(in_channels=hidden_size,
out_channels=intermediate_size,
activation=hidden_act,
......@@ -697,7 +711,8 @@ class BertEncoderCell(nn.Cell):
out_channels=hidden_size,
initializer_range=initializer_range,
dropout_prob=hidden_dropout_prob,
compute_type=compute_type)
compute_type=compute_type,
enable_fused_layernorm=enable_fused_layernorm)
def construct(self, hidden_states, attention_mask):
# self-attention
......@@ -744,7 +759,8 @@ class BertTransformer(nn.Cell):
use_relative_positions=False,
hidden_act="gelu",
compute_type=mstype.float32,
return_all_encoders=False):
return_all_encoders=False,
enable_fused_layernorm=False):
super(BertTransformer, self).__init__()
self.return_all_encoders = return_all_encoders
......@@ -761,7 +777,8 @@ class BertTransformer(nn.Cell):
hidden_dropout_prob=hidden_dropout_prob,
use_relative_positions=use_relative_positions,
hidden_act=hidden_act,
compute_type=compute_type)
compute_type=compute_type,
enable_fused_layernorm=enable_fused_layernorm)
layers.append(layer)
self.layers = nn.CellList(layers)
......@@ -802,21 +819,20 @@ class CreateAttentionMaskFromInputMask(nn.Cell):
if not self.input_mask_from_dataset:
self.input_mask = initializer(
"ones", [config.batch_size, config.seq_length], mstype.int32)
"ones", [config.batch_size, config.seq_length], mstype.int32).to_tensor()
self.cast = P.Cast()
self.reshape = P.Reshape()
self.shape = (config.batch_size, 1, config.seq_length)
self.broadcast_ones = initializer(
"ones", [config.batch_size, config.seq_length, 1], mstype.float32)
"ones", [config.batch_size, config.seq_length, 1], mstype.float32).to_tensor()
self.batch_matmul = P.BatchMatMul()
def construct(self, input_mask):
if not self.input_mask_from_dataset:
input_mask = self.input_mask
input_mask = self.cast(self.reshape(input_mask, self.shape), mstype.float32)
attention_mask = self.batch_matmul(self.broadcast_ones, input_mask)
attention_mask = self.cast(self.reshape(input_mask, self.shape), mstype.float32)
return attention_mask
......@@ -854,7 +870,7 @@ class BertModel(nn.Cell):
if not self.token_type_ids_from_dataset:
self.token_type_ids = initializer(
"zeros", [self.batch_size, self.seq_length], mstype.int32)
"zeros", [self.batch_size, self.seq_length], mstype.int32).to_tensor()
self.bert_embedding_lookup = EmbeddingLookup(
vocab_size=config.vocab_size,
......@@ -888,7 +904,8 @@ class BertModel(nn.Cell):
use_relative_positions=config.use_relative_positions,
hidden_act=config.hidden_act,
compute_type=config.compute_type,
return_all_encoders=True)
return_all_encoders=True,
enable_fused_layernorm=config.enable_fused_layernorm)
self.cast = P.Cast()
self.dtype = config.dtype
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in dataset.py, run_pretrain.py
"""
from easydict import EasyDict as edict
import mindspore.common.dtype as mstype
from .bert_model import BertConfig
cfg = edict({
'bert_network': 'base',
'loss_scale_value': 65536,
'scale_factor': 2,
'scale_window': 1000,
'optimizer': 'Lamb',
'AdamWeightDecayDynamicLR': edict({
'learning_rate': 3e-5,
'end_learning_rate': 1e-10,
'power': 5.0,
'weight_decay': 1e-5,
'eps': 1e-6,
'warmup_steps': 10000,
}),
'Lamb': edict({
'start_learning_rate': 3e-5,
'end_learning_rate': 1e-10,
'power': 10.0,
'warmup_steps': 10000,
'weight_decay': 0.01,
'eps': 1e-6,
}),
'Momentum': edict({
'learning_rate': 2e-5,
'momentum': 0.9,
}),
})
'''
Including two kinds of network: \
base: Goole BERT-base(the base version of BERT model).
large: BERT-NEZHA(a Chinese pretrained language model developed by Huawei, which introduced a improvement of \
Functional Relative Posetional Encoding as an effective positional encoding scheme).
'''
if cfg.bert_network == 'base':
bert_net_cfg = BertConfig(
batch_size=32,
seq_length=128,
vocab_size=21128,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=False,
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16
)
if cfg.bert_network == 'nezha':
bert_net_cfg = BertConfig(
batch_size=32,
seq_length=128,
vocab_size=21128,
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
intermediate_size=4096,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=True,
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16
)
if cfg.bert_network == 'large':
bert_net_cfg = BertConfig(
batch_size=16,
seq_length=512,
vocab_size=30522,
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
intermediate_size=4096,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=False,
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16,
enable_fused_layernorm=True
)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Data operations, will be used in run_pretrain.py
"""
import os
import mindspore.common.dtype as mstype
import mindspore.dataset.engine.datasets as de
import mindspore.dataset.transforms.c_transforms as C
from mindspore import log as logger
from .config import bert_net_cfg
def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", enable_data_sink="true",
data_sink_steps=1, data_dir=None, schema_dir=None):
"""create train dataset"""
# apply repeat operations
repeat_count = epoch_size
files = os.listdir(data_dir)
data_files = []
for file_name in files:
if "tfrecord" in file_name:
data_files.append(os.path.join(data_dir, file_name))
ds = de.TFRecordDataset(data_files, schema_dir if schema_dir != "" else None,
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"],
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank,
shard_equal_rows=True)
ori_dataset_size = ds.get_dataset_size()
print('origin dataset size: ', ori_dataset_size)
new_size = ori_dataset_size
if enable_data_sink == "true":
new_size = data_sink_steps * bert_net_cfg.batch_size
ds.set_dataset_size(new_size)
new_repeat_count = int(repeat_count * ori_dataset_size // ds.get_dataset_size())
type_cast_op = C.TypeCast(mstype.int32)
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)
ds = ds.map(input_columns="next_sentence_labels", operations=type_cast_op)
ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
ds = ds.map(input_columns="input_mask", operations=type_cast_op)
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
# apply batch operations
ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True)
ds = ds.repeat(max(new_repeat_count, repeat_count))
logger.info("data size: {}".format(ds.get_dataset_size()))
logger.info("repeatcount: {}".format(ds.get_repeat_count()))
return ds, new_repeat_count
def create_ner_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy",
data_file_path=None, schema_file_path=None):
"""create finetune or evaluation dataset"""
type_cast_op = C.TypeCast(mstype.int32)
ds = de.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None,
columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"])
if assessment_method == "Spearman_correlation":
type_cast_op_float = C.TypeCast(mstype.float32)
ds = ds.map(input_columns="label_ids", operations=type_cast_op_float)
else:
ds = ds.map(input_columns="label_ids", operations=type_cast_op)
ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
ds = ds.map(input_columns="input_mask", operations=type_cast_op)
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
ds = ds.repeat(repeat_count)
# apply shuffle operation
buffer_size = 960
ds = ds.shuffle(buffer_size=buffer_size)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
return ds
def create_classification_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy",
data_file_path=None, schema_file_path=None):
"""create finetune or evaluation dataset"""
type_cast_op = C.TypeCast(mstype.int32)
ds = de.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None,
columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"])
if assessment_method == "Spearman_correlation":
type_cast_op_float = C.TypeCast(mstype.float32)
ds = ds.map(input_columns="label_ids", operations=type_cast_op_float)
else:
ds = ds.map(input_columns="label_ids", operations=type_cast_op)
ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
ds = ds.map(input_columns="input_mask", operations=type_cast_op)
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
ds = ds.repeat(repeat_count)
# apply shuffle operation
buffer_size = 960
ds = ds.shuffle(buffer_size=buffer_size)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
return ds
def create_squad_dataset(batch_size=1, repeat_count=1, data_file_path=None, schema_file_path=None, is_training=True):
"""create finetune or evaluation dataset"""
type_cast_op = C.TypeCast(mstype.int32)
if is_training:
ds = de.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None,
columns_list=["input_ids", "input_mask", "segment_ids",
"start_positions", "end_positions",
"unique_ids", "is_impossible"])
ds = ds.map(input_columns="start_positions", operations=type_cast_op)
ds = ds.map(input_columns="end_positions", operations=type_cast_op)
else:
ds = de.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None,
columns_list=["input_ids", "input_mask", "segment_ids", "unique_ids"])
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
ds = ds.map(input_columns="input_mask", operations=type_cast_op)
ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
ds = ds.map(input_columns="input_mask", operations=type_cast_op)
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
ds = ds.repeat(repeat_count)
# apply shuffle operation
buffer_size = 960
ds = ds.shuffle(buffer_size=buffer_size)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
return ds
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""fused layernorm"""
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.ops.primitive import constexpr
import mindspore.common.dtype as mstype
from mindspore.nn.cell import Cell
import numpy as np
__all__ = ['FusedLayerNorm']
@constexpr
def get_shape_for_norm(x_shape, begin_norm_axis):
print("input_shape: ", x_shape)
norm_shape = x_shape[begin_norm_axis:]
output_shape = (1, -1, 1, int(np.prod(norm_shape)))
print("output_shape: ", output_shape)
return output_shape
class FusedLayerNorm(Cell):
r"""
Applies Layer Normalization over a mini-batch of inputs.
Layer normalization is widely used in recurrent neural networks. It applies
normalization over a mini-batch of inputs for each single training case as described
in the paper `Layer Normalization <https://arxiv.org/pdf/1607.06450.pdf>`_. Unlike batch
normalization, layer normalization performs exactly the same computation at training and
testing times. It can be described using the following formula. It is applied across all channels
and pixel but only one batch size.
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
Args:
normalized_shape (Union(tuple[int], list[int]): The normalization is performed over axis
`begin_norm_axis ... R - 1`.
begin_norm_axis (int): It first normalization dimension: normalization will be performed along dimensions
`begin_norm_axis: rank(inputs)`, the value should be in [-1, rank(input)). Default: -1.
begin_params_axis (int): The first parameter(beta, gamma)dimension: scale and centering parameters
will have dimensions `begin_params_axis: rank(inputs)` and will be broadcast with
the normalized inputs accordingly, the value should be in [-1, rank(input)). Default: -1.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'ones'.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'zeros'.
use_batch_nrom (bool): Whether use batchnorm to preocess.
Inputs:
- **input_x** (Tensor) - The shape of 'input_x' is :math:`(x_1, x_2, ..., x_R)`,
and `input_shape[begin_norm_axis:]` is equal to `normalized_shape`.
Outputs:
Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`.
Examples:
>>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32)
>>> shape1 = x.shape[1:]
>>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1)
>>> m(x)
"""
def __init__(self,
normalized_shape,
begin_norm_axis=-1,
begin_params_axis=-1,
gamma_init='ones',
beta_init='zeros',
use_batch_norm=False):
super(FusedLayerNorm, self).__init__()
if not isinstance(normalized_shape, (tuple, list)):
raise TypeError("The type of 'normalized_shape' should be tuple[int] or list[int], but '{}' type is {}."
.format(normalized_shape, type(normalized_shape)))
self.normalized_shape = normalized_shape
self.begin_norm_axis = begin_norm_axis
self.begin_params_axis = begin_params_axis
self.gamma = Parameter(initializer(
gamma_init, normalized_shape), name="gamma")
self.beta = Parameter(initializer(
beta_init, normalized_shape), name="beta")
self.layer_norm = P.LayerNorm(begin_norm_axis=self.begin_norm_axis, begin_params_axis=self.begin_params_axis)
self.batch_norm = P.BatchNorm(is_training=True, epsilon=1e-5)
self.use_batch_norm = use_batch_norm
def construct(self, input_x):
if self.use_batch_norm and self.training:
ones = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 1.0)
zeros = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 0.0)
shape_x = F.shape(input_x)
norm_shape = get_shape_for_norm(shape_x, self.begin_norm_axis)
input_x = F.reshape(input_x, norm_shape)
output, _, _, _, _, _ = self.batch_norm(input_x, ones, zeros, None, None)
output = F.reshape(output, shape_x)
y = output * self.gamma + self.beta
else:
y, _, _ = self.layer_norm(input_x, self.gamma, self.beta)
return y
def extend_repr(self):
"""Display instance object as string."""
s = 'normalized_shape={}, begin_norm_axis={}, begin_params_axis={}, gamma{}, beta={}'.format(
self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.gamma, self.beta)
return s
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Functional Cells used in Bert finetune and evaluation.
"""
import os
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.common import dtype as mstype
from mindspore.train.callback import Callback
class CrossEntropyCalculation(nn.Cell):
"""
Cross Entropy loss
"""
def __init__(self, is_training=True):
super(CrossEntropyCalculation, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.reduce_sum = P.ReduceSum()
self.reduce_mean = P.ReduceMean()
self.reshape = P.Reshape()
self.last_idx = (-1,)
self.neg = P.Neg()
self.cast = P.Cast()
self.is_training = is_training
def construct(self, logits, label_ids, num_labels):
if self.is_training:
label_ids = self.reshape(label_ids, self.last_idx)
one_hot_labels = self.onehot(label_ids, num_labels, self.on_value, self.off_value)
per_example_loss = self.neg(self.reduce_sum(one_hot_labels * logits, self.last_idx))
loss = self.reduce_mean(per_example_loss, self.last_idx)
return_value = self.cast(loss, mstype.float32)
else:
return_value = logits * 1.0
return return_value
def make_directory(path: str):
"""Make directory."""
if path is None or not isinstance(path, str) or path.strip() == "":
logger.error("The path(%r) is invalid type.", path)
raise TypeError("Input path is invaild type")
# convert the relative paths
path = os.path.realpath(path)
logger.debug("The abs path is %r", path)
# check the path is exist and write permissions?
if os.path.exists(path):
real_path = path
else:
# All exceptions need to be caught because create directory maybe have some limit(permissions)
logger.debug("The directory(%s) doesn't exist, will create it", path)
try:
os.makedirs(path, exist_ok=True)
real_path = path
except PermissionError as e:
logger.error("No write permission on the directory(%r), error = %r", path, e)
raise TypeError("No write permission on the directory.")
return real_path
class LossCallBack(Callback):
"""
Monitor the loss in training.
If the loss in NAN or INF terminating training.
Note:
if per_print_times is 0 do not print loss.
Args:
per_print_times (int): Print loss every times. Default: 1.
"""
def __init__(self, per_print_times=1):
super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0")
self._per_print_times = per_print_times
def step_end(self, run_context):
cb_params = run_context.original_args()
print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num,
str(cb_params.net_outputs)))
def LoadNewestCkpt(load_finetune_checkpoint_dir, steps_per_epoch, epoch_num, prefix):
"""
Find the ckpt finetune generated and load it into eval network.
"""
files = os.listdir(load_finetune_checkpoint_dir)
pre_len = len(prefix)
max_num = 0
for filename in files:
name_ext = os.path.splitext(filename)
if name_ext[-1] != ".ckpt":
continue
#steps_per_epoch = ds.get_dataset_size()
if filename.find(prefix) == 0 and not filename[pre_len].isalpha():
index = filename[pre_len:].find("-")
if index == 0 and max_num == 0:
load_finetune_checkpoint_path = os.path.join(load_finetune_checkpoint_dir, filename)
elif index not in (0, -1):
name_split = name_ext[-2].split('_')
if (steps_per_epoch != int(name_split[len(name_split)-1])) \
or (epoch_num != int(filename[pre_len + index + 1:pre_len + index + 2])):
continue
num = filename[pre_len + 1:pre_len + index]
if int(num) > max_num:
max_num = int(num)
load_finetune_checkpoint_path = os.path.join(load_finetune_checkpoint_dir, filename)
return load_finetune_checkpoint_path
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册