api.py 15.9 KB
Newer Older
P
Peter Pan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#!/user/bin/env python
# Copyright (c) 2017 VisualDL Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =======================================================================
import functools
import json
import os

C
chenjian 已提交
20 21
from flask import request

走神的阿圆's avatar
走神的阿圆 已提交
22
from visualdl import LogReader
C
chenjian 已提交
23 24
from visualdl.python.cache import MemCache
from visualdl.reader.graph_reader import GraphReader
P
Peter Pan 已提交
25
from visualdl.server import lib
C
chenjian 已提交
26
from visualdl.server.client_manager import ClientManager
P
Peter Pan 已提交
27 28 29 30 31 32 33
from visualdl.server.log import logger

error_retry_times = 3
error_sleep_time = 2  # seconds


def gen_result(data=None, status=0, msg=''):
C
chenjian 已提交
34
    return {'status': status, 'msg': msg, 'data': data}
P
Peter Pan 已提交
35 36


37
def result(mimetype='application/json', headers=None):
P
Peter Pan 已提交
38 39
    def decorator(func):
        @functools.wraps(func)
40
        def wrapper(self, *args, **kwargs):
C
chenjian 已提交
41 42 43 44 45 46 47 48
            data = None
            status = 0
            msg = ''
            try:
                data = func(self, *args, **kwargs)
            except Exception as e:
                msg = '{}'.format(e)
                status = -1
P
Peter Pan 已提交
49
            if mimetype == 'application/json':
C
chenjian 已提交
50
                data = json.dumps(gen_result(data, status, msg))
51 52 53 54 55
            if callable(headers):
                headers_output = headers(self)
            else:
                headers_output = headers
            return data, mimetype, headers_output
C
chenjian 已提交
56

P
Peter Pan 已提交
57
        return wrapper
C
chenjian 已提交
58

P
Peter Pan 已提交
59 60 61 62
    return decorator


def try_call(function, *args, **kwargs):
C
chenjian 已提交
63 64
    res = lib.retry(error_retry_times, function, error_sleep_time, *args,
                    **kwargs)
P
Peter Pan 已提交
65 66 67 68 69 70
    if not res:
        logger.error("Internal server error. Retry later.")
    return res


class Api(object):
71
    def __init__(self, logdir, model, cache_timeout):
P
Peter Pan 已提交
72
        self._reader = LogReader(logdir)
C
chenjian 已提交
73 74
        self._graph_reader = GraphReader(logdir)
        self._graph_reader.set_displayname(self._reader)
走神的阿圆's avatar
走神的阿圆 已提交
75
        if model:
C
chenjian 已提交
76 77
            if 'vdlgraph' in model:
                self._graph_reader.set_input_graph(model)
走神的阿圆's avatar
走神的阿圆 已提交
78
            self._reader.model = model
79 80 81
            self.model_name = os.path.basename(model)
        else:
            self.model_name = ''
C
chenjian 已提交
82
        self.graph_reader_client_manager = ClientManager(self._graph_reader)
P
Peter Pan 已提交
83 84 85 86 87 88 89 90 91 92
        # use a memory cache to reduce disk reading frequency.
        cache = MemCache(timeout=cache_timeout)
        self._cache = lib.cache_get(cache)

    def _get(self, key, func, *args, **kwargs):
        return self._cache(key, func, self._reader, *args, **kwargs)

    def _get_with_retry(self, key, func, *args, **kwargs):
        return self._cache(key, try_call, func, self._reader, *args, **kwargs)

C
chenjian 已提交
93 94 95
    def _get_with_reader(self, key, func, reader, *args, **kwargs):
        return self._cache(key, func, reader, *args, **kwargs)

P
Peter Pan 已提交
96 97 98 99 100 101 102 103
    @result()
    def components(self):
        return self._get('data/components', lib.get_components)

    @result()
    def runs(self):
        return self._get('data/runs', lib.get_runs)

C
chenjian 已提交
104 105 106 107
    @result()
    def graph_runs(self):
        client_ip = request.remote_addr
        graph_reader = self.graph_reader_client_manager.get_data(client_ip)
C
chenjian 已提交
108
        return lib.get_graph_runs(graph_reader)
