Customized Operations For Different Fields¶
Here is another example of the custom operations implemented with both native torch API and treetensor API.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import copy import torch import treetensor.torch as ttorch T, B = 3, 4 def with_nativetensor(batch_): mean_b_list = [] even_index_a_list = [] for i in range(len(batch_)): for k, v in batch_[i].items(): if k == 'a': v = v.float() even_index_a_list.append(v[::2]) elif k == 'b': v = v.float() transformed_v = torch.pow(v, 2) + 1.0 mean_b_list.append(transformed_v.mean()) elif k == 'c': for k1, v1 in v.items(): if k1 == 'd': v1 = v1.float() else: print('ignore keys: {}'.format(k1)) else: print('ignore keys: {}'.format(k)) for i in range(len(batch_)): for k in batch_[i].keys(): if k == 'd': batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5)) mean_b = sum(mean_b_list) / len(mean_b_list) even_index_a = torch.stack(even_index_a_list, dim=0) return batch_, mean_b, even_index_a def with_treetensor(batch_): batch_ = [ttorch.tensor(b) for b in batch_] batch_ = ttorch.stack(batch_) batch_ = batch_.float() batch_.b = ttorch.pow(batch_.b, 2) + 1.0 batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5)) mean_b = batch_.b.mean() even_index_a = batch_.a[:, ::2] batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0) return batch_, mean_b, even_index_a def get_data(): return { 'a': torch.rand(size=(T, 8)), 'b': torch.rand(size=(6,)), 'c': { 'd': torch.randint(0, 10, size=(1,)) } } if __name__ == "__main__": batch = [get_data() for _ in range(B)] batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch)) batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch)) print(batch0) print('\n\n') print(batch1) assert torch.abs(mean0 - mean1) < 1e-6 print('mean0 & mean1:', mean0, mean1) print('\n') assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6 print('even_index_a0:', even_index_a0) print('even_index_a1:', even_index_a1) assert len(batch0) == B assert len(batch1) == B assert isinstance(batch1[0], ttorch.Tensor) print(batch1[0].shape) |
The output should be like below, and all the assertions can be passed.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | [{'a': tensor([[0.3313, 0.0882, 0.6579, 0.3501, 0.3909, 0.9383, 0.1561, 0.0703], [0.9961, 0.9322, 0.4953, 0.6302, 0.3226, 0.7668, 0.9302, 0.4677], [0.5480, 0.5045, 0.5743, 0.8292, 0.3024, 0.1904, 0.6685, 0.9490]]), 'b': tensor([0.2765, 0.8172, 0.9692, 0.5797, 0.9039, 0.0055]), 'c': {'d': tensor([1])}}, {'a': tensor([[0.6972, 0.5427, 0.0417, 0.8488, 0.4968, 0.6170, 0.8957, 0.2799], [0.0127, 0.9730, 0.6826, 0.2333, 0.2428, 0.0609, 0.3715, 0.0104], [0.3351, 0.7599, 0.1900, 0.9579, 0.9783, 0.0326, 0.1451, 0.7471]]), 'b': tensor([0.9764, 0.1874, 0.1346, 0.8179, 0.3964, 0.2178]), 'c': {'d': tensor([1])}}, {'a': tensor([[0.7080, 0.8097, 0.9506, 0.2266, 0.4898, 0.4191, 0.6291, 0.4642], [0.8838, 0.5935, 0.3174, 0.6670, 0.9232, 0.3604, 0.4453, 0.5772], [0.2580, 0.4739, 0.1686, 0.3093, 0.0614, 0.4859, 0.4071, 0.2028]]), 'b': tensor([0.6880, 0.8413, 0.1673, 0.8797, 0.0595, 0.2570]), 'c': {'d': tensor([7])}}, {'a': tensor([[0.8292, 0.1015, 0.0075, 0.1469, 0.4491, 0.2319, 0.8218, 0.0663], [0.3731, 0.4277, 0.7806, 0.3766, 0.4146, 0.2895, 0.2871, 0.9064], [0.4393, 0.0466, 0.6933, 0.3058, 0.1305, 0.2063, 0.5092, 0.3820]]), 'b': tensor([0.7516, 0.2438, 0.9755, 0.6713, 0.4114, 0.0472]), 'c': {'d': tensor([2])}}] (<Tensor 0x7f8fe5cebfd0> ├── 'a' --> tensor([[[0.3313, 0.0882, 0.6579, 0.3501, 0.3909, 0.9383, 0.1561, 0.0703], │ [0.9961, 0.9322, 0.4953, 0.6302, 0.3226, 0.7668, 0.9302, 0.4677], │ [0.5480, 0.5045, 0.5743, 0.8292, 0.3024, 0.1904, 0.6685, 0.9490]]]) ├── 'b' --> tensor([[1.0764, 1.6678, 1.9394, 1.3360, 1.8170, 1.0000]]) └── 'c' --> <Tensor 0x7f8fe5cebf90> ├── 'd' --> tensor([[1.]]) └── 'noise' --> tensor([[[[-0.0621, -0.7930, 0.7569, -1.0191, -0.3130], [-0.4686, -0.1089, -0.9081, -1.0446, 0.4272], [ 1.0152, -0.3799, 0.9694, 1.8244, 0.5693], [-1.1199, 0.7588, -0.7775, 0.7499, 0.1142]], [[ 0.5598, -0.2789, 1.1207, -1.2933, 0.1816], [-0.5194, 0.0579, -0.4397, 0.5926, -0.9488], [-1.6616, -1.4776, -2.0146, 0.5458, -1.4844], [ 0.5248, 0.5729, 0.5522, -0.2959, 1.0412]], [[ 1.9960, 0.7783, -1.4010, 0.5961, 0.2503], [-1.6178, -0.7815, 0.1777, 0.3704, -0.1290], [ 0.5481, 0.1488, 1.8858, 0.2975, -0.7352], [-0.4605, 0.6365, -0.3557, -0.0358, -0.7355]]]]) , <Tensor 0x7f8fe5cf2090> ├── 'a' --> tensor([[[0.6972, 0.5427, 0.0417, 0.8488, 0.4968, 0.6170, 0.8957, 0.2799], │ [0.0127, 0.9730, 0.6826, 0.2333, 0.2428, 0.0609, 0.3715, 0.0104], │ [0.3351, 0.7599, 0.1900, 0.9579, 0.9783, 0.0326, 0.1451, 0.7471]]]) ├── 'b' --> tensor([[1.9533, 1.0351, 1.0181, 1.6690, 1.1571, 1.0474]]) └── 'c' --> <Tensor 0x7f8fe5cf2050> ├── 'd' --> tensor([[1.]]) └── 'noise' --> tensor([[[[-0.0456, 0.3777, 0.1782, -0.0794, 1.8114], [ 0.4518, 0.2652, 1.5227, 0.3376, -0.0074], [ 0.3699, 0.7519, 1.3634, 0.3915, -0.4720], [ 1.6225, 0.4938, 0.2811, -1.6052, 0.1634]], [[-0.8232, 1.2443, 0.6446, 0.0268, 0.2494], [ 1.1742, -0.0170, 1.2971, 1.3382, 0.6993], [-1.4969, 1.2580, -0.3939, -1.3453, 0.6606], [-0.1392, 0.5910, 0.4448, 0.6361, 0.0247]], [[ 1.0457, 0.7729, 0.7605, 0.1507, 0.0617], [ 0.9171, 0.3275, -0.8157, 1.2837, 1.1959], [-0.5456, 1.5890, -1.9056, 2.2679, -0.6888], [ 1.2832, -1.3582, 1.5985, -1.1178, -1.9899]]]]) , <Tensor 0x7f8fe5cf2110> ├── 'a' --> tensor([[[0.7080, 0.8097, 0.9506, 0.2266, 0.4898, 0.4191, 0.6291, 0.4642], │ [0.8838, 0.5935, 0.3174, 0.6670, 0.9232, 0.3604, 0.4453, 0.5772], │ [0.2580, 0.4739, 0.1686, 0.3093, 0.0614, 0.4859, 0.4071, 0.2028]]]) ├── 'b' --> tensor([[1.4734, 1.7078, 1.0280, 1.7738, 1.0035, 1.0661]]) └── 'c' --> <Tensor 0x7f8fe5cf20d0> ├── 'd' --> tensor([[7.]]) └── 'noise' --> tensor([[[[-1.3803, -0.9311, 0.8792, -0.4538, -0.4213], [ 1.0475, -0.0281, 0.3050, -0.1078, -0.1914], [-1.8961, -1.8887, -0.3832, -0.9808, 0.3060], [-0.3629, -1.4063, -0.2881, 0.5847, 0.7290]], [[-0.0598, -0.5249, -0.6472, -0.1638, -1.7160], [-0.6548, 0.7673, 0.9549, -1.3975, 1.1515], [-0.8544, 0.2624, 1.9701, 2.6244, 0.5877], [ 0.4577, -1.2086, -0.6552, 0.3812, 0.4401]], [[-1.3914, -0.6348, -0.5649, 1.3535, 0.0401], [-1.5634, -0.3972, -1.2206, 0.5179, -0.5741], [ 0.9142, 0.6368, 0.7496, 1.7236, -0.4995], [-0.4428, -1.6781, 0.7597, 1.8426, -0.5079]]]]) , <Tensor 0x7f8fe5cf2190> ├── 'a' --> tensor([[[0.8292, 0.1015, 0.0075, 0.1469, 0.4491, 0.2319, 0.8218, 0.0663], │ [0.3731, 0.4277, 0.7806, 0.3766, 0.4146, 0.2895, 0.2871, 0.9064], │ [0.4393, 0.0466, 0.6933, 0.3058, 0.1305, 0.2063, 0.5092, 0.3820]]]) ├── 'b' --> tensor([[1.5649, 1.0594, 1.9516, 1.4506, 1.1692, 1.0022]]) └── 'c' --> <Tensor 0x7f8fe5cf2150> ├── 'd' --> tensor([[2.]]) └── 'noise' --> tensor([[[[-1.5034, 2.4455, -1.2657, -0.1664, 1.2356], [-1.0091, -1.7266, -0.4433, -1.2735, 0.6320], [ 0.9520, 0.5521, 1.5311, 0.1378, -0.6468], [-0.9342, -0.3182, 0.5041, 1.3123, 1.4483]], [[-0.1654, -0.1329, -0.0582, 0.7893, -0.5403], [ 1.5499, 0.1751, -0.1191, -0.1223, -0.5811], [ 0.9971, 0.1710, -0.7453, -0.1878, -1.1388], [-1.9537, -0.2508, 0.3788, -0.8452, -0.5041]], [[ 0.9284, 0.5913, 1.1612, -0.1337, 0.3932], [ 0.0346, -0.2545, -2.0326, -1.4314, -0.3497], [-0.6397, 1.2387, 1.2089, -0.1503, 0.1883], [-0.0752, 0.6694, 1.0501, 1.3362, -0.5575]]]]) ) mean0 & mean1: tensor(1.3736) tensor(1.3736) even_index_a0: tensor([[[0.3313, 0.0882, 0.6579, 0.3501, 0.3909, 0.9383, 0.1561, 0.0703], [0.5480, 0.5045, 0.5743, 0.8292, 0.3024, 0.1904, 0.6685, 0.9490]], [[0.6972, 0.5427, 0.0417, 0.8488, 0.4968, 0.6170, 0.8957, 0.2799], [0.3351, 0.7599, 0.1900, 0.9579, 0.9783, 0.0326, 0.1451, 0.7471]], [[0.7080, 0.8097, 0.9506, 0.2266, 0.4898, 0.4191, 0.6291, 0.4642], [0.2580, 0.4739, 0.1686, 0.3093, 0.0614, 0.4859, 0.4071, 0.2028]], [[0.8292, 0.1015, 0.0075, 0.1469, 0.4491, 0.2319, 0.8218, 0.0663], [0.4393, 0.0466, 0.6933, 0.3058, 0.1305, 0.2063, 0.5092, 0.3820]]]) even_index_a1: tensor([[[0.3313, 0.0882, 0.6579, 0.3501, 0.3909, 0.9383, 0.1561, 0.0703], [0.5480, 0.5045, 0.5743, 0.8292, 0.3024, 0.1904, 0.6685, 0.9490]], [[0.6972, 0.5427, 0.0417, 0.8488, 0.4968, 0.6170, 0.8957, 0.2799], [0.3351, 0.7599, 0.1900, 0.9579, 0.9783, 0.0326, 0.1451, 0.7471]], [[0.7080, 0.8097, 0.9506, 0.2266, 0.4898, 0.4191, 0.6291, 0.4642], [0.2580, 0.4739, 0.1686, 0.3093, 0.0614, 0.4859, 0.4071, 0.2028]], [[0.8292, 0.1015, 0.0075, 0.1469, 0.4491, 0.2319, 0.8218, 0.0663], [0.4393, 0.0466, 0.6933, 0.3058, 0.1305, 0.2063, 0.5092, 0.3820]]]) <Size 0x7f8fe5cebed0> ├── 'a' --> torch.Size([1, 3, 8]) ├── 'b' --> torch.Size([1, 6]) └── 'c' --> <Size 0x7f8fe5cebe90> ├── 'd' --> torch.Size([1, 1]) └── 'noise' --> torch.Size([1, 3, 4, 5]) |
The implement with treetensor API is much simpler and clearer.