未验证 提交 2392f4bc 编写于 作者: C chenjian 提交者: GitHub

add available component tab control (#1157)

* add available component tab control

* fix string format
上级 07fc8488
...@@ -88,6 +88,18 @@ class ProfilerReader(object): ...@@ -88,6 +88,18 @@ class ProfilerReader(object):
else: else:
return None return None
def component_tabs(self, update=False):
"""Get component tabs used by vdl frontend.
"""
component_tabs = set()
if not self.logdir:
return component_tabs
if update is True:
self.runs(update=update)
if self.walks:
component_tabs.add('profiler')
return component_tabs
def profile_runs(self, update=False): def profile_runs(self, update=False):
"""Get profile run files. """Get profile run files.
......
...@@ -59,6 +59,14 @@ class ProfilerApi(object): ...@@ -59,6 +59,14 @@ class ProfilerApi(object):
lang = lang.lower() lang = lang.lower()
return self._reader.get_descriptions(lang) return self._reader.get_descriptions(lang)
def component_tabs(self):
'''
Get all component tabs supported by readers in Api.
'''
tabs = set()
tabs.update(self._reader.component_tabs(update=True))
return tabs
@result() @result()
def overview_environment(self, run, worker, span): def overview_environment(self, run, worker, span):
run_manager = self._reader.get_run_manager(run) run_manager = self._reader.get_run_manager(run)
...@@ -382,7 +390,8 @@ def create_profiler_api_call(logdir): ...@@ -382,7 +390,8 @@ def create_profiler_api_call(logdir):
'comparison/phase_table_inner': (api.comparison_phase_table_inner, [ 'comparison/phase_table_inner': (api.comparison_phase_table_inner, [
"base_run", "base_worker", "base_span", "exp_run", "exp_worker", "base_run", "base_worker", "base_span", "exp_run", "exp_worker",
"exp_span", "phase_name" "exp_span", "phase_name"
]) ]),
'component_tabs': (api.component_tabs, [])
} }
def call(path: str, args): def call(path: str, args):
......
...@@ -60,6 +60,18 @@ class GraphReader(object): ...@@ -60,6 +60,18 @@ class GraphReader(object):
def logdir(self): def logdir(self):
return self.dir return self.dir
def component_tabs(self, update=False):
"""Get component tabs used by vdl frontend.
"""
component_tabs = set()
if not self.logdir:
return component_tabs
if update is True:
self.runs(update=update)
if self.walks:
component_tabs.add('dynamic_graph')
return component_tabs
def get_all_walk(self): def get_all_walk(self):
flush_walks = {} flush_walks = {}
for dir in self.dir: for dir in self.dir:
......
...@@ -12,15 +12,16 @@ ...@@ -12,15 +12,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ======================================================================= # =======================================================================
import collections import collections
from functools import partial # noqa: F401 from functools import partial # noqa: F401
from visualdl.io import bfile
from visualdl.component import components from visualdl.component import components
from visualdl.io import bfile
from visualdl.proto import record_pb2
from visualdl.reader.record_reader import RecordReader from visualdl.reader.record_reader import RecordReader
from visualdl.server.data_manager import default_data_manager from visualdl.server.data_manager import default_data_manager
from visualdl.proto import record_pb2 from visualdl.utils.string_util import decode_tag
from visualdl.utils.string_util import decode_tag, encode_tag from visualdl.utils.string_util import encode_tag
def is_VDLRecord_file(path, check=False): def is_VDLRecord_file(path, check=False):
...@@ -74,10 +75,12 @@ class LogReader(object): ...@@ -74,10 +75,12 @@ class LogReader(object):
# {'run': {'scalar': {'tag1': data, 'tag2': data}}} # {'run': {'scalar': {'tag1': data, 'tag2': data}}}
self._log_datas = collections.defaultdict( self._log_datas = collections.defaultdict(
lambda: collections.defaultdict(lambda: collections.defaultdict(list))) lambda: collections.defaultdict(lambda: collections.defaultdict(
list)))
if file_path: if file_path:
self._log_data = collections.defaultdict(lambda: collections.defaultdict(list)) self._log_data = collections.defaultdict(lambda: collections.
defaultdict(list))
self.get_file_reader(file_path=file_path) self.get_file_reader(file_path=file_path)
remain = self.get_remain() remain = self.get_remain()
self.read_log_data(remain=remain) self.read_log_data(remain=remain)
...@@ -100,12 +103,15 @@ class LogReader(object): ...@@ -100,12 +103,15 @@ class LogReader(object):
self._model = model_path self._model = model_path
with bfile.BFile(model_path, 'rb') as bfp: with bfile.BFile(model_path, 'rb') as bfp:
if not bfp.isfile(model_path): if not bfp.isfile(model_path):
print("Model path %s should be file path, please check this path." % model_path) print(
"Model path %s should be file path, please check this path."
% model_path)
else: else:
if bfile.exists(model_path): if bfile.exists(model_path):
self._model = model_path self._model = model_path
else: else:
print("Model path %s is invalid, please check this path." % model_path) print("Model path %s is invalid, please check this path." %
model_path)
@property @property
def logdir(self): def logdir(self):
...@@ -130,15 +136,16 @@ class LogReader(object): ...@@ -130,15 +136,16 @@ class LogReader(object):
return log_tags return log_tags
def get_log_data(self, component, run, tag): def get_log_data(self, component, run, tag):
if (run in self._log_datas.keys() and if (run in self._log_datas.keys()
component in self._log_datas[run].keys() and and component in self._log_datas[run].keys()
tag in self._log_datas[run][component].keys()): and tag in self._log_datas[run][component].keys()):
return self._log_datas[run][component][tag] return self._log_datas[run][component][tag]
else: else:
file_path = bfile.join(run, self.walks[run]) file_path = bfile.join(run, self.walks[run])
reader = self._get_file_reader(file_path=file_path, update=False) reader = self._get_file_reader(file_path=file_path, update=False)
remain = self.get_remain(reader=reader) remain = self.get_remain(reader=reader)
data = self.read_log_data(remain=remain, update=False)[component][tag] data = self.read_log_data(
remain=remain, update=False)[component][tag]
data = self.parsing_from_proto(component, data) data = self.parsing_from_proto(component, data)
self._log_datas[run][component][tag] = data self._log_datas[run][component][tag] = data
return data return data
...@@ -303,7 +310,8 @@ class LogReader(object): ...@@ -303,7 +310,8 @@ class LogReader(object):
""" """
if self.reader is None and reader is None: if self.reader is None and reader is None:
raise RuntimeError("Please specify log path!") raise RuntimeError("Please specify log path!")
return self.reader.get_remain() if reader is None else reader.get_remain() return self.reader.get_remain(
) if reader is None else reader.get_remain()
def read_log_data(self, remain, update=True): def read_log_data(self, remain, update=True):
"""Parse data from log file without sampling. """Parse data from log file without sampling.
...@@ -311,7 +319,8 @@ class LogReader(object): ...@@ -311,7 +319,8 @@ class LogReader(object):
Args: Args:
remain: Raw data from log file. remain: Raw data from log file.
""" """
_log_data = collections.defaultdict(lambda: collections.defaultdict(list)) _log_data = collections.defaultdict(lambda: collections.defaultdict(
list))
for item in remain: for item in remain:
component, dir, tag, record = self.parse_from_bin(item) component, dir, tag, record = self.parse_from_bin(item)
_log_data[component][tag].append(record) _log_data[component][tag].append(record)
...@@ -343,6 +352,20 @@ class LogReader(object): ...@@ -343,6 +352,20 @@ class LogReader(object):
return components_set return components_set
def component_tabs(self, update=False):
"""Get component tabs used by vdl frontend.
"""
component_tabs = set()
if not self.logdir:
return component_tabs
if update is True:
self.load_new_data(update=update)
for component in set(self._tags.values()):
if component == 'meta_data':
continue
component_tabs.add(component)
return component_tabs
def load_new_data(self, update=True): def load_new_data(self, update=True):
"""Load remain data. """Load remain data.
......
...@@ -99,13 +99,19 @@ class Api(object): ...@@ -99,13 +99,19 @@ class Api(object):
def _get_with_retry(self, key, func, *args, **kwargs): def _get_with_retry(self, key, func, *args, **kwargs):
return self._cache(key, try_call, func, self._reader, *args, **kwargs) return self._cache(key, try_call, func, self._reader, *args, **kwargs)
def _get_with_reader(self, key, func, reader, *args, **kwargs):
return self._cache(key, func, reader, *args, **kwargs)
@result() @result()
def components(self): def components(self):
return self._get('data/components', lib.get_components) return self._get('data/components', lib.get_components)
def component_tabs(self):
'''
Get all component tabs supported by readers in Api.
'''
tabs = set()
tabs.update(self._reader.component_tabs(update=True))
tabs.update(self._graph_reader.component_tabs(update=True))
return tabs
@result() @result()
def runs(self): def runs(self):
return self._get('data/runs', lib.get_runs) return self._get('data/runs', lib.get_runs)
...@@ -380,6 +386,23 @@ class Api(object): ...@@ -380,6 +386,23 @@ class Api(object):
return lib.get_graph_all_nodes(graph_reader, run) return lib.get_graph_all_nodes(graph_reader, run)
@result()
def get_component_tabs(*apis, vdl_args, request_args):
'''
Get component tabs in all apis, so tabs can be presented according to existed data in frontend.
'''
all_tabs = set()
if vdl_args.component_tabs:
return list(vdl_args.component_tabs)
if vdl_args.logdir:
for api in apis:
all_tabs.update(api('component_tabs', request_args))
all_tabs.add('static_graph')
else:
return ['static_graph', 'x2paddle', 'fastdeploy_server']
return list(all_tabs)
def create_api_call(logdir, model, cache_timeout): def create_api_call(logdir, model, cache_timeout):
api = Api(logdir, model, cache_timeout) api = Api(logdir, model, cache_timeout)
routes = { routes = {
...@@ -426,7 +449,8 @@ def create_api_call(logdir, model, cache_timeout): ...@@ -426,7 +449,8 @@ def create_api_call(logdir, model, cache_timeout):
'hparams/data': (api.hparam_data, ['type']), 'hparams/data': (api.hparam_data, ['type']),
'hparams/indicators': (api.hparam_indicator, []), 'hparams/indicators': (api.hparam_indicator, []),
'hparams/list': (api.hparam_list, []), 'hparams/list': (api.hparam_list, []),
'hparams/metric': (api.hparam_metric, ['run', 'metric']) 'hparams/metric': (api.hparam_metric, ['run', 'metric']),
'component_tabs': (api.component_tabs, [])
} }
def call(path: str, args): def call(path: str, args):
......
...@@ -35,6 +35,7 @@ from visualdl import __version__ ...@@ -35,6 +35,7 @@ from visualdl import __version__
from visualdl.component.inference.model_convert_server import create_model_convert_api_call from visualdl.component.inference.model_convert_server import create_model_convert_api_call
from visualdl.component.profiler.profiler_server import create_profiler_api_call from visualdl.component.profiler.profiler_server import create_profiler_api_call
from visualdl.server.api import create_api_call from visualdl.server.api import create_api_call
from visualdl.server.api import get_component_tabs
from visualdl.server.args import parse_args from visualdl.server.args import parse_args
from visualdl.server.args import ParseArgs from visualdl.server.args import ParseArgs
from visualdl.server.log import info from visualdl.server.log import info
...@@ -152,6 +153,16 @@ def create_app(args): # noqa: C901 ...@@ -152,6 +153,16 @@ def create_app(args): # noqa: C901
return make_response( return make_response(
Response(data, mimetype=mimetype, headers=headers)) Response(data, mimetype=mimetype, headers=headers))
@app.route(api_path + '/component_tabs')
def component_tabs():
data, mimetype, headers = get_component_tabs(
api_call,
profiler_api_call,
vdl_args=args,
request_args=request.args)
return make_response(
Response(data, mimetype=mimetype, headers=headers))
@app.route(check_live_path) @app.route(check_live_path)
def check_live(): def check_live():
return '', 204 return '', 204
......
...@@ -12,13 +12,13 @@ ...@@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ======================================================================= # =======================================================================
import sys
import socket import socket
import sys
from argparse import ArgumentParser from argparse import ArgumentParser
from visualdl import __version__ from visualdl import __version__
from visualdl.server.log import (init_logger, logger) from visualdl.server.log import init_logger
from visualdl.server.log import logger
default_host = None default_host = None
default_port = 8040 default_port = 8040
...@@ -45,6 +45,7 @@ class DefaultArgs(object): ...@@ -45,6 +45,7 @@ class DefaultArgs(object):
self.theme = args.get('theme', None) self.theme = args.get('theme', None)
self.dest = args.get('dest', '') self.dest = args.get('dest', '')
self.behavior = args.get('behavior', '') self.behavior = args.get('behavior', '')
self.component_tabs = args.get('component_tabs', None)
def get_host(host=default_host, port=default_port): def get_host(host=default_host, port=default_port):
...@@ -73,6 +74,20 @@ def validate_args(args): ...@@ -73,6 +74,20 @@ def validate_args(args):
logger.error('Theme {} is not support.'.format(args.theme)) logger.error('Theme {} is not support.'.format(args.theme))
sys.exit(-1) sys.exit(-1)
# input unsupported component tab name
supported_tabs = [
'scalar', 'image', 'text', 'embeddings', 'audio', 'histogram',
'hyper_parameters', 'static_graph', 'dynamic_graph', 'pr_curve',
'roc_curve', 'profiler', 'x2paddle', 'fastdeploy_server'
]
if args.component_tabs is not None:
for component_tab in args.component_tabs:
if component_tab not in supported_tabs:
logger.error(
'Component_tab {} is not support. Please choose tabs in {}'
.format(component_tab, supported_tabs))
sys.exit(-1)
def format_args(args): def format_args(args):
# set default public path according to API mode option # set default public path according to API mode option
...@@ -112,6 +127,7 @@ class ParseArgs(object): ...@@ -112,6 +127,7 @@ class ParseArgs(object):
self.theme = args.theme self.theme = args.theme
self.dest = args.dest self.dest = args.dest
self.behavior = args.behavior self.behavior = args.behavior
self.component_tabs = args.component_tabs
def parse_args(): def parse_args():
...@@ -125,10 +141,14 @@ def parse_args(): ...@@ -125,10 +141,14 @@ def parse_args():
) )
parser.add_argument( parser.add_argument(
"--logdir", "--logdir", action="store", nargs="+", help="log file directory")
parser.add_argument(
"--component_tabs",
action="store", action="store",
nargs="+", nargs="+",
help="log file directory") help="component tabs presented in html page.")
parser.add_argument( parser.add_argument(
"-p", "-p",
"--port", "--port",
...@@ -155,7 +175,8 @@ def parse_args(): ...@@ -155,7 +175,8 @@ def parse_args():
dest="cache_timeout", dest="cache_timeout",
type=float, type=float,
default=default_cache_timeout, default=default_cache_timeout,
help="memory cache timeout duration in seconds (default: %(default)s)", ) help="memory cache timeout duration in seconds (default: %(default)s)",
)
parser.add_argument( parser.add_argument(
"-L", "-L",
"--language", "--language",
...@@ -169,27 +190,23 @@ def parse_args(): ...@@ -169,27 +190,23 @@ def parse_args():
action="store", action="store",
dest="public_path", dest="public_path",
default=None, default=None,
help="set public path" help="set public path")
)
parser.add_argument( parser.add_argument(
"--api-only", "--api-only",
action="store_true", action="store_true",
dest="api_only", dest="api_only",
default=False, default=False,
help="serve api only" help="serve api only")
)
parser.add_argument( parser.add_argument(
"--verbose", "--verbose",
"-v", "-v",
action="count", action="count",
default=0, default=0,
help="set log level, use -vvv... to get more information" help="set log level, use -vvv... to get more information")
)
parser.add_argument( parser.add_argument(
"--version", "--version",
action="version", action="version",
version="%(prog)s {}".format(__version__) version="%(prog)s {}".format(__version__))
)
parser.add_argument( parser.add_argument(
"--product", "--product",
type=str, type=str,
...@@ -201,25 +218,16 @@ def parse_args(): ...@@ -201,25 +218,16 @@ def parse_args():
action="store_false", action="store_false",
dest="telemetry", dest="telemetry",
default=True, default=True,
help="disable telemetry" help="disable telemetry")
)
parser.add_argument( parser.add_argument(
"--theme", "--theme",
action="store", action="store",
dest="theme", dest="theme",
default=None, default=None,
choices=support_themes, choices=support_themes,
help="set theme" help="set theme")
) parser.add_argument('dest', nargs='?', help='set destination for log')
parser.add_argument( parser.add_argument("behavior", nargs='?')
'dest',
nargs='?',
help='set destination for log'
)
parser.add_argument(
"behavior",
nargs='?'
)
args = parser.parse_args() args = parser.parse_args()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册