Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
d79f6c49
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 2 年 前同步成功
通知
285
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d79f6c49
编写于
10月 19, 2021
作者:
H
houj04
提交者:
GitHub
10月 19, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add xpu and npu support for ernie series. (#1639)
上级
08c56a7d
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
220 addition
and
37 deletion
+220
-37
modules/text/text_generation/ernie_gen_acrostic_poetry/module.py
.../text/text_generation/ernie_gen_acrostic_poetry/module.py
+54
-9
modules/text/text_generation/ernie_gen_couplet/module.py
modules/text/text_generation/ernie_gen_couplet/module.py
+55
-10
modules/text/text_generation/ernie_gen_lover_words/module.py
modules/text/text_generation/ernie_gen_lover_words/module.py
+55
-9
modules/text/text_generation/ernie_gen_poetry/module.py
modules/text/text_generation/ernie_gen_poetry/module.py
+56
-9
未找到文件。
modules/text/text_generation/ernie_gen_acrostic_poetry/module.py
浏览文件 @
d79f6c49
...
@@ -60,8 +60,36 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -60,8 +60,36 @@ class ErnieGen(hub.NLPPredictionModule):
self
.
rev_dict
[
self
.
tokenizer
.
vocab
[
'[UNK]'
]]
=
''
# replace [PAD]
self
.
rev_dict
[
self
.
tokenizer
.
vocab
[
'[UNK]'
]]
=
''
# replace [PAD]
self
.
rev_lookup
=
np
.
vectorize
(
lambda
i
:
self
.
rev_dict
[
i
])
self
.
rev_lookup
=
np
.
vectorize
(
lambda
i
:
self
.
rev_dict
[
i
])
# detect npu
npu_id
=
self
.
_get_device_id
(
"FLAGS_selected_npus"
)
if
npu_id
!=
-
1
:
# use npu
self
.
use_device
=
"npu"
else
:
# detect gpu
gpu_id
=
self
.
_get_device_id
(
"CUDA_VISIBLE_DEVICES"
)
if
gpu_id
!=
-
1
:
# use gpu
self
.
use_device
=
"gpu"
else
:
# detect xpu
xpu_id
=
self
.
_get_device_id
(
"XPU_VISIBLE_DEVICES"
)
if
xpu_id
!=
-
1
:
# use xpu
self
.
use_device
=
"xpu"
else
:
self
.
use_device
=
"cpu"
def
_get_device_id
(
self
,
places
):
try
:
places
=
os
.
environ
[
places
]
id
=
int
(
places
)
except
:
id
=
-
1
return
id
@
serving
@
serving
def
generate
(
self
,
texts
,
use_gpu
=
False
,
beam_width
=
5
):
def
generate
(
self
,
texts
,
use_gpu
=
False
,
beam_width
=
5
,
use_device
=
None
):
"""
"""
Get the continuation of the input poetry.
Get the continuation of the input poetry.
...
@@ -69,6 +97,7 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -69,6 +97,7 @@ class ErnieGen(hub.NLPPredictionModule):
texts(list): the front part of a poetry.
texts(list): the front part of a poetry.
use_gpu(bool): whether use gpu to predict or not
use_gpu(bool): whether use gpu to predict or not
beam_width(int): the beam search width.
beam_width(int): the beam search width.
use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag.
Returns:
Returns:
results(list): the poetry continuations.
results(list): the poetry continuations.
...
@@ -91,13 +120,25 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -91,13 +120,25 @@ class ErnieGen(hub.NLPPredictionModule):
'The input text: %s, contains non-Chinese characters, which may result in magic output'
%
text
)
'The input text: %s, contains non-Chinese characters, which may result in magic output'
%
text
)
break
break
if
use_gpu
and
"CUDA_VISIBLE_DEVICES"
not
in
os
.
environ
:
if
use_device
is
not
None
:
use_gpu
=
False
# check 'use_device' match 'device on init'
logger
.
warning
(
if
use_device
!=
self
.
use_device
:
"use_gpu has been set False as you didn't set the environment variable CUDA_VISIBLE_DEVICES while using use_gpu=True"
raise
RuntimeError
(
)
"the 'use_device' parameter when calling generate, does not match internal device found on init."
)
else
:
# use_device is None, follow use_gpu flag
if
use_gpu
==
False
:
use_device
=
"cpu"
elif
use_gpu
==
True
and
self
.
use_device
!=
'gpu'
:
use_device
=
"cpu"
logger
.
warning
(
"use_gpu has been set False as you didn't set the environment variable CUDA_VISIBLE_DEVICES while using use_gpu=True"
)
else
:
# use_gpu and self.use_device are both true
use_device
=
"gpu"
paddle
.
set_device
(
'gpu'
)
if
use_gpu
else
paddle
.
set_device
(
'cpu'
)
paddle
.
set_device
(
use_device
)
self
.
model
.
eval
()
self
.
model
.
eval
()
results
=
[]
results
=
[]
...
@@ -135,8 +176,11 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -135,8 +176,11 @@ class ErnieGen(hub.NLPPredictionModule):
"""
"""
self
.
arg_config_group
.
add_argument
(
self
.
arg_config_group
.
add_argument
(
'--use_gpu'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"whether use GPU for prediction"
)
'--use_gpu'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"whether use GPU for prediction"
)
self
.
arg_config_group
.
add_argument
(
'--beam_width'
,
type
=
int
,
default
=
5
,
help
=
"the beam search width"
)
self
.
arg_config_group
.
add_argument
(
'--beam_width'
,
type
=
int
,
default
=
5
,
help
=
"the beam search width"
)
self
.
arg_config_group
.
add_argument
(
'--use_device'
,
choices
=
[
"cpu"
,
"gpu"
,
"xpu"
,
"npu"
],
help
=
"use cpu, gpu, xpu or npu. overwrites use_gpu flag."
)
@
runnable
@
runnable
def
run_cmd
(
self
,
argvs
):
def
run_cmd
(
self
,
argvs
):
...
@@ -164,7 +208,8 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -164,7 +208,8 @@ class ErnieGen(hub.NLPPredictionModule):
self
.
parser
.
print_help
()
self
.
parser
.
print_help
()
return
None
return
None
results
=
self
.
generate
(
texts
=
input_data
,
use_gpu
=
args
.
use_gpu
,
beam_width
=
args
.
beam_width
)
results
=
self
.
generate
(
texts
=
input_data
,
use_gpu
=
args
.
use_gpu
,
beam_width
=
args
.
beam_width
,
use_device
=
args
.
use_device
)
return
results
return
results
...
...
modules/text/text_generation/ernie_gen_couplet/module.py
浏览文件 @
d79f6c49
...
@@ -54,8 +54,36 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -54,8 +54,36 @@ class ErnieGen(hub.NLPPredictionModule):
self
.
rev_dict
[
self
.
tokenizer
.
vocab
[
'[UNK]'
]]
=
''
# replace [PAD]
self
.
rev_dict
[
self
.
tokenizer
.
vocab
[
'[UNK]'
]]
=
''
# replace [PAD]
self
.
rev_lookup
=
np
.
vectorize
(
lambda
i
:
self
.
rev_dict
[
i
])
self
.
rev_lookup
=
np
.
vectorize
(
lambda
i
:
self
.
rev_dict
[
i
])
# detect npu
npu_id
=
self
.
_get_device_id
(
"FLAGS_selected_npus"
)
if
npu_id
!=
-
1
:
# use npu
self
.
use_device
=
"npu"
else
:
# detect gpu
gpu_id
=
self
.
_get_device_id
(
"CUDA_VISIBLE_DEVICES"
)
if
gpu_id
!=
-
1
:
# use gpu
self
.
use_device
=
"gpu"
else
:
# detect xpu
xpu_id
=
self
.
_get_device_id
(
"XPU_VISIBLE_DEVICES"
)
if
xpu_id
!=
-
1
:
# use xpu
self
.
use_device
=
"xpu"
else
:
self
.
use_device
=
"cpu"
def
_get_device_id
(
self
,
places
):
try
:
places
=
os
.
environ
[
places
]
id
=
int
(
places
)
except
:
id
=
-
1
return
id
@
serving
@
serving
def
generate
(
self
,
texts
,
use_gpu
=
False
,
beam_width
=
5
):
def
generate
(
self
,
texts
,
use_gpu
=
False
,
beam_width
=
5
,
use_device
=
None
):
"""
"""
Get the right rolls from the left rolls.
Get the right rolls from the left rolls.
...
@@ -63,6 +91,7 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -63,6 +91,7 @@ class ErnieGen(hub.NLPPredictionModule):
texts(list): the left rolls.
texts(list): the left rolls.
use_gpu(bool): whether use gpu to predict or not
use_gpu(bool): whether use gpu to predict or not
beam_width(int): the beam search width.
beam_width(int): the beam search width.
use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag.
Returns:
Returns:
results(list): the right rolls.
results(list): the right rolls.
...
@@ -80,13 +109,25 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -80,13 +109,25 @@ class ErnieGen(hub.NLPPredictionModule):
'The input text: %s, contains non-Chinese characters, which may result in magic output'
%
text
)
'The input text: %s, contains non-Chinese characters, which may result in magic output'
%
text
)
break
break
if
use_gpu
and
"CUDA_VISIBLE_DEVICES"
not
in
os
.
environ
:
if
use_device
is
not
None
:
use_gpu
=
False
# check 'use_device' match 'device on init'
logger
.
warning
(
if
use_device
!=
self
.
use_device
:
"use_gpu has been set False as you didn't set the environment variable CUDA_VISIBLE_DEVICES while using use_gpu=True"
raise
RuntimeError
(
)
"the 'use_device' parameter when calling generate, does not match internal device found on init."
)
else
:
paddle
.
set_device
(
'gpu'
)
if
use_gpu
else
paddle
.
set_device
(
'cpu'
)
# use_device is None, follow use_gpu flag
if
use_gpu
==
False
:
use_device
=
"cpu"
elif
use_gpu
==
True
and
self
.
use_device
!=
'gpu'
:
use_device
=
"cpu"
logger
.
warning
(
"use_gpu has been set False as you didn't set the environment variable CUDA_VISIBLE_DEVICES while using use_gpu=True"
)
else
:
# use_gpu and self.use_device are both true
use_device
=
"gpu"
paddle
.
set_device
(
use_device
)
self
.
model
.
eval
()
self
.
model
.
eval
()
results
=
[]
results
=
[]
...
@@ -124,8 +165,11 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -124,8 +165,11 @@ class ErnieGen(hub.NLPPredictionModule):
"""
"""
self
.
arg_config_group
.
add_argument
(
self
.
arg_config_group
.
add_argument
(
'--use_gpu'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"whether use GPU for prediction"
)
'--use_gpu'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"whether use GPU for prediction"
)
self
.
arg_config_group
.
add_argument
(
'--beam_width'
,
type
=
int
,
default
=
5
,
help
=
"the beam search width"
)
self
.
arg_config_group
.
add_argument
(
'--beam_width'
,
type
=
int
,
default
=
5
,
help
=
"the beam search width"
)
self
.
arg_config_group
.
add_argument
(
'--use_device'
,
choices
=
[
"cpu"
,
"gpu"
,
"xpu"
,
"npu"
],
help
=
"use cpu, gpu, xpu or npu. overwrites use_gpu flag."
)
@
runnable
@
runnable
def
run_cmd
(
self
,
argvs
):
def
run_cmd
(
self
,
argvs
):
...
@@ -153,7 +197,8 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -153,7 +197,8 @@ class ErnieGen(hub.NLPPredictionModule):
self
.
parser
.
print_help
()
self
.
parser
.
print_help
()
return
None
return
None
results
=
self
.
generate
(
texts
=
input_data
,
use_gpu
=
args
.
use_gpu
,
beam_width
=
args
.
beam_width
)
results
=
self
.
generate
(
texts
=
input_data
,
use_gpu
=
args
.
use_gpu
,
beam_width
=
args
.
beam_width
,
use_device
=
args
.
use_device
)
return
results
return
results
...
...
modules/text/text_generation/ernie_gen_lover_words/module.py
浏览文件 @
d79f6c49
...
@@ -54,8 +54,36 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -54,8 +54,36 @@ class ErnieGen(hub.NLPPredictionModule):
self
.
rev_dict
[
self
.
tokenizer
.
vocab
[
'[UNK]'
]]
=
''
# replace [PAD]
self
.
rev_dict
[
self
.
tokenizer
.
vocab
[
'[UNK]'
]]
=
''
# replace [PAD]
self
.
rev_lookup
=
np
.
vectorize
(
lambda
i
:
self
.
rev_dict
[
i
])
self
.
rev_lookup
=
np
.
vectorize
(
lambda
i
:
self
.
rev_dict
[
i
])
# detect npu
npu_id
=
self
.
_get_device_id
(
"FLAGS_selected_npus"
)
if
npu_id
!=
-
1
:
# use npu
self
.
use_device
=
"npu"
else
:
# detect gpu
gpu_id
=
self
.
_get_device_id
(
"CUDA_VISIBLE_DEVICES"
)
if
gpu_id
!=
-
1
:
# use gpu
self
.
use_device
=
"gpu"
else
:
# detect xpu
xpu_id
=
self
.
_get_device_id
(
"XPU_VISIBLE_DEVICES"
)
if
xpu_id
!=
-
1
:
# use xpu
self
.
use_device
=
"xpu"
else
:
self
.
use_device
=
"cpu"
def
_get_device_id
(
self
,
places
):
try
:
places
=
os
.
environ
[
places
]
id
=
int
(
places
)
except
:
id
=
-
1
return
id
@
serving
@
serving
def
generate
(
self
,
texts
,
use_gpu
=
False
,
beam_width
=
5
):
def
generate
(
self
,
texts
,
use_gpu
=
False
,
beam_width
=
5
,
use_device
=
None
):
"""
"""
Get the continuation of the input poetry.
Get the continuation of the input poetry.
...
@@ -63,6 +91,7 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -63,6 +91,7 @@ class ErnieGen(hub.NLPPredictionModule):
texts(list): the front part of a poetry.
texts(list): the front part of a poetry.
use_gpu(bool): whether use gpu to predict or not
use_gpu(bool): whether use gpu to predict or not
beam_width(int): the beam search width.
beam_width(int): the beam search width.
use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag.
Returns:
Returns:
results(list): the poetry continuations.
results(list): the poetry continuations.
...
@@ -74,12 +103,25 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -74,12 +103,25 @@ class ErnieGen(hub.NLPPredictionModule):
else
:
else
:
raise
ValueError
(
"The input texts should be a list with nonempty string elements."
)
raise
ValueError
(
"The input texts should be a list with nonempty string elements."
)
if
use_gpu
and
"CUDA_VISIBLE_DEVICES"
not
in
os
.
environ
:
if
use_device
is
not
None
:
use_gpu
=
False
# check 'use_device' match 'device on init'
logger
.
warning
(
if
use_device
!=
self
.
use_device
:
"use_gpu has been set False as you didn't set the environment variable CUDA_VISIBLE_DEVICES while using use_gpu=True"
raise
RuntimeError
(
)
"the 'use_device' parameter when calling generate, does not match internal device found on init."
)
paddle
.
set_device
(
'gpu'
)
if
use_gpu
else
paddle
.
set_device
(
'cpu'
)
else
:
# use_device is None, follow use_gpu flag
if
use_gpu
==
False
:
use_device
=
"cpu"
elif
use_gpu
==
True
and
self
.
use_device
!=
'gpu'
:
use_device
=
"cpu"
logger
.
warning
(
"use_gpu has been set False as you didn't set the environment variable CUDA_VISIBLE_DEVICES while using use_gpu=True"
)
else
:
# use_gpu and self.use_device are both true
use_device
=
"gpu"
paddle
.
set_device
(
use_device
)
self
.
model
.
eval
()
self
.
model
.
eval
()
results
=
[]
results
=
[]
for
text
in
predicted_data
:
for
text
in
predicted_data
:
...
@@ -116,8 +158,11 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -116,8 +158,11 @@ class ErnieGen(hub.NLPPredictionModule):
"""
"""
self
.
arg_config_group
.
add_argument
(
self
.
arg_config_group
.
add_argument
(
'--use_gpu'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"whether use GPU for prediction"
)
'--use_gpu'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"whether use GPU for prediction"
)
self
.
arg_config_group
.
add_argument
(
'--beam_width'
,
type
=
int
,
default
=
5
,
help
=
"the beam search width"
)
self
.
arg_config_group
.
add_argument
(
'--beam_width'
,
type
=
int
,
default
=
5
,
help
=
"the beam search width"
)
self
.
arg_config_group
.
add_argument
(
'--use_device'
,
choices
=
[
"cpu"
,
"gpu"
,
"xpu"
,
"npu"
],
help
=
"use cpu, gpu, xpu or npu. overwrites use_gpu flag."
)
@
runnable
@
runnable
def
run_cmd
(
self
,
argvs
):
def
run_cmd
(
self
,
argvs
):
...
@@ -145,7 +190,8 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -145,7 +190,8 @@ class ErnieGen(hub.NLPPredictionModule):
self
.
parser
.
print_help
()
self
.
parser
.
print_help
()
return
None
return
None
results
=
self
.
generate
(
texts
=
input_data
,
use_gpu
=
args
.
use_gpu
,
beam_width
=
args
.
beam_width
)
results
=
self
.
generate
(
texts
=
input_data
,
use_gpu
=
args
.
use_gpu
,
beam_width
=
args
.
beam_width
,
use_device
=
args
.
use_device
)
return
results
return
results
...
...
modules/text/text_generation/ernie_gen_poetry/module.py
浏览文件 @
d79f6c49
...
@@ -54,8 +54,36 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -54,8 +54,36 @@ class ErnieGen(hub.NLPPredictionModule):
self
.
rev_dict
[
self
.
tokenizer
.
vocab
[
'[UNK]'
]]
=
''
# replace [PAD]
self
.
rev_dict
[
self
.
tokenizer
.
vocab
[
'[UNK]'
]]
=
''
# replace [PAD]
self
.
rev_lookup
=
np
.
vectorize
(
lambda
i
:
self
.
rev_dict
[
i
])
self
.
rev_lookup
=
np
.
vectorize
(
lambda
i
:
self
.
rev_dict
[
i
])
# detect npu
npu_id
=
self
.
_get_device_id
(
"FLAGS_selected_npus"
)
if
npu_id
!=
-
1
:
# use npu
self
.
use_device
=
"npu"
else
:
# detect gpu
gpu_id
=
self
.
_get_device_id
(
"CUDA_VISIBLE_DEVICES"
)
if
gpu_id
!=
-
1
:
# use gpu
self
.
use_device
=
"gpu"
else
:
# detect xpu
xpu_id
=
self
.
_get_device_id
(
"XPU_VISIBLE_DEVICES"
)
if
xpu_id
!=
-
1
:
# use xpu
self
.
use_device
=
"xpu"
else
:
self
.
use_device
=
"cpu"
def
_get_device_id
(
self
,
places
):
try
:
places
=
os
.
environ
[
places
]
id
=
int
(
places
)
except
:
id
=
-
1
return
id
@
serving
@
serving
def
generate
(
self
,
texts
,
use_gpu
=
False
,
beam_width
=
5
):
def
generate
(
self
,
texts
,
use_gpu
=
False
,
beam_width
=
5
,
use_device
=
None
):
"""
"""
Get the continuation of the input poetry.
Get the continuation of the input poetry.
...
@@ -63,6 +91,7 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -63,6 +91,7 @@ class ErnieGen(hub.NLPPredictionModule):
texts(list): the front part of a poetry.
texts(list): the front part of a poetry.
use_gpu(bool): whether use gpu to predict or not
use_gpu(bool): whether use gpu to predict or not
beam_width(int): the beam search width.
beam_width(int): the beam search width.
use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag.
Returns:
Returns:
results(list): the poetry continuations.
results(list): the poetry continuations.
...
@@ -91,12 +120,26 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -91,12 +120,26 @@ class ErnieGen(hub.NLPPredictionModule):
%
text
)
%
text
)
break
break
if
use_gpu
and
"CUDA_VISIBLE_DEVICES"
not
in
os
.
environ
:
if
use_device
is
not
None
:
use_gpu
=
False
# check 'use_device' match 'device on init'
logger
.
warning
(
if
use_device
!=
self
.
use_device
:
"use_gpu has been set False as you didn't set the environment variable CUDA_VISIBLE_DEVICES while using use_gpu=True"
raise
RuntimeError
(
)
"the 'use_device' parameter when calling generate, does not match internal device found on init."
)
paddle
.
set_device
(
'gpu'
)
if
use_gpu
else
paddle
.
set_device
(
'cpu'
)
else
:
# use_device is None, follow use_gpu flag
if
use_gpu
==
False
:
use_device
=
"cpu"
elif
use_gpu
==
True
and
self
.
use_device
!=
'gpu'
:
use_device
=
"cpu"
logger
.
warning
(
"use_gpu has been set False as you didn't set the environment variable CUDA_VISIBLE_DEVICES while using use_gpu=True"
)
else
:
# use_gpu and self.use_device are both true
use_device
=
"gpu"
paddle
.
set_device
(
use_device
)
self
.
model
.
eval
()
self
.
model
.
eval
()
results
=
[]
results
=
[]
for
text
in
predicted_data
:
for
text
in
predicted_data
:
...
@@ -133,8 +176,11 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -133,8 +176,11 @@ class ErnieGen(hub.NLPPredictionModule):
"""
"""
self
.
arg_config_group
.
add_argument
(
self
.
arg_config_group
.
add_argument
(
'--use_gpu'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"whether use GPU for prediction"
)
'--use_gpu'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"whether use GPU for prediction"
)
self
.
arg_config_group
.
add_argument
(
'--beam_width'
,
type
=
int
,
default
=
5
,
help
=
"the beam search width"
)
self
.
arg_config_group
.
add_argument
(
'--beam_width'
,
type
=
int
,
default
=
5
,
help
=
"the beam search width"
)
self
.
arg_config_group
.
add_argument
(
'--use_device'
,
choices
=
[
"cpu"
,
"gpu"
,
"xpu"
,
"npu"
],
help
=
"use cpu, gpu, xpu or npu. overwrites use_gpu flag."
)
@
runnable
@
runnable
def
run_cmd
(
self
,
argvs
):
def
run_cmd
(
self
,
argvs
):
...
@@ -162,7 +208,8 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -162,7 +208,8 @@ class ErnieGen(hub.NLPPredictionModule):
self
.
parser
.
print_help
()
self
.
parser
.
print_help
()
return
None
return
None
results
=
self
.
generate
(
texts
=
input_data
,
use_gpu
=
args
.
use_gpu
,
beam_width
=
args
.
beam_width
)
results
=
self
.
generate
(
texts
=
input_data
,
use_gpu
=
args
.
use_gpu
,
beam_width
=
args
.
beam_width
,
use_device
=
args
.
use_device
)
return
results
return
results
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录