Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-treetensor
提交
6a3a7dd0
D
DI-treetensor
项目概览
OpenDILab开源决策智能平台
/
DI-treetensor
大约 1 年 前同步成功
通知
43
Star
172
Fork
11
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-treetensor
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
6a3a7dd0
编写于
8月 14, 2022
作者:
HansBug
😆
1
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
dev(hansbug): add support for CUDA stream test
上级
8178b497
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
68 addition
and
2 deletion
+68
-2
test/torch/test_stream.py
test/torch/test_stream.py
+61
-0
treetensor/torch/stream.py
treetensor/torch/stream.py
+7
-2
未找到文件。
test/torch/test_stream.py
0 → 100644
浏览文件 @
6a3a7dd0
import
unittest
import
pytest
import
torch
import
torch.cuda
import
treetensor.torch
as
ttorch
_CUDA_OK
=
torch
.
cuda
.
is_available
()
N
,
M
,
T
=
200
,
2
,
50
S1
,
S2
,
S3
=
32
,
128
,
512
# noinspection DuplicatedCode
@
pytest
.
mark
.
unittest
class
TestTorchStream
:
def
test_simple
(
self
):
a
=
ttorch
.
randn
({
f
'a
{
i
}
'
:
(
S1
,
S2
)
for
i
in
range
(
N
)})
b
=
ttorch
.
randn
({
f
'a
{
i
}
'
:
(
S2
,
S3
)
for
i
in
range
(
N
)})
c
=
ttorch
.
matmul
(
a
,
b
)
for
i
in
range
(
N
):
assert
torch
.
isclose
(
c
[
f
'a
{
i
}
'
],
torch
.
matmul
(
a
[
f
'a
{
i
}
'
],
b
[
f
'a
{
i
}
'
])
).
all
(),
f
'Not match on item
{
f
"a
{
i
}
"!r
}
.'
@
unittest
.
skipUnless
(
_CUDA_OK
,
'CUDA required'
)
def
test_simple_with_cuda
(
self
):
a
=
ttorch
.
randn
({
f
'a
{
i
}
'
:
(
S1
,
S2
)
for
i
in
range
(
N
)},
device
=
'cuda'
)
b
=
ttorch
.
randn
({
f
'a
{
i
}
'
:
(
S2
,
S3
)
for
i
in
range
(
N
)},
device
=
'cuda'
)
torch
.
cuda
.
synchronize
()
c
=
ttorch
.
matmul
(
a
,
b
)
torch
.
cuda
.
synchronize
()
for
i
in
range
(
N
):
assert
torch
.
isclose
(
c
[
f
'a
{
i
}
'
],
torch
.
matmul
(
a
[
f
'a
{
i
}
'
],
b
[
f
'a
{
i
}
'
])
).
all
(),
f
'Not match on item
{
f
"a
{
i
}
"!r
}
.'
@
unittest
.
skipUnless
(
not
_CUDA_OK
,
'No CUDA required'
)
def
test_stream_without_cuda
(
self
):
with
pytest
.
raises
(
AssertionError
):
ttorch
.
stream
(
10
)
@
unittest
.
skipUnless
(
_CUDA_OK
,
'CUDA required'
)
def
test_stream_with_cuda
(
self
):
a
=
ttorch
.
randn
({
f
'a
{
i
}
'
:
(
S1
,
S2
)
for
i
in
range
(
N
)},
device
=
'cuda'
)
b
=
ttorch
.
randn
({
f
'a
{
i
}
'
:
(
S2
,
S3
)
for
i
in
range
(
N
)},
device
=
'cuda'
)
ttorch
.
stream
(
4
)
torch
.
cuda
.
synchronize
()
c
=
ttorch
.
matmul
(
a
,
b
)
torch
.
cuda
.
synchronize
()
for
i
in
range
(
N
):
assert
torch
.
isclose
(
c
[
f
'a
{
i
}
'
],
torch
.
matmul
(
a
[
f
'a
{
i
}
'
],
b
[
f
'a
{
i
}
'
])
).
all
(),
f
'Not match on item
{
f
"a
{
i
}
"!r
}
.'
treetensor/torch/stream.py
浏览文件 @
6a3a7dd0
import
random
import
itertools
from
typing
import
Optional
,
List
import
torch
...
...
@@ -27,9 +27,14 @@ def stream(cnt):
_global_streams
=
_stream_pool
[:
cnt
]
_stream_count
=
itertools
.
count
()
def
stream_call
(
func
,
*
args
,
**
kwargs
):
if
_global_streams
is
not
None
:
with
torch
.
cuda
.
stream
(
random
.
choice
(
_global_streams
)):
_stream_index
=
next
(
_stream_count
)
%
len
(
_global_streams
)
_stream
=
_global_streams
[
_stream_index
]
with
torch
.
cuda
.
stream
(
_stream
):
return
func
(
*
args
,
**
kwargs
)
else
:
return
func
(
*
args
,
**
kwargs
)
HansBug
😆
@HansBug
mentioned in commit
abb72c1c
·
9月 20, 2022
mentioned in commit
abb72c1c
mentioned in commit abb72c1c1a8e4b3e9ba8a7cf206acd3d05abab6d
开关提交列表
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录