未验证 提交 f8aa0f06 编写于 作者: 走神的阿圆's avatar 走神的阿圆 提交者: GitHub

Add python interface for starting visualdl service . (#607)

* * Support starting vdl in python.

* Fix null point error.

* Add introduction to python startup method in README.md.

* Fix blank picture cause of scalar null response.
上级 c9f70b16
...@@ -75,7 +75,7 @@ VisualDL的graph支持paddle program的展示,同时兼容 ONNX(Open Neural Ne ...@@ -75,7 +75,7 @@ VisualDL的graph支持paddle program的展示,同时兼容 ONNX(Open Neural Ne
请使用下面的命令,来快速测试 VisualDL。 请使用下面的命令,来快速测试 VisualDL。
``` ```
# 安装,建是在虚拟环境或anaconda下。 # 安装,建是在虚拟环境或anaconda下。
pip install --upgrade visualdl pip install --upgrade visualdl
# 运行一个例子,vdl_create_scratch_log 将创建测试日志 # 运行一个例子,vdl_create_scratch_log 将创建测试日志
...@@ -230,17 +230,24 @@ int main() { ...@@ -230,17 +230,24 @@ int main() {
``` ```
## 启动Board ## 启动Board
当训练过程中已经产生了日志数据,就可以启动board进行实时预览可视化信息 当训练过程中已经产生了日志数据,就可以启动board进行实时预览可视化信息
### 在命令行中启动
``` ```
visualdl --logdir <some log dir> visualdl --logdir <some log dir>
``` ```
board 还支持一参数来实现远程的访问: board 还支持一参数来实现远程的访问:
- `--host` 设定IP - `--host` 设定IP
- `--port` 设定端口 - `--port` 设定端口
- `-m / --model_pb` 指定 ONNX 格式的模型文件 - `-m / --model_pb` 指定 ONNX 格式的模型文件
### 在Python脚本中启动
```python
>>> from visualdl.server import app
>>> app.run(logdir="SOME_LOG_DIR")
```
`app.run()`支持命令行启动的所有参数,除此之外,还可以通过指定`open_browser=True`,自动打开浏览器。
### 贡献 ### 贡献
VisualDL 是由 [PaddlePaddle](http://www.paddlepaddle.org/) 和 VisualDL 是由 [PaddlePaddle](http://www.paddlepaddle.org/) 和
......
...@@ -242,7 +242,7 @@ int main() { ...@@ -242,7 +242,7 @@ int main() {
## Launch Visual DL ## Launch Visual DL
After some logs have been generated during training, users can launch Visual DL application to see real-time data visualization by: After some logs have been generated during training, users can launch Visual DL application to see real-time data visualization by:
### Startup in command line
``` ```
visualdl --logdir <some log dir> visualdl --logdir <some log dir>
...@@ -254,6 +254,13 @@ visualDL also supports following optional parameters: ...@@ -254,6 +254,13 @@ visualDL also supports following optional parameters:
- `--port` set port - `--port` set port
- `-m / --model_pb` specify ONNX format for model file to view graph - `-m / --model_pb` specify ONNX format for model file to view graph
### Startup in python script
```python
>>> from visualdl.server import app
>>> app.run(logdir="SOME_LOG_DIR")
```
`app.run()` support all parameters for command line startup, in addition, you can also specify `pen_browser=True` to open browser automatically。
### Contribute ### Contribute
......
...@@ -32,7 +32,7 @@ VisualDL有三个功能模块. ...@@ -32,7 +32,7 @@ VisualDL有三个功能模块.
任何在 ```server``` 文件夹里代码的改动,都可以通过运行以下命令 任何在 ```server``` 文件夹里代码的改动,都可以通过运行以下命令
``` ```
python visualdl/server/visualdl --logdir={LOG_DIR} --port=8080 python visualdl/server/app.py --logdir={LOG_DIR} --port=8080
``` ```
来重启 Flask 服务器 来重启 Flask 服务器
......
...@@ -32,7 +32,7 @@ All backend and sdk logic is under visualdl sub directory ...@@ -32,7 +32,7 @@ All backend and sdk logic is under visualdl sub directory
Any code changes in ```server``` folder, simply run Any code changes in ```server``` folder, simply run
``` ```
python visualdl/server/visualdl --logdir={LOG_DIR} --port=8080 python visualdl/server/app.py --logdir={LOG_DIR} --port=8080
``` ```
to restart flask server to restart flask server
......
...@@ -58,4 +58,4 @@ cd $CURRENT_DIR ...@@ -58,4 +58,4 @@ cd $CURRENT_DIR
echo "Development server ready on http://$HOST:$FRONTEND_PORT" echo "Development server ready on http://$HOST:$FRONTEND_PORT"
# Run the visualDL with local PATH # Run the visualDL with local PATH
python ${SCRIPT_DIR}/../visualdl/server/visualdl "$ORIGINAL_ARGS" python ${SCRIPT_DIR}/../visualdl/server/app.py "$ORIGINAL_ARGS"
...@@ -103,9 +103,9 @@ packages = [ ...@@ -103,9 +103,9 @@ packages = [
libraries = ['core.so'] libraries = ['core.so']
if platform == 'win32': if platform == 'win32':
libraries = ['core.pyd', 'libprotobuf.dll'] libraries = ['core.pyd', 'libprotobuf.dll', 'zlib.dll']
scripts = ['visualdl/server/visualdl', 'demo/vdl_create_scratch_log'] scripts = ['visualdl/server/app.py', 'demo/vdl_create_scratch_log']
if platform == 'win32': if platform == 'win32':
scripts.append('visualdl/server/visualDL.bat') scripts.append('visualdl/server/visualDL.bat')
...@@ -128,4 +128,5 @@ setup( ...@@ -128,4 +128,5 @@ setup(
packages=packages, packages=packages,
ext_modules=[Extension('_foo', ['stub.cc'])], ext_modules=[Extension('_foo', ['stub.cc'])],
scripts=scripts, scripts=scripts,
cmdclass=cmdclass) cmdclass=cmdclass,
entry_points={'console_scripts': ['visualdl=visualdl.server.app:main']})
#!/user/bin/env python
# Copyright (c) 2017 VisualDL Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =======================================================================
import json
import os
import time
import sys
import multiprocessing
import threading
import re
import webbrowser
import requests
from argparse import ArgumentParser
from flask import (Flask, Response, redirect, request, send_file,
send_from_directory)
import visualdl
import visualdl.server
from visualdl.server import lib
from visualdl.server.log import logger
from visualdl.server.mock import data as mock_data
from visualdl.server.mock import data as mock_tags
from visualdl.python.cache import MemCache
from visualdl.python.storage import (LogWriter, LogReader)
try:
import exceptions
except:
pass
error_retry_times = 3
error_sleep_time = 2 # seconds
SERVER_DIR = os.path.join(visualdl.ROOT, 'server')
support_language = ["en", "zh"]
default_language = support_language[0]
server_path = os.path.abspath(os.path.dirname(sys.argv[0]))
static_file_path = os.path.join(SERVER_DIR, "./dist")
mock_data_path = os.path.join(SERVER_DIR, "./mock_data/")
class ParseArgs(object):
def __init__(self, logdir, host="0.0.0.0", port=8040, model_pb="", cache_timeout=20, language=default_language):
self.logdir = logdir
self.host = host
self.port = port
self.model_pb = model_pb
self.cache_timeout = cache_timeout
self.language = language
def try_call(function, *args, **kwargs):
res = lib.retry(error_retry_times, function, error_sleep_time, *args,
**kwargs)
if not res:
logger.error("server temporary error, will retry latter.")
return res
def parse_args():
"""
:return:
"""
parser = ArgumentParser(
description="VisualDL, a tool to visualize deep learning.")
parser.add_argument(
"-p",
"--port",
type=int,
default=8040,
action="store",
dest="port",
help="api service port")
parser.add_argument(
"-t",
"--host",
type=str,
default="0.0.0.0",
action="store",
help="api service ip")
parser.add_argument(
"-m",
"--model_pb",
type=str,
action="store",
help="model proto in ONNX format or in Paddle framework format")
parser.add_argument(
"--logdir",
required=True,
action="store",
dest="logdir",
help="log file directory")
parser.add_argument(
"--cache_timeout",
action="store",
dest="cache_timeout",
type=float,
default=20,
help="memory cache timeout duration in seconds, default 20",
)
parser.add_argument(
"-L",
"--language",
type=str,
default=default_language,
action="store",
help="set the default language")
args = parser.parse_args()
if not args.logdir:
parser.print_help()
sys.exit(-1)
return args
# status, msg, data
def gen_result(status, msg, data):
"""
:param status:
:param msg:
:return:
"""
result = dict()
result['status'] = status
result['msg'] = msg
result['data'] = data
return result
def create_app(args):
app = Flask(__name__, static_url_path="")
# set static expires in a short time to reduce browser's memory usage.
app.config['SEND_FILE_MAX_AGE_DEFAULT'] = 30
log_reader = LogReader(args.logdir)
# mannully put graph's image on this path also works.
graph_image_path = os.path.join(args.logdir, 'graph.jpg')
# use a memory cache to reduce disk reading frequency.
CACHE = MemCache(timeout=args.cache_timeout)
cache_get = lib.cache_get(CACHE)
@app.route("/")
def index():
language = args.language
if not language in support_language:
language = default_language
if language == default_language:
return redirect('/app/index', code=302)
return redirect('/app/' + language + '/index', code=302)
@app.route('/app/<path:filename>')
def serve_static(filename):
return send_from_directory(
os.path.join(server_path, static_file_path), filename if re.search(r'\..+$', filename) else filename + '.html')
@app.route('/graphs/image')
def serve_graph():
return send_file(os.path.join(os.getcwd(), graph_image_path))
@app.route('/api/logdir')
def logdir():
result = gen_result(0, "", {"logdir": args.logdir})
return Response(json.dumps(result), mimetype='application/json')
@app.route('/api/runs')
def runs():
data = cache_get('/data/runs', lib.get_modes, log_reader)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
@app.route('/api/language')
def language():
data = args.language
if not data in support_language:
data = default_language
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
@app.route("/api/scalars/tags")
def scalar_tags():
data = cache_get("/data/plugin/scalars/tags", try_call,
lib.get_scalar_tags, log_reader)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
@app.route("/api/images/tags")
def image_tags():
data = cache_get("/data/plugin/images/tags", try_call, lib.get_image_tags,
log_reader)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
@app.route("/api/audio/tags")
def audio_tags():
data = cache_get("/data/plugin/audio/tags", try_call, lib.get_audio_tags,
log_reader)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
@app.route("/api/histograms/tags")
def histogram_tags():
data = cache_get("/data/plugin/histograms/tags", try_call,
lib.get_histogram_tags, log_reader)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
@app.route("/api/texts/tags")
def texts_tags():
data = cache_get("/data/plugin/texts/tags", try_call,
lib.get_texts_tags, log_reader)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
@app.route('/api/scalars/list')
def scalars():
run = request.args.get('run')
tag = request.args.get('tag')
key = os.path.join('/data/plugin/scalars/scalars', run, tag)
data = cache_get(key, try_call, lib.get_scalar, log_reader, run, tag)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
@app.route('/api/images/list')
def images():
mode = request.args.get('run')
tag = request.args.get('tag')
key = os.path.join('/data/plugin/images/images', mode, tag)
data = cache_get(key, try_call, lib.get_image_tag_steps, log_reader, mode,
tag)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
@app.route('/api/images/image')
def individual_image():
mode = request.args.get('run')
tag = request.args.get('tag') # include a index
step_index = int(request.args.get('index')) # index of step
key = os.path.join('/data/plugin/images/individualImage', mode, tag,
str(step_index))
data = cache_get(key, try_call, lib.get_invididual_image, log_reader, mode,
tag, step_index)
response = send_file(
data, as_attachment=True, attachment_filename='img.png')
return response
@app.route('/api/histograms/list')
def histogram():
run = request.args.get('run')
tag = request.args.get('tag')
key = os.path.join('/data/plugin/histograms/histograms', run, tag)
data = cache_get(key, try_call, lib.get_histogram, log_reader, run, tag)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
@app.route('/api/texts/list')
def texts():
run = request.args.get('run')
tag = request.args.get('tag')
key = os.path.join('/data/plugin/texts/texts', run, tag)
data = cache_get(key, try_call, lib.get_texts, log_reader, run, tag)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
@app.route('/api/embeddings/embedding')
def embeddings():
run = request.args.get('run')
dimension = request.args.get('dimension')
reduction = request.args.get('reduction')
key = os.path.join('/data/plugin/embeddings/embeddings', run, dimension, reduction)
data = cache_get(key, try_call, lib.get_embeddings, log_reader, run, reduction, int(dimension))
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
@app.route('/api/audio/list')
def audio():
mode = request.args.get('run')
tag = request.args.get('tag')
key = os.path.join('/data/plugin/audio/audio', mode, tag)
data = cache_get(key, try_call, lib.get_audio_tag_steps, log_reader, mode,
tag)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
@app.route('/api/audio/audio')
def individual_audio():
mode = request.args.get('run')
tag = request.args.get('tag') # include a index
step_index = int(request.args.get('index')) # index of step
key = os.path.join('/data/plugin/audio/individualAudio', mode, tag,
str(step_index))
data = cache_get(key, try_call, lib.get_individual_audio, log_reader, mode,
tag, step_index)
response = send_file(
data, as_attachment=True, attachment_filename='audio.wav')
return response
return app
def _open_browser(app, index_url):
while True:
try:
requests.get(index_url)
break
except Exception as e:
time.sleep(0.5)
webbrowser.open(index_url)
def _run(logdir, host="127.0.0.1", port=8080, model_pb="", cache_timeout=20, language=default_language, open_browser=False):
args = ParseArgs(logdir=logdir, host=host, port=port, model_pb=model_pb, cache_timeout=cache_timeout, language=language)
logger.info(" port=" + str(args.port))
app = create_app(args)
index_url = "http://" + host + ":" + str(port)
if open_browser:
threading.Thread(target=_open_browser, kwargs={"app": app, "index_url": index_url}).start()
app.run(debug=False, host=args.host, port=args.port, threaded=True)
def run(logdir, host="127.0.0.1", port=8080, model_pb="", cache_timeout=20, language=default_language, open_browser=False):
kwarg = {
"logdir": logdir,
"host": host,
"port": port,
"model_pb": model_pb,
"cache_timeout": cache_timeout,
"language": language,
"open_browser": open_browser
}
p = multiprocessing.Process(target=_run, kwargs=kwarg)
p.start()
return p.pid
def main():
args = parse_args()
logger.info(" port=" + str(args.port))
app = create_app(args=parse_args())
app.run(debug=False, host=args.host, port=args.port, threaded=True)
...@@ -48,8 +48,8 @@ class Reservoir(object): ...@@ -48,8 +48,8 @@ class Reservoir(object):
raise ValueError("Max_size must be nonnegative integer.") raise ValueError("Max_size must be nonnegative integer.")
self._max_size = max_size self._max_size = max_size
self._buckets = collections.defaultdict( self._buckets = collections.defaultdict(
lambda : _ReservoirBucket(max_size=self._max_size, lambda: _ReservoirBucket(max_size=self._max_size,
random_instance=random.Random(seed)) random_instance=random.Random(seed))
) )
self._mutex = threading.Lock() self._mutex = threading.Lock()
...@@ -117,7 +117,7 @@ class Reservoir(object): ...@@ -117,7 +117,7 @@ class Reservoir(object):
def get_items(self, mode, tag): def get_items(self, mode, tag):
"""Get items with tag 'mode_tag' """Get items with tag 'mode_tag'
For usage habits of VisualDL, actually call self._get_items() For usage habits of VisualDL, actually call self._get_items()
Args: Args:
...@@ -162,6 +162,23 @@ class Reservoir(object): ...@@ -162,6 +162,23 @@ class Reservoir(object):
key = mode + "_" + tag key = mode + "_" + tag
self._add_item(key, item) self._add_item(key, item)
def _cut_tail(self, key):
with self._mutex:
self._buckets[key].cut_tail()
def cut_tail(self, mode, tag):
"""Pop the last item in reservoir buckets.
Sometimes the tail of the retrieved data is abnormal 0. This
method is used to handle this problem.
Args:
mode: Identity of one tablet.
tag: Identity of one record in tablet.
"""
key = mode + "_" + tag
self._cut_tail(key)
class _ReservoirBucket(object): class _ReservoirBucket(object):
"""Data manager for sampling data, use reservoir sampling. """Data manager for sampling data, use reservoir sampling.
...@@ -222,6 +239,16 @@ class _ReservoirBucket(object): ...@@ -222,6 +239,16 @@ class _ReservoirBucket(object):
with self._mutex: with self._mutex:
return self._num_items_index return self._num_items_index
def cut_tail(self):
"""Pop the last item in reservoir buckets.
Sometimes the tail of the retrieved data is abnormal 0. This
method is used to handle this problem.
"""
with self._mutex:
self._items.pop()
self._num_items_index -= 1
class DataManager(object): class DataManager(object):
"""Data manager for all plugin. """Data manager for all plugin.
...@@ -238,8 +265,8 @@ class DataManager(object): ...@@ -238,8 +265,8 @@ class DataManager(object):
self._image_reservoir = Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["image"]) self._image_reservoir = Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["image"])
self._reservoirs = {"scalar": self._scalar_reservoir, self._reservoirs = {"scalar": self._scalar_reservoir,
"histogram": self._histogram_reservoir, "histogram": self._histogram_reservoir,
"image": self._image_reservoir} "image": self._image_reservoir}
self._mutex = threading.Lock() self._mutex = threading.Lock()
def get_reservoir(self, plugin): def get_reservoir(self, plugin):
...@@ -299,7 +326,6 @@ if __name__ == '__main__': ...@@ -299,7 +326,6 @@ if __name__ == '__main__':
print(b) print(b)
c = d.get_reservoir("scalar").get_num_items_index("train", "loss") c = d.get_reservoir("scalar").get_num_items_index("train", "loss")
print(c) print(c)
print("***")
b = d.get_reservoir("scalar").get_items("train", "accu") b = d.get_reservoir("scalar").get_items("train", "accu")
print(b) print(b)
print(d.get_reservoir("scalar").get_num_items_index("train", "accu")) print(d.get_reservoir("scalar").get_num_items_index("train", "accu"))
...@@ -312,4 +338,4 @@ if __name__ == '__main__': ...@@ -312,4 +338,4 @@ if __name__ == '__main__':
print(c) print(c)
print(d.get_reservoir("scalar").exist_in_keys("train", "loss")) print(d.get_reservoir("scalar").exist_in_keys("train", "loss"))
print(d.get_reservoir("scalar").exist_in_keys("train", "loss2")) print(d.get_reservoir("scalar").exist_in_keys("train", "loss2"))
\ No newline at end of file
...@@ -65,15 +65,25 @@ def get_scalar(storage, mode, tag, num_records=300): ...@@ -65,15 +65,25 @@ def get_scalar(storage, mode, tag, num_records=300):
else: else:
num_items_index = data_reservoir.get_num_items_index(mode, tag) num_items_index = data_reservoir.get_num_items_index(mode, tag)
if num_items_index != scalar.size(): if num_items_index != scalar.size():
records = scalar.records(num_items_index) try:
ids = scalar.ids(num_items_index) records = scalar.records(num_items_index)
timestamps = scalar.timestamps(num_items_index) ids = scalar.ids(num_items_index)
timestamps = scalar.timestamps(num_items_index)
except Exception:
error_info = '\n'.join(map(str, sys.exc_info()))
logger.error("Unexpected error: %s" % error_info)
return data_reservoir.get_items(mode, tag)
data = list(zip(timestamps, ids, records)) data = list(zip(timestamps, ids, records))
for index in range(len(data)): for index in range(len(data)):
data_reservoir.add_item(mode=mode, tag=tag, item=data[index]) data_reservoir.add_item(mode=mode, tag=tag, item=data[index])
return data_reservoir.get_items(mode, tag) results = data_reservoir.get_items(mode, tag)
# TODO(Superjomn) some bug here, sometimes there are zero here.
if results[-1][-1] == 0:
data_reservoir.cut_tail(mode=mode, tag=tag)
results = data_reservoir.get_items(mode, tag)
return results
def get_image_tags(storage): def get_image_tags(storage):
......
#!/user/bin/env python
# Copyright (c) 2017 VisualDL Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =======================================================================
import json
import os
import time
import sys
import re
from argparse import ArgumentParser
from flask import (Flask, Response, redirect, request, send_file,
send_from_directory)
from flask_babel import Babel
import visualdl
import visualdl.server
import visualdl.server.graph as vdl_graph
import visualdl.server.model as model
from visualdl.server import lib
from visualdl.server.log import logger
from visualdl.server.mock import data as mock_data
from visualdl.server.mock import data as mock_tags
from visualdl.python.cache import MemCache
from visualdl.python.storage import (LogWriter, LogReader)
try:
import exceptions
except:
pass
support_language = ["en", "zh"]
default_language = support_language[0]
app = Flask(__name__, static_url_path="")
# set static expires in a short time to reduce browser's memory usage.
app.config['SEND_FILE_MAX_AGE_DEFAULT'] = 30
app.config['BABEL_DEFAULT_LOCALE'] = default_language
babel = Babel(app)
error_retry_times = 3
error_sleep_time = 2 # seconds
SERVER_DIR = os.path.join(visualdl.ROOT, 'server')
def try_call(function, *args, **kwargs):
res = lib.retry(error_retry_times, function, error_sleep_time, *args,
**kwargs)
if not res:
logger.error("Internal server error. Retry later.")
return res
def parse_args():
"""
:return:
"""
parser = ArgumentParser(
description="VisualDL, a tool to visualize deep learning.")
parser.add_argument(
"-p",
"--port",
type=int,
default=8040,
action="store",
dest="port",
help="api service port")
parser.add_argument(
"-t",
"--host",
type=str,
default="0.0.0.0",
action="store",
help="api service ip")
parser.add_argument(
"-m",
"--model_pb",
type=str,
action="store",
help="model proto in ONNX format or in Paddle framework format")
parser.add_argument(
"--logdir",
required=True,
action="store",
dest="logdir",
help="log file directory")
parser.add_argument(
"--cache_timeout",
action="store",
dest="cache_timeout",
type=float,
default=20,
help="memory cache timeout duration in seconds, default 20",
)
parser.add_argument(
"-L",
"--language",
type=str,
action="store",
help="set the default language")
args = parser.parse_args()
if not args.logdir:
parser.print_help()
sys.exit(-1)
return args
args = parse_args()
server_path = os.path.abspath(os.path.dirname(sys.argv[0]))
static_file_path = os.path.join(SERVER_DIR, "./dist")
mock_data_path = os.path.join(SERVER_DIR, "./mock_data/")
log_reader = LogReader(args.logdir)
# mannully put graph's image on this path also works.
graph_image_path = os.path.join(args.logdir, 'graph.jpg')
# use a memory cache to reduce disk reading frequency.
CACHE = MemCache(timeout=args.cache_timeout)
cache_get = lib.cache_get(CACHE)
# status, msg, data
def gen_result(status, msg, data):
"""
:param status:
:param msg:
:return:
"""
result = dict()
result['status'] = status
result['msg'] = msg
result['data'] = data
return result
@babel.localeselector
def get_locale():
language = args.language
if not language or not language in support_language:
language = request.accept_languages.best_match(support_language)
return language
@app.route("/")
def index():
language = get_locale()
if language == default_language:
return redirect('/app/index', code=302)
return redirect('/app/' + language + '/index', code=302)
@app.route('/app/<path:filename>')
def serve_static(filename):
return send_from_directory(
os.path.join(server_path, static_file_path), filename if re.search(r'\..+$', filename) else filename + '.html')
@app.route('/graphs/image')
def serve_graph():
return send_file(os.path.join(os.getcwd(), graph_image_path))
@app.route('/api/logdir')
def logdir():
result = gen_result(0, "", {"logdir": args.logdir})
return Response(json.dumps(result), mimetype='application/json')
@app.route('/api/runs')
def runs():
data = cache_get('/data/runs', lib.get_modes, log_reader)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
@app.route('/api/language')
def language():
data = get_locale()
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
@app.route("/api/scalars/tags")
def scalar_tags():
data = cache_get("/data/plugin/scalars/tags", try_call,
lib.get_scalar_tags, log_reader)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
@app.route("/api/images/tags")
def image_tags():
data = cache_get("/data/plugin/images/tags", try_call, lib.get_image_tags,
log_reader)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
@app.route("/api/audio/tags")
def audio_tags():
data = cache_get("/data/plugin/audio/tags", try_call, lib.get_audio_tags,
log_reader)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
@app.route("/api/histograms/tags")
def histogram_tags():
data = cache_get("/data/plugin/histograms/tags", try_call,
lib.get_histogram_tags, log_reader)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
@app.route("/api/texts/tags")
def texts_tags():
data = cache_get("/data/plugin/texts/tags", try_call,
lib.get_texts_tags, log_reader)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
@app.route('/api/scalars/list')
def scalars():
run = request.args.get('run')
tag = request.args.get('tag')
key = os.path.join('/data/plugin/scalars/scalars', run, tag)
data = cache_get(key, try_call, lib.get_scalar, log_reader, run, tag)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
@app.route('/api/images/list')
def images():
mode = request.args.get('run')
tag = request.args.get('tag')
key = os.path.join('/data/plugin/images/images', mode, tag)
data = cache_get(key, try_call, lib.get_image_tag_steps, log_reader, mode,
tag)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
@app.route('/api/images/image')
def individual_image():
mode = request.args.get('run')
tag = request.args.get('tag') # include a index
step_index = int(request.args.get('index')) # index of step
key = os.path.join('/data/plugin/images/individualImage', mode, tag,
str(step_index))
data = cache_get(key, try_call, lib.get_invididual_image, log_reader, mode,
tag, step_index)
response = send_file(
data, as_attachment=True, attachment_filename='img.png')
return response
@app.route('/api/histograms/list')
def histogram():
run = request.args.get('run')
tag = request.args.get('tag')
key = os.path.join('/data/plugin/histograms/histograms', run, tag)
data = cache_get(key, try_call, lib.get_histogram, log_reader, run, tag)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
@app.route('/api/texts/list')
def texts():
run = request.args.get('run')
tag = request.args.get('tag')
key = os.path.join('/data/plugin/texts/texts', run, tag)
data = cache_get(key, try_call, lib.get_texts, log_reader, run, tag)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
@app.route('/api/embeddings/embedding')
def embeddings():
run = request.args.get('run')
dimension = request.args.get('dimension')
reduction = request.args.get('reduction')
key = os.path.join('/data/plugin/embeddings/embeddings', run, dimension, reduction)
data = cache_get(key, try_call, lib.get_embeddings, log_reader, run, reduction, int(dimension))
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
@app.route('/api/audio/list')
def audio():
mode = request.args.get('run')
tag = request.args.get('tag')
key = os.path.join('/data/plugin/audio/audio', mode, tag)
data = cache_get(key, try_call, lib.get_audio_tag_steps, log_reader, mode,
tag)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
@app.route('/api/audio/audio')
def individual_audio():
mode = request.args.get('run')
tag = request.args.get('tag') # include a index
step_index = int(request.args.get('index')) # index of step
key = os.path.join('/data/plugin/audio/individualAudio', mode, tag,
str(step_index))
data = cache_get(key, try_call, lib.get_individual_audio, log_reader, mode,
tag, step_index)
response = send_file(
data, as_attachment=True, attachment_filename='audio.wav')
return response
@app.route('/api/graphs/graph')
def graph():
if model.is_onnx_model(args.model_pb):
json_str = vdl_graph.draw_onnx_graph(args.model_pb)
elif model.is_paddle_model(args.model_pb):
json_str = vdl_graph.draw_paddle_graph(args.model_pb)
else:
json_str = {}
data = {'data': json_str}
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
if __name__ == '__main__':
logger.info(" port=" + str(args.port))
app.run(debug=False, host=args.host, port=args.port, threaded=True)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册