From 01740e0eef6e4fe307803927a5fb5f1572d41c62 Mon Sep 17 00:00:00 2001 From: jm12138 <2286040843@qq.com> Date: Thu, 29 Dec 2022 10:12:03 +0800 Subject: [PATCH] ernie zeus add gradio app (#2151) * ernie zeus add gradio app * rm multiprocessing * update version --- .../text/text_generation/ernie_zeus/README.md | 45 +++++++++++++- .../text/text_generation/ernie_zeus/module.py | 62 ++++++++++++++++++- 2 files changed, 104 insertions(+), 3 deletions(-) diff --git a/modules/text/text_generation/ernie_zeus/README.md b/modules/text/text_generation/ernie_zeus/README.md index c59a01fb..fb14cab4 100644 --- a/modules/text/text_generation/ernie_zeus/README.md +++ b/modules/text/text_generation/ernie_zeus/README.md @@ -289,7 +289,44 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注 - **返回** - text(str): 段落摘要。 -## 四、更新历史 + + +## 四、服务部署 + +- PaddleHub Serving可以部署一个在线文本生成服务。 + +- ### 第一步:启动PaddleHub Serving + + - 运行启动命令: + - ```shell + $ hub serving start -m ernie_zeus + ``` + + - 这样就完成了一个文本生成的在线服务API的部署,默认端口号为8866。 + +- ### 第二步:发送预测请求 + + - 配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果。 + + - ```python + import requests + import json + + # 发送HTTP请求 + # 参数参考自定义文本生成接口 + data = {'text': '巨大的白色城堡'} + headers = {"Content-type": "application/json"} + url = "http://127.0.0.1:8866/predict/ernie_zeus" + r = requests.post(url=url, headers=headers, data=json.dumps(data)) + + # 获取返回结果 + print(r.json()["results"]) + +- ### gradio app 支持 + 从paddlehub 2.3.1开始支持使用链接 http://127.0.0.1:8866/gradio/ernie_zeus 在浏览器中访问ernie_zeus的gradio app。 + + +## 五、更新历史 * 1.0.0 @@ -299,6 +336,10 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注 移除默认 AK 和 SK +* 1.2.0 + + 添加 Serving 和 Gradio APP + ```shell - $ hub install ernie_zeus == 1.1.0 + $ hub install ernie_zeus == 1.2.0 ``` diff --git a/modules/text/text_generation/ernie_zeus/module.py b/modules/text/text_generation/ernie_zeus/module.py index 81921f3b..377bf75f 100644 --- a/modules/text/text_generation/ernie_zeus/module.py +++ b/modules/text/text_generation/ernie_zeus/module.py @@ -6,6 +6,7 @@ import requests from paddlehub.module.module import moduleinfo from paddlehub.module.module import runnable +from paddlehub.module.module import serving def get_access_token(ak: str = None, sk: str = None) -> str: @@ -46,12 +47,13 @@ def get_access_token(ak: str = None, sk: str = None) -> str: author='paddlepaddle', author_email='', summary='ernie_zeus', - version='1.1.0') + version='1.2.0') class ERNIEZeus: def __init__(self, ak: str = None, sk: str = None) -> None: self.access_token = get_access_token(ak, sk) + @serving def custom_generation(self, text: str, min_dec_len: int = 1, @@ -379,3 +381,61 @@ class ERNIEZeus: kwargs.pop(k) return func(**kwargs) + + def create_gradio_app(self): + import gradio as gr + + def inference(task: str, + text: str, + min_dec_len: int = 2, + seq_len: int = 512, + topp: float = 0.9, + penalty_score: float = 1.0): + + func = getattr(self, task) + try: + result = func(text, min_dec_len, seq_len, topp, penalty_score) + return result + except Exception as error: + return str(error) + + examples = [ + [ + 'text_summarization', + '外媒7月18日报道,阿联酋政府当日证实该国将建设首个核电站,以应对不断上涨的用电需求。分析称阿联酋作为世界第三大石油出口国,更愿意将该能源用于出口,而非发电。首座核反应堆预计在2017年运行。cntv李婉然编译报道', + 4, 512, 0.3, 1.0 + ], + ['copywriting_generation', '芍药香氛的沐浴乳', 8, 512, 0.9, 1.2], + ['novel_continuation', '昆仑山可以说是天下龙脉的根源,所有的山脉都可以看作是昆仑的分支。这些分出来的枝枝杈杈,都可以看作是一条条独立的龙脉。', 2, 512, 0.9, 1.2], + ['answer_generation', '做生意的基本原则是什么?', 2, 512, 0.5, 1.2], + ['couplet_continuation', '天增岁月人增寿', 2, 512, 1.0, 1.0], + ['composition_generation', '拔河比赛', 128, 512, 0.9, 1.2], + ['text_cloze', '她有着一双[MASK]的眼眸。', 1, 512, 0.3, 1.2], + ] + + text = gr.Textbox( + label="input_text", + placeholder="Please enter Chinese text.", + ) + task = gr.Dropdown(label="task", + choices=[ + 'text_summarization', 'copywriting_generation', 'novel_continuation', + 'answer_generation', 'couplet_continuation', 'composition_generation', 'text_cloze' + ], + value='text_summarization') + + min_dec_len = gr.Slider(minimum=1, maximum=511, value=1, label="min_dec_len", step=1, interactive=True) + seq_len = gr.Slider(minimum=2, maximum=512, value=128, label="seq_len", step=1, interactive=True) + topp = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, label="topp", step=0.01, interactive=True) + penalty_score = gr.Slider(minimum=1.0, + maximum=2.0, + value=1.0, + label="penalty_score", + step=0.01, + interactive=True) + text_gen = gr.Text(label="generated_text") + interface = gr.Interface(inference, [task, text, min_dec_len, seq_len, topp, penalty_score], [text_gen], + examples=examples, + allow_flagging='never', + title='ERNIE-Zeus') + return interface -- GitLab