Customized Operations For Different Fields

Here is another example of the custom operations implemented with both native torch API and treetensor API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import copy

import torch

import treetensor.torch as ttorch

T, B = 3, 4


def with_nativetensor(batch_):
    mean_b_list = []
    even_index_a_list = []
    for i in range(len(batch_)):
        for k, v in batch_[i].items():
            if k == 'a':
                v = v.float()
                even_index_a_list.append(v[::2])
            elif k == 'b':
                v = v.float()
                transformed_v = torch.pow(v, 2) + 1.0
                mean_b_list.append(transformed_v.mean())
            elif k == 'c':
                for k1, v1 in v.items():
                    if k1 == 'd':
                        v1 = v1.float()
                    else:
                        print('ignore keys: {}'.format(k1))
            else:
                print('ignore keys: {}'.format(k))
    for i in range(len(batch_)):
        for k in batch_[i].keys():
            if k == 'd':
                batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5))

    mean_b = sum(mean_b_list) / len(mean_b_list)
    even_index_a = torch.stack(even_index_a_list, dim=0)
    return batch_, mean_b, even_index_a


def with_treetensor(batch_):
    batch_ = [ttorch.tensor(b) for b in batch_]
    batch_ = ttorch.stack(batch_)
    batch_ = batch_.float()
    batch_.b = ttorch.pow(batch_.b, 2) + 1.0
    batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5))
    mean_b = batch_.b.mean()
    even_index_a = batch_.a[:, ::2]
    batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0)
    return batch_, mean_b, even_index_a


def get_data():
    return {
        'a': torch.rand(size=(T, 8)),
        'b': torch.rand(size=(6,)),
        'c': {
            'd': torch.randint(0, 10, size=(1,))
        }
    }


if __name__ == "__main__":
    batch = [get_data() for _ in range(B)]
    batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch))
    batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch))
    print(batch0)
    print('\n\n')
    print(batch1)

    assert torch.abs(mean0 - mean1) < 1e-6
    print('mean0 & mean1:', mean0, mean1)
    print('\n')

    assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6
    print('even_index_a0:', even_index_a0)
    print('even_index_a1:', even_index_a1)

    assert len(batch0) == B
    assert len(batch1) == B
    assert isinstance(batch1[0], ttorch.Tensor)
    print(batch1[0].shape)

The output should be like below, and all the assertions can be passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
[{'a': tensor([[0.2528, 0.4177, 0.5927, 0.7439, 0.2610, 0.9414, 0.3309, 0.8179],
        [0.2426, 0.5285, 0.7025, 0.1007, 0.7717, 0.2163, 0.5488, 0.6384],
        [0.2763, 0.1325, 0.2971, 0.6536, 0.6680, 0.6659, 0.1535, 0.6603]]), 'b': tensor([0.8131, 0.1540, 0.3999, 0.9620, 0.0247, 0.5228]), 'c': {'d': tensor([3])}}, {'a': tensor([[0.0930, 0.9039, 0.7044, 0.9361, 0.9592, 0.0232, 0.8527, 0.8502],
        [0.9800, 0.5632, 0.7258, 0.5329, 0.4701, 0.6252, 0.6134, 0.4340],
        [0.6840, 0.1951, 0.2539, 0.0763, 0.3272, 0.9452, 0.6882, 0.1001]]), 'b': tensor([0.8683, 0.5246, 0.5320, 0.4137, 0.8252, 0.5588]), 'c': {'d': tensor([3])}}, {'a': tensor([[0.5771, 0.9907, 0.1408, 0.6211, 0.2264, 0.6072, 0.5538, 0.1231],
        [0.1719, 0.5936, 0.1351, 0.8051, 0.5816, 0.8645, 0.4902, 0.6489],
        [0.1467, 0.7644, 0.3178, 0.4616, 0.6666, 0.2651, 0.4480, 0.3887]]), 'b': tensor([0.4217, 0.0332, 0.2061, 0.5351, 0.7510, 0.2466]), 'c': {'d': tensor([7])}}, {'a': tensor([[0.9133, 0.8223, 0.1066, 0.8299, 0.9051, 0.4557, 0.0785, 0.0093],
        [0.8969, 0.3639, 0.0384, 0.0895, 0.5555, 0.6507, 0.6078, 0.4365],
        [0.1915, 0.2788, 0.9265, 0.0536, 0.4257, 0.9293, 0.6852, 0.8060]]), 'b': tensor([0.3471, 0.9971, 0.1674, 0.5123, 0.2763, 0.9870]), 'c': {'d': tensor([0])}}]



