Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
dccd7ed9
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
1 年多 前同步成功
通知
116
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
dccd7ed9
编写于
3月 02, 2021
作者:
W
Wei Shengyu
提交者:
GitHub
3月 02, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #619 from huangxu96/cp_fp16_training
[Cherry-pick]new usage of amp training. (#564)
上级
e02a35ac
1df66418
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
90 addition
and
254 deletion
+90
-254
configs/ResNet/ResNet50_fp16.yaml
configs/ResNet/ResNet50_fp16.yaml
+14
-9
ppcls/data/imaug/operators.py
ppcls/data/imaug/operators.py
+16
-3
ppcls/modeling/architectures/resnet.py
ppcls/modeling/architectures/resnet.py
+12
-8
ppcls/modeling/loss.py
ppcls/modeling/loss.py
+8
-14
ppcls/optimizer/optimizer.py
ppcls/optimizer/optimizer.py
+4
-1
tools/static/dali.py
tools/static/dali.py
+5
-1
tools/static/optimizer.py
tools/static/optimizer.py
+0
-171
tools/static/program.py
tools/static/program.py
+23
-38
tools/static/train.py
tools/static/train.py
+8
-9
未找到文件。
configs/ResNet/ResNet50_fp16.yml
→
configs/ResNet/ResNet50_fp16.y
a
ml
浏览文件 @
dccd7ed9
...
@@ -11,21 +11,23 @@ validate: True
...
@@ -11,21 +11,23 @@ validate: True
valid_interval
:
1
valid_interval
:
1
epochs
:
120
epochs
:
120
topk
:
5
topk
:
5
image_shape
:
[
3
,
224
,
224
]
is_distributed
:
True
is_distributed
:
True
# mixed precision training
use_dali
:
True
use_amp
:
True
use_gpu
:
True
use_pure_fp16
:
False
data_format
:
"
NHWC"
multi_precision
:
False
image_channel
:
&image_channel
4
scale_loss
:
128.0
image_shape
:
[
*image_channel
,
224
,
224
]
use_dynamic_loss_scaling
:
True
data_format
:
"
NCHW"
image_shape
:
[
3
,
224
,
224
]
use_mix
:
False
use_mix
:
False
ls_epsilon
:
-1
ls_epsilon
:
-1
# mixed precision training
AMP
:
scale_loss
:
128.0
use_dynamic_loss_scaling
:
True
use_pure_fp16
:
&use_pure_fp16
True
LEARNING_RATE
:
LEARNING_RATE
:
function
:
'
Piecewise'
function
:
'
Piecewise'
params
:
params
:
...
@@ -37,6 +39,7 @@ OPTIMIZER:
...
@@ -37,6 +39,7 @@ OPTIMIZER:
function
:
'
Momentum'
function
:
'
Momentum'
params
:
params
:
momentum
:
0.9
momentum
:
0.9
multi_precision
:
*use_pure_fp16
regularizer
:
regularizer
:
function
:
'
L2'
function
:
'
L2'
factor
:
0.000100
factor
:
0.000100
...
@@ -61,6 +64,8 @@ TRAIN:
...
@@ -61,6 +64,8 @@ TRAIN:
mean
:
[
0.485
,
0.456
,
0.406
]
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
order
:
'
'
output_fp16
:
*use_pure_fp16
channel_num
:
*image_channel
-
ToCHWImage
:
-
ToCHWImage
:
VALID
:
VALID
:
...
...
ppcls/data/imaug/operators.py
浏览文件 @
dccd7ed9
...
@@ -195,14 +195,18 @@ class NormalizeImage(object):
...
@@ -195,14 +195,18 @@ class NormalizeImage(object):
""" normalize image such as substract mean, divide std
""" normalize image such as substract mean, divide std
"""
"""
def
__init__
(
self
,
scale
=
None
,
mean
=
None
,
std
=
None
,
order
=
'chw'
):
def
__init__
(
self
,
scale
=
None
,
mean
=
None
,
std
=
None
,
order
=
'chw'
,
output_fp16
=
False
,
channel_num
=
3
):
if
isinstance
(
scale
,
str
):
if
isinstance
(
scale
,
str
):
scale
=
eval
(
scale
)
scale
=
eval
(
scale
)
assert
channel_num
in
[
3
,
4
],
"channel number of input image should be set to 3 or 4."
self
.
channel_num
=
channel_num
self
.
output_dtype
=
'float16'
if
output_fp16
else
'float32'
self
.
scale
=
np
.
float32
(
scale
if
scale
is
not
None
else
1.0
/
255.0
)
self
.
scale
=
np
.
float32
(
scale
if
scale
is
not
None
else
1.0
/
255.0
)
self
.
order
=
order
mean
=
mean
if
mean
is
not
None
else
[
0.485
,
0.456
,
0.406
]
mean
=
mean
if
mean
is
not
None
else
[
0.485
,
0.456
,
0.406
]
std
=
std
if
std
is
not
None
else
[
0.229
,
0.224
,
0.225
]
std
=
std
if
std
is
not
None
else
[
0.229
,
0.224
,
0.225
]
shape
=
(
3
,
1
,
1
)
if
order
==
'chw'
else
(
1
,
1
,
3
)
shape
=
(
3
,
1
,
1
)
if
self
.
order
==
'chw'
else
(
1
,
1
,
3
)
self
.
mean
=
np
.
array
(
mean
).
reshape
(
shape
).
astype
(
'float32'
)
self
.
mean
=
np
.
array
(
mean
).
reshape
(
shape
).
astype
(
'float32'
)
self
.
std
=
np
.
array
(
std
).
reshape
(
shape
).
astype
(
'float32'
)
self
.
std
=
np
.
array
(
std
).
reshape
(
shape
).
astype
(
'float32'
)
...
@@ -213,7 +217,16 @@ class NormalizeImage(object):
...
@@ -213,7 +217,16 @@ class NormalizeImage(object):
assert
isinstance
(
img
,
assert
isinstance
(
img
,
np
.
ndarray
),
"invalid input 'img' in NormalizeImage"
np
.
ndarray
),
"invalid input 'img' in NormalizeImage"
return
(
img
.
astype
(
'float32'
)
*
self
.
scale
-
self
.
mean
)
/
self
.
std
img
=
(
img
.
astype
(
'float32'
)
*
self
.
scale
-
self
.
mean
)
/
self
.
std
if
self
.
channel_num
==
4
:
img_h
=
img
.
shape
[
1
]
if
self
.
order
==
'chw'
else
img
.
shape
[
0
]
img_w
=
img
.
shape
[
2
]
if
self
.
order
==
'chw'
else
img
.
shape
[
1
]
pad_zeros
=
np
.
zeros
((
1
,
img_h
,
img_w
))
if
self
.
order
==
'chw'
else
np
.
zeros
((
img_h
,
img_w
,
1
))
img
=
(
np
.
concatenate
((
img
,
pad_zeros
),
axis
=
0
)
if
self
.
order
==
'chw'
else
np
.
concatenate
((
img
,
pad_zeros
),
axis
=
2
))
return
img
.
astype
(
self
.
output_dtype
)
class
ToCHWImage
(
object
):
class
ToCHWImage
(
object
):
...
...
ppcls/modeling/architectures/resnet.py
浏览文件 @
dccd7ed9
...
@@ -277,14 +277,18 @@ class ResNet(nn.Layer):
...
@@ -277,14 +277,18 @@ class ResNet(nn.Layer):
bias_attr
=
ParamAttr
(
name
=
"fc_0.b_0"
))
bias_attr
=
ParamAttr
(
name
=
"fc_0.b_0"
))
def
forward
(
self
,
inputs
):
def
forward
(
self
,
inputs
):
y
=
self
.
conv
(
inputs
)
with
paddle
.
static
.
amp
.
fp16_guard
():
y
=
self
.
pool2d_max
(
y
)
if
self
.
data_format
==
"NHWC"
:
for
block
in
self
.
block_list
:
inputs
=
paddle
.
tensor
.
transpose
(
inputs
,
[
0
,
2
,
3
,
1
])
y
=
block
(
y
)
inputs
.
stop_gradient
=
True
y
=
self
.
pool2d_avg
(
y
)
y
=
self
.
conv
(
inputs
)
y
=
paddle
.
reshape
(
y
,
shape
=
[
-
1
,
self
.
pool2d_avg_channels
])
y
=
self
.
pool2d_max
(
y
)
y
=
self
.
out
(
y
)
for
block
in
self
.
block_list
:
return
y
y
=
block
(
y
)
y
=
self
.
pool2d_avg
(
y
)
y
=
paddle
.
reshape
(
y
,
shape
=
[
-
1
,
self
.
pool2d_avg_channels
])
y
=
self
.
out
(
y
)
return
y
def
ResNet18
(
**
args
):
def
ResNet18
(
**
args
):
...
...
ppcls/modeling/loss.py
浏览文件 @
dccd7ed9
...
@@ -42,17 +42,14 @@ class Loss(object):
...
@@ -42,17 +42,14 @@ class Loss(object):
soft_target
=
paddle
.
reshape
(
soft_target
,
shape
=
[
-
1
,
self
.
_class_dim
])
soft_target
=
paddle
.
reshape
(
soft_target
,
shape
=
[
-
1
,
self
.
_class_dim
])
return
soft_target
return
soft_target
def
_crossentropy
(
self
,
input
,
target
,
use_pure_fp16
=
False
):
def
_crossentropy
(
self
,
input
,
target
):
if
self
.
_label_smoothing
:
if
self
.
_label_smoothing
:
target
=
self
.
_labelsmoothing
(
target
)
target
=
self
.
_labelsmoothing
(
target
)
input
=
-
F
.
log_softmax
(
input
,
axis
=-
1
)
input
=
-
F
.
log_softmax
(
input
,
axis
=-
1
)
cost
=
paddle
.
sum
(
target
*
input
,
axis
=-
1
)
cost
=
paddle
.
sum
(
target
*
input
,
axis
=-
1
)
else
:
else
:
cost
=
F
.
cross_entropy
(
input
=
input
,
label
=
target
)
cost
=
F
.
cross_entropy
(
input
=
input
,
label
=
target
)
if
use_pure_fp16
:
avg_cost
=
paddle
.
mean
(
cost
)
avg_cost
=
paddle
.
sum
(
cost
)
else
:
avg_cost
=
paddle
.
mean
(
cost
)
return
avg_cost
return
avg_cost
def
_kldiv
(
self
,
input
,
target
,
name
=
None
):
def
_kldiv
(
self
,
input
,
target
,
name
=
None
):
...
@@ -81,8 +78,8 @@ class CELoss(Loss):
...
@@ -81,8 +78,8 @@ class CELoss(Loss):
def
__init__
(
self
,
class_dim
=
1000
,
epsilon
=
None
):
def
__init__
(
self
,
class_dim
=
1000
,
epsilon
=
None
):
super
(
CELoss
,
self
).
__init__
(
class_dim
,
epsilon
)
super
(
CELoss
,
self
).
__init__
(
class_dim
,
epsilon
)
def
__call__
(
self
,
input
,
target
,
use_pure_fp16
=
False
):
def
__call__
(
self
,
input
,
target
):
cost
=
self
.
_crossentropy
(
input
,
target
,
use_pure_fp16
)
cost
=
self
.
_crossentropy
(
input
,
target
)
return
cost
return
cost
...
@@ -94,14 +91,11 @@ class MixCELoss(Loss):
...
@@ -94,14 +91,11 @@ class MixCELoss(Loss):
def
__init__
(
self
,
class_dim
=
1000
,
epsilon
=
None
):
def
__init__
(
self
,
class_dim
=
1000
,
epsilon
=
None
):
super
(
MixCELoss
,
self
).
__init__
(
class_dim
,
epsilon
)
super
(
MixCELoss
,
self
).
__init__
(
class_dim
,
epsilon
)
def
__call__
(
self
,
input
,
target0
,
target1
,
lam
,
use_pure_fp16
=
False
):
def
__call__
(
self
,
input
,
target0
,
target1
,
lam
):
cost0
=
self
.
_crossentropy
(
input
,
target0
,
use_pure_fp16
)
cost0
=
self
.
_crossentropy
(
input
,
target0
)
cost1
=
self
.
_crossentropy
(
input
,
target1
,
use_pure_fp16
)
cost1
=
self
.
_crossentropy
(
input
,
target1
)
cost
=
lam
*
cost0
+
(
1.0
-
lam
)
*
cost1
cost
=
lam
*
cost0
+
(
1.0
-
lam
)
*
cost1
if
use_pure_fp16
:
avg_cost
=
paddle
.
mean
(
cost
)
avg_cost
=
paddle
.
sum
(
cost
)
else
:
avg_cost
=
paddle
.
mean
(
cost
)
return
avg_cost
return
avg_cost
...
...
ppcls/optimizer/optimizer.py
浏览文件 @
dccd7ed9
...
@@ -74,19 +74,22 @@ class Momentum(object):
...
@@ -74,19 +74,22 @@ class Momentum(object):
momentum
,
momentum
,
parameter_list
=
None
,
parameter_list
=
None
,
regularization
=
None
,
regularization
=
None
,
multi_precision
=
False
,
**
args
):
**
args
):
super
(
Momentum
,
self
).
__init__
()
super
(
Momentum
,
self
).
__init__
()
self
.
learning_rate
=
learning_rate
self
.
learning_rate
=
learning_rate
self
.
momentum
=
momentum
self
.
momentum
=
momentum
self
.
parameter_list
=
parameter_list
self
.
parameter_list
=
parameter_list
self
.
regularization
=
regularization
self
.
regularization
=
regularization
self
.
multi_precision
=
multi_precision
def
__call__
(
self
):
def
__call__
(
self
):
opt
=
paddle
.
optimizer
.
Momentum
(
opt
=
paddle
.
optimizer
.
Momentum
(
learning_rate
=
self
.
learning_rate
,
learning_rate
=
self
.
learning_rate
,
momentum
=
self
.
momentum
,
momentum
=
self
.
momentum
,
parameters
=
self
.
parameter_list
,
parameters
=
self
.
parameter_list
,
weight_decay
=
self
.
regularization
)
weight_decay
=
self
.
regularization
,
multi_precision
=
self
.
multi_precision
)
return
opt
return
opt
...
...
tools/static/dali.py
浏览文件 @
dccd7ed9
...
@@ -176,7 +176,11 @@ def build(config, mode='train'):
...
@@ -176,7 +176,11 @@ def build(config, mode='train'):
2
:
types
.
INTERP_CUBIC
,
# cv2.INTER_CUBIC
2
:
types
.
INTERP_CUBIC
,
# cv2.INTER_CUBIC
4
:
types
.
INTERP_LANCZOS3
,
# XXX use LANCZOS3 for cv2.INTER_LANCZOS4
4
:
types
.
INTERP_LANCZOS3
,
# XXX use LANCZOS3 for cv2.INTER_LANCZOS4
}
}
output_dtype
=
types
.
FLOAT16
if
config
.
get
(
"use_pure_fp16"
,
False
)
else
types
.
FLOAT
output_dtype
=
(
types
.
FLOAT16
if
'AMP'
in
config
and
config
.
AMP
.
get
(
"use_pure_fp16"
,
False
)
else
types
.
FLOAT
)
assert
interp
in
interp_map
,
"interpolation method not supported by DALI"
assert
interp
in
interp_map
,
"interpolation method not supported by DALI"
interp
=
interp_map
[
interp
]
interp
=
interp_map
[
interp
]
pad_output
=
False
pad_output
=
False
...
...
tools/static/optimizer.py
已删除
100644 → 0
浏览文件 @
e02a35ac
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
sys
import
paddle
import
paddle.fluid
as
fluid
import
paddle.regularizer
as
regularizer
__all__
=
[
'OptimizerBuilder'
]
class
L1Decay
(
object
):
"""
L1 Weight Decay Regularization, which encourages the weights to be sparse.
Args:
factor(float): regularization coeff. Default:0.0.
"""
def
__init__
(
self
,
factor
=
0.0
):
super
(
L1Decay
,
self
).
__init__
()
self
.
factor
=
factor
def
__call__
(
self
):
reg
=
regularizer
.
L1Decay
(
self
.
factor
)
return
reg
class
L2Decay
(
object
):
"""
L2 Weight Decay Regularization, which encourages the weights to be sparse.
Args:
factor(float): regularization coeff. Default:0.0.
"""
def
__init__
(
self
,
factor
=
0.0
):
super
(
L2Decay
,
self
).
__init__
()
self
.
factor
=
factor
def
__call__
(
self
):
reg
=
regularizer
.
L2Decay
(
self
.
factor
)
return
reg
class
Momentum
(
object
):
"""
Simple Momentum optimizer with velocity state.
Args:
learning_rate (float|Variable) - The learning rate used to update parameters.
Can be a float value or a Variable with one float value as data element.
momentum (float) - Momentum factor.
regularization (WeightDecayRegularizer, optional) - The strategy of regularization.
"""
def
__init__
(
self
,
learning_rate
,
momentum
,
parameter_list
=
None
,
regularization
=
None
,
config
=
None
,
**
args
):
super
(
Momentum
,
self
).
__init__
()
self
.
learning_rate
=
learning_rate
self
.
momentum
=
momentum
self
.
parameter_list
=
parameter_list
self
.
regularization
=
regularization
self
.
multi_precision
=
config
.
get
(
'multi_precision'
,
False
)
self
.
rescale_grad
=
(
1.0
/
(
config
[
'TRAIN'
][
'batch_size'
]
/
len
(
fluid
.
cuda_places
()))
if
config
.
get
(
'use_pure_fp16'
,
False
)
else
1.0
)
def
__call__
(
self
):
opt
=
fluid
.
contrib
.
optimizer
.
Momentum
(
learning_rate
=
self
.
learning_rate
,
momentum
=
self
.
momentum
,
regularization
=
self
.
regularization
,
multi_precision
=
self
.
multi_precision
,
rescale_grad
=
self
.
rescale_grad
)
return
opt
class
RMSProp
(
object
):
"""
Root Mean Squared Propagation (RMSProp) is an unpublished, adaptive learning rate method.
Args:
learning_rate (float|Variable) - The learning rate used to update parameters.
Can be a float value or a Variable with one float value as data element.
momentum (float) - Momentum factor.
rho (float) - rho value in equation.
epsilon (float) - avoid division by zero, default is 1e-6.
regularization (WeightDecayRegularizer, optional) - The strategy of regularization.
"""
def
__init__
(
self
,
learning_rate
,
momentum
,
rho
=
0.95
,
epsilon
=
1e-6
,
parameter_list
=
None
,
regularization
=
None
,
**
args
):
super
(
RMSProp
,
self
).
__init__
()
self
.
learning_rate
=
learning_rate
self
.
momentum
=
momentum
self
.
rho
=
rho
self
.
epsilon
=
epsilon
self
.
parameter_list
=
parameter_list
self
.
regularization
=
regularization
def
__call__
(
self
):
opt
=
paddle
.
optimizer
.
RMSProp
(
learning_rate
=
self
.
learning_rate
,
momentum
=
self
.
momentum
,
rho
=
self
.
rho
,
epsilon
=
self
.
epsilon
,
parameters
=
self
.
parameter_list
,
weight_decay
=
self
.
regularization
)
return
opt
class
OptimizerBuilder
(
object
):
"""
Build optimizer
Args:
function(str): optimizer name of learning rate
params(dict): parameters used for init the class
regularizer (dict): parameters used for create regularization
"""
def
__init__
(
self
,
config
=
None
,
function
=
'Momentum'
,
params
=
{
'momentum'
:
0.9
},
regularizer
=
None
):
self
.
function
=
function
self
.
params
=
params
self
.
config
=
config
# create regularizer
if
regularizer
is
not
None
:
mod
=
sys
.
modules
[
__name__
]
reg_func
=
regularizer
[
'function'
]
+
'Decay'
del
regularizer
[
'function'
]
reg
=
getattr
(
mod
,
reg_func
)(
**
regularizer
)()
self
.
params
[
'regularization'
]
=
reg
def
__call__
(
self
,
learning_rate
,
parameter_list
=
None
):
mod
=
sys
.
modules
[
__name__
]
opt
=
getattr
(
mod
,
self
.
function
)
return
opt
(
learning_rate
=
learning_rate
,
parameter_list
=
parameter_list
,
config
=
self
.
config
,
**
self
.
params
)()
tools/static/program.py
浏览文件 @
dccd7ed9
...
@@ -21,12 +21,10 @@ import time
...
@@ -21,12 +21,10 @@ import time
import
numpy
as
np
import
numpy
as
np
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
optimizer
import
OptimizerBuilder
from
ppcls.
optimizer
import
OptimizerBuilder
import
paddle
import
paddle
import
paddle.nn.functional
as
F
import
paddle.nn.functional
as
F
from
paddle
import
fluid
from
paddle.fluid.contrib.mixed_precision.fp16_utils
import
cast_model_to_fp16
from
ppcls.optimizer.learning_rate
import
LearningRateBuilder
from
ppcls.optimizer.learning_rate
import
LearningRateBuilder
from
ppcls.modeling
import
architectures
from
ppcls.modeling
import
architectures
...
@@ -83,11 +81,9 @@ def create_model(architecture, image, classes_num, config, is_train):
...
@@ -83,11 +81,9 @@ def create_model(architecture, image, classes_num, config, is_train):
Returns:
Returns:
out(variable): model output variable
out(variable): model output variable
"""
"""
use_pure_fp16
=
config
.
get
(
"use_pure_fp16"
,
False
)
name
=
architecture
[
"name"
]
name
=
architecture
[
"name"
]
params
=
architecture
.
get
(
"params"
,
{})
params
=
architecture
.
get
(
"params"
,
{})
data_format
=
"NCHW"
if
"data_format"
in
config
:
if
"data_format"
in
config
:
params
[
"data_format"
]
=
config
[
"data_format"
]
params
[
"data_format"
]
=
config
[
"data_format"
]
data_format
=
config
[
"data_format"
]
data_format
=
config
[
"data_format"
]
...
@@ -100,16 +96,8 @@ def create_model(architecture, image, classes_num, config, is_train):
...
@@ -100,16 +96,8 @@ def create_model(architecture, image, classes_num, config, is_train):
if
"is_test"
in
params
:
if
"is_test"
in
params
:
params
[
'is_test'
]
=
not
is_train
params
[
'is_test'
]
=
not
is_train
model
=
architectures
.
__dict__
[
name
](
class_dim
=
classes_num
,
**
params
)
model
=
architectures
.
__dict__
[
name
](
class_dim
=
classes_num
,
**
params
)
if
use_pure_fp16
and
not
config
.
get
(
"use_dali"
,
False
):
image
=
image
.
astype
(
'float16'
)
if
data_format
==
"NHWC"
:
image
=
paddle
.
tensor
.
transpose
(
image
,
[
0
,
2
,
3
,
1
])
image
.
stop_gradient
=
True
out
=
model
(
image
)
out
=
model
(
image
)
if
config
.
get
(
"use_pure_fp16"
,
False
):
cast_model_to_fp16
(
paddle
.
static
.
default_main_program
())
out
=
out
.
astype
(
'float32'
)
return
out
return
out
...
@@ -119,8 +107,7 @@ def create_loss(out,
...
@@ -119,8 +107,7 @@ def create_loss(out,
classes_num
=
1000
,
classes_num
=
1000
,
epsilon
=
None
,
epsilon
=
None
,
use_mix
=
False
,
use_mix
=
False
,
use_distillation
=
False
,
use_distillation
=
False
):
use_pure_fp16
=
False
):
"""
"""
Create a loss for optimization, such as:
Create a loss for optimization, such as:
1. CrossEnotry loss
1. CrossEnotry loss
...
@@ -137,7 +124,6 @@ def create_loss(out,
...
@@ -137,7 +124,6 @@ def create_loss(out,
classes_num(int): num of classes
classes_num(int): num of classes
epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
use_pure_fp16(bool): whether to use pure fp16 data as training parameter
Returns:
Returns:
loss(variable): loss variable
loss(variable): loss variable
...
@@ -162,10 +148,10 @@ def create_loss(out,
...
@@ -162,10 +148,10 @@ def create_loss(out,
if
use_mix
:
if
use_mix
:
loss
=
MixCELoss
(
class_dim
=
classes_num
,
epsilon
=
epsilon
)
loss
=
MixCELoss
(
class_dim
=
classes_num
,
epsilon
=
epsilon
)
return
loss
(
out
,
feed_y_a
,
feed_y_b
,
feed_lam
,
use_pure_fp16
)
return
loss
(
out
,
feed_y_a
,
feed_y_b
,
feed_lam
)
else
:
else
:
loss
=
CELoss
(
class_dim
=
classes_num
,
epsilon
=
epsilon
)
loss
=
CELoss
(
class_dim
=
classes_num
,
epsilon
=
epsilon
)
return
loss
(
out
,
target
,
use_pure_fp16
)
return
loss
(
out
,
target
)
def
create_metric
(
out
,
def
create_metric
(
out
,
...
@@ -239,9 +225,8 @@ def create_fetchs(out,
...
@@ -239,9 +225,8 @@ def create_fetchs(out,
fetchs(dict): dict of model outputs(included loss and measures)
fetchs(dict): dict of model outputs(included loss and measures)
"""
"""
fetchs
=
OrderedDict
()
fetchs
=
OrderedDict
()
use_pure_fp16
=
config
.
get
(
"use_pure_fp16"
,
False
)
loss
=
create_loss
(
out
,
feeds
,
architecture
,
classes_num
,
epsilon
,
use_mix
,
loss
=
create_loss
(
out
,
feeds
,
architecture
,
classes_num
,
epsilon
,
use_mix
,
use_distillation
,
use_pure_fp16
)
use_distillation
)
fetchs
[
'loss'
]
=
(
loss
,
AverageMeter
(
'loss'
,
'7.4f'
,
need_avg
=
True
))
fetchs
[
'loss'
]
=
(
loss
,
AverageMeter
(
'loss'
,
'7.4f'
,
need_avg
=
True
))
if
not
use_mix
:
if
not
use_mix
:
metric
=
create_metric
(
out
,
feeds
,
architecture
,
topk
,
classes_num
,
metric
=
create_metric
(
out
,
feeds
,
architecture
,
topk
,
classes_num
,
...
@@ -285,7 +270,7 @@ def create_optimizer(config):
...
@@ -285,7 +270,7 @@ def create_optimizer(config):
# create optimizer instance
# create optimizer instance
opt_config
=
config
[
'OPTIMIZER'
]
opt_config
=
config
[
'OPTIMIZER'
]
opt
=
OptimizerBuilder
(
config
,
**
opt_config
)
opt
=
OptimizerBuilder
(
**
opt_config
)
return
opt
(
lr
),
lr
return
opt
(
lr
),
lr
...
@@ -304,11 +289,11 @@ def create_strategy(config):
...
@@ -304,11 +289,11 @@ def create_strategy(config):
exec_strategy
=
paddle
.
static
.
ExecutionStrategy
()
exec_strategy
=
paddle
.
static
.
ExecutionStrategy
()
exec_strategy
.
num_threads
=
1
exec_strategy
.
num_threads
=
1
exec_strategy
.
num_iteration_per_drop_scope
=
10000
if
config
.
get
(
exec_strategy
.
num_iteration_per_drop_scope
=
(
10000
if
'AMP'
in
config
and
'use_pure_fp16'
,
False
)
else
10
config
.
AMP
.
get
(
"use_pure_fp16"
,
False
)
else
10
)
fuse_op
=
True
if
'AMP'
in
config
else
False
fuse_op
=
config
.
get
(
'use_amp'
,
False
)
or
config
.
get
(
'use_pure_fp16'
,
False
)
fuse_bn_act_ops
=
config
.
get
(
'fuse_bn_act_ops'
,
fuse_op
)
fuse_bn_act_ops
=
config
.
get
(
'fuse_bn_act_ops'
,
fuse_op
)
fuse_elewise_add_act_ops
=
config
.
get
(
'fuse_elewise_add_act_ops'
,
fuse_op
)
fuse_elewise_add_act_ops
=
config
.
get
(
'fuse_elewise_add_act_ops'
,
fuse_op
)
fuse_bn_add_act_ops
=
config
.
get
(
'fuse_bn_add_act_ops'
,
fuse_op
)
fuse_bn_add_act_ops
=
config
.
get
(
'fuse_bn_add_act_ops'
,
fuse_op
)
...
@@ -369,14 +354,17 @@ def dist_optimizer(config, optimizer):
...
@@ -369,14 +354,17 @@ def dist_optimizer(config, optimizer):
def
mixed_precision_optimizer
(
config
,
optimizer
):
def
mixed_precision_optimizer
(
config
,
optimizer
):
use_amp
=
config
.
get
(
'use_amp'
,
False
)
if
'AMP'
in
config
:
scale_loss
=
config
.
get
(
'scale_loss'
,
1.0
)
amp_cfg
=
config
.
AMP
if
config
.
AMP
else
dict
()
use_dynamic_loss_scaling
=
config
.
get
(
'use_dynamic_loss_scaling'
,
False
)
scale_loss
=
amp_cfg
.
get
(
'scale_loss'
,
1.0
)
if
use_amp
:
use_dynamic_loss_scaling
=
amp_cfg
.
get
(
'use_dynamic_loss_scaling'
,
False
)
optimizer
=
fluid
.
contrib
.
mixed_precision
.
decorate
(
use_pure_fp16
=
amp_cfg
.
get
(
'use_pure_fp16'
,
False
)
optimizer
=
paddle
.
static
.
amp
.
decorate
(
optimizer
,
optimizer
,
init_loss_scaling
=
scale_loss
,
init_loss_scaling
=
scale_loss
,
use_dynamic_loss_scaling
=
use_dynamic_loss_scaling
)
use_dynamic_loss_scaling
=
use_dynamic_loss_scaling
,
use_pure_fp16
=
use_pure_fp16
,
use_fp16_guard
=
True
)
return
optimizer
return
optimizer
...
@@ -407,15 +395,11 @@ def build(config, main_prog, startup_prog, is_train=True, is_distributed=True):
...
@@ -407,15 +395,11 @@ def build(config, main_prog, startup_prog, is_train=True, is_distributed=True):
use_dali
=
config
.
get
(
'use_dali'
,
False
)
use_dali
=
config
.
get
(
'use_dali'
,
False
)
use_distillation
=
config
.
get
(
'use_distillation'
)
use_distillation
=
config
.
get
(
'use_distillation'
)
image_dtype
=
"float32"
if
config
[
"ARCHITECTURE"
][
"name"
]
==
"ResNet50"
and
config
.
get
(
"use_pure_fp16"
,
False
)
\
and
config
.
get
(
"use_dali"
,
False
):
image_dtype
=
"float16"
feeds
=
create_feeds
(
feeds
=
create_feeds
(
config
.
image_shape
,
config
.
image_shape
,
use_mix
=
use_mix
,
use_mix
=
use_mix
,
use_dali
=
use_dali
,
use_dali
=
use_dali
,
dtype
=
image_dtype
)
dtype
=
"float32"
)
if
use_dali
and
use_mix
:
if
use_dali
and
use_mix
:
import
dali
import
dali
feeds
=
dali
.
mix
(
feeds
,
config
,
is_train
)
feeds
=
dali
.
mix
(
feeds
,
config
,
is_train
)
...
@@ -432,13 +416,14 @@ def build(config, main_prog, startup_prog, is_train=True, is_distributed=True):
...
@@ -432,13 +416,14 @@ def build(config, main_prog, startup_prog, is_train=True, is_distributed=True):
config
=
config
,
config
=
config
,
use_distillation
=
use_distillation
)
use_distillation
=
use_distillation
)
lr_scheduler
=
None
lr_scheduler
=
None
optimizer
=
None
if
is_train
:
if
is_train
:
optimizer
,
lr_scheduler
=
create_optimizer
(
config
)
optimizer
,
lr_scheduler
=
create_optimizer
(
config
)
optimizer
=
mixed_precision_optimizer
(
config
,
optimizer
)
optimizer
=
mixed_precision_optimizer
(
config
,
optimizer
)
if
is_distributed
:
if
is_distributed
:
optimizer
=
dist_optimizer
(
config
,
optimizer
)
optimizer
=
dist_optimizer
(
config
,
optimizer
)
optimizer
.
minimize
(
fetchs
[
'loss'
][
0
])
optimizer
.
minimize
(
fetchs
[
'loss'
][
0
])
return
fetchs
,
lr_scheduler
,
feeds
return
fetchs
,
lr_scheduler
,
feeds
,
optimizer
def
compile
(
config
,
program
,
loss_name
=
None
,
share_prog
=
None
):
def
compile
(
config
,
program
,
loss_name
=
None
,
share_prog
=
None
):
...
...
tools/static/train.py
浏览文件 @
dccd7ed9
...
@@ -26,8 +26,6 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))
...
@@ -26,8 +26,6 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))
from
sys
import
version_info
from
sys
import
version_info
import
paddle
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.contrib.mixed_precision.fp16_utils
import
cast_parameters_to_fp16
from
paddle.distributed
import
fleet
from
paddle.distributed
import
fleet
from
ppcls.data
import
Reader
from
ppcls.data
import
Reader
...
@@ -67,9 +65,7 @@ def main(args):
...
@@ -67,9 +65,7 @@ def main(args):
# assign the place
# assign the place
use_gpu
=
config
.
get
(
"use_gpu"
,
True
)
use_gpu
=
config
.
get
(
"use_gpu"
,
True
)
# amp related config
# amp related config
use_amp
=
config
.
get
(
'use_amp'
,
False
)
if
'AMP'
in
config
:
use_pure_fp16
=
config
.
get
(
'use_pure_fp16'
,
False
)
if
use_amp
or
use_pure_fp16
:
AMP_RELATED_FLAGS_SETTING
=
{
AMP_RELATED_FLAGS_SETTING
=
{
'FLAGS_cudnn_exhaustive_search'
:
1
,
'FLAGS_cudnn_exhaustive_search'
:
1
,
'FLAGS_conv_workspace_size_limit'
:
1500
,
'FLAGS_conv_workspace_size_limit'
:
1500
,
...
@@ -97,7 +93,7 @@ def main(args):
...
@@ -97,7 +93,7 @@ def main(args):
best_top1_acc
=
0.0
# best top1 acc record
best_top1_acc
=
0.0
# best top1 acc record
train_fetchs
,
lr_scheduler
,
train_feeds
=
program
.
build
(
train_fetchs
,
lr_scheduler
,
train_feeds
,
optimizer
=
program
.
build
(
config
,
config
,
train_prog
,
train_prog
,
startup_prog
,
startup_prog
,
...
@@ -106,7 +102,7 @@ def main(args):
...
@@ -106,7 +102,7 @@ def main(args):
if
config
.
validate
:
if
config
.
validate
:
valid_prog
=
paddle
.
static
.
Program
()
valid_prog
=
paddle
.
static
.
Program
()
valid_fetchs
,
_
,
valid_feeds
=
program
.
build
(
valid_fetchs
,
_
,
valid_feeds
,
_
=
program
.
build
(
config
,
config
,
valid_prog
,
valid_prog
,
startup_prog
,
startup_prog
,
...
@@ -119,11 +115,14 @@ def main(args):
...
@@ -119,11 +115,14 @@ def main(args):
exe
=
paddle
.
static
.
Executor
(
place
)
exe
=
paddle
.
static
.
Executor
(
place
)
# Parameter initialization
# Parameter initialization
exe
.
run
(
startup_prog
)
exe
.
run
(
startup_prog
)
if
config
.
get
(
"use_pure_fp16"
,
False
):
cast_parameters_to_fp16
(
place
,
train_prog
,
fluid
.
global_scope
())
# load pretrained models or checkpoints
# load pretrained models or checkpoints
init_model
(
config
,
train_prog
,
exe
)
init_model
(
config
,
train_prog
,
exe
)
if
'AMP'
in
config
and
config
.
AMP
.
get
(
"use_pure_fp16"
,
False
):
optimizer
.
amp_init
(
place
,
scope
=
paddle
.
static
.
global_scope
(),
test_program
=
valid_prog
if
config
.
validate
else
None
)
if
not
config
.
get
(
"is_distributed"
,
True
)
and
not
use_xpu
:
if
not
config
.
get
(
"is_distributed"
,
True
)
and
not
use_xpu
:
compiled_train_prog
=
program
.
compile
(
compiled_train_prog
=
program
.
compile
(
config
,
train_prog
,
loss_name
=
train_fetchs
[
"loss"
][
0
].
name
)
config
,
train_prog
,
loss_name
=
train_fetchs
[
"loss"
][
0
].
name
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录