Stack Structured Data

When the tensors form the tree structures, they are often needed to be stacked together, like the torch.stack() implemented in torch.

Stack With Native PyTorch API

Here is the common code implement with native pytorch API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]


# execute `stack` op
def stack(data, dim):
    elem = data[0]
    if isinstance(elem, torch.Tensor):
        return torch.stack(data, dim)
    elif isinstance(elem, dict):
        return {k: stack([item[k] for item in data], dim) for k in elem.keys()}
    elif isinstance(elem, bool):
        return torch.BoolTensor(data)
    else:
        raise TypeError("not support elem type: {}".format(type(elem)))


stacked_data = stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data['obs']['image'].shape == (B, 3, 32, 32)
assert stacked_data['action'].shape == (B, 1)
assert stacked_data['reward'].shape == (B, 1)
assert stacked_data['done'].shape == (B,)
assert stacked_data['done'].dtype == torch.bool

The output should be like below, and the assertion statements can be all passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
{'obs': {'scalar': tensor([[ 2.4406, -0.2326,  0.3159,  0.7629,  0.2003, -1.0960,  1.2239,  1.3413,
         -0.1980,  0.6515, -0.8431, -0.4854],
        [ 0.5706, -0.5807, -0.5703,  0.8545,  2.6043, -0.4146,  0.1899, -0.8225,
         -1.2208, -1.4324,  1.1359,  0.1634],
        [ 0.4165,  0.4536,  1.2908,  0.5528,  0.7488, -0.0271,  0.5562, -1.0044,
         -0.9215, -0.5183, -0.2449, -0.2035],
        [-1.3134, -0.3603, -0.4322, -1.0080, -0.7638, -1.3092,  0.1570, -0.5483,
          1.2298,  0.3452,  0.6044, -1.3629]]), 'image': tensor([[[[ 6.2303e-01,  3.7907e-01,  9.7125e-01,  ...,  3.7841e-01,
           -1.5278e+00, -1.2475e+00],
          [-5.1246e-01, -5.1847e-01,  3.8174e-01,  ..., -4.4257e-01,
            7.6749e-01,  2.1446e-01],
          [-1.1490e+00,  1.2335e+00, -9.0089e-01,  ...,  4.7820e-01,
           -6.2536e-01,  8.9593e-01],
          ...,
          [ 2.9318e-01,  2.3394e-01,  1.2432e+00,  ...,  2.6247e+00,
            8.2226e-01,  7.2769e-01],
          [-2.4359e-01, -5.6062e-01,  5.7370e-01,  ...,  3.2803e-01,
            6.7294e-01,  2.4750e-01],
          [-8.8385e-01, -1.5329e+00, -4.7090e-01,  ...,  9.9067e-01,
           -1.7397e+00,  5.3907e-01]],

         [[ 1.5124e+00, -5.3371e-01, -6.6217e-02,  ...,  1.2442e-01,
           -1.7005e+00, -8.2398e-02],
          [ 2.1172e+00,  2.0715e+00, -1.6335e+00,  ..., -1.1585e+00,
            9.6128e-02,  1.5859e+00],
          [ 7.3566e-02, -3.5602e-01, -9.1113e-01,  ..., -1.9116e-01,
            6.8653e-01,  4.5308e-01],
          ...,
          [ 7.9157e-01,  1.2348e+00,  5.7500e-01,  ...,  3.7405e-01,
            7.8314e-01, -1.0558e+00],
          [-9.2993e-01, -1.5717e+00, -3.4220e-01,  ...,  1.5684e+00,
            6.3501e-01, -9.3426e-01],
          [-8.4097e-01,  1.4319e+00,  3.8686e-01,  ...,  8.5413e-01,
            8.1418e-01,  3.9343e-01]],

         [[ 2.6349e-01,  1.1747e+00, -6.5949e-01,  ..., -1.3048e-01,
            1.0476e+00,  6.0574e-01],
          [ 7.8026e-01, -1.4457e+00,  5.2158e-02,  ..., -1.0862e+00,
            6.5111e-01,  6.1646e-02],
          [-1.3347e+00,  1.0432e+00, -7.0669e-02,  ..., -4.0402e-01,
            6.0034e-01, -2.5421e-01],
          ...,
          [-5.2352e-01, -2.5789e-01, -9.1947e-01,  ..., -1.7019e+00,
           -1.4448e+00,  7.0648e-01],
          [-5.4992e-01, -1.2591e+00, -2.3283e-01,  ...,  3.2222e-01,
            8.0506e-01,  1.8760e-02],
          [-1.8659e+00,  1.8008e+00, -7.2791e-01,  ..., -6.4481e-01,
            1.0983e+00, -1.0431e+00]]],


        [[[ 1.1491e-01, -2.4217e+00, -2.7502e-01,  ..., -9.4207e-01,
           -9.1544e-01,  2.1274e+00],
          [-2.6817e-02,  2.1101e+00, -1.8743e+00,  ..., -2.4443e+00,
           -9.9995e-01, -1.0752e-02],
          [-1.0347e+00,  1.9678e+00,  5.0758e-01,  ...,  1.4647e+00,
           -4.4884e-01,  6.2556e-01],
          ...,
          [ 1.9945e-01, -2.6560e-01, -1.8475e+00,  ...,  2.1487e+00,
           -1.2312e-01,  9.5470e-01],
          [ 2.0637e+00,  6.8295e-02, -1.1399e+00,  ...,  5.9918e-01,
           -3.6063e-01,  4.8671e-02],
          [ 2.4535e-01, -1.0596e+00, -8.4563e-01,  ...,  3.9147e-01,
            1.1692e+00,  4.7972e-01]],

         [[-2.3187e-01, -4.5913e-01,  1.6952e+00,  ..., -1.2197e-01,
           -4.7615e-01, -8.3907e-01],
          [ 2.7508e-01, -1.7677e-01,  7.2943e-01,  ...,  1.0739e+00,
            1.2985e-01, -1.0517e+00],
          [ 1.1066e-01, -1.0491e+00, -6.9424e-01,  ..., -1.3292e+00,
           -9.1646e-01, -7.7565e-01],
          ...,
          [-1.0676e+00,  4.3200e-01,  1.9565e-01,  ...,  4.6007e-01,
            5.8693e-01,  1.0253e+00],
          [ 1.5163e+00,  4.9878e-01, -1.0802e+00,  ...,  3.9201e-01,
           -1.9970e-01,  2.3266e-01],
          [-1.3408e+00, -3.6473e-02, -1.6834e-02,  ..., -4.2041e-01,
           -1.2722e-01, -9.4917e-01]],

         [[ 7.9040e-01,  2.1561e+00, -4.5664e-01,  ...,  5.7663e-01,
            3.3143e-01,  2.2916e+00],
          [ 2.0930e-01,  4.4103e-01,  3.3097e-01,  ..., -1.6918e+00,
            3.8298e-01, -2.2074e-01],
          [ 5.9926e-01, -2.7160e-01, -2.3685e-01,  ...,  2.6498e-01,
           -1.1259e+00, -1.6201e+00],
          ...,
          [ 1.1701e+00,  4.6029e-01,  1.3971e+00,  ..., -2.7892e-01,
            1.2556e+00, -1.1946e-01],
          [-1.9068e-01, -1.6099e+00, -4.5448e-01,  ..., -8.1447e-01,
           -1.6314e+00, -1.0942e+00],
          [ 3.6074e-01, -9.7278e-01, -1.6456e-01,  ...,  1.5826e-01,
            7.1130e-01, -3.4926e-01]]],


        [[[ 5.6811e-01,  1.0578e+00,  6.9692e-02,  ...,  4.0743e-01,
           -1.2546e+00, -1.7926e+00],
          [ 4.0721e-01,  2.0046e+00, -3.5012e-01,  ..., -1.7195e-01,
           -3.1115e-01,  9.2187e-03],
          [-1.3385e+00,  3.3045e-01, -5.7660e-01,  ...,  1.2781e+00,
           -3.6560e-02, -9.0471e-01],
          ...,
          [ 6.4635e-01,  1.4771e+00,  1.4980e-01,  ..., -1.1994e+00,
            2.2912e-01,  2.8611e-02],
          [-1.9118e+00, -2.1137e-01,  1.1292e+00,  ..., -4.0436e-01,
           -5.7048e-01, -3.2108e-01],
          [-1.6804e+00,  1.4075e-01, -1.7652e+00,  ...,  1.8895e+00,
           -4.8921e-01, -6.4779e-01]],

         [[-1.7040e-01,  4.2093e-01,  1.2370e+00,  ..., -1.6426e+00,
            1.4285e-01,  2.8646e-01],
          [-9.4651e-02, -1.4542e+00, -9.9900e-02,  ..., -4.5496e-01,
           -5.9735e-01,  5.0281e-01],
          [ 3.3284e-01, -3.1849e-01,  2.0653e-01,  ...,  6.2924e-01,
           -2.0909e-01, -9.3633e-01],
          ...,
          [-2.1661e-01, -1.8968e+00,  7.0910e-01,  ...,  9.1273e-01,
           -3.3205e-01, -1.8705e+00],
          [ 8.4481e-01, -5.0406e-01,  1.2868e+00,  ...,  1.6211e+00,
            8.4574e-01,  2.8751e+00],
          [ 1.0029e+00, -9.6006e-01, -8.6510e-01,  ...,  1.3225e+00,
            1.3517e+00, -1.1379e+00]],

         [[ 1.9638e-01, -8.4702e-01, -2.2450e-01,  ...,  6.8658e-01,
           -1.4613e+00, -1.1684e+00],
          [-1.5022e+00,  4.5041e-02,  1.6958e+00,  ..., -7.1676e-01,
            8.7583e-01,  1.5728e-01],
          [ 4.6696e-01, -3.6100e-01, -1.2866e-01,  ..., -1.1120e+00,
           -2.1483e-01,  2.5440e-01],
          ...,
          [-3.0620e-01,  4.8296e-01, -9.9256e-01,  ...,  1.3296e+00,
            5.3094e-01,  2.0354e-01],
          [-1.2197e+00, -2.7189e-01,  2.1092e-01,  ..., -1.0298e-02,
            8.6254e-01,  3.5636e-01],
          [ 8.2118e-01, -1.2405e+00, -1.0597e+00,  ...,  4.1031e-01,
           -2.2876e+00,  2.4251e-03]]],


        [[[ 8.5753e-01,  9.0316e-01, -1.3872e+00,  ...,  1.4939e+00,
            8.1304e-01, -3.1781e-01],
          [ 1.9352e+00, -1.8050e-01, -5.1657e-01,  ..., -7.1810e-01,
           -7.9894e-01, -9.6086e-02],
          [ 1.6380e+00,  9.9091e-01,  4.1572e-01,  ..., -3.0379e-01,
           -1.8658e+00,  1.0138e+00],
          ...,
          [-1.8735e+00, -9.2775e-01, -6.4842e-01,  ..., -7.8636e-01,
            4.7128e-01,  1.3403e+00],
          [-4.1145e-01, -1.5986e+00, -3.3623e-01,  ..., -1.6311e+00,
           -1.6678e+00, -8.4054e-02],
          [ 5.9511e-01, -7.6527e-01,  5.2863e-02,  ..., -8.2568e-01,
           -3.6506e-01, -7.5336e-01]],

         [[-7.8909e-02,  3.2635e-01,  2.2626e-01,  ..., -9.1411e-01,
           -2.1093e-01,  2.5400e-01],
          [-5.5747e-02, -7.9989e-01,  2.7386e-01,  ...,  1.3085e+00,
            8.0231e-01, -1.5837e+00],
          [ 5.2100e-01,  1.1549e+00, -1.0860e+00,  ...,  5.2940e-01,
           -9.0651e-02, -8.0337e-01],
          ...,
          [-1.3527e+00, -2.3600e+00, -7.8898e-01,  ..., -4.9543e-01,
           -2.0347e+00,  1.4247e+00],
          [ 1.4348e+00,  4.5750e-01,  8.8953e-02,  ..., -6.5069e-02,
           -4.9114e-01, -4.1113e-02],
          [-1.2637e+00, -1.1843e+00, -1.4069e-01,  ...,  1.1163e+00,
            1.4006e+00, -1.4831e+00]],

         [[ 1.6917e+00, -9.0099e-01,  3.9538e-01,  ..., -1.0712e+00,
           -2.3853e-01,  9.1446e-02],
          [-1.3546e+00,  3.3503e-01,  3.0341e-01,  ..., -1.7419e+00,
            3.7567e-01,  2.1281e+00],
          [ 1.2696e+00,  1.0401e+00, -2.0788e+00,  ...,  1.1066e+00,
            1.5788e-01,  6.8015e-01],
          ...,
          [-8.1503e-01, -1.8368e-01, -7.7326e-01,  ..., -9.1290e-01,
           -1.1748e-01,  6.9413e-01],
          [-8.8045e-01,  7.1265e-01,  6.7957e-01,  ..., -3.4913e-01,
            4.1585e-01, -2.2507e-01],
          [ 4.8887e-01,  4.1726e-01,  2.3335e-01,  ..., -1.0586e-01,
           -7.4991e-01,  5.6071e-01]]]])}, 'action': tensor([[0],
        [1],
        [2],
        [2]]), 'reward': tensor([[0.4662],
        [0.4125],
        [0.8629],
        [0.4453]]), 'done': tensor([False, False, False, False])}