(<Tensor 0x7f17f2932e50>
├── 'a' --> tensor([[[0.2528, 0.4177, 0.5927, 0.7439, 0.2610, 0.9414, 0.3309, 0.8179],
│                    [0.2426, 0.5285, 0.7025, 0.1007, 0.7717, 0.2163, 0.5488, 0.6384],
│                    [0.2763, 0.1325, 0.2971, 0.6536, 0.6680, 0.6659, 0.1535, 0.6603]]])
├── 'b' --> tensor([[1.6611, 1.0237, 1.1599, 1.9255, 1.0006, 1.2733]])
└── 'c' --> <Tensor 0x7f17f2932e10>
    ├── 'd' --> tensor([[3.]])
    └── 'noise' --> tensor([[[[ 0.6194, -0.8291, -1.1650,  0.0764,  0.2556],
                              [-1.4103,  1.0407, -0.1023,  0.3174,  0.4812],
                              [ 0.6355, -0.9907, -0.1500, -1.4863,  0.3427],
                              [-0.1761,  0.6150, -1.6094,  0.4111,  1.8947]],
                    
                             [[ 0.8455,  1.8923,  0.9399, -1.4162, -0.9732],
                              [-1.6211,  2.3156,  1.2249,  0.8859, -0.5506],
                              [ 0.1431, -0.6609, -2.5998,  0.3779,  0.0424],
                              [-1.0839, -0.2232, -0.3266,  0.8982, -0.1456]],
                    
                             [[-1.4125,  0.6481,  1.9928,  0.3429, -0.6100],
                              [ 0.9297,  0.6323,  0.8506, -0.6363, -0.6004],
                              [ 0.5461,  0.0813, -0.8113, -0.0896,  0.6426],
                              [-0.0986, -0.3976, -0.1211, -1.0975, -0.0321]]]])
, <Tensor 0x7f17f2932ed0>
├── 'a' --> tensor([[[0.0930, 0.9039, 0.7044, 0.9361, 0.9592, 0.0232, 0.8527, 0.8502],
│                    [0.9800, 0.5632, 0.7258, 0.5329, 0.4701, 0.6252, 0.6134, 0.4340],
│                    [0.6840, 0.1951, 0.2539, 0.0763, 0.3272, 0.9452, 0.6882, 0.1001]]])
├── 'b' --> tensor([[1.7539, 1.2752, 1.2830, 1.1712, 1.6810, 1.3123]])
└── 'c' --> <Tensor 0x7f17f2932e90>
    ├── 'd' --> tensor([[3.]])
    └── 'noise' --> tensor([[[[ 2.4444,  1.1720,  0.8632,  0.4761,  1.1405],
                              [-0.6512,  0.1339,  1.3094, -2.2030, -1.2528],
                              [-0.6839, -1.4079, -1.0550,  0.2930, -0.3839],
                              [-0.7287,  1.4465,  0.3683, -0.1231, -0.1292]],
                    
                             [[-2.0101,  1.7442, -0.0237,  0.9152,  0.0610],
                              [-0.2823,  0.9592, -0.6154, -0.9371, -0.7711],
                              [ 2.1444, -1.3623, -0.5261,  1.2471, -1.3095],
                              [-1.0510, -0.4227, -0.1627,  1.2090, -0.0536]],
                    
                             [[-0.3722,  0.4047, -0.2497, -1.1743,  2.9188],
                              [-0.2216, -1.2383,  2.9361, -0.0036, -0.8625],
                              [-1.2923,  0.1095, -0.5218,  1.1455, -0.3562],
                              [-1.0939,  0.1056,  0.3968,  0.5477,  0.5038]]]])
, <Tensor 0x7f17f2932f50>
├── 'a' --> tensor([[[0.5771, 0.9907, 0.1408, 0.6211, 0.2264, 0.6072, 0.5538, 0.1231],
│                    [0.1719, 0.5936, 0.1351, 0.8051, 0.5816, 0.8645, 0.4902, 0.6489],
│                    [0.1467, 0.7644, 0.3178, 0.4616, 0.6666, 0.2651, 0.4480, 0.3887]]])
├── 'b' --> tensor([[1.1779, 1.0011, 1.0425, 1.2864, 1.5640, 1.0608]])
└── 'c' --> <Tensor 0x7f17f2932f10>
    ├── 'd' --> tensor([[7.]])
    └── 'noise' --> tensor([[[[ 0.6417,  0.7936,  1.7457, -0.3449,  0.0586],
                              [ 0.5191, -0.5138,  1.5452, -0.3986,  0.8913],
                              [-1.1714, -0.5297,  0.2113, -0.9917, -1.0074],
                              [-0.5967, -0.9461, -0.6241, -1.3767, -0.9106]],
                    
                             [[ 0.2880, -0.9860, -1.7877, -1.0784,  0.0903],
                              [-0.7338,  1.2487,  0.9152, -0.1813, -0.6848],
                              [ 0.8098, -2.6129,  0.6112,  0.5383, -0.5587],
                              [-1.0908, -0.5596,  0.2072,  2.1724,  0.2787]],
                    
                             [[ 1.5288, -0.3257, -1.9534,  2.3294,  0.4874],
                              [-0.2468,  0.0650, -0.8965,  0.9325,  0.1155],
                              [ 0.3582, -1.1133,  0.1037,  0.2556,  1.1160],
                              [-0.0853,  2.2644, -1.0572,  0.4531,  1.0887]]]])
, <Tensor 0x7f17f2932fd0>
├── 'a' --> tensor([[[0.9133, 0.8223, 0.1066, 0.8299, 0.9051, 0.4557, 0.0785, 0.0093],
│                    [0.8969, 0.3639, 0.0384, 0.0895, 0.5555, 0.6507, 0.6078, 0.4365],
│                    [0.1915, 0.2788, 0.9265, 0.0536, 0.4257, 0.9293, 0.6852, 0.8060]]])
├── 'b' --> tensor([[1.1205, 1.9943, 1.0280, 1.2624, 1.0764, 1.9742]])
└── 'c' --> <Tensor 0x7f17f2932f90>
    ├── 'd' --> tensor([[0.]])
    └── 'noise' --> tensor([[[[-0.6496,  0.9187, -0.8898,  1.7309, -0.3362],
                              [-0.8324, -0.0810,  0.5725,  0.9787,  1.0387],
                              [ 0.5681,  0.6451, -0.0196,  1.3154,  0.1205],
                              [-0.9782,  0.6803,  0.6962, -0.9089,  0.2468]],
                    
                             [[-0.1351,  0.5176, -0.2932, -0.0736,  0.2269],
                              [-0.2297,  0.9318, -0.8866, -0.4340,  1.6622],
                              [-0.1321,  1.8283, -1.5091,  0.7979,  0.5349],
                              [-2.2566,  0.9935, -0.3138, -2.2293, -0.1792]],
                    
                             [[ 0.6219,  0.1924, -0.2116, -1.1552,  0.4944],
                              [ 0.1520, -1.2640, -0.4181,  0.8976,  0.3313],
                              [-1.1501,  0.6571, -0.2058,  0.6300, -0.6956],
                              [ 1.2968,  0.3220, -0.1916,  0.8898, -0.8460]]]])
)
mean0 & mean1: tensor(1.3379) tensor(1.3379)


