提交 12db1bad 编写于 作者: X Xu Jingxin

Make sure sampled data in buffer is different from each other

上级 3d698d0f
from typing import Any, Iterable, List, Optional, Union
from collections import deque
from collections import defaultdict, deque
from ding.worker.buffer import Buffer, apply_middleware, BufferedData
from ding.worker.buffer.utils import fastcopy
import itertools
import random
import uuid
......@@ -65,6 +66,8 @@ class DequeBuffer(Buffer):
else:
raise ValueError("There are less than {} data in buffer({})".format(size, self.count()))
sampled_data = self._independence(sampled_data)
return sampled_data
@apply_middleware("update")
......@@ -93,6 +96,19 @@ class DequeBuffer(Buffer):
def clear(self) -> None:
self.storage.clear()
def _independence(self, buffered_samples: List[BufferedData]) -> List[BufferedData]:
"""
Overview:
Make sure that each record is different from each other, but remember that this function
is different from clone_object. You may change the data in the buffer by modifying a record.
"""
occurred = defaultdict(int)
for i, buffered in enumerate(buffered_samples):
occurred[buffered.index] += 1
if occurred[buffered.index] > 1:
buffered_samples[i] = fastcopy.copy(buffered)
return buffered_samples
def __iter__(self) -> deque:
return iter(self.storage)
......
......@@ -165,3 +165,25 @@ def test_ignore_insufficient():
buffer.sample(3, ignore_insufficient=False)
data = buffer.sample(3, ignore_insufficient=True)
assert len(data) == 0
@pytest.mark.unittest
def test_independence():
# By replace
buffer = DequeBuffer(size=1)
data = {"key": "origin"}
buffer.push(data)
sampled_data = buffer.sample(2, replace=True)
assert len(sampled_data) == 2
sampled_data[0].data["key"] = "new"
assert sampled_data[1].data["key"] == "origin"
# By indices
buffer = DequeBuffer(size=1)
data = {"key": "origin"}
buffered = buffer.push(data)
indices = [buffered.index, buffered.index]
sampled_data = buffer.sample(indices=indices)
assert len(sampled_data) == 2
sampled_data[0].data["key"] = "new"
assert sampled_data[1].data["key"] == "origin"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册