Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
DeepSpeed
提交
cf5ea891
D
DeepSpeed
项目概览
Greenplum
/
DeepSpeed
上一次同步 大约 1 年
通知
10
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeed
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
cf5ea891
编写于
4月 22, 2021
作者:
O
Olatunji Ruwase
提交者:
GitHub
4月 22, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add nvme unit/perf tests (#993)
上级
669028f0
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
356 addition
and
24 deletion
+356
-24
csrc/aio/py_test/ds_aio_basic.py
csrc/aio/py_test/ds_aio_basic.py
+15
-15
csrc/aio/py_test/ds_aio_handle.py
csrc/aio/py_test/ds_aio_handle.py
+6
-6
csrc/aio/py_test/test_ds_aio.py
csrc/aio/py_test/test_ds_aio.py
+0
-3
tests/unit/test_aio.py
tests/unit/test_aio.py
+335
-0
未找到文件。
csrc/aio/py_test/ds_aio_basic.py
浏览文件 @
cf5ea891
...
...
@@ -8,7 +8,7 @@ Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
import
torch
import
os
import
time
from
deepspeed.ops.aio
import
aio_read
,
aio_write
from
deepspeed.ops.aio
import
AsyncIOBuilder
from
multiprocessing
import
Pool
,
Barrier
from
test_ds_aio_utils
import
report_results
,
task_log
,
task_barrier
...
...
@@ -56,7 +56,7 @@ def post_basic(pool_params):
def
main_basic_read
(
pool_params
):
args
,
tid
,
ctxt
=
pool_params
start_time
=
time
.
time
()
aio_read
(
ctxt
[
'buffer'
],
AsyncIOBuilder
().
load
().
aio_read
(
ctxt
[
'buffer'
],
ctxt
[
'file'
],
args
.
block_size
,
args
.
queue_depth
,
...
...
@@ -72,7 +72,7 @@ def main_basic_read(pool_params):
def
main_basic_write
(
pool_params
):
args
,
tid
,
ctxt
=
pool_params
start_time
=
time
.
time
()
aio_write
(
ctxt
[
'buffer'
],
AsyncIOBuilder
().
load
().
aio_write
(
ctxt
[
'buffer'
],
ctxt
[
'file'
],
args
.
block_size
,
args
.
queue_depth
,
...
...
csrc/aio/py_test/ds_aio_handle.py
浏览文件 @
cf5ea891
...
...
@@ -8,8 +8,8 @@ Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
import
torch
import
os
import
time
from
deepspeed.ops.aio
import
aio_handle
from
multiprocessing
import
Pool
,
Barrier
from
deepspeed.ops.aio
import
AsyncIOBuilder
from
test_ds_aio_utils
import
report_results
,
task_log
,
task_barrier
...
...
@@ -29,7 +29,7 @@ def pre_handle(args, tid, read_op):
)
io_parallel
=
args
.
io_parallel
if
args
.
io_parallel
else
1
handle
=
aio_handle
(
args
.
block_size
,
handle
=
AsyncIOBuilder
().
load
().
aio_handle
(
args
.
block_size
,
args
.
queue_depth
,
args
.
single_submit
,
args
.
overlap_events
,
...
...
csrc/aio/py_test/test_ds_aio.py
浏览文件 @
cf5ea891
...
...
@@ -12,12 +12,9 @@ import time
import
sys
from
multiprocessing
import
Pool
import
multiprocessing
as
mp
from
deepspeed.ops.aio
import
aio_read
,
aio_write
,
aio_handle
from
ds_aio_basic
import
aio_basic_multiprocessing
from
ds_aio_handle
import
aio_handle_multiprocessing
GB_DIVISOR
=
1024
**
3
def
parse_arguments
():
parser
=
argparse
.
ArgumentParser
()
...
...
tests/unit/test_aio.py
0 → 100755
浏览文件 @
cf5ea891
import
pytest
import
os
import
filecmp
import
torch
import
deepspeed
import
torch.distributed
as
dist
from
common
import
distributed_test
from
deepspeed.ops.aio
import
AsyncIOBuilder
MEGA_BYTE
=
1024
**
2
BLOCK_SIZE
=
MEGA_BYTE
QUEUE_DEPTH
=
2
IO_SIZE
=
16
*
MEGA_BYTE
IO_PARALLEL
=
2
def
_skip_if_no_aio
():
if
not
deepspeed
.
ops
.
__compatible_ops__
[
AsyncIOBuilder
.
NAME
]:
pytest
.
skip
(
'Skip tests since async-io is not compatible'
)
def
_do_ref_write
(
tmpdir
,
index
=
0
):
file_suffix
=
f
'
{
dist
.
get_rank
()
}
_
{
index
}
'
ref_file
=
os
.
path
.
join
(
tmpdir
,
f
'_py_random_
{
file_suffix
}
.pt'
)
ref_buffer
=
os
.
urandom
(
IO_SIZE
)
with
open
(
ref_file
,
'wb'
)
as
f
:
f
.
write
(
ref_buffer
)
return
ref_file
,
ref_buffer
def
_get_test_file_and_buffer
(
tmpdir
,
ref_buffer
,
cuda_device
,
index
=
0
):
file_suffix
=
f
'
{
dist
.
get_rank
()
}
_
{
index
}
'
test_file
=
os
.
path
.
join
(
tmpdir
,
f
'_aio_write_random_
{
file_suffix
}
.pt'
)
if
cuda_device
:
test_buffer
=
torch
.
cuda
.
ByteTensor
(
list
(
ref_buffer
))
else
:
test_buffer
=
torch
.
ByteTensor
(
list
(
ref_buffer
)).
pin_memory
()
return
test_file
,
test_buffer
def
_validate_handle_state
(
handle
,
single_submit
,
overlap_events
):
assert
handle
.
get_single_submit
()
==
single_submit
assert
handle
.
get_overlap_events
()
==
overlap_events
assert
handle
.
get_thread_count
()
==
IO_PARALLEL
assert
handle
.
get_block_size
()
==
BLOCK_SIZE
assert
handle
.
get_queue_depth
()
==
QUEUE_DEPTH
@
pytest
.
mark
.
parametrize
(
'single_submit, overlap_events'
,
[(
False
,
False
),
(
False
,
True
),
(
True
,
False
),
(
True
,
True
)])
def
test_parallel_read
(
tmpdir
,
single_submit
,
overlap_events
):
_skip_if_no_aio
()
@
distributed_test
(
world_size
=
[
2
])
def
_test_parallel_read
(
single_submit
,
overlap_events
):
ref_file
,
_
=
_do_ref_write
(
tmpdir
)
aio_buffer
=
torch
.
empty
(
IO_SIZE
,
dtype
=
torch
.
uint8
,
device
=
'cpu'
).
pin_memory
()
h
=
AsyncIOBuilder
().
load
().
aio_handle
(
BLOCK_SIZE
,
QUEUE_DEPTH
,
single_submit
,
overlap_events
,
IO_PARALLEL
)
_validate_handle_state
(
h
,
single_submit
,
overlap_events
)
read_status
=
h
.
sync_pread
(
aio_buffer
,
ref_file
)
assert
read_status
==
1
with
open
(
ref_file
,
'rb'
)
as
f
:
ref_buffer
=
list
(
f
.
read
())
assert
ref_buffer
==
aio_buffer
.
tolist
()
_test_parallel_read
(
single_submit
,
overlap_events
)
@
pytest
.
mark
.
parametrize
(
'single_submit, overlap_events, cuda_device'
,
[(
False
,
False
,
False
),
(
False
,
True
,
False
),
(
True
,
False
,
False
),
(
True
,
True
,
False
),
(
False
,
False
,
True
),
(
True
,
True
,
True
)])
def
test_async_read
(
tmpdir
,
single_submit
,
overlap_events
,
cuda_device
):
_skip_if_no_aio
()
@
distributed_test
(
world_size
=
[
2
])
def
_test_async_read
(
single_submit
,
overlap_events
,
cuda_device
):
ref_file
,
_
=
_do_ref_write
(
tmpdir
)
if
cuda_device
:
aio_buffer
=
torch
.
empty
(
IO_SIZE
,
dtype
=
torch
.
uint8
,
device
=
'cuda'
)
else
:
aio_buffer
=
torch
.
empty
(
IO_SIZE
,
dtype
=
torch
.
uint8
,
device
=
'cpu'
).
pin_memory
()
h
=
AsyncIOBuilder
().
load
().
aio_handle
(
BLOCK_SIZE
,
QUEUE_DEPTH
,
single_submit
,
overlap_events
,
IO_PARALLEL
)
_validate_handle_state
(
h
,
single_submit
,
overlap_events
)
read_status
=
h
.
async_pread
(
aio_buffer
,
ref_file
)
assert
read_status
==
0
wait_status
=
h
.
wait
()
assert
wait_status
==
1
with
open
(
ref_file
,
'rb'
)
as
f
:
ref_buffer
=
list
(
f
.
read
())
assert
ref_buffer
==
aio_buffer
.
tolist
()
_test_async_read
(
single_submit
,
overlap_events
,
cuda_device
)
@
pytest
.
mark
.
parametrize
(
'single_submit, overlap_events'
,
[(
False
,
False
),
(
False
,
True
),
(
True
,
False
),
(
True
,
True
)])
def
test_parallel_write
(
tmpdir
,
single_submit
,
overlap_events
):
_skip_if_no_aio
()
@
distributed_test
(
world_size
=
[
2
])
def
_test_parallel_write
(
single_submit
,
overlap_events
):
ref_file
,
ref_buffer
=
_do_ref_write
(
tmpdir
)
aio_file
,
aio_buffer
=
_get_test_file_and_buffer
(
tmpdir
,
ref_buffer
,
False
)
h
=
AsyncIOBuilder
().
load
().
aio_handle
(
BLOCK_SIZE
,
QUEUE_DEPTH
,
single_submit
,
overlap_events
,
IO_PARALLEL
)
_validate_handle_state
(
h
,
single_submit
,
overlap_events
)
write_status
=
h
.
sync_pwrite
(
aio_buffer
,
aio_file
)
assert
write_status
==
1
assert
os
.
path
.
isfile
(
aio_file
)
filecmp
.
clear_cache
()
assert
filecmp
.
cmp
(
ref_file
,
aio_file
,
shallow
=
False
)
_test_parallel_write
(
single_submit
,
overlap_events
)
@
pytest
.
mark
.
parametrize
(
'single_submit, overlap_events, cuda_device'
,
[(
False
,
False
,
False
),
(
False
,
True
,
False
),
(
True
,
False
,
False
),
(
True
,
True
,
False
),
(
False
,
False
,
True
),
(
True
,
True
,
True
)])
def
test_async_write
(
tmpdir
,
single_submit
,
overlap_events
,
cuda_device
):
_skip_if_no_aio
()
@
distributed_test
(
world_size
=
[
2
])
def
_test_async_write
(
single_submit
,
overlap_events
,
cuda_device
):
ref_file
,
ref_buffer
=
_do_ref_write
(
tmpdir
)
aio_file
,
aio_buffer
=
_get_test_file_and_buffer
(
tmpdir
,
ref_buffer
,
cuda_device
)
h
=
AsyncIOBuilder
().
load
().
aio_handle
(
BLOCK_SIZE
,
QUEUE_DEPTH
,
single_submit
,
overlap_events
,
IO_PARALLEL
)
_validate_handle_state
(
h
,
single_submit
,
overlap_events
)
write_status
=
h
.
async_pwrite
(
aio_buffer
,
aio_file
)
assert
write_status
==
0
wait_status
=
h
.
wait
()
assert
wait_status
==
1
assert
os
.
path
.
isfile
(
aio_file
)
filecmp
.
clear_cache
()
assert
filecmp
.
cmp
(
ref_file
,
aio_file
,
shallow
=
False
)
_test_async_write
(
single_submit
,
overlap_events
,
cuda_device
)
@
pytest
.
mark
.
parametrize
(
'async_queue, cuda_device'
,
[(
2
,
False
),
(
4
,
False
),
(
2
,
True
),
(
4
,
True
)])
def
test_async_queue_read
(
tmpdir
,
async_queue
,
cuda_device
):
_skip_if_no_aio
()
@
distributed_test
(
world_size
=
[
2
])
def
_test_async_queue_read
(
async_queue
,
cuda_device
):
ref_files
=
[]
for
i
in
range
(
async_queue
):
f
,
_
=
_do_ref_write
(
tmpdir
,
i
)
ref_files
.
append
(
f
)
aio_buffers
=
[]
for
i
in
range
(
async_queue
):
if
cuda_device
:
buf
=
torch
.
empty
(
IO_SIZE
,
dtype
=
torch
.
uint8
,
device
=
'cuda'
)
else
:
buf
=
torch
.
empty
(
IO_SIZE
,
dtype
=
torch
.
uint8
,
device
=
'cpu'
).
pin_memory
()
aio_buffers
.
append
(
buf
)
single_submit
=
True
overlap_events
=
True
h
=
AsyncIOBuilder
().
load
().
aio_handle
(
BLOCK_SIZE
,
QUEUE_DEPTH
,
single_submit
,
overlap_events
,
IO_PARALLEL
)
_validate_handle_state
(
h
,
single_submit
,
overlap_events
)
for
i
in
range
(
async_queue
):
read_status
=
h
.
async_pread
(
aio_buffers
[
i
],
ref_files
[
i
])
assert
read_status
==
0
wait_status
=
h
.
wait
()
assert
wait_status
==
async_queue
for
i
in
range
(
async_queue
):
with
open
(
ref_files
[
i
],
'rb'
)
as
f
:
ref_buffer
=
list
(
f
.
read
())
assert
ref_buffer
==
aio_buffers
[
i
].
tolist
()
_test_async_queue_read
(
async_queue
,
cuda_device
)
@
pytest
.
mark
.
parametrize
(
'async_queue, cuda_device'
,
[(
2
,
False
),
(
7
,
False
),
(
2
,
True
),
(
7
,
True
)])
def
test_async_queue_write
(
tmpdir
,
async_queue
,
cuda_device
):
_skip_if_no_aio
()
@
distributed_test
(
world_size
=
[
2
])
def
_test_async_queue_write
(
async_queue
,
cuda_device
):
ref_files
=
[]
ref_buffers
=
[]
for
i
in
range
(
async_queue
):
f
,
buf
=
_do_ref_write
(
tmpdir
,
i
)
ref_files
.
append
(
f
)
ref_buffers
.
append
(
buf
)
aio_files
=
[]
aio_buffers
=
[]
for
i
in
range
(
async_queue
):
f
,
buf
=
_get_test_file_and_buffer
(
tmpdir
,
ref_buffers
[
i
],
cuda_device
,
i
)
aio_files
.
append
(
f
)
aio_buffers
.
append
(
buf
)
single_submit
=
True
overlap_events
=
True
h
=
AsyncIOBuilder
().
load
().
aio_handle
(
BLOCK_SIZE
,
QUEUE_DEPTH
,
single_submit
,
overlap_events
,
IO_PARALLEL
)
_validate_handle_state
(
h
,
single_submit
,
overlap_events
)
for
i
in
range
(
async_queue
):
read_status
=
h
.
async_pwrite
(
aio_buffers
[
i
],
aio_files
[
i
])
assert
read_status
==
0
wait_status
=
h
.
wait
()
assert
wait_status
==
async_queue
for
i
in
range
(
async_queue
):
assert
os
.
path
.
isfile
(
aio_files
[
i
])
filecmp
.
clear_cache
()
assert
filecmp
.
cmp
(
ref_files
[
i
],
aio_files
[
i
],
shallow
=
False
)
_test_async_queue_write
(
async_queue
,
cuda_device
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录