Stack Structured Data

When the tensors form the tree structures, they are often needed to be stacked together, like the torch.stack() implemented in torch.

Stack With Native PyTorch API

Here is the common code implement with native pytorch API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]


# execute `stack` op
def stack(data, dim):
    elem = data[0]
    if isinstance(elem, torch.Tensor):
        return torch.stack(data, dim)
    elif isinstance(elem, dict):
        return {k: stack([item[k] for item in data], dim) for k in elem.keys()}
    elif isinstance(elem, bool):
        return torch.BoolTensor(data)
    else:
        raise TypeError("not support elem type: {}".format(type(elem)))


stacked_data = stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data['obs']['image'].shape == (B, 3, 32, 32)
assert stacked_data['action'].shape == (B, 1)
assert stacked_data['reward'].shape == (B, 1)
assert stacked_data['done'].shape == (B,)
assert stacked_data['done'].dtype == torch.bool

The output should be like below, and the assertion statements can be all passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
{'obs': {'scalar': tensor([[-1.4607, -0.3098, -0.2855,  1.0892,  2.0129, -0.5830, -0.3645, -2.2169,
          0.5652, -2.1109, -0.4848,  0.4675],
        [ 0.3150, -0.0179, -0.9303, -0.7602,  0.1695,  0.2604, -0.9208,  0.3807,
         -1.2486,  0.3504,  0.4016,  0.1729],
        [-0.1870, -2.3231, -0.3629,  0.2551,  1.7746,  0.5267,  1.4649, -0.2815,
          1.9678,  1.0912, -1.8584,  0.4913],
        [ 0.4994,  1.3094, -1.3254,  0.0055,  0.5832, -0.0751, -0.9258, -0.3123,
          1.1580,  0.6586, -0.2181,  1.2473]]), 'image': tensor([[[[ 1.3486e+00,  3.1630e-04,  9.4858e-01,  ..., -6.0048e-01,
            7.0951e-02,  1.3416e+00],
          [ 1.4403e-01,  8.8430e-01, -9.1637e-01,  ...,  1.6808e+00,
           -4.7636e-01, -4.3116e-01],
          [-7.3847e-01,  2.4302e-02,  5.9704e-02,  ...,  5.3134e-01,
           -8.9657e-01,  4.3325e-01],
          ...,
          [ 1.5193e-01,  1.5124e+00, -2.3019e+00,  ..., -1.2970e+00,
            6.9906e-01,  1.1957e-01],
          [ 2.9094e-01,  1.2388e-01, -3.5267e-01,  ..., -2.4005e+00,
            5.6672e-02, -2.3968e-01],
          [-1.1331e+00, -1.3203e+00, -1.5843e+00,  ...,  8.7908e-01,
           -8.1334e-02,  7.1196e-01]],

         [[ 1.7904e+00,  8.6618e-01, -8.4076e-01,  ..., -1.4151e-01,
            9.5637e-02, -1.9439e-01],
          [ 5.3955e-02,  1.3949e+00,  1.0767e+00,  ..., -2.8176e-01,
           -2.7050e+00,  4.1822e-01],
          [ 1.3015e+00,  1.6543e+00,  1.5777e+00,  ...,  1.7040e+00,
           -1.9235e-02,  2.7200e-01],
          ...,
          [ 2.0429e+00, -6.4446e-01,  1.1346e-01,  ...,  7.0581e-02,
           -1.3147e-01, -4.4320e-01],
          [ 1.6657e+00,  1.0031e+00, -1.3084e+00,  ...,  1.3864e+00,
            6.0901e-01,  1.4652e+00],
          [ 4.2022e-01,  1.0194e+00, -6.8863e-01,  ...,  1.2872e+00,
           -5.5808e-01,  7.4237e-01]],

         [[-5.3672e-01, -5.7924e-01, -2.8592e-01,  ...,  6.3599e-01,
            1.0248e+00,  8.6069e-01],
          [ 1.4439e+00, -7.5969e-01,  2.2018e+00,  ...,  3.3398e-01,
           -1.2543e+00,  2.2208e+00],
          [-1.8504e+00,  6.6175e-01,  1.6769e-01,  ...,  5.2623e-01,
           -3.3059e+00,  1.0604e+00],
          ...,
          [ 6.3397e-01,  5.0834e-01, -9.7249e-01,  ..., -4.1031e-01,
            6.0315e-01,  1.3101e+00],
          [-3.4543e-03, -1.8361e-01,  2.6744e-01,  ...,  5.3730e-01,
            2.3319e-01,  6.9069e-01],
          [ 2.0245e+00,  8.1553e-02,  3.6976e-01,  ..., -2.0310e-01,
           -1.0247e+00, -8.4923e-01]]],


        [[[ 3.9333e-01, -1.1791e-03,  8.4090e-01,  ...,  8.5771e-01,
           -2.7729e-01,  7.7209e-01],
          [-5.4817e-01, -1.1096e-01, -1.2877e+00,  ..., -1.2695e+00,
           -9.2796e-01,  1.6938e+00],
          [ 7.5265e-01, -6.2671e-01, -1.1036e+00,  ...,  3.6110e-01,
           -7.1477e-01, -7.4430e-01],
          ...,
          [ 9.6595e-01, -8.9932e-02,  4.7592e-01,  ..., -1.0272e+00,
            1.6230e-01,  2.0065e+00],
          [ 4.9604e-01,  6.2594e-01, -1.3647e+00,  ...,  2.4532e-01,
           -2.9389e-02, -1.2436e+00],
          [ 2.3471e+00, -1.5749e-01, -4.6400e-01,  ..., -1.1049e+00,
           -1.8053e+00, -3.7563e-01]],

         [[-6.3709e-01,  8.3775e-02,  1.0332e-01,  ..., -6.0955e-02,
           -2.6844e+00,  5.5119e-01],
          [-7.4883e-01,  2.0417e-01,  1.1604e+00,  ..., -2.1945e-01,
           -5.6994e-01, -4.5779e-01],
          [ 1.0382e+00,  7.7947e-01,  3.1680e-01,  ...,  6.7417e-01,
            2.2994e-01,  1.4514e+00],
          ...,
          [ 2.0677e+00, -8.1980e-01,  1.4386e+00,  ...,  1.5733e-01,
           -6.9231e-01,  3.8869e-02],
          [-1.5365e-01, -2.4951e-01, -3.2119e-01,  ..., -1.7675e+00,
            1.1691e+00,  2.8378e-01],
          [ 2.3755e-01,  2.0935e+00,  5.7434e-01,  ...,  1.0033e+00,
            1.5294e+00, -3.5352e-01]],

         [[ 1.9285e-01,  6.6657e-01,  8.9605e-01,  ...,  5.6345e-01,
           -1.2235e+00,  3.7019e-01],
          [-6.8642e-01, -6.1073e-01, -3.5424e-02,  ...,  2.0499e-01,
            8.2370e-01,  1.8997e+00],
          [ 6.8858e-01,  1.8094e+00,  8.1681e-02,  ..., -3.4173e-02,
            6.5707e-01, -6.1192e-01],
          ...,
          [ 1.7091e+00, -1.1555e+00,  1.8712e-01,  ...,  9.6459e-01,
            2.4414e+00, -1.3544e+00],
          [ 9.4608e-01,  4.4452e-01, -1.4537e+00,  ...,  2.8130e-01,
           -8.9218e-01, -2.9847e-01],
          [-5.9471e-01, -8.3986e-01,  5.9665e-01,  ..., -8.3943e-01,
           -5.7018e-01,  7.2707e-01]]],


        [[[ 1.6993e+00, -3.2606e-01, -3.3249e-01,  ...,  1.1195e-01,
           -1.4968e+00, -3.2666e-01],
          [ 1.1441e+00, -1.8613e+00,  5.4776e-01,  ..., -7.8962e-01,
            3.7409e-01,  2.0089e+00],
          [ 1.1385e+00,  1.1585e+00,  5.5855e-01,  ..., -1.5946e+00,
           -1.2094e-01, -8.8019e-02],
          ...,
          [-3.6718e-01,  2.0840e-01, -1.9935e+00,  ..., -7.2289e-01,
           -1.4150e+00, -1.9174e-01],
          [-1.1182e+00, -1.7001e+00,  1.3512e-01,  ..., -3.6640e-02,
            2.9737e+00,  1.8341e+00],
          [-4.4741e+00, -7.6293e-01,  4.3407e-01,  ...,  5.5492e-01,
            2.8019e-01, -4.5269e-01]],

         [[-4.7407e-01,  2.0214e+00, -2.9583e+00,  ..., -1.6253e+00,
           -1.7875e+00,  7.7868e-01],
          [-1.7513e-02, -1.1043e+00,  1.4510e+00,  ..., -1.8358e+00,
            9.4757e-01, -6.1726e-01],
          [ 9.5177e-01, -9.2181e-01, -7.8360e-01,  ..., -8.4698e-01,
            5.8719e-01, -1.9356e-01],
          ...,
          [-1.1373e+00,  5.5234e-01,  1.2291e+00,  ...,  2.2830e-01,
           -8.8716e-01, -7.7818e-01],
          [-1.3319e-01,  6.0197e-01,  5.0437e-02,  ..., -7.1805e-01,
            4.6809e-01, -1.0781e+00],
          [ 1.7074e-01, -5.8314e-01,  8.7072e-01,  ...,  7.0366e-01,
           -6.4285e-01, -9.3610e-01]],

         [[-1.2632e+00,  6.2555e-01, -8.2867e-01,  ..., -2.0814e+00,
            1.7938e+00, -2.2382e-01],
          [-7.4899e-01, -1.6356e+00,  3.5935e-01,  ..., -1.8370e-01,
            3.2226e-01, -1.4253e+00],
          [-2.7651e+00, -1.2550e+00,  1.2908e+00,  ..., -7.9643e-01,
           -4.0456e-01, -1.6549e+00],
          ...,
          [ 2.2757e-01, -7.9969e-01, -1.1071e+00,  ...,  2.2011e-01,
            1.4292e-01, -5.0016e-01],
          [-1.0346e+00, -4.1360e-01, -9.2221e-01,  ...,  1.6145e-01,
            2.2885e-01, -5.7004e-01],
          [-2.4836e-01, -2.0064e-01,  5.4564e-01,  ..., -8.5175e-01,
            5.3611e-01, -1.3287e+00]]],


        [[[ 7.8754e-01, -1.8022e+00, -8.5686e-02,  ..., -1.2964e+00,
           -3.3757e-02,  3.9055e-01],
          [ 6.6596e-01, -9.5347e-01,  3.5121e-01,  ..., -1.2046e+00,
            8.3156e-01,  2.2563e-01],
          [-3.9194e-01, -1.0662e-01, -4.4437e-01,  ...,  3.3369e-01,
            8.4527e-02, -1.3097e+00],
          ...,
          [ 8.8662e-01,  7.6803e-01,  3.3451e-01,  ...,  3.5352e-01,
            1.0496e+00,  7.8214e-01],
          [ 2.9224e-01, -5.0483e-01, -9.6146e-01,  ..., -1.7164e+00,
           -5.2491e-01, -1.1305e+00],
          [ 5.3469e-01,  1.4477e-01,  3.3255e-01,  ..., -2.2812e-01,
            3.2807e-01,  5.1115e-01]],

         [[ 5.6532e-02,  2.6445e-01,  9.3625e-01,  ...,  3.1135e+00,
            1.8264e+00, -1.5435e+00],
          [ 3.7219e-01, -1.9246e-01, -1.0503e-01,  ...,  1.1640e+00,
           -8.2183e-01, -9.0728e-01],
          [ 1.2188e+00, -1.3463e+00, -9.1112e-01,  ...,  4.0153e-01,
           -6.2391e-02, -2.0096e-01],
          ...,
          [ 6.9667e-01,  5.1313e-02,  8.2127e-01,  ..., -5.4653e-01,
            2.9229e-01, -2.1096e-01],
          [ 4.3281e-01,  1.5642e-01, -7.2910e-01,  ...,  1.0747e-01,
           -9.2597e-02,  6.9136e-01],
          [-4.2650e-01,  4.3437e-01, -6.1841e-01,  ..., -2.7288e-01,
           -2.2879e+00, -2.0530e+00]],

         [[ 4.6594e-01, -4.2089e-01,  4.2405e-01,  ..., -2.0306e+00,
           -6.3301e-01,  1.5008e+00],
          [-1.1737e+00,  1.6536e-01,  1.3177e+00,  ...,  1.5043e+00,
            4.4124e-01, -9.6321e-02],
          [-8.3223e-01, -8.3628e-01, -6.8036e-01,  ...,  6.0864e-01,
            4.7836e-01, -5.6300e-01],
          ...,
          [-8.6715e-01,  5.5698e-01, -6.0210e-01,  ..., -3.4453e-01,
            1.9895e+00,  2.0493e+00],
          [-1.3777e-01, -7.3046e-02,  5.0927e-01,  ...,  9.5767e-03,
            2.3491e+00,  7.5771e-01],
          [-1.1936e-01, -6.7813e-01,  7.4842e-01,  ...,  6.5036e-01,
            4.5451e-01, -3.0767e-01]]]])}, 'action': tensor([[5],
        [9],
        [0],
        [1]]), 'reward': tensor([[0.7465],
        [0.3932],
        [0.5916],
        [0.7785]]), 'done': tensor([False, False, False, False])}

