提交 3de05532 编写于 作者: N niuyazhe 提交者: Xu Jingxin

add naive use time count middleware in buffer

上级 368f2c26
from .clone_object import clone_object
from .use_time_check import use_time_check
from typing import Callable, Any, List
from collections import deque
def use_time_check(max_use: int = float("inf")) -> Callable:
"""
Overview:
This middleware aims to check the usage times of data in buffer. If the usage times of a data is
greater than max_use, this data will be removed from buffer as soon as possible.
"""
def push(next: Callable, data: Any, *args, **kwargs) -> None:
if 'meta' in kwargs:
kwargs['meta']['use_count'] = 0
else:
kwargs['meta'] = {'use_count': 0}
return next(data, *args, **kwargs)
def sample(next: Callable, *args, **kwargs) -> List[Any]:
kwargs['return_index'] = True
kwargs['return_meta'] = True
data = next(*args, **kwargs)
for i, (d, idx, meta) in enumerate(data):
meta['use_count'] += 1
if meta['use_count'] >= max_use:
print('max_use trigger') # TODO(nyz)
return data
def _immutable_object(action: str, next: Callable, *args, **kwargs) -> Any:
if action == "push":
return push(next, *args, **kwargs)
elif action == "sample":
return sample(next, *args, **kwargs)
return next(*args, **kwargs)
return _immutable_object
import pytest
import torch
from ding.worker.buffer import Buffer, MemoryStorage
from ding.worker.buffer.middleware import clone_object
from ding.worker.buffer import Buffer, DequeStorage
from ding.worker.buffer.middleware import clone_object, use_time_check
@pytest.mark.unittest
def test_clone_object():
buffer = Buffer(MemoryStorage(maxlen=10)).use(clone_object())
buffer = Buffer(DequeStorage(maxlen=10)).use(clone_object())
# Store a dict, a list, a tensor
arr = [{"key": "v1"}, ["a"], torch.Tensor([1, 2, 3])]
......@@ -34,3 +34,22 @@ def test_clone_object():
assert item[0] == 1
else:
raise Exception("Unexpected type")
@pytest.mark.tmp
def test_use_time_check():
def get_data():
return {'obs': torch.randn(4), 'reward': torch.randn(1), 'info': 'xxx'}
N = 6
buffer = Buffer(DequeStorage(maxlen=10)).use(use_time_check(max_use=2))
for _ in range(N):
buffer.push(get_data())
for i in range(2):
data = buffer.sample(size=N, replace=False)
assert len(data) == N
print('sample i')
buffer.sample(size=6, replace=False)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册