提交 3d698d0f 编写于 作者: X Xu Jingxin

Fix sample with indices, ensure return size is equal to input size or indices size

上级 1572fd3e
......@@ -52,7 +52,7 @@ class Buffer:
self.middleware = []
@abstractmethod
def push(self, data: Any, meta: Optional[dict] = None) -> Any:
def push(self, data: Any, meta: Optional[dict] = None) -> BufferedData:
"""
Overview:
Push data and it's meta information in buffer.
......@@ -60,7 +60,7 @@ class Buffer:
- data (:obj:`Any`): The data which will be pushed into buffer.
- meta (:obj:`dict`): Meta information, e.g. priority, count, staleness.
Returns:
- index (:obj:`Any`): The index of pushed data.
- buffered_data (:obj:`BufferedData`): The pushed data.
"""
raise NotImplementedError
......
from typing import Any, Iterable, List, Optional, Union
from collections import deque
from ding.worker.buffer import Buffer, apply_middleware, BufferedData
import itertools
import random
import uuid
import logging
from ding.worker.buffer import Buffer, apply_middleware, BufferedData
class DequeBuffer(Buffer):
......@@ -14,10 +14,11 @@ class DequeBuffer(Buffer):
self.storage = deque(maxlen=size)
@apply_middleware("push")
def push(self, data: Any, meta: Optional[dict] = None) -> str:
def push(self, data: Any, meta: Optional[dict] = None) -> BufferedData:
index = uuid.uuid1().hex
self.storage.append(BufferedData(data=data, index=index, meta=meta))
return index
buffered = BufferedData(data=data, index=index, meta=meta)
self.storage.append(buffered)
return buffered
@apply_middleware("sample")
def sample(
......@@ -40,18 +41,12 @@ class DequeBuffer(Buffer):
value_error = None
sampled_data = []
if indices:
sampled_data = list(filter(lambda item: item.index in indices, self.storage))
# for the same indices
if len(indices) != len(set(indices)):
sampled_data_no_same = sampled_data
sampled_data = [sampled_data_no_same[0]]
j = 0
for i in range(1, len(indices)):
if indices[i - 1] == indices[i]:
sampled_data.append(copy.deepcopy(sampled_data_no_same[j]))
else:
sampled_data.append(sampled_data_no_same[j])
j += 1
indices_set = set(indices)
hashed_data = filter(lambda item: item.index in indices_set, self.storage)
hashed_data = map(lambda item: (item.index, item), hashed_data)
hashed_data = dict(hashed_data)
# Re-sample and return in indices order
sampled_data = [hashed_data[index] for index in indices]
else:
if replace:
sampled_data = random.choices(storage, k=size)
......@@ -68,10 +63,7 @@ class DequeBuffer(Buffer):
.format(self.count(), size)
)
else:
if value_error:
raise ValueError("Some errors in sample operation") from value_error
else:
raise ValueError("There are less than {} data in buffer({})".format(size, self.count()))
raise ValueError("There are less than {} data in buffer({})".format(size, self.count()))
return sampled_data
......
from typing import Callable, Any, List
from ding.worker.buffer import BufferedData
import torch
import numpy as np
class FastCopy:
"""
The idea of this class comes from this article
https://newbedev.com/what-is-a-fast-pythonic-way-to-deepcopy-just-data-from-a-python-dict-or-list.
We use recursive calls to copy each object that needs to be copied, which will be 5x faster
than copy.deepcopy.
"""
def __init__(self):
dispatch = {}
dispatch[list] = self._copy_list
dispatch[dict] = self._copy_dict
dispatch[torch.Tensor] = self._copy_tensor
dispatch[np.ndarray] = self._copy_ndarray
dispatch[BufferedData] = self._copy_buffereddata
self.dispatch = dispatch
def _copy_list(self, l: List) -> dict:
ret = l.copy()
for idx, item in enumerate(ret):
cp = self.dispatch.get(type(item))
if cp is not None:
ret[idx] = cp(item)
return ret
def _copy_dict(self, d: dict) -> dict:
ret = d.copy()
for key, value in ret.items():
cp = self.dispatch.get(type(value))
if cp is not None:
ret[key] = cp(value)
return ret
def _copy_tensor(self, t: torch.Tensor) -> torch.Tensor:
return t.clone()
def _copy_ndarray(self, a: np.ndarray) -> np.ndarray:
return np.copy(a)
def _copy_buffereddata(self, d: BufferedData) -> BufferedData:
return BufferedData(data=self.copy(d.data), index=d.index, meta=self.copy(d.meta))
def copy(self, sth: Any) -> Any:
cp = self.dispatch.get(type(sth))
if cp is None:
return sth
else:
return cp(sth)
from ding.worker.buffer.utils import fastcopy
def clone_object():
......@@ -61,13 +9,12 @@ def clone_object():
try this middleware when you need to keep the object unchanged in buffer, and modify
the object after sampling it (usuallly in multiple threads)
"""
fastcopy = FastCopy()
def push(chain: Callable, data: Any, *args, **kwargs) -> Any:
def push(chain: Callable, data: Any, *args, **kwargs) -> BufferedData:
data = fastcopy.copy(data)
return chain(data, *args, **kwargs)
def sample(chain: Callable, *args, **kwargs) -> List[Any]:
def sample(chain: Callable, *args, **kwargs) -> List[BufferedData]:
data = chain(*args, **kwargs)
return fastcopy.copy(data)
......
from collections import defaultdict
from typing import Callable, Any, List, Dict, Optional
import copy
import numpy as np
from ding.utils import SumSegmentTree, MinSegmentTree
from ding.worker.buffer.buffer import BufferedData
class PriorityExperienceReplay:
......@@ -33,7 +35,7 @@ class PriorityExperienceReplay:
self.delta_anneal = (1 - self.IS_weight_power_factor) / self.IS_weight_anneal_train_iter
self.pivot = 0
def push(self, chain: Callable, data: Any, meta: Optional[dict] = None, *args, **kwargs) -> Any:
def push(self, chain: Callable, data: Any, meta: Optional[dict] = None, *args, **kwargs) -> BufferedData:
if meta is None:
meta = {'priority': self.max_priority}
else:
......@@ -41,12 +43,13 @@ class PriorityExperienceReplay:
meta['priority'] = self.max_priority
meta['priority_idx'] = self.pivot
self._update_tree(meta['priority'], self.pivot)
index = chain(data, meta=meta, *args, **kwargs)
buffered = chain(data, meta=meta, *args, **kwargs)
index = buffered.index
self.buffer_idx[self.pivot] = index
self.pivot = (self.pivot + 1) % self.buffer_size
return index
return buffered
def sample(self, chain: Callable, size: int, *args, **kwargs) -> List[Any]:
def sample(self, chain: Callable, size: int, *args, **kwargs) -> List[BufferedData]:
# Divide [0, 1) into size intervals on average
intervals = np.array([i * 1.0 / size for i in range(size)])
# Uniformly sample within each interval
......@@ -55,7 +58,7 @@ class PriorityExperienceReplay:
mass *= self.sum_tree.reduce()
indices = [self.sum_tree.find_prefixsum_idx(m) for m in mass]
indices = [self.buffer_idx[i] for i in indices]
# sample with indices
# Sample with indices
data = chain(indices=indices, *args, **kwargs)
if self.IS_weight:
# Calculate max weight for normalizing IS
......
from .fast_copy import FastCopy, fastcopy
import torch
import numpy as np
from typing import Any, List
from ding.worker.buffer.buffer import BufferedData
class FastCopy:
"""
The idea of this class comes from this article
https://newbedev.com/what-is-a-fast-pythonic-way-to-deepcopy-just-data-from-a-python-dict-or-list.
We use recursive calls to copy each object that needs to be copied, which will be 5x faster
than copy.deepcopy.
"""
def __init__(self):
dispatch = {}
dispatch[list] = self._copy_list
dispatch[dict] = self._copy_dict
dispatch[torch.Tensor] = self._copy_tensor
dispatch[np.ndarray] = self._copy_ndarray
dispatch[BufferedData] = self._copy_buffereddata
self.dispatch = dispatch
def _copy_list(self, l: List) -> dict:
ret = l.copy()
for idx, item in enumerate(ret):
cp = self.dispatch.get(type(item))
if cp is not None:
ret[idx] = cp(item)
return ret
def _copy_dict(self, d: dict) -> dict:
ret = d.copy()
for key, value in ret.items():
cp = self.dispatch.get(type(value))
if cp is not None:
ret[key] = cp(value)
return ret
def _copy_tensor(self, t: torch.Tensor) -> torch.Tensor:
return t.clone()
def _copy_ndarray(self, a: np.ndarray) -> np.ndarray:
return np.copy(a)
def _copy_buffereddata(self, d: BufferedData) -> BufferedData:
return BufferedData(data=self.copy(d.data), index=d.index, meta=self.copy(d.meta))
def copy(self, sth: Any) -> Any:
cp = self.dispatch.get(type(sth))
if cp is None:
return sth
else:
return cp(sth)
fastcopy = FastCopy()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册