We can see that the structure process function need to be fully implemented (like the function stack). This code is actually not clear, and due to hard coding, if you need to support more data types (such as integer), you must make special modifications to the function.

Stack With TreeTensor API

The same workflow can be implemented with treetensor API like the code below.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch

import treetensor.torch as ttorch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]
# execute `stack` op
data = [ttorch.tensor(d) for d in data]
stacked_data = ttorch.stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data.obs.image.shape == (B, 3, 32, 32)
assert stacked_data.action.shape == (B, 1)
assert stacked_data.reward.shape == (B, 1)
assert stacked_data.done.shape == (B,)
assert stacked_data.done.dtype == torch.bool

The output should be like below, and the assertion statements can be all passed as well.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
<Tensor 0x7fd3dd9ec410>
├── 'action' --> tensor([[5],
│                        [8],
│                        [7],
│                        [0]])
├── 'done' --> tensor([False, False, False, False])
├── 'obs' --> <Tensor 0x7fd3ddab9610>
│   ├── 'image' --> tensor([[[[-9.5575e-01,  3.0435e-01,  9.6322e-01,  ...,  8.6874e-01,
│   │                           7.9456e-01, -1.6487e+00],
│   │                         [ 1.7612e+00,  1.2713e+00, -1.2655e+00,  ..., -9.6476e-02,
│   │                           7.8624e-01, -1.2777e+00],
│   │                         [ 1.7776e-02,  1.4540e+00, -3.9624e-02,  ...,  2.3202e+00,
│   │                          -8.3055e-01,  1.7327e+00],
│   │                         ...,
│   │                         [-4.6596e-01, -1.1270e-01, -2.3538e+00,  ...,  1.1291e+00,
│   │                          -4.2642e-01, -1.0940e+00],
│   │                         [-4.8986e-02, -1.3458e+00,  2.5237e-02,  ...,  1.2630e+00,
│   │                           1.2025e+00,  7.1687e-01],
│   │                         [-1.1953e+00,  5.4740e-01,  4.3433e-01,  ...,  1.2152e+00,
│   │                          -1.8748e-01, -1.3608e+00]],
│   │               
│   │                        [[-6.8711e-01, -1.0439e+00,  5.1751e-01,  ...,  1.3581e+00,
│   │                           8.9694e-01, -8.0825e-01],
│   │                         [ 2.6084e+00, -4.8190e-01,  9.9855e-01,  ...,  6.8717e-01,
│   │                           1.6449e-01,  1.0780e+00],
│   │                         [-7.9782e-01,  1.5154e+00,  5.3518e-01,  ...,  2.2233e+00,
│   │                           5.8943e-01, -1.0916e+00],
│   │                         ...,
│   │                         [ 1.1781e+00,  1.8864e+00,  5.7001e-01,  ...,  4.0592e-01,
│   │                           6.7409e-01, -5.1948e-01],
│   │                         [ 6.6307e-01,  1.0433e+00, -5.7977e-01,  ..., -1.0678e+00,
│   │                          -7.2121e-01,  4.2973e-01],
│   │                         [-8.5520e-02, -2.1345e-01, -4.5899e-01,  ..., -1.0698e+00,
│   │                          -1.1945e-02,  3.8756e-01]],
│   │               
│   │                        [[-1.1346e+00,  2.3990e-01, -5.1426e-01,  ...,  4.5785e-01,
│   │                           6.5396e-01,  8.0841e-02],
│   │                         [-7.1712e-01, -1.4010e+00,  1.7004e-01,  ..., -1.1046e+00,
│   │                           6.3607e-01, -1.6358e+00],
│   │                         [-1.4234e-01, -7.7116e-01,  8.2450e-01,  ...,  1.8531e-01,
│   │                          -1.9638e+00, -2.3151e-01],
│   │                         ...,
│   │                         [-6.9609e-01,  8.5523e-01,  5.3044e-01,  ..., -3.1698e+00,
│   │                          -4.8980e-01,  6.0223e-01],
│   │                         [ 7.5207e-01, -6.5455e-01,  1.3681e+00,  ...,  1.7894e-02,
│   │                          -1.6630e+00,  3.8090e-01],
│   │                         [-8.9992e-01,  8.1974e-01, -9.9610e-01,  ...,  2.9078e-01,
│   │                           1.3935e+00, -9.3316e-01]]],
│   │               
│   │               
│   │                       [[[ 6.5041e-01, -2.2996e-01, -5.8957e-01,  ..., -5.9129e-01,
│   │                           7.8139e-01,  1.9445e-01],
│   │                         [ 5.6126e-01,  8.7148e-01,  1.0681e-01,  ..., -7.4448e-01,
│   │                          -4.9118e-01,  1.5250e-01],
│   │                         [ 9.7633e-01, -1.3232e+00, -6.3940e-01,  ...,  1.2407e+00,
│   │                          -1.4975e+00,  3.7790e-01],
│   │                         ...,
│   │                         [ 5.9587e-01,  1.2750e+00, -1.8124e-01,  ...,  1.7338e+00,
│   │                           3.1224e-01, -9.6997e-01],
│   │                         [-5.5004e-01,  2.8512e-01,  2.1334e+00,  ...,  2.3273e-01,
│   │                          -2.1724e+00, -7.6146e-01],
│   │                         [-3.7288e-01, -1.2362e+00,  6.6173e-01,  ..., -1.6114e+00,
│   │                           1.2461e+00,  1.3608e-01]],
│   │               
│   │                        [[-1.1628e+00,  1.0934e-01, -1.1415e+00,  ..., -3.6852e-01,
│   │                           1.0359e-01, -9.8663e-01],
│   │                         [ 8.3429e-01,  7.3595e-01, -3.1934e-01,  ...,  1.0849e+00,
│   │                          -7.4927e-01, -1.4136e+00],
│   │                         [ 1.4151e+00,  1.1724e+00,  1.9125e+00,  ...,  1.3181e+00,
│   │                           4.3245e-01,  1.9049e-01],
│   │                         ...,
│   │                         [-1.1057e+00,  1.1725e+00,  4.5918e-02,  ...,  1.5341e-01,
│   │                          -4.8413e-01, -2.4371e-01],
│   │                         [ 8.1626e-01,  1.4059e+00,  1.6425e+00,  ...,  2.8085e-02,
│   │                           1.0743e+00,  1.2348e-01],
│   │                         [ 1.7799e-01,  1.1224e-01, -1.1090e+00,  ...,  1.2087e+00,
│   │                           1.0709e+00,  7.2011e-01]],
│   │               
│   │                        [[-9.4435e-01, -1.0160e+00, -2.5757e-02,  ...,  4.0155e-01,
│   │                          -5.7387e-01,  1.2495e+00],
│   │                         [ 5.9113e-01, -7.5923e-01,  6.2401e-01,  ...,  1.1711e+00,
│   │                           3.5431e-01,  1.1396e+00],
│   │                         [ 1.4228e+00, -6.1841e-02,  5.4501e-01,  ..., -9.4527e-01,
│   │                           1.3437e+00,  1.7718e-01],
│   │                         ...,
│   │                         [-9.9858e-01,  9.8054e-02,  7.6392e-01,  ..., -7.0249e-01,
│   │                          -1.4028e-01, -4.7191e-01],
│   │                         [-8.9457e-01, -3.7717e-01,  1.5827e-01,  ..., -4.6045e-02,
│   │                          -9.5846e-01, -7.0488e-01],
│   │                         [-1.4611e+00,  1.6067e+00,  1.5282e+00,  ..., -9.0787e-01,
│   │                          -1.5427e+00,  2.9319e-03]]],
│   │               
│   │               
│   │                       [[[-1.6290e-01,  7.0655e-01,  1.4779e+00,  ...,  3.0759e-01,
│   │                           1.8696e+00, -8.3246e-01],
│   │                         [-1.1398e-01,  1.3689e+00, -3.1618e-01,  ..., -1.7842e+00,
│   │                           6.2403e-01, -1.2755e+00],
│   │                         [-5.5212e-01, -5.8304e-02,  3.9067e-01,  ...,  6.1067e-01,
│   │                          -4.2736e-01, -1.0182e+00],
│   │                         ...,
│   │                         [-2.0004e+00, -6.3674e-01, -1.2921e+00,  ...,  3.4109e-02,
│   │                           2.3377e-01,  7.2915e-01],
│   │                         [ 1.6659e+00,  8.0701e-01, -5.8565e-01,  ..., -1.6547e+00,
│   │                           5.1596e-01,  3.1387e-01],
│   │                         [ 1.0859e+00, -5.2274e-01,  2.8633e-01,  ..., -8.1373e-01,
│   │                          -5.5644e-01,  2.3325e-01]],
│   │               
│   │                        [[ 1.6157e+00, -1.3032e+00, -3.8967e-01,  ...,  1.2310e+00,
│   │                          -4.9855e-01,  2.9868e-01],
│   │                         [-7.5919e-02,  8.1238e-01,  2.2405e-01,  ..., -1.2111e+00,
│   │                          -1.2701e+00, -4.6776e-01],
│   │                         [-5.6995e-01, -9.9782e-01, -9.8708e-01,  ..., -7.0472e-02,
│   │                          -1.2038e+00,  1.6894e-01],
│   │                         ...,
│   │                         [ 1.3762e+00,  4.0974e-01, -1.1003e+00,  ...,  2.7685e-01,
│   │                          -7.4412e-01,  6.4947e-01],
│   │                         [-7.6426e-01, -2.1287e+00, -7.5558e-01,  ..., -4.8787e-01,
│   │                           3.9942e-01, -1.3180e-01],
│   │                         [-1.1687e+00,  9.5125e-01,  8.0836e-01,  ..., -1.5050e+00,
│   │                           7.1248e-01, -6.0008e-01]],
│   │               
│   │                        [[ 4.7317e-01, -4.8063e-01,  8.4052e-01,  ..., -1.4147e+00,
│   │                          -3.8455e-01,  1.9304e-01],
│   │                         [ 9.9553e-01, -4.4561e-01,  1.3067e+00,  ...,  3.5021e-01,
│   │                          -1.4267e+00, -4.6562e-01],
│   │                         [-4.1792e-01, -1.6641e+00,  6.5348e-01,  ...,  1.0325e+00,
│   │                          -1.5290e+00,  6.5137e-01],
│   │                         ...,
│   │                         [-7.0210e-01,  9.2405e-01,  1.9414e+00,  ...,  1.9439e+00,
│   │                          -1.6814e-01, -1.0692e+00],
│   │                         [-6.0007e-01,  9.4871e-01, -1.5166e+00,  ...,  7.9249e-01,
│   │                           1.1005e+00, -2.3878e-02],
│   │                         [ 1.2261e+00,  3.8003e-01, -7.2930e-01,  ..., -1.1425e+00,
│   │                          -3.2733e-01,  1.3253e+00]]],
│   │               
│   │               
│   │                       [[[-4.0302e-01, -4.8846e-01, -2.8228e+00,  ..., -9.8832e-01,
│   │                          -1.3734e+00, -3.6885e-01],
│   │                         [-1.3313e+00, -2.5525e-01, -1.0839e+00,  ..., -5.2812e-01,
│   │                          -1.1874e-01, -4.4591e-01],
│   │                         [-6.0366e-01,  4.5567e-01, -1.6828e+00,  ...,  2.6629e+00,
│   │                           5.6302e-01,  1.1166e+00],
│   │                         ...,
│   │                         [ 1.9959e-01, -1.4992e-01,  7.9302e-02,  ..., -1.4324e-01,
│   │                           7.2170e-01, -6.6362e-01],
│   │                         [-1.4764e+00, -4.9337e-01, -3.7185e-01,  ..., -4.5096e-01,
│   │                           4.0913e-01,  2.2368e-01],
│   │                         [-9.8589e-01,  4.9522e-01,  2.2190e-01,  ...,  1.6660e-01,
│   │                           1.8852e+00, -1.6943e+00]],
│   │               
│   │                        [[ 1.2006e+00,  1.4942e+00, -1.1296e-02,  ...,  8.8106e-02,
│   │                           1.5460e+00,  2.7187e-01],
│   │                         [-3.3818e-01, -1.3860e+00, -5.4836e-01,  ..., -1.8566e-01,
│   │                          -1.4503e-01,  1.1828e+00],
│   │                         [ 7.8731e-01,  3.6832e-01, -1.6402e+00,  ..., -8.6128e-01,
│   │                           7.8643e-01, -2.3991e+00],
│   │                         ...,
│   │                         [-1.2632e+00,  1.2817e+00,  1.3419e+00,  ...,  6.3172e-01,
│   │                           7.3576e-01, -9.1797e-01],
│   │                         [ 7.5540e-01,  1.7704e-01,  4.0129e-01,  ..., -9.0250e-01,
│   │                          -2.4268e-01,  5.8245e-01],
│   │                         [ 7.6338e-01,  9.9988e-02, -1.2155e+00,  ..., -1.0944e+00,
│   │                           4.3858e-01, -7.4171e-01]],
│   │               
│   │                        [[ 1.2930e+00, -2.1058e-01,  9.4332e-01,  ..., -1.0701e+00,
│   │                           6.1732e-01, -7.9441e-01],
│   │                         [ 4.7757e-01,  3.7257e-01,  6.8763e-01,  ...,  1.5797e-01,
│   │                          -1.1610e+00,  4.0394e-01],
│   │                         [-7.1361e-01, -3.4692e-01,  2.1726e+00,  ...,  1.1077e+00,
│   │                          -1.1899e+00, -8.1069e-01],
│   │                         ...,
│   │                         [ 1.6562e+00, -8.7608e-01,  7.4942e-02,  ...,  1.0803e+00,
│   │                           9.3535e-01, -5.4753e-01],
│   │                         [ 9.0557e-01,  1.1473e+00,  1.2133e+00,  ...,  6.7832e-01,
│   │                           6.7635e-01,  1.2330e+00],
│   │                         [-1.1059e+00,  6.8749e-01, -1.6309e+00,  ...,  4.8102e-01,
│   │                          -2.0025e+00, -1.1603e-01]]]])
│   └── 'scalar' --> tensor([[-0.2778,  1.0373,  0.6297, -0.9248, -1.2144,  0.5902, -0.1981, -1.3492,
│                              2.3159, -2.0238,  1.9589,  1.0304],
│                            [-0.3628,  0.2890,  1.0781,  0.2064,  0.1827, -0.0800,  1.5294, -0.3934,
│                             -0.9041, -0.3042, -0.1004,  0.1302],
│                            [ 0.2546,  0.9138, -0.6433,  0.4390, -0.6012, -0.6795, -0.3616, -0.6075,
│                              0.8750,  1.2899, -0.5521, -0.4684],
│                            [-0.6567,  0.2819, -0.7508,  0.4828,  0.2418,  0.6956, -0.3993, -0.0043,
│                              1.8764, -0.0331, -1.0780,  0.2885]])
└── 'reward' --> tensor([[0.2299],
                         [0.8472],
                         [0.9215],
                         [0.7098]])

This code looks much simpler and clearer.