提交 ca149e3f 编写于 作者: HansBug's avatar HansBug 😆

dev(hansbug): upgrade max, min, sum

上级 0150535c
......@@ -145,10 +145,24 @@ class TestTorchFuncsReduction:
assert isinstance(t1, torch.Tensor)
assert t1 == torch.tensor(1.0)
assert ttorch.isclose(ttorch.min(ttorch.tensor({
tt0 = ttorch.tensor({
'a': [1.0, 2.0, 1.5],
'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
})), ttorch.tensor(0.9), atol=1e-4)
})
assert ttorch.isclose(ttorch.min(tt0), ttorch.tensor(0.9), atol=1e-4).all()
tt1 = ttorch.min(tt0, reduce=False)
assert ttorch.isclose(tt1, ttorch.tensor({
'a': 1.0, 'b': 0.9,
}), atol=1e-4).all()
tt2_a, tt2_b = ttorch.min(tt0, dim=0)
assert ttorch.isclose(tt2_a, ttorch.tensor({
'a': 1.0, 'b': [1.3, 0.9],
}), atol=1e-4).all()
assert (tt2_b == ttorch.tensor({
'a': 0, 'b': [1, 0],
})).all()
@choose_mark()
def test_max(self):
......@@ -156,18 +170,40 @@ class TestTorchFuncsReduction:
assert isinstance(t1, torch.Tensor)
assert t1 == torch.tensor(2.0)
assert ttorch.isclose(ttorch.max(ttorch.tensor({
tt0 = ttorch.tensor({
'a': [1.0, 2.0, 1.5],
'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
})), ttorch.tensor(2.5), atol=1e-4)
})
assert ttorch.isclose(ttorch.max(tt0), ttorch.tensor(2.5), atol=1e-4)
tt1 = ttorch.max(tt0, reduce=False)
assert ttorch.isclose(tt1, ttorch.tensor({
'a': 2.0, 'b': 2.5,
}), atol=1e-4).all()
tt2_a, tt2_b = ttorch.max(tt0, dim=0)
assert ttorch.isclose(tt2_a, ttorch.tensor({
'a': 2.0, 'b': [1.8, 2.5],
}), atol=1e-4).all()
assert (tt2_b == ttorch.tensor({
'a': 1, 'b': [0, 1],
})).all()
@choose_mark()
def test_sum(self):
assert ttorch.sum(torch.tensor([1.0, 2.0, 1.5])) == torch.tensor(4.5)
assert (ttorch.sum(ttorch.tensor({
tt0 = ttorch.tensor({
'a': [1.0, 2.0, 1.5],
'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
})) == torch.tensor(11.0)).all()
})
assert ttorch.isclose(ttorch.sum(tt0), torch.tensor(11.0), atol=1e-4).all()
assert ttorch.isclose(ttorch.sum(tt0, reduce=False), ttorch.tensor({
'a': 4.5, 'b': {'x': 6.5},
}), atol=1e-4).all()
assert ttorch.isclose(ttorch.sum(tt0, dim=0), ttorch.tensor({
'a': 4.5, 'b': {'x': [3.1, 3.4]},
}), atol=1e-4).all()
@choose_mark()
def test_mean(self):
......
......@@ -86,30 +86,63 @@ class TestTorchTensorReduction:
@choose_mark()
def test_max(self):
t1 = ttorch.Tensor({
t0 = ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
}).max()
})
t1 = t0.max()
assert isinstance(t1, torch.Tensor)
assert t1.tolist() == 3
assert (t1 == torch.tensor(3)).all()
t2 = t0.max(reduce=False)
assert (t2 == ttorch.tensor({'a': 2, 'b': {'x': 3}})).all()
t3_a, t3_b = t0.max(dim=0)
assert (t3_a == ttorch.tensor({
'a': 2, 'b': {'x': [2, 3]},
})).all()
assert (t3_b == ttorch.tensor({
'a': 1, 'b': {'x': [1, 0]},
})).all()
@choose_mark()
def test_min(self):
t1 = ttorch.Tensor({
t0 = ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
}).min()
})
t1 = t0.min()
assert isinstance(t1, torch.Tensor)
assert t1.tolist() == -1
assert (t1 == torch.tensor(-1)).all()
t2 = t0.min(reduce=False)
assert (t2 == ttorch.tensor({'a': 1, 'b': {'x': -1}})).all()
t3_a, t3_b = t0.min(dim=0)
assert (t3_a == ttorch.tensor({
'a': 1, 'b': {'x': [0, -1]},
})).all()
assert (t3_b == ttorch.tensor({
'a': 0, 'b': {'x': [0, 1]},
})).all()
@choose_mark()
def test_sum(self):
t1 = ttorch.Tensor({
t0 = ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
}).sum()
})
t1 = t0.sum()
assert isinstance(t1, torch.Tensor)
assert t1.tolist() == 7
assert (t1 == ttorch.tensor(7)).all()
t2 = t0.sum(reduce=False)
assert (t2 == ttorch.tensor({'a': 3, 'b': {'x': 4}})).all()
t3 = t0.sum(dim=0)
assert (t3 == ttorch.tensor({
'a': 3, 'b': {'x': [2, 2]},
})).all()
@choose_mark()
def test_mean(self):
......
......@@ -11,5 +11,13 @@ class Torch(BaseTreeStruct):
pass
def auto_torch(value, cls: Type[Torch]):
return typetrans(value, cls) if isinstance(value, TreeValue) else value
# noinspection PyArgumentList
def auto_torch(v, cls: Type[Torch]):
if isinstance(v, TreeValue):
return typetrans(v, cls)
elif isinstance(v, (tuple, list, set)):
return type(v)((auto_torch(item, cls) for item in v))
elif isinstance(v, dict):
return type(v)({key: auto_torch(value, cls) for key, value in v.items()})
else:
return v
......@@ -117,7 +117,8 @@ def cat(tensors, *args, **kwargs):
# noinspection PyShadowingNames
@doc_from_base()
@post_process(lambda r: tuple(map(auto_tensor, r)))
@post_process(lambda r: tuple(r))
@post_process(auto_tensor)
@func_treelize(return_type=TreeValue, rise=dict(template=[None]))
@post_process(lambda r: list(r))
def split(tensor, split_size_or_sections, *args, **kwargs):
......@@ -207,7 +208,8 @@ def split(tensor, split_size_or_sections, *args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from_base()
@post_process(lambda r: tuple(map(auto_tensor, r)))
@post_process(lambda r: tuple(r))
@post_process(auto_tensor)
@func_treelize(return_type=TreeValue, rise=dict(template=[None]))
@post_process(lambda r: list(r))
def chunk(input, chunks, *args, **kwargs):
......
import torch
from treevalue import TreeValue
from treevalue.utils import post_process
from .base import doc_from_base, func_treelize
from .base import doc_from_base, func_treelize, auto_tensor
from ..base import rmreduce, post_reduce, auto_reduce
from ...common import Object
......@@ -93,29 +95,39 @@ def any(input, *args, reduce=None, **kwargs):
>>> ttorch.any(ttorch.tensor({'a': [False, False], 'b': {'x': [False, False]}}))
tensor(False)
.. note::
In this ``any`` function, the return value should be a tensor with single boolean value.
If what you need is a tree of boolean tensors, you should do like this
>>> ttorch.any(ttorch.tensor({'a': [True, False], 'b': {'x': [False, False]}}), reduce=False)
<Tensor 0x7fd45b52d518>
├── a --> tensor(True)
└── b --> <Tensor 0x7fd45b52d470>
└── x --> tensor(False)
>>> ttorch.tensor({
>>> 'a': [True, False],
>>> 'b': {'x': [False, False]},
>>> }).map(lambda x: torch.any(x))
<Tensor 0x7ff363bc6898>
├── a --> tensor(True)
└── b --> <Tensor 0x7ff363bc67f0>
└── x --> tensor(False)
>>> ttorch.any(ttorch.tensor({'a': [False, False], 'b': {'x': [False, False]}}), dim=0)
<Tensor 0x7fd45b534128>
├── a --> tensor(False)
└── b --> <Tensor 0x7fd45b534080>
└── x --> tensor(False)
"""
pass # pragma: no cover
# noinspection PyShadowingBuiltins,PyUnusedLocal
@post_reduce(torch.min)
@func_treelize(return_type=Object)
def _min_r(input, *args, **kwargs):
return input
# noinspection PyShadowingBuiltins
@post_process(auto_tensor)
@func_treelize(return_type=TreeValue, rise=True)
def _min_nr(input, *args, **kwargs):
return torch.min(input, *args, **kwargs)
# noinspection PyShadowingBuiltins,PyUnusedLocal
@doc_from_base()
@rmreduce(torch.min)
@func_treelize(return_type=Object)
def min(input, *args, **kwargs):
@auto_reduce(_min_r, _min_nr)
def min(input, *args, reduce=None, **kwargs):
"""
In ``treetensor``, you can get the ``min`` result of a whole tree with this function.
......@@ -132,29 +144,52 @@ def min(input, *args, **kwargs):
... }))
tensor(0.9000)
.. note::
>>> ttorch.min(ttorch.tensor({
... 'a': [1.0, 2.0, 1.5],
... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
... }), reduce=False)
<Tensor 0x7fd45b5913c8>
├── a --> tensor(1.)
└── b --> <Tensor 0x7fd45b5912e8>
└── x --> tensor(0.9000)
In this ``min`` function, the return value should be a tensor with single value.
>>> ttorch.min(ttorch.tensor({
... 'a': [1.0, 2.0, 1.5],
... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
... }), dim=0)
torch.return_types.min(
values=<Tensor 0x7fd45b52d2e8>
├── a --> tensor(1.)
└── b --> <Tensor 0x7fd45b52d208>
└── x --> tensor([1.3000, 0.9000])
,
indices=<Tensor 0x7fd45b591cc0>
├── a --> tensor(0)
└── b --> <Tensor 0x7fd45b52d3c8>
└── x --> tensor([1, 0])
)
"""
pass # pragma: no cover
If what you need is a tree of tensors, you should do like this
>>> ttorch.tensor({
... 'a': [1.0, 2.0, 1.5],
... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
... }).map(lambda x: torch.min(x))
<Tensor 0x7ff363bbb2b0>
├── a --> tensor(1.)
└── b --> <Tensor 0x7ff363bbb0b8>
└── x --> tensor(0.9000)
"""
return torch.min(input, *args, **kwargs)
# noinspection PyShadowingBuiltins,PyUnusedLocal
@post_reduce(torch.max)
@func_treelize(return_type=Object)
def _max_r(input, *args, **kwargs):
return input
# noinspection PyShadowingBuiltins
@post_process(auto_tensor)
@func_treelize(return_type=TreeValue, rise=True)
def _max_nr(input, *args, **kwargs):
return torch.max(input, *args, **kwargs)
# noinspection PyShadowingBuiltins,PyUnusedLocal
@doc_from_base()
@rmreduce(torch.max)
@func_treelize(return_type=Object)
def max(input, *args, **kwargs):
@auto_reduce(_max_r, _max_nr)
def max(input, *args, reduce=None, **kwargs):
"""
In ``treetensor``, you can get the ``max`` result of a whole tree with this function.
......@@ -171,29 +206,51 @@ def max(input, *args, **kwargs):
... }))
tensor(2.5000)
.. note::
>>> ttorch.max(ttorch.tensor({
... 'a': [1.0, 2.0, 1.5],
... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
... }), reduce=False)
<Tensor 0x7fd45b52d940>
├── a --> tensor(2.)
└── b --> <Tensor 0x7fd45b52d908>
└── x --> tensor(2.5000)
In this ``max`` function, the return value should be a tensor with single value.
>>> ttorch.max(ttorch.tensor({
... 'a': [1.0, 2.0, 1.5],
... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
... }), dim=0)
torch.return_types.max(
values=<Tensor 0x7fd45b5345f8>
├── a --> tensor(2.)
└── b --> <Tensor 0x7fd45b5345c0>
└── x --> tensor([1.8000, 2.5000])
,
indices=<Tensor 0x7fd45b5346d8>
├── a --> tensor(1)
└── b --> <Tensor 0x7fd45b5346a0>
└── x --> tensor([0, 1])
)
"""
pass # pragma: no cover
If what you need is a tree of tensors, you should do like this
>>> ttorch.tensor({
... 'a': [1.0, 2.0, 1.5],
... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
... }).map(lambda x: torch.max(x))
<Tensor 0x7ff363bc6b00>
├── a --> tensor(2.)
└── b --> <Tensor 0x7ff363bc6c18>
└── x --> tensor(2.5000)
"""
return torch.max(input, *args, **kwargs)
# noinspection PyShadowingBuiltins,PyUnusedLocal
@post_reduce(torch.sum)
@func_treelize(return_type=Object)
def _sum_r(input, *args, **kwargs):
return input
# noinspection PyShadowingBuiltins
@func_treelize()
def _sum_nr(input, *args, **kwargs):
return torch.sum(input, *args, **kwargs)
# noinspection PyShadowingBuiltins,PyUnusedLocal
@doc_from_base()
@rmreduce(torch.sum)
@func_treelize(return_type=Object)
def sum(input, *args, **kwargs):
@auto_reduce(_sum_r, _sum_nr)
def sum(input, *args, reduce=None, **kwargs):
"""
In ``treetensor``, you can get the ``sum`` result of a whole tree with this function.
......@@ -210,22 +267,25 @@ def sum(input, *args, **kwargs):
... }))
tensor(11.)
.. note::
In this ``sum`` function, the return value should be a tensor with single value.
If what you need is a tree of tensors, you should do like this
>>> ttorch.sum(ttorch.tensor({
... 'a': [1.0, 2.0, 1.5],
... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
... }), reduce=False)
<Tensor 0x7fd45b534898>
├── a --> tensor(4.5000)
└── b --> <Tensor 0x7fd45b5344e0>
└── x --> tensor(6.5000)
>>> ttorch.tensor({
... 'a': [1.0, 2.0, 1.5],
... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
... }).map(lambda x: torch.sum(x))
<Tensor 0x7ff363bbbda0>
├── a --> tensor(4.5000)
└── b --> <Tensor 0x7ff363bbbcf8>
└── x --> tensor(6.5000)
>>> ttorch.sum(ttorch.tensor({
... 'a': [1.0, 2.0, 1.5],
... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
... }), dim=0)
<Tensor 0x7f3640703128>
├── a --> tensor(4.5000)
└── b --> <Tensor 0x7f3640703080>
└── x --> tensor([3.1000, 3.4000])
"""
return torch.sum(input, *args, **kwargs)
pass # pragma: no cover
# noinspection PyShadowingBuiltins,PyUnusedLocal
......
......@@ -224,32 +224,65 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
pass # pragma: no cover
@doc_from_base()
@rmreduce(torch.max)
# noinspection PyShadowingBuiltins,PyUnusedLocal
@post_reduce(torch.max)
@method_treelize(return_type=Object)
def max(self: torch.Tensor, *args, **kwargs):
def __max_r(self, *args, **kwargs):
return self
# noinspection PyShadowingBuiltins
@post_process(lambda r: replaceable_partial(auto_torch, cls=Tensor)(r))
@method_treelize(return_type=TreeValue, rise=True)
def __max_nr(self, *args, **kwargs):
return torch.max(self, *args, **kwargs)
@doc_from_base()
@auto_reduce(__max_r, __max_nr)
def max(self: torch.Tensor, *args, reduce=None, **kwargs):
"""
See :func:`treetensor.torch.max`
"""
return self.max(*args, **kwargs)
pass # pragma: no cover
@doc_from_base()
@rmreduce(torch.min)
# noinspection PyShadowingBuiltins,PyUnusedLocal
@post_reduce(torch.min)
@method_treelize(return_type=Object)
def min(self: torch.Tensor, *args, **kwargs):
def __min_r(self, *args, **kwargs):
return self
# noinspection PyShadowingBuiltins
@post_process(lambda r: replaceable_partial(auto_torch, cls=Tensor)(r))
@method_treelize(return_type=TreeValue, rise=True)
def __min_nr(self, *args, **kwargs):
return torch.min(self, *args, **kwargs)
@doc_from_base()
@auto_reduce(__min_r, __min_nr)
def min(self: torch.Tensor, *args, reduce=None, **kwargs):
"""
See :func:`treetensor.torch.min`
"""
return self.min(*args, **kwargs)
pass # pragma: no cover
@doc_from_base()
@rmreduce(torch.sum)
# noinspection PyShadowingBuiltins,PyUnusedLocal
@post_reduce(torch.sum)
@method_treelize(return_type=Object)
def sum(self: torch.Tensor, *args, **kwargs):
def __sum_r(self, *args, **kwargs):
return self
# noinspection PyShadowingBuiltins
@post_process(lambda r: replaceable_partial(auto_torch, cls=Tensor)(r))
@method_treelize(return_type=TreeValue, rise=True)
def __sum_nr(self, *args, **kwargs):
return torch.sum(self, *args, **kwargs)
@doc_from_base()
@auto_reduce(__sum_r, __sum_nr)
def sum(self: torch.Tensor, *args, reduce=None, **kwargs):
"""
See :func:`treetensor.torch.sum`
"""
return self.sum(*args, **kwargs)
pass # pragma: no cover
@method_treelize()
def __eq__(self, other):
......@@ -681,7 +714,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
return self.log10_(*args, **kwargs)
@doc_from_base()
@post_process(lambda r: tuple(map(replaceable_partial(auto_torch, cls=Tensor), r)))
@post_process(lambda r: replaceable_partial(auto_torch, cls=Tensor)(r))
@method_treelize(return_type=TreeValue, rise=dict(template=[None]))
@post_process(lambda r: list(r))
def split(self, split_size, *args, **kwargs):
......@@ -691,7 +724,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
return self.split(split_size, *args, **kwargs)
@doc_from_base()
@post_process(lambda r: tuple(map(replaceable_partial(auto_torch, cls=Tensor), r)))
@post_process(lambda r: replaceable_partial(auto_torch, cls=Tensor)(r))
@method_treelize(return_type=TreeValue, rise=dict(template=[None]))
@post_process(lambda r: list(r))
def chunk(self, chunks, *args, **kwargs):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册