Customized Operations For Different Fields

Here is another example of the custom operations implemented with both native torch API and treetensor API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import copy

import torch

import treetensor.torch as ttorch

T, B = 3, 4


def with_nativetensor(batch_):
    mean_b_list = []
    even_index_a_list = []
    for i in range(len(batch_)):
        for k, v in batch_[i].items():
            if k == 'a':
                v = v.float()
                even_index_a_list.append(v[::2])
            elif k == 'b':
                v = v.float()
                transformed_v = torch.pow(v, 2) + 1.0
                mean_b_list.append(transformed_v.mean())
            elif k == 'c':
                for k1, v1 in v.items():
                    if k1 == 'd':
                        v1 = v1.float()
                    else:
                        print('ignore keys: {}'.format(k1))
            else:
                print('ignore keys: {}'.format(k))
    for i in range(len(batch_)):
        for k in batch_[i].keys():
            if k == 'd':
                batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5))

    mean_b = sum(mean_b_list) / len(mean_b_list)
    even_index_a = torch.stack(even_index_a_list, dim=0)
    return batch_, mean_b, even_index_a


def with_treetensor(batch_):
    batch_ = [ttorch.tensor(b) for b in batch_]
    batch_ = ttorch.stack(batch_)
    batch_ = batch_.float()
    batch_.b = ttorch.pow(batch_.b, 2) + 1.0
    batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5))
    mean_b = batch_.b.mean()
    even_index_a = batch_.a[:, ::2]
    batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0)
    return batch_, mean_b, even_index_a


def get_data():
    return {
        'a': torch.rand(size=(T, 8)),
        'b': torch.rand(size=(6,)),
        'c': {
            'd': torch.randint(0, 10, size=(1,))
        }
    }


if __name__ == "__main__":
    batch = [get_data() for _ in range(B)]
    batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch))
    batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch))
    print(batch0)
    print('\n\n')
    print(batch1)

    assert torch.abs(mean0 - mean1) < 1e-6
    print('mean0 & mean1:', mean0, mean1)
    print('\n')

    assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6
    print('even_index_a0:', even_index_a0)
    print('even_index_a1:', even_index_a1)

    assert len(batch0) == B
    assert len(batch1) == B
    assert isinstance(batch1[0], ttorch.Tensor)
    print(batch1[0].shape)

The output should be like below, and all the assertions can be passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
[{'a': tensor([[0.4959, 0.7429, 0.3304, 0.3691, 0.2602, 0.0217, 0.8756, 0.2599],
        [0.4695, 0.0128, 0.3597, 0.3050, 0.5799, 0.4451, 0.2758, 0.9532],
        [0.4889, 0.1749, 0.6407, 0.1816, 0.6972, 0.6399, 0.5626, 0.0948]]), 'b': tensor([0.6954, 0.2287, 0.6890, 0.9973, 0.1034, 0.9159]), 'c': {'d': tensor([4])}}, {'a': tensor([[0.6206, 0.3504, 0.5032, 0.3248, 0.7497, 0.3002, 0.2543, 0.1592],
        [0.1785, 0.7591, 0.3378, 0.8237, 0.6426, 0.2236, 0.3023, 0.7932],
        [0.0331, 0.4403, 0.9965, 0.7609, 0.7818, 0.3898, 0.0921, 0.2410]]), 'b': tensor([0.1389, 0.0367, 0.2335, 0.9589, 0.2701, 0.7552]), 'c': {'d': tensor([2])}}, {'a': tensor([[0.0787, 0.1154, 0.9513, 0.5076, 0.4105, 0.8078, 0.4976, 0.7081],
        [0.4414, 0.3664, 0.6787, 0.2895, 0.4783, 0.3019, 0.8096, 0.3731],
        [0.3586, 0.8682, 0.0797, 0.9027, 0.5423, 0.4816, 0.4864, 0.5681]]), 'b': tensor([0.3378, 0.7603, 0.1430, 0.2937, 0.0280, 0.7042]), 'c': {'d': tensor([1])}}, {'a': tensor([[0.7616, 0.8400, 0.8517, 0.3064, 0.9760, 0.8378, 0.3704, 0.7836],
        [0.5377, 0.6733, 0.5208, 0.8555, 0.8346, 0.0799, 0.4971, 0.3771],
        [0.5743, 0.6112, 0.0894, 0.4634, 0.2029, 0.5474, 0.6798, 0.5255]]), 'b': tensor([0.9623, 0.0927, 0.7332, 0.0350, 0.9289, 0.6587]), 'c': {'d': tensor([7])}}]