C
chenjian 已提交
109

P
Peter Pan 已提交
110 111 112 113 114 115 116 117 118
    @result()
    def tags(self):
        return self._get('data/tags', lib.get_tags)

    @result()
    def logs(self):
        return self._get('data/logs', lib.get_logs)

    @result()
P
Peter Pan 已提交
119
    def scalar_tags(self):
C
chenjian 已提交
120 121
        return self._get_with_retry('data/plugin/scalars/tags',
                                    lib.get_scalar_tags)
P
Peter Pan 已提交
122 123

    @result()
P
Peter Pan 已提交
124
    def image_tags(self):
C
chenjian 已提交
125 126
        return self._get_with_retry('data/plugin/images/tags',
                                    lib.get_image_tags)
P
Peter Pan 已提交
127

走神的阿圆's avatar
走神的阿圆 已提交
128 129 130 131
    @result()
    def text_tags(self):
        return self._get_with_retry('data/plugin/text/tags', lib.get_text_tags)

P
Peter Pan 已提交
132 133
    @result()
    def audio_tags(self):
C
chenjian 已提交
134 135
        return self._get_with_retry('data/plugin/audio/tags',
                                    lib.get_audio_tags)
P
Peter Pan 已提交
136 137

    @result()
P
Peter Pan 已提交
138
    def embedding_tags(self):
C
chenjian 已提交
139 140
        return self._get_with_retry('data/plugin/embeddings/tags',
                                    lib.get_embeddings_tags)
P
Peter Pan 已提交
141

走神的阿圆's avatar
走神的阿圆 已提交
142 143
    @result()
    def pr_curve_tags(self):
C
chenjian 已提交
144 145
        return self._get_with_retry('data/plugin/pr_curves/tags',
                                    lib.get_pr_curve_tags)
P
Peter Pan 已提交
146

P
Peter Pan 已提交
147 148
    @result()
    def roc_curve_tags(self):
C
chenjian 已提交
149 150
        return self._get_with_retry('data/plugin/roc_curves/tags',
                                    lib.get_roc_curve_tags)
走神的阿圆's avatar
走神的阿圆 已提交
151

走神的阿圆's avatar
走神的阿圆 已提交
152 153
    @result()
    def hparam_importance(self):
C
chenjian 已提交
154 155
        return self._get_with_retry('data/plugin/hparams/importance',
                                    lib.get_hparam_importance)
走神的阿圆's avatar
走神的阿圆 已提交
156 157 158

    @result()
    def hparam_indicator(self):
C
chenjian 已提交
159 160
        return self._get_with_retry('data/plugin/hparams/indicators',
                                    lib.get_hparam_indicator)
走神的阿圆's avatar
走神的阿圆 已提交
161 162 163

    @result()
    def hparam_list(self):
C
chenjian 已提交
164 165
        return self._get_with_retry('data/plugin/hparams/list',
                                    lib.get_hparam_list)
走神的阿圆's avatar
走神的阿圆 已提交
166 167 168 169 170 171 172 173 174 175 176

    @result()
    def hparam_metric(self, run, metric):
        key = os.path.join('data/plugin/hparams/metric', run, metric)
        return self._get_with_retry(key, lib.get_hparam_metric, run, metric)

    @result('text/csv')
    def hparam_data(self, type='tsv'):
        key = os.path.join('data/plugin/hparams/data', type)
        return self._get_with_retry(key, lib.get_hparam_data, type)

P
Peter Pan 已提交
177
    @result()
P
Peter Pan 已提交
178
    def scalar_list(self, run, tag):
P
Peter Pan 已提交
179 180 181
        key = os.path.join('data/plugin/scalars/scalars', run, tag)
        return self._get_with_retry(key, lib.get_scalar, run, tag)

走神的阿圆's avatar
走神的阿圆 已提交
182 183 184 185
    @result('text/csv')
    def scalar_data(self, run, tag, type='tsv'):
        key = os.path.join('data/plugin/scalars/data', run, tag, type)
        return self._get_with_retry(key, lib.get_scalar_data, run, tag, type)
走神的阿圆's avatar
走神的阿圆 已提交
186

P
Peter Pan 已提交
187
    @result()
