data.collate_fn¶
collate function¶
default_collate¶
- Overview:
Put each data field into a tensor with outer dimension batch size.
- Example:
>>> # a list with B tensors shaped (m, n) -->> a tensor shaped (B, m, n) >>> a = [torch.zeros(2,3) for _ in range(4)] >>> default_collate(a).shape torch.Size([4, 2, 3]) >>> >>> # a list with B lists, each list contains m elements -->> a list of m tensors, each with shape (B, ) >>> a = [[0 for __ in range(3)] for _ in range(4)] >>> default_collate(a) [tensor([0, 0, 0, 0]), tensor([0, 0, 0, 0]), tensor([0, 0, 0, 0])] >>> >>> # a list with B dicts, whose values are tensors shaped :math:`(m, n)` -->> >>> # a dict whose values are tensors with shape :math:`(B, m, n)` >>> a = [{i: torch.zeros(i,i+1) for i in range(2, 4)} for _ in range(4)] >>> print(a[0][2].shape, a[0][3].shape) torch.Size([2, 3]) torch.Size([3, 4]) >>> b = default_collate(a) >>> print(b[2].shape, b[3].shape) torch.Size([4, 2, 3]) torch.Size([4, 3, 4])
- Arguments:
batch (
Sequence
): a data sequence, whose length is batch size, whose element is one piece of data
- Returns:
ret (
Union[torch.Tensor, Mapping, Sequence]
): the collated data, with batch size into each data field. the return dtype depends on the original element dtype, can be [torch.Tensor, Mapping, Sequence].
timestep_collate¶
- Overview:
Put each timestepped data field into a tensor with outer dimension batch size using
default_collate
. For short, this process can be represented by: [len=B, ele={dict_key: [len=T, ele=Tensor(any_dims)]}] -> {dict_key: Tensor([T, B, any_dims])}- Arguments:
batch (
List[Dict[str, Any]]
): a list of dicts with length B, each element is {some_key: some_seq} (‘prev_state’ should be a key in the dict); some_seq is a sequence with length T, each element is a torch.Tensor with any shape.
- Returns:
ret (
Dict[str, Union[torch.Tensor, list]]
): the collated data, with timestep and batch size into each data field. By usingdefault_collate
, timestep would come to the first dim. So the final shape is \((T, B, dim1, dim2, ...)\)
diff_shape_collate¶
- Overview:
Similar to
default_collate
, put each data field into a tensor with outer dimension batch size. The main difference is that,diff_shape_collate
allows tensors in the batch have None, which is quite common StarCraft observation.- Arguments:
batch (
Sequence
): a data sequence, whose length is batch size, whose element is one piece of data
- Returns:
ret (
Union[torch.Tensor, Mapping, Sequence]
): the collated data, with batch size into each data field. the return dtype depends on the original element dtype, can be [torch.Tensor, Mapping, Sequence].
default_decollate¶
- Overview:
Drag out batch_size collated data’s batch size to decollate it, which is the reverse operation of
default_collate
.- Arguments:
batch (
Union[torch.Tensor, Sequence, Mapping]
): can refer to the Returns ofdefault_collate
ignore(
List[str]
): a list of names to be ignored, only function if inputbatch
is a dict. If key is in this list, its value would stay the same with no decollation.
- Returns:
ret (
List[Any]
): a list with B elements.