Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
f8bc7215
M
models
项目概览
PaddlePaddle
/
models
大约 2 年 前同步成功
通知
232
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f8bc7215
编写于
6月 08, 2017
作者:
G
guosheng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add README.md
上级
ae9c48af
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
444 addition
and
25 deletion
+444
-25
image_classification/caffe2paddle/README.md
image_classification/caffe2paddle/README.md
+55
-0
image_classification/caffe2paddle/caffe2paddle.py
image_classification/caffe2paddle/caffe2paddle.py
+29
-25
image_classification/caffe2paddle/image.py
image_classification/caffe2paddle/image.py
+223
-0
image_classification/caffe2paddle/paddle_resnet.py
image_classification/caffe2paddle/paddle_resnet.py
+137
-0
未找到文件。
image_classification/caffe2paddle/README.md
0 → 100644
浏览文件 @
f8bc7215
## 使用说明
`caffe2paddle.py`
提供了将Caffe训练的模型转换为PaddlePaddle可使用的模型的接口
`ModelConverter`
,其封装了图像领域常用的Convolution、BatchNorm等layer的转换函数,可完成VGG、ResNet等常用模型的转换。模型转换的基本过程是:基于Caffe的Python API加载模型并依次获取每一个layer的信息,将其中的参数根据layer类型与PaddlePaddle适配后序列化保存(对于Pooling等无需训练的layer不做处理),输出可以直接为PaddlePaddle的Python API加载使用的模型文件。
`ModelConverter`
的定义及说明如下:
```
python
class
ModelConverter
(
object
):
#设置Caffe网络配置文件、模型文件路径和要保存为的Paddle模型的文件名,并使用Caffe API加载模型
def
__init__
(
self
,
caffe_model_file
,
caffe_pretrained_file
,
paddle_tar_name
)
#输出保存Paddle模型
def
to_tar
(
self
,
f
)
#将参数值序列化输出为二进制
@
staticmethod
def
serialize
(
data
,
f
)
#依次对各个layer进行转换,转换时参照name_map进行layer和参数命名
def
convert
(
self
,
name_map
=
{})
#对Caffe模型的Convolution层的参数进行转换,将使用name值对Paddle模型中对应layer的参数命名
@
wrap_name_default
(
"img_conv_layer"
)
def
convert_Convolution_layer
(
self
,
params
,
name
=
None
)
#对Caffe模型的InnerProduct层的参数进行转换,将使用name值对Paddle模型中对应layer的参数命名
@
wrap_name_default
(
"fc_layer"
)
def
convert_InnerProduct_layer
(
self
,
params
,
name
=
None
)
#对Caffe模型的BatchNorm层的参数进行转换,将使用name值对Paddle模型中对应layer的参数命名
@
wrap_name_default
(
"batch_norm_layer"
)
def
convert_BatchNorm_layer
(
self
,
params
,
name
=
None
)
#对Caffe模型的Scale层的参数进行转换,将使用name值对Paddle模型中对应layer的参数命名
def
convert_Scale_layer
(
self
,
params
,
name
=
None
)
#输入图片路径和均值文件路径,使用加载的Caffe模型进行预测
def
caffe_predict
(
self
,
img
,
mean_file
)
```
`ModelConverter`
的使用方法如下:
```
python
#指定Caffe网络配置文件、模型文件路径和要保存为的Paddle模型的文件名,并从指定文件加载模型
converter
=
ModelConverter
(
"./ResNet-50-deploy.prototxt"
,
"./ResNet-50-model.caffemodel"
,
"Paddle_ResNet50.tar.gz"
)
#进行模型转换
converter
.
convert
(
name_map
=
{})
#进行预测并输出预测概率以便对比验证模型转换结果
converter
.
caffe_predict
(
img
=
'./caffe/examples/images/cat.jpg'
)
```
为验证并使用转换得到的模型,需基于PaddlePaddle API编写对应的网络结构配置文件,具体可参照PaddlePaddle使用文档,我们这里附上ResNet的配置以供使用。需要注意,上文给出的模型转换在调用
`ModelConverter.convert`
时传入了空的
`name_map`
,这将在遍历每一个layer进行参数保存时使用PaddlePaddle默认的layer和参数命名规则:以
`wrap_name_default`
中的值和调用计数构造layer name,并以此为前缀构造参数名(比如第一个InnerProduct层的bias参数将被命名为
`___fc_layer_0__.wbias`
);为此,在编写PaddlePaddle网络配置时要保证和Caffe端模型使用同样的拓扑顺序,尤其是对于ResNet这种有分支的网络结构,要保证两分支在PaddlePaddle和Caffe中先后顺序一致,这样才能够使得模型参数正确加载。如果不希望使用默认的layer name,可以使用一种更为精细的方法:建立Caffe和PaddlePaddle网络配置间layer name对应关系的
`dict`
并在调用
`ModelConverter.convert`
时作为
`name_map`
传入,这样在命名保存layer中的参数时将使用相应的layer name,另外这里只针对Caffe网络配置中Convolution、InnerProduct和BatchNorm类别的layer建立
`name_map`
即可(一方面,对于Pooling等无需训练的layer不需要保存,故这里没有提供转换接口;另一方面,对于Caffe中的Scale类别的layer,由于Caffe和PaddlePaddle在实现上的一些差别,PaddlePaddle中的batch_norm层同时包含BatchNorm和Scale层的复合,故这里对Scale进行了特殊处理)。
image_classification/caffe2paddle.py
→
image_classification/caffe2paddle
/caffe2paddle
.py
浏览文件 @
f8bc7215
# -*- coding: utf-8 -*-
import
os
import
os
import
functools
import
functools
import
inspect
import
inspect
...
@@ -9,6 +8,7 @@ import cStringIO
...
@@ -9,6 +8,7 @@ import cStringIO
import
numpy
as
np
import
numpy
as
np
import
caffe
import
caffe
from
paddle.proto.ParameterConfig_pb2
import
ParameterConfig
from
paddle.proto.ParameterConfig_pb2
import
ParameterConfig
from
image
import
load_and_transform
def
__default_not_set_callback__
(
kwargs
,
name
):
def
__default_not_set_callback__
(
kwargs
,
name
):
...
@@ -90,15 +90,16 @@ def wrap_name_default(name_prefix=None, name_param="name"):
...
@@ -90,15 +90,16 @@ def wrap_name_default(name_prefix=None, name_param="name"):
class
ModelConverter
(
object
):
class
ModelConverter
(
object
):
def
__init__
(
self
,
caffe_model_file
,
caffe_pretrained_file
,
def
__init__
(
self
,
caffe_model_file
,
caffe_pretrained_file
,
paddle_tar_name
):
paddle_
output_path
,
paddle_
tar_name
):
self
.
net
=
caffe
.
Net
(
caffe_model_file
,
caffe_pretrained_file
,
self
.
net
=
caffe
.
Net
(
caffe_model_file
,
caffe_pretrained_file
,
caffe
.
TEST
)
caffe
.
TEST
)
self
.
output_path
=
paddle_output_path
self
.
tar_name
=
paddle_tar_name
self
.
tar_name
=
paddle_tar_name
self
.
params
=
dict
()
self
.
params
=
dict
()
self
.
pre_layer_name
=
""
self
.
pre_layer_name
=
""
self
.
pre_layer_type
=
""
self
.
pre_layer_type
=
""
def
convert
(
self
):
def
convert
(
self
,
name_map
=
{}
):
layer_dict
=
self
.
net
.
layer_dict
layer_dict
=
self
.
net
.
layer_dict
for
layer_name
in
layer_dict
.
keys
():
for
layer_name
in
layer_dict
.
keys
():
layer
=
layer_dict
[
layer_name
]
layer
=
layer_dict
[
layer_name
]
...
@@ -106,7 +107,10 @@ class ModelConverter(object):
...
@@ -106,7 +107,10 @@ class ModelConverter(object):
layer_type
=
layer
.
type
layer_type
=
layer
.
type
if
len
(
layer_params
)
>
0
:
if
len
(
layer_params
)
>
0
:
self
.
pre_layer_name
=
getattr
(
self
.
pre_layer_name
=
getattr
(
self
,
"convert_"
+
layer_type
+
"_layer"
)(
layer_params
)
self
,
"convert_"
+
layer_type
+
"_layer"
)(
layer_params
,
name
=
None
if
name_map
==
None
else
name_map
.
get
(
layer_name
))
self
.
pre_layer_type
=
layer_type
self
.
pre_layer_type
=
layer_type
with
gzip
.
open
(
self
.
tar_name
,
'w'
)
as
f
:
with
gzip
.
open
(
self
.
tar_name
,
'w'
)
as
f
:
self
.
to_tar
(
f
)
self
.
to_tar
(
f
)
...
@@ -136,7 +140,7 @@ class ModelConverter(object):
...
@@ -136,7 +140,7 @@ class ModelConverter(object):
f
.
write
(
struct
.
pack
(
"IIQ"
,
0
,
4
,
data
.
size
))
f
.
write
(
struct
.
pack
(
"IIQ"
,
0
,
4
,
data
.
size
))
f
.
write
(
data
.
tobytes
())
f
.
write
(
data
.
tobytes
())
@
wrap_name_default
(
"
conv
"
)
@
wrap_name_default
(
"
img_conv_layer
"
)
def
convert_Convolution_layer
(
self
,
params
,
name
=
None
):
def
convert_Convolution_layer
(
self
,
params
,
name
=
None
):
for
i
in
range
(
len
(
params
)):
for
i
in
range
(
len
(
params
)):
data
=
np
.
array
(
params
[
i
].
data
)
data
=
np
.
array
(
params
[
i
].
data
)
...
@@ -149,6 +153,7 @@ class ModelConverter(object):
...
@@ -149,6 +153,7 @@ class ModelConverter(object):
param_conf
.
name
=
file_name
param_conf
.
name
=
file_name
param_conf
.
size
=
reduce
(
lambda
a
,
b
:
a
*
b
,
data
.
shape
)
param_conf
.
size
=
reduce
(
lambda
a
,
b
:
a
*
b
,
data
.
shape
)
self
.
params
[
file_name
]
=
(
param_conf
,
data
.
flatten
())
self
.
params
[
file_name
]
=
(
param_conf
,
data
.
flatten
())
return
name
return
name
@
wrap_name_default
(
"fc_layer"
)
@
wrap_name_default
(
"fc_layer"
)
...
@@ -171,9 +176,10 @@ class ModelConverter(object):
...
@@ -171,9 +176,10 @@ class ModelConverter(object):
self
.
params
[
file_name
]
=
(
param_conf
,
data
.
flatten
())
self
.
params
[
file_name
]
=
(
param_conf
,
data
.
flatten
())
return
name
return
name
@
wrap_name_default
(
"batch_norm"
)
@
wrap_name_default
(
"batch_norm
_layer
"
)
def
convert_BatchNorm_layer
(
self
,
params
,
name
=
None
):
def
convert_BatchNorm_layer
(
self
,
params
,
name
=
None
):
scale
=
np
.
array
(
params
[
-
1
].
data
)
scale
=
1
/
np
.
array
(
params
[
-
1
].
data
)[
0
]
if
np
.
array
(
params
[
-
1
].
data
)[
0
]
!=
0
else
0
for
i
in
range
(
2
):
for
i
in
range
(
2
):
data
=
np
.
array
(
params
[
i
].
data
)
*
scale
data
=
np
.
array
(
params
[
i
].
data
)
*
scale
file_name
=
"_%s.w%s"
%
(
name
,
str
(
i
+
1
))
file_name
=
"_%s.w%s"
%
(
name
,
str
(
i
+
1
))
...
@@ -210,19 +216,7 @@ class ModelConverter(object):
...
@@ -210,19 +216,7 @@ class ModelConverter(object):
mean_file
=
'./caffe/imagenet/ilsvrc_2012_mean.npy'
):
mean_file
=
'./caffe/imagenet/ilsvrc_2012_mean.npy'
):
net
=
self
.
net
net
=
self
.
net
mu
=
np
.
load
(
mean_file
)
net
.
blobs
[
'data'
].
data
[...]
=
load_img
(
img
,
mean_file
)
mu
=
mu
.
mean
(
1
).
mean
(
1
)
transformer
=
caffe
.
io
.
Transformer
({
'data'
:
net
.
blobs
[
'data'
].
data
.
shape
})
transformer
.
set_transpose
(
'data'
,
(
2
,
0
,
1
))
transformer
.
set_mean
(
'data'
,
mu
)
transformer
.
set_raw_scale
(
'data'
,
255
)
transformer
.
set_channel_swap
(
'data'
,
(
2
,
1
,
0
))
im
=
caffe
.
io
.
load_image
(
img
)
net
.
blobs
[
'data'
].
data
[...]
=
transformer
.
preprocess
(
'data'
,
im
)
out
=
net
.
forward
()
out
=
net
.
forward
()
output_prob
=
net
.
blobs
[
'prob'
].
data
[
0
].
flatten
()
output_prob
=
net
.
blobs
[
'prob'
].
data
[
0
].
flatten
()
...
@@ -231,9 +225,19 @@ class ModelConverter(object):
...
@@ -231,9 +225,19 @@ class ModelConverter(object):
print
'predicted class is:'
,
output_prob
.
argmax
()
print
'predicted class is:'
,
output_prob
.
argmax
()
def
load_image
(
file
,
mean_file
):
im
=
load_and_transform
(
file
,
256
,
224
,
is_train
=
False
)
im
=
im
[(
2
,
1
,
0
),
:,
:]
mu
=
np
.
load
(
mean_file
)
mu
=
mu
.
mean
(
1
).
mean
(
1
)
im
=
im
-
mu
[:,
None
,
None
]
im
=
im
/
255.0
return
im
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
converter
=
ModelConverter
(
"./
VGG_ILSVRC_16_layers_
deploy.prototxt"
,
converter
=
ModelConverter
(
"./
resnet50/ResNet-50-
deploy.prototxt"
,
"./
VGG_ILSVRC_16_layers
.caffemodel"
,
"./
resnet50/ResNet-50-model
.caffemodel"
,
"
test_vgg16
.tar.gz"
)
"
paddle_resnet50
.tar.gz"
)
converter
.
convert
()
converter
.
convert
(
name_map
=
dict
()
)
converter
.
caffe_predict
(
img
=
'./caffe/examples/images/cat.jpg'
)
converter
.
caffe_predict
(
"./images/cat.jpg"
)
image_classification/caffe2paddle/image.py
0 → 100644
浏览文件 @
f8bc7215
import
numpy
as
np
try
:
import
cv2
except
:
print
(
"import cv2 error, please install opencv-python: pip install opencv-python"
)
__all__
=
[
"load_image"
,
"resize_short"
,
"to_chw"
,
"center_crop"
,
"random_crop"
,
"left_right_flip"
,
"simple_transform"
,
"load_and_transform"
]
"""
This file contains some common interfaces for image preprocess.
Many users are confused about the image layout. We introduce
the image layout as follows.
- CHW Layout
- The abbreviations: C=channel, H=Height, W=Width
- The default layout of image opened by cv2 or PIL is HWC.
PaddlePaddle only supports the CHW layout. And CHW is simply
a transpose of HWC. It must transpose the input image.
- Color format: RGB or BGR
OpenCV use BGR color format. PIL use RGB color format. Both
formats can be used for training. Noted that, the format should
be keep consistent between the training and inference peroid.
"""
def
load_image
(
file
,
is_color
=
True
):
"""
Load an color or gray image from the file path.
Example usage:
.. code-block:: python
im = load_image('cat.jpg')
:param file: the input image path.
:type file: string
:param is_color: If set is_color True, it will load and
return a color image. Otherwise, it will
load and return a gray image.
"""
# cv2.IMAGE_COLOR for OpenCV3
# cv2.CV_LOAD_IMAGE_COLOR for older OpenCV Version
# cv2.IMAGE_GRAYSCALE for OpenCV3
# cv2.CV_LOAD_IMAGE_GRAYSCALE for older OpenCV Version
# Here, use constant 1 and 0
# 1: COLOR, 0: GRAYSCALE
flag
=
1
if
is_color
else
0
im
=
cv2
.
imread
(
file
,
flag
)
return
im
def
resize_short
(
im
,
size
):
"""
Resize an image so that the length of shorter edge is size.
Example usage:
.. code-block:: python
im = load_image('cat.jpg')
im = resize_short(im, 256)
:param im: the input image with HWC layout.
:type im: ndarray
:param size: the shorter edge size of image after resizing.
:type size: int
"""
assert
im
.
shape
[
-
1
]
==
1
or
im
.
shape
[
-
1
]
==
3
h
,
w
=
im
.
shape
[:
2
]
h_new
,
w_new
=
size
,
size
if
h
>
w
:
h_new
=
size
*
h
/
w
else
:
w_new
=
size
*
w
/
h
im
=
cv2
.
resize
(
im
,
(
h_new
,
w_new
),
interpolation
=
cv2
.
INTER_CUBIC
)
return
im
def
to_chw
(
im
,
order
=
(
2
,
0
,
1
)):
"""
Transpose the input image order. The image layout is HWC format
opened by cv2 or PIL. Transpose the input image to CHW layout
according the order (2,0,1).
Example usage:
.. code-block:: python
im = load_image('cat.jpg')
im = resize_short(im, 256)
im = to_chw(im)
:param im: the input image with HWC layout.
:type im: ndarray
:param order: the transposed order.
:type order: tuple|list
"""
assert
len
(
im
.
shape
)
==
len
(
order
)
im
=
im
.
transpose
(
order
)
return
im
def
center_crop
(
im
,
size
,
is_color
=
True
):
"""
Crop the center of image with size.
Example usage:
.. code-block:: python
im = center_crop(im, 224)
:param im: the input image with HWC layout.
:type im: ndarray
:param size: the cropping size.
:type size: int
:param is_color: whether the image is color or not.
:type is_color: bool
"""
h
,
w
=
im
.
shape
[:
2
]
h_start
=
(
h
-
size
)
/
2
w_start
=
(
w
-
size
)
/
2
h_end
,
w_end
=
h_start
+
size
,
w_start
+
size
if
is_color
:
im
=
im
[
h_start
:
h_end
,
w_start
:
w_end
,
:]
else
:
im
=
im
[
h_start
:
h_end
,
w_start
:
w_end
]
return
im
def
random_crop
(
im
,
size
,
is_color
=
True
):
"""
Randomly crop input image with size.
Example usage:
.. code-block:: python
im = random_crop(im, 224)
:param im: the input image with HWC layout.
:type im: ndarray
:param size: the cropping size.
:type size: int
:param is_color: whether the image is color or not.
:type is_color: bool
"""
h
,
w
=
im
.
shape
[:
2
]
h_start
=
np
.
random
.
randint
(
0
,
h
-
size
+
1
)
w_start
=
np
.
random
.
randint
(
0
,
w
-
size
+
1
)
h_end
,
w_end
=
h_start
+
size
,
w_start
+
size
if
is_color
:
im
=
im
[
h_start
:
h_end
,
w_start
:
w_end
,
:]
else
:
im
=
im
[
h_start
:
h_end
,
w_start
:
w_end
]
return
im
def
left_right_flip
(
im
):
"""
Flip an image along the horizontal direction.
Return the flipped image.
Example usage:
.. code-block:: python
im = left_right_flip(im)
:paam im: input image with HWC layout
:type im: ndarray
"""
if
len
(
im
.
shape
)
==
3
:
return
im
[:,
::
-
1
,
:]
else
:
return
im
[:,
::
-
1
,
:]
def
simple_transform
(
im
,
resize_size
,
crop_size
,
is_train
,
is_color
=
True
):
"""
Simply data argumentation for training. These operations include
resizing, croping and flipping.
Example usage:
.. code-block:: python
im = simple_transform(im, 256, 224, True)
:param im: The input image with HWC layout.
:type im: ndarray
:param resize_size: The shorter edge length of the resized image.
:type resize_size: int
:param crop_size: The cropping size.
:type crop_size: int
:param is_train: Whether it is training or not.
:type is_train: bool
"""
im
=
resize_short
(
im
,
resize_size
)
if
is_train
:
im
=
random_crop
(
im
,
crop_size
)
if
np
.
random
.
randint
(
2
)
==
0
:
im
=
left_right_flip
(
im
)
else
:
im
=
center_crop
(
im
,
crop_size
)
im
=
to_chw
(
im
)
return
im
def
load_and_transform
(
filename
,
resize_size
,
crop_size
,
is_train
,
is_color
=
True
):
"""
Load image from the input file `filename` and transform image for
data argumentation. Please refer to the `simple_transform` interface
for the transform operations.
Example usage:
.. code-block:: python
im = load_and_transform('cat.jpg', 256, 224, True)
:param filename: The file name of input image.
:type filename: string
:param resize_size: The shorter edge length of the resized image.
:type resize_size: int
:param crop_size: The cropping size.
:type crop_size: int
:param is_train: Whether it is training or not.
:type is_train: bool
"""
im
=
load_image
(
filename
)
im
=
simple_transform
(
im
,
resize_size
,
crop_size
,
is_train
,
is_color
)
return
im
image_classification/caffe2paddle/paddle_resnet.py
0 → 100644
浏览文件 @
f8bc7215
from
PIL
import
Image
import
gzip
import
numpy
as
np
import
paddle.v2
as
paddle
from
image
import
load_and_transform
__all__
=
[
'resnet_imagenet'
,
'resnet_cifar10'
]
def
conv_bn_layer
(
input
,
ch_out
,
filter_size
,
stride
,
padding
,
active_type
=
paddle
.
activation
.
Relu
(),
ch_in
=
None
):
tmp
=
paddle
.
layer
.
img_conv
(
input
=
input
,
filter_size
=
filter_size
,
num_channels
=
ch_in
,
num_filters
=
ch_out
,
stride
=
stride
,
padding
=
padding
,
act
=
paddle
.
activation
.
Linear
(),
bias_attr
=
False
)
return
paddle
.
layer
.
batch_norm
(
input
=
tmp
,
act
=
active_type
)
def
shortcut
(
input
,
n_out
,
stride
,
b_projection
):
if
b_projection
:
return
conv_bn_layer
(
input
,
n_out
,
1
,
stride
,
0
,
paddle
.
activation
.
Linear
())
else
:
return
input
def
basicblock
(
input
,
ch_out
,
stride
,
b_projection
):
# TODO: bug fix for ch_in = input.num_filters
conv1
=
conv_bn_layer
(
input
,
ch_out
,
3
,
stride
,
1
)
conv2
=
conv_bn_layer
(
conv1
,
ch_out
,
3
,
1
,
1
,
paddle
.
activation
.
Linear
())
short
=
shortcut
(
input
,
ch_out
,
stride
,
b_projection
)
return
paddle
.
layer
.
addto
(
input
=
[
conv2
,
short
],
act
=
paddle
.
activation
.
Relu
())
def
bottleneck
(
input
,
ch_out
,
stride
,
b_projection
):
# TODO: bug fix for ch_in = input.num_filters
short
=
shortcut
(
input
,
ch_out
*
4
,
stride
,
b_projection
)
conv1
=
conv_bn_layer
(
input
,
ch_out
,
1
,
stride
,
0
)
conv2
=
conv_bn_layer
(
conv1
,
ch_out
,
3
,
1
,
1
)
conv3
=
conv_bn_layer
(
conv2
,
ch_out
*
4
,
1
,
1
,
0
,
paddle
.
activation
.
Linear
())
return
paddle
.
layer
.
addto
(
input
=
[
conv3
,
short
],
act
=
paddle
.
activation
.
Relu
())
def
layer_warp
(
block_func
,
input
,
features
,
count
,
stride
):
conv
=
block_func
(
input
,
features
,
stride
,
True
)
for
i
in
range
(
1
,
count
):
conv
=
block_func
(
conv
,
features
,
1
,
False
)
return
conv
def
resnet_imagenet
(
input
,
depth
=
50
):
cfg
=
{
18
:
([
2
,
2
,
2
,
1
],
basicblock
),
34
:
([
3
,
4
,
6
,
3
],
basicblock
),
50
:
([
3
,
4
,
6
,
3
],
bottleneck
),
101
:
([
3
,
4
,
23
,
3
],
bottleneck
),
152
:
([
3
,
8
,
36
,
3
],
bottleneck
)
}
stages
,
block_func
=
cfg
[
depth
]
conv1
=
conv_bn_layer
(
input
,
ch_in
=
3
,
ch_out
=
64
,
filter_size
=
7
,
stride
=
2
,
padding
=
3
)
pool1
=
paddle
.
layer
.
img_pool
(
input
=
conv1
,
pool_size
=
3
,
stride
=
2
)
res1
=
layer_warp
(
block_func
,
pool1
,
64
,
stages
[
0
],
1
)
res2
=
layer_warp
(
block_func
,
res1
,
128
,
stages
[
1
],
2
)
res3
=
layer_warp
(
block_func
,
res2
,
256
,
stages
[
2
],
2
)
res4
=
layer_warp
(
block_func
,
res3
,
512
,
stages
[
3
],
2
)
pool2
=
paddle
.
layer
.
img_pool
(
input
=
res4
,
pool_size
=
7
,
stride
=
1
,
pool_type
=
paddle
.
pooling
.
Avg
())
return
pool2
def
resnet_cifar10
(
input
,
depth
=
32
):
# depth should be one of 20, 32, 44, 56, 110, 1202
assert
(
depth
-
2
)
%
6
==
0
n
=
(
depth
-
2
)
/
6
nStages
=
{
16
,
64
,
128
}
conv1
=
conv_bn_layer
(
input
,
ch_in
=
3
,
ch_out
=
16
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
)
res1
=
layer_warp
(
basicblock
,
conv1
,
16
,
n
,
1
)
res2
=
layer_warp
(
basicblock
,
res1
,
32
,
n
,
2
)
res3
=
layer_warp
(
basicblock
,
res2
,
64
,
n
,
2
)
pool
=
paddle
.
layer
.
img_pool
(
input
=
res3
,
pool_size
=
8
,
stride
=
1
,
pool_type
=
paddle
.
pooling
.
Avg
())
return
pool
def
load_image
(
file
,
mean_file
):
im
=
load_and_transform
(
file
,
256
,
224
,
is_train
=
False
)
im
=
im
[(
2
,
1
,
0
),
:,
:]
mu
=
np
.
load
(
mean_file
)
mu
=
mu
.
mean
(
1
).
mean
(
1
)
im
=
im
-
mu
[:,
None
,
None
]
im
=
im
.
flatten
()
im
=
im
/
255.0
return
im
DATA_DIM
=
3
*
224
*
224
CLASS_DIM
=
1000
BATCH_SIZE
=
128
MODEL_FILE
=
'paddle_resnet50.tar.gz'
if
__name__
==
"__main__"
:
paddle
.
init
(
use_gpu
=
False
,
trainer_count
=
1
)
img
=
paddle
.
layer
.
data
(
"image"
,
type
=
paddle
.
data_type
.
dense_vector
(
DATA_DIM
))
out
=
paddle
.
layer
.
fc
(
input
=
resnet_imagenet
(
img
,
50
),
size
=
1000
,
act
=
paddle
.
activation
.
Softmax
())
parameters
=
paddle
.
parameters
.
Parameters
.
from_tar
(
gzip
.
open
(
MODEL_FILE
))
test_data
=
[]
test_data
.
append
((
load_image
(
"./images/cat.jpg"
),
))
output_prob
=
paddle
.
infer
(
output_layer
=
out
,
parameters
=
parameters
,
input
=
test_data
,
field
=
"value"
)[
0
]
print
np
.
sort
(
output_prob
)[::
-
1
]
print
np
.
argsort
(
output_prob
)[::
-
1
]
print
'predicted class is:'
,
output_prob
.
argmax
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录