提交 ab4db2ac 编写于 作者: littletomatodonkey's avatar littletomatodonkey

support dict output for basemodel

上级 e5d3a2d8
...@@ -39,6 +39,7 @@ Architecture: ...@@ -39,6 +39,7 @@ Architecture:
Student: Student:
pretrained: null pretrained: null
freeze_params: false freeze_params: false
return_all_feats: true
model_type: rec model_type: rec
algorithm: CRNN algorithm: CRNN
Transform: Transform:
...@@ -57,6 +58,7 @@ Architecture: ...@@ -57,6 +58,7 @@ Architecture:
Teacher: Teacher:
pretrained: null pretrained: null
freeze_params: false freeze_params: false
return_all_feats: true
model_type: rec model_type: rec
algorithm: CRNN algorithm: CRNN
Transform: Transform:
...@@ -80,18 +82,26 @@ Loss: ...@@ -80,18 +82,26 @@ Loss:
- DistillationCTCLoss: - DistillationCTCLoss:
weight: 1.0 weight: 1.0
model_name_list: ["Student", "Teacher"] model_name_list: ["Student", "Teacher"]
key: null key: head_out
- DistillationDMLLoss: - DistillationDMLLoss:
weight: 1.0 weight: 1.0
act: "softmax" act: "softmax"
model_name_pairs: model_name_pairs:
- ["Student", "Teacher"] - ["Student", "Teacher"]
key: null key: head_out
- DistillationDistanceLoss:
weight: 1.0
mode: "l2"
model_name_pairs:
- ["Student", "Teacher"]
key: backbone_out
PostProcess: PostProcess:
name: DistillationCTCLabelDecode name: DistillationCTCLabelDecode
model_name: "Student" model_name: "Student"
key_out: null key: head_out
Metric: Metric:
name: RecMetric name: RecMetric
......
...@@ -97,6 +97,7 @@ class DistanceLoss(nn.Layer): ...@@ -97,6 +97,7 @@ class DistanceLoss(nn.Layer):
""" """
def __init__(self, mode="l2", name="loss_dist", **kargs): def __init__(self, mode="l2", name="loss_dist", **kargs):
super().__init__()
assert mode in ["l1", "l2", "smooth_l1"] assert mode in ["l1", "l2", "smooth_l1"]
if mode == "l1": if mode == "l1":
self.loss_func = nn.L1Loss(**kargs) self.loss_func = nn.L1Loss(**kargs)
......
...@@ -17,6 +17,7 @@ import paddle.nn as nn ...@@ -17,6 +17,7 @@ import paddle.nn as nn
from .distillation_loss import DistillationCTCLoss from .distillation_loss import DistillationCTCLoss
from .distillation_loss import DistillationDMLLoss from .distillation_loss import DistillationDMLLoss
from .distillation_loss import DistillationDistanceLoss
class CombinedLoss(nn.Layer): class CombinedLoss(nn.Layer):
......
...@@ -17,6 +17,7 @@ import paddle.nn as nn ...@@ -17,6 +17,7 @@ import paddle.nn as nn
from .rec_ctc_loss import CTCLoss from .rec_ctc_loss import CTCLoss
from .basic_loss import DMLLoss from .basic_loss import DMLLoss
from .basic_loss import DistanceLoss
class DistillationDMLLoss(DMLLoss): class DistillationDMLLoss(DMLLoss):
...@@ -69,3 +70,36 @@ class DistillationCTCLoss(CTCLoss): ...@@ -69,3 +70,36 @@ class DistillationCTCLoss(CTCLoss):
else: else:
loss_dict["{}_{}".format(self.name, model_name)] = loss loss_dict["{}_{}".format(self.name, model_name)] = loss
return loss_dict return loss_dict
class DistillationDistanceLoss(DistanceLoss):
"""
"""
def __init__(self,
mode="l2",
model_name_pairs=[],
key=None,
name="loss_distance",
**kargs):
super().__init__(mode=mode, name=name)
assert isinstance(model_name_pairs, list)
self.key = key
self.model_name_pairs = model_name_pairs
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
loss = super().forward(out1, out2)
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[
key]
else:
loss_dict["{}_{}".format(self.name, idx)] = loss
return loss_dict
...@@ -67,14 +67,23 @@ class BaseModel(nn.Layer): ...@@ -67,14 +67,23 @@ class BaseModel(nn.Layer):
config["Head"]['in_channels'] = in_channels config["Head"]['in_channels'] = in_channels
self.head = build_head(config["Head"]) self.head = build_head(config["Head"])
self.return_all_feats = config.get("return_all_feats", False)
def forward(self, x, data=None): def forward(self, x, data=None):
y = dict()
if self.use_transform: if self.use_transform:
x = self.transform(x) x = self.transform(x)
x = self.backbone(x) x = self.backbone(x)
y["backbone_out"] = x
if self.use_neck: if self.use_neck:
x = self.neck(x) x = self.neck(x)
y["neck_out"] = x
if data is None: if data is None:
x = self.head(x) x = self.head(x)
else: else:
x = self.head(x, data) x = self.head(x, data)
return x y["head_out"] = x
if self.return_all_feats:
return y
else:
return x
...@@ -136,17 +136,17 @@ class DistillationCTCLabelDecode(CTCLabelDecode): ...@@ -136,17 +136,17 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
character_type='ch', character_type='ch',
use_space_char=False, use_space_char=False,
model_name="student", model_name="student",
key_out=None, key=None,
**kwargs): **kwargs):
super(DistillationCTCLabelDecode, self).__init__( super(DistillationCTCLabelDecode, self).__init__(
character_dict_path, character_type, use_space_char) character_dict_path, character_type, use_space_char)
self.model_name = model_name self.model_name = model_name
self.key_out = key_out self.key = key
def __call__(self, preds, label=None, *args, **kwargs): def __call__(self, preds, label=None, *args, **kwargs):
pred = preds[self.model_name] pred = preds[self.model_name]
if self.key_out is not None: if self.key is not None:
pred = pred[self.key_out] pred = pred[self.key]
return super().__call__(pred, label=label, *args, **kwargs) return super().__call__(pred, label=label, *args, **kwargs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册