diff --git a/configs/rec/PP-OCRv4/ch_PP-OCRv4_rec_distill.yml b/configs/rec/PP-OCRv4/ch_PP-OCRv4_rec_distill.yml new file mode 100644 index 0000000000000000000000000000000000000000..f613ee52b467f279e4bbbd33dca3c58862f4715a --- /dev/null +++ b/configs/rec/PP-OCRv4/ch_PP-OCRv4_rec_distill.yml @@ -0,0 +1,231 @@ +Global: + debug: false + use_gpu: true + epoch_num: 200 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec_dkd_400w_svtr_ctc_lcnet_blank_dkd0.1/ + save_epoch_step: 40 + eval_batch_step: + - 0 + - 2000 + cal_metric_during_train: true + pretrained_model: null + checkpoints: ./output/rec_dkd_400w_svtr_ctc_lcnet_blank_dkd0.1/latest + save_inference_dir: null + use_visualdl: false + infer_img: doc/imgs_words/ch/word_1.jpg + character_dict_path: ppocr/utils/ppocr_keys_v1.txt + max_text_length: 25 + infer_mode: false + use_space_char: true + distributed: true + save_res_path: ./output/rec/predicts_ppocrv3.txt +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + name: Cosine + learning_rate: 0.001 + warmup_epoch: 2 + regularizer: + name: L2 + factor: 3.0e-05 +Architecture: + model_type: rec + name: DistillationModel + algorithm: Distillation + Models: + Teacher: + pretrained: + freeze_params: true + return_all_feats: true + model_type: rec + algorithm: SVTR + Transform: null + Backbone: + name: SVTRNet + img_size: + - 48 + - 320 + out_char_num: 40 + out_channels: 192 + patch_merging: Conv + embed_dim: + - 64 + - 128 + - 256 + depth: + - 3 + - 6 + - 3 + num_heads: + - 2 + - 4 + - 8 + mixer: + - Conv + - Conv + - Conv + - Conv + - Conv + - Conv + - Global + - Global + - Global + - Global + - Global + - Global + local_mixer: + - - 5 + - 5 + - - 5 + - 5 + - - 5 + - 5 + last_stage: false + prenorm: true + Head: + name: MultiHead + head_list: + - CTCHead: + Neck: + name: svtr + dims: 120 + depth: 2 + hidden_dims: 120 + kernel_size: [1, 3] + use_guide: True + Head: + fc_decay: 0.00001 + - NRTRHead: + nrtr_dim: 384 + max_text_length: *max_text_length + Student: + pretrained: + freeze_params: false + return_all_feats: true + model_type: rec + algorithm: SVTR + Transform: null + Backbone: + name: PPLCNetV3 + scale: 0.95 + Head: + name: MultiHead + head_list: + - CTCHead: + Neck: + name: svtr + dims: 120 + depth: 2 + hidden_dims: 120 + kernel_size: [1, 3] + use_guide: True + Head: + fc_decay: 0.00001 + - NRTRHead: + nrtr_dim: 384 + max_text_length: *max_text_length +Loss: + name: CombinedLoss + loss_config_list: + - DistillationDKDLoss: + weight: 0.1 + model_name_pairs: + - - Student + - Teacher + key: head_out + multi_head: true + alpha: 1.0 + beta: 2.0 + dis_head: gtc + name: dkd + - DistillationCTCLoss: + weight: 1.0 + model_name_list: + - Student + key: head_out + multi_head: true + - DistillationNRTRLoss: + weight: 1.0 + smoothing: false + model_name_list: + - Student + key: head_out + multi_head: true + - DistillCTCLogits: + weight: 1.0 + reduction: mean + model_name_pairs: + - - Student + - Teacher + key: head_out +PostProcess: + name: DistillationCTCLabelDecode + model_name: + - Student + key: head_out + multi_head: true +Metric: + name: DistillationMetric + base_metric_name: RecMetric + main_indicator: acc + key: Student + ignore_space: false +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/ + label_file_list: + - ./train_data/train_list.txt + ratio_list: + - 1.0 + transforms: + - DecodeImage: + img_mode: BGR + channel_first: false + - RecAug: + - MultiLabelEncode: + gtc_encode: NRTRLabelEncode + - KeepKeys: + keep_keys: + - image + - label_ctc + - label_gtc + - length + - valid_ratio + loader: + shuffle: true + batch_size_per_card: 128 + drop_last: true + num_workers: 8 + use_shared_memory: true +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data + label_file_list: + - ./train_data/val_list.txt + transforms: + - DecodeImage: + img_mode: BGR + channel_first: false + - MultiLabelEncode: + gtc_encode: NRTRLabelEncode + - RecResizeImg: + image_shape: [3, 48, 320] + - KeepKeys: + keep_keys: + - image + - label_ctc + - label_gtc + - length + - valid_ratio + loader: + shuffle: false + drop_last: false + batch_size_per_card: 128 + num_workers: 4 +profiler_options: null diff --git a/ppocr/losses/basic_loss.py b/ppocr/losses/basic_loss.py index 58410b4db2157074c2cb0f7db590c84021e10ace..9ad854cd120c996e2c18c61f00718e5826b25372 100644 --- a/ppocr/losses/basic_loss.py +++ b/ppocr/losses/basic_loss.py @@ -165,3 +165,79 @@ class LossFromOutput(nn.Layer): elif self.reduction == 'sum': loss = paddle.sum(loss) return {'loss': loss} + + +class KLDivLoss(nn.Layer): + """ + KLDivLoss + """ + + def __init__(self): + super().__init__() + + def _kldiv(self, x, target, mask=None): + eps = 1.0e-10 + loss = target * (paddle.log(target + eps) - x) + if mask is not None: + loss = loss.flatten(0, 1).sum(axis=1) + loss = loss.masked_select(mask).mean() + else: + # batch mean loss + loss = paddle.sum(loss) / loss.shape[0] + return loss + + def forward(self, logits_s, logits_t, mask=None): + log_out_s = F.log_softmax(logits_s, axis=-1) + out_t = F.softmax(logits_t, axis=-1) + loss = self._kldiv(log_out_s, out_t, mask) + return loss + + +class DKDLoss(nn.Layer): + """ + KLDivLoss + """ + + def __init__(self, temperature=1.0, alpha=1.0, beta=1.0): + super().__init__() + self.temperature = temperature + self.alpha = alpha + self.beta = beta + + def _cat_mask(self, t, mask1, mask2): + t1 = (t * mask1).sum(axis=1, keepdim=True) + t2 = (t * mask2).sum(axis=1, keepdim=True) + rt = paddle.concat([t1, t2], axis=1) + return rt + + def _kl_div(self, x, label, mask=None): + y = (label * (paddle.log(label + 1e-10) - x)).sum(axis=1) + if mask is not None: + y = y.masked_select(mask).mean() + else: + y = y.mean() + return y + + def forward(self, logits_student, logits_teacher, target, mask=None): + gt_mask = F.one_hot( + target.reshape([-1]), num_classes=logits_student.shape[-1]) + other_mask = 1 - gt_mask + logits_student = logits_student.flatten(0, 1) + logits_teacher = logits_teacher.flatten(0, 1) + pred_student = F.softmax(logits_student / self.temperature, axis=1) + pred_teacher = F.softmax(logits_teacher / self.temperature, axis=1) + pred_student = self._cat_mask(pred_student, gt_mask, other_mask) + pred_teacher = self._cat_mask(pred_teacher, gt_mask, other_mask) + log_pred_student = paddle.log(pred_student) + tckd_loss = self._kl_div(log_pred_student, + pred_teacher) * (self.temperature**2) + pred_teacher_part2 = F.softmax( + logits_teacher / self.temperature - 1000.0 * gt_mask, axis=1) + log_pred_student_part2 = F.log_softmax( + logits_student / self.temperature - 1000.0 * gt_mask, axis=1) + nckd_loss = self._kl_div(log_pred_student_part2, + pred_teacher_part2) * (self.temperature**2) + + loss = self.alpha * tckd_loss + self.beta * nckd_loss + + return loss diff --git a/ppocr/losses/combined_loss.py b/ppocr/losses/combined_loss.py index 8d697d544b51899cdafeff94be2ecce067b907a2..a520f10ffb6a83b444fb98c7d461bbcfaf4ce14d 100644 --- a/ppocr/losses/combined_loss.py +++ b/ppocr/losses/combined_loss.py @@ -20,9 +20,9 @@ from .center_loss import CenterLoss from .ace_loss import ACELoss from .rec_sar_loss import SARLoss -from .distillation_loss import DistillationCTCLoss -from .distillation_loss import DistillationSARLoss -from .distillation_loss import DistillationDMLLoss +from .distillation_loss import DistillationCTCLoss, DistillCTCLogits +from .distillation_loss import DistillationSARLoss, DistillationNRTRLoss +from .distillation_loss import DistillationDMLLoss, DistillationKLDivLoss, DistillationDKDLoss from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss from .distillation_loss import DistillationVQASerTokenLayoutLMLoss, DistillationSERDMLLoss from .distillation_loss import DistillationLossFromOutput diff --git a/ppocr/losses/distillation_loss.py b/ppocr/losses/distillation_loss.py index 4bfbed75a338e2bd3bca0b80d16028030bf2f0b5..dac662e5d7b107703c6c9e9b95624b18f0ebbb1b 100644 --- a/ppocr/losses/distillation_loss.py +++ b/ppocr/losses/distillation_loss.py @@ -14,12 +14,14 @@ import paddle import paddle.nn as nn +import paddle.nn.functional as F import numpy as np import cv2 from .rec_ctc_loss import CTCLoss from .rec_sar_loss import SARLoss -from .basic_loss import DMLLoss +from .rec_ce_loss import CELoss +from .basic_loss import DMLLoss, KLDivLoss, DKDLoss from .basic_loss import DistanceLoss from .basic_loss import LossFromOutput from .det_db_loss import DBLoss @@ -102,11 +104,220 @@ class DistillationDMLLoss(DMLLoss): if self.key is not None: out1 = out1[self.key] out2 = out2[self.key] + if self.maps_name is None: + if self.multi_head: + # for nrtr dml loss + max_len = batch[3].max() + tgt = batch[2][:, 1:2 + max_len] + tgt = tgt.reshape([-1]) + non_pad_mask = paddle.not_equal( + tgt, paddle.zeros( + tgt.shape, dtype=tgt.dtype)) + loss = super().forward(out1[self.dis_head], + out2[self.dis_head], non_pad_mask) + else: + loss = super().forward(out1, out2) + if isinstance(loss, dict): + for key in loss: + loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1], + idx)] = loss[key] + else: + loss_dict["{}_{}".format(self.name, idx)] = loss + else: + outs1 = self._slice_out(out1) + outs2 = self._slice_out(out2) + for _c, k in enumerate(outs1.keys()): + loss = super().forward(outs1[k], outs2[k]) + if isinstance(loss, dict): + for key in loss: + loss_dict["{}_{}_{}_{}_{}".format(key, pair[ + 0], pair[1], self.maps_name, idx)] = loss[key] + else: + loss_dict["{}_{}_{}".format(self.name, self.maps_name[ + _c], idx)] = loss + + loss_dict = _sum_loss(loss_dict) + + return loss_dict + + +class DistillationKLDivLoss(KLDivLoss): + """ + """ + def __init__(self, + model_name_pairs=[], + key=None, + multi_head=False, + dis_head='ctc', + maps_name=None, + name="kl_div"): + super().__init__() + assert isinstance(model_name_pairs, list) + self.key = key + self.multi_head = multi_head + self.dis_head = dis_head + self.model_name_pairs = self._check_model_name_pairs(model_name_pairs) + self.name = name + self.maps_name = self._check_maps_name(maps_name) + + def _check_model_name_pairs(self, model_name_pairs): + if not isinstance(model_name_pairs, list): + return [] + elif isinstance(model_name_pairs[0], list) and isinstance( + model_name_pairs[0][0], str): + return model_name_pairs + else: + return [model_name_pairs] + + def _check_maps_name(self, maps_name): + if maps_name is None: + return None + elif type(maps_name) == str: + return [maps_name] + elif type(maps_name) == list: + return [maps_name] + else: + return None + + def _slice_out(self, outs): + new_outs = {} + for k in self.maps_name: + if k == "thrink_maps": + new_outs[k] = outs[:, 0, :, :] + elif k == "threshold_maps": + new_outs[k] = outs[:, 1, :, :] + elif k == "binary_maps": + new_outs[k] = outs[:, 2, :, :] + else: + continue + return new_outs + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, pair in enumerate(self.model_name_pairs): + out1 = predicts[pair[0]] + out2 = predicts[pair[1]] + if self.key is not None: + out1 = out1[self.key] + out2 = out2[self.key] if self.maps_name is None: if self.multi_head: + # for nrtr dml loss + max_len = batch[3].max() + tgt = batch[2][:, 1:2 + max_len] + tgt = tgt.reshape([-1]) + non_pad_mask = paddle.not_equal( + tgt, paddle.zeros( + tgt.shape, dtype=tgt.dtype)) loss = super().forward(out1[self.dis_head], - out2[self.dis_head]) + out2[self.dis_head], non_pad_mask) + else: + loss = super().forward(out1, out2) + if isinstance(loss, dict): + for key in loss: + loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1], + idx)] = loss[key] + else: + loss_dict["{}_{}".format(self.name, idx)] = loss + else: + outs1 = self._slice_out(out1) + outs2 = self._slice_out(out2) + for _c, k in enumerate(outs1.keys()): + loss = super().forward(outs1[k], outs2[k]) + if isinstance(loss, dict): + for key in loss: + loss_dict["{}_{}_{}_{}_{}".format(key, pair[ + 0], pair[1], self.maps_name, idx)] = loss[key] + else: + loss_dict["{}_{}_{}".format(self.name, self.maps_name[ + _c], idx)] = loss + + loss_dict = _sum_loss(loss_dict) + + return loss_dict + + +class DistillationDKDLoss(DKDLoss): + """ + """ + + def __init__(self, + model_name_pairs=[], + key=None, + multi_head=False, + dis_head='ctc', + maps_name=None, + name="dkd", + temperature=1.0, + alpha=1.0, + beta=1.0): + super().__init__(temperature, alpha, beta) + assert isinstance(model_name_pairs, list) + self.key = key + self.multi_head = multi_head + self.dis_head = dis_head + self.model_name_pairs = self._check_model_name_pairs(model_name_pairs) + self.name = name + self.maps_name = self._check_maps_name(maps_name) + + def _check_model_name_pairs(self, model_name_pairs): + if not isinstance(model_name_pairs, list): + return [] + elif isinstance(model_name_pairs[0], list) and isinstance( + model_name_pairs[0][0], str): + return model_name_pairs + else: + return [model_name_pairs] + + def _check_maps_name(self, maps_name): + if maps_name is None: + return None + elif type(maps_name) == str: + return [maps_name] + elif type(maps_name) == list: + return [maps_name] + else: + return None + + def _slice_out(self, outs): + new_outs = {} + for k in self.maps_name: + if k == "thrink_maps": + new_outs[k] = outs[:, 0, :, :] + elif k == "threshold_maps": + new_outs[k] = outs[:, 1, :, :] + elif k == "binary_maps": + new_outs[k] = outs[:, 2, :, :] + else: + continue + return new_outs + + def forward(self, predicts, batch): + loss_dict = dict() + + for idx, pair in enumerate(self.model_name_pairs): + out1 = predicts[pair[0]] + out2 = predicts[pair[1]] + if self.key is not None: + out1 = out1[self.key] + out2 = out2[self.key] + if self.maps_name is None: + if self.multi_head: + # for nrtr dml loss + max_len = batch[3].max() + tgt = batch[2][:, 1:2 + + max_len] # [batch_size, max_len + 1] + + tgt = tgt.reshape([-1]) # batch_size * (max_len + 1) + non_pad_mask = paddle.not_equal( + tgt, paddle.zeros( + tgt.shape, + dtype=tgt.dtype)) # batch_size * (max_len + 1) + + loss = super().forward( + out1[self.dis_head], out2[self.dis_head], tgt, + non_pad_mask) # [batch_size, max_len + 1, num_char] else: loss = super().forward(out1, out2) if isinstance(loss, dict): @@ -199,6 +410,40 @@ class DistillationSARLoss(SARLoss): return loss_dict +class DistillationNRTRLoss(CELoss): + def __init__(self, + model_name_list=[], + key=None, + multi_head=False, + smoothing=True, + name="loss_nrtr", + **kwargs): + super().__init__(smoothing=smoothing) + self.model_name_list = model_name_list + self.key = key + self.name = name + self.multi_head = multi_head + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, model_name in enumerate(self.model_name_list): + out = predicts[model_name] + if self.key is not None: + out = out[self.key] + if self.multi_head: + assert 'gtc' in out, 'multi head has multi out' + loss = super().forward(out['gtc'], batch[:1] + batch[2:]) + else: + loss = super().forward(out, batch) + if isinstance(loss, dict): + for key in loss: + loss_dict["{}_{}_{}".format(self.name, model_name, + idx)] = loss[key] + else: + loss_dict["{}_{}".format(self.name, model_name)] = loss + return loss_dict + + class DistillationDBLoss(DBLoss): def __init__(self, model_name_list=[], @@ -459,3 +704,212 @@ class DistillationVQADistanceLoss(DistanceLoss): loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1], idx)] = loss return loss_dict + + +class CTCDKDLoss(nn.Layer): + """ + KLDivLoss + """ + + def __init__(self, temperature=0.5, alpha=1.0, beta=1.0): + super().__init__() + self.temperature = temperature + self.alpha = alpha + self.beta = beta + self.eps = 1e-6 + self.t = temperature + self.act = nn.Softmax(axis=-1) + self.use_log = True + + def kl_loss(self, p1, p2): # predict, label + loss = paddle.multiply( + p2, paddle.log((p2 + self.eps) / (p1 + self.eps) + self.eps)) + bs = loss.shape[0] + loss = paddle.sum(loss) / bs + return loss + + def _cat_mask(self, t, mask1, mask2): + t1 = (t * mask1).sum(axis=1, keepdim=True) + t2 = (t * mask2).sum(axis=1, keepdim=True) + rt = paddle.concat([t1, t2], axis=1) + return rt + + def multi_label_mask(self, targets): + + targets = targets.astype("int32") + res = F.one_hot(targets, num_classes=11465) + mask = paddle.clip(paddle.sum(res, axis=1), 0, 1) + mask[:, 0] = 0 # ingore ctc blank label + return mask + + def forward(self, logits_student, logits_teacher, targets, mask=None): + + gt_mask = self.multi_label_mask(targets) + other_mask = paddle.ones_like(gt_mask) - gt_mask + + pred_student = F.softmax(logits_student / self.temperature, axis=-1) + pred_teacher = F.softmax(logits_teacher / self.temperature, axis=-1) + + # differents with dkd + pred_student = paddle.mean(pred_student, axis=1) + pred_teacher = paddle.mean(pred_teacher, axis=1) + + pred_student = self._cat_mask(pred_student, gt_mask, other_mask) + pred_teacher = self._cat_mask(pred_teacher, gt_mask, other_mask) + + # differents with dkd + tckd_loss = self.kl_loss(pred_student, pred_teacher) + + gt_mask_ex = paddle.expand_as(gt_mask.unsqueeze(axis=1), logits_teacher) + pred_teacher_part2 = F.softmax( + logits_teacher / self.temperature - 1000.0 * gt_mask_ex, axis=-1) + pred_student_part2 = F.softmax( + logits_student / self.temperature - 1000.0 * gt_mask_ex, axis=-1) + # differents with dkd + pred_teacher_part2 = paddle.mean(pred_teacher_part2, axis=1) + pred_student_part2 = paddle.mean(pred_student_part2, axis=1) + + # differents with dkd + nckd_loss = self.kl_loss(pred_student_part2, pred_teacher_part2) + loss = self.alpha * tckd_loss + self.beta * nckd_loss + return loss + + +class KLCTCLogits(nn.Layer): + def __init__(self, weight=1.0, reduction='mean', mode="mean"): + super().__init__() + self.weight = weight + self.reduction = reduction + self.eps = 1e-6 + self.t = 0.5 + self.act = nn.Softmax(axis=-1) + self.use_log = True + self.mode = mode + self.ctc_dkd_loss = CTCDKDLoss() + + def kl_loss(self, p1, p2): # predict, label + loss = paddle.multiply( + p2, paddle.log((p2 + self.eps) / (p1 + self.eps) + self.eps)) + bs = loss.shape[0] + loss = paddle.sum(loss) / bs + return loss + + def forward_meanmax(self, stu_out, tea_out): + + stu_out = paddle.mean(F.softmax(stu_out / self.t, axis=-1), axis=1) + tea_out = paddle.mean(F.softmax(tea_out / self.t, axis=-1), axis=1) + loss = self.kl_loss(stu_out, tea_out) + + return loss + + def forward_meanlog(self, stu_out, tea_out): + stu_out = paddle.mean(F.softmax(stu_out / self.t, axis=-1), axis=1) + tea_out = paddle.mean(F.softmax(tea_out / self.t, axis=-1), axis=1) + if self.use_log is True: + # for recognition distillation, log is needed for feature map + log_out1 = paddle.log(stu_out) + log_out2 = paddle.log(tea_out) + loss = ( + self._kldiv(log_out1, tea_out) + self._kldiv(log_out2, stu_out) + ) / 2.0 + + return loss + + def forward_sum(self, stu_out, tea_out): + stu_out = paddle.sum(F.softmax(stu_out / self.t, axis=-1), axis=1) + tea_out = paddle.sum(F.softmax(tea_out / self.t, axis=-1), axis=1) + stu_out = paddle.log(stu_out) + bs = stu_out.shape[0] + loss = tea_out * (paddle.log(tea_out + self.eps) - stu_out) + loss = paddle.sum(loss, axis=1) / loss.shape[0] + return loss + + def _kldiv(self, x, target): + eps = 1.0e-10 + loss = target * (paddle.log(target + eps) - x) + loss = paddle.sum(paddle.mean(loss, axis=1)) / loss.shape[0] + return loss + + def forward(self, stu_out, tea_out, targets=None): + if self.mode == "log": + return self.forward_log(stu_out, tea_out) + elif self.mode == "mean": + blank_mask = paddle.ones_like(stu_out) + blank_mask.stop_gradient = True + blank_mask[:, :, 0] = -1 + stu_out *= blank_mask + tea_out *= blank_mask + return self.forward_meanmax(stu_out, tea_out) + elif self.mode == "sum": + return self.forward_sum(stu_out, tea_out) + elif self.mode == "meanlog": + blank_mask = paddle.ones_like(stu_out) + blank_mask.stop_gradient = True + blank_mask[:, :, 0] = -1 + stu_out *= blank_mask + tea_out *= blank_mask + return self.forward_meanlog(stu_out, tea_out) + elif self.mode == "ctcdkd": + # ingore ctc blank logits + blank_mask = paddle.ones_like(stu_out) + blank_mask.stop_gradient = True + blank_mask[:, :, 0] = -1 + stu_out *= blank_mask + tea_out *= blank_mask + return self.ctc_dkd_loss(stu_out, tea_out, targets) + else: + raise ValueError("error!!!!!!") + + def forward_log(self, out1, out2): + if self.act is not None: + out1 = self.act(out1) + 1e-10 + out2 = self.act(out2) + 1e-10 + if self.use_log is True: + # for recognition distillation, log is needed for feature map + log_out1 = paddle.log(out1) + log_out2 = paddle.log(out2) + loss = ( + self._kldiv(log_out1, out2) + self._kldiv(log_out2, out1)) / 2.0 + + return loss + + +class DistillCTCLogits(KLCTCLogits): + def __init__(self, + model_name_pairs=[], + key=None, + name="ctc_logits", + reduction="mean"): + super().__init__(reduction=reduction) + self.model_name_pairs = self._check_model_name_pairs(model_name_pairs) + self.key = key + self.name = name + + def _check_model_name_pairs(self, model_name_pairs): + if not isinstance(model_name_pairs, list): + return [] + elif isinstance(model_name_pairs[0], list) and isinstance( + model_name_pairs[0][0], str): + return model_name_pairs + else: + return [model_name_pairs] + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, pair in enumerate(self.model_name_pairs): + out1 = predicts[pair[0]] + out2 = predicts[pair[1]] + + if self.key is not None: + out1 = out1[self.key]['ctc'] + out2 = out2[self.key]['ctc'] + + ctc_label = batch[1] + loss = super().forward(out1, out2, ctc_label) + if isinstance(loss, dict): + for key in loss: + loss_dict["{}_{}_{}".format(self.name, model_name, + idx)] = loss[key] + else: + loss_dict["{}_{}".format(self.name, idx)] = loss + return loss_dict