提交 d3f1a516 编写于 作者: X Xu Jingxin

Add buffer copy middleware

上级 e9246b74
......@@ -2,8 +2,8 @@ from typing import Any, List
from collections import deque
from operator import itemgetter
from ding.worker.buffer import Storage
import numpy as np
import itertools
import random
class MemoryStorage(Storage):
......@@ -21,7 +21,10 @@ class MemoryStorage(Storage):
storage = self.storage
if range:
storage = list(itertools.islice(self.storage, range.start, range.stop, range.step))
return np.random.choice(storage, size, replace=replace)
if replace:
return random.choices(storage, k=size)
else:
return random.sample(storage, k=size)
def count(self) -> int:
return len(self.storage)
......
from .clone_object import clone_object
from typing import Callable, Any, List
import torch
import numpy as np
class FastCopy:
"""
The idea of this class comes from this article
https://newbedev.com/what-is-a-fast-pythonic-way-to-deepcopy-just-data-from-a-python-dict-or-list.
We use recursive calls to copy each object that needs to be copied, which will be 5x faster
than copy.deepcopy.
"""
def __init__(self):
dispatch = {}
dispatch[list] = self._copy_list
dispatch[dict] = self._copy_dict
dispatch[torch.Tensor] = self._copy_tensor
dispatch[np.ndarray] = self._copy_ndarray
self.dispatch = dispatch
def _copy_list(self, l: List) -> dict:
ret = l.copy()
for idx, item in enumerate(ret):
cp = self.dispatch.get(type(item))
if cp is not None:
ret[idx] = cp(item)
return ret
def _copy_dict(self, d: dict) -> dict:
ret = d.copy()
for key, value in ret.items():
cp = self.dispatch.get(type(value))
if cp is not None:
ret[key] = cp(value)
return ret
def _copy_tensor(self, t: torch.Tensor) -> torch.Tensor:
return t.clone()
def _copy_ndarray(self, a: np.ndarray) -> np.ndarray:
return np.copy(a)
def copy(self, sth: Any) -> Any:
cp = self.dispatch.get(type(sth))
if cp is None:
return sth
else:
return cp(sth)
def clone_object():
"""
This middleware freezes the objects saved in memory buffer as a copy,
try this middleware when you need to keep the object unchanged in buffer, and modify
the object after sampling it (usuallly in multiple threads)
"""
fastcopy = FastCopy()
def push(next: Callable, data: Any, *args, **kwargs) -> None:
data = fastcopy.copy(data)
return next(data, *args, **kwargs)
def sample(next: Callable, *args, **kwargs) -> List[Any]:
data = next(*args, **kwargs)
return fastcopy.copy(data)
def _immutable_object(action: str, next: Callable, *args, **kwargs):
if action == "push":
return push(next, *args, **kwargs)
elif action == "sample":
return sample(next, *args, **kwargs)
return next(*args, **kwargs)
return _immutable_object
import pytest
import torch
from ding.worker.buffer import Buffer, MemoryStorage
from ding.worker.buffer.middlewares import clone_object
@pytest.mark.unittest
def test_clone_object():
buffer = Buffer(MemoryStorage(maxlen=10)).use(clone_object())
# Store a dict, a list, a tensor
arr = [{"key": "v1"}, ["a"], torch.Tensor([1, 2, 3])]
for o in arr:
buffer.push(o)
# Modify it
for item in buffer.sample(len(arr)):
if isinstance(item, dict):
item["key"] = "v2"
elif isinstance(item, list):
item.append("b")
elif isinstance(item, torch.Tensor):
item[0] = 3
else:
raise Exception("Unexpected type")
# Resample it, and check their values
for item in buffer.sample(len(arr)):
if isinstance(item, dict):
assert item["key"] == "v1"
elif isinstance(item, list):
assert len(item) == 1
elif isinstance(item, torch.Tensor):
assert item[0] == 1
else:
raise Exception("Unexpected type")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册