提交 9caf963b 编写于 作者: X Xu Jingxin

Padding after sampling from buffer

上级 d053cfcc
......@@ -2,3 +2,4 @@ from .clone_object import clone_object
from .use_time_check import use_time_check
from .staleness_check import staleness_check
from .priority import priority
from .padding import padding
from typing import Callable, Any, List
from typing import Callable, Any, List, Union
from ding.worker.buffer import BufferedData
from ding.worker.buffer.utils import fastcopy
......@@ -14,15 +14,15 @@ def clone_object():
data = fastcopy.copy(data)
return chain(data, *args, **kwargs)
def sample(chain: Callable, *args, **kwargs) -> List[BufferedData]:
def sample(chain: Callable, *args, **kwargs) -> Union[List[BufferedData], List[List[BufferedData]]]:
data = chain(*args, **kwargs)
return fastcopy.copy(data)
def _immutable_object(action: str, chain: Callable, *args, **kwargs):
def _clone_object(action: str, chain: Callable, *args, **kwargs):
if action == "push":
return push(chain, *args, **kwargs)
elif action == "sample":
return sample(chain, *args, **kwargs)
return chain(*args, **kwargs)
return _immutable_object
return _clone_object
import random
from typing import Callable, Union, List
from ding.worker.buffer import BufferedData
from ding.worker.buffer.utils import fastcopy
def padding(method="group"):
"""
Overview:
Fill the nested buffer list to the same size as the largest list.
The default method `group` will randomly select data from each group
and fill it into the current group list.
Arguments:
- method (:obj:`str`): Padding method, currently only supports `group`.
"""
def sample(chain: Callable, *args, **kwargs) -> Union[List[BufferedData], List[List[BufferedData]]]:
sampled_data = chain(*args, **kwargs)
if len(sampled_data) == 0 or isinstance(sampled_data[0], BufferedData):
return sampled_data
if method == "group":
max_len = len(max(sampled_data, key=len))
for i, grouped_data in enumerate(sampled_data):
group_len = len(grouped_data)
if group_len == max_len:
continue
for _ in range(max_len - group_len):
sampled_data[i].append(fastcopy.copy(random.choice(grouped_data)))
return sampled_data
def _padding(action: str, chain: Callable, *args, **kwargs):
if action == "sample":
return sample(chain, *args, **kwargs)
return chain(*args, **kwargs)
return _padding
......@@ -2,6 +2,7 @@ import pytest
import torch
from ding.worker.buffer import DequeBuffer
from ding.worker.buffer.middleware import clone_object, use_time_check, staleness_check, priority
from ding.worker.buffer.middleware.padding import padding
@pytest.mark.unittest
......@@ -106,3 +107,15 @@ def test_priority():
assert buffer.count() == N + N - 1
buffer.clear()
assert buffer.count() == 0
@pytest.mark.unittest
def test_padding():
buffer = DequeBuffer(size=10)
buffer.use(padding(method="group"))
for i in range(10):
buffer.push(i, {"group": i & 5}) # [3,3,2,2]
sampled_data = buffer.sample(4, groupby="group")
assert len(sampled_data) == 4
for grouped_data in sampled_data:
assert len(grouped_data) == 3
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册