提交 78d4347d 编写于 作者: W Wang,Jeff

Add some comments

上级 6c30c1ed
...@@ -24,10 +24,12 @@ from resnet import resnet_cifar10 ...@@ -24,10 +24,12 @@ from resnet import resnet_cifar10
def inference_network(): def inference_network():
# The image is 32 * 32 with RGB representation.
data_shape = [3, 32, 32] data_shape = [3, 32, 32]
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32') images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
predict = resnet_cifar10(images, 32) predict = resnet_cifar10(images, 32)
# predict = vgg_bn_drop(images) # predict = vgg_bn_drop(images) # un-comment to use vgg net
return predict return predict
...@@ -87,7 +89,7 @@ def infer(use_cuda, inference_program, params_dirname=None): ...@@ -87,7 +89,7 @@ def infer(use_cuda, inference_program, params_dirname=None):
inferencer = fluid.Inferencer( inferencer = fluid.Inferencer(
infer_func=inference_program, param_path=params_dirname, place=place) infer_func=inference_program, param_path=params_dirname, place=place)
# inference # Prepare testing data.
from PIL import Image from PIL import Image
import numpy as np import numpy as np
import os import os
...@@ -105,13 +107,16 @@ def infer(use_cuda, inference_program, params_dirname=None): ...@@ -105,13 +107,16 @@ def infer(use_cuda, inference_program, params_dirname=None):
# image is B(Blue), G(green), R(Red). But PIL open # image is B(Blue), G(green), R(Red). But PIL open
# image in RGB mode. It must swap the channel order. # image in RGB mode. It must swap the channel order.
im = im[(2, 1, 0), :, :] # BGR im = im[(2, 1, 0), :, :] # BGR
# im = im.flatten()
im = im / 255.0 im = im / 255.0
# Add one dimension to mimic the list format.
im = numpy.expand_dims(im, axis=0) im = numpy.expand_dims(im, axis=0)
return im return im
cur_dir = os.path.dirname(os.path.realpath(__file__)) cur_dir = os.path.dirname(os.path.realpath(__file__))
img = load_image(cur_dir + '/image/dog.png') img = load_image(cur_dir + '/image/dog.png')
# inference
results = inferencer.infer({'pixel': img}) results = inferencer.infer({'pixel': img})
print("infer results: ", results) print("infer results: ", results)
...@@ -134,4 +139,6 @@ def main(use_cuda): ...@@ -134,4 +139,6 @@ def main(use_cuda):
if __name__ == '__main__': if __name__ == '__main__':
# For demo purpose, the training runs on CPU
# Please change accordingly.
main(use_cuda=False) main(use_cuda=False)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册