Stack Structured Data

When the tensors form the tree structures, they are often needed to be stacked together, like the torch.stack() implemented in torch.

Stack With Native PyTorch API

Here is the common code implement with native pytorch API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]


# execute `stack` op
def stack(data, dim):
    elem = data[0]
    if isinstance(elem, torch.Tensor):
        return torch.stack(data, dim)
    elif isinstance(elem, dict):
        return {k: stack([item[k] for item in data], dim) for k in elem.keys()}
    elif isinstance(elem, bool):
        return torch.BoolTensor(data)
    else:
        raise TypeError("not support elem type: {}".format(type(elem)))


stacked_data = stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data['obs']['image'].shape == (B, 3, 32, 32)
assert stacked_data['action'].shape == (B, 1)
assert stacked_data['reward'].shape == (B, 1)
assert stacked_data['done'].shape == (B,)
assert stacked_data['done'].dtype == torch.bool

The output should be like below, and the assertion statements can be all passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
{'obs': {'scalar': tensor([[ 0.0684,  0.5900,  0.7682,  0.9767, -0.1953,  1.0044,  0.8785,  1.3354,
          0.0172,  0.5675,  1.4946, -0.7357],
        [ 0.6647,  1.3862, -0.1558, -2.2917, -0.5581,  0.5269, -0.7996,  0.1080,
         -2.0253, -0.3329,  2.2568,  1.0306],
        [-0.5710, -0.5738, -0.1536, -0.5098,  2.1608, -0.1507, -1.7210, -1.4670,
          0.2884,  2.2689, -0.7243,  1.1279],
        [ 2.6774,  0.4147,  0.0980, -2.2475, -0.2817, -0.8208,  0.5228, -0.2241,
         -0.3283,  1.3230, -0.0200,  0.3638]]), 'image': tensor([[[[-4.7554e-01, -7.8893e-01,  6.1923e-01,  ..., -4.1218e-01,
           -3.2837e-01,  1.3920e-01],
          [ 8.4497e-01, -3.7064e-01,  2.1724e-01,  ..., -4.1546e-02,
           -2.1753e+00, -3.2035e+00],
          [-1.5135e+00,  5.0581e-01, -9.3895e-01,  ...,  1.1093e+00,
            4.7493e-03,  2.1973e-01],
          ...,
          [ 6.6788e-01, -6.5239e-01,  2.6905e-02,  ..., -7.0750e-01,
            1.0511e+00,  2.9874e-01],
          [-7.4152e-01, -1.4118e+00, -1.0775e+00,  ...,  1.8392e-01,
            3.1325e-01, -3.4358e-01],
          [ 1.1440e+00,  9.1121e-02, -2.1180e+00,  ...,  1.6172e+00,
           -8.7952e-01,  1.1219e+00]],

         [[ 8.4060e-01, -9.2141e-01, -8.0702e-02,  ..., -3.7814e-01,
           -8.3275e-01, -1.8125e+00],
          [-1.6505e-01,  1.6130e-01,  9.7864e-01,  ...,  4.4833e-01,
           -3.7006e-02,  1.1522e+00],
          [ 2.3753e-02, -4.2124e-01,  6.9132e-01,  ...,  2.1129e+00,
           -3.1379e-01,  1.3926e-01],
          ...,
          [ 3.2114e-01,  9.8572e-01,  1.2752e+00,  ...,  8.5630e-01,
            1.1347e+00, -6.3490e-01],
          [-7.0824e-01,  8.0937e-01, -3.2197e-01,  ...,  9.5789e-01,
            3.9017e-01,  5.1746e-01],
          [-4.6875e-01,  5.3735e-01, -2.4773e-01,  ..., -1.4322e+00,
           -2.3715e-01,  1.9818e+00]],

         [[-3.6825e-01,  5.7923e-01, -4.3449e-01,  ...,  1.6023e-01,
            9.2854e-02, -1.5878e+00],
          [-2.1778e-01, -1.0311e+00,  1.1699e+00,  ...,  7.0137e-01,
           -6.3166e-01,  4.4683e-01],
          [-8.8696e-01,  1.8437e+00,  8.1871e-01,  ...,  8.3734e-01,
           -2.2155e-01,  1.1899e+00],
          ...,
          [-1.7772e+00,  6.2551e-01, -9.4602e-01,  ...,  3.9921e-01,
            1.3871e+00, -1.8961e+00],
          [ 6.4974e-01, -7.0245e-01,  1.3713e+00,  ...,  2.7675e+00,
            3.5925e-02, -7.6517e-02],
          [ 8.0325e-01,  3.5958e-01,  5.5824e-02,  ...,  4.0694e-01,
           -1.2469e-01,  5.6514e-01]]],


        [[[-5.3634e-01, -1.4410e+00, -2.4831e-01,  ..., -6.9434e-01,
            1.8838e+00,  8.4324e-01],
          [-6.1656e-02,  3.7122e-01,  1.3806e+00,  ..., -1.4493e+00,
           -2.6998e-01,  1.2833e+00],
          [ 4.4306e-01, -4.2877e-01, -2.4539e+00,  ..., -2.3633e+00,
           -1.2441e+00, -9.9403e-01],
          ...,
          [ 5.2854e-01,  8.4586e-01, -6.2156e-01,  ..., -1.4302e+00,
           -5.5827e-01, -1.1264e+00],
          [-1.4087e+00,  6.6113e-01,  4.8471e-01,  ...,  1.4252e+00,
            7.8959e-01, -8.9088e-01],
          [-1.3083e+00, -9.0095e-01, -1.7217e+00,  ...,  1.3912e+00,
           -1.4644e+00, -3.7026e-01]],

         [[ 1.4495e+00, -1.5762e+00, -3.7459e-01,  ..., -1.2449e+00,
            6.0080e-01,  4.2218e-01],
          [-4.5309e-01, -2.0804e-01,  9.5235e-01,  ...,  2.7149e-01,
            6.2514e-01,  1.8583e+00],
          [ 1.9473e-01,  6.3637e-02, -2.0580e-01,  ...,  2.8095e-01,
            1.6792e+00, -1.3339e+00],
          ...,
          [-4.0345e-01,  2.7497e-01, -8.5166e-01,  ...,  9.4178e-02,
            3.5693e-01,  1.3158e-01],
          [ 2.6755e-01, -4.4425e-01, -1.6471e+00,  ..., -1.0925e-01,
           -1.8283e+00,  6.6745e-01],
          [-1.0658e+00,  1.0233e+00, -1.8240e-01,  ...,  7.8965e-01,
           -1.0671e-01, -4.7651e-01]],

         [[ 8.5835e-01, -4.0449e-01,  1.5172e+00,  ...,  5.2368e-01,
           -6.5865e-01,  8.2022e-02],
          [-8.1244e-01,  1.8165e+00, -3.0293e-03,  ..., -4.4460e-01,
           -4.7775e-01, -5.9445e-01],
          [ 9.0924e-01,  1.0830e+00,  6.1081e-03,  ..., -5.4020e-01,
           -7.3527e-01, -2.7708e-01],
          ...,
          [ 1.5517e+00, -2.5811e+00,  1.0571e-01,  ...,  2.3180e-01,
            7.2628e-02,  8.7154e-02],
          [-8.5432e-01,  1.7800e+00,  1.7164e-01,  ..., -9.7936e-01,
           -1.0239e-02, -9.5831e-01],
          [-1.9329e+00, -1.2611e+00, -8.6700e-01,  ...,  1.0201e+00,
            3.3375e-02,  7.7483e-01]]],


        [[[-3.6356e-01,  1.0173e+00,  4.7969e-01,  ...,  2.6483e-01,
            4.0317e-03,  1.0106e+00],
          [ 8.4685e-01, -7.2727e-01, -1.0917e+00,  ..., -2.0248e-01,
           -2.0098e-01, -1.8366e+00],
          [-2.8834e-01,  5.7593e-01,  1.1658e-01,  ..., -4.5401e-01,
            5.0389e-01, -3.6475e-01],
          ...,
          [-5.5608e-02,  8.4242e-01,  9.5908e-01,  ...,  2.6776e+00,
            3.8514e-01,  5.5698e-01],
          [ 8.3457e-01,  1.5143e+00, -1.4124e-01,  ..., -8.0907e-01,
            3.7562e-01, -3.2794e-02],
          [ 1.1167e+00, -1.0661e+00, -8.2759e-01,  ..., -1.1748e-01,
           -1.7745e+00,  8.4779e-01]],

         [[-9.9380e-01, -8.0596e-01, -8.9938e-01,  ..., -6.1964e-01,
            8.7959e-01,  1.2139e+00],
          [-1.1894e+00, -2.6745e-01,  5.5676e-02,  ..., -1.7458e+00,
            2.3760e-01,  1.0560e+00],
          [-9.2423e-01,  4.7977e-01,  7.0519e-01,  ...,  4.9398e-01,
           -7.4705e-01,  5.1814e-01],
          ...,
          [ 5.4402e-01,  1.9101e+00, -3.1822e-02,  ..., -5.5197e-02,
           -1.6194e+00, -4.4549e-01],
          [ 3.7633e-01,  5.8654e-01, -1.8900e+00,  ..., -5.5845e-01,
            1.9100e-01, -9.7450e-01],
          [ 1.7141e+00, -1.1306e-01, -1.7087e+00,  ...,  1.1209e+00,
            1.0622e-01, -4.0193e-01]],

         [[ 6.7631e-01, -7.2222e-01, -3.9454e-02,  ...,  1.5154e+00,
           -2.6778e-01, -6.3725e-03],
          [-8.5461e-01, -3.1623e-01, -2.8823e+00,  ..., -6.1211e-01,
            2.5968e+00,  8.5955e-01],
          [-1.3389e+00, -3.0963e-01,  6.0191e-01,  ..., -1.9449e+00,
            1.4050e+00, -5.9887e-02],
          ...,
          [-1.4205e+00, -1.6149e+00, -1.5146e+00,  ..., -2.1635e+00,
            1.4557e+00, -1.5180e+00],
          [ 4.4207e-01, -6.5645e-01, -1.8877e+00,  ...,  1.1305e+00,
            1.7926e+00, -1.8223e+00],
          [-9.8722e-01, -7.6968e-01, -6.0449e-01,  ...,  1.3348e+00,
            5.2510e-01, -3.4434e-01]]],


        [[[ 1.8359e+00,  1.1313e-01, -5.6412e-01,  ..., -6.9564e-01,
            1.1334e-01,  5.4327e-01],
          [ 1.0112e+00,  5.2960e-01, -2.3631e+00,  ..., -3.1935e-01,
           -1.6985e+00,  6.6485e-01],
          [-1.3515e-02, -4.0961e-01,  3.0934e-01,  ..., -2.4570e-01,
            4.6244e-02,  1.6408e+00],
          ...,
          [ 7.0827e-01, -4.8415e-01, -2.9063e-01,  ...,  4.2145e-01,
            4.1186e-01,  1.5838e+00],
          [-9.2618e-01, -1.6249e+00, -3.9775e-01,  ..., -6.1304e-02,
            1.2185e+00, -1.6106e+00],
          [ 6.8409e-01,  1.8687e+00, -8.2373e-01,  ..., -1.2352e+00,
           -7.1869e-01, -7.1370e-01]],

         [[-3.9507e-01,  2.2633e+00,  6.4439e-01,  ...,  1.0119e+00,
            5.0173e-01,  3.1536e-01],
          [ 3.7320e-01,  2.7209e+00,  9.7314e-02,  ..., -4.3044e-02,
            3.4794e-01, -5.1050e-01],
          [-2.4836e+00, -7.2880e-01,  2.5734e-01,  ..., -1.4909e-01,
            3.1253e-01,  5.9529e-01],
          ...,
          [ 6.1399e-01,  2.3557e-01, -9.2694e-01,  ..., -7.8869e-01,
            1.3819e-01,  4.1621e-01],
          [ 1.1132e-01, -7.6805e-01,  1.1877e+00,  ...,  6.4226e-01,
           -9.8559e-01,  9.3547e-02],
          [-4.3711e-02, -2.1611e+00, -4.7005e-01,  ...,  1.0812e+00,
           -5.9348e-01,  1.7551e+00]],

         [[ 7.2081e-01, -5.9835e-01,  5.3106e-02,  ...,  1.2680e+00,
            9.9266e-01,  9.8181e-01],
          [ 7.2749e-01, -1.2197e-02,  7.1126e-01,  ...,  4.6707e-01,
            2.1191e+00, -1.6978e+00],
          [-1.6090e+00, -5.9794e-01,  8.9397e-01,  ..., -2.5594e-02,
           -3.8599e-01,  8.3139e-01],
          ...,
          [-1.7583e+00,  1.6991e+00,  3.2186e-01,  ...,  1.4940e+00,
           -5.7106e-01, -5.6885e-01],
          [ 7.3534e-02,  6.2721e-01, -4.2231e-02,  ..., -6.6588e-01,
           -2.1823e-01, -6.3178e-01],
          [-6.1241e-02, -8.7040e-01,  1.7454e-01,  ..., -1.7895e+00,
           -3.5881e-02, -1.1140e+00]]]])}, 'action': tensor([[6],
        [8],
        [1],
        [6]]), 'reward': tensor([[0.9677],
        [0.6753],
        [0.3547],
        [0.8751]]), 'done': tensor([False, False, False, False])}

