data.collate_fn

collate function

default_collate

Overview:

Put each data field into a tensor with outer dimension batch size.

Example:
>>> # a list with B tensors shaped (m, n) -->> a tensor shaped (B, m, n)
>>> a = [torch.zeros(2,3) for _ in range(4)]
>>> default_collate(a).shape
torch.Size([4, 2, 3])
>>>
>>> # a list with B lists, each list contains m elements -->> a list of m tensors, each with shape (B, )
>>> a = [[0 for __ in range(3)] for _ in range(4)]
>>> default_collate(a)
[tensor([0, 0, 0, 0]), tensor([0, 0, 0, 0]), tensor([0, 0, 0, 0])]
>>>
>>> # a list with B dicts, whose values are tensors shaped :math:`(m, n)` -->>
>>> # a dict whose values are tensors with shape :math:`(B, m, n)`
>>> a = [{i: torch.zeros(i,i+1) for i in range(2, 4)} for _ in range(4)]
>>> print(a[0][2].shape, a[0][3].shape)
torch.Size([2, 3]) torch.Size([3, 4])
>>> b = default_collate(a)
>>> print(b[2].shape, b[3].shape)
torch.Size([4, 2, 3]) torch.Size([4, 3, 4])
Arguments:
  • batch (Sequence): a data sequence, whose length is batch size, whose element is one piece of data

Returns:
  • ret (Union[torch.Tensor, Mapping, Sequence]): the collated data, with batch size into each data field. the return dtype depends on the original element dtype, can be [torch.Tensor, Mapping, Sequence].

timestep_collate

Overview:

Put each timestepped data field into a tensor with outer dimension batch size using default_collate. For short, this process can be represented by: [len=B, ele={dict_key: [len=T, ele=Tensor(any_dims)]}] -> {dict_key: Tensor([T, B, any_dims])}

Arguments:
  • batch (List[Dict[str, Any]]): a list of dicts with length B, each element is {some_key: some_seq} (‘prev_state’ should be a key in the dict); some_seq is a sequence with length T, each element is a torch.Tensor with any shape.

Returns:
  • ret (Dict[str, Union[torch.Tensor, list]]): the collated data, with timestep and batch size into each data field. By using default_collate, timestep would come to the first dim. So the final shape is \((T, B, dim1, dim2, ...)\)

diff_shape_collate

Overview:

Similar to default_collate, put each data field into a tensor with outer dimension batch size. The main difference is that, diff_shape_collate allows tensors in the batch have None, which is quite common StarCraft observation.

Arguments:
  • batch (Sequence): a data sequence, whose length is batch size, whose element is one piece of data

Returns:
  • ret (Union[torch.Tensor, Mapping, Sequence]): the collated data, with batch size into each data field. the return dtype depends on the original element dtype, can be [torch.Tensor, Mapping, Sequence].

default_decollate

Overview:

Drag out batch_size collated data’s batch size to decollate it, which is the reverse operation of default_collate.

Arguments:
  • batch (Union[torch.Tensor, Sequence, Mapping]): can refer to the Returns of default_collate

  • ignore(List[str]): a list of names to be ignored, only function if input batch is a dict. If key is in this list, its value would stay the same with no decollation.

Returns:
  • ret (List[Any]): a list with B elements.