未验证 提交 69e91d85 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

add note (#5427)

* add note

* fix log

* fix hflip in train

* fix readme

* fix serving doc

* fix link
上级 41c2fd76
...@@ -35,12 +35,14 @@ Paddle Serving依托深度学习框架PaddlePaddle旨在帮助深度学习开发 ...@@ -35,12 +35,14 @@ Paddle Serving依托深度学习框架PaddlePaddle旨在帮助深度学习开发
<a name="21---"></a> <a name="21---"></a>
### 2.1 准备测试数据和部署环境 ### 2.1 准备测试数据和部署环境
【基本流程】 **【基本流程】**
**(1)准备测试数据:** 从验证集或者测试集中抽出至少一张图像,用于后续推理过程验证。 **(1)准备测试数据:** 从验证集或者测试集中抽出至少一张图像,用于后续推理过程验证。
**(2)准备部署环境** **(2)准备部署环境**
docker是一个开源的应用容器引擎,可以让应用程序更加方便地被打包和移植。建议在docker中进行Serving服务化部署。
首先准备docker环境,AIStudio环境已经安装了合适的docker。如果是非AIStudio环境,请[参考文档](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/doc/doc_ch/environment.md#2)中的 "1.3.2 Docker环境配置" 安装docker环境。 首先准备docker环境,AIStudio环境已经安装了合适的docker。如果是非AIStudio环境,请[参考文档](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/doc/doc_ch/environment.md#2)中的 "1.3.2 Docker环境配置" 安装docker环境。
然后安装Paddle Serving三个安装包,paddle-serving-server,paddle-serving-client 和 paddle-serving-app。 然后安装Paddle Serving三个安装包,paddle-serving-server,paddle-serving-client 和 paddle-serving-app。
...@@ -68,7 +70,7 @@ Paddle Serving Server更多不同运行环境的whl包下载地址,请参考 ...@@ -68,7 +70,7 @@ Paddle Serving Server更多不同运行环境的whl包下载地址,请参考
``` ```
python3 -m paddle_serving_client.convert --dirname {静态图模型路径} --model_filename {模型结构文件} --params_filename {模型参数文件} --serving_server {转换后的服务器端模型和配置文件存储路径} --serving_client {转换后的客户端模型和配置文件存储路径} python3 -m paddle_serving_client.convert --dirname {静态图模型路径} --model_filename {模型结构文件} --params_filename {模型参数文件} --serving_server {转换后的服务器端模型和配置文件存储路径} --serving_client {转换后的客户端模型和配置文件存储路径}
``` ```
上面命令中 "转换后的服务器端模型和配置文件" 将用于后续服务化部署。 上面命令中 "转换后的服务器端模型和配置文件" 将用于后续服务化部署。其中`paddle_serving_client.convert`命令是`paddle_serving_client` whl包内置的转换函数,无需修改。
【实战】 【实战】
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np import numpy as np
import requests import requests
import json import json
...@@ -7,6 +21,15 @@ import os ...@@ -7,6 +21,15 @@ import os
def cv2_to_base64(image): def cv2_to_base64(image):
"""cv2_to_base64
Convert an numpy array to a base64 object.
Args:
image: Input array.
Returns: Base64 output of the input.
"""
return base64.b64encode(image).decode('utf8') return base64.b64encode(image).decode('utf8')
......
...@@ -16,26 +16,68 @@ from paddle_serving_server.web_service import WebService, Op ...@@ -16,26 +16,68 @@ from paddle_serving_server.web_service import WebService, Op
class TIPCExampleOp(Op): class TIPCExampleOp(Op):
""" """TIPCExampleOp
ExampleOp for serving server, you can rename by yourself
ExampleOp for serving server. You can rename by yourself.
""" """
def init_op(self): def init_op(self):
""" """init_op
initialize the class
Initialize the class.
Args: None
Returns: None
""" """
pass pass
def preprocess(self, input_dicts, data_id, log_id): def preprocess(self, input_dicts, data_id, log_id):
# preprocess for the inputs """preprocess
In preprocess stage, assembling data for process stage. users can
override this function for model feed features.
Args:
input_dicts: input data to be preprocessed
data_id: inner unique id, increase auto
log_id: global unique id for RTT, 0 default
Return:
output_data: data for process stage
is_skip_process: skip process stage or not, False default
prod_errcode: None default, otherwise, product errores occured.
It is handled in the same way as exception.
prod_errinfo: "" default
"""
pass pass
def postprocess(self, input_dicts, fetch_dict, data_id, log_id): def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
"""postprocess
In postprocess stage, assemble data for next op or output.
Args:
input_data: data returned in preprocess stage, dict(for single predict) or list(for batch predict)
fetch_data: data returned in process stage, dict(for single predict) or list(for batch predict)
data_id: inner unique id, increase auto
log_id: logid, 0 default
Returns:
fetch_dict: fetch result must be dict type.
prod_errcode: None default, otherwise, product errores occured.
It is handled in the same way as exception.
prod_errinfo: "" default
"""
# postprocess for the service output # postprocess for the service output
pass pass
class TIPCExampleService(WebService): class TIPCExampleService(WebService):
"""TIPCExampleService
Service class to define the Serving OP.
"""
def get_pipeline_response(self, read_op): def get_pipeline_response(self, read_op):
tipc_example_op = TIPCExampleOp( tipc_example_op = TIPCExampleOp(
name="tipc_example", input_ops=[read_op]) name="tipc_example", input_ops=[read_op])
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
...@@ -9,6 +23,16 @@ import numpy as np ...@@ -9,6 +23,16 @@ import numpy as np
# parse args # parse args
def get_args(add_help=True): def get_args(add_help=True):
"""get_args
Parse all args using argparse lib
Args:
add_help: Whether to add -h option on args
Returns:
An object which contains many parameters used for inference.
"""
import argparse import argparse
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='PaddlePaddle Args', add_help=add_help) description='PaddlePaddle Args', add_help=add_help)
...@@ -17,14 +41,29 @@ def get_args(add_help=True): ...@@ -17,14 +41,29 @@ def get_args(add_help=True):
def build_model(args): def build_model(args):
""" """build_model
build model
Build your own model.
Args:
args: Parameters generated using argparser.
Returns:
A model whose type is nn.Layer
""" """
pass pass
def export(args): def export(args):
# build your own model """export
export inference model using jit.save
Args:
args: Parameters generated using argparser.
Returns: None
"""
model = build_model(args) model = build_model(args)
# decorate model with jit.save # decorate model with jit.save
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import paddle import paddle
from paddle import inference from paddle import inference
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from reprod_log import ReprodLogger
from preprocess_ops import ResizeImage, CenterCropImage, NormalizeImage, ToCHW, Compose
class InferenceEngine(object): class InferenceEngine(object):
"""InferenceEngine
Inference engina class which contains preprocess, run, postprocess
"""
def __init__(self, args): def __init__(self, args):
"""
Args:
args: Parameters generated using argparser.
Returns: None
"""
super().__init__() super().__init__()
pass pass
def load_predictor(self, model_file_path, params_file_path): def load_predictor(self, model_file_path, params_file_path):
""" """load_predictor
initialize the inference engine initialize the inference engine
Args:
model_file_path: inference model path (*.pdmodel)
model_file_path: inference parmaeter path (*.pdiparams)
Returns: None
""" """
pass pass
def preprocess(self, img_path): def preprocess(self, x):
# preprocess for data """preprocess
Preprocess to the input.
Args:
x: Raw input, it can be an image path, a numpy array and so on.
Returns: Input data after preprocess.
"""
pass pass
def postprocess(self, x): def postprocess(self, x):
# postprocess for the inference engine output """postprocess
Postprocess to the inference engine output.
Args:
x: Inference engine output.
Returns: Output data after postprocess.
"""
pass pass
def run(self, x): def run(self, x):
# run using the infer """run
Inference process using inference engine.
Args:
x: Input data after preprocess.
Returns: Inference engine output
"""
pass pass
...@@ -46,6 +99,17 @@ def get_args(add_help=True): ...@@ -46,6 +99,17 @@ def get_args(add_help=True):
def infer_main(args): def infer_main(args):
"""infer_main
Main inference function.
Args:
args: Parameters generated using argparser.
Returns:
class_id: Class index of the input.
prob: : Probability of the input.
"""
# init inference engine # init inference engine
inference_engine = InferenceEngine(args) inference_engine = InferenceEngine(args)
...@@ -86,4 +150,4 @@ def infer_main(args): ...@@ -86,4 +150,4 @@ def infer_main(args):
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
infer_main(args) infer_main(args)
\ No newline at end of file
...@@ -29,6 +29,13 @@ ...@@ -29,6 +29,13 @@
在此感谢[vision](https://github.com/pytorch/vision),提高了MobileNetV3论文复现的效率。 在此感谢[vision](https://github.com/pytorch/vision),提高了MobileNetV3论文复现的效率。
注意:在这里为了简化流程,仅关于`ImageNet标准训练过程`做训练对齐,具体地:
* 训练总共120epoch,总的batch size是256*8=2048,学习率为0.8,下降策略为Piecewise Decay(30epoch下降10倍)
* 训练预处理:RandomResizedCrop(size=224) + RandomFlip(p=0.5) + Normalize
* 评估预处理:Resize(256) + CenterCrop(224) + Normalize
这里`mobilenet_v3_small`的参考指标也是重新训练得到的。
## 2. 数据集和复现精度 ## 2. 数据集和复现精度
数据集为ImageNet,训练集包含1281167张图像,验证集包含50000张图像。 数据集为ImageNet,训练集包含1281167张图像,验证集包含50000张图像。
...@@ -38,9 +45,7 @@ ...@@ -38,9 +45,7 @@
| 模型 | top1/5 acc (参考精度) | top1/5 acc (复现精度) | 下载链接 | | 模型 | top1/5 acc (参考精度) | top1/5 acc (复现精度) | 下载链接 |
|:---------:|:------:|:----------:|:----------:| |:---------:|:------:|:----------:|:----------:|
| Mo | 0.677/0.874 | 0.677/0.874 | [预训练模型](https://paddle-model-ecology.bj.bcebos.com/model/mobilenetv3_reprod/mobilenet_v3_small_paddle_pretrained.pdparams) \| [Inference模型](https://paddle-model-ecology.bj.bcebos.com/model/mobilenetv3_reprod/mobilenet_v3_small_paddle_infer.tar) \| [日志(coming soon)]() | | Mo | -/- | 0.601/0.826 | [预训练模型](https://paddle-model-ecology.bj.bcebos.com/model/mobilenetv3_reprod/mobilenet_v3_small_pretrained.pdparams) \| [Inference模型(coming soon!)]() \| [日志](https://paddle-model-ecology.bj.bcebos.com/model/mobilenetv3_reprod/train_mobilenet_v3_small.log) |
* 注:目前提供的预训练模型是从参考代码提供的权重转过来的,完整的训练结果和日志敬请期待!
## 3. 准备环境与数据 ## 3. 准备环境与数据
...@@ -95,7 +100,7 @@ tar -xf test_images/lite_data.tar ...@@ -95,7 +100,7 @@ tar -xf test_images/lite_data.tar
```bash ```bash
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python3.7 train.py --data-path=./ILSVRC2012 --lr=0.00125 --batch-size=32 python3.7 train.py --data-path=./ILSVRC2012 --lr=0.1 --batch-size=256
``` ```
部分训练日志如下所示。 部分训练日志如下所示。
...@@ -109,7 +114,7 @@ python3.7 train.py --data-path=./ILSVRC2012 --lr=0.00125 --batch-size=32 ...@@ -109,7 +114,7 @@ python3.7 train.py --data-path=./ILSVRC2012 --lr=0.00125 --batch-size=32
```bash ```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3 export CUDA_VISIBLE_DEVICES=0,1,2,3
python3.7 -m paddle.distributed.launch --gpus="0,1,2,3" train.py --data-path="./ILSVRC2012" --lr=0.01 --batch-size=64 python3.7 -m paddle.distributed.launch --gpus="0,1,2,3" train.py --data-path="./ILSVRC2012" --lr=0.4 --batch-size=256
``` ```
更多配置参数可以参考[train.py](./train.py)`get_args_parser`函数。 更多配置参数可以参考[train.py](./train.py)`get_args_parser`函数。
...@@ -146,7 +151,7 @@ python tools/predict.py --pretrained=./mobilenet_v3_small_paddle_pretrained.pdpa ...@@ -146,7 +151,7 @@ python tools/predict.py --pretrained=./mobilenet_v3_small_paddle_pretrained.pdpa
<img src="./images/demo.jpg" width=300"> <img src="./images/demo.jpg" width=300">
</div> </div>
最终输出结果为`class_id: 8, prob: 0.9503437280654907`,表示预测的类别ID是`8`,置信度为`0.950` 最终输出结果为`class_id: 8, prob: 0.9091238975524902`,表示预测的类别ID是`8`,置信度为`0.909`
* 使用CPU预测 * 使用CPU预测
......
...@@ -397,3 +397,19 @@ def resized_crop( ...@@ -397,3 +397,19 @@ def resized_crop(
img = crop(img, top, left, height, width) img = crop(img, top, left, height, width)
img = resize(img, size, interpolation) img = resize(img, size, interpolation)
return img return img
def hflip(img):
"""Horizontally flip the given image.
Args:
img (PIL Image or Tensor): Image to be flipped. If img
is a Tensor, it is expected to be in [..., H, W] format,
where ... means it can have an arbitrary number of leading
dimensions.
Returns:
PIL Image or Tensor: Horizontally flipped image.
"""
if not isinstance(img, paddle.Tensor):
return F_pil.hflip(img)
return F_t.hflip(img)
...@@ -268,3 +268,7 @@ def resize(img: Tensor, ...@@ -268,3 +268,7 @@ def resize(img: Tensor,
out_dtype=out_dtype) out_dtype=out_dtype)
return img return img
def hflip(img):
return img.flip(-1)
import math import math
import numbers import numbers
import random
import warnings import warnings
from collections.abc import Sequence from collections.abc import Sequence
from typing import Tuple, List from typing import Tuple, List
...@@ -17,7 +18,7 @@ from .functional import InterpolationMode, _interpolation_modes_from_int ...@@ -17,7 +18,7 @@ from .functional import InterpolationMode, _interpolation_modes_from_int
__all__ = [ __all__ = [
"Compose", "ToTensor", "Normalize", "Resize", "CenterCrop", "Compose", "ToTensor", "Normalize", "Resize", "CenterCrop",
"RandomResizedCrop" "RandomResizedCrop", "RandomHorizontalFlip"
] ]
...@@ -370,3 +371,28 @@ def _setup_size(size, error_msg): ...@@ -370,3 +371,28 @@ def _setup_size(size, error_msg):
raise ValueError(error_msg) raise ValueError(error_msg)
return size return size
class RandomHorizontalFlip(paddle.nn.Layer):
"""Horizontally flip the given image randomly with a given probability.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading
dimensions
Args:
p (float): probability of the image being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
super().__init__()
self.p = p
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be flipped.
Returns:
PIL Image or Tensor: Randomly flipped image.
"""
if random.random() < self.p:
return F.hflip(img)
return img
...@@ -10,8 +10,8 @@ class ClassificationPresetTrain: ...@@ -10,8 +10,8 @@ class ClassificationPresetTrain:
auto_augment_policy=None, auto_augment_policy=None,
random_erase_prob=0.0): random_erase_prob=0.0):
trans = [transforms.RandomResizedCrop(crop_size)] trans = [transforms.RandomResizedCrop(crop_size)]
#if hflip_prob > 0: if hflip_prob > 0:
# trans.append(transforms.RandomHorizontalFlip(hflip_prob)) trans.append(transforms.RandomHorizontalFlip(hflip_prob))
#if auto_augment_policy is not None: #if auto_augment_policy is not None:
# aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) # aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
# trans.append(autoaugment.AutoAugment(policy=aa_policy)) # trans.append(autoaugment.AutoAugment(policy=aa_policy))
......
export CUDA_VISIBLE_DEVICES=0
python3.7 train.py \
--data-path /paddle/data/ILSVRC2012/ \
--model mobilenet_v3_small \
--lr 0.1 \
--batch-size=256 \
--output-dir "./output/" \
--epochs 120 \
--workers=6
...@@ -5,7 +5,7 @@ python3.7 -m paddle.distributed.launch \ ...@@ -5,7 +5,7 @@ python3.7 -m paddle.distributed.launch \
train.py \ train.py \
--data-path /paddle/data/ILSVRC2012/ \ --data-path /paddle/data/ILSVRC2012/ \
--model mobilenet_v3_small \ --model mobilenet_v3_small \
--lr 0.4 \ --lr 0.8 \
--batch-size=256 \ --batch-size=256 \
--output-dir "./output/" \ --output-dir "./output/" \
--epochs 120 \ --epochs 120 \
......
...@@ -221,7 +221,6 @@ def main(args): ...@@ -221,7 +221,6 @@ def main(args):
lr_scheduler.step() lr_scheduler.step()
if paddle.distributed.get_rank() == 0: if paddle.distributed.get_rank() == 0:
top1 = evaluate(model, criterion, data_loader_test, device=device) top1 = evaluate(model, criterion, data_loader_test, device=device)
best_top1 = max(best_top1, top1)
if args.output_dir: if args.output_dir:
paddle.save(model.state_dict(), paddle.save(model.state_dict(),
os.path.join(args.output_dir, os.path.join(args.output_dir,
...@@ -233,6 +232,12 @@ def main(args): ...@@ -233,6 +232,12 @@ def main(args):
os.path.join(args.output_dir, 'latest.pdparams')) os.path.join(args.output_dir, 'latest.pdparams'))
paddle.save(optimizer.state_dict(), paddle.save(optimizer.state_dict(),
os.path.join(args.output_dir, 'latest.pdopt')) os.path.join(args.output_dir, 'latest.pdopt'))
if top1 > best_top1:
best_top1 = top1
paddle.save(model.state_dict(),
os.path.join(args.output_dir, 'best.pdparams'))
paddle.save(optimizer.state_dict(),
os.path.join(args.output_dir, 'best.pdopt'))
total_time = time.time() - start_time total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time))) total_time_str = str(datetime.timedelta(seconds=int(total_time)))
...@@ -286,7 +291,7 @@ def get_args_parser(add_help=True): ...@@ -286,7 +291,7 @@ def get_args_parser(add_help=True):
type=float, type=float,
help='decrease lr by a factor of lr-gamma') help='decrease lr by a factor of lr-gamma')
parser.add_argument( parser.add_argument(
'--print-freq', default=1, type=int, help='print frequency') '--print-freq', default=10, type=int, help='print frequency')
parser.add_argument('--output-dir', default='.', help='path where to save') parser.add_argument('--output-dir', default='.', help='path where to save')
parser.add_argument('--resume', default='', help='resume from checkpoint') parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument( parser.add_argument(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册