diff --git a/treetensor/torch/funcs.py b/treetensor/torch/funcs.py index de688e73717bfd7b877ff86b6c55ef107d14295f..302fa100d3af3da439c4cdc1cb204a65062d70c4 100644 --- a/treetensor/torch/funcs.py +++ b/treetensor/torch/funcs.py @@ -30,12 +30,54 @@ func_treelize = post_process(post_process(args_mapping( @doc_from(torch.zeros) @func_treelize() def zeros(*args, **kwargs): + """ + In ``treetensor``, you can use ``zeros`` to create a tree of tensors with all zeros. + + Example:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> ttorch.zeros(2, 3) # the same as torch.zeros(2, 3) + torch.tensor([[0.0, 0.0, 0.0], + [0.0, 0.0, 0.0]]) + + >>> ttorch.zeros({ + >>> 'a': (2, 3), + >>> 'b': (4, ), + >>> }) + ttorch.tensor({ + 'a': torch.tensor([[0.0, 0.0, 0.0], + [0.0, 0.0, 0.0]]), + 'b': torch.tensor([0.0, 0.0, 0.0, 0.0]), + }) + """ return torch.zeros(*args, **kwargs) @doc_from(torch.zeros_like) @func_treelize() def zeros_like(input_, *args, **kwargs): + """ + In ``treetensor``, you can use ``zeros_like`` to create a tree of tensors with all zeros like another tree. + + Example:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> ttorch.zeros_like(torch.randn(2, 3)) # the same as torch.zeros_like(torch.randn(2, 3)) + torch.tensor([[0.0, 0.0, 0.0], + [0.0, 0.0, 0.0]]) + + >>> ttorch.zeros_like({ + >>> 'a': torch.randn(2, 3), + >>> 'b': torch.randn(4, ), + >>> }) + ttorch.tensor({ + 'a': torch.tensor([[0.0, 0.0, 0.0], + [0.0, 0.0, 0.0]]), + 'b': torch.tensor([0.0, 0.0, 0.0, 0.0]), + }) + """ return torch.zeros_like(input_, *args, **kwargs) @@ -110,21 +152,36 @@ def all(input_, *args, **kwargs): >>> import torch >>> import treetensor.torch as ttorch - >>> all(torch.tensor([True, True])) # the same as torch.all + >>> ttorch.all(torch.tensor([True, True])) # the same as torch.all torch.tensor(True) - >>> all(ttorch.tensor({ + >>> ttorch.all(ttorch.tensor({ >>> 'a': [True, True], >>> 'b': [True, True], >>> })) torch.tensor(True) - >>> all(Tensor({ + >>> ttorch.all(ttorch.tensor({ >>> 'a': [True, True], >>> 'b': [True, False], >>> })) torch.tensor(False) + .. note:: + + In this ``all`` function, the return value should be a tensor with single boolean value. + + If what you need is a tree of boolean tensors, you should do like this + + >>> ttorch.tensor({ + >>> 'a': [True, True], + >>> 'b': [True, False], + >>> }).map(torch.all) + ttorch.tensor({ + 'a': torch.tensor(True), + 'b': torch.tensor(False), + }) + """ return torch.all(input_, *args, **kwargs) @@ -133,6 +190,44 @@ def all(input_, *args, **kwargs): @tireduce(torch.any) @func_treelize(return_type=TreeObject) def any(input_, *args, **kwargs): + """ + In ``treetensor``, you can get the ``any`` result of a whole tree with this function. + + Example:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> ttorch.any(torch.tensor([False, False])) # the same as torch.any + torch.tensor(False) + + >>> ttorch.any(ttorch.tensor({ + >>> 'a': [True, False], + >>> 'b': [False, False], + >>> })) + torch.tensor(True) + + >>> ttorch.any(ttorch.tensor({ + >>> 'a': [False, False], + >>> 'b': [False, False], + >>> })) + torch.tensor(False) + + .. note:: + + In this ``any`` function, the return value should be a tensor with single boolean value. + + If what you need is a tree of boolean tensors, you should do like this + + >>> ttorch.tensor({ + >>> 'a': [True, False], + >>> 'b': [False, False], + >>> }).map(torch.any) + ttorch.tensor({ + 'a': torch.tensor(True), + 'b': torch.tensor(False), + }) + + """ return torch.any(input_, *args, **kwargs)