提交 02a0e808 编写于 作者: X Xu Jingxin

Init base buffer and storage

上级 f70d3ddb
......@@ -3,3 +3,4 @@ from .learner import *
from .replay_buffer import *
from .coordinator import *
from .adapter import *
from .buffer import *
from .buffer import Buffer
from .naive_buffer import NaiveBuffer
from .storage import Storage
from .memory_storage import MemoryStorage
from abc import abstractmethod
from typing import Any, List
from ding.worker.buffer.storage import Storage
class Buffer:
def __init__(self, storage: Storage) -> None:
self.storage = storage
@abstractmethod
def push(self, data: Any) -> None:
raise NotImplementedError
@abstractmethod
def sample(self, size: int) -> List[Any]:
raise NotImplementedError
@abstractmethod
def clear(self) -> None:
raise NotImplementedError
from typing import Any, List
from collections import deque
from operator import itemgetter
from ding.worker.buffer import Storage
import random
class MemoryStorage(Storage):
def __init__(self, maxlen: int) -> None:
self.storage = deque(maxlen=maxlen)
def append(self, data: Any) -> None:
self.storage.append(data)
def get(self, indices: List[int]) -> List[Any]:
return itemgetter(*indices)(self.storage)
def sample(self, size: int) -> List[Any]:
return random.sample(self.storage, size)
def count(self) -> int:
return len(self.storage)
def clear(self) -> None:
self.storage.clear()
from typing import Any, List
from ding.worker.buffer import Buffer
from ding.worker.buffer.storage import Storage
class NaiveBuffer(Buffer):
def __init__(self, storage: Storage, **kwargs) -> None:
super().__init__(storage, **kwargs)
def push(self, data: Any) -> None:
self.storage.append(data)
def sample(self, size: int) -> List[Any]:
return self.storage.sample(size)
def clear(self) -> None:
self.storage.clear()
from abc import abstractmethod
from typing import Any, List
class Storage:
@abstractmethod
def append(self, data: Any) -> None:
raise NotImplementedError
@abstractmethod
def get(self, indices: List[int]) -> List[Any]:
raise NotImplementedError
@abstractmethod
def sample(self, size: int) -> List[Any]:
raise NotImplementedError
@abstractmethod
def count(self) -> int:
raise NotImplementedError
@abstractmethod
def clear(self) -> None:
raise NotImplementedError
import pytest
from ding.worker.buffer import NaiveBuffer
from ding.worker.buffer import MemoryStorage
@pytest.mark.unittest
def test_naive_push_sample():
storage = MemoryStorage(maxlen=10)
naive_buffer = NaiveBuffer(storage)
for i in range(20):
naive_buffer.push(i)
assert storage.count() == 10
assert len(set(naive_buffer.sample(10))) == 10
assert 0 not in naive_buffer.sample(10)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册