Customized Operations For Different Fields

Here is another example of the custom operations implemented with both native torch API and treetensor API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import copy

import torch

import treetensor.torch as ttorch

T, B = 3, 4


def with_nativetensor(batch_):
    mean_b_list = []
    even_index_a_list = []
    for i in range(len(batch_)):
        for k, v in batch_[i].items():
            if k == 'a':
                v = v.float()
                even_index_a_list.append(v[::2])
            elif k == 'b':
                v = v.float()
                transformed_v = torch.pow(v, 2) + 1.0
                mean_b_list.append(transformed_v.mean())
            elif k == 'c':
                for k1, v1 in v.items():
                    if k1 == 'd':
                        v1 = v1.float()
                    else:
                        print('ignore keys: {}'.format(k1))
            else:
                print('ignore keys: {}'.format(k))
    for i in range(len(batch_)):
        for k in batch_[i].keys():
            if k == 'd':
                batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5))

    mean_b = sum(mean_b_list) / len(mean_b_list)
    even_index_a = torch.stack(even_index_a_list, dim=0)
    return batch_, mean_b, even_index_a


def with_treetensor(batch_):
    batch_ = [ttorch.tensor(b) for b in batch_]
    batch_ = ttorch.stack(batch_)
    batch_ = batch_.float()
    batch_.b = ttorch.pow(batch_.b, 2) + 1.0
    batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5))
    mean_b = batch_.b.mean()
    even_index_a = batch_.a[:, ::2]
    batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0)
    return batch_, mean_b, even_index_a


def get_data():
    return {
        'a': torch.rand(size=(T, 8)),
        'b': torch.rand(size=(6,)),
        'c': {
            'd': torch.randint(0, 10, size=(1,))
        }
    }


if __name__ == "__main__":
    batch = [get_data() for _ in range(B)]
    batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch))
    batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch))
    print(batch0)
    print('\n\n')
    print(batch1)

    assert torch.abs(mean0 - mean1) < 1e-6
    print('mean0 & mean1:', mean0, mean1)
    print('\n')

    assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6
    print('even_index_a0:', even_index_a0)
    print('even_index_a1:', even_index_a1)

    assert len(batch0) == B
    assert len(batch1) == B
    assert isinstance(batch1[0], ttorch.Tensor)
    print(batch1[0].shape)

The output should be like below, and all the assertions can be passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
[{'a': tensor([[0.4891, 0.4983, 0.3518, 0.3453, 0.0062, 0.1869, 0.2393, 0.2685],
        [0.7944, 0.0033, 0.2901, 0.3635, 0.8537, 0.7105, 0.4277, 0.5171],
        [0.6045, 0.5730, 0.0281, 0.1397, 0.0980, 0.1078, 0.9350, 0.8590]]), 'b': tensor([0.8431, 0.4037, 0.4630, 0.6901, 0.8469, 0.8049]), 'c': {'d': tensor([1])}}, {'a': tensor([[0.3522, 0.0802, 0.7094, 0.5263, 0.7597, 0.8375, 0.3589, 0.5269],
        [0.8205, 0.3569, 0.1815, 0.6871, 0.1704, 0.7794, 0.2382, 0.5780],
        [0.3023, 0.4060, 0.6800, 0.9777, 0.7968, 0.7725, 0.5000, 0.6039]]), 'b': tensor([0.4178, 0.1959, 0.6307, 0.4192, 0.5912, 0.4633]), 'c': {'d': tensor([1])}}, {'a': tensor([[0.0300, 0.9757, 0.0152, 0.4424, 0.7109, 0.1433, 0.2944, 0.4652],
        [0.4663, 0.5075, 0.5755, 0.3119, 0.2193, 0.4129, 0.9700, 0.2495],
        [0.6488, 0.5735, 0.0649, 0.0577, 0.4324, 0.4917, 0.7781, 0.5089]]), 'b': tensor([0.9035, 0.7308, 0.0986, 0.7287, 0.1370, 0.2584]), 'c': {'d': tensor([6])}}, {'a': tensor([[0.7394, 0.7904, 0.2096, 0.4392, 0.8947, 0.0067, 0.0271, 0.7526],
        [0.8193, 0.4045, 0.2795, 0.3559, 0.5538, 0.6154, 0.8504, 0.7521],
        [0.8375, 0.7244, 0.8103, 0.6262, 0.1807, 0.3922, 0.1483, 0.8142]]), 'b': tensor([0.1548, 0.3610, 0.9881, 0.6237, 0.6998, 0.3502]), 'c': {'d': tensor([1])}}]



