Customized Operations For Different Fields

Here is another example of the custom operations implemented with both native torch API and treetensor API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import copy

import torch

import treetensor.torch as ttorch

T, B = 3, 4


def with_nativetensor(batch_):
    mean_b_list = []
    even_index_a_list = []
    for i in range(len(batch_)):
        for k, v in batch_[i].items():
            if k == 'a':
                v = v.float()
                even_index_a_list.append(v[::2])
            elif k == 'b':
                v = v.float()
                transformed_v = torch.pow(v, 2) + 1.0
                mean_b_list.append(transformed_v.mean())
            elif k == 'c':
                for k1, v1 in v.items():
                    if k1 == 'd':
                        v1 = v1.float()
                    else:
                        print('ignore keys: {}'.format(k1))
            else:
                print('ignore keys: {}'.format(k))
    for i in range(len(batch_)):
        for k in batch_[i].keys():
            if k == 'd':
                batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5))

    mean_b = sum(mean_b_list) / len(mean_b_list)
    even_index_a = torch.stack(even_index_a_list, dim=0)
    return batch_, mean_b, even_index_a


def with_treetensor(batch_):
    batch_ = [ttorch.tensor(b) for b in batch_]
    batch_ = ttorch.stack(batch_)
    batch_ = batch_.float()
    batch_.b = ttorch.pow(batch_.b, 2) + 1.0
    batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5))
    mean_b = batch_.b.mean()
    even_index_a = batch_.a[:, ::2]
    batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0)
    return batch_, mean_b, even_index_a


def get_data():
    return {
        'a': torch.rand(size=(T, 8)),
        'b': torch.rand(size=(6,)),
        'c': {
            'd': torch.randint(0, 10, size=(1,))
        }
    }


if __name__ == "__main__":
    batch = [get_data() for _ in range(B)]
    batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch))
    batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch))
    print(batch0)
    print('\n\n')
    print(batch1)

    assert torch.abs(mean0 - mean1) < 1e-6
    print('mean0 & mean1:', mean0, mean1)
    print('\n')

    assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6
    print('even_index_a0:', even_index_a0)
    print('even_index_a1:', even_index_a1)

    assert len(batch0) == B
    assert len(batch1) == B
    assert isinstance(batch1[0], ttorch.Tensor)
    print(batch1[0].shape)

The output should be like below, and all the assertions can be passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
[{'a': tensor([[0.6370, 0.5287, 0.8719, 0.8738, 0.8669, 0.0233, 0.4351, 0.3231],
        [0.0673, 0.8459, 0.8704, 0.0062, 0.2567, 0.2812, 0.9117, 0.9412],
        [0.0644, 0.1503, 0.6570, 0.0957, 0.1638, 0.7742, 0.9270, 0.6866]]), 'b': tensor([0.9541, 0.0955, 0.8896, 0.1709, 0.4578, 0.9287]), 'c': {'d': tensor([5])}}, {'a': tensor([[0.6236, 0.2665, 0.3845, 0.1008, 0.1871, 0.0125, 0.1407, 0.8828],
        [0.4517, 0.4068, 0.1777, 0.8897, 0.2485, 0.1900, 0.8993, 0.3462],
        [0.8976, 0.8408, 0.2361, 0.9718, 0.7077, 0.0350, 0.1291, 0.7675]]), 'b': tensor([0.6229, 0.7557, 0.2653, 0.8749, 0.8205, 0.2778]), 'c': {'d': tensor([1])}}, {'a': tensor([[0.0803, 0.0747, 0.2994, 0.3551, 0.2325, 0.2029, 0.0963, 0.3583],
        [0.4929, 0.8611, 0.0243, 0.5447, 0.4929, 0.4975, 0.6029, 0.9135],
        [0.5109, 0.3885, 0.7017, 0.5907, 0.3846, 0.3243, 0.7258, 0.9102]]), 'b': tensor([0.6293, 0.8360, 0.0842, 0.0263, 0.5059, 0.5398]), 'c': {'d': tensor([0])}}, {'a': tensor([[0.5590, 0.5996, 0.9398, 0.4532, 0.7030, 0.0217, 0.5259, 0.0566],
        [0.1255, 0.7498, 0.4983, 0.3020, 0.4932, 0.7532, 0.5629, 0.0898],
        [0.2531, 0.5237, 0.3917, 0.0051, 0.9483, 0.1334, 0.5323, 0.3156]]), 'b': tensor([0.2050, 0.8358, 0.5785, 0.6644, 0.5952, 0.6730]), 'c': {'d': tensor([4])}}]



