Stack Structured Data

When the tensors form the tree structures, they are often needed to be stacked together, like the torch.stack() implemented in torch.

Stack With Native PyTorch API

Here is the common code implement with native pytorch API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]


# execute `stack` op
def stack(data, dim):
    elem = data[0]
    if isinstance(elem, torch.Tensor):
        return torch.stack(data, dim)
    elif isinstance(elem, dict):
        return {k: stack([item[k] for item in data], dim) for k in elem.keys()}
    elif isinstance(elem, bool):
        return torch.BoolTensor(data)
    else:
        raise TypeError("not support elem type: {}".format(type(elem)))


stacked_data = stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data['obs']['image'].shape == (B, 3, 32, 32)
assert stacked_data['action'].shape == (B, 1)
assert stacked_data['reward'].shape == (B, 1)
assert stacked_data['done'].shape == (B,)
assert stacked_data['done'].dtype == torch.bool

The output should be like below, and the assertion statements can be all passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
{'obs': {'scalar': tensor([[ 4.3669e-01,  8.9021e-01,  5.3170e-01,  3.8850e-01,  2.5766e-01,
          1.4229e+00, -4.6344e-01, -5.7613e-01,  1.2565e+00,  8.8336e-01,
          2.8976e-02, -4.5492e-01],
        [-4.6314e-01,  1.0236e+00, -8.2004e-01,  7.4699e-01, -7.9703e-01,
          6.0406e-01,  7.6891e-02, -8.1411e-02,  7.3719e-01,  2.2208e+00,
         -1.5854e-01,  4.3602e-01],
        [-2.2084e+00,  5.5808e-01,  1.2126e+00, -3.7734e-01,  2.0676e-01,
          1.0220e+00, -1.4313e+00, -1.2664e+00, -4.8975e-01, -2.6101e+00,
          1.3745e-01, -1.1013e+00],
        [-7.1217e-01, -9.4820e-01, -4.0018e-01,  2.6759e-02,  7.9237e-01,
         -2.6086e-03, -9.0878e-01,  2.6325e-01, -2.6493e-01, -7.7794e-01,
          1.2291e+00,  1.0078e+00]]), 'image': tensor([[[[-0.3275,  0.4182, -1.4739,  ...,  0.6414, -0.5135,  0.5664],
          [ 1.3142,  1.1921,  2.0937,  ...,  1.3473, -0.1149, -1.0577],
          [ 1.5199, -0.1670, -0.5960,  ..., -0.4828, -0.2021, -0.9806],
          ...,
          [-0.0788,  1.2865,  1.0771,  ...,  1.7751, -2.3545, -0.7276],
          [ 0.3823,  1.0907,  1.3590,  ...,  0.7005, -0.5024, -1.1501],
          [ 0.6898, -0.0465,  1.0417,  ...,  0.5936,  1.2882, -0.1992]],

         [[ 0.3578,  1.8518, -0.0595,  ..., -0.1942,  1.1433, -1.0692],
          [-0.1811, -1.1885,  1.1485,  ...,  0.0635,  0.9963,  0.2163],
          [ 0.3372,  0.8306,  1.0547,  ..., -0.6834,  0.0721, -0.5750],
          ...,
          [ 0.8159,  0.2718, -0.7497,  ...,  0.9155,  0.3946,  0.0981],
          [ 0.1430, -0.0116, -0.7891,  ...,  0.9564, -0.6220, -2.7529],
          [-1.4215,  0.8091, -0.2880,  ...,  0.0422, -1.5128,  1.2283]],

         [[ 0.0434, -1.1959,  0.2484,  ..., -0.5701,  0.3529, -0.1997],
          [ 0.8784,  0.4898,  1.3395,  ...,  0.2363, -0.6968,  0.0892],
          [-0.6362,  0.6688, -0.8044,  ...,  0.2682, -0.3687, -0.8847],
          ...,
          [ 0.4926,  0.8782,  0.3203,  ..., -1.4656,  0.0237,  0.5293],
          [ 1.3433,  0.8925, -0.6735,  ..., -1.2687, -0.5520, -0.3015],
          [ 0.5007,  1.3277, -0.7515,  ..., -0.8983, -0.3882,  0.4258]]],


        [[[ 0.6151,  0.7920,  1.0805,  ..., -0.9930, -0.5093,  0.0165],
          [ 0.7579, -0.6801, -0.1999,  ...,  0.9805,  0.2783, -0.7201],
          [-0.0452,  0.1348, -0.0970,  ..., -1.1582, -1.4129, -0.3204],
          ...,
          [ 0.2511, -0.1183, -0.3363,  ...,  1.2248, -0.7386, -0.3060],
          [-0.4246,  0.4046, -1.6528,  ..., -0.0469,  1.3259, -1.5073],
          [-1.4124,  0.6013,  0.2309,  ..., -0.9346,  1.8208,  0.0704]],

         [[ 0.2008, -0.6656, -0.1388,  ...,  0.0065,  0.3486,  0.8605],
          [ 0.3872, -1.1523, -1.2916,  ..., -0.7043,  0.8579,  0.5939],
          [ 0.6237, -0.3722,  1.1874,  ...,  0.3611, -0.0383, -0.4131],
          ...,
          [ 1.1506, -1.5977,  1.7848,  ...,  0.7646, -0.0908, -0.0747],
          [ 0.1997, -1.1890,  0.2889,  ...,  1.6946, -0.6036, -0.7187],
          [-0.8289, -0.1977, -0.1009,  ..., -0.7609,  0.5790, -0.5853]],

         [[-0.2818, -1.0831,  0.6289,  ..., -1.1969,  0.4769,  0.2612],
          [-0.4751, -1.2800, -0.4016,  ..., -0.1577,  0.5338,  1.3439],
          [ 0.3587,  0.2927,  0.3256,  ...,  1.5164,  0.5147,  0.1062],
          ...,
          [ 0.5571, -0.0127,  1.9707,  ...,  0.0044, -0.6523,  0.3054],
          [-0.7017, -0.6082,  0.3631,  ..., -0.1921, -0.9695, -0.0937],
          [ 0.9148, -1.2189,  0.8631,  ...,  1.3444, -2.0685, -0.1052]]],


        [[[ 0.5203, -1.6085,  0.7390,  ...,  1.3364, -0.9773,  0.6808],
          [-0.9380,  0.0376, -0.2978,  ..., -1.3102, -0.0484, -0.8849],
          [ 0.3203,  0.0055,  0.6669,  ...,  0.5124,  1.2998, -0.8741],
          ...,
          [ 0.6549,  0.7704, -0.6505,  ...,  1.3443,  0.4464, -0.6298],
          [ 0.1148,  0.1203, -0.0356,  ..., -0.0665, -1.3116, -0.1249],
          [ 0.3115, -0.7385, -2.1397,  ..., -1.5787,  0.9624,  0.8732]],

         [[-0.8747,  0.3767,  0.1666,  ..., -1.3713, -2.1499,  1.5090],
          [ 0.3312,  1.4856, -1.0122,  ...,  0.3993,  1.1563, -1.1390],
          [-2.3825,  2.3870, -0.2334,  ..., -0.6649, -1.5679,  0.8052],
          ...,
          [-1.3047, -0.3681,  1.7008,  ..., -0.6408, -0.1296, -1.1401],
          [-0.7674,  0.5899,  0.2221,  ...,  0.8101, -0.0342, -1.0104],
          [-0.6057, -0.5307,  0.9368,  ...,  0.8637,  0.2015, -1.3168]],

         [[ 0.0608,  0.4613, -0.1552,  ..., -0.4030,  2.0336,  0.0781],
          [-0.7815, -1.2312, -1.6037,  ..., -1.2078, -0.4956,  1.4761],
          [-0.4462, -0.4926,  1.7665,  ..., -0.0586, -1.4150,  0.4706],
          ...,
          [-0.0281, -0.1416, -1.0383,  ...,  0.0424, -1.0266,  0.0681],
          [ 1.4225, -0.3949,  1.8961,  ...,  1.0405, -0.4605,  0.2169],
          [-0.2689,  0.2854,  1.0440,  ..., -1.2716,  0.7755, -0.8281]]],


        [[[-0.5578, -1.0346,  1.1742,  ..., -1.8231, -0.7132, -1.2237],
          [ 0.0875, -0.6741,  0.3862,  ...,  0.1413, -0.2122, -0.0280],
          [-0.2127,  1.4064, -0.2287,  ...,  0.9333, -1.1013, -1.4483],
          ...,
          [-0.9176, -1.4950,  0.3122,  ...,  0.3414,  0.2909, -0.4952],
          [-0.8250, -0.8918,  0.9429,  ..., -0.9363, -0.5420, -0.4425],
          [ 0.8443, -0.3708,  0.5456,  ..., -1.3212, -1.2376, -0.8042]],

         [[-0.0065, -0.4843, -0.5712,  ..., -1.7102,  0.2239, -0.6678],
          [-0.8297, -0.5499,  0.4607,  ...,  0.0193, -0.4139,  0.8594],
          [-0.4408, -0.9924,  0.7317,  ...,  1.7710, -0.5421, -0.1837],
          ...,
          [-0.6618, -0.1726, -1.4053,  ..., -0.9396,  0.0144,  1.6487],
          [ 1.2997, -0.2583, -0.5821,  ...,  0.0840,  0.2057,  2.1852],
          [ 0.7864, -0.5424, -0.1111,  ..., -0.4962, -2.0916, -0.5218]],

         [[-0.0748, -0.3309,  0.6856,  ..., -1.4056, -0.9241,  0.8368],
          [-0.5009, -0.8774, -1.9881,  ..., -0.1075, -0.0656,  0.6465],
          [-0.4951,  0.9669,  0.7483,  ..., -0.7964, -0.8651,  0.8845],
          ...,
          [ 0.9168, -2.2085,  0.5690,  ..., -0.3359, -0.7335,  0.5488],
          [ 0.4435, -0.6064,  1.1689,  ..., -1.0362, -0.2478, -0.2087],
          [ 1.9028,  1.1593, -1.0306,  ...,  2.2122, -0.3239,  0.8489]]]])}, 'action': tensor([[1],
        [5],
        [1],
        [3]]), 'reward': tensor([[0.5848],
        [0.1354],
        [0.3738],
        [0.8513]]), 'done': tensor([False, False, False, False])}

