From f8aa0f06d7c5562d531bc97ee8d0fcb307197acb Mon Sep 17 00:00:00 2001 From: ShenYuhan Date: Thu, 26 Mar 2020 10:31:16 +0800 Subject: [PATCH] Add python interface for starting visualdl service . (#607) * * Support starting vdl in python. * Fix null point error. * Add introduction to python startup method in README.md. * Fix blank picture cause of scalar null response. --- README.cn.md | 11 +- README.md | 9 +- docs/develop/how_to_dev_backend_cn.md | 2 +- docs/develop/how_to_dev_backend_en.md | 2 +- scripts/start_dev_server.sh | 2 +- setup.py | 7 +- visualdl/server/app.py | 364 ++++++++++++++++++++++++++ visualdl/server/data_manager.py | 40 ++- visualdl/server/lib.py | 18 +- visualdl/server/visualdl | 352 ------------------------- 10 files changed, 435 insertions(+), 372 deletions(-) create mode 100644 visualdl/server/app.py delete mode 100644 visualdl/server/visualdl diff --git a/README.cn.md b/README.cn.md index c859ea7f..ac51c01d 100644 --- a/README.cn.md +++ b/README.cn.md @@ -75,7 +75,7 @@ VisualDL的graph支持paddle program的展示,同时兼容 ONNX(Open Neural Ne 请使用下面的命令,来快速测试 VisualDL。 ``` -# 安装,建議是在虚拟环境或anaconda下。 +# 安装,建议是在虚拟环境或anaconda下。 pip install --upgrade visualdl # 运行一个例子,vdl_create_scratch_log 将创建测试日志 @@ -230,17 +230,24 @@ int main() { ``` ## 启动Board 当训练过程中已经产生了日志数据,就可以启动board进行实时预览可视化信息 +### 在命令行中启动 ``` visualdl --logdir ``` -board 还支持一下参数来实现远程的访问: +board 还支持一些参数来实现远程的访问: - `--host` 设定IP - `--port` 设定端口 - `-m / --model_pb` 指定 ONNX 格式的模型文件 +### 在Python脚本中启动 +```python +>>> from visualdl.server import app +>>> app.run(logdir="SOME_LOG_DIR") +``` +`app.run()`支持命令行启动的所有参数,除此之外,还可以通过指定`open_browser=True`,自动打开浏览器。 ### 贡献 VisualDL 是由 [PaddlePaddle](http://www.paddlepaddle.org/) 和 diff --git a/README.md b/README.md index 419c6cde..f7363a67 100644 --- a/README.md +++ b/README.md @@ -242,7 +242,7 @@ int main() { ## Launch Visual DL After some logs have been generated during training, users can launch Visual DL application to see real-time data visualization by: - +### Startup in command line ``` visualdl --logdir @@ -254,6 +254,13 @@ visualDL also supports following optional parameters: - `--port` set port - `-m / --model_pb` specify ONNX format for model file to view graph +### Startup in python script +```python +>>> from visualdl.server import app + +>>> app.run(logdir="SOME_LOG_DIR") +``` +`app.run()` support all parameters for command line startup, in addition, you can also specify `pen_browser=True` to open browser automatically。 ### Contribute diff --git a/docs/develop/how_to_dev_backend_cn.md b/docs/develop/how_to_dev_backend_cn.md index 7bd536f2..a589ba33 100644 --- a/docs/develop/how_to_dev_backend_cn.md +++ b/docs/develop/how_to_dev_backend_cn.md @@ -32,7 +32,7 @@ VisualDL有三个功能模块. 任何在 ```server``` 文件夹里代码的改动,都可以通过运行以下命令 ``` -python visualdl/server/visualdl --logdir={LOG_DIR} --port=8080 +python visualdl/server/app.py --logdir={LOG_DIR} --port=8080 ``` 来重启 Flask 服务器 diff --git a/docs/develop/how_to_dev_backend_en.md b/docs/develop/how_to_dev_backend_en.md index cf41993e..03ac4b92 100644 --- a/docs/develop/how_to_dev_backend_en.md +++ b/docs/develop/how_to_dev_backend_en.md @@ -32,7 +32,7 @@ All backend and sdk logic is under visualdl sub directory Any code changes in ```server``` folder, simply run ``` -python visualdl/server/visualdl --logdir={LOG_DIR} --port=8080 +python visualdl/server/app.py --logdir={LOG_DIR} --port=8080 ``` to restart flask server diff --git a/scripts/start_dev_server.sh b/scripts/start_dev_server.sh index f8e12729..1af126a7 100755 --- a/scripts/start_dev_server.sh +++ b/scripts/start_dev_server.sh @@ -58,4 +58,4 @@ cd $CURRENT_DIR echo "Development server ready on http://$HOST:$FRONTEND_PORT" # Run the visualDL with local PATH -python ${SCRIPT_DIR}/../visualdl/server/visualdl "$ORIGINAL_ARGS" +python ${SCRIPT_DIR}/../visualdl/server/app.py "$ORIGINAL_ARGS" diff --git a/setup.py b/setup.py index acc225bc..b547d149 100644 --- a/setup.py +++ b/setup.py @@ -103,9 +103,9 @@ packages = [ libraries = ['core.so'] if platform == 'win32': - libraries = ['core.pyd', 'libprotobuf.dll'] + libraries = ['core.pyd', 'libprotobuf.dll', 'zlib.dll'] -scripts = ['visualdl/server/visualdl', 'demo/vdl_create_scratch_log'] +scripts = ['visualdl/server/app.py', 'demo/vdl_create_scratch_log'] if platform == 'win32': scripts.append('visualdl/server/visualDL.bat') @@ -128,4 +128,5 @@ setup( packages=packages, ext_modules=[Extension('_foo', ['stub.cc'])], scripts=scripts, - cmdclass=cmdclass) + cmdclass=cmdclass, + entry_points={'console_scripts': ['visualdl=visualdl.server.app:main']}) diff --git a/visualdl/server/app.py b/visualdl/server/app.py new file mode 100644 index 00000000..5ea3f389 --- /dev/null +++ b/visualdl/server/app.py @@ -0,0 +1,364 @@ +#!/user/bin/env python + +# Copyright (c) 2017 VisualDL Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ======================================================================= + +import json +import os +import time +import sys +import multiprocessing +import threading +import re +import webbrowser +import requests +from argparse import ArgumentParser + +from flask import (Flask, Response, redirect, request, send_file, + send_from_directory) + +import visualdl +import visualdl.server +from visualdl.server import lib +from visualdl.server.log import logger +from visualdl.server.mock import data as mock_data +from visualdl.server.mock import data as mock_tags +from visualdl.python.cache import MemCache +from visualdl.python.storage import (LogWriter, LogReader) + +try: + import exceptions +except: + pass + +error_retry_times = 3 +error_sleep_time = 2 # seconds + +SERVER_DIR = os.path.join(visualdl.ROOT, 'server') + +support_language = ["en", "zh"] +default_language = support_language[0] + +server_path = os.path.abspath(os.path.dirname(sys.argv[0])) +static_file_path = os.path.join(SERVER_DIR, "./dist") +mock_data_path = os.path.join(SERVER_DIR, "./mock_data/") + + +class ParseArgs(object): + def __init__(self, logdir, host="0.0.0.0", port=8040, model_pb="", cache_timeout=20, language=default_language): + self.logdir = logdir + self.host = host + self.port = port + self.model_pb = model_pb + self.cache_timeout = cache_timeout + self.language = language + + +def try_call(function, *args, **kwargs): + res = lib.retry(error_retry_times, function, error_sleep_time, *args, + **kwargs) + if not res: + logger.error("server temporary error, will retry latter.") + return res + + +def parse_args(): + """ + :return: + """ + parser = ArgumentParser( + description="VisualDL, a tool to visualize deep learning.") + parser.add_argument( + "-p", + "--port", + type=int, + default=8040, + action="store", + dest="port", + help="api service port") + parser.add_argument( + "-t", + "--host", + type=str, + default="0.0.0.0", + action="store", + help="api service ip") + parser.add_argument( + "-m", + "--model_pb", + type=str, + action="store", + help="model proto in ONNX format or in Paddle framework format") + parser.add_argument( + "--logdir", + required=True, + action="store", + dest="logdir", + help="log file directory") + parser.add_argument( + "--cache_timeout", + action="store", + dest="cache_timeout", + type=float, + default=20, + help="memory cache timeout duration in seconds, default 20", + ) + parser.add_argument( + "-L", + "--language", + type=str, + default=default_language, + action="store", + help="set the default language") + + args = parser.parse_args() + if not args.logdir: + parser.print_help() + sys.exit(-1) + return args + + +# status, msg, data +def gen_result(status, msg, data): + """ + :param status: + :param msg: + :return: + """ + result = dict() + result['status'] = status + result['msg'] = msg + result['data'] = data + return result + + +def create_app(args): + app = Flask(__name__, static_url_path="") + # set static expires in a short time to reduce browser's memory usage. + app.config['SEND_FILE_MAX_AGE_DEFAULT'] = 30 + + log_reader = LogReader(args.logdir) + + # mannully put graph's image on this path also works. + graph_image_path = os.path.join(args.logdir, 'graph.jpg') + # use a memory cache to reduce disk reading frequency. + CACHE = MemCache(timeout=args.cache_timeout) + cache_get = lib.cache_get(CACHE) + + @app.route("/") + def index(): + language = args.language + if not language in support_language: + language = default_language + if language == default_language: + return redirect('/app/index', code=302) + return redirect('/app/' + language + '/index', code=302) + + @app.route('/app/') + def serve_static(filename): + return send_from_directory( + os.path.join(server_path, static_file_path), filename if re.search(r'\..+$', filename) else filename + '.html') + + @app.route('/graphs/image') + def serve_graph(): + return send_file(os.path.join(os.getcwd(), graph_image_path)) + + @app.route('/api/logdir') + def logdir(): + result = gen_result(0, "", {"logdir": args.logdir}) + return Response(json.dumps(result), mimetype='application/json') + + @app.route('/api/runs') + def runs(): + data = cache_get('/data/runs', lib.get_modes, log_reader) + result = gen_result(0, "", data) + return Response(json.dumps(result), mimetype='application/json') + + @app.route('/api/language') + def language(): + data = args.language + if not data in support_language: + data = default_language + result = gen_result(0, "", data) + return Response(json.dumps(result), mimetype='application/json') + + @app.route("/api/scalars/tags") + def scalar_tags(): + data = cache_get("/data/plugin/scalars/tags", try_call, + lib.get_scalar_tags, log_reader) + result = gen_result(0, "", data) + return Response(json.dumps(result), mimetype='application/json') + + @app.route("/api/images/tags") + def image_tags(): + data = cache_get("/data/plugin/images/tags", try_call, lib.get_image_tags, + log_reader) + result = gen_result(0, "", data) + return Response(json.dumps(result), mimetype='application/json') + + @app.route("/api/audio/tags") + def audio_tags(): + data = cache_get("/data/plugin/audio/tags", try_call, lib.get_audio_tags, + log_reader) + result = gen_result(0, "", data) + return Response(json.dumps(result), mimetype='application/json') + + @app.route("/api/histograms/tags") + def histogram_tags(): + data = cache_get("/data/plugin/histograms/tags", try_call, + lib.get_histogram_tags, log_reader) + result = gen_result(0, "", data) + return Response(json.dumps(result), mimetype='application/json') + + @app.route("/api/texts/tags") + def texts_tags(): + data = cache_get("/data/plugin/texts/tags", try_call, + lib.get_texts_tags, log_reader) + result = gen_result(0, "", data) + return Response(json.dumps(result), mimetype='application/json') + + @app.route('/api/scalars/list') + def scalars(): + run = request.args.get('run') + tag = request.args.get('tag') + key = os.path.join('/data/plugin/scalars/scalars', run, tag) + data = cache_get(key, try_call, lib.get_scalar, log_reader, run, tag) + result = gen_result(0, "", data) + return Response(json.dumps(result), mimetype='application/json') + + @app.route('/api/images/list') + def images(): + mode = request.args.get('run') + tag = request.args.get('tag') + key = os.path.join('/data/plugin/images/images', mode, tag) + + data = cache_get(key, try_call, lib.get_image_tag_steps, log_reader, mode, + tag) + result = gen_result(0, "", data) + + return Response(json.dumps(result), mimetype='application/json') + + @app.route('/api/images/image') + def individual_image(): + mode = request.args.get('run') + tag = request.args.get('tag') # include a index + step_index = int(request.args.get('index')) # index of step + + key = os.path.join('/data/plugin/images/individualImage', mode, tag, + str(step_index)) + data = cache_get(key, try_call, lib.get_invididual_image, log_reader, mode, + tag, step_index) + response = send_file( + data, as_attachment=True, attachment_filename='img.png') + return response + + @app.route('/api/histograms/list') + def histogram(): + run = request.args.get('run') + tag = request.args.get('tag') + key = os.path.join('/data/plugin/histograms/histograms', run, tag) + data = cache_get(key, try_call, lib.get_histogram, log_reader, run, tag) + result = gen_result(0, "", data) + return Response(json.dumps(result), mimetype='application/json') + + @app.route('/api/texts/list') + def texts(): + run = request.args.get('run') + tag = request.args.get('tag') + key = os.path.join('/data/plugin/texts/texts', run, tag) + data = cache_get(key, try_call, lib.get_texts, log_reader, run, tag) + result = gen_result(0, "", data) + return Response(json.dumps(result), mimetype='application/json') + + @app.route('/api/embeddings/embedding') + def embeddings(): + run = request.args.get('run') + dimension = request.args.get('dimension') + reduction = request.args.get('reduction') + key = os.path.join('/data/plugin/embeddings/embeddings', run, dimension, reduction) + data = cache_get(key, try_call, lib.get_embeddings, log_reader, run, reduction, int(dimension)) + result = gen_result(0, "", data) + return Response(json.dumps(result), mimetype='application/json') + + @app.route('/api/audio/list') + def audio(): + mode = request.args.get('run') + tag = request.args.get('tag') + key = os.path.join('/data/plugin/audio/audio', mode, tag) + + data = cache_get(key, try_call, lib.get_audio_tag_steps, log_reader, mode, + tag) + result = gen_result(0, "", data) + + return Response(json.dumps(result), mimetype='application/json') + + @app.route('/api/audio/audio') + def individual_audio(): + mode = request.args.get('run') + tag = request.args.get('tag') # include a index + step_index = int(request.args.get('index')) # index of step + + key = os.path.join('/data/plugin/audio/individualAudio', mode, tag, + str(step_index)) + data = cache_get(key, try_call, lib.get_individual_audio, log_reader, mode, + tag, step_index) + response = send_file( + data, as_attachment=True, attachment_filename='audio.wav') + return response + + return app + + +def _open_browser(app, index_url): + while True: + try: + requests.get(index_url) + break + except Exception as e: + time.sleep(0.5) + webbrowser.open(index_url) + + +def _run(logdir, host="127.0.0.1", port=8080, model_pb="", cache_timeout=20, language=default_language, open_browser=False): + args = ParseArgs(logdir=logdir, host=host, port=port, model_pb=model_pb, cache_timeout=cache_timeout, language=language) + logger.info(" port=" + str(args.port)) + app = create_app(args) + index_url = "http://" + host + ":" + str(port) + if open_browser: + threading.Thread(target=_open_browser, kwargs={"app": app, "index_url": index_url}).start() + app.run(debug=False, host=args.host, port=args.port, threaded=True) + + +def run(logdir, host="127.0.0.1", port=8080, model_pb="", cache_timeout=20, language=default_language, open_browser=False): + kwarg = { + "logdir": logdir, + "host": host, + "port": port, + "model_pb": model_pb, + "cache_timeout": cache_timeout, + "language": language, + "open_browser": open_browser + } + + p = multiprocessing.Process(target=_run, kwargs=kwarg) + p.start() + return p.pid + + +def main(): + args = parse_args() + logger.info(" port=" + str(args.port)) + app = create_app(args=parse_args()) + app.run(debug=False, host=args.host, port=args.port, threaded=True) diff --git a/visualdl/server/data_manager.py b/visualdl/server/data_manager.py index 426b2506..0ff48920 100644 --- a/visualdl/server/data_manager.py +++ b/visualdl/server/data_manager.py @@ -48,8 +48,8 @@ class Reservoir(object): raise ValueError("Max_size must be nonnegative integer.") self._max_size = max_size self._buckets = collections.defaultdict( - lambda : _ReservoirBucket(max_size=self._max_size, - random_instance=random.Random(seed)) + lambda: _ReservoirBucket(max_size=self._max_size, + random_instance=random.Random(seed)) ) self._mutex = threading.Lock() @@ -117,7 +117,7 @@ class Reservoir(object): def get_items(self, mode, tag): """Get items with tag 'mode_tag' - + For usage habits of VisualDL, actually call self._get_items() Args: @@ -162,6 +162,23 @@ class Reservoir(object): key = mode + "_" + tag self._add_item(key, item) + def _cut_tail(self, key): + with self._mutex: + self._buckets[key].cut_tail() + + def cut_tail(self, mode, tag): + """Pop the last item in reservoir buckets. + + Sometimes the tail of the retrieved data is abnormal 0. This + method is used to handle this problem. + + Args: + mode: Identity of one tablet. + tag: Identity of one record in tablet. + """ + key = mode + "_" + tag + self._cut_tail(key) + class _ReservoirBucket(object): """Data manager for sampling data, use reservoir sampling. @@ -222,6 +239,16 @@ class _ReservoirBucket(object): with self._mutex: return self._num_items_index + def cut_tail(self): + """Pop the last item in reservoir buckets. + + Sometimes the tail of the retrieved data is abnormal 0. This + method is used to handle this problem. + """ + with self._mutex: + self._items.pop() + self._num_items_index -= 1 + class DataManager(object): """Data manager for all plugin. @@ -238,8 +265,8 @@ class DataManager(object): self._image_reservoir = Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["image"]) self._reservoirs = {"scalar": self._scalar_reservoir, - "histogram": self._histogram_reservoir, - "image": self._image_reservoir} + "histogram": self._histogram_reservoir, + "image": self._image_reservoir} self._mutex = threading.Lock() def get_reservoir(self, plugin): @@ -299,7 +326,6 @@ if __name__ == '__main__': print(b) c = d.get_reservoir("scalar").get_num_items_index("train", "loss") print(c) - print("***") b = d.get_reservoir("scalar").get_items("train", "accu") print(b) print(d.get_reservoir("scalar").get_num_items_index("train", "accu")) @@ -312,4 +338,4 @@ if __name__ == '__main__': print(c) print(d.get_reservoir("scalar").exist_in_keys("train", "loss")) - print(d.get_reservoir("scalar").exist_in_keys("train", "loss2")) \ No newline at end of file + print(d.get_reservoir("scalar").exist_in_keys("train", "loss2")) diff --git a/visualdl/server/lib.py b/visualdl/server/lib.py index de186418..71b54931 100644 --- a/visualdl/server/lib.py +++ b/visualdl/server/lib.py @@ -65,15 +65,25 @@ def get_scalar(storage, mode, tag, num_records=300): else: num_items_index = data_reservoir.get_num_items_index(mode, tag) if num_items_index != scalar.size(): - records = scalar.records(num_items_index) - ids = scalar.ids(num_items_index) - timestamps = scalar.timestamps(num_items_index) + try: + records = scalar.records(num_items_index) + ids = scalar.ids(num_items_index) + timestamps = scalar.timestamps(num_items_index) + except Exception: + error_info = '\n'.join(map(str, sys.exc_info())) + logger.error("Unexpected error: %s" % error_info) + return data_reservoir.get_items(mode, tag) data = list(zip(timestamps, ids, records)) for index in range(len(data)): data_reservoir.add_item(mode=mode, tag=tag, item=data[index]) - return data_reservoir.get_items(mode, tag) + results = data_reservoir.get_items(mode, tag) + # TODO(Superjomn) some bug here, sometimes there are zero here. + if results[-1][-1] == 0: + data_reservoir.cut_tail(mode=mode, tag=tag) + results = data_reservoir.get_items(mode, tag) + return results def get_image_tags(storage): diff --git a/visualdl/server/visualdl b/visualdl/server/visualdl deleted file mode 100644 index a0947c92..00000000 --- a/visualdl/server/visualdl +++ /dev/null @@ -1,352 +0,0 @@ -#!/user/bin/env python - -# Copyright (c) 2017 VisualDL Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ======================================================================= - -import json -import os -import time -import sys -import re -from argparse import ArgumentParser - -from flask import (Flask, Response, redirect, request, send_file, - send_from_directory) -from flask_babel import Babel - -import visualdl -import visualdl.server -import visualdl.server.graph as vdl_graph -import visualdl.server.model as model -from visualdl.server import lib -from visualdl.server.log import logger -from visualdl.server.mock import data as mock_data -from visualdl.server.mock import data as mock_tags -from visualdl.python.cache import MemCache -from visualdl.python.storage import (LogWriter, LogReader) - -try: - import exceptions -except: - pass - -support_language = ["en", "zh"] -default_language = support_language[0] - -app = Flask(__name__, static_url_path="") -# set static expires in a short time to reduce browser's memory usage. -app.config['SEND_FILE_MAX_AGE_DEFAULT'] = 30 -app.config['BABEL_DEFAULT_LOCALE'] = default_language -babel = Babel(app) - -error_retry_times = 3 -error_sleep_time = 2 # seconds - -SERVER_DIR = os.path.join(visualdl.ROOT, 'server') - - -def try_call(function, *args, **kwargs): - res = lib.retry(error_retry_times, function, error_sleep_time, *args, - **kwargs) - if not res: - logger.error("Internal server error. Retry later.") - return res - - -def parse_args(): - """ - :return: - """ - parser = ArgumentParser( - description="VisualDL, a tool to visualize deep learning.") - parser.add_argument( - "-p", - "--port", - type=int, - default=8040, - action="store", - dest="port", - help="api service port") - parser.add_argument( - "-t", - "--host", - type=str, - default="0.0.0.0", - action="store", - help="api service ip") - parser.add_argument( - "-m", - "--model_pb", - type=str, - action="store", - help="model proto in ONNX format or in Paddle framework format") - parser.add_argument( - "--logdir", - required=True, - action="store", - dest="logdir", - help="log file directory") - parser.add_argument( - "--cache_timeout", - action="store", - dest="cache_timeout", - type=float, - default=20, - help="memory cache timeout duration in seconds, default 20", - ) - parser.add_argument( - "-L", - "--language", - type=str, - action="store", - help="set the default language") - - args = parser.parse_args() - if not args.logdir: - parser.print_help() - sys.exit(-1) - return args - - -args = parse_args() -server_path = os.path.abspath(os.path.dirname(sys.argv[0])) -static_file_path = os.path.join(SERVER_DIR, "./dist") -mock_data_path = os.path.join(SERVER_DIR, "./mock_data/") - -log_reader = LogReader(args.logdir) - -# mannully put graph's image on this path also works. -graph_image_path = os.path.join(args.logdir, 'graph.jpg') -# use a memory cache to reduce disk reading frequency. -CACHE = MemCache(timeout=args.cache_timeout) -cache_get = lib.cache_get(CACHE) - - -# status, msg, data -def gen_result(status, msg, data): - """ - :param status: - :param msg: - :return: - """ - result = dict() - result['status'] = status - result['msg'] = msg - result['data'] = data - return result - - -@babel.localeselector -def get_locale(): - language = args.language - if not language or not language in support_language: - language = request.accept_languages.best_match(support_language) - return language - - -@app.route("/") -def index(): - language = get_locale() - if language == default_language: - return redirect('/app/index', code=302) - return redirect('/app/' + language + '/index', code=302) - - -@app.route('/app/') -def serve_static(filename): - return send_from_directory( - os.path.join(server_path, static_file_path), filename if re.search(r'\..+$', filename) else filename + '.html') - - -@app.route('/graphs/image') -def serve_graph(): - return send_file(os.path.join(os.getcwd(), graph_image_path)) - - -@app.route('/api/logdir') -def logdir(): - result = gen_result(0, "", {"logdir": args.logdir}) - return Response(json.dumps(result), mimetype='application/json') - - -@app.route('/api/runs') -def runs(): - data = cache_get('/data/runs', lib.get_modes, log_reader) - result = gen_result(0, "", data) - return Response(json.dumps(result), mimetype='application/json') - - -@app.route('/api/language') -def language(): - data = get_locale() - result = gen_result(0, "", data) - return Response(json.dumps(result), mimetype='application/json') - - -@app.route("/api/scalars/tags") -def scalar_tags(): - data = cache_get("/data/plugin/scalars/tags", try_call, - lib.get_scalar_tags, log_reader) - result = gen_result(0, "", data) - return Response(json.dumps(result), mimetype='application/json') - - -@app.route("/api/images/tags") -def image_tags(): - data = cache_get("/data/plugin/images/tags", try_call, lib.get_image_tags, - log_reader) - result = gen_result(0, "", data) - return Response(json.dumps(result), mimetype='application/json') - - -@app.route("/api/audio/tags") -def audio_tags(): - data = cache_get("/data/plugin/audio/tags", try_call, lib.get_audio_tags, - log_reader) - result = gen_result(0, "", data) - return Response(json.dumps(result), mimetype='application/json') - - -@app.route("/api/histograms/tags") -def histogram_tags(): - data = cache_get("/data/plugin/histograms/tags", try_call, - lib.get_histogram_tags, log_reader) - result = gen_result(0, "", data) - return Response(json.dumps(result), mimetype='application/json') - - -@app.route("/api/texts/tags") -def texts_tags(): - data = cache_get("/data/plugin/texts/tags", try_call, - lib.get_texts_tags, log_reader) - result = gen_result(0, "", data) - return Response(json.dumps(result), mimetype='application/json') - - -@app.route('/api/scalars/list') -def scalars(): - run = request.args.get('run') - tag = request.args.get('tag') - key = os.path.join('/data/plugin/scalars/scalars', run, tag) - data = cache_get(key, try_call, lib.get_scalar, log_reader, run, tag) - result = gen_result(0, "", data) - return Response(json.dumps(result), mimetype='application/json') - - -@app.route('/api/images/list') -def images(): - mode = request.args.get('run') - tag = request.args.get('tag') - key = os.path.join('/data/plugin/images/images', mode, tag) - - data = cache_get(key, try_call, lib.get_image_tag_steps, log_reader, mode, - tag) - result = gen_result(0, "", data) - - return Response(json.dumps(result), mimetype='application/json') - - -@app.route('/api/images/image') -def individual_image(): - mode = request.args.get('run') - tag = request.args.get('tag') # include a index - step_index = int(request.args.get('index')) # index of step - - key = os.path.join('/data/plugin/images/individualImage', mode, tag, - str(step_index)) - data = cache_get(key, try_call, lib.get_invididual_image, log_reader, mode, - tag, step_index) - response = send_file( - data, as_attachment=True, attachment_filename='img.png') - return response - - -@app.route('/api/histograms/list') -def histogram(): - run = request.args.get('run') - tag = request.args.get('tag') - key = os.path.join('/data/plugin/histograms/histograms', run, tag) - data = cache_get(key, try_call, lib.get_histogram, log_reader, run, tag) - result = gen_result(0, "", data) - return Response(json.dumps(result), mimetype='application/json') - - -@app.route('/api/texts/list') -def texts(): - run = request.args.get('run') - tag = request.args.get('tag') - key = os.path.join('/data/plugin/texts/texts', run, tag) - data = cache_get(key, try_call, lib.get_texts, log_reader, run, tag) - result = gen_result(0, "", data) - return Response(json.dumps(result), mimetype='application/json') - - -@app.route('/api/embeddings/embedding') -def embeddings(): - run = request.args.get('run') - dimension = request.args.get('dimension') - reduction = request.args.get('reduction') - key = os.path.join('/data/plugin/embeddings/embeddings', run, dimension, reduction) - data = cache_get(key, try_call, lib.get_embeddings, log_reader, run, reduction, int(dimension)) - result = gen_result(0, "", data) - return Response(json.dumps(result), mimetype='application/json') - - -@app.route('/api/audio/list') -def audio(): - mode = request.args.get('run') - tag = request.args.get('tag') - key = os.path.join('/data/plugin/audio/audio', mode, tag) - - data = cache_get(key, try_call, lib.get_audio_tag_steps, log_reader, mode, - tag) - result = gen_result(0, "", data) - - return Response(json.dumps(result), mimetype='application/json') - - -@app.route('/api/audio/audio') -def individual_audio(): - mode = request.args.get('run') - tag = request.args.get('tag') # include a index - step_index = int(request.args.get('index')) # index of step - - key = os.path.join('/data/plugin/audio/individualAudio', mode, tag, - str(step_index)) - data = cache_get(key, try_call, lib.get_individual_audio, log_reader, mode, - tag, step_index) - response = send_file( - data, as_attachment=True, attachment_filename='audio.wav') - return response - - -@app.route('/api/graphs/graph') -def graph(): - if model.is_onnx_model(args.model_pb): - json_str = vdl_graph.draw_onnx_graph(args.model_pb) - elif model.is_paddle_model(args.model_pb): - json_str = vdl_graph.draw_paddle_graph(args.model_pb) - else: - json_str = {} - data = {'data': json_str} - - result = gen_result(0, "", data) - return Response(json.dumps(result), mimetype='application/json') - - -if __name__ == '__main__': - logger.info(" port=" + str(args.port)) - - app.run(debug=False, host=args.host, port=args.port, threaded=True) -- GitLab