Customized Operations For Different Fields¶
Here is another example of the custom operations implemented with both native torch API and treetensor API.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import copy import torch import treetensor.torch as ttorch T, B = 3, 4 def with_nativetensor(batch_): mean_b_list = [] even_index_a_list = [] for i in range(len(batch_)): for k, v in batch_[i].items(): if k == 'a': v = v.float() even_index_a_list.append(v[::2]) elif k == 'b': v = v.float() transformed_v = torch.pow(v, 2) + 1.0 mean_b_list.append(transformed_v.mean()) elif k == 'c': for k1, v1 in v.items(): if k1 == 'd': v1 = v1.float() else: print('ignore keys: {}'.format(k1)) else: print('ignore keys: {}'.format(k)) for i in range(len(batch_)): for k in batch_[i].keys(): if k == 'd': batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5)) mean_b = sum(mean_b_list) / len(mean_b_list) even_index_a = torch.stack(even_index_a_list, dim=0) return batch_, mean_b, even_index_a def with_treetensor(batch_): batch_ = [ttorch.tensor(b) for b in batch_] batch_ = ttorch.stack(batch_) batch_ = batch_.float() batch_.b = ttorch.pow(batch_.b, 2) + 1.0 batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5)) mean_b = batch_.b.mean() even_index_a = batch_.a[:, ::2] batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0) return batch_, mean_b, even_index_a def get_data(): return { 'a': torch.rand(size=(T, 8)), 'b': torch.rand(size=(6,)), 'c': { 'd': torch.randint(0, 10, size=(1,)) } } if __name__ == "__main__": batch = [get_data() for _ in range(B)] batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch)) batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch)) print(batch0) print('\n\n') print(batch1) assert torch.abs(mean0 - mean1) < 1e-6 print('mean0 & mean1:', mean0, mean1) print('\n') assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6 print('even_index_a0:', even_index_a0) print('even_index_a1:', even_index_a1) assert len(batch0) == B assert len(batch1) == B assert isinstance(batch1[0], ttorch.Tensor) print(batch1[0].shape) |
The output should be like below, and all the assertions can be passed.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | [{'a': tensor([[0.6370, 0.5287, 0.8719, 0.8738, 0.8669, 0.0233, 0.4351, 0.3231], [0.0673, 0.8459, 0.8704, 0.0062, 0.2567, 0.2812, 0.9117, 0.9412], [0.0644, 0.1503, 0.6570, 0.0957, 0.1638, 0.7742, 0.9270, 0.6866]]), 'b': tensor([0.9541, 0.0955, 0.8896, 0.1709, 0.4578, 0.9287]), 'c': {'d': tensor([5])}}, {'a': tensor([[0.6236, 0.2665, 0.3845, 0.1008, 0.1871, 0.0125, 0.1407, 0.8828], [0.4517, 0.4068, 0.1777, 0.8897, 0.2485, 0.1900, 0.8993, 0.3462], [0.8976, 0.8408, 0.2361, 0.9718, 0.7077, 0.0350, 0.1291, 0.7675]]), 'b': tensor([0.6229, 0.7557, 0.2653, 0.8749, 0.8205, 0.2778]), 'c': {'d': tensor([1])}}, {'a': tensor([[0.0803, 0.0747, 0.2994, 0.3551, 0.2325, 0.2029, 0.0963, 0.3583], [0.4929, 0.8611, 0.0243, 0.5447, 0.4929, 0.4975, 0.6029, 0.9135], [0.5109, 0.3885, 0.7017, 0.5907, 0.3846, 0.3243, 0.7258, 0.9102]]), 'b': tensor([0.6293, 0.8360, 0.0842, 0.0263, 0.5059, 0.5398]), 'c': {'d': tensor([0])}}, {'a': tensor([[0.5590, 0.5996, 0.9398, 0.4532, 0.7030, 0.0217, 0.5259, 0.0566], [0.1255, 0.7498, 0.4983, 0.3020, 0.4932, 0.7532, 0.5629, 0.0898], [0.2531, 0.5237, 0.3917, 0.0051, 0.9483, 0.1334, 0.5323, 0.3156]]), 'b': tensor([0.2050, 0.8358, 0.5785, 0.6644, 0.5952, 0.6730]), 'c': {'d': tensor([4])}}] (<Tensor 0x7f7dffccc090> ├── 'a' --> tensor([[[0.6370, 0.5287, 0.8719, 0.8738, 0.8669, 0.0233, 0.4351, 0.3231], │ [0.0673, 0.8459, 0.8704, 0.0062, 0.2567, 0.2812, 0.9117, 0.9412], │ [0.0644, 0.1503, 0.6570, 0.0957, 0.1638, 0.7742, 0.9270, 0.6866]]]) ├── 'b' --> tensor([[1.9103, 1.0091, 1.7913, 1.0292, 1.2096, 1.8625]]) └── 'c' --> <Tensor 0x7f7dffccc050> ├── 'd' --> tensor([[5.]]) └── 'noise' --> tensor([[[[ 0.4562, -0.4098, 0.9975, -1.5102, 0.4956], [ 0.9847, 1.0742, 0.5543, 2.2826, 2.0131], [ 0.0873, -1.0024, 0.0923, 0.2751, -0.4338], [ 0.6638, 1.0220, 1.3232, -0.6142, -0.3566]], [[ 1.2246, 0.8735, 0.0900, 1.1970, 0.5007], [-0.7250, -0.9558, 0.1337, -0.0758, -0.2821], [-0.1152, -0.4845, 0.7222, 1.7570, -1.8673], [ 1.2254, -0.2409, -1.1709, -0.3074, -0.6063]], [[-0.6470, 0.3742, -0.8198, 0.9947, 0.2193], [-1.5520, 1.0254, -0.6323, 0.4202, 0.3086], [-1.8731, -0.8634, 1.8733, -1.1940, 0.7242], [-0.9276, 0.4131, -0.9818, -0.3848, -0.2841]]]]) , <Tensor 0x7f7dffccc110> ├── 'a' --> tensor([[[0.6236, 0.2665, 0.3845, 0.1008, 0.1871, 0.0125, 0.1407, 0.8828], │ [0.4517, 0.4068, 0.1777, 0.8897, 0.2485, 0.1900, 0.8993, 0.3462], │ [0.8976, 0.8408, 0.2361, 0.9718, 0.7077, 0.0350, 0.1291, 0.7675]]]) ├── 'b' --> tensor([[1.3880, 1.5711, 1.0704, 1.7655, 1.6732, 1.0772]]) └── 'c' --> <Tensor 0x7f7dffccc0d0> ├── 'd' --> tensor([[1.]]) └── 'noise' --> tensor([[[[ 1.6693, 1.3292, 1.1779, 0.1647, -0.3271], [-0.7005, 0.3124, 1.1802, -0.6941, -0.7028], [ 1.1857, -0.5633, 1.8195, 0.3823, -0.7769], [-0.3180, 0.7928, -0.4512, -0.6213, -0.2267]], [[-1.6616, -0.0282, -0.1339, 1.1325, -0.4535], [ 1.3824, -1.2600, -0.5668, -0.0589, -1.6768], [ 1.0283, -0.5619, -0.5195, 0.4863, -0.2197], [-0.0302, 1.4302, 1.4545, 0.4620, 0.2496]], [[ 0.0375, 1.7652, 0.1893, 0.2907, -0.3880], [ 0.8968, -0.6625, -2.0683, 1.4648, 0.5475], [-0.8635, 0.5427, 0.6779, 0.0429, -0.8031], [-0.0240, 0.2764, 1.1809, 1.8446, 0.5535]]]]) , <Tensor 0x7f7dffccc190> ├── 'a' --> tensor([[[0.0803, 0.0747, 0.2994, 0.3551, 0.2325, 0.2029, 0.0963, 0.3583], │ [0.4929, 0.8611, 0.0243, 0.5447, 0.4929, 0.4975, 0.6029, 0.9135], │ [0.5109, 0.3885, 0.7017, 0.5907, 0.3846, 0.3243, 0.7258, 0.9102]]]) ├── 'b' --> tensor([[1.3961, 1.6988, 1.0071, 1.0007, 1.2559, 1.2914]]) └── 'c' --> <Tensor 0x7f7dffccc150> ├── 'd' --> tensor([[0.]]) └── 'noise' --> tensor([[[[ 2.6231e-01, -7.3377e-01, -6.7766e-01, 6.3407e-01, -1.0650e+00], [-1.9808e+00, 1.7385e+00, 1.3419e-03, -1.0606e+00, -7.2179e-02], [ 2.5939e-01, 9.2158e-01, 1.1401e+00, 1.2967e+00, -2.1993e+00], [-6.4287e-01, -5.5683e-01, 1.0608e+00, 4.6332e-01, -1.7852e+00]], [[-6.8630e-01, -3.5486e-01, -7.1593e-01, -1.0573e+00, -1.9016e+00], [ 6.6695e-01, 2.2377e-01, -1.5779e+00, 6.3296e-01, -1.1480e+00], [-3.4535e-01, -9.3572e-01, -9.4721e-01, -6.5567e-02, -9.0092e-01], [-5.1379e-01, 9.2360e-01, 7.6557e-01, 3.7062e-01, -6.3338e-01]], [[ 1.4979e-01, -4.6819e-01, 5.6967e-01, 3.8858e-01, 2.0449e-01], [-4.2331e-01, 2.6956e+00, 1.0970e+00, 1.1910e-01, 2.9090e-01], [ 6.7524e-01, -3.5661e-01, -2.0881e-01, 1.7711e+00, -5.2344e-01], [ 5.0039e-01, 2.7028e-01, 4.6420e-01, 1.1946e+00, -3.8231e-01]]]]) , <Tensor 0x7f7dffccc210> ├── 'a' --> tensor([[[0.5590, 0.5996, 0.9398, 0.4532, 0.7030, 0.0217, 0.5259, 0.0566], │ [0.1255, 0.7498, 0.4983, 0.3020, 0.4932, 0.7532, 0.5629, 0.0898], │ [0.2531, 0.5237, 0.3917, 0.0051, 0.9483, 0.1334, 0.5323, 0.3156]]]) ├── 'b' --> tensor([[1.0420, 1.6985, 1.3346, 1.4415, 1.3543, 1.4530]]) └── 'c' --> <Tensor 0x7f7dffccc1d0> ├── 'd' --> tensor([[4.]]) └── 'noise' --> tensor([[[[ 6.4419e-01, -2.7867e-01, -2.0447e-01, 2.0649e-01, -9.7412e-01], [-8.1947e-01, 1.6657e-01, 1.0360e+00, -5.2035e-01, 1.5048e+00], [-1.2174e+00, -1.0475e+00, -7.3149e-01, -4.9862e-01, -4.1088e-03], [-1.7250e+00, -9.8784e-01, -1.0926e+00, -1.4733e-01, 1.2436e+00]], [[ 1.8212e+00, -9.0900e-02, 7.0699e-01, -1.1273e+00, 1.4671e-01], [-6.3425e-01, -4.8812e-01, 9.4056e-01, -1.3953e+00, -1.1121e+00], [-9.9047e-01, 6.5654e-01, 1.0493e-03, -2.7606e-01, 1.4477e+00], [-1.2463e+00, -1.3972e-01, 1.0830e+00, 3.8693e-01, -6.2379e-01]], [[ 3.6846e-01, 7.0519e-01, 1.0511e+00, 1.2688e+00, -3.8694e-01], [-3.0683e-01, 1.1013e-02, 7.5671e-01, -3.4247e-01, 1.0267e-01], [-1.2282e-01, -5.6973e-01, -5.5059e-01, -8.8279e-01, -8.9531e-01], [ 4.0072e-01, -1.1599e+00, 9.7310e-01, -9.1617e-01, -1.5341e+00]]]]) ) mean0 & mean1: tensor(1.3888) tensor(1.3888) even_index_a0: tensor([[[0.6370, 0.5287, 0.8719, 0.8738, 0.8669, 0.0233, 0.4351, 0.3231], [0.0644, 0.1503, 0.6570, 0.0957, 0.1638, 0.7742, 0.9270, 0.6866]], [[0.6236, 0.2665, 0.3845, 0.1008, 0.1871, 0.0125, 0.1407, 0.8828], [0.8976, 0.8408, 0.2361, 0.9718, 0.7077, 0.0350, 0.1291, 0.7675]], [[0.0803, 0.0747, 0.2994, 0.3551, 0.2325, 0.2029, 0.0963, 0.3583], [0.5109, 0.3885, 0.7017, 0.5907, 0.3846, 0.3243, 0.7258, 0.9102]], [[0.5590, 0.5996, 0.9398, 0.4532, 0.7030, 0.0217, 0.5259, 0.0566], [0.2531, 0.5237, 0.3917, 0.0051, 0.9483, 0.1334, 0.5323, 0.3156]]]) even_index_a1: tensor([[[0.6370, 0.5287, 0.8719, 0.8738, 0.8669, 0.0233, 0.4351, 0.3231], [0.0644, 0.1503, 0.6570, 0.0957, 0.1638, 0.7742, 0.9270, 0.6866]], [[0.6236, 0.2665, 0.3845, 0.1008, 0.1871, 0.0125, 0.1407, 0.8828], [0.8976, 0.8408, 0.2361, 0.9718, 0.7077, 0.0350, 0.1291, 0.7675]], [[0.0803, 0.0747, 0.2994, 0.3551, 0.2325, 0.2029, 0.0963, 0.3583], [0.5109, 0.3885, 0.7017, 0.5907, 0.3846, 0.3243, 0.7258, 0.9102]], [[0.5590, 0.5996, 0.9398, 0.4532, 0.7030, 0.0217, 0.5259, 0.0566], [0.2531, 0.5237, 0.3917, 0.0051, 0.9483, 0.1334, 0.5323, 0.3156]]]) <Size 0x7f7dffcc5bd0> ├── 'a' --> torch.Size([1, 3, 8]) ├── 'b' --> torch.Size([1, 6]) └── 'c' --> <Size 0x7f7e000317d0> ├── 'd' --> torch.Size([1, 1]) └── 'noise' --> torch.Size([1, 3, 4, 5]) |
The implement with treetensor API is much simpler and clearer.