提交 560e8205 编写于 作者: W wuzewu

Fix image classification module error

上级 71defaf0
...@@ -70,6 +70,7 @@ class ImageClassifierModule(RunModule, ImageServing): ...@@ -70,6 +70,7 @@ class ImageClassifierModule(RunModule, ImageServing):
''' '''
images = batch[0] images = batch[0]
labels = paddle.unsqueeze(batch[1], axis=-1) labels = paddle.unsqueeze(batch[1], axis=-1)
labels = labels.astype('int64')
preds, feature = self(images) preds, feature = self(images)
...@@ -104,7 +105,7 @@ class ImageClassifierModule(RunModule, ImageServing): ...@@ -104,7 +105,7 @@ class ImageClassifierModule(RunModule, ImageServing):
batch_data.append(image) batch_data.append(image)
except: except:
pass pass
batch_image = np.array(batch_data) batch_image = np.array(batch_data, dtype='float32')
preds, feature = self(paddle.to_tensor(batch_image)) preds, feature = self(paddle.to_tensor(batch_image))
preds = F.softmax(preds, axis=1).numpy() preds = F.softmax(preds, axis=1).numpy()
pred_idxs = np.argsort(preds)[:, ::-1][:, :top_k] pred_idxs = np.argsort(preds)[:, ::-1][:, :top_k]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册