From e3163aac396d01d36e3ddd71cdcedb9c0974db74 Mon Sep 17 00:00:00 2001 From: breezedeus Date: Sun, 8 Aug 2021 21:04:59 +0800 Subject: [PATCH] fix: `input_lengths` should be in cpu at `pack_padded_sequence` --- cnocr/models/crnn.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/cnocr/models/crnn.py b/cnocr/models/crnn.py index ef1a33d..3850077 100644 --- a/cnocr/models/crnn.py +++ b/cnocr/models/crnn.py @@ -166,13 +166,7 @@ class CRNN(OcrModel): self, batch, return_model_output: bool = False, return_preds: bool = False, ): imgs, img_lengths, labels_list, label_lengths = batch - return self( - imgs, - img_lengths, - labels_list, - return_model_output, - return_preds, - ) + return self(imgs, img_lengths, labels_list, return_model_output, return_preds) def _compute_loss( self, @@ -242,9 +236,16 @@ class CRNN(OcrModel): c, h, w = features.shape[1], features.shape[2], features.shape[3] features_seq = torch.reshape(features, shape=(-1, h * c, w)) features_seq = torch.transpose(features_seq, 1, 2) - features_seq = pack_padded_sequence(features_seq, input_lengths, batch_first=True, enforce_sorted=False) + features_seq = pack_padded_sequence( + features_seq, + input_lengths.to(device='cpu'), + batch_first=True, + enforce_sorted=False, + ) logits, _ = self.decoder(features_seq) - logits, output_lens = pad_packed_sequence(logits, batch_first=True, total_length=w) + logits, output_lens = pad_packed_sequence( + logits, batch_first=True, total_length=w + ) logits = self.linear(logits) out: Dict[str, Any] = {} -- GitLab