From f8bc7215ec3d433ac69a1e4d5116dda71c754fec Mon Sep 17 00:00:00 2001 From: guosheng Date: Thu, 8 Jun 2017 02:17:09 +0800 Subject: [PATCH] add README.md --- image_classification/caffe2paddle/README.md | 55 +++++ .../{ => caffe2paddle}/caffe2paddle.py | 54 +++-- image_classification/caffe2paddle/image.py | 223 ++++++++++++++++++ .../caffe2paddle/paddle_resnet.py | 137 +++++++++++ 4 files changed, 444 insertions(+), 25 deletions(-) create mode 100644 image_classification/caffe2paddle/README.md rename image_classification/{ => caffe2paddle}/caffe2paddle.py (86%) create mode 100644 image_classification/caffe2paddle/image.py create mode 100644 image_classification/caffe2paddle/paddle_resnet.py diff --git a/image_classification/caffe2paddle/README.md b/image_classification/caffe2paddle/README.md new file mode 100644 index 00000000..887a5f4f --- /dev/null +++ b/image_classification/caffe2paddle/README.md @@ -0,0 +1,55 @@ +## 使用说明 + +`caffe2paddle.py`提供了将Caffe训练的模型转换为PaddlePaddle可使用的模型的接口`ModelConverter`,其封装了图像领域常用的Convolution、BatchNorm等layer的转换函数,可完成VGG、ResNet等常用模型的转换。模型转换的基本过程是:基于Caffe的Python API加载模型并依次获取每一个layer的信息,将其中的参数根据layer类型与PaddlePaddle适配后序列化保存(对于Pooling等无需训练的layer不做处理),输出可以直接为PaddlePaddle的Python API加载使用的模型文件。 + +`ModelConverter`的定义及说明如下: + +```python +class ModelConverter(object): + #设置Caffe网络配置文件、模型文件路径和要保存为的Paddle模型的文件名,并使用Caffe API加载模型 + def __init__(self, caffe_model_file, caffe_pretrained_file, paddle_tar_name) + + #输出保存Paddle模型 + def to_tar(self, f) + + #将参数值序列化输出为二进制 + @staticmethod + def serialize(data, f) + + #依次对各个layer进行转换,转换时参照name_map进行layer和参数命名 + def convert(self, name_map={}) + + #对Caffe模型的Convolution层的参数进行转换,将使用name值对Paddle模型中对应layer的参数命名 + @wrap_name_default("img_conv_layer") + def convert_Convolution_layer(self, params, name=None) + + #对Caffe模型的InnerProduct层的参数进行转换,将使用name值对Paddle模型中对应layer的参数命名 + @wrap_name_default("fc_layer") + def convert_InnerProduct_layer(self, params, name=None) + + #对Caffe模型的BatchNorm层的参数进行转换,将使用name值对Paddle模型中对应layer的参数命名 + @wrap_name_default("batch_norm_layer") + def convert_BatchNorm_layer(self, params, name=None) + + #对Caffe模型的Scale层的参数进行转换,将使用name值对Paddle模型中对应layer的参数命名 + def convert_Scale_layer(self, params, name=None) + + #输入图片路径和均值文件路径,使用加载的Caffe模型进行预测 + def caffe_predict(self, img, mean_file) + +``` + +`ModelConverter`的使用方法如下: + +```python + #指定Caffe网络配置文件、模型文件路径和要保存为的Paddle模型的文件名,并从指定文件加载模型 + converter = ModelConverter("./ResNet-50-deploy.prototxt", + "./ResNet-50-model.caffemodel", + "Paddle_ResNet50.tar.gz") + #进行模型转换 + converter.convert(name_map={}) + #进行预测并输出预测概率以便对比验证模型转换结果 + converter.caffe_predict(img='./caffe/examples/images/cat.jpg') +``` + +为验证并使用转换得到的模型,需基于PaddlePaddle API编写对应的网络结构配置文件,具体可参照PaddlePaddle使用文档,我们这里附上ResNet的配置以供使用。需要注意,上文给出的模型转换在调用`ModelConverter.convert`时传入了空的`name_map`,这将在遍历每一个layer进行参数保存时使用PaddlePaddle默认的layer和参数命名规则:以`wrap_name_default`中的值和调用计数构造layer name,并以此为前缀构造参数名(比如第一个InnerProduct层的bias参数将被命名为`___fc_layer_0__.wbias`);为此,在编写PaddlePaddle网络配置时要保证和Caffe端模型使用同样的拓扑顺序,尤其是对于ResNet这种有分支的网络结构,要保证两分支在PaddlePaddle和Caffe中先后顺序一致,这样才能够使得模型参数正确加载。如果不希望使用默认的layer name,可以使用一种更为精细的方法:建立Caffe和PaddlePaddle网络配置间layer name对应关系的`dict`并在调用`ModelConverter.convert`时作为`name_map`传入,这样在命名保存layer中的参数时将使用相应的layer name,另外这里只针对Caffe网络配置中Convolution、InnerProduct和BatchNorm类别的layer建立`name_map`即可(一方面,对于Pooling等无需训练的layer不需要保存,故这里没有提供转换接口;另一方面,对于Caffe中的Scale类别的layer,由于Caffe和PaddlePaddle在实现上的一些差别,PaddlePaddle中的batch_norm层同时包含BatchNorm和Scale层的复合,故这里对Scale进行了特殊处理)。 diff --git a/image_classification/caffe2paddle.py b/image_classification/caffe2paddle/caffe2paddle.py similarity index 86% rename from image_classification/caffe2paddle.py rename to image_classification/caffe2paddle/caffe2paddle.py index 9157a32a..4d331d1d 100644 --- a/image_classification/caffe2paddle.py +++ b/image_classification/caffe2paddle/caffe2paddle.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import os import functools import inspect @@ -9,6 +8,7 @@ import cStringIO import numpy as np import caffe from paddle.proto.ParameterConfig_pb2 import ParameterConfig +from image import load_and_transform def __default_not_set_callback__(kwargs, name): @@ -90,15 +90,16 @@ def wrap_name_default(name_prefix=None, name_param="name"): class ModelConverter(object): def __init__(self, caffe_model_file, caffe_pretrained_file, - paddle_tar_name): + paddle_output_path, paddle_tar_name): self.net = caffe.Net(caffe_model_file, caffe_pretrained_file, caffe.TEST) + self.output_path = paddle_output_path self.tar_name = paddle_tar_name self.params = dict() self.pre_layer_name = "" self.pre_layer_type = "" - def convert(self): + def convert(self, name_map={}): layer_dict = self.net.layer_dict for layer_name in layer_dict.keys(): layer = layer_dict[layer_name] @@ -106,7 +107,10 @@ class ModelConverter(object): layer_type = layer.type if len(layer_params) > 0: self.pre_layer_name = getattr( - self, "convert_" + layer_type + "_layer")(layer_params) + self, "convert_" + layer_type + "_layer")( + layer_params, + name=None + if name_map == None else name_map.get(layer_name)) self.pre_layer_type = layer_type with gzip.open(self.tar_name, 'w') as f: self.to_tar(f) @@ -136,7 +140,7 @@ class ModelConverter(object): f.write(struct.pack("IIQ", 0, 4, data.size)) f.write(data.tobytes()) - @wrap_name_default("conv") + @wrap_name_default("img_conv_layer") def convert_Convolution_layer(self, params, name=None): for i in range(len(params)): data = np.array(params[i].data) @@ -149,6 +153,7 @@ class ModelConverter(object): param_conf.name = file_name param_conf.size = reduce(lambda a, b: a * b, data.shape) self.params[file_name] = (param_conf, data.flatten()) + return name @wrap_name_default("fc_layer") @@ -171,9 +176,10 @@ class ModelConverter(object): self.params[file_name] = (param_conf, data.flatten()) return name - @wrap_name_default("batch_norm") + @wrap_name_default("batch_norm_layer") def convert_BatchNorm_layer(self, params, name=None): - scale = np.array(params[-1].data) + scale = 1 / np.array(params[-1].data)[0] if np.array( + params[-1].data)[0] != 0 else 0 for i in range(2): data = np.array(params[i].data) * scale file_name = "_%s.w%s" % (name, str(i + 1)) @@ -210,19 +216,7 @@ class ModelConverter(object): mean_file='./caffe/imagenet/ilsvrc_2012_mean.npy'): net = self.net - mu = np.load(mean_file) - mu = mu.mean(1).mean(1) - - transformer = caffe.io.Transformer({ - 'data': net.blobs['data'].data.shape - }) - transformer.set_transpose('data', (2, 0, 1)) - transformer.set_mean('data', mu) - transformer.set_raw_scale('data', 255) - transformer.set_channel_swap('data', (2, 1, 0)) - im = caffe.io.load_image(img) - - net.blobs['data'].data[...] = transformer.preprocess('data', im) + net.blobs['data'].data[...] = load_img(img, mean_file) out = net.forward() output_prob = net.blobs['prob'].data[0].flatten() @@ -231,9 +225,19 @@ class ModelConverter(object): print 'predicted class is:', output_prob.argmax() +def load_image(file, mean_file): + im = load_and_transform(file, 256, 224, is_train=False) + im = im[(2, 1, 0), :, :] + mu = np.load(mean_file) + mu = mu.mean(1).mean(1) + im = im - mu[:, None, None] + im = im / 255.0 + return im + + if __name__ == "__main__": - converter = ModelConverter("./VGG_ILSVRC_16_layers_deploy.prototxt", - "./VGG_ILSVRC_16_layers.caffemodel", - "test_vgg16.tar.gz") - converter.convert() - converter.caffe_predict(img='./caffe/examples/images/cat.jpg') + converter = ModelConverter("./resnet50/ResNet-50-deploy.prototxt", + "./resnet50/ResNet-50-model.caffemodel", + "paddle_resnet50.tar.gz") + converter.convert(name_map=dict()) + converter.caffe_predict("./images/cat.jpg") diff --git a/image_classification/caffe2paddle/image.py b/image_classification/caffe2paddle/image.py new file mode 100644 index 00000000..5fd51704 --- /dev/null +++ b/image_classification/caffe2paddle/image.py @@ -0,0 +1,223 @@ +import numpy as np +try: + import cv2 +except: + print( + "import cv2 error, please install opencv-python: pip install opencv-python" + ) + +__all__ = [ + "load_image", "resize_short", "to_chw", "center_crop", "random_crop", + "left_right_flip", "simple_transform", "load_and_transform" +] +""" +This file contains some common interfaces for image preprocess. +Many users are confused about the image layout. We introduce +the image layout as follows. +- CHW Layout + - The abbreviations: C=channel, H=Height, W=Width + - The default layout of image opened by cv2 or PIL is HWC. + PaddlePaddle only supports the CHW layout. And CHW is simply + a transpose of HWC. It must transpose the input image. +- Color format: RGB or BGR + OpenCV use BGR color format. PIL use RGB color format. Both + formats can be used for training. Noted that, the format should + be keep consistent between the training and inference peroid. +""" + + +def load_image(file, is_color=True): + """ + Load an color or gray image from the file path. + Example usage: + + .. code-block:: python + im = load_image('cat.jpg') + :param file: the input image path. + :type file: string + :param is_color: If set is_color True, it will load and + return a color image. Otherwise, it will + load and return a gray image. + """ + # cv2.IMAGE_COLOR for OpenCV3 + # cv2.CV_LOAD_IMAGE_COLOR for older OpenCV Version + # cv2.IMAGE_GRAYSCALE for OpenCV3 + # cv2.CV_LOAD_IMAGE_GRAYSCALE for older OpenCV Version + # Here, use constant 1 and 0 + # 1: COLOR, 0: GRAYSCALE + flag = 1 if is_color else 0 + im = cv2.imread(file, flag) + return im + + +def resize_short(im, size): + """ + Resize an image so that the length of shorter edge is size. + Example usage: + + .. code-block:: python + im = load_image('cat.jpg') + im = resize_short(im, 256) + + :param im: the input image with HWC layout. + :type im: ndarray + :param size: the shorter edge size of image after resizing. + :type size: int + """ + assert im.shape[-1] == 1 or im.shape[-1] == 3 + h, w = im.shape[:2] + h_new, w_new = size, size + if h > w: + h_new = size * h / w + else: + w_new = size * w / h + im = cv2.resize(im, (h_new, w_new), interpolation=cv2.INTER_CUBIC) + return im + + +def to_chw(im, order=(2, 0, 1)): + """ + Transpose the input image order. The image layout is HWC format + opened by cv2 or PIL. Transpose the input image to CHW layout + according the order (2,0,1). + Example usage: + + .. code-block:: python + im = load_image('cat.jpg') + im = resize_short(im, 256) + im = to_chw(im) + + :param im: the input image with HWC layout. + :type im: ndarray + :param order: the transposed order. + :type order: tuple|list + """ + assert len(im.shape) == len(order) + im = im.transpose(order) + return im + + +def center_crop(im, size, is_color=True): + """ + Crop the center of image with size. + Example usage: + + .. code-block:: python + im = center_crop(im, 224) + + :param im: the input image with HWC layout. + :type im: ndarray + :param size: the cropping size. + :type size: int + :param is_color: whether the image is color or not. + :type is_color: bool + """ + h, w = im.shape[:2] + h_start = (h - size) / 2 + w_start = (w - size) / 2 + h_end, w_end = h_start + size, w_start + size + if is_color: + im = im[h_start:h_end, w_start:w_end, :] + else: + im = im[h_start:h_end, w_start:w_end] + return im + + +def random_crop(im, size, is_color=True): + """ + Randomly crop input image with size. + Example usage: + + .. code-block:: python + im = random_crop(im, 224) + + :param im: the input image with HWC layout. + :type im: ndarray + :param size: the cropping size. + :type size: int + :param is_color: whether the image is color or not. + :type is_color: bool + """ + h, w = im.shape[:2] + h_start = np.random.randint(0, h - size + 1) + w_start = np.random.randint(0, w - size + 1) + h_end, w_end = h_start + size, w_start + size + if is_color: + im = im[h_start:h_end, w_start:w_end, :] + else: + im = im[h_start:h_end, w_start:w_end] + return im + + +def left_right_flip(im): + """ + Flip an image along the horizontal direction. + Return the flipped image. + Example usage: + + .. code-block:: python + im = left_right_flip(im) + + :paam im: input image with HWC layout + :type im: ndarray + """ + if len(im.shape) == 3: + return im[:, ::-1, :] + else: + return im[:, ::-1, :] + + +def simple_transform(im, resize_size, crop_size, is_train, is_color=True): + """ + Simply data argumentation for training. These operations include + resizing, croping and flipping. + Example usage: + + .. code-block:: python + im = simple_transform(im, 256, 224, True) + :param im: The input image with HWC layout. + :type im: ndarray + :param resize_size: The shorter edge length of the resized image. + :type resize_size: int + :param crop_size: The cropping size. + :type crop_size: int + :param is_train: Whether it is training or not. + :type is_train: bool + """ + im = resize_short(im, resize_size) + if is_train: + im = random_crop(im, crop_size) + if np.random.randint(2) == 0: + im = left_right_flip(im) + else: + im = center_crop(im, crop_size) + im = to_chw(im) + + return im + + +def load_and_transform(filename, + resize_size, + crop_size, + is_train, + is_color=True): + """ + Load image from the input file `filename` and transform image for + data argumentation. Please refer to the `simple_transform` interface + for the transform operations. + Example usage: + + .. code-block:: python + im = load_and_transform('cat.jpg', 256, 224, True) + :param filename: The file name of input image. + :type filename: string + :param resize_size: The shorter edge length of the resized image. + :type resize_size: int + :param crop_size: The cropping size. + :type crop_size: int + :param is_train: Whether it is training or not. + :type is_train: bool + """ + im = load_image(filename) + im = simple_transform(im, resize_size, crop_size, is_train, is_color) + return im diff --git a/image_classification/caffe2paddle/paddle_resnet.py b/image_classification/caffe2paddle/paddle_resnet.py new file mode 100644 index 00000000..7d9a9e43 --- /dev/null +++ b/image_classification/caffe2paddle/paddle_resnet.py @@ -0,0 +1,137 @@ +from PIL import Image +import gzip +import numpy as np +import paddle.v2 as paddle +from image import load_and_transform + +__all__ = ['resnet_imagenet', 'resnet_cifar10'] + + +def conv_bn_layer(input, + ch_out, + filter_size, + stride, + padding, + active_type=paddle.activation.Relu(), + ch_in=None): + tmp = paddle.layer.img_conv( + input=input, + filter_size=filter_size, + num_channels=ch_in, + num_filters=ch_out, + stride=stride, + padding=padding, + act=paddle.activation.Linear(), + bias_attr=False) + return paddle.layer.batch_norm(input=tmp, act=active_type) + + +def shortcut(input, n_out, stride, b_projection): + if b_projection: + return conv_bn_layer(input, n_out, 1, stride, 0, + paddle.activation.Linear()) + else: + return input + + +def basicblock(input, ch_out, stride, b_projection): + # TODO: bug fix for ch_in = input.num_filters + conv1 = conv_bn_layer(input, ch_out, 3, stride, 1) + conv2 = conv_bn_layer(conv1, ch_out, 3, 1, 1, paddle.activation.Linear()) + short = shortcut(input, ch_out, stride, b_projection) + return paddle.layer.addto( + input=[conv2, short], act=paddle.activation.Relu()) + + +def bottleneck(input, ch_out, stride, b_projection): + # TODO: bug fix for ch_in = input.num_filters + short = shortcut(input, ch_out * 4, stride, b_projection) + conv1 = conv_bn_layer(input, ch_out, 1, stride, 0) + conv2 = conv_bn_layer(conv1, ch_out, 3, 1, 1) + conv3 = conv_bn_layer(conv2, ch_out * 4, 1, 1, 0, + paddle.activation.Linear()) + return paddle.layer.addto( + input=[conv3, short], act=paddle.activation.Relu()) + + +def layer_warp(block_func, input, features, count, stride): + conv = block_func(input, features, stride, True) + for i in range(1, count): + conv = block_func(conv, features, 1, False) + return conv + + +def resnet_imagenet(input, depth=50): + cfg = { + 18: ([2, 2, 2, 1], basicblock), + 34: ([3, 4, 6, 3], basicblock), + 50: ([3, 4, 6, 3], bottleneck), + 101: ([3, 4, 23, 3], bottleneck), + 152: ([3, 8, 36, 3], bottleneck) + } + stages, block_func = cfg[depth] + conv1 = conv_bn_layer( + input, ch_in=3, ch_out=64, filter_size=7, stride=2, padding=3) + pool1 = paddle.layer.img_pool(input=conv1, pool_size=3, stride=2) + res1 = layer_warp(block_func, pool1, 64, stages[0], 1) + res2 = layer_warp(block_func, res1, 128, stages[1], 2) + res3 = layer_warp(block_func, res2, 256, stages[2], 2) + res4 = layer_warp(block_func, res3, 512, stages[3], 2) + pool2 = paddle.layer.img_pool( + input=res4, pool_size=7, stride=1, pool_type=paddle.pooling.Avg()) + return pool2 + + +def resnet_cifar10(input, depth=32): + # depth should be one of 20, 32, 44, 56, 110, 1202 + assert (depth - 2) % 6 == 0 + n = (depth - 2) / 6 + nStages = {16, 64, 128} + conv1 = conv_bn_layer( + input, ch_in=3, ch_out=16, filter_size=3, stride=1, padding=1) + res1 = layer_warp(basicblock, conv1, 16, n, 1) + res2 = layer_warp(basicblock, res1, 32, n, 2) + res3 = layer_warp(basicblock, res2, 64, n, 2) + pool = paddle.layer.img_pool( + input=res3, pool_size=8, stride=1, pool_type=paddle.pooling.Avg()) + return pool + + +def load_image(file, mean_file): + im = load_and_transform(file, 256, 224, is_train=False) + im = im[(2, 1, 0), :, :] + mu = np.load(mean_file) + mu = mu.mean(1).mean(1) + im = im - mu[:, None, None] + im = im.flatten() + im = im / 255.0 + return im + + +DATA_DIM = 3 * 224 * 224 +CLASS_DIM = 1000 +BATCH_SIZE = 128 + +MODEL_FILE = 'paddle_resnet50.tar.gz' + +if __name__ == "__main__": + paddle.init(use_gpu=False, trainer_count=1) + + img = paddle.layer.data( + "image", type=paddle.data_type.dense_vector(DATA_DIM)) + out = paddle.layer.fc( + input=resnet_imagenet(img, 50), + size=1000, + act=paddle.activation.Softmax()) + + parameters = paddle.parameters.Parameters.from_tar(gzip.open(MODEL_FILE)) + + test_data = [] + test_data.append((load_image("./images/cat.jpg"), )) + output_prob = paddle.infer( + output_layer=out, parameters=parameters, input=test_data, + field="value")[0] + + print np.sort(output_prob)[::-1] + print np.argsort(output_prob)[::-1] + print 'predicted class is:', output_prob.argmax() -- GitLab