From abc4be007ede8d9d5bc8b1d2fea24466679f9a20 Mon Sep 17 00:00:00 2001 From: Double_V Date: Wed, 17 May 2023 16:45:04 +0800 Subject: [PATCH] add nrtr dml distill loss (#9968) * support min_area_rect crop * add check_install * fix requirement.txt * fix check_install * add lanms-neo for drrg * fix * fix doc * fix * support set gpu_id when inference * fix #8855 * fix #8855 * opt slim doc * fix doc bug * add v4_rec_distill config * delete debug * fix comment * fix comment * add dml nrtr distill loss --- ppocr/losses/distillation_loss.py | 233 ++++++++++++++++++++++++++++++ 1 file changed, 233 insertions(+) diff --git a/ppocr/losses/distillation_loss.py b/ppocr/losses/distillation_loss.py index dac662e5..5812544e 100644 --- a/ppocr/losses/distillation_loss.py +++ b/ppocr/losses/distillation_loss.py @@ -96,6 +96,96 @@ class DistillationDMLLoss(DMLLoss): continue return new_outs + def forward(self, predicts, batch): + loss_dict = dict() + for idx, pair in enumerate(self.model_name_pairs): + out1 = predicts[pair[0]] + out2 = predicts[pair[1]] + if self.key is not None: + out1 = out1[self.key] + out2 = out2[self.key] + if self.maps_name is None: + if self.multi_head: + loss = super().forward(out1[self.dis_head], + out2[self.dis_head]) + else: + loss = super().forward(out1, out2) + if isinstance(loss, dict): + for key in loss: + loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1], + idx)] = loss[key] + else: + loss_dict["{}_{}".format(self.name, idx)] = loss + else: + outs1 = self._slice_out(out1) + outs2 = self._slice_out(out2) + for _c, k in enumerate(outs1.keys()): + loss = super().forward(outs1[k], outs2[k]) + if isinstance(loss, dict): + for key in loss: + loss_dict["{}_{}_{}_{}_{}".format(key, pair[ + 0], pair[1], self.maps_name, idx)] = loss[key] + else: + loss_dict["{}_{}_{}".format(self.name, self.maps_name[ + _c], idx)] = loss + + loss_dict = _sum_loss(loss_dict) + + return loss_dict + + +class DistillationKLDivLoss(KLDivLoss): + """ + """ + + def __init__(self, + model_name_pairs=[], + key=None, + multi_head=False, + dis_head='ctc', + maps_name=None, + name="kl_div"): + super().__init__() + assert isinstance(model_name_pairs, list) + self.key = key + self.multi_head = multi_head + self.dis_head = dis_head + self.model_name_pairs = self._check_model_name_pairs(model_name_pairs) + self.name = name + self.maps_name = self._check_maps_name(maps_name) + + def _check_model_name_pairs(self, model_name_pairs): + if not isinstance(model_name_pairs, list): + return [] + elif isinstance(model_name_pairs[0], list) and isinstance( + model_name_pairs[0][0], str): + return model_name_pairs + else: + return [model_name_pairs] + + def _check_maps_name(self, maps_name): + if maps_name is None: + return None + elif type(maps_name) == str: + return [maps_name] + elif type(maps_name) == list: + return [maps_name] + else: + return None + + def _slice_out(self, outs): + new_outs = {} + for k in self.maps_name: + if k == "thrink_maps": + new_outs[k] = outs[:, 0, :, :] + elif k == "threshold_maps": + new_outs[k] = outs[:, 1, :, :] + elif k == "binary_maps": + new_outs[k] = outs[:, 2, :, :] + else: + continue + return new_outs + def forward(self, predicts, batch): loss_dict = dict() for idx, pair in enumerate(self.model_name_pairs): @@ -141,6 +231,149 @@ class DistillationDMLLoss(DMLLoss): return loss_dict +class DistillationDKDLoss(DKDLoss): + """ + """ + + def __init__(self, + model_name_pairs=[], + key=None, + multi_head=False, + dis_head='ctc', + maps_name=None, + name="dkd", + temperature=1.0, + alpha=1.0, + beta=1.0): + super().__init__(temperature, alpha, beta) + assert isinstance(model_name_pairs, list) + self.key = key + self.multi_head = multi_head + self.dis_head = dis_head + self.model_name_pairs = self._check_model_name_pairs(model_name_pairs) + self.name = name + self.maps_name = self._check_maps_name(maps_name) + + def _check_model_name_pairs(self, model_name_pairs): + if not isinstance(model_name_pairs, list): + return [] + elif isinstance(model_name_pairs[0], list) and isinstance( + model_name_pairs[0][0], str): + return model_name_pairs + else: + return [model_name_pairs] + + def _check_maps_name(self, maps_name): + if maps_name is None: + return None + elif type(maps_name) == str: + return [maps_name] + elif type(maps_name) == list: + return [maps_name] + else: + return None + + def _slice_out(self, outs): + new_outs = {} + for k in self.maps_name: + if k == "thrink_maps": + new_outs[k] = outs[:, 0, :, :] + elif k == "threshold_maps": + new_outs[k] = outs[:, 1, :, :] + elif k == "binary_maps": + new_outs[k] = outs[:, 2, :, :] + else: + continue + return new_outs + + def forward(self, predicts, batch): + loss_dict = dict() + + for idx, pair in enumerate(self.model_name_pairs): + out1 = predicts[pair[0]] + out2 = predicts[pair[1]] + if self.key is not None: + out1 = out1[self.key] + out2 = out2[self.key] + if self.maps_name is None: + if self.multi_head: + # for nrtr dml loss + max_len = batch[3].max() + tgt = batch[2][:, 1:2 + + max_len] # [batch_size, max_len + 1] + + tgt = tgt.reshape([-1]) # batch_size * (max_len + 1) + non_pad_mask = paddle.not_equal( + tgt, paddle.zeros( + tgt.shape, + dtype=tgt.dtype)) # batch_size * (max_len + 1) + + loss = super().forward( + out1[self.dis_head], out2[self.dis_head], tgt, + non_pad_mask) # [batch_size, max_len + 1, num_char] + else: + loss = super().forward(out1, out2) + if isinstance(loss, dict): + for key in loss: + loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1], + idx)] = loss[key] + else: + loss_dict["{}_{}".format(self.name, idx)] = loss + else: + outs1 = self._slice_out(out1) + outs2 = self._slice_out(out2) + for _c, k in enumerate(outs1.keys()): + loss = super().forward(outs1[k], outs2[k]) + if isinstance(loss, dict): + for key in loss: + loss_dict["{}_{}_{}_{}_{}".format(key, pair[ + 0], pair[1], self.maps_name, idx)] = loss[key] + else: + loss_dict["{}_{}_{}".format(self.name, self.maps_name[ + _c], idx)] = loss + + loss_dict = _sum_loss(loss_dict) + + return loss_dict + + +class DistillationNRTRDMLLoss(DistillationDMLLoss): + """ + """ + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, pair in enumerate(self.model_name_pairs): + out1 = predicts[pair[0]] + out2 = predicts[pair[1]] + if self.key is not None: + out1 = out1[self.key] + out2 = out2[self.key] + + if self.multi_head: + # for nrtr dml loss + max_len = batch[3].max() + tgt = batch[2][:, 1:2 + max_len] + tgt = tgt.reshape([-1]) + non_pad_mask = paddle.not_equal( + tgt, paddle.zeros( + tgt.shape, dtype=tgt.dtype)) + loss = super().forward(out1[self.dis_head], out2[self.dis_head], + non_pad_mask) + else: + loss = super().forward(out1, out2) + if isinstance(loss, dict): + for key in loss: + loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1], + idx)] = loss[key] + else: + loss_dict["{}_{}".format(self.name, idx)] = loss + + loss_dict = _sum_loss(loss_dict) + + return loss_dict + + class DistillationKLDivLoss(KLDivLoss): """ """ -- GitLab