Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_41840029
PaddleOCR
提交
f1048e29
P
PaddleOCR
项目概览
weixin_41840029
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f1048e29
编写于
10月 20, 2020
作者:
D
dyning
提交者:
GitHub
10月 20, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #970 from WenmuZhou/dygraph
解决crnn训练时对labels进行合并的bug
上级
52b40f36
a88ce7a5
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
6 addition
and
57 deletion
+6
-57
ppocr/modeling/necks/rnn.py
ppocr/modeling/necks/rnn.py
+4
-56
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+2
-1
未找到文件。
ppocr/modeling/necks/rnn.py
浏览文件 @
f1048e29
...
...
@@ -21,18 +21,6 @@ from paddle import nn
from
ppocr.modeling.heads.rec_ctc_head
import
get_para_bias_attr
class
EncoderWithReshape
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
**
kwargs
):
super
().
__init__
()
self
.
out_channels
=
in_channels
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
x
=
x
.
reshape
((
B
,
C
,
-
1
))
x
=
x
.
transpose
([
0
,
2
,
1
])
# (NTC)(batch, width, channels)
return
x
class
Im2Seq
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
**
kwargs
):
super
().
__init__
()
...
...
@@ -40,9 +28,8 @@ class Im2Seq(nn.Layer):
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
assert
H
==
1
x
=
x
.
transpose
((
0
,
2
,
3
,
1
))
x
=
x
.
reshape
((
-
1
,
C
))
x
=
x
.
reshape
((
B
,
-
1
,
W
))
x
=
x
.
transpose
((
0
,
2
,
1
))
# (NTC)(batch, width, channels)
return
x
...
...
@@ -50,49 +37,10 @@ class EncoderWithRNN(nn.Layer):
def
__init__
(
self
,
in_channels
,
hidden_size
):
super
(
EncoderWithRNN
,
self
).
__init__
()
self
.
out_channels
=
hidden_size
*
2
# self.lstm1_fw = nn.LSTMCell(
# in_channels,
# hidden_size,
# weight_ih_attr=ParamAttr(name='lstm_st1_fc1_w'),
# bias_ih_attr=ParamAttr(name='lstm_st1_fc1_b'),
# weight_hh_attr=ParamAttr(name='lstm_st1_out1_w'),
# bias_hh_attr=ParamAttr(name='lstm_st1_out1_b'),
# )
# self.lstm1_bw = nn.LSTMCell(
# in_channels,
# hidden_size,
# weight_ih_attr=ParamAttr(name='lstm_st1_fc2_w'),
# bias_ih_attr=ParamAttr(name='lstm_st1_fc2_b'),
# weight_hh_attr=ParamAttr(name='lstm_st1_out2_w'),
# bias_hh_attr=ParamAttr(name='lstm_st1_out2_b'),
# )
# self.lstm2_fw = nn.LSTMCell(
# hidden_size,
# hidden_size,
# weight_ih_attr=ParamAttr(name='lstm_st2_fc1_w'),
# bias_ih_attr=ParamAttr(name='lstm_st2_fc1_b'),
# weight_hh_attr=ParamAttr(name='lstm_st2_out1_w'),
# bias_hh_attr=ParamAttr(name='lstm_st2_out1_b'),
# )
# self.lstm2_bw = nn.LSTMCell(
# hidden_size,
# hidden_size,
# weight_ih_attr=ParamAttr(name='lstm_st2_fc2_w'),
# bias_ih_attr=ParamAttr(name='lstm_st2_fc2_b'),
# weight_hh_attr=ParamAttr(name='lstm_st2_out2_w'),
# bias_hh_attr=ParamAttr(name='lstm_st2_out2_b'),
# )
self
.
lstm
=
nn
.
LSTM
(
in_channels
,
hidden_size
,
direction
=
'bidirectional'
,
num_layers
=
2
)
def
forward
(
self
,
x
):
# fw_x, _ = self.lstm1_fw(x)
# fw_x, _ = self.lstm2_fw(fw_x)
#
# # bw
# bw_x, _ = self.lstm1_bw(x)
# bw_x, _ = self.lstm2_bw(bw_x)
# x = paddle.concat([fw_x, bw_x], axis=2)
x
,
_
=
self
.
lstm
(
x
)
return
x
...
...
@@ -118,13 +66,13 @@ class EncoderWithFC(nn.Layer):
class
SequenceEncoder
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
encoder_type
,
hidden_size
=
48
,
**
kwargs
):
super
(
SequenceEncoder
,
self
).
__init__
()
self
.
encoder_reshape
=
EncoderWithReshape
(
in_channels
)
self
.
encoder_reshape
=
Im2Seq
(
in_channels
)
self
.
out_channels
=
self
.
encoder_reshape
.
out_channels
if
encoder_type
==
'reshape'
:
self
.
only_reshape
=
True
else
:
support_encoder_dict
=
{
'reshape'
:
EncoderWithReshape
,
'reshape'
:
Im2Seq
,
'fc'
:
EncoderWithFC
,
'rnn'
:
EncoderWithRNN
}
...
...
ppocr/postprocess/rec_postprocess.py
浏览文件 @
f1048e29
...
...
@@ -70,6 +70,7 @@ class BaseRecLabelDecode(object):
if
text_index
[
batch_idx
][
idx
]
in
ignored_tokens
:
continue
if
is_remove_duplicate
:
# only for predict
if
idx
>
0
and
text_index
[
batch_idx
][
idx
-
1
]
==
text_index
[
batch_idx
][
idx
]:
continue
...
...
@@ -107,7 +108,7 @@ class CTCLabelDecode(BaseRecLabelDecode):
text
=
self
.
decode
(
preds_idx
,
preds_prob
)
if
label
is
None
:
return
text
label
=
self
.
decode
(
label
)
label
=
self
.
decode
(
label
,
is_remove_duplicate
=
False
)
return
text
,
label
def
add_special_char
(
self
,
dict_character
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录