未验证 提交 14b53057 编写于 作者: C chenjian 提交者: GitHub

Add gradio app support for module (#2132)

* add gradio app support for module

* add gradio app support for module

* add gradio app support for module

* fix

* add ernie_vilg demo

* upgrade

* update module version
上级 54d5642c
......@@ -400,6 +400,9 @@ DiscoDiffusion Prompt 技巧资料:https://docs.google.com/document/d/1l8s7uS2
image = Image.open(BytesIO(base64.b64decode(result)))
image.save('result_{}.png'.format(i))
- ### gradio app 支持
从paddlehub 2.3.1开始支持使用链接 http://127.0.0.1:8866/gradio/ernie_vilg 在浏览器中访问ernie_vilg的gradio app。
## 六、更新历史
......@@ -415,6 +418,9 @@ DiscoDiffusion Prompt 技巧资料:https://docs.google.com/document/d/1l8s7uS2
移除分辨率参数,移除默认 AK 和 SK
* 1.3.0
新增对gradio app的支持
```shell
$ hub install ernie_vilg == 1.2.0
$ hub install ernie_vilg == 1.3.0
```
......@@ -207,6 +207,10 @@
- 关于PaddleHub Serving更多信息参考:[服务部署](../../../../docs/docs_ch/tutorial/serving.md)
- ## gradio app 支持
从paddlehub 2.3.1开始支持使用链接 http://127.0.0.1:8866/gradio/jieba_paddle 在浏览器中访问jieba_paddle的gradio app。
## 五、更新历史
......@@ -218,6 +222,9 @@
移除 fluid api
* 1.1.0
新增对gradio app的支持
- ```shell
$ hub install jieba_paddle==1.0.1
$ hub install jieba_paddle==1.1.0
```
......@@ -14,7 +14,7 @@ from paddlehub.module.module import serving
@moduleinfo(
name="jieba_paddle",
version="1.0.1",
version="1.1.0",
summary=
"jieba_paddle is a chineses tokenizer using BiGRU base on the PaddlePaddle deeplearning framework. More information please refer to https://github.com/fxsjy/jieba.",
author="baidu-paddle",
......@@ -54,6 +54,24 @@ class JiebaPaddle(hub.Module):
return seg_list
def create_gradio_app(self):
import gradio as gr
def inference(text):
results = self.cut(sentence=text)
return results
title = "jieba_paddle"
description = "jieba_paddle is a word segmentation model based on paddlepaddle deep learning framework."
examples = [['今天是个好日子']]
app = gr.Interface(inference,
"text", [gr.outputs.Textbox(label="words")],
title=title,
description=description,
examples=examples)
return app
def check_dependency(self):
"""
Check jieba tool dependency.
......
......@@ -11,18 +11,18 @@ import os
import numpy as np
import six
from paddle.inference import Config
from paddle.inference import create_predictor
from .custom import Customization
from .processor import load_kv_dict
from .processor import parse_result
from .processor import word_to_ids
from paddle.inference import Config
from paddle.inference import create_predictor
from paddlehub.utils.utils import sys_stdin_encoding
from paddlehub.utils.parser import txt_parser
from paddlehub.module.module import moduleinfo
from paddlehub.module.module import runnable
from paddlehub.module.module import serving
from paddlehub.utils.parser import txt_parser
from paddlehub.utils.utils import sys_stdin_encoding
class DataFormatError(Exception):
......@@ -40,6 +40,7 @@ class DataFormatError(Exception):
author_email="paddle-dev@baidu.com",
type="nlp/lexical_analysis")
class LAC:
def __init__(self, user_dict=None):
"""
initialize with the necessary elements
......@@ -66,8 +67,8 @@ class LAC:
"""
predictor config setting
"""
model = self.default_pretrained_model_path+'.pdmodel'
params = self.default_pretrained_model_path+'.pdiparams'
model = self.default_pretrained_model_path + '.pdmodel'
params = self.default_pretrained_model_path + '.pdiparams'
cpu_config = Config(model, params)
cpu_config.disable_glog_info()
cpu_config.disable_gpu()
......
......@@ -12,19 +12,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import traceback
import socket
import threading
import time
import traceback
from multiprocessing import Process
from threading import Lock
from flask import Flask, request
import requests
from flask import Flask
from flask import redirect
from flask import request
from flask import Response
from paddlehub.serving.model_service.base_model_service import cv_module_info
from paddlehub.serving.model_service.base_model_service import nlp_module_info
from paddlehub.serving.model_service.base_model_service import v2_module_info
from paddlehub.utils import utils, log
from paddlehub.utils import log
from paddlehub.utils import utils
filename = 'HubServing-%s.log' % time.strftime("%Y_%m_%d", time.localtime())
_gradio_apps = {} # Used to store all launched gradio apps
_lock = Lock() # Used to prevent parallel requests to launch a server twice
def package_result(status: str, msg: str, data: dict):
'''
......@@ -55,6 +66,54 @@ def package_result(status: str, msg: str, data: dict):
return {"status": status, "msg": msg, "results": data}
def create_gradio_app(module_info: dict):
'''
Create a gradio app and launch a server for users.
Args:
module_info(dict): Module info include module name, method name and
other info.
Return:
int: port number, if server has been successful.
Exception:
Raise a exception if server can not been launched.
'''
module_name = module_info['module_name']
port = None
with _lock:
if module_name not in _gradio_apps:
try:
serving_method = getattr(module_info["module"], 'create_gradio_app')
except Exception:
raise RuntimeError('Module {} is not supported for gradio app.'.format(module_name))
def get_free_tcp_port():
tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
tcp.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
tcp.bind(('localhost', 0))
addr, port = tcp.getsockname()
tcp.close()
return port
port = get_free_tcp_port()
app = serving_method()
process = Process(target=app.launch, kwargs={'server_port': port})
process.start()
def check_alive():
nonlocal port
while True:
try:
requests.get('http://localhost:{}/'.format(port))
break
except Exception:
time.sleep(1)
check_alive()
_gradio_apps[module_name] = port
return port
def predict_v2(module_info: dict, input: dict):
'''
......@@ -159,6 +218,47 @@ def create_app(init_flag: bool = False, configs: dict = None):
results = predict_v2(module_info, inputs)
return results
@app_instance.route('/gradio/<module_name>', methods=["GET", "POST"])
def gradio_app(module_name: str):
if module_name in v2_module_info.modules:
module_info = v2_module_info.get_module_info(module_name)
module_info['module_name'] = module_name
else:
msg = "Module {} is not supported for gradio app.".format(module_name)
return package_result("111", msg, "")
create_gradio_app(module_info)
return redirect("/gradio/{}/app".format(module_name), code=302)
@app_instance.route("/gradio/<module_name>/<path:path>", methods=["GET", "POST"])
def request_gradio_app(module_name: str, path: str):
'''
Gradio app server url interface. We route urls for gradio app to gradio server.
Args:
module_name(str): Module name for gradio app.
path(str): All resource path from gradio server.
Returns:
Any thing from gradio server.
'''
port = _gradio_apps[module_name]
if path == 'app':
proxy_url = request.url.replace(request.host_url + 'gradio/{}/app'.format(module_name),
'http://localhost:{}/'.format(port))
else:
proxy_url = request.url.replace(request.host_url + 'gradio/{}/'.format(module_name),
'http://localhost:{}/'.format(port))
resp = requests.request(method=request.method,
url=proxy_url,
headers={key: value
for (key, value) in request.headers if key != 'Host'},
data=request.get_data(),
cookies=request.cookies,
allow_redirects=False)
headers = [(name, value) for (name, value) in resp.raw.headers.items()]
response = Response(resp.content, resp.status_code, headers)
return response
return app_instance
......
......@@ -17,3 +17,4 @@ tqdm
visualdl >= 2.0.0
# gunicorn not support windows
gunicorn >= 19.10.0; sys_platform != "win32"
gradio
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册