Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
openvinotoolkit
mmaction2
提交
2cdc9036
M
mmaction2
项目概览
openvinotoolkit
/
mmaction2
大约 1 年 前同步成功
通知
2
Star
5
Fork
3
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mmaction2
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
2cdc9036
编写于
6月 28, 2020
作者:
L
linjintao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add docstring for FP16 part
上级
7adb6410
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
74 addition
and
30 deletion
+74
-30
mmaction/core/fp16/decorators.py
mmaction/core/fp16/decorators.py
+34
-30
mmaction/core/fp16/hooks.py
mmaction/core/fp16/hooks.py
+30
-0
mmaction/core/fp16/utils.py
mmaction/core/fp16/utils.py
+10
-0
未找到文件。
mmaction/core/fp16/decorators.py
浏览文件 @
2cdc9036
...
...
@@ -19,21 +19,23 @@ def auto_fp16(apply_to=None, out_fp32=False):
`None` indicates all arguments.
out_fp32 (bool): Whether to convert the output back to fp32.
:Example:
class MyModule1(nn.Module)
# Convert x and y to fp16
@auto_fp16()
def forward(self, x, y):
pass
class MyModule2(nn.Module):
# convert pred to fp16
@auto_fp16(apply_to=('pred', ))
def do_something(self, pred, others):
pass
Example:
>>> import torch.nn as nn
>>> class MyModule1(nn.Module):
>>>
>>> # Convert x and y to fp16
>>> @auto_fp16()
>>> def forward(self, x, y):
>>> pass
>>> import torch.nn as nn
>>> class MyModule2(nn.Module):
>>>
>>> # convert pred to fp16
>>> @auto_fp16(apply_to=('pred', ))
>>> def do_something(self, pred, others):
>>> pass
"""
def
auto_fp16_wrapper
(
old_func
):
...
...
@@ -97,21 +99,23 @@ def force_fp32(apply_to=None, out_fp16=False):
`None` indicates all arguments.
out_fp16 (bool): Whether to convert the output back to fp16.
:Example:
class MyModule1(nn.Module)
# Convert x and y to fp32
@force_fp32()
def loss(self, x, y):
pass
class MyModule2(nn.Module):
# convert pred to fp32
@force_fp32(apply_to=('pred', ))
def post_process(self, pred, others):
pass
Example:
>>> import torch.nn as nn
>>> class MyModule1(nn.Module):
>>>
>>> # Convert x and y to fp32
>>> @force_fp32()
>>> def loss(self, x, y):
>>> pass
>>> import torch.nn as nn
>>> class MyModule2(nn.Module):
>>>
>>> # convert pred to fp32
>>> @force_fp32(apply_to=('pred', ))
>>> def post_process(self, pred, others):
>>> pass
"""
def
force_fp32_wrapper
(
old_func
):
...
...
mmaction/core/fp16/hooks.py
浏览文件 @
2cdc9036
...
...
@@ -38,6 +38,11 @@ class Fp16OptimizerHook(OptimizerHook):
self
.
distributed
=
distributed
def
before_run
(
self
,
runner
):
"""Preparing steps before Mixed Precision Training.
1. Make a master copy of fp32 weights for optimization.
2. Convert the main model from fp32 to fp16.
"""
# keep a copy of fp32 weights
runner
.
optimizer
.
param_groups
=
copy
.
deepcopy
(
runner
.
optimizer
.
param_groups
)
...
...
@@ -58,6 +63,14 @@ class Fp16OptimizerHook(OptimizerHook):
fp16_param
.
data
.
copy_
(
fp32_param
.
data
)
def
after_train_iter
(
self
,
runner
):
"""Backward optimization steps for Mixed Precision Training.
1. Scale the loss by a scale factor.
2. Backward the loss to obtain the gradients (fp16).
3. Copy gradients from the model to the fp32 weight copy.
4. Scale the gradients back and update the fp32 weight copy.
5. Copy back the params from fp32 weight copy to the fp16 model.
"""
# clear grads of last iteration
runner
.
model
.
zero_grad
()
runner
.
optimizer
.
zero_grad
()
...
...
@@ -85,6 +98,14 @@ class Fp16OptimizerHook(OptimizerHook):
def
wrap_fp16_model
(
model
):
"""Wrap the FP32 model to FP16.
1. Convert FP32 model to FP16.
2. Remain some necessary layers to be FP32, e.g., normalization layers.
Args:
model (nn.Module): Model in FP32.
"""
# convert model to fp16
model
.
half
()
# patch the normalization layers to make it work in fp32 mode
...
...
@@ -96,6 +117,15 @@ def wrap_fp16_model(model):
def
patch_norm_fp32
(
module
):
"""Recursively convert normalization layers from FP16 to FP32.
Args:
module (nn.Module): The modules to be converted in FP16.
Returns:
nn.Module: The converted module, the normalization layers have been
converted to FP32.
"""
if
isinstance
(
module
,
(
_BatchNorm
,
nn
.
GroupNorm
)):
module
.
float
()
module
.
forward
=
patch_forward_method
(
module
.
forward
,
torch
.
half
,
...
...
mmaction/core/fp16/utils.py
浏览文件 @
2cdc9036
...
...
@@ -5,6 +5,16 @@ import torch
def
cast_tensor_type
(
inputs
,
src_type
,
dst_type
):
"""Recursively convert Tensor in inputs from src_type to dst_type.
Args:
inputs: Inputs that to be casted.
src_type (torch.dtype): Source type..
dst_type (torch.dtype): Destination type.
Returns:
The same type with inputs, but all contained Tensors have been cast.
"""
if
isinstance
(
inputs
,
torch
.
Tensor
):
return
inputs
.
to
(
dst_type
)
elif
isinstance
(
inputs
,
str
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录