Customized Operations For Different Fields

Here is another example of the custom operations implemented with both native torch API and treetensor API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import copy

import torch

import treetensor.torch as ttorch

T, B = 3, 4


def with_nativetensor(batch_):
    mean_b_list = []
    even_index_a_list = []
    for i in range(len(batch_)):
        for k, v in batch_[i].items():
            if k == 'a':
                v = v.float()
                even_index_a_list.append(v[::2])
            elif k == 'b':
                v = v.float()
                transformed_v = torch.pow(v, 2) + 1.0
                mean_b_list.append(transformed_v.mean())
            elif k == 'c':
                for k1, v1 in v.items():
                    if k1 == 'd':
                        v1 = v1.float()
                    else:
                        print('ignore keys: {}'.format(k1))
            else:
                print('ignore keys: {}'.format(k))
    for i in range(len(batch_)):
        for k in batch_[i].keys():
            if k == 'd':
                batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5))

    mean_b = sum(mean_b_list) / len(mean_b_list)
    even_index_a = torch.stack(even_index_a_list, dim=0)
    return batch_, mean_b, even_index_a


def with_treetensor(batch_):
    batch_ = [ttorch.tensor(b) for b in batch_]
    batch_ = ttorch.stack(batch_)
    batch_ = batch_.float()
    batch_.b = ttorch.pow(batch_.b, 2) + 1.0
    batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5))
    mean_b = batch_.b.mean()
    even_index_a = batch_.a[:, ::2]
    batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0)
    return batch_, mean_b, even_index_a


def get_data():
    return {
        'a': torch.rand(size=(T, 8)),
        'b': torch.rand(size=(6,)),
        'c': {
            'd': torch.randint(0, 10, size=(1,))
        }
    }


if __name__ == "__main__":
    batch = [get_data() for _ in range(B)]
    batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch))
    batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch))
    print(batch0)
    print('\n\n')
    print(batch1)

    assert torch.abs(mean0 - mean1) < 1e-6
    print('mean0 & mean1:', mean0, mean1)
    print('\n')

    assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6
    print('even_index_a0:', even_index_a0)
    print('even_index_a1:', even_index_a1)

    assert len(batch0) == B
    assert len(batch1) == B
    assert isinstance(batch1[0], ttorch.Tensor)
    print(batch1[0].shape)

The output should be like below, and all the assertions can be passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
[{'a': tensor([[0.9257, 0.4266, 0.8127, 0.9392, 0.2294, 0.8336, 0.7605, 0.9144],
        [0.0884, 0.3489, 0.4901, 0.6910, 0.8835, 0.0469, 0.2055, 0.1559],
        [0.9612, 0.0284, 0.2305, 0.2726, 0.2317, 0.6156, 0.5317, 0.9245]]), 'b': tensor([0.3955, 0.9165, 0.3444, 0.2769, 0.7824, 0.0054]), 'c': {'d': tensor([7])}}, {'a': tensor([[0.3599, 0.8606, 0.2802, 0.7488, 0.1605, 0.6629, 0.0725, 0.8811],
        [0.8521, 0.8738, 0.1401, 0.9934, 0.8279, 0.9412, 0.4467, 0.4728],
        [0.1503, 0.4891, 0.8713, 0.5262, 0.4487, 0.4566, 0.2726, 0.5312]]), 'b': tensor([0.5027, 0.8534, 0.4864, 0.6362, 0.3039, 0.9323]), 'c': {'d': tensor([2])}}, {'a': tensor([[0.5323, 0.4024, 0.1017, 0.1310, 0.8797, 0.5557, 0.9597, 0.3834],
        [0.2552, 0.9185, 0.1508, 0.9390, 0.4087, 0.4922, 0.4747, 0.8337],
        [0.8026, 0.3843, 0.5243, 0.1324, 0.9962, 0.7263, 0.5565, 0.1281]]), 'b': tensor([0.3252, 0.8908, 0.5887, 0.1604, 0.2188, 0.1202]), 'c': {'d': tensor([9])}}, {'a': tensor([[0.2808, 0.5135, 0.2673, 0.1119, 0.1685, 0.3233, 0.4603, 0.2135],
        [0.3917, 0.5572, 0.5263, 0.4482, 0.3145, 0.9395, 0.6414, 0.6244],
        [0.6007, 0.0662, 0.8816, 0.1150, 0.7659, 0.0525, 0.6582, 0.0748]]), 'b': tensor([0.8992, 0.9398, 0.5766, 0.4901, 0.8376, 0.2124]), 'c': {'d': tensor([4])}}]



