未验证 提交 6dffdeb6 编写于 作者: Y Yan Chunwei 提交者: GitHub

feature/add cache to reduce disk reading frequency (#169)

上级 4f41b19e
......@@ -4,10 +4,12 @@ import os
import random
import subprocess
import numpy as np
from PIL import Image
from scipy.stats import norm
from visualdl import ROOT, LogWriter
from visualdl.server.log import logger as log
logdir = './scratch_log'
......@@ -92,3 +94,20 @@ with logw.mode("train") as logger:
data = np.random.random(shape).flatten()
image0.add_sample(shape, list(data))
image0.finish_sampling()
def download_graph_image():
'''
This is a scratch demo, it do not generate a ONNX proto, but just download an image
that generated before to show how the graph frontend works.
For real cases, just refer to README.
'''
import urllib
image_url = "https://github.com/PaddlePaddle/VisualDL/blob/develop/demo/mxnet/super_resolution_graph.png?raw=true"
log.warning('download graph demo from {}'.format(image_url))
graph_image = urllib.urlopen(image_url).read()
with open(os.path.join(logdir, 'graph.jpg'), 'wb') as f:
f.write(graph_image)
log.warning('graph ready!')
download_graph_image()
......@@ -26,6 +26,7 @@ def readlines(name):
VERSION_NUMBER = read('VERSION_NUMBER')
LICENSE = readlines('LICENSE')[0].strip()
# use memcache to reduce disk read frequency.
install_requires = ['Flask', 'numpy', 'Pillow', 'protobuf', 'scipy']
execute_requires = ['npm', 'node', 'bash']
......
......@@ -25,3 +25,4 @@ function(py_test TARGET_NAME)
endfunction()
py_test(test_summary SRCS test_storage.py)
py_test(test_cache SRCS cache.py)
import time
class MemCache(object):
class Record:
def __init__(self, value):
self.time = time.time()
self.value = value
def clear(self):
self.value = None
def expired(self, timeout):
return timeout > 0 and time.time() - self.time >= timeout
'''
A global dict to help cache some temporary data.
'''
def __init__(self, timeout=-1):
self._timeout = timeout
self._data = {}
def set(self, key, value):
self._data[key] = MemCache.Record(value)
def get(self, key):
rcd = self._data.get(key, None)
if not rcd: return None
# do not delete the key to accelerate speed
if rcd.expired(self._timeout):
rcd.clear()
return None
return rcd.value
if __name__ == '__main__':
import unittest
class TestMemCacheTest(unittest.TestCase):
def setUp(self):
self.cache = MemCache(timeout=1)
def expire(self):
self.cache.set("message", "hello")
self.assertFalse(self.cache.expired(1))
time.sleep(4)
self.assertTrue(self.cache.expired(1))
def test_have_key(self):
self.cache.set('message', 'hello')
self.assertTrue(self.cache.get('message'))
time.sleep(1.1)
self.assertFalse(self.cache.get('message'))
self.assertTrue(self.cache.get("message") is None)
unittest.main()
......@@ -4,10 +4,10 @@ import unittest
import numpy as np
from PIL import Image
from visualdl import LogReader, LogWriter
pprint.pprint(sys.path)
from visualdl import LogWriter, LogReader
class StorageTest(unittest.TestCase):
......
import pprint
import re
import sys
import time
......@@ -7,6 +6,7 @@ from tempfile import NamedTemporaryFile
import numpy as np
from PIL import Image
from log import logger
......@@ -90,7 +90,6 @@ def get_image_tags(storage):
def get_image_tag_steps(storage, mode, tag):
print 'image_tag_steps,mode,tag:', mode, tag
# remove suffix '/x'
res = re.search(r".*/([0-9]+$)", tag)
sample_index = 0
......@@ -211,3 +210,14 @@ def retry(ntimes, function, time2sleep, *args, **kwargs):
error_info = '\n'.join(map(str, sys.exc_info()))
logger.error("Unexpected error: %s" % error_info)
time.sleep(time2sleep)
def cache_get(cache):
def _handler(key, func, *args, **kwargs):
data = cache.get(key)
if data is None:
logger.warning('update cache %s' % key)
data = func(*args, **kwargs)
cache.set(key, data)
return data
return data
return _handler
......@@ -17,6 +17,7 @@ from visualdl.server import lib
from visualdl.server.log import logger
from visualdl.server.mock import data as mock_data
from visualdl.server.mock import data as mock_tags
from visualdl.python.cache import MemCache
from visualdl.python.storage import (LogWriter, LogReader)
app = Flask(__name__, static_url_path="")
......@@ -33,7 +34,7 @@ def try_call(function, *args, **kwargs):
res = lib.retry(error_retry_times, function, error_sleep_time, *args,
**kwargs)
if not res:
raise exceptions.IOError("server IO error, will retry latter.")
logger.error("server temporary error, will retry latter.")
return res
......@@ -70,6 +71,14 @@ def parse_args():
action="store",
dest="logdir",
help="log file directory")
parser.add_argument(
"--cache_timeout",
action="store",
dest="cache_timeout",
type=float,
default=20,
help="memory cache timeout duration in seconds, default 20",
)
args = parser.parse_args()
if not args.logdir:
parser.print_help()
......@@ -86,8 +95,11 @@ log_reader = LogReader(args.logdir)
# mannully put graph's image on this path also works.
graph_image_path = os.path.join(args.logdir, 'graph.jpg')
# use a memory cache to reduce disk reading frequency.
CACHE = MemCache(timeout=args.cache_timeout)
cache_get = lib.cache_get(CACHE)
# return data
# status, msg, data
def gen_result(status, msg, data):
"""
......@@ -126,33 +138,32 @@ def logdir():
@app.route('/data/runs')
def runs():
result = gen_result(0, "", lib.get_modes(log_reader))
data = cache_get('/data/runs', lib.get_modes, log_reader)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
@app.route("/data/plugin/scalars/tags")
def scalar_tags():
mode = request.args.get('mode')
is_debug = bool(request.args.get('debug'))
result = try_call(lib.get_scalar_tags, log_reader)
result = gen_result(0, "", result)
data = cache_get("/data/plugin/scalars/tags", try_call,
lib.get_scalar_tags, log_reader)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
@app.route("/data/plugin/images/tags")
def image_tags():
mode = request.args.get('run')
result = try_call(lib.get_image_tags, log_reader)
result = gen_result(0, "", result)
data = cache_get("/data/plugin/images/tags", try_call, lib.get_image_tags,
log_reader)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
@app.route("/data/plugin/histograms/tags")
def histogram_tags():
mode = request.args.get('run')
# hack to avlid IO conflicts
result = try_call(lib.get_histogram_tags, log_reader)
result = gen_result(0, "", result)
data = cache_get("/data/plugin/histograms/tags", try_call,
lib.get_histogram_tags, log_reader)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
......@@ -160,8 +171,9 @@ def histogram_tags():
def scalars():
run = request.args.get('run')
tag = request.args.get('tag')
result = try_call(lib.get_scalar, log_reader, run, tag)
result = gen_result(0, "", result)
key = os.path.join('/data/plugin/scalars/scalars', run, tag)
data = cache_get(key, try_call, lib.get_scalar, log_reader, run, tag)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
......@@ -169,9 +181,11 @@ def scalars():
def images():
mode = request.args.get('run')
tag = request.args.get('tag')
key = os.path.join('/data/plugin/images/images', mode, tag)
result = try_call(lib.get_image_tag_steps, log_reader, mode, tag)
result = gen_result(0, "", result)
data = cache_get(key, try_call, lib.get_image_tag_steps, log_reader, mode,
tag)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
......@@ -181,12 +195,13 @@ def individual_image():
mode = request.args.get('run')
tag = request.args.get('tag') # include a index
step_index = int(request.args.get('index')) # index of step
offset = 0
imagefile = try_call(lib.get_invididual_image, log_reader, mode, tag,
step_index)
key = os.path.join('/data/plugin/images/individualImage', mode, tag,
str(step_index))
data = cache_get(key, try_call, lib.get_invididual_image, log_reader, mode,
tag, step_index)
response = send_file(
imagefile, as_attachment=True, attachment_filename='img.png')
data, as_attachment=True, attachment_filename='img.png')
return response
......@@ -194,8 +209,9 @@ def individual_image():
def histogram():
run = request.args.get('run')
tag = request.args.get('tag')
result = try_call(lib.get_histogram, log_reader, run, tag)
result = gen_result(0, "", result)
key = os.path.join('/data/plugin/histograms/histograms', run, tag)
data = cache_get(key, try_call, lib.get_histogram, log_reader, run, tag)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册