Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
flybirding10011
DI-treetensor
提交
eb3ec123
D
DI-treetensor
项目概览
flybirding10011
/
DI-treetensor
与 Fork 源项目一致
Fork自
OpenDILab开源决策智能平台 / DI-treetensor
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-treetensor
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
eb3ec123
编写于
9月 21, 2021
作者:
HansBug
😆
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
dev(hansbug): add abs, sigmoid, sign, clamp, floor, ceil, round and its in-place versions
上级
a56cbc33
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
282 addition
and
10 deletion
+282
-10
docs/source/_libs/docs.py
docs/source/_libs/docs.py
+1
-0
treetensor/common/trees.py
treetensor/common/trees.py
+16
-8
treetensor/common/wrappers.py
treetensor/common/wrappers.py
+10
-0
treetensor/torch/funcs.py
treetensor/torch/funcs.py
+144
-1
treetensor/torch/tensor.py
treetensor/torch/tensor.py
+111
-1
未找到文件。
docs/source/_libs/docs.py
浏览文件 @
eb3ec123
...
...
@@ -36,6 +36,7 @@ def get_origin(obj):
def
print_title
(
title
:
str
,
levelc
=
'='
,
file
=
None
):
title
=
title
.
replace
(
'_'
,
'
\\
_'
)
_print
=
partial
(
print
,
file
=
file
)
_print
(
title
)
_print
(
levelc
*
(
len
(
title
)
+
5
))
...
...
treetensor/common/trees.py
浏览文件 @
eb3ec123
...
...
@@ -149,20 +149,28 @@ def clsmeta(func, allow_dict: bool = False) -> Type[type]:
class
_TempTreeValue
(
TreeValue
):
pass
_types
=
(
TreeValue
,
BaseTree
,
*
((
dict
,)
if
allow_dict
else
()),
)
func_treelize
=
post_process
(
post_process
(
args_mapping
(
lambda
i
,
x
:
TreeValue
(
x
)
if
isinstance
(
x
,
_types
)
else
x
)))(
def
_mapping_func
(
_
,
x
):
if
isinstance
(
x
,
TreeValue
):
return
x
elif
isinstance
(
x
,
BaseTree
):
return
TreeValue
(
x
)
elif
allow_dict
and
isinstance
(
x
,
dict
):
return
TreeValue
(
x
)
else
:
return
x
func_treelize
=
post_process
(
post_process
(
args_mapping
(
_mapping_func
)))(
replaceable_partial
(
original_func_treelize
,
return_type
=
_TempTreeValue
)
)
_wrapped_func
=
func_treelize
()(
func
)
class
_MetaClass
(
type
):
def
__call__
(
cls
,
*
args
,
**
kwargs
):
_result
=
_wrapped_func
(
*
args
,
**
kwargs
)
def
__call__
(
cls
,
data
,
*
args
,
**
kwargs
):
if
isinstance
(
data
,
BaseTree
):
return
type
.
__call__
(
cls
,
data
)
_result
=
_wrapped_func
(
data
,
*
args
,
**
kwargs
)
if
isinstance
(
_result
,
_TempTreeValue
):
return
type
.
__call__
(
cls
,
_result
)
else
:
...
...
treetensor/common/wrappers.py
浏览文件 @
eb3ec123
...
...
@@ -7,6 +7,7 @@ from treevalue import reduce_ as treevalue_reduce
__all__
=
[
'kwreduce'
,
'ireduce'
,
'vreduce'
,
'return_self'
,
]
...
...
@@ -55,3 +56,12 @@ def ireduce(rfunc):
return
_new_func
return
_decorator
def
return_self
(
func
):
@
wraps
(
func
)
def
_new_func
(
self
,
*
args
,
**
kwargs
):
func
(
self
,
*
args
,
**
kwargs
)
return
self
return
_new_func
treetensor/torch/funcs.py
浏览文件 @
eb3ec123
...
...
@@ -7,7 +7,7 @@ from treevalue.tree.common import BaseTree
from
treevalue.utils
import
post_process
from
.tensor
import
Tensor
,
tireduce
from
..common
import
Object
,
ireduce
from
..common
import
Object
,
ireduce
,
return_self
from
..utils
import
replaceable_partial
,
doc_from
,
args_mapping
__all__
=
[
...
...
@@ -23,6 +23,8 @@ __all__ = [
'equal'
,
'tensor'
,
'clone'
,
'dot'
,
'matmul'
,
'mm'
,
'isfinite'
,
'isinf'
,
'isnan'
,
'abs'
,
'abs_'
,
'clamp'
,
'clamp_'
,
'sign'
,
'sigmoid'
,
'sigmoid_'
,
'round'
,
'round_'
,
'floor'
,
'floor_'
,
'ceil'
,
'ceil_'
,
]
func_treelize
=
post_process
(
post_process
(
args_mapping
(
...
...
@@ -1039,3 +1041,144 @@ def isnan(input):
[False, False, True]])
"""
return
torch
.
isnan
(
input
)
# noinspection PyShadowingBuiltins
@
doc_from
(
torch
.
abs
)
@
func_treelize
()
def
abs
(
input
,
*
args
,
**
kwargs
):
return
torch
.
abs
(
input
,
*
args
,
**
kwargs
)
# noinspection PyShadowingBuiltins
@
doc_from
(
torch
.
abs_
)
@
return_self
@
func_treelize
()
def
abs_
(
input
):
return
torch
.
abs_
(
input
)
# noinspection PyShadowingBuiltins
@
doc_from
(
torch
.
clamp
)
@
func_treelize
()
def
clamp
(
input
,
*
args
,
**
kwargs
):
return
torch
.
clamp
(
input
,
*
args
,
**
kwargs
)
# noinspection PyShadowingBuiltins
@
doc_from
(
torch
.
clamp_
)
@
return_self
@
func_treelize
()
def
clamp_
(
input
,
*
args
,
**
kwargs
):
return
torch
.
clamp_
(
input
,
*
args
,
**
kwargs
)
# noinspection PyShadowingBuiltins
@
doc_from
(
torch
.
sign
)
@
func_treelize
()
def
sign
(
input
,
*
args
,
**
kwargs
):
return
torch
.
sign
(
input
,
*
args
,
**
kwargs
)
# noinspection PyShadowingBuiltins
@
doc_from
(
torch
.
round
)
@
func_treelize
()
def
round
(
input
,
*
args
,
**
kwargs
):
return
torch
.
round
(
input
,
*
args
,
**
kwargs
)
# noinspection PyShadowingBuiltins
@
doc_from
(
torch
.
round_
)
@
return_self
@
func_treelize
()
def
round_
(
input
):
return
torch
.
round_
(
input
)
# noinspection PyShadowingBuiltins
@
doc_from
(
torch
.
floor
)
@
func_treelize
()
def
floor
(
input
,
*
args
,
**
kwargs
):
return
torch
.
floor
(
input
,
*
args
,
**
kwargs
)
# noinspection PyShadowingBuiltins
@
doc_from
(
torch
.
floor_
)
@
return_self
@
func_treelize
()
def
floor_
(
input
):
return
torch
.
floor_
(
input
)
# noinspection PyShadowingBuiltins
@
doc_from
(
torch
.
ceil
)
@
func_treelize
()
def
ceil
(
input
,
*
args
,
**
kwargs
):
return
torch
.
ceil
(
input
,
*
args
,
**
kwargs
)
# noinspection PyShadowingBuiltins
@
doc_from
(
torch
.
ceil_
)
@
return_self
@
func_treelize
()
def
ceil_
(
input
):
return
torch
.
ceil_
(
input
)
# noinspection PyShadowingBuiltins
@
doc_from
(
torch
.
sigmoid
)
@
func_treelize
()
def
sigmoid
(
input
,
*
args
,
**
kwargs
):
"""
Get a tree of new tensors with the sigmoid of the elements of ``input``.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.tensor([1.0, 2.0, -1.5]).sigmoid()
tensor([0.7311, 0.8808, 0.1824])
>>> ttorch.tensor({
... 'a': [1.0, 2.0, -1.5],
... 'b': {'x': [[0.5, 1.2], [-2.5, 0.25]]},
... }).sigmoid()
<Tensor 0x7f973a312820>
├── a --> tensor([0.7311, 0.8808, 0.1824])
└── b --> <Tensor 0x7f973a3128b0>
└── x --> tensor([[0.6225, 0.7685],
[0.0759, 0.5622]])
"""
return
torch
.
sigmoid
(
input
,
*
args
,
**
kwargs
)
# noinspection PyShadowingBuiltins
@
doc_from
(
torch
.
sigmoid_
)
@
return_self
@
func_treelize
()
def
sigmoid_
(
input
):
"""
In-place version of :func:`treetensor.torch.sigmoid`.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> t = ttorch.tensor([1.0, 2.0, -1.5])
>>> ttorch.sigmoid_(t)
>>> t
tensor([0.7311, 0.8808, 0.1824])
>>> t = ttorch.tensor({
... 'a': [1.0, 2.0, -1.5],
... 'b': {'x': [[0.5, 1.2], [-2.5, 0.25]]},
... })
>>> ttorch.sigmoid_(t)
>>> t
<Tensor 0x7f68fea8d040>
├── a --> tensor([0.7311, 0.8808, 0.1824])
└── b --> <Tensor 0x7f68fea8ee50>
└── x --> tensor([[0.6225, 0.7685],
[0.0759, 0.5622]])
"""
return
torch
.
sigmoid_
(
input
)
treetensor/torch/tensor.py
浏览文件 @
eb3ec123
...
...
@@ -5,7 +5,7 @@ from treevalue.utils import pre_process
from
.base
import
Torch
from
.size
import
Size
from
..common
import
Object
,
ireduce
,
clsmeta
from
..common
import
Object
,
ireduce
,
clsmeta
,
return_self
from
..numpy
import
ndarray
from
..utils
import
current_names
,
doc_from
...
...
@@ -317,3 +317,113 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
See :func:`treetensor.torch.isnan`.
"""
return
self
.
isnan
()
@
doc_from
(
torch
.
Tensor
.
abs
)
@
method_treelize
()
def
abs
(
self
,
*
args
,
**
kwargs
):
"""
See :func:`treetensor.torch.abs`.
"""
return
self
.
abs
(
*
args
,
**
kwargs
)
@
doc_from
(
torch
.
Tensor
.
abs_
)
@
return_self
@
method_treelize
()
def
abs_
(
self
,
*
args
,
**
kwargs
):
"""
See :func:`treetensor.torch.abs_`.
"""
return
self
.
abs_
(
*
args
,
**
kwargs
)
@
doc_from
(
torch
.
Tensor
.
clamp
)
@
method_treelize
()
def
clamp
(
self
,
*
args
,
**
kwargs
):
"""
See :func:`treetensor.torch.clamp`.
"""
return
self
.
clamp
(
*
args
,
**
kwargs
)
@
doc_from
(
torch
.
Tensor
.
clamp_
)
@
return_self
@
method_treelize
()
def
clamp_
(
self
,
*
args
,
**
kwargs
):
"""
See :func:`treetensor.torch.clamp_`.
"""
return
self
.
clamp_
(
*
args
,
**
kwargs
)
@
doc_from
(
torch
.
Tensor
.
sign
)
@
method_treelize
()
def
sign
(
self
,
*
args
,
**
kwargs
):
"""
See :func:`treetensor.torch.sign`.
"""
return
self
.
sign
(
*
args
,
**
kwargs
)
@
doc_from
(
torch
.
Tensor
.
sigmoid
)
@
method_treelize
()
def
sigmoid
(
self
,
*
args
,
**
kwargs
):
"""
See :func:`treetensor.torch.sigmoid`.
"""
return
self
.
sigmoid
(
*
args
,
**
kwargs
)
@
doc_from
(
torch
.
Tensor
.
sigmoid_
)
@
return_self
@
method_treelize
()
def
sigmoid_
(
self
,
*
args
,
**
kwargs
):
"""
See :func:`treetensor.torch.sigmoid_`.
"""
return
self
.
sigmoid_
(
*
args
,
**
kwargs
)
@
doc_from
(
torch
.
Tensor
.
floor
)
@
method_treelize
()
def
floor
(
self
,
*
args
,
**
kwargs
):
"""
See :func:`treetensor.torch.floor`.
"""
return
self
.
floor
(
*
args
,
**
kwargs
)
@
doc_from
(
torch
.
Tensor
.
floor_
)
@
return_self
@
method_treelize
()
def
floor_
(
self
,
*
args
,
**
kwargs
):
"""
See :func:`treetensor.torch.floor_`.
"""
return
self
.
floor_
(
*
args
,
**
kwargs
)
@
doc_from
(
torch
.
Tensor
.
ceil
)
@
method_treelize
()
def
ceil
(
self
,
*
args
,
**
kwargs
):
"""
See :func:`treetensor.torch.ceil`.
"""
return
self
.
ceil
(
*
args
,
**
kwargs
)
@
doc_from
(
torch
.
Tensor
.
ceil_
)
@
return_self
@
method_treelize
()
def
ceil_
(
self
,
*
args
,
**
kwargs
):
"""
See :func:`treetensor.torch.ceil_`.
"""
return
self
.
ceil_
(
*
args
,
**
kwargs
)
@
doc_from
(
torch
.
Tensor
.
round
)
@
method_treelize
()
def
round
(
self
,
*
args
,
**
kwargs
):
"""
See :func:`treetensor.torch.round`.
"""
return
self
.
round
(
*
args
,
**
kwargs
)
@
doc_from
(
torch
.
Tensor
.
round_
)
@
return_self
@
method_treelize
()
def
round_
(
self
,
*
args
,
**
kwargs
):
"""
See :func:`treetensor.torch.round_`.
"""
return
self
.
round_
(
*
args
,
**
kwargs
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录