(<Tensor 0x7f7dffccc090>
├── 'a' --> tensor([[[0.6370, 0.5287, 0.8719, 0.8738, 0.8669, 0.0233, 0.4351, 0.3231],
│                    [0.0673, 0.8459, 0.8704, 0.0062, 0.2567, 0.2812, 0.9117, 0.9412],
│                    [0.0644, 0.1503, 0.6570, 0.0957, 0.1638, 0.7742, 0.9270, 0.6866]]])
├── 'b' --> tensor([[1.9103, 1.0091, 1.7913, 1.0292, 1.2096, 1.8625]])
└── 'c' --> <Tensor 0x7f7dffccc050>
    ├── 'd' --> tensor([[5.]])
    └── 'noise' --> tensor([[[[ 0.4562, -0.4098,  0.9975, -1.5102,  0.4956],
                              [ 0.9847,  1.0742,  0.5543,  2.2826,  2.0131],
                              [ 0.0873, -1.0024,  0.0923,  0.2751, -0.4338],
                              [ 0.6638,  1.0220,  1.3232, -0.6142, -0.3566]],
                    
                             [[ 1.2246,  0.8735,  0.0900,  1.1970,  0.5007],
                              [-0.7250, -0.9558,  0.1337, -0.0758, -0.2821],
                              [-0.1152, -0.4845,  0.7222,  1.7570, -1.8673],
                              [ 1.2254, -0.2409, -1.1709, -0.3074, -0.6063]],
                    
                             [[-0.6470,  0.3742, -0.8198,  0.9947,  0.2193],
                              [-1.5520,  1.0254, -0.6323,  0.4202,  0.3086],
                              [-1.8731, -0.8634,  1.8733, -1.1940,  0.7242],
                              [-0.9276,  0.4131, -0.9818, -0.3848, -0.2841]]]])
, <Tensor 0x7f7dffccc110>
├── 'a' --> tensor([[[0.6236, 0.2665, 0.3845, 0.1008, 0.1871, 0.0125, 0.1407, 0.8828],
│                    [0.4517, 0.4068, 0.1777, 0.8897, 0.2485, 0.1900, 0.8993, 0.3462],
│                    [0.8976, 0.8408, 0.2361, 0.9718, 0.7077, 0.0350, 0.1291, 0.7675]]])
├── 'b' --> tensor([[1.3880, 1.5711, 1.0704, 1.7655, 1.6732, 1.0772]])
└── 'c' --> <Tensor 0x7f7dffccc0d0>
    ├── 'd' --> tensor([[1.]])
    └── 'noise' --> tensor([[[[ 1.6693,  1.3292,  1.1779,  0.1647, -0.3271],
                              [-0.7005,  0.3124,  1.1802, -0.6941, -0.7028],
                              [ 1.1857, -0.5633,  1.8195,  0.3823, -0.7769],
                              [-0.3180,  0.7928, -0.4512, -0.6213, -0.2267]],
                    
                             [[-1.6616, -0.0282, -0.1339,  1.1325, -0.4535],
                              [ 1.3824, -1.2600, -0.5668, -0.0589, -1.6768],
                              [ 1.0283, -0.5619, -0.5195,  0.4863, -0.2197],
                              [-0.0302,  1.4302,  1.4545,  0.4620,  0.2496]],
                    
                             [[ 0.0375,  1.7652,  0.1893,  0.2907, -0.3880],
                              [ 0.8968, -0.6625, -2.0683,  1.4648,  0.5475],
                              [-0.8635,  0.5427,  0.6779,  0.0429, -0.8031],
                              [-0.0240,  0.2764,  1.1809,  1.8446,  0.5535]]]])
, <Tensor 0x7f7dffccc190>
├── 'a' --> tensor([[[0.0803, 0.0747, 0.2994, 0.3551, 0.2325, 0.2029, 0.0963, 0.3583],
│                    [0.4929, 0.8611, 0.0243, 0.5447, 0.4929, 0.4975, 0.6029, 0.9135],
│                    [0.5109, 0.3885, 0.7017, 0.5907, 0.3846, 0.3243, 0.7258, 0.9102]]])
├── 'b' --> tensor([[1.3961, 1.6988, 1.0071, 1.0007, 1.2559, 1.2914]])
└── 'c' --> <Tensor 0x7f7dffccc150>
    ├── 'd' --> tensor([[0.]])
    └── 'noise' --> tensor([[[[ 2.6231e-01, -7.3377e-01, -6.7766e-01,  6.3407e-01, -1.0650e+00],
                              [-1.9808e+00,  1.7385e+00,  1.3419e-03, -1.0606e+00, -7.2179e-02],
                              [ 2.5939e-01,  9.2158e-01,  1.1401e+00,  1.2967e+00, -2.1993e+00],
                              [-6.4287e-01, -5.5683e-01,  1.0608e+00,  4.6332e-01, -1.7852e+00]],
                    
                             [[-6.8630e-01, -3.5486e-01, -7.1593e-01, -1.0573e+00, -1.9016e+00],
                              [ 6.6695e-01,  2.2377e-01, -1.5779e+00,  6.3296e-01, -1.1480e+00],
                              [-3.4535e-01, -9.3572e-01, -9.4721e-01, -6.5567e-02, -9.0092e-01],
                              [-5.1379e-01,  9.2360e-01,  7.6557e-01,  3.7062e-01, -6.3338e-01]],
                    
                             [[ 1.4979e-01, -4.6819e-01,  5.6967e-01,  3.8858e-01,  2.0449e-01],
                              [-4.2331e-01,  2.6956e+00,  1.0970e+00,  1.1910e-01,  2.9090e-01],
                              [ 6.7524e-01, -3.5661e-01, -2.0881e-01,  1.7711e+00, -5.2344e-01],
                              [ 5.0039e-01,  2.7028e-01,  4.6420e-01,  1.1946e+00, -3.8231e-01]]]])
, <Tensor 0x7f7dffccc210>
├── 'a' --> tensor([[[0.5590, 0.5996, 0.9398, 0.4532, 0.7030, 0.0217, 0.5259, 0.0566],
│                    [0.1255, 0.7498, 0.4983, 0.3020, 0.4932, 0.7532, 0.5629, 0.0898],
│                    [0.2531, 0.5237, 0.3917, 0.0051, 0.9483, 0.1334, 0.5323, 0.3156]]])
├── 'b' --> tensor([[1.0420, 1.6985, 1.3346, 1.4415, 1.3543, 1.4530]])
└── 'c' --> <Tensor 0x7f7dffccc1d0>
    ├── 'd' --> tensor([[4.]])
    └── 'noise' --> tensor([[[[ 6.4419e-01, -2.7867e-01, -2.0447e-01,  2.0649e-01, -9.7412e-01],
                              [-8.1947e-01,  1.6657e-01,  1.0360e+00, -5.2035e-01,  1.5048e+00],
                              [-1.2174e+00, -1.0475e+00, -7.3149e-01, -4.9862e-01, -4.1088e-03],
                              [-1.7250e+00, -9.8784e-01, -1.0926e+00, -1.4733e-01,  1.2436e+00]],
                    
                             [[ 1.8212e+00, -9.0900e-02,  7.0699e-01, -1.1273e+00,  1.4671e-01],
                              [-6.3425e-01, -4.8812e-01,  9.4056e-01, -1.3953e+00, -1.1121e+00],
                              [-9.9047e-01,  6.5654e-01,  1.0493e-03, -2.7606e-01,  1.4477e+00],
                              [-1.2463e+00, -1.3972e-01,  1.0830e+00,  3.8693e-01, -6.2379e-01]],
                    
                             [[ 3.6846e-01,  7.0519e-01,  1.0511e+00,  1.2688e+00, -3.8694e-01],
                              [-3.0683e-01,  1.1013e-02,  7.5671e-01, -3.4247e-01,  1.0267e-01],
                              [-1.2282e-01, -5.6973e-01, -5.5059e-01, -8.8279e-01, -8.9531e-01],
                              [ 4.0072e-01, -1.1599e+00,  9.7310e-01, -9.1617e-01, -1.5341e+00]]]])
)
mean0 & mean1: tensor(1.3888) tensor(1.3888)


