提交 c2c43bb1 编写于 作者: 文幕地方's avatar 文幕地方

rename SLANetLoss to SLALoss

上级 731688c2
......@@ -54,7 +54,7 @@ Architecture:
loc_reg_num: &loc_reg_num 4
Loss:
name: SLANetLoss
name: SLALoss
structure_weight: 1.0
loc_weight: 2.0
loc_loss: smooth_l1
......
......@@ -52,7 +52,7 @@ from .basic_loss import DistanceLoss
from .combined_loss import CombinedLoss
# table loss
from .table_att_loss import TableAttentionLoss, SLANetLoss
from .table_att_loss import TableAttentionLoss, SLALoss
from .table_master_loss import TableMasterLoss
# vqa token loss
from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
......@@ -64,7 +64,7 @@ def build_loss(config):
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'SLANetLoss'
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'SLALoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
......
......@@ -55,9 +55,9 @@ class TableAttentionLoss(nn.Layer):
}
class SLANetLoss(nn.Layer):
class SLALoss(nn.Layer):
def __init__(self, structure_weight, loc_weight, loc_loss='mse', **kwargs):
super(SLANetLoss, self).__init__()
super(SLALoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='mean')
self.structure_weight = structure_weight
self.loc_weight = loc_weight
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册