caffe2paddle.py 6.7 KB
Newer Older
G
guosheng 已提交
1 2
import os
import struct
G
guosheng 已提交
3 4 5
import gzip
import tarfile
import cStringIO
G
guosheng 已提交
6
import numpy as np
G
guosheng 已提交
7
import cv2
G
guosheng 已提交
8
import caffe
G
guosheng 已提交
9
from paddle.proto.ParameterConfig_pb2 import ParameterConfig
G
guosheng 已提交
10
from paddle.trainer_config_helpers.default_decorators import wrap_name_default
G
guosheng 已提交
11 12 13 14


class ModelConverter(object):
    def __init__(self, caffe_model_file, caffe_pretrained_file,
G
guosheng 已提交
15
                 paddle_tar_name):
G
guosheng 已提交
16 17
        self.net = caffe.Net(caffe_model_file, caffe_pretrained_file,
                             caffe.TEST)
G
guosheng 已提交
18 19
        self.tar_name = paddle_tar_name
        self.params = dict()
G
guosheng 已提交
20 21 22
        self.pre_layer_name = ""
        self.pre_layer_type = ""

G
guosheng 已提交
23
    def convert(self, name_map=None):
G
guosheng 已提交
24 25 26 27 28 29 30
        layer_dict = self.net.layer_dict
        for layer_name in layer_dict.keys():
            layer = layer_dict[layer_name]
            layer_params = layer.blobs
            layer_type = layer.type
            if len(layer_params) > 0:
                self.pre_layer_name = getattr(
G
guosheng 已提交
31 32 33 34
                    self, "convert_" + layer_type + "_layer")(
                        layer_params,
                        name=None
                        if name_map == None else name_map.get(layer_name))
G
guosheng 已提交
35
            self.pre_layer_type = layer_type
G
guosheng 已提交
36 37
        with gzip.open(self.tar_name, 'w') as f:
            self.to_tar(f)
G
guosheng 已提交
38 39
        return

G
guosheng 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
    def to_tar(self, f):
        tar = tarfile.TarFile(fileobj=f, mode='w')
        for param_name in self.params.keys():
            param_conf, param_data = self.params[param_name]

            confStr = param_conf.SerializeToString()
            tarinfo = tarfile.TarInfo(name="%s.protobuf" % param_name)
            tarinfo.size = len(confStr)
            buf = cStringIO.StringIO(confStr)
            buf.seek(0)
            tar.addfile(tarinfo, fileobj=buf)

            buf = cStringIO.StringIO()
            self.serialize(param_data, buf)
            tarinfo = tarfile.TarInfo(name=param_name)
            buf.seek(0)
            tarinfo.size = len(buf.getvalue())
            tar.addfile(tarinfo, buf)

G
guosheng 已提交
59
    @staticmethod
G
guosheng 已提交
60 61 62
    def serialize(data, f):
        f.write(struct.pack("IIQ", 0, 4, data.size))
        f.write(data.tobytes())
G
guosheng 已提交
63

64
    @wrap_name_default("conv")
G
guosheng 已提交
65 66 67 68 69
    def convert_Convolution_layer(self, params, name=None):
        for i in range(len(params)):
            data = np.array(params[i].data)
            if len(params) == 2:
                suffix = "0" if i == 0 else "bias"
G
guosheng 已提交
70
                file_name = "_%s.w%s" % (name, suffix)
G
guosheng 已提交
71
            else:
G
guosheng 已提交
72 73 74
                file_name = "_%s.w%s" % (name, str(i))
            param_conf = ParameterConfig()
            param_conf.name = file_name
75 76 77 78
            dims = list(data.shape)
            if len(dims) == 1:
                dims.insert(1, 1)
                param_conf.dims.extend(dims)
G
guosheng 已提交
79 80
            param_conf.size = reduce(lambda a, b: a * b, data.shape)
            self.params[file_name] = (param_conf, data.flatten())
G
guosheng 已提交
81

G
guosheng 已提交
82 83 84 85 86 87 88 89
        return name

    @wrap_name_default("fc_layer")
    def convert_InnerProduct_layer(self, params, name=None):
        for i in range(len(params)):
            data = np.array(params[i].data)
            if len(params) == 2:
                suffix = "0" if i == 0 else "bias"
G
guosheng 已提交
90
                file_name = "_%s.w%s" % (name, suffix)
G
guosheng 已提交
91
            else:
G
guosheng 已提交
92
                file_name = "_%s.w%s" % (name, str(i))
