Customized Operations For Different Fields¶
Here is another example of the custom operations implemented with both native torch API and treetensor API.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import copy import torch import treetensor.torch as ttorch T, B = 3, 4 def with_nativetensor(batch_): mean_b_list = [] even_index_a_list = [] for i in range(len(batch_)): for k, v in batch_[i].items(): if k == 'a': v = v.float() even_index_a_list.append(v[::2]) elif k == 'b': v = v.float() transformed_v = torch.pow(v, 2) + 1.0 mean_b_list.append(transformed_v.mean()) elif k == 'c': for k1, v1 in v.items(): if k1 == 'd': v1 = v1.float() else: print('ignore keys: {}'.format(k1)) else: print('ignore keys: {}'.format(k)) for i in range(len(batch_)): for k in batch_[i].keys(): if k == 'd': batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5)) mean_b = sum(mean_b_list) / len(mean_b_list) even_index_a = torch.stack(even_index_a_list, dim=0) return batch_, mean_b, even_index_a def with_treetensor(batch_): batch_ = [ttorch.tensor(b) for b in batch_] batch_ = ttorch.stack(batch_) batch_ = batch_.float() batch_.b = ttorch.pow(batch_.b, 2) + 1.0 batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5)) mean_b = batch_.b.mean() even_index_a = batch_.a[:, ::2] batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0) return batch_, mean_b, even_index_a def get_data(): return { 'a': torch.rand(size=(T, 8)), 'b': torch.rand(size=(6,)), 'c': { 'd': torch.randint(0, 10, size=(1,)) } } if __name__ == "__main__": batch = [get_data() for _ in range(B)] batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch)) batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch)) print(batch0) print('\n\n') print(batch1) assert torch.abs(mean0 - mean1) < 1e-6 print('mean0 & mean1:', mean0, mean1) print('\n') assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6 print('even_index_a0:', even_index_a0) print('even_index_a1:', even_index_a1) assert len(batch0) == B assert len(batch1) == B assert isinstance(batch1[0], ttorch.Tensor) print(batch1[0].shape) |
The output should be like below, and all the assertions can be passed.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | [{'a': tensor([[0.4959, 0.7429, 0.3304, 0.3691, 0.2602, 0.0217, 0.8756, 0.2599], [0.4695, 0.0128, 0.3597, 0.3050, 0.5799, 0.4451, 0.2758, 0.9532], [0.4889, 0.1749, 0.6407, 0.1816, 0.6972, 0.6399, 0.5626, 0.0948]]), 'b': tensor([0.6954, 0.2287, 0.6890, 0.9973, 0.1034, 0.9159]), 'c': {'d': tensor([4])}}, {'a': tensor([[0.6206, 0.3504, 0.5032, 0.3248, 0.7497, 0.3002, 0.2543, 0.1592], [0.1785, 0.7591, 0.3378, 0.8237, 0.6426, 0.2236, 0.3023, 0.7932], [0.0331, 0.4403, 0.9965, 0.7609, 0.7818, 0.3898, 0.0921, 0.2410]]), 'b': tensor([0.1389, 0.0367, 0.2335, 0.9589, 0.2701, 0.7552]), 'c': {'d': tensor([2])}}, {'a': tensor([[0.0787, 0.1154, 0.9513, 0.5076, 0.4105, 0.8078, 0.4976, 0.7081], [0.4414, 0.3664, 0.6787, 0.2895, 0.4783, 0.3019, 0.8096, 0.3731], [0.3586, 0.8682, 0.0797, 0.9027, 0.5423, 0.4816, 0.4864, 0.5681]]), 'b': tensor([0.3378, 0.7603, 0.1430, 0.2937, 0.0280, 0.7042]), 'c': {'d': tensor([1])}}, {'a': tensor([[0.7616, 0.8400, 0.8517, 0.3064, 0.9760, 0.8378, 0.3704, 0.7836], [0.5377, 0.6733, 0.5208, 0.8555, 0.8346, 0.0799, 0.4971, 0.3771], [0.5743, 0.6112, 0.0894, 0.4634, 0.2029, 0.5474, 0.6798, 0.5255]]), 'b': tensor([0.9623, 0.0927, 0.7332, 0.0350, 0.9289, 0.6587]), 'c': {'d': tensor([7])}}] (<Tensor 0x7fbed975de90> ├── 'a' --> tensor([[[0.4959, 0.7429, 0.3304, 0.3691, 0.2602, 0.0217, 0.8756, 0.2599], │ [0.4695, 0.0128, 0.3597, 0.3050, 0.5799, 0.4451, 0.2758, 0.9532], │ [0.4889, 0.1749, 0.6407, 0.1816, 0.6972, 0.6399, 0.5626, 0.0948]]]) ├── 'b' --> tensor([[1.4836, 1.0523, 1.4748, 1.9947, 1.0107, 1.8388]]) └── 'c' --> <Tensor 0x7fbed975de50> ├── 'd' --> tensor([[4.]]) └── 'noise' --> tensor([[[[ 1.2098, -1.1480, -3.0458, -0.1071, -1.1395], [ 1.6361, 1.3693, 0.1426, 0.1702, 0.5653], [ 0.1225, -0.6091, 0.4540, 2.5245, 0.2460], [-1.1039, 1.3527, -0.2761, -0.3291, -0.0366]], [[-1.4353, 0.4943, 0.9781, -0.0853, 0.4413], [ 0.5125, 0.6533, 0.3187, -0.6209, -1.3550], [-0.9656, -0.1417, -1.7543, -0.9261, -0.4421], [ 0.3440, -0.2400, 1.1507, -0.2864, 1.7387]], [[ 1.6043, 1.0504, 0.4217, 0.6256, 0.1901], [ 2.0248, -0.5585, -0.2970, -2.2687, -2.1299], [-1.3107, -0.1643, 0.5487, -0.1616, 0.6969], [ 0.1701, 2.2368, -0.8534, -0.0991, -0.2608]]]]) , <Tensor 0x7fbed975df10> ├── 'a' --> tensor([[[0.6206, 0.3504, 0.5032, 0.3248, 0.7497, 0.3002, 0.2543, 0.1592], │ [0.1785, 0.7591, 0.3378, 0.8237, 0.6426, 0.2236, 0.3023, 0.7932], │ [0.0331, 0.4403, 0.9965, 0.7609, 0.7818, 0.3898, 0.0921, 0.2410]]]) ├── 'b' --> tensor([[1.0193, 1.0013, 1.0545, 1.9196, 1.0730, 1.5704]]) └── 'c' --> <Tensor 0x7fbed975ded0> ├── 'd' --> tensor([[2.]]) └── 'noise' --> tensor([[[[-0.8716, 0.3982, -0.4103, 1.0145, 0.2366], [ 0.1163, -1.5097, 0.8837, 0.6072, 1.3204], [-0.7921, -0.3152, -0.4015, -1.6218, 0.1952], [-0.0809, -1.2087, -1.4225, 0.4099, 0.5188]], [[ 0.1643, 0.2732, 0.2644, -2.1391, -0.7504], [ 0.9021, 0.2958, 0.0264, 0.2811, -0.6561], [-2.3868, -1.0131, -0.3284, 0.2162, -1.9766], [-0.4745, 2.6729, -0.1767, 2.0280, -0.6456]], [[ 0.8487, 0.9021, 2.1113, -0.8008, 0.9089], [ 0.4937, -1.3294, -1.5853, 0.6095, -0.2106], [ 0.8631, -1.4816, -1.3472, -0.9508, -0.1390], [-0.8438, -1.1405, -0.5679, 0.0308, 0.2206]]]]) , <Tensor 0x7fbed975df90> ├── 'a' --> tensor([[[0.0787, 0.1154, 0.9513, 0.5076, 0.4105, 0.8078, 0.4976, 0.7081], │ [0.4414, 0.3664, 0.6787, 0.2895, 0.4783, 0.3019, 0.8096, 0.3731], │ [0.3586, 0.8682, 0.0797, 0.9027, 0.5423, 0.4816, 0.4864, 0.5681]]]) ├── 'b' --> tensor([[1.1141, 1.5781, 1.0204, 1.0863, 1.0008, 1.4959]]) └── 'c' --> <Tensor 0x7fbed975df50> ├── 'd' --> tensor([[1.]]) └── 'noise' --> tensor([[[[-0.8661, 2.1075, -0.8831, 0.0351, 0.8010], [ 0.2314, -1.7201, 1.1340, 0.0306, 0.2898], [-0.2021, 0.1031, -1.1276, 0.9147, 1.1242], [ 0.1915, -1.2295, -0.0290, 0.8289, -2.2935]], [[ 0.3192, 0.1709, 0.6136, 1.0401, 1.6776], [ 0.6266, -0.4175, -0.2612, -0.3719, -1.8482], [ 0.4590, -0.8227, 1.4691, -0.5683, -1.0901], [-0.4111, -0.8091, -1.1143, -0.8473, -0.5122]], [[-1.6365, -1.0635, -1.2169, 0.3649, -1.6512], [ 1.0775, -1.9654, 0.0070, 0.6785, 0.0523], [-1.9686, -0.6915, -0.3784, -0.1928, -0.5714], [-1.6885, 0.6324, 0.3805, 0.8414, -1.1992]]]]) , <Tensor 0x7fbed9764050> ├── 'a' --> tensor([[[0.7616, 0.8400, 0.8517, 0.3064, 0.9760, 0.8378, 0.3704, 0.7836], │ [0.5377, 0.6733, 0.5208, 0.8555, 0.8346, 0.0799, 0.4971, 0.3771], │ [0.5743, 0.6112, 0.0894, 0.4634, 0.2029, 0.5474, 0.6798, 0.5255]]]) ├── 'b' --> tensor([[1.9260, 1.0086, 1.5375, 1.0012, 1.8628, 1.4339]]) └── 'c' --> <Tensor 0x7fbed975dfd0> ├── 'd' --> tensor([[7.]]) └── 'noise' --> tensor([[[[ 0.2398, -1.1715, -0.5465, 0.5210, -1.9378], [ 1.8088, 1.1304, 1.1054, -0.1605, 0.1069], [-0.3182, 1.1913, -0.8778, -0.2761, -1.9641], [-1.0521, 0.8548, 0.5549, 0.7722, 0.4085]], [[-0.8001, 0.5689, 1.0019, -0.0062, -0.9012], [ 0.1281, 0.4304, 0.6782, -1.2555, -0.8217], [ 0.3796, 0.6830, -2.9541, -0.6279, 0.0266], [-1.2134, -1.5165, 1.7040, -0.9219, 0.5670]], [[ 0.0077, -0.3625, -0.4056, 0.2789, 0.5481], [-1.1109, -0.3867, -0.9889, 0.7654, -0.1190], [ 1.6151, -0.7024, -0.0453, 1.0142, 2.8909], [ 0.7332, 1.2631, 1.1594, 0.5931, -1.5527]]]]) ) mean0 & mean1: tensor(1.3566) tensor(1.3566) even_index_a0: tensor([[[0.4959, 0.7429, 0.3304, 0.3691, 0.2602, 0.0217, 0.8756, 0.2599], [0.4889, 0.1749, 0.6407, 0.1816, 0.6972, 0.6399, 0.5626, 0.0948]], [[0.6206, 0.3504, 0.5032, 0.3248, 0.7497, 0.3002, 0.2543, 0.1592], [0.0331, 0.4403, 0.9965, 0.7609, 0.7818, 0.3898, 0.0921, 0.2410]], [[0.0787, 0.1154, 0.9513, 0.5076, 0.4105, 0.8078, 0.4976, 0.7081], [0.3586, 0.8682, 0.0797, 0.9027, 0.5423, 0.4816, 0.4864, 0.5681]], [[0.7616, 0.8400, 0.8517, 0.3064, 0.9760, 0.8378, 0.3704, 0.7836], [0.5743, 0.6112, 0.0894, 0.4634, 0.2029, 0.5474, 0.6798, 0.5255]]]) even_index_a1: tensor([[[0.4959, 0.7429, 0.3304, 0.3691, 0.2602, 0.0217, 0.8756, 0.2599], [0.4889, 0.1749, 0.6407, 0.1816, 0.6972, 0.6399, 0.5626, 0.0948]], [[0.6206, 0.3504, 0.5032, 0.3248, 0.7497, 0.3002, 0.2543, 0.1592], [0.0331, 0.4403, 0.9965, 0.7609, 0.7818, 0.3898, 0.0921, 0.2410]], [[0.0787, 0.1154, 0.9513, 0.5076, 0.4105, 0.8078, 0.4976, 0.7081], [0.3586, 0.8682, 0.0797, 0.9027, 0.5423, 0.4816, 0.4864, 0.5681]], [[0.7616, 0.8400, 0.8517, 0.3064, 0.9760, 0.8378, 0.3704, 0.7836], [0.5743, 0.6112, 0.0894, 0.4634, 0.2029, 0.5474, 0.6798, 0.5255]]]) <Size 0x7fbed975dd90> ├── 'a' --> torch.Size([1, 3, 8]) ├── 'b' --> torch.Size([1, 6]) └── 'c' --> <Size 0x7fbed975dd50> ├── 'd' --> torch.Size([1, 1]) └── 'noise' --> torch.Size([1, 3, 4, 5]) |
The implement with treetensor API is much simpler and clearer.