Customized Operations For Different Fields

Here is another example of the custom operations implemented with both native torch API and treetensor API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import copy

import torch

import treetensor.torch as ttorch

T, B = 3, 4


def with_nativetensor(batch_):
    mean_b_list = []
    even_index_a_list = []
    for i in range(len(batch_)):
        for k, v in batch_[i].items():
            if k == 'a':
                v = v.float()
                even_index_a_list.append(v[::2])
            elif k == 'b':
                v = v.float()
                transformed_v = torch.pow(v, 2) + 1.0
                mean_b_list.append(transformed_v.mean())
            elif k == 'c':
                for k1, v1 in v.items():
                    if k1 == 'd':
                        v1 = v1.float()
                    else:
                        print('ignore keys: {}'.format(k1))
            else:
                print('ignore keys: {}'.format(k))
    for i in range(len(batch_)):
        for k in batch_[i].keys():
            if k == 'd':
                batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5))

    mean_b = sum(mean_b_list) / len(mean_b_list)
    even_index_a = torch.stack(even_index_a_list, dim=0)
    return batch_, mean_b, even_index_a


def with_treetensor(batch_):
    batch_ = [ttorch.tensor(b) for b in batch_]
    batch_ = ttorch.stack(batch_)
    batch_ = batch_.float()
    batch_.b = ttorch.pow(batch_.b, 2) + 1.0
    batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5))
    mean_b = batch_.b.mean()
    even_index_a = batch_.a[:, ::2]
    batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0)
    return batch_, mean_b, even_index_a


def get_data():
    return {
        'a': torch.rand(size=(T, 8)),
        'b': torch.rand(size=(6,)),
        'c': {
            'd': torch.randint(0, 10, size=(1,))
        }
    }


if __name__ == "__main__":
    batch = [get_data() for _ in range(B)]
    batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch))
    batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch))
    print(batch0)
    print('\n\n')
    print(batch1)

    assert torch.abs(mean0 - mean1) < 1e-6
    print('mean0 & mean1:', mean0, mean1)
    print('\n')

    assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6
    print('even_index_a0:', even_index_a0)
    print('even_index_a1:', even_index_a1)

    assert len(batch0) == B
    assert len(batch1) == B
    assert isinstance(batch1[0], ttorch.Tensor)
    print(batch1[0].shape)

The output should be like below, and all the assertions can be passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
[{'a': tensor([[0.0530, 0.1316, 0.3868, 0.0582, 0.0592, 0.2844, 0.7060, 0.7104],
        [0.2715, 0.6086, 0.3717, 0.9783, 0.8932, 0.2700, 0.0806, 0.9643],
        [0.6430, 0.4201, 0.3315, 0.5621, 0.4383, 0.7077, 0.8652, 0.7994]]), 'b': tensor([0.7424, 0.9573, 0.3920, 0.1052, 0.1581, 0.8148]), 'c': {'d': tensor([2])}}, {'a': tensor([[0.0653, 0.4086, 0.6234, 0.7911, 0.9723, 0.0044, 0.9084, 0.6511],
        [0.4518, 0.1293, 0.2198, 0.2321, 0.9656, 0.4900, 0.3904, 0.9074],
        [0.9259, 0.2673, 0.2197, 0.2338, 0.8283, 0.3773, 0.8736, 0.9157]]), 'b': tensor([0.0709, 0.8475, 0.8364, 0.9020, 0.4540, 0.5760]), 'c': {'d': tensor([8])}}, {'a': tensor([[0.1737, 0.6593, 0.8535, 0.4953, 0.6238, 0.9004, 0.8182, 0.2689],
        [0.2865, 0.4794, 0.7387, 0.0603, 0.9837, 0.6522, 0.6537, 0.0974],
        [0.0411, 0.4282, 0.6774, 0.0315, 0.7757, 0.9708, 0.9896, 0.7082]]), 'b': tensor([0.0826, 0.7268, 0.3619, 0.9916, 0.2159, 0.2764]), 'c': {'d': tensor([2])}}, {'a': tensor([[0.7990, 0.3954, 0.2579, 0.3021, 0.7232, 0.1034, 0.9367, 0.8425],
        [0.7894, 0.4812, 0.3679, 0.4769, 0.5864, 0.6116, 0.4470, 0.3827],
        [0.6414, 0.4988, 0.6573, 0.6941, 0.3207, 0.3858, 0.7914, 0.9150]]), 'b': tensor([0.1669, 0.1758, 0.1928, 0.0075, 0.8306, 0.6041]), 'c': {'d': tensor([8])}}]



