diff --git a/configs/det/ch_ppocr_v2.1/ch_det_lite_train_distill_v2.1.yml b/configs/det/ch_ppocr_v2.1/ch_det_lite_train_distill_v2.1.yml index b27eb2f9c966b66f6569d600b1ae7e79bd1da0dc..5f0846fa1ae1403ec043eb0dfb5f50b0f9edde4f 100644 --- a/configs/det/ch_ppocr_v2.1/ch_det_lite_train_distill_v2.1.yml +++ b/configs/det/ch_ppocr_v2.1/ch_det_lite_train_distill_v2.1.yml @@ -88,7 +88,7 @@ Loss: - DistillationDMLLoss: model_name_pairs: - ["Student", "Student2"] - maps_name: ["thrink_maps"] + maps_name: "thrink_maps" weight: 1.0 act: "softmax" model_name_pairs: ["Student", "Student2"] @@ -96,7 +96,7 @@ Loss: - DistillationDBLoss: weight: 1.0 model_name_list: ["Student", "Student2"] - key: maps + # key: maps name: DBLoss balance_loss: true main_loss_type: DiceLoss diff --git a/ppocr/losses/combined_loss.py b/ppocr/losses/combined_loss.py index f37c4db128e5b0cb3ef9a9fe249402ba37d5ceeb..f10efa31e27f44968760486a71631a36f83679f0 100644 --- a/ppocr/losses/combined_loss.py +++ b/ppocr/losses/combined_loss.py @@ -50,11 +50,11 @@ class CombinedLoss(nn.Layer): if isinstance(loss, paddle.Tensor): loss = {"loss_{}_{}".format(str(loss), idx): loss} weight = self.loss_weight[idx] - for key in loss: + for key in loss.keys(): if key == "loss": loss_all += loss[key] * weight else: - loss["{}_{}".format(key, idx)] = loss[key] + loss_dict["{}_{}".format(key, idx)] = loss[key] # loss[f"{key}_{idx}"] = loss[key] loss_dict.update(loss) loss_dict["loss"] = loss_all diff --git a/ppocr/losses/distillation_loss.py b/ppocr/losses/distillation_loss.py index d4e4a8a2a4d0908ad4819b13629b0bf37de31133..43356c6f6ead790670114a3230d0af91424c6a0d 100644 --- a/ppocr/losses/distillation_loss.py +++ b/ppocr/losses/distillation_loss.py @@ -24,7 +24,6 @@ from .det_db_loss import DBLoss from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss - def _sum_loss(loss_dict): if "loss" in loss_dict.keys(): return loss_dict @@ -51,9 +50,17 @@ class DistillationDMLLoss(DMLLoss): super().__init__(act=act) assert isinstance(model_name_pairs, list) self.key = key - self.model_name_pairs = model_name_pairs + self.model_name_pairs = self._check_model_name_pairs(model_name_pairs) self.name = name self.maps_name = maps_name + + def _check_model_name_pairs(self, model_name_pairs): + if not isinstance(model_name_pairs, list): + return [] + elif isinstance(model_name_pairs[0], list) and isinstance(model_name_pairs[0][0], str): + return model_name_pairs + else: + return [model_name_pairs] def _check_maps_name(self, maps_name): if maps_name is None: @@ -69,13 +76,14 @@ class DistillationDMLLoss(DMLLoss): new_outs = {} for k in self.maps_name: if k == "thrink_maps": - new_outs[k] = paddle.slice(outs, axes=1, starts=0, ends=1) + new_outs[k] = paddle.slice(outs, axes=[1], starts=[0], ends=[1]) elif k == "threshold_maps": - new_outs[k] = paddle.slice(outs, axes=1, starts=1, ends=2) + new_outs[k] = paddle.slice(outs, axes=[1], starts=[1], ends=[2]) elif k == "binary_maps": - new_outs[k] = paddle.slice(outs, axes=1, starts=2, ends=3) + new_outs[k] = paddle.slice(outs, axes=[1], starts=[2], ends=[3]) else: continue + return new_outs def forward(self, predicts, batch): loss_dict = dict() @@ -104,7 +112,7 @@ class DistillationDMLLoss(DMLLoss): loss_dict["{}_{}_{}_{}_{}".format(key, pair[ 0], pair[1], map_name, idx)] = loss[key] else: - loss_dict["{}_{}_{}".format(self.name, map_name, + loss_dict["{}_{}_{}".format(self.name, self.maps_name, idx)] = loss loss_dict = _sum_loss(loss_dict) @@ -151,7 +159,7 @@ class DistillationDBLoss(DBLoss): self.name = name self.key = None - def forward(self, preicts, batch): + def forward(self, predicts, batch): loss_dict = {} for idx, model_name in enumerate(self.model_name_list): out = predicts[model_name] diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index f2ac65c4f9f212c7da9a1128b05d7bd6e0625b15..654ddf39d23590fbaf7f7b9b57f38cc86a1b6669 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -34,7 +34,8 @@ def build_post_process(config, global_config=None): support_dict = [ 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess', - 'DistillationCTCLabelDecode', 'TableLabelDecode', 'DistillationDBPostProcess' + 'DistillationCTCLabelDecode', 'TableLabelDecode', + 'DistillationDBPostProcess' ] config = copy.deepcopy(config) diff --git a/ppocr/postprocess/db_postprocess.py b/ppocr/postprocess/db_postprocess.py index 4561b4642ad8d21c14c2f1de7d360ae455d4b57b..f2b2fc69efc46d72c047895d6fe1a4c12f5663e9 100755 --- a/ppocr/postprocess/db_postprocess.py +++ b/ppocr/postprocess/db_postprocess.py @@ -200,12 +200,9 @@ class DistillationDBPostProcess(DBPostProcess): use_dilation=False, score_mode="fast", **kwargs): - super(DistillationDBPostProcess, self).__init__(thresh, - box_thresh, - max_candidates, - unclip_ratio, - use_dilation, - score_mode) + super(DistillationDBPostProcess, self).__init__( + thresh, box_thresh, max_candidates, unclip_ratio, use_dilation, + score_mode) if not isinstance(model_name, list): model_name = [model_name] self.model_name = model_name @@ -221,10 +218,3 @@ class DistillationDBPostProcess(DBPostProcess): results[name] = super().__call__(pred, shape_list=label) return results - - - - - - -