未验证 提交 ec17d938 编写于 作者: J Jack Zhou 提交者: GitHub

Add more embedding and sample for the TokenEmbedding

* add all wiki embedding and part of baidu encyclopedia embedding.

* add embedding example

* add people_daily, weibo, sougou pretrained embedding

* add zhihu, finacial,literature embedding

* Add embedding model readme; add embedding train example and readme

* fix README example

* fix embedding doc
上级 954f02ca
# Embedding 模型汇总
包括中英文参数,表格形式汇总
PaddleNLP提供多个开源的预训练Embedding模型,用户仅需在使用`paddlenlp.embeddings.TokenEmbedding`时,指定预训练模型的名称,即可加载相对应的预训练模型。以下为PaddleNLP所支持的预训练Embedding模型,其名称用作`paddlenlp.embeddings.TokenEmbedding`的参数。命名方式为:\${训练模型}.\${语料}.\${词向量类型}.\${co-occurrence type}.dim\${维度}。训练模型有三种,分别是Word2Vec(w2v, 使用skip-gram模型训练), GloVe(glove)和FastText(fasttext)。
## 中文词向量
以下预训练模型由[Chinese-Word-Vectors](https://github.com/Embedding/Chinese-Word-Vectors)提供。
根据不同类型的上下文为每个语料训练多个目标词向量,第二列开始表示不同类型的上下文。以下为上下文类别:
* Word表示训练时目标词预测的上下文是一个Word。
* Word + Ngram表示训练时目标词预测的上下文是一个Word或者Ngram,其中bigram表示2-grams,ngram.1-2表示1-gram或者2-grams。
* Word + Character表示训练时目标词预测的上下文是一个Word或者Character,其中word-character.char1-2表示上下文是1个或2个Character。
* Word + Character + Ngram表示训练时目标词预测的上下文是一个Word、Character或者Ngram。bigram-char表示上下文是2-grams或者1个Character。
| 语料 | Word | Word + Ngram | Word + Character | Word + Character + Ngram |
| ------------------------------------------- | ---- | ---- | ---- | ---- |
| Baidu Encyclopedia 百度百科 | w2v.baidu_encyclopedia.target.word-word.dim300 | w2v.baidu_encyclopedia.target.word-ngram.1-2.dim300 | w2v.baidu_encyclopedia.target.word-character.char1-2.dim300 | w2v.baidu_encyclopedia.target.bigram-char.dim300 |
| Wikipedia_zh 中文维基百科 | w2v.wiki.target.word-word.dim300 | w2v.wiki.target.word-bigram.dim300 | w2v.wiki.target.word-char.dim300 | w2v.wiki.target.bigram-char.dim300 |
| People's Daily News 人民日报 | w2v.people_daily.target.word-word.dim300 | w2v.people_daily.target.word-bigram.dim300 | w2v.people_daily.target.word-char.dim300 | w2v.people_daily.target.bigram-char.dim300 |
| Sogou News 搜狗新闻 | w2v.sogou.target.word-word.dim300 | w2v.sogou.target.word-bigram.dim300 | w2v.sogou.target.word-char.dim300 | w2v.sogou.target.bigram-char.dim300 |
| Financial News 金融新闻 | w2v.financial.target.word-word.dim300 | w2v.financial.target.word-bigram.dim300 | w2v.financial.target.word-char.dim300 | w2v.financial.target.bigram-char.dim300 |
| Zhihu_QA 知乎问答 | w2v.zhihu.target.word-word.dim300 | w2v.zhihu.target.word-bigram.dim300 | w2v.zhihu.target.word-char.dim300 | w2v.zhihu.target.bigram-char.dim300 |
| Weibo 微博 | w2v.weibo.target.word-word.dim300 | w2v.weibo.target.word-bigram.dim300 | w2v.weibo.target.word-char.dim300 | w2v.weibo.target.bigram-char.dim300 |
| Literature 文学作品 | w2v.literature.target.word-word.dim300 | w2v.literature.target.word-bigram.dim300 | w2v.literature.target.word-char.dim300 | w2v.literature.target.bigram-char.dim300 |
| Complete Library in Four Sections 四库全书 | w2v.sikuquanshu.target.word-word.dim300 | w2v.sikuquanshu.target.word-bigram.dim300 | 无 | 无 |
| Mixed-large 综合 | w2v.mixed-large.target.word-word.dim300 | 暂无 | w2v.mixed-large.target.word-word.dim300 | 暂无 |
特别地,对于百度百科语料,在不同的 Co-occurrence类型下分别提供了目标词与上下文向量:
| Co-occurrence 类型 | 目标词向量 | 上下文词向量 |
| --------------------------- | ------ | ---- |
| Word → Word | w2v.baidu_encyclopedia.target.word-word.dim300 | w2v.baidu_encyclopedia.context.word-word.dim300 |
| Word → Ngram (1-2) | w2v.baidu_encyclopedia.target.word-ngram.1-2.dim300 | 暂无 |
| Word → Ngram (1-3) | 暂无 | 暂无 |
| Ngram (1-2) → Ngram (1-2)| 暂无 | 暂无 |
| Word → Character (1) | w2v.baidu_encyclopedia.target.word-character.char1-1.dim300 | w2v.baidu_encyclopedia.context.word-character.char1-1.dim300 |
| Word → Character (1-2) | w2v.baidu_encyclopedia.target.word-character.char1-2.dim300 | w2v.baidu_encyclopedia.context.word-character.char1-2.dim300 |
| Word → Character (1-4) | w2v.baidu_encyclopedia.target.word-character.char1-4.dim300 | w2v.baidu_encyclopedia.context.word-character.char1-4.dim300 |
| Word → Word (left/right) | 暂无 | 暂无 |
| Word → Word (distance) | 暂无 | 暂无 |
## 英文词向量
待更新。
# Word Embedding with PaddleNLP
## 重点突出paddlenlp.embeddings
## 简介
重点演示PaddleNLP中embedding功能的使用,如何快速复用,快速加载
PaddleNLP已预置多个公开的预训练Embedding,用户可以通过使用`paddle.embeddings.TokenEmbedding`接口加载预训练Embedding,从而提升训练效果。以下通过文本分类训练的例子展示`paddle.embeddings.TokenEmbedding`对训练提升的效果。
## 快速开始
### 安装说明
* PaddlePaddle 安装
本项目依赖于 PaddlePaddle 2.0 及以上版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装
* PaddleNLP 安装
```shell
pip install paddlenlp
```
* 环境依赖
本项目依赖于jieba分词,请在运行本项目之前,安装jieba,如`pip install -U jieba`
Python的版本要求 3.6+,其它环境请参考 PaddlePaddle [安装说明](https://www.paddlepaddle.org.cn/install/quick/zh/2.0rc-linux-docker) 部分的内容
### 下载词表
下载词汇表文件dict.txt,用于构造词-id映射关系。
```bash
wget https://paddlenlp.bj.bcebos.com/data/dict.txt
```
### 启动训练
我们以中文情感分类公开数据集ChnSentiCorp为示例数据集,可以运行下面的命令,在训练集(train.tsv)上进行模型训练,并在开发集(dev.tsv)验证。实验输出的日志保存在use_token_embedding.txt和use_normal_embedding.txt。
CPU 启动:
```
nohup python train.py --vocab_path='./dict.txt' --use_gpu=False --lr=5e-4 --batch_size=64 --epochs=20 --use_token_embedding=True --vdl_dir='./vdl_dir' >use_token_embedding.txt 2>&1 &
nohup python train.py --vocab_path='./dict.txt' --use_gpu=False --lr=5e-4 --batch_size=64 --epochs=20 --use_token_embedding=False --vdl_dir='./vdl_dir'>use_normal_embedding.txt 2>&1 &
```
GPU 启动:
```
export CUDA_VISIBLE_DEVICES=0
nohup python train.py --vocab_path='./dict.txt' --use_gpu=True --lr=5e-4 --batch_size=64 --epochs=20 --use_token_embedding=True --vdl_dir='./vdl_dir' > use_token_embedding.txt 2>&1 &
# 如显存不足,可以先等第一个训练完成再启动该训练
nohup python train.py --vocab_path='./dict.txt' --use_gpu=True --lr=5e-4 --batch_size=64 --epochs=20 --use_token_embedding=False --vdl_dir='./vdl_dir' > use_normal_embedding.txt 2>&1 &
```
以上参数表示:
* `vocab_path`: 词汇表文件路径。
* `use_gpu`: 是否使用GPU进行训练, 默认为`False`
* `lr`: 学习率, 默认为5e-4。
* `batch_size`: 运行一个batch大小,默认为64。
* `epochs`: 训练轮次,默认为5。
* `use_token_embedding`: 是否使用PaddleNLP的TokenEmbedding,默认为True。
* `vdl_dir`: VisualDL日志目录。训练过程中的VisualDL信息会在该目录下保存。默认为`./vdl_dir`
该脚本还提供以下参数:
* `save_dir`: 模型保存目录。
* `init_from_ckpt`: 恢复模型训练的断点路径。
* `embedding_name`: 预训练Embedding名称,默认为`w2v.baidu_encyclopedia.target.word-word.dim300`. 支持的预训练Embedding可参考[Embedding 模型汇总](../../docs/embeddings.md)
### 启动VisualDL
推荐使用VisualDL查看实验对比。以下为VisualDL的启动命令,其中logdir参数指定的目录需要与启动训练时指定的`vdl_dir`相同。(更多VisualDL的用法,可参考[VisualDL使用指南](https://github.com/PaddlePaddle/VisualDL#2-launch-panel)
```
nohup visualdl --logdir ./vdl_dir --port 8888 --host 0.0.0.0 &
```
### 训练效果对比
在Chrome浏览器输入 `ip:8888` (ip为启动VisualDL机器的IP)。
以下为示例实验效果对比图,蓝色是使用`paddle.embeddings.TokenEmbedding`进行的实验,绿色是使用没有加载预训练模型的Embedding进行的实验。可以看到,使用`paddle.embeddings.TokenEmbedding`的训练,其验证acc变化趋势上升,并收敛于0.90左右,而没有使用`paddle.embeddings.TokenEmbedding`的训练,其验证acc变化趋势向下,并收敛于0.86左右。从示例实验可以观察到,使用`paddle.embedding.TokenEmbedding`能提升训练效果。
Eval Acc:
![eval acc](https://user-images.githubusercontent.com/10826371/102055579-0a743780-3e26-11eb-9025-99ffd06ecb68.png)
Eval Loss:
![eval loss](https://user-images.githubusercontent.com/10826371/102055669-28da3300-3e26-11eb-8034-ee902931b7cf.png)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from paddlenlp.data import JiebaTokenizer, Vocab
import jieba
tokenizer = jieba
def set_tokenizer(vocab):
global tokenizer
if vocab is not None:
tokenizer = JiebaTokenizer(vocab=vocab)
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = {}
with open(vocab_file, "r", encoding="utf-8") as reader:
tokens = reader.readlines()
for index, token in enumerate(tokens):
token = token.rstrip("\n").split("\t")[0]
vocab[token] = index
return vocab
def convert_tokens_to_ids(tokens, vocab):
""" Converts a token id (or a sequence of id) in a token string
(or a sequence of tokens), using the vocabulary.
"""
ids = []
unk_id = vocab.get('[UNK]', None)
for token in tokens:
wid = vocab.get(token, unk_id)
if wid:
ids.append(wid)
return ids
def convert_example(example, vocab, unk_token_id=1, is_test=False):
"""
Builds model inputs from a sequence for sequence classification tasks.
It use `jieba.cut` to tokenize text.
Args:
example(obj:`list[str]`): List of input data, containing text and label if it have label.
vocab(obj:`dict`): The vocabulary.
unk_token_id(obj:`int`, defaults to 1): The unknown token id.
is_test(obj:`False`, defaults to `False`): Whether the example contains label or not.
Returns:
input_ids(obj:`list[int]`): The list of token ids.s
valid_length(obj:`int`): The input sequence valid length.
label(obj:`numpy.array`, data type of int64, optional): The input label if not is_test.
"""
input_ids = []
for token in tokenizer.cut(example[0]):
token_id = vocab.get(token, unk_token_id)
input_ids.append(token_id)
valid_length = len(input_ids)
if not is_test:
label = np.array(example[-1], dtype="int64")
return input_ids, valid_length, label
else:
return input_ids, valid_length
def pad_texts_to_max_seq_len(texts, max_seq_len, pad_token_id=0):
"""
Padded the texts to the max sequence length if the length of text is lower than it.
Unless it truncates the text.
Args:
texts(obj:`list`): Texts which contrains a sequence of word ids.
max_seq_len(obj:`int`): Max sequence length.
pad_token_id(obj:`int`, optinal, defaults to 0) : The pad token index.
"""
for index, text in enumerate(texts):
seq_len = len(text)
if seq_len < max_seq_len:
padded_tokens = [pad_token_id for _ in range(max_seq_len - seq_len)]
new_text = text + padded_tokens
texts[index] = new_text
elif seq_len > max_seq_len:
new_text = text[:max_seq_len]
texts[index] = new_text
def generate_batch(batch, pad_token_id=0, return_label=True):
"""
Generates a batch whose text will be padded to the max sequence length in the batch.
Args:
batch(obj:`List[Example]`) : One batch, which contains texts, labels and the true sequence lengths.
pad_token_id(obj:`int`, optinal, defaults to 0) : The pad token index.
Returns:
batch(:obj:`Tuple[list]`): The batch data which contains texts, seq_lens and labels.
"""
seq_lens = [entry[1] for entry in batch]
batch_max_seq_len = max(seq_lens)
texts = [entry[0] for entry in batch]
pad_texts_to_max_seq_len(texts, batch_max_seq_len, pad_token_id)
if return_label:
labels = [[entry[-1]] for entry in batch]
return texts, seq_lens, labels
else:
return texts, seq_lens
def preprocess_prediction_data(data, vocab):
"""
It process the prediction data as the format used as training.
Args:
data (obj:`List[str]`): The prediction data whose each element is a tokenized text.
Returns:
examples (obj:`List(Example)`): The processed data whose each element is a Example (numedtuple) object.
A Example object contains `text`(word_ids) and `se_len`(sequence length).
"""
examples = []
for text in data:
tokens = " ".join(tokenizer.cut(text)).split(' ')
ids = convert_tokens_to_ids(tokens, vocab)
examples.append([ids, len(ids)])
return examples
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import argparse
import os
import os.path as osp
import paddle
import paddle.nn as nn
import paddlenlp as nlp
from paddlenlp.datasets import ChnSentiCorp
from paddlenlp.embeddings import TokenEmbedding
from paddlenlp.data import JiebaTokenizer, Vocab
import data
parser = argparse.ArgumentParser(__doc__)
parser.add_argument(
"--epochs", type=int, default=5, help="Number of epoches for training.")
parser.add_argument(
'--use_gpu',
type=eval,
default=True,
help="Whether use GPU for training, input should be True or False")
parser.add_argument(
"--lr", type=float, default=5e-4, help="Learning rate used to train.")
parser.add_argument(
"--save_dir",
type=str,
default='chekpoints/',
help="Directory to save model checkpoint")
parser.add_argument(
"--batch_size",
type=int,
default=64,
help="Total examples' number of a batch for training.")
parser.add_argument(
"--vocab_path",
type=str,
default="./dict.txt",
help="The directory to dataset.")
parser.add_argument(
"--init_from_ckpt",
type=str,
default=None,
help="The path of checkpoint to be loaded.")
parser.add_argument(
"--use_token_embedding",
type=eval,
default=True,
help="Whether use pretrained embedding")
parser.add_argument(
"--embedding_name",
type=str,
default="w2v.baidu_encyclopedia.target.word-word.dim300",
help="The name of pretrained embedding")
parser.add_argument(
"--vdl_dir", type=str, default="vdl_dir/", help="VisualDL log directory")
args = parser.parse_args()
def create_dataloader(dataset,
trans_fn=None,
mode='train',
batch_size=1,
use_gpu=False,
pad_token_id=0):
"""
Creats dataloader.
Args:
dataset(obj:`paddle.io.Dataset`): Dataset instance.
mode(obj:`str`, optional, defaults to obj:`train`): If mode is 'train', it will shuffle the dataset randomly.
batch_size(obj:`int`, optional, defaults to 1): The sample number of a mini-batch.
use_gpu(obj:`bool`, optional, defaults to obj:`False`): Whether to use gpu to run.
pad_token_id(obj:`int`, optional, defaults to 0): The pad token index.
Returns:
dataloader(obj:`paddle.io.DataLoader`): The dataloader which generates batches.
"""
if trans_fn:
dataset = dataset.apply(trans_fn, lazy=True)
shuffle = True if mode == 'train' else False
sampler = paddle.io.BatchSampler(
dataset=dataset, batch_size=batch_size, shuffle=shuffle)
dataloader = paddle.io.DataLoader(
dataset,
batch_sampler=sampler,
return_list=True,
collate_fn=lambda batch: data.generate_batch(batch, pad_token_id=pad_token_id))
return dataloader
class BoWModel(nn.Layer):
"""
This class implements the Bag of Words Classification Network model to classify texts.
At a high level, the model starts by embedding the tokens and running them through
a word embedding. Then, we encode these epresentations with a `BoWEncoder`.
Lastly, we take the output of the encoder to create a final representation,
which is passed through some feed-forward layers to output a logits (`output_layer`).
Args:
vocab_size (obj:`int`): The vocabulary size.
emb_dim (obj:`int`, optional, defaults to 300): The embedding dimension.
hidden_size (obj:`int`, optional, defaults to 128): The first full-connected layer hidden size.
fc_hidden_size (obj:`int`, optional, defaults to 96): The second full-connected layer hidden size.
num_classes (obj:`int`): All the labels that the data has.
"""
def __init__(self,
vocab_size,
num_classes,
vocab_path,
emb_dim=300,
hidden_size=128,
fc_hidden_size=96,
use_token_embedding=True):
super().__init__()
if use_token_embedding:
self.embedder = TokenEmbedding(
args.embedding_name, extended_vocab_path=vocab_path)
emb_dim = self.embedder.embedding_dim
else:
padding_idx = vocab_size - 1
self.embedder = nn.Embedding(
vocab_size, emb_dim, padding_idx=padding_idx)
self.bow_encoder = nlp.seq2vec.BoWEncoder(emb_dim)
self.fc1 = nn.Linear(self.bow_encoder.get_output_dim(), hidden_size)
self.fc2 = nn.Linear(hidden_size, fc_hidden_size)
self.dropout = nn.Dropout(p=0.3, axis=1)
self.output_layer = nn.Linear(fc_hidden_size, num_classes)
def forward(self, text, seq_len=None):
# Shape: (batch_size, num_tokens, embedding_dim)
embedded_text = self.embedder(text)
# Shape: (batch_size, embedding_dim)
summed = self.bow_encoder(embedded_text)
summed = self.dropout(summed)
encoded_text = paddle.tanh(summed)
# Shape: (batch_size, hidden_size)
fc1_out = paddle.tanh(self.fc1(encoded_text))
# Shape: (batch_size, fc_hidden_size)
fc2_out = paddle.tanh(self.fc2(fc1_out))
# Shape: (batch_size, num_classes)
logits = self.output_layer(fc2_out)
return logits
if __name__ == '__main__':
paddle.set_device('gpu') if args.use_gpu else paddle.set_device('cpu')
# Loads vocab.
if not os.path.exists(args.vocab_path):
raise RuntimeError('The vocab_path can not be found in the path %s' %
args.vocab_path)
vocab = data.load_vocab(args.vocab_path)
if '[PAD]' not in vocab:
vocab['[PAD]'] = len(vocab)
# Loads dataset.
train_ds, dev_ds, test_ds = ChnSentiCorp.get_datasets(
['train', 'dev', 'test'])
# Constructs the newtork.
num_classes = len(train_ds.get_labels())
model = BoWModel(
vocab_size=len(vocab),
num_classes=num_classes,
vocab_path=args.vocab_path,
use_token_embedding=args.use_token_embedding)
if args.use_token_embedding:
vocab = model.embedder.vocab
data.set_tokenizer(vocab)
vocab = vocab.token_to_idx
else:
v = Vocab.from_dict(vocab, unk_token="[UNK]", pad_token="[PAD]")
data.set_tokenizer(v)
model = paddle.Model(model)
# Reads data and generates mini-batches.
trans_fn = partial(
data.convert_example,
vocab=vocab,
unk_token_id=vocab['[UNK]'],
is_test=False)
train_loader = create_dataloader(
train_ds,
trans_fn=trans_fn,
batch_size=args.batch_size,
mode='train',
pad_token_id=vocab['[PAD]'])
dev_loader = create_dataloader(
dev_ds,
trans_fn=trans_fn,
batch_size=args.batch_size,
mode='validation',
pad_token_id=vocab['[PAD]'])
test_loader = create_dataloader(
test_ds,
trans_fn=trans_fn,
batch_size=args.batch_size,
mode='test',
pad_token_id=vocab['[PAD]'])
optimizer = paddle.optimizer.Adam(
parameters=model.parameters(), learning_rate=args.lr)
# Defines loss and metric.
criterion = paddle.nn.CrossEntropyLoss()
metric = paddle.metric.Accuracy()
model.prepare(optimizer, criterion, metric)
# Loads pre-trained parameters.
if args.init_from_ckpt:
model.load(args.init_from_ckpt)
print("Loaded checkpoint from %s" % args.init_from_ckpt)
# Starts training and evaluating.
log_dir = 'use_normal_embedding'
if args.use_token_embedding:
log_dir = 'use_token_embedding'
log_dir = osp.join(args.vdl_dir, log_dir)
callback = paddle.callbacks.VisualDL(log_dir=log_dir)
model.fit(train_loader,
dev_loader,
epochs=args.epochs,
save_dir=args.save_dir,
callbacks=callback)
# Finally tests model.
results = model.evaluate(test_loader, callbacks=callback)
print("Finally test acc: %.5f" % results['acc'])
......@@ -15,13 +15,63 @@
from enum import Enum
import os.path as osp
URL_ROOT = "https://bj.bcebos.com/paddlenlp"
URL_ROOT = "https://paddlenlp.bj.bcebos.com"
EMBEDDING_URL_ROOT = osp.join(URL_ROOT, "models/embeddings")
PAD_TOKEN = '[PAD]'
UNK_TOKEN = '[UNK]'
PAD_IDX = 0
UNK_IDX = 1
EMBEDDING_NAME_LIST = ["w2v.baidu_encyclopedia.target.word-word.dim300"]
EMBEDDING_NAME_LIST = [
# baidu_encyclopedia
"w2v.baidu_encyclopedia.target.word-word.dim300",
"w2v.baidu_encyclopedia.target.word-character.char1-1.dim300",
"w2v.baidu_encyclopedia.target.word-character.char1-2.dim300",
"w2v.baidu_encyclopedia.target.word-character.char1-4.dim300",
"w2v.baidu_encyclopedia.target.word-ngram.1-2.dim300",
"w2v.baidu_encyclopedia.target.bigram-char.dim300",
"w2v.baidu_encyclopedia.context.word-word.dim300",
"w2v.baidu_encyclopedia.context.word-character.char1-1.dim300",
"w2v.baidu_encyclopedia.context.word-character.char1-2.dim300",
"w2v.baidu_encyclopedia.context.word-character.char1-4.dim300",
# wikipedia
"w2v.wiki.target.bigram-char.dim300",
"w2v.wiki.target.word-char.dim300",
"w2v.wiki.target.word-word.dim300",
"w2v.wiki.target.word-bigram.dim300",
# people_daily
"w2v.people_daily.target.bigram-char.dim300",
"w2v.people_daily.target.word-char.dim300",
"w2v.people_daily.target.word-word.dim300",
"w2v.people_daily.target.word-bigram.dim300",
# weibo
"w2v.weibo.target.bigram-char.dim300",
"w2v.weibo.target.word-char.dim300",
"w2v.weibo.target.word-word.dim300",
"w2v.weibo.target.word-bigram.dim300",
# sogou
"w2v.sogou.target.bigram-char.dim300",
"w2v.sogou.target.word-char.dim300",
"w2v.sogou.target.word-word.dim300",
"w2v.sogou.target.word-bigram.dim300",
# zhihu
"w2v.zhihu.target.bigram-char.dim300",
"w2v.zhihu.target.word-char.dim300",
"w2v.zhihu.target.word-word.dim300",
"w2v.zhihu.target.word-bigram.dim300",
# finacial
"w2v.financial.target.bigram-char.dim300",
"w2v.financial.target.word-char.dim300",
"w2v.financial.target.word-word.dim300",
"w2v.financial.target.word-bigram.dim300",
# literature
"w2v.literature.target.bigram-char.dim300",
"w2v.literature.target.word-char.dim300",
"w2v.literature.target.word-word.dim300",
"w2v.literature.target.word-bigram.dim300",
# siku
"w2v.sikuquanshu.target.word-word.dim300",
"w2v.sikuquanshu.target.word-bigram.dim300",
# Mix-large
"w2v.mixed-large.target.word-char.dim300",
"w2v.mixed-large.target.word-word.dim300"
]
......@@ -24,8 +24,8 @@ from paddle.utils.download import get_path_from_url
from paddlenlp.utils.env import _get_sub_home, MODEL_HOME
from paddlenlp.utils.log import logger
from paddlenlp.data import Vocab, get_idx_from_word
from .constant import EMBEDDING_URL_ROOT, PAD_TOKEN, UNK_TOKEN, PAD_IDX, \
UNK_IDX, EMBEDDING_NAME_LIST
from .constant import EMBEDDING_URL_ROOT, PAD_TOKEN, UNK_TOKEN,\
EMBEDDING_NAME_LIST
EMBEDDING_HOME = _get_sub_home('embeddings', parent_home=MODEL_HOME)
......@@ -83,6 +83,15 @@ class TokenEmbedding(nn.Embedding):
self.weight.set_value(embedding_table)
self.set_trainable(trainable)
logger.info("Finish loading embedding vector.")
s = "Token Embedding brief:\
\nUnknown index: {}\
\nUnknown token: {}\
\nPadding index: {}\
\nPadding token: {}\
\nShape :{}".format(
self._word_to_idx[self.unknown_token], self.unknown_token,
self._word_to_idx[PAD_TOKEN], PAD_TOKEN, self.weight.shape)
logger.info(s)
def _init_without_extend_vocab(self, vector_np, pad_vector, unk_vector):
self._idx_to_word = list(vector_np['vocab'])
......@@ -100,10 +109,7 @@ class TokenEmbedding(nn.Embedding):
vocab_list = []
with open(extended_vocab_path, "r", encoding="utf-8") as f:
for line in f.readlines():
line = line.strip()
if line == "":
break
vocab = line.split()[0]
vocab = line.rstrip("\n").split("\t")[0]
vocab_list.append(vocab)
return vocab_list
......@@ -162,9 +168,12 @@ class TokenEmbedding(nn.Embedding):
unk_idx = self._word_to_idx[self.unknown_token]
embedding_table[unk_idx] = unk_vector
self._idx_to_word.append(PAD_TOKEN)
self._word_to_idx[PAD_TOKEN] = len(self._idx_to_word) - 1
embedding_table = np.append(embedding_table, [pad_vector], axis=0)
if PAD_TOKEN not in extend_vocab_set:
self._idx_to_word.append(PAD_TOKEN)
self._word_to_idx[PAD_TOKEN] = len(self._idx_to_word) - 1
embedding_table = np.append(embedding_table, [pad_vector], axis=0)
else:
embedding_table[self._word_to_idx[PAD_TOKEN]] = pad_vector
logger.info("Finish extending vocab.")
return embedding_table
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册