import torch from treevalue import TreeValue from treevalue.utils import post_process from .base import doc_from_base, func_treelize, auto_tensor from ..base import rmreduce, post_reduce, auto_reduce from ...common import Object __all__ = [ 'all', 'any', 'min', 'max', 'sum', 'mean', 'std', 'masked_select', ] # noinspection PyShadowingBuiltins,PyUnusedLocal @post_reduce(torch.all) @func_treelize(return_type=Object) def _all_r(input, *args, **kwargs): return input # noinspection PyShadowingBuiltins @func_treelize() def _all_nr(input, *args, **kwargs): return torch.all(input, *args, **kwargs) # noinspection PyShadowingBuiltins,PyUnusedLocal @doc_from_base() @auto_reduce(_all_r, _all_nr) def all(input, *args, reduce=None, **kwargs): """ In ``treetensor``, you can get the ``all`` result of a whole tree with this function. Example:: >>> import torch >>> import treetensor.torch as ttorch >>> ttorch.all(torch.tensor([True, True])) # the same as torch.all tensor(True) >>> ttorch.all(ttorch.tensor({'a': [True, True], 'b': {'x': [True, True]}})) tensor(True) >>> ttorch.all(ttorch.tensor({'a': [True, True], 'b': {'x': [True, False]}})) tensor(False) >>> ttorch.all(ttorch.tensor({'a': [True, True], 'b': {'x': [True, False]}}), reduce=False) ├── a --> tensor(True) └── b --> └── x --> tensor(False) >>> ttorch.all(ttorch.tensor({'a': [True, True], 'b': {'x': [True, False]}}), dim=0) ├── a --> tensor(True) └── b --> └── x --> tensor(False) """ pass # pragma: no cover # noinspection PyShadowingBuiltins,PyUnusedLocal @post_reduce(torch.any) @func_treelize(return_type=Object) def _any_r(input, *args, **kwargs): return input # noinspection PyShadowingBuiltins @func_treelize() def _any_nr(input, *args, **kwargs): return torch.any(input, *args, **kwargs) # noinspection PyShadowingBuiltins,PyUnusedLocal @doc_from_base() @auto_reduce(_any_r, _any_nr) def any(input, *args, reduce=None, **kwargs): """ In ``treetensor``, you can get the ``any`` result of a whole tree with this function. Example:: >>> import torch >>> import treetensor.torch as ttorch >>> ttorch.any(torch.tensor([False, False])) # the same as torch.any tensor(False) >>> ttorch.any(ttorch.tensor({'a': [True, False], 'b': {'x': [False, False]}})) tensor(True) >>> ttorch.any(ttorch.tensor({'a': [False, False], 'b': {'x': [False, False]}})) tensor(False) >>> ttorch.any(ttorch.tensor({'a': [True, False], 'b': {'x': [False, False]}}), reduce=False) ├── a --> tensor(True) └── b --> └── x --> tensor(False) >>> ttorch.any(ttorch.tensor({'a': [False, False], 'b': {'x': [False, False]}}), dim=0) ├── a --> tensor(False) └── b --> └── x --> tensor(False) """ pass # pragma: no cover # noinspection PyShadowingBuiltins,PyUnusedLocal @post_reduce(torch.min) @func_treelize(return_type=Object) def _min_r(input, *args, **kwargs): return input # noinspection PyShadowingBuiltins @post_process(auto_tensor) @func_treelize(return_type=TreeValue, rise=True) def _min_nr(input, *args, **kwargs): return torch.min(input, *args, **kwargs) # noinspection PyShadowingBuiltins,PyUnusedLocal @doc_from_base() @auto_reduce(_min_r, _min_nr) def min(input, *args, reduce=None, **kwargs): """ In ``treetensor``, you can get the ``min`` result of a whole tree with this function. Example:: >>> import torch >>> import treetensor.torch as ttorch >>> ttorch.min(torch.tensor([1.0, 2.0, 1.5])) # the same as torch.min tensor(1.) >>> ttorch.min(ttorch.tensor({ ... 'a': [1.0, 2.0, 1.5], ... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]}, ... })) tensor(0.9000) >>> ttorch.min(ttorch.tensor({ ... 'a': [1.0, 2.0, 1.5], ... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]}, ... }), reduce=False) ├── a --> tensor(1.) └── b --> └── x --> tensor(0.9000) >>> ttorch.min(ttorch.tensor({ ... 'a': [1.0, 2.0, 1.5], ... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]}, ... }), dim=0) torch.return_types.min( values= ├── a --> tensor(1.) └── b --> └── x --> tensor([1.3000, 0.9000]) , indices= ├── a --> tensor(0) └── b --> └── x --> tensor([1, 0]) ) """ pass # pragma: no cover # noinspection PyShadowingBuiltins,PyUnusedLocal @post_reduce(torch.max) @func_treelize(return_type=Object) def _max_r(input, *args, **kwargs): return input # noinspection PyShadowingBuiltins @post_process(auto_tensor) @func_treelize(return_type=TreeValue, rise=True) def _max_nr(input, *args, **kwargs): return torch.max(input, *args, **kwargs) # noinspection PyShadowingBuiltins,PyUnusedLocal @doc_from_base() @auto_reduce(_max_r, _max_nr) def max(input, *args, reduce=None, **kwargs): """ In ``treetensor``, you can get the ``max`` result of a whole tree with this function. Example:: >>> import torch >>> import treetensor.torch as ttorch >>> ttorch.max(torch.tensor([1.0, 2.0, 1.5])) # the same as torch.max tensor(2.) >>> ttorch.max(ttorch.tensor({ ... 'a': [1.0, 2.0, 1.5], ... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]}, ... })) tensor(2.5000) >>> ttorch.max(ttorch.tensor({ ... 'a': [1.0, 2.0, 1.5], ... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]}, ... }), reduce=False) ├── a --> tensor(2.) └── b --> └── x --> tensor(2.5000) >>> ttorch.max(ttorch.tensor({ ... 'a': [1.0, 2.0, 1.5], ... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]}, ... }), dim=0) torch.return_types.max( values= ├── a --> tensor(2.) └── b --> └── x --> tensor([1.8000, 2.5000]) , indices= ├── a --> tensor(1) └── b --> └── x --> tensor([0, 1]) ) """ pass # pragma: no cover # noinspection PyShadowingBuiltins,PyUnusedLocal @post_reduce(torch.sum) @func_treelize(return_type=Object) def _sum_r(input, *args, **kwargs): return input # noinspection PyShadowingBuiltins @func_treelize() def _sum_nr(input, *args, **kwargs): return torch.sum(input, *args, **kwargs) # noinspection PyShadowingBuiltins,PyUnusedLocal @doc_from_base() @auto_reduce(_sum_r, _sum_nr) def sum(input, *args, reduce=None, **kwargs): """ In ``treetensor``, you can get the ``sum`` result of a whole tree with this function. Example:: >>> import torch >>> import treetensor.torch as ttorch >>> ttorch.sum(torch.tensor([1.0, 2.0, 1.5])) # the same as torch.sum tensor(4.5000) >>> ttorch.sum(ttorch.tensor({ ... 'a': [1.0, 2.0, 1.5], ... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]}, ... })) tensor(11.) >>> ttorch.sum(ttorch.tensor({ ... 'a': [1.0, 2.0, 1.5], ... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]}, ... }), reduce=False) ├── a --> tensor(4.5000) └── b --> └── x --> tensor(6.5000) >>> ttorch.sum(ttorch.tensor({ ... 'a': [1.0, 2.0, 1.5], ... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]}, ... }), dim=0) ├── a --> tensor(4.5000) └── b --> └── x --> tensor([3.1000, 3.4000]) """ pass # pragma: no cover # noinspection PyShadowingBuiltins,PyUnusedLocal @post_reduce(torch.mean) @func_treelize(return_type=Object) def _mean_r(input, *args, **kwargs): return input # noinspection PyShadowingBuiltins @func_treelize() def _mean_nr(input, *args, **kwargs): return torch.mean(input, *args, **kwargs) # noinspection PyShadowingBuiltins,PyUnusedLocal @doc_from_base() @auto_reduce(_mean_r, _mean_nr) def mean(input, *args, reduce=None, **kwargs): """ Returns the mean value of all elements in the ``input`` tensor. Examples:: >>> import torch >>> import treetensor.torch as ttorch >>> t = torch.randn((2, 3)) * 30 >>> t tensor([[ 26.6598, 27.8008, -59.4753], [-79.1833, 3.3349, 20.1665]]) >>> ttorch.mean(t) tensor(-10.1161) >>> tt = ttorch.randn({ ... 'a': (2, 3), ... 'b': {'x': (3, 4)}, ... }) * 30 >>> tt ├── a --> tensor([[ 25.2702, 37.4206, -37.1401], │ [ -7.7245, -91.3234, -27.9402]]) └── b --> └── x --> tensor([[ 3.2028, -14.0720, 18.1739, 8.5944], [ 41.7761, 36.9908, -20.5495, 5.6480], [ -9.3438, -0.7416, 47.2113, 6.9325]]) >>> ttorch.mean(tt) tensor(1.2436) >>> ttorch.mean(tt, reduce=False) ├── a --> tensor(-16.9062) └── b --> └── x --> tensor(10.3186) >>> ttorch.mean(tt, dim=1) ├── a --> tensor([ 8.5169, -42.3294]) └── b --> └── x --> tensor([ 3.9748, 15.9663, 11.0146]) """ pass # pragma: no cover # noinspection PyShadowingBuiltins,PyUnusedLocal @post_reduce(torch.std) @func_treelize(return_type=Object) def _std_r(input, *args, **kwargs): return input # noinspection PyShadowingBuiltins @func_treelize() def _std_nr(input, *args, **kwargs): return torch.std(input, *args, **kwargs) # noinspection PyShadowingBuiltins,PyUnusedLocal @doc_from_base() @auto_reduce(_std_r, _std_nr) def std(input, *args, reduce=None, **kwargs): """ Returns the standard-deviation of all elements in the ``input`` tensor. Examples:: >>> import torch >>> import treetensor.torch as ttorch >>> t = torch.randn((2, 3)) * 30 >>> t tensor([[ 25.5133, 24.2050, 8.1067], [ 22.7316, -17.8863, -37.9171]]) >>> ttorch.std(t) tensor(26.3619) >>> tt = ttorch.randn({ ... 'a': (2, 3), ... 'b': {'x': (3, 4)}, ... }) * 30 >>> tt ├── a --> tensor([[-48.6580, 30.9506, -16.1800], │ [ 37.6667, 10.3850, -5.7679]]) └── b --> └── x --> tensor([[-17.9371, 8.4873, -49.0445, 4.7368], [ 21.3990, -11.2385, -15.9331, -41.6838], [ -7.1814, -38.1301, -2.2320, 10.1392]]) >>> ttorch.std() tensor(25.6854) >>> ttorch.std(tt, reduce=False) ├── a --> tensor(32.0483) └── b --> └── x --> tensor(22.1754) >>> ttorch.std(tt, dim=1) ├── a --> tensor([40.0284, 21.9536]) └── b --> └── x --> tensor([26.4519, 25.9011, 20.5223]) """ pass # pragma: no cover # noinspection PyShadowingBuiltins,PyUnusedLocal @rmreduce() @func_treelize(return_type=Object) def _masked_select_r(input, mask, *args, **kwargs): return torch.masked_select(input, mask, *args, **kwargs) # noinspection PyShadowingBuiltins @func_treelize() def _masked_select_nr(input, mask, *args, **kwargs): return torch.masked_select(input, mask, *args, **kwargs) # noinspection PyUnusedLocal def _ms_determine(mask, *args, out=None, **kwargs): return False if args or kwargs else None # noinspection PyUnusedLocal def _ms_condition(mask, *args, out=None, **kwargs): return not args and not kwargs # noinspection PyShadowingBuiltins,PyUnusedLocal @doc_from_base() @auto_reduce(_masked_select_r, _masked_select_nr, _ms_determine, _ms_condition) def masked_select(input, mask, *args, reduce=None, **kwargs): """ Returns a new 1-D tensor which indexes the ``input`` tensor according to the boolean mask ``mask`` which is a BoolTensor. Examples:: >>> import torch >>> import treetensor.torch as ttorch >>> t = torch.randn(3, 4) >>> t tensor([[ 0.0481, 0.1741, 0.9820, -0.6354], [ 0.8108, -0.7126, 0.1329, 1.0868], [-1.8267, 1.3676, -1.4490, -2.0224]]) >>> ttorch.masked_select(t, t > 0.3) tensor([0.9820, 0.8108, 1.0868, 1.3676]) >>> tt = ttorch.randn({ ... 'a': (2, 3), ... 'b': {'x': (3, 4)}, ... }) >>> tt ├── a --> tensor([[ 1.1799, 0.4652, -1.7895], │ [ 0.0423, 1.0866, 1.3533]]) └── b --> └── x --> tensor([[ 0.8139, -0.6732, 0.0065, 0.9073], [ 0.0596, -2.0621, -0.1598, -1.0793], [-0.0496, 2.1392, 0.6403, 0.4041]]) >>> ttorch.masked_select(tt, tt > 0.3) tensor([1.1799, 0.4652, 1.0866, 1.3533, 0.8139, 0.9073, 2.1392, 0.6403, 0.4041]) >>> ttorch.masked_select(tt, tt > 0.3, reduce=False) ├── a --> tensor([1.1799, 0.4652, 1.0866, 1.3533]) └── b --> └── x --> tensor([0.8139, 0.9073, 2.1392, 0.6403, 0.4041]) """ pass # pragma: no cover