From a4bc9da8b6b3c662cd3d2921754b4c69e7559247 Mon Sep 17 00:00:00 2001 From: WenmuZhou <572459439@qq.com> Date: Wed, 21 Sep 2022 17:50:50 +0800 Subject: [PATCH] fix bug --- configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml | 4 ++-- .../kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml | 10 ++++++---- ppocr/losses/distillation_loss.py | 5 +++++ ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py | 2 +- ppstructure/kie/predict_kie_token_ser_re.py | 5 ++++- 5 files changed, 18 insertions(+), 8 deletions(-) diff --git a/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml b/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml index 811c7d2d..e65af0a0 100644 --- a/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml +++ b/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml @@ -83,7 +83,7 @@ Train: order: 'hwc' - ToCHWImage: - KeepKeys: - keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order + keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'entities', 'relations'] # dataloader will return list in this order loader: shuffle: True drop_last: False @@ -122,7 +122,7 @@ Eval: order: 'hwc' - ToCHWImage: - KeepKeys: - keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order + keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'entities', 'relations'] # dataloader will return list in this order loader: shuffle: False drop_last: False diff --git a/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml b/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml index 0bd42901..cbf8cbb7 100644 --- a/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml +++ b/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml @@ -57,14 +57,16 @@ Loss: mode: "l2" model_name_pairs: - ["Student", "Teacher"] - key: hidden_states_5 + key: hidden_states + index: 5 name: "loss_5" - DistillationVQADistanceLoss: weight: 0.5 mode: "l2" model_name_pairs: - ["Student", "Teacher"] - key: hidden_states_8 + key: hidden_states + index: 8 name: "loss_8" @@ -126,7 +128,7 @@ Train: order: 'hwc' - ToCHWImage: - KeepKeys: - keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order + keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'entities', 'relations'] # dataloader will return list in this order loader: shuffle: True drop_last: False @@ -166,7 +168,7 @@ Eval: order: 'hwc' - ToCHWImage: - KeepKeys: - keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order + keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'entities', 'relations'] # dataloader will return list in this order loader: shuffle: False drop_last: False diff --git a/ppocr/losses/distillation_loss.py b/ppocr/losses/distillation_loss.py index 87fed623..4bfbed75 100644 --- a/ppocr/losses/distillation_loss.py +++ b/ppocr/losses/distillation_loss.py @@ -417,11 +417,13 @@ class DistillationVQADistanceLoss(DistanceLoss): mode="l2", model_name_pairs=[], key=None, + index=None, name="loss_distance", **kargs): super().__init__(mode=mode, **kargs) assert isinstance(model_name_pairs, list) self.key = key + self.index = index self.model_name_pairs = model_name_pairs self.name = name + "_l2" @@ -434,6 +436,9 @@ class DistillationVQADistanceLoss(DistanceLoss): if self.key is not None: out1 = out1[self.key] out2 = out2[self.key] + if self.index is not None: + out1 = out1[:, self.index, :, :] + out2 = out2[:, self.index, :, :] if attention_mask is not None: max_len = attention_mask.shape[-1] out1 = out1[:, :max_len] diff --git a/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py b/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py index a6011acf..64f7d761 100644 --- a/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py +++ b/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py @@ -32,7 +32,7 @@ class VQAReTokenLayoutLMPostProcess(object): return self._infer(pred_relations, *args, **kwargs) def _metric(self, pred_relations, label): - return pred_relations, label[6], label[5] + return pred_relations, label[-1], label[-2] def _infer(self, pred_relations, *args, **kwargs): ser_results = kwargs['ser_results'] diff --git a/ppstructure/kie/predict_kie_token_ser_re.py b/ppstructure/kie/predict_kie_token_ser_re.py index b4eace4b..278e08da 100644 --- a/ppstructure/kie/predict_kie_token_ser_re.py +++ b/ppstructure/kie/predict_kie_token_ser_re.py @@ -64,7 +64,10 @@ class SerRePredictor(object): for output_tensor in self.output_tensors: output = output_tensor.copy_to_cpu() outputs.append(output) - preds = dict(loss=outputs[0], pred_relations=outputs[1]) + preds = dict( + loss=outputs[1], + pred_relations=outputs[2], + hidden_states=outputs[0], ) post_result = self.postprocess_op( preds, -- GitLab