未验证 提交 18a4c142 编写于 作者: T Tingquan Gao 提交者: GitHub

Update preprocess and HubServing (#382)

add hub serving code and doc
上级 c933dcd8
{
"modules_info": {
"clas_system": {
"init_args": {
"version": "1.0.0",
"use_gpu": true
},
"predict_args": {
}
}
},
"port": 8866,
"use_multiprocess": false,
"workers": 2
}
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
sys.path.insert(0, ".")
import time
from paddlehub.common.logger import logger
from paddlehub.module.module import moduleinfo, serving
import cv2
import numpy as np
import paddlehub as hub
import tools.infer.predict as paddle_predict
from tools.infer.utils import Base64ToCV2
from deploy.hubserving.clas.params import read_params
@moduleinfo(
name="clas_system",
version="1.0.0",
summary="class system service",
author="paddle-dev",
author_email="paddle-dev@baidu.com",
type="cv/class")
class ClasSystem(hub.Module):
def _initialize(self, use_gpu=None):
"""
initialize with the necessary elements
"""
cfg = read_params()
if use_gpu is not None:
cfg.use_gpu = use_gpu
cfg.hubserving = True
cfg.enable_benchmark = False
self.args = cfg
if cfg.use_gpu:
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
print("Use GPU, GPU Memery:{}".format(cfg.gpu_mem))
print("CUDA_VISIBLE_DEVICES: ", _places)
except:
raise RuntimeError(
"Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES via export CUDA_VISIBLE_DEVICES=cuda_device_id."
)
else:
print("Use CPU")
def read_images(self, paths=[]):
images = []
for img_path in paths:
assert os.path.isfile(
img_path), "The {} isn't a valid file.".format(img_path)
img = cv2.imread(img_path)
if img is None:
logger.info("error in loading image:{}".format(img_path))
continue
img = img[:, :, ::-1]
images.append(img)
return images
def predict(self, images=[], paths=[], top_k=1):
"""
Args:
images (list(numpy.ndarray)): images data, shape of each is [H, W, C]. If images not paths
paths (list[str]): The paths of images. If paths not images
Returns:
res (list): The result of chinese texts and save path of images.
"""
if images != [] and isinstance(images, list) and paths == []:
predicted_data = images
elif images == [] and isinstance(paths, list) and paths != []:
predicted_data = self.read_images(paths)
else:
raise TypeError(
"The input data is inconsistent with expectations.")
assert predicted_data != [], "There is not any image to be predicted. Please check the input data."
all_results = []
for img in predicted_data:
if img is None:
logger.info("error in loading image")
all_results.append([])
continue
starttime = time.time()
self.args.image_file = img
self.args.top_k = top_k
classes, scores = paddle_predict.main(self.args)
elapse = time.time() - starttime
logger.info("Predict time: {}".format(elapse))
all_results.append([classes.tolist(), scores.tolist()])
return all_results
@serving
def serving_method(self, images, **kwargs):
"""
Run as a service.
"""
to_cv2 = Base64ToCV2()
images_decode = [to_cv2(image) for image in images]
results = self.predict(images_decode, **kwargs)
return results
if __name__ == '__main__':
clas = ClasSystem()
image_path = ['./deploy/hubserving/ILSVRC2012_val_00006666.JPEG', ]
res = clas.predict(paths=image_path, top_k=5)
print(res)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
class Config(object):
pass
def read_params():
cfg = Config()
cfg.model_file = "./inference/cls_infer.pdmodel"
cfg.params_file = "./inference/cls_infer.pdiparams"
cfg.batch_size = 1
cfg.use_gpu = False
cfg.ir_optim = True
cfg.gpu_mem = 8000
cfg.use_fp16 = False
cfg.use_tensorrt = False
# params for preprocess
cfg.resize_short = 256
cfg.resize = 224
cfg.normalize = True
return cfg
[English](readme_en.md) | 简体中文
# 基于PaddleHub Serving的服务部署
hubserving服务部署配置服务包`clas`下包含3个必选文件,目录如下:
```
deploy/hubserving/clas/
└─ __init__.py 空文件,必选
└─ config.json 配置文件,可选,使用配置启动服务时作为参数传入
└─ module.py 主模块,必选,包含服务的完整逻辑
└─ params.py 参数文件,必选,包含模型路径、前后处理参数等参数
```
## 快速启动服务
### 1. 准备环境
```shell
# 安装paddlehub
pip3 install paddlehub --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple
```
### 2. 下载推理模型
安装服务模块前,需要准备推理模型并放到正确路径,默认模型路径为:
```
分类推理模型结构文件:./inference/cls_infer.pdmodel
分类推理模型权重文件:./inference/cls_infer.pdiparams
```
**模型路径可在`params.py`中查看和修改。** 我们也提供了大量基于ImageNet-1k数据集的预训练模型,模型列表及下载地址详见[模型库概览](../../docs/zh_CN/models/models_intro.md),也可以替换成自己训练转换好的模型。
### 3. 安装服务模块
针对Linux环境和Windows环境,安装命令如下。
* 在Linux环境下,安装示例如下:
```shell
# 安装服务模块:
hub install deploy/hubserving/clas/
```
* 在Windows环境下(文件夹的分隔符为`\`),安装示例如下:
```shell
# 安装服务模块:
hub install deploy\hubserving\clas\
```
### 4. 启动服务
#### 方式1. 命令行命令启动(仅支持CPU)
**启动命令:**
```shell
$ hub serving start --modules Module1==Version1 \
--port XXXX \
--use_multiprocess \
--workers \
```
**参数:**
|参数|用途|
|-|-|
|--modules/-m| [**必选**] PaddleHub Serving预安装模型,以多个Module==Version键值对的形式列出<br>*`当不指定Version时,默认选择最新版本`*|
|--port/-p| [**可选**] 服务端口,默认为8866|
|--use_multiprocess| [**可选**] 是否启用并发方式,默认为单进程方式,推荐多核CPU机器使用此方式<br>*`Windows操作系统只支持单进程方式`*|
|--workers| [**可选**] 在并发方式下指定的并发任务数,默认为`2*cpu_count-1`,其中`cpu_count`为CPU核数|
如按默认参数启动服务: ```hub serving start -m clas_system```
这样就完成了一个服务化API的部署,使用默认端口号8866。
#### 方式2. 配置文件启动(支持CPU、GPU)
**启动命令:**
```hub serving start -c config.json```
其中,`config.json`格式如下:
```json
{
"modules_info": {
"clas_system": {
"init_args": {
"version": "1.0.0",
"use_gpu": true
},
"predict_args": {
}
}
},
"port": 8866,
"use_multiprocess": false,
"workers": 2
}
```
- `init_args`中的可配参数与`module.py`中的`_initialize`函数接口一致。其中,**当`use_gpu`为`true`时,表示使用GPU启动服务**。
- `predict_args`中的可配参数与`module.py`中的`predict`函数接口一致。
**注意:**
- 使用配置文件启动服务时,其他参数会被忽略。
- 如果使用GPU预测(即,`use_gpu`置为`true`),则需要在启动服务之前,设置CUDA_VISIBLE_DEVICES环境变量,如:```export CUDA_VISIBLE_DEVICES=0```,否则不用设置。
- **`use_gpu`不可与`use_multiprocess`同时为`true`**。
如,使用GPU 3号卡启动串联服务:
```shell
export CUDA_VISIBLE_DEVICES=3
hub serving start -c deploy/hubserving/clas/config.json
```
## 发送预测请求
配置好服务端,可使用以下命令发送预测请求,获取预测结果:
```python tools/test_hubserving.py server_url image_path```
需要给脚本传递2个参数:
- **server_url**:服务地址,格式为
`http://[ip_address]:[port]/predict/[module_name]`
- **image_path**:测试图像路径,可以是单张图片路径,也可以是图像集合目录路径
- **top_k**:[**可选**] 返回前 `top_k` 个 `score` ,默认为 `1`。
访问示例:
```python tools/test_hubserving.py http://127.0.0.1:8866/predict/clas_system ./deploy/hubserving/ILSVRC2012_val_00006666.JPEG 5```
## 返回结果格式说明
返回结果为列表(list),包含 `clas`,以及所有得分组成的 `scores` (list类型), `scores` 包含前 `top_k` 个 `score` 。
**说明:** 如果需要增加、删除、修改返回字段,可在相应模块的`module.py`文件中进行修改,完整流程参考下一节自定义修改服务模块。
## 自定义修改服务模块
如果需要修改服务逻辑,你一般需要操作以下步骤:
- 1、 停止服务
```hub serving stop --port/-p XXXX```
- 2、 到相应的`module.py`和`params.py`等文件中根据实际需求修改代码。
例如,例如需要替换部署服务所用模型,则需要到`params.py`中修改模型路径参数`cfg.model_file`和`cfg.params_file`。 **强烈建议修改后先直接运行`module.py`调试,能正确运行预测后再启动服务测试。**
- 3、 卸载旧服务包
```hub uninstall clas_system```
- 4、 安装修改后的新服务包
```hub install deploy/hubserving/clas_system/```
- 5、重新启动服务
```hub serving start -m clas_system```
English | [简体中文](readme.md)
# Service deployment based on PaddleHub Serving
HubServing service pack contains 3 files, the directory is as follows:
```
deploy/hubserving/clas/
└─ __init__.py Empty file, required
└─ config.json Configuration file, optional, passed in as a parameter when using configuration to start the service
└─ module.py Main module file, required, contains the complete logic of the service
└─ params.py Parameter file, required, including parameters such as model path, pre- and post-processing parameters
```
## Quick start service
### 1. Prepare the environment
```shell
# Install paddlehub
pip3 install paddlehub --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple
```
### 2. Download inference model
Before installing the service module, you need to prepare the inference model and put it in the correct path. The default model path is:
```
Model structure file: ./inference/cls_infer.pdmodel
Model parameters file: ./inference/cls_infer.pdiparams
```
**The model path can be found and modified in `params.py`.** More models provided by PaddleClas can be obtained from the [model library](../../docs/en/models/models_intro_en.md). You can also use models trained by yourself.
### 3. Install Service Module
* On Linux platform, the examples are as follows.
```shell
hub install deploy/hubserving/clas/
```
* On Windows platform, the examples are as follows.
```shell
hub install deploy\hubserving\clas\
```
### 4. Start service
#### Way 1. Start with command line parameters (CPU only)
**start command:**
```shell
$ hub serving start --modules Module1==Version1 \
--port XXXX \
--use_multiprocess \
--workers \
```
**parameters:**
|parameters|usage|
|-|-|
|--modules/-m|PaddleHub Serving pre-installed model, listed in the form of multiple Module==Version key-value pairs<br>*`When Version is not specified, the latest version is selected by default`*|
|--port/-p|Service port, default is 8866|
|--use_multiprocess|Enable concurrent mode, the default is single-process mode, this mode is recommended for multi-core CPU machines<br>*`Windows operating system only supports single-process mode`*|
|--workers|The number of concurrent tasks specified in concurrent mode, the default is `2*cpu_count-1`, where `cpu_count` is the number of CPU cores|
For example, start the 2-stage series service:
```shell
hub serving start -m clas_system
```
This completes the deployment of a service API, using the default port number 8866.
#### Way 2. Start with configuration file(CPU、GPU)
**start command:**
```shell
hub serving start --config/-c config.json
```
Wherein, the format of `config.json` is as follows:
```json
{
"modules_info": {
"clas_system": {
"init_args": {
"version": "1.0.0",
"use_gpu": true
},
"predict_args": {
}
}
},
"port": 8866,
"use_multiprocess": false,
"workers": 2
}
```
- The configurable parameters in `init_args` are consistent with the `_initialize` function interface in `module.py`. Among them, **when `use_gpu` is `true`, it means that the GPU is used to start the service**.
- The configurable parameters in `predict_args` are consistent with the `predict` function interface in `module.py`.
**Note:**
- When using the configuration file to start the service, other parameters will be ignored.
- If you use GPU prediction (that is, `use_gpu` is set to `true`), you need to set the environment variable CUDA_VISIBLE_DEVICES before starting the service, such as: ```export CUDA_VISIBLE_DEVICES=0```, otherwise you do not need to set it.
- **`use_gpu` and `use_multiprocess` cannot be `true` at the same time.**
For example, use GPU card No. 3 to start the 2-stage series service:
```shell
export CUDA_VISIBLE_DEVICES=3
hub serving start -c deploy/hubserving/clas/config.json
```
## Send prediction requests
After the service starts, you can use the following command to send a prediction request to obtain the prediction result:
```shell
python tools/test_hubserving.py server_url image_path
```
Two parameters need to be passed to the script:
- **server_url**:service address,format of which is
`http://[ip_address]:[port]/predict/[module_name]`
- **image_path**:Test image path, can be a single image path or an image directory path
- **top_k**:[**Optional**] Return the top `top_k` 's scores ,default by `1`.
**Eg.**
```shell
python tools/test_hubserving.py http://127.0.0.1:8866/predict/clas_system ./deploy/hubserving/ILSVRC2012_val_00006666.JPEG 5
```
## Returned result format
The returned result is a list, including classification results(`clas`), and the `top_k`'s scores(`socres`). And `scores` is a list, consist of `score`.
**Note:** If you need to add, delete or modify the returned fields, you can modify the file `module.py` of the corresponding module. For the complete process, refer to the user-defined modification service module in the next section.
## User defined service module modification
If you need to modify the service logic, the following steps are generally required:
- 1. Stop service
```shell
hub serving stop --port/-p XXXX
```
- 2. Modify the code in the corresponding files, like `module.py` and `params.py`, according to the actual needs.
For example, if you need to replace the model used by the deployed service, you need to modify model path parameters `cfg.model_file` and `cfg.params_file` in `params.py`. Of course, other related parameters may need to be modified at the same time. Please modify and debug according to the actual situation. It is suggested to run `module.py` directly for debugging after modification before starting the service test.
- 3. Uninstall old service module
```shell
hub uninstall clas_system
```
- 4. Install modified service module
```shell
hub install deploy/hubserving/clas_system/
```
- 5. Restart service
```shell
hub serving start -m clas_system
```
......@@ -198,7 +198,7 @@ After the training is completed, you can predict by using the pre-trained model
```python
python tools/infer/infer.py \
-i image path \
-m MobileNetV3_large_x1_0 \
--model MobileNetV3_large_x1_0 \
--pretrained_model "./output/MobileNetV3_large_x1_0/best_model/ppcls" \
--use_gpu True \
--load_static_weights False
......@@ -206,7 +206,7 @@ python tools/infer/infer.py \
Among them:
+ `image_file`(i): The path of the image file to be predicted, such as `./test.jpeg`;
+ `model`(m): Model name, such as `MobileNetV3_large_x1_0`;
+ `model`: Model name, such as `MobileNetV3_large_x1_0`;
+ `pretrained_model`: Weight file path, such as `./pretrained/MobileNetV3_large_x1_0_pretrained/`;
+ `use_gpu`: Whether to use the GPU, default by `True`;
+ `load_static_weights`: Whether to load the pre-trained model obtained from static image training, default by `False`;
......@@ -248,15 +248,15 @@ The above command will generate the model structure file (`cls_infer.pdmodel`) a
```bash
python tools/infer/predict.py \
--image_file image path \
-m "./inference/cls_infer.pdmodel" \
-p "./inference/cls_infer.pdiparams" \
--model_file "./inference/cls_infer.pdmodel" \
--params_file "./inference/cls_infer.pdiparams" \
--use_gpu=True \
--use_tensorrt=False
```
Among them:
+ `image_file`: The path of the image file to be predicted, such as `./test.jpeg`;
+ `model_file`(m): Model file path, such as `./MobileNetV3_large_x1_0/cls_infer.pdmodel`;
+ `params_file`(p): Weight file path, such as `./MobileNetV3_large_x1_0/cls_infer.pdiparams`;
+ `model_file`: Model file path, such as `./MobileNetV3_large_x1_0/cls_infer.pdmodel`;
+ `params_file`: Weight file path, such as `./MobileNetV3_large_x1_0/cls_infer.pdiparams`;
+ `use_tensorrt`: Whether to use the TesorRT, default by `True`;
+ `use_gpu`: Whether to use the GPU, default by `True`.
......
......@@ -212,7 +212,7 @@ python tools/eval.py \
```python
python tools/infer/infer.py \
-i 待预测的图片文件路径 \
-m MobileNetV3_large_x1_0 \
--model MobileNetV3_large_x1_0 \
--pretrained_model "./output/MobileNetV3_large_x1_0/best_model/ppcls" \
--use_gpu True \
--load_static_weights False
......@@ -220,7 +220,7 @@ python tools/infer/infer.py \
参数说明:
+ `image_file`(简写 i):待预测的图片文件路径或者批量预测时的图片文件夹,如 `./test.jpeg`
+ `model`(简写 m):模型名称,如 `MobileNetV3_large_x1_0`
+ `model`:模型名称,如 `MobileNetV3_large_x1_0`
+ `pretrained_model`:模型权重文件路径,如 `./output/MobileNetV3_large_x1_0/best_model/ppcls`
+ `use_gpu` : 是否开启GPU训练,默认值:`True`
+ `load_static_weights` : 模型权重文件是否为静态图训练得到的,默认值:`False`
......@@ -259,15 +259,15 @@ python tools/export_model.py \
```bash
python tools/infer/predict.py \
--image_file 图片路径 \
-m "./inference/cls_infer.pdmodel" \
-p "./inference/cls_infer.pdiparams" \
--model_file "./inference/cls_infer.pdmodel" \
--params_file "./inference/cls_infer.pdiparams" \
--use_gpu=True \
--use_tensorrt=False
```
其中:
+ `image_file`:待预测的图片文件路径,如 `./test.jpeg`
+ `model_file`(简写 m):模型结构文件路径,如 `./inference/cls_infer.pdmodel`
+ `params_file`(简写 p):模型权重文件路径,如 `./inference/cls_infer.pdiparams`
+ `model_file`:模型结构文件路径,如 `./inference/cls_infer.pdmodel`
+ `params_file`:模型权重文件路径,如 `./inference/cls_infer.pdiparams`
+ `use_tensorrt`:是否使用 TesorRT 预测引擎,默认值:`True`
+ `use_gpu`:是否使用 GPU 预测,默认值:`True`
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
from paddle.fluid.framework import Program, program_guard, name_scope, default_main_program
from paddle.fluid import unique_name, layers
import numpy as np
class ExponentialMovingAverage(object):
def __init__(self,
decay=0.999,
thres_steps=None,
zero_debias=False,
name=None):
class ExponentialMovingAverage():
def __init__(self, model, decay, thres_steps=True):
self._model = model
self._decay = decay
self._thres_steps = thres_steps
self._name = name if name is not None else ''
self._decay_var = self._get_ema_decay()
self._params_tmps = []
for param in default_main_program().global_block().all_parameters():
if param.do_model_average != False:
tmp = param.block.create_var(
name=unique_name.generate(".".join(
[self._name + param.name, 'ema_tmp'])),
dtype=param.dtype,
persistable=False,
stop_gradient=True)
self._params_tmps.append((param, tmp))
self._ema_vars = {}
for param, tmp in self._params_tmps:
with param.block.program._optimized_guard(
[param, tmp]), name_scope('moving_average'):
self._ema_vars[param.name] = self._create_ema_vars(param)
self.apply_program = Program()
block = self.apply_program.global_block()
with program_guard(main_program=self.apply_program):
decay_pow = self._get_decay_pow(block)
for param, tmp in self._params_tmps:
param = block._clone_variable(param)
tmp = block._clone_variable(tmp)
ema = block._clone_variable(self._ema_vars[param.name])
layers.assign(input=param, output=tmp)
# bias correction
if zero_debias:
ema = ema / (1.0 - decay_pow)
layers.assign(input=ema, output=param)
self._shadow = {}
self._backup = {}
self.restore_program = Program()
block = self.restore_program.global_block()
with program_guard(main_program=self.restore_program):
for param, tmp in self._params_tmps:
tmp = block._clone_variable(tmp)
param = block._clone_variable(param)
layers.assign(input=tmp, output=param)
def _get_ema_decay(self):
with default_main_program()._lr_schedule_guard():
decay_var = layers.tensor.create_global_var(
shape=[1],
value=self._decay,
dtype='float32',
persistable=True,
name="scheduled_ema_decay_rate")
if self._thres_steps is not None:
decay_t = (self._thres_steps + 1.0) / (self._thres_steps + 10.0)
with layers.control_flow.Switch() as switch:
with switch.case(decay_t < self._decay):
layers.tensor.assign(decay_t, decay_var)
with switch.default():
layers.tensor.assign(
np.array(
[self._decay], dtype=np.float32),
decay_var)
return decay_var
def _get_decay_pow(self, block):
global_steps = layers.learning_rate_scheduler._decay_step_counter()
decay_var = block._clone_variable(self._decay_var)
decay_pow_acc = layers.elementwise_pow(decay_var, global_steps + 1)
return decay_pow_acc
def _create_ema_vars(self, param):
param_ema = layers.create_global_var(
name=unique_name.generate(self._name + param.name + '_ema'),
shape=param.shape,
value=0.0,
dtype=param.dtype,
persistable=True)
return param_ema
def register(self):
self._update_step = 0
for name, param in self._model.named_parameters():
if param.stop_gradient is False:
self._shadow[name] = param.numpy().copy()
def update(self):
"""
Update Exponential Moving Average. Should only call this method in
train program.
"""
param_master_emas = []
for param, tmp in self._params_tmps:
with param.block.program._optimized_guard(
[param, tmp]), name_scope('moving_average'):
param_ema = self._ema_vars[param.name]
if param.name + '.master' in self._ema_vars:
master_ema = self._ema_vars[param.name + '.master']
param_master_emas.append([param_ema, master_ema])
else:
ema_t = param_ema * self._decay_var + param * (
1 - self._decay_var)
layers.assign(input=ema_t, output=param_ema)
# for fp16 params
for param_ema, master_ema in param_master_emas:
default_main_program().global_block().append_op(
type="cast",
inputs={"X": master_ema},
outputs={"Out": param_ema},
attrs={
"in_dtype": master_ema.dtype,
"out_dtype": param_ema.dtype
})
@signature_safe_contextmanager
def apply(self, executor, need_restore=True):
"""
Apply moving average to parameters for evaluation.
Args:
executor (Executor): The Executor to execute applying.
need_restore (bool): Whether to restore parameters after applying.
"""
executor.run(self.apply_program)
try:
yield
finally:
if need_restore:
self.restore(executor)
def restore(self, executor):
"""Restore parameters.
Args:
executor (Executor): The Executor to execute restoring.
"""
executor.run(self.restore_program)
decay = min(self._decay, (1 + self._update_step) / (
10 + self._update_step)) if self._thres_steps else self._decay
for name, param in self._model.named_parameters():
if param.stop_gradient is False:
assert name in self._shadow
new_val = np.array(param.numpy().copy())
old_val = np.array(self._shadow[name])
new_average = decay * old_val + (1 - decay) * new_val
self._shadow[name] = new_average
self._update_step += 1
return decay
def apply(self):
for name, param in self._model.named_parameters():
if param.stop_gradient is False:
assert name in self._shadow
self._backup[name] = np.array(param.numpy().copy())
param.set_value(np.array(self._shadow[name]))
def restore(self):
for name, param in self._model.named_parameters():
if param.stop_gradient is False:
assert name in self._backup
param.set_value(self._backup[name])
self._backup = {}
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import argparse
import functools
import shutil
import sys
def main():
"""
Usage: when training with flag use_ema, and evaluating EMA model, should clean the saved model at first.
To generate clean model:
python ema_clean.py ema_model_dir cleaned_model_dir
"""
cleaned_model_dir = sys.argv[1]
ema_model_dir = sys.argv[2]
if not os.path.exists(cleaned_model_dir):
os.makedirs(cleaned_model_dir)
items = os.listdir(ema_model_dir)
for item in items:
if item.find('ema') > -1:
item_clean = item.replace('_ema_0', '')
shutil.copyfile(os.path.join(ema_model_dir, item),
os.path.join(cleaned_model_dir, item_clean))
elif item.find('mean') > -1 or item.find('variance') > -1:
shutil.copyfile(os.path.join(ema_model_dir, item),
os.path.join(cleaned_model_dir, item))
if __name__ == '__main__':
main()
......@@ -13,7 +13,7 @@
# limitations under the License.
import numpy as np
import argparse
import cv2
import utils
import shutil
import os
......@@ -30,56 +30,6 @@ from paddle.distributed import ParallelEnv
import paddle.nn.functional as F
def parse_args():
def str2bool(v):
return v.lower() in ("true", "t", "1")
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--image_file", type=str)
parser.add_argument("-m", "--model", type=str)
parser.add_argument("-p", "--pretrained_model", type=str)
parser.add_argument("--class_num", type=int, default=1000)
parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument(
"--load_static_weights",
type=str2bool,
default=False,
help='Whether to load the pretrained weights saved in static mode')
# parameters for pre-label the images
parser.add_argument(
"--pre_label_image",
type=str2bool,
default=False,
help="Whether to pre-label the images using the loaded weights")
parser.add_argument("--pre_label_out_idr", type=str, default=None)
return parser.parse_args()
def create_operators():
size = 224
img_mean = [0.485, 0.456, 0.406]
img_std = [0.229, 0.224, 0.225]
img_scale = 1.0 / 255.0
decode_op = utils.DecodeImage()
resize_op = utils.ResizeImage(resize_short=256)
crop_op = utils.CropImage(size=(size, size))
normalize_op = utils.NormalizeImage(
scale=img_scale, mean=img_mean, std=img_std)
totensor_op = utils.ToTensor()
return [decode_op, resize_op, crop_op, normalize_op, totensor_op]
def preprocess(fname, ops):
data = open(fname, 'rb').read()
for op in ops:
data = op(data)
return data
def postprocess(outputs, topk=5):
output = outputs[0]
prob = np.array(output).flatten()
......@@ -112,8 +62,7 @@ def save_prelabel_results(class_id, input_filepath, output_idr):
def main():
args = parse_args()
operators = create_operators()
args = utils.parse_args()
# assign the place
place = 'gpu:{}'.format(ParallelEnv().dev_id) if args.use_gpu else 'cpu'
place = paddle.set_device(place)
......@@ -122,7 +71,8 @@ def main():
load_dygraph_pretrain(net, args.pretrained_model, args.load_static_weights)
image_list = get_image_list(args.image_file)
for idx, filename in enumerate(image_list):
data = preprocess(filename, operators)
img = cv2.imread(filename)[:, :, ::-1]
data = utils.preprocess(img, args)
data = np.expand_dims(data, axis=0)
data = paddle.to_tensor(data)
net.eval()
......
......@@ -12,35 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import utils
import sys
sys.path.insert(0, ".")
import tools.infer.utils as utils
import numpy as np
import cv2
import time
from paddle.inference import Config
from paddle.inference import create_predictor
def parse_args():
def str2bool(v):
return v.lower() in ("true", "t", "1")
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--image_file", type=str)
parser.add_argument("-m", "--model_file", type=str)
parser.add_argument("-p", "--params_file", type=str)
parser.add_argument("-b", "--batch_size", type=int, default=1)
parser.add_argument("--use_fp16", type=str2bool, default=False)
parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--gpu_mem", type=int, default=8000)
parser.add_argument("--enable_benchmark", type=str2bool, default=False)
parser.add_argument("--model_name", type=str)
return parser.parse_args()
def create_paddle_predictor(args):
config = Config(args.model_file, args.params_file)
......@@ -65,33 +47,7 @@ def create_paddle_predictor(args):
return predictor
def create_operators():
size = 224
img_mean = [0.485, 0.456, 0.406]
img_std = [0.229, 0.224, 0.225]
img_scale = 1.0 / 255.0
decode_op = utils.DecodeImage()
resize_op = utils.ResizeImage(resize_short=256)
crop_op = utils.CropImage(size=(size, size))
normalize_op = utils.NormalizeImage(
scale=img_scale, mean=img_mean, std=img_std)
totensor_op = utils.ToTensor()
return [decode_op, resize_op, crop_op, normalize_op, totensor_op]
def preprocess(fname, ops):
data = open(fname, 'rb').read()
for op in ops:
data = op(data)
return data
def main():
args = parse_args()
def main(args):
if not args.enable_benchmark:
assert args.batch_size == 1
assert args.use_fp16 is False
......@@ -102,7 +58,6 @@ def main():
if args.use_fp16 is True:
assert args.use_tensorrt is True
operators = create_operators()
predictor = create_paddle_predictor(args)
input_names = predictor.get_input_names()
......@@ -114,7 +69,15 @@ def main():
test_num = 500
test_time = 0.0
if not args.enable_benchmark:
inputs = preprocess(args.image_file, operators)
# for PaddleHubServing
if args.hubserving:
img = args.image_file
# for predict only
else:
img = cv2.imread(args.image_file)[:, :, ::-1]
assert img is not None, "Error in loading image: {}".format(
args.image_file)
inputs = utils.preprocess(img, args)
inputs = np.expand_dims(
inputs, axis=0).repeat(
args.batch_size, axis=0).copy()
......@@ -123,12 +86,7 @@ def main():
predictor.run()
output = output_tensor.copy_to_cpu()
output = output.flatten()
cls = np.argmax(output)
score = output[cls]
print("Current image file: {}".format(args.image_file))
print("\ttop-1 class: {0}".format(cls))
print("\ttop-1 score: {0}".format(score))
return utils.postprocess(output, args)
else:
for i in range(0, test_num + 10):
inputs = np.random.rand(args.batch_size, 3, 224,
......@@ -152,4 +110,8 @@ def main():
if __name__ == "__main__":
main()
args = utils.parse_args()
classes, scores = main(args)
print("Current image file: {}".format(args.image_file))
print("\ttop-1 class: {0}".format(classes[0]))
print("\ttop-1 score: {0}".format(scores[0]))
......@@ -12,23 +12,82 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import cv2
import numpy as np
class DecodeImage(object):
def __init__(self, to_rgb=True):
self.to_rgb = to_rgb
def __call__(self, img):
data = np.frombuffer(img, dtype='uint8')
img = cv2.imdecode(data, 1)
if self.to_rgb:
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (
img.shape)
img = img[:, :, ::-1]
return img
def parse_args():
def str2bool(v):
return v.lower() in ("true", "t", "1")
# general params
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--image_file", type=str)
parser.add_argument("--use_gpu", type=str2bool, default=True)
# params for preprocess
parser.add_argument("--resize_short", type=int, default=256)
parser.add_argument("--resize", type=int, default=224)
parser.add_argument("--normalize", type=str2bool, default=True)
# params for predict
parser.add_argument("--model_file", type=str)
parser.add_argument("--params_file", type=str)
parser.add_argument("-b", "--batch_size", type=int, default=1)
parser.add_argument("--use_fp16", type=str2bool, default=False)
parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--gpu_mem", type=int, default=8000)
parser.add_argument("--enable_benchmark", type=str2bool, default=False)
parser.add_argument("--model_name", type=str)
parser.add_argument("--top_k", type=int, default=1)
parser.add_argument("--hubserving", type=str2bool, default=False)
# params for infer
parser.add_argument("--model", type=str)
parser.add_argument("--pretrained_model", type=str)
parser.add_argument("--class_num", type=int, default=1000)
parser.add_argument(
"--load_static_weights",
type=str2bool,
default=False,
help='Whether to load the pretrained weights saved in static mode')
# parameters for pre-label the images
parser.add_argument(
"--pre_label_image",
type=str2bool,
default=False,
help="Whether to pre-label the images using the loaded weights")
parser.add_argument("--pre_label_out_idr", type=str, default=None)
return parser.parse_args()
def preprocess(img, args):
resize_op = ResizeImage(resize_short=args.resize_short)
img = resize_op(img)
crop_op = CropImage(size=(args.resize, args.resize))
img = crop_op(img)
if args.normalize:
img_mean = [0.485, 0.456, 0.406]
img_std = [0.229, 0.224, 0.225]
img_scale = 1.0 / 255.0
normalize_op = NormalizeImage(
scale=img_scale, mean=img_mean, std=img_std)
img = normalize_op(img)
tensor_op = ToTensor()
img = tensor_op(img)
return img
def postprocess(output, args):
output = output.flatten()
classes = np.argpartition(output, -args.top_k)[-args.top_k:]
classes = classes[np.argsort(-output[classes])]
scores = output[classes]
return classes, scores
class ResizeImage(object):
......@@ -82,3 +141,15 @@ class ToTensor(object):
def __call__(self, img):
img = img.transpose((2, 0, 1))
return img
class Base64ToCV2(object):
def __init__(self):
pass
def __call__(self, b64str):
import base64
data = base64.b64decode(b64str.encode('utf8'))
data = np.fromstring(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR)[:, :, ::-1]
return data
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
from ppcls.utils import logger
import cv2
import time
import requests
import json
import base64
import imghdr
def get_image_file_list(img_file):
imgs_lists = []
if img_file is None or not os.path.exists(img_file):
raise Exception("not found any img file in {}".format(img_file))
img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'GIF'}
if os.path.isfile(img_file) and imghdr.what(img_file) in img_end:
imgs_lists.append(img_file)
elif os.path.isdir(img_file):
for single_file in os.listdir(img_file):
file_path = os.path.join(img_file, single_file)
if imghdr.what(file_path) in img_end:
imgs_lists.append(file_path)
if len(imgs_lists) == 0:
raise Exception("not found any img file in {}".format(img_file))
return imgs_lists
def cv2_to_base64(image):
return base64.b64encode(image).decode('utf8')
def main(url, image_path, top_k=1):
image_file_list = get_image_file_list(image_path)
headers = {"Content-type": "application/json"}
cnt = 0
total_time = 0
all_acc = 0.0
for image_file in image_file_list:
img = open(image_file, 'rb').read()
if img is None:
logger.info("error in loading image:{}".format(image_file))
continue
data = {'images': [cv2_to_base64(img)], 'top_k': top_k}
starttime = time.time()
r = requests.post(url=url, headers=headers, data=json.dumps(data))
assert r.status_code == 200, "Request error, status_code: {}".format(
r.status_code)
elapse = time.time() - starttime
total_time += elapse
res = r.json()["results"][0]
classes = res[0]
scores = res[1]
all_acc += scores[0]
cnt += 1
scores = map(lambda x: round(x, 5), scores)
results = dict(zip(classes, scores))
file_str = image_file.split('/')[-1]
message = "No.{}, File:{}, The top-{} result(s):{}, Time cost:{:.3f}".format(
cnt, file_str, top_k, results, elapse)
logger.info(message)
logger.info("The average time cost: {}".format(float(total_time) / cnt))
logger.info("The average top-1 accuracy: {}".format(float(all_acc) / cnt))
if __name__ == '__main__':
if len(sys.argv) != 3 and len(sys.argv) != 4:
logger.info("Usage: %s server_url image_path" % sys.argv[0])
else:
server_url = sys.argv[1]
image_path = sys.argv[2]
top_k = int(sys.argv[3]) if len(sys.argv) == 4 else 1
main(server_url, image_path, top_k)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册