Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
cf5bb9f1
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
cf5bb9f1
编写于
6月 17, 2021
作者:
W
wuzewu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix ernie_gen* bug.
上级
a7c06ac9
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
168 addition
and
157 deletion
+168
-157
modules/text/text_generation/ernie_gen/module.py
modules/text/text_generation/ernie_gen/module.py
+70
-67
modules/text/text_generation/ernie_gen_acrostic_poetry/module.py
.../text/text_generation/ernie_gen_acrostic_poetry/module.py
+26
-24
modules/text/text_generation/ernie_gen_couplet/module.py
modules/text/text_generation/ernie_gen_couplet/module.py
+24
-22
modules/text/text_generation/ernie_gen_lover_words/module.py
modules/text/text_generation/ernie_gen_lover_words/module.py
+24
-22
modules/text/text_generation/ernie_gen_poetry/module.py
modules/text/text_generation/ernie_gen_poetry/module.py
+24
-22
未找到文件。
modules/text/text_generation/ernie_gen/module.py
浏览文件 @
cf5bb9f1
...
@@ -43,7 +43,7 @@ from .model import StackModel
...
@@ -43,7 +43,7 @@ from .model import StackModel
type
=
"nlp/text_generation"
,
type
=
"nlp/text_generation"
,
)
)
class
ErnieGen
(
hub
.
Module
):
class
ErnieGen
(
hub
.
Module
):
def
_
initialize
(
self
):
def
_
_init__
(
self
):
"""
"""
initialize with the necessary elements
initialize with the necessary elements
"""
"""
...
@@ -109,6 +109,7 @@ class ErnieGen(hub.Module):
...
@@ -109,6 +109,7 @@ class ErnieGen(hub.Module):
last_ppl(float): last model ppl.
last_ppl(float): last model ppl.
}
}
"""
"""
paddle
.
disable_static
()
paddle
.
set_device
(
'gpu'
)
if
use_gpu
else
paddle
.
set_device
(
'cpu'
)
paddle
.
set_device
(
'gpu'
)
if
use_gpu
else
paddle
.
set_device
(
'cpu'
)
if
init_ckpt_path
is
not
None
:
if
init_ckpt_path
is
not
None
:
...
@@ -118,7 +119,8 @@ class ErnieGen(hub.Module):
...
@@ -118,7 +119,8 @@ class ErnieGen(hub.Module):
train_dataset
=
self
.
_load_dataset
(
train_path
)
train_dataset
=
self
.
_load_dataset
(
train_path
)
attn_id
=
self
.
tokenizer
.
vocab
[
'[MASK]'
]
attn_id
=
self
.
tokenizer
.
vocab
[
'[MASK]'
]
trans_func
=
convert_example
(
tokenizer
=
self
.
tokenizer
,
trans_func
=
convert_example
(
tokenizer
=
self
.
tokenizer
,
attn_id
=
attn_id
,
attn_id
=
attn_id
,
tgt_type_id
=
1
,
tgt_type_id
=
1
,
max_encode_len
=
max_encode_len
,
max_encode_len
=
max_encode_len
,
...
@@ -137,7 +139,8 @@ class ErnieGen(hub.Module):
...
@@ -137,7 +139,8 @@ class ErnieGen(hub.Module):
Pad
(
axis
=
0
,
pad_val
=
self
.
tokenizer
.
pad_token_id
),
# attn_ids
Pad
(
axis
=
0
,
pad_val
=
self
.
tokenizer
.
pad_token_id
),
# attn_ids
Pad
(
axis
=
0
,
pad_val
=
self
.
tokenizer
.
pad_token_id
),
# tgt_labels
Pad
(
axis
=
0
,
pad_val
=
self
.
tokenizer
.
pad_token_id
),
# tgt_labels
):
after_padding
(
fn
(
samples
))
):
after_padding
(
fn
(
samples
))
train_data_loader
=
DataLoader
(
dataset
=
train_dataset
,
train_data_loader
=
DataLoader
(
dataset
=
train_dataset
,
batch_sampler
=
train_batch_sampler
,
batch_sampler
=
train_batch_sampler
,
collate_fn
=
batchify_fn
,
collate_fn
=
batchify_fn
,
num_workers
=
0
,
num_workers
=
0
,
...
@@ -146,11 +149,8 @@ class ErnieGen(hub.Module):
...
@@ -146,11 +149,8 @@ class ErnieGen(hub.Module):
if
dev_path
:
if
dev_path
:
dev_dataset
=
self
.
_load_dataset
(
dev_path
)
dev_dataset
=
self
.
_load_dataset
(
dev_path
)
dev_dataset
=
dev_dataset
.
map
(
trans_func
)
dev_dataset
=
dev_dataset
.
map
(
trans_func
)
dev_data_loader
=
DataLoader
(
dataset
=
dev_dataset
,
dev_data_loader
=
DataLoader
(
batch_size
=
batch_size
,
dataset
=
dev_dataset
,
batch_size
=
batch_size
,
collate_fn
=
batchify_fn
,
num_workers
=
0
,
return_list
=
True
)
collate_fn
=
batchify_fn
,
num_workers
=
0
,
return_list
=
True
)
label_num
=
self
.
model
.
word_emb
.
weight
.
shape
[
0
]
label_num
=
self
.
model
.
word_emb
.
weight
.
shape
[
0
]
train_model
=
StackModel
(
self
.
model
)
train_model
=
StackModel
(
self
.
model
)
...
@@ -158,7 +158,8 @@ class ErnieGen(hub.Module):
...
@@ -158,7 +158,8 @@ class ErnieGen(hub.Module):
# Generate parameter names needed to perform weight decay.
# Generate parameter names needed to perform weight decay.
# All bias and LayerNorm parameters are excluded.
# All bias and LayerNorm parameters are excluded.
decay_params
=
[
p
.
name
for
n
,
p
in
self
.
model
.
named_parameters
()
if
not
any
(
nd
in
n
for
nd
in
[
"bias"
,
"norm"
])]
decay_params
=
[
p
.
name
for
n
,
p
in
self
.
model
.
named_parameters
()
if
not
any
(
nd
in
n
for
nd
in
[
"bias"
,
"norm"
])]
optimizer
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
lr_scheduler
,
optimizer
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
lr_scheduler
,
parameters
=
self
.
model
.
parameters
(),
parameters
=
self
.
model
.
parameters
(),
weight_decay
=
weight_decay
,
weight_decay
=
weight_decay
,
grad_clip
=
nn
.
ClipGradByGlobalNorm
(
1.0
),
grad_clip
=
nn
.
ClipGradByGlobalNorm
(
1.0
),
...
@@ -174,8 +175,8 @@ class ErnieGen(hub.Module):
...
@@ -174,8 +175,8 @@ class ErnieGen(hub.Module):
(
src_ids
,
src_tids
,
src_pids
,
tgt_ids
,
tgt_tids
,
tgt_pids
,
attn_ids
,
mask_src_2_src
,
mask_tgt_2_srctgt
,
(
src_ids
,
src_tids
,
src_pids
,
tgt_ids
,
tgt_tids
,
tgt_pids
,
attn_ids
,
mask_src_2_src
,
mask_tgt_2_srctgt
,
mask_attn_2_srctgtattn
,
tgt_labels
,
_
)
=
batch
mask_attn_2_srctgtattn
,
tgt_labels
,
_
)
=
batch
if
label_smooth
>
0.
:
if
label_smooth
>
0.
:
tgt_labels
=
nn
.
functional
.
label_smooth
(
nn
.
functional
.
one_hot
(
tgt_labels
,
label_num
),
tgt_labels
=
nn
.
functional
.
label_smooth
(
epsilon
=
label_smooth
)
nn
.
functional
.
one_hot
(
tgt_labels
,
label_num
),
epsilon
=
label_smooth
)
tgt_pos
=
paddle
.
nonzero
(
attn_ids
==
attn_id
)
tgt_pos
=
paddle
.
nonzero
(
attn_ids
==
attn_id
)
loss
=
train_model
(
src_ids
,
src_tids
,
src_pids
,
tgt_ids
,
tgt_tids
,
tgt_pids
,
attn_ids
,
mask_src_2_src
,
loss
=
train_model
(
src_ids
,
src_tids
,
src_pids
,
tgt_ids
,
tgt_tids
,
tgt_pids
,
attn_ids
,
mask_src_2_src
,
...
@@ -189,8 +190,8 @@ class ErnieGen(hub.Module):
...
@@ -189,8 +190,8 @@ class ErnieGen(hub.Module):
if
global_step
%
log_interval
==
0
and
paddle
.
distributed
.
get_rank
()
==
0
:
if
global_step
%
log_interval
==
0
and
paddle
.
distributed
.
get_rank
()
==
0
:
loss_np
=
loss
.
numpy
()
loss_np
=
loss
.
numpy
()
ppl
=
np
.
exp
(
loss_np
)
ppl
=
np
.
exp
(
loss_np
)
logger
.
info
(
'[step %d / %d]train loss %.5f, ppl %.5f, elr %.3e'
%
logger
.
info
(
'[step %d / %d]train loss %.5f, ppl %.5f, elr %.3e'
%
(
global_step
,
max_steps
,
loss_np
,
(
global_step
,
max_steps
,
loss_np
,
ppl
,
lr_scheduler
.
get_lr
()))
ppl
,
lr_scheduler
.
get_lr
()))
if
save_dir
and
global_step
%
save_interval
==
0
and
global_step
>
0
:
if
save_dir
and
global_step
%
save_interval
==
0
and
global_step
>
0
:
loss_np
=
loss
.
numpy
()
loss_np
=
loss
.
numpy
()
ppl
=
np
.
exp
(
loss_np
)
ppl
=
np
.
exp
(
loss_np
)
...
@@ -213,8 +214,8 @@ class ErnieGen(hub.Module):
...
@@ -213,8 +214,8 @@ class ErnieGen(hub.Module):
if
global_step
%
save_interval
!=
0
:
if
global_step
%
save_interval
!=
0
:
loss_np
=
loss
.
numpy
()
loss_np
=
loss
.
numpy
()
ppl
=
np
.
exp
(
loss_np
)
ppl
=
np
.
exp
(
loss_np
)
logger
.
info
(
'[final step %d]train loss %.5f, ppl %.5f, elr %.3e'
%
logger
.
info
(
'[final step %d]train loss %.5f, ppl %.5f, elr %.3e'
%
(
global_step
,
loss_np
,
ppl
,
(
global_step
,
loss_np
,
ppl
,
lr_scheduler
.
get_lr
()))
lr_scheduler
.
get_lr
()))
if
save_dir
:
if
save_dir
:
save_name
=
"step_%s_ppl_%.5f.pdparams"
%
(
global_step
,
ppl
)
save_name
=
"step_%s_ppl_%.5f.pdparams"
%
(
global_step
,
ppl
)
save_path
=
os
.
path
.
join
(
save_dir
,
save_name
)
save_path
=
os
.
path
.
join
(
save_dir
,
save_name
)
...
@@ -304,7 +305,8 @@ class ErnieGen(hub.Module):
...
@@ -304,7 +305,8 @@ class ErnieGen(hub.Module):
for
data
in
data_loader
:
for
data
in
data_loader
:
(
src_ids
,
src_tids
,
src_pids
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
raw_tgt_labels
)
=
data
# never use target when infer
(
src_ids
,
src_tids
,
src_pids
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
raw_tgt_labels
)
=
data
# never use target when infer
# Use greedy_search_infilling or beam_search_infilling to get predictions
# Use greedy_search_infilling or beam_search_infilling to get predictions
output_ids
=
beam_search_infilling
(
model
,
output_ids
=
beam_search_infilling
(
model
,
src_ids
,
src_ids
,
src_tids
,
src_tids
,
eos_id
=
eos_id
,
eos_id
=
eos_id
,
...
@@ -359,7 +361,8 @@ class ErnieGen(hub.Module):
...
@@ -359,7 +361,8 @@ class ErnieGen(hub.Module):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
module
=
ErnieGen
()
module
=
ErnieGen
()
result
=
module
.
finetune
(
train_path
=
'test_data/train.txt'
,
result
=
module
.
finetune
(
train_path
=
'test_data/train.txt'
,
dev_path
=
'test_data/dev.txt'
,
dev_path
=
'test_data/dev.txt'
,
max_steps
=
30
,
max_steps
=
30
,
batch_size
=
2
,
batch_size
=
2
,
...
...
modules/text/text_generation/ernie_gen_acrostic_poetry/module.py
浏览文件 @
cf5bb9f1
...
@@ -39,7 +39,7 @@ from ernie_gen_acrostic_poetry.decode import beam_search_infilling
...
@@ -39,7 +39,7 @@ from ernie_gen_acrostic_poetry.decode import beam_search_infilling
type
=
"nlp/text_generation"
,
type
=
"nlp/text_generation"
,
)
)
class
ErnieGen
(
hub
.
NLPPredictionModule
):
class
ErnieGen
(
hub
.
NLPPredictionModule
):
def
_
initialize
(
self
,
line
=
4
,
word
=
7
):
def
_
_init__
(
self
,
line
=
4
,
word
=
7
):
"""
"""
initialize with the necessary elements
initialize with the necessary elements
"""
"""
...
@@ -73,14 +73,16 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -73,14 +73,16 @@ class ErnieGen(hub.NLPPredictionModule):
Returns:
Returns:
results(list): the poetry continuations.
results(list): the poetry continuations.
"""
"""
paddle
.
disable_static
()
if
texts
and
isinstance
(
texts
,
list
)
and
all
(
texts
)
and
all
([
isinstance
(
text
,
str
)
for
text
in
texts
]):
if
texts
and
isinstance
(
texts
,
list
)
and
all
(
texts
)
and
all
([
isinstance
(
text
,
str
)
for
text
in
texts
]):
predicted_data
=
texts
predicted_data
=
texts
else
:
else
:
raise
ValueError
(
"The input texts should be a list with nonempty string elements."
)
raise
ValueError
(
"The input texts should be a list with nonempty string elements."
)
for
i
,
text
in
enumerate
(
texts
):
for
i
,
text
in
enumerate
(
texts
):
if
len
(
text
)
>
self
.
line
:
if
len
(
text
)
>
self
.
line
:
logger
.
warning
(
'The input text: %s, contains more than %i characters, which will be cut off'
%
logger
.
warning
(
(
text
,
self
.
line
))
'The input text: %s, contains more than %i characters, which will be cut off'
%
(
text
,
self
.
line
))
texts
[
i
]
=
text
[:
self
.
line
]
texts
[
i
]
=
text
[:
self
.
line
]
for
char
in
text
:
for
char
in
text
:
...
@@ -104,7 +106,8 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -104,7 +106,8 @@ class ErnieGen(hub.NLPPredictionModule):
encode_text
=
self
.
tokenizer
.
encode
(
text
)
encode_text
=
self
.
tokenizer
.
encode
(
text
)
src_ids
=
paddle
.
to_tensor
(
encode_text
[
'input_ids'
]).
unsqueeze
(
0
)
src_ids
=
paddle
.
to_tensor
(
encode_text
[
'input_ids'
]).
unsqueeze
(
0
)
src_sids
=
paddle
.
to_tensor
(
encode_text
[
'token_type_ids'
]).
unsqueeze
(
0
)
src_sids
=
paddle
.
to_tensor
(
encode_text
[
'token_type_ids'
]).
unsqueeze
(
0
)
output_ids
=
beam_search_infilling
(
self
.
model
,
output_ids
=
beam_search_infilling
(
self
.
model
,
src_ids
,
src_ids
,
src_sids
,
src_sids
,
eos_id
=
self
.
tokenizer
.
vocab
[
'[SEP]'
],
eos_id
=
self
.
tokenizer
.
vocab
[
'[SEP]'
],
...
@@ -130,10 +133,8 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -130,10 +133,8 @@ class ErnieGen(hub.NLPPredictionModule):
"""
"""
Add the command config options
Add the command config options
"""
"""
self
.
arg_config_group
.
add_argument
(
'--use_gpu'
,
self
.
arg_config_group
.
add_argument
(
type
=
ast
.
literal_eval
,
'--use_gpu'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"whether use GPU for prediction"
)
default
=
False
,
help
=
"whether use GPU for prediction"
)
self
.
arg_config_group
.
add_argument
(
'--beam_width'
,
type
=
int
,
default
=
5
,
help
=
"the beam search width"
)
self
.
arg_config_group
.
add_argument
(
'--beam_width'
,
type
=
int
,
default
=
5
,
help
=
"the beam search width"
)
...
@@ -142,7 +143,8 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -142,7 +143,8 @@ class ErnieGen(hub.NLPPredictionModule):
"""
"""
Run as a command
Run as a command
"""
"""
self
.
parser
=
argparse
.
ArgumentParser
(
description
=
'Run the %s module.'
%
self
.
name
,
self
.
parser
=
argparse
.
ArgumentParser
(
description
=
'Run the %s module.'
%
self
.
name
,
prog
=
'hub run %s'
%
self
.
name
,
prog
=
'hub run %s'
%
self
.
name
,
usage
=
'%(prog)s'
,
usage
=
'%(prog)s'
,
add_help
=
True
)
add_help
=
True
)
...
...
modules/text/text_generation/ernie_gen_couplet/module.py
浏览文件 @
cf5bb9f1
...
@@ -39,7 +39,7 @@ from ernie_gen_couplet.decode import beam_search_infilling
...
@@ -39,7 +39,7 @@ from ernie_gen_couplet.decode import beam_search_infilling
type
=
"nlp/text_generation"
,
type
=
"nlp/text_generation"
,
)
)
class
ErnieGen
(
hub
.
NLPPredictionModule
):
class
ErnieGen
(
hub
.
NLPPredictionModule
):
def
_
initialize
(
self
):
def
_
_init__
(
self
):
"""
"""
initialize with the necessary elements
initialize with the necessary elements
"""
"""
...
@@ -67,6 +67,8 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -67,6 +67,8 @@ class ErnieGen(hub.NLPPredictionModule):
Returns:
Returns:
results(list): the right rolls.
results(list): the right rolls.
"""
"""
paddle
.
disable_static
()
if
texts
and
isinstance
(
texts
,
list
)
and
all
(
texts
)
and
all
([
isinstance
(
text
,
str
)
for
text
in
texts
]):
if
texts
and
isinstance
(
texts
,
list
)
and
all
(
texts
)
and
all
([
isinstance
(
text
,
str
)
for
text
in
texts
]):
predicted_data
=
texts
predicted_data
=
texts
else
:
else
:
...
@@ -93,7 +95,8 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -93,7 +95,8 @@ class ErnieGen(hub.NLPPredictionModule):
encode_text
=
self
.
tokenizer
.
encode
(
text
)
encode_text
=
self
.
tokenizer
.
encode
(
text
)
src_ids
=
paddle
.
to_tensor
(
encode_text
[
'input_ids'
]).
unsqueeze
(
0
)
src_ids
=
paddle
.
to_tensor
(
encode_text
[
'input_ids'
]).
unsqueeze
(
0
)
src_sids
=
paddle
.
to_tensor
(
encode_text
[
'token_type_ids'
]).
unsqueeze
(
0
)
src_sids
=
paddle
.
to_tensor
(
encode_text
[
'token_type_ids'
]).
unsqueeze
(
0
)
output_ids
=
beam_search_infilling
(
self
.
model
,
output_ids
=
beam_search_infilling
(
self
.
model
,
src_ids
,
src_ids
,
src_sids
,
src_sids
,
eos_id
=
self
.
tokenizer
.
vocab
[
'[SEP]'
],
eos_id
=
self
.
tokenizer
.
vocab
[
'[SEP]'
],
...
@@ -119,10 +122,8 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -119,10 +122,8 @@ class ErnieGen(hub.NLPPredictionModule):
"""
"""
Add the command config options
Add the command config options
"""
"""
self
.
arg_config_group
.
add_argument
(
'--use_gpu'
,
self
.
arg_config_group
.
add_argument
(
type
=
ast
.
literal_eval
,
'--use_gpu'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"whether use GPU for prediction"
)
default
=
False
,
help
=
"whether use GPU for prediction"
)
self
.
arg_config_group
.
add_argument
(
'--beam_width'
,
type
=
int
,
default
=
5
,
help
=
"the beam search width"
)
self
.
arg_config_group
.
add_argument
(
'--beam_width'
,
type
=
int
,
default
=
5
,
help
=
"the beam search width"
)
...
@@ -131,7 +132,8 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -131,7 +132,8 @@ class ErnieGen(hub.NLPPredictionModule):
"""
"""
Run as a command
Run as a command
"""
"""
self
.
parser
=
argparse
.
ArgumentParser
(
description
=
'Run the %s module.'
%
self
.
name
,
self
.
parser
=
argparse
.
ArgumentParser
(
description
=
'Run the %s module.'
%
self
.
name
,
prog
=
'hub run %s'
%
self
.
name
,
prog
=
'hub run %s'
%
self
.
name
,
usage
=
'%(prog)s'
,
usage
=
'%(prog)s'
,
add_help
=
True
)
add_help
=
True
)
...
...
modules/text/text_generation/ernie_gen_lover_words/module.py
浏览文件 @
cf5bb9f1
...
@@ -39,7 +39,7 @@ from ernie_gen_lover_words.decode import beam_search_infilling
...
@@ -39,7 +39,7 @@ from ernie_gen_lover_words.decode import beam_search_infilling
type
=
"nlp/text_generation"
,
type
=
"nlp/text_generation"
,
)
)
class
ErnieGen
(
hub
.
NLPPredictionModule
):
class
ErnieGen
(
hub
.
NLPPredictionModule
):
def
_
initialize
(
self
):
def
_
_init__
(
self
):
"""
"""
initialize with the necessary elements
initialize with the necessary elements
"""
"""
...
@@ -67,6 +67,8 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -67,6 +67,8 @@ class ErnieGen(hub.NLPPredictionModule):
Returns:
Returns:
results(list): the poetry continuations.
results(list): the poetry continuations.
"""
"""
paddle
.
disable_static
()
if
texts
and
isinstance
(
texts
,
list
)
and
all
(
texts
)
and
all
([
isinstance
(
text
,
str
)
for
text
in
texts
]):
if
texts
and
isinstance
(
texts
,
list
)
and
all
(
texts
)
and
all
([
isinstance
(
text
,
str
)
for
text
in
texts
]):
predicted_data
=
texts
predicted_data
=
texts
else
:
else
:
...
@@ -85,7 +87,8 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -85,7 +87,8 @@ class ErnieGen(hub.NLPPredictionModule):
encode_text
=
self
.
tokenizer
.
encode
(
text
)
encode_text
=
self
.
tokenizer
.
encode
(
text
)
src_ids
=
paddle
.
to_tensor
(
encode_text
[
'input_ids'
]).
unsqueeze
(
0
)
src_ids
=
paddle
.
to_tensor
(
encode_text
[
'input_ids'
]).
unsqueeze
(
0
)
src_sids
=
paddle
.
to_tensor
(
encode_text
[
'token_type_ids'
]).
unsqueeze
(
0
)
src_sids
=
paddle
.
to_tensor
(
encode_text
[
'token_type_ids'
]).
unsqueeze
(
0
)
output_ids
=
beam_search_infilling
(
self
.
model
,
output_ids
=
beam_search_infilling
(
self
.
model
,
src_ids
,
src_ids
,
src_sids
,
src_sids
,
eos_id
=
self
.
tokenizer
.
vocab
[
'[SEP]'
],
eos_id
=
self
.
tokenizer
.
vocab
[
'[SEP]'
],
...
@@ -111,10 +114,8 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -111,10 +114,8 @@ class ErnieGen(hub.NLPPredictionModule):
"""
"""
Add the command config options
Add the command config options
"""
"""
self
.
arg_config_group
.
add_argument
(
'--use_gpu'
,
self
.
arg_config_group
.
add_argument
(
type
=
ast
.
literal_eval
,
'--use_gpu'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"whether use GPU for prediction"
)
default
=
False
,
help
=
"whether use GPU for prediction"
)
self
.
arg_config_group
.
add_argument
(
'--beam_width'
,
type
=
int
,
default
=
5
,
help
=
"the beam search width"
)
self
.
arg_config_group
.
add_argument
(
'--beam_width'
,
type
=
int
,
default
=
5
,
help
=
"the beam search width"
)
...
@@ -123,7 +124,8 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -123,7 +124,8 @@ class ErnieGen(hub.NLPPredictionModule):
"""
"""
Run as a command
Run as a command
"""
"""
self
.
parser
=
argparse
.
ArgumentParser
(
description
=
'Run the %s module.'
%
self
.
name
,
self
.
parser
=
argparse
.
ArgumentParser
(
description
=
'Run the %s module.'
%
self
.
name
,
prog
=
'hub run %s'
%
self
.
name
,
prog
=
'hub run %s'
%
self
.
name
,
usage
=
'%(prog)s'
,
usage
=
'%(prog)s'
,
add_help
=
True
)
add_help
=
True
)
...
...
modules/text/text_generation/ernie_gen_poetry/module.py
浏览文件 @
cf5bb9f1
...
@@ -39,7 +39,7 @@ from ernie_gen_poetry.decode import beam_search_infilling
...
@@ -39,7 +39,7 @@ from ernie_gen_poetry.decode import beam_search_infilling
type
=
"nlp/text_generation"
,
type
=
"nlp/text_generation"
,
)
)
class
ErnieGen
(
hub
.
NLPPredictionModule
):
class
ErnieGen
(
hub
.
NLPPredictionModule
):
def
_
initialize
(
self
):
def
_
_init__
(
self
):
"""
"""
initialize with the necessary elements
initialize with the necessary elements
"""
"""
...
@@ -67,6 +67,8 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -67,6 +67,8 @@ class ErnieGen(hub.NLPPredictionModule):
Returns:
Returns:
results(list): the poetry continuations.
results(list): the poetry continuations.
"""
"""
paddle
.
disable_static
()
if
texts
and
isinstance
(
texts
,
list
)
and
all
(
texts
)
and
all
([
isinstance
(
text
,
str
)
for
text
in
texts
]):
if
texts
and
isinstance
(
texts
,
list
)
and
all
(
texts
)
and
all
([
isinstance
(
text
,
str
)
for
text
in
texts
]):
predicted_data
=
texts
predicted_data
=
texts
else
:
else
:
...
@@ -102,7 +104,8 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -102,7 +104,8 @@ class ErnieGen(hub.NLPPredictionModule):
encode_text
=
self
.
tokenizer
.
encode
(
text
)
encode_text
=
self
.
tokenizer
.
encode
(
text
)
src_ids
=
paddle
.
to_tensor
(
encode_text
[
'input_ids'
]).
unsqueeze
(
0
)
src_ids
=
paddle
.
to_tensor
(
encode_text
[
'input_ids'
]).
unsqueeze
(
0
)
src_sids
=
paddle
.
to_tensor
(
encode_text
[
'token_type_ids'
]).
unsqueeze
(
0
)
src_sids
=
paddle
.
to_tensor
(
encode_text
[
'token_type_ids'
]).
unsqueeze
(
0
)
output_ids
=
beam_search_infilling
(
self
.
model
,
output_ids
=
beam_search_infilling
(
self
.
model
,
src_ids
,
src_ids
,
src_sids
,
src_sids
,
eos_id
=
self
.
tokenizer
.
vocab
[
'[SEP]'
],
eos_id
=
self
.
tokenizer
.
vocab
[
'[SEP]'
],
...
@@ -128,10 +131,8 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -128,10 +131,8 @@ class ErnieGen(hub.NLPPredictionModule):
"""
"""
Add the command config options
Add the command config options
"""
"""
self
.
arg_config_group
.
add_argument
(
'--use_gpu'
,
self
.
arg_config_group
.
add_argument
(
type
=
ast
.
literal_eval
,
'--use_gpu'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"whether use GPU for prediction"
)
default
=
False
,
help
=
"whether use GPU for prediction"
)
self
.
arg_config_group
.
add_argument
(
'--beam_width'
,
type
=
int
,
default
=
5
,
help
=
"the beam search width"
)
self
.
arg_config_group
.
add_argument
(
'--beam_width'
,
type
=
int
,
default
=
5
,
help
=
"the beam search width"
)
...
@@ -140,7 +141,8 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -140,7 +141,8 @@ class ErnieGen(hub.NLPPredictionModule):
"""
"""
Run as a command
Run as a command
"""
"""
self
.
parser
=
argparse
.
ArgumentParser
(
description
=
'Run the %s module.'
%
self
.
name
,
self
.
parser
=
argparse
.
ArgumentParser
(
description
=
'Run the %s module.'
%
self
.
name
,
prog
=
'hub run %s'
%
self
.
name
,
prog
=
'hub run %s'
%
self
.
name
,
usage
=
'%(prog)s'
,
usage
=
'%(prog)s'
,
add_help
=
True
)
add_help
=
True
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录