提交 003eefaa 编写于 作者: HansBug's avatar HansBug 😆

dev(hansbug): complete new version of std and mean

上级 4980f767
......@@ -722,59 +722,25 @@ class TestTorchFuncsMath:
[1.2041, 0.5740, math.nan]]},
}), rtol=1e-4, atol=1e-4, equal_nan=True).all()
@choose_mark()
def test_std(self):
t1 = torch.tensor([[25.5133, 24.2050, 8.1067],
[22.7316, -17.8863, -37.9171]]).std()
assert isinstance(t1, torch.Tensor)
assert ttorch.isclose(t1, torch.tensor(26.3619), atol=1e-4).all()
tt1 = ttorch.tensor({
'a': [[-48.6580, 30.9506, -16.1800],
[37.6667, 10.3850, -5.7679]],
'b': {'x': [[-17.9371, 8.4873, -49.0445, 4.7368],
[21.3990, -11.2385, -15.9331, -41.6838],
[-7.1814, -38.1301, -2.2320, 10.1392]]},
}).std()
assert ttorch.isclose(tt1, ttorch.tensor({
'a': 32.0483,
'b': {'x': 22.1754},
}), atol=1e-4).all()
@choose_mark()
def test_mean(self):
t1 = torch.tensor([[11.8069, 16.7822, -11.8583],
[-10.0426, 38.7326, 30.0298]]).mean()
assert isinstance(t1, torch.Tensor)
assert ttorch.isclose(t1, torch.tensor(12.5751), atol=1e-4).all()
tt1 = ttorch.tensor({
'a': [[-29.3862, 10.3668, -19.8407],
[11.3299, -0.7511, -13.8404]],
'b': {'x': [[-25.1722, 22.6307, -9.3588, -6.8217],
[-31.4652, 6.6465, 36.9483, -4.0487],
[-17.2146, 24.0029, 35.4574, -29.2970]]},
}).mean()
assert ttorch.isclose(tt1, ttorch.tensor({
'a': -7.0203,
'b': {'x': 0.1923},
}), atol=1e-4).all()
@choose_mark()
def test_dist(self):
t1 = torch.tensor([-0.6566, 1.2243, 1.5018, -0.1492, 0.8947]).dist(
t1 = ttorch.dist(
torch.tensor([-0.6566, 1.2243, 1.5018, -0.1492, 0.8947]),
torch.tensor([0.5898, 0.6839, 0.0388, 0.4649, 0.7964]),
)
assert isinstance(t1, torch.Tensor)
assert ttorch.isclose(t1, torch.tensor(2.0911), atol=1e-4).all()
tt1 = ttorch.tensor({
'a': [-0.5491, 1.5006, -0.0483, 1.2282, -1.4837],
'b': {'x': [-1.8414, 1.2913, 0.0943, 0.3473, 1.2717, 0.6013]},
}).dist(ttorch.tensor({
'a': [0.1389, -0.7804, -1.3048, -1.1066, 1.3225],
'b': {'x': [1.4873, 0.2218, -0.1063, -0.8726, -0.6756, 0.4805]},
}))
tt1 = ttorch.dist(
ttorch.tensor({
'a': [-0.5491, 1.5006, -0.0483, 1.2282, -1.4837],
'b': {'x': [-1.8414, 1.2913, 0.0943, 0.3473, 1.2717, 0.6013]},
}), ttorch.tensor({
'a': [0.1389, -0.7804, -1.3048, -1.1066, 1.3225],
'b': {'x': [1.4873, 0.2218, -0.1063, -0.8726, -0.6756, 0.4805]},
})
)
assert ttorch.isclose(tt1, ttorch.tensor({
'a': 4.5366,
'b': {'x': 4.1904}
......@@ -782,19 +748,19 @@ class TestTorchFuncsMath:
@choose_mark()
def test_norm(self):
t1 = torch.tensor([[0.0363, -1.7385, 1.0669, 2.6967],
[0.0848, 0.2735, 0.3538, 0.2271],
[-0.1014, 1.1351, -0.5761, -1.2671]]).norm()
t1 = ttorch.norm(torch.tensor([[0.0363, -1.7385, 1.0669, 2.6967],
[0.0848, 0.2735, 0.3538, 0.2271],
[-0.1014, 1.1351, -0.5761, -1.2671]]))
assert isinstance(t1, torch.Tensor)
assert ttorch.isclose(t1, torch.tensor(3.8638), atol=1e-4).all()
tt1 = ttorch.tensor({
tt1 = ttorch.norm(ttorch.tensor({
'a': [[-0.5012, 2.0900, 0.0151],
[-0.5035, 0.2144, 0.8370]],
'b': {'x': [[0.3911, 0.3557, -2.2156, 0.3653],
[-0.3503, 1.2182, -0.2364, -0.2854],
[-1.5770, -0.7349, 0.8391, -0.2845]]},
}).norm()
}))
assert ttorch.isclose(tt1, ttorch.tensor({
'a': 2.3706,
'b': {'x': 3.2982},
......
import pytest
import torch
import treetensor.torch as ttorch
......@@ -118,6 +119,84 @@ class TestTorchFuncsReduction:
'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
})) == torch.tensor(11.0)).all()
@choose_mark()
def test_mean(self):
t0 = torch.tensor([[26.6598, 27.8008, -59.4753],
[-79.1833, 3.3349, 20.1665]])
t1 = ttorch.mean(t0)
assert isinstance(t1, torch.Tensor)
assert ttorch.isclose(t1, torch.tensor(-10.1161), atol=1e-4).all()
t2 = ttorch.mean(t0, dim=1)
assert isinstance(t2, torch.Tensor)
assert ttorch.isclose(t2, torch.tensor([-1.6716, -18.5606]), atol=1e-4).all()
tt0 = ttorch.tensor({
'a': [[25.2702, 37.4206, -37.1401],
[-7.7245, -91.3234, -27.9402]],
'b': {'x': [[3.2028, -14.0720, 18.1739, 8.5944],
[41.7761, 36.9908, -20.5495, 5.6480],
[-9.3438, -0.7416, 47.2113, 6.9325]]},
})
tt1 = ttorch.mean(tt0)
assert isinstance(tt1, torch.Tensor)
assert ttorch.isclose(tt1, torch.tensor(1.2436), atol=1e-4).all()
tt2 = ttorch.mean(tt0, reduce=False)
assert ttorch.isclose(tt2, ttorch.tensor({
'a': -16.9062,
'b': {'x': 10.3186},
}), atol=1e-4).all()
tt3 = ttorch.mean(tt0, dim=1)
assert ttorch.isclose(tt3, ttorch.tensor({
'a': [8.5169, -42.3294],
'b': {'x': [3.9748, 15.9663, 11.0146]}
}), atol=1e-4).all()
with pytest.warns(UserWarning):
tt4 = ttorch.mean(tt0, dim=1, reduce=True)
assert ttorch.isclose(tt4, ttorch.tensor({
'a': [8.5169, -42.3294],
'b': {'x': [3.9748, 15.9663, 11.0146]}
}), atol=1e-4).all()
@choose_mark()
def test_std(self):
t0 = torch.tensor([[25.5133, 24.2050, 8.1067],
[22.7316, -17.8863, -37.9171]])
t1 = ttorch.std(t0)
assert isinstance(t1, torch.Tensor)
assert ttorch.isclose(t1, torch.tensor(26.3619), atol=1e-4).all()
t2 = ttorch.std(t0, dim=1)
assert isinstance(t2, torch.Tensor)
assert ttorch.isclose(t2, torch.tensor([9.6941, 30.9012]), atol=1e-4).all()
tt0 = ttorch.tensor({
'a': [[-48.6580, 30.9506, -16.1800],
[37.6667, 10.3850, -5.7679]],
'b': {'x': [[-17.9371, 8.4873, -49.0445, 4.7368],
[21.3990, -11.2385, -15.9331, -41.6838],
[-7.1814, -38.1301, -2.2320, 10.1392]]},
})
tt1 = ttorch.std(tt0)
assert isinstance(tt1, torch.Tensor)
assert ttorch.isclose(tt1, torch.tensor(25.6854), atol=1e-4).all()
tt2 = ttorch.std(tt0, reduce=False)
assert ttorch.isclose(tt2, ttorch.tensor({
'a': 32.0483,
'b': {'x': 22.1754},
}), atol=1e-4).all()
tt3 = ttorch.std(tt0, dim=1)
assert ttorch.isclose(tt3, ttorch.tensor({
'a': [40.0284, 21.9536],
'b': {'x': [26.4519, 25.9011, 20.5223]},
}), atol=1e-4).all()
with pytest.warns(UserWarning):
tt4 = ttorch.std(tt0, dim=1, reduce=True)
assert ttorch.isclose(tt4, ttorch.tensor({
'a': [40.0284, 21.9536],
'b': {'x': [26.4519, 25.9011, 20.5223]},
}), atol=1e-4).all()
@choose_mark()
def test_masked_select(self):
tx = torch.tensor([[0.0481, 0.1741, 0.9820, -0.6354],
......
......@@ -890,63 +890,21 @@ class TestTorchTensorMath:
[1.2041, 0.5740, math.nan]]},
}), rtol=1e-4, atol=1e-4, equal_nan=True).all()
@choose_mark()
def test_std(self):
t1 = ttorch.std(torch.tensor([[25.5133, 24.2050, 8.1067],
[22.7316, -17.8863, -37.9171]]))
assert isinstance(t1, torch.Tensor)
assert ttorch.isclose(t1, torch.tensor(26.3619), atol=1e-4).all()
tt1 = ttorch.std(ttorch.tensor({
'a': [[-48.6580, 30.9506, -16.1800],
[37.6667, 10.3850, -5.7679]],
'b': {'x': [[-17.9371, 8.4873, -49.0445, 4.7368],
[21.3990, -11.2385, -15.9331, -41.6838],
[-7.1814, -38.1301, -2.2320, 10.1392]]},
}))
assert ttorch.isclose(tt1, ttorch.tensor({
'a': 32.0483,
'b': {'x': 22.1754},
}), atol=1e-4).all()
@choose_mark()
def test_mean(self):
t1 = ttorch.mean(torch.tensor([[11.8069, 16.7822, -11.8583],
[-10.0426, 38.7326, 30.0298]]))
assert isinstance(t1, torch.Tensor)
assert ttorch.isclose(t1, torch.tensor(12.5751), atol=1e-4).all()
tt1 = ttorch.mean(ttorch.tensor({
'a': [[-29.3862, 10.3668, -19.8407],
[11.3299, -0.7511, -13.8404]],
'b': {'x': [[-25.1722, 22.6307, -9.3588, -6.8217],
[-31.4652, 6.6465, 36.9483, -4.0487],
[-17.2146, 24.0029, 35.4574, -29.2970]]},
}))
assert ttorch.isclose(tt1, ttorch.tensor({
'a': -7.0203,
'b': {'x': 0.1923},
}), atol=1e-4).all()
@choose_mark()
def test_dist(self):
t1 = ttorch.dist(
torch.tensor([-0.6566, 1.2243, 1.5018, -0.1492, 0.8947]),
t1 = torch.tensor([-0.6566, 1.2243, 1.5018, -0.1492, 0.8947]).dist(
torch.tensor([0.5898, 0.6839, 0.0388, 0.4649, 0.7964]),
)
assert isinstance(t1, torch.Tensor)
assert ttorch.isclose(t1, torch.tensor(2.0911), atol=1e-4).all()
tt1 = ttorch.dist(
ttorch.tensor({
'a': [-0.5491, 1.5006, -0.0483, 1.2282, -1.4837],
'b': {'x': [-1.8414, 1.2913, 0.0943, 0.3473, 1.2717, 0.6013]},
}),
ttorch.tensor({
'a': [0.1389, -0.7804, -1.3048, -1.1066, 1.3225],
'b': {'x': [1.4873, 0.2218, -0.1063, -0.8726, -0.6756, 0.4805]},
})
)
tt1 = ttorch.tensor({
'a': [-0.5491, 1.5006, -0.0483, 1.2282, -1.4837],
'b': {'x': [-1.8414, 1.2913, 0.0943, 0.3473, 1.2717, 0.6013]},
}).dist(ttorch.tensor({
'a': [0.1389, -0.7804, -1.3048, -1.1066, 1.3225],
'b': {'x': [1.4873, 0.2218, -0.1063, -0.8726, -0.6756, 0.4805]},
}))
assert ttorch.isclose(tt1, ttorch.tensor({
'a': 4.5366,
'b': {'x': 4.1904}
......@@ -954,19 +912,19 @@ class TestTorchTensorMath:
@choose_mark()
def test_norm(self):
t1 = ttorch.norm(torch.tensor([[0.0363, -1.7385, 1.0669, 2.6967],
[0.0848, 0.2735, 0.3538, 0.2271],
[-0.1014, 1.1351, -0.5761, -1.2671]]))
t1 = torch.tensor([[0.0363, -1.7385, 1.0669, 2.6967],
[0.0848, 0.2735, 0.3538, 0.2271],
[-0.1014, 1.1351, -0.5761, -1.2671]]).norm()
assert isinstance(t1, torch.Tensor)
assert ttorch.isclose(t1, torch.tensor(3.8638), atol=1e-4).all()
tt1 = ttorch.norm(ttorch.tensor({
tt1 = ttorch.tensor({
'a': [[-0.5012, 2.0900, 0.0151],
[-0.5035, 0.2144, 0.8370]],
'b': {'x': [[0.3911, 0.3557, -2.2156, 0.3653],
[-0.3503, 1.2182, -0.2364, -0.2854],
[-1.5770, -0.7349, 0.8391, -0.2845]]},
}))
}).norm()
assert ttorch.isclose(tt1, ttorch.tensor({
'a': 2.3706,
'b': {'x': 3.2982},
......
......@@ -69,6 +69,70 @@ class TestTorchTensorReduction:
assert isinstance(t1, torch.Tensor)
assert t1.tolist() == 7
@choose_mark()
def test_mean(self):
t0 = torch.tensor([[26.6598, 27.8008, -59.4753],
[-79.1833, 3.3349, 20.1665]])
t1 = t0.mean()
assert isinstance(t1, torch.Tensor)
assert ttorch.isclose(t1, torch.tensor(-10.1161), atol=1e-4).all()
t2 = t0.mean(dim=1)
assert isinstance(t2, torch.Tensor)
assert ttorch.isclose(t2, torch.tensor([-1.6716, -18.5606]), atol=1e-4).all()
tt0 = ttorch.tensor({
'a': [[25.2702, 37.4206, -37.1401],
[-7.7245, -91.3234, -27.9402]],
'b': {'x': [[3.2028, -14.0720, 18.1739, 8.5944],
[41.7761, 36.9908, -20.5495, 5.6480],
[-9.3438, -0.7416, 47.2113, 6.9325]]},
})
tt1 = tt0.mean()
assert isinstance(tt1, torch.Tensor)
assert ttorch.isclose(tt1, torch.tensor(1.2436), atol=1e-4).all()
tt2 = tt0.mean(reduce=False)
assert ttorch.isclose(tt2, ttorch.tensor({
'a': -16.9062,
'b': {'x': 10.3186},
}), atol=1e-4).all()
tt3 = tt0.mean(dim=1)
assert ttorch.isclose(tt3, ttorch.tensor({
'a': [8.5169, -42.3294],
'b': {'x': [3.9748, 15.9663, 11.0146]}
}), atol=1e-4).all()
@choose_mark()
def test_std(self):
t0 = torch.tensor([[25.5133, 24.2050, 8.1067],
[22.7316, -17.8863, -37.9171]])
t1 = t0.std()
assert isinstance(t1, torch.Tensor)
assert ttorch.isclose(t1, torch.tensor(26.3619), atol=1e-4).all()
t2 = t0.std(dim=1)
assert isinstance(t2, torch.Tensor)
assert ttorch.isclose(t2, torch.tensor([9.6941, 30.9012]), atol=1e-4).all()
tt0 = ttorch.tensor({
'a': [[-48.6580, 30.9506, -16.1800],
[37.6667, 10.3850, -5.7679]],
'b': {'x': [[-17.9371, 8.4873, -49.0445, 4.7368],
[21.3990, -11.2385, -15.9331, -41.6838],
[-7.1814, -38.1301, -2.2320, 10.1392]]},
})
tt1 = tt0.std()
assert isinstance(tt1, torch.Tensor)
assert ttorch.isclose(tt1, torch.tensor(25.6854), atol=1e-4).all()
tt2 = tt0.std(reduce=False)
assert ttorch.isclose(tt2, ttorch.tensor({
'a': 32.0483,
'b': {'x': 22.1754},
}), atol=1e-4).all()
tt3 = tt0.std(dim=1)
assert ttorch.isclose(tt3, ttorch.tensor({
'a': [40.0284, 21.9536],
'b': {'x': [26.4519, 25.9011, 20.5223]},
}), atol=1e-4).all()
@choose_mark()
def test_masked_select(self):
tx = torch.tensor([[0.0481, 0.1741, 0.9820, -0.6354],
......
from .reduce import *
from .reduce import __all__ as _reduce_all
from .torch import *
from .torch import __all__ as _torch_all
__all__ = [
*_reduce_all,
*_torch_all,
]
import warnings
from functools import wraps
from typing import Optional
import torch
from ...common import ireduce
__all__ = ['rmreduce', 'post_reduce', 'auto_reduce']
def _reduce_func(rfunc):
rfunc = rfunc or (lambda x: x)
def _new_func(ts):
return rfunc(torch.cat(tuple(map(lambda x: x.view((-1,)), ts))))
return _new_func
def rmreduce(rfunc=None):
return ireduce(_reduce_func(rfunc))
def post_reduce(rfunc=None, prefunc=None):
rfunc = rfunc or (lambda x, *args, **kwargs: x)
def _decorator(func):
func = rmreduce(prefunc)(func)
# noinspection PyUnusedLocal,PyShadowingBuiltins
@wraps(func)
def _new_func(input, *args, **kwargs):
result = func(input, *args, **kwargs)
return rfunc(result, *args, **kwargs)
return _new_func
return _decorator
def _default_auto_determine(*args, **kwargs):
return False if args or kwargs else None
def _default_auto_condition(*args, **kwargs):
return not args and not kwargs
def auto_reduce(rfunc, nrfunc, determine=None, condition=None):
determine = determine or _default_auto_determine
condition = condition or _default_auto_condition
def _decorator(func):
# noinspection PyUnusedLocal,PyShadowingBuiltins
@wraps(func)
def _new_func(input, *args, reduce: Optional[bool] = None, **kwargs):
_determine = determine(*args, **kwargs)
if _determine is not None:
if reduce is not None:
if not _determine and reduce:
warnings.warn(UserWarning(
f'Reduce forbidden for this case of function {func.__name__}, '
f'enablement of reduce option will be ignored.'), stacklevel=2)
elif _determine and not reduce:
warnings.warn(UserWarning(
f'Reduce must be processed for this case of function {func.__name__}, '
f'disablement of reduce option will be ignored.'), stacklevel=2)
reduce = not not _determine
_reduce = condition(*args, **kwargs) if reduce is None else not not reduce
return (rfunc if _reduce else nrfunc)(input, *args, **kwargs)
return _new_func
return _decorator
from typing import Type
from treevalue import typetrans, TreeValue
from treevalue import TreeValue, typetrans
from ..common import BaseTreeStruct
from ...common import BaseTreeStruct
__all__ = ['Torch']
__all__ = ['Torch', 'auto_torch']
class Torch(BaseTreeStruct):
pass
def _auto_torch(value, cls: Type[Torch]):
def auto_torch(value, cls: Type[Torch]):
return typetrans(value, cls) if isinstance(value, TreeValue) else value
......@@ -4,7 +4,7 @@ from treevalue import func_treelize as original_func_treelize
from treevalue.tree.common import BaseTree
from treevalue.utils import post_process
from ..base import _auto_torch
from ..base import auto_torch
from ..tensor import Tensor
from ...utils import doc_from_base as original_doc_from_base
from ...utils import replaceable_partial, args_mapping
......@@ -14,4 +14,4 @@ func_treelize = post_process(post_process(args_mapping(
replaceable_partial(original_func_treelize, return_type=Tensor)
)
doc_from_base = replaceable_partial(original_doc_from_base, base=torch)
auto_tensor = replaceable_partial(_auto_torch, cls=Tensor)
auto_tensor = replaceable_partial(auto_torch, cls=Tensor)
......@@ -9,7 +9,7 @@ __all__ = [
'add', 'sub', 'mul', 'div', 'pow', 'neg', 'neg_',
'exp', 'exp_', 'exp2', 'exp2_', 'sqrt', 'sqrt_',
'log', 'log_', 'log2', 'log2_', 'log10', 'log10_',
'mean', 'std', 'dist', 'norm',
'dist', 'norm',
]
......@@ -1079,95 +1079,6 @@ def log10_(input):
return torch.log10_(input)
# noinspection PyShadowingBuiltins
@doc_from_base()
@func_treelize()
def std(input, *args, **kwargs):
"""
Returns the standard-deviation of all elements in the ``input`` tensor.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> t = torch.randn((2, 3)) * 30
>>> t
tensor([[ 25.5133, 24.2050, 8.1067],
[ 22.7316, -17.8863, -37.9171]])
>>> ttorch.std(t)
tensor(26.3619)
>>> tt = ttorch.randn({
... 'a': (2, 3),
... 'b': {'x': (3, 4)},
... }) * 30
>>> tt
<Tensor 0x7f7c7288ca58>
├── a --> tensor([[-48.6580, 30.9506, -16.1800],
│ [ 37.6667, 10.3850, -5.7679]])
└── b --> <Tensor 0x7f7c7288c978>
└── x --> tensor([[-17.9371, 8.4873, -49.0445, 4.7368],
[ 21.3990, -11.2385, -15.9331, -41.6838],
[ -7.1814, -38.1301, -2.2320, 10.1392]])
>>> ttorch.std(tt)
<Tensor 0x7f7c7288c470>
├── a --> tensor(32.0483)
└── b --> <Tensor 0x7f7c7288c3c8>
└── x --> tensor(22.1754)
.. note::
Reduction will not be processed in :func:`treetensor.torch.std`.
It means the result should be a tree of tensors instead of one tensor.
"""
return torch.std(input, *args, **kwargs)
# noinspection PyShadowingBuiltins
@doc_from_base()
@func_treelize()
def mean(input, *args, **kwargs):
"""
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> t = torch.randn((2, 3)) * 30
>>> t
tensor([[ 11.8069, 16.7822, -11.8583],
[-10.0426, 38.7326, 30.0298]])
>>> ttorch.mean(t)
tensor(12.5751)
>>> tt = ttorch.randn({
... 'a': (2, 3),
... 'b': {'x': (3, 4)},
... }) * 30
>>> tt
<Tensor 0x7f95f684f6a0>
├── a --> tensor([[-29.3862, 10.3668, -19.8407],
│ [ 11.3299, -0.7511, -13.8404]])
└── b --> <Tensor 0x7f95f684f828>
└── x --> tensor([[-25.1722, 22.6307, -9.3588, -6.8217],
[-31.4652, 6.6465, 36.9483, -4.0487],
[-17.2146, 24.0029, 35.4574, -29.2970]])
>>> ttorch.mean(tt)
<Tensor 0x7f95f6849e80>
├── a --> tensor(-7.0203)
└── b --> <Tensor 0x7f95f6849470>
└── x --> tensor(0.1923)
.. note::
Reduction will not be processed in :func:`treetensor.torch.std`.
It means the result should be a tree of tensors instead of one tensor.
"""
return torch.mean(input, *args, **kwargs)
# noinspection PyShadowingBuiltins
@doc_from_base()
@func_treelize()
......
import torch
from .base import doc_from_base, func_treelize
from ..tensor import tireduce
from ...common import Object, ireduce
from ..base import rmreduce, post_reduce, auto_reduce
from ...common import Object
__all__ = [
'all', 'any',
'min', 'max', 'sum',
'min', 'max', 'sum', 'mean', 'std',
'masked_select',
]
# noinspection PyShadowingBuiltins
@doc_from_base()
@tireduce(torch.all)
@rmreduce(torch.all)
@func_treelize(return_type=Object)
def all(input, *args, **kwargs):
"""
......@@ -52,7 +52,7 @@ def all(input, *args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from_base()
@tireduce(torch.any)
@rmreduce(torch.any)
@func_treelize(return_type=Object)
def any(input, *args, **kwargs):
"""
......@@ -91,7 +91,7 @@ def any(input, *args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from_base()
@tireduce(torch.min)
@rmreduce(torch.min)
@func_treelize(return_type=Object)
def min(input, *args, **kwargs):
"""
......@@ -130,7 +130,7 @@ def min(input, *args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from_base()
@tireduce(torch.max)
@rmreduce(torch.max)
@func_treelize(return_type=Object)
def max(input, *args, **kwargs):
"""
......@@ -169,7 +169,7 @@ def max(input, *args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from_base()
@tireduce(torch.sum)
@rmreduce(torch.sum)
@func_treelize(return_type=Object)
def sum(input, *args, **kwargs):
"""
......@@ -206,9 +206,129 @@ def sum(input, *args, **kwargs):
return torch.sum(input, *args, **kwargs)
# noinspection PyShadowingBuiltins,PyUnusedLocal
@post_reduce(torch.mean)
@func_treelize(return_type=Object)
def _mean_r(input, *args, **kwargs):
return input
# noinspection PyShadowingBuiltins
@func_treelize()
def _mean_nr(input, *args, **kwargs):
return torch.mean(input, *args, **kwargs)
# noinspection PyShadowingBuiltins,PyUnusedLocal
@doc_from_base()
@auto_reduce(_mean_r, _mean_nr)
def mean(input, *args, reduce=None, **kwargs):
"""
Returns the mean value of all elements in the ``input`` tensor.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> t = torch.randn((2, 3)) * 30
>>> t
tensor([[ 26.6598, 27.8008, -59.4753],
[-79.1833, 3.3349, 20.1665]])
>>> ttorch.mean(t)
tensor(-10.1161)
>>> tt = ttorch.randn({
... 'a': (2, 3),
... 'b': {'x': (3, 4)},
... }) * 30
>>> tt
<Tensor 0x7f2f5b9f6cf8>
├── a --> tensor([[ 25.2702, 37.4206, -37.1401],
│ [ -7.7245, -91.3234, -27.9402]])
└── b --> <Tensor 0x7f2f5b9f6c18>
└── x --> tensor([[ 3.2028, -14.0720, 18.1739, 8.5944],
[ 41.7761, 36.9908, -20.5495, 5.6480],
[ -9.3438, -0.7416, 47.2113, 6.9325]])
>>> ttorch.mean(tt)
tensor(1.2436)
>>> ttorch.mean(tt, reduce=False)
<Tensor 0x7f1321caf080>
├── a --> tensor(-16.9062)
└── b --> <Tensor 0x7f1321caf048>
└── x --> tensor(10.3186)
>>> ttorch.mean(tt, dim=1)
<Tensor 0x7f63dbbc9828>
├── a --> tensor([ 8.5169, -42.3294])
└── b --> <Tensor 0x7f63dbbc9780>
└── x --> tensor([ 3.9748, 15.9663, 11.0146])
"""
pass # pragma: no cover
# noinspection PyShadowingBuiltins,PyUnusedLocal
@post_reduce(torch.std)
@func_treelize(return_type=Object)
def _std_r(input, *args, **kwargs):
return input
# noinspection PyShadowingBuiltins
@func_treelize()
def _std_nr(input, *args, **kwargs):
return torch.std(input, *args, **kwargs)
# noinspection PyShadowingBuiltins,PyUnusedLocal
@doc_from_base()
@auto_reduce(_std_r, _std_nr)
def std(input, *args, reduce=None, **kwargs):
"""
Returns the standard-deviation of all elements in the ``input`` tensor.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> t = torch.randn((2, 3)) * 30
>>> t
tensor([[ 25.5133, 24.2050, 8.1067],
[ 22.7316, -17.8863, -37.9171]])
>>> ttorch.std(t)
tensor(26.3619)
>>> tt = ttorch.randn({
... 'a': (2, 3),
... 'b': {'x': (3, 4)},
... }) * 30
>>> tt
<Tensor 0x7f7c7288ca58>
├── a --> tensor([[-48.6580, 30.9506, -16.1800],
│ [ 37.6667, 10.3850, -5.7679]])
└── b --> <Tensor 0x7f7c7288c978>
└── x --> tensor([[-17.9371, 8.4873, -49.0445, 4.7368],
[ 21.3990, -11.2385, -15.9331, -41.6838],
[ -7.1814, -38.1301, -2.2320, 10.1392]])
>>> ttorch.std()
tensor(25.6854)
>>> ttorch.std(tt, reduce=False)
<Tensor 0x7f7c7288c470>
├── a --> tensor(32.0483)
└── b --> <Tensor 0x7f7c7288c3c8>
└── x --> tensor(22.1754)
>>> ttorch.std(tt, dim=1)
<Tensor 0x7f1321ca1c50>
├── a --> tensor([40.0284, 21.9536])
└── b --> <Tensor 0x7f1321ca1fd0>
└── x --> tensor([26.4519, 25.9011, 20.5223])
"""
pass # pragma: no cover
# noinspection PyShadowingBuiltins
@doc_from_base()
@ireduce(torch.cat, piter=tuple)
@rmreduce()
@func_treelize(return_type=Object)
def masked_select(input, mask, *args, **kwargs):
"""
......
import numpy as np
import torch
from treevalue import method_treelize, TreeValue
from treevalue.utils import pre_process, post_process
from treevalue.utils import post_process
from .base import Torch, _auto_torch
from .base import Torch, auto_torch, rmreduce, post_reduce, auto_reduce
from .size import Size
from ..common import Object, ireduce, clsmeta, return_self
from ..numpy import ndarray
......@@ -14,8 +14,6 @@ __all__ = [
'Tensor'
]
_reduce_tensor_wrap = pre_process(lambda it: ((torch.tensor([*it]),), {}))
tireduce = pre_process(lambda rfunc: ((_reduce_tensor_wrap(rfunc),), {}))(ireduce)
doc_from_base = replaceable_partial(original_doc_from_base, base=torch.Tensor)
......@@ -176,9 +174,19 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return self.shape
@doc_from_base()
@return_self
@method_treelize()
def requires_grad_(self, requires_grad=True):
"""
Change if autograd should record operations on this tensor:
sets this tensor’s ``requires_grad`` attribute in-place. Returns this tensor.
"""
return self.requires_grad_(requires_grad)
# noinspection PyArgumentList
@doc_from_base()
@tireduce(torch.all)
@rmreduce(torch.all)
@method_treelize(return_type=Object)
def all(self: torch.Tensor, *args, **kwargs) -> bool:
"""
......@@ -188,7 +196,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
# noinspection PyArgumentList
@doc_from_base()
@tireduce(torch.any)
@rmreduce(torch.any)
@method_treelize(return_type=Object)
def any(self: torch.Tensor, *args, **kwargs) -> bool:
"""
......@@ -197,7 +205,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
return self.any(*args, **kwargs)
@doc_from_base()
@tireduce(torch.max)
@rmreduce(torch.max)
@method_treelize(return_type=Object)
def max(self: torch.Tensor, *args, **kwargs):
"""
......@@ -206,7 +214,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
return self.max(*args, **kwargs)
@doc_from_base()
@tireduce(torch.min)
@rmreduce(torch.min)
@method_treelize(return_type=Object)
def min(self: torch.Tensor, *args, **kwargs):
"""
......@@ -215,7 +223,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
return self.min(*args, **kwargs)
@doc_from_base()
@tireduce(torch.sum)
@rmreduce(torch.sum)
@method_treelize(return_type=Object)
def sum(self: torch.Tensor, *args, **kwargs):
"""
......@@ -653,7 +661,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
return self.log10_(*args, **kwargs)
@doc_from_base()
@post_process(lambda r: tuple(map(replaceable_partial(_auto_torch, cls=Tensor), r)))
@post_process(lambda r: tuple(map(replaceable_partial(auto_torch, cls=Tensor), r)))
@method_treelize(return_type=TreeValue, rise=dict(template=[None]))
@post_process(lambda r: list(r))
def split(self, split_size, *args, **kwargs):
......@@ -663,7 +671,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
return self.split(split_size, *args, **kwargs)
@doc_from_base()
@post_process(lambda r: tuple(map(replaceable_partial(_auto_torch, cls=Tensor), r)))
@post_process(lambda r: tuple(map(replaceable_partial(auto_torch, cls=Tensor), r)))
@method_treelize(return_type=TreeValue, rise=dict(template=[None]))
@post_process(lambda r: list(r))
def chunk(self, chunks, *args, **kwargs):
......@@ -733,7 +741,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
return self.index_select(dim, index)
@doc_from_base()
@ireduce(torch.cat, piter=tuple)
@rmreduce()
@method_treelize()
def masked_select(self, mask):
"""
......@@ -741,21 +749,43 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return self.masked_select(mask)
# noinspection PyUnusedLocal
@post_reduce(torch.std)
@method_treelize(return_type=Object)
def __std_r(self, *args, **kwargs):
return self
@method_treelize()
def __std_nr(self, *args, **kwargs):
return torch.std(self, *args, **kwargs)
@doc_from_base()
@auto_reduce(__std_r, __std_nr)
@method_treelize()
def std(self, *args, **kwargs):
"""
See :func:`treetensor.torch.std`.
"""
return self.std(*args, **kwargs)
pass # pragma: no cover
# noinspection PyUnusedLocal
@post_reduce(torch.mean)
@method_treelize(return_type=Object)
def __mean_r(self, *args, **kwargs):
return self
@method_treelize()
def __mean_nr(self, *args, **kwargs):
return torch.mean(self, *args, **kwargs)
@doc_from_base()
@auto_reduce(__mean_r, __mean_nr)
@method_treelize()
def mean(self, *args, **kwargs):
"""
See :func:`treetensor.torch.mean`.
"""
return self.mean(*args, **kwargs)
pass # pragma: no cover
@doc_from_base()
@method_treelize()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册