Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-treetensor
提交
455fd137
D
DI-treetensor
项目概览
OpenDILab开源决策智能平台
/
DI-treetensor
9 个月 前同步成功
通知
40
Star
172
Fork
11
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-treetensor
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
455fd137
编写于
2月 13, 2023
作者:
HansBug
😆
1
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
dev(hansbug): add pshape
上级
7d50bf4e
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
153 addition
and
2 deletion
+153
-2
requirements.txt
requirements.txt
+1
-1
test/common/constraints/test_shape.py
test/common/constraints/test_shape.py
+11
-0
test/torch/constraints/__init__.py
test/torch/constraints/__init__.py
+0
-0
test/torch/constraints/test_shape.py
test/torch/constraints/test_shape.py
+97
-0
treetensor/common/constraints/shape.py
treetensor/common/constraints/shape.py
+8
-1
treetensor/torch/__init__.py
treetensor/torch/__init__.py
+3
-0
treetensor/torch/constraints/__init__.py
treetensor/torch/constraints/__init__.py
+6
-0
treetensor/torch/constraints/shape.py
treetensor/torch/constraints/shape.py
+16
-0
treetensor/torch/tensor.py
treetensor/torch/tensor.py
+11
-0
未找到文件。
requirements.txt
浏览文件 @
455fd137
treevalue
>=1.4.
3
treevalue
>=1.4.
5
torch
>=1.1.0,<=1.12.1
hbutils
>=0.6.13
numpy
test/common/constraints/test_shape.py
浏览文件 @
455fd137
...
...
@@ -22,6 +22,17 @@ class TestCommonConstraintsShape:
assert
c1
.
prefix
==
(
2
,
3
,
4
)
assert
repr
(
c1
)
==
'<ShapePrefixConstraint (2, 3, 4)>'
assert
len
(
c1
)
==
3
assert
c1
[
0
]
==
2
assert
c1
[
1
]
==
3
assert
c1
[
2
]
==
4
with
pytest
.
raises
(
IndexError
):
_
=
c1
[
3
]
assert
c1
[
-
1
]
==
4
assert
c1
[
-
2
]
==
3
assert
c1
[
-
3
]
==
2
assert
c1
[
1
:]
==
(
3
,
4
)
c1
.
validate
(
np
.
random
.
rand
(
2
,
3
,
4
))
c1
.
validate
(
np
.
random
.
rand
(
2
,
3
,
4
,
5
))
with
pytest
.
raises
(
ValueError
):
...
...
test/torch/constraints/__init__.py
0 → 100644
浏览文件 @
455fd137
test/torch/constraints/test_shape.py
0 → 100644
浏览文件 @
455fd137
import
numpy
as
np
import
pytest
import
torch
import
treetensor.torch
as
ttorch
from
treetensor.torch
import
TensorShapePrefixConstraint
,
shape_prefix
# noinspection DuplicatedCode
@
pytest
.
mark
.
unittest
class
TestCommonConstraintsShape
:
def
test_shape_prefix
(
self
):
c1
=
shape_prefix
(
2
,
3
,
4
)
assert
isinstance
(
c1
,
TensorShapePrefixConstraint
)
assert
c1
.
prefix
==
(
2
,
3
,
4
)
assert
repr
(
c1
)
==
'<TensorShapePrefixConstraint (2, 3, 4)>'
assert
len
(
c1
)
==
3
assert
c1
[
0
]
==
2
assert
c1
[
1
]
==
3
assert
c1
[
2
]
==
4
with
pytest
.
raises
(
IndexError
):
_
=
c1
[
3
]
assert
c1
[
-
1
]
==
4
assert
c1
[
-
2
]
==
3
assert
c1
[
-
3
]
==
2
assert
c1
[
1
:]
==
(
3
,
4
)
with
pytest
.
raises
(
TypeError
):
c1
.
validate
(
np
.
random
.
rand
(
2
,
3
,
4
))
with
pytest
.
raises
(
TypeError
):
c1
.
validate
(
np
.
random
.
rand
(
2
,
3
,
4
,
5
))
with
pytest
.
raises
(
TypeError
):
c1
.
validate
(
np
.
random
.
rand
(
2
,
3
))
with
pytest
.
raises
(
TypeError
):
c1
.
validate
(
np
.
random
.
rand
(
2
,
3
,
3
))
with
pytest
.
raises
(
TypeError
):
c1
.
validate
(
np
.
random
.
rand
(
2
,
3
,
3
,
4
))
with
pytest
.
raises
(
TypeError
):
c1
.
validate
([
2
,
3
,
4
,
5
])
c1
.
validate
(
torch
.
randn
(
2
,
3
,
4
))
c1
.
validate
(
torch
.
randn
(
2
,
3
,
4
,
5
))
with
pytest
.
raises
(
ValueError
):
c1
.
validate
(
torch
.
randn
(
2
,
3
))
with
pytest
.
raises
(
ValueError
):
c1
.
validate
(
torch
.
randn
(
2
,
3
,
3
))
with
pytest
.
raises
(
ValueError
):
c1
.
validate
(
torch
.
randn
(
2
,
3
,
3
,
4
))
with
pytest
.
raises
(
TypeError
):
c1
.
validate
([
2
,
3
,
4
,
5
])
assert
c1
==
shape_prefix
(
2
,
3
,
4
)
assert
not
c1
!=
shape_prefix
(
2
,
3
,
4
)
assert
c1
>=
shape_prefix
(
2
,
3
,
4
)
assert
c1
<=
shape_prefix
(
2
,
3
,
4
)
assert
not
c1
>
shape_prefix
(
2
,
3
,
4
)
assert
not
c1
<
shape_prefix
(
2
,
3
,
4
)
assert
not
c1
==
shape_prefix
(
2
,
3
)
assert
c1
!=
shape_prefix
(
2
,
3
)
assert
c1
>=
shape_prefix
(
2
,
3
)
assert
not
c1
<=
shape_prefix
(
2
,
3
)
assert
c1
>
shape_prefix
(
2
,
3
)
assert
not
c1
<
shape_prefix
(
2
,
3
)
assert
not
c1
==
shape_prefix
(
2
,
3
,
4
,
5
)
assert
c1
!=
shape_prefix
(
2
,
3
,
4
,
5
)
assert
not
c1
>=
shape_prefix
(
2
,
3
,
4
,
5
)
assert
c1
<=
shape_prefix
(
2
,
3
,
4
,
5
)
assert
not
c1
>
shape_prefix
(
2
,
3
,
4
,
5
)
assert
c1
<
shape_prefix
(
2
,
3
,
4
,
5
)
assert
not
c1
==
shape_prefix
(
2
,
3
,
3
)
assert
c1
!=
shape_prefix
(
2
,
3
,
3
)
assert
not
c1
>=
shape_prefix
(
2
,
3
,
3
)
assert
not
c1
<=
shape_prefix
(
2
,
3
,
3
)
assert
not
c1
>
shape_prefix
(
2
,
3
,
3
)
assert
not
c1
<
shape_prefix
(
2
,
3
,
3
)
assert
not
c1
>=
np
.
ndarray
assert
not
c1
>
np
.
ndarray
assert
c1
>=
torch
.
Tensor
assert
c1
>
torch
.
Tensor
def
test_pshape
(
self
):
tt
=
ttorch
.
tensor
({
'a'
:
[[
0.8479
,
1.0074
,
0.2725
],
[
1.1674
,
1.0784
,
0.0655
]],
'b'
:
{
'x'
:
[[
0.2644
,
0.7268
,
0.2781
,
0.6469
],
[
2.0015
,
0.4448
,
0.8814
,
1.0063
],
[
0.1847
,
0.5864
,
0.4417
,
0.2117
]]},
})
assert
tt
.
pshape
is
None
tt2
=
tt
.
with_constraints
(
shape_prefix
(
2
,
3
),
clear
=
False
)
assert
tt2
.
pshape
==
(
2
,
3
)
treetensor/common/constraints/shape.py
浏览文件 @
455fd137
from
collections.abc
import
Sequence
from
typing
import
Type
,
TypeVar
,
Optional
from
treevalue.tree
import
ValueConstraint
...
...
@@ -9,7 +10,7 @@ __all__ = [
]
class
ShapePrefixConstraint
(
ValueConstraint
):
class
ShapePrefixConstraint
(
ValueConstraint
,
Sequence
):
__type__
:
Optional
[
type
]
=
None
def
__init__
(
self
,
*
prefix
):
...
...
@@ -20,6 +21,12 @@ class ShapePrefixConstraint(ValueConstraint):
def
prefix
(
self
):
return
self
.
__prefix
def
__getitem__
(
self
,
index
):
return
self
.
__prefix
[
index
]
def
__len__
(
self
)
->
int
:
return
len
(
self
.
__prefix
)
def
_validate_value
(
self
,
instance
):
if
self
.
__type__
and
not
isinstance
(
instance
,
self
.
__type__
):
raise
TypeError
(
f
'Invalid type,
{
self
.
__type__
.
__name__
!
r
}
expected but
{
instance
!
r
}
found.'
)
...
...
treetensor/torch/__init__.py
浏览文件 @
455fd137
...
...
@@ -5,6 +5,8 @@ from typing import Iterable
import
torch
from
.constraints
import
*
from
.constraints
import
__all__
as
_constraints_all
from
.funcs
import
*
from
.funcs
import
__all__
as
_funcs_all
from
.funcs.base
import
get_func_from_torch
...
...
@@ -17,6 +19,7 @@ from .tensor import __all__ as _tensor_all
from
..config.meta
import
__VERSION__
__all__
=
[
*
_constraints_all
,
*
_funcs_all
,
*
_size_all
,
*
_tensor_all
,
...
...
treetensor/torch/constraints/__init__.py
0 → 100644
浏览文件 @
455fd137
from
.shape
import
*
from
.shape
import
__all__
as
_shape_all
__all__
=
[
*
_shape_all
]
treetensor/torch/constraints/shape.py
0 → 100644
浏览文件 @
455fd137
import
torch
from
...common.constraints
import
ShapePrefixConstraint
from
...common.constraints
import
shape_prefix
as
_origin_shape_prefix
__all__
=
[
'TensorShapePrefixConstraint'
,
'shape_prefix'
,
]
class
TensorShapePrefixConstraint
(
ShapePrefixConstraint
):
__type__
=
torch
.
Tensor
def
shape_prefix
(
*
shape
):
return
_origin_shape_prefix
(
*
shape
,
type_
=
TensorShapePrefixConstraint
)
treetensor/torch/tensor.py
浏览文件 @
455fd137
from
typing
import
Tuple
,
Optional
import
numpy
as
np
import
torch
as
pytorch
from
hbutils.reflection
import
post_process
from
treevalue
import
method_treelize
,
TreeValue
,
typetrans
from
.base
import
Torch
,
rmreduce
,
post_reduce
,
auto_reduce
from
.constraints
import
TensorShapePrefixConstraint
from
.size
import
Size
from
.stream
import
stream_call
from
..common
import
Object
,
ireduce
,
clsmeta
,
return_self
,
auto_tree
,
get_tree_proxy
...
...
@@ -116,6 +119,14 @@ class Tensor(Torch, metaclass=_TensorMeta):
else
:
return
tree
@
property
def
pshape
(
self
)
->
Optional
[
Tuple
[
int
,
...]]:
constraint
=
self
.
constraint
.
access_first
(
TensorShapePrefixConstraint
)
if
constraint
:
return
constraint
.
prefix
else
:
return
None
@
property
def
torch
(
self
):
"""
...
...
HansBug
😆
@HansBug
mentioned in commit
9dd4461b
·
2月 15, 2023
mentioned in commit
9dd4461b
mentioned in commit 9dd4461b29e524f2b65bd40c984311f09b48dc9d
开关提交列表
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录