提交 8701ca24 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!950 add bert base version finetune and evaluation script in bert_clue dir

Merge pull request !950 from yoonlee666/master
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''
CRF script.
'''
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
import mindspore.common.dtype as mstype
class CRF(nn.Cell):
'''
Conditional Random Field
Args:
tag_to_index: The dict for tag to index mapping with extra "<START>" and "<STOP>"sign.
batch_size: Batch size, i.e., the length of the first dimension.
seq_length: Sequence length, i.e., the length of the second dimention.
is_training: Specifies whether to use training mode.
Returns:
Training mode: Tensor, total loss.
Evaluation mode: Tuple, the index for each step with the highest score; Tuple, the index for the last
step with the highest score.
'''
def __init__(self, tag_to_index, batch_size=1, seq_length=128, is_training=True):
super(CRF, self).__init__()
self.target_size = len(tag_to_index)
self.is_training = is_training
self.tag_to_index = tag_to_index
self.batch_size = batch_size
self.seq_length = seq_length
self.START_TAG = "<START>"
self.STOP_TAG = "<STOP>"
self.START_VALUE = Tensor(self.target_size-2, dtype=mstype.int32)
self.STOP_VALUE = Tensor(self.target_size-1, dtype=mstype.int32)
transitions = np.random.normal(size=(self.target_size, self.target_size)).astype(np.float32)
transitions[tag_to_index[self.START_TAG], :] = -10000
transitions[:, tag_to_index[self.STOP_TAG]] = -10000
self.transitions = Parameter(Tensor(transitions), name="transition_matrix")
self.cat = P.Concat(axis=-1)
self.argmax = P.ArgMaxWithValue(axis=-1)
self.log = P.Log()
self.exp = P.Exp()
self.sum = P.ReduceSum()
self.tile = P.Tile()
self.reduce_sum = P.ReduceSum(keep_dims=True)
self.reshape = P.Reshape()
self.expand = P.ExpandDims()
self.mean = P.ReduceMean()
init_alphas = np.ones(shape=(self.batch_size, self.target_size)) * -10000.0
init_alphas[:, self.tag_to_index[self.START_TAG]] = 0.
self.init_alphas = Tensor(init_alphas, dtype=mstype.float32)
self.cast = P.Cast()
self.reduce_max = P.ReduceMax(keep_dims=True)
self.on_value = Tensor(1.0, dtype=mstype.float32)
self.off_value = Tensor(0.0, dtype=mstype.float32)
self.onehot = P.OneHot()
def log_sum_exp(self, logits):
'''
Compute the log_sum_exp score for normalization factor.
'''
max_score = self.reduce_max(logits, -1) #16 5 5
score = self.log(self.reduce_sum(self.exp(logits - max_score), -1))
score = max_score + score
return score
def _realpath_score(self, features, label):
'''
Compute the emission and transition score for the real path.
'''
label = label * 1
concat_A = self.tile(self.reshape(self.START_VALUE, (1,)), (self.batch_size,))
concat_A = self.reshape(concat_A, (self.batch_size, 1))
labels = self.cat((concat_A, label))
onehot_label = self.onehot(label, self.target_size, self.on_value, self.off_value)
emits = features * onehot_label
labels = self.onehot(labels, self.target_size, self.on_value, self.off_value)
label1 = labels[:, 1:, :]
label2 = labels[:, :self.seq_length, :]
label1 = self.expand(label1, 3)
label2 = self.expand(label2, 2)
label_trans = label1 * label2
transitions = self.expand(self.expand(self.transitions, 0), 0)
trans = transitions * label_trans
score = self.sum(emits, (1, 2)) + self.sum(trans, (1, 2, 3))
stop_value_index = labels[:, (self.seq_length-1):self.seq_length, :]
stop_value = self.transitions[(self.target_size-1):self.target_size, :]
stop_score = stop_value * self.reshape(stop_value_index, (self.batch_size, self.target_size))
score = score + self.sum(stop_score, 1)
score = self.reshape(score, (self.batch_size, -1))
return score
def _normalization_factor(self, features):
'''
Compute the total score for all the paths.
'''
forward_var = self.init_alphas
forward_var = self.expand(forward_var, 1)
for idx in range(self.seq_length):
feat = features[:, idx:(idx+1), :]
emit_score = self.reshape(feat, (self.batch_size, self.target_size, 1))
next_tag_var = emit_score + self.transitions + forward_var
forward_var = self.log_sum_exp(next_tag_var)
forward_var = self.reshape(forward_var, (self.batch_size, 1, self.target_size))
terminal_var = forward_var + self.reshape(self.transitions[(self.target_size-1):self.target_size, :], (1, -1))
alpha = self.log_sum_exp(terminal_var)
alpha = self.reshape(alpha, (self.batch_size, -1))
return alpha
def _decoder(self, features):
'''
Viterbi decode for evaluation.
'''
backpointers = ()
forward_var = self.init_alphas
for idx in range(self.seq_length):
feat = features[:, idx:(idx+1), :]
feat = self.reshape(feat, (self.batch_size, self.target_size))
bptrs_t = ()
next_tag_var = self.expand(forward_var, 1) + self.transitions
best_tag_id, best_tag_value = self.argmax(next_tag_var)
bptrs_t += (best_tag_id,)
forward_var = best_tag_value + feat
backpointers += (bptrs_t,)
terminal_var = forward_var + self.reshape(self.transitions[(self.target_size-1):self.target_size, :], (1, -1))
best_tag_id, _ = self.argmax(terminal_var)
return backpointers, best_tag_id
def construct(self, features, label):
if self.is_training:
forward_score = self._normalization_factor(features)
gold_score = self._realpath_score(features, label)
return_value = self.mean(forward_score - gold_score)
else:
path_list, tag = self._decoder(features)
return_value = path_list, tag
return return_value
def postprocess(backpointers, best_tag_id):
'''
Do postprocess
'''
best_tag_id = best_tag_id.asnumpy()
batch_size = len(best_tag_id)
best_path = []
for i in range(batch_size):
best_path.append([])
best_local_id = best_tag_id[i]
best_path[-1].append(best_local_id)
for bptrs_t in reversed(backpointers):
bptrs_t = bptrs_t[0].asnumpy()
local_idx = bptrs_t[i]
best_local_id = local_idx[best_local_id]
best_path[-1].append(best_local_id)
# Pop off the start tag (we dont want to return that to the caller)
best_path[-1].pop()
best_path[-1].reverse()
return best_path
......@@ -16,21 +16,21 @@ This example implements pre-training, fine-tuning and evaluation of [BERT-base](
- Run `run_standalone_pretrain.sh` for non-distributed pre-training of BERT-base and BERT-NEZHA model.
``` bash
sh run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_PATH
sh run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR
```
- Run `run_distribute_pretrain.sh` for distributed pre-training of BERT-base and BERT-NEZHA model.
``` bash
sh run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_HCCL_CONFIG_PATH MINDSPORE_PATH
sh run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_HCCL_CONFIG_PATH
```
### Fine-Tuning
- Set options in `finetune_config.py`. Make sure the 'data_file', 'schema_file' and 'ckpt_file' are set to your own path, set the 'pre_training_ckpt' to save the checkpoint files generated.
- Set options in `finetune_config.py`. Make sure the 'data_file', 'schema_file' and 'pre_training_file' are set to your own path. Set the 'pre_training_ckpt' to a saved checkpoint file generated after pre-training.
- Run `finetune.py` for fine-tuning of BERT-base and BERT-NEZHA model.
```bash
python finetune.py --backend=ms
python finetune.py
```
### Evaluation
......@@ -39,7 +39,7 @@ This example implements pre-training, fine-tuning and evaluation of [BERT-base](
- Run `evaluation.py` for evaluation of BERT-base and BERT-NEZHA model.
```bash
python evaluation.py --backend=ms
python evaluation.py
```
## Usage
......@@ -76,28 +76,33 @@ options:
It contains of parameters of BERT model and options for training, which is set in file `config.py`, `finetune_config.py` and `evaluation_config.py` respectively.
### Options:
```
Pre-Training:
config.py:
bert_network version of BERT model: base | nezha, default is base
loss_scale_value initial value of loss scale: N, default is 2^32
scale_factor factor used to update loss scale: N, default is 2
scale_window steps for once updatation of loss scale: N, default is 1000
optimizer optimizer used in the network: AdamWerigtDecayDynamicLR | Lamb | Momentum, default is "Lamb"
Fine-Tuning:
task task type: NER | XNLI | LCQMC | SENTI
data_file dataset file to load: PATH, default is "/your/path/cn-wiki-128"
schema_file dataset schema file to load: PATH, default is "/your/path/datasetSchema.json"
epoch_num repeat counts of training: N, default is 40
finetune_config.py:
task task type: NER | XNLI | LCQMC | SENTIi | OTHERS
num_labels number of labels to do classification
data_file dataset file to load: PATH, default is "/your/path/train.tfrecord"
schema_file dataset schema file to load: PATH, default is "/your/path/schema.json"
epoch_num repeat counts of training: N, default is 5
ckpt_prefix prefix used to save checkpoint files: PREFIX, default is "bert"
ckpt_dir path to save checkpoint files: PATH, default is None
pre_training_ckpt checkpoint file to load: PATH, default is "/your/path/pre_training.ckpt"
optimizer optimizer used in the network: AdamWeigtDecayDynamicLR | Lamb | Momentum, default is "Lamb"
use_crf whether to use crf for evaluation. use_crf takes effect only when task type is NER, default is False
optimizer optimizer used in fine-tune network: AdamWeigtDecayDynamicLR | Lamb | Momentum, default is "Lamb"
Evaluation:
task task type: NER | XNLI | LCQMC | SENTI
evaluation_config.py:
task task type: NER | XNLI | LCQMC | SENTI | OTHERS
num_labels number of labels to do classsification
data_file dataset file to load: PATH, default is "/your/path/evaluation.tfrecord"
schema_file dataset schema file to load: PATH, default is "/your/path/schema.json"
finetune_ckpt checkpoint file to load: PATH, default is "/your/path/your.ckpt"
use_crf whether to use crf for evaluation. use_crf takes effect only when task type is NER, default is False
clue_benchmark whether to use clue benchmark. clue_benchmark takes effect only when task type is NER, default is False
```
### Parameters:
......@@ -124,25 +129,24 @@ Parameters for dataset and network (Pre-Training/Fine-Tuning/Evaluation):
Parameters for optimizer:
AdamWeightDecayDynamicLR:
decay_steps steps of the learning rate decay: N, default is 12276*3
learning_rate value of learning rate: Q, default is 1e-5
end_learning_rate value of end learning rate: Q, default is 0.0
power power: Q, default is 10.0
warmup_steps steps of the learning rate warm up: N, default is 2100
weight_decay weight decay: Q, default is 1e-5
eps term added to the denominator to improve numerical stability: Q, default is 1e-6
decay_steps steps of the learning rate decay: N
learning_rate value of learning rate: Q
end_learning_rate value of end learning rate: Q, must be positive
power power: Q
warmup_steps steps of the learning rate warm up: N
weight_decay weight decay: Q
eps term added to the denominator to improve numerical stability: Q
Lamb:
decay_steps steps of the learning rate decay: N, default is 12276*3
learning_rate value of learning rate: Q, default is 1e-5
end_learning_rate value of end learning rate: Q, default is 0.0
power power: Q, default is 5.0
warmup_steps steps of the learning rate warm up: N, default is 2100
weight_decay weight decay: Q, default is 1e-5
decay_filter function to determine whether to apply weight decay on parameters: FUNCTION, default is lambda x: False
decay_steps steps of the learning rate decay: N
learning_rate value of learning rate: Q
end_learning_rate value of end learning rate: Q
power power: Q
warmup_steps steps of the learning rate warm up: N
weight_decay weight decay: Q
Momentum:
learning_rate value of learning rate: Q, default is 2e-5
momentum momentum for the moving average: Q, default is 0.9
learning_rate value of learning rate: Q
momentum momentum for the moving average: Q
```
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''bert clue evaluation'''
import json
import numpy as np
from evaluation_config import cfg
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from CRF import postprocess
import tokenization
from sample_process import label_generation, process_one_example_p
vocab_file = "./vocab.txt"
tokenizer_ = tokenization.FullTokenizer(vocab_file=vocab_file)
def process(model, text, sequence_length):
"""
process text.
"""
data = [text]
features = []
res = []
ids = []
for i in data:
feature = process_one_example_p(tokenizer_, i, max_seq_len=sequence_length)
features.append(feature)
input_ids, input_mask, token_type_id = feature
input_ids = Tensor(np.array(input_ids), mstype.int32)
input_mask = Tensor(np.array(input_mask), mstype.int32)
token_type_id = Tensor(np.array(token_type_id), mstype.int32)
if cfg.use_crf:
backpointers, best_tag_id = model.predict(input_ids, input_mask, token_type_id, Tensor(1))
best_path = postprocess(backpointers, best_tag_id)
logits = []
for ele in best_path:
logits.extend(ele)
ids = logits
else:
logits = model.predict(input_ids, input_mask, token_type_id, Tensor(1))
ids = logits.asnumpy()
ids = np.argmax(ids, axis=-1)
ids = list(ids)
res = label_generation(text, ids)
return res
def submit(model, path, sequence_length):
"""
submit task
"""
data = []
for line in open(path):
if not line.strip():
continue
oneline = json.loads(line.strip())
res = process(model, oneline["text"], sequence_length)
print("text", oneline["text"])
print("res:", res)
data.append(json.dumps({"label": res}, ensure_ascii=False))
open("ner_predict.json", "w").write("\n".join(data))
......@@ -30,6 +30,7 @@ cfg = edict({
'power': 5.0,
'weight_decay': 1e-5,
'eps': 1e-6,
'warmup_steps': 10000,
}),
'Lamb': edict({
'start_learning_rate': 3e-5,
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Bert evaluation script.
"""
import os
import numpy as np
from evaluation_config import cfg, bert_net_cfg
from utils import BertNER, BertCLS
import mindspore.common.dtype as mstype
from mindspore import context
from mindspore.common.tensor import Tensor
import mindspore.dataset as de
import mindspore.dataset.transforms.c_transforms as C
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from CRF import postprocess
from cluener_evaluation import submit
from finetune_config import tag_to_index
class Accuracy():
'''
calculate accuracy
'''
def __init__(self):
self.acc_num = 0
self.total_num = 0
def update(self, logits, labels):
labels = labels.asnumpy()
labels = np.reshape(labels, -1)
logits = logits.asnumpy()
logit_id = np.argmax(logits, axis=-1)
self.acc_num += np.sum(labels == logit_id)
self.total_num += len(labels)
print("=========================accuracy is ", self.acc_num / self.total_num)
class F1():
'''
calculate F1 score
'''
def __init__(self):
self.TP = 0
self.FP = 0
self.FN = 0
def update(self, logits, labels):
'''
update F1 score
'''
labels = labels.asnumpy()
labels = np.reshape(labels, -1)
if cfg.use_crf:
backpointers, best_tag_id = logits
best_path = postprocess(backpointers, best_tag_id)
logit_id = []
for ele in best_path:
logit_id.extend(ele)
else:
logits = logits.asnumpy()
logit_id = np.argmax(logits, axis=-1)
logit_id = np.reshape(logit_id, -1)
pos_eva = np.isin(logit_id, [i for i in range(1, cfg.num_labels)])
pos_label = np.isin(labels, [i for i in range(1, cfg.num_labels)])
self.TP += np.sum(pos_eva&pos_label)
self.FP += np.sum(pos_eva&(~pos_label))
self.FN += np.sum((~pos_eva)&pos_label)
def get_dataset(batch_size=1, repeat_count=1, distribute_file=''):
'''
get dataset
'''
ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask",
"segment_ids", "label_ids"])
type_cast_op = C.TypeCast(mstype.int32)
ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
ds = ds.map(input_columns="input_mask", operations=type_cast_op)
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
ds = ds.map(input_columns="label_ids", operations=type_cast_op)
ds = ds.repeat(repeat_count)
# apply shuffle operation
buffer_size = 960
ds = ds.shuffle(buffer_size=buffer_size)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
return ds
def bert_predict(Evaluation):
'''
prediction function
'''
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid)
dataset = get_dataset(bert_net_cfg.batch_size, 1)
if cfg.use_crf:
net_for_pretraining = Evaluation(bert_net_cfg, False, num_labels=len(tag_to_index), use_crf=True,
tag_to_index=tag_to_index, dropout_prob=0.0)
else:
net_for_pretraining = Evaluation(bert_net_cfg, False, num_labels)
net_for_pretraining.set_train(False)
param_dict = load_checkpoint(cfg.finetune_ckpt)
load_param_into_net(net_for_pretraining, param_dict)
model = Model(net_for_pretraining)
return model, dataset
def test_eval():
'''
evaluation function
'''
task_type = BertNER if cfg.task == "NER" else BertCLS
model, dataset = bert_predict(task_type)
if cfg.clue_benchmark:
submit(model, cfg.data_file, bert_net_cfg.seq_length)
else:
callback = F1() if cfg.task == "NER" else Accuracy()
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
for data in dataset.create_dict_iterator():
input_data = []
for i in columns_list:
input_data.append(Tensor(data[i]))
input_ids, input_mask, token_type_id, label_ids = input_data
logits = model.predict(input_ids, input_mask, token_type_id, label_ids)
callback.update(logits, label_ids)
print("==============================================================")
if cfg.task == "NER":
print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP)))
print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN)))
print("F1 {:.6f} ".format(2*callback.TP / (2*callback.TP + callback.FP + callback.FP)))
else:
print("acc_num {} , total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num,
callback.acc_num / callback.total_num))
print("==============================================================")
if __name__ == "__main__":
num_labels = cfg.num_labels
test_eval()
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
config settings, will be used in finetune.py
"""
from easydict import EasyDict as edict
import mindspore.common.dtype as mstype
from mindspore.model_zoo.Bert_NEZHA import BertConfig
cfg = edict({
'task': 'NER',
'num_labels': 41,
'data_file': '/your/path/evaluation.tfrecord',
'schema_file': '/your/path/schema.json',
'finetune_ckpt': '/your/path/your.ckpt',
'use_crf': False,
'clue_benchmark': False,
})
bert_net_cfg = BertConfig(
batch_size=16 if not cfg.clue_benchmark else 1,
seq_length=128,
vocab_size=21128,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=False,
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16,
)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''
Bert finetune script.
'''
import os
from utils import BertFinetuneCell, BertCLS, BertNER
from finetune_config import cfg, bert_net_cfg, tag_to_index
import mindspore.common.dtype as mstype
import mindspore.communication.management as D
from mindspore import context
import mindspore.dataset as de
import mindspore.dataset.transforms.c_transforms as C
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.nn.optim import AdamWeightDecayDynamicLR, Lamb, Momentum
from mindspore.train.model import Model
from mindspore.train.callback import Callback
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
from mindspore.train.serialization import load_checkpoint, load_param_into_net
class LossCallBack(Callback):
'''
Monitor the loss in training.
If the loss is NAN or INF, terminate training.
Note:
If per_print_times is 0, do not print loss.
Args:
per_print_times (int): Print loss every times. Default: 1.
'''
def __init__(self, per_print_times=1):
super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be in and >= 0.")
self._per_print_times = per_print_times
def step_end(self, run_context):
cb_params = run_context.original_args()
with open("./loss.log", "a+") as f:
f.write("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num,
str(cb_params.net_outputs)))
f.write("\n")
def get_dataset(batch_size=1, repeat_count=1, distribute_file=''):
'''
get dataset
'''
ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask",
"segment_ids", "label_ids"])
type_cast_op = C.TypeCast(mstype.int32)
ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
ds = ds.map(input_columns="input_mask", operations=type_cast_op)
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
ds = ds.map(input_columns="label_ids", operations=type_cast_op)
ds = ds.repeat(repeat_count)
# apply shuffle operation
buffer_size = 960
ds = ds.shuffle(buffer_size=buffer_size)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
return ds
def test_train():
'''
finetune function
pytest -s finetune.py::test_train
'''
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid,
enable_mem_reuse=True, enable_task_sink=True)
#BertCLSTrain for classification
#BertNERTrain for sequence labeling
if cfg.task == 'NER':
if cfg.use_crf:
netwithloss = BertNER(bert_net_cfg, True, num_labels=len(tag_to_index), use_crf=True,
tag_to_index=tag_to_index, dropout_prob=0.1)
else:
netwithloss = BertNER(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1)
else:
netwithloss = BertCLS(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1)
dataset = get_dataset(bert_net_cfg.batch_size, cfg.epoch_num)
# optimizer
steps_per_epoch = dataset.get_dataset_size()
if cfg.optimizer == 'AdamWeightDecayDynamicLR':
optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(),
decay_steps=steps_per_epoch * cfg.epoch_num,
learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate,
end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate,
power=cfg.AdamWeightDecayDynamicLR.power,
warmup_steps=steps_per_epoch,
weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay,
eps=cfg.AdamWeightDecayDynamicLR.eps)
elif cfg.optimizer == 'Lamb':
optimizer = Lamb(netwithloss.trainable_params(), decay_steps=steps_per_epoch * cfg.epoch_num,
start_learning_rate=cfg.Lamb.start_learning_rate, end_learning_rate=cfg.Lamb.end_learning_rate,
power=cfg.Lamb.power, warmup_steps=steps_per_epoch, decay_filter=cfg.Lamb.decay_filter)
elif cfg.optimizer == 'Momentum':
optimizer = Momentum(netwithloss.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
momentum=cfg.Momentum.momentum)
else:
raise Exception("Optimizer not supported.")
# load checkpoint into network
ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
ckpoint_cb = ModelCheckpoint(prefix=cfg.ckpt_prefix, directory=cfg.ckpt_dir, config=ckpt_config)
param_dict = load_checkpoint(cfg.pre_training_ckpt)
load_param_into_net(netwithloss, param_dict)
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000)
netwithgrads = BertFinetuneCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
model = Model(netwithgrads)
model.train(cfg.epoch_num, dataset, callbacks=[LossCallBack(), ckpoint_cb])
D.release()
if __name__ == "__main__":
test_train()
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
config settings, will be used in finetune.py
"""
from easydict import EasyDict as edict
import mindspore.common.dtype as mstype
from mindspore.model_zoo.Bert_NEZHA import BertConfig
cfg = edict({
'task': 'NER',
'num_labels': 41,
'data_file': '/your/path/train.tfrecord',
'schema_file': '/your/path/schema.json',
'epoch_num': 5,
'ckpt_prefix': 'bert',
'ckpt_dir': None,
'pre_training_ckpt': '/your/path/pre_training.ckpt',
'use_crf': False,
'optimizer': 'Lamb',
'AdamWeightDecayDynamicLR': edict({
'learning_rate': 2e-5,
'end_learning_rate': 1e-7,
'power': 1.0,
'weight_decay': 1e-5,
'eps': 1e-6,
}),
'Lamb': edict({
'start_learning_rate': 2e-5,
'end_learning_rate': 1e-7,
'power': 1.0,
'decay_filter': lambda x: False,
}),
'Momentum': edict({
'learning_rate': 2e-5,
'momentum': 0.9,
}),
})
bert_net_cfg = BertConfig(
batch_size=16,
seq_length=128,
vocab_size=21128,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=False,
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16,
)
tag_to_index = {
"O": 0,
"S_address": 1,
"B_address": 2,
"M_address": 3,
"E_address": 4,
"S_book": 5,
"B_book": 6,
"M_book": 7,
"E_book": 8,
"S_company": 9,
"B_company": 10,
"M_company": 11,
"E_company": 12,
"S_game": 13,
"B_game": 14,
"M_game": 15,
"E_game": 16,
"S_government": 17,
"B_government": 18,
"M_government": 19,
"E_government": 20,
"S_movie": 21,
"B_movie": 22,
"M_movie": 23,
"E_movie": 24,
"S_name": 25,
"B_name": 26,
"M_name": 27,
"E_name": 28,
"S_organization": 29,
"B_organization": 30,
"M_organization": 31,
"E_organization": 32,
"S_position": 33,
"B_position": 34,
"M_position": 35,
"E_position": 36,
"S_scene": 37,
"B_scene": 38,
"M_scene": 39,
"E_scene": 40,
"<START>": 41,
"<STOP>": 42
}
......@@ -112,7 +112,8 @@ def run_pretrain():
end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate,
power=cfg.AdamWeightDecayDynamicLR.power,
weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay,
eps=cfg.AdamWeightDecayDynamicLR.eps)
eps=cfg.AdamWeightDecayDynamicLR.eps,
warmup_steps=cfg.AdamWeightDecayDynamicLR.warmup_steps)
else:
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecayDynamicLR]".
format(cfg.optimizer))
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""process txt"""
import re
import json
def process_one_example_p(tokenizer, text, max_seq_len=128):
"""process one testline"""
textlist = list(text)
tokens = []
for _, word in enumerate(textlist):
token = tokenizer.tokenize(word)
tokens.extend(token)
if len(tokens) >= max_seq_len - 1:
tokens = tokens[0:(max_seq_len - 2)]
ntokens = []
segment_ids = []
label_ids = []
ntokens.append("[CLS]")
segment_ids.append(0)
for _, token in enumerate(tokens):
ntokens.append(token)
segment_ids.append(0)
ntokens.append("[SEP]")
segment_ids.append(0)
input_ids = tokenizer.convert_tokens_to_ids(ntokens)
input_mask = [1] * len(input_ids)
while len(input_ids) < max_seq_len:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
label_ids.append(0)
ntokens.append("**NULL**")
assert len(input_ids) == max_seq_len
assert len(input_mask) == max_seq_len
assert len(segment_ids) == max_seq_len
feature = (input_ids, input_mask, segment_ids)
return feature
def label_generation(text, probs):
"""generate label"""
data = [text]
probs = [probs]
result = []
label2id = json.loads(open("./label2id.json").read())
id2label = [k for k, v in label2id.items()]
for index, prob in enumerate(probs):
for v in prob[1:len(data[index]) + 1]:
result.append(id2label[int(v)])
labels = {}
start = None
index = 0
for _, t in zip("".join(data), result):
if re.search("^[BS]", t):
if start is not None:
label = result[index - 1][2:]
if labels.get(label):
te_ = text[start:index]
labels[label][te_] = [[start, index - 1]]
else:
te_ = text[start:index]
labels[label] = {te_: [[start, index - 1]]}
start = index
if re.search("^O", t):
if start is not None:
label = result[index - 1][2:]
if labels.get(label):
te_ = text[start:index]
labels[label][te_] = [[start, index - 1]]
else:
te_ = text[start:index]
labels[label] = {te_: [[start, index - 1]]}
start = None
index += 1
if start is not None:
label = result[start][2:]
if labels.get(label):
te_ = text[start:index]
labels[label][te_] = [[start, index - 1]]
else:
te_ = text[start:index]
labels[label] = {te_: [[start, index - 1]]}
return labels
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''
Functional Cells used in Bert finetune and evaluation.
'''
import mindspore.nn as nn
from mindspore.common.initializer import TruncatedNormal
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common import dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode
from mindspore.communication.management import get_group_size
from mindspore import context
from mindspore.model_zoo.Bert_NEZHA.bert_model import BertModel
from mindspore.model_zoo.Bert_NEZHA.bert_for_pre_training import ClipGradients
from CRF import CRF
GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0
grad_scale = C.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal()
@grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
return grad * reciprocal(scale)
class BertFinetuneCell(nn.Cell):
"""
Especifically defined for finetuning where only four inputs tensor are needed.
"""
def __init__(self, network, optimizer, scale_update_cell=None):
super(BertFinetuneCell, self).__init__(auto_prefix=False)
self.network = network
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation('grad',
get_by_list=True,
sens_param=True)
self.reducer_flag = False
self.allreduce = P.AllReduce()
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = None
if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean")
degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.clip_gradients = ClipGradients()
self.cast = P.Cast()
self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus()
self.clear_before_grad = P.NPUClearFloatStatus()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
self.base = Tensor(1, mstype.float32)
self.less_equal = P.LessEqual()
self.hyper_map = C.HyperMap()
self.loss_scale = None
self.loss_scaling_manager = scale_update_cell
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
name="loss_scale")
def construct(self,
input_ids,
input_mask,
token_type_id,
label_ids,
sens=None):
weights = self.weights
init = self.alloc_status()
loss = self.network(input_ids,
input_mask,
token_type_id,
label_ids)
if sens is None:
scaling_sens = self.loss_scale
else:
scaling_sens = sens
grads = self.grad(self.network, weights)(input_ids,
input_mask,
token_type_id,
label_ids,
self.cast(scaling_sens,
mstype.float32))
clear_before_grad = self.clear_before_grad(init)
F.control_depend(loss, init)
self.depend_parameter_use(clear_before_grad, scaling_sens)
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
if self.reducer_flag:
grads = self.grad_reducer(grads)
flag = self.get_status(init)
flag_sum = self.reduce_sum(init, (0,))
if self.is_distributed:
flag_reduce = self.allreduce(flag_sum)
cond = self.less_equal(self.base, flag_reduce)
else:
cond = self.less_equal(self.base, flag_sum)
F.control_depend(grads, flag)
F.control_depend(flag, flag_sum)
overflow = cond
if sens is None:
overflow = self.loss_scaling_manager(self.loss_scale, cond)
if overflow:
succ = False
else:
succ = self.optimizer(grads)
ret = (loss, cond)
return F.depend(ret, succ)
class BertCLSModel(nn.Cell):
"""
This class is responsible for classification task evaluation, i.e. XNLI(num_labels=3),
LCQMC(num_labels=2), Chnsenti(num_labels=2). The returned output represents the final
logits as the results of log_softmax is propotional to that of softmax.
"""
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False):
super(BertCLSModel, self).__init__()
self.bert = BertModel(config, is_training, use_one_hot_embeddings)
self.cast = P.Cast()
self.weight_init = TruncatedNormal(config.initializer_range)
self.log_softmax = P.LogSoftmax(axis=-1)
self.dtype = config.dtype
self.num_labels = num_labels
self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init,
has_bias=True).to_float(config.compute_type)
self.dropout = nn.Dropout(1 - dropout_prob)
def construct(self, input_ids, input_mask, token_type_id):
_, pooled_output, _ = \
self.bert(input_ids, token_type_id, input_mask)
cls = self.cast(pooled_output, self.dtype)
cls = self.dropout(cls)
logits = self.dense_1(cls)
logits = self.cast(logits, self.dtype)
log_probs = self.log_softmax(logits)
return log_probs
class BertNERModel(nn.Cell):
"""
This class is responsible for sequence labeling task evaluation, i.e. NER(num_labels=11).
The returned output represents the final logits as the results of log_softmax is propotional to that of softmax.
"""
def __init__(self, config, is_training, num_labels=11, use_crf=False, dropout_prob=0.0,
use_one_hot_embeddings=False):
super(BertNERModel, self).__init__()
self.bert = BertModel(config, is_training, use_one_hot_embeddings)
self.cast = P.Cast()
self.weight_init = TruncatedNormal(config.initializer_range)
self.log_softmax = P.LogSoftmax(axis=-1)
self.dtype = config.dtype
self.num_labels = num_labels
self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init,
has_bias=True).to_float(config.compute_type)
self.dropout = nn.Dropout(1 - dropout_prob)
self.reshape = P.Reshape()
self.shape = (-1, config.hidden_size)
self.use_crf = use_crf
self.origin_shape = (config.batch_size, config.seq_length, self.num_labels)
def construct(self, input_ids, input_mask, token_type_id):
sequence_output, _, _ = \
self.bert(input_ids, token_type_id, input_mask)
seq = self.dropout(sequence_output)
seq = self.reshape(seq, self.shape)
logits = self.dense_1(seq)
logits = self.cast(logits, self.dtype)
if self.use_crf:
return_value = self.reshape(logits, self.origin_shape)
else:
return_value = self.log_softmax(logits)
return return_value
class CrossEntropyCalculation(nn.Cell):
"""
Cross Entropy loss
"""
def __init__(self, is_training=True):
super(CrossEntropyCalculation, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.reduce_sum = P.ReduceSum()
self.reduce_mean = P.ReduceMean()
self.reshape = P.Reshape()
self.last_idx = (-1,)
self.neg = P.Neg()
self.cast = P.Cast()
self.is_training = is_training
def construct(self, logits, label_ids, num_labels):
if self.is_training:
label_ids = self.reshape(label_ids, self.last_idx)
one_hot_labels = self.onehot(label_ids, num_labels, self.on_value, self.off_value)
per_example_loss = self.neg(self.reduce_sum(one_hot_labels * logits, self.last_idx))
loss = self.reduce_mean(per_example_loss, self.last_idx)
return_value = self.cast(loss, mstype.float32)
else:
return_value = logits * 1.0
return return_value
class BertCLS(nn.Cell):
"""
Train interface for classification finetuning task.
"""
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False):
super(BertCLS, self).__init__()
self.bert = BertCLSModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings)
self.loss = CrossEntropyCalculation(is_training)
self.num_labels = num_labels
def construct(self, input_ids, input_mask, token_type_id, label_ids):
log_probs = self.bert(input_ids, input_mask, token_type_id)
loss = self.loss(log_probs, label_ids, self.num_labels)
return loss
class BertNER(nn.Cell):
"""
Train interface for sequence labeling finetuning task.
"""
def __init__(self, config, is_training, num_labels=11, use_crf=False, tag_to_index=None, dropout_prob=0.0,
use_one_hot_embeddings=False):
super(BertNER, self).__init__()
self.bert = BertNERModel(config, is_training, num_labels, use_crf, dropout_prob, use_one_hot_embeddings)
if use_crf:
if not tag_to_index:
raise Exception("The dict for tag-index mapping should be provided for CRF.")
self.loss = CRF(tag_to_index, config.batch_size, config.seq_length, is_training)
else:
self.loss = CrossEntropyCalculation(is_training)
self.num_labels = num_labels
self.use_crf = use_crf
def construct(self, input_ids, input_mask, token_type_id, label_ids):
logits = self.bert(input_ids, input_mask, token_type_id)
if self.use_crf:
loss = self.loss(logits, label_ids)
else:
loss = self.loss(logits, label_ids, self.num_labels)
return loss
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册