(<Tensor 0x7fb1d42a0f90>
├── 'a' --> tensor([[[0.9257, 0.4266, 0.8127, 0.9392, 0.2294, 0.8336, 0.7605, 0.9144],
│                    [0.0884, 0.3489, 0.4901, 0.6910, 0.8835, 0.0469, 0.2055, 0.1559],
│                    [0.9612, 0.0284, 0.2305, 0.2726, 0.2317, 0.6156, 0.5317, 0.9245]]])
├── 'b' --> tensor([[1.1564, 1.8400, 1.1186, 1.0767, 1.6121, 1.0000]])
└── 'c' --> <Tensor 0x7fb1d42a0f50>
    ├── 'd' --> tensor([[7.]])
    └── 'noise' --> tensor([[[[ 0.3579, -0.6408,  0.9409, -1.4434,  0.4311],
                              [-1.1070,  0.4677,  1.0348,  0.3771, -1.8838],
                              [-2.8118, -1.0437, -0.9295, -1.2843, -0.3687],
                              [ 0.2445, -0.9614, -0.4345,  1.2057, -1.0109]],
                    
                             [[ 0.8215, -1.5672,  0.3532, -0.4932, -0.1338],
                              [-0.7677,  2.4574,  1.1538, -1.0093,  0.3429],
                              [-0.5387, -0.2103, -0.6103, -0.0457, -0.1621],
                              [-0.8051,  1.4865, -1.0652,  0.6543, -0.6310]],
                    
                             [[-0.7165,  0.4523,  0.3861,  0.4409,  0.6224],
                              [-0.8699,  0.6322,  0.5477, -1.9075,  0.3583],
                              [ 0.5514,  0.1888,  0.1900,  1.1220, -1.4074],
                              [ 1.1672,  0.4753,  1.9815, -0.7217,  1.5522]]]])
, <Tensor 0x7fb1d42a7050>
├── 'a' --> tensor([[[0.3599, 0.8606, 0.2802, 0.7488, 0.1605, 0.6629, 0.0725, 0.8811],
│                    [0.8521, 0.8738, 0.1401, 0.9934, 0.8279, 0.9412, 0.4467, 0.4728],
│                    [0.1503, 0.4891, 0.8713, 0.5262, 0.4487, 0.4566, 0.2726, 0.5312]]])
├── 'b' --> tensor([[1.2527, 1.7283, 1.2365, 1.4048, 1.0923, 1.8691]])
└── 'c' --> <Tensor 0x7fb1d42a0fd0>
    ├── 'd' --> tensor([[2.]])
    └── 'noise' --> tensor([[[[ 1.0493,  1.0225, -1.2554, -0.0225, -0.5568],
                              [ 0.8382,  1.3356,  0.7029, -0.1023,  0.1161],
                              [ 0.8582,  0.0081, -1.4857,  0.1620,  0.0081],
                              [ 0.1524,  1.2400, -1.3443,  0.3112,  0.0590]],
                    
                             [[-0.2105,  1.4063, -0.7215, -0.8753, -1.2693],
                              [-0.0757, -2.0197,  1.2485, -0.0855, -1.3478],
                              [-1.1685,  0.0979,  0.5640,  0.0359, -0.6250],
                              [ 1.0069, -0.0556, -1.4091, -0.5358,  0.5204]],
                    
                             [[ 0.8484,  2.1689, -3.0154,  0.1979,  0.7170],
                              [ 0.8199, -0.0033,  1.9158, -0.2894, -0.4480],
                              [-0.8988, -0.9611, -0.4975, -0.8224, -0.3678],
                              [ 0.3620, -0.8372, -0.9729,  0.3145, -0.1403]]]])
, <Tensor 0x7fb1d42a70d0>
├── 'a' --> tensor([[[0.5323, 0.4024, 0.1017, 0.1310, 0.8797, 0.5557, 0.9597, 0.3834],
│                    [0.2552, 0.9185, 0.1508, 0.9390, 0.4087, 0.4922, 0.4747, 0.8337],
│                    [0.8026, 0.3843, 0.5243, 0.1324, 0.9962, 0.7263, 0.5565, 0.1281]]])
├── 'b' --> tensor([[1.1057, 1.7935, 1.3466, 1.0257, 1.0479, 1.0144]])
└── 'c' --> <Tensor 0x7fb1d42a7090>
    ├── 'd' --> tensor([[9.]])
    └── 'noise' --> tensor([[[[-1.3456, -0.2045, -2.2172,  1.5476, -0.8317],
                              [ 1.1984, -0.4774, -1.4471,  0.8426,  0.0722],
                              [ 0.2040, -0.5946, -0.7240,  1.0343, -1.3400],
                              [ 0.2246, -0.8020,  2.1181,  1.0530, -0.5405]],
                    
                             [[-0.4242,  1.0609,  1.3618, -1.0542, -1.9754],
                              [-1.0875, -0.3019,  1.0196,  0.2886, -1.7909],
                              [-0.1305, -0.9066,  1.0717,  0.0406,  0.1890],
                              [-1.2663, -0.1013, -0.1627,  1.5421,  0.1341]],
                    
                             [[ 1.5481, -0.5824, -0.4398, -0.3064, -0.8305],
                              [-0.1520,  0.0735, -1.7671, -1.7186, -0.4571],
                              [ 0.0132, -0.4691,  0.5066,  1.5784,  0.6487],
                              [-0.1768, -1.6254,  1.4006,  0.8015, -1.0122]]]])
, <Tensor 0x7fb1d42a7150>
├── 'a' --> tensor([[[0.2808, 0.5135, 0.2673, 0.1119, 0.1685, 0.3233, 0.4603, 0.2135],
│                    [0.3917, 0.5572, 0.5263, 0.4482, 0.3145, 0.9395, 0.6414, 0.6244],
│                    [0.6007, 0.0662, 0.8816, 0.1150, 0.7659, 0.0525, 0.6582, 0.0748]]])
├── 'b' --> tensor([[1.8086, 1.8831, 1.3325, 1.2402, 1.7015, 1.0451]])
└── 'c' --> <Tensor 0x7fb1d42a7110>
    ├── 'd' --> tensor([[4.]])
    └── 'noise' --> tensor([[[[ 1.9825, -1.0632, -0.1681,  0.3218, -0.5951],
                              [ 0.0681, -0.7200, -1.4380, -1.4024,  0.9112],
                              [ 0.0235, -0.0749, -1.0119,  2.9069, -1.0269],
                              [ 0.0671,  1.7382, -0.1251,  0.2279, -1.8300]],
                    
                             [[-0.2455,  0.3796,  1.0014,  2.2101, -0.6025],
                              [ 0.5893,  0.2538, -0.6225,  0.0841,  0.8188],
                              [ 1.2263,  0.2263, -0.0100, -0.9343,  0.3980],
                              [ 0.1762,  0.4784, -0.3005, -0.1692,  0.2834]],
                    
                             [[ 0.8008,  0.8531,  0.1050, -0.1336, -1.3390],
                              [-1.8002, -0.1440,  1.2495,  0.1844,  1.5687],
                              [-1.2012,  0.1109, -0.5724,  0.0432, -1.4930],
                              [-1.5599,  1.4164,  0.6822, -1.3556,  0.9604]]]])
)
mean0 & mean1: tensor(1.3639) tensor(1.3639)


