提交 c2c43bb1 编写于 作者: 文幕地方's avatar 文幕地方

rename SLANetLoss to SLALoss

上级 731688c2
...@@ -54,7 +54,7 @@ Architecture: ...@@ -54,7 +54,7 @@ Architecture:
loc_reg_num: &loc_reg_num 4 loc_reg_num: &loc_reg_num 4
Loss: Loss:
name: SLANetLoss name: SLALoss
structure_weight: 1.0 structure_weight: 1.0
loc_weight: 2.0 loc_weight: 2.0
loc_loss: smooth_l1 loc_loss: smooth_l1
......
...@@ -52,7 +52,7 @@ from .basic_loss import DistanceLoss ...@@ -52,7 +52,7 @@ from .basic_loss import DistanceLoss
from .combined_loss import CombinedLoss from .combined_loss import CombinedLoss
# table loss # table loss
from .table_att_loss import TableAttentionLoss, SLANetLoss from .table_att_loss import TableAttentionLoss, SLALoss
from .table_master_loss import TableMasterLoss from .table_master_loss import TableMasterLoss
# vqa token loss # vqa token loss
from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
...@@ -64,7 +64,7 @@ def build_loss(config): ...@@ -64,7 +64,7 @@ def build_loss(config):
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss', 'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss', 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'SLANetLoss' 'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'SLALoss'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_name = config.pop('name') module_name = config.pop('name')
......
...@@ -55,9 +55,9 @@ class TableAttentionLoss(nn.Layer): ...@@ -55,9 +55,9 @@ class TableAttentionLoss(nn.Layer):
} }
class SLANetLoss(nn.Layer): class SLALoss(nn.Layer):
def __init__(self, structure_weight, loc_weight, loc_loss='mse', **kwargs): def __init__(self, structure_weight, loc_weight, loc_loss='mse', **kwargs):
super(SLANetLoss, self).__init__() super(SLALoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='mean') self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='mean')
self.structure_weight = structure_weight self.structure_weight = structure_weight
self.loc_weight = loc_weight self.loc_weight = loc_weight
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册