P
Peter Pan 已提交
188
    def image_list(self, mode, tag):
P
Peter Pan 已提交
189 190 191 192
        key = os.path.join('data/plugin/images/images', mode, tag)
        return self._get_with_retry(key, lib.get_image_tag_steps, mode, tag)

    @result('image/png')
P
Peter Pan 已提交
193
    def image_image(self, mode, tag, index=0):
P
Peter Pan 已提交
194
        index = int(index)
C
chenjian 已提交
195 196 197 198
        key = os.path.join('data/plugin/images/individualImage', mode, tag,
                           str(index))
        return self._get_with_retry(key, lib.get_individual_image, mode, tag,
                                    index)
P
Peter Pan 已提交
199

走神的阿圆's avatar
走神的阿圆 已提交
200 201 202 203 204 205 206 207
    @result()
    def text_list(self, mode, tag):
        key = os.path.join('data/plugin/text/text', mode, tag)
        return self._get_with_retry(key, lib.get_text_tag_steps, mode, tag)

    @result('text/plain')
    def text_text(self, mode, tag, index=0):
        index = int(index)
C
chenjian 已提交
208 209 210 211
        key = os.path.join('data/plugin/text/individualText', mode, tag,
                           str(index))
        return self._get_with_retry(key, lib.get_individual_text, mode, tag,
                                    index)
走神的阿圆's avatar
走神的阿圆 已提交
212

P
Peter Pan 已提交
213 214 215 216 217
    @result()
    def audio_list(self, run, tag):
        key = os.path.join('data/plugin/audio/audio', run, tag)
        return self._get_with_retry(key, lib.get_audio_tag_steps, run, tag)

P
Peter Pan 已提交
218
    @result('audio/wav')
P
Peter Pan 已提交
219 220
    def audio_audio(self, run, tag, index=0):
        index = int(index)
C
chenjian 已提交
221 222 223 224
        key = os.path.join('data/plugin/audio/individualAudio', run, tag,
                           str(index))
        return self._get_with_retry(key, lib.get_individual_audio, run, tag,
                                    index)
P
Peter Pan 已提交
225 226

    @result()
C
chenjian 已提交
227 228 229 230 231
    def embedding_embedding(self,
                            run,
                            tag='default',
                            reduction='pca',
                            dimension=2):
P
Peter Pan 已提交
232
        dimension = int(dimension)
C
chenjian 已提交
233 234 235 236
        key = os.path.join('data/plugin/embeddings/embeddings', run,
                           str(dimension), reduction)
        return self._get_with_retry(key, lib.get_embeddings, run, tag,
                                    reduction, dimension)
P
Peter Pan 已提交
237

走神的阿圆's avatar
走神的阿圆 已提交
238 239
    @result()
    def embedding_list(self):
C
chenjian 已提交
240 241
        return self._get_with_retry('data/plugin/embeddings/list',
                                    lib.get_embeddings_list)
走神的阿圆's avatar
走神的阿圆 已提交
242 243 244 245 246 247 248 249 250 251 252

    @result('text/tab-separated-values')
    def embedding_metadata(self, name):
        key = os.path.join('data/plugin/embeddings/metadata', name)
        return self._get_with_retry(key, lib.get_embedding_labels, name)

    @result('application/octet-stream')
    def embedding_tensor(self, name):
        key = os.path.join('data/plugin/embeddings/tensor', name)
        return self._get_with_retry(key, lib.get_embedding_tensors, name)

253 254
    @result()
    def histogram_tags(self):
C
chenjian 已提交
255 256
        return self._get_with_retry('data/plugin/histogram/tags',
                                    lib.get_histogram_tags)
257 258

    @result()
P
Peter Pan 已提交
259 260 261
    def histogram_list(self, run, tag):
        key = os.path.join('data/plugin/histogram/histogram', run, tag)
        return self._get_with_retry(key, lib.get_histogram, run, tag)
262

走神的阿圆's avatar
走神的阿圆 已提交
263 264 265 266
    @result()
    def pr_curves_pr_curve(self, run, tag):
        key = os.path.join('data/plugin/pr_curves/pr_curve', run, tag)
        return self._get_with_retry(key, lib.get_pr_curve, run, tag)
