diff --git a/configs/table/SLANet.yml b/configs/table/SLANet.yml index acf8d0304558816d206a3c7f37de4aaba301683b..4858c71c28710c2f68faac6397d4f60fba7a4e95 100644 --- a/configs/table/SLANet.yml +++ b/configs/table/SLANet.yml @@ -54,7 +54,7 @@ Architecture: loc_reg_num: &loc_reg_num 4 Loss: - name: SLANetLoss + name: SLALoss structure_weight: 1.0 loc_weight: 2.0 loc_loss: smooth_l1 diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index 8f3adfccd46b7cedd3141e1cfce5baba621c8676..3ac766da92cd6d2fdd552ba28fe4d34aa8767b42 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -52,7 +52,7 @@ from .basic_loss import DistanceLoss from .combined_loss import CombinedLoss # table loss -from .table_att_loss import TableAttentionLoss, SLANetLoss +from .table_att_loss import TableAttentionLoss, SLALoss from .table_master_loss import TableMasterLoss # vqa token loss from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss @@ -64,7 +64,7 @@ def build_loss(config): 'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss', 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss', - 'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'SLANetLoss' + 'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'SLALoss' ] config = copy.deepcopy(config) module_name = config.pop('name') diff --git a/ppocr/losses/table_att_loss.py b/ppocr/losses/table_att_loss.py index d97715d5419840431eaa17ea359f4b06afed26d3..f1771847b46b99d8cf2a3ae69e7e990ee02f26a5 100644 --- a/ppocr/losses/table_att_loss.py +++ b/ppocr/losses/table_att_loss.py @@ -55,9 +55,9 @@ class TableAttentionLoss(nn.Layer): } -class SLANetLoss(nn.Layer): +class SLALoss(nn.Layer): def __init__(self, structure_weight, loc_weight, loc_loss='mse', **kwargs): - super(SLANetLoss, self).__init__() + super(SLALoss, self).__init__() self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='mean') self.structure_weight = structure_weight self.loc_weight = loc_weight