Customized Operations For Different Fields¶
Here is another example of the custom operations implemented with both native torch API and treetensor API.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import copy import torch import treetensor.torch as ttorch T, B = 3, 4 def with_nativetensor(batch_): mean_b_list = [] even_index_a_list = [] for i in range(len(batch_)): for k, v in batch_[i].items(): if k == 'a': v = v.float() even_index_a_list.append(v[::2]) elif k == 'b': v = v.float() transformed_v = torch.pow(v, 2) + 1.0 mean_b_list.append(transformed_v.mean()) elif k == 'c': for k1, v1 in v.items(): if k1 == 'd': v1 = v1.float() else: print('ignore keys: {}'.format(k1)) else: print('ignore keys: {}'.format(k)) for i in range(len(batch_)): for k in batch_[i].keys(): if k == 'd': batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5)) mean_b = sum(mean_b_list) / len(mean_b_list) even_index_a = torch.stack(even_index_a_list, dim=0) return batch_, mean_b, even_index_a def with_treetensor(batch_): batch_ = [ttorch.tensor(b) for b in batch_] batch_ = ttorch.stack(batch_) batch_ = batch_.float() batch_.b = ttorch.pow(batch_.b, 2) + 1.0 batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5)) mean_b = batch_.b.mean() even_index_a = batch_.a[:, ::2] batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0) return batch_, mean_b, even_index_a def get_data(): return { 'a': torch.rand(size=(T, 8)), 'b': torch.rand(size=(6,)), 'c': { 'd': torch.randint(0, 10, size=(1,)) } } if __name__ == "__main__": batch = [get_data() for _ in range(B)] batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch)) batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch)) print(batch0) print('\n\n') print(batch1) assert torch.abs(mean0 - mean1) < 1e-6 print('mean0 & mean1:', mean0, mean1) print('\n') assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6 print('even_index_a0:', even_index_a0) print('even_index_a1:', even_index_a1) assert len(batch0) == B assert len(batch1) == B assert isinstance(batch1[0], ttorch.Tensor) print(batch1[0].shape) |
The output should be like below, and all the assertions can be passed.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | [{'a': tensor([[0.0530, 0.1316, 0.3868, 0.0582, 0.0592, 0.2844, 0.7060, 0.7104], [0.2715, 0.6086, 0.3717, 0.9783, 0.8932, 0.2700, 0.0806, 0.9643], [0.6430, 0.4201, 0.3315, 0.5621, 0.4383, 0.7077, 0.8652, 0.7994]]), 'b': tensor([0.7424, 0.9573, 0.3920, 0.1052, 0.1581, 0.8148]), 'c': {'d': tensor([2])}}, {'a': tensor([[0.0653, 0.4086, 0.6234, 0.7911, 0.9723, 0.0044, 0.9084, 0.6511], [0.4518, 0.1293, 0.2198, 0.2321, 0.9656, 0.4900, 0.3904, 0.9074], [0.9259, 0.2673, 0.2197, 0.2338, 0.8283, 0.3773, 0.8736, 0.9157]]), 'b': tensor([0.0709, 0.8475, 0.8364, 0.9020, 0.4540, 0.5760]), 'c': {'d': tensor([8])}}, {'a': tensor([[0.1737, 0.6593, 0.8535, 0.4953, 0.6238, 0.9004, 0.8182, 0.2689], [0.2865, 0.4794, 0.7387, 0.0603, 0.9837, 0.6522, 0.6537, 0.0974], [0.0411, 0.4282, 0.6774, 0.0315, 0.7757, 0.9708, 0.9896, 0.7082]]), 'b': tensor([0.0826, 0.7268, 0.3619, 0.9916, 0.2159, 0.2764]), 'c': {'d': tensor([2])}}, {'a': tensor([[0.7990, 0.3954, 0.2579, 0.3021, 0.7232, 0.1034, 0.9367, 0.8425], [0.7894, 0.4812, 0.3679, 0.4769, 0.5864, 0.6116, 0.4470, 0.3827], [0.6414, 0.4988, 0.6573, 0.6941, 0.3207, 0.3858, 0.7914, 0.9150]]), 'b': tensor([0.1669, 0.1758, 0.1928, 0.0075, 0.8306, 0.6041]), 'c': {'d': tensor([8])}}] (<Tensor 0x7fe654722fd0> ├── 'a' --> tensor([[[0.0530, 0.1316, 0.3868, 0.0582, 0.0592, 0.2844, 0.7060, 0.7104], │ [0.2715, 0.6086, 0.3717, 0.9783, 0.8932, 0.2700, 0.0806, 0.9643], │ [0.6430, 0.4201, 0.3315, 0.5621, 0.4383, 0.7077, 0.8652, 0.7994]]]) ├── 'b' --> tensor([[1.5512, 1.9164, 1.1537, 1.0111, 1.0250, 1.6638]]) └── 'c' --> <Tensor 0x7fe654722f90> ├── 'd' --> tensor([[2.]]) └── 'noise' --> tensor([[[[-0.6145, -0.1458, 0.1173, 0.3199, -0.2685], [-0.4438, -0.9441, -1.3548, -1.6339, -0.1808], [-0.1598, -0.7564, 0.5177, 1.6621, -0.7301], [-0.0091, -0.0693, 0.3701, -0.1008, -0.9677]], [[-0.0215, -0.3856, -0.2832, -0.6427, -1.1976], [ 0.1676, -0.7570, 1.3808, -0.6569, 0.7185], [-1.7909, -0.6165, -0.3450, -0.2981, 1.4836], [ 1.1538, -0.4351, -0.3297, 0.2331, -1.3635]], [[ 0.8167, 1.4640, 0.4096, -0.2440, -0.9645], [-1.0080, -1.1316, -0.0838, 0.5443, -1.8280], [-0.2243, 0.3282, -0.7281, -0.8862, -1.0531], [ 0.4967, -1.6835, 2.8434, -1.4404, 1.2524]]]]) , <Tensor 0x7fe654729090> ├── 'a' --> tensor([[[0.0653, 0.4086, 0.6234, 0.7911, 0.9723, 0.0044, 0.9084, 0.6511], │ [0.4518, 0.1293, 0.2198, 0.2321, 0.9656, 0.4900, 0.3904, 0.9074], │ [0.9259, 0.2673, 0.2197, 0.2338, 0.8283, 0.3773, 0.8736, 0.9157]]]) ├── 'b' --> tensor([[1.0050, 1.7183, 1.6996, 1.8136, 1.2061, 1.3318]]) └── 'c' --> <Tensor 0x7fe654729050> ├── 'd' --> tensor([[8.]]) └── 'noise' --> tensor([[[[ 0.6687, -0.1380, -0.2344, -1.0965, 0.9932], [-0.4782, -1.5013, 0.1346, 2.0136, -0.0120], [-0.8172, 0.3474, 0.2281, 0.6967, 0.3457], [ 0.8758, -1.9384, -0.5588, 0.4844, -1.3965]], [[ 0.5175, 1.3308, 1.5126, 1.8909, 0.1056], [ 0.6862, 1.4825, -1.2959, -1.5295, 1.0123], [-0.2036, 1.4430, 0.3562, -0.0324, -0.8127], [-0.0664, -1.1613, 2.0566, -0.5090, -1.0022]], [[-0.0094, 0.4440, 1.0630, -0.1183, -0.8398], [-0.7594, -0.3139, 0.7612, 0.7844, 0.7298], [-0.0850, -1.3562, -1.3477, -1.9326, -0.1287], [-0.3684, -0.1948, -0.2106, 0.1244, -0.6444]]]]) , <Tensor 0x7fe654729110> ├── 'a' --> tensor([[[0.1737, 0.6593, 0.8535, 0.4953, 0.6238, 0.9004, 0.8182, 0.2689], │ [0.2865, 0.4794, 0.7387, 0.0603, 0.9837, 0.6522, 0.6537, 0.0974], │ [0.0411, 0.4282, 0.6774, 0.0315, 0.7757, 0.9708, 0.9896, 0.7082]]]) ├── 'b' --> tensor([[1.0068, 1.5282, 1.1310, 1.9833, 1.0466, 1.0764]]) └── 'c' --> <Tensor 0x7fe6547290d0> ├── 'd' --> tensor([[2.]]) └── 'noise' --> tensor([[[[ 0.4160, 2.0009, -1.7171, 0.5402, -2.3210], [-0.2561, 0.4562, -1.0994, 1.2536, -1.3823], [ 0.4414, -0.6419, 0.4306, -0.0867, 0.3658], [ 0.3535, -0.7061, 1.8489, -0.3642, -0.4975]], [[-0.0270, -0.0539, 1.0774, -0.4132, 0.1036], [ 0.0773, 0.7107, -0.3725, -0.3476, -0.0902], [-1.1531, 0.4045, 0.6975, 0.1649, -1.1100], [-0.2848, -0.0567, 1.1481, -1.3258, -1.5258]], [[ 0.6759, -0.1135, -0.0723, -0.0190, -0.2226], [-1.6966, 0.0369, 1.4791, -0.0274, 0.5058], [ 0.1358, -0.6024, 2.6037, -0.1705, 1.1308], [-0.6299, -1.1563, -1.3002, -1.5118, -1.8600]]]]) , <Tensor 0x7fe654729190> ├── 'a' --> tensor([[[0.7990, 0.3954, 0.2579, 0.3021, 0.7232, 0.1034, 0.9367, 0.8425], │ [0.7894, 0.4812, 0.3679, 0.4769, 0.5864, 0.6116, 0.4470, 0.3827], │ [0.6414, 0.4988, 0.6573, 0.6941, 0.3207, 0.3858, 0.7914, 0.9150]]]) ├── 'b' --> tensor([[1.0279, 1.0309, 1.0372, 1.0001, 1.6899, 1.3649]]) └── 'c' --> <Tensor 0x7fe654729150> ├── 'd' --> tensor([[8.]]) └── 'noise' --> tensor([[[[-0.0519, 0.5795, 0.9674, 1.4432, -0.5475], [-0.5406, 0.6173, -0.9155, -0.7560, 1.3850], [-1.7609, 0.5079, 0.0522, 0.4386, 0.0854], [ 0.0806, -2.1930, 0.7717, 0.3758, -0.9107]], [[-0.3009, 1.7650, 0.0811, 1.0014, -1.2704], [-0.5530, -0.2263, 0.4376, 0.9503, -0.3118], [-0.8690, -1.4592, -1.3738, 0.0543, 1.7000], [ 1.1483, -1.3858, 0.6293, -0.1860, -1.5392]], [[-0.8780, 0.2691, -1.2989, -0.3519, 0.7185], [-1.4287, 0.4811, 1.0223, -0.3136, -2.3332], [ 0.9769, -0.4040, -1.3943, -0.1041, 0.7485], [ 0.2314, -0.1190, -0.1268, -0.3399, -0.9422]]]]) ) mean0 & mean1: tensor(1.3341) tensor(1.3341) even_index_a0: tensor([[[0.0530, 0.1316, 0.3868, 0.0582, 0.0592, 0.2844, 0.7060, 0.7104], [0.6430, 0.4201, 0.3315, 0.5621, 0.4383, 0.7077, 0.8652, 0.7994]], [[0.0653, 0.4086, 0.6234, 0.7911, 0.9723, 0.0044, 0.9084, 0.6511], [0.9259, 0.2673, 0.2197, 0.2338, 0.8283, 0.3773, 0.8736, 0.9157]], [[0.1737, 0.6593, 0.8535, 0.4953, 0.6238, 0.9004, 0.8182, 0.2689], [0.0411, 0.4282, 0.6774, 0.0315, 0.7757, 0.9708, 0.9896, 0.7082]], [[0.7990, 0.3954, 0.2579, 0.3021, 0.7232, 0.1034, 0.9367, 0.8425], [0.6414, 0.4988, 0.6573, 0.6941, 0.3207, 0.3858, 0.7914, 0.9150]]]) even_index_a1: tensor([[[0.0530, 0.1316, 0.3868, 0.0582, 0.0592, 0.2844, 0.7060, 0.7104], [0.6430, 0.4201, 0.3315, 0.5621, 0.4383, 0.7077, 0.8652, 0.7994]], [[0.0653, 0.4086, 0.6234, 0.7911, 0.9723, 0.0044, 0.9084, 0.6511], [0.9259, 0.2673, 0.2197, 0.2338, 0.8283, 0.3773, 0.8736, 0.9157]], [[0.1737, 0.6593, 0.8535, 0.4953, 0.6238, 0.9004, 0.8182, 0.2689], [0.0411, 0.4282, 0.6774, 0.0315, 0.7757, 0.9708, 0.9896, 0.7082]], [[0.7990, 0.3954, 0.2579, 0.3021, 0.7232, 0.1034, 0.9367, 0.8425], [0.6414, 0.4988, 0.6573, 0.6941, 0.3207, 0.3858, 0.7914, 0.9150]]]) <Size 0x7fe6548e9f90> ├── 'a' --> torch.Size([1, 3, 8]) ├── 'b' --> torch.Size([1, 6]) └── 'c' --> <Size 0x7fe654a9a110> ├── 'd' --> torch.Size([1, 1]) └── 'noise' --> torch.Size([1, 3, 4, 5]) |
The implement with treetensor API is much simpler and clearer.