未验证 提交 3e8604ba 编写于 作者: 走神的阿圆's avatar 走神的阿圆 提交者: GitHub

Add --model to parameters. (#661)

* Add --model to parameters.
上级 2b9e8c65
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ======================================================================= # =======================================================================
import os
from visualdl.io import bfile from visualdl.io import bfile
from visualdl.component import components from visualdl.component import components
from visualdl.reader.record_reader import RecordReader from visualdl.reader.record_reader import RecordReader
...@@ -61,6 +62,22 @@ class LogReader(object): ...@@ -61,6 +62,22 @@ class LogReader(object):
self.load_new_data(update=True) self.load_new_data(update=True)
self._a_tags = {} self._a_tags = {}
self._model = ""
@property
def model(self):
return self._model
@model.setter
def model(self, model_path):
if not os.path.isfile(model_path):
print("Model path %s should be file path, please check this path." % model_path)
else:
if os.path.exists(model_path):
self._model = model_path
else:
print("Model path %s is invalid, please check this path." % model_path)
@property @property
def logdir(self): def logdir(self):
return self.dir return self.dir
...@@ -213,8 +230,8 @@ class LogReader(object): ...@@ -213,8 +230,8 @@ class LogReader(object):
if update is True: if update is True:
self.load_new_data(update=update) self.load_new_data(update=update)
components_set = set(self._tags.values()) components_set = set(self._tags.values())
if 0 == len(components_set): components_set.add('scalar')
return {'scalar'}
return components_set return components_set
def load_new_data(self, update=True): def load_new_data(self, update=True):
......
...@@ -57,8 +57,9 @@ def try_call(function, *args, **kwargs): ...@@ -57,8 +57,9 @@ def try_call(function, *args, **kwargs):
class Api(object): class Api(object):
def __init__(self, logdir, cache_timeout): def __init__(self, logdir, model, cache_timeout):
self._reader = LogReader(logdir) self._reader = LogReader(logdir)
self._reader.model = model
# use a memory cache to reduce disk reading frequency. # use a memory cache to reduce disk reading frequency.
cache = MemCache(timeout=cache_timeout) cache = MemCache(timeout=cache_timeout)
...@@ -144,9 +145,14 @@ class Api(object): ...@@ -144,9 +145,14 @@ class Api(object):
key = os.path.join('data/plugin/embeddings/embeddings', run, tag) key = os.path.join('data/plugin/embeddings/embeddings', run, tag)
return self._get_with_retry(key, lib.get_embeddings, run, tag) return self._get_with_retry(key, lib.get_embeddings, run, tag)
@result('application/octet-stream')
def graphs_graph(self):
key = os.path.join('data/plugin/graphs/graph')
return self._get_with_retry(key, lib.get_graph)
def create_api_call(logdir, cache_timeout):
api = Api(logdir, cache_timeout) def create_api_call(logdir, model, cache_timeout):
api = Api(logdir, model, cache_timeout)
routes = { routes = {
'components': (api.components, []), 'components': (api.components, []),
'runs': (api.runs, []), 'runs': (api.runs, []),
...@@ -163,7 +169,8 @@ def create_api_call(logdir, cache_timeout): ...@@ -163,7 +169,8 @@ def create_api_call(logdir, cache_timeout):
'audio/list': (api.audio_list, ['run', 'tag']), 'audio/list': (api.audio_list, ['run', 'tag']),
'audio/audio': (api.audio_audio, ['run', 'tag', 'index']), 'audio/audio': (api.audio_audio, ['run', 'tag', 'index']),
'embeddings/embedding': (api.embeddings_embedding, ['run', 'tag', 'reduction', 'dimension']), 'embeddings/embedding': (api.embeddings_embedding, ['run', 'tag', 'reduction', 'dimension']),
'histogram/histogram': (api.histogram_histogram, ['run', 'tag']) 'histogram/histogram': (api.histogram_histogram, ['run', 'tag']),
'graphs/graph': (api.graphs_graph, [])
} }
def call(path: str, args): def call(path: str, args):
......
...@@ -52,7 +52,7 @@ def create_app(args): ...@@ -52,7 +52,7 @@ def create_app(args):
app.config['BABEL_DEFAULT_LOCALE'] = default_language app.config['BABEL_DEFAULT_LOCALE'] = default_language
babel = Babel(app) babel = Babel(app)
api_call = create_api_call(args.logdir, args.cache_timeout) api_call = create_api_call(args.logdir, args.model, args.cache_timeout)
update_util.PbUpdater().start() update_util.PbUpdater().start()
...@@ -100,6 +100,7 @@ def create_app(args): ...@@ -100,6 +100,7 @@ def create_app(args):
def serve_api(method): def serve_api(method):
data, mimetype = api_call(method, request.args) data, mimetype = api_call(method, request.args)
return make_response(Response(data, mimetype=mimetype)) return make_response(Response(data, mimetype=mimetype))
return app return app
......
...@@ -77,6 +77,7 @@ class ParseArgs(object): ...@@ -77,6 +77,7 @@ class ParseArgs(object):
self.public_path = args.public_path self.public_path = args.public_path
self.api_only = args.api_only self.api_only = args.api_only
self.open_browser = args.open_browser self.open_browser = args.open_browser
self.model = args.model
def parse_args(): def parse_args():
...@@ -105,6 +106,13 @@ def parse_args(): ...@@ -105,6 +106,13 @@ def parse_args():
default=default_host, default=default_host,
action="store", action="store",
help="api service ip") help="api service ip")
parser.add_argument(
"--model",
type=str,
action="store",
dest="model",
default="",
help="model file path")
parser.add_argument( parser.add_argument(
"--cache_timeout", "--cache_timeout",
action="store", action="store",
......
...@@ -155,6 +155,14 @@ def get_histogram(log_reader, run, tag): ...@@ -155,6 +155,14 @@ def get_histogram(log_reader, run, tag):
return results return results
def get_graph(log_reader):
result = b""
if log_reader.model:
with open(log_reader.model, "rb") as fp:
result = fp.read()
return result
def retry(ntimes, function, time2sleep, *args, **kwargs): def retry(ntimes, function, time2sleep, *args, **kwargs):
''' '''
try to execute `function` `ntimes`, if exception catched, the thread will try to execute `function` `ntimes`, if exception catched, the thread will
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册