From 5d7b2ce8a6c43a2389f5a158a9b2519deb0602d7 Mon Sep 17 00:00:00 2001 From: breezedeus Date: Thu, 26 Aug 2021 12:49:15 +0800 Subject: [PATCH] fix: different devices when running in gpu --- cnocr/models/ctc.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cnocr/models/ctc.py b/cnocr/models/ctc.py index b5bda80..7c84f9c 100644 --- a/cnocr/models/ctc.py +++ b/cnocr/models/ctc.py @@ -65,7 +65,9 @@ class CTCPostProcessor(object): best_path = torch.argmax(probs, dim=1) # [N, T] if input_lengths is not None: - length_mask = gen_length_mask(input_lengths, probs.shape) # [N, 1, T] + length_mask = gen_length_mask(input_lengths, probs.shape).to( + device=probs.device + ) # [N, 1, T] probs.masked_fill_(length_mask, 1.0) best_path.masked_fill_(length_mask.squeeze(1), blank) @@ -102,4 +104,4 @@ class CTCPostProcessor(object): vocab=self.vocab, input_lengths=input_lengths, blank=len(self.vocab), - ) \ No newline at end of file + ) -- GitLab