提交 8217663f 编写于 作者: HansBug's avatar HansBug 😆

dev, doc, test(hansbug): add squeeze, unsqueeze, where, reshape

上级 b3286b03
......@@ -1655,3 +1655,94 @@ class TestTorchFuncs:
[[18, 21, 17, 12],
[36, 30, 33, 31]]]},
})).all()
@choose_mark()
def test_reshape(self):
t1 = ttorch.reshape(torch.tensor([[1, 2], [3, 4]]), (-1,))
assert isinstance(t1, torch.Tensor)
assert (t1 == ttorch.tensor([1, 2, 3, 4])).all()
t2 = ttorch.reshape(ttorch.tensor({
'a': [[1, 2], [3, 4]],
'b': {'x': [[2], [3], [5], [7], [11], [13]]},
}), (-1,))
assert (t2 == ttorch.tensor({
'a': [1, 2, 3, 4],
'b': {'x': [2, 3, 5, 7, 11, 13]},
})).all()
@choose_mark()
def test_squeeze(self):
t1 = torch.randint(100, (2, 1, 2, 1, 2))
assert t1.shape == torch.Size([2, 1, 2, 1, 2])
assert ttorch.squeeze(t1).shape == torch.Size([2, 2, 2])
t2 = ttorch.randint(100, {
'a': (2, 1, 2, 1, 2),
'b': {'x': (2, 1, 1, 3)},
})
assert t2.shape == ttorch.Size({
'a': (2, 1, 2, 1, 2),
'b': {'x': (2, 1, 1, 3)},
})
assert ttorch.squeeze(t2).shape == ttorch.Size({
'a': (2, 2, 2),
'b': {'x': (2, 3)},
})
@choose_mark()
def test_unsqueeze(self):
t1 = torch.randint(100, (100,))
assert t1.shape == torch.Size([100])
assert ttorch.unsqueeze(t1, 0).shape == torch.Size([1, 100])
tt1 = ttorch.randint(100, {
'a': (2, 2, 2),
'b': {'x': (2, 3)},
})
assert tt1.shape == ttorch.Size({
'a': (2, 2, 2),
'b': {'x': (2, 3)},
})
assert ttorch.unsqueeze(tt1, 1).shape == ttorch.Size({
'a': (2, 1, 2, 2),
'b': {'x': (2, 1, 3)},
})
@choose_mark()
def test_where(self):
t1 = ttorch.where(
torch.tensor([[True, False], [False, True]]),
torch.tensor([[2, 8], [16, 4]]),
torch.tensor([[3, 11], [5, 7]]),
)
assert isinstance(t1, torch.Tensor)
assert (t1 == ttorch.tensor([[2, 11],
[5, 4]])).all()
t2 = ttorch.tensor({
'a': [[27, 90, 80],
[12, 59, 5]],
'b': {'x': [[[71, 52, 92, 79],
[48, 4, 13, 96]],
[[72, 89, 44, 62],
[32, 4, 29, 76]],
[[6, 3, 93, 89],
[44, 89, 85, 90]]]},
})
assert (ttorch.where(t2 % 2 == 1, t2,
ttorch.zeros({'a': (2, 3), 'b': {'x': (3, 2, 4)}}, dtype=torch.long)) ==
ttorch.tensor({
'a': [[27, 0, 0],
[0, 59, 5]],
'b': {'x': [[[71, 0, 0, 79],
[0, 0, 13, 0]],
[[0, 89, 0, 0],
[0, 0, 29, 0]],
[[0, 3, 93, 89],
[0, 89, 85, 0]]]},
})).all()
......@@ -1324,3 +1324,139 @@ class TestTorchTensor:
[[73, 81, 11],
[58, 54, 78]]]
})).all()
@choose_mark()
def test_reshape(self):
t1 = torch.tensor([[1, 2], [3, 4]]).reshape((-1,))
assert isinstance(t1, torch.Tensor)
assert (t1 == ttorch.tensor([1, 2, 3, 4])).all()
t2 = ttorch.tensor({
'a': [[1, 2], [3, 4]],
'b': {'x': [[2], [3], [5], [7], [11], [13]]},
}).reshape((-1,))
assert (t2 == ttorch.tensor({
'a': [1, 2, 3, 4],
'b': {'x': [2, 3, 5, 7, 11, 13]},
})).all()
@choose_mark()
def test_squeeze(self):
t1 = torch.randint(100, (2, 1, 2, 1, 2))
assert t1.shape == torch.Size([2, 1, 2, 1, 2])
assert t1.squeeze().shape == torch.Size([2, 2, 2])
t2 = ttorch.randint(100, {
'a': (2, 1, 2, 1, 2),
'b': {'x': (2, 1, 1, 3)},
})
assert t2.shape == ttorch.Size({
'a': (2, 1, 2, 1, 2),
'b': {'x': (2, 1, 1, 3)},
})
assert t2.squeeze().shape == ttorch.Size({
'a': (2, 2, 2),
'b': {'x': (2, 3)},
})
@choose_mark()
def test_squeeze_(self):
t1 = torch.randint(100, (2, 1, 2, 1, 2))
assert t1.shape == torch.Size([2, 1, 2, 1, 2])
t1r = t1.squeeze_()
assert t1r is t1
assert t1.shape == torch.Size([2, 2, 2])
t2 = ttorch.randint(100, {
'a': (2, 1, 2, 1, 2),
'b': {'x': (2, 1, 1, 3)},
})
assert t2.shape == ttorch.Size({
'a': (2, 1, 2, 1, 2),
'b': {'x': (2, 1, 1, 3)},
})
t2r = t2.squeeze_()
assert t2r is t2
assert t2.shape == ttorch.Size({
'a': (2, 2, 2),
'b': {'x': (2, 3)},
})
@choose_mark()
def test_unsqueeze(self):
t1 = torch.randint(100, (100,))
assert t1.shape == torch.Size([100])
assert t1.unsqueeze(0).shape == torch.Size([1, 100])
tt1 = ttorch.randint(100, {
'a': (2, 2, 2),
'b': {'x': (2, 3)},
})
assert tt1.shape == ttorch.Size({
'a': (2, 2, 2),
'b': {'x': (2, 3)},
})
assert tt1.unsqueeze(1).shape == ttorch.Size({
'a': (2, 1, 2, 2),
'b': {'x': (2, 1, 3)},
})
@choose_mark()
def test_unsqueeze_(self):
t1 = torch.randint(100, (100,))
assert t1.shape == torch.Size([100])
t1r = t1.unsqueeze_(0)
assert t1r is t1
assert t1.shape == torch.Size([1, 100])
tt1 = ttorch.randint(100, {
'a': (2, 2, 2),
'b': {'x': (2, 3)},
})
assert tt1.shape == ttorch.Size({
'a': (2, 2, 2),
'b': {'x': (2, 3)},
})
tt1r = tt1.unsqueeze_(1)
assert tt1r is tt1
assert tt1.shape == ttorch.Size({
'a': (2, 1, 2, 2),
'b': {'x': (2, 1, 3)},
})
@choose_mark()
def test_where(self):
t1 = torch.tensor([[2, 8], [16, 4]]).where(
torch.tensor([[True, False], [False, True]]),
torch.tensor([[3, 11], [5, 7]]),
)
assert isinstance(t1, torch.Tensor)
assert (t1 == ttorch.tensor([[2, 11],
[5, 4]])).all()
t2 = ttorch.tensor({
'a': [[27, 90, 80],
[12, 59, 5]],
'b': {'x': [[[71, 52, 92, 79],
[48, 4, 13, 96]],
[[72, 89, 44, 62],
[32, 4, 29, 76]],
[[6, 3, 93, 89],
[44, 89, 85, 90]]]},
})
assert (t2.where(t2 % 2 == 1,
ttorch.zeros({'a': (2, 3), 'b': {'x': (3, 2, 4)}}, dtype=torch.long)) ==
ttorch.tensor({
'a': [[27, 0, 0],
[0, 59, 5]],
'b': {'x': [[[71, 0, 0, 79],
[0, 0, 13, 0]],
[[0, 89, 0, 0],
[0, 0, 29, 0]],
[[0, 3, 93, 89],
[0, 89, 85, 0]]]},
})).all()
......@@ -31,7 +31,7 @@ __all__ = [
'add', 'sub', 'mul', 'div', 'pow', 'neg', 'neg_',
'exp', 'exp_', 'exp2', 'exp2_', 'sqrt', 'sqrt_',
'log', 'log_', 'log2', 'log2_', 'log10', 'log10_',
'cat', 'split', 'stack',
'cat', 'split', 'stack', 'reshape', 'where', 'squeeze', 'unsqueeze',
]
func_treelize = post_process(post_process(args_mapping(
......@@ -2443,4 +2443,172 @@ def stack(tensors, *args, **kwargs):
return torch.stack(tensors, *args, **kwargs)
sys.modules[__name__] = module_autoremove(sys.modules[__name__])
# noinspection PyShadowingBuiltins
@doc_from_base()
@func_treelize()
def reshape(input, shape):
"""
Returns a tensor with the same data and number of elements as ``input``,
but with the specified shape. When possible, the returned tensor will be a view of ``input``.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.reshape(torch.tensor([[1, 2], [3, 4]]), (-1, ))
tensor([1, 2, 3, 4])
>>> ttorch.reshape(ttorch.tensor({
... 'a': [[1, 2], [3, 4]],
... 'b': {'x': [[2], [3], [5], [7], [11], [13]]},
... }), (-1, ))
<Tensor 0x7fc9efa3bda0>
├── a --> tensor([1, 2, 3, 4])
└── b --> <Tensor 0x7fc9efa3bcf8>
└── x --> tensor([ 2, 3, 5, 7, 11, 13])
.. note::
If the given ``shape`` is only one tuple, it should make sure that all the tensors
in this tree can be reshaped to the given ``shape``. Or you can give a tree of tuples
to reshape the tensors to different shapes.
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.reshape(ttorch.tensor({
... 'a': [[1, 2], [3, 4]],
... 'b': {'x': [[2], [3], [5], [7], [11], [13]]},
... }), {'a': (4, ), 'b': {'x': (3, 2)}})
<Tensor 0x7fc9efa3bd68>
├── a --> tensor([1, 2, 3, 4])
└── b --> <Tensor 0x7fc9efa3bf28>
└── x --> tensor([[ 2, 3],
[ 5, 7],
[11, 13]])
"""
return torch.reshape(input, shape)
# noinspection PyShadowingBuiltins
@doc_from_base()
@func_treelize()
def squeeze(input, *args, **kwargs):
"""
Returns a tensor with all the dimensions of ``input`` of size 1 removed.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> t1 = torch.randint(100, (2, 1, 2, 1, 2))
>>> t1.shape
torch.Size([2, 1, 2, 1, 2])
>>> ttorch.squeeze(t1).shape
torch.Size([2, 2, 2])
>>> tt1 = ttorch.randint(100, {
... 'a': (2, 1, 2, 1, 2),
... 'b': {'x': (2, 1, 1, 3)},
... })
>>> tt1.shape
<Size 0x7fa4c1b05410>
├── a --> torch.Size([2, 1, 2, 1, 2])
└── b --> <Size 0x7fa4c1b05510>
└── x --> torch.Size([2, 1, 1, 3])
>>> ttorch.squeeze(tt1).shape
<Size 0x7fa4c1b9f3d0>
├── a --> torch.Size([2, 2, 2])
└── b --> <Size 0x7fa4c1afe710>
└── x --> torch.Size([2, 3])
"""
return torch.squeeze(input, *args, *kwargs)
# noinspection PyShadowingBuiltins
@doc_from_base()
@func_treelize()
def unsqueeze(input, dim):
"""
Returns a new tensor with a dimension of size one inserted at the specified position.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> t1 = torch.randint(100, (100, ))
>>> t1.shape
torch.Size([100])
>>> ttorch.unsqueeze(t1, 0).shape
torch.Size([1, 100])
>>> tt1 = ttorch.randint(100, {
... 'a': (2, 2, 2),
... 'b': {'x': (2, 3)},
... })
>>> tt1.shape
<Size 0x7f5d1a5741d0>
├── a --> torch.Size([2, 2, 2])
└── b --> <Size 0x7f5d1a5740b8>
└── x --> torch.Size([2, 3])
>>> ttorch.unsqueeze(tt1, 1).shape
<Size 0x7f5d1a5c98d0>
├── a --> torch.Size([2, 1, 2, 2])
└── b --> <Size 0x7f5d1a5c99b0>
└── x --> torch.Size([2, 1, 3])
"""
return torch.unsqueeze(input, dim)
@doc_from_base()
@func_treelize()
def where(condition, x, y):
"""
Return a tree of tensors of elements selected from either ``x`` or ``y``, depending on ``condition``.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.where(
... torch.tensor([[True, False], [False, True]]),
... torch.tensor([[2, 8], [16, 4]]),
... torch.tensor([[3, 11], [5, 7]]),
... )
tensor([[ 2, 11],
[ 5, 4]])
>>> tt1 = ttorch.randint(1, 99, {'a': (2, 3), 'b': {'x': (3, 2, 4)}})
>>> tt1
<Tensor 0x7f6760ad9908>
├── a --> tensor([[27, 90, 80],
│ [12, 59, 5]])
└── b --> <Tensor 0x7f6760ad9860>
└── x --> tensor([[[71, 52, 92, 79],
[48, 4, 13, 96]],
[[72, 89, 44, 62],
[32, 4, 29, 76]],
[[ 6, 3, 93, 89],
[44, 89, 85, 90]]])
>>> ttorch.where(tt1 % 2 == 1, tt1, 0)
<Tensor 0x7f6760ad9d30>
├── a --> tensor([[27, 0, 0],
│ [ 0, 59, 5]])
└── b --> <Tensor 0x7f6760ad9f98>
└── x --> tensor([[[71, 0, 0, 79],
[ 0, 0, 13, 0]],
[[ 0, 89, 0, 0],
[ 0, 0, 29, 0]],
[[ 0, 3, 93, 89],
[ 0, 89, 85, 0]]])
"""
return torch.where(condition, x, y)
_current_module = sys.modules[__name__]
_current_module = module_autoremove(_current_module)
sys.modules[__name__] = _current_module
......@@ -661,3 +661,55 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
See :func:`treetensor.torch.split`.
"""
return self.split(split_size, *args, **kwargs)
@doc_from_base()
@method_treelize()
def reshape(self, *args, **kwargs):
"""
See :func:`treetensor.torch.reshape`.
"""
return self.reshape(*args, **kwargs)
@doc_from_base()
@method_treelize()
def squeeze(self, *args, **kwargs):
"""
See :func:`treetensor.torch.squeeze`.
"""
return self.squeeze(*args, **kwargs)
@doc_from_base()
@return_self
@method_treelize()
def squeeze_(self, *args, **kwargs):
"""
In-place version of :meth:`Tensor.squeeze'.
"""
return self.squeeze_(*args, **kwargs)
@doc_from_base()
@method_treelize()
def unsqueeze(self, dim):
"""
See :func:`treetensor.torch.unsqueeze`.
"""
return self.unsqueeze(dim)
@doc_from_base()
@return_self
@method_treelize()
def unsqueeze_(self, dim):
"""
In-place version of :meth:`Tensor.unsqueeze'.
"""
return self.unsqueeze_(dim)
@doc_from_base()
@method_treelize()
def where(self, condition, y, *args, **kwargs):
"""
``self.where(condition, y)`` is equivalent to
``treetensor.torch.where(condition, self, y)``.
See :func:`treetensor.torch.where`.
"""
return self.where(condition, y, *args, **kwargs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册