Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
ca9ea622
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ca9ea622
编写于
10月 20, 2020
作者:
W
WenmuZhou
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
添加im2seq实现
上级
bdad0cef
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
4 addition
and
56 deletion
+4
-56
ppocr/modeling/necks/rnn.py
ppocr/modeling/necks/rnn.py
+4
-56
未找到文件。
ppocr/modeling/necks/rnn.py
浏览文件 @
ca9ea622
...
@@ -21,18 +21,6 @@ from paddle import nn
...
@@ -21,18 +21,6 @@ from paddle import nn
from
ppocr.modeling.heads.rec_ctc_head
import
get_para_bias_attr
from
ppocr.modeling.heads.rec_ctc_head
import
get_para_bias_attr
class
EncoderWithReshape
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
**
kwargs
):
super
().
__init__
()
self
.
out_channels
=
in_channels
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
x
=
x
.
reshape
((
B
,
C
,
-
1
))
x
=
x
.
transpose
([
0
,
2
,
1
])
# (NTC)(batch, width, channels)
return
x
class
Im2Seq
(
nn
.
Layer
):
class
Im2Seq
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
**
kwargs
):
def
__init__
(
self
,
in_channels
,
**
kwargs
):
super
().
__init__
()
super
().
__init__
()
...
@@ -40,9 +28,8 @@ class Im2Seq(nn.Layer):
...
@@ -40,9 +28,8 @@ class Im2Seq(nn.Layer):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
B
,
C
,
H
,
W
=
x
.
shape
assert
H
==
1
x
=
x
.
reshape
((
B
,
-
1
,
W
))
x
=
x
.
transpose
((
0
,
2
,
3
,
1
))
x
=
x
.
transpose
((
0
,
2
,
1
))
# (NTC)(batch, width, channels)
x
=
x
.
reshape
((
-
1
,
C
))
return
x
return
x
...
@@ -50,49 +37,10 @@ class EncoderWithRNN(nn.Layer):
...
@@ -50,49 +37,10 @@ class EncoderWithRNN(nn.Layer):
def
__init__
(
self
,
in_channels
,
hidden_size
):
def
__init__
(
self
,
in_channels
,
hidden_size
):
super
(
EncoderWithRNN
,
self
).
__init__
()
super
(
EncoderWithRNN
,
self
).
__init__
()
self
.
out_channels
=
hidden_size
*
2
self
.
out_channels
=
hidden_size
*
2
# self.lstm1_fw = nn.LSTMCell(
# in_channels,
# hidden_size,
# weight_ih_attr=ParamAttr(name='lstm_st1_fc1_w'),
# bias_ih_attr=ParamAttr(name='lstm_st1_fc1_b'),
# weight_hh_attr=ParamAttr(name='lstm_st1_out1_w'),
# bias_hh_attr=ParamAttr(name='lstm_st1_out1_b'),
# )
# self.lstm1_bw = nn.LSTMCell(
# in_channels,
# hidden_size,
# weight_ih_attr=ParamAttr(name='lstm_st1_fc2_w'),
# bias_ih_attr=ParamAttr(name='lstm_st1_fc2_b'),
# weight_hh_attr=ParamAttr(name='lstm_st1_out2_w'),
# bias_hh_attr=ParamAttr(name='lstm_st1_out2_b'),
# )
# self.lstm2_fw = nn.LSTMCell(
# hidden_size,
# hidden_size,
# weight_ih_attr=ParamAttr(name='lstm_st2_fc1_w'),
# bias_ih_attr=ParamAttr(name='lstm_st2_fc1_b'),
# weight_hh_attr=ParamAttr(name='lstm_st2_out1_w'),
# bias_hh_attr=ParamAttr(name='lstm_st2_out1_b'),
# )
# self.lstm2_bw = nn.LSTMCell(
# hidden_size,
# hidden_size,
# weight_ih_attr=ParamAttr(name='lstm_st2_fc2_w'),
# bias_ih_attr=ParamAttr(name='lstm_st2_fc2_b'),
# weight_hh_attr=ParamAttr(name='lstm_st2_out2_w'),
# bias_hh_attr=ParamAttr(name='lstm_st2_out2_b'),
# )
self
.
lstm
=
nn
.
LSTM
(
self
.
lstm
=
nn
.
LSTM
(
in_channels
,
hidden_size
,
direction
=
'bidirectional'
,
num_layers
=
2
)
in_channels
,
hidden_size
,
direction
=
'bidirectional'
,
num_layers
=
2
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
# fw_x, _ = self.lstm1_fw(x)
# fw_x, _ = self.lstm2_fw(fw_x)
#
# # bw
# bw_x, _ = self.lstm1_bw(x)
# bw_x, _ = self.lstm2_bw(bw_x)
# x = paddle.concat([fw_x, bw_x], axis=2)
x
,
_
=
self
.
lstm
(
x
)
x
,
_
=
self
.
lstm
(
x
)
return
x
return
x
...
@@ -118,13 +66,13 @@ class EncoderWithFC(nn.Layer):
...
@@ -118,13 +66,13 @@ class EncoderWithFC(nn.Layer):
class
SequenceEncoder
(
nn
.
Layer
):
class
SequenceEncoder
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
encoder_type
,
hidden_size
=
48
,
**
kwargs
):
def
__init__
(
self
,
in_channels
,
encoder_type
,
hidden_size
=
48
,
**
kwargs
):
super
(
SequenceEncoder
,
self
).
__init__
()
super
(
SequenceEncoder
,
self
).
__init__
()
self
.
encoder_reshape
=
EncoderWithReshape
(
in_channels
)
self
.
encoder_reshape
=
Im2Seq
(
in_channels
)
self
.
out_channels
=
self
.
encoder_reshape
.
out_channels
self
.
out_channels
=
self
.
encoder_reshape
.
out_channels
if
encoder_type
==
'reshape'
:
if
encoder_type
==
'reshape'
:
self
.
only_reshape
=
True
self
.
only_reshape
=
True
else
:
else
:
support_encoder_dict
=
{
support_encoder_dict
=
{
'reshape'
:
EncoderWithReshape
,
'reshape'
:
Im2Seq
,
'fc'
:
EncoderWithFC
,
'fc'
:
EncoderWithFC
,
'rnn'
:
EncoderWithRNN
'rnn'
:
EncoderWithRNN
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录