From ca149e3fce098e2bfd8be14f37f8f2f89a5c85ed Mon Sep 17 00:00:00 2001 From: HansBug Date: Tue, 28 Sep 2021 15:43:07 +0800 Subject: [PATCH] dev(hansbug): upgrade max, min, sum --- test/torch/funcs/test_reduction.py | 48 ++++++- test/torch/tensor/test_reduction.py | 51 ++++++-- treetensor/torch/base/torch.py | 12 +- treetensor/torch/funcs/operation.py | 6 +- treetensor/torch/funcs/reduction.py | 186 ++++++++++++++++++---------- treetensor/torch/tensor.py | 61 ++++++--- 6 files changed, 268 insertions(+), 96 deletions(-) diff --git a/test/torch/funcs/test_reduction.py b/test/torch/funcs/test_reduction.py index 7e341c896..943e9bd87 100644 --- a/test/torch/funcs/test_reduction.py +++ b/test/torch/funcs/test_reduction.py @@ -145,10 +145,24 @@ class TestTorchFuncsReduction: assert isinstance(t1, torch.Tensor) assert t1 == torch.tensor(1.0) - assert ttorch.isclose(ttorch.min(ttorch.tensor({ + tt0 = ttorch.tensor({ 'a': [1.0, 2.0, 1.5], 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]}, - })), ttorch.tensor(0.9), atol=1e-4) + }) + assert ttorch.isclose(ttorch.min(tt0), ttorch.tensor(0.9), atol=1e-4).all() + + tt1 = ttorch.min(tt0, reduce=False) + assert ttorch.isclose(tt1, ttorch.tensor({ + 'a': 1.0, 'b': 0.9, + }), atol=1e-4).all() + + tt2_a, tt2_b = ttorch.min(tt0, dim=0) + assert ttorch.isclose(tt2_a, ttorch.tensor({ + 'a': 1.0, 'b': [1.3, 0.9], + }), atol=1e-4).all() + assert (tt2_b == ttorch.tensor({ + 'a': 0, 'b': [1, 0], + })).all() @choose_mark() def test_max(self): @@ -156,18 +170,40 @@ class TestTorchFuncsReduction: assert isinstance(t1, torch.Tensor) assert t1 == torch.tensor(2.0) - assert ttorch.isclose(ttorch.max(ttorch.tensor({ + tt0 = ttorch.tensor({ 'a': [1.0, 2.0, 1.5], 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]}, - })), ttorch.tensor(2.5), atol=1e-4) + }) + assert ttorch.isclose(ttorch.max(tt0), ttorch.tensor(2.5), atol=1e-4) + + tt1 = ttorch.max(tt0, reduce=False) + assert ttorch.isclose(tt1, ttorch.tensor({ + 'a': 2.0, 'b': 2.5, + }), atol=1e-4).all() + + tt2_a, tt2_b = ttorch.max(tt0, dim=0) + assert ttorch.isclose(tt2_a, ttorch.tensor({ + 'a': 2.0, 'b': [1.8, 2.5], + }), atol=1e-4).all() + assert (tt2_b == ttorch.tensor({ + 'a': 1, 'b': [0, 1], + })).all() @choose_mark() def test_sum(self): assert ttorch.sum(torch.tensor([1.0, 2.0, 1.5])) == torch.tensor(4.5) - assert (ttorch.sum(ttorch.tensor({ + + tt0 = ttorch.tensor({ 'a': [1.0, 2.0, 1.5], 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]}, - })) == torch.tensor(11.0)).all() + }) + assert ttorch.isclose(ttorch.sum(tt0), torch.tensor(11.0), atol=1e-4).all() + assert ttorch.isclose(ttorch.sum(tt0, reduce=False), ttorch.tensor({ + 'a': 4.5, 'b': {'x': 6.5}, + }), atol=1e-4).all() + assert ttorch.isclose(ttorch.sum(tt0, dim=0), ttorch.tensor({ + 'a': 4.5, 'b': {'x': [3.1, 3.4]}, + }), atol=1e-4).all() @choose_mark() def test_mean(self): diff --git a/test/torch/tensor/test_reduction.py b/test/torch/tensor/test_reduction.py index 9867a9f7e..f32a06cdb 100644 --- a/test/torch/tensor/test_reduction.py +++ b/test/torch/tensor/test_reduction.py @@ -86,30 +86,63 @@ class TestTorchTensorReduction: @choose_mark() def test_max(self): - t1 = ttorch.Tensor({ + t0 = ttorch.Tensor({ 'a': [1, 2], 'b': {'x': [[0, 3], [2, -1]]} - }).max() + }) + t1 = t0.max() assert isinstance(t1, torch.Tensor) - assert t1.tolist() == 3 + assert (t1 == torch.tensor(3)).all() + + t2 = t0.max(reduce=False) + assert (t2 == ttorch.tensor({'a': 2, 'b': {'x': 3}})).all() + + t3_a, t3_b = t0.max(dim=0) + assert (t3_a == ttorch.tensor({ + 'a': 2, 'b': {'x': [2, 3]}, + })).all() + assert (t3_b == ttorch.tensor({ + 'a': 1, 'b': {'x': [1, 0]}, + })).all() @choose_mark() def test_min(self): - t1 = ttorch.Tensor({ + t0 = ttorch.Tensor({ 'a': [1, 2], 'b': {'x': [[0, 3], [2, -1]]} - }).min() + }) + t1 = t0.min() assert isinstance(t1, torch.Tensor) - assert t1.tolist() == -1 + assert (t1 == torch.tensor(-1)).all() + + t2 = t0.min(reduce=False) + assert (t2 == ttorch.tensor({'a': 1, 'b': {'x': -1}})).all() + + t3_a, t3_b = t0.min(dim=0) + assert (t3_a == ttorch.tensor({ + 'a': 1, 'b': {'x': [0, -1]}, + })).all() + assert (t3_b == ttorch.tensor({ + 'a': 0, 'b': {'x': [0, 1]}, + })).all() @choose_mark() def test_sum(self): - t1 = ttorch.Tensor({ + t0 = ttorch.Tensor({ 'a': [1, 2], 'b': {'x': [[0, 3], [2, -1]]} - }).sum() + }) + t1 = t0.sum() assert isinstance(t1, torch.Tensor) - assert t1.tolist() == 7 + assert (t1 == ttorch.tensor(7)).all() + + t2 = t0.sum(reduce=False) + assert (t2 == ttorch.tensor({'a': 3, 'b': {'x': 4}})).all() + + t3 = t0.sum(dim=0) + assert (t3 == ttorch.tensor({ + 'a': 3, 'b': {'x': [2, 2]}, + })).all() @choose_mark() def test_mean(self): diff --git a/treetensor/torch/base/torch.py b/treetensor/torch/base/torch.py index ca17e8b76..c195b233d 100644 --- a/treetensor/torch/base/torch.py +++ b/treetensor/torch/base/torch.py @@ -11,5 +11,13 @@ class Torch(BaseTreeStruct): pass -def auto_torch(value, cls: Type[Torch]): - return typetrans(value, cls) if isinstance(value, TreeValue) else value +# noinspection PyArgumentList +def auto_torch(v, cls: Type[Torch]): + if isinstance(v, TreeValue): + return typetrans(v, cls) + elif isinstance(v, (tuple, list, set)): + return type(v)((auto_torch(item, cls) for item in v)) + elif isinstance(v, dict): + return type(v)({key: auto_torch(value, cls) for key, value in v.items()}) + else: + return v diff --git a/treetensor/torch/funcs/operation.py b/treetensor/torch/funcs/operation.py index 542b687ae..6aef1fd41 100644 --- a/treetensor/torch/funcs/operation.py +++ b/treetensor/torch/funcs/operation.py @@ -117,7 +117,8 @@ def cat(tensors, *args, **kwargs): # noinspection PyShadowingNames @doc_from_base() -@post_process(lambda r: tuple(map(auto_tensor, r))) +@post_process(lambda r: tuple(r)) +@post_process(auto_tensor) @func_treelize(return_type=TreeValue, rise=dict(template=[None])) @post_process(lambda r: list(r)) def split(tensor, split_size_or_sections, *args, **kwargs): @@ -207,7 +208,8 @@ def split(tensor, split_size_or_sections, *args, **kwargs): # noinspection PyShadowingBuiltins @doc_from_base() -@post_process(lambda r: tuple(map(auto_tensor, r))) +@post_process(lambda r: tuple(r)) +@post_process(auto_tensor) @func_treelize(return_type=TreeValue, rise=dict(template=[None])) @post_process(lambda r: list(r)) def chunk(input, chunks, *args, **kwargs): diff --git a/treetensor/torch/funcs/reduction.py b/treetensor/torch/funcs/reduction.py index badd65c6b..857ba6802 100644 --- a/treetensor/torch/funcs/reduction.py +++ b/treetensor/torch/funcs/reduction.py @@ -1,6 +1,8 @@ import torch +from treevalue import TreeValue +from treevalue.utils import post_process -from .base import doc_from_base, func_treelize +from .base import doc_from_base, func_treelize, auto_tensor from ..base import rmreduce, post_reduce, auto_reduce from ...common import Object @@ -93,29 +95,39 @@ def any(input, *args, reduce=None, **kwargs): >>> ttorch.any(ttorch.tensor({'a': [False, False], 'b': {'x': [False, False]}})) tensor(False) - .. note:: - - In this ``any`` function, the return value should be a tensor with single boolean value. - - If what you need is a tree of boolean tensors, you should do like this + >>> ttorch.any(ttorch.tensor({'a': [True, False], 'b': {'x': [False, False]}}), reduce=False) + + ├── a --> tensor(True) + └── b --> + └── x --> tensor(False) - >>> ttorch.tensor({ - >>> 'a': [True, False], - >>> 'b': {'x': [False, False]}, - >>> }).map(lambda x: torch.any(x)) - - ├── a --> tensor(True) - └── b --> - └── x --> tensor(False) + >>> ttorch.any(ttorch.tensor({'a': [False, False], 'b': {'x': [False, False]}}), dim=0) + + ├── a --> tensor(False) + └── b --> + └── x --> tensor(False) """ pass # pragma: no cover +# noinspection PyShadowingBuiltins,PyUnusedLocal +@post_reduce(torch.min) +@func_treelize(return_type=Object) +def _min_r(input, *args, **kwargs): + return input + + # noinspection PyShadowingBuiltins +@post_process(auto_tensor) +@func_treelize(return_type=TreeValue, rise=True) +def _min_nr(input, *args, **kwargs): + return torch.min(input, *args, **kwargs) + + +# noinspection PyShadowingBuiltins,PyUnusedLocal @doc_from_base() -@rmreduce(torch.min) -@func_treelize(return_type=Object) -def min(input, *args, **kwargs): +@auto_reduce(_min_r, _min_nr) +def min(input, *args, reduce=None, **kwargs): """ In ``treetensor``, you can get the ``min`` result of a whole tree with this function. @@ -132,29 +144,52 @@ def min(input, *args, **kwargs): ... })) tensor(0.9000) - .. note:: + >>> ttorch.min(ttorch.tensor({ + ... 'a': [1.0, 2.0, 1.5], + ... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]}, + ... }), reduce=False) + + ├── a --> tensor(1.) + └── b --> + └── x --> tensor(0.9000) - In this ``min`` function, the return value should be a tensor with single value. + >>> ttorch.min(ttorch.tensor({ + ... 'a': [1.0, 2.0, 1.5], + ... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]}, + ... }), dim=0) + torch.return_types.min( + values= + ├── a --> tensor(1.) + └── b --> + └── x --> tensor([1.3000, 0.9000]) + , + indices= + ├── a --> tensor(0) + └── b --> + └── x --> tensor([1, 0]) + ) + """ + pass # pragma: no cover - If what you need is a tree of tensors, you should do like this - >>> ttorch.tensor({ - ... 'a': [1.0, 2.0, 1.5], - ... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]}, - ... }).map(lambda x: torch.min(x)) - - ├── a --> tensor(1.) - └── b --> - └── x --> tensor(0.9000) - """ - return torch.min(input, *args, **kwargs) +# noinspection PyShadowingBuiltins,PyUnusedLocal +@post_reduce(torch.max) +@func_treelize(return_type=Object) +def _max_r(input, *args, **kwargs): + return input # noinspection PyShadowingBuiltins +@post_process(auto_tensor) +@func_treelize(return_type=TreeValue, rise=True) +def _max_nr(input, *args, **kwargs): + return torch.max(input, *args, **kwargs) + + +# noinspection PyShadowingBuiltins,PyUnusedLocal @doc_from_base() -@rmreduce(torch.max) -@func_treelize(return_type=Object) -def max(input, *args, **kwargs): +@auto_reduce(_max_r, _max_nr) +def max(input, *args, reduce=None, **kwargs): """ In ``treetensor``, you can get the ``max`` result of a whole tree with this function. @@ -171,29 +206,51 @@ def max(input, *args, **kwargs): ... })) tensor(2.5000) - .. note:: + >>> ttorch.max(ttorch.tensor({ + ... 'a': [1.0, 2.0, 1.5], + ... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]}, + ... }), reduce=False) + + ├── a --> tensor(2.) + └── b --> + └── x --> tensor(2.5000) - In this ``max`` function, the return value should be a tensor with single value. + >>> ttorch.max(ttorch.tensor({ + ... 'a': [1.0, 2.0, 1.5], + ... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]}, + ... }), dim=0) + torch.return_types.max( + values= + ├── a --> tensor(2.) + └── b --> + └── x --> tensor([1.8000, 2.5000]) + , + indices= + ├── a --> tensor(1) + └── b --> + └── x --> tensor([0, 1]) + ) + """ + pass # pragma: no cover - If what you need is a tree of tensors, you should do like this - >>> ttorch.tensor({ - ... 'a': [1.0, 2.0, 1.5], - ... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]}, - ... }).map(lambda x: torch.max(x)) - - ├── a --> tensor(2.) - └── b --> - └── x --> tensor(2.5000) - """ - return torch.max(input, *args, **kwargs) +# noinspection PyShadowingBuiltins,PyUnusedLocal +@post_reduce(torch.sum) +@func_treelize(return_type=Object) +def _sum_r(input, *args, **kwargs): + return input # noinspection PyShadowingBuiltins +@func_treelize() +def _sum_nr(input, *args, **kwargs): + return torch.sum(input, *args, **kwargs) + + +# noinspection PyShadowingBuiltins,PyUnusedLocal @doc_from_base() -@rmreduce(torch.sum) -@func_treelize(return_type=Object) -def sum(input, *args, **kwargs): +@auto_reduce(_sum_r, _sum_nr) +def sum(input, *args, reduce=None, **kwargs): """ In ``treetensor``, you can get the ``sum`` result of a whole tree with this function. @@ -210,22 +267,25 @@ def sum(input, *args, **kwargs): ... })) tensor(11.) - .. note:: - - In this ``sum`` function, the return value should be a tensor with single value. - - If what you need is a tree of tensors, you should do like this + >>> ttorch.sum(ttorch.tensor({ + ... 'a': [1.0, 2.0, 1.5], + ... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]}, + ... }), reduce=False) + + ├── a --> tensor(4.5000) + └── b --> + └── x --> tensor(6.5000) - >>> ttorch.tensor({ - ... 'a': [1.0, 2.0, 1.5], - ... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]}, - ... }).map(lambda x: torch.sum(x)) - - ├── a --> tensor(4.5000) - └── b --> - └── x --> tensor(6.5000) + >>> ttorch.sum(ttorch.tensor({ + ... 'a': [1.0, 2.0, 1.5], + ... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]}, + ... }), dim=0) + + ├── a --> tensor(4.5000) + └── b --> + └── x --> tensor([3.1000, 3.4000]) """ - return torch.sum(input, *args, **kwargs) + pass # pragma: no cover # noinspection PyShadowingBuiltins,PyUnusedLocal diff --git a/treetensor/torch/tensor.py b/treetensor/torch/tensor.py index 51191c71e..19fe43ebf 100644 --- a/treetensor/torch/tensor.py +++ b/treetensor/torch/tensor.py @@ -224,32 +224,65 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ pass # pragma: no cover - @doc_from_base() - @rmreduce(torch.max) + # noinspection PyShadowingBuiltins,PyUnusedLocal + @post_reduce(torch.max) @method_treelize(return_type=Object) - def max(self: torch.Tensor, *args, **kwargs): + def __max_r(self, *args, **kwargs): + return self + + # noinspection PyShadowingBuiltins + @post_process(lambda r: replaceable_partial(auto_torch, cls=Tensor)(r)) + @method_treelize(return_type=TreeValue, rise=True) + def __max_nr(self, *args, **kwargs): + return torch.max(self, *args, **kwargs) + + @doc_from_base() + @auto_reduce(__max_r, __max_nr) + def max(self: torch.Tensor, *args, reduce=None, **kwargs): """ See :func:`treetensor.torch.max` """ - return self.max(*args, **kwargs) + pass # pragma: no cover - @doc_from_base() - @rmreduce(torch.min) + # noinspection PyShadowingBuiltins,PyUnusedLocal + @post_reduce(torch.min) @method_treelize(return_type=Object) - def min(self: torch.Tensor, *args, **kwargs): + def __min_r(self, *args, **kwargs): + return self + + # noinspection PyShadowingBuiltins + @post_process(lambda r: replaceable_partial(auto_torch, cls=Tensor)(r)) + @method_treelize(return_type=TreeValue, rise=True) + def __min_nr(self, *args, **kwargs): + return torch.min(self, *args, **kwargs) + + @doc_from_base() + @auto_reduce(__min_r, __min_nr) + def min(self: torch.Tensor, *args, reduce=None, **kwargs): """ See :func:`treetensor.torch.min` """ - return self.min(*args, **kwargs) + pass # pragma: no cover - @doc_from_base() - @rmreduce(torch.sum) + # noinspection PyShadowingBuiltins,PyUnusedLocal + @post_reduce(torch.sum) @method_treelize(return_type=Object) - def sum(self: torch.Tensor, *args, **kwargs): + def __sum_r(self, *args, **kwargs): + return self + + # noinspection PyShadowingBuiltins + @post_process(lambda r: replaceable_partial(auto_torch, cls=Tensor)(r)) + @method_treelize(return_type=TreeValue, rise=True) + def __sum_nr(self, *args, **kwargs): + return torch.sum(self, *args, **kwargs) + + @doc_from_base() + @auto_reduce(__sum_r, __sum_nr) + def sum(self: torch.Tensor, *args, reduce=None, **kwargs): """ See :func:`treetensor.torch.sum` """ - return self.sum(*args, **kwargs) + pass # pragma: no cover @method_treelize() def __eq__(self, other): @@ -681,7 +714,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): return self.log10_(*args, **kwargs) @doc_from_base() - @post_process(lambda r: tuple(map(replaceable_partial(auto_torch, cls=Tensor), r))) + @post_process(lambda r: replaceable_partial(auto_torch, cls=Tensor)(r)) @method_treelize(return_type=TreeValue, rise=dict(template=[None])) @post_process(lambda r: list(r)) def split(self, split_size, *args, **kwargs): @@ -691,7 +724,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): return self.split(split_size, *args, **kwargs) @doc_from_base() - @post_process(lambda r: tuple(map(replaceable_partial(auto_torch, cls=Tensor), r))) + @post_process(lambda r: replaceable_partial(auto_torch, cls=Tensor)(r)) @method_treelize(return_type=TreeValue, rise=dict(template=[None])) @post_process(lambda r: list(r)) def chunk(self, chunks, *args, **kwargs): -- GitLab