Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
wmsofts
DI-treetensor
提交
37d38b34
D
DI-treetensor
项目概览
wmsofts
/
DI-treetensor
与 Fork 源项目一致
Fork自
OpenDILab开源决策智能平台 / DI-treetensor
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-treetensor
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
37d38b34
编写于
3月 17, 2022
作者:
HansBug
😆
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
dev(hansbug): add documentation for new functions
上级
41907116
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
127 addition
and
1 deletion
+127
-1
docs/source/api_doc/numpy/funcs.rst.py
docs/source/api_doc/numpy/funcs.rst.py
+1
-1
test/numpy/test_funcs.py
test/numpy/test_funcs.py
+94
-0
treetensor/numpy/funcs.py
treetensor/numpy/funcs.py
+32
-0
未找到文件。
docs/source/api_doc/numpy/funcs.rst.py
浏览文件 @
37d38b34
...
...
@@ -73,7 +73,7 @@ with the following command and find its documentation.
print_title
(
f
"Description From Numpy v
{
_short_version
}
"
,
levelc
=
'-'
,
file
=
p_func
)
current_module
(
np
.
__name__
,
file
=
p_func
)
_origin_doc
=
_doc_process
(
_origin
.
__doc__
or
""
)
_origin_doc
=
_doc_process
(
_origin
.
__doc__
or
""
)
.
lstrip
()
_doc_lines
=
_origin_doc
.
splitlines
()
_first_line
,
_other_lines
=
_doc_lines
[
0
],
_doc_lines
[
1
:]
if
_first_line
.
strip
():
...
...
test/numpy/test_funcs.py
浏览文件 @
37d38b34
...
...
@@ -127,3 +127,97 @@ class TestNumpyFuncs:
'd'
:
True
,
}
})
def
test_zeros
(
self
):
zs
=
tnp
.
zeros
((
2
,
3
))
assert
isinstance
(
zs
,
np
.
ndarray
)
assert
np
.
allclose
(
zs
,
np
.
zeros
((
2
,
3
)))
zs
=
tnp
.
zeros
({
'a'
:
(
2
,
3
),
'c'
:
{
'x'
:
(
3
,
4
)}})
assert
tnp
.
allclose
(
zs
,
tnp
.
ndarray
({
'a'
:
np
.
zeros
((
2
,
3
)),
'c'
:
{
'x'
:
np
.
zeros
((
3
,
4
))}
}))
def
test_ones
(
self
):
zs
=
tnp
.
ones
((
2
,
3
))
assert
isinstance
(
zs
,
np
.
ndarray
)
assert
np
.
allclose
(
zs
,
np
.
ones
((
2
,
3
)))
zs
=
tnp
.
ones
({
'a'
:
(
2
,
3
),
'c'
:
{
'x'
:
(
3
,
4
)}})
assert
tnp
.
allclose
(
zs
,
tnp
.
ndarray
({
'a'
:
np
.
ones
((
2
,
3
)),
'c'
:
{
'x'
:
np
.
zeros
((
3
,
4
))}
}))
def
test_stack
(
self
):
a
=
np
.
array
([
1
,
2
,
3
])
b
=
np
.
array
([
2
,
3
,
4
])
nd
=
tnp
.
stack
((
a
,
b
))
assert
isinstance
(
nd
,
np
.
ndarray
)
assert
np
.
allclose
(
nd
,
np
.
array
([[
1
,
2
,
3
],
[
2
,
3
,
4
]]))
a
=
tnp
.
array
({
'a'
:
[
1
,
2
,
3
],
'c'
:
{
'x'
:
[
11
,
22
,
33
]},
})
b
=
tnp
.
array
({
'a'
:
[
2
,
3
,
4
],
'c'
:
{
'x'
:
[
22
,
33
,
44
]},
})
nd
=
tnp
.
stack
((
a
,
b
))
assert
tnp
.
allclose
(
nd
,
tnp
.
array
({
'a'
:
[[
1
,
2
,
3
],
[
2
,
3
,
4
]],
'c'
:
{
'x'
:
[[
11
,
22
,
33
],
[
22
,
33
,
44
]]},
}))
def
test_concatenate
(
self
):
a
=
np
.
array
([[
1
,
2
],
[
3
,
4
]])
b
=
np
.
array
([[
5
,
6
]])
nd
=
tnp
.
concatenate
((
a
,
b
),
axis
=
0
)
assert
isinstance
(
nd
,
np
.
ndarray
)
assert
np
.
allclose
(
nd
,
np
.
array
([[
1
,
2
],
[
3
,
4
],
[
5
,
6
]]))
a
=
tnp
.
array
({
'a'
:
[[
1
,
2
],
[
3
,
4
]],
'c'
:
{
'x'
:
[[
11
,
22
],
[
33
,
44
]]},
})
b
=
tnp
.
array
({
'a'
:
[[
5
,
6
]],
'c'
:
{
'x'
:
[[
55
,
66
]]},
})
nd
=
tnp
.
concatenate
((
a
,
b
),
axis
=
0
)
assert
tnp
.
allclose
(
nd
,
tnp
.
array
({
'a'
:
[[
1
,
2
],
[
3
,
4
],
[
5
,
6
]],
'c'
:
{
'x'
:
[[
11
,
22
],
[
33
,
44
],
[
55
,
66
]]},
}))
def
test_split
(
self
):
x
=
np
.
arange
(
9.0
)
ns
=
tnp
.
split
(
x
,
3
)
assert
len
(
ns
)
==
3
assert
isinstance
(
ns
[
0
],
np
.
ndarray
)
assert
np
.
allclose
(
ns
[
0
],
np
.
array
([
0.0
,
1.0
,
2.0
]))
assert
isinstance
(
ns
[
1
],
np
.
ndarray
)
assert
np
.
allclose
(
ns
[
1
],
np
.
array
([
3.0
,
4.0
,
5.0
]))
assert
isinstance
(
ns
[
2
],
np
.
ndarray
)
assert
np
.
allclose
(
ns
[
2
],
np
.
array
([
6.0
,
7.0
,
8.0
]))
xx
=
tnp
.
arange
(
tnp
.
ndarray
({
'a'
:
9.0
,
'c'
:
{
'x'
:
18.0
}}))
ns
=
tnp
.
split
(
xx
,
3
)
assert
len
(
ns
)
==
3
assert
tnp
.
allclose
(
ns
[
0
],
tnp
.
array
({
'a'
:
[
0.0
,
1.0
,
2.0
],
'c'
:
{
'x'
:
[
0.0
,
1.0
,
2.0
,
3.0
,
4.0
,
5.0
]},
}))
assert
tnp
.
allclose
(
ns
[
1
],
tnp
.
array
({
'a'
:
[
3.0
,
4.0
,
5.0
],
'c'
:
{
'x'
:
[
6.0
,
7.0
,
8.0
,
9.0
,
10.0
,
11.0
]},
}))
assert
tnp
.
allclose
(
ns
[
2
],
tnp
.
array
({
'a'
:
[
6.0
,
7.0
,
8.0
],
'c'
:
{
'x'
:
[
12.0
,
13.0
,
14.0
,
15.0
,
16.0
,
17.0
]},
}))
treetensor/numpy/funcs.py
浏览文件 @
37d38b34
...
...
@@ -13,6 +13,8 @@ from ..utils import replaceable_partial, doc_from, args_mapping
__all__
=
[
'all'
,
'any'
,
'array'
,
'equal'
,
'array_equal'
,
'stack'
,
'concatenate'
,
'split'
,
'zeros'
,
'ones'
,
]
func_treelize
=
post_process
(
post_process
(
args_mapping
(
...
...
@@ -71,3 +73,33 @@ def array(p_object, *args, **kwargs):
})
"""
return
np
.
array
(
p_object
,
*
args
,
**
kwargs
)
@
doc_from
(
np
.
stack
)
@
func_treelize
(
subside
=
True
)
def
stack
(
arrays
,
*
args
,
**
kwargs
):
return
np
.
stack
(
arrays
,
*
args
,
**
kwargs
)
@
doc_from
(
np
.
concatenate
)
@
func_treelize
(
subside
=
True
)
def
concatenate
(
arrays
,
*
args
,
**
kwargs
):
return
np
.
concatenate
(
arrays
,
*
args
,
**
kwargs
)
@
doc_from
(
np
.
split
)
@
func_treelize
(
rise
=
True
)
def
split
(
ary
,
*
args
,
**
kwargs
):
return
np
.
split
(
ary
,
*
args
,
**
kwargs
)
@
doc_from
(
np
.
zeros
)
@
func_treelize
()
def
zeros
(
shape
,
*
args
,
**
kwargs
):
return
np
.
zeros
(
shape
,
*
args
,
**
kwargs
)
@
doc_from
(
np
.
ones
)
@
func_treelize
()
def
ones
(
shape
,
*
args
,
**
kwargs
):
return
np
.
ones
(
shape
,
*
args
,
**
kwargs
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录