提交 eb3ec123 编写于 作者: HansBug's avatar HansBug 😆

dev(hansbug): add abs, sigmoid, sign, clamp, floor, ceil, round and its in-place versions

上级 a56cbc33
......@@ -36,6 +36,7 @@ def get_origin(obj):
def print_title(title: str, levelc='=', file=None):
title = title.replace('_', '\\_')
_print = partial(print, file=file)
_print(title)
_print(levelc * (len(title) + 5))
......
......@@ -149,20 +149,28 @@ def clsmeta(func, allow_dict: bool = False) -> Type[type]:
class _TempTreeValue(TreeValue):
pass
_types = (
TreeValue, BaseTree,
*((dict,) if allow_dict else ()),
)
func_treelize = post_process(post_process(args_mapping(
lambda i, x: TreeValue(x) if isinstance(x, _types) else x)))(
def _mapping_func(_, x):
if isinstance(x, TreeValue):
return x
elif isinstance(x, BaseTree):
return TreeValue(x)
elif allow_dict and isinstance(x, dict):
return TreeValue(x)
else:
return x
func_treelize = post_process(post_process(args_mapping(_mapping_func)))(
replaceable_partial(original_func_treelize, return_type=_TempTreeValue)
)
_wrapped_func = func_treelize()(func)
class _MetaClass(type):
def __call__(cls, *args, **kwargs):
_result = _wrapped_func(*args, **kwargs)
def __call__(cls, data, *args, **kwargs):
if isinstance(data, BaseTree):
return type.__call__(cls, data)
_result = _wrapped_func(data, *args, **kwargs)
if isinstance(_result, _TempTreeValue):
return type.__call__(cls, _result)
else:
......
......@@ -7,6 +7,7 @@ from treevalue import reduce_ as treevalue_reduce
__all__ = [
'kwreduce', 'ireduce', 'vreduce',
'return_self',
]
......@@ -55,3 +56,12 @@ def ireduce(rfunc):
return _new_func
return _decorator
def return_self(func):
@wraps(func)
def _new_func(self, *args, **kwargs):
func(self, *args, **kwargs)
return self
return _new_func
......@@ -7,7 +7,7 @@ from treevalue.tree.common import BaseTree
from treevalue.utils import post_process
from .tensor import Tensor, tireduce
from ..common import Object, ireduce
from ..common import Object, ireduce, return_self
from ..utils import replaceable_partial, doc_from, args_mapping
__all__ = [
......@@ -23,6 +23,8 @@ __all__ = [
'equal', 'tensor', 'clone',
'dot', 'matmul', 'mm',
'isfinite', 'isinf', 'isnan',
'abs', 'abs_', 'clamp', 'clamp_', 'sign', 'sigmoid', 'sigmoid_',
'round', 'round_', 'floor', 'floor_', 'ceil', 'ceil_',
]
func_treelize = post_process(post_process(args_mapping(
......@@ -1039,3 +1041,144 @@ def isnan(input):
[False, False, True]])
"""
return torch.isnan(input)
# noinspection PyShadowingBuiltins
@doc_from(torch.abs)
@func_treelize()
def abs(input, *args, **kwargs):
return torch.abs(input, *args, **kwargs)
# noinspection PyShadowingBuiltins
@doc_from(torch.abs_)
@return_self
@func_treelize()
def abs_(input):
return torch.abs_(input)
# noinspection PyShadowingBuiltins
@doc_from(torch.clamp)
@func_treelize()
def clamp(input, *args, **kwargs):
return torch.clamp(input, *args, **kwargs)
# noinspection PyShadowingBuiltins
@doc_from(torch.clamp_)
@return_self
@func_treelize()
def clamp_(input, *args, **kwargs):
return torch.clamp_(input, *args, **kwargs)
# noinspection PyShadowingBuiltins
@doc_from(torch.sign)
@func_treelize()
def sign(input, *args, **kwargs):
return torch.sign(input, *args, **kwargs)
# noinspection PyShadowingBuiltins
@doc_from(torch.round)
@func_treelize()
def round(input, *args, **kwargs):
return torch.round(input, *args, **kwargs)
# noinspection PyShadowingBuiltins
@doc_from(torch.round_)
@return_self
@func_treelize()
def round_(input):
return torch.round_(input)
# noinspection PyShadowingBuiltins
@doc_from(torch.floor)
@func_treelize()
def floor(input, *args, **kwargs):
return torch.floor(input, *args, **kwargs)
# noinspection PyShadowingBuiltins
@doc_from(torch.floor_)
@return_self
@func_treelize()
def floor_(input):
return torch.floor_(input)
# noinspection PyShadowingBuiltins
@doc_from(torch.ceil)
@func_treelize()
def ceil(input, *args, **kwargs):
return torch.ceil(input, *args, **kwargs)
# noinspection PyShadowingBuiltins
@doc_from(torch.ceil_)
@return_self
@func_treelize()
def ceil_(input):
return torch.ceil_(input)
# noinspection PyShadowingBuiltins
@doc_from(torch.sigmoid)
@func_treelize()
def sigmoid(input, *args, **kwargs):
"""
Get a tree of new tensors with the sigmoid of the elements of ``input``.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.tensor([1.0, 2.0, -1.5]).sigmoid()
tensor([0.7311, 0.8808, 0.1824])
>>> ttorch.tensor({
... 'a': [1.0, 2.0, -1.5],
... 'b': {'x': [[0.5, 1.2], [-2.5, 0.25]]},
... }).sigmoid()
<Tensor 0x7f973a312820>
├── a --> tensor([0.7311, 0.8808, 0.1824])
└── b --> <Tensor 0x7f973a3128b0>
└── x --> tensor([[0.6225, 0.7685],
[0.0759, 0.5622]])
"""
return torch.sigmoid(input, *args, **kwargs)
# noinspection PyShadowingBuiltins
@doc_from(torch.sigmoid_)
@return_self
@func_treelize()
def sigmoid_(input):
"""
In-place version of :func:`treetensor.torch.sigmoid`.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> t = ttorch.tensor([1.0, 2.0, -1.5])
>>> ttorch.sigmoid_(t)
>>> t
tensor([0.7311, 0.8808, 0.1824])
>>> t = ttorch.tensor({
... 'a': [1.0, 2.0, -1.5],
... 'b': {'x': [[0.5, 1.2], [-2.5, 0.25]]},
... })
>>> ttorch.sigmoid_(t)
>>> t
<Tensor 0x7f68fea8d040>
├── a --> tensor([0.7311, 0.8808, 0.1824])
└── b --> <Tensor 0x7f68fea8ee50>
└── x --> tensor([[0.6225, 0.7685],
[0.0759, 0.5622]])
"""
return torch.sigmoid_(input)
......@@ -5,7 +5,7 @@ from treevalue.utils import pre_process
from .base import Torch
from .size import Size
from ..common import Object, ireduce, clsmeta
from ..common import Object, ireduce, clsmeta, return_self
from ..numpy import ndarray
from ..utils import current_names, doc_from
......@@ -317,3 +317,113 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
See :func:`treetensor.torch.isnan`.
"""
return self.isnan()
@doc_from(torch.Tensor.abs)
@method_treelize()
def abs(self, *args, **kwargs):
"""
See :func:`treetensor.torch.abs`.
"""
return self.abs(*args, **kwargs)
@doc_from(torch.Tensor.abs_)
@return_self
@method_treelize()
def abs_(self, *args, **kwargs):
"""
See :func:`treetensor.torch.abs_`.
"""
return self.abs_(*args, **kwargs)
@doc_from(torch.Tensor.clamp)
@method_treelize()
def clamp(self, *args, **kwargs):
"""
See :func:`treetensor.torch.clamp`.
"""
return self.clamp(*args, **kwargs)
@doc_from(torch.Tensor.clamp_)
@return_self
@method_treelize()
def clamp_(self, *args, **kwargs):
"""
See :func:`treetensor.torch.clamp_`.
"""
return self.clamp_(*args, **kwargs)
@doc_from(torch.Tensor.sign)
@method_treelize()
def sign(self, *args, **kwargs):
"""
See :func:`treetensor.torch.sign`.
"""
return self.sign(*args, **kwargs)
@doc_from(torch.Tensor.sigmoid)
@method_treelize()
def sigmoid(self, *args, **kwargs):
"""
See :func:`treetensor.torch.sigmoid`.
"""
return self.sigmoid(*args, **kwargs)
@doc_from(torch.Tensor.sigmoid_)
@return_self
@method_treelize()
def sigmoid_(self, *args, **kwargs):
"""
See :func:`treetensor.torch.sigmoid_`.
"""
return self.sigmoid_(*args, **kwargs)
@doc_from(torch.Tensor.floor)
@method_treelize()
def floor(self, *args, **kwargs):
"""
See :func:`treetensor.torch.floor`.
"""
return self.floor(*args, **kwargs)
@doc_from(torch.Tensor.floor_)
@return_self
@method_treelize()
def floor_(self, *args, **kwargs):
"""
See :func:`treetensor.torch.floor_`.
"""
return self.floor_(*args, **kwargs)
@doc_from(torch.Tensor.ceil)
@method_treelize()
def ceil(self, *args, **kwargs):
"""
See :func:`treetensor.torch.ceil`.
"""
return self.ceil(*args, **kwargs)
@doc_from(torch.Tensor.ceil_)
@return_self
@method_treelize()
def ceil_(self, *args, **kwargs):
"""
See :func:`treetensor.torch.ceil_`.
"""
return self.ceil_(*args, **kwargs)
@doc_from(torch.Tensor.round)
@method_treelize()
def round(self, *args, **kwargs):
"""
See :func:`treetensor.torch.round`.
"""
return self.round(*args, **kwargs)
@doc_from(torch.Tensor.round_)
@return_self
@method_treelize()
def round_(self, *args, **kwargs):
"""
See :func:`treetensor.torch.round_`.
"""
return self.round_(*args, **kwargs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册