Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
6f8e0ab5
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6f8e0ab5
编写于
6月 18, 2021
作者:
K
kinghuin
提交者:
GitHub
6月 18, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix ernie_gen evaluation bug
上级
cf5bb9f1
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
79 addition
and
80 deletion
+79
-80
modules/text/text_generation/ernie_gen/README.md
modules/text/text_generation/ernie_gen/README.md
+2
-2
modules/text/text_generation/ernie_gen/module.py
modules/text/text_generation/ernie_gen/module.py
+68
-69
modules/text/text_generation/ernie_gen_acrostic_poetry/README.md
.../text/text_generation/ernie_gen_acrostic_poetry/README.md
+2
-2
modules/text/text_generation/ernie_gen_couplet/README.md
modules/text/text_generation/ernie_gen_couplet/README.md
+2
-2
modules/text/text_generation/ernie_gen_lover_words/README.md
modules/text/text_generation/ernie_gen_lover_words/README.md
+2
-2
modules/text/text_generation/ernie_gen_poetry/README.md
modules/text/text_generation/ernie_gen_poetry/README.md
+3
-3
未找到文件。
modules/text/text_generation/ernie_gen/README.md
浏览文件 @
6f8e0ab5
...
...
@@ -170,9 +170,9 @@ https://github.com/PaddlePaddle/ERNIE/blob/repro/ernie-gen/
### 依赖
paddlepaddle >=
1.8.2
paddlepaddle >=
2.0.0
paddlehub >=
1.7
.0
paddlehub >=
2.0
.0
paddlenlp >= 2.0.0
...
...
modules/text/text_generation/ernie_gen/module.py
浏览文件 @
6f8e0ab5
...
...
@@ -42,7 +42,7 @@ from .model import StackModel
author_email
=
""
,
type
=
"nlp/text_generation"
,
)
class
ErnieGen
(
hub
.
Module
):
class
ErnieGen
():
def
__init__
(
self
):
"""
initialize with the necessary elements
...
...
@@ -59,25 +59,25 @@ class ErnieGen(hub.Module):
return
self
.
_model
def
finetune
(
self
,
train_path
,
dev_path
=
None
,
save_dir
=
"ernie_gen_result"
,
init_ckpt_path
=
None
,
use_gpu
=
True
,
max_steps
=
500
,
batch_size
=
8
,
max_encode_len
=
50
,
max_decode_len
=
50
,
learning_rate
=
5e-5
,
warmup_proportion
=
0.1
,
weight_decay
=
0.1
,
noise_prob
=
0
,
label_smooth
=
0
,
beam_width
=
5
,
length_penalty
=
1.0
,
log_interval
=
100
,
save_interval
=
200
,
self
,
train_path
,
dev_path
=
None
,
save_dir
=
"ernie_gen_result"
,
init_ckpt_path
=
None
,
use_gpu
=
True
,
max_steps
=
500
,
batch_size
=
8
,
max_encode_len
=
50
,
max_decode_len
=
50
,
learning_rate
=
5e-5
,
warmup_proportion
=
0.1
,
weight_decay
=
0.1
,
noise_prob
=
0
,
label_smooth
=
0
,
beam_width
=
5
,
length_penalty
=
1.0
,
log_interval
=
100
,
save_interval
=
200
,
):
"""
finetune with the specified dataset.
...
...
@@ -119,13 +119,12 @@ class ErnieGen(hub.Module):
train_dataset
=
self
.
_load_dataset
(
train_path
)
attn_id
=
self
.
tokenizer
.
vocab
[
'[MASK]'
]
trans_func
=
convert_example
(
tokenizer
=
self
.
tokenizer
,
attn_id
=
attn_id
,
tgt_type_id
=
1
,
max_encode_len
=
max_encode_len
,
max_decode_len
=
max_decode_len
,
noise_prob
=
noise_prob
)
trans_func
=
convert_example
(
tokenizer
=
self
.
tokenizer
,
attn_id
=
attn_id
,
tgt_type_id
=
1
,
max_encode_len
=
max_encode_len
,
max_decode_len
=
max_decode_len
,
noise_prob
=
noise_prob
)
train_dataset
=
train_dataset
.
map
(
trans_func
)
train_batch_sampler
=
paddle
.
io
.
BatchSampler
(
train_dataset
,
batch_size
=
batch_size
,
shuffle
=
True
)
...
...
@@ -139,18 +138,20 @@ class ErnieGen(hub.Module):
Pad
(
axis
=
0
,
pad_val
=
self
.
tokenizer
.
pad_token_id
),
# attn_ids
Pad
(
axis
=
0
,
pad_val
=
self
.
tokenizer
.
pad_token_id
),
# tgt_labels
):
after_padding
(
fn
(
samples
))
train_data_loader
=
DataLoader
(
dataset
=
train_dataset
,
batch_sampler
=
train_batch_sampler
,
collate_fn
=
batchify_fn
,
num_workers
=
0
,
return_list
=
True
)
train_data_loader
=
DataLoader
(
dataset
=
train_dataset
,
batch_sampler
=
train_batch_sampler
,
collate_fn
=
batchify_fn
,
num_workers
=
0
,
return_list
=
True
)
if
dev_path
:
dev_dataset
=
self
.
_load_dataset
(
dev_path
)
dev_dataset
=
dev_dataset
.
map
(
trans_func
)
dev_data_loader
=
DataLoader
(
dataset
=
dev_dataset
,
batch_size
=
batch_size
,
collate_fn
=
batchify_fn
,
num_workers
=
0
,
return_list
=
True
)
dev_data_loader
=
DataLoader
(
dataset
=
dev_dataset
,
batch_size
=
batch_size
,
collate_fn
=
batchify_fn
,
num_workers
=
0
,
return_list
=
True
)
label_num
=
self
.
model
.
word_emb
.
weight
.
shape
[
0
]
train_model
=
StackModel
(
self
.
model
)
...
...
@@ -158,12 +159,11 @@ class ErnieGen(hub.Module):
# Generate parameter names needed to perform weight decay.
# All bias and LayerNorm parameters are excluded.
decay_params
=
[
p
.
name
for
n
,
p
in
self
.
model
.
named_parameters
()
if
not
any
(
nd
in
n
for
nd
in
[
"bias"
,
"norm"
])]
optimizer
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
lr_scheduler
,
parameters
=
self
.
model
.
parameters
(),
weight_decay
=
weight_decay
,
grad_clip
=
nn
.
ClipGradByGlobalNorm
(
1.0
),
apply_decay_param_fun
=
lambda
x
:
x
in
decay_params
)
optimizer
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
lr_scheduler
,
parameters
=
self
.
model
.
parameters
(),
weight_decay
=
weight_decay
,
grad_clip
=
nn
.
ClipGradByGlobalNorm
(
1.0
),
apply_decay_param_fun
=
lambda
x
:
x
in
decay_params
)
rouge1
=
Rouge1
()
rouge2
=
Rouge2
()
...
...
@@ -175,8 +175,8 @@ class ErnieGen(hub.Module):
(
src_ids
,
src_tids
,
src_pids
,
tgt_ids
,
tgt_tids
,
tgt_pids
,
attn_ids
,
mask_src_2_src
,
mask_tgt_2_srctgt
,
mask_attn_2_srctgtattn
,
tgt_labels
,
_
)
=
batch
if
label_smooth
>
0.
:
tgt_labels
=
nn
.
functional
.
label_smooth
(
nn
.
functional
.
one_hot
(
tgt_labels
,
label_num
),
epsilon
=
label_smooth
)
tgt_labels
=
nn
.
functional
.
label_smooth
(
nn
.
functional
.
one_hot
(
tgt_labels
,
label_num
),
epsilon
=
label_smooth
)
tgt_pos
=
paddle
.
nonzero
(
attn_ids
==
attn_id
)
loss
=
train_model
(
src_ids
,
src_tids
,
src_pids
,
tgt_ids
,
tgt_tids
,
tgt_pids
,
attn_ids
,
mask_src_2_src
,
...
...
@@ -190,8 +190,8 @@ class ErnieGen(hub.Module):
if
global_step
%
log_interval
==
0
and
paddle
.
distributed
.
get_rank
()
==
0
:
loss_np
=
loss
.
numpy
()
ppl
=
np
.
exp
(
loss_np
)
logger
.
info
(
'[step %d / %d]train loss %.5f, ppl %.5f, elr %.3e'
%
(
global_step
,
max_steps
,
loss_np
,
ppl
,
lr_scheduler
.
get_lr
()))
logger
.
info
(
'[step %d / %d]train loss %.5f, ppl %.5f, elr %.3e'
%
(
global_step
,
max_steps
,
loss_np
,
ppl
,
lr_scheduler
.
get_lr
()))
if
save_dir
and
global_step
%
save_interval
==
0
and
global_step
>
0
:
loss_np
=
loss
.
numpy
()
ppl
=
np
.
exp
(
loss_np
)
...
...
@@ -214,8 +214,8 @@ class ErnieGen(hub.Module):
if
global_step
%
save_interval
!=
0
:
loss_np
=
loss
.
numpy
()
ppl
=
np
.
exp
(
loss_np
)
logger
.
info
(
'[final step %d]train loss %.5f, ppl %.5f, elr %.3e'
%
(
global_step
,
loss_np
,
ppl
,
lr_scheduler
.
get_lr
()))
logger
.
info
(
'[final step %d]train loss %.5f, ppl %.5f, elr %.3e'
%
(
global_step
,
loss_np
,
ppl
,
lr_scheduler
.
get_lr
()))
if
save_dir
:
save_name
=
"step_%s_ppl_%.5f.pdparams"
%
(
global_step
,
ppl
)
save_path
=
os
.
path
.
join
(
save_dir
,
save_name
)
...
...
@@ -291,6 +291,7 @@ class ErnieGen(hub.Module):
def
_evaluate
(
self
,
model
,
data_loader
,
tokenizer
,
rouge1
,
rouge2
,
attn_id
,
max_decode_len
,
max_encode_len
,
beam_width
,
length_penalty
):
paddle
.
disable_static
()
model
.
eval
()
vocab
=
tokenizer
.
vocab
...
...
@@ -305,21 +306,20 @@ class ErnieGen(hub.Module):
for
data
in
data_loader
:
(
src_ids
,
src_tids
,
src_pids
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
raw_tgt_labels
)
=
data
# never use target when infer
# Use greedy_search_infilling or beam_search_infilling to get predictions
output_ids
=
beam_search_infilling
(
model
,
src_ids
,
src_tids
,
eos_id
=
eos_id
,
sos_id
=
sos_id
,
attn_id
=
attn_id
,
pad_id
=
pad_id
,
unk_id
=
unk_id
,
vocab_size
=
vocab_size
,
max_decode_len
=
max_decode_len
,
max_encode_len
=
max_encode_len
,
beam_width
=
beam_width
,
length_penalty
=
length_penalty
,
tgt_type_id
=
1
)
output_ids
=
beam_search_infilling
(
model
,
src_ids
,
src_tids
,
eos_id
=
eos_id
,
sos_id
=
sos_id
,
attn_id
=
attn_id
,
pad_id
=
pad_id
,
unk_id
=
unk_id
,
vocab_size
=
vocab_size
,
max_decode_len
=
max_decode_len
,
max_encode_len
=
max_encode_len
,
beam_width
=
beam_width
,
length_penalty
=
length_penalty
,
tgt_type_id
=
1
)
for
ids
in
output_ids
.
tolist
():
if
eos_id
in
ids
:
...
...
@@ -361,11 +361,10 @@ class ErnieGen(hub.Module):
if
__name__
==
"__main__"
:
module
=
ErnieGen
()
result
=
module
.
finetune
(
train_path
=
'test_data/train.txt'
,
dev_path
=
'test_data/dev.txt'
,
max_steps
=
30
,
batch_size
=
2
,
log_interval
=
10
,
save_interval
=
20
)
result
=
module
.
finetune
(
train_path
=
'test_data/train.txt'
,
dev_path
=
'test_data/dev.txt'
,
max_steps
=
30
,
batch_size
=
2
,
log_interval
=
10
,
save_interval
=
20
)
module
.
export
(
params_path
=
result
[
'last_save_path'
],
module_name
=
"ernie_gen_test"
,
author
=
"test"
)
modules/text/text_generation/ernie_gen_acrostic_poetry/README.md
浏览文件 @
6f8e0ab5
...
...
@@ -99,9 +99,9 @@ https://github.com/PaddlePaddle/ERNIE/blob/repro/ernie-gen/
### 依赖
paddlepaddle >=
1.8.2
paddlepaddle >=
2.0.0
paddlehub >=
1.7
.0
paddlehub >=
2.0
.0
paddlenlp >= 2.0.0
...
...
modules/text/text_generation/ernie_gen_couplet/README.md
浏览文件 @
6f8e0ab5
...
...
@@ -87,9 +87,9 @@ https://github.com/PaddlePaddle/ERNIE/blob/repro/ernie-gen/
### 依赖
paddlepaddle >=
1.8.2
paddlepaddle >=
2.0.0
paddlehub >=
1.7
.0
paddlehub >=
2.0
.0
paddlenlp >= 2.0.0
...
...
modules/text/text_generation/ernie_gen_lover_words/README.md
浏览文件 @
6f8e0ab5
...
...
@@ -87,9 +87,9 @@ https://github.com/PaddlePaddle/ERNIE/blob/repro/ernie-gen/
### 依赖
paddlepaddle >=
1.8.2
paddlepaddle >=
2.0.0
paddlehub >=
1.7
.0
paddlehub >=
2.0
.0
paddlenlp >= 2.0.0
...
...
modules/text/text_generation/ernie_gen_poetry/README.md
浏览文件 @
6f8e0ab5
...
...
@@ -87,11 +87,11 @@ https://github.com/PaddlePaddle/ERNIE/blob/repro/ernie-gen/
### 依赖
paddlepaddle >=
1.8.2
paddlepaddle >=
2.0.0
paddlehub >=
1.7
.0
paddlehub >=
2.0
.0
PaddleNLP
>= 2.0.0
paddlenlp
>= 2.0.0
## 更新历史
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录