# Copyright (c) 2022 VisualDL Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =======================================================================
import gradio as gr
import numpy as np
from .http_client_manager import get_metric_data
from .http_client_manager import HttpClientManager
from .http_client_manager import metrics_table_head
from .visualizer import visualize_detection
from .visualizer import visualize_face_alignment
from .visualizer import visualize_face_detection
from .visualizer import visualize_headpose
from .visualizer import visualize_keypoint_detection
from .visualizer import visualize_matting
from .visualizer import visualize_ocr
from .visualizer import visualize_segmentation
_http_manager = HttpClientManager()
supported_tasks = {
'detection': visualize_detection,
'facedet': visualize_face_detection,
'keypointdetection': visualize_keypoint_detection,
'segmentation': visualize_segmentation,
'matting': visualize_matting,
'ocr': visualize_ocr,
'facealignment': visualize_face_alignment,
'headpose': visualize_headpose,
'unspecified': lambda x: str(x)
}
def create_gradio_client_app(): # noqa:C901
css = """
.gradio-container {
font-family: 'IBM Plex Sans', sans-serif;
}
.gr-button {
color: white;
border-color: black;
background: black;
}
input[type='range'] {
accent-color: black;
}
.dark input[type='range'] {
accent-color: #dfdfdf;
}
#gallery {
min-height: 22rem;
margin-bottom: 15px;
margin-left: auto;
margin-right: auto;
border-bottom-right-radius: .5rem !important;
border-bottom-left-radius: .5rem !important;
}
#gallery>div>.h-full {
min-height: 20rem;
}
.details:hover {
text-decoration: underline;
}
.gr-button {
white-space: nowrap;
}
.gr-button:focus {
border-color: rgb(147 197 253 / var(--tw-border-opacity));
outline: none;
box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000);
--tw-border-opacity: 1;
--tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) \
var(--tw-ring-offset-color);
--tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color);
--tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
--tw-ring-opacity: .5;
}
.footer {
margin-bottom: 45px;
margin-top: 35px;
text-align: center;
border-bottom: 1px solid #e5e5e5;
}
.footer>p {
font-size: .8rem;
display: inline-block;
padding: 0 10px;
transform: translateY(10px);
background: white;
}
.dark .footer {
border-color: #303030;
}
.dark .footer>p {
background: #0b0f19;
}
.prompt h4{
margin: 1.25em 0 .25em 0;
font-weight: bold;
font-size: 115%;
}
"""
block = gr.Blocks(css=css)
with block:
gr.HTML("""
FastDeploy Client
The client is used for creating requests to fastdeploy server.
""")
with gr.Group():
with gr.Box():
with gr.Column():
with gr.Row():
server_addr_text = gr.Textbox(
label="服务ip",
show_label=True,
max_lines=1,
placeholder="localhost",
)
server_http_port_text = gr.Textbox(
label="推理服务端口",
show_label=True,
max_lines=1,
placeholder="8000",
)
server_metric_port_text = gr.Textbox(
label="性能服务端口",
show_label=True,
max_lines=1,
placeholder="8002",
)
with gr.Row():
model_name_text = gr.Textbox(
label="模型名称",
show_label=True,
max_lines=1,
placeholder="yolov5",
)
model_version_text = gr.Textbox(
label="模型版本",
show_label=True,
max_lines=1,
placeholder="1",
)
with gr.Box():
with gr.Tab("组件形式"):
check_button = gr.Button("获取模型输入输出")
component_format_column = gr.Column(visible=False)
with component_format_column:
task_radio = gr.Radio(
choices=list(supported_tasks.keys()),
value='unspecified',
label='任务类型',
visible=True)
gr.Markdown("根据模型需要,挑选文本框或者图像框进行输入")
with gr.Row():
with gr.Column():
gr.Markdown("模型输入")
input_accordions = []
input_name_texts = []
input_images = []
input_texts = []
for i in range(6):
accordion = gr.Accordion(
"输入变量 {}".format(i),
open=True,
visible=False)
with accordion:
input_name_text = gr.Textbox(
label="变量名", interactive=False)
input_image = gr.Image(type='numpy')
input_text = gr.Textbox(
label="文本框", max_lines=1000)
input_accordions.append(accordion)
input_name_texts.append(input_name_text)
input_images.append(input_image)
input_texts.append(input_text)
with gr.Column():
gr.Markdown("模型输出")
output_accordions = []
output_name_texts = []
output_images = []
output_texts = []
for i in range(6):
accordion = gr.Accordion(
"输出变量 {}".format(i),
open=True,
visible=False)
with accordion:
output_name_text = gr.Textbox(
label="变量名", interactive=False)
output_text = gr.Textbox(
label="服务返回的原数据",
interactive=False,
show_label=True)
output_image = gr.Image(
interactive=False)
output_accordions.append(accordion)
output_name_texts.append(output_name_text)
output_images.append(output_image)
output_texts.append(output_text)
component_submit_button = gr.Button("提交请求")
with gr.Tab("原始形式"):
gr.Markdown("模型输入")
raw_payload_text = gr.Textbox(
label="负载数据", max_lines=10000)
with gr.Column():
gr.Markdown("输出")
output_raw_text = gr.Textbox(
label="服务返回的原始数据", interactive=False)
raw_submit_button = gr.Button("提交请求")
with gr.Box():
with gr.Column():
gr.Markdown("服务性能统计(每次提交请求会自动更新数据,您也可以手动点击更新)")
output_html_table = gr.HTML(
label="metrics",
interactive=False,
show_label=False,
value=metrics_table_head.format('', ''))
update_metric_button = gr.Button("更新统计数据")
status_text = gr.Textbox(
label="status",
show_label=True,
max_lines=1,
interactive=False)
all_input_output_components = input_accordions + input_name_texts + input_images + \
input_texts + output_accordions + output_name_texts + output_images + output_texts
def get_input_output_name(server_ip, server_port, model_name,
model_version):
try:
server_addr = server_ip + ':' + server_port
input_metas, output_metas = _http_manager.get_model_meta(
server_addr, model_name, model_version)
except Exception as e:
return {status_text: str(e)}
results = {
component: None
for component in all_input_output_components
}
results[component_format_column] = gr.update(visible=True)
# results[check_button] = gr.update(visible=False)
for input_accordio in input_accordions:
results[input_accordio] = gr.update(visible=False)
for output_accordio in output_accordions:
results[output_accordio] = gr.update(visible=False)
results[status_text] = 'GetInputOutputName Successful'
for i, input_meta in enumerate(input_metas):
results[input_accordions[i]] = gr.update(visible=True)
results[input_name_texts[i]] = input_meta['name']
for i, output_meta in enumerate(output_metas):
results[output_accordions[i]] = gr.update(visible=True)
results[output_name_texts[i]] = output_meta['name']
return results
def component_inference(*args):
server_ip = args[0]
http_port = args[1]
metric_port = args[2]
model_name = args[3]
model_version = args[4]
names = args[5:5 + len(input_name_texts)]
images = args[5 + len(input_name_texts):5 + len(input_name_texts) +
len(input_images)]
texts = args[5 + len(input_name_texts) + len(input_images):5 +
len(input_name_texts) + len(input_images) +
len(input_texts)]
task_type = args[-1]
server_addr = server_ip + ':' + http_port
if server_ip and http_port and model_name and model_version:
inputs = {}
for i, input_name in enumerate(names):
if input_name:
if images[i] is not None:
inputs[input_name] = np.array([images[i]])
if texts[i]:
inputs[input_name] = np.array(
[[texts[i].encode('utf-8')]], dtype=np.object_)
try:
infer_results = _http_manager.infer(
server_addr, model_name, model_version, inputs)
results = {status_text: 'Inference Successful'}
for i, (output_name,
data) in enumerate(infer_results.items()):
results[output_name_texts[i]] = output_name
results[output_texts[i]] = str(data)
if task_type != 'unspecified':
try:
results[output_images[i]] = supported_tasks[
task_type](images[0], data)
except Exception:
results[output_images[i]] = None
if metric_port:
html_table = get_metric_data(server_ip, metric_port)
results[output_html_table] = html_table
return results
except Exception as e:
return {status_text: 'Error: {}'.format(e)}
else:
return {
status_text:
'Please input server addr, model name and model version.'
}
def raw_inference(*args):
server_ip = args[0]
http_port = args[1]
metric_port = args[2]
model_name = args[3]
model_version = args[4]
payload_text = args[5]
server_addr = server_ip + ':' + http_port
try:
result = _http_manager.raw_infer(server_addr, model_name,
model_version, payload_text)
results = {
status_text: 'Get response from server',
output_raw_text: result
}
if server_ip and metric_port:
html_table = get_metric_data(server_ip, metric_port)
results[output_html_table] = html_table
return results
except Exception as e:
return {status_text: 'Error: {}'.format(e)}
def update_metric(server_ip, metrics_port):
if server_ip and metrics_port:
try:
html_table = get_metric_data(server_ip, metrics_port)
return {
output_html_table: html_table,
status_text: "Successfully update metrics."
}
except Exception as e:
return {status_text: 'Error: {}'.format(e)}
else:
return {
status_text: 'Please input server ip and metrics_port.'
}
check_button.click(
fn=get_input_output_name,
inputs=[
server_addr_text, server_http_port_text, model_name_text,
model_version_text
],
outputs=[
*all_input_output_components, check_button,
component_format_column, status_text
])
component_submit_button.click(
fn=component_inference,
inputs=[
server_addr_text, server_http_port_text,
server_metric_port_text, model_name_text, model_version_text,
*input_name_texts, *input_images, *input_texts, task_radio
],
outputs=[
*output_name_texts, *output_images, *output_texts, status_text,
output_html_table
])
raw_submit_button.click(
fn=raw_inference,
inputs=[
server_addr_text, server_http_port_text,
server_metric_port_text, model_name_text, model_version_text,
raw_payload_text
],
outputs=[output_raw_text, status_text, output_html_table])
update_metric_button.click(
fn=update_metric,
inputs=[server_addr_text, server_metric_port_text],
outputs=[output_html_table, status_text])
return block