提交 df4a2f6a 编写于 作者: A andyjpaddle

update rec_sar_head

上级 073fad37
......@@ -9,7 +9,7 @@ from paddle import nn
class SARLoss(nn.Layer):
def __init__(self, **kwargs):
super(SARLoss, self).__init__()
self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="mean", ignore_index=92)
self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="mean", ignore_index=96)
def forward(self, predicts, batch):
predict = predicts[:, :-1, :] # ignore last index of outputs to be in same seq_len with targets
......
......@@ -118,8 +118,7 @@ class BaseDecoder(nn.Layer):
class ParallelSARDecoder(BaseDecoder):
"""
Args:
num_classes (int): Output class number.
channels (list[int]): Network layer channels.
out_channels (int): Output class number.
enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
dec_bi_rnn (bool): If True, use bidirectional RNN in decoder.
dec_drop_rnn (float): Dropout of RNN layer in decoder.
......@@ -137,7 +136,7 @@ class ParallelSARDecoder(BaseDecoder):
"""
def __init__(self,
num_classes=93, # 90 + unknown + start + padding
out_channels, # 90 + unknown + start + padding
enc_bi_rnn=False,
dec_bi_rnn=False,
dec_drop_rnn=0.0,
......@@ -148,8 +147,6 @@ class ParallelSARDecoder(BaseDecoder):
pred_dropout=0.1,
max_text_length=30,
mask=True,
start_idx=91,
padding_idx=92, # 92
pred_concat=True,
**kwargs):
super().__init__()
......@@ -157,7 +154,8 @@ class ParallelSARDecoder(BaseDecoder):
self.num_classes = num_classes
self.enc_bi_rnn = enc_bi_rnn
self.d_k = d_k
self.start_idx = start_idx
self.start_idx = out_channels - 2
self.padding_idx = out_channels - 1
self.max_seq_len = max_text_length
self.mask = mask
self.pred_concat = pred_concat
......@@ -191,7 +189,7 @@ class ParallelSARDecoder(BaseDecoder):
# Decoder input embedding
self.embedding = nn.Embedding(
self.num_classes, encoder_rnn_out_size, padding_idx=padding_idx)
self.num_classes, encoder_rnn_out_size, padding_idx=self.padding_idx)
# Prediction layer
self.pred_dropout = nn.Dropout(pred_dropout)
......@@ -330,6 +328,7 @@ class ParallelSARDecoder(BaseDecoder):
class SARHead(nn.Layer):
def __init__(self,
out_channels,
enc_bi_rnn=False,
enc_drop_rnn=0.1,
enc_gru=False,
......@@ -351,7 +350,8 @@ class SARHead(nn.Layer):
# decoder module
self.decoder = ParallelSARDecoder(
enc_bi_rnn=enc_bi_rnn,
out_channels=out_channels,
enc_bi_rnn=enc_bi_rnn,
dec_bi_rnn=dec_bi_rnn,
dec_drop_rnn=dec_drop_rnn,
dec_gru=dec_gru,
......@@ -375,4 +375,4 @@ class SARHead(nn.Layer):
# (bsz, seq_len, num_classes)
return final_out
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册