even_index_a0: tensor([[[0.2528, 0.4177, 0.5927, 0.7439, 0.2610, 0.9414, 0.3309, 0.8179],
         [0.2763, 0.1325, 0.2971, 0.6536, 0.6680, 0.6659, 0.1535, 0.6603]],

        [[0.0930, 0.9039, 0.7044, 0.9361, 0.9592, 0.0232, 0.8527, 0.8502],
         [0.6840, 0.1951, 0.2539, 0.0763, 0.3272, 0.9452, 0.6882, 0.1001]],

        [[0.5771, 0.9907, 0.1408, 0.6211, 0.2264, 0.6072, 0.5538, 0.1231],
         [0.1467, 0.7644, 0.3178, 0.4616, 0.6666, 0.2651, 0.4480, 0.3887]],

        [[0.9133, 0.8223, 0.1066, 0.8299, 0.9051, 0.4557, 0.0785, 0.0093],
         [0.1915, 0.2788, 0.9265, 0.0536, 0.4257, 0.9293, 0.6852, 0.8060]]])
even_index_a1: tensor([[[0.2528, 0.4177, 0.5927, 0.7439, 0.2610, 0.9414, 0.3309, 0.8179],
         [0.2763, 0.1325, 0.2971, 0.6536, 0.6680, 0.6659, 0.1535, 0.6603]],

        [[0.0930, 0.9039, 0.7044, 0.9361, 0.9592, 0.0232, 0.8527, 0.8502],
         [0.6840, 0.1951, 0.2539, 0.0763, 0.3272, 0.9452, 0.6882, 0.1001]],

        [[0.5771, 0.9907, 0.1408, 0.6211, 0.2264, 0.6072, 0.5538, 0.1231],
         [0.1467, 0.7644, 0.3178, 0.4616, 0.6666, 0.2651, 0.4480, 0.3887]],

        [[0.9133, 0.8223, 0.1066, 0.8299, 0.9051, 0.4557, 0.0785, 0.0093],
         [0.1915, 0.2788, 0.9265, 0.0536, 0.4257, 0.9293, 0.6852, 0.8060]]])
<Size 0x7f17f2932d50>
├── 'a' --> torch.Size([1, 3, 8])
├── 'b' --> torch.Size([1, 6])
└── 'c' --> <Size 0x7f17f2932d10>
    ├── 'd' --> torch.Size([1, 1])
    └── 'noise' --> torch.Size([1, 3, 4, 5])

The implement with treetensor API is much simpler and clearer.