未验证 提交 b84cace2 编写于 作者: W wawltor 提交者: GitHub

add the support token_embedding of express (#5058)

上级 16c6580d
......@@ -18,6 +18,7 @@ import paddle.nn as nn
from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.layers import LinearChainCrf, ViterbiDecoder, LinearChainCrfLoss
from paddlenlp.metrics import ChunkEvaluator
from paddlenlp.embeddings import TokenEmbedding
def parse_decodes(ds, decodes, lens):
......@@ -95,9 +96,18 @@ class ExpressDataset(paddle.io.Dataset):
class BiGRUWithCRF(nn.Layer):
def __init__(self, emb_size, hidden_size, word_num, label_num):
def __init__(self,
emb_size,
hidden_size,
word_num,
label_num,
use_w2v_emb=False):
super(BiGRUWithCRF, self).__init__()
self.word_emb = nn.Embedding(word_num, emb_size)
if use_w2v_emb:
self.word_emb = TokenEmbedding(
extended_vocab_path='./conf/word.dic', unknown_token='OOV')
else:
self.word_emb = nn.Embedding(word_num, emb_size)
self.gru = nn.GRU(emb_size,
hidden_size,
num_layers=2,
......@@ -153,7 +163,7 @@ if __name__ == '__main__':
model = paddle.Model(network)
optimizer = paddle.optimizer.Adam(
learning_rate=0.002, parameters=model.parameters())
learning_rate=0.001, parameters=model.parameters())
crf_loss = LinearChainCrfLoss(network.crf.transitions)
chunk_evaluator = ChunkEvaluator((train_ds.label_num + 2) // 2, 'IOB')
model.prepare(optimizer, crf_loss, chunk_evaluator)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册