We can see that the structure process function need to be fully implemented (like the function stack). This code is actually not clear, and due to hard coding, if you need to support more data types (such as integer), you must make special modifications to the function.

Stack With TreeTensor API

The same workflow can be implemented with treetensor API like the code below.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch

import treetensor.torch as ttorch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]
# execute `stack` op
data = [ttorch.tensor(d) for d in data]
stacked_data = ttorch.stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data.obs.image.shape == (B, 3, 32, 32)
assert stacked_data.action.shape == (B, 1)
assert stacked_data.reward.shape == (B, 1)
assert stacked_data.done.shape == (B,)
assert stacked_data.done.dtype == torch.bool

The output should be like below, and the assertion statements can be all passed as well.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
<Tensor 0x7f6dab8b72d0>
├── 'action' --> tensor([[0],
│                        [6],
│                        [2],
│                        [7]])
├── 'done' --> tensor([False, False, False, False])
├── 'obs' --> <Tensor 0x7f6dab982650>
│   ├── 'image' --> tensor([[[[-7.0929e-01, -9.7217e-01, -1.2975e-02,  ...,  2.1097e+00,
│   │                           9.8190e-02,  7.5606e-01],
│   │                         [ 1.9071e-01,  2.6015e+00,  5.5273e-01,  ..., -1.8416e+00,
│   │                           5.5823e-01, -2.0857e-01],
│   │                         [ 3.1263e-01, -9.2959e-01, -2.0500e-02,  ...,  1.0514e+00,
│   │                          -9.1237e-02, -4.0738e-01],
│   │                         ...,
│   │                         [ 5.9616e-01,  2.0198e-01,  4.5488e-01,  ..., -1.1430e+00,
│   │                           1.6472e+00,  1.1657e+00],
│   │                         [ 1.3439e+00,  8.5232e-01,  1.1845e+00,  ..., -1.4170e+00,
│   │                          -2.4567e-01,  6.2718e-01],
│   │                         [-3.5228e-01,  2.9214e-01,  7.2495e-02,  ..., -2.5119e-01,
│   │                           1.5822e-01, -7.6209e-01]],
│   │               
│   │                        [[ 8.0268e-01, -1.9287e-01, -4.6442e-01,  ..., -1.1711e+00,
│   │                          -4.1858e-01, -3.3726e-01],
│   │                         [ 2.8643e-01, -1.2490e-01,  5.7117e-02,  ...,  8.0224e-02,
│   │                           3.4846e-01,  1.0001e+00],
│   │                         [-9.4231e-01, -6.1459e-01,  1.4439e+00,  ..., -1.6833e+00,
│   │                           4.3204e-01,  1.9055e+00],
│   │                         ...,
│   │                         [-2.1811e-01, -1.8737e-01, -1.7026e-01,  ...,  1.4128e+00,
│   │                           1.3461e+00, -6.5178e-02],
│   │                         [ 4.7062e-01, -9.3386e-01,  1.6096e-01,  ..., -9.8426e-01,
│   │                          -8.4429e-01, -1.4373e+00],
│   │                         [-3.4172e-01,  2.8421e-01, -9.7381e-01,  ..., -1.6720e+00,
│   │                           1.4061e+00, -3.8971e-01]],
│   │               
│   │                        [[-3.7199e-01, -8.6740e-01, -2.0878e+00,  ...,  3.6950e-02,
│   │                          -1.0812e+00,  9.6393e-01],
│   │                         [-5.6934e-01,  2.2135e+00,  5.9485e-01,  ...,  2.0860e-01,
│   │                           1.0534e+00, -5.9086e-01],
│   │                         [-6.1829e-01, -5.6516e-01, -4.9411e-01,  ..., -2.2105e-01,
│   │                           5.7419e-01,  1.4643e+00],
│   │                         ...,
│   │                         [-6.6794e-01,  3.6151e-01,  1.9617e-01,  ...,  8.9301e-01,
│   │                          -5.7946e-01, -4.8280e-01],
│   │                         [-7.2684e-02,  4.7881e-01, -8.2127e-01,  ...,  5.1355e-01,
│   │                          -7.9445e-01, -1.5814e-01],
│   │                         [ 1.0026e+00,  7.7249e-02, -4.9655e-01,  ...,  1.5630e+00,
│   │                           6.1333e-01,  4.4560e-02]]],
│   │               
│   │               
│   │                       [[[ 1.7293e-01,  1.5962e+00,  3.7038e-01,  ...,  1.1744e+00,
│   │                           9.1051e-01,  4.1325e-01],
│   │                         [ 7.5254e-01, -7.6255e-01, -2.1491e+00,  ..., -3.7040e-01,
│   │                          -1.4315e+00, -1.5967e+00],
│   │                         [-1.2312e+00,  4.5415e-01,  1.8938e-01,  ..., -1.5609e+00,
│   │                          -6.7308e-01, -4.4229e-01],
│   │                         ...,
│   │                         [-2.3806e-01, -8.5905e-01, -1.3252e+00,  ...,  8.8494e-01,
│   │                           1.5893e+00,  1.1286e+00],
│   │                         [-7.2676e-01,  2.4444e+00,  1.4938e+00,  ..., -1.1569e-01,
│   │                          -5.1522e-01,  2.3293e-01],
│   │                         [ 2.7554e-01, -3.6180e-01, -3.5846e-01,  ..., -1.5911e+00,
│   │                          -1.0920e+00, -3.1720e-01]],
│   │               
│   │                        [[ 7.4778e-02, -1.3204e+00,  2.1470e-01,  ..., -1.5994e+00,
│   │                           5.9122e-01, -9.0373e-01],
│   │                         [-6.2726e-01, -4.0204e-01,  2.5627e-01,  ...,  2.0521e+00,
│   │                          -3.4887e-01,  4.8551e-01],
│   │                         [-8.0137e-01,  2.5164e-01, -2.9974e-01,  ..., -1.7212e+00,
│   │                           1.1569e-01,  5.8473e-01],
│   │                         ...,
│   │                         [ 3.3794e-01,  2.5298e+00,  3.0920e-01,  ...,  4.1635e-01,
│   │                           2.2498e+00,  3.0156e-04],
│   │                         [ 6.7903e-01,  1.6308e+00,  1.8596e+00,  ...,  6.8641e-01,
│   │                          -2.0860e+00, -7.3169e-01],
│   │                         [-8.9653e-01,  1.0549e+00,  1.4606e+00,  ..., -5.3254e-01,
│   │                           2.1875e+00,  8.2325e-01]],
│   │               
│   │                        [[-1.3178e-01,  1.7179e+00, -1.4787e+00,  ...,  7.1023e-01,
│   │                          -4.1943e-01,  2.3317e-01],
│   │                         [-1.3295e-01, -7.7733e-01,  1.4723e+00,  ..., -7.4537e-01,
│   │                          -3.2538e-02, -1.7051e+00],
│   │                         [-9.5274e-02,  1.6361e-01, -2.1062e+00,  ...,  2.1729e+00,
│   │                          -1.9650e+00,  6.6619e-01],
│   │                         ...,
│   │                         [-3.5468e-01,  4.5415e-01, -3.6189e-01,  ...,  1.1417e+00,
│   │                           5.9284e-01,  8.4571e-01],
│   │                         [ 1.0969e-01, -1.2662e+00, -2.4143e-01,  ..., -1.3031e+00,
│   │                           4.2810e-01, -1.8673e+00],
│   │                         [-8.6853e-01,  6.7554e-02,  1.7818e-01,  ..., -4.1682e-01,
│   │                          -1.0642e+00,  6.7293e-01]]],
│   │               
│   │               
│   │                       [[[-1.8116e+00,  2.4847e-01, -7.0697e-01,  ..., -1.6679e-01,
│   │                          -6.6617e-01,  9.5506e-01],
│   │                         [-2.2718e-01,  2.2259e-01,  4.9405e-01,  ..., -1.1066e-01,
│   │                          -1.5705e+00,  1.8300e+00],
│   │                         [-1.2835e-01,  6.4524e-01, -2.8214e-01,  ...,  2.7977e+00,
│   │                          -1.8072e+00,  7.5596e-01],
│   │                         ...,
│   │                         [-2.6377e-01,  1.1243e-01,  6.4680e-01,  ...,  4.5359e-01,
│   │                          -2.3662e-02,  5.6789e-01],
│   │                         [ 1.0021e+00, -2.7510e-01, -2.0072e+00,  ...,  1.3054e-01,
│   │                          -5.4631e-01,  5.4375e-02],
│   │                         [ 1.8649e+00, -4.8946e-01,  7.0974e-01,  ...,  8.6659e-02,
│   │                           8.7556e-02,  5.4467e-01]],
│   │               
│   │                        [[-6.7881e-01, -5.7728e-01, -2.2985e+00,  ..., -1.0671e+00,
│   │                           5.8919e-02, -2.0369e-01],
│   │                         [-2.6757e+00,  6.4680e-01,  3.4915e-01,  ..., -1.5500e+00,
│   │                          -7.6875e-01,  2.1688e-01],
│   │                         [ 2.6456e-02,  6.9047e-01,  1.3072e+00,  ...,  1.1214e+00,
│   │                           1.6742e+00, -2.5208e+00],
│   │                         ...,
│   │                         [ 6.5759e-01, -1.2418e+00,  2.5567e-01,  ..., -5.3737e-01,
│   │                          -9.9298e-01, -1.2697e+00],
│   │                         [ 3.7382e-01,  8.0992e-01, -9.9138e-02,  ..., -1.6591e+00,
│   │                           1.2628e-02,  3.3264e-01],
│   │                         [-1.1315e-01, -3.2351e-01,  4.3002e-01,  ..., -1.2894e+00,
│   │                          -8.9081e-02, -1.7620e-01]],
│   │               
│   │                        [[ 1.8263e+00, -2.2592e+00, -3.3927e-02,  ..., -1.2444e+00,
│   │                           5.3591e-01,  1.0444e+00],
│   │                         [-1.1673e+00,  5.1561e-01, -5.4102e-01,  ..., -1.1459e+00,
│   │                           1.1770e-01, -6.4547e-01],
│   │                         [ 1.3625e+00,  4.5128e-01,  1.1528e+00,  ..., -1.0714e+00,
│   │                          -9.3821e-02, -9.4100e-01],
│   │                         ...,
│   │                         [-7.5836e-01, -2.7440e-01,  1.2628e+00,  ...,  1.0015e+00,
│   │                           1.1752e+00, -1.1488e-01],
│   │                         [ 4.7207e-01,  1.5958e+00,  3.6174e-01,  ..., -3.3027e-02,
│   │                          -8.4638e-01, -5.4286e-02],
│   │                         [-6.0204e-01,  9.0042e-02,  4.3473e-01,  ..., -8.5500e-02,
│   │                           3.2565e-01, -1.7850e+00]]],
│   │               
│   │               
│   │                       [[[-3.6217e-01,  9.0754e-01, -1.0629e-01,  ..., -1.6553e+00,
│   │                          -2.8224e-01, -3.9628e-02],
│   │                         [-5.2766e-01,  2.5577e-01,  5.0644e-01,  ..., -4.4013e-01,
│   │                          -8.0330e-01,  6.3605e-01],
│   │                         [-2.4126e-02,  4.5952e-01, -2.2054e+00,  ..., -2.3739e-01,
│   │                           1.3734e+00, -6.0692e-01],
│   │                         ...,
│   │                         [-7.5242e-01,  4.3804e-02, -3.9001e-01,  ..., -4.0766e-01,
│   │                           1.3963e-01, -1.3210e+00],
│   │                         [-1.9754e+00, -4.9701e-02,  6.6845e-01,  ..., -6.2150e-01,
│   │                           2.4794e-01,  4.7469e-01],
│   │                         [-3.7290e-02,  7.6249e-01,  6.4762e-02,  ...,  7.0506e-01,
│   │                           1.4366e+00,  1.2019e-02]],
│   │               
│   │                        [[-1.8096e+00, -1.4150e+00,  1.3770e+00,  ...,  6.9775e-01,
│   │                          -1.7277e+00, -2.2242e-01],
│   │                         [-2.4699e-01, -7.5911e-01, -2.2286e+00,  ..., -3.5520e-01,
│   │                           2.6527e-01,  1.4078e+00],
│   │                         [ 3.5618e-01,  1.5643e-01, -1.2414e-01,  ...,  1.9738e-01,
│   │                           5.0234e-01, -1.4401e+00],
│   │                         ...,
│   │                         [-1.3732e-01, -1.5700e+00,  5.5436e-01,  ...,  5.4962e-01,
│   │                          -1.7558e+00, -6.9809e-01],
│   │                         [ 4.7354e-01,  1.0241e+00, -4.5289e-02,  ..., -1.0261e+00,
│   │                          -3.4457e-01, -1.7164e-01],
│   │                         [ 9.6393e-01,  2.6452e-01,  3.9929e-01,  ..., -2.5072e-01,
│   │                          -5.8897e-01, -1.9762e-01]],
│   │               
│   │                        [[ 1.8298e+00,  9.5147e-02,  1.9860e-01,  ..., -7.6044e-01,
│   │                           1.1379e+00, -8.2668e-02],
│   │                         [-9.6574e-01, -1.1040e+00,  1.1724e-01,  ...,  1.8336e+00,
│   │                           1.2860e+00,  1.3923e+00],
│   │                         [ 3.5227e-01, -4.0235e-01, -6.5789e-01,  ...,  5.9040e-01,
│   │                           1.0622e+00, -7.8706e-01],
│   │                         ...,
│   │                         [ 8.2754e-01,  6.3649e-01, -8.0216e-01,  ...,  2.0213e-01,
│   │                          -2.6091e-01,  6.6371e-01],
│   │                         [-8.9857e-02,  9.5119e-01, -1.2953e+00,  ...,  3.1626e-01,
│   │                           2.1883e+00,  1.4886e+00],
│   │                         [-1.1685e+00,  6.0690e-01, -4.0226e-01,  ...,  3.3574e-01,
│   │                           2.0042e+00, -2.0040e-01]]]])
│   └── 'scalar' --> tensor([[-0.8852,  0.2552,  0.9167,  1.6804,  0.2647,  0.4875,  0.4769,  1.1374,
│                              0.3430, -0.0883,  0.0560,  0.0412],
│                            [-0.3992,  0.0289,  0.5185,  0.0122, -1.0255,  0.6371,  1.6496, -0.0538,
│                             -0.1678, -0.5983, -0.6269,  0.1011],
│                            [-1.7059,  1.1215, -0.5887, -0.0272,  2.9538, -1.3454,  1.8133, -1.1440,
│                              1.1386,  0.0786, -0.7860, -1.9610],
│                            [ 0.9267, -1.6123,  0.4862,  0.1678, -0.1714,  0.9942, -2.0951, -0.3835,
│                             -0.2612, -0.0263,  2.3000,  0.2585]])
└── 'reward' --> tensor([[0.8413],
                         [0.8528],
                         [0.0342],
                         [0.2741]])

This code looks much simpler and clearer.