Stack Structured Data

When the tensors form the tree structures, they are often needed to be stacked together, like the torch.stack() implemented in torch.

Stack With Native PyTorch API

Here is the common code implement with native pytorch API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]


# execute `stack` op
def stack(data, dim):
    elem = data[0]
    if isinstance(elem, torch.Tensor):
        return torch.stack(data, dim)
    elif isinstance(elem, dict):
        return {k: stack([item[k] for item in data], dim) for k in elem.keys()}
    elif isinstance(elem, bool):
        return torch.BoolTensor(data)
    else:
        raise TypeError("not support elem type: {}".format(type(elem)))


stacked_data = stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data['obs']['image'].shape == (B, 3, 32, 32)
assert stacked_data['action'].shape == (B, 1)
assert stacked_data['reward'].shape == (B, 1)
assert stacked_data['done'].shape == (B,)
assert stacked_data['done'].dtype == torch.bool

The output should be like below, and the assertion statements can be all passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
{'obs': {'scalar': tensor([[ 1.2487,  0.0659, -0.8828, -0.7956, -0.1911,  0.0722, -1.4615,  0.7627,
          0.2051,  1.3918, -0.6095, -2.5266],
        [ 1.6848, -0.7302, -0.2342, -0.3704,  0.1854,  1.3675, -0.7894, -0.5639,
          0.6334, -1.2212,  0.0083,  1.0401],
        [-0.0731, -1.5188, -3.2280, -0.1229, -1.1559, -0.7541,  0.9519, -1.5153,
          0.4552,  0.2878,  0.4917, -0.7240],
        [-1.7298, -1.5880, -0.2057,  0.0480, -0.3015,  1.5555, -0.4237, -0.4012,
          0.1069, -0.4387,  0.7484,  0.8109]]), 'image': tensor([[[[ 5.4568e-01,  1.3125e+00,  8.5267e-01,  ...,  3.4564e-01,
            8.3070e-01, -6.8391e-01],
          [ 2.9995e-01, -3.7075e-01,  3.1960e-01,  ..., -1.4002e-01,
            5.2041e-01,  2.0944e-01],
          [-7.9993e-01, -1.1783e+00,  3.0799e-01,  ...,  9.0114e-01,
           -5.5500e-01, -1.6128e+00],
          ...,
          [-1.9486e+00, -2.4392e-01,  3.5795e-01,  ..., -1.3657e-01,
           -7.0286e-01, -3.2742e-02],
          [ 1.4715e-02,  1.6891e-02,  8.4609e-02,  ...,  3.1909e-01,
           -9.8957e-01, -5.2522e-01],
          [ 1.3045e+00, -9.6447e-01,  1.7717e+00,  ..., -8.4638e-01,
           -2.1045e+00, -2.6723e-01]],

         [[ 1.2287e+00, -8.2976e-01,  2.9527e-01,  ..., -1.0946e+00,
           -2.2675e+00,  1.1860e+00],
          [ 1.4456e+00, -2.9175e-01, -1.4244e-01,  ...,  1.2039e+00,
           -7.1352e-01,  1.7125e-01],
          [ 1.1094e+00,  9.3555e-01,  2.1929e-01,  ...,  2.2428e-02,
           -1.0324e+00,  8.3560e-01],
          ...,
          [-1.8602e-01, -2.1478e-01,  3.0988e-01,  ...,  5.0561e-01,
            7.3534e-01,  1.4172e+00],
          [-8.7334e-01,  1.6744e+00, -1.0960e+00,  ..., -3.3990e-01,
           -2.8963e-01,  1.3335e+00],
          [ 2.4212e+00,  7.5342e-01,  8.1316e-01,  ...,  2.9624e-01,
            2.3937e-01,  3.1217e-01]],

         [[-7.5525e-01, -1.7512e+00, -4.9703e-01,  ...,  9.1071e-01,
           -6.9380e-01, -3.8026e-01],
          [ 6.8297e-01, -1.8199e+00,  7.0351e-02,  ...,  2.6667e-01,
           -1.3006e-01, -5.0755e-01],
          [ 2.2669e+00, -7.2633e-01,  1.9552e+00,  ..., -6.2078e-02,
            4.7782e-01,  2.8568e-01],
          ...,
          [ 5.0042e-01,  3.7822e-02,  7.5141e-01,  ...,  3.7158e-01,
            7.4444e-01,  5.1530e-01],
          [ 8.1599e-01, -7.0243e-01, -3.4992e-01,  ...,  2.0992e+00,
            2.0398e-01,  3.9508e-01],
          [-1.4976e+00, -1.0547e+00,  1.3596e+00,  ..., -3.0970e+00,
            2.3750e+00, -1.1632e-01]]],


        [[[-3.2651e-01, -1.8138e+00, -5.4320e-01,  ...,  8.7521e-01,
            1.0433e+00,  1.7235e+00],
          [-1.3835e+00, -3.3954e-01, -2.7084e-01,  ..., -7.2559e-01,
           -1.0555e-01, -3.5252e-01],
          [ 1.1560e+00,  1.0947e+00, -3.4015e-01,  ...,  4.2856e-01,
           -2.6645e-01,  4.6230e-01],
          ...,
          [-1.3601e+00, -4.2079e-01, -1.9801e+00,  ...,  2.6100e-01,
            2.5009e-01,  5.0118e-01],
          [-1.4183e+00, -3.8030e-01,  3.9685e-01,  ..., -1.2996e+00,
            5.4177e-01,  4.0098e-01],
          [-1.4642e+00,  2.2874e+00, -3.6630e-01,  ..., -9.5352e-01,
           -1.8468e+00, -1.7926e+00]],

         [[ 2.9176e-01,  5.8598e-01,  6.4524e-01,  ...,  2.2106e-01,
           -1.2959e+00,  1.6307e+00],
          [-1.8381e+00,  1.6662e+00,  9.9973e-01,  ...,  1.0389e+00,
            6.2232e-01, -1.2866e+00],
          [ 2.1080e-01,  1.8014e+00, -3.3409e-02,  ...,  3.1470e-01,
            9.7042e-01,  1.8585e+00],
          ...,
          [-2.9094e-01, -2.5811e-01,  1.4689e-01,  ...,  1.3053e+00,
            1.0435e+00,  1.9542e+00],
          [ 1.5792e+00,  6.1428e-03, -1.5544e+00,  ..., -8.4026e-01,
            1.3295e+00, -1.3172e-01],
          [-4.5052e-01,  5.5177e-03, -9.5845e-01,  ..., -6.7349e-02,
            5.0406e-01,  4.4078e-01]],

         [[-9.8311e-01, -3.7986e-01,  4.1895e-01,  ..., -2.1160e-01,
           -1.3994e+00,  1.1588e-01],
          [ 1.8306e-01, -3.6433e-01, -5.4773e-01,  ...,  9.3740e-01,
            9.2961e-04, -2.5617e-01],
          [-5.2898e-01, -6.1652e-01, -1.3020e+00,  ...,  1.3490e+00,
           -4.5652e-01,  4.8356e-01],
          ...,
          [-9.2123e-01, -3.5525e-01,  2.6412e-01,  ..., -1.0599e+00,
           -1.9792e-01, -1.2992e+00],
          [ 1.5468e-01,  8.4751e-01, -5.2128e-01,  ...,  8.8275e-01,
           -6.3472e-01, -3.1174e-01],
          [ 1.1747e-01, -7.6513e-01, -2.2058e-01,  ...,  2.4015e+00,
            1.5834e+00,  4.2212e-01]]],


        [[[-7.2425e-01, -4.5502e-01, -4.7149e-01,  ..., -6.7861e-02,
           -1.5309e-01,  2.1381e+00],
          [-1.9982e-02,  2.7995e-01,  2.0240e+00,  ..., -1.0835e+00,
            2.3037e-01, -2.8825e+00],
          [-7.9583e-01, -1.6658e+00,  6.0063e-01,  ..., -1.6351e+00,
            1.1359e+00,  1.2969e+00],
          ...,
          [ 2.2538e+00,  5.3171e-01, -1.1684e+00,  ..., -4.4470e-01,
            1.6598e-01, -3.7410e-01],
          [ 3.9328e-02, -4.8167e-01, -7.7380e-01,  ..., -3.2484e-01,
           -1.9947e+00, -1.5201e+00],
          [-1.7360e+00,  1.1486e-01,  4.0050e-01,  ..., -4.0868e-01,
            4.5551e-01, -6.2560e-01]],

         [[ 1.4916e-01,  2.6327e-01, -9.2421e-01,  ...,  1.1232e+00,
            9.2454e-01, -2.1548e-01],
          [-1.7798e-01, -2.3534e-01, -2.1099e-01,  ...,  2.6555e-02,
           -1.6403e+00, -1.3546e+00],
          [ 8.5559e-01, -2.4245e+00, -5.5505e-01,  ...,  1.4165e-01,
            2.2297e-01,  6.7547e-02],
          ...,
          [ 5.1256e-01, -1.2606e-01,  1.5935e+00,  ..., -1.1555e+00,
            2.7255e-02, -1.4003e+00],
          [ 5.3431e-01,  2.5034e-01,  8.9313e-02,  ..., -7.1464e-01,
            2.2875e+00,  1.5020e+00],
          [ 7.5106e-01,  1.2033e+00,  4.0816e-01,  ..., -2.4958e-01,
           -7.7272e-02, -1.8377e+00]],

         [[-5.0873e-01,  4.5417e-01,  2.9624e-01,  ..., -1.1058e+00,
           -4.1511e-01, -2.3333e+00],
          [ 6.3320e-03, -6.8507e-01, -1.2379e+00,  ..., -1.5970e-01,
            1.9783e+00,  5.9012e-01],
          [ 2.0505e+00, -3.7271e-01, -1.1901e+00,  ..., -8.5279e-01,
           -1.8106e+00, -9.7263e-01],
          ...,
          [ 2.2055e+00,  6.8701e-01, -6.8200e-01,  ...,  1.8404e+00,
           -1.2277e+00, -6.3110e-01],
          [-8.2342e-01,  4.4454e-01,  1.2279e+00,  ...,  1.0266e+00,
           -6.7878e-02, -5.8478e-02],
          [ 4.9169e-01, -8.1805e-01,  1.9280e+00,  ...,  2.0880e-01,
           -8.1046e-01, -1.0432e+00]]],


        [[[-3.7968e-01,  7.8225e-02,  1.5999e+00,  ..., -9.8021e-01,
           -1.0688e+00,  1.6284e+00],
          [-3.5402e-01, -4.3988e-01,  1.8506e-02,  ...,  1.5246e+00,
           -4.4886e-01,  5.6454e-01],
          [ 1.0891e+00,  8.7219e-01,  1.1726e-01,  ...,  1.0126e+00,
            2.3321e-01, -4.7287e-01],
          ...,
          [ 9.2974e-02, -7.4575e-01, -8.8682e-01,  ..., -1.6220e-01,
           -1.5445e-02,  1.1572e+00],
          [-3.1327e-01,  1.2644e-01,  4.2322e-01,  ...,  2.1216e+00,
           -2.4136e-01,  1.2086e-02],
          [ 3.8632e-01,  1.5433e-01, -3.3871e-01,  ...,  5.7449e-01,
           -1.3997e-01, -1.6376e-01]],

         [[ 4.3074e-01, -1.2000e+00,  1.0080e-02,  ..., -7.8164e-01,
           -1.3571e+00, -9.4236e-01],
          [-8.0036e-01,  1.1922e+00,  1.4388e+00,  ...,  1.9770e-01,
           -2.2703e-01,  3.9754e-01],
          [ 1.2028e+00, -1.0379e+00,  7.7509e-01,  ...,  1.2225e+00,
           -1.2627e+00,  1.5380e+00],
          ...,
          [ 1.6637e+00,  1.7794e-01, -4.3641e-01,  ..., -5.8911e-01,
           -1.3908e-01,  7.6500e-01],
          [ 5.1329e-01,  9.9871e-01, -1.3560e+00,  ..., -1.6761e-01,
            1.0929e+00,  1.0061e+00],
          [ 1.1049e-01, -2.2302e-01,  2.8102e-02,  ...,  1.2264e-01,
            3.8269e-02,  7.3134e-01]],

         [[-6.9656e-01,  3.1748e-01, -4.1338e-01,  ..., -1.6892e-02,
            1.8131e+00,  1.6852e+00],
          [ 1.1389e+00,  6.5516e-01, -4.1091e-01,  ..., -2.0759e-01,
           -8.8667e-01, -8.5072e-01],
          [-1.8422e+00,  4.2830e-01, -3.1311e-02,  ...,  9.6866e-01,
           -6.4141e-01,  1.0588e+00],
          ...,
          [-3.1157e-01, -2.9272e-01,  1.2738e+00,  ..., -1.1907e+00,
            2.7738e-01, -2.4553e-01],
          [-6.4724e-01,  5.8602e-01,  5.5032e-01,  ..., -1.7885e-01,
           -2.2955e-01,  1.4179e-01],
          [ 8.7774e-02, -1.1011e+00, -1.1919e-01,  ...,  1.1613e+00,
            3.6297e-01,  2.2333e-01]]]])}, 'action': tensor([[5],
        [0],
        [7],
        [8]]), 'reward': tensor([[0.4868],
        [0.3705],
        [0.3860],
        [0.0417]]), 'done': tensor([False, False, False, False])}