We can see that the structure process function need to be fully implemented (like the function stack). This code is actually not clear, and due to hard coding, if you need to support more data types (such as integer), you must make special modifications to the function.

Stack With TreeTensor API

The same workflow can be implemented with treetensor API like the code below.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch

import treetensor.torch as ttorch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]
# execute `stack` op
data = [ttorch.tensor(d) for d in data]
stacked_data = ttorch.stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data.obs.image.shape == (B, 3, 32, 32)
assert stacked_data.action.shape == (B, 1)
assert stacked_data.reward.shape == (B, 1)
assert stacked_data.done.shape == (B,)
assert stacked_data.done.dtype == torch.bool

The output should be like below, and the assertion statements can be all passed as well.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
<Tensor 0x7f4c0c9d3390>
├── 'action' --> tensor([[5],
│                        [1],
│                        [2],
│                        [3]])
├── 'done' --> tensor([False, False, False, False])
├── 'obs' --> <Tensor 0x7f4c0caa0510>
│   ├── 'image' --> tensor([[[[ 8.1806e-02, -2.2467e+00,  3.0448e-01,  ...,  1.0109e+00,
│   │                          -2.3672e+00, -1.0418e+00],
│   │                         [-3.7578e-01, -1.6058e+00, -1.1329e+00,  ..., -1.7447e+00,
│   │                          -1.9063e-01, -5.8661e-01],
│   │                         [-8.5886e-01,  3.5040e-01,  7.7742e-01,  ..., -3.0636e-01,
│   │                          -5.2193e-01, -1.3168e+00],
│   │                         ...,
│   │                         [ 2.0698e-01,  6.2280e-01,  7.6723e-01,  ...,  7.4431e-02,
│   │                          -2.7281e-01,  1.0101e+00],
│   │                         [ 2.7868e-01, -2.4541e-02,  9.9881e-04,  ..., -1.4403e+00,
│   │                          -2.1614e+00, -7.3145e-01],
│   │                         [-7.9699e-01, -4.8809e-01,  7.3654e-01,  ..., -8.2619e-01,
│   │                           4.2870e-01, -1.6029e+00]],
│   │               
│   │                        [[ 3.9483e-02,  5.3317e-01, -4.3117e-01,  ...,  2.0204e-01,
│   │                           4.4214e-01, -6.3987e-01],
│   │                         [ 6.3943e-01, -1.6869e+00, -8.6172e-02,  ...,  1.4202e+00,
│   │                          -1.4649e+00, -8.4035e-01],
│   │                         [-1.8227e+00, -7.9037e-01, -1.0401e+00,  ...,  1.5514e+00,
│   │                           2.8870e-01,  4.6759e-02],
│   │                         ...,
│   │                         [ 1.5269e+00,  7.6670e-01,  2.3740e-01,  ...,  4.0158e-01,
│   │                          -4.9379e-01,  4.3831e-02],
│   │                         [ 1.0369e+00, -3.2564e-01, -1.3360e-01,  ...,  5.9646e-02,
│   │                           1.6869e-01,  1.5293e+00],
│   │                         [ 3.8367e-01,  1.6027e+00,  6.2689e-02,  ..., -1.9582e+00,
│   │                           3.6963e-01,  8.6440e-01]],
│   │               
│   │                        [[ 5.9077e-01,  7.9733e-01,  1.2124e+00,  ...,  6.2363e-01,
│   │                          -1.1896e+00, -8.6018e-02],
│   │                         [ 1.3771e+00, -6.6635e-02,  1.9537e+00,  ..., -6.0661e-01,
│   │                          -1.0195e+00,  1.0428e+00],
│   │                         [-2.4200e+00, -8.0138e-01,  9.2924e-01,  ..., -1.1254e+00,
│   │                           1.6013e+00, -1.1984e+00],
│   │                         ...,
│   │                         [-8.2102e-01, -7.2953e-01, -2.3210e-01,  ..., -4.9438e-01,
│   │                           1.2480e+00,  1.0954e+00],
│   │                         [ 9.7913e-01, -3.4900e-01, -1.3639e+00,  ..., -1.3129e+00,
│   │                          -7.1410e-01, -3.9435e-01],
│   │                         [-3.7570e-01,  2.3576e+00,  1.2298e+00,  ..., -1.4837e-01,
│   │                           6.6909e-01, -9.9855e-01]]],
│   │               
│   │               
│   │                       [[[-1.1887e-01, -5.4122e-01,  9.8354e-01,  ..., -2.5723e+00,
│   │                           2.6096e-02, -1.3758e+00],
│   │                         [-1.9846e-01, -1.9940e-01,  4.4985e-01,  ...,  1.0873e+00,
│   │                          -2.3991e-01, -1.8733e+00],
│   │                         [-1.2441e+00, -4.9923e-03,  1.8964e+00,  ...,  1.8982e-01,
│   │                          -2.3382e-01, -5.2221e-01],
│   │                         ...,
│   │                         [ 3.9414e-01,  7.2855e-01,  8.2830e-01,  ..., -4.7457e-01,
│   │                           2.0843e+00, -6.6788e-01],
│   │                         [ 2.2242e-01, -1.9587e+00,  6.4669e-01,  ...,  1.5607e-01,
│   │                           6.3033e-01, -2.1415e+00],
│   │                         [-6.3034e-01, -7.3532e-02, -1.0174e-02,  ...,  1.1528e+00,
│   │                          -9.8129e-01,  6.2031e-01]],
│   │               
│   │                        [[ 2.0721e+00, -7.2781e-01, -5.2350e-01,  ..., -1.0972e+00,
│   │                           7.3554e-01, -1.4081e+00],
│   │                         [-9.3098e-01, -2.0797e+00, -1.6015e-02,  ...,  7.0262e-02,
│   │                           7.1007e-01, -1.2796e+00],
│   │                         [ 1.6380e+00, -6.2613e-01,  7.5009e-01,  ...,  1.2406e+00,
│   │                          -1.9939e-01, -3.9735e-02],
│   │                         ...,
│   │                         [-2.4547e-01, -9.1112e-01, -2.1860e+00,  ...,  1.0461e+00,
│   │                          -1.0128e+00,  1.1512e+00],
│   │                         [-6.0626e-01,  1.6284e+00,  1.8072e+00,  ..., -6.6619e-01,
│   │                           7.4477e-01, -6.9145e-01],
│   │                         [-1.1699e+00,  1.4026e-01,  3.1296e-01,  ...,  7.8478e-01,
│   │                          -1.3807e+00, -5.1316e-01]],
│   │               
│   │                        [[-6.7703e-01,  7.8114e-01, -5.4215e-01,  ..., -5.7191e-01,
│   │                           1.1967e-01,  2.5206e-01],
│   │                         [-3.0890e-01,  7.4094e-01, -6.9470e-01,  ...,  1.6260e+00,
│   │                          -4.8081e-01, -2.5796e-02],
│   │                         [ 2.1091e-01, -6.5780e-01, -7.3678e-01,  ..., -8.8445e-02,
│   │                           7.4454e-01, -5.0253e-01],
│   │                         ...,
│   │                         [ 1.6098e+00,  9.9627e-01,  1.4506e-01,  ...,  7.4524e-01,
│   │                          -5.8891e-01,  4.1910e-01],
│   │                         [ 2.2642e-01, -1.2556e+00, -2.9519e-01,  ..., -8.1372e-01,
│   │                          -5.2390e-01, -6.9954e-01],
│   │                         [-2.3675e+00, -3.6602e-01, -2.8567e-01,  ...,  8.6267e-01,
│   │                          -1.2013e+00, -1.2936e+00]]],
│   │               
│   │               
│   │                       [[[-3.4263e-01,  1.3159e-01,  3.6555e-01,  ...,  1.9203e+00,
│   │                           2.4527e-01,  1.3194e+00],
│   │                         [-4.6519e-01, -1.7776e+00, -5.0229e-01,  ...,  1.1426e+00,
│   │                          -1.5107e-01, -1.0286e+00],
│   │                         [ 5.9309e-01, -1.5659e+00,  1.2265e+00,  ...,  2.4337e-01,
│   │                          -5.2158e-01,  1.0018e+00],
│   │                         ...,
│   │                         [ 4.6418e-01, -1.5557e+00, -5.9542e-01,  ..., -8.2674e-01,
│   │                          -6.8531e-01, -5.2215e-01],
│   │                         [ 1.5428e+00, -8.4671e-01, -5.9348e-01,  ...,  2.0883e+00,
│   │                           6.5298e-02,  3.0662e+00],
│   │                         [ 1.5043e+00, -3.8282e-01, -7.2771e-01,  ..., -1.4024e-01,
│   │                          -2.1325e-01,  1.2760e-01]],
│   │               
│   │                        [[ 1.7851e+00,  1.8010e-01,  3.1680e-02,  ...,  1.9732e-01,
│   │                           4.7090e-01,  2.3229e-01],
│   │                         [-1.8809e+00,  2.1519e+00,  1.3604e+00,  ...,  2.1301e+00,
│   │                           8.4879e-01,  8.7548e-01],
│   │                         [ 2.5895e+00, -4.6578e-01,  1.3226e+00,  ...,  2.5155e+00,
│   │                          -3.6891e-01,  1.6837e+00],
│   │                         ...,
│   │                         [ 1.4207e+00,  5.8059e-01,  6.3880e-01,  ...,  1.2239e+00,
│   │                          -4.8423e-01, -7.4884e-02],
│   │                         [ 1.1180e-01,  1.3894e+00, -6.6145e-01,  ...,  5.9361e-01,
│   │                          -1.8604e+00, -4.4077e-01],
│   │                         [-1.9552e-01,  1.5031e-01,  8.7746e-01,  ..., -6.1427e-01,
│   │                          -4.1117e-01, -1.9545e+00]],
│   │               
│   │                        [[-5.9359e-02, -7.8311e-01,  5.9126e-01,  ...,  7.2944e-01,
│   │                           5.7847e-01,  1.6694e+00],
│   │                         [-3.1214e-01, -2.9955e-01, -1.2139e-01,  ...,  4.0318e-01,
│   │                           3.0014e-01,  9.5816e-02],
│   │                         [-3.4815e-01,  2.3893e+00,  9.2751e-01,  ..., -1.3496e+00,
│   │                          -5.8793e-01, -6.8229e-01],
│   │                         ...,
│   │                         [ 2.1550e+00,  1.2703e+00, -9.0881e-02,  ..., -2.2307e-02,
│   │                           7.3062e-01, -9.6352e-01],
│   │                         [-1.8699e+00,  1.6503e+00,  2.8753e-01,  ..., -3.4384e-01,
│   │                           7.0190e-01,  2.3421e-01],
│   │                         [-9.1674e-02,  3.8146e-01,  1.0559e+00,  ...,  2.3662e-01,
│   │                          -1.2357e-02,  5.7720e-01]]],
│   │               
│   │               
│   │                       [[[-1.2671e-01,  1.2837e+00,  1.9911e+00,  ...,  7.1862e-01,
│   │                          -1.4210e-01,  3.5973e-01],
│   │                         [ 3.1652e-02,  3.3393e-01,  5.9773e-01,  ...,  2.2100e+00,
│   │                           5.9873e-03, -1.9354e-01],
│   │                         [-1.8264e+00,  5.6837e-01,  9.2700e-01,  ..., -7.4586e-01,
│   │                           4.9251e-02,  5.9342e-01],
│   │                         ...,
│   │                         [ 1.1316e+00, -4.7306e-01, -1.2746e+00,  ..., -1.2610e+00,
│   │                          -2.9877e-01, -3.1842e-01],
│   │                         [-5.9445e-01,  6.9342e-01,  2.9963e-01,  ...,  3.0733e-01,
│   │                           6.3283e-01,  1.0808e+00],
│   │                         [-8.3860e-01, -3.1658e-01, -5.6023e-01,  ...,  1.0421e+00,
│   │                          -5.0113e-01, -1.3363e-01]],
│   │               
│   │                        [[-7.2741e-01,  6.9167e-01, -3.0078e-01,  ..., -1.0068e+00,
│   │                          -2.1929e-01, -1.2627e+00],
│   │                         [-1.3662e+00, -8.5414e-01,  1.2524e+00,  ...,  1.2622e+00,
│   │                          -2.1934e-01, -4.3038e-01],
│   │                         [ 1.8290e+00, -8.3026e-01,  3.1194e-03,  ...,  6.0688e-01,
│   │                           9.5129e-01, -1.0989e+00],
│   │                         ...,
│   │                         [ 3.2719e-01,  2.2221e-01, -9.5300e-01,  ..., -3.6650e-01,
│   │                          -3.0580e-01, -1.7415e+00],
│   │                         [-3.5989e-01,  7.0112e-01,  4.5027e-01,  ...,  8.3663e-01,
│   │                          -3.3003e-01,  6.1198e-01],
│   │                         [-6.3113e-01, -4.3319e-01,  4.4508e-01,  ..., -1.4889e+00,
│   │                          -1.6057e+00, -9.3383e-01]],
│   │               
│   │                        [[-1.6817e-01,  8.5331e-01,  1.7388e-01,  ..., -4.7345e-01,
│   │                          -2.9480e-01, -8.1398e-03],
│   │                         [-8.3057e-02, -1.1261e+00,  7.9436e-02,  ..., -2.0992e-01,
│   │                          -2.1231e+00,  1.4430e+00],
│   │                         [-9.0368e-01,  1.6929e+00, -1.1819e-01,  ..., -2.8050e-01,
│   │                          -1.1694e+00, -1.1560e-01],
│   │                         ...,
│   │                         [-3.3417e-01, -1.0609e+00,  3.8088e-01,  ..., -6.8374e-01,
│   │                          -7.8031e-01, -6.2998e-01],
│   │                         [-1.4974e+00,  1.6499e+00, -2.4728e+00,  ..., -1.5842e+00,
│   │                           1.1440e+00, -3.5524e-01],
│   │                         [ 9.5896e-01, -7.5106e-01, -6.6478e-01,  ...,  4.8918e-01,
│   │                          -6.4041e-01, -5.0692e-01]]]])
│   └── 'scalar' --> tensor([[-0.1070,  0.1540,  0.3935, -0.0139, -0.9102, -0.4365, -0.5973, -0.8943,
│                             -0.2632, -0.9520, -0.3040, -1.0482],
│                            [-0.6612,  0.4418, -0.4626,  0.9469,  0.4252,  1.4546,  1.0041, -1.1637,
│                              0.4782,  0.3539, -0.8206, -0.2705],
│                            [-0.2826, -0.5210,  1.2588, -1.4922, -1.2991, -1.0265, -0.3590, -0.3966,
│                             -0.4347,  0.4519, -0.6855, -0.4290],
│                            [-0.9282,  0.4347,  0.5183,  0.1083, -0.8681,  1.6350, -1.0606,  2.5068,
│                              1.5937,  1.4252,  0.5863, -0.1068]])
└── 'reward' --> tensor([[0.1657],
                         [0.0856],
                         [0.7046],
                         [0.6817]])

This code looks much simpler and clearer.