From 0150535ccaf49d31fcfdd292b0b8b8c3c9bfd939 Mon Sep 17 00:00:00 2001 From: HansBug Date: Tue, 28 Sep 2021 14:45:20 +0800 Subject: [PATCH] dev(hansbug): upgrade all and any --- test/torch/funcs/test_reduction.py | 62 ++++++++++++++++++++++++++--- test/torch/tensor/test_reduction.py | 42 +++++++++++++++++++ treetensor/torch/funcs/construct.py | 7 +++- treetensor/torch/funcs/reduction.py | 62 +++++++++++++++++++---------- treetensor/torch/tensor.py | 40 ++++++++++++++----- 5 files changed, 175 insertions(+), 38 deletions(-) diff --git a/test/torch/funcs/test_reduction.py b/test/torch/funcs/test_reduction.py index c06e40004..7e341c896 100644 --- a/test/torch/funcs/test_reduction.py +++ b/test/torch/funcs/test_reduction.py @@ -27,7 +27,7 @@ class TestTorchFuncsReduction: r4 = ttorch.all({ 'a': torch.tensor([True, True, True]), 'b': torch.tensor([True, True, True]), - }).all() + }) assert torch.is_tensor(r4) assert r4 == torch.tensor(True) assert r4 @@ -35,7 +35,7 @@ class TestTorchFuncsReduction: r5 = ttorch.all({ 'a': torch.tensor([True, True, True]), 'b': torch.tensor([True, True, False]), - }).all() + }) assert torch.is_tensor(r5) assert r5 == torch.tensor(False) assert not r5 @@ -43,11 +43,36 @@ class TestTorchFuncsReduction: r6 = ttorch.all({ 'a': torch.tensor([False, False, False]), 'b': torch.tensor([False, False, False]), - }).all() + }) assert torch.is_tensor(r6) assert r6 == torch.tensor(False) assert not r6 + r7 = ttorch.all(ttorch.tensor({ + 'a': torch.tensor([True, True, True]), + 'b': torch.tensor([True, True, False]), + }), reduce=False) + assert (r7 == ttorch.tensor({ + 'a': True, 'b': False + })).all() + + r8 = ttorch.all(ttorch.tensor({ + 'a': torch.tensor([True, True, True]), + 'b': torch.tensor([True, True, False]), + }), dim=0) + assert (r8 == ttorch.tensor({ + 'a': True, 'b': False + })).all() + + with pytest.warns(UserWarning): + r9 = ttorch.all(ttorch.tensor({ + 'a': torch.tensor([True, True, True]), + 'b': torch.tensor([True, True, False]), + }), dim=0, reduce=True) + assert (r9 == ttorch.tensor({ + 'a': True, 'b': False + })).all() + @choose_mark() def test_any(self): r1 = ttorch.any(torch.tensor([True, True, True])) @@ -68,7 +93,7 @@ class TestTorchFuncsReduction: r4 = ttorch.any({ 'a': torch.tensor([True, True, True]), 'b': torch.tensor([True, True, True]), - }).all() + }) assert torch.is_tensor(r4) assert r4 == torch.tensor(True) assert r4 @@ -76,7 +101,7 @@ class TestTorchFuncsReduction: r5 = ttorch.any({ 'a': torch.tensor([True, True, True]), 'b': torch.tensor([True, True, False]), - }).all() + }) assert torch.is_tensor(r5) assert r5 == torch.tensor(True) assert r5 @@ -84,11 +109,36 @@ class TestTorchFuncsReduction: r6 = ttorch.any({ 'a': torch.tensor([False, False, False]), 'b': torch.tensor([False, False, False]), - }).all() + }) assert torch.is_tensor(r6) assert r6 == torch.tensor(False) assert not r6 + r7 = ttorch.any(ttorch.tensor({ + 'a': torch.tensor([True, True, False]), + 'b': torch.tensor([False, False, False]), + }), reduce=False) + assert (r7 == ttorch.tensor({ + 'a': True, 'b': False + })).all() + + r8 = ttorch.any(ttorch.tensor({ + 'a': torch.tensor([True, True, False]), + 'b': torch.tensor([False, False, False]), + }), dim=0) + assert (r8 == ttorch.tensor({ + 'a': True, 'b': False + })).all() + + with pytest.warns(UserWarning): + r9 = ttorch.any(ttorch.tensor({ + 'a': torch.tensor([True, True, False]), + 'b': torch.tensor([False, False, False]), + }), dim=0, reduce=True) + assert (r9 == ttorch.tensor({ + 'a': True, 'b': False + })).all() + @choose_mark() def test_min(self): t1 = ttorch.min(torch.tensor([1.0, 2.0, 1.5])) diff --git a/test/torch/tensor/test_reduction.py b/test/torch/tensor/test_reduction.py index 4dae46317..9867a9f7e 100644 --- a/test/torch/tensor/test_reduction.py +++ b/test/torch/tensor/test_reduction.py @@ -1,3 +1,4 @@ +import pytest import torch import treetensor.torch as ttorch @@ -24,6 +25,22 @@ class TestTorchTensorReduction: assert t2.dtype == torch.bool assert not t2 + t3 = ttorch.tensor({ + 'a': [True, False], + 'b': {'x': [[True, True, ], [True, True, ]]} + }).all(reduce=False) + assert (t3 == ttorch.tensor({ + 'a': False, 'b': {'x': True}, + })).all() + + t4 = ttorch.tensor({ + 'a': [True, False], + 'b': {'x': [[True, True, ], [True, True, ]]} + }).all(dim=0) + assert (t4 == ttorch.tensor({ + 'a': False, 'b': {'x': [True, True]}, + })).all() + @choose_mark() def test_any(self): t1 = ttorch.Tensor({ @@ -42,6 +59,31 @@ class TestTorchTensorReduction: assert t2.dtype == torch.bool assert not t2 + t3 = ttorch.Tensor({ + 'a': [True, False], + 'b': {'x': [[False, False, ], [False, False, ]]} + }).any(reduce=False) + assert (t3 == ttorch.tensor({ + 'a': True, 'b': False, + })) + + t4 = ttorch.Tensor({ + 'a': [True, False], + 'b': {'x': [[False, False, ], [False, False, ]]} + }).any(dim=0) + assert (t4 == ttorch.tensor({ + 'a': True, 'b': [False, False], + })) + + with pytest.warns(UserWarning): + t5 = ttorch.Tensor({ + 'a': [True, False], + 'b': {'x': [[False, False, ], [False, False, ]]} + }).any(dim=0, reduce=True) + assert (t5 == ttorch.tensor({ + 'a': True, 'b': [False, False], + })) + @choose_mark() def test_max(self): t1 = ttorch.Tensor({ diff --git a/treetensor/torch/funcs/construct.py b/treetensor/torch/funcs/construct.py index 54c280987..f8a7d10a9 100644 --- a/treetensor/torch/funcs/construct.py +++ b/treetensor/torch/funcs/construct.py @@ -15,7 +15,7 @@ __all__ = [ @doc_from_base() @func_treelize() -def tensor(*args, **kwargs): +def tensor(data, *args, **kwargs): """ In ``treetensor``, you can create a tree tensor with simple data structure. @@ -36,7 +36,10 @@ def tensor(*args, **kwargs): └── c --> tensor([[ True, False], [False, True]]) """ - return torch.tensor(*args, **kwargs) + if torch.is_tensor(data): + return data + else: + return torch.tensor(data, *args, **kwargs) # noinspection PyShadowingBuiltins diff --git a/treetensor/torch/funcs/reduction.py b/treetensor/torch/funcs/reduction.py index aba5570c5..badd65c6b 100644 --- a/treetensor/torch/funcs/reduction.py +++ b/treetensor/torch/funcs/reduction.py @@ -11,11 +11,23 @@ __all__ = [ ] +# noinspection PyShadowingBuiltins,PyUnusedLocal +@post_reduce(torch.all) +@func_treelize(return_type=Object) +def _all_r(input, *args, **kwargs): + return input + + # noinspection PyShadowingBuiltins +@func_treelize() +def _all_nr(input, *args, **kwargs): + return torch.all(input, *args, **kwargs) + + +# noinspection PyShadowingBuiltins,PyUnusedLocal @doc_from_base() -@rmreduce(torch.all) -@func_treelize(return_type=Object) -def all(input, *args, **kwargs): +@auto_reduce(_all_r, _all_nr) +def all(input, *args, reduce=None, **kwargs): """ In ``treetensor``, you can get the ``all`` result of a whole tree with this function. @@ -32,29 +44,39 @@ def all(input, *args, **kwargs): >>> ttorch.all(ttorch.tensor({'a': [True, True], 'b': {'x': [True, False]}})) tensor(False) - .. note:: - - In this ``all`` function, the return value should be a tensor with single boolean value. + >>> ttorch.all(ttorch.tensor({'a': [True, True], 'b': {'x': [True, False]}}), reduce=False) + + ├── a --> tensor(True) + └── b --> + └── x --> tensor(False) - If what you need is a tree of boolean tensors, you should do like this + >>> ttorch.all(ttorch.tensor({'a': [True, True], 'b': {'x': [True, False]}}), dim=0) + + ├── a --> tensor(True) + └── b --> + └── x --> tensor(False) - >>> ttorch.tensor({ - ... 'a': [True, True], - ... 'b': {'x': [True, False]}, - ... }).map(lambda x: torch.all(x)) - - ├── a --> tensor(True) - └── b --> - └── x --> tensor(False) """ - return torch.all(input, *args, **kwargs) + pass # pragma: no cover + + +# noinspection PyShadowingBuiltins,PyUnusedLocal +@post_reduce(torch.any) +@func_treelize(return_type=Object) +def _any_r(input, *args, **kwargs): + return input # noinspection PyShadowingBuiltins +@func_treelize() +def _any_nr(input, *args, **kwargs): + return torch.any(input, *args, **kwargs) + + +# noinspection PyShadowingBuiltins,PyUnusedLocal @doc_from_base() -@rmreduce(torch.any) -@func_treelize(return_type=Object) -def any(input, *args, **kwargs): +@auto_reduce(_any_r, _any_nr) +def any(input, *args, reduce=None, **kwargs): """ In ``treetensor``, you can get the ``any`` result of a whole tree with this function. @@ -86,7 +108,7 @@ def any(input, *args, **kwargs): └── b --> └── x --> tensor(False) """ - return torch.any(input, *args, **kwargs) + pass # pragma: no cover # noinspection PyShadowingBuiltins diff --git a/treetensor/torch/tensor.py b/treetensor/torch/tensor.py index 63ecd85a3..51191c71e 100644 --- a/treetensor/torch/tensor.py +++ b/treetensor/torch/tensor.py @@ -184,25 +184,45 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ return self.requires_grad_(requires_grad) + # noinspection PyShadowingBuiltins,PyUnusedLocal + @post_reduce(torch.all) + @method_treelize(return_type=Object) + def __all_r(self, *args, **kwargs): + return self + + # noinspection PyShadowingBuiltins + @method_treelize() + def __all_nr(self, *args, **kwargs): + return torch.all(self, *args, **kwargs) + # noinspection PyArgumentList @doc_from_base() - @rmreduce(torch.all) - @method_treelize(return_type=Object) - def all(self: torch.Tensor, *args, **kwargs) -> bool: + @auto_reduce(__all_r, __all_nr) + def all(self: torch.Tensor, *args, reduce=None, **kwargs) -> bool: """ See :func:`treetensor.torch.all` """ - return self.all(*args, **kwargs) + pass # pragma: no cover + + # noinspection PyShadowingBuiltins,PyUnusedLocal + @post_reduce(torch.any) + @method_treelize(return_type=Object) + def __any_r(self, *args, **kwargs): + return self + + # noinspection PyShadowingBuiltins + @method_treelize() + def __any_nr(self, *args, **kwargs): + return torch.any(self, *args, **kwargs) # noinspection PyArgumentList @doc_from_base() - @rmreduce(torch.any) - @method_treelize(return_type=Object) - def any(self: torch.Tensor, *args, **kwargs) -> bool: + @auto_reduce(__any_r, __any_nr) + def any(self: torch.Tensor, *args, reduce=None, **kwargs) -> bool: """ See :func:`treetensor.torch.any` """ - return self.any(*args, **kwargs) + pass # pragma: no cover @doc_from_base() @rmreduce(torch.max) @@ -762,7 +782,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): @doc_from_base() @auto_reduce(__std_r, __std_nr) @method_treelize() - def std(self, *args, **kwargs): + def std(self, *args, reduce=None, **kwargs): """ See :func:`treetensor.torch.std`. """ @@ -781,7 +801,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): @doc_from_base() @auto_reduce(__mean_r, __mean_nr) @method_treelize() - def mean(self, *args, **kwargs): + def mean(self, *args, reduce=None, **kwargs): """ See :func:`treetensor.torch.mean`. """ -- GitLab