(<Tensor 0x7fe654722fd0>
├── 'a' --> tensor([[[0.0530, 0.1316, 0.3868, 0.0582, 0.0592, 0.2844, 0.7060, 0.7104],
│                    [0.2715, 0.6086, 0.3717, 0.9783, 0.8932, 0.2700, 0.0806, 0.9643],
│                    [0.6430, 0.4201, 0.3315, 0.5621, 0.4383, 0.7077, 0.8652, 0.7994]]])
├── 'b' --> tensor([[1.5512, 1.9164, 1.1537, 1.0111, 1.0250, 1.6638]])
└── 'c' --> <Tensor 0x7fe654722f90>
    ├── 'd' --> tensor([[2.]])
    └── 'noise' --> tensor([[[[-0.6145, -0.1458,  0.1173,  0.3199, -0.2685],
                              [-0.4438, -0.9441, -1.3548, -1.6339, -0.1808],
                              [-0.1598, -0.7564,  0.5177,  1.6621, -0.7301],
                              [-0.0091, -0.0693,  0.3701, -0.1008, -0.9677]],
                    
                             [[-0.0215, -0.3856, -0.2832, -0.6427, -1.1976],
                              [ 0.1676, -0.7570,  1.3808, -0.6569,  0.7185],
                              [-1.7909, -0.6165, -0.3450, -0.2981,  1.4836],
                              [ 1.1538, -0.4351, -0.3297,  0.2331, -1.3635]],
                    
                             [[ 0.8167,  1.4640,  0.4096, -0.2440, -0.9645],
                              [-1.0080, -1.1316, -0.0838,  0.5443, -1.8280],
                              [-0.2243,  0.3282, -0.7281, -0.8862, -1.0531],
                              [ 0.4967, -1.6835,  2.8434, -1.4404,  1.2524]]]])
, <Tensor 0x7fe654729090>
├── 'a' --> tensor([[[0.0653, 0.4086, 0.6234, 0.7911, 0.9723, 0.0044, 0.9084, 0.6511],
│                    [0.4518, 0.1293, 0.2198, 0.2321, 0.9656, 0.4900, 0.3904, 0.9074],
│                    [0.9259, 0.2673, 0.2197, 0.2338, 0.8283, 0.3773, 0.8736, 0.9157]]])
├── 'b' --> tensor([[1.0050, 1.7183, 1.6996, 1.8136, 1.2061, 1.3318]])
└── 'c' --> <Tensor 0x7fe654729050>
    ├── 'd' --> tensor([[8.]])
    └── 'noise' --> tensor([[[[ 0.6687, -0.1380, -0.2344, -1.0965,  0.9932],
                              [-0.4782, -1.5013,  0.1346,  2.0136, -0.0120],
                              [-0.8172,  0.3474,  0.2281,  0.6967,  0.3457],
                              [ 0.8758, -1.9384, -0.5588,  0.4844, -1.3965]],
                    
                             [[ 0.5175,  1.3308,  1.5126,  1.8909,  0.1056],
                              [ 0.6862,  1.4825, -1.2959, -1.5295,  1.0123],
                              [-0.2036,  1.4430,  0.3562, -0.0324, -0.8127],
                              [-0.0664, -1.1613,  2.0566, -0.5090, -1.0022]],
                    
                             [[-0.0094,  0.4440,  1.0630, -0.1183, -0.8398],
                              [-0.7594, -0.3139,  0.7612,  0.7844,  0.7298],
                              [-0.0850, -1.3562, -1.3477, -1.9326, -0.1287],
                              [-0.3684, -0.1948, -0.2106,  0.1244, -0.6444]]]])
, <Tensor 0x7fe654729110>
├── 'a' --> tensor([[[0.1737, 0.6593, 0.8535, 0.4953, 0.6238, 0.9004, 0.8182, 0.2689],
│                    [0.2865, 0.4794, 0.7387, 0.0603, 0.9837, 0.6522, 0.6537, 0.0974],
│                    [0.0411, 0.4282, 0.6774, 0.0315, 0.7757, 0.9708, 0.9896, 0.7082]]])
├── 'b' --> tensor([[1.0068, 1.5282, 1.1310, 1.9833, 1.0466, 1.0764]])
└── 'c' --> <Tensor 0x7fe6547290d0>
    ├── 'd' --> tensor([[2.]])
    └── 'noise' --> tensor([[[[ 0.4160,  2.0009, -1.7171,  0.5402, -2.3210],
                              [-0.2561,  0.4562, -1.0994,  1.2536, -1.3823],
                              [ 0.4414, -0.6419,  0.4306, -0.0867,  0.3658],
                              [ 0.3535, -0.7061,  1.8489, -0.3642, -0.4975]],
                    
                             [[-0.0270, -0.0539,  1.0774, -0.4132,  0.1036],
                              [ 0.0773,  0.7107, -0.3725, -0.3476, -0.0902],
                              [-1.1531,  0.4045,  0.6975,  0.1649, -1.1100],
                              [-0.2848, -0.0567,  1.1481, -1.3258, -1.5258]],
                    
                             [[ 0.6759, -0.1135, -0.0723, -0.0190, -0.2226],
                              [-1.6966,  0.0369,  1.4791, -0.0274,  0.5058],
                              [ 0.1358, -0.6024,  2.6037, -0.1705,  1.1308],
                              [-0.6299, -1.1563, -1.3002, -1.5118, -1.8600]]]])
, <Tensor 0x7fe654729190>
├── 'a' --> tensor([[[0.7990, 0.3954, 0.2579, 0.3021, 0.7232, 0.1034, 0.9367, 0.8425],
│                    [0.7894, 0.4812, 0.3679, 0.4769, 0.5864, 0.6116, 0.4470, 0.3827],
│                    [0.6414, 0.4988, 0.6573, 0.6941, 0.3207, 0.3858, 0.7914, 0.9150]]])
├── 'b' --> tensor([[1.0279, 1.0309, 1.0372, 1.0001, 1.6899, 1.3649]])
└── 'c' --> <Tensor 0x7fe654729150>
    ├── 'd' --> tensor([[8.]])
    └── 'noise' --> tensor([[[[-0.0519,  0.5795,  0.9674,  1.4432, -0.5475],
                              [-0.5406,  0.6173, -0.9155, -0.7560,  1.3850],
                              [-1.7609,  0.5079,  0.0522,  0.4386,  0.0854],
                              [ 0.0806, -2.1930,  0.7717,  0.3758, -0.9107]],
                    
                             [[-0.3009,  1.7650,  0.0811,  1.0014, -1.2704],
                              [-0.5530, -0.2263,  0.4376,  0.9503, -0.3118],
                              [-0.8690, -1.4592, -1.3738,  0.0543,  1.7000],
                              [ 1.1483, -1.3858,  0.6293, -0.1860, -1.5392]],
                    
                             [[-0.8780,  0.2691, -1.2989, -0.3519,  0.7185],
                              [-1.4287,  0.4811,  1.0223, -0.3136, -2.3332],
                              [ 0.9769, -0.4040, -1.3943, -0.1041,  0.7485],
                              [ 0.2314, -0.1190, -0.1268, -0.3399, -0.9422]]]])
)
mean0 & mean1: tensor(1.3341) tensor(1.3341)


