diff --git a/image_classification/caffe2paddle/README.md b/image_classification/caffe2paddle/README.md index 887a5f4f5592f8ba8575fca732039fafb25430a9..c90e000186e974803494cd5d25df1fc71004c37b 100644 --- a/image_classification/caffe2paddle/README.md +++ b/image_classification/caffe2paddle/README.md @@ -1,55 +1,39 @@ ## 使用说明 -`caffe2paddle.py`提供了将Caffe训练的模型转换为PaddlePaddle可使用的模型的接口`ModelConverter`,其封装了图像领域常用的Convolution、BatchNorm等layer的转换函数,可完成VGG、ResNet等常用模型的转换。模型转换的基本过程是:基于Caffe的Python API加载模型并依次获取每一个layer的信息,将其中的参数根据layer类型与PaddlePaddle适配后序列化保存(对于Pooling等无需训练的layer不做处理),输出可以直接为PaddlePaddle的Python API加载使用的模型文件。 +`caffe2paddle.py`提供了将Caffe训练的模型转换为PaddlePaddle可使用的模型的接口`ModelConverter`,其封装了图像领域常用的Convolution、BatchNorm等layer的转换函数,可以完成VGG、ResNet等常用模型的转换。模型转换的基本过程是:基于Caffe的Python API加载模型并依次获取每一个layer的信息,将其中的参数根据layer类型与PaddlePaddle适配后序列化保存(对于Pooling等无需训练的layer不做处理),输出可以直接为PaddlePaddle的Python API加载使用的模型文件。 -`ModelConverter`的定义及说明如下: +可以按如下方法使用`ModelConverter`接口: ```python -class ModelConverter(object): - #设置Caffe网络配置文件、模型文件路径和要保存为的Paddle模型的文件名,并使用Caffe API加载模型 - def __init__(self, caffe_model_file, caffe_pretrained_file, paddle_tar_name) - - #输出保存Paddle模型 - def to_tar(self, f) - - #将参数值序列化输出为二进制 - @staticmethod - def serialize(data, f) - - #依次对各个layer进行转换,转换时参照name_map进行layer和参数命名 - def convert(self, name_map={}) - - #对Caffe模型的Convolution层的参数进行转换,将使用name值对Paddle模型中对应layer的参数命名 - @wrap_name_default("img_conv_layer") - def convert_Convolution_layer(self, params, name=None) - - #对Caffe模型的InnerProduct层的参数进行转换,将使用name值对Paddle模型中对应layer的参数命名 - @wrap_name_default("fc_layer") - def convert_InnerProduct_layer(self, params, name=None) - - #对Caffe模型的BatchNorm层的参数进行转换,将使用name值对Paddle模型中对应layer的参数命名 - @wrap_name_default("batch_norm_layer") - def convert_BatchNorm_layer(self, params, name=None) - - #对Caffe模型的Scale层的参数进行转换,将使用name值对Paddle模型中对应layer的参数命名 - def convert_Scale_layer(self, params, name=None) +# 定义以下变量为相应的文件路径和文件名 +caffe_model_file = "./ResNet-50-deploy.prototxt" # Caffe网络配置文件的路径 +caffe_pretrained_file = "./ResNet-50-model.caffemodel" # Caffe模型文件的路径 +paddle_tar_name = "Paddle_ResNet50.tar.gz" # 输出的Paddle模型的文件名 + +# 初始化,从指定文件加载模型 +converter = ModelConverter(caffe_model_file=caffe_model_file, + caffe_pretrained_file=caffe_pretrained_file, + paddle_tar_name=paddle_tar_name) +# 进行模型转换 +converter.convert() +``` - #输入图片路径和均值文件路径,使用加载的Caffe模型进行预测 - def caffe_predict(self, img, mean_file) +`caffe2paddle.py`中已提供以上步骤,修改其中文件相关变量的值后执行`python caffe2paddle.py`即可完成模型转换。此外,为辅助验证转换结果,`ModelConverter`中封装了使用Caffe API预测的接口`caffe_predict`,使用如下所示,将会打印按类别概率排序的(类别id, 概率)的列表: +```python +# img为图片路径,mean_file为图像均值文件的路径 +converter.caffe_predict(img="./cat.jpg", mean_file="./imagenet/ilsvrc_2012_mean.npy") ``` -`ModelConverter`的使用方法如下: +需要注意,在模型转换时会对layer的参数进行命名,这里默认使用PaddlePaddle中默认的layer和参数命名规则:以`wrap_name_default`中的值和该layer类型的调用计数构造layer name,并以此为前缀构造参数名,比如第一个InnerProduct层(相应转换函数说明见下方)的bias参数将被命名为`___fc_layer_0__.wbias`。 ```python - #指定Caffe网络配置文件、模型文件路径和要保存为的Paddle模型的文件名,并从指定文件加载模型 - converter = ModelConverter("./ResNet-50-deploy.prototxt", - "./ResNet-50-model.caffemodel", - "Paddle_ResNet50.tar.gz") - #进行模型转换 - converter.convert(name_map={}) - #进行预测并输出预测概率以便对比验证模型转换结果 - converter.caffe_predict(img='./caffe/examples/images/cat.jpg') +# 对InnerProduct层的参数进行转换,使用name值构造对应layer的参数名 +# wrap_name_default设置默认name值为fc_layer +@wrap_name_default("fc_layer") +def convert_InnerProduct_layer(self, params, name=None) ``` -为验证并使用转换得到的模型,需基于PaddlePaddle API编写对应的网络结构配置文件,具体可参照PaddlePaddle使用文档,我们这里附上ResNet的配置以供使用。需要注意,上文给出的模型转换在调用`ModelConverter.convert`时传入了空的`name_map`,这将在遍历每一个layer进行参数保存时使用PaddlePaddle默认的layer和参数命名规则:以`wrap_name_default`中的值和调用计数构造layer name,并以此为前缀构造参数名(比如第一个InnerProduct层的bias参数将被命名为`___fc_layer_0__.wbias`);为此,在编写PaddlePaddle网络配置时要保证和Caffe端模型使用同样的拓扑顺序,尤其是对于ResNet这种有分支的网络结构,要保证两分支在PaddlePaddle和Caffe中先后顺序一致,这样才能够使得模型参数正确加载。如果不希望使用默认的layer name,可以使用一种更为精细的方法:建立Caffe和PaddlePaddle网络配置间layer name对应关系的`dict`并在调用`ModelConverter.convert`时作为`name_map`传入,这样在命名保存layer中的参数时将使用相应的layer name,另外这里只针对Caffe网络配置中Convolution、InnerProduct和BatchNorm类别的layer建立`name_map`即可(一方面,对于Pooling等无需训练的layer不需要保存,故这里没有提供转换接口;另一方面,对于Caffe中的Scale类别的layer,由于Caffe和PaddlePaddle在实现上的一些差别,PaddlePaddle中的batch_norm层同时包含BatchNorm和Scale层的复合,故这里对Scale进行了特殊处理)。 +为此,在验证和使用转换得到的模型时,编写PaddlePaddle网络配置无需指定layer name并且要保证和Caffe端模型使用同样的拓扑顺序,尤其是对于ResNet这种有分支的网络结构,要保证两分支在PaddlePaddle和Caffe中先后顺序一致,这样才能够使得模型参数正确加载。 + +如果不希望使用默认的命名,并且在PaddlePaddle网络配置中指定了layer name,可以建立Caffe和PaddlePaddle网络配置间layer name对应关系的`dict`并在调用`ModelConverter.convert`时作为`name_map`的值传入,这样在命名保存layer中的参数时将使用相应的layer name,不受拓扑顺序的影响。另外这里只针对Caffe网络配置中Convolution、InnerProduct和BatchNorm类别的layer建立`name_map`即可(一方面,对于Pooling等无需训练的layer不需要保存,故这里没有提供转换接口;另一方面,对于Caffe中的Scale类别的layer,由于Caffe和PaddlePaddle在实现上的一些差别,PaddlePaddle中的batch_norm层是BatchNorm和Scale层的复合,故这里对Scale进行了特殊处理)。 diff --git a/image_classification/caffe2paddle/caffe2paddle.py b/image_classification/caffe2paddle/caffe2paddle.py index 4d331d1d08104b379fbe66ad048727c75f79722c..4d01a3c2dc9125ba72fa5d561187daef5cc8547e 100644 --- a/image_classification/caffe2paddle/caffe2paddle.py +++ b/image_classification/caffe2paddle/caffe2paddle.py @@ -1,105 +1,26 @@ import os -import functools -import inspect import struct import gzip import tarfile import cStringIO import numpy as np +import cv2 import caffe from paddle.proto.ParameterConfig_pb2 import ParameterConfig -from image import load_and_transform - - -def __default_not_set_callback__(kwargs, name): - return name not in kwargs or kwargs[name] is None - - -def wrap_param_default(param_names=None, - default_factory=None, - not_set_callback=__default_not_set_callback__): - assert param_names is not None - assert isinstance(param_names, list) or isinstance(param_names, tuple) - for each_param_name in param_names: - assert isinstance(each_param_name, basestring) - - def __impl__(func): - @functools.wraps(func) - def __wrapper__(*args, **kwargs): - if len(args) != 0: - argspec = inspect.getargspec(func) - num_positional = len(argspec.args) - if argspec.defaults: - num_positional -= len(argspec.defaults) - assert argspec.varargs or len( - args - ) <= num_positional, "Must use keyword arguments for non-positional args" - for name in param_names: - if not_set_callback(kwargs, name): # Not set - kwargs[name] = default_factory(func) - return func(*args, **kwargs) - - if hasattr(func, "argspec"): - __wrapper__.argspec = func.argspec - else: - __wrapper__.argspec = inspect.getargspec(func) - return __wrapper__ - - return __impl__ - - -class DefaultNameFactory(object): - def __init__(self, name_prefix): - self.__counter__ = 0 - self.__name_prefix__ = name_prefix - - def __call__(self, func): - if self.__name_prefix__ is None: - self.__name_prefix__ = func.__name__ - tmp = "__%s_%d__" % (self.__name_prefix__, self.__counter__) - self.__check_name__(tmp) - self.__counter__ += 1 - return tmp - - def __check_name__(self, nm): - pass - - def reset(self): - self.__counter__ = 0 - - -def wrap_name_default(name_prefix=None, name_param="name"): - """ - Decorator to set "name" arguments default to "{name_prefix}_{invoke_count}". - - .. code:: python - - @wrap_name_default("some_name") - def func(name=None): - print name # name will never be None. If name is not set, - # name will be "some_name_%d" - - :param name_prefix: name prefix. wrapped function"s __name__ if None. - :type name_prefix: basestring - :return: a decorator to set default name - :rtype: callable - """ - factory = DefaultNameFactory(name_prefix) - return wrap_param_default([name_param], factory) +from paddle.trainer_config_helpers.default_decorators import wrap_name_default class ModelConverter(object): def __init__(self, caffe_model_file, caffe_pretrained_file, - paddle_output_path, paddle_tar_name): + paddle_tar_name): self.net = caffe.Net(caffe_model_file, caffe_pretrained_file, caffe.TEST) - self.output_path = paddle_output_path self.tar_name = paddle_tar_name self.params = dict() self.pre_layer_name = "" self.pre_layer_type = "" - def convert(self, name_map={}): + def convert(self, name_map=None): layer_dict = self.net.layer_dict for layer_name in layer_dict.keys(): layer = layer_dict[layer_name] @@ -216,28 +137,51 @@ class ModelConverter(object): mean_file='./caffe/imagenet/ilsvrc_2012_mean.npy'): net = self.net - net.blobs['data'].data[...] = load_img(img, mean_file) + net.blobs['data'].data[...] = load_image(img, mean_file=mean_file) out = net.forward() output_prob = net.blobs['prob'].data[0].flatten() - print np.sort(output_prob)[::-1] - print np.argsort(output_prob)[::-1] - print 'predicted class is:', output_prob.argmax() - - -def load_image(file, mean_file): - im = load_and_transform(file, 256, 224, is_train=False) - im = im[(2, 1, 0), :, :] - mu = np.load(mean_file) - mu = mu.mean(1).mean(1) - im = im - mu[:, None, None] + print zip(np.argsort(output_prob)[::-1], np.sort(output_prob)[::-1]) + + +def load_image(file, resize_size=256, crop_size=224, mean_file=None): + # load image + im = cv2.imread(file) + # resize + h, w = im.shape[:2] + h_new, w_new = resize_size, resize_size + if h > w: + h_new = resize_size * h / w + else: + w_new = resize_size * w / h + im = cv2.resize(im, (h_new, w_new), interpolation=cv2.INTER_CUBIC) + # crop + h, w = im.shape[:2] + h_start = (h - crop_size) / 2 + w_start = (w - crop_size) / 2 + h_end, w_end = h_start + crop_size, w_start + crop_size + im = im[h_start:h_end, w_start:w_end, :] + # transpose to CHW order + im = im.transpose((2, 0, 1)) + + if mean_file: + mu = np.load(mean_file) + mu = mu.mean(1).mean(1) + im = im - mu[:, None, None] im = im / 255.0 return im if __name__ == "__main__": - converter = ModelConverter("./resnet50/ResNet-50-deploy.prototxt", - "./resnet50/ResNet-50-model.caffemodel", - "paddle_resnet50.tar.gz") - converter.convert(name_map=dict()) - converter.caffe_predict("./images/cat.jpg") + caffe_model_file = "./ResNet-50-deploy.prototxt" + caffe_pretrained_file = "./ResNet-50-model.caffemodel" + paddle_tar_name = "Paddle_ResNet50.tar.gz" + + converter = ModelConverter( + caffe_model_file=caffe_model_file, + caffe_pretrained_file=caffe_pretrained_file, + paddle_tar_name=paddle_tar_name) + converter.convert() + + converter.caffe_predict("./cat.jpg", + "./caffe/imagenet/ilsvrc_2012_mean.npy") diff --git a/image_classification/caffe2paddle/image.py b/image_classification/caffe2paddle/image.py deleted file mode 100644 index 5fd51704f19b06423f85239824379dd5c7c3c52e..0000000000000000000000000000000000000000 --- a/image_classification/caffe2paddle/image.py +++ /dev/null @@ -1,223 +0,0 @@ -import numpy as np -try: - import cv2 -except: - print( - "import cv2 error, please install opencv-python: pip install opencv-python" - ) - -__all__ = [ - "load_image", "resize_short", "to_chw", "center_crop", "random_crop", - "left_right_flip", "simple_transform", "load_and_transform" -] -""" -This file contains some common interfaces for image preprocess. -Many users are confused about the image layout. We introduce -the image layout as follows. -- CHW Layout - - The abbreviations: C=channel, H=Height, W=Width - - The default layout of image opened by cv2 or PIL is HWC. - PaddlePaddle only supports the CHW layout. And CHW is simply - a transpose of HWC. It must transpose the input image. -- Color format: RGB or BGR - OpenCV use BGR color format. PIL use RGB color format. Both - formats can be used for training. Noted that, the format should - be keep consistent between the training and inference peroid. -""" - - -def load_image(file, is_color=True): - """ - Load an color or gray image from the file path. - Example usage: - - .. code-block:: python - im = load_image('cat.jpg') - :param file: the input image path. - :type file: string - :param is_color: If set is_color True, it will load and - return a color image. Otherwise, it will - load and return a gray image. - """ - # cv2.IMAGE_COLOR for OpenCV3 - # cv2.CV_LOAD_IMAGE_COLOR for older OpenCV Version - # cv2.IMAGE_GRAYSCALE for OpenCV3 - # cv2.CV_LOAD_IMAGE_GRAYSCALE for older OpenCV Version - # Here, use constant 1 and 0 - # 1: COLOR, 0: GRAYSCALE - flag = 1 if is_color else 0 - im = cv2.imread(file, flag) - return im - - -def resize_short(im, size): - """ - Resize an image so that the length of shorter edge is size. - Example usage: - - .. code-block:: python - im = load_image('cat.jpg') - im = resize_short(im, 256) - - :param im: the input image with HWC layout. - :type im: ndarray - :param size: the shorter edge size of image after resizing. - :type size: int - """ - assert im.shape[-1] == 1 or im.shape[-1] == 3 - h, w = im.shape[:2] - h_new, w_new = size, size - if h > w: - h_new = size * h / w - else: - w_new = size * w / h - im = cv2.resize(im, (h_new, w_new), interpolation=cv2.INTER_CUBIC) - return im - - -def to_chw(im, order=(2, 0, 1)): - """ - Transpose the input image order. The image layout is HWC format - opened by cv2 or PIL. Transpose the input image to CHW layout - according the order (2,0,1). - Example usage: - - .. code-block:: python - im = load_image('cat.jpg') - im = resize_short(im, 256) - im = to_chw(im) - - :param im: the input image with HWC layout. - :type im: ndarray - :param order: the transposed order. - :type order: tuple|list - """ - assert len(im.shape) == len(order) - im = im.transpose(order) - return im - - -def center_crop(im, size, is_color=True): - """ - Crop the center of image with size. - Example usage: - - .. code-block:: python - im = center_crop(im, 224) - - :param im: the input image with HWC layout. - :type im: ndarray - :param size: the cropping size. - :type size: int - :param is_color: whether the image is color or not. - :type is_color: bool - """ - h, w = im.shape[:2] - h_start = (h - size) / 2 - w_start = (w - size) / 2 - h_end, w_end = h_start + size, w_start + size - if is_color: - im = im[h_start:h_end, w_start:w_end, :] - else: - im = im[h_start:h_end, w_start:w_end] - return im - - -def random_crop(im, size, is_color=True): - """ - Randomly crop input image with size. - Example usage: - - .. code-block:: python - im = random_crop(im, 224) - - :param im: the input image with HWC layout. - :type im: ndarray - :param size: the cropping size. - :type size: int - :param is_color: whether the image is color or not. - :type is_color: bool - """ - h, w = im.shape[:2] - h_start = np.random.randint(0, h - size + 1) - w_start = np.random.randint(0, w - size + 1) - h_end, w_end = h_start + size, w_start + size - if is_color: - im = im[h_start:h_end, w_start:w_end, :] - else: - im = im[h_start:h_end, w_start:w_end] - return im - - -def left_right_flip(im): - """ - Flip an image along the horizontal direction. - Return the flipped image. - Example usage: - - .. code-block:: python - im = left_right_flip(im) - - :paam im: input image with HWC layout - :type im: ndarray - """ - if len(im.shape) == 3: - return im[:, ::-1, :] - else: - return im[:, ::-1, :] - - -def simple_transform(im, resize_size, crop_size, is_train, is_color=True): - """ - Simply data argumentation for training. These operations include - resizing, croping and flipping. - Example usage: - - .. code-block:: python - im = simple_transform(im, 256, 224, True) - :param im: The input image with HWC layout. - :type im: ndarray - :param resize_size: The shorter edge length of the resized image. - :type resize_size: int - :param crop_size: The cropping size. - :type crop_size: int - :param is_train: Whether it is training or not. - :type is_train: bool - """ - im = resize_short(im, resize_size) - if is_train: - im = random_crop(im, crop_size) - if np.random.randint(2) == 0: - im = left_right_flip(im) - else: - im = center_crop(im, crop_size) - im = to_chw(im) - - return im - - -def load_and_transform(filename, - resize_size, - crop_size, - is_train, - is_color=True): - """ - Load image from the input file `filename` and transform image for - data argumentation. Please refer to the `simple_transform` interface - for the transform operations. - Example usage: - - .. code-block:: python - im = load_and_transform('cat.jpg', 256, 224, True) - :param filename: The file name of input image. - :type filename: string - :param resize_size: The shorter edge length of the resized image. - :type resize_size: int - :param crop_size: The cropping size. - :type crop_size: int - :param is_train: Whether it is training or not. - :type is_train: bool - """ - im = load_image(filename) - im = simple_transform(im, resize_size, crop_size, is_train, is_color) - return im diff --git a/image_classification/caffe2paddle/paddle_resnet.py b/image_classification/caffe2paddle/paddle_resnet.py deleted file mode 100644 index 7d9a9e43a9698501bb086e0e361d416bb2aec7a1..0000000000000000000000000000000000000000 --- a/image_classification/caffe2paddle/paddle_resnet.py +++ /dev/null @@ -1,137 +0,0 @@ -from PIL import Image -import gzip -import numpy as np -import paddle.v2 as paddle -from image import load_and_transform - -__all__ = ['resnet_imagenet', 'resnet_cifar10'] - - -def conv_bn_layer(input, - ch_out, - filter_size, - stride, - padding, - active_type=paddle.activation.Relu(), - ch_in=None): - tmp = paddle.layer.img_conv( - input=input, - filter_size=filter_size, - num_channels=ch_in, - num_filters=ch_out, - stride=stride, - padding=padding, - act=paddle.activation.Linear(), - bias_attr=False) - return paddle.layer.batch_norm(input=tmp, act=active_type) - - -def shortcut(input, n_out, stride, b_projection): - if b_projection: - return conv_bn_layer(input, n_out, 1, stride, 0, - paddle.activation.Linear()) - else: - return input - - -def basicblock(input, ch_out, stride, b_projection): - # TODO: bug fix for ch_in = input.num_filters - conv1 = conv_bn_layer(input, ch_out, 3, stride, 1) - conv2 = conv_bn_layer(conv1, ch_out, 3, 1, 1, paddle.activation.Linear()) - short = shortcut(input, ch_out, stride, b_projection) - return paddle.layer.addto( - input=[conv2, short], act=paddle.activation.Relu()) - - -def bottleneck(input, ch_out, stride, b_projection): - # TODO: bug fix for ch_in = input.num_filters - short = shortcut(input, ch_out * 4, stride, b_projection) - conv1 = conv_bn_layer(input, ch_out, 1, stride, 0) - conv2 = conv_bn_layer(conv1, ch_out, 3, 1, 1) - conv3 = conv_bn_layer(conv2, ch_out * 4, 1, 1, 0, - paddle.activation.Linear()) - return paddle.layer.addto( - input=[conv3, short], act=paddle.activation.Relu()) - - -def layer_warp(block_func, input, features, count, stride): - conv = block_func(input, features, stride, True) - for i in range(1, count): - conv = block_func(conv, features, 1, False) - return conv - - -def resnet_imagenet(input, depth=50): - cfg = { - 18: ([2, 2, 2, 1], basicblock), - 34: ([3, 4, 6, 3], basicblock), - 50: ([3, 4, 6, 3], bottleneck), - 101: ([3, 4, 23, 3], bottleneck), - 152: ([3, 8, 36, 3], bottleneck) - } - stages, block_func = cfg[depth] - conv1 = conv_bn_layer( - input, ch_in=3, ch_out=64, filter_size=7, stride=2, padding=3) - pool1 = paddle.layer.img_pool(input=conv1, pool_size=3, stride=2) - res1 = layer_warp(block_func, pool1, 64, stages[0], 1) - res2 = layer_warp(block_func, res1, 128, stages[1], 2) - res3 = layer_warp(block_func, res2, 256, stages[2], 2) - res4 = layer_warp(block_func, res3, 512, stages[3], 2) - pool2 = paddle.layer.img_pool( - input=res4, pool_size=7, stride=1, pool_type=paddle.pooling.Avg()) - return pool2 - - -def resnet_cifar10(input, depth=32): - # depth should be one of 20, 32, 44, 56, 110, 1202 - assert (depth - 2) % 6 == 0 - n = (depth - 2) / 6 - nStages = {16, 64, 128} - conv1 = conv_bn_layer( - input, ch_in=3, ch_out=16, filter_size=3, stride=1, padding=1) - res1 = layer_warp(basicblock, conv1, 16, n, 1) - res2 = layer_warp(basicblock, res1, 32, n, 2) - res3 = layer_warp(basicblock, res2, 64, n, 2) - pool = paddle.layer.img_pool( - input=res3, pool_size=8, stride=1, pool_type=paddle.pooling.Avg()) - return pool - - -def load_image(file, mean_file): - im = load_and_transform(file, 256, 224, is_train=False) - im = im[(2, 1, 0), :, :] - mu = np.load(mean_file) - mu = mu.mean(1).mean(1) - im = im - mu[:, None, None] - im = im.flatten() - im = im / 255.0 - return im - - -DATA_DIM = 3 * 224 * 224 -CLASS_DIM = 1000 -BATCH_SIZE = 128 - -MODEL_FILE = 'paddle_resnet50.tar.gz' - -if __name__ == "__main__": - paddle.init(use_gpu=False, trainer_count=1) - - img = paddle.layer.data( - "image", type=paddle.data_type.dense_vector(DATA_DIM)) - out = paddle.layer.fc( - input=resnet_imagenet(img, 50), - size=1000, - act=paddle.activation.Softmax()) - - parameters = paddle.parameters.Parameters.from_tar(gzip.open(MODEL_FILE)) - - test_data = [] - test_data.append((load_image("./images/cat.jpg"), )) - output_prob = paddle.infer( - output_layer=out, parameters=parameters, input=test_data, - field="value")[0] - - print np.sort(output_prob)[::-1] - print np.argsort(output_prob)[::-1] - print 'predicted class is:', output_prob.argmax()