提交 0343756e 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix metric

上级 b48f7609
...@@ -95,17 +95,17 @@ Loss: ...@@ -95,17 +95,17 @@ Loss:
model_name_pairs: model_name_pairs:
- ["Student", "Teacher"] - ["Student", "Teacher"]
key: backbone_out key: backbone_out
PostProcess: PostProcess:
name: DistillationCTCLabelDecode name: DistillationCTCLabelDecode
model_name: "Student" model_name: ["Student", "Teacher"]
key: head_out key: head_out
Metric: Metric:
name: RecMetric name: DistillationMetric
base_metric_name: RecMetric
main_indicator: acc main_indicator: acc
key: "Student"
Train: Train:
dataset: dataset:
......
...@@ -22,9 +22,8 @@ from paddle.nn import SmoothL1Loss ...@@ -22,9 +22,8 @@ from paddle.nn import SmoothL1Loss
class CELoss(nn.Layer): class CELoss(nn.Layer):
def __init__(self, name="loss_ce", epsilon=None): def __init__(self, epsilon=None):
super().__init__() super().__init__()
self.name = name
if epsilon is not None and (epsilon <= 0 or epsilon >= 1): if epsilon is not None and (epsilon <= 0 or epsilon >= 1):
epsilon = None epsilon = None
self.epsilon = epsilon self.epsilon = epsilon
...@@ -52,9 +51,7 @@ class CELoss(nn.Layer): ...@@ -52,9 +51,7 @@ class CELoss(nn.Layer):
else: else:
soft_label = False soft_label = False
loss = F.cross_entropy(x, label=label, soft_label=soft_label) loss = F.cross_entropy(x, label=label, soft_label=soft_label)
return loss
loss_dict[self.name] = paddle.mean(loss)
return loss_dict
class DMLLoss(nn.Layer): class DMLLoss(nn.Layer):
...@@ -62,11 +59,10 @@ class DMLLoss(nn.Layer): ...@@ -62,11 +59,10 @@ class DMLLoss(nn.Layer):
DMLLoss DMLLoss
""" """
def __init__(self, act=None, name="loss_dml"): def __init__(self, act=None):
super().__init__() super().__init__()
if act is not None: if act is not None:
assert act in ["softmax", "sigmoid"] assert act in ["softmax", "sigmoid"]
self.name = name
if act == "softmax": if act == "softmax":
self.act = nn.Softmax(axis=-1) self.act = nn.Softmax(axis=-1)
elif act == "sigmoid": elif act == "sigmoid":
...@@ -75,7 +71,6 @@ class DMLLoss(nn.Layer): ...@@ -75,7 +71,6 @@ class DMLLoss(nn.Layer):
self.act = None self.act = None
def forward(self, out1, out2): def forward(self, out1, out2):
loss_dict = {}
if self.act is not None: if self.act is not None:
out1 = self.act(out1) out1 = self.act(out1)
out2 = self.act(out2) out2 = self.act(out2)
...@@ -85,18 +80,16 @@ class DMLLoss(nn.Layer): ...@@ -85,18 +80,16 @@ class DMLLoss(nn.Layer):
loss = (F.kl_div( loss = (F.kl_div(
log_out1, out2, reduction='batchmean') + F.kl_div( log_out1, out2, reduction='batchmean') + F.kl_div(
log_out2, log_out1, reduction='batchmean')) / 2.0 log_out2, log_out1, reduction='batchmean')) / 2.0
loss_dict[self.name] = loss return loss
return loss_dict
class DistanceLoss(nn.Layer): class DistanceLoss(nn.Layer):
""" """
DistanceLoss: DistanceLoss:
mode: loss mode mode: loss mode
name: loss key in the output dict
""" """
def __init__(self, mode="l2", name="loss_dist", **kargs): def __init__(self, mode="l2", **kargs):
super().__init__() super().__init__()
assert mode in ["l1", "l2", "smooth_l1"] assert mode in ["l1", "l2", "smooth_l1"]
if mode == "l1": if mode == "l1":
...@@ -106,7 +99,5 @@ class DistanceLoss(nn.Layer): ...@@ -106,7 +99,5 @@ class DistanceLoss(nn.Layer):
elif mode == "smooth_l1": elif mode == "smooth_l1":
self.loss_func = nn.SmoothL1Loss(**kargs) self.loss_func = nn.SmoothL1Loss(**kargs)
self.name = "{}_{}".format(name, mode)
def forward(self, x, y): def forward(self, x, y):
return {self.name: self.loss_func(x, y)} return self.loss_func(x, y)
...@@ -26,10 +26,11 @@ class DistillationDMLLoss(DMLLoss): ...@@ -26,10 +26,11 @@ class DistillationDMLLoss(DMLLoss):
def __init__(self, model_name_pairs=[], act=None, key=None, def __init__(self, model_name_pairs=[], act=None, key=None,
name="loss_dml"): name="loss_dml"):
super().__init__(act=act, name=name) super().__init__(act=act)
assert isinstance(model_name_pairs, list) assert isinstance(model_name_pairs, list)
self.key = key self.key = key
self.model_name_pairs = model_name_pairs self.model_name_pairs = model_name_pairs
self.name = name
def forward(self, predicts, batch): def forward(self, predicts, batch):
loss_dict = dict() loss_dict = dict()
...@@ -42,8 +43,8 @@ class DistillationDMLLoss(DMLLoss): ...@@ -42,8 +43,8 @@ class DistillationDMLLoss(DMLLoss):
loss = super().forward(out1, out2) loss = super().forward(out1, out2)
if isinstance(loss, dict): if isinstance(loss, dict):
for key in loss: for key in loss:
loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[ loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
key] idx)] = loss[key]
else: else:
loss_dict["{}_{}".format(self.name, idx)] = loss loss_dict["{}_{}".format(self.name, idx)] = loss
return loss_dict return loss_dict
...@@ -82,10 +83,11 @@ class DistillationDistanceLoss(DistanceLoss): ...@@ -82,10 +83,11 @@ class DistillationDistanceLoss(DistanceLoss):
key=None, key=None,
name="loss_distance", name="loss_distance",
**kargs): **kargs):
super().__init__(mode=mode, name=name, **kargs) super().__init__(mode=mode, **kargs)
assert isinstance(model_name_pairs, list) assert isinstance(model_name_pairs, list)
self.key = key self.key = key
self.model_name_pairs = model_name_pairs self.model_name_pairs = model_name_pairs
self.name = name + "_l2"
def forward(self, predicts, batch): def forward(self, predicts, batch):
loss_dict = dict() loss_dict = dict()
...@@ -101,5 +103,6 @@ class DistillationDistanceLoss(DistanceLoss): ...@@ -101,5 +103,6 @@ class DistillationDistanceLoss(DistanceLoss):
loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[ loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[
key] key]
else: else:
loss_dict["{}_{}".format(self.name, idx)] = loss loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1],
idx)] = loss
return loss_dict return loss_dict
...@@ -19,20 +19,23 @@ from __future__ import unicode_literals ...@@ -19,20 +19,23 @@ from __future__ import unicode_literals
import copy import copy
__all__ = ['build_metric'] __all__ = ["build_metric"]
from .det_metric import DetMetric
from .rec_metric import RecMetric
from .cls_metric import ClsMetric
from .e2e_metric import E2EMetric
from .distillation_metric import DistillationMetric
def build_metric(config):
from .det_metric import DetMetric
from .rec_metric import RecMetric
from .cls_metric import ClsMetric
from .e2e_metric import E2EMetric
support_dict = ['DetMetric', 'RecMetric', 'ClsMetric', 'E2EMetric'] def build_metric(config):
support_dict = [
"DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric"
]
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_name = config.pop('name') module_name = config.pop("name")
assert module_name in support_dict, Exception( assert module_name in support_dict, Exception(
'metric only support {}'.format(support_dict)) "metric only support {}".format(support_dict))
module_class = eval(module_name)(**config) module_class = eval(module_name)(**config)
return module_class return module_class
...@@ -135,19 +135,25 @@ class DistillationCTCLabelDecode(CTCLabelDecode): ...@@ -135,19 +135,25 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
character_dict_path=None, character_dict_path=None,
character_type='ch', character_type='ch',
use_space_char=False, use_space_char=False,
model_name="student", model_name=["student"],
key=None, key=None,
**kwargs): **kwargs):
super(DistillationCTCLabelDecode, self).__init__( super(DistillationCTCLabelDecode, self).__init__(
character_dict_path, character_type, use_space_char) character_dict_path, character_type, use_space_char)
if not isinstance(model_name, list):
model_name = [model_name]
self.model_name = model_name self.model_name = model_name
self.key = key self.key = key
def __call__(self, preds, label=None, *args, **kwargs): def __call__(self, preds, label=None, *args, **kwargs):
pred = preds[self.model_name] output = dict()
if self.key is not None: for name in self.model_name:
pred = pred[self.key] pred = preds[name]
return super().__call__(pred, label=label, *args, **kwargs) if self.key is not None:
pred = pred[self.key]
output[name] = super().__call__(pred, label=label, *args, **kwargs)
return output
class AttnLabelDecode(BaseRecLabelDecode): class AttnLabelDecode(BaseRecLabelDecode):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册