lib.py 7.1 KB
Newer Older
J
Jeff Wang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2017 VisualDL Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =======================================================================

O
Oraoto 已提交
16
from __future__ import absolute_import
Y
Yan Chunwei 已提交
17
import sys
18
import time
S
superjom 已提交
19
import numpy as np
20 21
from visualdl.server.log import logger
from visualdl.utils.string_util import encode_tag, decode_tag
22

S
superjom 已提交
23

24
def get_components(log_reader):
25 26 27
    components = log_reader.components(update=True)
    components.add('graph')
    return list(components)
S
superjom 已提交
28

S
superjom 已提交
29

30 31
def get_runs(log_reader):
    return log_reader.runs()
32 33


34 35
def get_tags(log_reader):
    return log_reader.tags()
S
superjom 已提交
36 37


38 39 40 41 42 43 44 45 46
def get_logs(log_reader, component):
    all_tag = log_reader.data_manager.get_reservoir(component).keys
    tags = {}
    for item in all_tag:
        index = item.rfind('/')
        run = item[0:index]
        tag = encode_tag(item[index + 1:])
        if run in tags.keys():
            tags[run].append(tag)
47
        else:
48 49
            tags[run] = [tag]
    return tags
50 51


52 53
def get_scalar_tags(log_reader):
    return get_logs(log_reader, "scalar")
54 55


56 57 58 59 60 61
def get_scalar(log_reader, run, tag):
    log_reader.load_new_data()
    records = log_reader.data_manager.get_reservoir("scalar").get_items(
        run, decode_tag(tag))
    results = [[item.timestamp, item.id, item.value] for item in records]
    return results
62 63


64 65
def get_image_tags(log_reader):
    return get_logs(log_reader, "image")
66 67


68 69 70 71 72 73 74 75 76
def get_image_tag_steps(log_reader, run, tag):
    log_reader.load_new_data()
    records = log_reader.data_manager.get_reservoir("image").get_items(
        run, decode_tag(tag))
    result = [{
        "step": item.id,
        "wallTime": item.timestamp
    } for item in records]
    return result
77 78


79 80 81 82 83
def get_individual_image(log_reader, run, tag, step_index):
    log_reader.load_new_data()
    records = log_reader.data_manager.get_reservoir("image").get_items(
        run, decode_tag(tag))
    return records[step_index].image.encoded_image_string
84 85


86 87
def get_audio_tags(log_reader):
    return get_logs(log_reader, "audio")
88 89


90 91 92 93 94 95 96 97 98
def get_audio_tag_steps(log_reader, run, tag):
    log_reader.load_new_data()
    records = log_reader.data_manager.get_reservoir("audio").get_items(
        run, decode_tag(tag))
    result = [{
        "step": item.id,
        "wallTime": item.timestamp
    } for item in records]
    return result
99 100


101 102 103 104
def get_individual_audio(log_reader, run, tag, step_index):
    log_reader.load_new_data()
    records = log_reader.data_manager.get_reservoir("audio").get_items(
        run, decode_tag(tag))
P
Peter Pan 已提交
105
    result = records[step_index].audio.encoded_audio_string
106
    return result
107 108


109 110 111 112
def get_embeddings_tags(log_reader):
    return get_logs(log_reader, "embeddings")


113 114 115 116
def get_histogram_tags(log_reader):
    return get_logs(log_reader, "histogram")


走神的阿圆's avatar
走神的阿圆 已提交
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
def get_pr_curve_tags(log_reader):
    return get_logs(log_reader, "pr_curve")


def get_pr_curve(log_reader, run, tag):
    log_reader.load_new_data()
    records = log_reader.data_manager.get_reservoir("pr_curve").get_items(
        run, decode_tag(tag))
    results = []
    for item in records:
        pr_curve = item.pr_curve
        length = len(pr_curve.precision)
        num_thresholds = [float(v) / length for v in range(1, length + 1)]
        results.append([item.timestamp,
                        item.id,
                        list(pr_curve.precision),
                        list(pr_curve.recall),
                        list(pr_curve.TP),
                        list(pr_curve.FP),
                        list(pr_curve.TN),
                        list(pr_curve.FN),
                        num_thresholds])
    return results


