提交 9c67db8b 编写于 作者: N niuyazhe

polish(nyz): add return index in push and copy same data in sample

上级 f764de31
......@@ -52,13 +52,15 @@ class Buffer:
self.middleware = []
@abstractmethod
def push(self, data: Any, meta: Optional[dict] = None) -> None:
def push(self, data: Any, meta: Optional[dict] = None) -> Any:
"""
Overview:
Push data and it's meta information in buffer.
Arguments:
- data (:obj:`Any`): The data which will be pushed into buffer.
- meta (:obj:`dict`): Meta information, e.g. priority, count, staleness.
Returns:
- index (:obj:`Any`): The index of pushed data.
"""
raise NotImplementedError
......@@ -68,7 +70,7 @@ class Buffer:
size: Optional[int] = None,
indices: Optional[List[str]] = None,
replace: bool = False,
range: Optional[slice] = None,
sample_range: Optional[slice] = None,
ignore_insufficient: bool = False
) -> List[BufferedData]:
"""
......@@ -78,7 +80,7 @@ class Buffer:
- size (:obj:`Optional[int]`): The number of the data that will be sampled.
- indices (:obj:`Optional[List[str]]`): Sample with multiple indices.
- replace (:obj:`bool`): If use replace is true, you may receive duplicated data from the buffer.
- range (:obj:`slice`): Range slice.
- sample_range (:obj:`slice`): Sample range slice.
- ignore_insufficient (:obj:`bool`): If ignore_insufficient is true, sampling more than buffer size
with no repetition will not cause an exception.
Returns:
......
import enum
import copy
from typing import Any, Iterable, List, Optional, Tuple, Union
from collections import deque
from ding.worker.buffer import Buffer, apply_middleware, BufferedData
import itertools
import random
import uuid
import logging
from ding.worker.buffer import Buffer, apply_middleware, BufferedData
class DequeBuffer(Buffer):
......@@ -14,9 +16,10 @@ class DequeBuffer(Buffer):
self.storage = deque(maxlen=size)
@apply_middleware("push")
def push(self, data: Any, meta: Optional[dict] = None) -> None:
def push(self, data: Any, meta: Optional[dict] = None) -> str:
index = uuid.uuid1().hex
self.storage.append(BufferedData(data=data, index=index, meta=meta))
return index
@apply_middleware("sample")
def sample(
......@@ -24,12 +27,12 @@ class DequeBuffer(Buffer):
size: Optional[int] = None,
indices: Optional[List[str]] = None,
replace: bool = False,
range: Optional[slice] = None,
ignore_insufficient: bool = False
sample_range: Optional[slice] = None,
ignore_insufficient: bool = False,
) -> List[BufferedData]:
storage = self.storage
if range:
storage = list(itertools.islice(self.storage, range.start, range.stop, range.step))
if sample_range:
storage = list(itertools.islice(self.storage, sample_range.start, sample_range.stop, sample_range.step))
assert size or indices, "One of size and indices must not be empty."
if (size and indices) and (size != len(indices)):
raise AssertionError("Size and indices length must be equal.")
......@@ -40,6 +43,17 @@ class DequeBuffer(Buffer):
sampled_data = []
if indices:
sampled_data = list(filter(lambda item: item.index in indices, self.storage))
# for the same indices
if len(indices) != len(set(indices)):
sampled_data_no_same = sampled_data
sampled_data = [sampled_data_no_same[0]]
j = 0
for i in range(1, len(indices)):
if indices[i - 1] == indices[i]:
sampled_data.append(copy.deepcopy(sampled_data_no_same[j]))
else:
sampled_data.append(sampled_data_no_same[j])
j += 1
else:
if replace:
sampled_data = random.choices(storage, k=size)
......@@ -49,8 +63,17 @@ class DequeBuffer(Buffer):
except ValueError as e:
value_error = e
if not ignore_insufficient and (value_error or len(sampled_data) != size):
raise ValueError("There are less than {} data in buffer".format(size))
if value_error or len(sampled_data) != size:
if ignore_insufficient:
logging.warning(
"Sample operation is ignored due to data insufficient, current buffer count is {} while sample size is {}"
.format(self.count(), size)
)
else:
if value_error:
raise ValueError("Some errors in sample operation") from value_error
else:
raise ValueError("There are less than {} data in buffer({})".format(size, self.count()))
return sampled_data
......
......@@ -63,7 +63,7 @@ def clone_object():
"""
fastcopy = FastCopy()
def push(chain: Callable, data: Any, *args, **kwargs) -> None:
def push(chain: Callable, data: Any, *args, **kwargs) -> Any:
data = fastcopy.copy(data)
return chain(data, *args, **kwargs)
......
......@@ -16,6 +16,7 @@ class PriorityExperienceReplay:
IS_weight_anneal_train_iter: int = int(1e5)
) -> None:
self.buffer = buffer
self.buffer_idx = {}
self.buffer_size = buffer_size
self.IS_weight = IS_weight
self.priority_power_factor = priority_power_factor
......@@ -32,7 +33,7 @@ class PriorityExperienceReplay:
self.delta_anneal = (1 - self.IS_weight_power_factor) / self.IS_weight_anneal_train_iter
self.pivot = 0
def push(self, chain: Callable, data: Any, meta: Optional[dict] = None, *args, **kwargs) -> None:
def push(self, chain: Callable, data: Any, meta: Optional[dict] = None, *args, **kwargs) -> Any:
if meta is None:
meta = {'priority': self.max_priority}
else:
......@@ -40,8 +41,10 @@ class PriorityExperienceReplay:
meta['priority'] = self.max_priority
meta['priority_idx'] = self.pivot
self._update_tree(meta['priority'], self.pivot)
index = chain(data, meta=meta, *args, **kwargs)
self.buffer_idx[self.pivot] = index
self.pivot = (self.pivot + 1) % self.buffer_size
return chain(data, meta=meta, *args, **kwargs)
return index
def sample(self, chain: Callable, size: int, *args, **kwargs) -> List[Any]:
# Divide [0, 1) into size intervals on average
......@@ -51,8 +54,9 @@ class PriorityExperienceReplay:
# Rescale to [0, S), where S is the sum of all datas' priority (root value of sum tree)
mass *= self.sum_tree.reduce()
indices = [self.sum_tree.find_prefixsum_idx(m) for m in mass]
# TODO sample with indices
data = chain(size, *args, **kwargs)
indices = [self.buffer_idx[i] for i in indices]
# sample with indices
data = chain(indices=indices, *args, **kwargs)
if self.IS_weight:
# Calculate max weight for normalizing IS
sum_tree_root = self.sum_tree.reduce()
......@@ -83,6 +87,7 @@ class PriorityExperienceReplay:
priority_idx = meta['priority_idx']
self.sum_tree[priority_idx] = self.sum_tree.neutral_element
self.min_tree[priority_idx] = self.min_tree.neutral_element
self.buffer_idx.pop(priority_idx)
return chain(index, *args, **kwargs)
def clear(self, chain: Callable) -> None:
......@@ -91,6 +96,7 @@ class PriorityExperienceReplay:
self.sum_tree = SumSegmentTree(capacity)
if self.IS_weight:
self.min_tree = MinSegmentTree(capacity)
self.buffer_idx = {}
self.pivot = 0
chain()
......@@ -105,14 +111,15 @@ class PriorityExperienceReplay:
'IS_weight_power_factor': self.IS_weight_power_factor,
'sumtree': self.sumtree,
'mintree': self.mintree,
'buffer_idx': self.buffer_idx,
}
def load_state_dict(self, _state_dict: Dict, deepcopy: bool = False) -> None:
for k, v in _state_dict.items():
if deepcopy:
setattr(self, '_{}'.format(k), copy.deepcopy(v))
setattr(self, '{}'.format(k), copy.deepcopy(v))
else:
setattr(self, '_{}'.format(k), v)
setattr(self, '{}'.format(k), v)
def priority(*per_args, **per_kwargs):
......
......@@ -9,7 +9,7 @@ def staleness_check(buffer_: 'Buffer', max_staleness: int = float("inf")) -> Cal
If data's staleness is greater(>) than max_staleness, this data will be removed from buffer as soon as possible.
"""
def push(next: Callable, data: Any, *args, **kwargs) -> None:
def push(next: Callable, data: Any, *args, **kwargs) -> Any:
assert 'meta' in kwargs and 'train_iter_data_collected' in kwargs[
'meta'], "staleness_check middleware must push data with meta={'train_iter_data_collected': <iter>}"
return next(data, *args, **kwargs)
......
......@@ -8,7 +8,7 @@ def use_time_check(buffer_: 'Buffer', max_use: int = float("inf")) -> Callable:
greater than or equal to max_use, this data will be removed from buffer as soon as possible.
"""
def push(chain: Callable, data: Any, meta: dict = None, *args, **kwargs) -> None:
def push(chain: Callable, data: Any, meta: dict = None, *args, **kwargs) -> Any:
if meta:
meta["use_count"] = 0
else:
......
......@@ -76,8 +76,8 @@ def test_naive_push_sample():
buffer.clear()
for i in range(10):
buffer.push(i)
assert len(buffer.sample(5, range=slice(5, 10))) == 5
assert 0 not in [item.data for item in buffer.sample(5, range=slice(5, 10))]
assert len(buffer.sample(5, sample_range=slice(5, 10))) == 5
assert 0 not in [item.data for item in buffer.sample(5, sample_range=slice(5, 10))]
@pytest.mark.unittest
......@@ -162,6 +162,6 @@ def test_ignore_insufficient():
buffer.push(i)
with pytest.raises(ValueError):
buffer.sample(3)
buffer.sample(3, ignore_insufficient=False)
data = buffer.sample(3, ignore_insufficient=True)
assert len(data) == 0
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册