lib.py 7.1 KB
Newer Older
J
Jeff Wang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2017 VisualDL Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =======================================================================

O
Oraoto 已提交
16
from __future__ import absolute_import
Y
Yan Chunwei 已提交
17
import sys
18
import time
S
superjom 已提交
19
import numpy as np
20
from visualdl.server.log import logger
走神的阿圆's avatar
走神的阿圆 已提交
21
from visualdl.io import bfile
22
from visualdl.utils.string_util import encode_tag, decode_tag
23

S
superjom 已提交
24

25
def get_components(log_reader):
26 27 28
    components = log_reader.components(update=True)
    components.add('graph')
    return list(components)
S
superjom 已提交
29

S
superjom 已提交
30

31 32
def get_runs(log_reader):
    return log_reader.runs()
33 34


35 36
def get_tags(log_reader):
    return log_reader.tags()
S
superjom 已提交
37 38


39 40 41 42 43 44 45 46 47
def get_logs(log_reader, component):
    all_tag = log_reader.data_manager.get_reservoir(component).keys
    tags = {}
    for item in all_tag:
        index = item.rfind('/')
        run = item[0:index]
        tag = encode_tag(item[index + 1:])
        if run in tags.keys():
            tags[run].append(tag)
48
        else:
49 50
            tags[run] = [tag]
    return tags
51 52


53 54
def get_scalar_tags(log_reader):
    return get_logs(log_reader, "scalar")
55 56


57 58 59 60 61 62
def get_scalar(log_reader, run, tag):
    log_reader.load_new_data()
    records = log_reader.data_manager.get_reservoir("scalar").get_items(
        run, decode_tag(tag))
    results = [[item.timestamp, item.id, item.value] for item in records]
    return results
63 64


65 66
def get_image_tags(log_reader):
    return get_logs(log_reader, "image")
67 68


69 70 71 72 73 74 75 76 77
def get_image_tag_steps(log_reader, run, tag):
    log_reader.load_new_data()
    records = log_reader.data_manager.get_reservoir("image").get_items(
        run, decode_tag(tag))
    result = [{
        "step": item.id,
        "wallTime": item.timestamp
    } for item in records]
    return result
78 79


80 81 82 83 84
def get_individual_image(log_reader, run, tag, step_index):
    log_reader.load_new_data()
    records = log_reader.data_manager.get_reservoir("image").get_items(
        run, decode_tag(tag))
    return records[step_index].image.encoded_image_string
85 86


87 88
def get_audio_tags(log_reader):
    return get_logs(log_reader, "audio")
89 90


91 92 93 94 95 96 97 98 99
def get_audio_tag_steps(log_reader, run, tag):
    log_reader.load_new_data()
    records = log_reader.data_manager.get_reservoir("audio").get_items(
        run, decode_tag(tag))
    result = [{
        "step": item.id,
        "wallTime": item.timestamp
    } for item in records]
    return result
100 101


102 103 104 105
def get_individual_audio(log_reader, run, tag, step_index):
    log_reader.load_new_data()
    records = log_reader.data_manager.get_reservoir("audio").get_items(
        run, decode_tag(tag))
P
Peter Pan 已提交
106
    result = records[step_index].audio.encoded_audio_string
107
    return result
108 109


110 111 112 113
def get_embeddings_tags(log_reader):
    return get_logs(log_reader, "embeddings")


114 115 116 117
def get_histogram_tags(log_reader):
    return get_logs(log_reader, "histogram")


走神的阿圆's avatar
走神的阿圆 已提交
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
def get_pr_curve_tags(log_reader):
    return get_logs(log_reader, "pr_curve")


def get_pr_curve(log_reader, run, tag):
    log_reader.load_new_data()
    records = log_reader.data_manager.get_reservoir("pr_curve").get_items(
        run, decode_tag(tag))
    results = []
    for item in records:
        pr_curve = item.pr_curve
        length = len(pr_curve.precision)
        num_thresholds = [float(v) / length for v in range(1, length + 1)]
        results.append([item.timestamp,
                        item.id,
                        list(pr_curve.precision),
                        list(pr_curve.recall),
                        list(pr_curve.TP),
                        list(pr_curve.FP),
                        list(pr_curve.TN),
                        list(pr_curve.FN),
                        num_thresholds])
    return results


