lib.py 5.3 KB
Newer Older
S
superjom 已提交
1
import pprint
S
superjom 已提交
2
import random
S
superjom 已提交
3
import re
S
superjom 已提交
4
import urllib
S
superjom 已提交
5 6 7 8 9
from tempfile import NamedTemporaryFile

import numpy as np
from PIL import Image

S
superjom 已提交
10

S
superjom 已提交
11 12 13 14
def get_modes(storage):
    return storage.modes()


15
def get_tags(storage, component):
S
superjom 已提交
16 17
    result = {}
    for mode in storage.modes():
S
superjom 已提交
18
        with storage.mode(mode) as reader:
19
            tags = reader.tags(component)
S
superjom 已提交
20 21 22 23
            if tags:
                result[mode] = {}
                for tag in tags:
                    result[mode][tag] = {
S
superjom 已提交
24
                        'displayName': tag,
S
superjom 已提交
25 26
                        'description': "",
                    }
S
superjom 已提交
27 28 29
    return result


30 31 32 33
def get_scalar_tags(storage):
    return get_tags(storage, 'scalar')


34
def get_scalar(storage, mode, tag, num_records=300):
D
daminglu 已提交
35 36
    assert num_records > 1

S
superjom 已提交
37 38
    with storage.mode(mode) as reader:
        scalar = reader.scalar(tag)
S
superjom 已提交
39

S
superjom 已提交
40 41 42
        records = scalar.records()
        ids = scalar.ids()
        timestamps = scalar.timestamps()
S
superjom 已提交
43

S
superjom 已提交
44
        data = zip(timestamps, ids, records)
D
daminglu 已提交
45 46 47
        data_size = len(data)

        if data_size <= num_records:
S
superjom 已提交
48
            return data
S
superjom 已提交
49

D
daminglu 已提交
50 51 52 53 54 55 56 57 58 59 60 61 62
        span = float(data_size) / (num_records - 1)
        span_offset = 0

        data_idx = int(span_offset * span)
        sampled_data = []

        while data_idx < data_size:
            sampled_data.append(data[data_size - data_idx - 1])
            span_offset += 1
            data_idx = int(span_offset * span)

        sampled_data.append(data[0])
        return sampled_data[::-1]
S
superjom 已提交
63 64


S
superjom 已提交
65
def get_image_tags(storage):
S
superjom 已提交
66 67 68
    result = {}

    for mode in storage.modes():
S
superjom 已提交
69 70 71 72 73 74 75
        with storage.mode(mode) as reader:
            tags = reader.tags('image')
            if tags:
                result[mode] = {}
                for tag in tags:
                    image = reader.image(tag)
                    for i in xrange(max(1, image.num_samples())):
76 77
                        caption = tag if image.num_samples(
                        ) <= 1 else '%s/%d' % (tag, i)
S
superjom 已提交
78 79 80 81 82
                        result[mode][caption] = {
                            'displayName': caption,
                            'description': "",
                            'samples': 1,
                        }
S
superjom 已提交
83 84 85 86
    return result


def get_image_tag_steps(storage, mode, tag):
S
superjom 已提交
87
    print 'image_tag_steps,mode,tag:', mode, tag
S
superjom 已提交
88 89
    # remove suffix '/x'
    res = re.search(r".*/([0-9]+$)", tag)
S
superjom 已提交
90
    sample_index = 0
S
superjom 已提交
91
    origin_tag = tag
S
superjom 已提交
92 93
    if res:
        tag = tag[:tag.rfind('/')]
S
superjom 已提交
94
        sample_index = int(res.groups()[0])
S
superjom 已提交
95

S
superjom 已提交
96 97 98
    with storage.mode(mode) as reader:
        image = reader.image(tag)
        res = []
S
superjom 已提交
99

S
superjom 已提交
100 101
    for step_index in range(image.num_records()):
        record = image.record(step_index, sample_index)
S
superjom 已提交
102
        shape = record.shape()
S
superjom 已提交
103
        # TODO(ChunweiYan) remove this trick, some shape will be empty
S
superjom 已提交
104
        if not shape: continue
S
uptest  
superjom 已提交
105
        # assert shape, "%s,%s" % (mode, tag)
S
superjom 已提交
106 107
        query = urllib.urlencode({
            'sample': 0,
S
superjom 已提交
108
            'index': step_index,
S
superjom 已提交
109
            'tag': origin_tag,
S
superjom 已提交
110 111
            'run': mode,
        })
S
superjom 已提交
112
        res.append({
S
superjom 已提交
113 114 115 116 117
            'height': shape[0],
            'width': shape[1],
            'step': record.step_id(),
            'wall_time': image.timestamp(step_index),
            'query': query,
S
superjom 已提交
118 119 120 121
        })
    return res


S
superjom 已提交
122
def get_invididual_image(storage, mode, tag, step_index, max_size=80):
S
superjom 已提交
123 124 125 126 127 128 129 130 131 132
    with storage.mode(mode) as reader:
        res = re.search(r".*/([0-9]+$)", tag)
        # remove suffix '/x'
        if res:
            offset = int(res.groups()[0])
            tag = tag[:tag.rfind('/')]

        image = reader.image(tag)
        record = image.record(step_index, offset)

S
superjom 已提交
133 134
        shape = record.shape()

Q
Qiao Longfei 已提交
135 136 137
        if shape[2] == 1:
          shape = [shape[0], shape[1]]
        data = np.array(record.data(), dtype='uint8').reshape(shape)
S
superjom 已提交
138 139
        tempfile = NamedTemporaryFile(mode='w+b', suffix='.png')
        with Image.fromarray(data) as im:
S
superjom 已提交
140
            size = max(shape[0], shape[1])
S
superjom 已提交
141 142
            if size > max_size:
                scale = max_size * 1. / size
143
                scaled_shape = (int(shape[0] * scale), int(shape[1] * scale))
S
superjom 已提交
144
                im = im.resize(scaled_shape)
S
superjom 已提交
145 146 147
            im.save(tempfile)
        tempfile.seek(0, 0)
        return tempfile
S
superjom 已提交
148 149


150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
def get_histogram_tags(storage):
    return get_tags(storage, 'histogram')


def get_histogram(storage, mode, tag):
    with storage.mode(mode) as reader:
        histogram = reader.histogram(tag)
        res = []

        for i in xrange(histogram.num_records()):
            try:
                # some bug with protobuf, some times may overflow
                record = histogram.record(i)
            except:
                continue

            res.append([])
            py_record = res[-1]
            py_record.append(record.timestamp())
            py_record.append(record.step())
            py_record.append([])

            data = py_record[-1]
            for j in xrange(record.num_instances()):
                instance = record.instance(j)
                data.append(
                    [instance.left(),
                     instance.right(),
                     instance.frequency()])
        return res


S
superjom 已提交
182
if __name__ == '__main__':
183
    reader = storage.LogReader('./tmp/mock')
S
superjom 已提交
184
    tags = get_image_tags(reader)
S
superjom 已提交
185 186

    tags = get_image_tag_steps(reader, 'train', 'layer1/layer2/image0/0')
S
superjom 已提交
187
    pprint.pprint(tags)
S
superjom 已提交
188 189 190

    image = get_invididual_image(reader, "train", 'layer1/layer2/image0/0', 2)
    print image