From 3e8604babe4ed7f4b1c6db1346135e5233d81c95 Mon Sep 17 00:00:00 2001 From: ShenYuhan Date: Thu, 11 Jun 2020 19:42:53 +0800 Subject: [PATCH] Add --model to parameters. (#661) * Add --model to parameters. --- visualdl/reader/reader.py | 21 +++++++++++++++++++-- visualdl/server/api.py | 15 +++++++++++---- visualdl/server/app.py | 3 ++- visualdl/server/args.py | 8 ++++++++ visualdl/server/lib.py | 8 ++++++++ 5 files changed, 48 insertions(+), 7 deletions(-) diff --git a/visualdl/reader/reader.py b/visualdl/reader/reader.py index c89db3b6..b1bbb650 100644 --- a/visualdl/reader/reader.py +++ b/visualdl/reader/reader.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ======================================================================= +import os from visualdl.io import bfile from visualdl.component import components from visualdl.reader.record_reader import RecordReader @@ -61,6 +62,22 @@ class LogReader(object): self.load_new_data(update=True) self._a_tags = {} + self._model = "" + + @property + def model(self): + return self._model + + @model.setter + def model(self, model_path): + if not os.path.isfile(model_path): + print("Model path %s should be file path, please check this path." % model_path) + else: + if os.path.exists(model_path): + self._model = model_path + else: + print("Model path %s is invalid, please check this path." % model_path) + @property def logdir(self): return self.dir @@ -213,8 +230,8 @@ class LogReader(object): if update is True: self.load_new_data(update=update) components_set = set(self._tags.values()) - if 0 == len(components_set): - return {'scalar'} + components_set.add('scalar') + return components_set def load_new_data(self, update=True): diff --git a/visualdl/server/api.py b/visualdl/server/api.py index 463bcaf1..d68ede19 100644 --- a/visualdl/server/api.py +++ b/visualdl/server/api.py @@ -57,8 +57,9 @@ def try_call(function, *args, **kwargs): class Api(object): - def __init__(self, logdir, cache_timeout): + def __init__(self, logdir, model, cache_timeout): self._reader = LogReader(logdir) + self._reader.model = model # use a memory cache to reduce disk reading frequency. cache = MemCache(timeout=cache_timeout) @@ -144,9 +145,14 @@ class Api(object): key = os.path.join('data/plugin/embeddings/embeddings', run, tag) return self._get_with_retry(key, lib.get_embeddings, run, tag) + @result('application/octet-stream') + def graphs_graph(self): + key = os.path.join('data/plugin/graphs/graph') + return self._get_with_retry(key, lib.get_graph) -def create_api_call(logdir, cache_timeout): - api = Api(logdir, cache_timeout) + +def create_api_call(logdir, model, cache_timeout): + api = Api(logdir, model, cache_timeout) routes = { 'components': (api.components, []), 'runs': (api.runs, []), @@ -163,7 +169,8 @@ def create_api_call(logdir, cache_timeout): 'audio/list': (api.audio_list, ['run', 'tag']), 'audio/audio': (api.audio_audio, ['run', 'tag', 'index']), 'embeddings/embedding': (api.embeddings_embedding, ['run', 'tag', 'reduction', 'dimension']), - 'histogram/histogram': (api.histogram_histogram, ['run', 'tag']) + 'histogram/histogram': (api.histogram_histogram, ['run', 'tag']), + 'graphs/graph': (api.graphs_graph, []) } def call(path: str, args): diff --git a/visualdl/server/app.py b/visualdl/server/app.py index db1f6ae5..aea6ee19 100644 --- a/visualdl/server/app.py +++ b/visualdl/server/app.py @@ -52,7 +52,7 @@ def create_app(args): app.config['BABEL_DEFAULT_LOCALE'] = default_language babel = Babel(app) - api_call = create_api_call(args.logdir, args.cache_timeout) + api_call = create_api_call(args.logdir, args.model, args.cache_timeout) update_util.PbUpdater().start() @@ -100,6 +100,7 @@ def create_app(args): def serve_api(method): data, mimetype = api_call(method, request.args) return make_response(Response(data, mimetype=mimetype)) + return app diff --git a/visualdl/server/args.py b/visualdl/server/args.py index ff42c227..a298d24e 100644 --- a/visualdl/server/args.py +++ b/visualdl/server/args.py @@ -77,6 +77,7 @@ class ParseArgs(object): self.public_path = args.public_path self.api_only = args.api_only self.open_browser = args.open_browser + self.model = args.model def parse_args(): @@ -105,6 +106,13 @@ def parse_args(): default=default_host, action="store", help="api service ip") + parser.add_argument( + "--model", + type=str, + action="store", + dest="model", + default="", + help="model file path") parser.add_argument( "--cache_timeout", action="store", diff --git a/visualdl/server/lib.py b/visualdl/server/lib.py index 371fce30..525734d4 100644 --- a/visualdl/server/lib.py +++ b/visualdl/server/lib.py @@ -155,6 +155,14 @@ def get_histogram(log_reader, run, tag): return results +def get_graph(log_reader): + result = b"" + if log_reader.model: + with open(log_reader.model, "rb") as fp: + result = fp.read() + return result + + def retry(ntimes, function, time2sleep, *args, **kwargs): ''' try to execute `function` `ntimes`, if exception catched, the thread will -- GitLab