Customized Operations For Different Fields¶
Here is another example of the custom operations implemented with both native torch API and treetensor API.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import copy import torch import treetensor.torch as ttorch T, B = 3, 4 def with_nativetensor(batch_): mean_b_list = [] even_index_a_list = [] for i in range(len(batch_)): for k, v in batch_[i].items(): if k == 'a': v = v.float() even_index_a_list.append(v[::2]) elif k == 'b': v = v.float() transformed_v = torch.pow(v, 2) + 1.0 mean_b_list.append(transformed_v.mean()) elif k == 'c': for k1, v1 in v.items(): if k1 == 'd': v1 = v1.float() else: print('ignore keys: {}'.format(k1)) else: print('ignore keys: {}'.format(k)) for i in range(len(batch_)): for k in batch_[i].keys(): if k == 'd': batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5)) mean_b = sum(mean_b_list) / len(mean_b_list) even_index_a = torch.stack(even_index_a_list, dim=0) return batch_, mean_b, even_index_a def with_treetensor(batch_): batch_ = [ttorch.tensor(b) for b in batch_] batch_ = ttorch.stack(batch_) batch_ = batch_.float() batch_.b = ttorch.pow(batch_.b, 2) + 1.0 batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5)) mean_b = batch_.b.mean() even_index_a = batch_.a[:, ::2] batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0) return batch_, mean_b, even_index_a def get_data(): return { 'a': torch.rand(size=(T, 8)), 'b': torch.rand(size=(6,)), 'c': { 'd': torch.randint(0, 10, size=(1,)) } } if __name__ == "__main__": batch = [get_data() for _ in range(B)] batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch)) batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch)) print(batch0) print('\n\n') print(batch1) assert torch.abs(mean0 - mean1) < 1e-6 print('mean0 & mean1:', mean0, mean1) print('\n') assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6 print('even_index_a0:', even_index_a0) print('even_index_a1:', even_index_a1) assert len(batch0) == B assert len(batch1) == B assert isinstance(batch1[0], ttorch.Tensor) print(batch1[0].shape) |
The output should be like below, and all the assertions can be passed.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | [{'a': tensor([[0.9257, 0.4266, 0.8127, 0.9392, 0.2294, 0.8336, 0.7605, 0.9144], [0.0884, 0.3489, 0.4901, 0.6910, 0.8835, 0.0469, 0.2055, 0.1559], [0.9612, 0.0284, 0.2305, 0.2726, 0.2317, 0.6156, 0.5317, 0.9245]]), 'b': tensor([0.3955, 0.9165, 0.3444, 0.2769, 0.7824, 0.0054]), 'c': {'d': tensor([7])}}, {'a': tensor([[0.3599, 0.8606, 0.2802, 0.7488, 0.1605, 0.6629, 0.0725, 0.8811], [0.8521, 0.8738, 0.1401, 0.9934, 0.8279, 0.9412, 0.4467, 0.4728], [0.1503, 0.4891, 0.8713, 0.5262, 0.4487, 0.4566, 0.2726, 0.5312]]), 'b': tensor([0.5027, 0.8534, 0.4864, 0.6362, 0.3039, 0.9323]), 'c': {'d': tensor([2])}}, {'a': tensor([[0.5323, 0.4024, 0.1017, 0.1310, 0.8797, 0.5557, 0.9597, 0.3834], [0.2552, 0.9185, 0.1508, 0.9390, 0.4087, 0.4922, 0.4747, 0.8337], [0.8026, 0.3843, 0.5243, 0.1324, 0.9962, 0.7263, 0.5565, 0.1281]]), 'b': tensor([0.3252, 0.8908, 0.5887, 0.1604, 0.2188, 0.1202]), 'c': {'d': tensor([9])}}, {'a': tensor([[0.2808, 0.5135, 0.2673, 0.1119, 0.1685, 0.3233, 0.4603, 0.2135], [0.3917, 0.5572, 0.5263, 0.4482, 0.3145, 0.9395, 0.6414, 0.6244], [0.6007, 0.0662, 0.8816, 0.1150, 0.7659, 0.0525, 0.6582, 0.0748]]), 'b': tensor([0.8992, 0.9398, 0.5766, 0.4901, 0.8376, 0.2124]), 'c': {'d': tensor([4])}}] (<Tensor 0x7fb1d42a0f90> ├── 'a' --> tensor([[[0.9257, 0.4266, 0.8127, 0.9392, 0.2294, 0.8336, 0.7605, 0.9144], │ [0.0884, 0.3489, 0.4901, 0.6910, 0.8835, 0.0469, 0.2055, 0.1559], │ [0.9612, 0.0284, 0.2305, 0.2726, 0.2317, 0.6156, 0.5317, 0.9245]]]) ├── 'b' --> tensor([[1.1564, 1.8400, 1.1186, 1.0767, 1.6121, 1.0000]]) └── 'c' --> <Tensor 0x7fb1d42a0f50> ├── 'd' --> tensor([[7.]]) └── 'noise' --> tensor([[[[ 0.3579, -0.6408, 0.9409, -1.4434, 0.4311], [-1.1070, 0.4677, 1.0348, 0.3771, -1.8838], [-2.8118, -1.0437, -0.9295, -1.2843, -0.3687], [ 0.2445, -0.9614, -0.4345, 1.2057, -1.0109]], [[ 0.8215, -1.5672, 0.3532, -0.4932, -0.1338], [-0.7677, 2.4574, 1.1538, -1.0093, 0.3429], [-0.5387, -0.2103, -0.6103, -0.0457, -0.1621], [-0.8051, 1.4865, -1.0652, 0.6543, -0.6310]], [[-0.7165, 0.4523, 0.3861, 0.4409, 0.6224], [-0.8699, 0.6322, 0.5477, -1.9075, 0.3583], [ 0.5514, 0.1888, 0.1900, 1.1220, -1.4074], [ 1.1672, 0.4753, 1.9815, -0.7217, 1.5522]]]]) , <Tensor 0x7fb1d42a7050> ├── 'a' --> tensor([[[0.3599, 0.8606, 0.2802, 0.7488, 0.1605, 0.6629, 0.0725, 0.8811], │ [0.8521, 0.8738, 0.1401, 0.9934, 0.8279, 0.9412, 0.4467, 0.4728], │ [0.1503, 0.4891, 0.8713, 0.5262, 0.4487, 0.4566, 0.2726, 0.5312]]]) ├── 'b' --> tensor([[1.2527, 1.7283, 1.2365, 1.4048, 1.0923, 1.8691]]) └── 'c' --> <Tensor 0x7fb1d42a0fd0> ├── 'd' --> tensor([[2.]]) └── 'noise' --> tensor([[[[ 1.0493, 1.0225, -1.2554, -0.0225, -0.5568], [ 0.8382, 1.3356, 0.7029, -0.1023, 0.1161], [ 0.8582, 0.0081, -1.4857, 0.1620, 0.0081], [ 0.1524, 1.2400, -1.3443, 0.3112, 0.0590]], [[-0.2105, 1.4063, -0.7215, -0.8753, -1.2693], [-0.0757, -2.0197, 1.2485, -0.0855, -1.3478], [-1.1685, 0.0979, 0.5640, 0.0359, -0.6250], [ 1.0069, -0.0556, -1.4091, -0.5358, 0.5204]], [[ 0.8484, 2.1689, -3.0154, 0.1979, 0.7170], [ 0.8199, -0.0033, 1.9158, -0.2894, -0.4480], [-0.8988, -0.9611, -0.4975, -0.8224, -0.3678], [ 0.3620, -0.8372, -0.9729, 0.3145, -0.1403]]]]) , <Tensor 0x7fb1d42a70d0> ├── 'a' --> tensor([[[0.5323, 0.4024, 0.1017, 0.1310, 0.8797, 0.5557, 0.9597, 0.3834], │ [0.2552, 0.9185, 0.1508, 0.9390, 0.4087, 0.4922, 0.4747, 0.8337], │ [0.8026, 0.3843, 0.5243, 0.1324, 0.9962, 0.7263, 0.5565, 0.1281]]]) ├── 'b' --> tensor([[1.1057, 1.7935, 1.3466, 1.0257, 1.0479, 1.0144]]) └── 'c' --> <Tensor 0x7fb1d42a7090> ├── 'd' --> tensor([[9.]]) └── 'noise' --> tensor([[[[-1.3456, -0.2045, -2.2172, 1.5476, -0.8317], [ 1.1984, -0.4774, -1.4471, 0.8426, 0.0722], [ 0.2040, -0.5946, -0.7240, 1.0343, -1.3400], [ 0.2246, -0.8020, 2.1181, 1.0530, -0.5405]], [[-0.4242, 1.0609, 1.3618, -1.0542, -1.9754], [-1.0875, -0.3019, 1.0196, 0.2886, -1.7909], [-0.1305, -0.9066, 1.0717, 0.0406, 0.1890], [-1.2663, -0.1013, -0.1627, 1.5421, 0.1341]], [[ 1.5481, -0.5824, -0.4398, -0.3064, -0.8305], [-0.1520, 0.0735, -1.7671, -1.7186, -0.4571], [ 0.0132, -0.4691, 0.5066, 1.5784, 0.6487], [-0.1768, -1.6254, 1.4006, 0.8015, -1.0122]]]]) , <Tensor 0x7fb1d42a7150> ├── 'a' --> tensor([[[0.2808, 0.5135, 0.2673, 0.1119, 0.1685, 0.3233, 0.4603, 0.2135], │ [0.3917, 0.5572, 0.5263, 0.4482, 0.3145, 0.9395, 0.6414, 0.6244], │ [0.6007, 0.0662, 0.8816, 0.1150, 0.7659, 0.0525, 0.6582, 0.0748]]]) ├── 'b' --> tensor([[1.8086, 1.8831, 1.3325, 1.2402, 1.7015, 1.0451]]) └── 'c' --> <Tensor 0x7fb1d42a7110> ├── 'd' --> tensor([[4.]]) └── 'noise' --> tensor([[[[ 1.9825, -1.0632, -0.1681, 0.3218, -0.5951], [ 0.0681, -0.7200, -1.4380, -1.4024, 0.9112], [ 0.0235, -0.0749, -1.0119, 2.9069, -1.0269], [ 0.0671, 1.7382, -0.1251, 0.2279, -1.8300]], [[-0.2455, 0.3796, 1.0014, 2.2101, -0.6025], [ 0.5893, 0.2538, -0.6225, 0.0841, 0.8188], [ 1.2263, 0.2263, -0.0100, -0.9343, 0.3980], [ 0.1762, 0.4784, -0.3005, -0.1692, 0.2834]], [[ 0.8008, 0.8531, 0.1050, -0.1336, -1.3390], [-1.8002, -0.1440, 1.2495, 0.1844, 1.5687], [-1.2012, 0.1109, -0.5724, 0.0432, -1.4930], [-1.5599, 1.4164, 0.6822, -1.3556, 0.9604]]]]) ) mean0 & mean1: tensor(1.3639) tensor(1.3639) even_index_a0: tensor([[[0.9257, 0.4266, 0.8127, 0.9392, 0.2294, 0.8336, 0.7605, 0.9144], [0.9612, 0.0284, 0.2305, 0.2726, 0.2317, 0.6156, 0.5317, 0.9245]], [[0.3599, 0.8606, 0.2802, 0.7488, 0.1605, 0.6629, 0.0725, 0.8811], [0.1503, 0.4891, 0.8713, 0.5262, 0.4487, 0.4566, 0.2726, 0.5312]], [[0.5323, 0.4024, 0.1017, 0.1310, 0.8797, 0.5557, 0.9597, 0.3834], [0.8026, 0.3843, 0.5243, 0.1324, 0.9962, 0.7263, 0.5565, 0.1281]], [[0.2808, 0.5135, 0.2673, 0.1119, 0.1685, 0.3233, 0.4603, 0.2135], [0.6007, 0.0662, 0.8816, 0.1150, 0.7659, 0.0525, 0.6582, 0.0748]]]) even_index_a1: tensor([[[0.9257, 0.4266, 0.8127, 0.9392, 0.2294, 0.8336, 0.7605, 0.9144], [0.9612, 0.0284, 0.2305, 0.2726, 0.2317, 0.6156, 0.5317, 0.9245]], [[0.3599, 0.8606, 0.2802, 0.7488, 0.1605, 0.6629, 0.0725, 0.8811], [0.1503, 0.4891, 0.8713, 0.5262, 0.4487, 0.4566, 0.2726, 0.5312]], [[0.5323, 0.4024, 0.1017, 0.1310, 0.8797, 0.5557, 0.9597, 0.3834], [0.8026, 0.3843, 0.5243, 0.1324, 0.9962, 0.7263, 0.5565, 0.1281]], [[0.2808, 0.5135, 0.2673, 0.1119, 0.1685, 0.3233, 0.4603, 0.2135], [0.6007, 0.0662, 0.8816, 0.1150, 0.7659, 0.0525, 0.6582, 0.0748]]]) <Size 0x7fb1d42a0e90> ├── 'a' --> torch.Size([1, 3, 8]) ├── 'b' --> torch.Size([1, 6]) └── 'c' --> <Size 0x7fb1d42a0e50> ├── 'd' --> torch.Size([1, 1]) └── 'noise' --> torch.Size([1, 3, 4, 5]) |
The implement with treetensor API is much simpler and clearer.