未验证 提交 b84cace2 编写于 作者: W wawltor 提交者: GitHub

add the support token_embedding of express (#5058)

上级 16c6580d
...@@ -18,6 +18,7 @@ import paddle.nn as nn ...@@ -18,6 +18,7 @@ import paddle.nn as nn
from paddlenlp.data import Stack, Tuple, Pad from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.layers import LinearChainCrf, ViterbiDecoder, LinearChainCrfLoss from paddlenlp.layers import LinearChainCrf, ViterbiDecoder, LinearChainCrfLoss
from paddlenlp.metrics import ChunkEvaluator from paddlenlp.metrics import ChunkEvaluator
from paddlenlp.embeddings import TokenEmbedding
def parse_decodes(ds, decodes, lens): def parse_decodes(ds, decodes, lens):
...@@ -95,9 +96,18 @@ class ExpressDataset(paddle.io.Dataset): ...@@ -95,9 +96,18 @@ class ExpressDataset(paddle.io.Dataset):
class BiGRUWithCRF(nn.Layer): class BiGRUWithCRF(nn.Layer):
def __init__(self, emb_size, hidden_size, word_num, label_num): def __init__(self,
emb_size,
hidden_size,
word_num,
label_num,
use_w2v_emb=False):
super(BiGRUWithCRF, self).__init__() super(BiGRUWithCRF, self).__init__()
self.word_emb = nn.Embedding(word_num, emb_size) if use_w2v_emb:
self.word_emb = TokenEmbedding(
extended_vocab_path='./conf/word.dic', unknown_token='OOV')
else:
self.word_emb = nn.Embedding(word_num, emb_size)
self.gru = nn.GRU(emb_size, self.gru = nn.GRU(emb_size,
hidden_size, hidden_size,
num_layers=2, num_layers=2,
...@@ -153,7 +163,7 @@ if __name__ == '__main__': ...@@ -153,7 +163,7 @@ if __name__ == '__main__':
model = paddle.Model(network) model = paddle.Model(network)
optimizer = paddle.optimizer.Adam( optimizer = paddle.optimizer.Adam(
learning_rate=0.002, parameters=model.parameters()) learning_rate=0.001, parameters=model.parameters())
crf_loss = LinearChainCrfLoss(network.crf.transitions) crf_loss = LinearChainCrfLoss(network.crf.transitions)
chunk_evaluator = ChunkEvaluator((train_ds.label_num + 2) // 2, 'IOB') chunk_evaluator = ChunkEvaluator((train_ds.label_num + 2) // 2, 'IOB')
model.prepare(optimizer, crf_loss, chunk_evaluator) model.prepare(optimizer, crf_loss, chunk_evaluator)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册