未验证 提交 f1048e29 编写于 作者: D dyning 提交者: GitHub

Merge pull request #970 from WenmuZhou/dygraph

解决crnn训练时对labels进行合并的bug
...@@ -21,18 +21,6 @@ from paddle import nn ...@@ -21,18 +21,6 @@ from paddle import nn
from ppocr.modeling.heads.rec_ctc_head import get_para_bias_attr from ppocr.modeling.heads.rec_ctc_head import get_para_bias_attr
class EncoderWithReshape(nn.Layer):
def __init__(self, in_channels, **kwargs):
super().__init__()
self.out_channels = in_channels
def forward(self, x):
B, C, H, W = x.shape
x = x.reshape((B, C, -1))
x = x.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
return x
class Im2Seq(nn.Layer): class Im2Seq(nn.Layer):
def __init__(self, in_channels, **kwargs): def __init__(self, in_channels, **kwargs):
super().__init__() super().__init__()
...@@ -40,9 +28,8 @@ class Im2Seq(nn.Layer): ...@@ -40,9 +28,8 @@ class Im2Seq(nn.Layer):
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
assert H == 1 x = x.reshape((B, -1, W))
x = x.transpose((0, 2, 3, 1)) x = x.transpose((0, 2, 1)) # (NTC)(batch, width, channels)
x = x.reshape((-1, C))
return x return x
...@@ -50,49 +37,10 @@ class EncoderWithRNN(nn.Layer): ...@@ -50,49 +37,10 @@ class EncoderWithRNN(nn.Layer):
def __init__(self, in_channels, hidden_size): def __init__(self, in_channels, hidden_size):
super(EncoderWithRNN, self).__init__() super(EncoderWithRNN, self).__init__()
self.out_channels = hidden_size * 2 self.out_channels = hidden_size * 2
# self.lstm1_fw = nn.LSTMCell(
# in_channels,
# hidden_size,
# weight_ih_attr=ParamAttr(name='lstm_st1_fc1_w'),
# bias_ih_attr=ParamAttr(name='lstm_st1_fc1_b'),
# weight_hh_attr=ParamAttr(name='lstm_st1_out1_w'),
# bias_hh_attr=ParamAttr(name='lstm_st1_out1_b'),
# )
# self.lstm1_bw = nn.LSTMCell(
# in_channels,
# hidden_size,
# weight_ih_attr=ParamAttr(name='lstm_st1_fc2_w'),
# bias_ih_attr=ParamAttr(name='lstm_st1_fc2_b'),
# weight_hh_attr=ParamAttr(name='lstm_st1_out2_w'),
# bias_hh_attr=ParamAttr(name='lstm_st1_out2_b'),
# )
# self.lstm2_fw = nn.LSTMCell(
# hidden_size,
# hidden_size,
# weight_ih_attr=ParamAttr(name='lstm_st2_fc1_w'),
# bias_ih_attr=ParamAttr(name='lstm_st2_fc1_b'),
# weight_hh_attr=ParamAttr(name='lstm_st2_out1_w'),
# bias_hh_attr=ParamAttr(name='lstm_st2_out1_b'),
# )
# self.lstm2_bw = nn.LSTMCell(
# hidden_size,
# hidden_size,
# weight_ih_attr=ParamAttr(name='lstm_st2_fc2_w'),
# bias_ih_attr=ParamAttr(name='lstm_st2_fc2_b'),
# weight_hh_attr=ParamAttr(name='lstm_st2_out2_w'),
# bias_hh_attr=ParamAttr(name='lstm_st2_out2_b'),
# )
self.lstm = nn.LSTM( self.lstm = nn.LSTM(
in_channels, hidden_size, direction='bidirectional', num_layers=2) in_channels, hidden_size, direction='bidirectional', num_layers=2)
def forward(self, x): def forward(self, x):
# fw_x, _ = self.lstm1_fw(x)
# fw_x, _ = self.lstm2_fw(fw_x)
#
# # bw
# bw_x, _ = self.lstm1_bw(x)
# bw_x, _ = self.lstm2_bw(bw_x)
# x = paddle.concat([fw_x, bw_x], axis=2)
x, _ = self.lstm(x) x, _ = self.lstm(x)
return x return x
...@@ -118,13 +66,13 @@ class EncoderWithFC(nn.Layer): ...@@ -118,13 +66,13 @@ class EncoderWithFC(nn.Layer):
class SequenceEncoder(nn.Layer): class SequenceEncoder(nn.Layer):
def __init__(self, in_channels, encoder_type, hidden_size=48, **kwargs): def __init__(self, in_channels, encoder_type, hidden_size=48, **kwargs):
super(SequenceEncoder, self).__init__() super(SequenceEncoder, self).__init__()
self.encoder_reshape = EncoderWithReshape(in_channels) self.encoder_reshape = Im2Seq(in_channels)
self.out_channels = self.encoder_reshape.out_channels self.out_channels = self.encoder_reshape.out_channels
if encoder_type == 'reshape': if encoder_type == 'reshape':
self.only_reshape = True self.only_reshape = True
else: else:
support_encoder_dict = { support_encoder_dict = {
'reshape': EncoderWithReshape, 'reshape': Im2Seq,
'fc': EncoderWithFC, 'fc': EncoderWithFC,
'rnn': EncoderWithRNN 'rnn': EncoderWithRNN
} }
......
...@@ -70,6 +70,7 @@ class BaseRecLabelDecode(object): ...@@ -70,6 +70,7 @@ class BaseRecLabelDecode(object):
if text_index[batch_idx][idx] in ignored_tokens: if text_index[batch_idx][idx] in ignored_tokens:
continue continue
if is_remove_duplicate: if is_remove_duplicate:
# only for predict
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
batch_idx][idx]: batch_idx][idx]:
continue continue
...@@ -107,7 +108,7 @@ class CTCLabelDecode(BaseRecLabelDecode): ...@@ -107,7 +108,7 @@ class CTCLabelDecode(BaseRecLabelDecode):
text = self.decode(preds_idx, preds_prob) text = self.decode(preds_idx, preds_prob)
if label is None: if label is None:
return text return text
label = self.decode(label) label = self.decode(label, is_remove_duplicate=False)
return text, label return text, label
def add_special_char(self, dict_character): def add_special_char(self, dict_character):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册