(<Tensor 0x7fbed975de90>
├── 'a' --> tensor([[[0.4959, 0.7429, 0.3304, 0.3691, 0.2602, 0.0217, 0.8756, 0.2599],
│                    [0.4695, 0.0128, 0.3597, 0.3050, 0.5799, 0.4451, 0.2758, 0.9532],
│                    [0.4889, 0.1749, 0.6407, 0.1816, 0.6972, 0.6399, 0.5626, 0.0948]]])
├── 'b' --> tensor([[1.4836, 1.0523, 1.4748, 1.9947, 1.0107, 1.8388]])
└── 'c' --> <Tensor 0x7fbed975de50>
    ├── 'd' --> tensor([[4.]])
    └── 'noise' --> tensor([[[[ 1.2098, -1.1480, -3.0458, -0.1071, -1.1395],
                              [ 1.6361,  1.3693,  0.1426,  0.1702,  0.5653],
                              [ 0.1225, -0.6091,  0.4540,  2.5245,  0.2460],
                              [-1.1039,  1.3527, -0.2761, -0.3291, -0.0366]],
                    
                             [[-1.4353,  0.4943,  0.9781, -0.0853,  0.4413],
                              [ 0.5125,  0.6533,  0.3187, -0.6209, -1.3550],
                              [-0.9656, -0.1417, -1.7543, -0.9261, -0.4421],
                              [ 0.3440, -0.2400,  1.1507, -0.2864,  1.7387]],
                    
                             [[ 1.6043,  1.0504,  0.4217,  0.6256,  0.1901],
                              [ 2.0248, -0.5585, -0.2970, -2.2687, -2.1299],
                              [-1.3107, -0.1643,  0.5487, -0.1616,  0.6969],
                              [ 0.1701,  2.2368, -0.8534, -0.0991, -0.2608]]]])
, <Tensor 0x7fbed975df10>
├── 'a' --> tensor([[[0.6206, 0.3504, 0.5032, 0.3248, 0.7497, 0.3002, 0.2543, 0.1592],
│                    [0.1785, 0.7591, 0.3378, 0.8237, 0.6426, 0.2236, 0.3023, 0.7932],
│                    [0.0331, 0.4403, 0.9965, 0.7609, 0.7818, 0.3898, 0.0921, 0.2410]]])
├── 'b' --> tensor([[1.0193, 1.0013, 1.0545, 1.9196, 1.0730, 1.5704]])
└── 'c' --> <Tensor 0x7fbed975ded0>
    ├── 'd' --> tensor([[2.]])
    └── 'noise' --> tensor([[[[-0.8716,  0.3982, -0.4103,  1.0145,  0.2366],
                              [ 0.1163, -1.5097,  0.8837,  0.6072,  1.3204],
                              [-0.7921, -0.3152, -0.4015, -1.6218,  0.1952],
                              [-0.0809, -1.2087, -1.4225,  0.4099,  0.5188]],
                    
                             [[ 0.1643,  0.2732,  0.2644, -2.1391, -0.7504],
                              [ 0.9021,  0.2958,  0.0264,  0.2811, -0.6561],
                              [-2.3868, -1.0131, -0.3284,  0.2162, -1.9766],
                              [-0.4745,  2.6729, -0.1767,  2.0280, -0.6456]],
                    
                             [[ 0.8487,  0.9021,  2.1113, -0.8008,  0.9089],
                              [ 0.4937, -1.3294, -1.5853,  0.6095, -0.2106],
                              [ 0.8631, -1.4816, -1.3472, -0.9508, -0.1390],
                              [-0.8438, -1.1405, -0.5679,  0.0308,  0.2206]]]])
, <Tensor 0x7fbed975df90>
├── 'a' --> tensor([[[0.0787, 0.1154, 0.9513, 0.5076, 0.4105, 0.8078, 0.4976, 0.7081],
│                    [0.4414, 0.3664, 0.6787, 0.2895, 0.4783, 0.3019, 0.8096, 0.3731],
│                    [0.3586, 0.8682, 0.0797, 0.9027, 0.5423, 0.4816, 0.4864, 0.5681]]])
├── 'b' --> tensor([[1.1141, 1.5781, 1.0204, 1.0863, 1.0008, 1.4959]])
└── 'c' --> <Tensor 0x7fbed975df50>
    ├── 'd' --> tensor([[1.]])
    └── 'noise' --> tensor([[[[-0.8661,  2.1075, -0.8831,  0.0351,  0.8010],
                              [ 0.2314, -1.7201,  1.1340,  0.0306,  0.2898],
                              [-0.2021,  0.1031, -1.1276,  0.9147,  1.1242],
                              [ 0.1915, -1.2295, -0.0290,  0.8289, -2.2935]],
                    
                             [[ 0.3192,  0.1709,  0.6136,  1.0401,  1.6776],
                              [ 0.6266, -0.4175, -0.2612, -0.3719, -1.8482],
                              [ 0.4590, -0.8227,  1.4691, -0.5683, -1.0901],
                              [-0.4111, -0.8091, -1.1143, -0.8473, -0.5122]],
                    
                             [[-1.6365, -1.0635, -1.2169,  0.3649, -1.6512],
                              [ 1.0775, -1.9654,  0.0070,  0.6785,  0.0523],
                              [-1.9686, -0.6915, -0.3784, -0.1928, -0.5714],
                              [-1.6885,  0.6324,  0.3805,  0.8414, -1.1992]]]])
, <Tensor 0x7fbed9764050>
├── 'a' --> tensor([[[0.7616, 0.8400, 0.8517, 0.3064, 0.9760, 0.8378, 0.3704, 0.7836],
│                    [0.5377, 0.6733, 0.5208, 0.8555, 0.8346, 0.0799, 0.4971, 0.3771],
│                    [0.5743, 0.6112, 0.0894, 0.4634, 0.2029, 0.5474, 0.6798, 0.5255]]])
├── 'b' --> tensor([[1.9260, 1.0086, 1.5375, 1.0012, 1.8628, 1.4339]])
└── 'c' --> <Tensor 0x7fbed975dfd0>
    ├── 'd' --> tensor([[7.]])
    └── 'noise' --> tensor([[[[ 0.2398, -1.1715, -0.5465,  0.5210, -1.9378],
                              [ 1.8088,  1.1304,  1.1054, -0.1605,  0.1069],
                              [-0.3182,  1.1913, -0.8778, -0.2761, -1.9641],
                              [-1.0521,  0.8548,  0.5549,  0.7722,  0.4085]],
                    
                             [[-0.8001,  0.5689,  1.0019, -0.0062, -0.9012],
                              [ 0.1281,  0.4304,  0.6782, -1.2555, -0.8217],
                              [ 0.3796,  0.6830, -2.9541, -0.6279,  0.0266],
                              [-1.2134, -1.5165,  1.7040, -0.9219,  0.5670]],
                    
                             [[ 0.0077, -0.3625, -0.4056,  0.2789,  0.5481],
                              [-1.1109, -0.3867, -0.9889,  0.7654, -0.1190],
                              [ 1.6151, -0.7024, -0.0453,  1.0142,  2.8909],
                              [ 0.7332,  1.2631,  1.1594,  0.5931, -1.5527]]]])
)
mean0 & mean1: tensor(1.3566) tensor(1.3566)