We can see that the structure process function need to be fully implemented (like the function stack). This code is actually not clear, and due to hard coding, if you need to support more data types (such as integer), you must make special modifications to the function.

Stack With TreeTensor API

The same workflow can be implemented with treetensor API like the code below.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch

import treetensor.torch as ttorch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]
# execute `stack` op
data = [ttorch.tensor(d) for d in data]
stacked_data = ttorch.stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data.obs.image.shape == (B, 3, 32, 32)
assert stacked_data.action.shape == (B, 1)
assert stacked_data.reward.shape == (B, 1)
assert stacked_data.done.shape == (B,)
assert stacked_data.done.dtype == torch.bool

The output should be like below, and the assertion statements can be all passed as well.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
<Tensor 0x7f2997fff590>
├── 'action' --> tensor([[4],
│                        [8],
│                        [4],
│                        [8]])
├── 'done' --> tensor([False, False, False, False])
├── 'obs' --> <Tensor 0x7f2997fff5d0>
│   ├── 'image' --> tensor([[[[-0.0672, -0.0630,  0.0620,  ..., -0.1392,  1.3751,  0.4971],
│   │                         [-1.0882, -0.5762,  0.4322,  ...,  0.0836,  0.9350,  1.3846],
│   │                         [-1.2985, -0.7872, -0.2924,  ...,  1.1470,  0.9763,  0.1230],
│   │                         ...,
│   │                         [ 0.3916, -0.7103,  0.6722,  ..., -1.4107,  0.4783, -0.9401],
│   │                         [ 1.3308, -0.2335, -1.0654,  ...,  0.0858, -0.2349,  0.3034],
│   │                         [ 0.8558, -0.9990,  0.2647,  ...,  0.1819, -1.3422, -0.3963]],
│   │               
│   │                        [[-0.1613, -0.4208, -0.0185,  ..., -1.1834,  0.5306,  0.2790],
│   │                         [ 0.7737,  1.0139,  1.2693,  ..., -0.4637, -1.7979, -1.1924],
│   │                         [ 0.5818,  0.7808, -0.3680,  ..., -1.0213, -0.0658, -0.6323],
│   │                         ...,
│   │                         [-0.1475,  0.0759, -0.8877,  ..., -1.0376,  0.8723,  0.3218],
│   │                         [-0.0751, -1.1105, -0.4679,  ..., -1.2987, -0.0368,  2.1064],
│   │                         [ 1.4135, -1.5006, -1.6955,  ...,  0.3030,  0.0042, -0.4974]],
│   │               
│   │                        [[ 1.0542,  0.8837, -1.2863,  ..., -0.0581, -0.4548,  0.5833],
│   │                         [-0.6228, -1.1552,  1.0618,  ..., -1.2627, -0.9395, -1.1910],
│   │                         [ 0.9514,  1.3390,  1.3661,  ..., -0.9148, -0.3406, -0.1843],
│   │                         ...,
│   │                         [-1.4357,  0.9096,  0.6402,  ..., -1.5467,  0.7123,  1.2526],
│   │                         [ 1.4199, -0.4602, -1.8541,  ...,  0.1535, -2.0045,  1.8305],
│   │                         [-0.1067,  0.0878,  0.5743,  ..., -3.4279, -0.3968,  0.3653]]],
│   │               
│   │               
│   │                       [[[-0.8043, -0.2095,  0.5204,  ..., -0.6405,  0.0799, -1.1175],
│   │                         [ 0.6695, -0.2450, -0.1862,  ...,  2.3209, -0.6252,  1.0478],
│   │                         [-1.2586,  0.1570, -0.3718,  ..., -0.1745, -0.3558,  0.4878],
│   │                         ...,
│   │                         [-0.1055,  0.2066, -0.4919,  ..., -0.5088, -0.5332,  1.4573],
│   │                         [ 1.3801, -0.1064,  1.5342,  ...,  0.4528,  0.1134, -1.6186],
│   │                         [-1.2885,  1.5293,  1.0278,  ..., -0.2581, -1.8196,  0.0419]],
│   │               
│   │                        [[ 0.1767, -1.2691,  1.4488,  ...,  0.0537,  2.4332,  1.3084],
│   │                         [-1.9883, -0.8405, -1.9822,  ...,  0.1984, -1.3904, -0.1282],
│   │                         [-0.5508,  0.4290,  1.0892,  ..., -1.2228,  0.4547, -0.4385],
│   │                         ...,
│   │                         [-0.8020, -0.3676, -0.1572,  ...,  0.7562,  1.1367, -0.6621],
│   │                         [-0.0750, -0.0172, -1.2188,  ...,  2.9413, -1.2360,  1.6817],
│   │                         [-0.1110,  0.5100,  0.3097,  ...,  1.0586, -0.5567,  1.6480]],
│   │               
│   │                        [[ 0.9641,  0.8322,  0.8819,  ...,  1.1606, -1.1696, -1.2609],
│   │                         [ 0.1653,  1.5093, -0.0313,  ...,  0.0242,  0.3498, -0.0463],
│   │                         [-0.2704,  0.1059,  0.5695,  ...,  0.6060,  0.8625,  0.2072],
│   │                         ...,
│   │                         [ 0.5223, -1.2731, -1.6267,  ...,  1.6531, -0.2693, -0.2709],
│   │                         [ 1.4091,  0.9915, -1.2171,  ...,  1.6000, -1.4802, -1.2880],
│   │                         [ 0.7806, -0.5546, -0.7595,  ...,  0.5932, -0.8286, -0.4641]]],
│   │               
│   │               
│   │                       [[[-0.5078, -0.1389,  0.3135,  ..., -0.5817,  1.5611, -0.9665],
│   │                         [ 0.3361,  0.7957, -2.2603,  ..., -1.2018, -1.6205, -0.5193],
│   │                         [-0.2796, -0.2151,  0.3970,  ..., -0.3267,  0.4383, -0.8802],
│   │                         ...,
│   │                         [ 0.3649, -1.2698, -0.3297,  ..., -0.4642,  0.9421, -0.8969],
│   │                         [-3.5398, -1.3365,  0.2732,  ..., -0.2043, -1.0875, -1.7606],
│   │                         [ 1.3389,  0.8844, -0.8726,  ..., -0.2136, -0.7577,  0.1528]],
│   │               
│   │                        [[ 0.1266, -1.1817, -1.6102,  ..., -1.6661,  0.3754,  0.8016],
│   │                         [-0.3275, -0.4856, -0.2517,  ...,  0.8218,  0.3492, -0.5772],
│   │                         [-0.5595,  0.1046,  1.6680,  ..., -1.0508, -1.2683, -0.6408],
│   │                         ...,
│   │                         [ 0.5485, -0.4472,  0.3330,  ..., -2.3137,  1.3527,  1.6108],
│   │                         [-1.8751, -0.9017, -1.3693,  ...,  0.7377,  1.7705, -1.0012],
│   │                         [-0.1600, -1.1750, -0.9609,  ...,  0.0267, -1.7438, -0.1357]],
│   │               
│   │                        [[ 0.0295, -0.1752, -0.0933,  ..., -0.7064,  1.7305,  0.0484],
│   │                         [ 0.0977, -2.0690, -0.0565,  ...,  0.5252,  0.5882,  1.8323],
│   │                         [-0.8063, -0.1968,  0.0058,  ..., -0.1629,  0.8157,  0.5958],
│   │                         ...,
│   │                         [-1.7675,  0.3520,  0.7394,  ...,  0.9167, -1.4397,  0.8258],
│   │                         [-0.0533,  0.8308, -0.4375,  ..., -1.8468, -0.8035,  1.5208],
│   │                         [-1.1298,  0.2310,  1.8792,  ..., -0.8094,  1.5368, -0.9962]]],
│   │               
│   │               
│   │                       [[[-0.1918, -0.0811,  1.4200,  ...,  1.8283,  0.7468, -0.5494],
│   │                         [-1.1889, -0.0111,  0.6391,  ..., -2.0235,  1.2109,  0.3940],
│   │                         [-0.8798, -0.3754, -0.3869,  ..., -0.4266,  0.9059, -0.2676],
│   │                         ...,
│   │                         [-0.9199,  0.3943, -0.8301,  ..., -0.6845,  0.5190,  0.0398],
│   │                         [-1.3283, -0.9408, -0.6897,  ..., -0.2246,  1.3044, -1.4520],
│   │                         [-1.4113, -0.2010, -1.0671,  ..., -0.9592,  0.6844, -0.8724]],
│   │               
│   │                        [[-0.1867, -0.3295,  0.9064,  ..., -0.5670, -0.9720,  0.1669],
│   │                         [-0.6180,  0.7009,  0.5636,  ..., -0.2222,  0.2423,  0.1320],
│   │                         [-0.0569,  0.1350,  0.3491,  ...,  0.1231, -0.8858, -0.5849],
│   │                         ...,
│   │                         [-0.1763, -3.0825, -0.7967,  ...,  0.9897,  1.0332, -2.8303],
│   │                         [ 0.0180, -0.3289, -0.3358,  ...,  0.6526, -2.3824, -0.3013],
│   │                         [ 0.9976, -0.7704,  1.9511,  ..., -0.1970,  0.0279, -0.4749]],
│   │               
│   │                        [[ 1.4741,  1.1849,  0.0474,  ..., -0.2876,  1.6320, -1.6171],
│   │                         [-0.1468, -1.1539,  1.0065,  ...,  0.6058, -0.3940, -0.5912],
│   │                         [ 1.4676, -0.4917,  1.8070,  ..., -0.4347,  0.2397,  0.5888],
│   │                         ...,
│   │                         [-1.6730, -0.9058,  0.7252,  ..., -0.5221, -1.4716, -0.9461],
│   │                         [-1.1886,  0.3819,  0.4536,  ...,  0.9436, -0.2282, -0.3439],
│   │                         [-0.1948,  0.0660,  0.8384,  ...,  0.2731,  0.0443,  0.7630]]]])
│   └── 'scalar' --> tensor([[-1.2252,  0.0368, -0.0747, -0.8726, -0.3240, -0.3707, -0.6999,  0.2730,
│                              0.5353, -0.3152,  0.0718, -0.0288],
│                            [ 0.4700,  0.5138,  0.2613, -2.4029, -0.3593,  0.2235, -0.7172, -0.5973,
│                              0.3618,  1.2508, -0.2310,  0.2328],
│                            [ 0.2828,  0.5045,  0.7388,  1.9964, -1.3182, -0.6180,  0.5437,  0.8925,
│                             -1.1228, -0.5329, -0.7316,  1.1478],
│                            [ 1.6263,  0.0425, -0.8926,  1.0294, -0.4377, -0.4261,  1.8836,  0.8233,
│                              2.0130,  1.2278, -0.7956,  1.4801]])
└── 'reward' --> tensor([[0.2219],
                         [0.3270],
                         [0.7693],
                         [0.9634]])

This code looks much simpler and clearer.