G
guosheng 已提交
93
            data = np.transpose(data)
G
guosheng 已提交
94 95 96 97 98 99 100 101
            param_conf = ParameterConfig()
            param_conf.name = file_name
            dims = list(data.shape)
            if len(dims) < 2:
                dims.insert(0, 1)
            param_conf.size = reduce(lambda a, b: a * b, dims)
            param_conf.dims.extend(dims)
            self.params[file_name] = (param_conf, data.flatten())
G
guosheng 已提交
102 103
        return name

104
    @wrap_name_default("batch_norm")
G
guosheng 已提交
105
    def convert_BatchNorm_layer(self, params, name=None):
106 107
        scale = 1 / np.array(params[-1].data)[0] if np.array(params[-1].data)[
            0] != 0 else 0
G
guosheng 已提交
108 109
        for i in range(2):
            data = np.array(params[i].data) * scale
G
guosheng 已提交
110 111 112 113 114 115 116 117 118
            file_name = "_%s.w%s" % (name, str(i + 1))
            param_conf = ParameterConfig()
            param_conf.name = file_name
            dims = list(data.shape)
            assert len(dims) == 1
            dims.insert(0, 1)
            param_conf.size = reduce(lambda a, b: a * b, dims)
            param_conf.dims.extend(dims)
            self.params[file_name] = (param_conf, data.flatten())
G
guosheng 已提交
119 120 121 122 123 124 125 126
        return name

    def convert_Scale_layer(self, params, name=None):
        assert self.pre_layer_type == "BatchNorm"
        name = self.pre_layer_name
        for i in range(len(params)):
            data = np.array(params[i].data)
            suffix = "0" if i == 0 else "bias"
G
guosheng 已提交
127 128 129 130 131 132 133 134 135 136
            file_name = "_%s.w%s" % (name, suffix)
            param_conf = ParameterConfig()
            param_conf.name = file_name
            dims = list(data.shape)
            assert len(dims) == 1
            dims.insert(0, 1)
            param_conf.size = reduce(lambda a, b: a * b, dims)
            if i == 1:
                param_conf.dims.extend(dims)
            self.params[file_name] = (param_conf, data.flatten())
G
guosheng 已提交
137 138
        return name

G
guosheng 已提交
139 140 141 142 143
    def caffe_predict(self,
                      img,
                      mean_file='./caffe/imagenet/ilsvrc_2012_mean.npy'):
        net = self.net

G
guosheng 已提交
144
        net.blobs['data'].data[...] = load_image(img, mean_file=mean_file)
G
guosheng 已提交
145 146 147
        out = net.forward()

        output_prob = net.blobs['prob'].data[0].flatten()
G
guosheng 已提交
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
        print zip(np.argsort(output_prob)[::-1], np.sort(output_prob)[::-1])


def load_image(file, resize_size=256, crop_size=224, mean_file=None):
    # load image
    im = cv2.imread(file)
    # resize
    h, w = im.shape[:2]
    h_new, w_new = resize_size, resize_size
    if h > w:
        h_new = resize_size * h / w
    else:
        w_new = resize_size * w / h
    im = cv2.resize(im, (h_new, w_new), interpolation=cv2.INTER_CUBIC)
    # crop
    h, w = im.shape[:2]
    h_start = (h - crop_size) / 2
    w_start = (w - crop_size) / 2
    h_end, w_end = h_start + crop_size, w_start + crop_size
    im = im[h_start:h_end, w_start:w_end, :]
    # transpose to CHW order
    im = im.transpose((2, 0, 1))

    if mean_file:
        mu = np.load(mean_file)
        mu = mu.mean(1).mean(1)
        im = im - mu[:, None, None]
G
guosheng 已提交
175 176 177 178
    im = im / 255.0
    return im


G
guosheng 已提交
179
if __name__ == "__main__":
G
guosheng 已提交
180 181 182 183 184 185 186 187 188 189 190 191
    caffe_model_file = "./ResNet-50-deploy.prototxt"
    caffe_pretrained_file = "./ResNet-50-model.caffemodel"
    paddle_tar_name = "Paddle_ResNet50.tar.gz"

    converter = ModelConverter(
        caffe_model_file=caffe_model_file,
        caffe_pretrained_file=caffe_pretrained_file,
        paddle_tar_name=paddle_tar_name)
    converter.convert()

    converter.caffe_predict("./cat.jpg",
                            "./caffe/imagenet/ilsvrc_2012_mean.npy")