P
Peter Pan 已提交
267

P
Peter Pan 已提交
268 269 270 271
    @result()
    def roc_curves_roc_curve(self, run, tag):
        key = os.path.join('data/plugin/roc_curves/roc_curve', run, tag)
        return self._get_with_retry(key, lib.get_roc_curve, run, tag)
P
Peter Pan 已提交
272

走神的阿圆's avatar
走神的阿圆 已提交
273 274 275 276
    @result()
    def pr_curves_steps(self, run):
        key = os.path.join('data/plugin/pr_curves/steps', run)
        return self._get_with_retry(key, lib.get_pr_curve_step, run)
P
Peter Pan 已提交
277

P
Peter Pan 已提交
278 279 280 281
    @result()
    def roc_curves_steps(self, run):
        key = os.path.join('data/plugin/roc_curves/steps', run)
        return self._get_with_retry(key, lib.get_roc_curve_step, run)
走神的阿圆's avatar
走神的阿圆 已提交
282

C
chenjian 已提交
283 284 285 286 287 288 289
    @result('application/octet-stream', lambda s: {
        "Content-Disposition": 'attachment; filename="%s"' % s.model_name
    } if len(s.model_name) else None)
    def graph_static_graph(self):
        key = os.path.join('data/plugin/graphs/static_graph')
        return self._get_with_retry(key, lib.get_static_graph)

C
chenjian 已提交
290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371
    @result()
    def graph_graph(self, run, expand_all, refresh):
        client_ip = request.remote_addr
        graph_reader = self.graph_reader_client_manager.get_data(client_ip)
        if expand_all is not None:
            if (expand_all.lower() == 'true'):
                expand_all = True
            else:
                expand_all = False
        else:
            expand_all = False
        if refresh is not None:
            if (refresh.lower() == 'true'):
                refresh = True
            else:
                refresh = False
        else:
            refresh = True
        return lib.get_graph(
            graph_reader, run, expand_all=expand_all, refresh=refresh)

    @result()
    def graph_upload(self):
        client_ip = request.remote_addr
        graph_reader = self.graph_reader_client_manager.get_data(client_ip)
        files = request.files
        if 'file' in files:
            file_handle = request.files['file']
            if 'pdmodel' in file_handle.filename:
                graph_reader.set_input_graph(file_handle.stream.read(),
                                             'pdmodel')
            elif 'vdlgraph' in file_handle.filename:
                graph_reader.set_input_graph(file_handle.stream.read(),
                                             'vdlgraph')

    @result()
    def graph_manipulate(self, run, nodeid, expand, keep_state):
        client_ip = request.remote_addr
        graph_reader = self.graph_reader_client_manager.get_data(client_ip)
        if expand is not None:
            if (expand.lower() == 'true'):
                expand = True
            else:
                expand = False
        else:
            expand = False
        if keep_state is not None:
            if (keep_state.lower() == 'true'):
                keep_state = True
            else:
                keep_state = False
        else:
            keep_state = False
        return lib.get_graph(graph_reader, run, nodeid, expand, keep_state)

    @result()
    def graph_search(self, run, nodeid, keep_state, is_node):
        client_ip = request.remote_addr
        graph_reader = self.graph_reader_client_manager.get_data(client_ip)
        if keep_state is not None:
            if (keep_state.lower() == 'true'):
                keep_state = True
            else:
                keep_state = False
        else:
            keep_state = False

        if is_node is not None:
            if (is_node.lower() == 'true'):
                is_node = True
            else:
                is_node = False
        else:
            is_node = False
        return lib.get_graph_search(graph_reader, run, nodeid, keep_state,
                                    is_node)

    @result()
    def graph_get_all_nodes(self, run):
        client_ip = request.remote_addr
        graph_reader = self.graph_reader_client_manager.get_data(client_ip)
        return lib.get_graph_all_nodes(graph_reader, run)
P
Peter Pan 已提交
372

373 374 375

def create_api_call(logdir, model, cache_timeout):
    api = Api(logdir, model, cache_timeout)
