未验证 提交 ec6479a2 编写于 作者: 走神的阿圆's avatar 走神的阿圆 提交者: GitHub

Reacquire walks when acquire components info from frontend. (#639)

上级 ad6eff68
...@@ -23,7 +23,7 @@ copyright = '2020, PaddlePaddle' ...@@ -23,7 +23,7 @@ copyright = '2020, PaddlePaddle'
author = 'PaddlePaddle' author = 'PaddlePaddle'
# The full version, including alpha/beta/rc tags # The full version, including alpha/beta/rc tags
release = '2.0.0.beta1' release = '2.0.0.beta4'
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------
......
...@@ -199,9 +199,11 @@ class LogReader(object): ...@@ -199,9 +199,11 @@ class LogReader(object):
self.add_remain() self.add_remain()
return self._tags return self._tags
def components(self): def components(self, update=False):
"""Get components type used by vdl. """Get components type used by vdl.
""" """
if update is True:
self.load_new_data(update=update)
return list(set(self._tags.values())) return list(set(self._tags.values()))
def load_new_data(self, update=True): def load_new_data(self, update=True):
...@@ -210,5 +212,5 @@ class LogReader(object): ...@@ -210,5 +212,5 @@ class LogReader(object):
Make sure all readers for every vdl log file are registered, load all Make sure all readers for every vdl log file are registered, load all
remain data. remain data.
""" """
self.register_readers(update=True) self.register_readers(update=update)
self.add_remain() self.add_remain()
...@@ -22,7 +22,7 @@ from visualdl.utils.string_util import encode_tag, decode_tag ...@@ -22,7 +22,7 @@ from visualdl.utils.string_util import encode_tag, decode_tag
def get_components(log_reader): def get_components(log_reader):
return log_reader.components() return log_reader.components(update=True)
def get_runs(log_reader): def get_runs(log_reader):
......
# Copyright (c) 2017 VisualDL Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =======================================================================
from __future__ import print_function
import pprint
import unittest
from visualdl import LogReader, LogWriter
from . import lib
from .storage_mock import add_histogram, add_image, add_scalar
_retry_counter = 0
class LibTest(unittest.TestCase):
def setUp(self):
dir = "./tmp/mock"
writer = LogWriter(dir, sync_cycle=30)
add_scalar(writer, "train", "layer/scalar0/min", 1000, 1)
add_scalar(writer, "test", "layer/scalar0/min", 1000, 10)
add_scalar(writer, "valid", "layer/scalar0/min", 1000, 10)
add_scalar(writer, "train", "layer/scalar0/max", 1000, 1)
add_scalar(writer, "test", "layer/scalar0/max", 1000, 10)
add_scalar(writer, "valid", "layer/scalar0/max", 1000, 10)
add_image(writer, "train", "layer/image0", 7, 10, 1)
add_image(writer, "test", "layer/image0", 7, 10, 3)
add_image(writer, "train", "layer/image1", 7, 10, 1, shape=[30, 30, 2])
add_image(writer, "test", "layer/image1", 7, 10, 1, shape=[30, 30, 2])
add_histogram(writer, "train", "layer/histogram0", 100)
add_histogram(writer, "test", "layer/histogram0", 100)
self.reader = LogReader(dir)
def test_retry(self):
ntimes = 7
time2sleep = 1
def func():
global _retry_counter
if _retry_counter < 5:
_retry_counter += 1
raise
return _retry_counter
lib.retry(ntimes, func, time2sleep)
self.assertEqual(_retry_counter, 5)
def test_modes(self):
modes = lib.get_modes(self.reader)
self.assertEqual(
sorted(modes), sorted(["default", "train", "test", "valid"]))
def test_scalar(self):
tags = lib.get_scalar_tags(self.reader)
print('scalar tags:')
pprint.pprint(tags)
self.assertEqual(len(tags), 3)
self.assertEqual(
sorted(tags.keys()), sorted("train test valid".split()))
def test_image(self):
tags = lib.get_image_tags(self.reader)
self.assertEqual(len(tags), 2)
tags = lib.get_image_tag_steps(self.reader, 'train', 'layer/image0/0')
pprint.pprint(tags)
image = lib.get_invididual_image(self.reader, "train",
'layer/image0/0', 2)
print(image)
def test_histogram(self):
tags = lib.get_histogram_tags(self.reader)
self.assertEqual(len(tags), 2)
res = lib.get_histogram(self.reader, "train", "layer/histogram0")
pprint.pprint(res)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2017 VisualDL Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =======================================================================
import random
import numpy as np
def add_scalar(writer, mode, tag, num_steps, skip):
with writer.mode(mode) as my_writer:
scalar = my_writer.scalar(tag)
for i in range(num_steps):
if i % skip == 0:
scalar.add_record(i, random.random())
def add_image(writer,
mode,
tag,
num_samples,
num_passes,
step_cycle,
shape=[50, 50, 3]):
with writer.mode(mode) as writer_:
image_writer = writer_.image(tag, num_samples, step_cycle)
for pass_ in range(num_passes):
image_writer.start_sampling()
for ins in range(2 * num_samples):
data = np.random.random(shape) * 256
data = np.ndarray.flatten(data)
image_writer.add_sample(shape, list(data))
image_writer.finish_sampling()
def add_histogram(writer, mode, tag, num_buckets):
with writer.mode(mode) as writer:
histogram = writer.histogram(tag, num_buckets)
for i in range(10):
histogram.add_record(i, np.random.normal(
0.1 + i * 0.01, size=1000))
...@@ -13,4 +13,4 @@ ...@@ -13,4 +13,4 @@
# limitations under the License. # limitations under the License.
# ======================================================================= # =======================================================================
vdl_version = '2.0.0-beta.3' vdl_version = '2.0.0-beta.4'
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册