From 0938c44a0a4a6a3231d7c030ecf9a442bd817d79 Mon Sep 17 00:00:00 2001 From: littletomatodonkey <2120160898@bit.edu.cn> Date: Thu, 17 Jun 2021 13:29:49 +0800 Subject: [PATCH] add embedding and fix eval when distillation (#3112) --- ...c_chinese_lite_train_distillation_v2.1.yml | 6 ++- ppocr/modeling/heads/rec_ctc_head.py | 46 +++++++++++++++---- tools/eval.py | 11 ++++- 3 files changed, 50 insertions(+), 13 deletions(-) diff --git a/configs/rec/ch_ppocr_v2.1/rec_chinese_lite_train_distillation_v2.1.yml b/configs/rec/ch_ppocr_v2.1/rec_chinese_lite_train_distillation_v2.1.yml index 6b60ae08..791b34cf 100644 --- a/configs/rec/ch_ppocr_v2.1/rec_chinese_lite_train_distillation_v2.1.yml +++ b/configs/rec/ch_ppocr_v2.1/rec_chinese_lite_train_distillation_v2.1.yml @@ -52,9 +52,10 @@ Architecture: Neck: name: SequenceEncoder encoder_type: rnn - hidden_size: 48 + hidden_size: 64 Head: name: CTCHead + mid_channels: 96 fc_decay: 0.00001 Teacher: pretrained: @@ -71,9 +72,10 @@ Architecture: Neck: name: SequenceEncoder encoder_type: rnn - hidden_size: 48 + hidden_size: 64 Head: name: CTCHead + mid_channels: 96 fc_decay: 0.00001 diff --git a/ppocr/modeling/heads/rec_ctc_head.py b/ppocr/modeling/heads/rec_ctc_head.py index 481f93e4..b54322da 100755 --- a/ppocr/modeling/heads/rec_ctc_head.py +++ b/ppocr/modeling/heads/rec_ctc_head.py @@ -33,19 +33,47 @@ def get_para_bias_attr(l2_decay, k): class CTCHead(nn.Layer): - def __init__(self, in_channels, out_channels, fc_decay=0.0004, **kwargs): + def __init__(self, + in_channels, + out_channels, + fc_decay=0.0004, + mid_channels=None, + **kwargs): super(CTCHead, self).__init__() - weight_attr, bias_attr = get_para_bias_attr( - l2_decay=fc_decay, k=in_channels) - self.fc = nn.Linear( - in_channels, - out_channels, - weight_attr=weight_attr, - bias_attr=bias_attr) + if mid_channels is None: + weight_attr, bias_attr = get_para_bias_attr( + l2_decay=fc_decay, k=in_channels) + self.fc = nn.Linear( + in_channels, + out_channels, + weight_attr=weight_attr, + bias_attr=bias_attr) + else: + weight_attr1, bias_attr1 = get_para_bias_attr( + l2_decay=fc_decay, k=in_channels) + self.fc1 = nn.Linear( + in_channels, + mid_channels, + weight_attr=weight_attr1, + bias_attr=bias_attr1) + + weight_attr2, bias_attr2 = get_para_bias_attr( + l2_decay=fc_decay, k=mid_channels) + self.fc2 = nn.Linear( + mid_channels, + out_channels, + weight_attr=weight_attr2, + bias_attr=bias_attr2) self.out_channels = out_channels + self.mid_channels = mid_channels def forward(self, x, labels=None): - predicts = self.fc(x) + if self.mid_channels is None: + predicts = self.fc(x) + else: + predicts = self.fc1(x) + predicts = self.fc2(predicts) + if not self.training: predicts = F.softmax(predicts, axis=2) return predicts diff --git a/tools/eval.py b/tools/eval.py index 66eb315f..f5e8cd55 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -44,8 +44,15 @@ def main(): # build model # for rec algorithm if hasattr(post_process_class, 'character'): - config['Architecture']["Head"]['out_channels'] = len( - getattr(post_process_class, 'character')) + char_num = len(getattr(post_process_class, 'character')) + if config['Architecture']["algorithm"] in ["Distillation", + ]: # distillation model + for key in config['Architecture']["Models"]: + config['Architecture']["Models"][key]["Head"][ + 'out_channels'] = char_num + else: # base rec model + config['Architecture']["Head"]['out_channels'] = char_num + model = build_model(config['Architecture']) use_srn = config['Architecture']['algorithm'] == "SRN" -- GitLab