提交 e3163aac 编写于 作者: B breezedeus

fix: `input_lengths` should be in cpu at `pack_padded_sequence`

上级 c480f012
......@@ -166,13 +166,7 @@ class CRNN(OcrModel):
self, batch, return_model_output: bool = False, return_preds: bool = False,
):
imgs, img_lengths, labels_list, label_lengths = batch
return self(
imgs,
img_lengths,
labels_list,
return_model_output,
return_preds,
)
return self(imgs, img_lengths, labels_list, return_model_output, return_preds)
def _compute_loss(
self,
......@@ -242,9 +236,16 @@ class CRNN(OcrModel):
c, h, w = features.shape[1], features.shape[2], features.shape[3]
features_seq = torch.reshape(features, shape=(-1, h * c, w))
features_seq = torch.transpose(features_seq, 1, 2)
features_seq = pack_padded_sequence(features_seq, input_lengths, batch_first=True, enforce_sorted=False)
features_seq = pack_padded_sequence(
features_seq,
input_lengths.to(device='cpu'),
batch_first=True,
enforce_sorted=False,
)
logits, _ = self.decoder(features_seq)
logits, output_lens = pad_packed_sequence(logits, batch_first=True, total_length=w)
logits, output_lens = pad_packed_sequence(
logits, batch_first=True, total_length=w
)
logits = self.linear(logits)
out: Dict[str, Any] = {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册