diff --git a/configs/rec/ch_ppocr_v2.1/rec_chinese_lite_train_distillation_v2.1.yml b/configs/rec/ch_ppocr_v2.1/rec_chinese_lite_train_distillation_v2.1.yml index 6b60ae0860959405fd512913d022c02e2e2dae05..791b34cf5785d81a0f1346c0ef1ad4485ed3fee8 100644 --- a/configs/rec/ch_ppocr_v2.1/rec_chinese_lite_train_distillation_v2.1.yml +++ b/configs/rec/ch_ppocr_v2.1/rec_chinese_lite_train_distillation_v2.1.yml @@ -52,9 +52,10 @@ Architecture: Neck: name: SequenceEncoder encoder_type: rnn - hidden_size: 48 + hidden_size: 64 Head: name: CTCHead + mid_channels: 96 fc_decay: 0.00001 Teacher: pretrained: @@ -71,9 +72,10 @@ Architecture: Neck: name: SequenceEncoder encoder_type: rnn - hidden_size: 48 + hidden_size: 64 Head: name: CTCHead + mid_channels: 96 fc_decay: 0.00001 diff --git a/ppocr/modeling/heads/rec_ctc_head.py b/ppocr/modeling/heads/rec_ctc_head.py index 481f93e47e58f8267b23e632df1a1e80733d5944..b54322da01cebff1034a2b89e33015ff120fc727 100755 --- a/ppocr/modeling/heads/rec_ctc_head.py +++ b/ppocr/modeling/heads/rec_ctc_head.py @@ -33,19 +33,47 @@ def get_para_bias_attr(l2_decay, k): class CTCHead(nn.Layer): - def __init__(self, in_channels, out_channels, fc_decay=0.0004, **kwargs): + def __init__(self, + in_channels, + out_channels, + fc_decay=0.0004, + mid_channels=None, + **kwargs): super(CTCHead, self).__init__() - weight_attr, bias_attr = get_para_bias_attr( - l2_decay=fc_decay, k=in_channels) - self.fc = nn.Linear( - in_channels, - out_channels, - weight_attr=weight_attr, - bias_attr=bias_attr) + if mid_channels is None: + weight_attr, bias_attr = get_para_bias_attr( + l2_decay=fc_decay, k=in_channels) + self.fc = nn.Linear( + in_channels, + out_channels, + weight_attr=weight_attr, + bias_attr=bias_attr) + else: + weight_attr1, bias_attr1 = get_para_bias_attr( + l2_decay=fc_decay, k=in_channels) + self.fc1 = nn.Linear( + in_channels, + mid_channels, + weight_attr=weight_attr1, + bias_attr=bias_attr1) + + weight_attr2, bias_attr2 = get_para_bias_attr( + l2_decay=fc_decay, k=mid_channels) + self.fc2 = nn.Linear( + mid_channels, + out_channels, + weight_attr=weight_attr2, + bias_attr=bias_attr2) self.out_channels = out_channels + self.mid_channels = mid_channels def forward(self, x, labels=None): - predicts = self.fc(x) + if self.mid_channels is None: + predicts = self.fc(x) + else: + predicts = self.fc1(x) + predicts = self.fc2(predicts) + if not self.training: predicts = F.softmax(predicts, axis=2) return predicts diff --git a/tools/eval.py b/tools/eval.py index 66eb315f9b37ed681f6a899613fa43c1313bc654..f5e8cd5593a0eb12247300b9f52f152655a49c59 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -44,8 +44,15 @@ def main(): # build model # for rec algorithm if hasattr(post_process_class, 'character'): - config['Architecture']["Head"]['out_channels'] = len( - getattr(post_process_class, 'character')) + char_num = len(getattr(post_process_class, 'character')) + if config['Architecture']["algorithm"] in ["Distillation", + ]: # distillation model + for key in config['Architecture']["Models"]: + config['Architecture']["Models"][key]["Head"][ + 'out_channels'] = char_num + else: # base rec model + config['Architecture']["Head"]['out_channels'] = char_num + model = build_model(config['Architecture']) use_srn = config['Architecture']['algorithm'] == "SRN"