(<Tensor 0x7f996e898f90>
├── 'a' --> tensor([[[0.4891, 0.4983, 0.3518, 0.3453, 0.0062, 0.1869, 0.2393, 0.2685],
│                    [0.7944, 0.0033, 0.2901, 0.3635, 0.8537, 0.7105, 0.4277, 0.5171],
│                    [0.6045, 0.5730, 0.0281, 0.1397, 0.0980, 0.1078, 0.9350, 0.8590]]])
├── 'b' --> tensor([[1.7108, 1.1630, 1.2144, 1.4762, 1.7172, 1.6479]])
└── 'c' --> <Tensor 0x7f996e898f50>
    ├── 'd' --> tensor([[1.]])
    └── 'noise' --> tensor([[[[ 1.0692,  0.5142,  1.2165,  0.7112,  0.1560],
                              [-1.0644, -1.2981,  1.5262, -0.6554,  0.5670],
                              [-0.9887, -0.9253,  1.0349, -0.2842,  0.1049],
                              [ 0.5471,  0.0200, -1.4839,  1.3312,  0.6531]],
                    
                             [[ 0.7144, -1.1904,  0.3611, -0.1194, -0.6089],
                              [-0.3058, -0.0769,  1.0874, -0.5698,  0.4647],
                              [-1.2264, -0.2936, -0.4243,  0.2722,  0.8972],
                              [-0.5769,  0.3364,  0.4670, -1.3205, -1.0617]],
                    
                             [[ 0.7087,  0.8764,  1.8436,  3.3415, -0.9340],
                              [ 0.4428, -1.1900, -0.8580,  1.1566,  0.1243],
                              [-2.5001,  1.7994, -0.6013, -0.4448, -0.4889],
                              [ 1.3573,  0.0114,  0.3192, -0.4993, -0.0766]]]])
, <Tensor 0x7f996e89f050>
├── 'a' --> tensor([[[0.3522, 0.0802, 0.7094, 0.5263, 0.7597, 0.8375, 0.3589, 0.5269],
│                    [0.8205, 0.3569, 0.1815, 0.6871, 0.1704, 0.7794, 0.2382, 0.5780],
│                    [0.3023, 0.4060, 0.6800, 0.9777, 0.7968, 0.7725, 0.5000, 0.6039]]])
├── 'b' --> tensor([[1.1745, 1.0384, 1.3978, 1.1757, 1.3495, 1.2147]])
└── 'c' --> <Tensor 0x7f996e898fd0>
    ├── 'd' --> tensor([[1.]])
    └── 'noise' --> tensor([[[[-0.8616,  1.0088,  0.4363,  0.3854, -0.4530],
                              [-1.3458,  0.2346,  0.0978, -1.1845, -0.2845],
                              [ 1.4661,  0.0099,  1.9448,  0.9461, -0.6596],
                              [ 0.5714, -0.4801,  0.3029,  1.3462, -0.5059]],
                    
                             [[-0.8858, -0.2870, -0.0291,  0.0954,  0.0958],
                              [-1.1399,  1.2167, -0.1784, -1.1622, -0.8205],
                              [-1.1587, -0.3150, -0.6488, -0.4852,  1.7651],
                              [-0.5875, -0.0582, -1.3466,  0.5439, -1.4773]],
                    
                             [[ 0.9459, -1.6996, -0.1797, -1.3600,  1.0270],
                              [-0.0406,  0.7433,  0.5458,  1.1187,  2.1452],
                              [ 0.5953,  0.8925, -1.6194, -1.1096, -0.1076],
                              [-0.4481, -0.3585, -0.2967, -1.3786, -0.2491]]]])
, <Tensor 0x7f996e89f0d0>
├── 'a' --> tensor([[[0.0300, 0.9757, 0.0152, 0.4424, 0.7109, 0.1433, 0.2944, 0.4652],
│                    [0.4663, 0.5075, 0.5755, 0.3119, 0.2193, 0.4129, 0.9700, 0.2495],
│                    [0.6488, 0.5735, 0.0649, 0.0577, 0.4324, 0.4917, 0.7781, 0.5089]]])
├── 'b' --> tensor([[1.8163, 1.5341, 1.0097, 1.5310, 1.0188, 1.0668]])
└── 'c' --> <Tensor 0x7f996e89f090>
    ├── 'd' --> tensor([[6.]])
    └── 'noise' --> tensor([[[[-0.0845, -0.2490,  0.4574, -0.2055, -0.5052],
                              [ 0.6392,  1.4586, -2.1247, -0.0794,  0.9077],
                              [-1.6232,  1.1221, -0.2364, -1.6875, -2.4838],
                              [-0.3776, -0.6827,  1.1125, -1.3828,  0.2339]],
                    
                             [[ 0.9376, -0.8476, -0.8220, -0.9632,  1.1160],
                              [ 0.3020,  0.9054,  0.5701, -0.7925, -0.4109],
                              [-1.2571, -0.3713, -0.7708,  0.1949, -0.5856],
                              [ 1.4660,  0.1360, -1.9473, -1.5097, -1.0084]],
                    
                             [[ 0.7180,  2.8065, -0.1474,  1.2503,  0.3180],
                              [-2.3330, -1.0297,  0.7406,  0.8737, -1.4619],
                              [ 1.4348,  1.0801,  1.4215,  0.2860,  1.4047],
                              [-0.7142, -0.1309, -3.4383, -0.4981, -0.7635]]]])
, <Tensor 0x7f996e89f150>
├── 'a' --> tensor([[[0.7394, 0.7904, 0.2096, 0.4392, 0.8947, 0.0067, 0.0271, 0.7526],
│                    [0.8193, 0.4045, 0.2795, 0.3559, 0.5538, 0.6154, 0.8504, 0.7521],
│                    [0.8375, 0.7244, 0.8103, 0.6262, 0.1807, 0.3922, 0.1483, 0.8142]]])
├── 'b' --> tensor([[1.0240, 1.1303, 1.9763, 1.3890, 1.4897, 1.1226]])
└── 'c' --> <Tensor 0x7f996e89f110>
    ├── 'd' --> tensor([[1.]])
    └── 'noise' --> tensor([[[[-0.1892, -0.5689,  0.6290,  1.0417, -0.5784],
                              [ 1.6671,  0.6265, -0.6924,  0.9414,  0.0544],
                              [ 0.2755,  1.3875,  0.5984,  0.5777, -0.2418],
                              [ 1.7944, -0.2855, -0.3323,  1.2764,  0.1819]],
                    
                             [[ 2.8290, -1.5367,  1.9294,  1.1419,  0.4531],
                              [ 0.2037,  0.6561,  0.5935,  0.2346, -1.0272],
                              [ 0.2232,  0.3071, -0.6290, -0.0665,  0.7090],
                              [-1.1359, -0.2510, -1.1360, -1.6798, -0.4487]],
                    
                             [[-1.5467,  1.3900,  1.6637, -0.8699, -1.5268],
                              [-1.0286,  0.6390,  0.0227,  0.0939, -0.5320],
                              [-0.7108, -0.8937, -0.2330, -0.5127, -1.1275],
                              [ 0.1164,  0.7993, -0.6485,  0.1880, -1.3724]]]])
)
mean0 & mean1: tensor(1.3495) tensor(1.3495)


