未验证 提交 cdfcee74 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update se_resnet18_vd_imagenet_model (#2044)

* add se_resnet18_vd_imagenet_model

* update se_resnet18_vd_imagenet_model
上级 0514cf31
## 命令行预测
# se_resnet18_vd_imagenet
```
hub run se_resnet18_vd_imagenet --input_path "/PATH/TO/IMAGE"
```
|模型名称|se_resnet18_vd_imagenet|
| :--- | :---: |
|类别|图像-图像分类|
|网络|SE-ResNet|
|数据集|ImageNet-2012|
|是否支持Fine-tuning|否|
|模型大小|48MB|
|最新更新日期|-|
|数据指标|-|
## API
```python
def get_expected_image_width()
```
## 一、模型基本信息
返回预处理的图片宽度,也就是224。
```python
def get_expected_image_height()
```
返回预处理的图片高度,也就是224。
- ### 模型介绍
```python
def get_pretrained_images_mean()
```
- Squeeze-and-Excitation Networks是由Momenta在2017年提出的一种图像分类结构。该结构通过对特征通道间的相关性进行建模,把重要的特征进行强化来提升准确率。SE_ResNet基于ResNet模型添加了SE Block。该PaddleHub Module结构为SE_ResNet18,基于ImageNet-2012数据集训练,接受输入图片大小为224 x 224 x 3,支持直接通过命令行或者Python接口进行预测。
返回预处理的图片均值,也就是 \[0.485, 0.456, 0.406\]
```python
def get_pretrained_images_std()
```
## 二、安装
返回预处理的图片标准差,也就是 \[0.229, 0.224, 0.225\]
- ### 1、环境依赖
- paddlepaddle >= 1.6.2
```python
def context(trainable=True, pretrained=True)
```
- paddlehub >= 1.6.0 | [如何安装paddlehub](../../../../docs/docs_ch/get_start/installation.rst)
**参数**
* trainable (bool): 计算图的参数是否为可训练的;
* pretrained (bool): 是否加载默认的预训练模型。
- ### 2、安装
**返回**
- ```shell
$ hub install se_resnet18_vd_imagenet
```
- 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
* inputs (dict): 计算图的输入,key 为 'image', value 为图片的张量;
* outputs (dict): 计算图的输出,key 为 'classification' 和 'feature_map',其相应的值为:
* classification (paddle.fluid.framework.Variable): 分类结果,也就是全连接层的输出;
* feature\_map (paddle.fluid.framework.Variable): 特征匹配,全连接层前面的那个张量。
* context\_prog(fluid.Program): 计算图,用于迁移学习。
## 三、模型API预测
```python
def classification(images=None,
paths=None,
batch_size=1,
use_gpu=False,
top_k=1):
```
- ### 1、命令行预测
**参数**
- ```shell
$ hub run se_resnet18_vd_imagenet --input_path "/PATH/TO/IMAGE"
```
- 通过命令行方式实现分类模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst)
* images (list\[numpy.ndarray\]): 图片数据,每一个图片数据的shape 均为 \[H, W, C\],颜色空间为 BGR;
* paths (list\[str\]): 图片的路径;
* batch\_size (int): batch 的大小;
* use\_gpu (bool): 是否使用 GPU 来预测;
* top\_k (int): 返回预测结果的前 k 个。
- ### 2、预测代码示例
**返回**
- ```python
import paddlehub as hub
import cv2
res (list\[dict\]): 分类结果,列表的每一个元素均为字典,其中 key 为识别动物的类别,value为置信度。
classifier = hub.Module(name="se_resnet18_vd_imagenet")
result = classifier.classification(images=[cv2.imread('/PATH/TO/IMAGE')])
# or
# result = classifier.classification(paths=['/PATH/TO/IMAGE'])
```
```python
def save_inference_model(dirname,
model_filename=None,
params_filename=None,
combined=True)
```
- ### 3、API
将模型保存到指定路径。
**参数**
- ```python
def classification(images=None,
paths=None,
batch_size=1,
use_gpu=False,
top_k=1):
```
- 分类接口API。
- **参数**
* dirname: 存在模型的目录名称
* model\_filename: 模型文件名称,默认为\_\_model\_\_
* params\_filename: 参数文件名称,默认为\_\_params\_\_(仅当`combined`为True时生效)
* combined: 是否将参数保存到统一的一个文件中
- images (list\[numpy.ndarray\]): 图片数据,每一个图片数据的shape 均为 \[H, W, C\],颜色空间为 BGR; <br/>
- paths (list\[str\]): 图片的路径; <br/>
- batch\_size (int): batch 的大小;<br/>
- use\_gpu (bool): 是否使用 GPU;**若使用GPU,请先设置CUDA_VISIBLE_DEVICES环境变量** <br/>
- top\_k (int): 返回预测结果的前 k 个。
## 预测代码示例
- **返回**
```python
import paddlehub as hub
import cv2
- res (list\[dict\]): 分类结果,列表的每一个元素均为字典,其中 key 为识别的菜品类别,value为置信度。
classifier = hub.Module(name="se_resnet18_vd_imagenet")
result = classifier.classification(images=[cv2.imread('/PATH/TO/IMAGE')])
# or
# result = classifier.classification(paths=['/PATH/TO/IMAGE'])
```
## 服务部署
## 四、服务部署
PaddleHub Serving可以部署一个在线图像识别服务。
- PaddleHub Serving可以部署一个图像识别的在线服务。
## 第一步:启动PaddleHub Serving
- ### 第一步:启动PaddleHub Serving
运行启动命令:
```shell
$ hub serving start -m se_resnet18_vd_imagenet
```
- 运行启动命令:
- ```shell
$ hub serving start -m se_resnet18_vd_imagenet
```
这样就完成了一个在线图像识别服务化API的部署,默认端口号为8866。
- 这样就完成了一个图像识别的在线服务的部署,默认端口号为8866。
**NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置。
- **NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置。
## 第二步:发送预测请求
- ### 第二步:发送预测请求
配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
- 配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
```python
import requests
import json
import cv2
import base64
- ```python
import requests
import json
import cv2
import base64
def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tostring()).decode('utf8')
def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tostring()).decode('utf8')
# 发送HTTP请求
data = {'images':[cv2_to_base64(cv2.imread("/PATH/TO/IMAGE"))]}
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/se_resnet18_vd_imagenet"
r = requests.post(url=url, headers=headers, data=json.dumps(data))
# 打印预测结果
print(r.json()["results"])
```
# 发送HTTP请求
data = {'images':[cv2_to_base64(cv2.imread("/PATH/TO/IMAGE"))]}
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/se_resnet18_vd_imagenet"
r = requests.post(url=url, headers=headers, data=json.dumps(data))
# 打印预测结果
print(r.json()["results"])
```
## 五、更新历史
### 查看代码
* 1.0.0
https://github.com/PaddlePaddle/PaddleClas
初始发布
### 依赖
* 1.1.0
paddlepaddle >= 1.6.2
移除 Fluid API
paddlehub >= 1.6.0
- ```shell
$ hub install se_resnet18_vd_imagenet==1.1.0
```
# se_resnet18_vd_imagenet
|Module Name|se_resnet18_vd_imagenet|
| :--- | :---: |
|Category|image classification|
|Network|SE-ResNet|
|Dataset|ImageNet-2012|
|Fine-tuning supported or not|No|
|Module Size|48MB|
|Latest update date|-|
|Data indicators|-|
## I.Basic Information
- ### Module Introduction
- Res2Net is an improvement on ResNet, which can improve performance without increasing computation. This module is based on Res2Net, trained on ImageNet-2012, and can predict an image of size 224*224*3.
## II.Installation
- ### 1、Environmental Dependence
- paddlepaddle >= 1.6.2
- paddlehub >= 1.6.0 | [How to install PaddleHub](../../../../docs/docs_en/get_start/installation.rst)
- ### 2、Installation
- ```shell
$ hub install se_resnet18_vd_imagenet
```
- In case of any problems during installation, please refer to: [Windows_Quickstart](../../../../docs/docs_en/get_start/windows_quickstart.md) | [Linux_Quickstart](../../../../docs/docs_en/get_start/linux_quickstart.md) | [Mac_Quickstart](../../../../docs/docs_en/get_start/mac_quickstart.md)
## III.Module API Prediction
- ### 1、Command line Prediction
- ```shell
$ hub run se_resnet18_vd_imagenet --input_path "/PATH/TO/IMAGE"
```
- If you want to call the Hub module through the command line, please refer to: [PaddleHub Command Line Instruction](../../../../docs/docs_ch/tutorial/cmd_usage.rst)
- ### 2、Prediction Code Example
- ```python
import paddlehub as hub
import cv2
classifier = hub.Module(name="se_resnet18_vd_imagenet")
result = classifier.classification(images=[cv2.imread('/PATH/TO/IMAGE')])
# or
# result = classifier.classification(paths=['/PATH/TO/IMAGE'])
```
- ### 3、API
- ```python
def classification(images=None,
paths=None,
batch_size=1,
use_gpu=False,
top_k=1):
```
- classification API.
- **Parameters**
- images (list\[numpy.ndarray\]): image data, ndarray.shape is in the format [H, W, C], BGR;
- paths (list[str]): image path;
- batch_size (int): the size of batch;
- use_gpu (bool): use GPU or not; **set the CUDA_VISIBLE_DEVICES environment variable first if you are using GPU**
- top\_k (int): return the first k results
- **Return**
- res (list\[dict\]): classication results, each element in the list is dict, key is the label name, and value is the corresponding probability
## IV.Server Deployment
- PaddleHub Serving can deploy an online service of image classification.
- ### Step 1: Start PaddleHub Serving
- Run the startup command:
- ```shell
$ hub serving start -m se_resnet18_vd_imagenet
```
- The servitization API is now deployed and the default port number is 8866.
- **NOTE:** If GPU is used for prediction, set CUDA_VISIBLE_DEVICES environment variable before the service, otherwise it need not be set.
- ### Step 2: Send a predictive request
- With a configured server, use the following lines of code to send the prediction request and obtain the result
- ```python
import requests
import json
import cv2
import base64
def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tostring()).decode('utf8')
# Send an HTTP request
data = {'images':[cv2_to_base64(cv2.imread("/PATH/TO/IMAGE"))]}
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/se_resnet18_vd_imagenet"
r = requests.post(url=url, headers=headers, data=json.dumps(data))
# print prediction results
print(r.json()["results"])
```
## V.Release Note
* 1.0.0
First release
* 1.1.0
Remove Fluid API
- ```shell
$ hub install se_resnet18_vd_imagenet==1.1.0
```
# coding=utf-8
import os
import time
from collections import OrderedDict
import cv2
import numpy as np
from PIL import Image
......
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
import ast
import argparse
import ast
import os
import numpy as np
import paddle.fluid as fluid
import paddlehub as hub
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor
from paddlehub.module.module import moduleinfo, runnable, serving
from paddlehub.common.paddle_helper import add_vars_prefix
from se_resnet18_vd_imagenet.processor import postprocess, base64_to_cv2
from se_resnet18_vd_imagenet.data_feed import reader
from se_resnet18_vd_imagenet.se_resnet import SE_ResNet18_vd
@moduleinfo(
name="se_resnet18_vd_imagenet",
type="CV/image_classification",
author="paddlepaddle",
author_email="paddle-dev@baidu.com",
summary="SE_ResNet18_vd is a image classfication model, this module is trained with imagenet datasets.",
version="1.0.0")
class SEResNet18vdImageNet(hub.Module):
def _initialize(self):
self.default_pretrained_model_path = os.path.join(self.directory, "se_resnet18_vd_imagenet_model")
from paddle.inference import Config
from paddle.inference import create_predictor
from .data_feed import reader
from .processor import base64_to_cv2
from .processor import postprocess
from paddlehub.module.module import moduleinfo
from paddlehub.module.module import runnable
from paddlehub.module.module import serving
@moduleinfo(name="se_resnet18_vd_imagenet",
type="CV/image_classification",
author="paddlepaddle",
author_email="paddle-dev@baidu.com",
summary="SE_ResNet18_vd is a image classfication model, this module is trained with imagenet datasets.",
version="1.1.0")
class SEResNet18vdImageNet:
def __init__(self):
self.default_pretrained_model_path = os.path.join(self.directory, "se_resnet18_vd_imagenet_model", "model")
label_file = os.path.join(self.directory, "label_list.txt")
with open(label_file, 'r', encoding='utf-8') as file:
self.label_list = file.read().split("\n")[:-1]
......@@ -51,10 +50,12 @@ class SEResNet18vdImageNet(hub.Module):
"""
predictor config setting
"""
cpu_config = AnalysisConfig(self.default_pretrained_model_path)
model = self.default_pretrained_model_path + '.pdmodel'
params = self.default_pretrained_model_path + '.pdiparams'
cpu_config = Config(model, params)
cpu_config.disable_glog_info()
cpu_config.disable_gpu()
self.cpu_predictor = create_paddle_predictor(cpu_config)
self.cpu_predictor = create_predictor(cpu_config)
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
......@@ -63,58 +64,10 @@ class SEResNet18vdImageNet(hub.Module):
except:
use_gpu = False
if use_gpu:
gpu_config = AnalysisConfig(self.default_pretrained_model_path)
gpu_config = Config(model, params)
gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0)
self.gpu_predictor = create_paddle_predictor(gpu_config)
def context(self, trainable=True, pretrained=True):
"""context for transfer learning.
Args:
trainable (bool): Set parameters in program to be trainable.
pretrained (bool) : Whether to load pretrained model.
Returns:
inputs (dict): key is 'image', corresponding vaule is image tensor.
outputs (dict): key is :
'classification', corresponding value is the result of classification.
'feature_map', corresponding value is the result of the layer before the fully connected layer.
context_prog (fluid.Program): program for transfer learning.
"""
context_prog = fluid.Program()
startup_prog = fluid.Program()
with fluid.program_guard(context_prog, startup_prog):
with fluid.unique_name.guard():
image = fluid.layers.data(name="image", shape=[3, 224, 224], dtype="float32")
resnet_vd = SE_ResNet18_vd()
output, feature_map = resnet_vd.net(input=image, class_dim=len(self.label_list))
name_prefix = '@HUB_{}@'.format(self.name)
inputs = {'image': name_prefix + image.name}
outputs = {'classification': name_prefix + output.name, 'feature_map': name_prefix + feature_map.name}
add_vars_prefix(context_prog, name_prefix)
add_vars_prefix(startup_prog, name_prefix)
global_vars = context_prog.global_block().vars
inputs = {key: global_vars[value] for key, value in inputs.items()}
outputs = {key: global_vars[value] for key, value in outputs.items()}
place = fluid.CPUPlace()
exe = fluid.Executor(place)
# pretrained
if pretrained:
def _if_exist(var):
b = os.path.exists(os.path.join(self.default_pretrained_model_path, var.name))
return b
fluid.io.load_vars(exe, self.default_pretrained_model_path, context_prog, predicate=_if_exist)
else:
exe.run(startup_prog)
# trainable
for param in context_prog.global_block().iter_parameters():
param.trainable = trainable
return inputs, outputs, context_prog
self.gpu_predictor = create_predictor(gpu_config)
def classification(self, images=None, paths=None, batch_size=1, use_gpu=False, top_k=1):
"""
......@@ -136,7 +89,7 @@ class SEResNet18vdImageNet(hub.Module):
int(_places[0])
except:
raise RuntimeError(
"Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES as cuda_device_id."
"Attempt to use GPU for prediction, but environment variable CUDA_VISIBLE_DEVICES was not set correctly."
)
if not self.predictor_set:
......@@ -161,32 +114,19 @@ class SEResNet18vdImageNet(hub.Module):
pass
# feed batch image
batch_image = np.array([data['image'] for data in batch_data])
batch_image = PaddleTensor(batch_image.copy())
predictor_output = self.gpu_predictor.run([batch_image]) if use_gpu else self.cpu_predictor.run(
[batch_image])
out = postprocess(data_out=predictor_output[0].as_ndarray(), label_list=self.label_list, top_k=top_k)
predictor = self.gpu_predictor if use_gpu else self.cpu_predictor
input_names = predictor.get_input_names()
input_handle = predictor.get_input_handle(input_names[0])
input_handle.copy_from_cpu(batch_image.copy())
predictor.run()
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
out = postprocess(data_out=output_handle.copy_to_cpu(), label_list=self.label_list, top_k=top_k)
res += out
return res
def save_inference_model(self, dirname, model_filename=None, params_filename=None, combined=True):
if combined:
model_filename = "__model__" if not model_filename else model_filename
params_filename = "__params__" if not params_filename else params_filename
place = fluid.CPUPlace()
exe = fluid.Executor(place)
program, feeded_var_names, target_vars = fluid.io.load_inference_model(
dirname=self.default_pretrained_model_path, executor=exe)
fluid.io.save_inference_model(
dirname=dirname,
main_program=program,
executor=exe,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
model_filename=model_filename,
params_filename=params_filename)
@serving
def serving_method(self, images, **kwargs):
"""
......@@ -201,11 +141,10 @@ class SEResNet18vdImageNet(hub.Module):
"""
Run as a command.
"""
self.parser = argparse.ArgumentParser(
description="Run the {} module.".format(self.name),
prog='hub run {}'.format(self.name),
usage='%(prog)s',
add_help=True)
self.parser = argparse.ArgumentParser(description="Run the {} module.".format(self.name),
prog='hub run {}'.format(self.name),
usage='%(prog)s',
add_help=True)
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group(
title="Config options", description="Run configuration for controlling module behavior, not required.")
......@@ -219,8 +158,10 @@ class SEResNet18vdImageNet(hub.Module):
"""
Add the command config options.
"""
self.arg_config_group.add_argument(
'--use_gpu', type=ast.literal_eval, default=False, help="whether use GPU or not.")
self.arg_config_group.add_argument('--use_gpu',
type=ast.literal_eval,
default=False,
help="whether use GPU or not.")
self.arg_config_group.add_argument('--batch_size', type=ast.literal_eval, default=1, help="batch size.")
self.arg_config_group.add_argument('--top_k', type=ast.literal_eval, default=1, help="Return top k results.")
......
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import base64
import cv2
import os
import cv2
import numpy as np
......
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
__all__ = ["ResNet", "ResNet18_vd", "ResNet34_vd", "ResNet50_vd", "ResNet101_vd", "ResNet152_vd", "ResNet200_vd"]
class ResNet():
def __init__(self, layers=50, is_3x3=False):
self.layers = layers
self.is_3x3 = is_3x3
def net(self, input, class_dim=1000):
is_3x3 = self.is_3x3
layers = self.layers
supported_layers = [18, 34, 50, 101, 152, 200]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, layers)
if layers == 18:
depth = [2, 2, 2, 2]
elif layers == 34 or layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
elif layers == 200:
depth = [3, 12, 48, 3]
num_filters = [64, 128, 256, 512]
if is_3x3 == False:
conv = self.conv_bn_layer(input=input, num_filters=64, filter_size=7, stride=2, act='relu')
else:
conv = self.conv_bn_layer(input=input, num_filters=32, filter_size=3, stride=2, act='relu', name='conv1_1')
conv = self.conv_bn_layer(input=conv, num_filters=32, filter_size=3, stride=1, act='relu', name='conv1_2')
conv = self.conv_bn_layer(input=conv, num_filters=64, filter_size=3, stride=1, act='relu', name='conv1_3')
conv = fluid.layers.pool2d(input=conv, pool_size=3, pool_stride=2, pool_padding=1, pool_type='max')
if layers >= 50:
for block in range(len(depth)):
for i in range(depth[block]):
if layers in [101, 152, 200] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
if_first=block == i == 0,
name=conv_name)
else:
for block in range(len(depth)):
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
conv = self.basic_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
if_first=block == i == 0,
name=conv_name)
pool = fluid.layers.pool2d(input=conv, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(
input=pool,
size=class_dim,
param_attr=fluid.param_attr.ParamAttr(initializer=fluid.initializer.Uniform(-stdv, stdv)))
return out, pool
def conv_bn_layer(self, input, num_filters, filter_size, stride=1, groups=1, act=None, name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(
input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def conv_bn_layer_new(self, input, num_filters, filter_size, stride=1, groups=1, act=None, name=None):
pool = fluid.layers.pool2d(
input=input, pool_size=2, pool_stride=2, pool_padding=0, pool_type='avg', ceil_mode=True)
conv = fluid.layers.conv2d(
input=pool,
num_filters=num_filters,
filter_size=filter_size,
stride=1,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(
input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def shortcut(self, input, ch_out, stride, name, if_first=False):
ch_in = input.shape[1]
if ch_in != ch_out or stride != 1:
if if_first:
return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
else:
return self.conv_bn_layer_new(input, ch_out, 1, stride, name=name)
elif if_first:
return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
else:
return input
def bottleneck_block(self, input, num_filters, stride, name, if_first):
conv0 = self.conv_bn_layer(
input=input, num_filters=num_filters, filter_size=1, act='relu', name=name + "_branch2a")
conv1 = self.conv_bn_layer(
input=conv0, num_filters=num_filters, filter_size=3, stride=stride, act='relu', name=name + "_branch2b")
conv2 = self.conv_bn_layer(
input=conv1, num_filters=num_filters * 4, filter_size=1, act=None, name=name + "_branch2c")
short = self.shortcut(input, num_filters * 4, stride, if_first=if_first, name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu')
def basic_block(self, input, num_filters, stride, name, if_first):
conv0 = self.conv_bn_layer(
input=input, num_filters=num_filters, filter_size=3, act='relu', stride=stride, name=name + "_branch2a")
conv1 = self.conv_bn_layer(
input=conv0, num_filters=num_filters, filter_size=3, act=None, name=name + "_branch2b")
short = self.shortcut(input, num_filters, stride, if_first=if_first, name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=conv1, act='relu')
def ResNet18_vd():
model = ResNet(layers=18, is_3x3=True)
return model
def ResNet34_vd():
model = ResNet(layers=34, is_3x3=True)
return model
def ResNet50_vd():
model = ResNet(layers=50, is_3x3=True)
return model
def ResNet101_vd():
model = ResNet(layers=101, is_3x3=True)
return model
def ResNet152_vd():
model = ResNet(layers=152, is_3x3=True)
return model
def ResNet200_vd():
model = ResNet(layers=200, is_3x3=True)
return model
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
import math
__all__ = [
"SE_ResNet_vd", "SE_ResNet18_vd", "SE_ResNet34_vd", "SE_ResNet50_vd", "SE_ResNet101_vd", "SE_ResNet152_vd",
"SE_ResNet200_vd"
]
class SE_ResNet_vd():
def __init__(self, layers=50, is_3x3=False):
self.layers = layers
self.is_3x3 = is_3x3
def net(self, input, class_dim=1000):
is_3x3 = self.is_3x3
layers = self.layers
supported_layers = [18, 34, 50, 101, 152, 200]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, layers)
if layers == 18:
depth = [2, 2, 2, 2]
elif layers == 34 or layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
elif layers == 200:
depth = [3, 12, 48, 3]
num_filters = [64, 128, 256, 512]
reduction_ratio = 16
if is_3x3 == False:
conv = self.conv_bn_layer(input=input, num_filters=64, filter_size=7, stride=2, act='relu')
else:
conv = self.conv_bn_layer(input=input, num_filters=32, filter_size=3, stride=2, act='relu', name='conv1_1')
conv = self.conv_bn_layer(input=conv, num_filters=32, filter_size=3, stride=1, act='relu', name='conv1_2')
conv = self.conv_bn_layer(input=conv, num_filters=64, filter_size=3, stride=1, act='relu', name='conv1_3')
conv = fluid.layers.pool2d(input=conv, pool_size=3, pool_stride=2, pool_padding=1, pool_type='max')
if layers >= 50:
for block in range(len(depth)):
for i in range(depth[block]):
if layers in [101, 152, 200] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
if_first=block == i == 0,
reduction_ratio=reduction_ratio,
name=conv_name)
else:
for block in range(len(depth)):
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
conv = self.basic_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
if_first=block == i == 0,
reduction_ratio=reduction_ratio,
name=conv_name)
pool = fluid.layers.pool2d(input=conv, pool_size=7, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(
input=pool,
size=class_dim,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv), name='fc6_weights'),
bias_attr=ParamAttr(name='fc6_offset'))
return out, pool
def conv_bn_layer(self, input, num_filters, filter_size, stride=1, groups=1, act=None, name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(
input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def conv_bn_layer_new(self, input, num_filters, filter_size, stride=1, groups=1, act=None, name=None):
pool = fluid.layers.pool2d(
input=input, pool_size=2, pool_stride=2, pool_padding=0, pool_type='avg', ceil_mode=True)
conv = fluid.layers.conv2d(
input=pool,
num_filters=num_filters,
filter_size=filter_size,
stride=1,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(
input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def shortcut(self, input, ch_out, stride, name, if_first=False):
ch_in = input.shape[1]
if ch_in != ch_out or stride != 1:
if if_first:
return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
else:
return self.conv_bn_layer_new(input, ch_out, 1, stride, name=name)
elif if_first:
return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
else:
return input
def bottleneck_block(self, input, num_filters, stride, name, if_first, reduction_ratio):
conv0 = self.conv_bn_layer(
input=input, num_filters=num_filters, filter_size=1, act='relu', name=name + "_branch2a")
conv1 = self.conv_bn_layer(
input=conv0, num_filters=num_filters, filter_size=3, stride=stride, act='relu', name=name + "_branch2b")
conv2 = self.conv_bn_layer(
input=conv1, num_filters=num_filters * 4, filter_size=1, act=None, name=name + "_branch2c")
scale = self.squeeze_excitation(
input=conv2, num_channels=num_filters * 4, reduction_ratio=reduction_ratio, name='fc_' + name)
short = self.shortcut(input, num_filters * 4, stride, if_first=if_first, name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=scale, act='relu')
def basic_block(self, input, num_filters, stride, name, if_first, reduction_ratio):
conv0 = self.conv_bn_layer(
input=input, num_filters=num_filters, filter_size=3, act='relu', stride=stride, name=name + "_branch2a")
conv1 = self.conv_bn_layer(
input=conv0, num_filters=num_filters, filter_size=3, act=None, name=name + "_branch2b")
scale = self.squeeze_excitation(
input=conv1, num_channels=num_filters, reduction_ratio=reduction_ratio, name='fc_' + name)
short = self.shortcut(input, num_filters, stride, if_first=if_first, name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=scale, act='relu')
def squeeze_excitation(self, input, num_channels, reduction_ratio, name=None):
pool = fluid.layers.pool2d(input=input, pool_size=0, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
squeeze = fluid.layers.fc(
input=pool,
size=num_channels // reduction_ratio,
act='relu',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv), name=name + '_sqz_weights'),
bias_attr=ParamAttr(name=name + '_sqz_offset'))
stdv = 1.0 / math.sqrt(squeeze.shape[1] * 1.0)
excitation = fluid.layers.fc(
input=squeeze,
size=num_channels,
act='sigmoid',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv), name=name + '_exc_weights'),
bias_attr=ParamAttr(name=name + '_exc_offset'))
scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0)
return scale
def SE_ResNet18_vd():
model = SE_ResNet_vd(layers=18, is_3x3=True)
return model
def SE_ResNet34_vd():
model = SE_ResNet_vd(layers=34, is_3x3=True)
return model
def SE_ResNet50_vd():
model = SE_ResNet_vd(layers=50, is_3x3=True)
return model
def SE_ResNet101_vd():
model = SE_ResNet_vd(layers=101, is_3x3=True)
return model
def SE_ResNet152_vd():
model = SE_ResNet_vd(layers=152, is_3x3=True)
return model
def SE_ResNet200_vd():
model = SE_ResNet_vd(layers=200, is_3x3=True)
return model
import os
import shutil
import unittest
import cv2
import requests
import paddlehub as hub
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://unsplash.com/photos/brFsZ7qszSY/download?ixid=MnwxMjA3fDB8MXxzZWFyY2h8OHx8ZG9nfGVufDB8fHx8MTY2MzA1ODQ1MQ&force=true&w=640'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
cls.module = hub.Module(name="se_resnet18_vd_imagenet")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('inference')
def test_classification1(self):
results = self.module.classification(paths=['tests/test.jpg'])
data = results[0]
self.assertTrue('Pembroke' in data)
self.assertTrue(data['Pembroke'] > 0.5)
def test_classification2(self):
results = self.module.classification(images=[cv2.imread('tests/test.jpg')])
data = results[0]
self.assertTrue('Pembroke' in data)
self.assertTrue(data['Pembroke'] > 0.5)
def test_classification3(self):
results = self.module.classification(images=[cv2.imread('tests/test.jpg')], use_gpu=True)
data = results[0]
self.assertTrue('Pembroke' in data)
self.assertTrue(data['Pembroke'] > 0.5)
def test_classification4(self):
self.assertRaises(AssertionError, self.module.classification, paths=['no.jpg'])
def test_classification5(self):
self.assertRaises(TypeError, self.module.classification, images=['tests/test.jpg'])
def test_save_inference_model(self):
self.module.save_inference_model('./inference/model')
self.assertTrue(os.path.exists('./inference/model.pdmodel'))
self.assertTrue(os.path.exists('./inference/model.pdiparams'))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册