提交 ec068cb5 编写于 作者: X Xu Jingxin

Support sample by grouped meta key

上级 12db1bad
......@@ -71,7 +71,9 @@ class Buffer:
indices: Optional[List[str]] = None,
replace: bool = False,
sample_range: Optional[slice] = None,
ignore_insufficient: bool = False
ignore_insufficient: bool = False,
groupby: str = None,
rolling_window: int = None
) -> List[BufferedData]:
"""
Overview:
......@@ -83,6 +85,8 @@ class Buffer:
- sample_range (:obj:`slice`): Sample range slice.
- ignore_insufficient (:obj:`bool`): If ignore_insufficient is true, sampling more than buffer size
with no repetition will not cause an exception.
- groupby (:obj:`str`): Groupby key in meta.
- rolling_window (:obj:`int`): Return batches of window size.
Returns:
- sample_data (:obj:`List[BufferedData]`):
A list of data with length ``size``.
......
from typing import Any, Iterable, List, Optional, Union
from collections import defaultdict, deque
from ding.worker.buffer import Buffer, apply_middleware, BufferedData
from ding.worker.buffer.utils import fastcopy
import itertools
......@@ -13,12 +14,20 @@ class DequeBuffer(Buffer):
def __init__(self, size: int) -> None:
super().__init__()
self.storage = deque(maxlen=size)
# Meta index is a dict which use deque as values
self.meta_index = {}
@apply_middleware("push")
def push(self, data: Any, meta: Optional[dict] = None) -> BufferedData:
index = uuid.uuid1().hex
if meta is None:
meta = {}
buffered = BufferedData(data=data, index=index, meta=meta)
self.storage.append(buffered)
# Add meta index
for key in self.meta_index:
self.meta_index[key].append(meta[key] if key in meta else None)
return buffered
@apply_middleware("sample")
......@@ -29,25 +38,33 @@ class DequeBuffer(Buffer):
replace: bool = False,
sample_range: Optional[slice] = None,
ignore_insufficient: bool = False,
groupby: str = None,
rolling_window: int = None
) -> List[BufferedData]:
storage = self.storage
if sample_range:
storage = list(itertools.islice(self.storage, sample_range.start, sample_range.stop, sample_range.step))
# Size and indices
assert size or indices, "One of size and indices must not be empty."
if (size and indices) and (size != len(indices)):
raise AssertionError("Size and indices length must be equal.")
if not size:
size = len(indices)
# Indices and groupby
assert not (indices and groupby), "Cannot use groupby and indicex at the same time."
value_error = None
sampled_data = []
if indices:
indices_set = set(indices)
hashed_data = filter(lambda item: item.index in indices_set, self.storage)
hashed_data = filter(lambda item: item.index in indices_set, storage)
hashed_data = map(lambda item: (item.index, item), hashed_data)
hashed_data = dict(hashed_data)
# Re-sample and return in indices order
sampled_data = [hashed_data[index] for index in indices]
elif groupby:
sampled_data = self._sample_by_group(size=size, groupby=groupby, storage=storage, replace=replace)
else:
if replace:
sampled_data = random.choices(storage, k=size)
......@@ -66,7 +83,10 @@ class DequeBuffer(Buffer):
else:
raise ValueError("There are less than {} data in buffer({})".format(size, self.count()))
sampled_data = self._independence(sampled_data)
if groupby:
sampled_data = [self._independence(data) for data in sampled_data]
else:
sampled_data = self._independence(sampled_data)
return sampled_data
......@@ -109,6 +129,34 @@ class DequeBuffer(Buffer):
buffered_samples[i] = fastcopy.copy(buffered)
return buffered_samples
def _sample_by_group(self,
size: int,
groupby: str,
storage: deque,
replace: bool = False) -> List[List[BufferedData]]:
if groupby not in self.meta_index:
self._create_index(groupby)
meta_indices = list(set(self.meta_index[groupby]))
sampled_groups = []
if replace:
sampled_groups = random.choices(meta_indices, k=size)
else:
try:
sampled_groups = random.sample(meta_indices, k=size)
except ValueError as e:
pass
sampled_data = defaultdict(list)
for buffered in storage:
meta_value = buffered.meta[groupby] if groupby in buffered.meta else None
if meta_value in sampled_groups:
sampled_data[buffered.meta[groupby]].append(buffered)
return sampled_data.values()
def _create_index(self, meta_key: str):
self.meta_index[meta_key] = deque(maxlen=self.storage.maxlen)
for data in self.storage:
self.meta_index[meta_key].append(data.meta[meta_key] if meta_key in data.meta else None)
def __iter__(self) -> deque:
return iter(self.storage)
......
......@@ -187,3 +187,30 @@ def test_independence():
assert len(sampled_data) == 2
sampled_data[0].data["key"] = "new"
assert sampled_data[1].data["key"] == "origin"
@pytest.mark.unittest
def test_groupby():
buffer = DequeBuffer(size=3)
buffer.push("a", {"group": 1})
buffer.push("b", {"group": 2})
buffer.push("c", {"group": 2})
sampled_data = buffer.sample(2, groupby="group")
assert len(sampled_data) == 2
group1 = sampled_data[0] if len(sampled_data[0]) == 1 else sampled_data[1]
group2 = sampled_data[0] if len(sampled_data[0]) == 2 else sampled_data[1]
# Group1 should contain a
assert "a" == group1[0].data
# Group2 should contain b and c
data = [buffered.data for buffered in group2] # ["b", "c"]
assert "b" in data
assert "c" in data
# Push new data and swap out a
buffer.push("d", {"group": 2})
sampled_data = buffer.sample(1, groupby="group")
assert len(sampled_data) == 1
assert len(sampled_data[0]) == 3
data = [buffered.data for buffered in sampled_data[0]]
assert "d" in data
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册