even_index_a0: tensor([[[0.4891, 0.4983, 0.3518, 0.3453, 0.0062, 0.1869, 0.2393, 0.2685],
         [0.6045, 0.5730, 0.0281, 0.1397, 0.0980, 0.1078, 0.9350, 0.8590]],

        [[0.3522, 0.0802, 0.7094, 0.5263, 0.7597, 0.8375, 0.3589, 0.5269],
         [0.3023, 0.4060, 0.6800, 0.9777, 0.7968, 0.7725, 0.5000, 0.6039]],

        [[0.0300, 0.9757, 0.0152, 0.4424, 0.7109, 0.1433, 0.2944, 0.4652],
         [0.6488, 0.5735, 0.0649, 0.0577, 0.4324, 0.4917, 0.7781, 0.5089]],

        [[0.7394, 0.7904, 0.2096, 0.4392, 0.8947, 0.0067, 0.0271, 0.7526],
         [0.8375, 0.7244, 0.8103, 0.6262, 0.1807, 0.3922, 0.1483, 0.8142]]])
even_index_a1: tensor([[[0.4891, 0.4983, 0.3518, 0.3453, 0.0062, 0.1869, 0.2393, 0.2685],
         [0.6045, 0.5730, 0.0281, 0.1397, 0.0980, 0.1078, 0.9350, 0.8590]],

        [[0.3522, 0.0802, 0.7094, 0.5263, 0.7597, 0.8375, 0.3589, 0.5269],
         [0.3023, 0.4060, 0.6800, 0.9777, 0.7968, 0.7725, 0.5000, 0.6039]],

        [[0.0300, 0.9757, 0.0152, 0.4424, 0.7109, 0.1433, 0.2944, 0.4652],
         [0.6488, 0.5735, 0.0649, 0.0577, 0.4324, 0.4917, 0.7781, 0.5089]],

        [[0.7394, 0.7904, 0.2096, 0.4392, 0.8947, 0.0067, 0.0271, 0.7526],
         [0.8375, 0.7244, 0.8103, 0.6262, 0.1807, 0.3922, 0.1483, 0.8142]]])
<Size 0x7f996e898e90>
├── 'a' --> torch.Size([1, 3, 8])
├── 'b' --> torch.Size([1, 6])
└── 'c' --> <Size 0x7f996e898e50>
    ├── 'd' --> torch.Size([1, 1])
    └── 'noise' --> torch.Size([1, 3, 4, 5])

The implement with treetensor API is much simpler and clearer.