未验证 提交 3e8604ba 编写于 作者: 走神的阿圆's avatar 走神的阿圆 提交者: GitHub

Add --model to parameters. (#661)

* Add --model to parameters.
上级 2b9e8c65
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =======================================================================
import os
from visualdl.io import bfile
from visualdl.component import components
from visualdl.reader.record_reader import RecordReader
......@@ -61,6 +62,22 @@ class LogReader(object):
self.load_new_data(update=True)
self._a_tags = {}
self._model = ""
@property
def model(self):
return self._model
@model.setter
def model(self, model_path):
if not os.path.isfile(model_path):
print("Model path %s should be file path, please check this path." % model_path)
else:
if os.path.exists(model_path):
self._model = model_path
else:
print("Model path %s is invalid, please check this path." % model_path)
@property
def logdir(self):
return self.dir
......@@ -213,8 +230,8 @@ class LogReader(object):
if update is True:
self.load_new_data(update=update)
components_set = set(self._tags.values())
if 0 == len(components_set):
return {'scalar'}
components_set.add('scalar')
return components_set
def load_new_data(self, update=True):
......
......@@ -57,8 +57,9 @@ def try_call(function, *args, **kwargs):
class Api(object):
def __init__(self, logdir, cache_timeout):
def __init__(self, logdir, model, cache_timeout):
self._reader = LogReader(logdir)
self._reader.model = model
# use a memory cache to reduce disk reading frequency.
cache = MemCache(timeout=cache_timeout)
......@@ -144,9 +145,14 @@ class Api(object):
key = os.path.join('data/plugin/embeddings/embeddings', run, tag)
return self._get_with_retry(key, lib.get_embeddings, run, tag)
@result('application/octet-stream')
def graphs_graph(self):
key = os.path.join('data/plugin/graphs/graph')
return self._get_with_retry(key, lib.get_graph)
def create_api_call(logdir, cache_timeout):
api = Api(logdir, cache_timeout)
def create_api_call(logdir, model, cache_timeout):
api = Api(logdir, model, cache_timeout)
routes = {
'components': (api.components, []),
'runs': (api.runs, []),
......@@ -163,7 +169,8 @@ def create_api_call(logdir, cache_timeout):
'audio/list': (api.audio_list, ['run', 'tag']),
'audio/audio': (api.audio_audio, ['run', 'tag', 'index']),
'embeddings/embedding': (api.embeddings_embedding, ['run', 'tag', 'reduction', 'dimension']),
'histogram/histogram': (api.histogram_histogram, ['run', 'tag'])
'histogram/histogram': (api.histogram_histogram, ['run', 'tag']),
'graphs/graph': (api.graphs_graph, [])
}
def call(path: str, args):
......
......@@ -52,7 +52,7 @@ def create_app(args):
app.config['BABEL_DEFAULT_LOCALE'] = default_language
babel = Babel(app)
api_call = create_api_call(args.logdir, args.cache_timeout)
api_call = create_api_call(args.logdir, args.model, args.cache_timeout)
update_util.PbUpdater().start()
......@@ -100,6 +100,7 @@ def create_app(args):
def serve_api(method):
data, mimetype = api_call(method, request.args)
return make_response(Response(data, mimetype=mimetype))
return app
......
......@@ -77,6 +77,7 @@ class ParseArgs(object):
self.public_path = args.public_path
self.api_only = args.api_only
self.open_browser = args.open_browser
self.model = args.model
def parse_args():
......@@ -105,6 +106,13 @@ def parse_args():
default=default_host,
action="store",
help="api service ip")
parser.add_argument(
"--model",
type=str,
action="store",
dest="model",
default="",
help="model file path")
parser.add_argument(
"--cache_timeout",
action="store",
......
......@@ -155,6 +155,14 @@ def get_histogram(log_reader, run, tag):
return results
def get_graph(log_reader):
result = b""
if log_reader.model:
with open(log_reader.model, "rb") as fp:
result = fp.read()
return result
def retry(ntimes, function, time2sleep, *args, **kwargs):
'''
try to execute `function` `ntimes`, if exception catched, the thread will
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册