We can see that the structure process function need to be fully implemented (like the function stack). This code is actually not clear, and due to hard coding, if you need to support more data types (such as integer), you must make special modifications to the function.

Stack With TreeTensor API

The same workflow can be implemented with treetensor API like the code below.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch

import treetensor.torch as ttorch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]
# execute `stack` op
data = [ttorch.tensor(d) for d in data]
stacked_data = ttorch.stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data.obs.image.shape == (B, 3, 32, 32)
assert stacked_data.action.shape == (B, 1)
assert stacked_data.reward.shape == (B, 1)
assert stacked_data.done.shape == (B,)
assert stacked_data.done.dtype == torch.bool

The output should be like below, and the assertion statements can be all passed as well.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
<Tensor 0x7f12678f0650>
├── 'action' --> tensor([[2],
│                        [9],
│                        [0],
│                        [5]])
├── 'done' --> tensor([False, False, False, False])
├── 'obs' --> <Tensor 0x7f12678f0690>
│   ├── 'image' --> tensor([[[[ 2.7174e-01,  8.4075e-01, -7.9458e-02,  ...,  1.3963e-01,
│   │                           8.8294e-02, -1.2869e+00],
│   │                         [ 1.8013e+00,  8.0568e-01, -1.1037e+00,  ...,  1.5109e+00,
│   │                          -1.2097e-01,  6.2731e-01],
│   │                         [-1.5198e+00,  1.5264e+00, -4.1801e-01,  ..., -1.7145e+00,
│   │                          -1.8279e-01,  1.0262e+00],
│   │                         ...,
│   │                         [ 3.1572e-02, -1.0523e+00,  1.1362e+00,  ...,  5.3862e-02,
│   │                           3.1879e-01, -1.7152e+00],
│   │                         [-3.9797e-01,  2.4824e-02,  2.0390e+00,  ...,  5.7118e-01,
│   │                          -1.0629e+00,  2.9664e-01],
│   │                         [ 2.6700e-01,  2.1804e+00, -4.3854e-01,  ...,  2.2275e+00,
│   │                          -3.8434e-01, -3.7832e-01]],
│   │               
│   │                        [[-1.9528e+00,  3.6321e-01, -8.5363e-01,  ...,  1.8319e+00,
│   │                          -6.8025e-01, -2.4421e-02],
│   │                         [ 5.8760e-01, -4.5020e-01,  9.1451e-01,  ...,  4.6670e-01,
│   │                           4.8890e-02,  5.1474e-01],
│   │                         [-1.1205e+00, -1.1165e+00, -6.7862e-01,  ...,  1.6130e+00,
│   │                          -1.2244e-01,  4.3652e-01],
│   │                         ...,
│   │                         [ 7.9428e-01,  1.0063e+00,  2.2393e-01,  ..., -9.2476e-01,
│   │                           5.4630e-01, -8.5017e-01],
│   │                         [-5.4621e-01, -2.5475e-01, -7.3253e-01,  ..., -1.4202e-01,
│   │                           5.7406e-01, -1.3774e+00],
│   │                         [ 7.9886e-01,  4.9150e-01, -5.7629e-02,  ..., -8.6407e-04,
│   │                           2.4093e-01,  8.8798e-01]],
│   │               
│   │                        [[-1.3801e-02, -1.1603e+00, -1.1130e+00,  ..., -4.2448e-01,
│   │                          -8.8471e-01,  7.8828e-01],
│   │                         [-4.6408e-01, -8.5120e-01,  4.1770e-01,  ..., -8.0185e-01,
│   │                           2.4898e-01, -4.3429e-01],
│   │                         [-1.4288e+00,  1.3095e-01,  1.0402e+00,  ..., -5.9340e-01,
│   │                          -1.0735e+00,  7.9782e-01],
│   │                         ...,
│   │                         [ 4.7282e-01, -1.2754e+00, -2.1945e+00,  ..., -4.7047e-01,
│   │                           2.0249e+00,  4.2164e-01],
│   │                         [ 7.2714e-01,  7.7910e-02, -1.0925e-01,  ...,  1.1044e+00,
│   │                           8.1618e-01,  1.2075e+00],
│   │                         [-2.6204e+00,  4.9593e-01,  2.2340e-01,  ..., -1.3694e+00,
│   │                          -1.3057e+00,  1.6505e+00]]],
│   │               
│   │               
│   │                       [[[ 1.5815e+00, -2.4895e+00,  9.5249e-01,  ...,  2.6635e-01,
│   │                          -2.1858e-01,  7.2233e-01],
│   │                         [ 8.7314e-02,  1.3227e+00, -8.0892e-01,  ..., -8.2966e-02,
│   │                          -2.6443e+00, -3.9960e-01],
│   │                         [ 5.7032e-01,  1.9727e-01,  4.9284e-01,  ..., -1.2727e+00,
│   │                          -1.4547e+00,  3.7392e-01],
│   │                         ...,
│   │                         [ 1.4247e+00,  5.1217e-01,  7.2069e-01,  ...,  7.4686e-01,
│   │                           1.6763e-02,  1.8450e+00],
│   │                         [ 1.0928e+00, -2.7781e-01,  2.3649e-01,  ...,  7.0187e-01,
│   │                          -5.5931e-01,  3.0299e-01],
│   │                         [ 1.2083e+00, -3.5083e-01, -2.3508e-01,  ...,  1.6534e-01,
│   │                          -1.1668e+00,  4.7002e-01]],
│   │               
│   │                        [[ 1.2273e+00,  3.8515e-02, -7.8459e-01,  ..., -1.4852e+00,
│   │                          -4.4358e-01, -6.5553e-01],
│   │                         [-1.5152e-01, -5.3045e-01,  4.4792e-01,  ..., -1.6353e+00,
│   │                          -5.4691e-01, -2.5966e-01],
│   │                         [ 4.9432e-01, -1.2479e+00, -9.5712e-01,  ...,  1.0864e-01,
│   │                           9.7499e-01,  2.5438e-01],
│   │                         ...,
│   │                         [ 1.3159e+00, -1.1940e+00,  3.0967e-01,  ..., -9.4624e-01,
│   │                           1.2925e+00, -5.9210e-01],
│   │                         [ 1.0185e+00, -2.5770e-01,  9.5838e-01,  ...,  9.4455e-01,
│   │                          -2.1570e+00, -1.0822e+00],
│   │                         [-1.0953e+00,  1.4239e+00, -5.2768e-01,  ...,  2.6302e+00,
│   │                           5.6827e-01,  9.7741e-01]],
│   │               
│   │                        [[-2.0424e+00,  3.2737e-02, -4.8023e-01,  ..., -1.8149e+00,
│   │                           9.0617e-01,  1.5241e+00],
│   │                         [ 6.4508e-01,  8.7260e-01, -1.3558e-01,  ..., -5.2893e-01,
│   │                           1.0573e+00, -7.9975e-02],
│   │                         [-4.4610e-01, -9.6510e-01, -7.0041e-02,  ..., -2.9133e-01,
│   │                           1.9978e+00, -7.3373e-01],
│   │                         ...,
│   │                         [ 2.7455e-01, -1.5899e+00, -2.2344e-02,  ...,  1.3424e-01,
│   │                           6.4618e-01, -2.9424e-01],
│   │                         [-4.7321e-01,  4.5975e-01,  8.1262e-01,  ...,  5.4087e-01,
│   │                           4.9447e-01, -3.7355e-01],
│   │                         [ 3.5231e-01, -1.3274e-01, -3.5614e-01,  ...,  3.5276e-02,
│   │                           6.6931e-03, -4.8639e-01]]],
│   │               
│   │               
│   │                       [[[ 8.4379e-01, -2.6520e-01,  1.2273e+00,  ...,  5.5299e-02,
│   │                           3.4515e-01,  1.5519e+00],
│   │                         [-5.2636e-02,  1.1631e+00, -1.4064e+00,  ...,  2.0097e+00,
│   │                           6.0648e-01, -5.3019e-01],
│   │                         [ 1.5707e+00,  1.6587e+00,  7.2872e-01,  ...,  6.0731e-01,
│   │                          -6.7997e-03, -3.0308e-01],
│   │                         ...,
│   │                         [-7.0126e-01,  3.0605e-01,  9.9898e-01,  ...,  6.3144e-01,
│   │                           7.8108e-01,  1.2025e+00],
│   │                         [ 1.0146e+00, -1.1620e-01,  1.4512e+00,  ..., -1.2422e+00,
│   │                          -1.2267e+00,  1.5722e-01],
│   │                         [-8.7043e-01, -1.4461e+00, -1.6991e+00,  ..., -2.7868e-01,
│   │                          -8.2987e-01,  2.4401e+00]],
│   │               
│   │                        [[-5.9967e-01,  1.7147e-01,  5.2208e-02,  ...,  1.4748e+00,
│   │                          -8.3516e-01,  1.3266e+00],
│   │                         [ 8.5094e-01, -7.6562e-01, -8.6355e-01,  ..., -1.0500e+00,
│   │                           2.1243e+00, -9.1288e-01],
│   │                         [ 5.6462e-01, -1.5457e-02,  1.1052e-01,  ...,  2.2953e-01,
│   │                          -4.5519e-01,  1.5854e+00],
│   │                         ...,
│   │                         [ 2.7945e-01, -6.0974e-02, -1.0053e+00,  ..., -4.6412e-01,
│   │                           9.2725e-02, -7.9102e-01],
│   │                         [-5.6040e-01,  4.4604e-01, -9.9043e-01,  ..., -3.6201e-01,
│   │                           1.4416e+00,  9.3231e-01],
│   │                         [ 2.9608e+00,  3.6448e-01,  3.0193e-02,  ..., -4.8420e-01,
│   │                          -3.3271e-01,  5.9446e-01]],
│   │               
│   │                        [[ 3.8868e-01,  3.6463e-01,  1.6857e+00,  ..., -6.1569e-01,
│   │                           6.9278e-01,  4.4910e-01],
│   │                         [ 1.9834e-01,  6.9101e-01, -2.0629e-02,  ..., -1.6444e+00,
│   │                          -1.2178e+00,  3.0970e-01],
│   │                         [-9.0467e-01,  1.1019e+00,  3.9728e-01,  ...,  1.8723e-01,
│   │                           1.5122e-01,  4.4419e-01],
│   │                         ...,
│   │                         [-1.4700e+00,  4.9256e-01,  2.3022e-01,  ..., -1.0663e+00,
│   │                           3.7794e-02,  2.8978e-01],
│   │                         [-8.9654e-01, -4.4096e-01, -1.2118e+00,  ...,  2.3538e-01,
│   │                           4.7474e-01, -7.4461e-01],
│   │                         [ 3.8790e-01,  2.6374e-01, -1.7780e+00,  ..., -8.9639e-01,
│   │                           1.2213e+00,  1.1863e+00]]],
│   │               
│   │               
│   │                       [[[ 2.4322e-02, -2.5743e+00,  5.1638e-01,  ..., -3.6933e-02,
│   │                           2.6479e-01, -7.5655e-01],
│   │                         [ 8.8832e-01,  8.5004e-01,  2.9753e-01,  ...,  1.1076e+00,
│   │                           1.4869e+00,  1.3190e+00],
│   │                         [-2.6850e-01,  2.1383e+00,  4.3644e-01,  ...,  4.9856e-01,
│   │                           9.6244e-01,  1.6387e+00],
│   │                         ...,
│   │                         [-2.4553e-01,  8.6521e-01, -5.6526e-01,  ...,  7.3546e-02,
│   │                          -1.1062e+00,  4.7137e-02],
│   │                         [-2.5494e+00, -1.0082e+00, -5.7467e-01,  ..., -4.3138e-01,
│   │                           1.2186e+00, -1.2407e+00],
│   │                         [ 1.1695e+00,  8.7605e-01, -5.7000e-01,  ...,  1.8175e+00,
│   │                          -3.3591e-02,  9.1747e-01]],
│   │               
│   │                        [[-1.0501e+00, -1.6914e+00,  2.3887e+00,  ...,  4.1096e-01,
│   │                          -1.6464e+00,  5.1187e-01],
│   │                         [-1.8721e+00, -4.8365e-01, -2.1771e+00,  ..., -7.3189e-01,
│   │                          -1.0515e+00, -1.8894e-01],
│   │                         [-5.7640e-01,  2.0585e-01, -1.2769e+00,  ..., -1.7404e+00,
│   │                           1.0246e+00,  1.1252e+00],
│   │                         ...,
│   │                         [ 1.5959e+00, -7.5755e-01,  4.9973e-01,  ..., -2.1143e-01,
│   │                          -1.4349e+00, -6.3097e-01],
│   │                         [-5.5512e-01,  1.7481e+00,  1.7209e-02,  ..., -1.2853e+00,
│   │                           1.4605e+00, -1.6332e-01],
│   │                         [-2.9038e-01,  2.8948e-01, -1.5924e-01,  ..., -1.9852e-01,
│   │                          -3.8356e-01, -1.0769e+00]],
│   │               
│   │                        [[ 8.6237e-02, -6.3570e-01, -9.0865e-01,  ...,  6.9019e-01,
│   │                           2.1939e+00, -8.6537e-02],
│   │                         [-1.7389e+00, -1.7428e+00, -3.2349e-02,  ..., -1.7452e+00,
│   │                          -5.3535e-01, -5.4949e-01],
│   │                         [ 2.3114e-01, -5.8270e-01, -2.7907e-01,  ...,  1.2802e-01,
│   │                          -1.1014e+00,  2.9640e+00],
│   │                         ...,
│   │                         [ 8.0411e-01, -5.8283e-01, -8.2478e-01,  ...,  1.1436e+00,
│   │                          -1.5617e+00, -1.3802e-01],
│   │                         [ 1.5495e+00, -5.6759e-01,  8.5886e-01,  ..., -7.8004e-01,
│   │                           1.6617e+00, -3.0001e-01],
│   │                         [-7.1772e-01, -1.1057e+00, -1.2355e+00,  ...,  9.5935e-01,
│   │                           6.8387e-01,  9.0734e-01]]]])
│   └── 'scalar' --> tensor([[-1.8204, -0.1322,  1.0637, -1.3983, -0.2195,  1.7401,  0.8969, -0.8056,
│                             -0.6063,  0.7722, -0.4887,  0.7388],
│                            [ 1.0119, -0.1385, -0.7511,  0.0861, -1.1991, -0.1082,  1.0137,  0.8894,
│                              0.4636,  0.2413,  1.0350, -1.3930],
│                            [ 0.1317, -1.2091, -0.6377,  0.8724,  0.8124,  1.5408,  0.8843,  0.2093,
│                             -0.0585,  0.0530,  0.0937, -1.0890],
│                            [ 0.0135,  0.2899, -0.5063,  0.8530,  0.8375,  0.4573,  0.3635, -0.8134,
│                              0.6480,  0.6351, -0.8952, -0.4923]])
└── 'reward' --> tensor([[0.6362],
                         [0.0962],
                         [0.2362],
                         [0.2291]])

This code looks much simpler and clearer.