Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
DeepSpeed
提交
fee73135
D
DeepSpeed
项目概览
Greenplum
/
DeepSpeed
上一次同步 大约 1 年
通知
10
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeed
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
fee73135
编写于
4月 20, 2022
作者:
O
Olatunji Ruwase
提交者:
GitHub
4月 20, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Use cuda events to improve timing for multi-stream execution (#1881)
Co-authored-by:
N
Jeff Rasley
<
jerasley@microsoft.com
>
上级
66aae13d
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
32 addition
and
20 deletion
+32
-20
deepspeed/utils/timer.py
deepspeed/utils/timer.py
+32
-20
未找到文件。
deepspeed/utils/timer.py
浏览文件 @
fee73135
...
...
@@ -20,42 +20,55 @@ except ImportError:
pass
class
CudaEventTimer
(
object
):
def
__init__
(
self
,
start_event
:
torch
.
cuda
.
Event
,
end_event
:
torch
.
cuda
.
Event
):
self
.
start_event
=
start_event
self
.
end_event
=
end_event
def
get_elapsed_msec
(
self
):
torch
.
cuda
.
current_stream
().
wait_event
(
self
.
end_event
)
self
.
end_event
.
synchronize
()
return
self
.
start_event
.
elapsed_time
(
self
.
end_event
)
class
SynchronizedWallClockTimer
:
"""Group of timers. Borrowed from Nvidia Megatron code"""
class
Timer
:
"""Timer."""
def
__init__
(
self
,
name
):
self
.
name_
=
name
self
.
elapsed_
=
0.0
self
.
started_
=
False
self
.
start_time
=
time
.
time
()
self
.
records
=
[]
self
.
event_timers
=
[]
self
.
start_event
=
None
self
.
elapsed_records
=
None
def
start
(
self
):
"""Start the timer."""
assert
not
self
.
started_
,
"
timer has already been started"
torch
.
cuda
.
synchronize
(
)
self
.
start_
time
=
time
.
time
()
assert
not
self
.
started_
,
f
"
{
self
.
name
}
timer has already been started"
self
.
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
self
.
start_
event
.
record
()
self
.
started_
=
True
def
stop
(
self
,
reset
=
False
,
record
=
False
):
"""Stop the timer."""
assert
self
.
started_
,
"timer is not started"
torch
.
cuda
.
synchronize
()
if
reset
:
self
.
elapsed_
=
time
.
time
()
-
self
.
start_time
else
:
self
.
elapsed_
+=
time
.
time
()
-
self
.
start_time
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
.
record
()
self
.
event_timers
.
append
(
CudaEventTimer
(
self
.
start_event
,
end_event
))
self
.
start_event
=
None
self
.
started_
=
False
if
record
:
self
.
records
.
append
(
self
.
elapsed_
)
def
_get_elapsed_msec
(
self
):
self
.
elapsed_records
=
[
et
.
get_elapsed_msec
()
for
et
in
self
.
event_timers
]
self
.
event_timers
.
clear
()
return
sum
(
self
.
elapsed_records
)
def
reset
(
self
):
"""Reset timer."""
self
.
elapsed_
=
0.0
self
.
started_
=
False
self
.
acc_
=
0.0
self
.
cnt_
=
0
self
.
start_event
=
None
self
.
elapsed_records
=
None
self
.
event_timers
.
clear
()
def
elapsed
(
self
,
reset
=
True
):
"""Calculate the elapsed time."""
...
...
@@ -64,7 +77,7 @@ class SynchronizedWallClockTimer:
if
self
.
started_
:
self
.
stop
()
# Get the elapsed time.
elapsed_
=
self
.
elapsed_
elapsed_
=
self
.
_get_elapsed_msec
()
# Reset the elapsed time
if
reset
:
self
.
reset
()
...
...
@@ -74,7 +87,7 @@ class SynchronizedWallClockTimer:
return
elapsed_
def
mean
(
self
):
return
trim_mean
(
self
.
records
,
0.1
)
return
trim_mean
(
self
.
elapsed_
records
,
0.1
)
def
__init__
(
self
):
self
.
timers
=
{}
...
...
@@ -102,8 +115,7 @@ class SynchronizedWallClockTimer:
string
=
f
"rank=
{
torch
.
distributed
.
get_rank
()
}
time (ms)"
for
name
in
names
:
if
name
in
self
.
timers
:
elapsed_time
=
(
self
.
timers
[
name
].
elapsed
(
reset
=
reset
)
*
1000.0
/
normalizer
)
elapsed_time
=
(
self
.
timers
[
name
].
elapsed
(
reset
=
reset
)
/
normalizer
)
string
+=
" | {}: {:.2f}"
.
format
(
name
,
elapsed_time
)
log_dist
(
string
,
ranks
=
ranks
or
[
0
])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录