未验证 提交 0938c44a 编写于 作者: L littletomatodonkey 提交者: GitHub

add embedding and fix eval when distillation (#3112)

上级 56ab75b0
......@@ -52,9 +52,10 @@ Architecture:
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 48
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00001
Teacher:
pretrained:
......@@ -71,9 +72,10 @@ Architecture:
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 48
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00001
......
......@@ -33,19 +33,47 @@ def get_para_bias_attr(l2_decay, k):
class CTCHead(nn.Layer):
def __init__(self, in_channels, out_channels, fc_decay=0.0004, **kwargs):
def __init__(self,
in_channels,
out_channels,
fc_decay=0.0004,
mid_channels=None,
**kwargs):
super(CTCHead, self).__init__()
weight_attr, bias_attr = get_para_bias_attr(
l2_decay=fc_decay, k=in_channels)
self.fc = nn.Linear(
in_channels,
out_channels,
weight_attr=weight_attr,
bias_attr=bias_attr)
if mid_channels is None:
weight_attr, bias_attr = get_para_bias_attr(
l2_decay=fc_decay, k=in_channels)
self.fc = nn.Linear(
in_channels,
out_channels,
weight_attr=weight_attr,
bias_attr=bias_attr)
else:
weight_attr1, bias_attr1 = get_para_bias_attr(
l2_decay=fc_decay, k=in_channels)
self.fc1 = nn.Linear(
in_channels,
mid_channels,
weight_attr=weight_attr1,
bias_attr=bias_attr1)
weight_attr2, bias_attr2 = get_para_bias_attr(
l2_decay=fc_decay, k=mid_channels)
self.fc2 = nn.Linear(
mid_channels,
out_channels,
weight_attr=weight_attr2,
bias_attr=bias_attr2)
self.out_channels = out_channels
self.mid_channels = mid_channels
def forward(self, x, labels=None):
predicts = self.fc(x)
if self.mid_channels is None:
predicts = self.fc(x)
else:
predicts = self.fc1(x)
predicts = self.fc2(predicts)
if not self.training:
predicts = F.softmax(predicts, axis=2)
return predicts
......@@ -44,8 +44,15 @@ def main():
# build model
# for rec algorithm
if hasattr(post_process_class, 'character'):
config['Architecture']["Head"]['out_channels'] = len(
getattr(post_process_class, 'character'))
char_num = len(getattr(post_process_class, 'character'))
if config['Architecture']["algorithm"] in ["Distillation",
]: # distillation model
for key in config['Architecture']["Models"]:
config['Architecture']["Models"][key]["Head"][
'out_channels'] = char_num
else: # base rec model
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture'])
use_srn = config['Architecture']['algorithm'] == "SRN"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册