提交 ae9c48af 编写于 作者: G guosheng

add caffe_predict to test

上级 54c03f63
......@@ -205,10 +205,35 @@ class ModelConverter(object):
self.params[file_name] = (param_conf, data.flatten())
return name
def caffe_predict(self,
img,
mean_file='./caffe/imagenet/ilsvrc_2012_mean.npy'):
net = self.net
mu = np.load(mean_file)
mu = mu.mean(1).mean(1)
transformer = caffe.io.Transformer({
'data': net.blobs['data'].data.shape
})
transformer.set_transpose('data', (2, 0, 1))
transformer.set_mean('data', mu)
transformer.set_raw_scale('data', 255)
transformer.set_channel_swap('data', (2, 1, 0))
im = caffe.io.load_image(img)
net.blobs['data'].data[...] = transformer.preprocess('data', im)
out = net.forward()
output_prob = net.blobs['prob'].data[0].flatten()
print np.sort(output_prob)[::-1]
print np.argsort(output_prob)[::-1]
print 'predicted class is:', output_prob.argmax()
if __name__ == "__main__":
converter = ModelConverter("./VGG_ILSVRC_16_layers_deploy.prototxt",
"./VGG_ILSVRC_16_layers.caffemodel",
"/Users/baidu/caffe/caffe/python/paddle_model",
"test_vgg16.tar.gz")
converter.convert()
converter.caffe_predict(img='./caffe/examples/images/cat.jpg')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册