未验证 提交 69788b6d 编写于 作者: W wuzewu 提交者: GitHub

update resnet50_vd_10w

<p align="center">
<img src="http://bj.bcebos.com/ibox-thumbnail98/77fa9b7003e4665867855b2b65216519?authorization=bce-auth-v1%2Ffbe74140929444858491fbf2b6bc0935%2F2020-04-08T11%3A05%3A10Z%2F1800%2F%2F1df0ecb4a52adefeae240c9e2189e8032560333e399b3187ef1a76e4ffa5f19f" hspace='5' width=800/> <br /> ResNet 系列的网络结构
</p>
模型的详情可参考[论文](https://arxiv.org/pdf/1812.01187.pdf)
## API
```python
def get_expected_image_width()
```
返回预处理的图片宽度,也就是224。
```python
def get_expected_image_height()
```
返回预处理的图片高度,也就是224。
```python
def get_pretrained_images_mean()
```
返回预处理的图片均值,也就是 \[0.485, 0.456, 0.406\]
```python
def get_pretrained_images_std()
```
返回预处理的图片标准差,也就是 \[0.229, 0.224, 0.225\]
```python
def context(trainable=True, pretrained=True)
```
**参数**
* trainable (bool): 计算图的参数是否为可训练的;
* pretrained (bool): 是否加载默认的预训练模型。
**返回**
* inputs (dict): 计算图的输入,key 为 'image', value 为图片的张量;
* outputs (dict): 计算图的输出,key 为 'feature\_map', value为全连接层前面的那个张量。
* context\_prog(fluid.Program): 计算图,用于迁移学习。
```python
def save_inference_model(dirname,
model_filename=None,
params_filename=None,
combined=True)
```
将模型保存到指定路径。
**参数**
* dirname: 存在模型的目录名称
* model_filename: 模型文件名称,默认为\_\_model\_\_
* params_filename: 参数文件名称,默认为\_\_params\_\_(仅当`combined`为True时生效)
* combined: 是否将参数保存到统一的一个文件中
## 代码示例
```python
import paddlehub as hub
import cv2
classifier = hub.Module(name="resnet50_vd_10w")
input_dict, output_dict, program = classifier.context(trainable=True)
```
### 查看代码
[PaddleClas](https://github.com/PaddlePaddle/PaddleClas)
### 依赖
paddlepaddle >= 1.6.2
paddlehub >= 1.6.0
# coding=utf-8
import os
import time
from collections import OrderedDict
import cv2
import numpy as np
from PIL import Image
__all__ = ['reader']
DATA_DIM = 224
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
def resize_short(img, target_size):
percent = float(target_size) / min(img.size[0], img.size[1])
resized_width = int(round(img.size[0] * percent))
resized_height = int(round(img.size[1] * percent))
img = img.resize((resized_width, resized_height), Image.LANCZOS)
return img
def crop_image(img, target_size, center):
width, height = img.size
size = target_size
if center == True:
w_start = (width - size) / 2
h_start = (height - size) / 2
else:
w_start = np.random.randint(0, width - size + 1)
h_start = np.random.randint(0, height - size + 1)
w_end = w_start + size
h_end = h_start + size
img = img.crop((w_start, h_start, w_end, h_end))
return img
def process_image(img):
img = resize_short(img, target_size=256)
img = crop_image(img, target_size=DATA_DIM, center=True)
if img.mode != 'RGB':
img = img.convert('RGB')
img = np.array(img).astype('float32').transpose((2, 0, 1)) / 255
img -= img_mean
img /= img_std
return img
def reader(images=None, paths=None):
"""
Preprocess to yield image.
Args:
images (list[numpy.ndarray]): images data, shape of each is [H, W, C].
paths (list[str]): paths to images.
Yield:
each (collections.OrderedDict): info of original image, preprocessed image.
"""
component = list()
if paths:
for im_path in paths:
each = OrderedDict()
assert os.path.isfile(
im_path), "The {} isn't a valid file path.".format(im_path)
each['org_im_path'] = im_path
each['org_im'] = Image.open(im_path)
each['org_im_width'], each['org_im_height'] = each['org_im'].size
component.append(each)
if images is not None:
assert type(images), "images is a list."
for im in images:
each = OrderedDict()
each['org_im'] = Image.fromarray(im[:, :, ::-1])
each['org_im_path'] = 'ndarray_time={}'.format(
round(time.time(), 6) * 1e6)
each['org_im_width'], each['org_im_height'] = each['org_im'].size
component.append(each)
for element in component:
element['image'] = process_image(element['org_im'])
yield element
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
import ast
import argparse
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import math
import numpy as np
import paddle.fluid as fluid
import paddlehub as hub
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor
from paddlehub.module.module import moduleinfo, runnable, serving
from paddlehub.common.paddle_helper import add_vars_prefix
from resnet50_vd_10w.processor import postprocess, base64_to_cv2
from resnet50_vd_10w.data_feed import reader
from resnet50_vd_10w.resnet_vd import ResNet50_vd
import paddle
from paddle import ParamAttr
import paddle.nn as nn
from paddle.nn import Conv2d, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2d, MaxPool2d, AvgPool2d
from paddle.nn.initializer import Uniform
from paddlehub.module.module import moduleinfo
from paddlehub.module.cv_module import ImageClassifierModule
class ConvBNLayer(nn.Layer):
"""Basic conv bn layer."""
def __init__(
self,
num_channels: int,
num_filters: int,
filter_size: int,
stride: int = 1,
groups: int = 1,
is_vd_mode: bool = False,
act: str = None,
name: str = None,
):
super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode
self._pool2d_avg = AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)
self._conv = Conv2d(in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
weight_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
self._batch_norm = BatchNorm(num_filters,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def forward(self, inputs: paddle.Tensor):
if self.is_vd_mode:
inputs = self._pool2d_avg(inputs)
y = self._conv(inputs)
y = self._batch_norm(y)
return y
class BottleneckBlock(nn.Layer):
"""Bottleneck Block for ResNet50_vd."""
def __init__(self,
num_channels: int,
num_filters: int,
stride: int,
shortcut: bool = True,
if_first: bool = False,
name: str = None):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
act='relu',
name=name + "_branch2a")
self.conv1 = ConvBNLayer(num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu',
name=name + "_branch2b")
self.conv2 = ConvBNLayer(num_channels=num_filters,
num_filters=num_filters * 4,
filter_size=1,
act=None,
name=name + "_branch2c")
if not shortcut:
self.short = ConvBNLayer(num_channels=num_channels,
num_filters=num_filters * 4,
filter_size=1,
stride=1,
is_vd_mode=False if if_first else True,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs: paddle.Tensor):
y = self.conv0(inputs)
conv1 = self.conv1(y)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = paddle.elementwise_add(x=short, y=conv2, act='relu')
return y
class BasicBlock(nn.Layer):
"""Basic block for ResNet50_vd."""
def __init__(self,
num_channels: int,
num_filters: int,
stride: int,
shortcut: bool = True,
if_first: bool = False,
name: str = None):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = ConvBNLayer(num_channels=num_channels,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu',
name=name + "_branch2a")
self.conv1 = ConvBNLayer(num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
act=None,
name=name + "_branch2b")
if not shortcut:
self.short = ConvBNLayer(num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
stride=1,
is_vd_mode=False if if_first else True,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs: paddle.Tensor):
y = self.conv0(inputs)
conv1 = self.conv1(y)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = paddle.elementwise_add(x=short, y=conv1, act='relu')
return y
@moduleinfo(
name="resnet50_vd_10w",
type="CV/image_classification",
@moduleinfo(name="resnet50_vd_10w",
type="CV/classification",
author="paddlepaddle",
author_email="paddle-dev@baidu.com",
summary=
"ResNet50vd is a image classfication model, this module is trained with Baidu's self-built dataset with 100,000 categories.",
version="1.0.0")
class ResNet50vd(hub.Module):
def _initialize(self):
self.default_pretrained_model_path = os.path.join(
self.directory, "model")
def get_expected_image_width(self):
return 224
def get_expected_image_height(self):
return 224
def get_pretrained_images_mean(self):
im_mean = np.array([0.485, 0.456, 0.406]).reshape(1, 3)
return im_mean
def get_pretrained_images_std(self):
im_std = np.array([0.229, 0.224, 0.225]).reshape(1, 3)
return im_std
def context(self, trainable=True, pretrained=True):
"""context for transfer learning.
Args:
trainable (bool): Set parameters in program to be trainable.
pretrained (bool) : Whether to load pretrained model.
Returns:
inputs (dict): key is 'image', corresponding vaule is image tensor.
outputs (dict): key is 'feature_map', corresponding value is the result of the layer before the fully connected layer.
context_prog (fluid.Program): program for transfer learning.
"""
context_prog = fluid.Program()
startup_prog = fluid.Program()
with fluid.program_guard(context_prog, startup_prog):
with fluid.unique_name.guard():
image = fluid.layers.data(
name="image", shape=[3, 224, 224], dtype="float32")
resnet_vd = ResNet50_vd()
feature_map = resnet_vd.net(input=image)
name_prefix = '@HUB_{}@'.format(self.name)
inputs = {'image': name_prefix + image.name}
outputs = {'feature_map': name_prefix + feature_map.name}
add_vars_prefix(context_prog, name_prefix)
add_vars_prefix(startup_prog, name_prefix)
global_vars = context_prog.global_block().vars
inputs = {
key: global_vars[value]
for key, value in inputs.items()
}
outputs = {
key: global_vars[value]
for key, value in outputs.items()
}
place = fluid.CPUPlace()
exe = fluid.Executor(place)
# pretrained
if pretrained:
def _if_exist(var):
b = os.path.exists(
os.path.join(self.default_pretrained_model_path,
var.name))
return b
fluid.io.load_vars(
exe,
self.default_pretrained_model_path,
context_prog,
predicate=_if_exist)
author_email="",
summary="resnet50_vd_imagenet_ssld is a classification model, "
"this module is trained with Baidu open sourced dataset.",
version="1.1.0",
meta=ImageClassifierModule)
class ResNet50_vd(nn.Layer):
"""ResNet50_vd model."""
def __init__(self, class_dim: int = 1000, load_checkpoint: str = None):
super(ResNet50_vd, self).__init__()
self.layers = 50
depth = [3, 4, 6, 3]
num_channels = [64, 256, 512, 1024]
num_filters = [64, 128, 256, 512]
self.conv1_1 = ConvBNLayer(num_channels=3, num_filters=32, filter_size=3, stride=2, act='relu', name="conv1_1")
self.conv1_2 = ConvBNLayer(num_channels=32, num_filters=32, filter_size=3, stride=1, act='relu', name="conv1_2")
self.conv1_3 = ConvBNLayer(num_channels=32, num_filters=64, filter_size=3, stride=1, act='relu', name="conv1_3")
self.pool2d_max = MaxPool2d(kernel_size=3, stride=2, padding=1)
self.block_list = []
for block in range(len(depth)):
shortcut = False
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
bottleneck_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BottleneckBlock(num_channels=num_channels[block] if i == 0 else num_filters[block] * 4,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
if_first=block == i == 0,
name=conv_name))
self.block_list.append(bottleneck_block)
shortcut = True
self.pool2d_avg = AdaptiveAvgPool2d(1)
self.pool2d_avg_channels = num_channels[-1] * 2
stdv = 1.0 / math.sqrt(self.pool2d_avg_channels * 1.0)
self.out = Linear(self.pool2d_avg_channels,
class_dim,
weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv), name="fc_0.w_0"),
bias_attr=ParamAttr(name="fc_0.b_0"))
if load_checkpoint is not None:
model_dict = paddle.load(load_checkpoint)[0]
self.set_dict(model_dict)
print("load custom checkpoint success")
else:
exe.run(startup_prog)
# trainable
for param in context_prog.global_block().iter_parameters():
param.trainable = trainable
return inputs, outputs, context_prog
def save_inference_model(self,
dirname,
model_filename=None,
params_filename=None,
combined=True):
if combined:
model_filename = "__model__" if not model_filename else model_filename
params_filename = "__params__" if not params_filename else params_filename
place = fluid.CPUPlace()
exe = fluid.Executor(place)
program, feeded_var_names, target_vars = fluid.io.load_inference_model(
dirname=self.default_pretrained_model_path, executor=exe)
fluid.io.save_inference_model(
dirname=dirname,
main_program=program,
executor=exe,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
model_filename=model_filename,
params_filename=params_filename)
checkpoint = os.path.join(self.directory, 'resnet50_vd_10w.pdparams')
if not os.path.exists(checkpoint):
os.system(
'wget https://paddlehub.bj.bcebos.com/dygraph/image_classification/resnet50_vd_10w.pdparams -O ' +
checkpoint)
model_dict = paddle.load(checkpoint)[0]
self.set_dict(model_dict)
print("load pretrained checkpoint success")
def forward(self, inputs: paddle.Tensor):
y = self.conv1_1(inputs)
y = self.conv1_2(y)
y = self.conv1_3(y)
y = self.pool2d_max(y)
for block in self.block_list:
y = block(y)
y = self.pool2d_avg(y)
y = paddle.reshape(y, shape=[-1, self.pool2d_avg_channels])
y = self.out(y)
return y
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import base64
import cv2
import os
import numpy as np
def base64_to_cv2(b64str):
data = base64.b64decode(b64str.encode('utf8'))
data = np.fromstring(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
return data
def softmax(x):
orig_shape = x.shape
if len(x.shape) > 1:
tmp = np.max(x, axis=1)
x -= tmp.reshape((x.shape[0], 1))
x = np.exp(x)
tmp = np.sum(x, axis=1)
x /= tmp.reshape((x.shape[0], 1))
else:
tmp = np.max(x)
x -= tmp
x = np.exp(x)
tmp = np.sum(x)
x /= tmp
return x
def postprocess(data_out, label_list, top_k):
"""
Postprocess output of network, one image at a time.
Args:
data_out (numpy.ndarray): output data of network.
label_list (list): list of label.
top_k (int): Return top k results.
"""
output = []
for result in data_out:
result_i = softmax(result)
output_i = {}
indexs = np.argsort(result_i)[::-1][0:top_k]
for index in indexs:
label = label_list[index].split(',')[0]
output_i[label] = float(result_i[index])
output.append(output_i)
return output
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
__all__ = [
"ResNet", "ResNet50_vd", "ResNet101_vd", "ResNet152_vd", "ResNet200_vd"
]
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
class ResNet():
def __init__(self, layers=50, is_3x3=False):
self.params = train_parameters
self.layers = layers
self.is_3x3 = is_3x3
def net(self, input):
is_3x3 = self.is_3x3
layers = self.layers
supported_layers = [50, 101, 152, 200]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, layers)
if layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
elif layers == 200:
depth = [3, 12, 48, 3]
num_filters = [64, 128, 256, 512]
if is_3x3 == False:
conv = self.conv_bn_layer(
input=input,
num_filters=64,
filter_size=7,
stride=2,
act='relu')
else:
conv = self.conv_bn_layer(
input=input,
num_filters=32,
filter_size=3,
stride=2,
act='relu',
name='conv1_1')
conv = self.conv_bn_layer(
input=conv,
num_filters=32,
filter_size=3,
stride=1,
act='relu',
name='conv1_2')
conv = self.conv_bn_layer(
input=conv,
num_filters=64,
filter_size=3,
stride=1,
act='relu',
name='conv1_3')
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
for block in range(len(depth)):
for i in range(depth[block]):
if layers in [101, 152, 200] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
if_first=block == 0,
name=conv_name)
pool = fluid.layers.pool2d(
input=conv, pool_size=7, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
return pool
def conv_bn_layer(self,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(
input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def conv_bn_layer_new(self,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
name=None):
pool = fluid.layers.pool2d(
input=input,
pool_size=2,
pool_stride=2,
pool_padding=0,
pool_type='avg')
conv = fluid.layers.conv2d(
input=pool,
num_filters=num_filters,
filter_size=filter_size,
stride=1,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(
input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def shortcut(self, input, ch_out, stride, name, if_first=False):
ch_in = input.shape[1]
if ch_in != ch_out or stride != 1:
if if_first:
return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
else:
return self.conv_bn_layer_new(
input, ch_out, 1, stride, name=name)
else:
return input
def bottleneck_block(self, input, num_filters, stride, name, if_first):
conv0 = self.conv_bn_layer(
input=input,
num_filters=num_filters,
filter_size=1,
act='relu',
name=name + "_branch2a")
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu',
name=name + "_branch2b")
conv2 = self.conv_bn_layer(
input=conv1,
num_filters=num_filters * 4,
filter_size=1,
act=None,
name=name + "_branch2c")
short = self.shortcut(
input,
num_filters * 4,
stride,
if_first=if_first,
name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu')
def ResNet50_vd():
model = ResNet(layers=50, is_3x3=True)
return model
def ResNet101_vd():
model = ResNet(layers=101, is_3x3=True)
return model
def ResNet152_vd():
model = ResNet(layers=152, is_3x3=True)
return model
def ResNet200_vd():
model = ResNet(layers=200, is_3x3=True)
return model
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册