提交 7ff4109d 编写于 作者: HansBug's avatar HansBug 😆

doc(hansbug): add documentation for treetensor.torch.Tensor

上级 3422e98c
......@@ -28,6 +28,11 @@ class Tensor(TreeTorch):
@doc_from(torch.Tensor.numpy)
@method_treelize(return_type=ndarray)
def numpy(self: torch.Tensor) -> np.ndarray:
"""
Returns ``self`` tree tensor as a NumPy ``ndarray``.
This tensor and the returned :class:`treetensor.numpy.ndarray` share the same underlying storage.
Changes to self tensor will be reflected in the ``ndarray`` and vice versa.
"""
return self.numpy()
@doc_from(torch.Tensor.tolist)
......@@ -56,28 +61,78 @@ class Tensor(TreeTorch):
@doc_from(torch.Tensor.cpu)
@method_treelize()
def cpu(self: torch.Tensor, *args, **kwargs):
"""
Returns a copy of this tree tensor in CPU memory.
If this tree tensor is already in CPU memory and on the correct device,
then no copy is performed and the original object is returned.
"""
return self.cpu(*args, **kwargs)
@doc_from(torch.Tensor.cuda)
@method_treelize()
def cuda(self: torch.Tensor, *args, **kwargs):
"""
Returns a copy of this tree tensor in CUDA memory.
If this tree tensor is already in CUDA memory and on the correct device,
then no copy is performed and the original object is returned.
"""
return self.cuda(*args, **kwargs)
@doc_from(torch.Tensor.to)
@method_treelize()
def to(self: torch.Tensor, *args, **kwargs):
"""
Turn the original tree tensor to another format.
Example::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.tensor({
... 'a': [[1, 11], [2, 22], [3, 33]],
... 'b': {'x': [[4, 5], [6, 7]]},
... }).to(torch.float64)
<Tensor 0x7ff363bb6518>
├── a --> tensor([[ 1., 11.],
│ [ 2., 22.],
│ [ 3., 33.]], dtype=torch.float64)
└── b --> <Tensor 0x7ff363bb6ef0>
└── x --> tensor([[4., 5.],
[6., 7.]], dtype=torch.float64)
"""
return self.to(*args, **kwargs)
@doc_from(torch.Tensor.numel)
@ireduce(sum)
@method_treelize(return_type=TreeObject)
def numel(self: torch.Tensor):
"""
See :func:`treetensor.torch.numel`
"""
return self.numel()
@property
@doc_from(torch.Tensor.shape)
@method_treelize(return_type=Size)
def shape(self: torch.Tensor):
"""
Get the size of the tensors in the tree.
Example::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.tensor({
... 'a': [[1, 11], [2, 22], [3, 33]],
... 'b': {'x': [[4, 5], [6, 7]]},
... }).shape
<Size 0x7ff363bbbd68>
├── a --> torch.Size([3, 2])
└── b --> <Size 0x7ff363bbbcf8>
└── x --> torch.Size([2, 2])
"""
return self.shape
@doc_from(torch.Tensor.all)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册