未验证 提交 cdfcee74 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update se_resnet18_vd_imagenet_model (#2044)

* add se_resnet18_vd_imagenet_model

* update se_resnet18_vd_imagenet_model
上级 0514cf31
## 命令行预测 # se_resnet18_vd_imagenet
``` |模型名称|se_resnet18_vd_imagenet|
hub run se_resnet18_vd_imagenet --input_path "/PATH/TO/IMAGE" | :--- | :---: |
``` |类别|图像-图像分类|
|网络|SE-ResNet|
|数据集|ImageNet-2012|
|是否支持Fine-tuning|否|
|模型大小|48MB|
|最新更新日期|-|
|数据指标|-|
## API
```python ## 一、模型基本信息
def get_expected_image_width()
```
返回预处理的图片宽度,也就是224。
```python
def get_expected_image_height()
```
返回预处理的图片高度,也就是224。 - ### 模型介绍
```python - Squeeze-and-Excitation Networks是由Momenta在2017年提出的一种图像分类结构。该结构通过对特征通道间的相关性进行建模,把重要的特征进行强化来提升准确率。SE_ResNet基于ResNet模型添加了SE Block。该PaddleHub Module结构为SE_ResNet18,基于ImageNet-2012数据集训练,接受输入图片大小为224 x 224 x 3,支持直接通过命令行或者Python接口进行预测。
def get_pretrained_images_mean()
```
返回预处理的图片均值,也就是 \[0.485, 0.456, 0.406\]
```python ## 二、安装
def get_pretrained_images_std()
```
返回预处理的图片标准差,也就是 \[0.229, 0.224, 0.225\] - ### 1、环境依赖
- paddlepaddle >= 1.6.2
```python - paddlehub >= 1.6.0 | [如何安装paddlehub](../../../../docs/docs_ch/get_start/installation.rst)
def context(trainable=True, pretrained=True)
```
**参数**
* trainable (bool): 计算图的参数是否为可训练的; - ### 2、安装
* pretrained (bool): 是否加载默认的预训练模型。
**返回** - ```shell
$ hub install se_resnet18_vd_imagenet
```
- 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
* inputs (dict): 计算图的输入,key 为 'image', value 为图片的张量; ## 三、模型API预测
* outputs (dict): 计算图的输出,key 为 'classification' 和 'feature_map',其相应的值为:
* classification (paddle.fluid.framework.Variable): 分类结果,也就是全连接层的输出;
* feature\_map (paddle.fluid.framework.Variable): 特征匹配,全连接层前面的那个张量。
* context\_prog(fluid.Program): 计算图,用于迁移学习。
```python - ### 1、命令行预测
def classification(images=None,
paths=None,
batch_size=1,
use_gpu=False,
top_k=1):
```
**参数** - ```shell
$ hub run se_resnet18_vd_imagenet --input_path "/PATH/TO/IMAGE"
```
- 通过命令行方式实现分类模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst)
* images (list\[numpy.ndarray\]): 图片数据,每一个图片数据的shape 均为 \[H, W, C\],颜色空间为 BGR; - ### 2、预测代码示例
* paths (list\[str\]): 图片的路径;
* batch\_size (int): batch 的大小;
* use\_gpu (bool): 是否使用 GPU 来预测;
* top\_k (int): 返回预测结果的前 k 个。
**返回** - ```python
import paddlehub as hub
import cv2
res (list\[dict\]): 分类结果,列表的每一个元素均为字典,其中 key 为识别动物的类别,value为置信度。 classifier = hub.Module(name="se_resnet18_vd_imagenet")
result = classifier.classification(images=[cv2.imread('/PATH/TO/IMAGE')])
# or
# result = classifier.classification(paths=['/PATH/TO/IMAGE'])
```
```python - ### 3、API
def save_inference_model(dirname,
model_filename=None,
params_filename=None,
combined=True)
```
将模型保存到指定路径。
**参数** - ```python
def classification(images=None,
paths=None,
batch_size=1,
use_gpu=False,
top_k=1):
```
- 分类接口API。
- **参数**
* dirname: 存在模型的目录名称 - images (list\[numpy.ndarray\]): 图片数据,每一个图片数据的shape 均为 \[H, W, C\],颜色空间为 BGR; <br/>
* model\_filename: 模型文件名称,默认为\_\_model\_\_ - paths (list\[str\]): 图片的路径; <br/>
* params\_filename: 参数文件名称,默认为\_\_params\_\_(仅当`combined`为True时生效) - batch\_size (int): batch 的大小;<br/>
* combined: 是否将参数保存到统一的一个文件中 - use\_gpu (bool): 是否使用 GPU;**若使用GPU,请先设置CUDA_VISIBLE_DEVICES环境变量** <br/>
- top\_k (int): 返回预测结果的前 k 个。
## 预测代码示例 - **返回**
```python - res (list\[dict\]): 分类结果,列表的每一个元素均为字典,其中 key 为识别的菜品类别,value为置信度。
import paddlehub as hub
import cv2
classifier = hub.Module(name="se_resnet18_vd_imagenet")
result = classifier.classification(images=[cv2.imread('/PATH/TO/IMAGE')])
# or
# result = classifier.classification(paths=['/PATH/TO/IMAGE'])
```
## 服务部署 ## 四、服务部署
PaddleHub Serving可以部署一个在线图像识别服务。 - PaddleHub Serving可以部署一个图像识别的在线服务。
## 第一步:启动PaddleHub Serving - ### 第一步:启动PaddleHub Serving
运行启动命令: - 运行启动命令:
```shell - ```shell
$ hub serving start -m se_resnet18_vd_imagenet $ hub serving start -m se_resnet18_vd_imagenet
``` ```
这样就完成了一个在线图像识别服务化API的部署,默认端口号为8866。 - 这样就完成了一个图像识别的在线服务的部署,默认端口号为8866。
**NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置。 - **NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置。
## 第二步:发送预测请求 - ### 第二步:发送预测请求
配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果 - 配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
```python - ```python
import requests import requests
import json import json
import cv2 import cv2
import base64 import base64
def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tostring()).decode('utf8')
def cv2_to_base64(image): # 发送HTTP请求
data = cv2.imencode('.jpg', image)[1] data = {'images':[cv2_to_base64(cv2.imread("/PATH/TO/IMAGE"))]}
return base64.b64encode(data.tostring()).decode('utf8') headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/se_resnet18_vd_imagenet"
r = requests.post(url=url, headers=headers, data=json.dumps(data))
# 打印预测结果
print(r.json()["results"])
```
# 发送HTTP请求
data = {'images':[cv2_to_base64(cv2.imread("/PATH/TO/IMAGE"))]}
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/se_resnet18_vd_imagenet"
r = requests.post(url=url, headers=headers, data=json.dumps(data))
# 打印预测结果 ## 五、更新历史
print(r.json()["results"])
```
### 查看代码 * 1.0.0
https://github.com/PaddlePaddle/PaddleClas 初始发布
### 依赖 * 1.1.0
paddlepaddle >= 1.6.2 移除 Fluid API
paddlehub >= 1.6.0 - ```shell
$ hub install se_resnet18_vd_imagenet==1.1.0
```
# se_resnet18_vd_imagenet
|Module Name|se_resnet18_vd_imagenet|
| :--- | :---: |
|Category|image classification|
|Network|SE-ResNet|
|Dataset|ImageNet-2012|
|Fine-tuning supported or not|No|
|Module Size|48MB|
|Latest update date|-|
|Data indicators|-|
## I.Basic Information
- ### Module Introduction
- Res2Net is an improvement on ResNet, which can improve performance without increasing computation. This module is based on Res2Net, trained on ImageNet-2012, and can predict an image of size 224*224*3.
## II.Installation
- ### 1、Environmental Dependence
- paddlepaddle >= 1.6.2
- paddlehub >= 1.6.0 | [How to install PaddleHub](../../../../docs/docs_en/get_start/installation.rst)
- ### 2、Installation
- ```shell
$ hub install se_resnet18_vd_imagenet
```
- In case of any problems during installation, please refer to: [Windows_Quickstart](../../../../docs/docs_en/get_start/windows_quickstart.md) | [Linux_Quickstart](../../../../docs/docs_en/get_start/linux_quickstart.md) | [Mac_Quickstart](../../../../docs/docs_en/get_start/mac_quickstart.md)
## III.Module API Prediction
- ### 1、Command line Prediction
- ```shell
$ hub run se_resnet18_vd_imagenet --input_path "/PATH/TO/IMAGE"
```
- If you want to call the Hub module through the command line, please refer to: [PaddleHub Command Line Instruction](../../../../docs/docs_ch/tutorial/cmd_usage.rst)
- ### 2、Prediction Code Example
- ```python
import paddlehub as hub
import cv2
classifier = hub.Module(name="se_resnet18_vd_imagenet")
result = classifier.classification(images=[cv2.imread('/PATH/TO/IMAGE')])
# or
# result = classifier.classification(paths=['/PATH/TO/IMAGE'])
```
- ### 3、API
- ```python
def classification(images=None,
paths=None,
batch_size=1,
use_gpu=False,
top_k=1):
```
- classification API.
- **Parameters**
- images (list\[numpy.ndarray\]): image data, ndarray.shape is in the format [H, W, C], BGR;
- paths (list[str]): image path;
- batch_size (int): the size of batch;
- use_gpu (bool): use GPU or not; **set the CUDA_VISIBLE_DEVICES environment variable first if you are using GPU**
- top\_k (int): return the first k results
- **Return**
- res (list\[dict\]): classication results, each element in the list is dict, key is the label name, and value is the corresponding probability
## IV.Server Deployment
- PaddleHub Serving can deploy an online service of image classification.
- ### Step 1: Start PaddleHub Serving
- Run the startup command:
- ```shell
$ hub serving start -m se_resnet18_vd_imagenet
```
- The servitization API is now deployed and the default port number is 8866.
- **NOTE:** If GPU is used for prediction, set CUDA_VISIBLE_DEVICES environment variable before the service, otherwise it need not be set.
- ### Step 2: Send a predictive request
- With a configured server, use the following lines of code to send the prediction request and obtain the result
- ```python
import requests
import json
import cv2
import base64
def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tostring()).decode('utf8')
# Send an HTTP request
data = {'images':[cv2_to_base64(cv2.imread("/PATH/TO/IMAGE"))]}
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/se_resnet18_vd_imagenet"
r = requests.post(url=url, headers=headers, data=json.dumps(data))
# print prediction results
print(r.json()["results"])
```
## V.Release Note
* 1.0.0
First release
* 1.1.0
Remove Fluid API
- ```shell
$ hub install se_resnet18_vd_imagenet==1.1.0
```
# coding=utf-8
import os import os
import time import time
from collections import OrderedDict from collections import OrderedDict
import cv2
import numpy as np import numpy as np
from PIL import Image from PIL import Image
......
# coding=utf-8
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
import ast
import argparse import argparse
import ast
import os import os
import numpy as np import numpy as np
import paddle.fluid as fluid from paddle.inference import Config
import paddlehub as hub from paddle.inference import create_predictor
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor
from paddlehub.module.module import moduleinfo, runnable, serving from .data_feed import reader
from paddlehub.common.paddle_helper import add_vars_prefix from .processor import base64_to_cv2
from .processor import postprocess
from se_resnet18_vd_imagenet.processor import postprocess, base64_to_cv2 from paddlehub.module.module import moduleinfo
from se_resnet18_vd_imagenet.data_feed import reader from paddlehub.module.module import runnable
from se_resnet18_vd_imagenet.se_resnet import SE_ResNet18_vd from paddlehub.module.module import serving
@moduleinfo( @moduleinfo(name="se_resnet18_vd_imagenet",
name="se_resnet18_vd_imagenet", type="CV/image_classification",
type="CV/image_classification", author="paddlepaddle",
author="paddlepaddle", author_email="paddle-dev@baidu.com",
author_email="paddle-dev@baidu.com", summary="SE_ResNet18_vd is a image classfication model, this module is trained with imagenet datasets.",
summary="SE_ResNet18_vd is a image classfication model, this module is trained with imagenet datasets.", version="1.1.0")
version="1.0.0") class SEResNet18vdImageNet:
class SEResNet18vdImageNet(hub.Module):
def _initialize(self): def __init__(self):
self.default_pretrained_model_path = os.path.join(self.directory, "se_resnet18_vd_imagenet_model") self.default_pretrained_model_path = os.path.join(self.directory, "se_resnet18_vd_imagenet_model", "model")
label_file = os.path.join(self.directory, "label_list.txt") label_file = os.path.join(self.directory, "label_list.txt")
with open(label_file, 'r', encoding='utf-8') as file: with open(label_file, 'r', encoding='utf-8') as file:
self.label_list = file.read().split("\n")[:-1] self.label_list = file.read().split("\n")[:-1]
...@@ -51,10 +50,12 @@ class SEResNet18vdImageNet(hub.Module): ...@@ -51,10 +50,12 @@ class SEResNet18vdImageNet(hub.Module):
""" """
predictor config setting predictor config setting
""" """
cpu_config = AnalysisConfig(self.default_pretrained_model_path) model = self.default_pretrained_model_path + '.pdmodel'
params = self.default_pretrained_model_path + '.pdiparams'
cpu_config = Config(model, params)
cpu_config.disable_glog_info() cpu_config.disable_glog_info()
cpu_config.disable_gpu() cpu_config.disable_gpu()
self.cpu_predictor = create_paddle_predictor(cpu_config) self.cpu_predictor = create_predictor(cpu_config)
try: try:
_places = os.environ["CUDA_VISIBLE_DEVICES"] _places = os.environ["CUDA_VISIBLE_DEVICES"]
...@@ -63,58 +64,10 @@ class SEResNet18vdImageNet(hub.Module): ...@@ -63,58 +64,10 @@ class SEResNet18vdImageNet(hub.Module):
except: except:
use_gpu = False use_gpu = False
if use_gpu: if use_gpu:
gpu_config = AnalysisConfig(self.default_pretrained_model_path) gpu_config = Config(model, params)
gpu_config.disable_glog_info() gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0) gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0)
self.gpu_predictor = create_paddle_predictor(gpu_config) self.gpu_predictor = create_predictor(gpu_config)
def context(self, trainable=True, pretrained=True):
"""context for transfer learning.
Args:
trainable (bool): Set parameters in program to be trainable.
pretrained (bool) : Whether to load pretrained model.
Returns:
inputs (dict): key is 'image', corresponding vaule is image tensor.
outputs (dict): key is :
'classification', corresponding value is the result of classification.
'feature_map', corresponding value is the result of the layer before the fully connected layer.
context_prog (fluid.Program): program for transfer learning.
"""
context_prog = fluid.Program()
startup_prog = fluid.Program()
with fluid.program_guard(context_prog, startup_prog):
with fluid.unique_name.guard():
image = fluid.layers.data(name="image", shape=[3, 224, 224], dtype="float32")
resnet_vd = SE_ResNet18_vd()
output, feature_map = resnet_vd.net(input=image, class_dim=len(self.label_list))
name_prefix = '@HUB_{}@'.format(self.name)
inputs = {'image': name_prefix + image.name}
outputs = {'classification': name_prefix + output.name, 'feature_map': name_prefix + feature_map.name}
add_vars_prefix(context_prog, name_prefix)
add_vars_prefix(startup_prog, name_prefix)
global_vars = context_prog.global_block().vars
inputs = {key: global_vars[value] for key, value in inputs.items()}
outputs = {key: global_vars[value] for key, value in outputs.items()}
place = fluid.CPUPlace()
exe = fluid.Executor(place)
# pretrained
if pretrained:
def _if_exist(var):
b = os.path.exists(os.path.join(self.default_pretrained_model_path, var.name))
return b
fluid.io.load_vars(exe, self.default_pretrained_model_path, context_prog, predicate=_if_exist)
else:
exe.run(startup_prog)
# trainable
for param in context_prog.global_block().iter_parameters():
param.trainable = trainable
return inputs, outputs, context_prog
def classification(self, images=None, paths=None, batch_size=1, use_gpu=False, top_k=1): def classification(self, images=None, paths=None, batch_size=1, use_gpu=False, top_k=1):
""" """
...@@ -136,7 +89,7 @@ class SEResNet18vdImageNet(hub.Module): ...@@ -136,7 +89,7 @@ class SEResNet18vdImageNet(hub.Module):
int(_places[0]) int(_places[0])
except: except:
raise RuntimeError( raise RuntimeError(
"Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES as cuda_device_id." "Attempt to use GPU for prediction, but environment variable CUDA_VISIBLE_DEVICES was not set correctly."
) )
if not self.predictor_set: if not self.predictor_set:
...@@ -161,32 +114,19 @@ class SEResNet18vdImageNet(hub.Module): ...@@ -161,32 +114,19 @@ class SEResNet18vdImageNet(hub.Module):
pass pass
# feed batch image # feed batch image
batch_image = np.array([data['image'] for data in batch_data]) batch_image = np.array([data['image'] for data in batch_data])
batch_image = PaddleTensor(batch_image.copy())
predictor_output = self.gpu_predictor.run([batch_image]) if use_gpu else self.cpu_predictor.run( predictor = self.gpu_predictor if use_gpu else self.cpu_predictor
[batch_image]) input_names = predictor.get_input_names()
out = postprocess(data_out=predictor_output[0].as_ndarray(), label_list=self.label_list, top_k=top_k) input_handle = predictor.get_input_handle(input_names[0])
input_handle.copy_from_cpu(batch_image.copy())
predictor.run()
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
out = postprocess(data_out=output_handle.copy_to_cpu(), label_list=self.label_list, top_k=top_k)
res += out res += out
return res return res
def save_inference_model(self, dirname, model_filename=None, params_filename=None, combined=True):
if combined:
model_filename = "__model__" if not model_filename else model_filename
params_filename = "__params__" if not params_filename else params_filename
place = fluid.CPUPlace()
exe = fluid.Executor(place)
program, feeded_var_names, target_vars = fluid.io.load_inference_model(
dirname=self.default_pretrained_model_path, executor=exe)
fluid.io.save_inference_model(
dirname=dirname,
main_program=program,
executor=exe,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
model_filename=model_filename,
params_filename=params_filename)
@serving @serving
def serving_method(self, images, **kwargs): def serving_method(self, images, **kwargs):
""" """
...@@ -201,11 +141,10 @@ class SEResNet18vdImageNet(hub.Module): ...@@ -201,11 +141,10 @@ class SEResNet18vdImageNet(hub.Module):
""" """
Run as a command. Run as a command.
""" """
self.parser = argparse.ArgumentParser( self.parser = argparse.ArgumentParser(description="Run the {} module.".format(self.name),
description="Run the {} module.".format(self.name), prog='hub run {}'.format(self.name),
prog='hub run {}'.format(self.name), usage='%(prog)s',
usage='%(prog)s', add_help=True)
add_help=True)
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required") self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group( self.arg_config_group = self.parser.add_argument_group(
title="Config options", description="Run configuration for controlling module behavior, not required.") title="Config options", description="Run configuration for controlling module behavior, not required.")
...@@ -219,8 +158,10 @@ class SEResNet18vdImageNet(hub.Module): ...@@ -219,8 +158,10 @@ class SEResNet18vdImageNet(hub.Module):
""" """
Add the command config options. Add the command config options.
""" """
self.arg_config_group.add_argument( self.arg_config_group.add_argument('--use_gpu',
'--use_gpu', type=ast.literal_eval, default=False, help="whether use GPU or not.") type=ast.literal_eval,
default=False,
help="whether use GPU or not.")
self.arg_config_group.add_argument('--batch_size', type=ast.literal_eval, default=1, help="batch size.") self.arg_config_group.add_argument('--batch_size', type=ast.literal_eval, default=1, help="batch size.")
self.arg_config_group.add_argument('--top_k', type=ast.literal_eval, default=1, help="Return top k results.") self.arg_config_group.add_argument('--top_k', type=ast.literal_eval, default=1, help="Return top k results.")
......
# coding=utf-8
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import base64 import base64
import cv2
import os
import cv2
import numpy as np import numpy as np
......
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
__all__ = ["ResNet", "ResNet18_vd", "ResNet34_vd", "ResNet50_vd", "ResNet101_vd", "ResNet152_vd", "ResNet200_vd"]
class ResNet():
def __init__(self, layers=50, is_3x3=False):
self.layers = layers
self.is_3x3 = is_3x3
def net(self, input, class_dim=1000):
is_3x3 = self.is_3x3
layers = self.layers
supported_layers = [18, 34, 50, 101, 152, 200]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, layers)
if layers == 18:
depth = [2, 2, 2, 2]
elif layers == 34 or layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
elif layers == 200:
depth = [3, 12, 48, 3]
num_filters = [64, 128, 256, 512]
if is_3x3 == False:
conv = self.conv_bn_layer(input=input, num_filters=64, filter_size=7, stride=2, act='relu')
else:
conv = self.conv_bn_layer(input=input, num_filters=32, filter_size=3, stride=2, act='relu', name='conv1_1')
conv = self.conv_bn_layer(input=conv, num_filters=32, filter_size=3, stride=1, act='relu', name='conv1_2')
conv = self.conv_bn_layer(input=conv, num_filters=64, filter_size=3, stride=1, act='relu', name='conv1_3')
conv = fluid.layers.pool2d(input=conv, pool_size=3, pool_stride=2, pool_padding=1, pool_type='max')
if layers >= 50:
for block in range(len(depth)):
for i in range(depth[block]):
if layers in [101, 152, 200] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
if_first=block == i == 0,
name=conv_name)
else:
for block in range(len(depth)):
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
conv = self.basic_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
if_first=block == i == 0,
name=conv_name)
pool = fluid.layers.pool2d(input=conv, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(
input=pool,
size=class_dim,
param_attr=fluid.param_attr.ParamAttr(initializer=fluid.initializer.Uniform(-stdv, stdv)))
return out, pool
def conv_bn_layer(self, input, num_filters, filter_size, stride=1, groups=1, act=None, name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(
input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def conv_bn_layer_new(self, input, num_filters, filter_size, stride=1, groups=1, act=None, name=None):
pool = fluid.layers.pool2d(
input=input, pool_size=2, pool_stride=2, pool_padding=0, pool_type='avg', ceil_mode=True)
conv = fluid.layers.conv2d(
input=pool,
num_filters=num_filters,
filter_size=filter_size,
stride=1,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(
input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def shortcut(self, input, ch_out, stride, name, if_first=False):
ch_in = input.shape[1]
if ch_in != ch_out or stride != 1:
if if_first:
return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
else:
return self.conv_bn_layer_new(input, ch_out, 1, stride, name=name)
elif if_first:
return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
else:
return input
def bottleneck_block(self, input, num_filters, stride, name, if_first):
conv0 = self.conv_bn_layer(
input=input, num_filters=num_filters, filter_size=1, act='relu', name=name + "_branch2a")
conv1 = self.conv_bn_layer(
input=conv0, num_filters=num_filters, filter_size=3, stride=stride, act='relu', name=name + "_branch2b")
conv2 = self.conv_bn_layer(
input=conv1, num_filters=num_filters * 4, filter_size=1, act=None, name=name + "_branch2c")
short = self.shortcut(input, num_filters * 4, stride, if_first=if_first, name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu')
def basic_block(self, input, num_filters, stride, name, if_first):
conv0 = self.conv_bn_layer(
input=input, num_filters=num_filters, filter_size=3, act='relu', stride=stride, name=name + "_branch2a")
conv1 = self.conv_bn_layer(
input=conv0, num_filters=num_filters, filter_size=3, act=None, name=name + "_branch2b")
short = self.shortcut(input, num_filters, stride, if_first=if_first, name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=conv1, act='relu')
def ResNet18_vd():
model = ResNet(layers=18, is_3x3=True)
return model
def ResNet34_vd():
model = ResNet(layers=34, is_3x3=True)
return model
def ResNet50_vd():
model = ResNet(layers=50, is_3x3=True)
return model
def ResNet101_vd():
model = ResNet(layers=101, is_3x3=True)
return model
def ResNet152_vd():
model = ResNet(layers=152, is_3x3=True)
return model
def ResNet200_vd():
model = ResNet(layers=200, is_3x3=True)
return model
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
import math
__all__ = [
"SE_ResNet_vd", "SE_ResNet18_vd", "SE_ResNet34_vd", "SE_ResNet50_vd", "SE_ResNet101_vd", "SE_ResNet152_vd",
"SE_ResNet200_vd"
]
class SE_ResNet_vd():
def __init__(self, layers=50, is_3x3=False):
self.layers = layers
self.is_3x3 = is_3x3
def net(self, input, class_dim=1000):
is_3x3 = self.is_3x3
layers = self.layers
supported_layers = [18, 34, 50, 101, 152, 200]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, layers)
if layers == 18:
depth = [2, 2, 2, 2]
elif layers == 34 or layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
elif layers == 200:
depth = [3, 12, 48, 3]
num_filters = [64, 128, 256, 512]
reduction_ratio = 16
if is_3x3 == False:
conv = self.conv_bn_layer(input=input, num_filters=64, filter_size=7, stride=2, act='relu')
else:
conv = self.conv_bn_layer(input=input, num_filters=32, filter_size=3, stride=2, act='relu', name='conv1_1')
conv = self.conv_bn_layer(input=conv, num_filters=32, filter_size=3, stride=1, act='relu', name='conv1_2')
conv = self.conv_bn_layer(input=conv, num_filters=64, filter_size=3, stride=1, act='relu', name='conv1_3')
conv = fluid.layers.pool2d(input=conv, pool_size=3, pool_stride=2, pool_padding=1, pool_type='max')
if layers >= 50:
for block in range(len(depth)):
for i in range(depth[block]):
if layers in [101, 152, 200] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
if_first=block == i == 0,
reduction_ratio=reduction_ratio,
name=conv_name)
else:
for block in range(len(depth)):
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
conv = self.basic_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
if_first=block == i == 0,
reduction_ratio=reduction_ratio,
name=conv_name)
pool = fluid.layers.pool2d(input=conv, pool_size=7, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(
input=pool,
size=class_dim,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv), name='fc6_weights'),
bias_attr=ParamAttr(name='fc6_offset'))
return out, pool
def conv_bn_layer(self, input, num_filters, filter_size, stride=1, groups=1, act=None, name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(
input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def conv_bn_layer_new(self, input, num_filters, filter_size, stride=1, groups=1, act=None, name=None):
pool = fluid.layers.pool2d(
input=input, pool_size=2, pool_stride=2, pool_padding=0, pool_type='avg', ceil_mode=True)
conv = fluid.layers.conv2d(
input=pool,
num_filters=num_filters,
filter_size=filter_size,
stride=1,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(
input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def shortcut(self, input, ch_out, stride, name, if_first=False):
ch_in = input.shape[1]
if ch_in != ch_out or stride != 1:
if if_first:
return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
else:
return self.conv_bn_layer_new(input, ch_out, 1, stride, name=name)
elif if_first:
return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
else:
return input
def bottleneck_block(self, input, num_filters, stride, name, if_first, reduction_ratio):
conv0 = self.conv_bn_layer(
input=input, num_filters=num_filters, filter_size=1, act='relu', name=name + "_branch2a")
conv1 = self.conv_bn_layer(
input=conv0, num_filters=num_filters, filter_size=3, stride=stride, act='relu', name=name + "_branch2b")
conv2 = self.conv_bn_layer(
input=conv1, num_filters=num_filters * 4, filter_size=1, act=None, name=name + "_branch2c")
scale = self.squeeze_excitation(
input=conv2, num_channels=num_filters * 4, reduction_ratio=reduction_ratio, name='fc_' + name)
short = self.shortcut(input, num_filters * 4, stride, if_first=if_first, name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=scale, act='relu')
def basic_block(self, input, num_filters, stride, name, if_first, reduction_ratio):
conv0 = self.conv_bn_layer(
input=input, num_filters=num_filters, filter_size=3, act='relu', stride=stride, name=name + "_branch2a")
conv1 = self.conv_bn_layer(
input=conv0, num_filters=num_filters, filter_size=3, act=None, name=name + "_branch2b")
scale = self.squeeze_excitation(
input=conv1, num_channels=num_filters, reduction_ratio=reduction_ratio, name='fc_' + name)
short = self.shortcut(input, num_filters, stride, if_first=if_first, name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=scale, act='relu')
def squeeze_excitation(self, input, num_channels, reduction_ratio, name=None):
pool = fluid.layers.pool2d(input=input, pool_size=0, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
squeeze = fluid.layers.fc(
input=pool,
size=num_channels // reduction_ratio,
act='relu',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv), name=name + '_sqz_weights'),
bias_attr=ParamAttr(name=name + '_sqz_offset'))
stdv = 1.0 / math.sqrt(squeeze.shape[1] * 1.0)
excitation = fluid.layers.fc(
input=squeeze,
size=num_channels,
act='sigmoid',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv), name=name + '_exc_weights'),
bias_attr=ParamAttr(name=name + '_exc_offset'))
scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0)
return scale
def SE_ResNet18_vd():
model = SE_ResNet_vd(layers=18, is_3x3=True)
return model
def SE_ResNet34_vd():
model = SE_ResNet_vd(layers=34, is_3x3=True)
return model
def SE_ResNet50_vd():
model = SE_ResNet_vd(layers=50, is_3x3=True)
return model
def SE_ResNet101_vd():
model = SE_ResNet_vd(layers=101, is_3x3=True)
return model
def SE_ResNet152_vd():
model = SE_ResNet_vd(layers=152, is_3x3=True)
return model
def SE_ResNet200_vd():
model = SE_ResNet_vd(layers=200, is_3x3=True)
return model
import os
import shutil
import unittest
import cv2
import requests
import paddlehub as hub
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://unsplash.com/photos/brFsZ7qszSY/download?ixid=MnwxMjA3fDB8MXxzZWFyY2h8OHx8ZG9nfGVufDB8fHx8MTY2MzA1ODQ1MQ&force=true&w=640'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
cls.module = hub.Module(name="se_resnet18_vd_imagenet")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('inference')
def test_classification1(self):
results = self.module.classification(paths=['tests/test.jpg'])
data = results[0]
self.assertTrue('Pembroke' in data)
self.assertTrue(data['Pembroke'] > 0.5)
def test_classification2(self):
results = self.module.classification(images=[cv2.imread('tests/test.jpg')])
data = results[0]
self.assertTrue('Pembroke' in data)
self.assertTrue(data['Pembroke'] > 0.5)
def test_classification3(self):
results = self.module.classification(images=[cv2.imread('tests/test.jpg')], use_gpu=True)
data = results[0]
self.assertTrue('Pembroke' in data)
self.assertTrue(data['Pembroke'] > 0.5)
def test_classification4(self):
self.assertRaises(AssertionError, self.module.classification, paths=['no.jpg'])
def test_classification5(self):
self.assertRaises(TypeError, self.module.classification, images=['tests/test.jpg'])
def test_save_inference_model(self):
self.module.save_inference_model('./inference/model')
self.assertTrue(os.path.exists('./inference/model.pdmodel'))
self.assertTrue(os.path.exists('./inference/model.pdiparams'))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册