Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
96aade95
M
models
项目概览
PaddlePaddle
/
models
大约 1 年 前同步成功
通知
222
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
96aade95
编写于
1月 11, 2019
作者:
T
typhoonzero
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
follow comments
上级
fd82711e
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
83 addition
and
86 deletion
+83
-86
fluid/PaddleCV/image_classification/train.py
fluid/PaddleCV/image_classification/train.py
+4
-86
fluid/PaddleCV/image_classification/utils/__init__.py
fluid/PaddleCV/image_classification/utils/__init__.py
+1
-0
fluid/PaddleCV/image_classification/utils/fp16_utils.py
fluid/PaddleCV/image_classification/utils/fp16_utils.py
+78
-0
未找到文件。
fluid/PaddleCV/image_classification/train.py
浏览文件 @
96aade95
...
...
@@ -17,6 +17,7 @@ import functools
import
subprocess
import
utils
from
utils.learning_rate
import
cosine_decay
from
utils.fp16_utils
import
create_master_params_grads
,
master_param_to_train_param
from
utility
import
add_arguments
,
print_arguments
import
models
import
models_name
...
...
@@ -160,62 +161,6 @@ def net_config(image, label, model, args):
return
avg_cost
,
acc_top1
,
acc_top5
def
cast_fp16_to_fp32
(
i
,
o
,
prog
):
prog
.
global_block
().
append_op
(
type
=
"cast"
,
inputs
=
{
"X"
:
i
},
outputs
=
{
"Out"
:
o
},
attrs
=
{
"in_dtype"
:
fluid
.
core
.
VarDesc
.
VarType
.
FP16
,
"out_dtype"
:
fluid
.
core
.
VarDesc
.
VarType
.
FP32
}
)
def
cast_fp32_to_fp16
(
i
,
o
,
prog
):
prog
.
global_block
().
append_op
(
type
=
"cast"
,
inputs
=
{
"X"
:
i
},
outputs
=
{
"Out"
:
o
},
attrs
=
{
"in_dtype"
:
fluid
.
core
.
VarDesc
.
VarType
.
FP32
,
"out_dtype"
:
fluid
.
core
.
VarDesc
.
VarType
.
FP16
}
)
def
copy_to_master_param
(
p
,
block
):
v
=
block
.
vars
.
get
(
p
.
name
,
None
)
if
v
is
None
:
raise
ValueError
(
"no param name %s found!"
%
p
.
name
)
new_p
=
fluid
.
framework
.
Parameter
(
block
=
block
,
shape
=
v
.
shape
,
dtype
=
fluid
.
core
.
VarDesc
.
VarType
.
FP32
,
type
=
v
.
type
,
lod_level
=
v
.
lod_level
,
stop_gradient
=
p
.
stop_gradient
,
trainable
=
p
.
trainable
,
optimize_attr
=
p
.
optimize_attr
,
regularizer
=
p
.
regularizer
,
gradient_clip_attr
=
p
.
gradient_clip_attr
,
error_clip
=
p
.
error_clip
,
name
=
v
.
name
+
".master"
)
return
new_p
def
update_op_role_var
(
params_grads
,
master_params_grads
,
main_prog
):
orig_grad_name_set
=
set
()
for
_
,
g
in
params_grads
:
orig_grad_name_set
.
add
(
g
.
name
)
master_g2p_dict
=
dict
()
for
idx
,
master
in
enumerate
(
master_params_grads
):
orig
=
params_grads
[
idx
]
master_g2p_dict
[
orig
[
1
].
name
]
=
[
master
[
0
].
name
,
master
[
1
].
name
]
for
op
in
main_prog
.
global_block
().
ops
:
for
oname
in
op
.
output_arg_names
:
if
oname
in
orig_grad_name_set
:
# rename
print
(
"setting to "
,
master_g2p_dict
[
oname
])
op
.
_set_attr
(
"op_role_var"
,
master_g2p_dict
[
oname
])
def
build_program
(
is_train
,
main_prog
,
startup_prog
,
args
):
image_shape
=
[
int
(
m
)
for
m
in
args
.
image_shape
.
split
(
","
)]
model_name
=
args
.
model
...
...
@@ -249,38 +194,11 @@ def build_program(is_train, main_prog, startup_prog, args):
optimizer
=
optimizer_setting
(
params
)
if
args
.
fp16
:
master_params_grads
=
[]
params_grads
=
optimizer
.
backward
(
avg_cost
)
tmp_role
=
main_prog
.
_current_role
OpRole
=
fluid
.
core
.
op_proto_and_checker_maker
.
OpRole
main_prog
.
_current_role
=
OpRole
.
Backward
for
p
,
g
in
params_grads
:
master_param
=
copy_to_master_param
(
p
,
main_prog
.
global_block
())
startup_master_param
=
startup_prog
.
global_block
().
_clone_variable
(
master_param
)
startup_p
=
startup_prog
.
global_block
().
var
(
p
.
name
)
cast_fp16_to_fp32
(
startup_p
,
startup_master_param
,
startup_prog
)
if
g
.
name
.
startswith
(
"batch_norm"
):
if
args
.
scale_loss
>
1
:
scaled_g
=
g
/
float
(
args
.
scale_loss
)
else
:
scaled_g
=
g
master_params_grads
.
append
([
p
,
scaled_g
])
continue
master_grad
=
fluid
.
layers
.
cast
(
g
,
"float32"
)
if
args
.
scale_loss
>
1
:
master_grad
=
master_grad
/
float
(
args
.
scale_loss
)
master_params_grads
.
append
([
master_param
,
master_grad
])
main_prog
.
_current_role
=
tmp_role
master_params_grads
=
create_master_params_grads
(
params_grads
,
main_prog
,
startup_prog
,
args
.
scale_loss
)
optimizer
.
apply_gradients
(
master_params_grads
)
for
idx
,
m_p_g
in
enumerate
(
master_params_grads
):
train_p
,
train_g
=
params_grads
[
idx
]
if
train_p
.
name
.
startswith
(
"batch_norm"
):
continue
with
main_prog
.
_optimized_guard
([
m_p_g
[
0
],
m_p_g
[
1
]]):
cast_fp32_to_fp16
(
m_p_g
[
0
],
train_p
,
main_prog
)
master_param_to_train_param
(
master_params_grads
,
params_grads
,
main_prog
)
else
:
optimizer
.
minimize
(
avg_cost
)
...
...
fluid/PaddleCV/image_classification/utils/__init__.py
浏览文件 @
96aade95
from
.learning_rate
import
cosine_decay
,
lr_warmup
from
.fp16_utils
import
create_master_params_grads
,
master_param_to_train_param
fluid/PaddleCV/image_classification/utils/fp16_utils.py
0 → 100644
浏览文件 @
96aade95
from
__future__
import
print_function
import
paddle
import
paddle.fluid
as
fluid
def
cast_fp16_to_fp32
(
i
,
o
,
prog
):
prog
.
global_block
().
append_op
(
type
=
"cast"
,
inputs
=
{
"X"
:
i
},
outputs
=
{
"Out"
:
o
},
attrs
=
{
"in_dtype"
:
fluid
.
core
.
VarDesc
.
VarType
.
FP16
,
"out_dtype"
:
fluid
.
core
.
VarDesc
.
VarType
.
FP32
}
)
def
cast_fp32_to_fp16
(
i
,
o
,
prog
):
prog
.
global_block
().
append_op
(
type
=
"cast"
,
inputs
=
{
"X"
:
i
},
outputs
=
{
"Out"
:
o
},
attrs
=
{
"in_dtype"
:
fluid
.
core
.
VarDesc
.
VarType
.
FP32
,
"out_dtype"
:
fluid
.
core
.
VarDesc
.
VarType
.
FP16
}
)
def
copy_to_master_param
(
p
,
block
):
v
=
block
.
vars
.
get
(
p
.
name
,
None
)
if
v
is
None
:
raise
ValueError
(
"no param name %s found!"
%
p
.
name
)
new_p
=
fluid
.
framework
.
Parameter
(
block
=
block
,
shape
=
v
.
shape
,
dtype
=
fluid
.
core
.
VarDesc
.
VarType
.
FP32
,
type
=
v
.
type
,
lod_level
=
v
.
lod_level
,
stop_gradient
=
p
.
stop_gradient
,
trainable
=
p
.
trainable
,
optimize_attr
=
p
.
optimize_attr
,
regularizer
=
p
.
regularizer
,
gradient_clip_attr
=
p
.
gradient_clip_attr
,
error_clip
=
p
.
error_clip
,
name
=
v
.
name
+
".master"
)
return
new_p
def
create_master_params_grads
(
params_grads
,
main_prog
,
startup_prog
,
scale_loss
):
master_params_grads
=
[]
tmp_role
=
main_prog
.
_current_role
OpRole
=
fluid
.
core
.
op_proto_and_checker_maker
.
OpRole
main_prog
.
_current_role
=
OpRole
.
Backward
for
p
,
g
in
params_grads
:
# create master parameters
master_param
=
copy_to_master_param
(
p
,
main_prog
.
global_block
())
startup_master_param
=
startup_prog
.
global_block
().
_clone_variable
(
master_param
)
startup_p
=
startup_prog
.
global_block
().
var
(
p
.
name
)
cast_fp16_to_fp32
(
startup_p
,
startup_master_param
,
startup_prog
)
# cast fp16 gradients to fp32 before apply gradients
if
g
.
name
.
startswith
(
"batch_norm"
):
if
scale_loss
>
1
:
scaled_g
=
g
/
float
(
scale_loss
)
else
:
scaled_g
=
g
master_params_grads
.
append
([
p
,
scaled_g
])
continue
master_grad
=
fluid
.
layers
.
cast
(
g
,
"float32"
)
if
scale_loss
>
1
:
master_grad
=
master_grad
/
float
(
scale_loss
)
master_params_grads
.
append
([
master_param
,
master_grad
])
main_prog
.
_current_role
=
tmp_role
return
master_params_grads
def
master_param_to_train_param
(
master_params_grads
,
params_grads
,
main_prog
):
for
idx
,
m_p_g
in
enumerate
(
master_params_grads
):
train_p
,
_
=
params_grads
[
idx
]
if
train_p
.
name
.
startswith
(
"batch_norm"
):
continue
with
main_prog
.
_optimized_guard
([
m_p_g
[
0
],
m_p_g
[
1
]]):
cast_fp32_to_fp16
(
m_p_g
[
0
],
train_p
,
main_prog
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录