even_index_a0: tensor([[[0.4959, 0.7429, 0.3304, 0.3691, 0.2602, 0.0217, 0.8756, 0.2599],
         [0.4889, 0.1749, 0.6407, 0.1816, 0.6972, 0.6399, 0.5626, 0.0948]],

        [[0.6206, 0.3504, 0.5032, 0.3248, 0.7497, 0.3002, 0.2543, 0.1592],
         [0.0331, 0.4403, 0.9965, 0.7609, 0.7818, 0.3898, 0.0921, 0.2410]],

        [[0.0787, 0.1154, 0.9513, 0.5076, 0.4105, 0.8078, 0.4976, 0.7081],
         [0.3586, 0.8682, 0.0797, 0.9027, 0.5423, 0.4816, 0.4864, 0.5681]],

        [[0.7616, 0.8400, 0.8517, 0.3064, 0.9760, 0.8378, 0.3704, 0.7836],
         [0.5743, 0.6112, 0.0894, 0.4634, 0.2029, 0.5474, 0.6798, 0.5255]]])
even_index_a1: tensor([[[0.4959, 0.7429, 0.3304, 0.3691, 0.2602, 0.0217, 0.8756, 0.2599],
         [0.4889, 0.1749, 0.6407, 0.1816, 0.6972, 0.6399, 0.5626, 0.0948]],

        [[0.6206, 0.3504, 0.5032, 0.3248, 0.7497, 0.3002, 0.2543, 0.1592],
         [0.0331, 0.4403, 0.9965, 0.7609, 0.7818, 0.3898, 0.0921, 0.2410]],

        [[0.0787, 0.1154, 0.9513, 0.5076, 0.4105, 0.8078, 0.4976, 0.7081],
         [0.3586, 0.8682, 0.0797, 0.9027, 0.5423, 0.4816, 0.4864, 0.5681]],

        [[0.7616, 0.8400, 0.8517, 0.3064, 0.9760, 0.8378, 0.3704, 0.7836],
         [0.5743, 0.6112, 0.0894, 0.4634, 0.2029, 0.5474, 0.6798, 0.5255]]])
<Size 0x7fbed975dd90>
├── 'a' --> torch.Size([1, 3, 8])
├── 'b' --> torch.Size([1, 6])
└── 'c' --> <Size 0x7fbed975dd50>
    ├── 'd' --> torch.Size([1, 1])
    └── 'noise' --> torch.Size([1, 3, 4, 5])

The implement with treetensor API is much simpler and clearer.