diff --git a/visualdl/component/profiler/profiler_reader.py b/visualdl/component/profiler/profiler_reader.py index 79f19543df62410ca4ff9c6e6089a27f0c9db7bb..15985802d716b9beb2f6fbb1cd193780ebb80fa6 100644 --- a/visualdl/component/profiler/profiler_reader.py +++ b/visualdl/component/profiler/profiler_reader.py @@ -88,6 +88,18 @@ class ProfilerReader(object): else: return None + def component_tabs(self, update=False): + """Get component tabs used by vdl frontend. + """ + component_tabs = set() + if not self.logdir: + return component_tabs + if update is True: + self.runs(update=update) + if self.walks: + component_tabs.add('profiler') + return component_tabs + def profile_runs(self, update=False): """Get profile run files. diff --git a/visualdl/component/profiler/profiler_server.py b/visualdl/component/profiler/profiler_server.py index bb497706325b80e8220cd0b81231f9ce2a98c408..408c334991d690eebd3b48d57b628e4024d3239e 100644 --- a/visualdl/component/profiler/profiler_server.py +++ b/visualdl/component/profiler/profiler_server.py @@ -59,6 +59,14 @@ class ProfilerApi(object): lang = lang.lower() return self._reader.get_descriptions(lang) + def component_tabs(self): + ''' + Get all component tabs supported by readers in Api. + ''' + tabs = set() + tabs.update(self._reader.component_tabs(update=True)) + return tabs + @result() def overview_environment(self, run, worker, span): run_manager = self._reader.get_run_manager(run) @@ -382,7 +390,8 @@ def create_profiler_api_call(logdir): 'comparison/phase_table_inner': (api.comparison_phase_table_inner, [ "base_run", "base_worker", "base_span", "exp_run", "exp_worker", "exp_span", "phase_name" - ]) + ]), + 'component_tabs': (api.component_tabs, []) } def call(path: str, args): diff --git a/visualdl/reader/graph_reader.py b/visualdl/reader/graph_reader.py index 90621d0d2a12338f1c526f911ce6799e42efe931..1acc99ed9c6e9b844c78d63a9e709282cd0c8be8 100644 --- a/visualdl/reader/graph_reader.py +++ b/visualdl/reader/graph_reader.py @@ -60,6 +60,18 @@ class GraphReader(object): def logdir(self): return self.dir + def component_tabs(self, update=False): + """Get component tabs used by vdl frontend. + """ + component_tabs = set() + if not self.logdir: + return component_tabs + if update is True: + self.runs(update=update) + if self.walks: + component_tabs.add('dynamic_graph') + return component_tabs + def get_all_walk(self): flush_walks = {} for dir in self.dir: diff --git a/visualdl/reader/reader.py b/visualdl/reader/reader.py index 5dfcd9f917461c9fb5ee40c67c7b87990c2be08f..081df075b872b93c29322ead67da2035a337150f 100644 --- a/visualdl/reader/reader.py +++ b/visualdl/reader/reader.py @@ -12,15 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # ======================================================================= - import collections from functools import partial # noqa: F401 -from visualdl.io import bfile + from visualdl.component import components +from visualdl.io import bfile +from visualdl.proto import record_pb2 from visualdl.reader.record_reader import RecordReader from visualdl.server.data_manager import default_data_manager -from visualdl.proto import record_pb2 -from visualdl.utils.string_util import decode_tag, encode_tag +from visualdl.utils.string_util import decode_tag +from visualdl.utils.string_util import encode_tag def is_VDLRecord_file(path, check=False): @@ -74,10 +75,12 @@ class LogReader(object): # {'run': {'scalar': {'tag1': data, 'tag2': data}}} self._log_datas = collections.defaultdict( - lambda: collections.defaultdict(lambda: collections.defaultdict(list))) + lambda: collections.defaultdict(lambda: collections.defaultdict( + list))) if file_path: - self._log_data = collections.defaultdict(lambda: collections.defaultdict(list)) + self._log_data = collections.defaultdict(lambda: collections. + defaultdict(list)) self.get_file_reader(file_path=file_path) remain = self.get_remain() self.read_log_data(remain=remain) @@ -100,12 +103,15 @@ class LogReader(object): self._model = model_path with bfile.BFile(model_path, 'rb') as bfp: if not bfp.isfile(model_path): - print("Model path %s should be file path, please check this path." % model_path) + print( + "Model path %s should be file path, please check this path." + % model_path) else: if bfile.exists(model_path): self._model = model_path else: - print("Model path %s is invalid, please check this path." % model_path) + print("Model path %s is invalid, please check this path." % + model_path) @property def logdir(self): @@ -130,15 +136,16 @@ class LogReader(object): return log_tags def get_log_data(self, component, run, tag): - if (run in self._log_datas.keys() and - component in self._log_datas[run].keys() and - tag in self._log_datas[run][component].keys()): + if (run in self._log_datas.keys() + and component in self._log_datas[run].keys() + and tag in self._log_datas[run][component].keys()): return self._log_datas[run][component][tag] else: file_path = bfile.join(run, self.walks[run]) reader = self._get_file_reader(file_path=file_path, update=False) remain = self.get_remain(reader=reader) - data = self.read_log_data(remain=remain, update=False)[component][tag] + data = self.read_log_data( + remain=remain, update=False)[component][tag] data = self.parsing_from_proto(component, data) self._log_datas[run][component][tag] = data return data @@ -303,7 +310,8 @@ class LogReader(object): """ if self.reader is None and reader is None: raise RuntimeError("Please specify log path!") - return self.reader.get_remain() if reader is None else reader.get_remain() + return self.reader.get_remain( + ) if reader is None else reader.get_remain() def read_log_data(self, remain, update=True): """Parse data from log file without sampling. @@ -311,7 +319,8 @@ class LogReader(object): Args: remain: Raw data from log file. """ - _log_data = collections.defaultdict(lambda: collections.defaultdict(list)) + _log_data = collections.defaultdict(lambda: collections.defaultdict( + list)) for item in remain: component, dir, tag, record = self.parse_from_bin(item) _log_data[component][tag].append(record) @@ -343,6 +352,20 @@ class LogReader(object): return components_set + def component_tabs(self, update=False): + """Get component tabs used by vdl frontend. + """ + component_tabs = set() + if not self.logdir: + return component_tabs + if update is True: + self.load_new_data(update=update) + for component in set(self._tags.values()): + if component == 'meta_data': + continue + component_tabs.add(component) + return component_tabs + def load_new_data(self, update=True): """Load remain data. diff --git a/visualdl/server/api.py b/visualdl/server/api.py index 195b93602b014f0949387fb6bd24bb4dfbea35f9..2ea87d52fd7b837314cfe72250ab7e59dfefe22f 100644 --- a/visualdl/server/api.py +++ b/visualdl/server/api.py @@ -99,13 +99,19 @@ class Api(object): def _get_with_retry(self, key, func, *args, **kwargs): return self._cache(key, try_call, func, self._reader, *args, **kwargs) - def _get_with_reader(self, key, func, reader, *args, **kwargs): - return self._cache(key, func, reader, *args, **kwargs) - @result() def components(self): return self._get('data/components', lib.get_components) + def component_tabs(self): + ''' + Get all component tabs supported by readers in Api. + ''' + tabs = set() + tabs.update(self._reader.component_tabs(update=True)) + tabs.update(self._graph_reader.component_tabs(update=True)) + return tabs + @result() def runs(self): return self._get('data/runs', lib.get_runs) @@ -380,6 +386,23 @@ class Api(object): return lib.get_graph_all_nodes(graph_reader, run) +@result() +def get_component_tabs(*apis, vdl_args, request_args): + ''' + Get component tabs in all apis, so tabs can be presented according to existed data in frontend. + ''' + all_tabs = set() + if vdl_args.component_tabs: + return list(vdl_args.component_tabs) + if vdl_args.logdir: + for api in apis: + all_tabs.update(api('component_tabs', request_args)) + all_tabs.add('static_graph') + else: + return ['static_graph', 'x2paddle', 'fastdeploy_server'] + return list(all_tabs) + + def create_api_call(logdir, model, cache_timeout): api = Api(logdir, model, cache_timeout) routes = { @@ -426,7 +449,8 @@ def create_api_call(logdir, model, cache_timeout): 'hparams/data': (api.hparam_data, ['type']), 'hparams/indicators': (api.hparam_indicator, []), 'hparams/list': (api.hparam_list, []), - 'hparams/metric': (api.hparam_metric, ['run', 'metric']) + 'hparams/metric': (api.hparam_metric, ['run', 'metric']), + 'component_tabs': (api.component_tabs, []) } def call(path: str, args): diff --git a/visualdl/server/app.py b/visualdl/server/app.py index 06dff8163a396e2ff43ffebc31cb167c34811965..5f9454fa94ed2961a9cabf3407f69bee0e5d1eea 100644 --- a/visualdl/server/app.py +++ b/visualdl/server/app.py @@ -35,6 +35,7 @@ from visualdl import __version__ from visualdl.component.inference.model_convert_server import create_model_convert_api_call from visualdl.component.profiler.profiler_server import create_profiler_api_call from visualdl.server.api import create_api_call +from visualdl.server.api import get_component_tabs from visualdl.server.args import parse_args from visualdl.server.args import ParseArgs from visualdl.server.log import info @@ -152,6 +153,16 @@ def create_app(args): # noqa: C901 return make_response( Response(data, mimetype=mimetype, headers=headers)) + @app.route(api_path + '/component_tabs') + def component_tabs(): + data, mimetype, headers = get_component_tabs( + api_call, + profiler_api_call, + vdl_args=args, + request_args=request.args) + return make_response( + Response(data, mimetype=mimetype, headers=headers)) + @app.route(check_live_path) def check_live(): return '', 204 diff --git a/visualdl/server/args.py b/visualdl/server/args.py index aff593f79f349da0080d982444b96bc6d42f48e7..cb42422c7468e59c8cfc194bf7a8aeb8188ba04c 100644 --- a/visualdl/server/args.py +++ b/visualdl/server/args.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ======================================================================= - -import sys import socket +import sys from argparse import ArgumentParser from visualdl import __version__ -from visualdl.server.log import (init_logger, logger) +from visualdl.server.log import init_logger +from visualdl.server.log import logger default_host = None default_port = 8040 @@ -45,6 +45,7 @@ class DefaultArgs(object): self.theme = args.get('theme', None) self.dest = args.get('dest', '') self.behavior = args.get('behavior', '') + self.component_tabs = args.get('component_tabs', None) def get_host(host=default_host, port=default_port): @@ -73,6 +74,20 @@ def validate_args(args): logger.error('Theme {} is not support.'.format(args.theme)) sys.exit(-1) + # input unsupported component tab name + supported_tabs = [ + 'scalar', 'image', 'text', 'embeddings', 'audio', 'histogram', + 'hyper_parameters', 'static_graph', 'dynamic_graph', 'pr_curve', + 'roc_curve', 'profiler', 'x2paddle', 'fastdeploy_server' + ] + if args.component_tabs is not None: + for component_tab in args.component_tabs: + if component_tab not in supported_tabs: + logger.error( + 'Component_tab {} is not support. Please choose tabs in {}' + .format(component_tab, supported_tabs)) + sys.exit(-1) + def format_args(args): # set default public path according to API mode option @@ -112,6 +127,7 @@ class ParseArgs(object): self.theme = args.theme self.dest = args.dest self.behavior = args.behavior + self.component_tabs = args.component_tabs def parse_args(): @@ -125,10 +141,14 @@ def parse_args(): ) parser.add_argument( - "--logdir", + "--logdir", action="store", nargs="+", help="log file directory") + + parser.add_argument( + "--component_tabs", action="store", nargs="+", - help="log file directory") + help="component tabs presented in html page.") + parser.add_argument( "-p", "--port", @@ -155,7 +175,8 @@ def parse_args(): dest="cache_timeout", type=float, default=default_cache_timeout, - help="memory cache timeout duration in seconds (default: %(default)s)", ) + help="memory cache timeout duration in seconds (default: %(default)s)", + ) parser.add_argument( "-L", "--language", @@ -169,27 +190,23 @@ def parse_args(): action="store", dest="public_path", default=None, - help="set public path" - ) + help="set public path") parser.add_argument( "--api-only", action="store_true", dest="api_only", default=False, - help="serve api only" - ) + help="serve api only") parser.add_argument( "--verbose", "-v", action="count", default=0, - help="set log level, use -vvv... to get more information" - ) + help="set log level, use -vvv... to get more information") parser.add_argument( "--version", action="version", - version="%(prog)s {}".format(__version__) - ) + version="%(prog)s {}".format(__version__)) parser.add_argument( "--product", type=str, @@ -201,25 +218,16 @@ def parse_args(): action="store_false", dest="telemetry", default=True, - help="disable telemetry" - ) + help="disable telemetry") parser.add_argument( "--theme", action="store", dest="theme", default=None, choices=support_themes, - help="set theme" - ) - parser.add_argument( - 'dest', - nargs='?', - help='set destination for log' - ) - parser.add_argument( - "behavior", - nargs='?' - ) + help="set theme") + parser.add_argument('dest', nargs='?', help='set destination for log') + parser.add_argument("behavior", nargs='?') args = parser.parse_args()