From c2c43bb1bc757bc5ea565ab59a23294f7a172517 Mon Sep 17 00:00:00 2001 From: WenmuZhou <572459439@qq.com> Date: Wed, 10 Aug 2022 14:58:08 +0000 Subject: [PATCH] rename SLANetLoss to SLALoss --- configs/table/SLANet.yml | 2 +- ppocr/losses/__init__.py | 4 ++-- ppocr/losses/table_att_loss.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/configs/table/SLANet.yml b/configs/table/SLANet.yml index acf8d030..4858c71c 100644 --- a/configs/table/SLANet.yml +++ b/configs/table/SLANet.yml @@ -54,7 +54,7 @@ Architecture: loc_reg_num: &loc_reg_num 4 Loss: - name: SLANetLoss + name: SLALoss structure_weight: 1.0 loc_weight: 2.0 loc_loss: smooth_l1 diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index 8f3adfcc..3ac766da 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -52,7 +52,7 @@ from .basic_loss import DistanceLoss from .combined_loss import CombinedLoss # table loss -from .table_att_loss import TableAttentionLoss, SLANetLoss +from .table_att_loss import TableAttentionLoss, SLALoss from .table_master_loss import TableMasterLoss # vqa token loss from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss @@ -64,7 +64,7 @@ def build_loss(config): 'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss', 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss', - 'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'SLANetLoss' + 'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'SLALoss' ] config = copy.deepcopy(config) module_name = config.pop('name') diff --git a/ppocr/losses/table_att_loss.py b/ppocr/losses/table_att_loss.py index d97715d5..f1771847 100644 --- a/ppocr/losses/table_att_loss.py +++ b/ppocr/losses/table_att_loss.py @@ -55,9 +55,9 @@ class TableAttentionLoss(nn.Layer): } -class SLANetLoss(nn.Layer): +class SLALoss(nn.Layer): def __init__(self, structure_weight, loc_weight, loc_loss='mse', **kwargs): - super(SLANetLoss, self).__init__() + super(SLALoss, self).__init__() self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='mean') self.structure_weight = structure_weight self.loc_weight = loc_weight -- GitLab