Stack Structured Data

When the tensors form the tree structures, they are often needed to be stacked together, like the torch.stack() implemented in torch.

Stack With Native PyTorch API

Here is the common code implement with native pytorch API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]


# execute `stack` op
def stack(data, dim):
    elem = data[0]
    if isinstance(elem, torch.Tensor):
        return torch.stack(data, dim)
    elif isinstance(elem, dict):
        return {k: stack([item[k] for item in data], dim) for k in elem.keys()}
    elif isinstance(elem, bool):
        return torch.BoolTensor(data)
    else:
        raise TypeError("not support elem type: {}".format(type(elem)))


stacked_data = stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data['obs']['image'].shape == (B, 3, 32, 32)
assert stacked_data['action'].shape == (B, 1)
assert stacked_data['reward'].shape == (B, 1)
assert stacked_data['done'].shape == (B,)
assert stacked_data['done'].dtype == torch.bool

The output should be like below, and the assertion statements can be all passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
{'obs': {'scalar': tensor([[-0.9653,  0.2794,  0.6762, -0.5648,  0.1323, -0.2394, -0.7473, -0.8947,
         -0.4574,  1.4945, -0.1608, -0.1810],
        [ 0.0704, -1.1610, -0.2324, -0.0225,  1.0857, -1.0716, -0.9738, -1.6584,
          1.3913,  0.5232,  1.2952, -0.2125],
        [-0.4179, -0.1554,  0.1348, -1.3415,  1.3966,  2.3113,  0.6406,  1.4914,
         -0.2567, -0.6691, -1.3460,  0.0437],
        [ 1.1704, -0.1859, -1.8404, -2.0382,  0.8331,  0.8798,  0.2999,  0.3354,
          0.3455,  0.6102, -0.0515,  0.3824]]), 'image': tensor([[[[-2.1510e-01,  1.2466e+00, -4.2766e-01,  ..., -1.3210e+00,
            9.8189e-01, -1.5625e-01],
          [-1.5442e+00,  7.5932e-01, -6.6570e-01,  ...,  1.3680e+00,
            8.6966e-01,  1.0230e+00],
          [-5.6185e-01,  2.1383e-01,  2.3674e-02,  ..., -9.6401e-02,
           -8.8759e-01, -7.6741e-01],
          ...,
          [-7.1287e-01, -1.8947e+00, -1.1646e+00,  ...,  2.5116e-01,
           -7.6691e-01, -6.3039e-01],
          [-4.0780e-01, -4.8236e-01,  1.3168e-01,  ...,  6.7044e-01,
            1.0694e-01, -1.3816e+00],
          [-3.2997e-01, -2.9355e-01,  1.8204e+00,  ..., -1.2730e-01,
           -1.6598e+00, -2.5747e-01]],

         [[-1.7536e-01, -1.4343e-01, -1.3264e+00,  ...,  1.1017e+00,
           -4.0140e-01, -7.4133e-01],
          [-1.0264e+00, -1.5133e-02, -1.6098e+00,  ..., -1.8364e+00,
            1.1302e-01, -1.6570e+00],
          [ 6.6805e-01, -5.9831e-01, -2.7080e-01,  ..., -2.0288e+00,
           -7.9704e-01,  9.1810e-01],
          ...,
          [-1.8158e+00,  8.4483e-01, -1.2064e+00,  ...,  1.8743e+00,
            7.3134e-01,  5.9182e-01],
          [-3.5575e-01, -7.5124e-01, -1.3447e+00,  ...,  1.1492e+00,
           -9.7551e-02, -3.7493e-01],
          [ 5.3087e-01, -8.4311e-01,  1.0880e+00,  ...,  1.4205e-02,
           -5.7899e-01,  1.4183e+00]],

         [[-8.9715e-01, -2.1203e+00, -2.9776e-01,  ...,  6.7714e-01,
            1.3840e+00, -1.5458e-01],
          [ 2.8555e-01,  3.4980e-01, -1.1201e+00,  ...,  1.7664e-01,
           -5.0757e-01,  1.2395e+00],
          [ 5.9821e-01, -5.4228e-01,  1.1660e-01,  ..., -4.3247e-01,
            2.5472e-01,  6.2107e-01],
          ...,
          [-4.8011e-01,  1.7328e+00,  1.5797e-01,  ..., -7.2143e-01,
           -6.4900e-01,  3.0915e-01],
          [ 5.3445e-01, -2.8293e-01,  4.5129e-01,  ...,  6.8931e-01,
            1.3022e+00, -2.3826e+00],
          [ 5.5744e-01, -8.0627e-01, -7.6569e-01,  ..., -5.8473e-01,
            5.6329e-02,  1.8209e+00]]],


        [[[ 1.7417e+00, -5.2286e-01, -5.7421e-01,  ...,  3.8276e-01,
            5.9477e-01, -6.7463e-01],
          [ 1.0149e+00, -2.8901e-01,  2.7958e-01,  ..., -1.7723e+00,
            9.1930e-01, -2.0273e-01],
          [-2.1279e+00,  8.9646e-02, -3.4087e-01,  ...,  1.2326e-01,
           -1.8085e+00,  1.3335e+00],
          ...,
          [-2.6689e+00, -3.9941e-01, -2.7378e-01,  ...,  1.1493e-01,
            2.7423e-01, -4.4539e-02],
          [-2.5137e+00, -8.8941e-01, -3.4811e-01,  ..., -4.2482e-01,
            3.4490e-01, -4.6287e-01],
          [ 2.2572e-01,  4.2749e-01,  2.8765e-01,  ...,  4.6673e-01,
           -6.3409e-01,  1.8812e+00]],

         [[-3.9139e-01,  2.3599e+00, -2.6551e-01,  ..., -4.4410e-01,
            1.8228e-01, -9.6349e-01],
          [ 9.5802e-01, -3.8816e-01, -3.3238e-02,  ..., -3.4134e-01,
           -1.0318e+00,  1.5829e+00],
          [-4.7366e-01,  7.2926e-01, -2.0550e-02,  ...,  1.0060e+00,
            3.8402e-01, -1.5309e+00],
          ...,
          [ 4.7439e-02,  5.1075e-01, -3.5242e-01,  ...,  1.3405e-01,
           -2.6594e-01,  1.3726e+00],
          [ 2.3895e-01,  2.1416e-01,  8.2366e-01,  ..., -4.1398e-01,
           -9.8388e-03, -1.8883e-01],
          [-9.2074e-01,  5.4405e-01, -5.3393e-01,  ...,  1.0737e+00,
           -7.0875e-01,  1.5392e+00]],

         [[-1.0356e+00,  7.9282e-01, -7.0456e-01,  ..., -1.9830e+00,
           -1.6769e+00, -1.3357e+00],
          [ 1.3427e+00,  1.3562e+00, -1.7503e+00,  ...,  7.3884e-01,
           -6.9080e-01, -1.4351e+00],
          [ 1.4125e-03,  1.7199e+00, -1.4906e+00,  ..., -1.1759e+00,
           -1.9094e+00, -7.7813e-01],
          ...,
          [-1.7033e+00, -1.1843e-01, -5.1239e-02,  ...,  2.2637e+00,
           -6.1274e-01, -5.4958e-01],
          [ 6.2504e-01,  5.6857e-01, -1.9228e+00,  ..., -4.7543e-01,
           -8.3564e-01, -6.0720e-02],
          [-9.4151e-01,  6.6396e-01,  4.0825e-01,  ...,  9.1352e-01,
            1.7016e+00, -8.9271e-01]]],


        [[[-1.8035e-02, -4.1038e-01,  4.6528e-01,  ...,  1.3320e+00,
           -2.2202e-02, -5.6734e-01],
          [ 1.5609e+00, -1.3475e+00, -6.8490e-01,  ..., -1.4487e+00,
           -1.7813e+00, -4.8816e-01],
          [-2.4170e-01, -1.9875e-01, -4.7081e-01,  ...,  5.5468e-01,
            1.5089e+00, -1.7039e+00],
          ...,
          [-2.8433e-01, -1.1331e-01,  3.8021e-01,  ..., -5.2810e-01,
            8.6127e-01,  1.0419e-01],
          [ 4.7962e-01, -3.5945e-01, -1.2636e+00,  ...,  1.0710e-01,
            2.0815e-01, -2.3217e-01],
          [-6.8797e-01, -9.7932e-01, -1.7100e-01,  ...,  2.1285e-01,
            2.0813e+00, -1.1733e-01]],

         [[ 1.6367e+00,  1.7202e-01, -2.6958e-01,  ...,  1.4795e+00,
            1.3829e+00,  1.0713e+00],
          [ 9.4590e-01, -7.3242e-01,  3.0893e-01,  ...,  1.3215e+00,
           -2.4057e+00,  1.4025e+00],
          [-2.9022e-01, -6.9762e-02,  9.9333e-01,  ...,  1.5793e+00,
           -8.0237e-01,  1.0870e+00],
          ...,
          [ 1.5575e+00,  9.3970e-02,  1.5473e+00,  ..., -1.4530e+00,
            1.3246e+00,  5.5236e-01],
          [-2.7165e-01, -6.5844e-01,  2.5957e-01,  ...,  8.3361e-01,
            4.2603e-01, -9.0773e-01],
          [-4.8695e-01, -9.4238e-01, -5.6357e-02,  ..., -1.1497e+00,
           -1.8601e+00,  2.8369e-01]],

         [[-2.2039e-01, -3.2542e-01,  1.3395e+00,  ..., -1.4653e-01,
            5.0896e-01,  1.4688e+00],
          [-3.5080e-01, -1.0138e+00,  1.0627e-01,  ...,  2.0036e-01,
            5.5636e-01,  1.2772e+00],
          [ 3.9791e-01,  8.3535e-01, -1.5602e+00,  ..., -1.4974e+00,
           -1.2321e-01, -5.0440e-01],
          ...,
          [ 3.0972e-01,  7.1728e-01,  3.0622e-01,  ...,  5.4236e-01,
            1.5220e+00, -3.0230e-01],
          [-1.0495e+00,  1.5131e+00, -9.7914e-01,  ...,  1.4621e+00,
            6.8744e-01,  1.3985e+00],
          [-2.1085e+00, -5.4347e-01, -5.1850e-01,  ...,  3.2091e+00,
           -1.6795e+00,  5.8800e-01]]],


        [[[-1.6973e+00,  1.4300e+00, -8.4846e-01,  ..., -2.9478e-01,
            8.3313e-01, -8.2913e-01],
          [-7.3278e-01,  4.6506e-02,  2.6129e-01,  ..., -1.6276e+00,
           -1.2436e+00, -2.6591e-01],
          [ 6.9983e-01,  8.3662e-01, -2.2950e-01,  ..., -1.3067e+00,
           -3.1848e-01,  1.2753e+00],
          ...,
          [-1.5203e+00, -2.3257e+00,  7.1926e-01,  ..., -3.5604e-01,
           -7.6428e-01,  2.8162e-01],
          [ 1.1173e+00,  1.5596e-01,  3.2165e-02,  ..., -1.7895e+00,
            1.5967e+00, -1.7299e+00],
          [ 1.2477e+00,  2.9796e-01,  4.2369e-02,  ...,  8.9054e-01,
            8.2655e-01,  6.7353e-01]],

         [[-2.6982e+00,  2.2135e+00, -1.3709e+00,  ..., -2.3079e-01,
            1.9187e+00,  3.7689e-01],
          [-5.6415e-01, -1.3818e+00,  7.2952e-01,  ..., -3.4655e-01,
           -1.6244e+00, -6.8353e-01],
          [-7.6714e-01, -5.7190e-02, -1.6892e-01,  ...,  1.0527e+00,
            6.3541e-01,  1.2930e+00],
          ...,
          [ 9.4460e-01, -2.5814e-01, -3.7332e-01,  ..., -3.0850e-01,
           -1.8962e+00, -1.5938e-01],
          [-6.6908e-01,  2.8770e-02,  2.4292e-02,  ..., -5.9262e-01,
            7.8031e-01, -9.3813e-01],
          [-4.1122e-01,  7.5999e-01,  2.5141e-01,  ..., -1.9138e+00,
           -4.3660e-01, -7.7377e-02]],

         [[ 1.6859e-01,  7.6200e-01,  8.0001e-02,  ...,  8.0774e-01,
            4.5494e-01, -2.2144e-01],
          [-2.8631e-01,  3.2273e-01, -1.0323e+00,  ..., -2.0587e-01,
            1.3576e+00, -4.7716e-01],
          [-8.4524e-02,  9.3652e-01, -1.1412e+00,  ..., -1.6467e+00,
            1.6534e-01, -7.1401e-01],
          ...,
          [-4.5254e-01,  1.1064e+00,  1.9975e-01,  ..., -5.8709e-01,
           -3.1495e-01, -1.4034e+00],
          [-9.1621e-01, -1.0344e+00, -1.1503e+00,  ..., -8.7694e-01,
           -2.8796e-01,  6.2142e-01],
          [ 1.0341e+00,  1.2766e+00, -6.4451e-01,  ..., -2.4392e-01,
           -6.7707e-01,  5.8827e-01]]]])}, 'action': tensor([[7],
        [2],
        [2],
        [7]]), 'reward': tensor([[0.3620],
        [0.3212],
        [0.8445],
        [0.6204]]), 'done': tensor([False, False, False, False])}

