“47aea0cdf8bb2487be3efd89e42d71bf81d30f18”上不存在“release/0.10.0/doc/tutorials/text_generation/index_en.html”
未验证 提交 03d651b4 编写于 作者: K kinghuin 提交者: GitHub

Optimize BigruCRF example (#5017)

* optimize lac

* formatted

* optimize lac

* optimize lac
上级 ad4720ec
......@@ -25,8 +25,8 @@
我们提供了少数样本用以示例输入数据格式。执行以下命令,下载并解压示例数据集:
```bash
wget --no-check-certificate https://baidu-nlp.bj.bcebos.com/lexical_analysis-dataset-2.0.0.tar.gz
tar xvf lexical_analysis-dataset-2.0.0.tar.gz
wget --no-check-certificate https://paddlenlp.bj.bcebos.com/data/lexical_analysis_dataset_tiny.tar.gz
tar xvf lexical_analysis_dataset_tiny.tar.gz
```
训练使用的数据可以由用户根据实际的应用场景,自己组织数据。除了第一行是 `text_a\tlabel` 固定的开头,后面的每行数据都是由两列组成,以制表符分隔,第一列是 utf-8 编码的中文文本,以 `\002` 分割,第二列是对应每个字的标注,以 `\002` 分隔。我们采用 IOB2 标注体系,即以 X-B 作为类型为 X 的词的开始,以 X-I 作为类型为 X 的词的持续,以 O 表示不关注的字(实际上,在词性、专名联合标注中,不存在 O )。示例如下:
......@@ -59,27 +59,21 @@ export CUDA_VISIBLE_DEVICES=0,1 # 支持多卡训练
```bash
python -m paddle.distributed.launch train.py \
--base_path ./data \
--word_dict_path ./conf/word.dic \
--label_dict_path ./conf/tag.dic \
--word_rep_dict_path ./conf/q2b.dic \
--root ./lexical_analysis_dataset_tiny \
--model_save_dir ./save_dir \
--epochs 10 \
--batch_size 32 \
--use_gpu True
```
其中 base_path 是数据集所在文件夹路径,word_dict_path 是输入文本的词典路径,label_dict_path 是标记标签的词典路径,word_rep_dict_path 是对输入文本中特殊字符进行转换的字典路径。
其中 root 是数据集所在文件夹路径。
### 2.4 模型评估
通过加载训练保存的模型,可以对测试集数据进行验证,启动方式如下:
```bash
python eval.py --base_path ./data \
--word_dict_path ./conf/word.dic \
--label_dict_path ./conf/tag.dic \
--word_rep_dict_path ./conf/q2b.dic \
python eval.py --root ./lexical_analysis_dataset_tiny \
--init_checkpoint ./save_dir/final \
--batch_size 32 \
--use_gpu True
......@@ -92,23 +86,12 @@ python eval.py --base_path ./data \
对无标签数据可以启动模型预测:
```bash
python predict.py --base_path ./data \
--word_dict_path ./conf/word.dic \
--label_dict_path ./conf/tag.dic \
--word_rep_dict_path ./conf/q2b.dic \
python predict.py --root ./lexical_analysis_dataset_tiny \
--init_checkpoint ./save_dir/final \
--batch_size 32 \
--use_gpu True
```
### 预训练模型
我们提供了在大规模数据集中预训练得到的模型:
| 模型 | Precision | Recall | F1-score |
| :--------------: | :-------: | :----: | :------: |
| [BiGRU + CRF](链接) | 89.2% | 89.4% | 89.3% |
### 如何贡献代码
......
 
、 ,
。 .
— -
~ ~
‖ |
… .
‘ '
’ '
“ "
” "
〔 (
〕 )
〈 <
〉 >
「 '
」 '
『 "
』 "
〖 [
〗 ]
【 [
】 ]
∶ :
$ $
! !
" "
# #
% %
& &
' '
( (
) )
* *
+ +
, ,
- -
. .
/ /
0 0
1 1
2 2
3 3
4 4
5 5
6 6
7 7
8 8
9 9
: :
; ;
< <
= =
> >
? ?
@ @
A a
B b
C c
D d
E e
F f
G g
H h
I i
J j
K k
L l
M m
N n
O o
P p
Q q
R r
S s
T t
U u
V v
W w
X x
Y y
Z z
[ [
\ \
] ]
^ ^
_ _
` `
a a
b b
c c
d d
e e
f f
g g
h h
i i
j j
k k
l l
m m
n n
o o
p p
q q
r r
s s
t t
u u
v v
w w
x x
y y
z z
{ {
| |
} }
 ̄ ~
〝 "
〞 "
﹐ ,
﹑ ,
﹒ .
﹔ ;
﹕ :
﹖ ?
﹗ !
﹙ (
﹚ )
﹛ {
﹜ {
﹝ [
﹞ ]
﹟ #
﹠ &
﹡ *
﹢ +
﹣ -
﹤ <
﹥ >
﹦ =
﹨ \
﹩ $
﹪ %
﹫ @
,
A a
B b
C c
D d
E e
F f
G g
H h
I i
J j
K k
L l
M m
N n
O o
P p
Q q
R r
S s
T t
U u
V v
W w
X x
Y y
Z z
0 a-B
1 a-I
2 ad-B
3 ad-I
4 an-B
5 an-I
6 c-B
7 c-I
8 d-B
9 d-I
10 f-B
11 f-I
12 m-B
13 m-I
14 n-B
15 n-I
16 nr-B
17 nr-I
18 ns-B
19 ns-I
20 nt-B
21 nt-I
22 nw-B
23 nw-I
24 nz-B
25 nz-I
26 p-B
27 p-I
28 q-B
29 q-I
30 r-B
31 r-I
32 s-B
33 s-I
34 t-B
35 t-I
36 u-B
37 u-I
38 v-B
39 v-I
40 vd-B
41 vd-I
42 vn-B
43 vn-I
44 w-B
45 w-I
46 xc-B
47 xc-I
48 PER-B
49 PER-I
50 LOC-B
51 LOC-I
52 ORG-B
53 ORG-I
54 TIME-B
55 TIME-I
56 O
......@@ -26,118 +26,6 @@ import numpy as np
CHAR_DELIMITER = "\002"
def load_kv_dict(dict_path,
delimiter="\t",
key_func=None,
value_func=None,
reverse=False):
"""
Load key-value dict from file
"""
vocab = {}
for line in open(dict_path, "r", encoding='utf8'):
terms = line.strip("\n").split(delimiter)
if len(terms) != 2:
continue
if reverse:
value, key = terms
else:
key, value = terms
if key in vocab:
raise KeyError("key duplicated with [%s]" % (key))
if key_func:
key = key_func(key)
if value_func:
value = value_func(value)
vocab[key] = value
return vocab
def convert_tokens_to_ids(tokens, vocab, oov_replace=None, token_replace=None):
"""convert tokens to token indexs"""
token_ids = []
oov_replace_token = vocab.get(oov_replace) if oov_replace else None
for token in tokens:
if token_replace:
token = token_replace.get(token, token)
token_id = vocab.get(token, oov_replace_token)
token_ids.append(token_id)
return token_ids
def batch_padding_fn(max_seq_len):
def pad_batch_to_max_seq_len(batch):
batch_max_seq_len = min(
max([len(sample[0]) for sample in batch]), max_seq_len)
batch_word_ids = []
batch_label_ids = []
batch_lens = []
for i, sample in enumerate(batch):
sample_word_ids = sample[0][:batch_max_seq_len]
sample_words_len = len(sample_word_ids)
sample_word_ids += [
0 for _ in range(batch_max_seq_len - sample_words_len)
]
batch_word_ids.append(sample_word_ids)
if len(sample) == 2:
sampel_label_ids = sample[1][:batch_max_seq_len] + [
0 for _ in range(batch_max_seq_len - sample_words_len)
]
batch_label_ids.append(sampel_label_ids)
batch_lens.append(np.int64(sample_words_len))
if batch_label_ids:
return batch_word_ids, batch_lens, batch_label_ids
else:
return batch_word_ids, batch_lens
return pad_batch_to_max_seq_len
def parse_lac_result(words, preds, lengths, word_vocab, label_vocab):
""" parse padding result """
batch_out = []
id2word_dict = dict(zip(word_vocab.values(), word_vocab.keys()))
id2label_dict = dict(zip(label_vocab.values(), label_vocab.keys()))
for sent_index in range(len(lengths)):
sent = [
id2word_dict[index]
for index in words[sent_index][:lengths[sent_index] - 1]
]
tags = [
id2label_dict[index]
for index in preds[sent_index][:lengths[sent_index] - 1]
]
sent_out = []
tags_out = []
parital_word = ""
for ind, tag in enumerate(tags):
# for the first word
if parital_word == "":
parital_word = sent[ind]
tags_out.append(tag.split('-')[0])
continue
# for the beginning of word
if tag.endswith("-B") or (tag == "O" and tags[ind - 1] != "O"):
sent_out.append(parital_word)
tags_out.append(tag.split('-')[0])
parital_word = sent[ind]
continue
parital_word += sent[ind]
# append the last word, except for len(tags)=0
if len(sent_out) < len(tags_out):
sent_out.append(parital_word)
batch_out.append([sent_out, tags_out])
return batch_out
class LacDataset(paddle.io.Dataset):
"""Load the dataset and convert all the texts to ids.
......@@ -149,17 +37,17 @@ class LacDataset(paddle.io.Dataset):
mode (str, optional): The load mode, "train", "test" or "infer". Defaults to 'train', meaning load the train dataset.
"""
def __init__(self,
base_path,
word_vocab,
label_vocab,
word_replace_dict,
mode='train'):
def __init__(self, base_path, mode='train'):
self.mode = mode
self.base_path = base_path
self.word_vocab = word_vocab
self.label_vocab = label_vocab
self.word_replace_dict = word_replace_dict
word_dict_path = os.path.join(self.base_path, 'word.dic')
label_dict_path = os.path.join(self.base_path, 'tag.dic')
word_rep_dict_path = os.path.join(self.base_path, 'q2b.dic')
self.word_vocab = self._load_kv_dict(
word_dict_path, value_func=np.int64, reverse=True)
self.label_vocab = self._load_kv_dict(
label_dict_path, value_func=np.int64, reverse=True)
self.word_replace_dict = self._load_kv_dict(word_rep_dict_path)
# Calculate vocab size and labels number, note: vocab value strats from 0.
self.vocab_size = max(self.word_vocab.values()) + 1
......@@ -179,9 +67,12 @@ class LacDataset(paddle.io.Dataset):
def __getitem__(self, index):
if self.mode == "infer":
return [self.word_ids[index]]
return [self.word_ids[index], len(self.word_ids[index])]
else:
return [self.word_ids[index], self.label_ids[index]]
return [
self.word_ids[index], len(self.word_ids[index]),
self.label_ids[index]
]
def _read_file(self):
self.word_ids = []
......@@ -198,7 +89,7 @@ class LacDataset(paddle.io.Dataset):
words, labels = line.split("\t")
words = words.split(CHAR_DELIMITER)
tmp_word_ids = convert_tokens_to_ids(
tmp_word_ids = self._convert_tokens_to_ids(
words,
self.word_vocab,
oov_replace="OOV",
......@@ -206,7 +97,7 @@ class LacDataset(paddle.io.Dataset):
self.word_ids.append(tmp_word_ids)
if self.mode != "infer":
tmp_label_ids = convert_tokens_to_ids(
tmp_label_ids = self._convert_tokens_to_ids(
labels.split(CHAR_DELIMITER),
self.label_vocab,
oov_replace="O")
......@@ -217,3 +108,88 @@ class LacDataset(paddle.io.Dataset):
tmp_word_ids, tmp_label_ids)
self.total += 1
def _load_kv_dict(self,
dict_path,
delimiter="\t",
key_func=None,
value_func=None,
reverse=False):
"""
Load key-value dict from file
"""
vocab = {}
for line in open(dict_path, "r", encoding='utf8'):
terms = line.strip("\n").split(delimiter)
if len(terms) != 2:
continue
if reverse:
value, key = terms
else:
key, value = terms
if key in vocab:
raise KeyError("key duplicated with [%s]" % (key))
if key_func:
key = key_func(key)
if value_func:
value = value_func(value)
vocab[key] = value
return vocab
def _convert_tokens_to_ids(self,
tokens,
vocab,
oov_replace=None,
token_replace=None):
"""convert tokens to token indexs"""
token_ids = []
oov_replace_token = vocab.get(oov_replace) if oov_replace else None
for token in tokens:
if token_replace:
token = token_replace.get(token, token)
token_id = vocab.get(token, oov_replace_token)
token_ids.append(token_id)
return token_ids
def parse_lac_result(words, preds, lengths, word_vocab, label_vocab):
""" parse padding result """
batch_out = []
id2word_dict = dict(zip(word_vocab.values(), word_vocab.keys()))
id2label_dict = dict(zip(label_vocab.values(), label_vocab.keys()))
for sent_index in range(len(lengths)):
sent = [
id2word_dict[index]
for index in words[sent_index][:lengths[sent_index] - 1]
]
tags = [
id2label_dict[index]
for index in preds[sent_index][:lengths[sent_index] - 1]
]
sent_out = []
tags_out = []
parital_word = ""
for ind, tag in enumerate(tags):
# for the first word
if parital_word == "":
parital_word = sent[ind]
tags_out.append(tag.split('-')[0])
continue
# for the beginning of word
if tag.endswith("-B") or (tag == "O" and tags[ind - 1] != "O"):
sent_out.append(parital_word)
tags_out.append(tag.split('-')[0])
parital_word = sent[ind]
continue
parital_word += sent[ind]
# append the last word, except for len(tags)=0
if len(sent_out) < len(tags_out):
sent_out.append(parital_word)
batch_out.append([sent_out, tags_out])
return batch_out
......@@ -20,33 +20,21 @@ import argparse
import paddle
import numpy as np
from paddlenlp.data import Pad, Tuple, Stack
from paddlenlp.metrics import ChunkEvaluator
from data import load_kv_dict, batch_padding_fn, LacDataset
from model import BiGruCrf, ViterbiDecoder, ChunkEvaluator
from data import LacDataset
from model import BiGruCrf
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--base_path", type=str, default=None,
help="The folder where the dataset is located.")
parser.add_argument("--word_dict_path", type=str, default=None,
help="The path of the word dictionary.")
parser.add_argument("--label_dict_path", type=str, default=None,
help="The path of the label dictionary.")
parser.add_argument("--word_rep_dict_path", type=str, default=None,
help="The path of the word replacement Dictionary")
parser.add_argument("--init_checkpoint", type=str, default=None,
help="Path to init model.")
parser.add_argument("--batch_size", type=int, default=300,
help="The number of sequences contained in a mini-batch.")
parser.add_argument("--max_seq_len", type=int, default=64,
help="Number of words of the longest seqence.")
parser.add_argument("--use_gpu", type=ast.literal_eval,
default=True, help="If set, use GPU for training.")
parser.add_argument("--emb_dim", type=int,
default=128,
help="The dimension in which a word is embedded.")
parser.add_argument("--hidden_size", type=int, default=128,
help="The number of hidden nodes in the GRU layer.")
parser.add_argument("--root", type=str, default=None, help="The folder where the dataset is located.")
parser.add_argument("--init_checkpoint", type=str, default=None, help="Path to init model.")
parser.add_argument("--batch_size", type=int, default=300, help="The number of sequences contained in a mini-batch.")
parser.add_argument("--max_seq_len", type=int, default=64, help="Number of words of the longest seqence.")
parser.add_argument("--use_gpu", type=ast.literal_eval, default=True, help="If set, use GPU for training.")
parser.add_argument("--emb_dim", type=int, default=128, help="The dimension in which a word is embedded.")
parser.add_argument("--hidden_size", type=int, default=128, help="The number of hidden nodes in the GRU layer.")
args = parser.parse_args()
# yapf: enable
......@@ -55,14 +43,13 @@ def evaluate(args):
place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace()
paddle.set_device("gpu" if args.use_gpu else "cpu")
# Load vocab to create dataset.
word_vocab = load_kv_dict(
args.word_dict_path, value_func=np.int64, reverse=True)
label_vocab = load_kv_dict(
args.label_dict_path, value_func=np.int64, reverse=True)
word_rep_dict = load_kv_dict(args.word_rep_dict_path)
test_dataset = LacDataset(
args.base_path, word_vocab, label_vocab, word_rep_dict, mode='test')
# create dataset.
test_dataset = LacDataset(args.root, mode='test')
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=0), # word_ids
Stack(), # length
Pad(axis=0, pad_val=0), # label_ids
): fn(samples)
# Create sampler for dataloader
test_sampler = paddle.io.BatchSampler(
......@@ -75,7 +62,7 @@ def evaluate(args):
batch_sampler=test_sampler,
places=place,
return_list=True,
collate_fn=batch_padding_fn(args.max_seq_len))
collate_fn=batchify_fn)
# Define the model network and metric evaluator
network = BiGruCrf(args.emb_dim, args.hidden_size, test_dataset.vocab_size,
......
......@@ -88,107 +88,3 @@ class BiGruCrf(nn.Layer):
emission = self.fc(bigru_output)
_, prediction = self.viterbi_decoder(emission, lengths)
return emission, lengths, prediction
class ChunkEvaluator(paddle.metric.Metric):
"""ChunkEvaluator computes the precision, recall and F1-score for chunk detection.
It is often used in sequence tagging tasks, such as Named Entity Recognition(NER).
Args:
num_chunk_types (int): The number of chunk types.
chunk_scheme (str): Indicate the tagging schemes used here. The value must
be IOB, IOE, IOBES or plain.
excluded_chunk_types (list, optional): Indicate the chunk types shouldn't
be taken into account. It should be a list of chunk type ids(integer).
Default None.
"""
def __init__(self, num_chunk_types, chunk_scheme,
excluded_chunk_types=None):
super(ChunkEvaluator, self).__init__()
self.num_chunk_types = num_chunk_types
self.chunk_scheme = chunk_scheme
self.excluded_chunk_types = excluded_chunk_types
self.num_infer_chunks = 0
self.num_label_chunks = 0
self.num_correct_chunks = 0
def compute(self, inputs, lengths, predictions, labels):
precision, recall, f1_score, num_infer_chunks, num_label_chunks, num_correct_chunks = paddle.metric.chunk_eval(
predictions,
labels,
chunk_scheme=self.chunk_scheme,
num_chunk_types=self.num_chunk_types,
excluded_chunk_types=self.excluded_chunk_types,
seq_length=lengths)
return num_infer_chunks, num_label_chunks, num_correct_chunks
def _is_number_or_matrix(self, var):
def _is_number_(var):
return isinstance(
var, int) or isinstance(var, np.int64) or isinstance(
var, float) or (isinstance(var, np.ndarray) and
var.shape == (1, ))
return _is_number_(var) or isinstance(var, np.ndarray)
def update(self, num_infer_chunks, num_label_chunks, num_correct_chunks):
"""
This function takes (num_infer_chunks, num_label_chunks, num_correct_chunks) as input,
to accumulate and update the corresponding status of the ChunkEvaluator object. The update method is as follows:
.. math::
\\\\ \\begin{array}{l}{\\text { self. num_infer_chunks }+=\\text { num_infer_chunks }} \\\\ {\\text { self. num_Label_chunks }+=\\text { num_label_chunks }} \\\\ {\\text { self. num_correct_chunks }+=\\text { num_correct_chunks }}\\end{array} \\\\
Args:
num_infer_chunks(int|numpy.array): The number of chunks in Inference on the given minibatch.
num_label_chunks(int|numpy.array): The number of chunks in Label on the given mini-batch.
num_correct_chunks(int|float|numpy.array): The number of chunks both in Inference and Label on the
given mini-batch.
"""
if not self._is_number_or_matrix(num_infer_chunks):
raise ValueError(
"The 'num_infer_chunks' must be a number(int) or a numpy ndarray."
)
if not self._is_number_or_matrix(num_label_chunks):
raise ValueError(
"The 'num_label_chunks' must be a number(int, float) or a numpy ndarray."
)
if not self._is_number_or_matrix(num_correct_chunks):
raise ValueError(
"The 'num_correct_chunks' must be a number(int, float) or a numpy ndarray."
)
self.num_infer_chunks += num_infer_chunks
self.num_label_chunks += num_label_chunks
self.num_correct_chunks += num_correct_chunks
def accumulate(self):
"""
This function returns the mean precision, recall and f1 score for all accumulated minibatches.
Returns:
float: mean precision, recall and f1 score.
"""
precision = float(
self.num_correct_chunks
) / self.num_infer_chunks if self.num_infer_chunks else 0
recall = float(self.num_correct_chunks
) / self.num_label_chunks if self.num_label_chunks else 0
f1_score = float(2 * precision * recall) / (
precision + recall) if self.num_correct_chunks else 0
return precision, recall, f1_score
def reset(self):
"""
Reset function empties the evaluation memory for previous mini-batches.
"""
self.num_infer_chunks = 0
self.num_label_chunks = 0
self.num_correct_chunks = 0
def name(self):
"""
Return name of metric instance.
"""
return "precision", "recall", "f1"
......@@ -19,16 +19,15 @@ import argparse
import numpy as np
import paddle
from paddlenlp.data import Pad, Tuple, Stack
from paddlenlp.metrics import ChunkEvaluator
from data import load_kv_dict, batch_padding_fn, LacDataset, parse_lac_result
from model import BiGruCrf, ViterbiDecoder, ChunkEvaluator
from data import LacDataset, parse_lac_result
from model import BiGruCrf
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--base_path", type=str, default=None, help="The folder where the dataset is located.")
parser.add_argument("--word_dict_path", type=str, default=None, help="The path of the word dictionary.")
parser.add_argument("--label_dict_path", type=str, default=None, help="The path of the label dictionary.")
parser.add_argument("--word_rep_dict_path", type=str, default=None, help="The path of the word replacement Dictionary")
parser.add_argument("--root", type=str, default=None, help="The folder where the dataset is located.")
parser.add_argument("--init_checkpoint", type=str, default=None, help="Path to init model.")
parser.add_argument("--batch_size", type=int, default=300, help="The number of sequences contained in a mini-batch.")
parser.add_argument("--max_seq_len", type=int, default=64, help="Number of words of the longest seqence.")
......@@ -43,14 +42,13 @@ def infer(args):
place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace()
paddle.set_device("gpu" if args.use_gpu else "cpu")
# Load vocab to create dataset.
word_vocab = load_kv_dict(
args.word_dict_path, value_func=np.int64, reverse=True)
label_vocab = load_kv_dict(
args.label_dict_path, value_func=np.int64, reverse=True)
word_rep_dict = load_kv_dict(args.word_rep_dict_path)
infer_dataset = LacDataset(
args.base_path, word_vocab, label_vocab, word_rep_dict, mode='infer')
# create dataset.
infer_dataset = LacDataset(args.root, mode='infer')
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=0), # word_ids
Stack(), # length
): fn(samples)
# Create sampler for dataloader
infer_sampler = paddle.io.BatchSampler(
......@@ -63,7 +61,7 @@ def infer(args):
batch_sampler=infer_sampler,
places=place,
return_list=True,
collate_fn=batch_padding_fn(args.max_seq_len))
collate_fn=batchify_fn)
# Define the model network
network = BiGruCrf(args.emb_dim, args.hidden_size, infer_dataset.vocab_size,
......@@ -82,7 +80,8 @@ def infer(args):
[pred for batch_pred in crf_decodes for pred in batch_pred])
results = parse_lac_result(infer_dataset.word_ids, preds, lengths,
word_vocab, label_vocab)
infer_dataset.word_vocab,
infer_dataset.label_vocab)
sent_tags = []
for sent, tags in results:
......
......@@ -20,17 +20,15 @@ import argparse
import numpy as np
import paddle
from data import load_kv_dict, batch_padding_fn, LacDataset
from data import LacDataset
from model import BiGruCrf
from paddlenlp.data import Pad, Tuple, Stack
from paddlenlp.layers.crf import LinearChainCrfLoss, ViterbiDecoder
from paddlenlp.metrics.chunk_evaluator import ChunkEvaluator
from paddlenlp.metrics import ChunkEvaluator
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--base_path", type=str, default=None, help="The folder where the dataset is located.")
parser.add_argument("--word_dict_path", type=str, default=None, help="The path of the word dictionary.")
parser.add_argument("--label_dict_path", type=str, default=None, help="The path of the label dictionary.")
parser.add_argument("--word_rep_dict_path", type=str, default=None, help="The path of the word replacement Dictionary")
parser.add_argument("--root", type=str, default=None, help="The folder where the dataset is located.")
parser.add_argument("--init_checkpoint", type=str, default=None, help="Path to init model.")
parser.add_argument("--model_save_dir", type=str, default=None, help="The model will be saved in this path.")
parser.add_argument("--epochs", type=int, default=10, help="Corpus iteration num.")
......@@ -52,16 +50,15 @@ def train(args):
place = paddle.CPUPlace()
paddle.set_device("cpu")
# Load vocab to create dataset.
word_vocab = load_kv_dict(
args.word_dict_path, value_func=np.int64, reverse=True)
label_vocab = load_kv_dict(
args.label_dict_path, value_func=np.int64, reverse=True)
word_rep_dict = load_kv_dict(args.word_rep_dict_path)
train_dataset = LacDataset(
args.base_path, word_vocab, label_vocab, word_rep_dict, mode='train')
test_dataset = LacDataset(
args.base_path, word_vocab, label_vocab, word_rep_dict, mode='test')
# create dataset.
train_dataset = LacDataset(args.root, mode='train')
test_dataset = LacDataset(args.root, mode='test')
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=0), # word_ids
Stack(), # length
Pad(axis=0, pad_val=0), # label_ids
): fn(samples)
# Create sampler for dataloader
train_sampler = paddle.io.DistributedBatchSampler(
......@@ -74,7 +71,7 @@ def train(args):
batch_sampler=train_sampler,
places=place,
return_list=True,
collate_fn=batch_padding_fn(args.max_seq_len))
collate_fn=batchify_fn)
test_sampler = paddle.io.BatchSampler(
dataset=test_dataset,
......@@ -86,7 +83,7 @@ def train(args):
batch_sampler=test_sampler,
places=place,
return_list=True,
collate_fn=batch_padding_fn(args.max_seq_len))
collate_fn=batchify_fn)
# Define the model netword and its loss
network = BiGruCrf(args.emb_dim, args.hidden_size, train_dataset.vocab_size,
......@@ -105,17 +102,18 @@ def train(args):
model.load(args.init_checkpoint)
# Start training
callback = paddle.callbacks.ProgBarLogger(log_freq=10, verbose=3)
model.fit(train_data=train_loader,
eval_data=test_loader,
batch_size=args.batch_size,
epochs=args.epochs,
eval_freq=1,
log_freq=1,
log_freq=10,
save_dir=args.model_save_dir,
save_freq=1,
verbose=2,
drop_last=True,
shuffle=True)
shuffle=True,
callbacks=callback)
if __name__ == "__main__":
......
......@@ -113,12 +113,6 @@ def parse_args():
return args
def set_seed(args):
random.seed(args.seed + paddle.distributed.get_rank())
np.random.seed(args.seed + paddle.distributed.get_rank())
paddle.seed(args.seed + paddle.distributed.get_rank())
def evaluate(model, loss_fct, metric, data_loader, label_num):
model.eval()
metric.reset()
......@@ -251,8 +245,6 @@ def do_train(args):
if paddle.distributed.get_world_size() > 1:
paddle.distributed.init_parallel_env()
set_seed(args)
train_dataset, dev_dataset = ppnlp.datasets.MSRA_NER.get_datasets(
["train", "dev"])
tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path)
......
......@@ -49,7 +49,6 @@ class LinearChainCrf(nn.Layer):
dtype='float32')
self.with_start_stop_tag = with_start_stop_tag
self.max_seq_len = 0
self._initial_alpha = None
self._start_tensor = None
self._stop_tensor = None
......@@ -226,16 +225,14 @@ class LinearChainCrf(nn.Layer):
return self._batch_index
def _get_seq_index(self, length):
if self._seq_index is None or length > self.max_seq_len:
self.max_seq_len = length
self._seq_index = paddle.arange(end=self.max_seq_len, dtype="int64")
if self._seq_index is None or length > self._seq_index.shape[0]:
self._seq_index = paddle.arange(end=length, dtype="int64")
return self._seq_index[:length]
def _get_batch_seq_index(self, batch_size, length):
if self._batch_seq_index is None or length > self.max_seq_len:
self.max_seq_len = length
if self._batch_seq_index is None or length > self._batch_seq_index.shape[1]:
self._batch_seq_index = paddle.cumsum(
paddle.ones([batch_size, self.max_seq_len + 2], "int64"),
paddle.ones([batch_size, length + 2], "int64"),
axis=1) - 1
if self.with_start_stop_tag:
return self._batch_seq_index[:, :length + 2]
......
......@@ -239,7 +239,7 @@ class BertPretrainedModel(PretrainedModel):
"bert-base-chinese":
"http://paddlenlp.bj.bcebos.com/models/transformers/bert/bert-base-chinese.pdparams",
"bert-base-multilingual-cased":
"http://paddlenlp.bj.bcebos.com/models/transformers/bert/bert-base-multilingual-cased.pdparamss",
"http://paddlenlp.bj.bcebos.com/models/transformers/bert/bert-base-multilingual-cased.pdparams",
"bert-large-cased":
"http://paddlenlp.bj.bcebos.com/models/transformers/bert/bert-large-cased.pdparams",
"bert-wwm-chinese":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册