Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
DeepSpeed
提交
8d3b42c2
D
DeepSpeed
项目概览
Greenplum
/
DeepSpeed
上一次同步 大约 1 年
通知
10
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeed
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
8d3b42c2
编写于
1月 30, 2023
作者:
B
Bing Xie
提交者:
GitHub
1月 30, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Bing/formatting correction (#2764)
* modify engine.py for formatting * commit formatting changes on engine.py
上级
258d2831
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
22 addition
and
5 deletion
+22
-5
deepspeed/runtime/engine.py
deepspeed/runtime/engine.py
+22
-5
未找到文件。
deepspeed/runtime/engine.py
浏览文件 @
8d3b42c2
...
@@ -2009,11 +2009,14 @@ class DeepSpeedEngine(Module):
...
@@ -2009,11 +2009,14 @@ class DeepSpeedEngine(Module):
return
loss
return
loss
def
is_gradient_accumulation_boundary
(
self
):
def
is_gradient_accumulation_boundary
(
self
):
"""Query whether the current micro-batch is at the boundary of
"""
Query whether the current micro-batch is at the boundary of
gradient accumulation, and thus will trigger gradient reductions and
gradient accumulation, and thus will trigger gradient reductions and
an optimizer step.
an optimizer step.
Returns:
Returns:
bool: if the current step is a gradient accumulation boundary.
bool: if the current step is a gradient accumulation boundary.
"""
"""
if
self
.
_is_gradient_accumulation_boundary
is
None
:
if
self
.
_is_gradient_accumulation_boundary
is
None
:
return
(
self
.
micro_steps
+
1
)
%
\
return
(
self
.
micro_steps
+
1
)
%
\
...
@@ -2022,7 +2025,8 @@ class DeepSpeedEngine(Module):
...
@@ -2022,7 +2025,8 @@ class DeepSpeedEngine(Module):
return
self
.
_is_gradient_accumulation_boundary
return
self
.
_is_gradient_accumulation_boundary
def
set_gradient_accumulation_boundary
(
self
,
is_boundary
):
def
set_gradient_accumulation_boundary
(
self
,
is_boundary
):
"""Manually overrides the DeepSpeed engine's gradient accumulation boundary state, this is an optional
"""
Manually overrides the DeepSpeed engine's gradient accumulation boundary state, this is an optional
feature and should be used with care. The state should be set before to the intended
feature and should be used with care. The state should be set before to the intended
value before each forward/backward. The final fordward/backward should have the
value before each forward/backward. The final fordward/backward should have the
boundary state set to True. This style allows client code to only call engine.step() once after all
boundary state set to True. This style allows client code to only call engine.step() once after all
...
@@ -2714,7 +2718,9 @@ class DeepSpeedEngine(Module):
...
@@ -2714,7 +2718,9 @@ class DeepSpeedEngine(Module):
load_lr_scheduler_states
=
True
,
load_lr_scheduler_states
=
True
,
load_module_only
=
False
,
load_module_only
=
False
,
custom_load_fn
=
None
):
custom_load_fn
=
None
):
"""Load training checkpoint
"""
Load training checkpoint
Arguments:
Arguments:
load_dir: Required. Directory to load the checkpoint from
load_dir: Required. Directory to load the checkpoint from
tag: Checkpoint tag used as a unique identifier for checkpoint, if not provided will attempt to load tag in 'latest' file
tag: Checkpoint tag used as a unique identifier for checkpoint, if not provided will attempt to load tag in 'latest' file
...
@@ -2723,14 +2729,17 @@ class DeepSpeedEngine(Module):
...
@@ -2723,14 +2729,17 @@ class DeepSpeedEngine(Module):
load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint.
load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint.
load_module_only: Optional. Boolean to load only the model weights from the checkpoint. Ex. warmstarting.
load_module_only: Optional. Boolean to load only the model weights from the checkpoint. Ex. warmstarting.
custom_load_fn: Optional. Custom model load function.
custom_load_fn: Optional. Custom model load function.
Returns:
Returns:
A tuple of ``load_path`` and ``client_state``.
A tuple of ``load_path`` and ``client_state``.
*``load_path``: Path of the loaded checkpoint. ``None`` if loading the checkpoint failed.
*``load_path``: Path of the loaded checkpoint. ``None`` if loading the checkpoint failed.
*``client_state``: State dictionary used for loading required training states in the client code.
*``client_state``: State dictionary used for loading required training states in the client code.
Important: under ZeRO3, one cannot load checkpoint with ``engine.load_checkpoint()`` right
Important: under ZeRO3, one cannot load checkpoint with ``engine.load_checkpoint()`` right
after ``engine.save_checkpoint()``. It is because ``engine.module`` is partitioned, and
after ``engine.save_checkpoint()``. It is because ``engine.module`` is partitioned, and
``load_checkpoint()`` wants a pristine model. If insisting to do so, please reinitialize engine
``load_checkpoint()`` wants a pristine model. If insisting to do so, please reinitialize engine
before ``load_checkpoint()``.
before ``load_checkpoint()``.
"""
"""
if
tag
is
None
:
if
tag
is
None
:
...
@@ -3062,7 +3071,8 @@ class DeepSpeedEngine(Module):
...
@@ -3062,7 +3071,8 @@ class DeepSpeedEngine(Module):
logger
.
warning
(
msg
)
logger
.
warning
(
msg
)
def
save_checkpoint
(
self
,
save_dir
,
tag
=
None
,
client_state
=
{},
save_latest
=
True
):
def
save_checkpoint
(
self
,
save_dir
,
tag
=
None
,
client_state
=
{},
save_latest
=
True
):
r
"""Save training checkpoint
"""Save training checkpoint
Arguments:
Arguments:
save_dir: Required. Directory for saving the checkpoint
save_dir: Required. Directory for saving the checkpoint
tag: Optional. Checkpoint tag used as a unique identifier for the checkpoint, global step is
tag: Optional. Checkpoint tag used as a unique identifier for the checkpoint, global step is
...
@@ -3073,6 +3083,7 @@ class DeepSpeedEngine(Module):
...
@@ -3073,6 +3083,7 @@ class DeepSpeedEngine(Module):
because each process needs to save its master weights and scheduler+optimizer states. This
because each process needs to save its master weights and scheduler+optimizer states. This
method will hang waiting to synchronize with other processes if it's called just for the
method will hang waiting to synchronize with other processes if it's called just for the
process with rank 0.
process with rank 0.
"""
"""
if
self
.
zero_optimization_partition_weights
():
if
self
.
zero_optimization_partition_weights
():
# Prepare for checkpoint save by ensuring all parameters are partitioned
# Prepare for checkpoint save by ensuring all parameters are partitioned
...
@@ -3467,17 +3478,23 @@ class DeepSpeedEngine(Module):
...
@@ -3467,17 +3478,23 @@ class DeepSpeedEngine(Module):
return
self
.
save_16bit_model
(
save_dir
,
save_filename
)
return
self
.
save_16bit_model
(
save_dir
,
save_filename
)
def
save_16bit_model
(
self
,
save_dir
,
save_filename
=
"pytorch_model.bin"
):
def
save_16bit_model
(
self
,
save_dir
,
save_filename
=
"pytorch_model.bin"
):
r
"""Save 16bit model weights
"""
Save 16bit model weights
This method saves the 16bit model weights at the desired destination.
This method saves the 16bit model weights at the desired destination.
Arguments:
Arguments:
save_dir: Required. Directory for saving the model
save_dir: Required. Directory for saving the model
save_filename: Optional. Filename to save to. Defaults to ``pytorch_model.bin``
save_filename: Optional. Filename to save to. Defaults to ``pytorch_model.bin``
Returns:
Returns:
``True`` when a model has been saved, ``False`` otherwise. It will not be saved if
``True`` when a model has been saved, ``False`` otherwise. It will not be saved if
stage3_gather_16bit_weights_on_model_save is ``False``.
stage3_gather_16bit_weights_on_model_save is ``False``.
Important: all processes must call this method and not just the process with rank 0. It is
Important: all processes must call this method and not just the process with rank 0. It is
because the processes need to work in sync to gather the weights. This method will hang
because the processes need to work in sync to gather the weights. This method will hang
waiting to synchronize with other processes if it's called just for the process with rank 0.
waiting to synchronize with other processes if it's called just for the process with rank 0.
"""
"""
path
=
os
.
path
.
join
(
save_dir
,
save_filename
)
path
=
os
.
path
.
join
(
save_dir
,
save_filename
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录