P
Peter Pan 已提交
376 377 378
    routes = {
        'components': (api.components, []),
        'runs': (api.runs, []),
C
chenjian 已提交
379
        'graph_runs': (api.graph_runs, []),
P
Peter Pan 已提交
380 381
        'tags': (api.tags, []),
        'logs': (api.logs, []),
P
Peter Pan 已提交
382 383
        'scalar/tags': (api.scalar_tags, []),
        'image/tags': (api.image_tags, []),
走神的阿圆's avatar
走神的阿圆 已提交
384
        'text/tags': (api.text_tags, []),
P
Peter Pan 已提交
385
        'audio/tags': (api.audio_tags, []),
P
Peter Pan 已提交
386
        'embedding/tags': (api.embedding_tags, []),
387
        'histogram/tags': (api.histogram_tags, []),
走神的阿圆's avatar
走神的阿圆 已提交
388
        'pr-curve/tags': (api.pr_curve_tags, []),
P
Peter Pan 已提交
389
        'roc-curve/tags': (api.roc_curve_tags, []),
P
Peter Pan 已提交
390
        'scalar/list': (api.scalar_list, ['run', 'tag']),
走神的阿圆's avatar
走神的阿圆 已提交
391
        'scalar/data': (api.scalar_data, ['run', 'tag', 'type']),
P
Peter Pan 已提交
392 393
        'image/list': (api.image_list, ['run', 'tag']),
        'image/image': (api.image_image, ['run', 'tag', 'index']),
走神的阿圆's avatar
走神的阿圆 已提交
394 395
        'text/list': (api.text_list, ['run', 'tag']),
        'text/text': (api.text_text, ['run', 'tag', 'index']),
P
Peter Pan 已提交
396 397
        'audio/list': (api.audio_list, ['run', 'tag']),
        'audio/audio': (api.audio_audio, ['run', 'tag', 'index']),
C
chenjian 已提交
398 399
        'embedding/embedding': (api.embedding_embedding,
                                ['run', 'tag', 'reduction', 'dimension']),
走神的阿圆's avatar
走神的阿圆 已提交
400 401 402
        'embedding/list': (api.embedding_list, []),
        'embedding/tensor': (api.embedding_tensor, ['name']),
        'embedding/metadata': (api.embedding_metadata, ['name']),
P
Peter Pan 已提交
403
        'histogram/list': (api.histogram_list, ['run', 'tag']),
C
chenjian 已提交
404
        'graph/graph': (api.graph_graph, ['run', 'expand_all', 'refresh']),
C
chenjian 已提交
405
        'graph/static_graph': (api.graph_static_graph, []),
C
chenjian 已提交
406 407 408 409 410 411
        'graph/upload': (api.graph_upload, []),
        'graph/search': (api.graph_search,
                         ['run', 'nodeid', 'keep_state', 'is_node']),
        'graph/get_all_nodes': (api.graph_get_all_nodes, ['run']),
        'graph/manipulate': (api.graph_manipulate,
                             ['run', 'nodeid', 'expand', 'keep_state']),
走神的阿圆's avatar
走神的阿圆 已提交
412
        'pr-curve/list': (api.pr_curves_pr_curve, ['run', 'tag']),
P
Peter Pan 已提交
413 414
        'roc-curve/list': (api.roc_curves_roc_curve, ['run', 'tag']),
        'pr-curve/steps': (api.pr_curves_steps, ['run']),
走神的阿圆's avatar
走神的阿圆 已提交
415 416 417 418 419 420
        'roc-curve/steps': (api.roc_curves_steps, ['run']),
        'hparams/importance': (api.hparam_importance, []),
        'hparams/data': (api.hparam_data, ['type']),
        'hparams/indicators': (api.hparam_indicator, []),
        'hparams/list': (api.hparam_list, []),
        'hparams/metric': (api.hparam_metric, ['run', 'metric'])
P
Peter Pan 已提交
421 422 423 424 425
    }

    def call(path: str, args):
        route = routes.get(path)
        if not route:
C
chenjian 已提交
426 427
            return json.dumps(gen_result(
                status=1, msg='api not found')), 'application/json', None
P
Peter Pan 已提交
428 429 430 431 432
        method, call_arg_names = route
        call_args = [args.get(name) for name in call_arg_names]
        return method(*call_args)

    return call