Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
2005cc3e
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2005cc3e
编写于
10月 15, 2021
作者:
S
stephon
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add amp train
上级
6fc27265
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
172 addition
and
5 deletion
+172
-5
configs/det/det_mv3_db_amp.yml
configs/det/det_mv3_db_amp.yml
+135
-0
tools/program.py
tools/program.py
+20
-5
tools/train.py
tools/train.py
+17
-0
未找到文件。
configs/det/det_mv3_db_amp.yml
0 → 100644
浏览文件 @
2005cc3e
Global
:
use_gpu
:
true
epoch_num
:
1200
log_smooth_window
:
20
print_batch_step
:
10
save_model_dir
:
./output/db_mv3/
save_epoch_step
:
1200
# evaluation is run every 2000 iterations
eval_batch_step
:
[
0
,
2000
]
cal_metric_during_train
:
False
pretrained_model
:
./pretrain_models/MobileNetV3_large_x0_5_pretrained
checkpoints
:
save_inference_dir
:
use_visualdl
:
False
infer_img
:
doc/imgs_en/img_10.jpg
save_res_path
:
./output/det_db/predicts_db.txt
AMP
:
scale_loss
:
1024.0
use_dynamic_loss_scaling
:
True
Architecture
:
model_type
:
det
algorithm
:
DB
Transform
:
Backbone
:
name
:
MobileNetV3
scale
:
0.5
model_name
:
large
Neck
:
name
:
DBFPN
out_channels
:
256
Head
:
name
:
DBHead
k
:
50
Loss
:
name
:
DBLoss
balance_loss
:
true
main_loss_type
:
DiceLoss
alpha
:
5
beta
:
10
ohem_ratio
:
3
Optimizer
:
name
:
Adam
beta1
:
0.9
beta2
:
0.999
lr
:
learning_rate
:
0.001
regularizer
:
name
:
'
L2'
factor
:
0
PostProcess
:
name
:
DBPostProcess
thresh
:
0.3
box_thresh
:
0.6
max_candidates
:
1000
unclip_ratio
:
1.5
Metric
:
name
:
DetMetric
main_indicator
:
hmean
Train
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data/icdar2015/text_localization/
label_file_list
:
-
./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list
:
[
1.0
]
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
DetLabelEncode
:
# Class handling label
-
IaaAugment
:
augmenter_args
:
-
{
'
type'
:
Fliplr
,
'
args'
:
{
'
p'
:
0.5
}
}
-
{
'
type'
:
Affine
,
'
args'
:
{
'
rotate'
:
[
-10
,
10
]
}
}
-
{
'
type'
:
Resize
,
'
args'
:
{
'
size'
:
[
0.5
,
3
]
}
}
-
EastRandomCropData
:
size
:
[
640
,
640
]
max_tries
:
50
keep_ratio
:
true
-
MakeBorderMap
:
shrink_ratio
:
0.4
thresh_min
:
0.3
thresh_max
:
0.7
-
MakeShrinkMap
:
shrink_ratio
:
0.4
min_text_size
:
8
-
NormalizeImage
:
scale
:
1./255.
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
hwc'
-
ToCHWImage
:
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
threshold_map'
,
'
threshold_mask'
,
'
shrink_map'
,
'
shrink_mask'
]
# the order of the dataloader list
loader
:
shuffle
:
True
drop_last
:
False
batch_size_per_card
:
16
num_workers
:
8
use_shared_memory
:
False
Eval
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data/icdar2015/text_localization/
label_file_list
:
-
./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
DetLabelEncode
:
# Class handling label
-
DetResizeForTest
:
image_shape
:
[
736
,
1280
]
-
NormalizeImage
:
scale
:
1./255.
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
hwc'
-
ToCHWImage
:
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
shape'
,
'
polys'
,
'
ignore_tags'
]
loader
:
shuffle
:
False
drop_last
:
False
batch_size_per_card
:
1
# must be 1
num_workers
:
8
use_shared_memory
:
False
tools/program.py
浏览文件 @
2005cc3e
...
...
@@ -226,12 +226,27 @@ def train(config,
images
=
batch
[
0
]
if
use_srn
:
model_average
=
True
# use amp
if
scaler
:
with
paddle
.
amp
.
auto_cast
():
if
model_type
==
'table'
or
extra_input
:
preds
=
model
(
images
,
data
=
batch
[
1
:])
else
:
preds
=
model
(
images
)
else
:
if
model_type
==
'table'
or
extra_input
:
preds
=
model
(
images
,
data
=
batch
[
1
:])
else
:
preds
=
model
(
images
)
loss
=
loss_class
(
preds
,
batch
)
avg_loss
=
loss
[
'loss'
]
if
scaler
:
scaled_avg_loss
=
scaler
.
scale
(
avg_loss
)
scaled_avg_loss
.
backward
()
scaler
.
minimize
(
optimizer
,
scaled_avg_loss
)
else
:
avg_loss
.
backward
()
optimizer
.
step
()
optimizer
.
clear_grad
()
...
...
tools/train.py
浏览文件 @
2005cc3e
...
...
@@ -102,6 +102,23 @@ def main(config, device, logger, vdl_writer):
if
valid_dataloader
is
not
None
:
logger
.
info
(
'valid dataloader has {} iters'
.
format
(
len
(
valid_dataloader
)))
use_amp
=
True
if
"AMP"
in
config
else
False
if
use_amp
:
AMP_RELATED_FLAGS_SETTING
=
{
'FLAGS_cudnn_batchnorm_spatial_persistent'
:
1
,
'FLAGS_max_inplace_grad_add'
:
8
,
}
paddle
.
fluid
.
set_flags
(
AMP_RELATED_FLAGS_SETTING
)
scale_loss
=
config
[
"AMP"
].
get
(
"scale_loss"
,
1.0
)
use_dynamic_loss_scaling
=
config
[
"AMP"
].
get
(
"use_dynamic_loss_scaling"
,
False
)
scaler
=
paddle
.
amp
.
GradScaler
(
init_loss_scaling
=
scale_loss
,
use_dynamic_loss_scaling
=
use_dynamic_loss_scaling
)
else
:
scaler
=
None
# start train
program
.
train
(
config
,
train_dataloader
,
valid_dataloader
,
device
,
model
,
loss_class
,
optimizer
,
lr_scheduler
,
post_process_class
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录