Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
flybirding10011
DI-treetensor
提交
f45bd500
D
DI-treetensor
项目概览
flybirding10011
/
DI-treetensor
与 Fork 源项目一致
Fork自
OpenDILab开源决策智能平台 / DI-treetensor
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-treetensor
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
f45bd500
编写于
9月 16, 2021
作者:
HansBug
😆
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
dev(hansbug): add __str__ and __repr__ support for Tensor
上级
8c14515a
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
170 addition
and
7 deletion
+170
-7
treetensor/common/trees.py
treetensor/common/trees.py
+81
-3
treetensor/torch/funcs.py
treetensor/torch/funcs.py
+89
-4
未找到文件。
treetensor/common/trees.py
浏览文件 @
f45bd500
import
builtins
import
io
import
os
from
abc
import
ABCMeta
from
functools
import
partial
from
typing
import
Optional
,
Tuple
,
Callable
from
treevalue
import
general_tree_value
from
treevalue
import
general_tree_value
,
TreeValue
from
treevalue.tree.tree.tree
import
get_data_property
__all__
=
[
'BaseTreeStruct'
,
"TreeObject"
,
'BaseTreeStruct'
,
"TreeObject"
,
'print_tree'
,
]
def
_tree_title
(
node
:
TreeValue
):
_tree
=
get_data_property
(
node
)
return
"<{cls} {id}>"
.
format
(
cls
=
node
.
__class__
.
__name__
,
id
=
hex
(
id
(
_tree
.
actual
())),
)
def
print_tree
(
tree
:
TreeValue
,
repr_
:
Callable
=
str
,
ascii_
:
bool
=
False
,
file
=
None
):
print_to_file
=
partial
(
builtins
.
print
,
file
=
file
)
node_ids
=
{}
if
ascii_
:
_HORI
,
_VECT
,
_CROS
,
_SROS
=
'|'
,
'-'
,
'+'
,
'+'
else
:
_HORI
,
_VECT
,
_CROS
,
_SROS
=
'
\u2502
'
,
'
\u2500
'
,
'
\u251c
'
,
'
\u2514
'
def
_print_layer
(
node
,
path
:
Tuple
[
str
,
...],
prefixes
:
Tuple
[
str
,
...],
current_key
:
Optional
[
str
]
=
None
,
is_last_key
:
bool
=
True
):
# noinspection PyShadowingBuiltins
def
print
(
*
args
,
pid
:
Optional
[
int
]
=
-
1
,
**
kwargs
,
):
if
pid
is
not
None
:
print_to_file
(
prefixes
[
pid
],
end
=
''
)
print_to_file
(
*
args
,
**
kwargs
)
_need_iter
=
True
if
isinstance
(
node
,
TreeValue
):
_node_id
=
id
(
get_data_property
(
node
).
actual
())
_content
=
f
'<
{
node
.
__class__
.
__name__
}
{
hex
(
_node_id
)
}
>'
if
_node_id
in
node_ids
.
keys
():
_str_old_path
=
'.'
.
join
((
'<root>'
,
*
node_ids
[
_node_id
]))
_content
=
f
'
{
_content
}{
os
.
linesep
}
(The same address as
{
_str_old_path
}
)'
_need_iter
=
False
else
:
node_ids
[
_node_id
]
=
path
_need_iter
=
True
else
:
_content
=
repr_
(
node
)
_need_iter
=
False
if
current_key
:
_key_arrow
=
f
'
{
current_key
}
--> '
_appended_prefix
=
(
_HORI
if
_need_iter
and
len
(
node
)
>
0
else
' '
)
+
' '
*
(
len
(
_key_arrow
)
-
1
)
for
index
,
line
in
enumerate
(
_content
.
splitlines
()):
if
index
==
0
:
print
(
f
'
{
_CROS
if
not
is_last_key
else
_SROS
}{
_VECT
*
2
}
{
_key_arrow
}
'
,
pid
=-
2
,
end
=
''
)
else
:
print
(
_appended_prefix
,
end
=
''
)
print
(
line
,
pid
=
None
)
else
:
print
(
_content
)
if
_need_iter
:
_length
=
len
(
node
)
for
index
,
(
key
,
value
)
in
enumerate
(
sorted
(
node
)):
_is_last_line
=
index
+
1
>=
_length
_new_prefixes
=
(
*
prefixes
,
prefixes
[
-
1
]
+
f
'
{
_HORI
if
not
_is_last_line
else
" "
}
'
)
_new_path
=
(
*
path
,
key
)
_print_layer
(
value
,
_new_path
,
_new_prefixes
,
key
,
_is_last_line
)
if
isinstance
(
tree
,
TreeValue
):
_print_layer
(
tree
,
(),
(
''
,
''
,))
else
:
print
(
repr_
(
tree
),
file
=
file
)
class
BaseTreeStruct
(
general_tree_value
(),
metaclass
=
ABCMeta
):
"""
Overview:
Base structure of all the trees in ``treetensor``.
"""
pass
def
__repr__
(
self
):
with
io
.
StringIO
()
as
sfile
:
print_tree
(
self
,
repr_
=
repr
,
ascii_
=
False
,
file
=
sfile
)
return
sfile
.
getvalue
()
def
__str__
(
self
):
return
self
.
__repr__
()
class
TreeObject
(
BaseTreeStruct
):
...
...
treetensor/torch/funcs.py
浏览文件 @
f45bd500
...
...
@@ -23,8 +23,8 @@ __all__ = [
'empty'
,
'empty_like'
,
'all'
,
'any'
,
'min'
,
'max'
,
'sum'
,
'eq'
,
'
equal
'
,
'tensor'
,
'eq'
,
'
ne'
,
'lt'
,
'le'
,
'gt'
,
'ge
'
,
'
equal'
,
'
tensor'
,
]
func_treelize
=
post_process
(
post_process
(
args_mapping
(
...
...
@@ -446,15 +446,100 @@ def sum(input, *args, **kwargs):
@
doc_from
(
torch
.
eq
)
@
func_treelize
()
def
eq
(
input
,
other
,
*
args
,
**
kwargs
):
"""
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.eq(
>>> torch.tensor([[1, 2], [3, 4]]),
>>> torch.tensor([[1, 1], [4, 4]]),
>>> )
torch.tensor([[ True, False],
[False, True]])
>>> ttorch.eq(
>>> ttorch.tensor({
>>> 'a': [[1, 2], [3, 4]],
>>> 'b': [1.0, 1.5, 2.0],
>>> }),
>>> ttorch.tensor({
>>> 'a': [[1, 1], [4, 4]],
>>> 'b': [1.3, 1.2, 2.0],
>>> }),
>>> )
"""
return
torch
.
eq
(
input
,
other
,
*
args
,
**
kwargs
)
# noinspection PyShadowingBuiltins
@
doc_from
(
torch
.
ne
)
@
func_treelize
()
def
ne
(
input
,
other
,
*
args
,
**
kwargs
):
return
torch
.
ne
(
input
,
other
,
*
args
,
**
kwargs
)
# noinspection PyShadowingBuiltins
@
doc_from
(
torch
.
lt
)
@
func_treelize
()
def
lt
(
input
,
other
,
*
args
,
**
kwargs
):
return
torch
.
lt
(
input
,
other
,
*
args
,
**
kwargs
)
# noinspection PyShadowingBuiltins
@
doc_from
(
torch
.
le
)
@
func_treelize
()
def
le
(
input
,
other
,
*
args
,
**
kwargs
):
return
torch
.
le
(
input
,
other
,
*
args
,
**
kwargs
)
# noinspection PyShadowingBuiltins
@
doc_from
(
torch
.
gt
)
@
func_treelize
()
def
gt
(
input
,
other
,
*
args
,
**
kwargs
):
return
torch
.
gt
(
input
,
other
,
*
args
,
**
kwargs
)
# noinspection PyShadowingBuiltins
@
doc_from
(
torch
.
ge
)
@
func_treelize
()
def
ge
(
input
,
other
,
*
args
,
**
kwargs
):
return
torch
.
ge
(
input
,
other
,
*
args
,
**
kwargs
)
# noinspection PyShadowingBuiltins,PyArgumentList
@
doc_from
(
torch
.
equal
)
@
ireduce
(
builtins
.
all
)
@
func_treelize
()
def
equal
(
input
,
other
,
*
args
,
**
kwargs
):
return
torch
.
equal
(
input
,
other
,
*
args
,
**
kwargs
)
def
equal
(
input
,
other
):
"""
In ``treetensor``, you can get the equality of the two tree tensors.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.equal(
>>> torch.tensor([1, 2, 3]),
>>> torch.tensor([1, 2, 3]),
>>> ) # the same as torch.equal
True
>>> ttorch.equal(
>>> ttorch.tensor({
>>> 'a': torch.tensor([1, 2, 3]),
>>> 'b': torch.tensor([[4, 5], [6, 7]]),
>>> }),
>>> ttorch.tensor({
>>> 'a': torch.tensor([1, 2, 3]),
>>> 'b': torch.tensor([[4, 5], [6, 7]]),
>>> }),
>>> )
True
"""
return
torch
.
equal
(
input
,
other
)
@
doc_from
(
torch
.
tensor
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录