未验证 提交 03d651b4 编写于 作者: K kinghuin 提交者: GitHub

Optimize BigruCRF example (#5017)

* optimize lac

* formatted

* optimize lac

* optimize lac
上级 ad4720ec
...@@ -25,8 +25,8 @@ ...@@ -25,8 +25,8 @@
我们提供了少数样本用以示例输入数据格式。执行以下命令,下载并解压示例数据集: 我们提供了少数样本用以示例输入数据格式。执行以下命令,下载并解压示例数据集:
```bash ```bash
wget --no-check-certificate https://baidu-nlp.bj.bcebos.com/lexical_analysis-dataset-2.0.0.tar.gz wget --no-check-certificate https://paddlenlp.bj.bcebos.com/data/lexical_analysis_dataset_tiny.tar.gz
tar xvf lexical_analysis-dataset-2.0.0.tar.gz tar xvf lexical_analysis_dataset_tiny.tar.gz
``` ```
训练使用的数据可以由用户根据实际的应用场景,自己组织数据。除了第一行是 `text_a\tlabel` 固定的开头,后面的每行数据都是由两列组成,以制表符分隔,第一列是 utf-8 编码的中文文本,以 `\002` 分割,第二列是对应每个字的标注,以 `\002` 分隔。我们采用 IOB2 标注体系,即以 X-B 作为类型为 X 的词的开始,以 X-I 作为类型为 X 的词的持续,以 O 表示不关注的字(实际上,在词性、专名联合标注中,不存在 O )。示例如下: 训练使用的数据可以由用户根据实际的应用场景,自己组织数据。除了第一行是 `text_a\tlabel` 固定的开头,后面的每行数据都是由两列组成,以制表符分隔,第一列是 utf-8 编码的中文文本,以 `\002` 分割,第二列是对应每个字的标注,以 `\002` 分隔。我们采用 IOB2 标注体系,即以 X-B 作为类型为 X 的词的开始,以 X-I 作为类型为 X 的词的持续,以 O 表示不关注的字(实际上,在词性、专名联合标注中,不存在 O )。示例如下:
...@@ -59,27 +59,21 @@ export CUDA_VISIBLE_DEVICES=0,1 # 支持多卡训练 ...@@ -59,27 +59,21 @@ export CUDA_VISIBLE_DEVICES=0,1 # 支持多卡训练
```bash ```bash
python -m paddle.distributed.launch train.py \ python -m paddle.distributed.launch train.py \
--base_path ./data \ --root ./lexical_analysis_dataset_tiny \
--word_dict_path ./conf/word.dic \
--label_dict_path ./conf/tag.dic \
--word_rep_dict_path ./conf/q2b.dic \
--model_save_dir ./save_dir \ --model_save_dir ./save_dir \
--epochs 10 \ --epochs 10 \
--batch_size 32 \ --batch_size 32 \
--use_gpu True --use_gpu True
``` ```
其中 base_path 是数据集所在文件夹路径,word_dict_path 是输入文本的词典路径,label_dict_path 是标记标签的词典路径,word_rep_dict_path 是对输入文本中特殊字符进行转换的字典路径。 其中 root 是数据集所在文件夹路径。
### 2.4 模型评估 ### 2.4 模型评估
通过加载训练保存的模型,可以对测试集数据进行验证,启动方式如下: 通过加载训练保存的模型,可以对测试集数据进行验证,启动方式如下:
```bash ```bash
python eval.py --base_path ./data \ python eval.py --root ./lexical_analysis_dataset_tiny \
--word_dict_path ./conf/word.dic \
--label_dict_path ./conf/tag.dic \
--word_rep_dict_path ./conf/q2b.dic \
--init_checkpoint ./save_dir/final \ --init_checkpoint ./save_dir/final \
--batch_size 32 \ --batch_size 32 \
--use_gpu True --use_gpu True
...@@ -92,23 +86,12 @@ python eval.py --base_path ./data \ ...@@ -92,23 +86,12 @@ python eval.py --base_path ./data \
对无标签数据可以启动模型预测: 对无标签数据可以启动模型预测:
```bash ```bash
python predict.py --base_path ./data \ python predict.py --root ./lexical_analysis_dataset_tiny \
--word_dict_path ./conf/word.dic \
--label_dict_path ./conf/tag.dic \
--word_rep_dict_path ./conf/q2b.dic \
--init_checkpoint ./save_dir/final \ --init_checkpoint ./save_dir/final \
--batch_size 32 \ --batch_size 32 \
--use_gpu True --use_gpu True
``` ```
### 预训练模型
我们提供了在大规模数据集中预训练得到的模型:
| 模型 | Precision | Recall | F1-score |
| :--------------: | :-------: | :----: | :------: |
| [BiGRU + CRF](链接) | 89.2% | 89.4% | 89.3% |
### 如何贡献代码 ### 如何贡献代码
......
 
、 ,
。 .
— -
~ ~
‖ |
… .
‘ '
’ '
“ "
” "
〔 (
〕 )
〈 <
〉 >
「 '
」 '
『 "
』 "
〖 [
〗 ]
【 [
】 ]
∶ :
$ $
! !
" "
# #
% %
& &
' '
( (
) )
* *
+ +
, ,
- -
. .
/ /
0 0
1 1
2 2
3 3
4 4
5 5
6 6
7 7
8 8
9 9
: :
; ;
< <
= =
> >
? ?
@ @
A a
B b
C c
D d
E e
F f
G g
H h
I i
J j
K k
L l
M m
N n
O o
P p
Q q
R r
S s
T t
U u
V v
W w
X x
Y y
Z z
[ [
\ \
] ]
^ ^
_ _
` `
a a
b b
c c
d d
e e
f f
g g
h h
i i
j j
k k
l l
m m
n n
o o
p p
q q
r r
s s
t t
u u
v v
w w
x x
y y
z z
{ {
| |
} }
 ̄ ~
〝 "
〞 "
﹐ ,
﹑ ,
﹒ .
﹔ ;
﹕ :
﹖ ?
﹗ !
﹙ (
﹚ )
﹛ {
﹜ {
﹝ [
﹞ ]
﹟ #
﹠ &
﹡ *
﹢ +
﹣ -
﹤ <
﹥ >
﹦ =
﹨ \
﹩ $
﹪ %
﹫ @
,
A a
B b
C c
D d
E e
F f
G g
H h
I i
J j
K k
L l
M m
N n
O o
P p
Q q
R r
S s
T t
U u
V v
W w
X x
Y y
Z z
0 a-B
1 a-I
2 ad-B
3 ad-I
4 an-B
5 an-I
6 c-B
7 c-I
8 d-B
9 d-I
10 f-B
11 f-I
12 m-B
13 m-I
14 n-B
15 n-I
16 nr-B
17 nr-I
18 ns-B
19 ns-I
20 nt-B
21 nt-I
22 nw-B
23 nw-I
24 nz-B
25 nz-I
26 p-B
27 p-I
28 q-B
29 q-I
30 r-B
31 r-I
32 s-B
33 s-I
34 t-B
35 t-I
36 u-B
37 u-I
38 v-B
39 v-I
40 vd-B
41 vd-I
42 vn-B
43 vn-I
44 w-B
45 w-I
46 xc-B
47 xc-I
48 PER-B
49 PER-I
50 LOC-B
51 LOC-I
52 ORG-B
53 ORG-I
54 TIME-B
55 TIME-I
56 O
...@@ -26,118 +26,6 @@ import numpy as np ...@@ -26,118 +26,6 @@ import numpy as np
CHAR_DELIMITER = "\002" CHAR_DELIMITER = "\002"
def load_kv_dict(dict_path,
delimiter="\t",
key_func=None,
value_func=None,
reverse=False):
"""
Load key-value dict from file
"""
vocab = {}
for line in open(dict_path, "r", encoding='utf8'):
terms = line.strip("\n").split(delimiter)
if len(terms) != 2:
continue
if reverse:
value, key = terms
else:
key, value = terms
if key in vocab:
raise KeyError("key duplicated with [%s]" % (key))
if key_func:
key = key_func(key)
if value_func:
value = value_func(value)
vocab[key] = value
return vocab
def convert_tokens_to_ids(tokens, vocab, oov_replace=None, token_replace=None):
"""convert tokens to token indexs"""
token_ids = []
oov_replace_token = vocab.get(oov_replace) if oov_replace else None
for token in tokens:
if token_replace:
token = token_replace.get(token, token)
token_id = vocab.get(token, oov_replace_token)
token_ids.append(token_id)
return token_ids
def batch_padding_fn(max_seq_len):
def pad_batch_to_max_seq_len(batch):
batch_max_seq_len = min(
max([len(sample[0]) for sample in batch]), max_seq_len)
batch_word_ids = []
batch_label_ids = []
batch_lens = []
for i, sample in enumerate(batch):
sample_word_ids = sample[0][:batch_max_seq_len]
sample_words_len = len(sample_word_ids)
sample_word_ids += [
0 for _ in range(batch_max_seq_len - sample_words_len)
]
batch_word_ids.append(sample_word_ids)
if len(sample) == 2:
sampel_label_ids = sample[1][:batch_max_seq_len] + [
0 for _ in range(batch_max_seq_len - sample_words_len)
]
batch_label_ids.append(sampel_label_ids)
batch_lens.append(np.int64(sample_words_len))
if batch_label_ids:
return batch_word_ids, batch_lens, batch_label_ids
else:
return batch_word_ids, batch_lens
return pad_batch_to_max_seq_len
def parse_lac_result(words, preds, lengths, word_vocab, label_vocab):
""" parse padding result """
batch_out = []
id2word_dict = dict(zip(word_vocab.values(), word_vocab.keys()))
id2label_dict = dict(zip(label_vocab.values(), label_vocab.keys()))
for sent_index in range(len(lengths)):
sent = [
id2word_dict[index]
for index in words[sent_index][:lengths[sent_index] - 1]
]
tags = [
id2label_dict[index]
for index in preds[sent_index][:lengths[sent_index] - 1]
]
sent_out = []
tags_out = []
parital_word = ""
for ind, tag in enumerate(tags):
# for the first word
if parital_word == "":
parital_word = sent[ind]
tags_out.append(tag.split('-')[0])
continue
# for the beginning of word
if tag.endswith("-B") or (tag == "O" and tags[ind - 1] != "O"):
sent_out.append(parital_word)
tags_out.append(tag.split('-')[0])
parital_word = sent[ind]
continue
parital_word += sent[ind]
# append the last word, except for len(tags)=0
if len(sent_out) < len(tags_out):
sent_out.append(parital_word)
batch_out.append([sent_out, tags_out])
return batch_out
class LacDataset(paddle.io.Dataset): class LacDataset(paddle.io.Dataset):
"""Load the dataset and convert all the texts to ids. """Load the dataset and convert all the texts to ids.
...@@ -149,17 +37,17 @@ class LacDataset(paddle.io.Dataset): ...@@ -149,17 +37,17 @@ class LacDataset(paddle.io.Dataset):
mode (str, optional): The load mode, "train", "test" or "infer". Defaults to 'train', meaning load the train dataset. mode (str, optional): The load mode, "train", "test" or "infer". Defaults to 'train', meaning load the train dataset.
""" """
def __init__(self, def __init__(self, base_path, mode='train'):
base_path,
word_vocab,
label_vocab,
word_replace_dict,
mode='train'):
self.mode = mode self.mode = mode
self.base_path = base_path self.base_path = base_path
self.word_vocab = word_vocab word_dict_path = os.path.join(self.base_path, 'word.dic')
self.label_vocab = label_vocab label_dict_path = os.path.join(self.base_path, 'tag.dic')
self.word_replace_dict = word_replace_dict word_rep_dict_path = os.path.join(self.base_path, 'q2b.dic')
self.word_vocab = self._load_kv_dict(
word_dict_path, value_func=np.int64, reverse=True)
self.label_vocab = self._load_kv_dict(
label_dict_path, value_func=np.int64, reverse=True)
self.word_replace_dict = self._load_kv_dict(word_rep_dict_path)
# Calculate vocab size and labels number, note: vocab value strats from 0. # Calculate vocab size and labels number, note: vocab value strats from 0.
self.vocab_size = max(self.word_vocab.values()) + 1 self.vocab_size = max(self.word_vocab.values()) + 1
...@@ -179,9 +67,12 @@ class LacDataset(paddle.io.Dataset): ...@@ -179,9 +67,12 @@ class LacDataset(paddle.io.Dataset):
def __getitem__(self, index): def __getitem__(self, index):
if self.mode == "infer": if self.mode == "infer":
return [self.word_ids[index]] return [self.word_ids[index], len(self.word_ids[index])]
else: else:
return [self.word_ids[index], self.label_ids[index]] return [
self.word_ids[index], len(self.word_ids[index]),
self.label_ids[index]
]
def _read_file(self): def _read_file(self):
self.word_ids = [] self.word_ids = []
...@@ -198,7 +89,7 @@ class LacDataset(paddle.io.Dataset): ...@@ -198,7 +89,7 @@ class LacDataset(paddle.io.Dataset):
words, labels = line.split("\t") words, labels = line.split("\t")
words = words.split(CHAR_DELIMITER) words = words.split(CHAR_DELIMITER)
tmp_word_ids = convert_tokens_to_ids( tmp_word_ids = self._convert_tokens_to_ids(
words, words,
self.word_vocab, self.word_vocab,
oov_replace="OOV", oov_replace="OOV",
...@@ -206,7 +97,7 @@ class LacDataset(paddle.io.Dataset): ...@@ -206,7 +97,7 @@ class LacDataset(paddle.io.Dataset):
self.word_ids.append(tmp_word_ids) self.word_ids.append(tmp_word_ids)
if self.mode != "infer": if self.mode != "infer":
tmp_label_ids = convert_tokens_to_ids( tmp_label_ids = self._convert_tokens_to_ids(
labels.split(CHAR_DELIMITER), labels.split(CHAR_DELIMITER),
self.label_vocab, self.label_vocab,
oov_replace="O") oov_replace="O")
...@@ -217,3 +108,88 @@ class LacDataset(paddle.io.Dataset): ...@@ -217,3 +108,88 @@ class LacDataset(paddle.io.Dataset):
tmp_word_ids, tmp_label_ids) tmp_word_ids, tmp_label_ids)
self.total += 1 self.total += 1
def _load_kv_dict(self,
dict_path,
delimiter="\t",
key_func=None,
value_func=None,
reverse=False):
"""
Load key-value dict from file
"""
vocab = {}
for line in open(dict_path, "r", encoding='utf8'):
terms = line.strip("\n").split(delimiter)
if len(terms) != 2:
continue
if reverse:
value, key = terms
else:
key, value = terms
if key in vocab:
raise KeyError("key duplicated with [%s]" % (key))
if key_func:
key = key_func(key)
if value_func:
value = value_func(value)
vocab[key] = value
return vocab
def _convert_tokens_to_ids(self,
tokens,
vocab,
oov_replace=None,
token_replace=None):
"""convert tokens to token indexs"""
token_ids = []
oov_replace_token = vocab.get(oov_replace) if oov_replace else None
for token in tokens:
if token_replace:
token = token_replace.get(token, token)
token_id = vocab.get(token, oov_replace_token)
token_ids.append(token_id)
return token_ids
def parse_lac_result(words, preds, lengths, word_vocab, label_vocab):
""" parse padding result """
batch_out = []
id2word_dict = dict(zip(word_vocab.values(), word_vocab.keys()))
id2label_dict = dict(zip(label_vocab.values(), label_vocab.keys()))
for sent_index in range(len(lengths)):
sent = [
id2word_dict[index]
for index in words[sent_index][:lengths[sent_index] - 1]
]
tags = [
id2label_dict[index]
for index in preds[sent_index][:lengths[sent_index] - 1]
]
sent_out = []
tags_out = []
parital_word = ""
for ind, tag in enumerate(tags):
# for the first word
if parital_word == "":
parital_word = sent[ind]
tags_out.append(tag.split('-')[0])
continue
# for the beginning of word
if tag.endswith("-B") or (tag == "O" and tags[ind - 1] != "O"):
sent_out.append(parital_word)
tags_out.append(tag.split('-')[0])
parital_word = sent[ind]
continue
parital_word += sent[ind]
# append the last word, except for len(tags)=0
if len(sent_out) < len(tags_out):
sent_out.append(parital_word)
batch_out.append([sent_out, tags_out])
return batch_out
...@@ -20,33 +20,21 @@ import argparse ...@@ -20,33 +20,21 @@ import argparse
import paddle import paddle
import numpy as np import numpy as np
from paddlenlp.data import Pad, Tuple, Stack
from paddlenlp.metrics import ChunkEvaluator
from data import load_kv_dict, batch_padding_fn, LacDataset from data import LacDataset
from model import BiGruCrf, ViterbiDecoder, ChunkEvaluator from model import BiGruCrf
# yapf: disable # yapf: disable
parser = argparse.ArgumentParser(__doc__) parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--base_path", type=str, default=None, parser.add_argument("--root", type=str, default=None, help="The folder where the dataset is located.")
help="The folder where the dataset is located.") parser.add_argument("--init_checkpoint", type=str, default=None, help="Path to init model.")
parser.add_argument("--word_dict_path", type=str, default=None, parser.add_argument("--batch_size", type=int, default=300, help="The number of sequences contained in a mini-batch.")
help="The path of the word dictionary.") parser.add_argument("--max_seq_len", type=int, default=64, help="Number of words of the longest seqence.")
parser.add_argument("--label_dict_path", type=str, default=None, parser.add_argument("--use_gpu", type=ast.literal_eval, default=True, help="If set, use GPU for training.")
help="The path of the label dictionary.") parser.add_argument("--emb_dim", type=int, default=128, help="The dimension in which a word is embedded.")
parser.add_argument("--word_rep_dict_path", type=str, default=None, parser.add_argument("--hidden_size", type=int, default=128, help="The number of hidden nodes in the GRU layer.")
help="The path of the word replacement Dictionary")
parser.add_argument("--init_checkpoint", type=str, default=None,
help="Path to init model.")
parser.add_argument("--batch_size", type=int, default=300,
help="The number of sequences contained in a mini-batch.")
parser.add_argument("--max_seq_len", type=int, default=64,
help="Number of words of the longest seqence.")
parser.add_argument("--use_gpu", type=ast.literal_eval,
default=True, help="If set, use GPU for training.")
parser.add_argument("--emb_dim", type=int,
default=128,
help="The dimension in which a word is embedded.")
parser.add_argument("--hidden_size", type=int, default=128,
help="The number of hidden nodes in the GRU layer.")
args = parser.parse_args() args = parser.parse_args()
# yapf: enable # yapf: enable
...@@ -55,14 +43,13 @@ def evaluate(args): ...@@ -55,14 +43,13 @@ def evaluate(args):
place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace() place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace()
paddle.set_device("gpu" if args.use_gpu else "cpu") paddle.set_device("gpu" if args.use_gpu else "cpu")
# Load vocab to create dataset. # create dataset.
word_vocab = load_kv_dict( test_dataset = LacDataset(args.root, mode='test')
args.word_dict_path, value_func=np.int64, reverse=True) batchify_fn = lambda samples, fn=Tuple(
label_vocab = load_kv_dict( Pad(axis=0, pad_val=0), # word_ids
args.label_dict_path, value_func=np.int64, reverse=True) Stack(), # length
word_rep_dict = load_kv_dict(args.word_rep_dict_path) Pad(axis=0, pad_val=0), # label_ids
test_dataset = LacDataset( ): fn(samples)
args.base_path, word_vocab, label_vocab, word_rep_dict, mode='test')
# Create sampler for dataloader # Create sampler for dataloader
test_sampler = paddle.io.BatchSampler( test_sampler = paddle.io.BatchSampler(
...@@ -75,7 +62,7 @@ def evaluate(args): ...@@ -75,7 +62,7 @@ def evaluate(args):
batch_sampler=test_sampler, batch_sampler=test_sampler,
places=place, places=place,
return_list=True, return_list=True,
collate_fn=batch_padding_fn(args.max_seq_len)) collate_fn=batchify_fn)
# Define the model network and metric evaluator # Define the model network and metric evaluator
network = BiGruCrf(args.emb_dim, args.hidden_size, test_dataset.vocab_size, network = BiGruCrf(args.emb_dim, args.hidden_size, test_dataset.vocab_size,
......
...@@ -88,107 +88,3 @@ class BiGruCrf(nn.Layer): ...@@ -88,107 +88,3 @@ class BiGruCrf(nn.Layer):
emission = self.fc(bigru_output) emission = self.fc(bigru_output)
_, prediction = self.viterbi_decoder(emission, lengths) _, prediction = self.viterbi_decoder(emission, lengths)
return emission, lengths, prediction return emission, lengths, prediction
class ChunkEvaluator(paddle.metric.Metric):
"""ChunkEvaluator computes the precision, recall and F1-score for chunk detection.
It is often used in sequence tagging tasks, such as Named Entity Recognition(NER).
Args:
num_chunk_types (int): The number of chunk types.
chunk_scheme (str): Indicate the tagging schemes used here. The value must
be IOB, IOE, IOBES or plain.
excluded_chunk_types (list, optional): Indicate the chunk types shouldn't
be taken into account. It should be a list of chunk type ids(integer).
Default None.
"""
def __init__(self, num_chunk_types, chunk_scheme,
excluded_chunk_types=None):
super(ChunkEvaluator, self).__init__()
self.num_chunk_types = num_chunk_types
self.chunk_scheme = chunk_scheme
self.excluded_chunk_types = excluded_chunk_types
self.num_infer_chunks = 0
self.num_label_chunks = 0
self.num_correct_chunks = 0
def compute(self, inputs, lengths, predictions, labels):
precision, recall, f1_score, num_infer_chunks, num_label_chunks, num_correct_chunks = paddle.metric.chunk_eval(
predictions,
labels,
chunk_scheme=self.chunk_scheme,
num_chunk_types=self.num_chunk_types,
excluded_chunk_types=self.excluded_chunk_types,
seq_length=lengths)
return num_infer_chunks, num_label_chunks, num_correct_chunks
def _is_number_or_matrix(self, var):
def _is_number_(var):
return isinstance(
var, int) or isinstance(var, np.int64) or isinstance(
var, float) or (isinstance(var, np.ndarray) and
var.shape == (1, ))
return _is_number_(var) or isinstance(var, np.ndarray)
def update(self, num_infer_chunks, num_label_chunks, num_correct_chunks):
"""
This function takes (num_infer_chunks, num_label_chunks, num_correct_chunks) as input,
to accumulate and update the corresponding status of the ChunkEvaluator object. The update method is as follows:
.. math::
\\\\ \\begin{array}{l}{\\text { self. num_infer_chunks }+=\\text { num_infer_chunks }} \\\\ {\\text { self. num_Label_chunks }+=\\text { num_label_chunks }} \\\\ {\\text { self. num_correct_chunks }+=\\text { num_correct_chunks }}\\end{array} \\\\
Args:
num_infer_chunks(int|numpy.array): The number of chunks in Inference on the given minibatch.
num_label_chunks(int|numpy.array): The number of chunks in Label on the given mini-batch.
num_correct_chunks(int|float|numpy.array): The number of chunks both in Inference and Label on the
given mini-batch.
"""
if not self._is_number_or_matrix(num_infer_chunks):
raise ValueError(
"The 'num_infer_chunks' must be a number(int) or a numpy ndarray."
)
if not self._is_number_or_matrix(num_label_chunks):
raise ValueError(
"The 'num_label_chunks' must be a number(int, float) or a numpy ndarray."
)
if not self._is_number_or_matrix(num_correct_chunks):
raise ValueError(
"The 'num_correct_chunks' must be a number(int, float) or a numpy ndarray."
)
self.num_infer_chunks += num_infer_chunks
self.num_label_chunks += num_label_chunks
self.num_correct_chunks += num_correct_chunks
def accumulate(self):
"""
This function returns the mean precision, recall and f1 score for all accumulated minibatches.
Returns:
float: mean precision, recall and f1 score.
"""
precision = float(
self.num_correct_chunks
) / self.num_infer_chunks if self.num_infer_chunks else 0
recall = float(self.num_correct_chunks
) / self.num_label_chunks if self.num_label_chunks else 0
f1_score = float(2 * precision * recall) / (
precision + recall) if self.num_correct_chunks else 0
return precision, recall, f1_score
def reset(self):
"""
Reset function empties the evaluation memory for previous mini-batches.
"""
self.num_infer_chunks = 0
self.num_label_chunks = 0
self.num_correct_chunks = 0
def name(self):
"""
Return name of metric instance.
"""
return "precision", "recall", "f1"
...@@ -19,16 +19,15 @@ import argparse ...@@ -19,16 +19,15 @@ import argparse
import numpy as np import numpy as np
import paddle import paddle
from paddlenlp.data import Pad, Tuple, Stack
from paddlenlp.metrics import ChunkEvaluator
from data import load_kv_dict, batch_padding_fn, LacDataset, parse_lac_result from data import LacDataset, parse_lac_result
from model import BiGruCrf, ViterbiDecoder, ChunkEvaluator from model import BiGruCrf
# yapf: disable # yapf: disable
parser = argparse.ArgumentParser(__doc__) parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--base_path", type=str, default=None, help="The folder where the dataset is located.") parser.add_argument("--root", type=str, default=None, help="The folder where the dataset is located.")
parser.add_argument("--word_dict_path", type=str, default=None, help="The path of the word dictionary.")
parser.add_argument("--label_dict_path", type=str, default=None, help="The path of the label dictionary.")
parser.add_argument("--word_rep_dict_path", type=str, default=None, help="The path of the word replacement Dictionary")
parser.add_argument("--init_checkpoint", type=str, default=None, help="Path to init model.") parser.add_argument("--init_checkpoint", type=str, default=None, help="Path to init model.")
parser.add_argument("--batch_size", type=int, default=300, help="The number of sequences contained in a mini-batch.") parser.add_argument("--batch_size", type=int, default=300, help="The number of sequences contained in a mini-batch.")
parser.add_argument("--max_seq_len", type=int, default=64, help="Number of words of the longest seqence.") parser.add_argument("--max_seq_len", type=int, default=64, help="Number of words of the longest seqence.")
...@@ -43,14 +42,13 @@ def infer(args): ...@@ -43,14 +42,13 @@ def infer(args):
place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace() place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace()
paddle.set_device("gpu" if args.use_gpu else "cpu") paddle.set_device("gpu" if args.use_gpu else "cpu")
# Load vocab to create dataset. # create dataset.
word_vocab = load_kv_dict( infer_dataset = LacDataset(args.root, mode='infer')
args.word_dict_path, value_func=np.int64, reverse=True)
label_vocab = load_kv_dict( batchify_fn = lambda samples, fn=Tuple(
args.label_dict_path, value_func=np.int64, reverse=True) Pad(axis=0, pad_val=0), # word_ids
word_rep_dict = load_kv_dict(args.word_rep_dict_path) Stack(), # length
infer_dataset = LacDataset( ): fn(samples)
args.base_path, word_vocab, label_vocab, word_rep_dict, mode='infer')
# Create sampler for dataloader # Create sampler for dataloader
infer_sampler = paddle.io.BatchSampler( infer_sampler = paddle.io.BatchSampler(
...@@ -63,7 +61,7 @@ def infer(args): ...@@ -63,7 +61,7 @@ def infer(args):
batch_sampler=infer_sampler, batch_sampler=infer_sampler,
places=place, places=place,
return_list=True, return_list=True,
collate_fn=batch_padding_fn(args.max_seq_len)) collate_fn=batchify_fn)
# Define the model network # Define the model network
network = BiGruCrf(args.emb_dim, args.hidden_size, infer_dataset.vocab_size, network = BiGruCrf(args.emb_dim, args.hidden_size, infer_dataset.vocab_size,
...@@ -82,7 +80,8 @@ def infer(args): ...@@ -82,7 +80,8 @@ def infer(args):
[pred for batch_pred in crf_decodes for pred in batch_pred]) [pred for batch_pred in crf_decodes for pred in batch_pred])
results = parse_lac_result(infer_dataset.word_ids, preds, lengths, results = parse_lac_result(infer_dataset.word_ids, preds, lengths,
word_vocab, label_vocab) infer_dataset.word_vocab,
infer_dataset.label_vocab)
sent_tags = [] sent_tags = []
for sent, tags in results: for sent, tags in results:
......
...@@ -20,17 +20,15 @@ import argparse ...@@ -20,17 +20,15 @@ import argparse
import numpy as np import numpy as np
import paddle import paddle
from data import load_kv_dict, batch_padding_fn, LacDataset from data import LacDataset
from model import BiGruCrf from model import BiGruCrf
from paddlenlp.data import Pad, Tuple, Stack
from paddlenlp.layers.crf import LinearChainCrfLoss, ViterbiDecoder from paddlenlp.layers.crf import LinearChainCrfLoss, ViterbiDecoder
from paddlenlp.metrics.chunk_evaluator import ChunkEvaluator from paddlenlp.metrics import ChunkEvaluator
# yapf: disable # yapf: disable
parser = argparse.ArgumentParser(__doc__) parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--base_path", type=str, default=None, help="The folder where the dataset is located.") parser.add_argument("--root", type=str, default=None, help="The folder where the dataset is located.")
parser.add_argument("--word_dict_path", type=str, default=None, help="The path of the word dictionary.")
parser.add_argument("--label_dict_path", type=str, default=None, help="The path of the label dictionary.")
parser.add_argument("--word_rep_dict_path", type=str, default=None, help="The path of the word replacement Dictionary")
parser.add_argument("--init_checkpoint", type=str, default=None, help="Path to init model.") parser.add_argument("--init_checkpoint", type=str, default=None, help="Path to init model.")
parser.add_argument("--model_save_dir", type=str, default=None, help="The model will be saved in this path.") parser.add_argument("--model_save_dir", type=str, default=None, help="The model will be saved in this path.")
parser.add_argument("--epochs", type=int, default=10, help="Corpus iteration num.") parser.add_argument("--epochs", type=int, default=10, help="Corpus iteration num.")
...@@ -52,16 +50,15 @@ def train(args): ...@@ -52,16 +50,15 @@ def train(args):
place = paddle.CPUPlace() place = paddle.CPUPlace()
paddle.set_device("cpu") paddle.set_device("cpu")
# Load vocab to create dataset. # create dataset.
word_vocab = load_kv_dict( train_dataset = LacDataset(args.root, mode='train')
args.word_dict_path, value_func=np.int64, reverse=True) test_dataset = LacDataset(args.root, mode='test')
label_vocab = load_kv_dict(
args.label_dict_path, value_func=np.int64, reverse=True) batchify_fn = lambda samples, fn=Tuple(
word_rep_dict = load_kv_dict(args.word_rep_dict_path) Pad(axis=0, pad_val=0), # word_ids
train_dataset = LacDataset( Stack(), # length
args.base_path, word_vocab, label_vocab, word_rep_dict, mode='train') Pad(axis=0, pad_val=0), # label_ids
test_dataset = LacDataset( ): fn(samples)
args.base_path, word_vocab, label_vocab, word_rep_dict, mode='test')
# Create sampler for dataloader # Create sampler for dataloader
train_sampler = paddle.io.DistributedBatchSampler( train_sampler = paddle.io.DistributedBatchSampler(
...@@ -74,7 +71,7 @@ def train(args): ...@@ -74,7 +71,7 @@ def train(args):
batch_sampler=train_sampler, batch_sampler=train_sampler,
places=place, places=place,
return_list=True, return_list=True,
collate_fn=batch_padding_fn(args.max_seq_len)) collate_fn=batchify_fn)
test_sampler = paddle.io.BatchSampler( test_sampler = paddle.io.BatchSampler(
dataset=test_dataset, dataset=test_dataset,
...@@ -86,7 +83,7 @@ def train(args): ...@@ -86,7 +83,7 @@ def train(args):
batch_sampler=test_sampler, batch_sampler=test_sampler,
places=place, places=place,
return_list=True, return_list=True,
collate_fn=batch_padding_fn(args.max_seq_len)) collate_fn=batchify_fn)
# Define the model netword and its loss # Define the model netword and its loss
network = BiGruCrf(args.emb_dim, args.hidden_size, train_dataset.vocab_size, network = BiGruCrf(args.emb_dim, args.hidden_size, train_dataset.vocab_size,
...@@ -105,17 +102,18 @@ def train(args): ...@@ -105,17 +102,18 @@ def train(args):
model.load(args.init_checkpoint) model.load(args.init_checkpoint)
# Start training # Start training
callback = paddle.callbacks.ProgBarLogger(log_freq=10, verbose=3)
model.fit(train_data=train_loader, model.fit(train_data=train_loader,
eval_data=test_loader, eval_data=test_loader,
batch_size=args.batch_size, batch_size=args.batch_size,
epochs=args.epochs, epochs=args.epochs,
eval_freq=1, eval_freq=1,
log_freq=1, log_freq=10,
save_dir=args.model_save_dir, save_dir=args.model_save_dir,
save_freq=1, save_freq=1,
verbose=2,
drop_last=True, drop_last=True,
shuffle=True) shuffle=True,
callbacks=callback)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -113,12 +113,6 @@ def parse_args(): ...@@ -113,12 +113,6 @@ def parse_args():
return args return args
def set_seed(args):
random.seed(args.seed + paddle.distributed.get_rank())
np.random.seed(args.seed + paddle.distributed.get_rank())
paddle.seed(args.seed + paddle.distributed.get_rank())
def evaluate(model, loss_fct, metric, data_loader, label_num): def evaluate(model, loss_fct, metric, data_loader, label_num):
model.eval() model.eval()
metric.reset() metric.reset()
...@@ -251,8 +245,6 @@ def do_train(args): ...@@ -251,8 +245,6 @@ def do_train(args):
if paddle.distributed.get_world_size() > 1: if paddle.distributed.get_world_size() > 1:
paddle.distributed.init_parallel_env() paddle.distributed.init_parallel_env()
set_seed(args)
train_dataset, dev_dataset = ppnlp.datasets.MSRA_NER.get_datasets( train_dataset, dev_dataset = ppnlp.datasets.MSRA_NER.get_datasets(
["train", "dev"]) ["train", "dev"])
tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path) tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path)
......
...@@ -49,7 +49,6 @@ class LinearChainCrf(nn.Layer): ...@@ -49,7 +49,6 @@ class LinearChainCrf(nn.Layer):
dtype='float32') dtype='float32')
self.with_start_stop_tag = with_start_stop_tag self.with_start_stop_tag = with_start_stop_tag
self.max_seq_len = 0
self._initial_alpha = None self._initial_alpha = None
self._start_tensor = None self._start_tensor = None
self._stop_tensor = None self._stop_tensor = None
...@@ -226,16 +225,14 @@ class LinearChainCrf(nn.Layer): ...@@ -226,16 +225,14 @@ class LinearChainCrf(nn.Layer):
return self._batch_index return self._batch_index
def _get_seq_index(self, length): def _get_seq_index(self, length):
if self._seq_index is None or length > self.max_seq_len: if self._seq_index is None or length > self._seq_index.shape[0]:
self.max_seq_len = length self._seq_index = paddle.arange(end=length, dtype="int64")
self._seq_index = paddle.arange(end=self.max_seq_len, dtype="int64")
return self._seq_index[:length] return self._seq_index[:length]
def _get_batch_seq_index(self, batch_size, length): def _get_batch_seq_index(self, batch_size, length):
if self._batch_seq_index is None or length > self.max_seq_len: if self._batch_seq_index is None or length > self._batch_seq_index.shape[1]:
self.max_seq_len = length
self._batch_seq_index = paddle.cumsum( self._batch_seq_index = paddle.cumsum(
paddle.ones([batch_size, self.max_seq_len + 2], "int64"), paddle.ones([batch_size, length + 2], "int64"),
axis=1) - 1 axis=1) - 1
if self.with_start_stop_tag: if self.with_start_stop_tag:
return self._batch_seq_index[:, :length + 2] return self._batch_seq_index[:, :length + 2]
......
...@@ -239,7 +239,7 @@ class BertPretrainedModel(PretrainedModel): ...@@ -239,7 +239,7 @@ class BertPretrainedModel(PretrainedModel):
"bert-base-chinese": "bert-base-chinese":
"http://paddlenlp.bj.bcebos.com/models/transformers/bert/bert-base-chinese.pdparams", "http://paddlenlp.bj.bcebos.com/models/transformers/bert/bert-base-chinese.pdparams",
"bert-base-multilingual-cased": "bert-base-multilingual-cased":
"http://paddlenlp.bj.bcebos.com/models/transformers/bert/bert-base-multilingual-cased.pdparamss", "http://paddlenlp.bj.bcebos.com/models/transformers/bert/bert-base-multilingual-cased.pdparams",
"bert-large-cased": "bert-large-cased":
"http://paddlenlp.bj.bcebos.com/models/transformers/bert/bert-large-cased.pdparams", "http://paddlenlp.bj.bcebos.com/models/transformers/bert/bert-large-cased.pdparams",
"bert-wwm-chinese": "bert-wwm-chinese":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册