Customized Operations For Different Fields

Here is another example of the custom operations implemented with both native torch API and treetensor API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import copy

import torch

import treetensor.torch as ttorch

T, B = 3, 4


def with_nativetensor(batch_):
    mean_b_list = []
    even_index_a_list = []
    for i in range(len(batch_)):
        for k, v in batch_[i].items():
            if k == 'a':
                v = v.float()
                even_index_a_list.append(v[::2])
            elif k == 'b':
                v = v.float()
                transformed_v = torch.pow(v, 2) + 1.0
                mean_b_list.append(transformed_v.mean())
            elif k == 'c':
                for k1, v1 in v.items():
                    if k1 == 'd':
                        v1 = v1.float()
                    else:
                        print('ignore keys: {}'.format(k1))
            else:
                print('ignore keys: {}'.format(k))
    for i in range(len(batch_)):
        for k in batch_[i].keys():
            if k == 'd':
                batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5))

    mean_b = sum(mean_b_list) / len(mean_b_list)
    even_index_a = torch.stack(even_index_a_list, dim=0)
    return batch_, mean_b, even_index_a


def with_treetensor(batch_):
    batch_ = [ttorch.tensor(b) for b in batch_]
    batch_ = ttorch.stack(batch_)
    batch_ = batch_.float()
    batch_.b = ttorch.pow(batch_.b, 2) + 1.0
    batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5))
    mean_b = batch_.b.mean()
    even_index_a = batch_.a[:, ::2]
    batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0)
    return batch_, mean_b, even_index_a


def get_data():
    return {
        'a': torch.rand(size=(T, 8)),
        'b': torch.rand(size=(6,)),
        'c': {
            'd': torch.randint(0, 10, size=(1,))
        }
    }


if __name__ == "__main__":
    batch = [get_data() for _ in range(B)]
    batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch))
    batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch))
    print(batch0)
    print('\n\n')
    print(batch1)

    assert torch.abs(mean0 - mean1) < 1e-6
    print('mean0 & mean1:', mean0, mean1)
    print('\n')

    assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6
    print('even_index_a0:', even_index_a0)
    print('even_index_a1:', even_index_a1)

    assert len(batch0) == B
    assert len(batch1) == B
    assert isinstance(batch1[0], ttorch.Tensor)
    print(batch1[0].shape)

The output should be like below, and all the assertions can be passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
[{'a': tensor([[0.3313, 0.0882, 0.6579, 0.3501, 0.3909, 0.9383, 0.1561, 0.0703],
        [0.9961, 0.9322, 0.4953, 0.6302, 0.3226, 0.7668, 0.9302, 0.4677],
        [0.5480, 0.5045, 0.5743, 0.8292, 0.3024, 0.1904, 0.6685, 0.9490]]), 'b': tensor([0.2765, 0.8172, 0.9692, 0.5797, 0.9039, 0.0055]), 'c': {'d': tensor([1])}}, {'a': tensor([[0.6972, 0.5427, 0.0417, 0.8488, 0.4968, 0.6170, 0.8957, 0.2799],
        [0.0127, 0.9730, 0.6826, 0.2333, 0.2428, 0.0609, 0.3715, 0.0104],
        [0.3351, 0.7599, 0.1900, 0.9579, 0.9783, 0.0326, 0.1451, 0.7471]]), 'b': tensor([0.9764, 0.1874, 0.1346, 0.8179, 0.3964, 0.2178]), 'c': {'d': tensor([1])}}, {'a': tensor([[0.7080, 0.8097, 0.9506, 0.2266, 0.4898, 0.4191, 0.6291, 0.4642],
        [0.8838, 0.5935, 0.3174, 0.6670, 0.9232, 0.3604, 0.4453, 0.5772],
        [0.2580, 0.4739, 0.1686, 0.3093, 0.0614, 0.4859, 0.4071, 0.2028]]), 'b': tensor([0.6880, 0.8413, 0.1673, 0.8797, 0.0595, 0.2570]), 'c': {'d': tensor([7])}}, {'a': tensor([[0.8292, 0.1015, 0.0075, 0.1469, 0.4491, 0.2319, 0.8218, 0.0663],
        [0.3731, 0.4277, 0.7806, 0.3766, 0.4146, 0.2895, 0.2871, 0.9064],
        [0.4393, 0.0466, 0.6933, 0.3058, 0.1305, 0.2063, 0.5092, 0.3820]]), 'b': tensor([0.7516, 0.2438, 0.9755, 0.6713, 0.4114, 0.0472]), 'c': {'d': tensor([2])}}]