even_index_a0: tensor([[[0.9257, 0.4266, 0.8127, 0.9392, 0.2294, 0.8336, 0.7605, 0.9144],
         [0.9612, 0.0284, 0.2305, 0.2726, 0.2317, 0.6156, 0.5317, 0.9245]],

        [[0.3599, 0.8606, 0.2802, 0.7488, 0.1605, 0.6629, 0.0725, 0.8811],
         [0.1503, 0.4891, 0.8713, 0.5262, 0.4487, 0.4566, 0.2726, 0.5312]],

        [[0.5323, 0.4024, 0.1017, 0.1310, 0.8797, 0.5557, 0.9597, 0.3834],
         [0.8026, 0.3843, 0.5243, 0.1324, 0.9962, 0.7263, 0.5565, 0.1281]],

        [[0.2808, 0.5135, 0.2673, 0.1119, 0.1685, 0.3233, 0.4603, 0.2135],
         [0.6007, 0.0662, 0.8816, 0.1150, 0.7659, 0.0525, 0.6582, 0.0748]]])
even_index_a1: tensor([[[0.9257, 0.4266, 0.8127, 0.9392, 0.2294, 0.8336, 0.7605, 0.9144],
         [0.9612, 0.0284, 0.2305, 0.2726, 0.2317, 0.6156, 0.5317, 0.9245]],

        [[0.3599, 0.8606, 0.2802, 0.7488, 0.1605, 0.6629, 0.0725, 0.8811],
         [0.1503, 0.4891, 0.8713, 0.5262, 0.4487, 0.4566, 0.2726, 0.5312]],

        [[0.5323, 0.4024, 0.1017, 0.1310, 0.8797, 0.5557, 0.9597, 0.3834],
         [0.8026, 0.3843, 0.5243, 0.1324, 0.9962, 0.7263, 0.5565, 0.1281]],

        [[0.2808, 0.5135, 0.2673, 0.1119, 0.1685, 0.3233, 0.4603, 0.2135],
         [0.6007, 0.0662, 0.8816, 0.1150, 0.7659, 0.0525, 0.6582, 0.0748]]])
<Size 0x7fb1d42a0e90>
├── 'a' --> torch.Size([1, 3, 8])
├── 'b' --> torch.Size([1, 6])
└── 'c' --> <Size 0x7fb1d42a0e50>
    ├── 'd' --> torch.Size([1, 1])
    └── 'noise' --> torch.Size([1, 3, 4, 5])

The implement with treetensor API is much simpler and clearer.