4.1 KB
Newer Older
wanglong03 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
#!/bin/env python

#   a demo to show how to use the converted model genereated by caffe2fluid
#   only support imagenet data

import os
import sys
import inspect
import numpy as np
import paddle.v2 as paddle
import paddle.v2.fluid as fluid

def load_data(imgfile, shape):
    h, w = shape[1:]
    from PIL import Image
    im =

    # The storage order of the loaded image is W(widht),
    # H(height), C(channel). PaddlePaddle requires
    # the CHW order, so transpose them.
    im = im.resize((w, h), Image.ANTIALIAS)
    im = np.array(im).astype(np.float32)
    im = im.transpose((2, 0, 1))  # CHW
    im = im[(2, 1, 0), :, :]  # BGR

    # The mean to be subtracted from each image.
    # By default, the per-channel ImageNet mean.
    mean = np.array([104., 117., 124.], dtype=np.float32)
    mean = mean.reshape([3, 1, 1])
    im = im - mean
    return im.reshape([1] + shape)

def build_model(net_file, net_name):
    print('build model with net_file[%s] and net_name[%s]' % (net_file, net_name))

    net_path = os.path.dirname(net_file)
    module_name = os.path.basename(net_file).rstrip('.py')
    if net_path not in sys.path:
        sys.path.insert(0, net_path)

        m = __import__(module_name, fromlist=[net_name])
        MyNet = getattr(m, net_name)
    except Exception as e:
        print('failed to load module[%s]' % (module_name))
        return None

    input_name = 'data'
    input_shape = MyNet.input_shapes()[input_name]
    images ='image', shape=input_shape, dtype='float32')
    #label ='label', shape=[1], dtype='int64')

    net = MyNet({input_name: images})
    input_shape = MyNet.input_shapes()[input_name]
    return net, input_shape

def dump_results(results, names, root):
    if os.path.exists(root) is False:

    for i in range(len(names)):
        n = names[i]
        res = results[i]
        filename = os.path.join(root, n) + '.npy', res)

def infer(net_file, net_name, model_file, imgfile, debug=False):
    """ do inference using a model which consist '' and 'xxx.npy'
    #1, build model
    net, input_shape = build_model(net_file, net_name)
    prediction = net.get_output()

    #2, load weights for this model
    place = fluid.CPUPlace()
    exe = fluid.Executor(place)
    startup_program = fluid.default_startup_program()

    if model_file.find('.npy') > 0:
        net.load(data_path=model_file, exe=exe, place=place)
        net.load(data_path=model_file, exe=exe)

    #3, test this model
    test_program = fluid.default_main_program().clone()

    fetch_list_var = []
    fetch_list_name = []
    if debug is False:
        for k, v in net.layers.items():

    np_images = load_data(imgfile, input_shape)
    results =, 
            feed={'image': np_images},

    if debug is True:
        dump_path = 'results.layers'
        dump_results(results, fetch_list_name, dump_path)
        print('all results dumped to [%s]' % (dump_path))
        result = results[0]
        print('predicted class:', np.argmax(result))

if __name__ == "__main__":
    """ maybe more convenient to use '' to call this tool
    net_file = 'models/resnet50/'
    weight_file = 'models/resnet50/resnet50.npy'
    imgfile = 'data/65.jpeg'
    net_name = 'ResNet50'

    argc = len(sys.argv)
    if argc == 5:
        net_file = sys.argv[1]
        weight_file =  sys.argv[2]
        imgfile = sys.argv[3]
        net_name= sys.argv[4]
    elif argc > 1:
        print('\tpython %s [net_file] [weight_file] [imgfile] [net_name]' % (sys.argv[0]))
        print('\teg:python %s %s %s %s %s' % (sys.argv[0], 
            net_file, weight_file, imgfile, net_name))

    infer(net_file, net_name, weight_file, imgfile)