Stack Structured Data

When the tensors form the tree structures, they are often needed to be stacked together, like the torch.stack() implemented in torch.

Stack With Native PyTorch API

Here is the common code implement with native pytorch API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]


# execute `stack` op
def stack(data, dim):
    elem = data[0]
    if isinstance(elem, torch.Tensor):
        return torch.stack(data, dim)
    elif isinstance(elem, dict):
        return {k: stack([item[k] for item in data], dim) for k in elem.keys()}
    elif isinstance(elem, bool):
        return torch.BoolTensor(data)
    else:
        raise TypeError("not support elem type: {}".format(type(elem)))


stacked_data = stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data['obs']['image'].shape == (B, 3, 32, 32)
assert stacked_data['action'].shape == (B, 1)
assert stacked_data['reward'].shape == (B, 1)
assert stacked_data['done'].shape == (B,)
assert stacked_data['done'].dtype == torch.bool

The output should be like below, and the assertion statements can be all passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
{'obs': {'scalar': tensor([[-1.9811, -0.6310,  0.2687,  0.1140,  1.3347,  0.8463,  0.5271,  0.1108,
          1.2042,  1.3970, -0.0295,  1.0527],
        [ 0.4061,  0.3583,  0.3792, -0.5670,  1.1174,  1.3263, -0.8318, -0.0783,
         -0.0136, -0.7235,  1.1906, -0.2105],
        [ 0.9149, -0.6170,  0.5386, -0.0205, -0.2088,  0.6321,  0.0632,  1.6475,
          1.1266, -1.4222, -0.0841, -0.7570],
        [ 1.8852,  1.0335,  0.5187, -0.3776, -1.5329,  0.2858,  0.0705,  0.9547,
          1.2295, -0.5275,  0.7283, -0.7945]]), 'image': tensor([[[[-1.0217e+00, -7.8143e-01, -4.3289e-01,  ..., -4.9537e-02,
           -1.7112e+00,  3.9130e-02],
          [-5.6892e-01, -4.2825e-01, -1.6657e-01,  ..., -9.4172e-01,
           -7.4692e-01,  8.9978e-01],
          [ 1.8623e-01,  1.8211e-01, -1.0610e+00,  ...,  1.4546e-01,
           -1.6602e-01,  8.1264e-01],
          ...,
          [-5.3276e-01,  2.7405e-01,  3.1004e-01,  ..., -4.1659e-01,
           -1.7300e-01, -6.5661e-01],
          [-1.7153e-01,  2.1620e-01,  7.9388e-01,  ..., -1.4514e-01,
           -1.5142e+00, -1.4285e-01],
          [-8.7724e-01, -3.4188e-01, -9.0337e-01,  ..., -1.2792e+00,
            1.4045e+00, -6.9548e-01]],

         [[-1.6186e-01, -1.0006e+00,  1.2906e-01,  ...,  9.9842e-01,
            1.2895e+00,  1.5779e+00],
          [-1.3438e+00, -1.6090e+00,  1.2038e+00,  ..., -1.2569e+00,
           -7.0812e-01,  2.1606e-01],
          [ 1.3344e+00,  6.9511e-01, -7.9401e-01,  ..., -1.6805e-01,
            9.0465e-01, -6.9544e-01],
          ...,
          [-1.4468e-02,  3.4023e+00, -5.0362e-01,  ...,  1.1120e+00,
            4.2084e-01, -5.4911e-01],
          [-1.7984e+00, -1.1071e-01, -2.1617e-02,  ..., -1.6131e-01,
           -4.1380e-01, -1.1293e+00],
          [ 1.2276e-01,  6.6592e-01,  7.7478e-01,  ..., -1.5599e+00,
            7.5058e-01,  5.0642e-02]],

         [[-7.4318e-01,  2.7006e+00,  7.4177e-01,  ..., -1.7956e-01,
           -8.8795e-01, -4.8795e-01],
          [-5.2797e-01,  1.5856e+00,  5.3728e-01,  ..., -1.7127e-01,
           -1.7857e-02,  2.5123e+00],
          [ 1.4399e+00,  1.0639e+00,  1.4504e+00,  ..., -3.7635e-01,
            1.8044e-01,  6.8953e-01],
          ...,
          [ 8.5867e-01, -6.1371e-01, -5.6900e-01,  ..., -2.7395e-01,
           -7.6884e-01,  5.3553e-01],
          [ 9.1688e-01,  9.7921e-01,  4.4553e-01,  ..., -2.9864e-01,
            5.2475e-01, -3.0736e-01],
          [-5.1459e-01, -5.0015e-01, -6.2014e-01,  ..., -8.9271e-01,
            6.9432e-01, -5.6116e-01]]],


        [[[-5.0897e-02,  8.2480e-01,  4.3396e-01,  ...,  1.1813e-01,
            1.7422e+00,  1.9401e+00],
          [-7.1134e-01, -9.0490e-01, -2.7712e-01,  ...,  7.6666e-01,
            4.9568e-01, -1.6780e-01],
          [-1.2148e+00,  2.0147e+00,  1.0314e+00,  ..., -1.5945e+00,
            4.4528e-01, -9.9011e-01],
          ...,
          [-4.1808e-01,  2.3181e+00,  4.3920e-01,  ..., -4.8859e-01,
            1.2778e+00, -5.4779e-01],
          [-2.9996e-01,  6.1118e-01,  1.1470e+00,  ...,  4.4889e-01,
            1.2352e+00,  9.0062e-01],
          [-3.2604e-02, -8.1552e-01,  5.1272e-01,  ...,  9.1129e-01,
           -4.6579e-01,  5.2802e-02]],

         [[-2.8821e-01,  9.9290e-02, -7.1423e-02,  ...,  2.0243e+00,
            9.4038e-01, -1.5481e+00],
          [-4.7209e-01, -6.2841e-01,  9.2118e-02,  ...,  1.9647e+00,
            6.7564e-01, -1.3165e+00],
          [ 1.2051e+00, -5.3143e-01, -9.7679e-03,  ...,  1.5663e+00,
            2.5792e-01,  4.2365e-02],
          ...,
          [-1.6373e-01,  1.5249e+00, -5.5395e-01,  ..., -5.2294e-01,
            7.5198e-01,  6.1379e-01],
          [-2.0124e+00, -1.5621e+00,  9.0538e-01,  ..., -2.6934e-02,
           -1.3232e+00, -2.5675e+00],
          [ 1.4431e-03, -9.7911e-01, -2.5921e-01,  ...,  6.9748e-01,
           -8.3841e-01,  7.7708e-01]],

         [[ 1.0238e+00, -9.5894e-02,  3.3228e-01,  ...,  1.6909e+00,
           -7.9180e-01,  9.3506e-02],
          [ 5.6678e-01,  6.8133e-01, -9.3976e-01,  ..., -7.3359e-02,
           -2.8835e-02,  2.8456e+00],
          [ 3.0378e-01,  1.6101e+00,  9.2667e-01,  ...,  1.0352e+00,
           -3.3934e-01, -4.2686e-01],
          ...,
          [ 1.9312e+00, -7.7472e-01, -6.5101e-01,  ..., -5.5219e-01,
            1.2199e+00,  2.7379e+00],
          [ 2.2563e-01,  5.1150e-01,  1.0954e+00,  ...,  1.3035e+00,
           -1.4855e-02, -1.6582e+00],
          [-5.8985e-01,  4.0435e-01,  7.5526e-01,  ..., -1.3463e-01,
            5.7227e-01,  1.2230e+00]]],


        [[[ 2.0728e+00, -1.1643e+00,  1.6030e+00,  ..., -1.3063e+00,
           -2.0675e+00, -1.0917e+00],
          [ 1.2941e+00, -1.6117e+00,  1.7384e+00,  ...,  4.3916e-01,
           -1.2751e+00, -5.7244e-01],
          [ 3.9854e-01, -8.7394e-01,  2.0281e+00,  ...,  3.8379e-02,
           -6.1575e-01,  5.4254e-01],
          ...,
          [-4.5115e-01, -3.6644e-01, -9.4586e-01,  ..., -1.2554e-01,
            2.8711e-01,  3.1956e-01],
          [-9.4513e-01,  5.3026e-01,  7.1446e-01,  ...,  2.9875e+00,
           -8.9192e-01,  8.4269e-01],
          [ 1.1424e+00,  5.4257e-01,  3.2919e-02,  ..., -1.7755e-01,
            2.2702e-01,  6.6473e-01]],

         [[-1.0433e+00,  8.8607e-01,  1.3225e+00,  ..., -5.8009e-01,
            8.2273e-02,  8.8970e-02],
          [ 5.2234e-01, -4.7690e-01,  1.2335e-01,  ..., -6.9824e-01,
           -1.2020e+00, -1.1763e+00],
          [-1.2462e+00, -1.1291e-02, -4.2222e-02,  ...,  3.3486e-01,
            6.2738e-02,  2.0671e+00],
          ...,
          [-6.5995e-01,  1.1704e+00,  9.1433e-01,  ...,  2.5771e-02,
            1.4814e+00,  1.2737e-01],
          [ 3.4179e-01,  1.4259e+00,  2.5998e-01,  ..., -8.1194e-01,
            1.0044e+00, -4.6482e-01],
          [ 1.0553e+00, -3.0865e-01, -5.2562e-01,  ..., -3.1224e-01,
           -1.4488e+00,  3.5115e-01]],

         [[-6.5305e-01,  3.2041e+00, -8.7197e-01,  ...,  8.3033e-01,
            4.8570e-01,  1.7511e-01],
          [-2.5058e-01,  1.7956e+00, -4.1616e-01,  ...,  3.6998e-02,
           -1.6247e-01, -1.0354e+00],
          [ 1.0436e+00,  1.7690e-01, -5.7779e-01,  ..., -1.2321e+00,
            8.5321e-01,  1.5080e+00],
          ...,
          [ 1.5087e+00,  5.7879e-02, -1.6408e+00,  ...,  4.7006e-01,
           -1.7988e-01, -9.3684e-01],
          [ 1.5043e+00,  9.5380e-01, -6.9563e-01,  ...,  8.1419e-01,
           -9.7050e-01, -1.5380e-01],
          [-5.3725e-01, -1.7450e-01,  1.1852e+00,  ...,  1.7710e-01,
           -6.3294e-01,  1.6872e-01]]],


        [[[ 1.9603e-01, -1.6891e+00,  8.0873e-01,  ..., -4.4775e-01,
           -2.1522e-01, -1.1590e-01],
          [-8.3213e-01, -2.9785e-02,  4.2612e-01,  ...,  3.4570e-01,
           -1.9253e+00,  1.4409e+00],
          [-5.4972e-01,  5.6047e-01, -5.7543e-01,  ...,  1.1894e+00,
           -3.8629e-01,  6.9385e-02],
          ...,
          [ 5.0974e-01, -1.1303e+00,  9.5037e-01,  ...,  2.6223e-02,
           -7.7417e-01, -1.6434e-01],
          [ 1.0510e+00, -1.0833e-01, -2.6753e+00,  ...,  1.2547e+00,
            4.2441e-01, -1.6332e+00],
          [-8.3794e-01,  2.2824e+00, -1.0109e+00,  ..., -1.7651e+00,
           -1.0836e+00,  2.7319e-01]],

         [[-2.4120e+00, -4.3214e-01, -6.3075e-01,  ...,  2.8971e-01,
           -5.6491e-02, -1.0041e+00],
          [ 8.2335e-02, -3.8088e-02, -5.1951e-01,  ...,  4.5253e-01,
            4.6697e-01, -2.0568e+00],
          [-5.8801e-01, -1.3690e+00,  8.1746e-01,  ..., -4.3432e-01,
           -5.1702e-01, -1.7009e+00],
          ...,
          [-3.7792e-01, -3.5522e-03,  2.4129e+00,  ..., -6.4406e-01,
            1.3251e+00, -9.4023e-02],
          [-8.4223e-02,  5.7053e-02, -1.2574e-01,  ...,  6.5649e-01,
           -1.4324e+00,  1.6781e+00],
          [-1.7209e-01, -9.8804e-02, -3.6555e-01,  ..., -1.0103e+00,
            1.9645e+00,  3.9248e-01]],

         [[-2.2872e+00, -5.5936e-01,  2.5653e-01,  ...,  5.0004e-01,
            1.0525e+00, -1.7588e-01],
          [ 1.1948e-01, -2.0766e+00, -4.7760e-01,  ...,  1.0470e+00,
            7.1625e-01, -1.6456e-01],
          [-2.1393e+00,  3.4796e-01, -4.4012e-01,  ..., -6.8943e-01,
           -1.5674e-01, -5.8238e-01],
          ...,
          [-1.4445e+00,  2.2739e+00,  5.8728e-02,  ..., -8.4974e-01,
           -9.6034e-01,  6.3553e-01],
          [ 9.1072e-01,  1.0182e+00, -5.1382e-03,  ...,  3.3224e-02,
           -2.5886e-02, -1.9414e+00],
          [-2.1205e-01, -6.2315e-01, -1.1816e+00,  ...,  2.4435e-01,
            1.0656e+00,  5.7412e-01]]]])}, 'action': tensor([[1],
        [2],
        [5],
        [8]]), 'reward': tensor([[0.4963],
        [0.6468],
        [0.5695],
        [0.7422]]), 'done': tensor([False, False, False, False])}

