From 560e82051c395fb99ad75e4eaf2c2078fc9337db Mon Sep 17 00:00:00 2001 From: wuzewu Date: Thu, 22 Jul 2021 17:05:03 +0800 Subject: [PATCH] Fix image classification module error --- paddlehub/module/cv_module.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddlehub/module/cv_module.py b/paddlehub/module/cv_module.py index 4c5aa03e..62e2a303 100644 --- a/paddlehub/module/cv_module.py +++ b/paddlehub/module/cv_module.py @@ -70,6 +70,7 @@ class ImageClassifierModule(RunModule, ImageServing): ''' images = batch[0] labels = paddle.unsqueeze(batch[1], axis=-1) + labels = labels.astype('int64') preds, feature = self(images) @@ -104,7 +105,7 @@ class ImageClassifierModule(RunModule, ImageServing): batch_data.append(image) except: pass - batch_image = np.array(batch_data) + batch_image = np.array(batch_data, dtype='float32') preds, feature = self(paddle.to_tensor(batch_image)) preds = F.softmax(preds, axis=1).numpy() pred_idxs = np.argsort(preds)[:, ::-1][:, :top_k] -- GitLab