#!/user/bin/env python import json import os import time import sys from argparse import ArgumentParser from flask import (Flask, Response, redirect, request, send_file, send_from_directory) import visualdl import visualdl.server import visualdl.server.graph as vdl_graph from visualdl.server import lib from visualdl.server.log import logger from visualdl.server.mock import data as mock_data from visualdl.server.mock import data as mock_tags from visualdl.python.cache import MemCache from visualdl.python.storage import (LogWriter, LogReader) try: import exceptions except: pass app = Flask(__name__, static_url_path="") # set static expires in a short time to reduce browser's memory usage. app.config['SEND_FILE_MAX_AGE_DEFAULT'] = 30 error_retry_times = 3 error_sleep_time = 2 # seconds SERVER_DIR = os.path.join(visualdl.ROOT, 'server') def try_call(function, *args, **kwargs): res = lib.retry(error_retry_times, function, error_sleep_time, *args, **kwargs) if not res: logger.error("server temporary error, will retry latter.") return res def parse_args(): """ :return: """ parser = ArgumentParser( description="VisualDL, a tool to visualize deep learning.") parser.add_argument( "-p", "--port", type=int, default=8040, action="store", dest="port", help="api service port") parser.add_argument( "-t", "--host", type=str, default="0.0.0.0", action="store", help="api service ip") parser.add_argument( "-m", "--model_pb", type=str, action="store", help="model proto in ONNX format") parser.add_argument( "--logdir", required=True, action="store", dest="logdir", help="log file directory") parser.add_argument( "--cache_timeout", action="store", dest="cache_timeout", type=float, default=20, help="memory cache timeout duration in seconds, default 20", ) args = parser.parse_args() if not args.logdir: parser.print_help() sys.exit(-1) return args args = parse_args() server_path = os.path.abspath(os.path.dirname(sys.argv[0])) static_file_path = os.path.join(SERVER_DIR, "./dist") mock_data_path = os.path.join(SERVER_DIR, "./mock_data/") log_reader = LogReader(args.logdir) # mannully put graph's image on this path also works. graph_image_path = os.path.join(args.logdir, 'graph.jpg') # use a memory cache to reduce disk reading frequency. CACHE = MemCache(timeout=args.cache_timeout) cache_get = lib.cache_get(CACHE) # status, msg, data def gen_result(status, msg, data): """ :param status: :param msg: :return: """ result = dict() result['status'] = status result['msg'] = msg result['data'] = data return result @app.route("/") def index(): return redirect('/static/index.html', code=302) @app.route('/static/') def serve_static(filename): return send_from_directory( os.path.join(server_path, static_file_path), filename) @app.route('/graphs/image') def serve_graph(): return send_file(os.path.join(os.getcwd(), graph_image_path)) @app.route('/data/logdir') def logdir(): result = gen_result(0, "", {"logdir": args.logdir}) return Response(json.dumps(result), mimetype='application/json') @app.route('/data/runs') def runs(): data = cache_get('/data/runs', lib.get_modes, log_reader) result = gen_result(0, "", data) return Response(json.dumps(result), mimetype='application/json') @app.route("/data/plugin/scalars/tags") def scalar_tags(): data = cache_get("/data/plugin/scalars/tags", try_call, lib.get_scalar_tags, log_reader) result = gen_result(0, "", data) return Response(json.dumps(result), mimetype='application/json') @app.route("/data/plugin/images/tags") def image_tags(): data = cache_get("/data/plugin/images/tags", try_call, lib.get_image_tags, log_reader) result = gen_result(0, "", data) return Response(json.dumps(result), mimetype='application/json') @app.route("/data/plugin/histograms/tags") def histogram_tags(): data = cache_get("/data/plugin/histograms/tags", try_call, lib.get_histogram_tags, log_reader) result = gen_result(0, "", data) return Response(json.dumps(result), mimetype='application/json') @app.route('/data/plugin/scalars/scalars') def scalars(): run = request.args.get('run') tag = request.args.get('tag') key = os.path.join('/data/plugin/scalars/scalars', run, tag) data = cache_get(key, try_call, lib.get_scalar, log_reader, run, tag) result = gen_result(0, "", data) return Response(json.dumps(result), mimetype='application/json') @app.route('/data/plugin/images/images') def images(): mode = request.args.get('run') tag = request.args.get('tag') key = os.path.join('/data/plugin/images/images', mode, tag) data = cache_get(key, try_call, lib.get_image_tag_steps, log_reader, mode, tag) result = gen_result(0, "", data) return Response(json.dumps(result), mimetype='application/json') @app.route('/data/plugin/images/individualImage') def individual_image(): mode = request.args.get('run') tag = request.args.get('tag') # include a index step_index = int(request.args.get('index')) # index of step key = os.path.join('/data/plugin/images/individualImage', mode, tag, str(step_index)) data = cache_get(key, try_call, lib.get_invididual_image, log_reader, mode, tag, step_index) response = send_file( data, as_attachment=True, attachment_filename='img.png') return response @app.route('/data/plugin/histograms/histograms') def histogram(): run = request.args.get('run') tag = request.args.get('tag') key = os.path.join('/data/plugin/histograms/histograms', run, tag) data = cache_get(key, try_call, lib.get_histogram, log_reader, run, tag) result = gen_result(0, "", data) return Response(json.dumps(result), mimetype='application/json') @app.route('/data/plugin/graphs/graph') def graph(): # TODO(ChunweiYan) need to add a config for whether have graph. if graph_image_path is None or not os.path.isfile(graph_image_path): data = {'url': ''} else: data = {'url': '/graphs/image'} result = gen_result(0, "", data) return Response(json.dumps(result), mimetype='application/json') if __name__ == '__main__': logger.info(" port=" + str(args.port)) if args.model_pb: # draw graph image_dir = os.path.join(args.logdir, "graphs") if not os.path.isdir(image_dir): os.mkdir(image_dir) graph_image_path = vdl_graph.draw_graph(args.model_pb, image_dir) app.run(debug=False, host=args.host, port=args.port, threaded=True)