lib.py 9.7 KB
Newer Older
J
Jeff Wang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2017 VisualDL Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =======================================================================

O
Oraoto 已提交
16
from __future__ import absolute_import
Y
Yan Chunwei 已提交
17
import sys
18
import time
走神的阿圆's avatar
走神的阿圆 已提交
19
import os
S
superjom 已提交
20
import numpy as np
21
from visualdl.server.log import logger
走神的阿圆's avatar
走神的阿圆 已提交
22
from visualdl.io import bfile
23
from visualdl.utils.string_util import encode_tag, decode_tag
24

S
superjom 已提交
25

走神的阿圆's avatar
走神的阿圆 已提交
26 27 28 29
MODIFY_PREFIX = {}
MODIFIED_RUNS = []


走神的阿圆's avatar
走神的阿圆 已提交
30 31 32 33
def s2ms(timestamp):
    return timestamp * 1000 if timestamp < 2000000000 else timestamp


34
def get_components(log_reader):
35 36 37
    components = log_reader.components(update=True)
    components.add('graph')
    return list(components)
S
superjom 已提交
38

S
superjom 已提交
39

40
def get_runs(log_reader):
走神的阿圆's avatar
走神的阿圆 已提交
41 42 43 44 45 46 47
    runs = []
    for item in log_reader.runs():
        if item in log_reader.tags2name:
            runs.append(log_reader.tags2name[item])
        else:
            runs.append(item)
    return runs
48 49


50 51
def get_tags(log_reader):
    return log_reader.tags()
S
superjom 已提交
52 53


54 55 56 57 58 59 60 61 62
def get_logs(log_reader, component):
    all_tag = log_reader.data_manager.get_reservoir(component).keys
    tags = {}
    for item in all_tag:
        index = item.rfind('/')
        run = item[0:index]
        tag = encode_tag(item[index + 1:])
        if run in tags.keys():
            tags[run].append(tag)
63
        else:
64
            tags[run] = [tag]
走神的阿圆's avatar
走神的阿圆 已提交
65 66 67
        if run not in log_reader.tags2name.keys():
            log_reader.tags2name[run] = run
            log_reader.name2tags[run] = run
走神的阿圆's avatar
走神的阿圆 已提交
68 69 70 71 72 73 74
    fake_tags = {}
    for key, value in tags.items():
        if key in log_reader.tags2name:
            fake_tags[log_reader.tags2name[key]] = value
        else:
            fake_tags[key] = value

走神的阿圆's avatar
走神的阿圆 已提交
75 76 77 78 79
    run2tag = {'runs': [], 'tags': []}
    for run, tags in fake_tags.items():
        run2tag['runs'].append(run)
        run2tag['tags'].append(tags)

走神的阿圆's avatar
走神的阿圆 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
    run_prefix = os.getenv('VISUALDL_RUN_PREFIX')
    global MODIFY_PREFIX, MODIFIED_RUNS
    if component not in MODIFY_PREFIX:
        MODIFY_PREFIX.update({component: False})
    if run_prefix and not MODIFY_PREFIX[component]:
        MODIFY_PREFIX[component] = True
        temp_name2tags = log_reader.name2tags.copy()
        for key, value in temp_name2tags.items():
            if key in MODIFIED_RUNS:
                continue
            index = key.find(run_prefix)
            if index != -1:
                temp_key = key[index+len(run_prefix):]

                log_reader.name2tags.pop(key)
                log_reader.name2tags.update({temp_key: value})

                log_reader.tags2name.pop(value)
                log_reader.tags2name.update({value: temp_key})

                run2tag['runs'][run2tag['runs'].index(key)] = temp_key
            else:
                temp_key = key

            MODIFIED_RUNS.append(temp_key)

走神的阿圆's avatar
走神的阿圆 已提交
106
    return run2tag
107 108


109 110
def get_scalar_tags(log_reader):
    return get_logs(log_reader, "scalar")
111 112


113
def get_scalar(log_reader, run, tag):
走神的阿圆's avatar
走神的阿圆 已提交
114
    run = log_reader.name2tags[run] if run in log_reader.name2tags else run
115 116 117
    log_reader.load_new_data()
    records = log_reader.data_manager.get_reservoir("scalar").get_items(
        run, decode_tag(tag))
走神的阿圆's avatar
走神的阿圆 已提交
118
    results = [[s2ms(item.timestamp), item.id, item.value] for item in records]
