提交 0150535c 编写于 作者: HansBug's avatar HansBug 😆

dev(hansbug): upgrade all and any

上级 980844e9
......@@ -27,7 +27,7 @@ class TestTorchFuncsReduction:
r4 = ttorch.all({
'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, True]),
}).all()
})
assert torch.is_tensor(r4)
assert r4 == torch.tensor(True)
assert r4
......@@ -35,7 +35,7 @@ class TestTorchFuncsReduction:
r5 = ttorch.all({
'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, False]),
}).all()
})
assert torch.is_tensor(r5)
assert r5 == torch.tensor(False)
assert not r5
......@@ -43,11 +43,36 @@ class TestTorchFuncsReduction:
r6 = ttorch.all({
'a': torch.tensor([False, False, False]),
'b': torch.tensor([False, False, False]),
}).all()
})
assert torch.is_tensor(r6)
assert r6 == torch.tensor(False)
assert not r6
r7 = ttorch.all(ttorch.tensor({
'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, False]),
}), reduce=False)
assert (r7 == ttorch.tensor({
'a': True, 'b': False
})).all()
r8 = ttorch.all(ttorch.tensor({
'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, False]),
}), dim=0)
assert (r8 == ttorch.tensor({
'a': True, 'b': False
})).all()
with pytest.warns(UserWarning):
r9 = ttorch.all(ttorch.tensor({
'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, False]),
}), dim=0, reduce=True)
assert (r9 == ttorch.tensor({
'a': True, 'b': False
})).all()
@choose_mark()
def test_any(self):
r1 = ttorch.any(torch.tensor([True, True, True]))
......@@ -68,7 +93,7 @@ class TestTorchFuncsReduction:
r4 = ttorch.any({
'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, True]),
}).all()
})
assert torch.is_tensor(r4)
assert r4 == torch.tensor(True)
assert r4
......@@ -76,7 +101,7 @@ class TestTorchFuncsReduction:
r5 = ttorch.any({
'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, False]),
}).all()
})
assert torch.is_tensor(r5)
assert r5 == torch.tensor(True)
assert r5
......@@ -84,11 +109,36 @@ class TestTorchFuncsReduction:
r6 = ttorch.any({
'a': torch.tensor([False, False, False]),
'b': torch.tensor([False, False, False]),
}).all()
})
assert torch.is_tensor(r6)
assert r6 == torch.tensor(False)
assert not r6
r7 = ttorch.any(ttorch.tensor({
'a': torch.tensor([True, True, False]),
'b': torch.tensor([False, False, False]),
}), reduce=False)
assert (r7 == ttorch.tensor({
'a': True, 'b': False
})).all()
r8 = ttorch.any(ttorch.tensor({
'a': torch.tensor([True, True, False]),
'b': torch.tensor([False, False, False]),
}), dim=0)
assert (r8 == ttorch.tensor({
'a': True, 'b': False
})).all()
with pytest.warns(UserWarning):
r9 = ttorch.any(ttorch.tensor({
'a': torch.tensor([True, True, False]),
'b': torch.tensor([False, False, False]),
}), dim=0, reduce=True)
assert (r9 == ttorch.tensor({
'a': True, 'b': False
})).all()
@choose_mark()
def test_min(self):
t1 = ttorch.min(torch.tensor([1.0, 2.0, 1.5]))
......
import pytest
import torch
import treetensor.torch as ttorch
......@@ -24,6 +25,22 @@ class TestTorchTensorReduction:
assert t2.dtype == torch.bool
assert not t2
t3 = ttorch.tensor({
'a': [True, False],
'b': {'x': [[True, True, ], [True, True, ]]}
}).all(reduce=False)
assert (t3 == ttorch.tensor({
'a': False, 'b': {'x': True},
})).all()
t4 = ttorch.tensor({
'a': [True, False],
'b': {'x': [[True, True, ], [True, True, ]]}
}).all(dim=0)
assert (t4 == ttorch.tensor({
'a': False, 'b': {'x': [True, True]},
})).all()
@choose_mark()
def test_any(self):
t1 = ttorch.Tensor({
......@@ -42,6 +59,31 @@ class TestTorchTensorReduction:
assert t2.dtype == torch.bool
assert not t2
t3 = ttorch.Tensor({
'a': [True, False],
'b': {'x': [[False, False, ], [False, False, ]]}
}).any(reduce=False)
assert (t3 == ttorch.tensor({
'a': True, 'b': False,
}))
t4 = ttorch.Tensor({
'a': [True, False],
'b': {'x': [[False, False, ], [False, False, ]]}
}).any(dim=0)
assert (t4 == ttorch.tensor({
'a': True, 'b': [False, False],
}))
with pytest.warns(UserWarning):
t5 = ttorch.Tensor({
'a': [True, False],
'b': {'x': [[False, False, ], [False, False, ]]}
}).any(dim=0, reduce=True)
assert (t5 == ttorch.tensor({
'a': True, 'b': [False, False],
}))
@choose_mark()
def test_max(self):
t1 = ttorch.Tensor({
......
......@@ -15,7 +15,7 @@ __all__ = [
@doc_from_base()
@func_treelize()
def tensor(*args, **kwargs):
def tensor(data, *args, **kwargs):
"""
In ``treetensor``, you can create a tree tensor with simple data structure.
......@@ -36,7 +36,10 @@ def tensor(*args, **kwargs):
└── c --> tensor([[ True, False],
[False, True]])
"""
return torch.tensor(*args, **kwargs)
if torch.is_tensor(data):
return data
else:
return torch.tensor(data, *args, **kwargs)
# noinspection PyShadowingBuiltins
......
......@@ -11,11 +11,23 @@ __all__ = [
]
# noinspection PyShadowingBuiltins,PyUnusedLocal
@post_reduce(torch.all)
@func_treelize(return_type=Object)
def _all_r(input, *args, **kwargs):
return input
# noinspection PyShadowingBuiltins
@func_treelize()
def _all_nr(input, *args, **kwargs):
return torch.all(input, *args, **kwargs)
# noinspection PyShadowingBuiltins,PyUnusedLocal
@doc_from_base()
@rmreduce(torch.all)
@func_treelize(return_type=Object)
def all(input, *args, **kwargs):
@auto_reduce(_all_r, _all_nr)
def all(input, *args, reduce=None, **kwargs):
"""
In ``treetensor``, you can get the ``all`` result of a whole tree with this function.
......@@ -32,29 +44,39 @@ def all(input, *args, **kwargs):
>>> ttorch.all(ttorch.tensor({'a': [True, True], 'b': {'x': [True, False]}}))
tensor(False)
.. note::
In this ``all`` function, the return value should be a tensor with single boolean value.
If what you need is a tree of boolean tensors, you should do like this
>>> ttorch.all(ttorch.tensor({'a': [True, True], 'b': {'x': [True, False]}}), reduce=False)
<Tensor 0x7fcda55652b0>
├── a --> tensor(True)
└── b --> <Tensor 0x7fcda5565208>
└── x --> tensor(False)
>>> ttorch.tensor({
... 'a': [True, True],
... 'b': {'x': [True, False]},
... }).map(lambda x: torch.all(x))
<Tensor 0x7ff363bbc588>
>>> ttorch.all(ttorch.tensor({'a': [True, True], 'b': {'x': [True, False]}}), dim=0)
<Tensor 0x7fcda5565780>
├── a --> tensor(True)
└── b --> <Tensor 0x7ff363bb6438>
└── b --> <Tensor 0x7fcda55656d8>
└── x --> tensor(False)
"""
return torch.all(input, *args, **kwargs)
pass # pragma: no cover
# noinspection PyShadowingBuiltins,PyUnusedLocal
@post_reduce(torch.any)
@func_treelize(return_type=Object)
def _any_r(input, *args, **kwargs):
return input
# noinspection PyShadowingBuiltins
@func_treelize()
def _any_nr(input, *args, **kwargs):
return torch.any(input, *args, **kwargs)
# noinspection PyShadowingBuiltins,PyUnusedLocal
@doc_from_base()
@rmreduce(torch.any)
@func_treelize(return_type=Object)
def any(input, *args, **kwargs):
@auto_reduce(_any_r, _any_nr)
def any(input, *args, reduce=None, **kwargs):
"""
In ``treetensor``, you can get the ``any`` result of a whole tree with this function.
......@@ -86,7 +108,7 @@ def any(input, *args, **kwargs):
└── b --> <Tensor 0x7ff363bc67f0>
└── x --> tensor(False)
"""
return torch.any(input, *args, **kwargs)
pass # pragma: no cover
# noinspection PyShadowingBuiltins
......
......@@ -184,25 +184,45 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return self.requires_grad_(requires_grad)
# noinspection PyShadowingBuiltins,PyUnusedLocal
@post_reduce(torch.all)
@method_treelize(return_type=Object)
def __all_r(self, *args, **kwargs):
return self
# noinspection PyShadowingBuiltins
@method_treelize()
def __all_nr(self, *args, **kwargs):
return torch.all(self, *args, **kwargs)
# noinspection PyArgumentList
@doc_from_base()
@rmreduce(torch.all)
@method_treelize(return_type=Object)
def all(self: torch.Tensor, *args, **kwargs) -> bool:
@auto_reduce(__all_r, __all_nr)
def all(self: torch.Tensor, *args, reduce=None, **kwargs) -> bool:
"""
See :func:`treetensor.torch.all`
"""
return self.all(*args, **kwargs)
pass # pragma: no cover
# noinspection PyShadowingBuiltins,PyUnusedLocal
@post_reduce(torch.any)
@method_treelize(return_type=Object)
def __any_r(self, *args, **kwargs):
return self
# noinspection PyShadowingBuiltins
@method_treelize()
def __any_nr(self, *args, **kwargs):
return torch.any(self, *args, **kwargs)
# noinspection PyArgumentList
@doc_from_base()
@rmreduce(torch.any)
@method_treelize(return_type=Object)
def any(self: torch.Tensor, *args, **kwargs) -> bool:
@auto_reduce(__any_r, __any_nr)
def any(self: torch.Tensor, *args, reduce=None, **kwargs) -> bool:
"""
See :func:`treetensor.torch.any`
"""
return self.any(*args, **kwargs)
pass # pragma: no cover
@doc_from_base()
@rmreduce(torch.max)
......@@ -762,7 +782,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
@doc_from_base()
@auto_reduce(__std_r, __std_nr)
@method_treelize()
def std(self, *args, **kwargs):
def std(self, *args, reduce=None, **kwargs):
"""
See :func:`treetensor.torch.std`.
"""
......@@ -781,7 +801,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
@doc_from_base()
@auto_reduce(__mean_r, __mean_nr)
@method_treelize()
def mean(self, *args, **kwargs):
def mean(self, *args, reduce=None, **kwargs):
"""
See :func:`treetensor.torch.mean`.
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册