提交 64ec4fb7 编写于 作者: A andyjpaddle

fix sar bug

上级 25e2f818
...@@ -216,7 +216,7 @@ class ParallelSARDecoder(BaseDecoder): ...@@ -216,7 +216,7 @@ class ParallelSARDecoder(BaseDecoder):
self.pred_dropout = nn.Dropout(pred_dropout) self.pred_dropout = nn.Dropout(pred_dropout)
pred_num_classes = self.num_classes - 1 pred_num_classes = self.num_classes - 1
if pred_concat: if pred_concat:
fc_in_channel = decoder_rnn_out_size + d_model + d_enc fc_in_channel = decoder_rnn_out_size + d_model + encoder_rnn_out_size
else: else:
fc_in_channel = d_model fc_in_channel = d_model
self.prediction = nn.Linear(fc_in_channel, pred_num_classes) self.prediction = nn.Linear(fc_in_channel, pred_num_classes)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册