119
    return results
120 121


122 123
def get_image_tags(log_reader):
    return get_logs(log_reader, "image")
124 125


126
def get_image_tag_steps(log_reader, run, tag):
走神的阿圆's avatar
走神的阿圆 已提交
127
    run = log_reader.name2tags[run] if run in log_reader.name2tags else run
128 129 130 131 132
    log_reader.load_new_data()
    records = log_reader.data_manager.get_reservoir("image").get_items(
        run, decode_tag(tag))
    result = [{
        "step": item.id,
走神的阿圆's avatar
走神的阿圆 已提交
133
        "wallTime": s2ms(item.timestamp)
134 135
    } for item in records]
    return result
136 137


138
def get_individual_image(log_reader, run, tag, step_index):
走神的阿圆's avatar
走神的阿圆 已提交
139
    run = log_reader.name2tags[run] if run in log_reader.name2tags else run
140 141 142 143
    log_reader.load_new_data()
    records = log_reader.data_manager.get_reservoir("image").get_items(
        run, decode_tag(tag))
    return records[step_index].image.encoded_image_string
144 145


146 147
def get_audio_tags(log_reader):
    return get_logs(log_reader, "audio")
148 149


150
def get_audio_tag_steps(log_reader, run, tag):
走神的阿圆's avatar
走神的阿圆 已提交
151
    run = log_reader.name2tags[run] if run in log_reader.name2tags else run
152 153 154 155 156
    log_reader.load_new_data()
    records = log_reader.data_manager.get_reservoir("audio").get_items(
        run, decode_tag(tag))
    result = [{
        "step": item.id,
走神的阿圆's avatar
走神的阿圆 已提交
157
        "wallTime": s2ms(item.timestamp)
158 159
    } for item in records]
    return result
160 161


162
def get_individual_audio(log_reader, run, tag, step_index):
走神的阿圆's avatar
走神的阿圆 已提交
163
    run = log_reader.name2tags[run] if run in log_reader.name2tags else run
164 165 166
    log_reader.load_new_data()
    records = log_reader.data_manager.get_reservoir("audio").get_items(
        run, decode_tag(tag))
P
Peter Pan 已提交
167
    result = records[step_index].audio.encoded_audio_string
168
    return result
169 170


171 172 173 174
def get_embeddings_tags(log_reader):
    return get_logs(log_reader, "embeddings")


175 176 177 178
def get_histogram_tags(log_reader):
    return get_logs(log_reader, "histogram")


走神的阿圆's avatar
走神的阿圆 已提交
179 180 181 182 183
def get_pr_curve_tags(log_reader):
    return get_logs(log_reader, "pr_curve")


def get_pr_curve(log_reader, run, tag):
走神的阿圆's avatar
走神的阿圆 已提交
184
    run = log_reader.name2tags[run] if run in log_reader.name2tags else run
走神的阿圆's avatar
走神的阿圆 已提交
185 186 187 188 189 190 191 192
    log_reader.load_new_data()
    records = log_reader.data_manager.get_reservoir("pr_curve").get_items(
        run, decode_tag(tag))
    results = []
    for item in records:
        pr_curve = item.pr_curve
        length = len(pr_curve.precision)
        num_thresholds = [float(v) / length for v in range(1, length + 1)]
走神的阿圆's avatar
走神的阿圆 已提交
193
        results.append([s2ms(item.timestamp),
走神的阿圆's avatar
走神的阿圆 已提交
194 195 196 197 198 199 200 201 202 203 204 205
                        item.id,
                        list(pr_curve.precision),
                        list(pr_curve.recall),
                        list(pr_curve.TP),
                        list(pr_curve.FP),
                        list(pr_curve.TN),
                        list(pr_curve.FN),
                        num_thresholds])
    return results


def get_pr_curve_step(log_reader, run, tag=None):
走神的阿圆's avatar
走神的阿圆 已提交
206
    fake_run = run
走神的阿圆's avatar
走神的阿圆 已提交
207
    run = log_reader.name2tags[run] if run in log_reader.name2tags else run
走神的阿圆's avatar
走神的阿圆 已提交
208
    run2tag = get_pr_curve_tags(log_reader)
走神的阿圆's avatar
走神的阿圆 已提交
209
    tag = run2tag['tags'][run2tag['runs'].index(fake_run)][0]
