Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
832f1c74
M
models
项目概览
PaddlePaddle
/
models
1 年多 前同步成功
通知
222
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
832f1c74
编写于
9月 16, 2020
作者:
Z
Zhen Wang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add pure fp16 training.
上级
08f3c0be
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
36 addition
and
14 deletion
+36
-14
PaddleCV/image_classification/build_model.py
PaddleCV/image_classification/build_model.py
+13
-3
PaddleCV/image_classification/dali.py
PaddleCV/image_classification/dali.py
+13
-7
PaddleCV/image_classification/scripts/train/ResNet50_fp16.sh
PaddleCV/image_classification/scripts/train/ResNet50_fp16.sh
+4
-2
PaddleCV/image_classification/train.py
PaddleCV/image_classification/train.py
+4
-1
PaddleCV/image_classification/utils/utility.py
PaddleCV/image_classification/utils/utility.py
+2
-1
未找到文件。
PaddleCV/image_classification/build_model.py
浏览文件 @
832f1c74
...
...
@@ -14,7 +14,7 @@
import
paddle
import
paddle.fluid
as
fluid
import
utils.utility
as
utility
from
paddle.fluid.contrib.mixed_precision.fp16_utils
import
cast_net_to_fp16
def
_calc_label_smoothing_loss
(
softmax_out
,
label
,
class_dim
,
epsilon
):
"""Calculate label smoothing loss
...
...
@@ -44,7 +44,12 @@ def _basic_model(data, model, args, is_train):
data_format
=
args
.
data_format
)
else
:
net_out
=
model
.
net
(
input
=
image
,
class_dim
=
args
.
class_dim
)
softmax_out
=
fluid
.
layers
.
softmax
(
net_out
,
use_cudnn
=
False
)
if
args
.
use_pure_fp16
:
cast_net_to_fp16
(
fluid
.
default_main_program
())
net_out_fp32
=
fluid
.
layers
.
cast
(
x
=
net_out
,
dtype
=
"float32"
)
softmax_out
=
fluid
.
layers
.
softmax
(
net_out_fp32
,
use_cudnn
=
False
)
else
:
softmax_out
=
fluid
.
layers
.
softmax
(
net_out
,
use_cudnn
=
False
)
if
is_train
and
args
.
use_label_smoothing
:
cost
=
_calc_label_smoothing_loss
(
softmax_out
,
label
,
args
.
class_dim
,
...
...
@@ -104,7 +109,12 @@ def _mixup_model(data, model, args, is_train):
data_format
=
args
.
data_format
)
else
:
net_out
=
model
.
net
(
input
=
image
,
class_dim
=
args
.
class_dim
)
softmax_out
=
fluid
.
layers
.
softmax
(
net_out
,
use_cudnn
=
False
)
if
args
.
use_pure_fp16
:
cast_net_to_fp16
(
fluid
.
default_main_program
())
net_out_fp32
=
fluid
.
layers
.
cast
(
x
=
net_out
,
dtype
=
"float32"
)
softmax_out
=
fluid
.
layers
.
softmax
(
net_out_fp32
,
use_cudnn
=
False
)
else
:
softmax_out
=
fluid
.
layers
.
softmax
(
net_out
,
use_cudnn
=
False
)
if
not
args
.
use_label_smoothing
:
loss_a
=
fluid
.
layers
.
cross_entropy
(
input
=
softmax_out
,
label
=
y_a
)
loss_b
=
fluid
.
layers
.
cross_entropy
(
input
=
softmax_out
,
label
=
y_b
)
...
...
PaddleCV/image_classification/dali.py
浏览文件 @
832f1c74
...
...
@@ -43,7 +43,8 @@ class HybridTrainPipe(Pipeline):
num_shards
=
1
,
random_shuffle
=
True
,
num_threads
=
4
,
seed
=
42
):
seed
=
42
,
output_dtype
=
types
.
FLOAT
):
super
(
HybridTrainPipe
,
self
).
__init__
(
batch_size
,
num_threads
,
device_id
,
seed
=
seed
)
self
.
input
=
ops
.
FileReader
(
...
...
@@ -68,7 +69,7 @@ class HybridTrainPipe(Pipeline):
device
=
'gpu'
,
resize_x
=
crop
,
resize_y
=
crop
,
interp_type
=
interp
)
self
.
cmnp
=
ops
.
CropMirrorNormalize
(
device
=
"gpu"
,
output_dtype
=
types
.
FLOAT
,
output_dtype
=
output_dtype
,
output_layout
=
types
.
NCHW
,
crop
=
(
crop
,
crop
),
image_type
=
types
.
RGB
,
...
...
@@ -104,7 +105,8 @@ class HybridValPipe(Pipeline):
num_shards
=
1
,
random_shuffle
=
False
,
num_threads
=
4
,
seed
=
42
):
seed
=
42
,
output_dtype
=
types
.
FLOAT
):
super
(
HybridValPipe
,
self
).
__init__
(
batch_size
,
num_threads
,
device_id
,
seed
=
seed
)
self
.
input
=
ops
.
FileReader
(
...
...
@@ -118,7 +120,7 @@ class HybridValPipe(Pipeline):
device
=
"gpu"
,
resize_shorter
=
resize_shorter
,
interp_type
=
interp
)
self
.
cmnp
=
ops
.
CropMirrorNormalize
(
device
=
"gpu"
,
output_dtype
=
types
.
FLOAT
,
output_dtype
=
output_dtype
,
output_layout
=
types
.
NCHW
,
crop
=
(
crop
,
crop
),
image_type
=
types
.
RGB
,
...
...
@@ -159,6 +161,7 @@ def build(settings, mode='train'):
min_area
=
settings
.
lower_scale
lower
=
settings
.
lower_ratio
upper
=
settings
.
upper_ratio
output_dtype
=
types
.
FLOAT16
if
settings
.
use_pure_fp16
else
types
.
FLOAT
interp
=
settings
.
interpolation
or
1
# default to linear
interp_map
=
{
...
...
@@ -188,7 +191,8 @@ def build(settings, mode='train'):
interp
,
mean
,
std
,
device_id
=
device_id
)
device_id
=
device_id
,
output_dtype
=
output_dtype
)
pipe
.
build
()
return
DALIGenericIterator
(
pipe
,
[
'feed_image'
,
'feed_label'
],
...
...
@@ -221,7 +225,8 @@ def build(settings, mode='train'):
device_id
,
shard_id
,
num_shards
,
seed
=
42
+
shard_id
)
seed
=
42
+
shard_id
,
output_dtype
=
output_dtype
)
pipe
.
build
()
pipelines
=
[
pipe
]
sample_per_shard
=
len
(
pipe
)
//
num_shards
...
...
@@ -248,7 +253,8 @@ def build(settings, mode='train'):
device_id
,
idx
,
num_shards
,
seed
=
42
+
idx
)
seed
=
42
+
idx
,
output_dtype
=
output_dtype
)
pipe
.
build
()
pipelines
.
append
(
pipe
)
sample_per_shard
=
len
(
pipelines
[
0
])
...
...
PaddleCV/image_classification/scripts/train/ResNet50_fp16.sh
浏览文件 @
832f1c74
...
...
@@ -7,7 +7,8 @@ export FLAGS_cudnn_batchnorm_spatial_persistent=1
DATA_DIR
=
"Your image dataset path, e.g. /work/datasets/ILSVRC2012/"
DATA_FORMAT
=
"NHWC"
USE_FP16
=
true
#whether to use float16
USE_AMP
=
true
#whether to use amp
USE_PURE_FP16
=
false
USE_DALI
=
true
if
${
USE_DALI
}
;
then
...
...
@@ -24,7 +25,8 @@ python train.py \
--print_step
=
10
\
--model_save_dir
=
output/
\
--lr_strategy
=
piecewise_decay
\
--use_fp16
=
${
USE_FP16
}
\
--use_amp
=
${
USE_AMP
}
\
--use_pure_fp16
=
${
USE_PURE_FP16
}
\
--scale_loss
=
128.0
\
--use_dynamic_loss_scaling
=
true
\
--data_format
=
${
DATA_FORMAT
}
\
...
...
PaddleCV/image_classification/train.py
浏览文件 @
832f1c74
...
...
@@ -29,6 +29,7 @@ import reader
from
utils
import
*
import
models
from
build_model
import
create_model
from
paddle.fluid.contrib.mixed_precision.fp16_utils
import
cast_parameters_to_fp16
logging
.
basicConfig
(
level
=
logging
.
INFO
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -72,7 +73,7 @@ def build_program(is_train, main_prog, startup_prog, args):
global_lr
.
persistable
=
True
loss_out
.
append
(
global_lr
)
if
args
.
use_
fp16
:
if
args
.
use_
amp
:
optimizer
=
fluid
.
contrib
.
mixed_precision
.
decorate
(
optimizer
,
init_loss_scaling
=
args
.
scale_loss
,
...
...
@@ -192,6 +193,8 @@ def train(args):
place
=
fluid
.
CUDAPlace
(
gpu_id
)
if
args
.
use_gpu
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup_prog
)
if
args
.
use_pure_fp16
:
cast_parameters_to_fp16
(
exe
,
train_prog
)
trainer_id
=
int
(
os
.
getenv
(
"PADDLE_TRAINER_ID"
,
0
))
...
...
PaddleCV/image_classification/utils/utility.py
浏览文件 @
832f1c74
...
...
@@ -139,7 +139,8 @@ def parse_args():
# SWITCH
add_arg
(
'validate'
,
bool
,
True
,
"whether to validate when training."
)
add_arg
(
'use_fp16'
,
bool
,
False
,
"Whether to enable half precision training with fp16."
)
add_arg
(
'use_amp'
,
bool
,
False
,
"Whether to enable mixed precision training with fp16."
)
add_arg
(
'use_pure_fp16'
,
bool
,
False
,
"Whether to enable all half precision training with fp16."
)
add_arg
(
'scale_loss'
,
float
,
1.0
,
"The value of scale_loss for fp16."
)
add_arg
(
'use_dynamic_loss_scaling'
,
bool
,
True
,
"Whether to use dynamic loss scaling."
)
add_arg
(
'data_format'
,
str
,
"NCHW"
,
"Tensor data format when training."
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录