未验证 提交 6dffdeb6 编写于 作者: Y Yan Chunwei 提交者: GitHub

feature/add cache to reduce disk reading frequency (#169)

上级 4f41b19e
...@@ -4,10 +4,12 @@ import os ...@@ -4,10 +4,12 @@ import os
import random import random
import subprocess import subprocess
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from scipy.stats import norm from scipy.stats import norm
from visualdl import ROOT, LogWriter from visualdl import ROOT, LogWriter
from visualdl.server.log import logger as log
logdir = './scratch_log' logdir = './scratch_log'
...@@ -92,3 +94,20 @@ with logw.mode("train") as logger: ...@@ -92,3 +94,20 @@ with logw.mode("train") as logger:
data = np.random.random(shape).flatten() data = np.random.random(shape).flatten()
image0.add_sample(shape, list(data)) image0.add_sample(shape, list(data))
image0.finish_sampling() image0.finish_sampling()
def download_graph_image():
'''
This is a scratch demo, it do not generate a ONNX proto, but just download an image
that generated before to show how the graph frontend works.
For real cases, just refer to README.
'''
import urllib
image_url = "https://github.com/PaddlePaddle/VisualDL/blob/develop/demo/mxnet/super_resolution_graph.png?raw=true"
log.warning('download graph demo from {}'.format(image_url))
graph_image = urllib.urlopen(image_url).read()
with open(os.path.join(logdir, 'graph.jpg'), 'wb') as f:
f.write(graph_image)
log.warning('graph ready!')
download_graph_image()
...@@ -26,6 +26,7 @@ def readlines(name): ...@@ -26,6 +26,7 @@ def readlines(name):
VERSION_NUMBER = read('VERSION_NUMBER') VERSION_NUMBER = read('VERSION_NUMBER')
LICENSE = readlines('LICENSE')[0].strip() LICENSE = readlines('LICENSE')[0].strip()
# use memcache to reduce disk read frequency.
install_requires = ['Flask', 'numpy', 'Pillow', 'protobuf', 'scipy'] install_requires = ['Flask', 'numpy', 'Pillow', 'protobuf', 'scipy']
execute_requires = ['npm', 'node', 'bash'] execute_requires = ['npm', 'node', 'bash']
......
...@@ -25,3 +25,4 @@ function(py_test TARGET_NAME) ...@@ -25,3 +25,4 @@ function(py_test TARGET_NAME)
endfunction() endfunction()
py_test(test_summary SRCS test_storage.py) py_test(test_summary SRCS test_storage.py)
py_test(test_cache SRCS cache.py)
import time
class MemCache(object):
class Record:
def __init__(self, value):
self.time = time.time()
self.value = value
def clear(self):
self.value = None
def expired(self, timeout):
return timeout > 0 and time.time() - self.time >= timeout
'''
A global dict to help cache some temporary data.
'''
def __init__(self, timeout=-1):
self._timeout = timeout
self._data = {}
def set(self, key, value):
self._data[key] = MemCache.Record(value)
def get(self, key):
rcd = self._data.get(key, None)
if not rcd: return None
# do not delete the key to accelerate speed
if rcd.expired(self._timeout):
rcd.clear()
return None
return rcd.value
if __name__ == '__main__':
import unittest
class TestMemCacheTest(unittest.TestCase):
def setUp(self):
self.cache = MemCache(timeout=1)
def expire(self):
self.cache.set("message", "hello")
self.assertFalse(self.cache.expired(1))
time.sleep(4)
self.assertTrue(self.cache.expired(1))
def test_have_key(self):
self.cache.set('message', 'hello')
self.assertTrue(self.cache.get('message'))
time.sleep(1.1)
self.assertFalse(self.cache.get('message'))
self.assertTrue(self.cache.get("message") is None)
unittest.main()
...@@ -4,10 +4,10 @@ import unittest ...@@ -4,10 +4,10 @@ import unittest
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from visualdl import LogReader, LogWriter
pprint.pprint(sys.path) pprint.pprint(sys.path)
from visualdl import LogWriter, LogReader
class StorageTest(unittest.TestCase): class StorageTest(unittest.TestCase):
......
import pprint
import re import re
import sys import sys
import time import time
...@@ -7,6 +6,7 @@ from tempfile import NamedTemporaryFile ...@@ -7,6 +6,7 @@ from tempfile import NamedTemporaryFile
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from log import logger from log import logger
...@@ -90,7 +90,6 @@ def get_image_tags(storage): ...@@ -90,7 +90,6 @@ def get_image_tags(storage):
def get_image_tag_steps(storage, mode, tag): def get_image_tag_steps(storage, mode, tag):
print 'image_tag_steps,mode,tag:', mode, tag
# remove suffix '/x' # remove suffix '/x'
res = re.search(r".*/([0-9]+$)", tag) res = re.search(r".*/([0-9]+$)", tag)
sample_index = 0 sample_index = 0
...@@ -211,3 +210,14 @@ def retry(ntimes, function, time2sleep, *args, **kwargs): ...@@ -211,3 +210,14 @@ def retry(ntimes, function, time2sleep, *args, **kwargs):
error_info = '\n'.join(map(str, sys.exc_info())) error_info = '\n'.join(map(str, sys.exc_info()))
logger.error("Unexpected error: %s" % error_info) logger.error("Unexpected error: %s" % error_info)
time.sleep(time2sleep) time.sleep(time2sleep)
def cache_get(cache):
def _handler(key, func, *args, **kwargs):
data = cache.get(key)
if data is None:
logger.warning('update cache %s' % key)
data = func(*args, **kwargs)
cache.set(key, data)
return data
return data
return _handler
...@@ -17,6 +17,7 @@ from visualdl.server import lib ...@@ -17,6 +17,7 @@ from visualdl.server import lib
from visualdl.server.log import logger from visualdl.server.log import logger
from visualdl.server.mock import data as mock_data from visualdl.server.mock import data as mock_data
from visualdl.server.mock import data as mock_tags from visualdl.server.mock import data as mock_tags
from visualdl.python.cache import MemCache
from visualdl.python.storage import (LogWriter, LogReader) from visualdl.python.storage import (LogWriter, LogReader)
app = Flask(__name__, static_url_path="") app = Flask(__name__, static_url_path="")
...@@ -33,7 +34,7 @@ def try_call(function, *args, **kwargs): ...@@ -33,7 +34,7 @@ def try_call(function, *args, **kwargs):
res = lib.retry(error_retry_times, function, error_sleep_time, *args, res = lib.retry(error_retry_times, function, error_sleep_time, *args,
**kwargs) **kwargs)
if not res: if not res:
raise exceptions.IOError("server IO error, will retry latter.") logger.error("server temporary error, will retry latter.")
return res return res
...@@ -70,6 +71,14 @@ def parse_args(): ...@@ -70,6 +71,14 @@ def parse_args():
action="store", action="store",
dest="logdir", dest="logdir",
help="log file directory") help="log file directory")
parser.add_argument(
"--cache_timeout",
action="store",
dest="cache_timeout",
type=float,
default=20,
help="memory cache timeout duration in seconds, default 20",
)
args = parser.parse_args() args = parser.parse_args()
if not args.logdir: if not args.logdir:
parser.print_help() parser.print_help()
...@@ -86,8 +95,11 @@ log_reader = LogReader(args.logdir) ...@@ -86,8 +95,11 @@ log_reader = LogReader(args.logdir)
# mannully put graph's image on this path also works. # mannully put graph's image on this path also works.
graph_image_path = os.path.join(args.logdir, 'graph.jpg') graph_image_path = os.path.join(args.logdir, 'graph.jpg')
# use a memory cache to reduce disk reading frequency.
CACHE = MemCache(timeout=args.cache_timeout)
cache_get = lib.cache_get(CACHE)
# return data
# status, msg, data # status, msg, data
def gen_result(status, msg, data): def gen_result(status, msg, data):
""" """
...@@ -126,33 +138,32 @@ def logdir(): ...@@ -126,33 +138,32 @@ def logdir():
@app.route('/data/runs') @app.route('/data/runs')
def runs(): def runs():
result = gen_result(0, "", lib.get_modes(log_reader)) data = cache_get('/data/runs', lib.get_modes, log_reader)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json') return Response(json.dumps(result), mimetype='application/json')
@app.route("/data/plugin/scalars/tags") @app.route("/data/plugin/scalars/tags")
def scalar_tags(): def scalar_tags():
mode = request.args.get('mode') data = cache_get("/data/plugin/scalars/tags", try_call,
is_debug = bool(request.args.get('debug')) lib.get_scalar_tags, log_reader)
result = try_call(lib.get_scalar_tags, log_reader) result = gen_result(0, "", data)
result = gen_result(0, "", result)
return Response(json.dumps(result), mimetype='application/json') return Response(json.dumps(result), mimetype='application/json')
@app.route("/data/plugin/images/tags") @app.route("/data/plugin/images/tags")
def image_tags(): def image_tags():
mode = request.args.get('run') data = cache_get("/data/plugin/images/tags", try_call, lib.get_image_tags,
result = try_call(lib.get_image_tags, log_reader) log_reader)
result = gen_result(0, "", result) result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json') return Response(json.dumps(result), mimetype='application/json')
@app.route("/data/plugin/histograms/tags") @app.route("/data/plugin/histograms/tags")
def histogram_tags(): def histogram_tags():
mode = request.args.get('run') data = cache_get("/data/plugin/histograms/tags", try_call,
# hack to avlid IO conflicts lib.get_histogram_tags, log_reader)
result = try_call(lib.get_histogram_tags, log_reader) result = gen_result(0, "", data)
result = gen_result(0, "", result)
return Response(json.dumps(result), mimetype='application/json') return Response(json.dumps(result), mimetype='application/json')
...@@ -160,8 +171,9 @@ def histogram_tags(): ...@@ -160,8 +171,9 @@ def histogram_tags():
def scalars(): def scalars():
run = request.args.get('run') run = request.args.get('run')
tag = request.args.get('tag') tag = request.args.get('tag')
result = try_call(lib.get_scalar, log_reader, run, tag) key = os.path.join('/data/plugin/scalars/scalars', run, tag)
result = gen_result(0, "", result) data = cache_get(key, try_call, lib.get_scalar, log_reader, run, tag)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json') return Response(json.dumps(result), mimetype='application/json')
...@@ -169,9 +181,11 @@ def scalars(): ...@@ -169,9 +181,11 @@ def scalars():
def images(): def images():
mode = request.args.get('run') mode = request.args.get('run')
tag = request.args.get('tag') tag = request.args.get('tag')
key = os.path.join('/data/plugin/images/images', mode, tag)
result = try_call(lib.get_image_tag_steps, log_reader, mode, tag) data = cache_get(key, try_call, lib.get_image_tag_steps, log_reader, mode,
result = gen_result(0, "", result) tag)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json') return Response(json.dumps(result), mimetype='application/json')
...@@ -181,12 +195,13 @@ def individual_image(): ...@@ -181,12 +195,13 @@ def individual_image():
mode = request.args.get('run') mode = request.args.get('run')
tag = request.args.get('tag') # include a index tag = request.args.get('tag') # include a index
step_index = int(request.args.get('index')) # index of step step_index = int(request.args.get('index')) # index of step
offset = 0
imagefile = try_call(lib.get_invididual_image, log_reader, mode, tag, key = os.path.join('/data/plugin/images/individualImage', mode, tag,
step_index) str(step_index))
data = cache_get(key, try_call, lib.get_invididual_image, log_reader, mode,
tag, step_index)
response = send_file( response = send_file(
imagefile, as_attachment=True, attachment_filename='img.png') data, as_attachment=True, attachment_filename='img.png')
return response return response
...@@ -194,8 +209,9 @@ def individual_image(): ...@@ -194,8 +209,9 @@ def individual_image():
def histogram(): def histogram():
run = request.args.get('run') run = request.args.get('run')
tag = request.args.get('tag') tag = request.args.get('tag')
result = try_call(lib.get_histogram, log_reader, run, tag) key = os.path.join('/data/plugin/histograms/histograms', run, tag)
result = gen_result(0, "", result) data = cache_get(key, try_call, lib.get_histogram, log_reader, run, tag)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json') return Response(json.dumps(result), mimetype='application/json')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册