even_index_a0: tensor([[[0.6370, 0.5287, 0.8719, 0.8738, 0.8669, 0.0233, 0.4351, 0.3231],
         [0.0644, 0.1503, 0.6570, 0.0957, 0.1638, 0.7742, 0.9270, 0.6866]],

        [[0.6236, 0.2665, 0.3845, 0.1008, 0.1871, 0.0125, 0.1407, 0.8828],
         [0.8976, 0.8408, 0.2361, 0.9718, 0.7077, 0.0350, 0.1291, 0.7675]],

        [[0.0803, 0.0747, 0.2994, 0.3551, 0.2325, 0.2029, 0.0963, 0.3583],
         [0.5109, 0.3885, 0.7017, 0.5907, 0.3846, 0.3243, 0.7258, 0.9102]],

        [[0.5590, 0.5996, 0.9398, 0.4532, 0.7030, 0.0217, 0.5259, 0.0566],
         [0.2531, 0.5237, 0.3917, 0.0051, 0.9483, 0.1334, 0.5323, 0.3156]]])
even_index_a1: tensor([[[0.6370, 0.5287, 0.8719, 0.8738, 0.8669, 0.0233, 0.4351, 0.3231],
         [0.0644, 0.1503, 0.6570, 0.0957, 0.1638, 0.7742, 0.9270, 0.6866]],

        [[0.6236, 0.2665, 0.3845, 0.1008, 0.1871, 0.0125, 0.1407, 0.8828],
         [0.8976, 0.8408, 0.2361, 0.9718, 0.7077, 0.0350, 0.1291, 0.7675]],

        [[0.0803, 0.0747, 0.2994, 0.3551, 0.2325, 0.2029, 0.0963, 0.3583],
         [0.5109, 0.3885, 0.7017, 0.5907, 0.3846, 0.3243, 0.7258, 0.9102]],

        [[0.5590, 0.5996, 0.9398, 0.4532, 0.7030, 0.0217, 0.5259, 0.0566],
         [0.2531, 0.5237, 0.3917, 0.0051, 0.9483, 0.1334, 0.5323, 0.3156]]])
<Size 0x7f7dffcc5bd0>
├── 'a' --> torch.Size([1, 3, 8])
├── 'b' --> torch.Size([1, 6])
└── 'c' --> <Size 0x7f7e000317d0>
    ├── 'd' --> torch.Size([1, 1])
    └── 'noise' --> torch.Size([1, 3, 4, 5])

The implement with treetensor API is much simpler and clearer.