Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-treetensor
提交
e756aaba
D
DI-treetensor
项目概览
OpenDILab开源决策智能平台
/
DI-treetensor
大约 1 年 前同步成功
通知
44
Star
172
Fork
11
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-treetensor
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
e756aaba
编写于
3月 10, 2022
作者:
HansBug
😆
提交者:
GitHub
3月 10, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #6 from opendilab/dev/np2tensor
dev(hansbug): add tensor method for treetensor.numpy.ndarray
上级
e18010fe
bb3d9bab
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
37 addition
and
0 deletion
+37
-0
test/numpy/test_array.py
test/numpy/test_array.py
+21
-0
treetensor/numpy/array.py
treetensor/numpy/array.py
+16
-0
未找到文件。
test/numpy/test_array.py
浏览文件 @
e756aaba
import
numpy
as
np
import
pytest
import
torch
import
treetensor.numpy
as
tnp
import
treetensor.torch
as
ttorch
from
treetensor.common
import
Object
...
...
@@ -233,3 +235,22 @@ class TestNumpyArray:
'd'
:
[
0
,
0
,
0.0
],
}
})
def
test_tensor
(
self
):
assert
(
self
.
_DEMO_1
.
tensor
()
==
ttorch
.
Tensor
({
'a'
:
ttorch
.
Tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]]),
'b'
:
ttorch
.
Tensor
([
1
,
3
,
5
,
7
]),
'x'
:
{
'c'
:
ttorch
.
Tensor
([[
11
],
[
23
]]),
'd'
:
ttorch
.
Tensor
([
3
,
9
,
11.0
])
}
})).
all
()
assert
(
self
.
_DEMO_1
.
tensor
(
dtype
=
torch
.
float64
)
==
ttorch
.
Tensor
({
'a'
:
ttorch
.
Tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
dtype
=
torch
.
float64
),
'b'
:
ttorch
.
Tensor
([
1
,
3
,
5
,
7
],
dtype
=
torch
.
float64
),
'x'
:
{
'c'
:
ttorch
.
Tensor
([[
11
],
[
23
]],
dtype
=
torch
.
float64
),
'd'
:
ttorch
.
Tensor
([
3
,
9
,
11.0
],
dtype
=
torch
.
float64
),
}
})).
all
()
treetensor/numpy/array.py
浏览文件 @
e756aaba
from
functools
import
lru_cache
import
numpy
import
torch
from
treevalue
import
method_treelize
from
.base
import
TreeNumpy
...
...
@@ -12,6 +15,12 @@ __all__ = [
_ArrayProxy
,
_InstanceArrayProxy
=
get_tree_proxy
(
numpy
.
ndarray
)
@
lru_cache
()
def
_get_tensor_class
(
args0
):
from
..torch
import
Tensor
return
Tensor
(
args0
)
class
_BaseArrayMeta
(
clsmeta
(
numpy
.
asarray
,
allow_dict
=
True
)):
pass
...
...
@@ -92,6 +101,13 @@ class ndarray(TreeNumpy, metaclass=_ArrayMeta):
def
any
(
self
:
numpy
.
ndarray
,
*
args
,
**
kwargs
):
return
self
.
any
(
*
args
,
**
kwargs
)
@
method_treelize
(
return_type
=
_get_tensor_class
)
def
tensor
(
self
:
numpy
.
ndarray
,
*
args
,
**
kwargs
):
tensor_
:
torch
.
Tensor
=
torch
.
from_numpy
(
self
)
if
args
or
kwargs
:
tensor_
=
tensor_
.
to
(
*
args
,
**
kwargs
)
return
tensor_
@
method_treelize
()
def
__eq__
(
self
,
other
):
"""
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录