We can see that the structure process function need to be fully implemented (like the function stack). This code is actually not clear, and due to hard coding, if you need to support more data types (such as integer), you must make special modifications to the function.

Stack With TreeTensor API

The same workflow can be implemented with treetensor API like the code below.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch

import treetensor.torch as ttorch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]
# execute `stack` op
data = [ttorch.tensor(d) for d in data]
stacked_data = ttorch.stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data.obs.image.shape == (B, 3, 32, 32)
assert stacked_data.action.shape == (B, 1)
assert stacked_data.reward.shape == (B, 1)
assert stacked_data.done.shape == (B,)
assert stacked_data.done.dtype == torch.bool

The output should be like below, and the assertion statements can be all passed as well.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
<Tensor 0x7fd1cdae8410>
├── 'action' --> tensor([[4],
│                        [4],
│                        [7],
│                        [6]])
├── 'done' --> tensor([False, False, False, False])
├── 'obs' --> <Tensor 0x7fd1cdae8450>
│   ├── 'image' --> tensor([[[[ 1.1120e+00,  1.0652e+00, -1.8711e-02,  ..., -1.4274e+00,
│   │                          -8.9094e-01, -1.2164e+00],
│   │                         [-7.3942e-01,  3.3305e-01,  7.6887e-01,  ..., -2.5532e+00,
│   │                           2.1153e+00, -2.1781e-01],
│   │                         [-1.4291e+00,  4.2130e-01, -1.9371e+00,  ...,  1.5247e+00,
│   │                           4.3569e-01, -8.3170e-01],
│   │                         ...,
│   │                         [-1.0854e+00, -4.9489e-01,  6.5617e-01,  ..., -1.2792e-01,
│   │                          -7.5047e-01,  3.9402e-01],
│   │                         [-3.8303e-01,  1.2912e+00,  5.9583e-01,  ...,  5.6579e-01,
│   │                           9.8079e-01,  2.3406e-01],
│   │                         [ 2.9173e-01, -1.1586e+00,  9.4326e-01,  ..., -4.6902e-01,
│   │                           1.7224e+00, -7.0382e-02]],
│   │               
│   │                        [[-1.1771e+00,  2.1115e-01, -1.2133e-01,  ...,  1.0395e-01,
│   │                           3.2609e-02,  6.4315e-01],
│   │                         [-5.6029e-01, -1.6556e-01, -1.5280e-01,  ..., -1.3250e+00,
│   │                           1.8436e+00,  8.9761e-01],
│   │                         [ 1.1372e+00, -3.2635e-01, -2.8189e-01,  ...,  1.0379e+00,
│   │                          -4.7714e-01,  1.1621e+00],
│   │                         ...,
│   │                         [ 2.1081e-01,  1.0286e+00, -1.3434e+00,  ..., -1.5022e+00,
│   │                          -7.8666e-01, -3.8443e-01],
│   │                         [ 6.2958e-01,  4.8181e-01,  4.5918e-01,  ...,  5.5505e-02,
│   │                           6.8336e-01,  5.9625e-01],
│   │                         [ 6.1970e-01,  1.8820e+00,  9.0758e-01,  ..., -1.1916e-01,
│   │                           5.9577e-01, -6.0466e-01]],
│   │               
│   │                        [[ 4.6491e-01, -1.3628e+00, -2.6461e-01,  ...,  1.1364e+00,
│   │                          -2.4842e-01, -6.0490e-01],
│   │                         [ 8.1897e-01, -7.8106e-01, -7.4445e-01,  ...,  1.5070e+00,
│   │                           7.1418e-01, -4.3707e-01],
│   │                         [-1.1498e+00, -7.1178e-01, -7.6401e-01,  ...,  1.1338e+00,
│   │                          -7.2540e-01,  2.2130e-01],
│   │                         ...,
│   │                         [ 1.9028e+00, -2.2978e-01,  7.4112e-01,  ..., -5.0386e-01,
│   │                          -9.6476e-03, -5.2327e-01],
│   │                         [ 9.3059e-01,  8.6720e-03, -5.7161e-01,  ...,  3.1230e+00,
│   │                          -3.3930e-01, -2.9262e+00],
│   │                         [-2.0024e+00, -4.2776e-01,  1.0159e+00,  ...,  7.2031e-01,
│   │                          -8.4929e-01,  6.7192e-01]]],
│   │               
│   │               
│   │                       [[[-7.4946e-01,  6.2085e-03, -9.3155e-01,  ..., -1.2886e+00,
│   │                           8.1089e-01,  1.4471e+00],
│   │                         [-1.7379e-01,  3.3838e-01,  3.5747e-01,  ..., -8.1911e-01,
│   │                          -5.0463e-01,  1.0396e+00],
│   │                         [-6.0061e-02,  2.2164e+00, -1.3818e+00,  ..., -7.1555e-01,
│   │                          -1.5584e-01, -6.2739e-02],
│   │                         ...,
│   │                         [-1.4002e+00,  2.0557e-01,  1.3711e+00,  ..., -5.7885e-01,
│   │                           1.2891e-01, -4.4446e-01],
│   │                         [ 1.2462e+00, -1.1010e-01, -5.6955e-01,  ..., -4.9456e-02,
│   │                           4.5143e-01,  4.3211e-01],
│   │                         [-1.2503e+00, -1.6337e+00, -9.4521e-01,  ..., -8.2510e-01,
│   │                          -5.4567e-02,  1.8762e+00]],
│   │               
│   │                        [[-1.1783e-01, -9.6168e-01, -1.7166e-01,  ...,  3.8140e-01,
│   │                           3.6628e-01,  7.7986e-01],
│   │                         [-8.9988e-01,  3.7112e-01,  4.7687e-01,  ..., -1.3090e+00,
│   │                           1.2695e-02, -8.3217e-01],
│   │                         [ 4.8216e-01,  5.2280e-01, -3.9670e-01,  ..., -6.6901e-01,
│   │                           9.4368e-01, -1.0293e+00],
│   │                         ...,
│   │                         [-2.1858e-01,  7.0027e-01,  1.5163e+00,  ..., -1.0875e+00,
│   │                           1.2678e+00,  7.5929e-01],
│   │                         [-1.9335e-01,  1.4673e-01, -8.8923e-01,  ..., -8.3066e-01,
│   │                          -1.4487e-01,  2.1643e-01],
│   │                         [ 2.8073e-02,  6.3094e-01,  6.3708e-02,  ...,  1.3551e+00,
│   │                           1.6452e+00, -6.6140e-01]],
│   │               
│   │                        [[ 4.9463e-03, -1.9623e-01, -4.9203e-01,  ...,  6.7779e-01,
│   │                          -8.4536e-01, -1.4163e+00],
│   │                         [ 2.0980e+00,  5.9004e-01,  2.0284e-01,  ...,  1.0863e+00,
│   │                          -1.8685e+00, -5.3584e-01],
│   │                         [-4.8914e-01, -9.2731e-01, -3.0998e-01,  ..., -1.7269e-01,
│   │                          -9.2906e-01,  1.1111e+00],
│   │                         ...,
│   │                         [-1.7144e-01, -1.1823e+00,  1.0034e+00,  ...,  8.7030e-01,
│   │                           3.2317e-02, -4.7965e-01],
│   │                         [ 1.6552e-01,  8.0694e-01,  3.4302e-01,  ...,  5.4761e-01,
│   │                           1.4254e-01,  1.8610e+00],
│   │                         [ 8.5434e-01,  2.6667e-02,  1.4795e+00,  ...,  4.7420e-01,
│   │                          -2.6851e-01,  6.3540e-01]]],
│   │               
│   │               
│   │                       [[[ 1.1161e-01,  2.7600e-01, -9.0501e-01,  ...,  8.5278e-01,
│   │                          -5.7198e-01, -1.0580e+00],
│   │                         [ 2.9975e-02, -2.1246e-02, -6.6952e-01,  ..., -1.1615e+00,
│   │                           1.6845e+00,  3.5386e-01],
│   │                         [-1.0669e+00,  1.2259e-01,  8.0363e-01,  ..., -3.4948e-01,
│   │                          -1.5478e+00, -1.0355e+00],
│   │                         ...,
│   │                         [-4.9419e-01,  7.2132e-01,  1.5403e-01,  ...,  5.5404e-01,
│   │                          -4.1560e-01, -2.9129e-01],
│   │                         [ 1.1063e+00,  2.5934e-01, -2.3085e+00,  ...,  7.3961e-01,
│   │                          -5.3484e-02,  1.2675e+00],
│   │                         [-9.7389e-01, -3.4386e-02, -2.0944e+00,  ..., -6.9170e-01,
│   │                          -5.5973e-01,  8.8102e-01]],
│   │               
│   │                        [[-3.2493e-01,  1.7689e+00, -7.4987e-01,  ...,  2.1188e+00,
│   │                           2.8330e-01, -8.6761e-01],
│   │                         [ 1.3041e+00,  2.3690e+00,  1.7224e-01,  ..., -2.3055e-01,
│   │                           1.1557e+00,  1.5670e+00],
│   │                         [-1.0137e+00, -4.6392e-01,  1.0091e+00,  ..., -7.9436e-01,
│   │                           1.3387e+00, -1.4808e+00],
│   │                         ...,
│   │                         [-1.1166e+00, -1.8199e+00, -6.0483e-01,  ..., -1.2928e+00,
│   │                           6.5236e-01, -4.3333e-01],
│   │                         [ 8.1618e-01, -3.3065e-01,  4.6522e-01,  ...,  1.4188e-01,
│   │                          -2.6101e-01, -9.2924e-01],
│   │                         [ 9.4463e-01, -2.1671e-01, -1.5293e+00,  ..., -4.5051e-01,
│   │                           2.0618e-01,  2.0891e+00]],
│   │               
│   │                        [[-3.7359e-01, -1.6675e+00,  4.2610e-01,  ..., -4.6969e-01,
│   │                          -8.1186e-01,  4.1536e-01],
│   │                         [ 5.8391e-01, -3.0582e-01,  2.1031e-01,  ..., -2.3754e-01,
│   │                           1.9116e+00, -2.0607e+00],
│   │                         [ 1.3587e+00, -1.0052e-01,  1.1926e+00,  ...,  7.2726e-01,
│   │                          -3.3619e-01,  7.4082e-01],
│   │                         ...,
│   │                         [ 1.7260e+00, -4.0726e-01, -3.2122e-01,  ...,  3.7223e-01,
│   │                          -1.2357e-01, -7.4051e-01],
│   │                         [ 6.4691e-01,  1.6219e+00, -1.4099e+00,  ..., -1.1589e+00,
│   │                          -1.7679e-01, -1.4607e-01],
│   │                         [ 3.5256e-01,  1.6297e+00, -1.7844e+00,  ..., -3.6691e-01,
│   │                           8.3400e-02, -3.7971e-01]]],
│   │               
│   │               
│   │                       [[[ 1.6298e+00,  1.0045e+00, -7.1409e-01,  ..., -1.1313e+00,
│   │                          -2.6969e-02,  1.2998e+00],
│   │                         [ 1.4353e+00,  2.5983e+00,  9.2646e-01,  ..., -1.0482e+00,
│   │                          -1.8050e+00, -6.6460e-01],
│   │                         [ 6.2836e-01, -3.0142e-03, -4.2502e-01,  ..., -7.5916e-01,
│   │                          -3.7153e-01, -6.8018e-01],
│   │                         ...,
│   │                         [ 5.9002e-01, -8.7105e-01,  1.0054e-01,  ..., -1.5220e-02,
│   │                           1.0148e+00, -6.5861e-01],
│   │                         [ 1.5776e-01,  9.4041e-01, -4.1541e-01,  ..., -2.4471e-01,
│   │                           1.1865e+00, -9.9986e-01],
│   │                         [ 8.7208e-01,  2.3051e+00, -2.3872e-01,  ..., -3.0373e-01,
│   │                          -1.4570e+00, -4.5589e-01]],
│   │               
│   │                        [[ 1.5979e-01,  6.4252e-01,  3.4811e-01,  ...,  6.7810e-01,
│   │                          -1.9442e+00,  1.6983e-01],
│   │                         [ 1.3424e+00, -2.7060e-01,  1.0343e+00,  ..., -5.8269e-02,
│   │                           5.3868e-03,  2.1271e-01],
│   │                         [ 3.5006e-01,  1.2138e+00,  3.3074e-01,  ...,  1.6759e+00,
│   │                           8.2739e-01,  1.3642e+00],
│   │                         ...,
│   │                         [ 4.7256e-01, -5.4217e-01,  1.7902e+00,  ..., -6.1768e-01,
│   │                           1.7588e+00,  2.2403e+00],
│   │                         [ 5.0427e-01,  1.3930e-01, -1.2398e+00,  ...,  9.6154e-01,
│   │                          -1.4100e+00,  2.9219e-01],
│   │                         [-1.4728e+00,  1.3727e-01, -4.6274e-01,  ..., -4.9321e-01,
│   │                          -2.6927e-01, -1.8901e-01]],
│   │               
│   │                        [[-1.5526e-01,  6.0333e-01,  4.2962e-01,  ...,  1.5859e+00,
│   │                          -1.6241e+00,  4.1703e-01],
│   │                         [-7.8257e-03, -1.8631e+00,  4.5330e-01,  ...,  6.7645e-01,
│   │                           4.3932e-01,  1.2232e+00],
│   │                         [ 5.2341e-01, -4.7767e-01, -7.4467e-01,  ..., -4.2431e-01,
│   │                           1.0481e+00,  1.8991e-01],
│   │                         ...,
│   │                         [ 1.8289e+00,  6.3172e-01, -1.8334e+00,  ..., -1.9872e+00,
│   │                           1.1775e+00, -2.6506e-01],
│   │                         [ 2.1269e+00, -6.5587e-01, -5.3907e-02,  ...,  2.2962e+00,
│   │                          -1.9859e-01, -1.2717e+00],
│   │                         [ 7.0784e-02,  9.8301e-01,  6.3765e-01,  ...,  5.1097e-01,
│   │                          -5.3653e-01, -7.0412e-01]]]])
│   └── 'scalar' --> tensor([[ 1.2278, -0.2813,  0.4079, -0.9899,  0.2032,  0.0373,  1.2151,  0.8677,
│                             -0.3586, -0.0770, -0.9867, -0.5494],
│                            [-1.3885,  0.0707,  0.2566,  0.6897, -0.0413,  0.7399, -0.4054, -0.4191,
│                             -1.0761, -0.2202,  0.7571,  0.3978],
│                            [ 0.9820, -2.1144, -0.2150, -1.6446, -0.7589, -0.9285, -0.4003, -0.7832,
│                             -0.4763,  0.9101, -0.5856, -1.1109],
│                            [ 0.1793, -0.2907, -1.1559,  1.3444,  0.8474, -2.4334,  0.2661, -0.3551,
│                              0.8259,  0.9258,  2.5820, -0.3945]])
└── 'reward' --> tensor([[0.1670],
                         [0.3284],
                         [0.8451],
                         [0.9060]])

This code looks much simpler and clearer.