Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
89e72a1a
M
models
项目概览
PaddlePaddle
/
models
接近 2 年 前同步成功
通知
230
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
89e72a1a
编写于
12月 12, 2020
作者:
J
Jack Zhou
提交者:
GitHub
12月 12, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimize TokenEmbedding (#5022)
* 1. move PAD to the end 2. Add Readme.md * optimize readme doc
上级
392ae994
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
112 addition
and
28 deletion
+112
-28
PaddleNLP/paddlenlp/embeddings/README.md
PaddleNLP/paddlenlp/embeddings/README.md
+91
-7
PaddleNLP/paddlenlp/embeddings/token_embedding.py
PaddleNLP/paddlenlp/embeddings/token_embedding.py
+21
-21
未找到文件。
PaddleNLP/paddlenlp/embeddings/README.md
浏览文件 @
89e72a1a
# paddlenlp.embeddings
# paddlenlp.embeddings
##
Embedding快速复用热启
##
TokenEmbedding参数
初定三个模型的Embedding数据,SimNet,word2vec,FastText
| 参数 | 类型 | 属性 |
| ------------ | ------------ | ------------ |
| embedding_name |
**string**
| 预训练embedding名称,可通过paddlenlp.embeddings.list_embedding_name查询。 |
| unknown_token |
**string**
| unknown token。 |
| unknown_token_vector |
**list**
或者
**np.array**
| 用来初始化unknown token对应的vector。默认为None(以正态分布方式初始化vector)|
| extended_vocab_path |
**string**
| 扩展词表的文件名路径。词表格式为一行一个词。 |
| trainable |
**bool**
| 是否可训练。True表示Embedding可以更新参数,False为不可更新。 |
使用LAC切词+大规模中文语料快速训练多个中文的embedding,注意筛选高质量词表
## 初始化
```
python
import
paddle
from
paddlenlp.embeddings
import
TokenEmbedding
,
list_embedding_name
paddle
.
set_device
(
"cpu"
)
*
SimNet 大搜数据中文
# 查看预训练embedding名称:
*
word2vec 中英文
print
(
list_embedding_name
())
# ['w2v.baidu_encyclopedia.target.word-word.dim300']
*
fasttext 中英文
## 再提供Fleet的word2vec训练入口
# 初始化TokenEmbedding, 预训练embedding没下载时会自动下载并加载数据
\ No newline at end of file
token_embedding
=
TokenEmbedding
(
embedding_name
=
"w2v.baidu_encyclopedia.target.word-word.dim300"
)
# 查看token_embedding详情
print
(
token_embedding
)
Object
type
:
<
paddlenlp
.
embeddings
.
token_embedding
.
TokenEmbedding
object
at
0x7f67fd192e30
>
Unknown
index
:
1
Unknown
token
:
[
UNK
]
Padding
index
:
0
Padding
token
:
[
PAD
]
Parameter
containing
:
Tensor
(
shape
=
[
636015
,
300
],
dtype
=
float32
,
place
=
CPUPlace
,
stop_gradient
=
False
,
[[
0.
,
0.
,
0.
,
...,
0.
,
0.
,
0.
],
[
0.00372404
,
0.01534354
,
0.01341010
,
...,
-
0.00605236
,
-
0.02150303
,
0.02372430
],
[
-
0.24200200
,
0.13931701
,
0.07378800
,
...,
0.14103900
,
0.05592300
,
-
0.08004800
],
...,
[
0.01615800
,
-
0.00266300
,
-
0.00628300
,
...,
0.01484100
,
0.00196600
,
-
0.01032000
],
[
0.01705700
,
0.00040400
,
-
0.01222000
,
...,
0.02837200
,
0.02402500
,
-
0.00814800
],
[
0.02628800
,
-
0.00008300
,
-
0.00393500
,
...,
0.00654000
,
0.00024600
,
-
0.00662600
]])
```
## 查询embedding结果
```
python
test_token_embedding
=
token_embedding
.
search
(
"中国"
)
print
(
test_token_embedding
)
[[
0.260801
0.1047
0.129453
-
0.257317
-
0.16152
0.19567
-
0.074868
0.361168
0.245882
-
0.219141
-
0.388083
0.235189
0.029316
0.154215
-
0.354343
0.017746
0.009028
0.01197
-
0.121429
0.096542
0.009255
...,
-
0.260592
-
0.019668
-
0.063312
-
0.094939
0.657352
0.247547
-
0.161621
0.289043
-
0.284084
0.205076
0.059885
0.055871
0.159309
0.062181
0.123634
0.282932
0.140399
-
0.076253
-
0.087103
0.07262
]]
```
## 计算词向量cosine相似度
```
python
score
=
token_embedding
.
cosine_sim
(
"中国"
,
"美国"
)
print
(
score
)
# 0.49586025
```
## 计算词向量内积
```
python
score
=
token_embedding
.
dot
(
"中国"
,
"美国"
)
print
(
score
)
# 8.611071
```
## 训练
```
python
in_words
=
paddle
.
to_tensor
([
0
,
2
,
3
])
input_embeddings
=
token_embedding
(
in_words
)
linear
=
paddle
.
nn
.
Linear
(
token_embedding
.
embedding_dim
,
20
)
input_fc
=
linear
(
input_embeddings
)
print
(
input_fc
)
Tensor
(
shape
=
[
3
,
20
],
dtype
=
float32
,
place
=
CPUPlace
,
stop_gradient
=
False
,
[[
0.
,
0.
,
0.
,
...,
0.
,
0.
,
0.
],
[
-
0.23473957
,
0.17878169
,
0.07215232
,
...,
0.03698236
,
0.14291850
,
0.05136518
],
[
-
0.42466098
,
0.15017235
,
-
0.04780108
,
...,
-
0.04995505
,
0.15847842
,
0.00025209
]])
```
## 切词
```
python
from
paddlenlp.data
import
JiebaTokenizer
tokenizer
=
JiebaTokenizer
(
vocab
=
token_embedding
.
vocab
)
words
=
tokenizer
.
cut
(
"中国人民"
)
print
(
words
)
# ['中国人', '民']
tokens
=
tokenizer
.
encode
(
"中国人民"
)
print
(
tokens
)
# [12532, 1336]
```
PaddleNLP/paddlenlp/embeddings/token_embedding.py
浏览文件 @
89e72a1a
...
@@ -86,13 +86,13 @@ class TokenEmbedding(nn.Embedding):
...
@@ -86,13 +86,13 @@ class TokenEmbedding(nn.Embedding):
def
_init_without_extend_vocab
(
self
,
vector_np
,
pad_vector
,
unk_vector
):
def
_init_without_extend_vocab
(
self
,
vector_np
,
pad_vector
,
unk_vector
):
self
.
_idx_to_word
=
list
(
vector_np
[
'vocab'
])
self
.
_idx_to_word
=
list
(
vector_np
[
'vocab'
])
self
.
_idx_to_word
.
insert
(
PAD_IDX
,
PAD_TOKEN
)
self
.
_idx_to_word
.
append
(
self
.
unknown_token
)
self
.
_idx_to_word
.
insert
(
UNK_IDX
,
self
.
unknown_token
)
self
.
_idx_to_word
.
append
(
PAD_TOKEN
)
self
.
_word_to_idx
=
self
.
_construct_word_to_idx
(
self
.
_idx_to_word
)
self
.
_word_to_idx
=
self
.
_construct_word_to_idx
(
self
.
_idx_to_word
)
# insert unk, pad embedding
# insert unk, pad embedding
embedding_table
=
np
.
insert
(
embedding_table
=
np
.
append
(
vector_np
[
'embedding'
],
[
0
],
[
pad_vector
,
unk_vector
],
vector_np
[
'embedding'
],
[
unk_vector
,
pad_vector
],
axis
=
0
)
axis
=
0
).
astype
(
paddle
.
get_default_dtype
())
return
embedding_table
return
embedding_table
def
_read_vocab_list_from_file
(
self
,
extended_vocab_path
):
def
_read_vocab_list_from_file
(
self
,
extended_vocab_path
):
...
@@ -114,23 +114,12 @@ class TokenEmbedding(nn.Embedding):
...
@@ -114,23 +114,12 @@ class TokenEmbedding(nn.Embedding):
extend_vocab_set
=
set
(
extend_vocab_list
)
extend_vocab_set
=
set
(
extend_vocab_list
)
# update idx_to_word
# update idx_to_word
self
.
_idx_to_word
=
extend_vocab_list
self
.
_idx_to_word
=
extend_vocab_list
self
.
_word_to_idx
=
self
.
_construct_word_to_idx
(
self
.
_idx_to_word
)
embedding_table
=
np
.
random
.
normal
(
embedding_table
=
np
.
random
.
normal
(
scale
=
0.02
,
scale
=
0.02
,
size
=
(
len
(
self
.
_idx_to_word
),
size
=
(
len
(
self
.
_idx_to_word
),
self
.
embedding_dim
)).
astype
(
paddle
.
get_default_dtype
())
self
.
embedding_dim
)).
astype
(
paddle
.
get_default_dtype
())
self
.
_idx_to_word
.
append
(
PAD_TOKEN
)
embedding_table
=
np
.
append
(
embedding_table
,
[
pad_vector
],
axis
=
0
)
if
self
.
unknown_token
not
in
extend_vocab_set
:
self
.
_idx_to_word
.
append
(
self
.
unknown_token
)
embedding_table
=
np
.
append
(
embedding_table
,
[
unk_vector
],
axis
=
0
)
self
.
_word_to_idx
=
self
.
_construct_word_to_idx
(
self
.
_idx_to_word
)
else
:
self
.
_word_to_idx
=
self
.
_construct_word_to_idx
(
self
.
_idx_to_word
)
unk_idx
=
self
.
_word_to_idx
[
self
.
unknown_token
]
embedding_table
[
unk_idx
]
=
unk_vector
pretrained_idx_to_word
=
list
(
vector_np
[
'vocab'
])
pretrained_idx_to_word
=
list
(
vector_np
[
'vocab'
])
pretrained_word_to_idx
=
self
.
_construct_word_to_idx
(
pretrained_word_to_idx
=
self
.
_construct_word_to_idx
(
pretrained_idx_to_word
)
pretrained_idx_to_word
)
...
@@ -165,6 +154,18 @@ class TokenEmbedding(nn.Embedding):
...
@@ -165,6 +154,18 @@ class TokenEmbedding(nn.Embedding):
pretrained_embedding_table
[
pretrained_vocab_subtract_index
],
pretrained_embedding_table
[
pretrained_vocab_subtract_index
],
axis
=
0
)
axis
=
0
)
if
self
.
unknown_token
not
in
extend_vocab_set
:
self
.
_idx_to_word
.
append
(
self
.
unknown_token
)
self
.
_word_to_idx
[
self
.
unknown_token
]
=
len
(
self
.
_idx_to_word
)
-
1
embedding_table
=
np
.
append
(
embedding_table
,
[
unk_vector
],
axis
=
0
)
else
:
unk_idx
=
self
.
_word_to_idx
[
self
.
unknown_token
]
embedding_table
[
unk_idx
]
=
unk_vector
self
.
_idx_to_word
.
append
(
PAD_TOKEN
)
self
.
_word_to_idx
[
PAD_TOKEN
]
=
len
(
self
.
_idx_to_word
)
-
1
embedding_table
=
np
.
append
(
embedding_table
,
[
pad_vector
],
axis
=
0
)
logger
.
info
(
"Finish extending vocab."
)
logger
.
info
(
"Finish extending vocab."
)
return
embedding_table
return
embedding_table
...
@@ -221,13 +222,12 @@ class TokenEmbedding(nn.Embedding):
...
@@ -221,13 +222,12 @@ class TokenEmbedding(nn.Embedding):
def
__repr__
(
self
):
def
__repr__
(
self
):
s
=
"Object type: {}
\
s
=
"Object type: {}
\
\n
Padding index: {}
\
\n
Padding token: {}
\
\n
Unknown index: {}
\
\n
Unknown index: {}
\
\n
Unknown token: {}
\
\n
Unknown token: {}
\
\n
Padding index: {}
\
\n
Padding token: {}
\
\n
{}"
.
format
(
\n
{}"
.
format
(
super
(
TokenEmbedding
,
self
).
__repr__
(),
super
(
TokenEmbedding
,
self
).
__repr__
(),
self
.
_word_to_idx
[
PAD_TOKEN
],
PAD_TOKEN
,
self
.
_word_to_idx
[
self
.
unknown_token
],
self
.
unknown_token
,
self
.
_word_to_idx
[
self
.
unknown_token
],
self
.
unknown_token
,
self
.
weight
)
self
.
_word_to_idx
[
PAD_TOKEN
],
PAD_TOKEN
,
self
.
weight
)
return
s
return
s
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录