提交 92c9eaf3 编写于 作者: A andyjpaddle

fix sar train on cpu

上级 529133fb
...@@ -275,7 +275,6 @@ class ParallelSARDecoder(BaseDecoder): ...@@ -275,7 +275,6 @@ class ParallelSARDecoder(BaseDecoder):
if img_metas is not None and self.mask: if img_metas is not None and self.mask:
valid_ratios = img_metas[-1] valid_ratios = img_metas[-1]
label = label.cuda()
lab_embedding = self.embedding(label) lab_embedding = self.embedding(label)
# bsz * seq_len * emb_dim # bsz * seq_len * emb_dim
out_enc = out_enc.unsqueeze(1) out_enc = out_enc.unsqueeze(1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册