even_index_a0: tensor([[[0.0530, 0.1316, 0.3868, 0.0582, 0.0592, 0.2844, 0.7060, 0.7104],
         [0.6430, 0.4201, 0.3315, 0.5621, 0.4383, 0.7077, 0.8652, 0.7994]],

        [[0.0653, 0.4086, 0.6234, 0.7911, 0.9723, 0.0044, 0.9084, 0.6511],
         [0.9259, 0.2673, 0.2197, 0.2338, 0.8283, 0.3773, 0.8736, 0.9157]],

        [[0.1737, 0.6593, 0.8535, 0.4953, 0.6238, 0.9004, 0.8182, 0.2689],
         [0.0411, 0.4282, 0.6774, 0.0315, 0.7757, 0.9708, 0.9896, 0.7082]],

        [[0.7990, 0.3954, 0.2579, 0.3021, 0.7232, 0.1034, 0.9367, 0.8425],
         [0.6414, 0.4988, 0.6573, 0.6941, 0.3207, 0.3858, 0.7914, 0.9150]]])
even_index_a1: tensor([[[0.0530, 0.1316, 0.3868, 0.0582, 0.0592, 0.2844, 0.7060, 0.7104],
         [0.6430, 0.4201, 0.3315, 0.5621, 0.4383, 0.7077, 0.8652, 0.7994]],

        [[0.0653, 0.4086, 0.6234, 0.7911, 0.9723, 0.0044, 0.9084, 0.6511],
         [0.9259, 0.2673, 0.2197, 0.2338, 0.8283, 0.3773, 0.8736, 0.9157]],

        [[0.1737, 0.6593, 0.8535, 0.4953, 0.6238, 0.9004, 0.8182, 0.2689],
         [0.0411, 0.4282, 0.6774, 0.0315, 0.7757, 0.9708, 0.9896, 0.7082]],

        [[0.7990, 0.3954, 0.2579, 0.3021, 0.7232, 0.1034, 0.9367, 0.8425],
         [0.6414, 0.4988, 0.6573, 0.6941, 0.3207, 0.3858, 0.7914, 0.9150]]])
<Size 0x7fe6548e9f90>
├── 'a' --> torch.Size([1, 3, 8])
├── 'b' --> torch.Size([1, 6])
└── 'c' --> <Size 0x7fe654a9a110>
    ├── 'd' --> torch.Size([1, 1])
    └── 'noise' --> torch.Size([1, 3, 4, 5])

The implement with treetensor API is much simpler and clearer.