Customized Operations For Different Fields¶
Here is another example of the custom operations implemented with both native torch API and treetensor API.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import copy import torch import treetensor.torch as ttorch T, B = 3, 4 def with_nativetensor(batch_): mean_b_list = [] even_index_a_list = [] for i in range(len(batch_)): for k, v in batch_[i].items(): if k == 'a': v = v.float() even_index_a_list.append(v[::2]) elif k == 'b': v = v.float() transformed_v = torch.pow(v, 2) + 1.0 mean_b_list.append(transformed_v.mean()) elif k == 'c': for k1, v1 in v.items(): if k1 == 'd': v1 = v1.float() else: print('ignore keys: {}'.format(k1)) else: print('ignore keys: {}'.format(k)) for i in range(len(batch_)): for k in batch_[i].keys(): if k == 'd': batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5)) mean_b = sum(mean_b_list) / len(mean_b_list) even_index_a = torch.stack(even_index_a_list, dim=0) return batch_, mean_b, even_index_a def with_treetensor(batch_): batch_ = [ttorch.tensor(b) for b in batch_] batch_ = ttorch.stack(batch_) batch_ = batch_.float() batch_.b = ttorch.pow(batch_.b, 2) + 1.0 batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5)) mean_b = batch_.b.mean() even_index_a = batch_.a[:, ::2] batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0) return batch_, mean_b, even_index_a def get_data(): return { 'a': torch.rand(size=(T, 8)), 'b': torch.rand(size=(6,)), 'c': { 'd': torch.randint(0, 10, size=(1,)) } } if __name__ == "__main__": batch = [get_data() for _ in range(B)] batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch)) batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch)) print(batch0) print('\n\n') print(batch1) assert torch.abs(mean0 - mean1) < 1e-6 print('mean0 & mean1:', mean0, mean1) print('\n') assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6 print('even_index_a0:', even_index_a0) print('even_index_a1:', even_index_a1) assert len(batch0) == B assert len(batch1) == B assert isinstance(batch1[0], ttorch.Tensor) print(batch1[0].shape) |
The output should be like below, and all the assertions can be passed.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | [{'a': tensor([[0.2528, 0.4177, 0.5927, 0.7439, 0.2610, 0.9414, 0.3309, 0.8179], [0.2426, 0.5285, 0.7025, 0.1007, 0.7717, 0.2163, 0.5488, 0.6384], [0.2763, 0.1325, 0.2971, 0.6536, 0.6680, 0.6659, 0.1535, 0.6603]]), 'b': tensor([0.8131, 0.1540, 0.3999, 0.9620, 0.0247, 0.5228]), 'c': {'d': tensor([3])}}, {'a': tensor([[0.0930, 0.9039, 0.7044, 0.9361, 0.9592, 0.0232, 0.8527, 0.8502], [0.9800, 0.5632, 0.7258, 0.5329, 0.4701, 0.6252, 0.6134, 0.4340], [0.6840, 0.1951, 0.2539, 0.0763, 0.3272, 0.9452, 0.6882, 0.1001]]), 'b': tensor([0.8683, 0.5246, 0.5320, 0.4137, 0.8252, 0.5588]), 'c': {'d': tensor([3])}}, {'a': tensor([[0.5771, 0.9907, 0.1408, 0.6211, 0.2264, 0.6072, 0.5538, 0.1231], [0.1719, 0.5936, 0.1351, 0.8051, 0.5816, 0.8645, 0.4902, 0.6489], [0.1467, 0.7644, 0.3178, 0.4616, 0.6666, 0.2651, 0.4480, 0.3887]]), 'b': tensor([0.4217, 0.0332, 0.2061, 0.5351, 0.7510, 0.2466]), 'c': {'d': tensor([7])}}, {'a': tensor([[0.9133, 0.8223, 0.1066, 0.8299, 0.9051, 0.4557, 0.0785, 0.0093], [0.8969, 0.3639, 0.0384, 0.0895, 0.5555, 0.6507, 0.6078, 0.4365], [0.1915, 0.2788, 0.9265, 0.0536, 0.4257, 0.9293, 0.6852, 0.8060]]), 'b': tensor([0.3471, 0.9971, 0.1674, 0.5123, 0.2763, 0.9870]), 'c': {'d': tensor([0])}}] (<Tensor 0x7f17f2932e50> ├── 'a' --> tensor([[[0.2528, 0.4177, 0.5927, 0.7439, 0.2610, 0.9414, 0.3309, 0.8179], │ [0.2426, 0.5285, 0.7025, 0.1007, 0.7717, 0.2163, 0.5488, 0.6384], │ [0.2763, 0.1325, 0.2971, 0.6536, 0.6680, 0.6659, 0.1535, 0.6603]]]) ├── 'b' --> tensor([[1.6611, 1.0237, 1.1599, 1.9255, 1.0006, 1.2733]]) └── 'c' --> <Tensor 0x7f17f2932e10> ├── 'd' --> tensor([[3.]]) └── 'noise' --> tensor([[[[ 0.6194, -0.8291, -1.1650, 0.0764, 0.2556], [-1.4103, 1.0407, -0.1023, 0.3174, 0.4812], [ 0.6355, -0.9907, -0.1500, -1.4863, 0.3427], [-0.1761, 0.6150, -1.6094, 0.4111, 1.8947]], [[ 0.8455, 1.8923, 0.9399, -1.4162, -0.9732], [-1.6211, 2.3156, 1.2249, 0.8859, -0.5506], [ 0.1431, -0.6609, -2.5998, 0.3779, 0.0424], [-1.0839, -0.2232, -0.3266, 0.8982, -0.1456]], [[-1.4125, 0.6481, 1.9928, 0.3429, -0.6100], [ 0.9297, 0.6323, 0.8506, -0.6363, -0.6004], [ 0.5461, 0.0813, -0.8113, -0.0896, 0.6426], [-0.0986, -0.3976, -0.1211, -1.0975, -0.0321]]]]) , <Tensor 0x7f17f2932ed0> ├── 'a' --> tensor([[[0.0930, 0.9039, 0.7044, 0.9361, 0.9592, 0.0232, 0.8527, 0.8502], │ [0.9800, 0.5632, 0.7258, 0.5329, 0.4701, 0.6252, 0.6134, 0.4340], │ [0.6840, 0.1951, 0.2539, 0.0763, 0.3272, 0.9452, 0.6882, 0.1001]]]) ├── 'b' --> tensor([[1.7539, 1.2752, 1.2830, 1.1712, 1.6810, 1.3123]]) └── 'c' --> <Tensor 0x7f17f2932e90> ├── 'd' --> tensor([[3.]]) └── 'noise' --> tensor([[[[ 2.4444, 1.1720, 0.8632, 0.4761, 1.1405], [-0.6512, 0.1339, 1.3094, -2.2030, -1.2528], [-0.6839, -1.4079, -1.0550, 0.2930, -0.3839], [-0.7287, 1.4465, 0.3683, -0.1231, -0.1292]], [[-2.0101, 1.7442, -0.0237, 0.9152, 0.0610], [-0.2823, 0.9592, -0.6154, -0.9371, -0.7711], [ 2.1444, -1.3623, -0.5261, 1.2471, -1.3095], [-1.0510, -0.4227, -0.1627, 1.2090, -0.0536]], [[-0.3722, 0.4047, -0.2497, -1.1743, 2.9188], [-0.2216, -1.2383, 2.9361, -0.0036, -0.8625], [-1.2923, 0.1095, -0.5218, 1.1455, -0.3562], [-1.0939, 0.1056, 0.3968, 0.5477, 0.5038]]]]) , <Tensor 0x7f17f2932f50> ├── 'a' --> tensor([[[0.5771, 0.9907, 0.1408, 0.6211, 0.2264, 0.6072, 0.5538, 0.1231], │ [0.1719, 0.5936, 0.1351, 0.8051, 0.5816, 0.8645, 0.4902, 0.6489], │ [0.1467, 0.7644, 0.3178, 0.4616, 0.6666, 0.2651, 0.4480, 0.3887]]]) ├── 'b' --> tensor([[1.1779, 1.0011, 1.0425, 1.2864, 1.5640, 1.0608]]) └── 'c' --> <Tensor 0x7f17f2932f10> ├── 'd' --> tensor([[7.]]) └── 'noise' --> tensor([[[[ 0.6417, 0.7936, 1.7457, -0.3449, 0.0586], [ 0.5191, -0.5138, 1.5452, -0.3986, 0.8913], [-1.1714, -0.5297, 0.2113, -0.9917, -1.0074], [-0.5967, -0.9461, -0.6241, -1.3767, -0.9106]], [[ 0.2880, -0.9860, -1.7877, -1.0784, 0.0903], [-0.7338, 1.2487, 0.9152, -0.1813, -0.6848], [ 0.8098, -2.6129, 0.6112, 0.5383, -0.5587], [-1.0908, -0.5596, 0.2072, 2.1724, 0.2787]], [[ 1.5288, -0.3257, -1.9534, 2.3294, 0.4874], [-0.2468, 0.0650, -0.8965, 0.9325, 0.1155], [ 0.3582, -1.1133, 0.1037, 0.2556, 1.1160], [-0.0853, 2.2644, -1.0572, 0.4531, 1.0887]]]]) , <Tensor 0x7f17f2932fd0> ├── 'a' --> tensor([[[0.9133, 0.8223, 0.1066, 0.8299, 0.9051, 0.4557, 0.0785, 0.0093], │ [0.8969, 0.3639, 0.0384, 0.0895, 0.5555, 0.6507, 0.6078, 0.4365], │ [0.1915, 0.2788, 0.9265, 0.0536, 0.4257, 0.9293, 0.6852, 0.8060]]]) ├── 'b' --> tensor([[1.1205, 1.9943, 1.0280, 1.2624, 1.0764, 1.9742]]) └── 'c' --> <Tensor 0x7f17f2932f90> ├── 'd' --> tensor([[0.]]) └── 'noise' --> tensor([[[[-0.6496, 0.9187, -0.8898, 1.7309, -0.3362], [-0.8324, -0.0810, 0.5725, 0.9787, 1.0387], [ 0.5681, 0.6451, -0.0196, 1.3154, 0.1205], [-0.9782, 0.6803, 0.6962, -0.9089, 0.2468]], [[-0.1351, 0.5176, -0.2932, -0.0736, 0.2269], [-0.2297, 0.9318, -0.8866, -0.4340, 1.6622], [-0.1321, 1.8283, -1.5091, 0.7979, 0.5349], [-2.2566, 0.9935, -0.3138, -2.2293, -0.1792]], [[ 0.6219, 0.1924, -0.2116, -1.1552, 0.4944], [ 0.1520, -1.2640, -0.4181, 0.8976, 0.3313], [-1.1501, 0.6571, -0.2058, 0.6300, -0.6956], [ 1.2968, 0.3220, -0.1916, 0.8898, -0.8460]]]]) ) mean0 & mean1: tensor(1.3379) tensor(1.3379) even_index_a0: tensor([[[0.2528, 0.4177, 0.5927, 0.7439, 0.2610, 0.9414, 0.3309, 0.8179], [0.2763, 0.1325, 0.2971, 0.6536, 0.6680, 0.6659, 0.1535, 0.6603]], [[0.0930, 0.9039, 0.7044, 0.9361, 0.9592, 0.0232, 0.8527, 0.8502], [0.6840, 0.1951, 0.2539, 0.0763, 0.3272, 0.9452, 0.6882, 0.1001]], [[0.5771, 0.9907, 0.1408, 0.6211, 0.2264, 0.6072, 0.5538, 0.1231], [0.1467, 0.7644, 0.3178, 0.4616, 0.6666, 0.2651, 0.4480, 0.3887]], [[0.9133, 0.8223, 0.1066, 0.8299, 0.9051, 0.4557, 0.0785, 0.0093], [0.1915, 0.2788, 0.9265, 0.0536, 0.4257, 0.9293, 0.6852, 0.8060]]]) even_index_a1: tensor([[[0.2528, 0.4177, 0.5927, 0.7439, 0.2610, 0.9414, 0.3309, 0.8179], [0.2763, 0.1325, 0.2971, 0.6536, 0.6680, 0.6659, 0.1535, 0.6603]], [[0.0930, 0.9039, 0.7044, 0.9361, 0.9592, 0.0232, 0.8527, 0.8502], [0.6840, 0.1951, 0.2539, 0.0763, 0.3272, 0.9452, 0.6882, 0.1001]], [[0.5771, 0.9907, 0.1408, 0.6211, 0.2264, 0.6072, 0.5538, 0.1231], [0.1467, 0.7644, 0.3178, 0.4616, 0.6666, 0.2651, 0.4480, 0.3887]], [[0.9133, 0.8223, 0.1066, 0.8299, 0.9051, 0.4557, 0.0785, 0.0093], [0.1915, 0.2788, 0.9265, 0.0536, 0.4257, 0.9293, 0.6852, 0.8060]]]) <Size 0x7f17f2932d50> ├── 'a' --> torch.Size([1, 3, 8]) ├── 'b' --> torch.Size([1, 6]) └── 'c' --> <Size 0x7f17f2932d10> ├── 'd' --> torch.Size([1, 1]) └── 'noise' --> torch.Size([1, 3, 4, 5]) |
The implement with treetensor API is much simpler and clearer.