未验证 提交 52f26263 编写于 作者: 走神的阿圆's avatar 走神的阿圆 提交者: GitHub

add_maxmin_value (#779)

上级 fcd6d434
......@@ -153,6 +153,25 @@ class Reservoir(object):
with self._mutex:
self._buckets[key].add_item(item)
def _add_scalar_item(self, key, item):
"""Add a new scalar item to reservoir buckets with given tag as key.
If bucket with key has not yet reached full size, each item will be
added.
If bucket with key is full, each item will be added with same
probability.
Add new item to buckets will always valid because self._buckets is a
collection.defaultdict.
Args:
key: Tag of one bucket to add new item.
item: New item to add to bucket.
"""
with self._mutex:
self._buckets[key].add_scalar_item(item)
def add_item(self, run, tag, item):
"""Add a new item to reservoir buckets with given tag as key.
......@@ -166,6 +185,19 @@ class Reservoir(object):
key = run + "/" + tag
self._add_item(key, item)
def add_scalar_item(self, run, tag, item):
"""Add a new scalar item to reservoir buckets with given tag as key.
For usage habits of VisualDL, actually call self._add_items()
Args:
run: Identity of one tablet.
tag: Identity of one record in tablet.
item: New item to add to bucket.
"""
key = run + "/" + tag
self._add_scalar_item(key, item)
def _cut_tail(self, key):
with self._mutex:
self._buckets[key].cut_tail()
......@@ -210,6 +242,9 @@ class _ReservoirBucket(object):
self._mutex = threading.Lock()
self._num_items_index = 0
self.max_scalar = None
self.min_scalar = None
def add_item(self, item):
""" Add an item to bucket, replacing an old item with probability.
......@@ -229,6 +264,39 @@ class _ReservoirBucket(object):
self._items.append(item)
self._num_items_index += 1
def add_scalar_item(self, item):
""" Add an scalar item to bucket, replacing an old item with probability.
Use reservoir sampling to add a new item to sampling bucket,
each item in a steam has same probability stay in the bucket.
Args:
item: The item to add to reservoir bucket.
"""
with self._mutex:
if not self.max_scalar or self.max_scalar.value < item.value:
self.max_scalar = item
if not self.min_scalar or self.min_scalar.value > item.value:
self.min_scalar = item
if len(self._items) < self._max_size or self._max_size == 0:
self._items.append(item)
else:
if item.id == self.min_scalar.id or item.id == self.max_scalar.id:
r = self._random.randint(1, self._max_size - 1)
else:
r = self._random.randint(1, self._num_items_index)
if r < self._max_size:
if self._items[r].id == self.min_scalar.id or self._items[r].id == self.max_scalar.id:
if r - 1 > 0:
r = r - 1
elif r + 1 < self._max_size:
r = r + 1
self._items.pop(r)
self._items.append(item)
self._num_items_index += 1
@property
def items(self):
"""Get self._items
......@@ -325,7 +393,10 @@ class DataManager(object):
item: The item to add to reservoir bucket.
"""
with self._mutex:
self._reservoirs[plugin].add_item(run, tag, item)
if 'scalar' == plugin:
self._reservoirs[plugin].add_scalar_item(run, tag, item)
else:
self._reservoirs[plugin].add_item(run, tag, item)
def get_keys(self):
"""Get all plugin buckets name.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册