Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
c93b4a17
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1529
Star
32963
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c93b4a17
编写于
11月 06, 2020
作者:
D
dyning
提交者:
GitHub
11月 06, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1123 from WenmuZhou/dygraph_rc
fix some error and make some change
上级
96c91907
a414dd86
变更
15
显示空白变更内容
内联
并排
Showing
15 changed file
with
65 addition
and
73 deletion
+65
-73
configs/det/det_mv3_db.yml
configs/det/det_mv3_db.yml
+2
-2
configs/rec/rec_mv3_none_bilstm_ctc.yml
configs/rec/rec_mv3_none_bilstm_ctc.yml
+8
-11
ppocr/data/lmdb_dataset.py
ppocr/data/lmdb_dataset.py
+0
-4
ppocr/data/simple_dataset.py
ppocr/data/simple_dataset.py
+0
-3
ppocr/metrics/__init__.py
ppocr/metrics/__init__.py
+2
-2
ppocr/metrics/det_metric.py
ppocr/metrics/det_metric.py
+0
-0
ppocr/metrics/rec_metric.py
ppocr/metrics/rec_metric.py
+0
-0
ppocr/modeling/heads/det_db_head.py
ppocr/modeling/heads/det_db_head.py
+2
-2
ppocr/modeling/necks/db_fpn.py
ppocr/modeling/necks/db_fpn.py
+10
-7
ppocr/optimizer/__init__.py
ppocr/optimizer/__init__.py
+2
-3
ppocr/optimizer/learning_rate.py
ppocr/optimizer/learning_rate.py
+23
-21
ppocr/utils/logging.py
ppocr/utils/logging.py
+1
-1
tools/program.py
tools/program.py
+10
-11
tools/train.py
tools/train.py
+4
-5
train.sh
train.sh
+1
-1
未找到文件。
configs/det/det_mv3_db.yml
浏览文件 @
c93b4a17
...
@@ -44,9 +44,9 @@ Optimizer:
...
@@ -44,9 +44,9 @@ Optimizer:
name
:
Adam
name
:
Adam
beta1
:
0.9
beta1
:
0.9
beta2
:
0.999
beta2
:
0.999
l
earning_rate
:
l
r
:
# name: Cosine
# name: Cosine
l
r
:
0.001
l
earning_rate
:
0.001
# warmup_epoch: 0
# warmup_epoch: 0
regularizer
:
regularizer
:
name
:
'
L2'
name
:
'
L2'
...
...
configs/rec/rec_mv3_none_bilstm_ctc.yml
浏览文件 @
c93b4a17
...
@@ -6,7 +6,7 @@ Global:
...
@@ -6,7 +6,7 @@ Global:
save_model_dir
:
./output/rec/mv3_none_bilstm_ctc/
save_model_dir
:
./output/rec/mv3_none_bilstm_ctc/
save_epoch_step
:
3
save_epoch_step
:
3
# evaluation is run every 5000 iterations after the 4000th iteration
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step
:
[
0
,
1
000
]
eval_batch_step
:
[
0
,
2
000
]
# if pretrained_model is saved in static mode, load_static_weights must set to True
# if pretrained_model is saved in static mode, load_static_weights must set to True
cal_metric_during_train
:
True
cal_metric_during_train
:
True
pretrained_model
:
pretrained_model
:
...
@@ -18,22 +18,19 @@ Global:
...
@@ -18,22 +18,19 @@ Global:
character_dict_path
:
character_dict_path
:
character_type
:
en
character_type
:
en
max_text_length
:
25
max_text_length
:
25
loss_type
:
ctc
infer_mode
:
False
infer_mode
:
False
# use_space_char: True
use_space_char
:
False
# use_tps: False
Optimizer
:
Optimizer
:
name
:
Adam
name
:
Adam
beta1
:
0.9
beta1
:
0.9
beta2
:
0.999
beta2
:
0.999
l
earning_rate
:
l
r
:
l
r
:
0.0005
l
earning_rate
:
0.0005
regularizer
:
regularizer
:
name
:
'
L2'
name
:
'
L2'
factor
:
0
.00001
factor
:
0
Architecture
:
Architecture
:
model_type
:
rec
model_type
:
rec
...
@@ -49,7 +46,7 @@ Architecture:
...
@@ -49,7 +46,7 @@ Architecture:
hidden_size
:
96
hidden_size
:
96
Head
:
Head
:
name
:
CTCHead
name
:
CTCHead
fc_decay
:
0
.0004
fc_decay
:
0
Loss
:
Loss
:
name
:
CTCLoss
name
:
CTCLoss
...
@@ -75,8 +72,8 @@ Train:
...
@@ -75,8 +72,8 @@ Train:
-
KeepKeys
:
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
label'
,
'
length'
]
# dataloader will return list in this order
keep_keys
:
[
'
image'
,
'
label'
,
'
length'
]
# dataloader will return list in this order
loader
:
loader
:
shuffle
:
True
batch_size_per_card
:
256
batch_size_per_card
:
256
shuffle
:
False
drop_last
:
True
drop_last
:
True
num_workers
:
8
num_workers
:
8
...
@@ -97,4 +94,4 @@ Eval:
...
@@ -97,4 +94,4 @@ Eval:
shuffle
:
False
shuffle
:
False
drop_last
:
False
drop_last
:
False
batch_size_per_card
:
256
batch_size_per_card
:
256
num_workers
:
2
num_workers
:
4
ppocr/data/lmdb_dataset.py
浏览文件 @
c93b4a17
...
@@ -11,13 +11,9 @@
...
@@ -11,13 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
copy
import
numpy
as
np
import
numpy
as
np
import
os
import
os
import
random
import
paddle
from
paddle.io
import
Dataset
from
paddle.io
import
Dataset
import
time
import
lmdb
import
lmdb
import
cv2
import
cv2
...
...
ppocr/data/simple_dataset.py
浏览文件 @
c93b4a17
...
@@ -11,13 +11,10 @@
...
@@ -11,13 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
copy
import
numpy
as
np
import
numpy
as
np
import
os
import
os
import
random
import
random
import
paddle
from
paddle.io
import
Dataset
from
paddle.io
import
Dataset
import
time
from
.imaug
import
transform
,
create_operators
from
.imaug
import
transform
,
create_operators
...
...
ppocr/metrics/__init__.py
浏览文件 @
c93b4a17
...
@@ -23,8 +23,8 @@ __all__ = ['build_metric']
...
@@ -23,8 +23,8 @@ __all__ = ['build_metric']
def
build_metric
(
config
):
def
build_metric
(
config
):
from
.
DetM
etric
import
DetMetric
from
.
det_m
etric
import
DetMetric
from
.
RecM
etric
import
RecMetric
from
.
rec_m
etric
import
RecMetric
support_dict
=
[
'DetMetric'
,
'RecMetric'
]
support_dict
=
[
'DetMetric'
,
'RecMetric'
]
...
...
ppocr/metrics/
DetM
etric.py
→
ppocr/metrics/
det_m
etric.py
浏览文件 @
c93b4a17
文件已移动
ppocr/metrics/
RecM
etric.py
→
ppocr/metrics/
rec_m
etric.py
浏览文件 @
c93b4a17
文件已移动
ppocr/modeling/heads/det_db_head.py
浏览文件 @
c93b4a17
...
@@ -58,7 +58,7 @@ class Head(nn.Layer):
...
@@ -58,7 +58,7 @@ class Head(nn.Layer):
stride
=
2
,
stride
=
2
,
weight_attr
=
ParamAttr
(
weight_attr
=
ParamAttr
(
name
=
name_list
[
2
]
+
'.w_0'
,
name
=
name_list
[
2
]
+
'.w_0'
,
initializer
=
paddle
.
nn
.
initializer
.
Kaiming
Normal
()),
initializer
=
paddle
.
nn
.
initializer
.
Kaiming
Uniform
()),
bias_attr
=
get_bias_attr
(
in_channels
//
4
,
name_list
[
-
1
]
+
"conv2"
))
bias_attr
=
get_bias_attr
(
in_channels
//
4
,
name_list
[
-
1
]
+
"conv2"
))
self
.
conv_bn2
=
nn
.
BatchNorm
(
self
.
conv_bn2
=
nn
.
BatchNorm
(
num_channels
=
in_channels
//
4
,
num_channels
=
in_channels
//
4
,
...
@@ -78,7 +78,7 @@ class Head(nn.Layer):
...
@@ -78,7 +78,7 @@ class Head(nn.Layer):
stride
=
2
,
stride
=
2
,
weight_attr
=
ParamAttr
(
weight_attr
=
ParamAttr
(
name
=
name_list
[
4
]
+
'.w_0'
,
name
=
name_list
[
4
]
+
'.w_0'
,
initializer
=
paddle
.
nn
.
initializer
.
Kaiming
Normal
()),
initializer
=
paddle
.
nn
.
initializer
.
Kaiming
Uniform
()),
bias_attr
=
get_bias_attr
(
in_channels
//
4
,
name_list
[
-
1
]
+
"conv3"
),
bias_attr
=
get_bias_attr
(
in_channels
//
4
,
name_list
[
-
1
]
+
"conv3"
),
)
)
...
...
ppocr/modeling/necks/db_fpn.py
浏览文件 @
c93b4a17
...
@@ -26,7 +26,7 @@ class DBFPN(nn.Layer):
...
@@ -26,7 +26,7 @@ class DBFPN(nn.Layer):
def
__init__
(
self
,
in_channels
,
out_channels
,
**
kwargs
):
def
__init__
(
self
,
in_channels
,
out_channels
,
**
kwargs
):
super
(
DBFPN
,
self
).
__init__
()
super
(
DBFPN
,
self
).
__init__
()
self
.
out_channels
=
out_channels
self
.
out_channels
=
out_channels
weight_attr
=
paddle
.
nn
.
initializer
.
Kaiming
Normal
()
weight_attr
=
paddle
.
nn
.
initializer
.
Kaiming
Uniform
()
self
.
in2_conv
=
nn
.
Conv2D
(
self
.
in2_conv
=
nn
.
Conv2D
(
in_channels
=
in_channels
[
0
],
in_channels
=
in_channels
[
0
],
...
@@ -97,17 +97,20 @@ class DBFPN(nn.Layer):
...
@@ -97,17 +97,20 @@ class DBFPN(nn.Layer):
in3
=
self
.
in3_conv
(
c3
)
in3
=
self
.
in3_conv
(
c3
)
in2
=
self
.
in2_conv
(
c2
)
in2
=
self
.
in2_conv
(
c2
)
out4
=
in4
+
F
.
upsample
(
in5
,
scale_factor
=
2
,
mode
=
"nearest"
)
# 1/16
out4
=
in4
+
F
.
upsample
(
out3
=
in3
+
F
.
upsample
(
out4
,
scale_factor
=
2
,
mode
=
"nearest"
)
# 1/8
in5
,
scale_factor
=
2
,
mode
=
"nearest"
,
align_mode
=
1
)
# 1/16
out2
=
in2
+
F
.
upsample
(
out3
,
scale_factor
=
2
,
mode
=
"nearest"
)
# 1/4
out3
=
in3
+
F
.
upsample
(
out4
,
scale_factor
=
2
,
mode
=
"nearest"
,
align_mode
=
1
)
# 1/8
out2
=
in2
+
F
.
upsample
(
out3
,
scale_factor
=
2
,
mode
=
"nearest"
,
align_mode
=
1
)
# 1/4
p5
=
self
.
p5_conv
(
in5
)
p5
=
self
.
p5_conv
(
in5
)
p4
=
self
.
p4_conv
(
out4
)
p4
=
self
.
p4_conv
(
out4
)
p3
=
self
.
p3_conv
(
out3
)
p3
=
self
.
p3_conv
(
out3
)
p2
=
self
.
p2_conv
(
out2
)
p2
=
self
.
p2_conv
(
out2
)
p5
=
F
.
upsample
(
p5
,
scale_factor
=
8
,
mode
=
"nearest"
)
p5
=
F
.
upsample
(
p5
,
scale_factor
=
8
,
mode
=
"nearest"
,
align_mode
=
1
)
p4
=
F
.
upsample
(
p4
,
scale_factor
=
4
,
mode
=
"nearest"
)
p4
=
F
.
upsample
(
p4
,
scale_factor
=
4
,
mode
=
"nearest"
,
align_mode
=
1
)
p3
=
F
.
upsample
(
p3
,
scale_factor
=
2
,
mode
=
"nearest"
)
p3
=
F
.
upsample
(
p3
,
scale_factor
=
2
,
mode
=
"nearest"
,
align_mode
=
1
)
fuse
=
paddle
.
concat
([
p5
,
p4
,
p3
,
p2
],
axis
=
1
)
fuse
=
paddle
.
concat
([
p5
,
p4
,
p3
,
p2
],
axis
=
1
)
return
fuse
return
fuse
ppocr/optimizer/__init__.py
浏览文件 @
c93b4a17
...
@@ -29,7 +29,7 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
...
@@ -29,7 +29,7 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
lr_name
=
lr_config
.
pop
(
'name'
)
lr_name
=
lr_config
.
pop
(
'name'
)
lr
=
getattr
(
learning_rate
,
lr_name
)(
**
lr_config
)()
lr
=
getattr
(
learning_rate
,
lr_name
)(
**
lr_config
)()
else
:
else
:
lr
=
lr_config
[
'l
r
'
]
lr
=
lr_config
[
'l
earning_rate
'
]
return
lr
return
lr
...
@@ -37,8 +37,7 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
...
@@ -37,8 +37,7 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
from
.
import
regularizer
,
optimizer
from
.
import
regularizer
,
optimizer
config
=
copy
.
deepcopy
(
config
)
config
=
copy
.
deepcopy
(
config
)
# step1 build lr
# step1 build lr
lr
=
build_lr_scheduler
(
lr
=
build_lr_scheduler
(
config
.
pop
(
'lr'
),
epochs
,
step_each_epoch
)
config
.
pop
(
'learning_rate'
),
epochs
,
step_each_epoch
)
# step2 build regularization
# step2 build regularization
if
'regularizer'
in
config
and
config
[
'regularizer'
]
is
not
None
:
if
'regularizer'
in
config
and
config
[
'regularizer'
]
is
not
None
:
...
...
ppocr/optimizer/learning_rate.py
浏览文件 @
c93b4a17
...
@@ -17,7 +17,7 @@ from __future__ import division
...
@@ -17,7 +17,7 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
from
__future__
import
unicode_literals
from
__future__
import
unicode_literals
from
paddle.optimizer
import
lr
as
lr_scheduler
from
paddle.optimizer
import
lr
class
Linear
(
object
):
class
Linear
(
object
):
...
@@ -32,7 +32,7 @@ class Linear(object):
...
@@ -32,7 +32,7 @@ class Linear(object):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
l
r
,
l
earning_rate
,
epochs
,
epochs
,
step_each_epoch
,
step_each_epoch
,
end_lr
=
0.0
,
end_lr
=
0.0
,
...
@@ -41,7 +41,7 @@ class Linear(object):
...
@@ -41,7 +41,7 @@ class Linear(object):
last_epoch
=-
1
,
last_epoch
=-
1
,
**
kwargs
):
**
kwargs
):
super
(
Linear
,
self
).
__init__
()
super
(
Linear
,
self
).
__init__
()
self
.
l
r
=
lr
self
.
l
earning_rate
=
learning_rate
self
.
epochs
=
epochs
*
step_each_epoch
self
.
epochs
=
epochs
*
step_each_epoch
self
.
end_lr
=
end_lr
self
.
end_lr
=
end_lr
self
.
power
=
power
self
.
power
=
power
...
@@ -49,18 +49,18 @@ class Linear(object):
...
@@ -49,18 +49,18 @@ class Linear(object):
self
.
warmup_epoch
=
warmup_epoch
*
step_each_epoch
self
.
warmup_epoch
=
warmup_epoch
*
step_each_epoch
def
__call__
(
self
):
def
__call__
(
self
):
learning_rate
=
lr
_scheduler
.
PolynomialLR
(
learning_rate
=
lr
.
PolynomialDecay
(
learning_rate
=
self
.
l
r
,
learning_rate
=
self
.
l
earning_rate
,
decay_steps
=
self
.
epochs
,
decay_steps
=
self
.
epochs
,
end_lr
=
self
.
end_lr
,
end_lr
=
self
.
end_lr
,
power
=
self
.
power
,
power
=
self
.
power
,
last_epoch
=
self
.
last_epoch
)
last_epoch
=
self
.
last_epoch
)
if
self
.
warmup_epoch
>
0
:
if
self
.
warmup_epoch
>
0
:
learning_rate
=
lr
_scheduler
.
LinearL
rWarmup
(
learning_rate
=
lr
.
Linea
rWarmup
(
learning_rate
=
learning_rate
,
learning_rate
=
learning_rate
,
warmup_steps
=
self
.
warmup_epoch
,
warmup_steps
=
self
.
warmup_epoch
,
start_lr
=
0.0
,
start_lr
=
0.0
,
end_lr
=
self
.
l
r
,
end_lr
=
self
.
l
earning_rate
,
last_epoch
=
self
.
last_epoch
)
last_epoch
=
self
.
last_epoch
)
return
learning_rate
return
learning_rate
...
@@ -77,27 +77,29 @@ class Cosine(object):
...
@@ -77,27 +77,29 @@ class Cosine(object):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
l
r
,
l
earning_rate
,
step_each_epoch
,
step_each_epoch
,
epochs
,
epochs
,
warmup_epoch
=
0
,
warmup_epoch
=
0
,
last_epoch
=-
1
,
last_epoch
=-
1
,
**
kwargs
):
**
kwargs
):
super
(
Cosine
,
self
).
__init__
()
super
(
Cosine
,
self
).
__init__
()
self
.
l
r
=
lr
self
.
l
earning_rate
=
learning_rate
self
.
T_max
=
step_each_epoch
*
epochs
self
.
T_max
=
step_each_epoch
*
epochs
self
.
last_epoch
=
last_epoch
self
.
last_epoch
=
last_epoch
self
.
warmup_epoch
=
warmup_epoch
*
step_each_epoch
self
.
warmup_epoch
=
warmup_epoch
*
step_each_epoch
def
__call__
(
self
):
def
__call__
(
self
):
learning_rate
=
lr_scheduler
.
CosineAnnealingLR
(
learning_rate
=
lr
.
CosineAnnealingDecay
(
learning_rate
=
self
.
lr
,
T_max
=
self
.
T_max
,
last_epoch
=
self
.
last_epoch
)
learning_rate
=
self
.
learning_rate
,
T_max
=
self
.
T_max
,
last_epoch
=
self
.
last_epoch
)
if
self
.
warmup_epoch
>
0
:
if
self
.
warmup_epoch
>
0
:
learning_rate
=
lr
_scheduler
.
LinearL
rWarmup
(
learning_rate
=
lr
.
Linea
rWarmup
(
learning_rate
=
learning_rate
,
learning_rate
=
learning_rate
,
warmup_steps
=
self
.
warmup_epoch
,
warmup_steps
=
self
.
warmup_epoch
,
start_lr
=
0.0
,
start_lr
=
0.0
,
end_lr
=
self
.
l
r
,
end_lr
=
self
.
l
earning_rate
,
last_epoch
=
self
.
last_epoch
)
last_epoch
=
self
.
last_epoch
)
return
learning_rate
return
learning_rate
...
@@ -115,7 +117,7 @@ class Step(object):
...
@@ -115,7 +117,7 @@ class Step(object):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
l
r
,
l
earning_rate
,
step_size
,
step_size
,
step_each_epoch
,
step_each_epoch
,
gamma
,
gamma
,
...
@@ -124,23 +126,23 @@ class Step(object):
...
@@ -124,23 +126,23 @@ class Step(object):
**
kwargs
):
**
kwargs
):
super
(
Step
,
self
).
__init__
()
super
(
Step
,
self
).
__init__
()
self
.
step_size
=
step_each_epoch
*
step_size
self
.
step_size
=
step_each_epoch
*
step_size
self
.
l
r
=
lr
self
.
l
earning_rate
=
learning_rate
self
.
gamma
=
gamma
self
.
gamma
=
gamma
self
.
last_epoch
=
last_epoch
self
.
last_epoch
=
last_epoch
self
.
warmup_epoch
=
warmup_epoch
*
step_each_epoch
self
.
warmup_epoch
=
warmup_epoch
*
step_each_epoch
def
__call__
(
self
):
def
__call__
(
self
):
learning_rate
=
lr
_scheduler
.
StepLR
(
learning_rate
=
lr
.
StepDecay
(
learning_rate
=
self
.
l
r
,
learning_rate
=
self
.
l
earning_rate
,
step_size
=
self
.
step_size
,
step_size
=
self
.
step_size
,
gamma
=
self
.
gamma
,
gamma
=
self
.
gamma
,
last_epoch
=
self
.
last_epoch
)
last_epoch
=
self
.
last_epoch
)
if
self
.
warmup_epoch
>
0
:
if
self
.
warmup_epoch
>
0
:
learning_rate
=
lr
_scheduler
.
LinearL
rWarmup
(
learning_rate
=
lr
.
Linea
rWarmup
(
learning_rate
=
learning_rate
,
learning_rate
=
learning_rate
,
warmup_steps
=
self
.
warmup_epoch
,
warmup_steps
=
self
.
warmup_epoch
,
start_lr
=
0.0
,
start_lr
=
0.0
,
end_lr
=
self
.
l
r
,
end_lr
=
self
.
l
earning_rate
,
last_epoch
=
self
.
last_epoch
)
last_epoch
=
self
.
last_epoch
)
return
learning_rate
return
learning_rate
...
@@ -169,12 +171,12 @@ class Piecewise(object):
...
@@ -169,12 +171,12 @@ class Piecewise(object):
self
.
warmup_epoch
=
warmup_epoch
*
step_each_epoch
self
.
warmup_epoch
=
warmup_epoch
*
step_each_epoch
def
__call__
(
self
):
def
__call__
(
self
):
learning_rate
=
lr
_scheduler
.
PiecewiseLR
(
learning_rate
=
lr
.
PiecewiseDecay
(
boundaries
=
self
.
boundaries
,
boundaries
=
self
.
boundaries
,
values
=
self
.
values
,
values
=
self
.
values
,
last_epoch
=
self
.
last_epoch
)
last_epoch
=
self
.
last_epoch
)
if
self
.
warmup_epoch
>
0
:
if
self
.
warmup_epoch
>
0
:
learning_rate
=
lr
_scheduler
.
LinearL
rWarmup
(
learning_rate
=
lr
.
Linea
rWarmup
(
learning_rate
=
learning_rate
,
learning_rate
=
learning_rate
,
warmup_steps
=
self
.
warmup_epoch
,
warmup_steps
=
self
.
warmup_epoch
,
start_lr
=
0.0
,
start_lr
=
0.0
,
...
...
ppocr/utils/logging.py
浏览文件 @
c93b4a17
...
@@ -22,7 +22,7 @@ logger_initialized = {}
...
@@ -22,7 +22,7 @@ logger_initialized = {}
@
functools
.
lru_cache
()
@
functools
.
lru_cache
()
def
get_logger
(
name
=
'
ppocr
'
,
log_file
=
None
,
log_level
=
logging
.
INFO
):
def
get_logger
(
name
=
'
root
'
,
log_file
=
None
,
log_level
=
logging
.
INFO
):
"""Initialize and get a logger by name.
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
logger by adding one or two handlers, otherwise the initialized logger will
...
...
tools/program.py
浏览文件 @
c93b4a17
...
@@ -152,7 +152,6 @@ def train(config,
...
@@ -152,7 +152,6 @@ def train(config,
pre_best_model_dict
,
pre_best_model_dict
,
logger
,
logger
,
vdl_writer
=
None
):
vdl_writer
=
None
):
cal_metric_during_train
=
config
[
'Global'
].
get
(
'cal_metric_during_train'
,
cal_metric_during_train
=
config
[
'Global'
].
get
(
'cal_metric_during_train'
,
False
)
False
)
log_smooth_window
=
config
[
'Global'
][
'log_smooth_window'
]
log_smooth_window
=
config
[
'Global'
][
'log_smooth_window'
]
...
@@ -185,14 +184,13 @@ def train(config,
...
@@ -185,14 +184,13 @@ def train(config,
for
epoch
in
range
(
start_epoch
,
epoch_num
):
for
epoch
in
range
(
start_epoch
,
epoch_num
):
if
epoch
>
0
:
if
epoch
>
0
:
train_
loader
=
build_dataloader
(
config
,
'Train'
,
device
)
train_
dataloader
=
build_dataloader
(
config
,
'Train'
,
device
,
logger
)
for
idx
,
batch
in
enumerate
(
train_dataloader
):
for
idx
,
batch
in
enumerate
(
train_dataloader
):
if
idx
>=
len
(
train_dataloader
):
if
idx
>=
len
(
train_dataloader
):
break
break
lr
=
optimizer
.
get_lr
()
lr
=
optimizer
.
get_lr
()
t1
=
time
.
time
()
t1
=
time
.
time
()
batch
=
[
paddle
.
to_tensor
(
x
)
for
x
in
batch
]
images
=
batch
[
0
]
images
=
batch
[
0
]
preds
=
model
(
images
)
preds
=
model
(
images
)
loss
=
loss_class
(
preds
,
batch
)
loss
=
loss_class
(
preds
,
batch
)
...
@@ -301,11 +299,11 @@ def eval(model, valid_dataloader, post_process_class, eval_class, logger,
...
@@ -301,11 +299,11 @@ def eval(model, valid_dataloader, post_process_class, eval_class, logger,
with
paddle
.
no_grad
():
with
paddle
.
no_grad
():
total_frame
=
0.0
total_frame
=
0.0
total_time
=
0.0
total_time
=
0.0
#
pbar = tqdm(total=len(valid_dataloader), desc='eval model:')
pbar
=
tqdm
(
total
=
len
(
valid_dataloader
),
desc
=
'eval model:'
)
for
idx
,
batch
in
enumerate
(
valid_dataloader
):
for
idx
,
batch
in
enumerate
(
valid_dataloader
):
if
idx
>=
len
(
valid_dataloader
):
if
idx
>=
len
(
valid_dataloader
):
break
break
images
=
paddle
.
to_tensor
(
batch
[
0
])
images
=
batch
[
0
]
start
=
time
.
time
()
start
=
time
.
time
()
preds
=
model
(
images
)
preds
=
model
(
images
)
...
@@ -315,15 +313,15 @@ def eval(model, valid_dataloader, post_process_class, eval_class, logger,
...
@@ -315,15 +313,15 @@ def eval(model, valid_dataloader, post_process_class, eval_class, logger,
total_time
+=
time
.
time
()
-
start
total_time
+=
time
.
time
()
-
start
# Evaluate the results of the current batch
# Evaluate the results of the current batch
eval_class
(
post_result
,
batch
)
eval_class
(
post_result
,
batch
)
#
pbar.update(1)
pbar
.
update
(
1
)
total_frame
+=
len
(
images
)
total_frame
+=
len
(
images
)
if
idx
%
print_batch_step
==
0
and
dist
.
get_rank
()
==
0
:
#
if idx % print_batch_step == 0 and dist.get_rank() == 0:
logger
.
info
(
'tackling images for eval: {}/{}'
.
format
(
#
logger.info('tackling images for eval: {}/{}'.format(
idx
,
len
(
valid_dataloader
)))
#
idx, len(valid_dataloader)))
# Get final metirc,eg. acc or hmean
# Get final metirc,eg. acc or hmean
metirc
=
eval_class
.
get_metric
()
metirc
=
eval_class
.
get_metric
()
#
pbar.close()
pbar
.
close
()
model
.
train
()
model
.
train
()
metirc
[
'fps'
]
=
total_frame
/
total_time
metirc
[
'fps'
]
=
total_frame
/
total_time
return
metirc
return
metirc
...
@@ -354,7 +352,8 @@ def preprocess():
...
@@ -354,7 +352,8 @@ def preprocess():
with
open
(
os
.
path
.
join
(
save_model_dir
,
'config.yml'
),
'w'
)
as
f
:
with
open
(
os
.
path
.
join
(
save_model_dir
,
'config.yml'
),
'w'
)
as
f
:
yaml
.
dump
(
dict
(
config
),
f
,
default_flow_style
=
False
,
sort_keys
=
False
)
yaml
.
dump
(
dict
(
config
),
f
,
default_flow_style
=
False
,
sort_keys
=
False
)
logger
=
get_logger
(
log_file
=
'{}/train.log'
.
format
(
save_model_dir
))
logger
=
get_logger
(
name
=
'root'
,
log_file
=
'{}/train.log'
.
format
(
save_model_dir
))
if
config
[
'Global'
][
'use_visualdl'
]:
if
config
[
'Global'
][
'use_visualdl'
]:
from
visualdl
import
LogWriter
from
visualdl
import
LogWriter
vdl_writer_path
=
'{}/vdl/'
.
format
(
save_model_dir
)
vdl_writer_path
=
'{}/vdl/'
.
format
(
save_model_dir
)
...
...
tools/train.py
浏览文件 @
c93b4a17
...
@@ -36,7 +36,6 @@ from ppocr.optimizer import build_optimizer
...
@@ -36,7 +36,6 @@ from ppocr.optimizer import build_optimizer
from
ppocr.postprocess
import
build_post_process
from
ppocr.postprocess
import
build_post_process
from
ppocr.metrics
import
build_metric
from
ppocr.metrics
import
build_metric
from
ppocr.utils.save_load
import
init_model
from
ppocr.utils.save_load
import
init_model
from
ppocr.utils.utility
import
print_dict
import
tools.program
as
program
import
tools.program
as
program
dist
.
get_world_size
()
dist
.
get_world_size
()
...
@@ -61,7 +60,7 @@ def main(config, device, logger, vdl_writer):
...
@@ -61,7 +60,7 @@ def main(config, device, logger, vdl_writer):
global_config
)
global_config
)
# build model
# build model
#for rec algorithm
#
for rec algorithm
if
hasattr
(
post_process_class
,
'character'
):
if
hasattr
(
post_process_class
,
'character'
):
char_num
=
len
(
getattr
(
post_process_class
,
'character'
))
char_num
=
len
(
getattr
(
post_process_class
,
'character'
))
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
...
@@ -81,10 +80,11 @@ def main(config, device, logger, vdl_writer):
...
@@ -81,10 +80,11 @@ def main(config, device, logger, vdl_writer):
# build metric
# build metric
eval_class
=
build_metric
(
config
[
'Metric'
])
eval_class
=
build_metric
(
config
[
'Metric'
])
# load pretrain model
# load pretrain model
pre_best_model_dict
=
init_model
(
config
,
model
,
logger
,
optimizer
)
pre_best_model_dict
=
init_model
(
config
,
model
,
logger
,
optimizer
)
logger
.
info
(
'train dataloader has {} iters, valid dataloader has {} iters'
.
format
(
len
(
train_dataloader
),
len
(
valid_dataloader
)))
# start train
# start train
program
.
train
(
config
,
train_dataloader
,
valid_dataloader
,
device
,
model
,
program
.
train
(
config
,
train_dataloader
,
valid_dataloader
,
device
,
model
,
loss_class
,
optimizer
,
lr_scheduler
,
post_process_class
,
loss_class
,
optimizer
,
lr_scheduler
,
post_process_class
,
...
@@ -92,8 +92,7 @@ def main(config, device, logger, vdl_writer):
...
@@ -92,8 +92,7 @@ def main(config, device, logger, vdl_writer):
def
test_reader
(
config
,
device
,
logger
):
def
test_reader
(
config
,
device
,
logger
):
loader
=
build_dataloader
(
config
,
'Train'
,
device
)
loader
=
build_dataloader
(
config
,
'Train'
,
device
,
logger
)
# loader = build_dataloader(config, 'Eval', device)
import
time
import
time
starttime
=
time
.
time
()
starttime
=
time
.
time
()
count
=
0
count
=
0
...
...
train.sh
浏览文件 @
c93b4a17
python
-m
paddle.distributed.launch
--selected_gpus
'0,1,2,3,4,5,6,7'
tools/train.py
-c
configs/det/det_mv3_db.yml
python3
-m
paddle.distributed.launch
--selected_gpus
'0,1,2,3,4,5,6,7'
tools/train.py
-c
configs/rec/rec_mv3_none_bilstm_ctc.yml
\ No newline at end of file
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录