未验证 提交 0938c44a 编写于 作者: L littletomatodonkey 提交者: GitHub

add embedding and fix eval when distillation (#3112)

上级 56ab75b0
...@@ -52,9 +52,10 @@ Architecture: ...@@ -52,9 +52,10 @@ Architecture:
Neck: Neck:
name: SequenceEncoder name: SequenceEncoder
encoder_type: rnn encoder_type: rnn
hidden_size: 48 hidden_size: 64
Head: Head:
name: CTCHead name: CTCHead
mid_channels: 96
fc_decay: 0.00001 fc_decay: 0.00001
Teacher: Teacher:
pretrained: pretrained:
...@@ -71,9 +72,10 @@ Architecture: ...@@ -71,9 +72,10 @@ Architecture:
Neck: Neck:
name: SequenceEncoder name: SequenceEncoder
encoder_type: rnn encoder_type: rnn
hidden_size: 48 hidden_size: 64
Head: Head:
name: CTCHead name: CTCHead
mid_channels: 96
fc_decay: 0.00001 fc_decay: 0.00001
......
...@@ -33,8 +33,14 @@ def get_para_bias_attr(l2_decay, k): ...@@ -33,8 +33,14 @@ def get_para_bias_attr(l2_decay, k):
class CTCHead(nn.Layer): class CTCHead(nn.Layer):
def __init__(self, in_channels, out_channels, fc_decay=0.0004, **kwargs): def __init__(self,
in_channels,
out_channels,
fc_decay=0.0004,
mid_channels=None,
**kwargs):
super(CTCHead, self).__init__() super(CTCHead, self).__init__()
if mid_channels is None:
weight_attr, bias_attr = get_para_bias_attr( weight_attr, bias_attr = get_para_bias_attr(
l2_decay=fc_decay, k=in_channels) l2_decay=fc_decay, k=in_channels)
self.fc = nn.Linear( self.fc = nn.Linear(
...@@ -42,10 +48,32 @@ class CTCHead(nn.Layer): ...@@ -42,10 +48,32 @@ class CTCHead(nn.Layer):
out_channels, out_channels,
weight_attr=weight_attr, weight_attr=weight_attr,
bias_attr=bias_attr) bias_attr=bias_attr)
else:
weight_attr1, bias_attr1 = get_para_bias_attr(
l2_decay=fc_decay, k=in_channels)
self.fc1 = nn.Linear(
in_channels,
mid_channels,
weight_attr=weight_attr1,
bias_attr=bias_attr1)
weight_attr2, bias_attr2 = get_para_bias_attr(
l2_decay=fc_decay, k=mid_channels)
self.fc2 = nn.Linear(
mid_channels,
out_channels,
weight_attr=weight_attr2,
bias_attr=bias_attr2)
self.out_channels = out_channels self.out_channels = out_channels
self.mid_channels = mid_channels
def forward(self, x, labels=None): def forward(self, x, labels=None):
if self.mid_channels is None:
predicts = self.fc(x) predicts = self.fc(x)
else:
predicts = self.fc1(x)
predicts = self.fc2(predicts)
if not self.training: if not self.training:
predicts = F.softmax(predicts, axis=2) predicts = F.softmax(predicts, axis=2)
return predicts return predicts
...@@ -44,8 +44,15 @@ def main(): ...@@ -44,8 +44,15 @@ def main():
# build model # build model
# for rec algorithm # for rec algorithm
if hasattr(post_process_class, 'character'): if hasattr(post_process_class, 'character'):
config['Architecture']["Head"]['out_channels'] = len( char_num = len(getattr(post_process_class, 'character'))
getattr(post_process_class, 'character')) if config['Architecture']["algorithm"] in ["Distillation",
]: # distillation model
for key in config['Architecture']["Models"]:
config['Architecture']["Models"][key]["Head"][
'out_channels'] = char_num
else: # base rec model
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
use_srn = config['Architecture']['algorithm'] == "SRN" use_srn = config['Architecture']['algorithm'] == "SRN"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册