(<Tensor 0x7f8fe5cebfd0>
├── 'a' --> tensor([[[0.3313, 0.0882, 0.6579, 0.3501, 0.3909, 0.9383, 0.1561, 0.0703],
│                    [0.9961, 0.9322, 0.4953, 0.6302, 0.3226, 0.7668, 0.9302, 0.4677],
│                    [0.5480, 0.5045, 0.5743, 0.8292, 0.3024, 0.1904, 0.6685, 0.9490]]])
├── 'b' --> tensor([[1.0764, 1.6678, 1.9394, 1.3360, 1.8170, 1.0000]])
└── 'c' --> <Tensor 0x7f8fe5cebf90>
    ├── 'd' --> tensor([[1.]])
    └── 'noise' --> tensor([[[[-0.0621, -0.7930,  0.7569, -1.0191, -0.3130],
                              [-0.4686, -0.1089, -0.9081, -1.0446,  0.4272],
                              [ 1.0152, -0.3799,  0.9694,  1.8244,  0.5693],
                              [-1.1199,  0.7588, -0.7775,  0.7499,  0.1142]],
                    
                             [[ 0.5598, -0.2789,  1.1207, -1.2933,  0.1816],
                              [-0.5194,  0.0579, -0.4397,  0.5926, -0.9488],
                              [-1.6616, -1.4776, -2.0146,  0.5458, -1.4844],
                              [ 0.5248,  0.5729,  0.5522, -0.2959,  1.0412]],
                    
                             [[ 1.9960,  0.7783, -1.4010,  0.5961,  0.2503],
                              [-1.6178, -0.7815,  0.1777,  0.3704, -0.1290],
                              [ 0.5481,  0.1488,  1.8858,  0.2975, -0.7352],
                              [-0.4605,  0.6365, -0.3557, -0.0358, -0.7355]]]])
, <Tensor 0x7f8fe5cf2090>
├── 'a' --> tensor([[[0.6972, 0.5427, 0.0417, 0.8488, 0.4968, 0.6170, 0.8957, 0.2799],
│                    [0.0127, 0.9730, 0.6826, 0.2333, 0.2428, 0.0609, 0.3715, 0.0104],
│                    [0.3351, 0.7599, 0.1900, 0.9579, 0.9783, 0.0326, 0.1451, 0.7471]]])
├── 'b' --> tensor([[1.9533, 1.0351, 1.0181, 1.6690, 1.1571, 1.0474]])
└── 'c' --> <Tensor 0x7f8fe5cf2050>
    ├── 'd' --> tensor([[1.]])
    └── 'noise' --> tensor([[[[-0.0456,  0.3777,  0.1782, -0.0794,  1.8114],
                              [ 0.4518,  0.2652,  1.5227,  0.3376, -0.0074],
                              [ 0.3699,  0.7519,  1.3634,  0.3915, -0.4720],
                              [ 1.6225,  0.4938,  0.2811, -1.6052,  0.1634]],
                    
                             [[-0.8232,  1.2443,  0.6446,  0.0268,  0.2494],
                              [ 1.1742, -0.0170,  1.2971,  1.3382,  0.6993],
                              [-1.4969,  1.2580, -0.3939, -1.3453,  0.6606],
                              [-0.1392,  0.5910,  0.4448,  0.6361,  0.0247]],
                    
                             [[ 1.0457,  0.7729,  0.7605,  0.1507,  0.0617],
                              [ 0.9171,  0.3275, -0.8157,  1.2837,  1.1959],
                              [-0.5456,  1.5890, -1.9056,  2.2679, -0.6888],
                              [ 1.2832, -1.3582,  1.5985, -1.1178, -1.9899]]]])
, <Tensor 0x7f8fe5cf2110>
├── 'a' --> tensor([[[0.7080, 0.8097, 0.9506, 0.2266, 0.4898, 0.4191, 0.6291, 0.4642],
│                    [0.8838, 0.5935, 0.3174, 0.6670, 0.9232, 0.3604, 0.4453, 0.5772],
│                    [0.2580, 0.4739, 0.1686, 0.3093, 0.0614, 0.4859, 0.4071, 0.2028]]])
├── 'b' --> tensor([[1.4734, 1.7078, 1.0280, 1.7738, 1.0035, 1.0661]])
└── 'c' --> <Tensor 0x7f8fe5cf20d0>
    ├── 'd' --> tensor([[7.]])
    └── 'noise' --> tensor([[[[-1.3803, -0.9311,  0.8792, -0.4538, -0.4213],
                              [ 1.0475, -0.0281,  0.3050, -0.1078, -0.1914],
                              [-1.8961, -1.8887, -0.3832, -0.9808,  0.3060],
                              [-0.3629, -1.4063, -0.2881,  0.5847,  0.7290]],
                    
                             [[-0.0598, -0.5249, -0.6472, -0.1638, -1.7160],
                              [-0.6548,  0.7673,  0.9549, -1.3975,  1.1515],
                              [-0.8544,  0.2624,  1.9701,  2.6244,  0.5877],
                              [ 0.4577, -1.2086, -0.6552,  0.3812,  0.4401]],
                    
                             [[-1.3914, -0.6348, -0.5649,  1.3535,  0.0401],
                              [-1.5634, -0.3972, -1.2206,  0.5179, -0.5741],
                              [ 0.9142,  0.6368,  0.7496,  1.7236, -0.4995],
                              [-0.4428, -1.6781,  0.7597,  1.8426, -0.5079]]]])
, <Tensor 0x7f8fe5cf2190>
├── 'a' --> tensor([[[0.8292, 0.1015, 0.0075, 0.1469, 0.4491, 0.2319, 0.8218, 0.0663],
│                    [0.3731, 0.4277, 0.7806, 0.3766, 0.4146, 0.2895, 0.2871, 0.9064],
│                    [0.4393, 0.0466, 0.6933, 0.3058, 0.1305, 0.2063, 0.5092, 0.3820]]])
├── 'b' --> tensor([[1.5649, 1.0594, 1.9516, 1.4506, 1.1692, 1.0022]])
└── 'c' --> <Tensor 0x7f8fe5cf2150>
    ├── 'd' --> tensor([[2.]])
    └── 'noise' --> tensor([[[[-1.5034,  2.4455, -1.2657, -0.1664,  1.2356],
                              [-1.0091, -1.7266, -0.4433, -1.2735,  0.6320],
                              [ 0.9520,  0.5521,  1.5311,  0.1378, -0.6468],
                              [-0.9342, -0.3182,  0.5041,  1.3123,  1.4483]],
                    
                             [[-0.1654, -0.1329, -0.0582,  0.7893, -0.5403],
                              [ 1.5499,  0.1751, -0.1191, -0.1223, -0.5811],
                              [ 0.9971,  0.1710, -0.7453, -0.1878, -1.1388],
                              [-1.9537, -0.2508,  0.3788, -0.8452, -0.5041]],
                    
                             [[ 0.9284,  0.5913,  1.1612, -0.1337,  0.3932],
                              [ 0.0346, -0.2545, -2.0326, -1.4314, -0.3497],
                              [-0.6397,  1.2387,  1.2089, -0.1503,  0.1883],
                              [-0.0752,  0.6694,  1.0501,  1.3362, -0.5575]]]])
)
mean0 & mean1: tensor(1.3736) tensor(1.3736)


