未验证 提交 e758d273 编写于 作者: A andyjpaddle 提交者: GitHub

fix sar bug (#5864)

上级 e6133036
...@@ -216,7 +216,7 @@ class ParallelSARDecoder(BaseDecoder): ...@@ -216,7 +216,7 @@ class ParallelSARDecoder(BaseDecoder):
self.pred_dropout = nn.Dropout(pred_dropout) self.pred_dropout = nn.Dropout(pred_dropout)
pred_num_classes = self.num_classes - 1 pred_num_classes = self.num_classes - 1
if pred_concat: if pred_concat:
fc_in_channel = decoder_rnn_out_size + d_model + d_enc fc_in_channel = decoder_rnn_out_size + d_model + encoder_rnn_out_size
else: else:
fc_in_channel = d_model fc_in_channel = d_model
self.prediction = nn.Linear(fc_in_channel, pred_num_classes) self.prediction = nn.Linear(fc_in_channel, pred_num_classes)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册