未验证 提交 047b8b69 编写于 作者: J Jack Zhou 提交者: GitHub

Add en vector for the embedding

* add more English embedding name

* fix doc bug

* delete useless description

* add comments of TokenEmbedding

* add embedding model info
上级 69ccb4c4
- [Embedding 模型汇总](#embedding-模型汇总)
- [中文词向量](#中文词向量)
- [英文词向量](#英文词向量)
- [GloVe](#glove)
- [FastText](#fasttext)
- [模型信息](#模型信息)
- [致谢](#致谢)
- [参考论文](#参考论文)
# Embedding 模型汇总 # Embedding 模型汇总
PaddleNLP提供多个开源的预训练Embedding模型,用户仅需在使用`paddlenlp.embeddings.TokenEmbedding`时,指定预训练模型的名称,即可加载相对应的预训练模型。以下为PaddleNLP所支持的预训练Embedding模型,其名称用作`paddlenlp.embeddings.TokenEmbedding`的参数。命名方式为:\${训练模型}.\${语料}.\${词向量类型}.\${co-occurrence type}.dim\${维度}。训练模型有三种,分别是Word2Vec(w2v, 使用skip-gram模型训练), GloVe(glove)和FastText(fasttext)。 PaddleNLP提供多个开源的预训练Embedding模型,用户仅需在使用`paddlenlp.embeddings.TokenEmbedding`时,指定预训练模型的名称,即可加载相对应的预训练模型。以下为PaddleNLP所支持的预训练Embedding模型,其名称用作`paddlenlp.embeddings.TokenEmbedding`的参数。命名方式为:\${训练模型}.\${语料}.\${词向量类型}.\${co-occurrence type}.dim\${维度}。训练模型有三种,分别是Word2Vec(w2v, 使用skip-gram模型训练), GloVe(glove)和FastText(fasttext)。
...@@ -42,11 +51,91 @@ PaddleNLP提供多个开源的预训练Embedding模型,用户仅需在使用`p ...@@ -42,11 +51,91 @@ PaddleNLP提供多个开源的预训练Embedding模型,用户仅需在使用`p
## 英文词向量 ## 英文词向量
待更新。 ### GloVe
| 语料 | 25维 | 50维 | 100维 | 200维 | 300 维 |
| ----------------- | ------ | ------ | ------ | ------ | ------ |
| Wiki2014 + GigaWord | 无 | glove.wiki2014-gigaword.target.word-word.dim50.en | glove.wiki2014-gigaword.target.word-word.dim100.en | glove.wiki2014-gigaword.target.word-word.dim200.en | glove.wiki2014-gigaword.target.word-word.dim300.en |
| Twitter | glove.twitter.target.word-word.dim25.en | glove.twitter.target.word-word.dim50.en | glove.twitter.target.word-word.dim100.en | glove.twitter.target.word-word.dim200.en | 无 |
### FastText
| 语料 | 名称 |
|------|------|
| Wiki2017 | fasttext.wiki-news.target.word-word.dim300.en |
| Crawl | fasttext.crawl.target.word-word.dim300.en |
## 模型信息
| 模型 | 文件大小 | 词表大小 |
|-----|---------|---------|
| w2v.baidu_encyclopedia.target.word-word.dim300 | 678.21 MB | 635965 |
| w2v.baidu_encyclopedia.target.word-character.char1-1.dim300 | 679.15 MB | 636038 |
| w2v.baidu_encyclopedia.target.word-character.char1-2.dim300 | 679.30 MB | 636038 |
| w2v.baidu_encyclopedia.target.word-character.char1-4.dim300 | 679.51 MB | 636038 |
| w2v.baidu_encyclopedia.target.word-ngram.1-2.dim300 | 679.48 MB | 635977 |
| w2v.baidu_encyclopedia.target.word-ngram.1-3.dim300 | 671.27 MB | 628669 |
| w2v.baidu_encyclopedia.target.word-ngram.2-2.dim300 | 7.28 GB | 6969069 |
| w2v.baidu_encyclopedia.target.word-wordLR.dim300 | 678.22 MB | 635958 |
| w2v.baidu_encyclopedia.target.word-wordPosition.dim300 | 679.32 MB | 636038 |
| w2v.baidu_encyclopedia.target.bigram-char.dim300 | 679.29 MB | 635976 |
| w2v.baidu_encyclopedia.context.word-word.dim300 | 677.74 MB | 635952 |
| w2v.baidu_encyclopedia.context.word-character.char1-1.dim300 | 678.65 MB | 636200 |
| w2v.baidu_encyclopedia.context.word-character.char1-2.dim300 | 844.23 MB | 792631 |
| w2v.baidu_encyclopedia.context.word-character.char1-4.dim300 | 1.16 GB | 1117461 |
| w2v.baidu_encyclopedia.context.word-ngram.1-2.dim300 | 7.25 GB | 6967598 |
| w2v.baidu_encyclopedia.context.word-ngram.1-3.dim300 | 5.21 GB | 5000001 |
| w2v.baidu_encyclopedia.context.word-ngram.2-2.dim300 | 7.26 GB | 6968998 |
| w2v.baidu_encyclopedia.context.word-wordLR.dim300 | 1.32 GB | 1271031 |
| w2v.baidu_encyclopedia.context.word-wordPosition.dim300 | 6.47 GB | 6293920 |
| w2v.wiki.target.bigram-char.dim300 | 375.98 MB | 352274 |
| w2v.wiki.target.word-char.dim300 | 375.52 MB | 352223 |
| w2v.wiki.target.word-word.dim300 | 374.95 MB | 352219 |
| w2v.wiki.target.word-bigram.dim300 | 375.72 MB | 352219 |
| w2v.people_daily.target.bigram-char.dim300 | 379.96 MB | 356055 |
| w2v.people_daily.target.word-char.dim300 | 379.45 MB | 355998 |
| w2v.people_daily.target.word-word.dim300 | 378.93 MB | 355989 |
| w2v.people_daily.target.word-bigram.dim300 | 379.68 MB | 355991 |
| w2v.weibo.target.bigram-char.dim300 | 208.24 MB | 195199 |
| w2v.weibo.target.word-char.dim300 | 208.03 MB | 195204 |
| w2v.weibo.target.word-word.dim300 | 207.94 MB | 195204 |
| w2v.weibo.target.word-bigram.dim300 | 208.19 MB | 195204 |
| w2v.sogou.target.bigram-char.dim300 | 389.81 MB | 365112 |
| w2v.sogou.target.word-char.dim300 | 389.89 MB | 365078 |
| w2v.sogou.target.word-word.dim300 | 388.66 MB | 364992 |
| w2v.sogou.target.word-bigram.dim300 | 388.66 MB | 364994 |
| w2v.zhihu.target.bigram-char.dim300 | 277.35 MB | 259755 |
| w2v.zhihu.target.word-char.dim300 | 277.40 MB | 259940 |
| w2v.zhihu.target.word-word.dim300 | 276.98 MB | 259871 |
| w2v.zhihu.target.word-bigram.dim300 | 277.53 MB | 259885 |
| w2v.financial.target.bigram-char.dim300 | 499.52 MB | 467163 |
| w2v.financial.target.word-char.dim300 | 499.17 MB | 467343 |
| w2v.financial.target.word-word.dim300 | 498.94 MB | 467324 |
| w2v.financial.target.word-bigram.dim300 | 499.54 MB | 467331 |
| w2v.literature.target.bigram-char.dim300 | 200.69 MB | 187975 |
| w2v.literature.target.word-char.dim300 | 200.44 MB | 187980 |
| w2v.literature.target.word-word.dim300 | 200.28 MB | 187961 |
| w2v.literature.target.word-bigram.dim300 | 200.59 MB | 187962 |
| w2v.sikuquanshu.target.word-word.dim300 | 20.70 MB | 19529 |
| w2v.sikuquanshu.target.word-bigram.dim300 | 20.77 MB | 19529 |
| w2v.mixed-large.target.word-char.dim300 | 1.35 GB | 1292552 |
| w2v.mixed-large.target.word-word.dim300 | 1.35 GB | 1292483 |
| glove.wiki2014-gigaword.target.word-word.dim50.en | 73.45 MB | 400002 |
| glove.wiki2014-gigaword.target.word-word.dim100.en | 143.30 MB | 400002 |
| glove.wiki2014-gigaword.target.word-word.dim200.en | 282.97 MB | 400002 |
| glove.wiki2014-gigaword.target.word-word.dim300.en | 422.83 MB | 400002 |
| glove.twitter.target.word-word.dim25.en | 116.92 MB | 1193516 |
| glove.twitter.target.word-word.dim50.en | 221.64 MB | 1193516 |
| glove.twitter.target.word-word.dim100.en | 431.08 MB | 1193516 |
| glove.twitter.target.word-word.dim200.en | 848.56 MB | 1193516 |
| fasttext.wiki-news.target.word-word.dim300.en | 541.63 MB | 999996 |
| fasttext.crawl.target.word-word.dim300.en | 1.19 GB | 2000002 |
## 致谢 ## 致谢
- 感谢 [Chinese-Word-Vectors](https://github.com/Embedding/Chinese-Word-Vectors)提供Word2Vec中文Embedding来源 - 感谢 [Chinese-Word-Vectors](https://github.com/Embedding/Chinese-Word-Vectors)提供Word2Vec中文Embedding预训练模型,[GloVe Project](https://nlp.stanford.edu/projects/glove)提供的GloVe英文Embedding预训练模型,[FastText Project](https://fasttext.cc/docs/en/english-vectors.html)提供的fasttext英文预训练模型
## 参考论文 ## 参考论文
- Li, Shen, et al. "Analogical reasoning on chinese morphological and semantic relations." arXiv preprint arXiv:1805.06504 (2018). - Li, Shen, et al. "Analogical reasoning on chinese morphological and semantic relations." arXiv preprint arXiv:1805.06504 (2018).
- Qiu, Yuanyuan, et al. "Revisiting correlations between intrinsic and extrinsic evaluations of word embeddings." Chinese Computational Linguistics and Natural Language Processing Based on Naturally Annotated Big Data. Springer, Cham, 2018. 209-221. - Qiu, Yuanyuan, et al. "Revisiting correlations between intrinsic and extrinsic evaluations of word embeddings." Chinese Computational Linguistics and Natural Language Processing Based on Naturally Annotated Big Data. Springer, Cham, 2018. 209-221.
- Jeffrey Pennington, Richard Socher, and Christopher D. Manning. 2014. GloVe: Global Vectors for Word Representation.
- T. Mikolov, E. Grave, P. Bojanowski, C. Puhrsch, A. Joulin. Advances in Pre-Training Distributed Word Representations
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
## 简介 ## 简介
PaddleNLP已预置多个公开的预训练Embedding,用户可以通过使用`paddle.embeddings.TokenEmbedding`接口加载预训练Embedding,从而提升训练效果。以下通过文本分类训练的例子展示`paddle.embeddings.TokenEmbedding`对训练提升的效果。 PaddleNLP已预置多个公开的预训练Embedding,用户可以通过使用`paddlenlp.embeddings.TokenEmbedding`接口加载预训练Embedding,从而提升训练效果。以下通过文本分类训练的例子展示`paddlenlp.embeddings.TokenEmbedding`对训练提升的效果。
## 快速开始 ## 快速开始
...@@ -13,17 +13,14 @@ PaddleNLP已预置多个公开的预训练Embedding,用户可以通过使用`p ...@@ -13,17 +13,14 @@ PaddleNLP已预置多个公开的预训练Embedding,用户可以通过使用`p
本项目依赖于 PaddlePaddle 2.0 及以上版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装 本项目依赖于 PaddlePaddle 2.0 及以上版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装
* PaddleNLP 安装
```shell
pip install paddlenlp
```
* 环境依赖 * 环境依赖
本项目依赖于jieba分词,请在运行本项目之前,安装jieba,如`pip install -U jieba` - python >= 3.6
- paddlepaddle-gpu >= 2.0.0-rc1
Python的版本要求 3.6+,其它环境请参考 PaddlePaddle [安装说明](https://www.paddlepaddle.org.cn/install/quick/zh/2.0rc-linux-docker) 部分的内容 ```
pip install paddlenlp==2.0.0b
```
### 下载词表 ### 下载词表
...@@ -35,24 +32,27 @@ wget https://paddlenlp.bj.bcebos.com/data/dict.txt ...@@ -35,24 +32,27 @@ wget https://paddlenlp.bj.bcebos.com/data/dict.txt
### 启动训练 ### 启动训练
我们以中文情感分类公开数据集ChnSentiCorp为示例数据集,可以运行下面的命令,在训练集(train.tsv)上进行模型训练,并在开发集(dev.tsv)验证。实验输出的日志保存在use_token_embedding.txt和use_normal_embedding.txt。使用PaddlePaddle框架的Embedding在ChnSentiCorp下非常容易过拟合,因此调低了它的学习率 我们以中文情感分类公开数据集ChnSentiCorp为示例数据集,可以运行下面的命令,在训练集(train.tsv)上进行模型训练,并在验证集(dev.tsv)验证
CPU 启动: CPU 启动:
``` ```
nohup python train.py --vocab_path='./dict.txt' --use_gpu=False --lr=5e-4 --batch_size=64 --epochs=20 --use_token_embedding=True --vdl_dir='./vdl_dir' >use_token_embedding.txt 2>&1 & # 使用paddlenlp.embeddings.TokenEmbedding
python train.py --vocab_path='./dict.txt' --use_gpu=False --lr=5e-4 --batch_size=64 --epochs=20 --use_token_embedding=True --vdl_dir='./vdl_dir'
nohup python train.py --vocab_path='./dict.txt' --use_gpu=False --lr=1e-4 --batch_size=64 --epochs=20 --use_token_embedding=False --vdl_dir='./vdl_dir'>use_normal_embedding.txt 2>&1 & # 使用paddle.nn.Embedding
python train.py --vocab_path='./dict.txt' --use_gpu=False --lr=1e-4 --batch_size=64 --epochs=20 --use_token_embedding=False --vdl_dir='./vdl_dir'
``` ```
GPU 启动: GPU 启动:
``` ```
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
nohup python train.py --vocab_path='./dict.txt' --use_gpu=True --lr=5e-4 --batch_size=64 --epochs=20 --use_token_embedding=True --vdl_dir='./vdl_dir' > use_token_embedding.txt 2>&1 & # 使用paddlenlp.embeddings.TokenEmbedding
python train.py --vocab_path='./dict.txt' --use_gpu=True --lr=5e-4 --batch_size=64 --epochs=20 --use_token_embedding=True --vdl_dir='./vdl_dir'
# 如显存不足,可以先等第一个训练完成再启动该训练 # 使用paddle.nn.Embedding
nohup python train.py --vocab_path='./dict.txt' --use_gpu=True --lr=1e-4 --batch_size=64 --epochs=20 --use_token_embedding=False --vdl_dir='./vdl_dir' > use_normal_embedding.txt 2>&1 & python train.py --vocab_path='./dict.txt' --use_gpu=True --lr=1e-4 --batch_size=64 --epochs=20 --use_token_embedding=False --vdl_dir='./vdl_dir'
``` ```
以上参数表示: 以上参数表示:
...@@ -62,7 +62,7 @@ nohup python train.py --vocab_path='./dict.txt' --use_gpu=True --lr=1e-4 --batch ...@@ -62,7 +62,7 @@ nohup python train.py --vocab_path='./dict.txt' --use_gpu=True --lr=1e-4 --batch
* `lr`: 学习率, 默认为5e-4。 * `lr`: 学习率, 默认为5e-4。
* `batch_size`: 运行一个batch大小,默认为64。 * `batch_size`: 运行一个batch大小,默认为64。
* `epochs`: 训练轮次,默认为5。 * `epochs`: 训练轮次,默认为5。
* `use_token_embedding`: 是否使用PaddleNLP的TokenEmbedding,默认为True。 * `use_token_embedding`: 是否使用`paddlenlp.embeddings.TokenEmbedding`,默认为True。
* `vdl_dir`: VisualDL日志目录。训练过程中的VisualDL信息会在该目录下保存。默认为`./vdl_dir` * `vdl_dir`: VisualDL日志目录。训练过程中的VisualDL信息会在该目录下保存。默认为`./vdl_dir`
该脚本还提供以下参数: 该脚本还提供以下参数:
...@@ -76,14 +76,14 @@ nohup python train.py --vocab_path='./dict.txt' --use_gpu=True --lr=1e-4 --batch ...@@ -76,14 +76,14 @@ nohup python train.py --vocab_path='./dict.txt' --use_gpu=True --lr=1e-4 --batch
推荐使用VisualDL查看实验对比。以下为VisualDL的启动命令,其中logdir参数指定的目录需要与启动训练时指定的`vdl_dir`相同。(更多VisualDL的用法,可参考[VisualDL使用指南](https://github.com/PaddlePaddle/VisualDL#2-launch-panel) 推荐使用VisualDL查看实验对比。以下为VisualDL的启动命令,其中logdir参数指定的目录需要与启动训练时指定的`vdl_dir`相同。(更多VisualDL的用法,可参考[VisualDL使用指南](https://github.com/PaddlePaddle/VisualDL#2-launch-panel)
``` ```
nohup visualdl --logdir ./vdl_dir --port 8888 --host 0.0.0.0 & visualdl --logdir ./vdl_dir --port 8888 --host 0.0.0.0
``` ```
### 训练效果对比 ### 训练效果对比
在Chrome浏览器输入 `ip:8888` (ip为启动VisualDL机器的IP)。 在Chrome浏览器输入 `ip:8888` (ip为启动VisualDL机器的IP)。
以下为示例实验效果对比图,蓝色是使用`paddle.embeddings.TokenEmbedding`进行的实验,绿色是使用没有加载预训练模型的Embedding进行的实验。可以看到,使用`paddle.embeddings.TokenEmbedding`的训练,其验证acc变化趋势上升,并收敛于0.90左右,收敛后相对平稳,不容易过拟合。而没有使用`paddle.embeddings.TokenEmbedding`的训练,其验证acc变化趋势向下,并收敛于0.86左右。从示例实验可以观察到,使用`paddle.embedding.TokenEmbedding`能提升训练效果。 以下为示例实验效果对比图,蓝色是使用`paddlenlp.embeddings.TokenEmbedding`进行的实验,绿色是使用没有加载预训练模型的Embedding进行的实验。可以看到,使用`paddlenlp.embeddings.TokenEmbedding`的训练,其验证acc变化趋势上升,并收敛于0.90左右,收敛后相对平稳,不容易过拟合。而没有使用`paddlenlp.embeddings.TokenEmbedding`的训练,其验证acc变化趋势向下,并收敛于0.86左右。从示例实验可以观察到,使用`paddlenlp.embedding.TokenEmbedding`能提升训练效果。
Eval Acc: Eval Acc:
...@@ -95,8 +95,10 @@ Eval Acc: ...@@ -95,8 +95,10 @@ Eval Acc:
| paddelnlp.embeddings.TokenEmbedding | 0.9082 | | paddelnlp.embeddings.TokenEmbedding | 0.9082 |
## 致谢 ## 致谢
- 感谢 [Chinese-Word-Vectors](https://github.com/Embedding/Chinese-Word-Vectors)提供Word2Vec中文Embedding来源 - 感谢 [Chinese-Word-Vectors](https://github.com/Embedding/Chinese-Word-Vectors)提供Word2Vec中文Embedding预训练模型,[GloVe Project](https://nlp.stanford.edu/projects/glove)提供的GloVe英文Embedding预训练模型,[FastText Project](https://fasttext.cc/docs/en/english-vectors.html)提供的fasttext英文预训练模型
## 参考论文 ## 参考论文
- Li, Shen, et al. "Analogical reasoning on chinese morphological and semantic relations." arXiv preprint arXiv:1805.06504 (2018). - Li, Shen, et al. "Analogical reasoning on chinese morphological and semantic relations." arXiv preprint arXiv:1805.06504 (2018).
- Qiu, Yuanyuan, et al. "Revisiting correlations between intrinsic and extrinsic evaluations of word embeddings." Chinese Computational Linguistics and Natural Language Processing Based on Naturally Annotated Big Data. Springer, Cham, 2018. 209-221. - Qiu, Yuanyuan, et al. "Revisiting correlations between intrinsic and extrinsic evaluations of word embeddings." Chinese Computational Linguistics and Natural Language Processing Based on Naturally Annotated Big Data. Springer, Cham, 2018. 209-221.
- Jeffrey Pennington, Richard Socher, and Christopher D. Manning. 2014. GloVe: Global Vectors for Word Representation.
- T. Mikolov, E. Grave, P. Bojanowski, C. Puhrsch, A. Joulin. Advances in Pre-Training Distributed Word Representations
...@@ -83,5 +83,18 @@ EMBEDDING_NAME_LIST = [ ...@@ -83,5 +83,18 @@ EMBEDDING_NAME_LIST = [
"w2v.sikuquanshu.target.word-bigram.dim300", "w2v.sikuquanshu.target.word-bigram.dim300",
# Mix-large # Mix-large
"w2v.mixed-large.target.word-char.dim300", "w2v.mixed-large.target.word-char.dim300",
"w2v.mixed-large.target.word-word.dim300" "w2v.mixed-large.target.word-word.dim300",
# GloVe
"glove.wiki2014-gigaword.target.word-word.dim50.en",
"glove.wiki2014-gigaword.target.word-word.dim100.en",
"glove.wiki2014-gigaword.target.word-word.dim200.en",
"glove.wiki2014-gigaword.target.word-word.dim300.en",
"glove.twitter.target.word-word.dim25.en",
"glove.twitter.target.word-word.dim50.en",
"glove.twitter.target.word-word.dim100.en",
"glove.twitter.target.word-word.dim200.en",
# FastText
"fasttext.wiki-news.target.word-word.dim300.en",
"fasttext.crawl.target.word-word.dim300.en"
] ]
...@@ -33,10 +33,33 @@ __all__ = ['list_embedding_name', 'TokenEmbedding'] ...@@ -33,10 +33,33 @@ __all__ = ['list_embedding_name', 'TokenEmbedding']
def list_embedding_name(): def list_embedding_name():
"""
List all names of pretrained embedding models paddlenlp provides.
"""
return list(EMBEDDING_NAME_LIST) return list(EMBEDDING_NAME_LIST)
class TokenEmbedding(nn.Embedding): class TokenEmbedding(nn.Embedding):
"""
A `TokenEmbedding` can load pre-trained embedding model which paddlenlp provides by
specifying embedding name. Furthermore, a `TokenEmbedding` can load extended vocabulary
by specifying extended_vocab_path.
Args:
embedding_name (object: `str`, optional, default to `w2v.baidu_encyclopedia.target.word-word.dim300`):
The pre-trained embedding model name. Use `paddlenlp.embeddings.list_embedding_name()` to
show which embedding model we have alreaady provide.
unknown_token (object: `str`, optional, default to `[UNK]`):
Specifying unknown token as unknown_token.
unknown_token_vector (object: list, optional, default to `None`):
To initialize the vector of unknown token. If it's none, use normal distribution to
initialize the vector of unknown token.
extended_vocab_path (object: `str`, optional, default to `None`):
The file path of extended vocabulary.
trainable (object: `bool`, optional, default to True):
Whether the weight of embedding can be trained.
"""
def __init__(self, def __init__(self,
embedding_name=EMBEDDING_NAME_LIST[0], embedding_name=EMBEDDING_NAME_LIST[0],
unknown_token=UNK_TOKEN, unknown_token=UNK_TOKEN,
...@@ -49,7 +72,7 @@ class TokenEmbedding(nn.Embedding): ...@@ -49,7 +72,7 @@ class TokenEmbedding(nn.Embedding):
url = osp.join(EMBEDDING_URL_ROOT, embedding_name + ".tar.gz") url = osp.join(EMBEDDING_URL_ROOT, embedding_name + ".tar.gz")
get_path_from_url(url, EMBEDDING_HOME) get_path_from_url(url, EMBEDDING_HOME)
logger.info("Loading embedding vector...") logger.info("Loading token embedding...")
vector_np = np.load(vector_path) vector_np = np.load(vector_path)
self.embedding_dim = vector_np['embedding'].shape[1] self.embedding_dim = vector_np['embedding'].shape[1]
self.unknown_token = unknown_token self.unknown_token = unknown_token
...@@ -81,7 +104,7 @@ class TokenEmbedding(nn.Embedding): ...@@ -81,7 +104,7 @@ class TokenEmbedding(nn.Embedding):
self.weight.set_value(embedding_table) self.weight.set_value(embedding_table)
self.set_trainable(trainable) self.set_trainable(trainable)
logger.info("Finish loading embedding vector.") logger.info("Finish loading embedding vector.")
s = "Token Embedding brief:\ s = "Token Embedding info:\
\nUnknown index: {}\ \nUnknown index: {}\
\nUnknown token: {}\ \nUnknown token: {}\
\nPadding index: {}\ \nPadding index: {}\
...@@ -92,6 +115,9 @@ class TokenEmbedding(nn.Embedding): ...@@ -92,6 +115,9 @@ class TokenEmbedding(nn.Embedding):
logger.info(s) logger.info(s)
def _init_without_extend_vocab(self, vector_np, pad_vector, unk_vector): def _init_without_extend_vocab(self, vector_np, pad_vector, unk_vector):
"""
Construct index to word list, word to index dict and embedding weight.
"""
self._idx_to_word = list(vector_np['vocab']) self._idx_to_word = list(vector_np['vocab'])
self._idx_to_word.append(self.unknown_token) self._idx_to_word.append(self.unknown_token)
self._idx_to_word.append(PAD_TOKEN) self._idx_to_word.append(PAD_TOKEN)
...@@ -113,6 +139,10 @@ class TokenEmbedding(nn.Embedding): ...@@ -113,6 +139,10 @@ class TokenEmbedding(nn.Embedding):
def _extend_vocab(self, extended_vocab_path, vector_np, pad_vector, def _extend_vocab(self, extended_vocab_path, vector_np, pad_vector,
unk_vector): unk_vector):
"""
Construct index to word list, word to index dict and embedding weight using
extended vocab.
"""
logger.info("Start extending vocab.") logger.info("Start extending vocab.")
extend_vocab_list = self._read_vocab_list_from_file(extended_vocab_path) extend_vocab_list = self._read_vocab_list_from_file(extended_vocab_path)
extend_vocab_set = set(extend_vocab_list) extend_vocab_set = set(extend_vocab_list)
...@@ -182,18 +212,37 @@ class TokenEmbedding(nn.Embedding): ...@@ -182,18 +212,37 @@ class TokenEmbedding(nn.Embedding):
return embedding_table return embedding_table
def set_trainable(self, trainable): def set_trainable(self, trainable):
"""
Set the weight of embedding can be trained.
Args:
trainable (object: `bool`, required):
Whether the weight of embedding can be trained.
"""
self.weight.stop_gradient = not trainable self.weight.stop_gradient = not trainable
def search(self, words): def search(self, words):
"""
Get the vectors of specifying words.
Args:
words (object: `list` or `str` or `int`, required): The words which need to be searched.
Returns:
word_vector (object: `numpy.array`): The vectors of specifying words.
"""
idx_list = self.get_idx_list_from_words(words) idx_list = self.get_idx_list_from_words(words)
idx_tensor = paddle.to_tensor(idx_list) idx_tensor = paddle.to_tensor(idx_list)
return self(idx_tensor).numpy() return self(idx_tensor).numpy()
def get_idx_from_word(self, word): def get_idx_from_word(self, word):
"""
Get the index of specifying word by searching word_to_idx dict.
"""
return get_idx_from_word(word, self.vocab.token_to_idx, return get_idx_from_word(word, self.vocab.token_to_idx,
self.unknown_token) self.unknown_token)
def get_idx_list_from_words(self, words): def get_idx_list_from_words(self, words):
"""
Get the index list of specifying words by searching word_to_idx dict.
"""
if isinstance(words, str): if isinstance(words, str):
idx_list = [self.get_idx_from_word(words)] idx_list = [self.get_idx_from_word(words)]
elif isinstance(words, int): elif isinstance(words, int):
...@@ -217,23 +266,50 @@ class TokenEmbedding(nn.Embedding): ...@@ -217,23 +266,50 @@ class TokenEmbedding(nn.Embedding):
return calc_kernel(embedding_a, embedding_b) return calc_kernel(embedding_a, embedding_b)
def dot(self, word_a, word_b): def dot(self, word_a, word_b):
"""
Calculate the scalar product of 2 words.
Args:
word_a (object: `str`, required): The first word string.
word_b (object: `str`, required): The second word string.
Returns:
The scalar product of 2 words.
"""
dot = self._dot_np dot = self._dot_np
return self._calc_word(word_a, word_b, lambda x, y: dot(x, y)) return self._calc_word(word_a, word_b, lambda x, y: dot(x, y))
def cosine_sim(self, word_a, word_b): def cosine_sim(self, word_a, word_b):
"""
Calculate the cosine similarity of 2 words.
Args:
word_a (object: `str`, required): The first word string.
word_b (object: `str`, required): The second word string.
Returns:
The cosine similarity of 2 words.
"""
dot = self._dot_np dot = self._dot_np
return self._calc_word( return self._calc_word(
word_a, word_b, word_a, word_b,
lambda x, y: dot(x, y) / (np.sqrt(dot(x, x)) * np.sqrt(dot(y, y)))) lambda x, y: dot(x, y) / (np.sqrt(dot(x, x)) * np.sqrt(dot(y, y))))
def _construct_word_to_idx(self, idx_to_word): def _construct_word_to_idx(self, idx_to_word):
"""
Construct word to index dict.
Args:
idx_to_word (object: 'list', required):
Returns:
word_to_idx (object: `dict`): The word to index dict constructed by idx_to_word.
"""
word_to_idx = {} word_to_idx = {}
for i, word in enumerate(idx_to_word): for i, word in enumerate(idx_to_word):
word_to_idx[word] = i word_to_idx[word] = i
return word_to_idx return word_to_idx
def __repr__(self): def __repr__(self):
s = "Object type: {}\ """
Returns:
info (object: `str`): The token embedding infomation.
"""
info = "Object type: {}\
\nUnknown index: {}\ \nUnknown index: {}\
\nUnknown token: {}\ \nUnknown token: {}\
\nPadding index: {}\ \nPadding index: {}\
...@@ -242,4 +318,4 @@ class TokenEmbedding(nn.Embedding): ...@@ -242,4 +318,4 @@ class TokenEmbedding(nn.Embedding):
super(TokenEmbedding, self).__repr__(), super(TokenEmbedding, self).__repr__(),
self._word_to_idx[self.unknown_token], self.unknown_token, self._word_to_idx[self.unknown_token], self.unknown_token,
self._word_to_idx[PAD_TOKEN], PAD_TOKEN, self.weight) self._word_to_idx[PAD_TOKEN], PAD_TOKEN, self.weight)
return s return info
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册