Customized Operations For Different Fields¶
Here is another example of the custom operations implemented with both native torch API and treetensor API.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import copy import torch import treetensor.torch as ttorch T, B = 3, 4 def with_nativetensor(batch_): mean_b_list = [] even_index_a_list = [] for i in range(len(batch_)): for k, v in batch_[i].items(): if k == 'a': v = v.float() even_index_a_list.append(v[::2]) elif k == 'b': v = v.float() transformed_v = torch.pow(v, 2) + 1.0 mean_b_list.append(transformed_v.mean()) elif k == 'c': for k1, v1 in v.items(): if k1 == 'd': v1 = v1.float() else: print('ignore keys: {}'.format(k1)) else: print('ignore keys: {}'.format(k)) for i in range(len(batch_)): for k in batch_[i].keys(): if k == 'd': batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5)) mean_b = sum(mean_b_list) / len(mean_b_list) even_index_a = torch.stack(even_index_a_list, dim=0) return batch_, mean_b, even_index_a def with_treetensor(batch_): batch_ = [ttorch.tensor(b) for b in batch_] batch_ = ttorch.stack(batch_) batch_ = batch_.float() batch_.b = ttorch.pow(batch_.b, 2) + 1.0 batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5)) mean_b = batch_.b.mean() even_index_a = batch_.a[:, ::2] batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0) return batch_, mean_b, even_index_a def get_data(): return { 'a': torch.rand(size=(T, 8)), 'b': torch.rand(size=(6,)), 'c': { 'd': torch.randint(0, 10, size=(1,)) } } if __name__ == "__main__": batch = [get_data() for _ in range(B)] batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch)) batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch)) print(batch0) print('\n\n') print(batch1) assert torch.abs(mean0 - mean1) < 1e-6 print('mean0 & mean1:', mean0, mean1) print('\n') assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6 print('even_index_a0:', even_index_a0) print('even_index_a1:', even_index_a1) assert len(batch0) == B assert len(batch1) == B assert isinstance(batch1[0], ttorch.Tensor) print(batch1[0].shape) |
The output should be like below, and all the assertions can be passed.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | [{'a': tensor([[0.4891, 0.4983, 0.3518, 0.3453, 0.0062, 0.1869, 0.2393, 0.2685], [0.7944, 0.0033, 0.2901, 0.3635, 0.8537, 0.7105, 0.4277, 0.5171], [0.6045, 0.5730, 0.0281, 0.1397, 0.0980, 0.1078, 0.9350, 0.8590]]), 'b': tensor([0.8431, 0.4037, 0.4630, 0.6901, 0.8469, 0.8049]), 'c': {'d': tensor([1])}}, {'a': tensor([[0.3522, 0.0802, 0.7094, 0.5263, 0.7597, 0.8375, 0.3589, 0.5269], [0.8205, 0.3569, 0.1815, 0.6871, 0.1704, 0.7794, 0.2382, 0.5780], [0.3023, 0.4060, 0.6800, 0.9777, 0.7968, 0.7725, 0.5000, 0.6039]]), 'b': tensor([0.4178, 0.1959, 0.6307, 0.4192, 0.5912, 0.4633]), 'c': {'d': tensor([1])}}, {'a': tensor([[0.0300, 0.9757, 0.0152, 0.4424, 0.7109, 0.1433, 0.2944, 0.4652], [0.4663, 0.5075, 0.5755, 0.3119, 0.2193, 0.4129, 0.9700, 0.2495], [0.6488, 0.5735, 0.0649, 0.0577, 0.4324, 0.4917, 0.7781, 0.5089]]), 'b': tensor([0.9035, 0.7308, 0.0986, 0.7287, 0.1370, 0.2584]), 'c': {'d': tensor([6])}}, {'a': tensor([[0.7394, 0.7904, 0.2096, 0.4392, 0.8947, 0.0067, 0.0271, 0.7526], [0.8193, 0.4045, 0.2795, 0.3559, 0.5538, 0.6154, 0.8504, 0.7521], [0.8375, 0.7244, 0.8103, 0.6262, 0.1807, 0.3922, 0.1483, 0.8142]]), 'b': tensor([0.1548, 0.3610, 0.9881, 0.6237, 0.6998, 0.3502]), 'c': {'d': tensor([1])}}] (<Tensor 0x7f996e898f90> ├── 'a' --> tensor([[[0.4891, 0.4983, 0.3518, 0.3453, 0.0062, 0.1869, 0.2393, 0.2685], │ [0.7944, 0.0033, 0.2901, 0.3635, 0.8537, 0.7105, 0.4277, 0.5171], │ [0.6045, 0.5730, 0.0281, 0.1397, 0.0980, 0.1078, 0.9350, 0.8590]]]) ├── 'b' --> tensor([[1.7108, 1.1630, 1.2144, 1.4762, 1.7172, 1.6479]]) └── 'c' --> <Tensor 0x7f996e898f50> ├── 'd' --> tensor([[1.]]) └── 'noise' --> tensor([[[[ 1.0692, 0.5142, 1.2165, 0.7112, 0.1560], [-1.0644, -1.2981, 1.5262, -0.6554, 0.5670], [-0.9887, -0.9253, 1.0349, -0.2842, 0.1049], [ 0.5471, 0.0200, -1.4839, 1.3312, 0.6531]], [[ 0.7144, -1.1904, 0.3611, -0.1194, -0.6089], [-0.3058, -0.0769, 1.0874, -0.5698, 0.4647], [-1.2264, -0.2936, -0.4243, 0.2722, 0.8972], [-0.5769, 0.3364, 0.4670, -1.3205, -1.0617]], [[ 0.7087, 0.8764, 1.8436, 3.3415, -0.9340], [ 0.4428, -1.1900, -0.8580, 1.1566, 0.1243], [-2.5001, 1.7994, -0.6013, -0.4448, -0.4889], [ 1.3573, 0.0114, 0.3192, -0.4993, -0.0766]]]]) , <Tensor 0x7f996e89f050> ├── 'a' --> tensor([[[0.3522, 0.0802, 0.7094, 0.5263, 0.7597, 0.8375, 0.3589, 0.5269], │ [0.8205, 0.3569, 0.1815, 0.6871, 0.1704, 0.7794, 0.2382, 0.5780], │ [0.3023, 0.4060, 0.6800, 0.9777, 0.7968, 0.7725, 0.5000, 0.6039]]]) ├── 'b' --> tensor([[1.1745, 1.0384, 1.3978, 1.1757, 1.3495, 1.2147]]) └── 'c' --> <Tensor 0x7f996e898fd0> ├── 'd' --> tensor([[1.]]) └── 'noise' --> tensor([[[[-0.8616, 1.0088, 0.4363, 0.3854, -0.4530], [-1.3458, 0.2346, 0.0978, -1.1845, -0.2845], [ 1.4661, 0.0099, 1.9448, 0.9461, -0.6596], [ 0.5714, -0.4801, 0.3029, 1.3462, -0.5059]], [[-0.8858, -0.2870, -0.0291, 0.0954, 0.0958], [-1.1399, 1.2167, -0.1784, -1.1622, -0.8205], [-1.1587, -0.3150, -0.6488, -0.4852, 1.7651], [-0.5875, -0.0582, -1.3466, 0.5439, -1.4773]], [[ 0.9459, -1.6996, -0.1797, -1.3600, 1.0270], [-0.0406, 0.7433, 0.5458, 1.1187, 2.1452], [ 0.5953, 0.8925, -1.6194, -1.1096, -0.1076], [-0.4481, -0.3585, -0.2967, -1.3786, -0.2491]]]]) , <Tensor 0x7f996e89f0d0> ├── 'a' --> tensor([[[0.0300, 0.9757, 0.0152, 0.4424, 0.7109, 0.1433, 0.2944, 0.4652], │ [0.4663, 0.5075, 0.5755, 0.3119, 0.2193, 0.4129, 0.9700, 0.2495], │ [0.6488, 0.5735, 0.0649, 0.0577, 0.4324, 0.4917, 0.7781, 0.5089]]]) ├── 'b' --> tensor([[1.8163, 1.5341, 1.0097, 1.5310, 1.0188, 1.0668]]) └── 'c' --> <Tensor 0x7f996e89f090> ├── 'd' --> tensor([[6.]]) └── 'noise' --> tensor([[[[-0.0845, -0.2490, 0.4574, -0.2055, -0.5052], [ 0.6392, 1.4586, -2.1247, -0.0794, 0.9077], [-1.6232, 1.1221, -0.2364, -1.6875, -2.4838], [-0.3776, -0.6827, 1.1125, -1.3828, 0.2339]], [[ 0.9376, -0.8476, -0.8220, -0.9632, 1.1160], [ 0.3020, 0.9054, 0.5701, -0.7925, -0.4109], [-1.2571, -0.3713, -0.7708, 0.1949, -0.5856], [ 1.4660, 0.1360, -1.9473, -1.5097, -1.0084]], [[ 0.7180, 2.8065, -0.1474, 1.2503, 0.3180], [-2.3330, -1.0297, 0.7406, 0.8737, -1.4619], [ 1.4348, 1.0801, 1.4215, 0.2860, 1.4047], [-0.7142, -0.1309, -3.4383, -0.4981, -0.7635]]]]) , <Tensor 0x7f996e89f150> ├── 'a' --> tensor([[[0.7394, 0.7904, 0.2096, 0.4392, 0.8947, 0.0067, 0.0271, 0.7526], │ [0.8193, 0.4045, 0.2795, 0.3559, 0.5538, 0.6154, 0.8504, 0.7521], │ [0.8375, 0.7244, 0.8103, 0.6262, 0.1807, 0.3922, 0.1483, 0.8142]]]) ├── 'b' --> tensor([[1.0240, 1.1303, 1.9763, 1.3890, 1.4897, 1.1226]]) └── 'c' --> <Tensor 0x7f996e89f110> ├── 'd' --> tensor([[1.]]) └── 'noise' --> tensor([[[[-0.1892, -0.5689, 0.6290, 1.0417, -0.5784], [ 1.6671, 0.6265, -0.6924, 0.9414, 0.0544], [ 0.2755, 1.3875, 0.5984, 0.5777, -0.2418], [ 1.7944, -0.2855, -0.3323, 1.2764, 0.1819]], [[ 2.8290, -1.5367, 1.9294, 1.1419, 0.4531], [ 0.2037, 0.6561, 0.5935, 0.2346, -1.0272], [ 0.2232, 0.3071, -0.6290, -0.0665, 0.7090], [-1.1359, -0.2510, -1.1360, -1.6798, -0.4487]], [[-1.5467, 1.3900, 1.6637, -0.8699, -1.5268], [-1.0286, 0.6390, 0.0227, 0.0939, -0.5320], [-0.7108, -0.8937, -0.2330, -0.5127, -1.1275], [ 0.1164, 0.7993, -0.6485, 0.1880, -1.3724]]]]) ) mean0 & mean1: tensor(1.3495) tensor(1.3495) even_index_a0: tensor([[[0.4891, 0.4983, 0.3518, 0.3453, 0.0062, 0.1869, 0.2393, 0.2685], [0.6045, 0.5730, 0.0281, 0.1397, 0.0980, 0.1078, 0.9350, 0.8590]], [[0.3522, 0.0802, 0.7094, 0.5263, 0.7597, 0.8375, 0.3589, 0.5269], [0.3023, 0.4060, 0.6800, 0.9777, 0.7968, 0.7725, 0.5000, 0.6039]], [[0.0300, 0.9757, 0.0152, 0.4424, 0.7109, 0.1433, 0.2944, 0.4652], [0.6488, 0.5735, 0.0649, 0.0577, 0.4324, 0.4917, 0.7781, 0.5089]], [[0.7394, 0.7904, 0.2096, 0.4392, 0.8947, 0.0067, 0.0271, 0.7526], [0.8375, 0.7244, 0.8103, 0.6262, 0.1807, 0.3922, 0.1483, 0.8142]]]) even_index_a1: tensor([[[0.4891, 0.4983, 0.3518, 0.3453, 0.0062, 0.1869, 0.2393, 0.2685], [0.6045, 0.5730, 0.0281, 0.1397, 0.0980, 0.1078, 0.9350, 0.8590]], [[0.3522, 0.0802, 0.7094, 0.5263, 0.7597, 0.8375, 0.3589, 0.5269], [0.3023, 0.4060, 0.6800, 0.9777, 0.7968, 0.7725, 0.5000, 0.6039]], [[0.0300, 0.9757, 0.0152, 0.4424, 0.7109, 0.1433, 0.2944, 0.4652], [0.6488, 0.5735, 0.0649, 0.0577, 0.4324, 0.4917, 0.7781, 0.5089]], [[0.7394, 0.7904, 0.2096, 0.4392, 0.8947, 0.0067, 0.0271, 0.7526], [0.8375, 0.7244, 0.8103, 0.6262, 0.1807, 0.3922, 0.1483, 0.8142]]]) <Size 0x7f996e898e90> ├── 'a' --> torch.Size([1, 3, 8]) ├── 'b' --> torch.Size([1, 6]) └── 'c' --> <Size 0x7f996e898e50> ├── 'd' --> torch.Size([1, 1]) └── 'noise' --> torch.Size([1, 3, 4, 5]) |
The implement with treetensor API is much simpler and clearer.