提交 0343756e 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix metric

上级 b48f7609
......@@ -95,17 +95,17 @@ Loss:
model_name_pairs:
- ["Student", "Teacher"]
key: backbone_out
PostProcess:
name: DistillationCTCLabelDecode
model_name: "Student"
model_name: ["Student", "Teacher"]
key: head_out
Metric:
name: RecMetric
name: DistillationMetric
base_metric_name: RecMetric
main_indicator: acc
key: "Student"
Train:
dataset:
......
......@@ -22,9 +22,8 @@ from paddle.nn import SmoothL1Loss
class CELoss(nn.Layer):
def __init__(self, name="loss_ce", epsilon=None):
def __init__(self, epsilon=None):
super().__init__()
self.name = name
if epsilon is not None and (epsilon <= 0 or epsilon >= 1):
epsilon = None
self.epsilon = epsilon
......@@ -52,9 +51,7 @@ class CELoss(nn.Layer):
else:
soft_label = False
loss = F.cross_entropy(x, label=label, soft_label=soft_label)
loss_dict[self.name] = paddle.mean(loss)
return loss_dict
return loss
class DMLLoss(nn.Layer):
......@@ -62,11 +59,10 @@ class DMLLoss(nn.Layer):
DMLLoss
"""
def __init__(self, act=None, name="loss_dml"):
def __init__(self, act=None):
super().__init__()
if act is not None:
assert act in ["softmax", "sigmoid"]
self.name = name
if act == "softmax":
self.act = nn.Softmax(axis=-1)
elif act == "sigmoid":
......@@ -75,7 +71,6 @@ class DMLLoss(nn.Layer):
self.act = None
def forward(self, out1, out2):
loss_dict = {}
if self.act is not None:
out1 = self.act(out1)
out2 = self.act(out2)
......@@ -85,18 +80,16 @@ class DMLLoss(nn.Layer):
loss = (F.kl_div(
log_out1, out2, reduction='batchmean') + F.kl_div(
log_out2, log_out1, reduction='batchmean')) / 2.0
loss_dict[self.name] = loss
return loss_dict
return loss
class DistanceLoss(nn.Layer):
"""
DistanceLoss:
mode: loss mode
name: loss key in the output dict
"""
def __init__(self, mode="l2", name="loss_dist", **kargs):
def __init__(self, mode="l2", **kargs):
super().__init__()
assert mode in ["l1", "l2", "smooth_l1"]
if mode == "l1":
......@@ -106,7 +99,5 @@ class DistanceLoss(nn.Layer):
elif mode == "smooth_l1":
self.loss_func = nn.SmoothL1Loss(**kargs)
self.name = "{}_{}".format(name, mode)
def forward(self, x, y):
return {self.name: self.loss_func(x, y)}
return self.loss_func(x, y)
......@@ -26,10 +26,11 @@ class DistillationDMLLoss(DMLLoss):
def __init__(self, model_name_pairs=[], act=None, key=None,
name="loss_dml"):
super().__init__(act=act, name=name)
super().__init__(act=act)
assert isinstance(model_name_pairs, list)
self.key = key
self.model_name_pairs = model_name_pairs
self.name = name
def forward(self, predicts, batch):
loss_dict = dict()
......@@ -42,8 +43,8 @@ class DistillationDMLLoss(DMLLoss):
loss = super().forward(out1, out2)
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[
key]
loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
idx)] = loss[key]
else:
loss_dict["{}_{}".format(self.name, idx)] = loss
return loss_dict
......@@ -82,10 +83,11 @@ class DistillationDistanceLoss(DistanceLoss):
key=None,
name="loss_distance",
**kargs):
super().__init__(mode=mode, name=name, **kargs)
super().__init__(mode=mode, **kargs)
assert isinstance(model_name_pairs, list)
self.key = key
self.model_name_pairs = model_name_pairs
self.name = name + "_l2"
def forward(self, predicts, batch):
loss_dict = dict()
......@@ -101,5 +103,6 @@ class DistillationDistanceLoss(DistanceLoss):
loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[
key]
else:
loss_dict["{}_{}".format(self.name, idx)] = loss
loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1],
idx)] = loss
return loss_dict
......@@ -19,20 +19,23 @@ from __future__ import unicode_literals
import copy
__all__ = ['build_metric']
__all__ = ["build_metric"]
from .det_metric import DetMetric
from .rec_metric import RecMetric
from .cls_metric import ClsMetric
from .e2e_metric import E2EMetric
from .distillation_metric import DistillationMetric
def build_metric(config):
from .det_metric import DetMetric
from .rec_metric import RecMetric
from .cls_metric import ClsMetric
from .e2e_metric import E2EMetric
support_dict = ['DetMetric', 'RecMetric', 'ClsMetric', 'E2EMetric']
def build_metric(config):
support_dict = [
"DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric"
]
config = copy.deepcopy(config)
module_name = config.pop('name')
module_name = config.pop("name")
assert module_name in support_dict, Exception(
'metric only support {}'.format(support_dict))
"metric only support {}".format(support_dict))
module_class = eval(module_name)(**config)
return module_class
......@@ -135,19 +135,25 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
character_dict_path=None,
character_type='ch',
use_space_char=False,
model_name="student",
model_name=["student"],
key=None,
**kwargs):
super(DistillationCTCLabelDecode, self).__init__(
character_dict_path, character_type, use_space_char)
if not isinstance(model_name, list):
model_name = [model_name]
self.model_name = model_name
self.key = key
def __call__(self, preds, label=None, *args, **kwargs):
pred = preds[self.model_name]
if self.key is not None:
pred = pred[self.key]
return super().__call__(pred, label=label, *args, **kwargs)
output = dict()
for name in self.model_name:
pred = preds[name]
if self.key is not None:
pred = pred[self.key]
output[name] = super().__call__(pred, label=label, *args, **kwargs)
return output
class AttnLabelDecode(BaseRecLabelDecode):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册