diff --git a/ppocr/postprocess/pse_postprocess/pse_postprocess.py b/ppocr/postprocess/pse_postprocess/pse_postprocess.py index 34f1b8c9b5397a5513462468a9ee3d8530389607..962f3efe922c4a2656e0f44f478e1baf301a5542 100755 --- a/ppocr/postprocess/pse_postprocess/pse_postprocess.py +++ b/ppocr/postprocess/pse_postprocess/pse_postprocess.py @@ -58,6 +58,8 @@ class PSEPostProcess(object): kernels = (pred > self.thresh).astype('float32') text_mask = kernels[:, 0, :, :] + text_mask = paddle.unsqueeze(text_mask, axis=1) + kernels[:, 0:, :, :] = kernels[:, 0:, :, :] * text_mask score = score.numpy()