We can see that the structure process function need to be fully implemented (like the function stack). This code is actually not clear, and due to hard coding, if you need to support more data types (such as integer), you must make special modifications to the function.

Stack With TreeTensor API

The same workflow can be implemented with treetensor API like the code below.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch

import treetensor.torch as ttorch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]
# execute `stack` op
data = [ttorch.tensor(d) for d in data]
stacked_data = ttorch.stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data.obs.image.shape == (B, 3, 32, 32)
assert stacked_data.action.shape == (B, 1)
assert stacked_data.reward.shape == (B, 1)
assert stacked_data.done.shape == (B,)
assert stacked_data.done.dtype == torch.bool

The output should be like below, and the assertion statements can be all passed as well.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
<Tensor 0x7f292b7eb210>
├── 'action' --> tensor([[8],
│                        [3],
│                        [3],
│                        [4]])
├── 'done' --> tensor([False, False, False, False])
├── 'obs' --> <Tensor 0x7f292b8b6490>
│   ├── 'image' --> tensor([[[[-0.4818,  0.0614, -0.2000,  ...,  0.8999, -1.2232,  1.6721],
│   │                         [ 1.9980, -0.8405,  0.6540,  ...,  1.7764, -0.0639, -0.2902],
│   │                         [ 0.1980,  1.1496, -1.2959,  ...,  0.7035, -1.3997, -0.9859],
│   │                         ...,
│   │                         [ 0.2021,  0.2702, -0.3861,  ..., -0.4255,  1.2922, -0.4338],
│   │                         [-1.7705, -1.8797, -0.0472,  ..., -0.5711,  0.6340, -1.9222],
│   │                         [ 0.5078, -0.7376, -0.1259,  ..., -1.2985,  0.7058, -1.7158]],
│   │               
│   │                        [[ 1.4858,  1.4058, -1.6203,  ..., -2.0498,  0.1949,  1.2816],
│   │                         [-0.2632, -0.5347,  0.9677,  ...,  1.0978,  0.5936, -0.4696],
│   │                         [ 0.7017,  0.0058, -0.3265,  ...,  1.1440,  0.7496,  0.7700],
│   │                         ...,
│   │                         [ 0.7236,  2.2621,  1.6976,  ..., -0.2528,  0.0922,  0.0625],
│   │                         [ 0.1047, -1.5386,  0.1846,  ...,  0.3682,  0.9567, -2.1020],
│   │                         [-0.6356, -0.1127, -0.5688,  ...,  0.0242,  0.7062, -0.0134]],
│   │               
│   │                        [[-0.7178, -0.8888, -0.5877,  ...,  1.8334, -0.8736, -0.0279],
│   │                         [ 0.8544,  1.5760, -0.2241,  ..., -0.0134, -0.7586, -0.2082],
│   │                         [-0.7742,  0.0043,  0.5172,  ..., -2.7700,  0.7682, -0.2444],
│   │                         ...,
│   │                         [-1.0856, -0.0300, -0.2643,  ..., -0.2202, -1.0572,  1.3172],
│   │                         [-0.9744, -1.1653,  1.9260,  ..., -0.1276,  1.0593, -0.9854],
│   │                         [ 2.9881,  0.1415,  0.1750,  ..., -0.1766,  0.1721, -0.6121]]],
│   │               
│   │               
│   │                       [[[ 0.8580,  0.9120,  0.7760,  ...,  0.1763,  0.6335,  0.7007],
│   │                         [-1.1302, -1.8961, -0.3683,  ...,  1.2136, -0.2449, -0.7995],
│   │                         [ 0.8152, -0.2098,  1.0887,  ...,  1.7743, -0.8910, -1.2528],
│   │                         ...,
│   │                         [ 1.1765,  1.1983,  1.7004,  ...,  1.7091, -0.5024,  1.0783],
│   │                         [-0.8448,  0.4810, -0.0855,  ...,  1.3065,  0.7230, -0.3457],
│   │                         [-1.9113, -0.6781, -1.7346,  ..., -0.1186,  2.2734,  0.2485]],
│   │               
│   │                        [[ 0.9367, -1.9198,  0.3563,  ..., -0.2069, -0.2165, -0.5298],
│   │                         [ 0.6911,  0.4246, -0.0455,  ...,  0.7091, -0.3573,  0.5733],
│   │                         [-0.0295,  0.0800,  0.4336,  ..., -0.0860, -0.2170, -0.7907],
│   │                         ...,
│   │                         [-1.6877, -1.5251, -0.1875,  ..., -0.5495,  0.2324, -2.0532],
│   │                         [ 0.2008, -0.6604, -0.3721,  ..., -0.7068, -1.4813,  2.1455],
│   │                         [ 0.3628, -1.0283, -1.8975,  ...,  0.4064, -1.6480, -0.5772]],
│   │               
│   │                        [[ 0.1391,  2.7230,  1.6127,  ..., -1.5901,  0.0738, -0.5597],
│   │                         [-0.3550,  0.6466,  0.3285,  ..., -0.4007, -0.2001, -0.1061],
│   │                         [ 1.3476,  2.4513,  0.5282,  ..., -0.6894, -0.4245, -0.7765],
│   │                         ...,
│   │                         [ 2.1364,  2.2842, -1.8798,  ...,  0.1278,  1.3280, -0.7495],
│   │                         [-0.2322, -3.1975, -0.7693,  ..., -0.5218, -0.8755,  1.1567],
│   │                         [ 0.0489, -1.7127, -0.4228,  ..., -1.8188,  0.3711,  1.5242]]],
│   │               
│   │               
│   │                       [[[ 0.1691, -1.2590, -2.0192,  ...,  0.2920,  0.5471, -1.5846],
│   │                         [ 1.4142, -0.0090,  0.0754,  ...,  1.2489, -2.1405,  0.5645],
│   │                         [ 1.0901, -1.4021,  0.4625,  ..., -0.9278,  0.7467,  0.1996],
│   │                         ...,
│   │                         [-0.9103, -0.9424,  0.7065,  ..., -0.0110, -0.7272,  0.7633],
│   │                         [-0.9072,  1.4261, -0.3166,  ..., -0.1316,  0.6759, -0.5940],
│   │                         [ 1.2139, -0.6721,  0.2343,  ...,  0.9316, -0.8617,  2.0422]],
│   │               
│   │                        [[ 0.2848, -0.0595,  0.1320,  ..., -0.6899,  0.7797, -1.8241],
│   │                         [-1.2803, -0.8680, -1.4280,  ...,  0.4864,  1.3518, -0.8249],
│   │                         [ 0.0747,  1.3147,  1.2281,  ...,  0.4065, -0.2759,  0.1069],
│   │                         ...,
│   │                         [ 0.3170,  1.0171,  0.8855,  ..., -0.0045,  0.5486,  0.3613],
│   │                         [ 0.1593, -0.0686, -0.6079,  ...,  0.8011,  2.6053,  0.0667],
│   │                         [-2.7334, -0.0870,  1.1399,  ...,  0.2792,  1.5351,  2.5477]],
│   │               
│   │                        [[-1.9834,  1.3776,  0.5873,  ..., -0.6822, -1.0092,  2.0214],
│   │                         [ 2.0267, -0.6797,  1.9385,  ...,  1.0604,  2.1556,  1.6940],
│   │                         [-1.4925,  0.3790,  1.7753,  ...,  0.9691, -0.2756,  0.3984],
│   │                         ...,
│   │                         [-1.5448,  0.6799, -0.3676,  ..., -0.8264, -0.5285,  1.3893],
│   │                         [-0.0999,  1.6444, -0.5071,  ...,  0.8431,  1.0526,  1.0685],
│   │                         [-1.1875, -1.0998,  0.7626,  ..., -0.7700, -1.6407,  0.2429]]],
│   │               
│   │               
│   │                       [[[ 0.3591, -0.3646, -0.8619,  ..., -0.6444,  1.7634, -0.5952],
│   │                         [-1.9774, -0.7415,  1.4967,  ...,  0.0244, -0.4483,  0.6853],
│   │                         [-0.1830, -0.9443, -0.6225,  ..., -0.3044,  0.4169, -1.3675],
│   │                         ...,
│   │                         [ 0.4516,  0.2624,  0.0766,  ...,  0.1745, -0.2693,  1.5344],
│   │                         [ 0.3985, -0.9245, -0.4268,  ..., -0.8736,  1.8599,  0.0412],
│   │                         [-0.2072,  2.1750, -0.0656,  ..., -0.9639,  0.5582, -0.8620]],
│   │               
│   │                        [[-0.3643, -0.1554,  0.7326,  ...,  0.9781, -0.0116,  1.1131],
│   │                         [-0.2833,  1.1247,  0.1839,  ..., -0.7435,  0.5654,  1.3504],
│   │                         [ 0.9207,  0.0829, -0.2563,  ...,  0.7910, -1.2189,  1.6276],
│   │                         ...,
│   │                         [-0.2156,  0.6012,  0.8204,  ...,  1.0305,  0.1277,  1.6339],
│   │                         [ 1.1937, -0.2664,  0.8933,  ..., -0.0296,  0.0098,  0.0052],
│   │                         [ 1.7989, -0.8900, -0.1160,  ...,  1.2997,  0.2163, -0.8628]],
│   │               
│   │                        [[-0.0758, -0.0892, -0.7936,  ..., -0.3783, -1.3074, -3.3614],
│   │                         [ 0.8473, -1.3261,  0.7216,  ..., -0.8144,  0.5457, -0.3377],
│   │                         [-1.5004, -0.1523,  0.5245,  ..., -0.2075,  1.2055,  1.1126],
│   │                         ...,
│   │                         [-2.2294, -0.0675, -0.7088,  ..., -0.3955, -0.8321,  0.8382],
│   │                         [ 0.2396, -0.8568,  0.6367,  ...,  0.9191, -0.1329,  0.1222],
│   │                         [ 0.1633,  0.0164, -0.5553,  ..., -0.6766,  0.1396, -0.8661]]]])
│   └── 'scalar' --> tensor([[-8.5551e-01,  5.4546e-01,  1.1029e+00, -1.1888e+00,  8.9608e-01,
│                             -4.4977e-02,  5.6021e-01, -2.8788e-01,  4.3632e-01,  5.2148e-01,
│                             -2.4079e+00,  1.6206e-03],
│                            [ 9.8680e-01,  1.1020e+00, -7.8887e-02,  1.1543e-01, -6.2044e-03,
│                             -6.8349e-03, -6.9648e-01, -1.6404e+00,  2.8663e-01, -1.8853e-01,
│                             -6.3422e-01,  9.2385e-01],
│                            [-1.2654e+00,  1.6770e-01, -6.1825e-01,  5.8132e-01, -6.6159e-01,
│                              1.4355e+00,  3.8890e-01, -2.5180e+00, -1.7002e+00, -6.1562e-01,
│                              4.0966e-01,  4.0351e-01],
│                            [ 4.8671e-01, -6.7668e-01,  1.3168e+00,  7.3319e-01, -8.4078e-01,
│                              4.9478e-01,  1.8344e+00,  2.2663e-01, -2.0824e+00, -5.5326e-01,
│                             -1.0055e-01, -1.6759e+00]])
└── 'reward' --> tensor([[0.8328],
                         [0.1868],
                         [0.9141],
                         [0.9749]])

This code looks much simpler and clearer.