Customized Operations For Different FieldsΒΆ
Here is another example of the custom operations implemented with both native torch API and treetensor API.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import copy
import torch
import treetensor.torch as ttorch
T, B = 3, 4
def with_nativetensor(batch_):
mean_b_list = []
even_index_a_list = []
for i in range(len(batch_)):
for k, v in batch_[i].items():
if k == 'a':
v = v.float()
even_index_a_list.append(v[::2])
elif k == 'b':
v = v.float()
transformed_v = torch.pow(v, 2) + 1.0
mean_b_list.append(transformed_v.mean())
elif k == 'c':
for k1, v1 in v.items():
if k1 == 'd':
v1 = v1.float()
else:
print('ignore keys: {}'.format(k1))
else:
print('ignore keys: {}'.format(k))
for i in range(len(batch_)):
for k in batch_[i].keys():
if k == 'd':
batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5))
mean_b = sum(mean_b_list) / len(mean_b_list)
even_index_a = torch.stack(even_index_a_list, dim=0)
return batch_, mean_b, even_index_a
def with_treetensor(batch_):
batch_ = [ttorch.tensor(b) for b in batch_]
batch_ = ttorch.stack(batch_)
batch_ = batch_.float()
batch_.b = ttorch.pow(batch_.b, 2) + 1.0
batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5))
mean_b = batch_.b.mean()
even_index_a = batch_.a[:, ::2]
batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0)
return batch_, mean_b, even_index_a
def get_data():
return {
'a': torch.rand(size=(T, 8)),
'b': torch.rand(size=(6,)),
'c': {
'd': torch.randint(0, 10, size=(1,))
}
}
if __name__ == "__main__":
batch = [get_data() for _ in range(B)]
batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch))
batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch))
print(batch0)
print('\n\n')
print(batch1)
assert torch.abs(mean0 - mean1) < 1e-6
print('mean0 & mean1:', mean0, mean1)
print('\n')
assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6
print('even_index_a0:', even_index_a0)
print('even_index_a1:', even_index_a1)
assert len(batch0) == B
assert len(batch1) == B
assert isinstance(batch1[0], ttorch.Tensor)
print(batch1[0].shape)
|
The output should be like below, and all the assertions can be passed.
1 |
|
The implement with treetensor API is much simpler and clearer.