Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
ed2f0de9
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ed2f0de9
编写于
1月 22, 2021
作者:
T
tink2123
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
mv model_average to incubate
上级
93670ab5
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
12 addition
and
9 deletion
+12
-9
ppocr/losses/rec_srn_loss.py
ppocr/losses/rec_srn_loss.py
+1
-1
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+2
-2
tools/program.py
tools/program.py
+9
-6
未找到文件。
ppocr/losses/rec_srn_loss.py
浏览文件 @
ed2f0de9
...
@@ -42,6 +42,6 @@ class SRNLoss(nn.Layer):
...
@@ -42,6 +42,6 @@ class SRNLoss(nn.Layer):
cost_gsrm
=
paddle
.
reshape
(
x
=
paddle
.
sum
(
cost_gsrm
),
shape
=
[
1
])
cost_gsrm
=
paddle
.
reshape
(
x
=
paddle
.
sum
(
cost_gsrm
),
shape
=
[
1
])
cost_vsfd
=
paddle
.
reshape
(
x
=
paddle
.
sum
(
cost_vsfd
),
shape
=
[
1
])
cost_vsfd
=
paddle
.
reshape
(
x
=
paddle
.
sum
(
cost_vsfd
),
shape
=
[
1
])
sum_cost
=
cost_word
+
cost_vsfd
*
2.0
+
cost_gsrm
*
0.15
sum_cost
=
cost_word
*
3.0
+
cost_vsfd
+
cost_gsrm
*
0.15
return
{
'loss'
:
sum_cost
,
'word_loss'
:
cost_word
,
'img_loss'
:
cost_vsfd
}
return
{
'loss'
:
sum_cost
,
'word_loss'
:
cost_word
,
'img_loss'
:
cost_vsfd
}
ppocr/postprocess/rec_postprocess.py
浏览文件 @
ed2f0de9
...
@@ -182,12 +182,12 @@ class SRNLabelDecode(BaseRecLabelDecode):
...
@@ -182,12 +182,12 @@ class SRNLabelDecode(BaseRecLabelDecode):
preds_prob
=
np
.
reshape
(
preds_prob
,
[
-
1
,
25
])
preds_prob
=
np
.
reshape
(
preds_prob
,
[
-
1
,
25
])
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
True
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
)
if
label
is
None
:
if
label
is
None
:
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
False
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
False
)
return
text
return
text
label
=
self
.
decode
(
label
,
is_remove_duplicate
=
True
)
label
=
self
.
decode
(
label
)
return
text
,
label
return
text
,
label
def
decode
(
self
,
text_index
,
text_prob
=
None
,
is_remove_duplicate
=
False
):
def
decode
(
self
,
text_index
,
text_prob
=
None
,
is_remove_duplicate
=
False
):
...
...
tools/program.py
浏览文件 @
ed2f0de9
...
@@ -174,6 +174,7 @@ def train(config,
...
@@ -174,6 +174,7 @@ def train(config,
best_model_dict
=
{
main_indicator
:
0
}
best_model_dict
=
{
main_indicator
:
0
}
best_model_dict
.
update
(
pre_best_model_dict
)
best_model_dict
.
update
(
pre_best_model_dict
)
train_stats
=
TrainingStats
(
log_smooth_window
,
[
'lr'
])
train_stats
=
TrainingStats
(
log_smooth_window
,
[
'lr'
])
model_average
=
False
model
.
train
()
model
.
train
()
if
'start_epoch'
in
best_model_dict
:
if
'start_epoch'
in
best_model_dict
:
...
@@ -197,6 +198,7 @@ def train(config,
...
@@ -197,6 +198,7 @@ def train(config,
if
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
:
if
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
:
others
=
batch
[
-
4
:]
others
=
batch
[
-
4
:]
preds
=
model
(
images
,
others
)
preds
=
model
(
images
,
others
)
model_average
=
True
else
:
else
:
preds
=
model
(
images
)
preds
=
model
(
images
)
loss
=
loss_class
(
preds
,
batch
)
loss
=
loss_class
(
preds
,
batch
)
...
@@ -242,12 +244,13 @@ def train(config,
...
@@ -242,12 +244,13 @@ def train(config,
# eval
# eval
if
global_step
>
start_eval_step
and
\
if
global_step
>
start_eval_step
and
\
(
global_step
-
start_eval_step
)
%
eval_batch_step
==
0
and
dist
.
get_rank
()
==
0
:
(
global_step
-
start_eval_step
)
%
eval_batch_step
==
0
and
dist
.
get_rank
()
==
0
:
model_average
=
paddle
.
optimizer
.
ModelAverage
(
if
model_average
:
0.15
,
Model_Average
=
paddle
.
incubate
.
optimizer
.
ModelAverage
(
parameters
=
model
.
parameters
(),
0.15
,
min_average_window
=
10000
,
parameters
=
model
.
parameters
(),
max_average_window
=
15625
)
min_average_window
=
10000
,
model_average
.
apply
()
max_average_window
=
15625
)
Model_Average
.
apply
()
cur_metirc
=
eval
(
model
,
valid_dataloader
,
post_process_class
,
cur_metirc
=
eval
(
model
,
valid_dataloader
,
post_process_class
,
eval_class
)
eval_class
)
cur_metirc_str
=
'cur metirc, {}'
.
format
(
', '
.
join
(
cur_metirc_str
=
'cur metirc, {}'
.
format
(
', '
.
join
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录