diff --git a/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml b/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml index 811c7d2d6f16344a3d6ad060fec1a1966241d81b..e65af0a064418f1f21725f6b9e249a8be8391f41 100644 --- a/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml +++ b/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml @@ -83,7 +83,7 @@ Train: order: 'hwc' - ToCHWImage: - KeepKeys: - keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order + keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'entities', 'relations'] # dataloader will return list in this order loader: shuffle: True drop_last: False @@ -122,7 +122,7 @@ Eval: order: 'hwc' - ToCHWImage: - KeepKeys: - keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order + keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'entities', 'relations'] # dataloader will return list in this order loader: shuffle: False drop_last: False diff --git a/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml b/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml index 0bd42901ebd3a37eb29ce854b1e434dc356d9643..cbf8cbb76d5df67d1ccb5c34089abd8cf5bcdfcf 100644 --- a/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml +++ b/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml @@ -57,14 +57,16 @@ Loss: mode: "l2" model_name_pairs: - ["Student", "Teacher"] - key: hidden_states_5 + key: hidden_states + index: 5 name: "loss_5" - DistillationVQADistanceLoss: weight: 0.5 mode: "l2" model_name_pairs: - ["Student", "Teacher"] - key: hidden_states_8 + key: hidden_states + index: 8 name: "loss_8" @@ -126,7 +128,7 @@ Train: order: 'hwc' - ToCHWImage: - KeepKeys: - keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order + keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'entities', 'relations'] # dataloader will return list in this order loader: shuffle: True drop_last: False @@ -166,7 +168,7 @@ Eval: order: 'hwc' - ToCHWImage: - KeepKeys: - keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order + keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'entities', 'relations'] # dataloader will return list in this order loader: shuffle: False drop_last: False diff --git a/ppocr/losses/distillation_loss.py b/ppocr/losses/distillation_loss.py index 87fed6235d73aef2695cd6db95662e615d52c94c..4bfbed75a338e2bd3bca0b80d16028030bf2f0b5 100644 --- a/ppocr/losses/distillation_loss.py +++ b/ppocr/losses/distillation_loss.py @@ -417,11 +417,13 @@ class DistillationVQADistanceLoss(DistanceLoss): mode="l2", model_name_pairs=[], key=None, + index=None, name="loss_distance", **kargs): super().__init__(mode=mode, **kargs) assert isinstance(model_name_pairs, list) self.key = key + self.index = index self.model_name_pairs = model_name_pairs self.name = name + "_l2" @@ -434,6 +436,9 @@ class DistillationVQADistanceLoss(DistanceLoss): if self.key is not None: out1 = out1[self.key] out2 = out2[self.key] + if self.index is not None: + out1 = out1[:, self.index, :, :] + out2 = out2[:, self.index, :, :] if attention_mask is not None: max_len = attention_mask.shape[-1] out1 = out1[:, :max_len] diff --git a/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py b/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py index a6011acf8d68f65adfc84e134c9cc0e733dd68ea..64f7d761950249eaef2946e09365dbaab4d94c6c 100644 --- a/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py +++ b/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py @@ -32,7 +32,7 @@ class VQAReTokenLayoutLMPostProcess(object): return self._infer(pred_relations, *args, **kwargs) def _metric(self, pred_relations, label): - return pred_relations, label[6], label[5] + return pred_relations, label[-1], label[-2] def _infer(self, pred_relations, *args, **kwargs): ser_results = kwargs['ser_results'] diff --git a/ppstructure/kie/predict_kie_token_ser_re.py b/ppstructure/kie/predict_kie_token_ser_re.py index b4eace4b5ee15ccf64a03e96dafcb1cfb021e656..278e08da918ab8f77062b444becd399b4ea2c0b6 100644 --- a/ppstructure/kie/predict_kie_token_ser_re.py +++ b/ppstructure/kie/predict_kie_token_ser_re.py @@ -64,7 +64,10 @@ class SerRePredictor(object): for output_tensor in self.output_tensors: output = output_tensor.copy_to_cpu() outputs.append(output) - preds = dict(loss=outputs[0], pred_relations=outputs[1]) + preds = dict( + loss=outputs[1], + pred_relations=outputs[2], + hidden_states=outputs[0], ) post_result = self.postprocess_op( preds,