even_index_a0: tensor([[[0.3313, 0.0882, 0.6579, 0.3501, 0.3909, 0.9383, 0.1561, 0.0703],
         [0.5480, 0.5045, 0.5743, 0.8292, 0.3024, 0.1904, 0.6685, 0.9490]],

        [[0.6972, 0.5427, 0.0417, 0.8488, 0.4968, 0.6170, 0.8957, 0.2799],
         [0.3351, 0.7599, 0.1900, 0.9579, 0.9783, 0.0326, 0.1451, 0.7471]],

        [[0.7080, 0.8097, 0.9506, 0.2266, 0.4898, 0.4191, 0.6291, 0.4642],
         [0.2580, 0.4739, 0.1686, 0.3093, 0.0614, 0.4859, 0.4071, 0.2028]],

        [[0.8292, 0.1015, 0.0075, 0.1469, 0.4491, 0.2319, 0.8218, 0.0663],
         [0.4393, 0.0466, 0.6933, 0.3058, 0.1305, 0.2063, 0.5092, 0.3820]]])
even_index_a1: tensor([[[0.3313, 0.0882, 0.6579, 0.3501, 0.3909, 0.9383, 0.1561, 0.0703],
         [0.5480, 0.5045, 0.5743, 0.8292, 0.3024, 0.1904, 0.6685, 0.9490]],

        [[0.6972, 0.5427, 0.0417, 0.8488, 0.4968, 0.6170, 0.8957, 0.2799],
         [0.3351, 0.7599, 0.1900, 0.9579, 0.9783, 0.0326, 0.1451, 0.7471]],

        [[0.7080, 0.8097, 0.9506, 0.2266, 0.4898, 0.4191, 0.6291, 0.4642],
         [0.2580, 0.4739, 0.1686, 0.3093, 0.0614, 0.4859, 0.4071, 0.2028]],

        [[0.8292, 0.1015, 0.0075, 0.1469, 0.4491, 0.2319, 0.8218, 0.0663],
         [0.4393, 0.0466, 0.6933, 0.3058, 0.1305, 0.2063, 0.5092, 0.3820]]])
<Size 0x7f8fe5cebed0>
├── 'a' --> torch.Size([1, 3, 8])
├── 'b' --> torch.Size([1, 6])
└── 'c' --> <Size 0x7f8fe5cebe90>
    ├── 'd' --> torch.Size([1, 1])
    └── 'noise' --> torch.Size([1, 3, 4, 5])

The implement with treetensor API is much simpler and clearer.