def get_pr_curve_step(log_reader, run, tag=None):
    tag = get_pr_curve_tags(log_reader)[run][0] if tag is None else tag
    log_reader.load_new_data()
    records = log_reader.data_manager.get_reservoir("pr_curve").get_items(
        run, decode_tag(tag))
    results = [[item.timestamp, item.id] for item in records]
    return results


152
def get_embeddings(log_reader, run, tag, reduction, dimension=2):
153 154 155
    log_reader.load_new_data()
    records = log_reader.data_manager.get_reservoir("embeddings").get_items(
        run, decode_tag(tag))
156

157 158 159 160 161 162
    labels = []
    vectors = []
    for item in records[0].embeddings.embeddings:
        labels.append(item.label)
        vectors.append(item.vectors)
    vectors = np.array(vectors)
163

164 165 166 167
    if reduction == 'tsne':
        import visualdl.server.tsne as tsne
        low_dim_embs = tsne.tsne(
            vectors, dimension, initial_dims=50, perplexity=30.0)
168

169 170
    elif reduction == 'pca':
        low_dim_embs = simple_pca(vectors, dimension)
171

172
    return {"embedding": low_dim_embs.tolist(), "labels": labels}
173 174


175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
def get_histogram(log_reader, run, tag):
    log_reader.load_new_data()
    records = log_reader.data_manager.get_reservoir("histogram").get_items(
        run, decode_tag(tag))

    results = []
    for item in records:
        histogram = item.histogram
        hist = histogram.hist
        bin_edges = histogram.bin_edges
        histogram_data = []
        for index in range(len(hist)):
            histogram_data.append([bin_edges[index], bin_edges[index+1], hist[index]])
        results.append([item.timestamp, item.id, histogram_data])

    return results


193 194 195
def get_graph(log_reader):
    result = b""
    if log_reader.model:
走神的阿圆's avatar
走神的阿圆 已提交
196 197
        with bfile.BFile(log_reader.model, 'rb') as bfp:
            result = bfp.read_file(log_reader.model)
198 199 200
    return result


201 202 203 204 205
def retry(ntimes, function, time2sleep, *args, **kwargs):
    '''
    try to execute `function` `ntimes`, if exception catched, the thread will
    sleep `time2sleep` seconds.
    '''
O
Oraoto 已提交
206
    for i in range(ntimes):
207 208
        try:
            return function(*args, **kwargs)
T
Thuan Nguyen 已提交
209
        except Exception:
Y
Yan Chunwei 已提交
210 211
            error_info = '\n'.join(map(str, sys.exc_info()))
            logger.error("Unexpected error: %s" % error_info)
212
            time.sleep(time2sleep)
213

T
Thuan Nguyen 已提交
214

215 216 217 218 219 220 221 222 223
def cache_get(cache):
    def _handler(key, func, *args, **kwargs):
        data = cache.get(key)
        if data is None:
            logger.warning('update cache %s' % key)
            data = func(*args, **kwargs)
            cache.set(key, data)
            return data
        return data
T
Thuan Nguyen 已提交
224

225
    return _handler
J
Jeff Wang 已提交
226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248


def simple_pca(x, dimension):
    """
    A simple PCA implementation to do the dimension reduction.
    """

    # Center the data.
    x -= np.mean(x, axis=0)

    # Computing the Covariance Matrix
    cov = np.cov(x, rowvar=False)

    # Get eigenvectors and eigenvalues from the covariance matrix
    eigvals, eigvecs = np.linalg.eig(cov)

    # Sort the eigvals from high to low
    order = np.argsort(eigvals)[::-1]

    # Drop the eigenvectors with low eigenvalues
    eigvecs = eigvecs[:, order[:dimension]]

    return np.dot(x, eigvecs)