走神的阿圆's avatar
走神的阿圆 已提交
210 211 212
    log_reader.load_new_data()
    records = log_reader.data_manager.get_reservoir("pr_curve").get_items(
        run, decode_tag(tag))
走神的阿圆's avatar
走神的阿圆 已提交
213
    results = [[s2ms(item.timestamp), item.id] for item in records]
走神的阿圆's avatar
走神的阿圆 已提交
214 215 216
    return results


217
def get_embeddings(log_reader, run, tag, reduction, dimension=2):
走神的阿圆's avatar
走神的阿圆 已提交
218
    run = log_reader.name2tags[run] if run in log_reader.name2tags else run
219 220 221
    log_reader.load_new_data()
    records = log_reader.data_manager.get_reservoir("embeddings").get_items(
        run, decode_tag(tag))
222

223 224 225 226 227 228
    labels = []
    vectors = []
    for item in records[0].embeddings.embeddings:
        labels.append(item.label)
        vectors.append(item.vectors)
    vectors = np.array(vectors)
229

230 231 232 233
    if reduction == 'tsne':
        import visualdl.server.tsne as tsne
        low_dim_embs = tsne.tsne(
            vectors, dimension, initial_dims=50, perplexity=30.0)
234

235 236
    elif reduction == 'pca':
        low_dim_embs = simple_pca(vectors, dimension)
237

238
    return {"embedding": low_dim_embs.tolist(), "labels": labels}
239 240


241
def get_histogram(log_reader, run, tag):
走神的阿圆's avatar
走神的阿圆 已提交
242
    run = log_reader.name2tags[run] if run in log_reader.name2tags else run
243 244 245 246 247 248 249 250 251 252 253 254
    log_reader.load_new_data()
    records = log_reader.data_manager.get_reservoir("histogram").get_items(
        run, decode_tag(tag))

    results = []
    for item in records:
        histogram = item.histogram
        hist = histogram.hist
        bin_edges = histogram.bin_edges
        histogram_data = []
        for index in range(len(hist)):
            histogram_data.append([bin_edges[index], bin_edges[index+1], hist[index]])
走神的阿圆's avatar
走神的阿圆 已提交
255
        results.append([s2ms(item.timestamp), item.id, histogram_data])
256 257 258 259

    return results


260 261 262
def get_graph(log_reader):
    result = b""
    if log_reader.model:
走神的阿圆's avatar
走神的阿圆 已提交
263 264
        with bfile.BFile(log_reader.model, 'rb') as bfp:
            result = bfp.read_file(log_reader.model)
265 266 267
    return result


268
def retry(ntimes, function, time2sleep, *args, **kwargs):
269
    """
270 271
    try to execute `function` `ntimes`, if exception catched, the thread will
    sleep `time2sleep` seconds.
272
    """
O
Oraoto 已提交
273
    for i in range(ntimes):
274 275
        try:
            return function(*args, **kwargs)
T
Thuan Nguyen 已提交
276
        except Exception:
277 278 279 280 281 282 283
            if i < ntimes-1:
                error_info = '\n'.join(map(str, sys.exc_info()))
                logger.error("Unexpected error: %s" % error_info)
                time.sleep(time2sleep)
            else:
                import traceback
                traceback.print_exc()
284

T
Thuan Nguyen 已提交
285

286 287 288 289 290 291 292 293 294
def cache_get(cache):
    def _handler(key, func, *args, **kwargs):
        data = cache.get(key)
        if data is None:
            logger.warning('update cache %s' % key)
            data = func(*args, **kwargs)
            cache.set(key, data)
            return data
        return data
T
Thuan Nguyen 已提交
295

296
    return _handler
J
Jeff Wang 已提交
297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319


def simple_pca(x, dimension):
    """
    A simple PCA implementation to do the dimension reduction.
    """

    # Center the data.
    x -= np.mean(x, axis=0)

    # Computing the Covariance Matrix
    cov = np.cov(x, rowvar=False)

    # Get eigenvectors and eigenvalues from the covariance matrix
    eigvals, eigvecs = np.linalg.eig(cov)

    # Sort the eigvals from high to low
    order = np.argsort(eigvals)[::-1]

    # Drop the eigenvectors with low eigenvalues
    eigvecs = eigvecs[:, order[:dimension]]

    return np.dot(x, eigvecs)