提交 560e8205 编写于 作者: W wuzewu

Fix image classification module error

上级 71defaf0
......@@ -70,6 +70,7 @@ class ImageClassifierModule(RunModule, ImageServing):
'''
images = batch[0]
labels = paddle.unsqueeze(batch[1], axis=-1)
labels = labels.astype('int64')
preds, feature = self(images)
......@@ -104,7 +105,7 @@ class ImageClassifierModule(RunModule, ImageServing):
batch_data.append(image)
except:
pass
batch_image = np.array(batch_data)
batch_image = np.array(batch_data, dtype='float32')
preds, feature = self(paddle.to_tensor(batch_image))
preds = F.softmax(preds, axis=1).numpy()
pred_idxs = np.argsort(preds)[:, ::-1][:, :top_k]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册