def get_pr_curve_step(log_reader, run, tag=None):
    tag = get_pr_curve_tags(log_reader)[run][0] if tag is None else tag
    log_reader.load_new_data()
    records = log_reader.data_manager.get_reservoir("pr_curve").get_items(
        run, decode_tag(tag))
    results = [[item.timestamp, item.id] for item in records]
    return results


151
def get_embeddings(log_reader, run, tag, reduction, dimension=2):
152 153 154
    log_reader.load_new_data()
    records = log_reader.data_manager.get_reservoir("embeddings").get_items(
        run, decode_tag(tag))
155

156 157 158 159 160 161
    labels = []
    vectors = []
    for item in records[0].embeddings.embeddings:
        labels.append(item.label)
        vectors.append(item.vectors)
    vectors = np.array(vectors)
162

163 164 165 166
    if reduction == 'tsne':
        import visualdl.server.tsne as tsne
        low_dim_embs = tsne.tsne(
            vectors, dimension, initial_dims=50, perplexity=30.0)
167

168 169
    elif reduction == 'pca':
        low_dim_embs = simple_pca(vectors, dimension)
170

171
    return {"embedding": low_dim_embs.tolist(), "labels": labels}
172 173


174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191
def get_histogram(log_reader, run, tag):
    log_reader.load_new_data()
    records = log_reader.data_manager.get_reservoir("histogram").get_items(
        run, decode_tag(tag))

    results = []
    for item in records:
        histogram = item.histogram
        hist = histogram.hist
        bin_edges = histogram.bin_edges
        histogram_data = []
        for index in range(len(hist)):
            histogram_data.append([bin_edges[index], bin_edges[index+1], hist[index]])
        results.append([item.timestamp, item.id, histogram_data])

    return results


192 193 194 195 196 197 198 199
def get_graph(log_reader):
    result = b""
    if log_reader.model:
        with open(log_reader.model, "rb") as fp:
            result = fp.read()
    return result


200 201 202 203 204
def retry(ntimes, function, time2sleep, *args, **kwargs):
    '''
    try to execute `function` `ntimes`, if exception catched, the thread will
    sleep `time2sleep` seconds.
    '''
O
Oraoto 已提交
205
    for i in range(ntimes):
206 207
        try:
            return function(*args, **kwargs)
T
Thuan Nguyen 已提交
208
        except Exception:
Y
Yan Chunwei 已提交
209 210
            error_info = '\n'.join(map(str, sys.exc_info()))
            logger.error("Unexpected error: %s" % error_info)
211
            time.sleep(time2sleep)
212

T
Thuan Nguyen 已提交
213

214 215 216 217 218 219 220 221 222
def cache_get(cache):
    def _handler(key, func, *args, **kwargs):
        data = cache.get(key)
        if data is None:
            logger.warning('update cache %s' % key)
            data = func(*args, **kwargs)
            cache.set(key, data)
            return data
        return data
T
Thuan Nguyen 已提交
223

224
    return _handler
J
Jeff Wang 已提交
225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247


def simple_pca(x, dimension):
    """
    A simple PCA implementation to do the dimension reduction.
    """

    # Center the data.
    x -= np.mean(x, axis=0)

    # Computing the Covariance Matrix
    cov = np.cov(x, rowvar=False)

    # Get eigenvectors and eigenvalues from the covariance matrix
    eigvals, eigvecs = np.linalg.eig(cov)

    # Sort the eigvals from high to low
    order = np.argsort(eigvals)[::-1]

    # Drop the eigenvectors with low eigenvalues
    eigvecs = eigvecs[:, order[:dimension]]

    return np.dot(x, eigvecs)