未验证 提交 cbaf2477 编写于 作者: A adaxiadaxi 提交者: GitHub

优化ernie_gen系列输入文本检查 (#833)

上级 95f3a2a4
......@@ -17,7 +17,7 @@ $ hub run ernie_gen_acrostic_poetry --input_text="我喜欢你" --use_gpu True -
**参数**
* input_text: 诗歌的藏头,长度不超过line值
* input_text: 诗歌的藏头,长度不应超过4,否则将被截断
* use\_gpu: 是否使用 GPU;**若使用GPU,请先设置CUDA\_VISIBLE\_DEVICES环境变量**
* beam\_width: beam search宽度,决定每个藏头输出的下文数目。
......@@ -109,3 +109,7 @@ paddlehub >= 1.7.0
* 1.0.0
初始发布
* 1.0.1
完善API的输入文本检查
......@@ -35,7 +35,7 @@ from ernie_gen_acrostic_poetry.model.modeling_ernie_gen import ErnieModelForGene
@moduleinfo(
name="ernie_gen_acrostic_poetry",
version="1.0.0",
version="1.0.1",
summary=
"ERNIE-GEN is a multi-flow language generation framework for both pre-training and fine-tuning. This module has fine-tuned for poetry generation task.",
author="adaxiadaxi",
......@@ -52,6 +52,7 @@ class ErnieGen(hub.NLPPredictionModule):
if word not in [5, 7]:
raise ValueError("The word could only be 5 or 7.")
self.line = line
assets_path = os.path.join(self.directory, "assets")
gen_checkpoint_path = os.path.join(
assets_path, "ernie_gen_acrostic_poetry_L%sW%s" % (line, word))
......@@ -90,6 +91,27 @@ class ErnieGen(hub.NLPPredictionModule):
Returns:
results(list): the poetry continuations.
"""
if texts and isinstance(texts, list) and all(texts) and all(
[isinstance(text, str) for text in texts]):
predicted_data = texts
else:
raise ValueError(
"The input texts should be a list with nonempty string elements."
)
for i, text in enumerate(texts):
if len(text) > self.line:
logger.warning(
'The input text: %s, contains more than %i characters, which will be cut off'
% (text, self.line))
texts[i] = text[:self.line]
for char in text:
if not '\u4e00' <= char <= '\u9fff':
logger.warning(
'The input text: %s, contains non-Chinese characters, which may result in magic output'
% text)
break
if use_gpu and "CUDA_VISIBLE_DEVICES" not in os.environ:
use_gpu = False
logger.warning(
......@@ -100,12 +122,6 @@ class ErnieGen(hub.NLPPredictionModule):
else:
place = fluid.CPUPlace()
if texts and isinstance(texts, list):
predicted_data = texts
else:
raise ValueError(
"The input data is inconsistent with expectations.")
with fluid.dygraph.guard(place):
self.model.eval()
results = []
......
......@@ -101,3 +101,7 @@ paddlehub >= 1.7.0
* 1.0.1
修复windows中的编码问题
* 1.0.2
完善API的输入文本检查
......@@ -35,7 +35,7 @@ from ernie_gen_couplet.model.modeling_ernie_gen import ErnieModelForGeneration
@moduleinfo(
name="ernie_gen_couplet",
version="1.0.1",
version="1.0.2",
summary=
"ERNIE-GEN is a multi-flow language generation framework for both pre-training and fine-tuning. This module has fine-tuned for couplet generation task.",
author="baidu-nlp",
......@@ -84,6 +84,21 @@ class ErnieGen(hub.NLPPredictionModule):
Returns:
results(list): the right rolls.
"""
if texts and isinstance(texts, list) and all(texts) and all(
[isinstance(text, str) for text in texts]):
predicted_data = texts
else:
raise ValueError(
"The input texts should be a list with nonempty string elements."
)
for i, text in enumerate(texts):
for char in text:
if not '\u4e00' <= char <= '\u9fff':
logger.warning(
'The input text: %s, contains non-Chinese characters, which may result in magic output'
% text)
break
if use_gpu and "CUDA_VISIBLE_DEVICES" not in os.environ:
use_gpu = False
logger.warning(
......@@ -94,12 +109,6 @@ class ErnieGen(hub.NLPPredictionModule):
else:
place = fluid.CPUPlace()
if texts and isinstance(texts, list):
predicted_data = texts
else:
raise ValueError(
"The input data is inconsistent with expectations.")
with fluid.dygraph.guard(place):
self.model.eval()
results = []
......
......@@ -97,3 +97,7 @@ paddlehub >= 1.7.0
* 1.0.0
初始发布
* 1.0.1
完善API的输入文本检查
......@@ -35,7 +35,7 @@ from ernie_gen_lover_words.model.modeling_ernie_gen import ErnieModelForGenerati
@moduleinfo(
name="ernie_gen_lover_words",
version="1.0.0",
version="1.0.1",
summary=
"ERNIE-GEN is a multi-flow language generation framework for both pre-training and fine-tuning. This module has fine-tuned for lover's words generation task.",
author="adaxiadaxi",
......@@ -84,6 +84,14 @@ class ErnieGen(hub.NLPPredictionModule):
Returns:
results(list): the poetry continuations.
"""
if texts and isinstance(texts, list) and all(texts) and all(
[isinstance(text, str) for text in texts]):
predicted_data = texts
else:
raise ValueError(
"The input texts should be a list with nonempty string elements."
)
if use_gpu and "CUDA_VISIBLE_DEVICES" not in os.environ:
use_gpu = False
logger.warning(
......@@ -94,12 +102,6 @@ class ErnieGen(hub.NLPPredictionModule):
else:
place = fluid.CPUPlace()
if texts and isinstance(texts, list):
predicted_data = texts
else:
raise ValueError(
"The input data is inconsistent with expectations.")
with fluid.dygraph.guard(place):
self.model.eval()
results = []
......
......@@ -101,3 +101,7 @@ paddlehub >= 1.7.0
* 1.0.1
修复windows中的编码问题
* 1.0.2
完善API的输入文本检查
......@@ -35,7 +35,7 @@ from ernie_gen_poetry.model.modeling_ernie_gen import ErnieModelForGeneration
@moduleinfo(
name="ernie_gen_poetry",
version="1.0.1",
version="1.0.2",
summary=
"ERNIE-GEN is a multi-flow language generation framework for both pre-training and fine-tuning. This module has fine-tuned for poetry generation task.",
author="baidu-nlp",
......@@ -84,6 +84,32 @@ class ErnieGen(hub.NLPPredictionModule):
Returns:
results(list): the poetry continuations.
"""
if texts and isinstance(texts, list) and all(texts) and all(
[isinstance(text, str) for text in texts]):
predicted_data = texts
else:
raise ValueError(
"The input texts should be a list with nonempty string elements."
)
for i, text in enumerate(texts):
if ',' not in text or '。' not in text:
logger.warning(
"The input text: %s, does not contain ',' or '。', which is not a complete verse and may result in magic output"
% text)
else:
front, rear = text[:-1].split(',')
if len(front) != len(rear):
logger.warning(
"The input text: %s, is no antithetical parallelism, which may result in magic output"
% text)
for char in text:
if not '\u4e00' <= char <= '\u9fff' and char not in [',', '。']:
logger.warning(
"The input text: %s, contains characters not Chinese or ‘,’ '。', which may result in magic output"
% text)
break
if use_gpu and "CUDA_VISIBLE_DEVICES" not in os.environ:
use_gpu = False
logger.warning(
......@@ -94,12 +120,6 @@ class ErnieGen(hub.NLPPredictionModule):
else:
place = fluid.CPUPlace()
if texts and isinstance(texts, list):
predicted_data = texts
else:
raise ValueError(
"The input data is inconsistent with expectations.")
with fluid.dygraph.guard(place):
self.model.eval()
results = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册