From 82220e54c8dff88ced77d043e066e79de395673d Mon Sep 17 00:00:00 2001 From: HansBug Date: Thu, 30 Sep 2021 23:01:18 +0800 Subject: [PATCH] dev(hansbug): add operation show --- docs/source/best_practice/operation/index.rst | 19 +++++ .../best_practice/operation/operation.demo.py | 81 +++++++++++++++++++ docs/source/index.rst | 1 + 3 files changed, 101 insertions(+) create mode 100644 docs/source/best_practice/operation/index.rst create mode 100644 docs/source/best_practice/operation/operation.demo.py diff --git a/docs/source/best_practice/operation/index.rst b/docs/source/best_practice/operation/index.rst new file mode 100644 index 000000000..ce325a62c --- /dev/null +++ b/docs/source/best_practice/operation/index.rst @@ -0,0 +1,19 @@ +Customized Operations For Different Fields +============================================ + +Here is another example of the custom operations implemented \ +with both native torch API and treetensor API. + +.. literalinclude:: operation.demo.py + :language: python + :linenos: + +The output should be like below, and all the assertions can \ +be passed. + +.. literalinclude:: operation.demo.py.txt + :language: text + :linenos: + +The implement with treetensor API is much simpler and clearer. + diff --git a/docs/source/best_practice/operation/operation.demo.py b/docs/source/best_practice/operation/operation.demo.py new file mode 100644 index 000000000..6455e0dc2 --- /dev/null +++ b/docs/source/best_practice/operation/operation.demo.py @@ -0,0 +1,81 @@ +import copy + +import torch + +import treetensor.torch as ttorch + +T, B = 3, 4 + + +def with_nativetensor(batch_): + mean_b_list = [] + even_index_a_list = [] + for i in range(len(batch_)): + for k, v in batch_[i].items(): + if k == 'a': + v = v.float() + even_index_a_list.append(v[::2]) + elif k == 'b': + v = v.float() + transformed_v = torch.pow(v, 2) + 1.0 + mean_b_list.append(transformed_v.mean()) + elif k == 'c': + for k1, v1 in v.items(): + if k1 == 'd': + v1 = v1.float() + else: + print('ignore keys: {}'.format(k1)) + else: + print('ignore keys: {}'.format(k)) + for i in range(len(batch_)): + for k in batch_[i].keys(): + if k == 'd': + batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5)) + + mean_b = sum(mean_b_list) / len(mean_b_list) + even_index_a = torch.stack(even_index_a_list, dim=0) + return batch_, mean_b, even_index_a + + +def with_treetensor(batch_): + batch_ = [ttorch.tensor(b) for b in batch_] + batch_ = ttorch.stack(batch_) + batch_ = batch_.float() + batch_.b = ttorch.pow(batch_.b, 2) + 1.0 + batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5)) + mean_b = batch_.b.mean() + even_index_a = batch_.a[:, ::2] + batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0) + return batch_, mean_b, even_index_a + + +def get_data(): + return { + 'a': torch.rand(size=(T, 8)), + 'b': torch.rand(size=(6,)), + 'c': { + 'd': torch.randint(0, 10, size=(1,)) + } + } + + +if __name__ == "__main__": + batch = [get_data() for _ in range(B)] + batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch)) + batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch)) + print(batch0) + print('\n\n') + print(batch1) + + assert torch.abs(mean0 - mean1) < 1e-6 + print('mean0 & mean1:', mean0, mean1) + print('\n') + + assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6 + print('even_index_a0:', even_index_a0) + print('even_index_a1:', even_index_a1) + + assert len(batch0) == B + assert len(batch1) == B + assert isinstance(batch1[0], ttorch.Tensor) + print(batch1[0].shape) diff --git a/docs/source/index.rst b/docs/source/index.rst index 500a26750..4502618c4 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -20,6 +20,7 @@ module. :caption: Best Practice best_practice/stack/index + best_practice/operation/index .. toctree:: :maxdepth: 2 -- GitLab