Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
69e91d85
M
models
项目概览
PaddlePaddle
/
models
1 年多 前同步成功
通知
222
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
69e91d85
编写于
12月 20, 2021
作者:
littletomatodonkey
提交者:
GitHub
12月 20, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add note (#5427)
* add note * fix log * fix hflip in train * fix readme * fix serving doc * fix link
上级
41c2fd76
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
267 addition
and
31 deletion
+267
-31
docs/tipc/serving/serving.md
docs/tipc/serving/serving.md
+4
-2
docs/tipc/serving/template/code/pipeline_http_client.py
docs/tipc/serving/template/code/pipeline_http_client.py
+23
-0
docs/tipc/serving/template/code/web_service.py
docs/tipc/serving/template/code/web_service.py
+47
-5
docs/tipc/train_infer_python/template/code/export_model.py
docs/tipc/train_infer_python/template/code/export_model.py
+42
-3
docs/tipc/train_infer_python/template/code/infer.py
docs/tipc/train_infer_python/template/code/infer.py
+73
-9
tutorials/mobilenetv3_prod/Step6/README.md
tutorials/mobilenetv3_prod/Step6/README.md
+11
-6
tutorials/mobilenetv3_prod/Step6/paddlevision/transforms/functional.py
...ilenetv3_prod/Step6/paddlevision/transforms/functional.py
+16
-0
tutorials/mobilenetv3_prod/Step6/paddlevision/transforms/functional_tensor.py
...3_prod/Step6/paddlevision/transforms/functional_tensor.py
+4
-0
tutorials/mobilenetv3_prod/Step6/paddlevision/transforms/transforms.py
...ilenetv3_prod/Step6/paddlevision/transforms/transforms.py
+27
-1
tutorials/mobilenetv3_prod/Step6/presets.py
tutorials/mobilenetv3_prod/Step6/presets.py
+2
-2
tutorials/mobilenetv3_prod/Step6/shell/train.sh
tutorials/mobilenetv3_prod/Step6/shell/train.sh
+10
-0
tutorials/mobilenetv3_prod/Step6/shell/train_dist.sh
tutorials/mobilenetv3_prod/Step6/shell/train_dist.sh
+1
-1
tutorials/mobilenetv3_prod/Step6/train.py
tutorials/mobilenetv3_prod/Step6/train.py
+7
-2
未找到文件。
docs/tipc/serving/serving.md
浏览文件 @
69e91d85
...
...
@@ -35,12 +35,14 @@ Paddle Serving依托深度学习框架PaddlePaddle旨在帮助深度学习开发
<a
name=
"21---"
></a>
### 2.1 准备测试数据和部署环境
【基本流程】
**【基本流程】**
**(1)准备测试数据:**
从验证集或者测试集中抽出至少一张图像,用于后续推理过程验证。
**(2)准备部署环境**
docker是一个开源的应用容器引擎,可以让应用程序更加方便地被打包和移植。建议在docker中进行Serving服务化部署。
首先准备docker环境,AIStudio环境已经安装了合适的docker。如果是非AIStudio环境,请
[
参考文档
](
https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/doc/doc_ch/environment.md#2
)
中的 "1.3.2 Docker环境配置" 安装docker环境。
然后安装Paddle Serving三个安装包,paddle-serving-server,paddle-serving-client 和 paddle-serving-app。
...
...
@@ -68,7 +70,7 @@ Paddle Serving Server更多不同运行环境的whl包下载地址,请参考
```
python3 -m paddle_serving_client.convert --dirname {静态图模型路径} --model_filename {模型结构文件} --params_filename {模型参数文件} --serving_server {转换后的服务器端模型和配置文件存储路径} --serving_client {转换后的客户端模型和配置文件存储路径}
```
上面命令中 "转换后的服务器端模型和配置文件" 将用于后续服务化部署。
上面命令中 "转换后的服务器端模型和配置文件" 将用于后续服务化部署。
其中
`paddle_serving_client.convert`
命令是
`paddle_serving_client`
whl包内置的转换函数,无需修改。
【实战】
...
...
docs/tipc/serving/template/code/pipeline_http_client.py
浏览文件 @
69e91d85
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
requests
import
json
...
...
@@ -7,6 +21,15 @@ import os
def
cv2_to_base64
(
image
):
"""cv2_to_base64
Convert an numpy array to a base64 object.
Args:
image: Input array.
Returns: Base64 output of the input.
"""
return
base64
.
b64encode
(
image
).
decode
(
'utf8'
)
...
...
docs/tipc/serving/template/code/web_service.py
浏览文件 @
69e91d85
...
...
@@ -16,26 +16,68 @@ from paddle_serving_server.web_service import WebService, Op
class
TIPCExampleOp
(
Op
):
"""
ExampleOp for serving server, you can rename by yourself
"""TIPCExampleOp
ExampleOp for serving server. You can rename by yourself.
"""
def
init_op
(
self
):
"""
initialize the class
"""init_op
Initialize the class.
Args: None
Returns: None
"""
pass
def
preprocess
(
self
,
input_dicts
,
data_id
,
log_id
):
# preprocess for the inputs
"""preprocess
In preprocess stage, assembling data for process stage. users can
override this function for model feed features.
Args:
input_dicts: input data to be preprocessed
data_id: inner unique id, increase auto
log_id: global unique id for RTT, 0 default
Return:
output_data: data for process stage
is_skip_process: skip process stage or not, False default
prod_errcode: None default, otherwise, product errores occured.
It is handled in the same way as exception.
prod_errinfo: "" default
"""
pass
def
postprocess
(
self
,
input_dicts
,
fetch_dict
,
data_id
,
log_id
):
"""postprocess
In postprocess stage, assemble data for next op or output.
Args:
input_data: data returned in preprocess stage, dict(for single predict) or list(for batch predict)
fetch_data: data returned in process stage, dict(for single predict) or list(for batch predict)
data_id: inner unique id, increase auto
log_id: logid, 0 default
Returns:
fetch_dict: fetch result must be dict type.
prod_errcode: None default, otherwise, product errores occured.
It is handled in the same way as exception.
prod_errinfo: "" default
"""
# postprocess for the service output
pass
class
TIPCExampleService
(
WebService
):
"""TIPCExampleService
Service class to define the Serving OP.
"""
def
get_pipeline_response
(
self
,
read_op
):
tipc_example_op
=
TIPCExampleOp
(
name
=
"tipc_example"
,
input_ops
=
[
read_op
])
...
...
docs/tipc/train_infer_python/template/code/export_model.py
浏览文件 @
69e91d85
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
...
...
@@ -9,6 +23,16 @@ import numpy as np
# parse args
def
get_args
(
add_help
=
True
):
"""get_args
Parse all args using argparse lib
Args:
add_help: Whether to add -h option on args
Returns:
An object which contains many parameters used for inference.
"""
import
argparse
parser
=
argparse
.
ArgumentParser
(
description
=
'PaddlePaddle Args'
,
add_help
=
add_help
)
...
...
@@ -17,14 +41,29 @@ def get_args(add_help=True):
def
build_model
(
args
):
"""
build model
"""build_model
Build your own model.
Args:
args: Parameters generated using argparser.
Returns:
A model whose type is nn.Layer
"""
pass
def
export
(
args
):
# build your own model
"""export
export inference model using jit.save
Args:
args: Parameters generated using argparser.
Returns: None
"""
model
=
build_model
(
args
)
# decorate model with jit.save
...
...
docs/tipc/train_infer_python/template/code/infer.py
浏览文件 @
69e91d85
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
paddle
from
paddle
import
inference
import
numpy
as
np
from
PIL
import
Image
from
reprod_log
import
ReprodLogger
from
preprocess_ops
import
ResizeImage
,
CenterCropImage
,
NormalizeImage
,
ToCHW
,
Compose
class
InferenceEngine
(
object
):
"""InferenceEngine
Inference engina class which contains preprocess, run, postprocess
"""
def
__init__
(
self
,
args
):
"""
Args:
args: Parameters generated using argparser.
Returns: None
"""
super
().
__init__
()
pass
def
load_predictor
(
self
,
model_file_path
,
params_file_path
):
"""
"""load_predictor
initialize the inference engine
Args:
model_file_path: inference model path (*.pdmodel)
model_file_path: inference parmaeter path (*.pdiparams)
Returns: None
"""
pass
def
preprocess
(
self
,
img_path
):
# preprocess for data
def
preprocess
(
self
,
x
):
"""preprocess
Preprocess to the input.
Args:
x: Raw input, it can be an image path, a numpy array and so on.
Returns: Input data after preprocess.
"""
pass
def
postprocess
(
self
,
x
):
# postprocess for the inference engine output
"""postprocess
Postprocess to the inference engine output.
Args:
x: Inference engine output.
Returns: Output data after postprocess.
"""
pass
def
run
(
self
,
x
):
# run using the infer
"""run
Inference process using inference engine.
Args:
x: Input data after preprocess.
Returns: Inference engine output
"""
pass
...
...
@@ -46,6 +99,17 @@ def get_args(add_help=True):
def
infer_main
(
args
):
"""infer_main
Main inference function.
Args:
args: Parameters generated using argparser.
Returns:
class_id: Class index of the input.
prob: : Probability of the input.
"""
# init inference engine
inference_engine
=
InferenceEngine
(
args
)
...
...
@@ -86,4 +150,4 @@ def infer_main(args):
if
__name__
==
"__main__"
:
args
=
get_args
()
infer_main
(
args
)
\ No newline at end of file
infer_main
(
args
)
tutorials/mobilenetv3_prod/Step6/README.md
浏览文件 @
69e91d85
...
...
@@ -29,6 +29,13 @@
在此感谢
[
vision
](
https://github.com/pytorch/vision
)
,提高了MobileNetV3论文复现的效率。
注意:在这里为了简化流程,仅关于
`ImageNet标准训练过程`
做训练对齐,具体地:
*
训练总共120epoch,总的batch size是256
*
8=2048,学习率为0.8,下降策略为Piecewise Decay(30epoch下降10倍)
*
训练预处理:RandomResizedCrop(size=224) + RandomFlip(p=0.5) + Normalize
*
评估预处理:Resize(256) + CenterCrop(224) + Normalize
这里
`mobilenet_v3_small`
的参考指标也是重新训练得到的。
## 2. 数据集和复现精度
数据集为ImageNet,训练集包含1281167张图像,验证集包含50000张图像。
...
...
@@ -38,9 +45,7 @@
| 模型 | top1/5 acc (参考精度) | top1/5 acc (复现精度) | 下载链接 |
|:---------:|:------:|:----------:|:----------:|
| Mo | 0.677/0.874 | 0.677/0.874 |
[
预训练模型
](
https://paddle-model-ecology.bj.bcebos.com/model/mobilenetv3_reprod/mobilenet_v3_small_paddle_pretrained.pdparams
)
\|
[
Inference模型
](
https://paddle-model-ecology.bj.bcebos.com/model/mobilenetv3_reprod/mobilenet_v3_small_paddle_infer.tar
)
\|
[
日志(coming soon)
](
)
|
*
注:目前提供的预训练模型是从参考代码提供的权重转过来的,完整的训练结果和日志敬请期待!
| Mo | -/- | 0.601/0.826 |
[
预训练模型
](
https://paddle-model-ecology.bj.bcebos.com/model/mobilenetv3_reprod/mobilenet_v3_small_pretrained.pdparams
)
\|
[
Inference模型(coming soon!)
](
)
\|
[
日志
](
https://paddle-model-ecology.bj.bcebos.com/model/mobilenetv3_reprod/train_mobilenet_v3_small.log
)
|
## 3. 准备环境与数据
...
...
@@ -95,7 +100,7 @@ tar -xf test_images/lite_data.tar
```
bash
export
CUDA_VISIBLE_DEVICES
=
0
python3.7 train.py
--data-path
=
./ILSVRC2012
--lr
=
0.
00125
--batch-size
=
32
python3.7 train.py
--data-path
=
./ILSVRC2012
--lr
=
0.
1
--batch-size
=
256
```
部分训练日志如下所示。
...
...
@@ -109,7 +114,7 @@ python3.7 train.py --data-path=./ILSVRC2012 --lr=0.00125 --batch-size=32
```
bash
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3
python3.7
-m
paddle.distributed.launch
--gpus
=
"0,1,2,3"
train.py
--data-path
=
"./ILSVRC2012"
--lr
=
0.
01
--batch-size
=
64
python3.7
-m
paddle.distributed.launch
--gpus
=
"0,1,2,3"
train.py
--data-path
=
"./ILSVRC2012"
--lr
=
0.
4
--batch-size
=
256
```
更多配置参数可以参考
[
train.py
](
./train.py
)
的
`get_args_parser`
函数。
...
...
@@ -146,7 +151,7 @@ python tools/predict.py --pretrained=./mobilenet_v3_small_paddle_pretrained.pdpa
<img
src=
"./images/demo.jpg"
width=
300"
>
</div>
最终输出结果为
`class_id: 8, prob: 0.9
503437280654907`
,表示预测的类别ID是
`8`
,置信度为
`0.950
`
。
最终输出结果为
`class_id: 8, prob: 0.9
091238975524902`
,表示预测的类别ID是
`8`
,置信度为
`0.909
`
。
*
使用CPU预测
...
...
tutorials/mobilenetv3_prod/Step6/paddlevision/transforms/functional.py
浏览文件 @
69e91d85
...
...
@@ -397,3 +397,19 @@ def resized_crop(
img
=
crop
(
img
,
top
,
left
,
height
,
width
)
img
=
resize
(
img
,
size
,
interpolation
)
return
img
def
hflip
(
img
):
"""Horizontally flip the given image.
Args:
img (PIL Image or Tensor): Image to be flipped. If img
is a Tensor, it is expected to be in [..., H, W] format,
where ... means it can have an arbitrary number of leading
dimensions.
Returns:
PIL Image or Tensor: Horizontally flipped image.
"""
if
not
isinstance
(
img
,
paddle
.
Tensor
):
return
F_pil
.
hflip
(
img
)
return
F_t
.
hflip
(
img
)
tutorials/mobilenetv3_prod/Step6/paddlevision/transforms/functional_tensor.py
浏览文件 @
69e91d85
...
...
@@ -268,3 +268,7 @@ def resize(img: Tensor,
out_dtype
=
out_dtype
)
return
img
def
hflip
(
img
):
return
img
.
flip
(
-
1
)
tutorials/mobilenetv3_prod/Step6/paddlevision/transforms/transforms.py
浏览文件 @
69e91d85
import
math
import
numbers
import
random
import
warnings
from
collections.abc
import
Sequence
from
typing
import
Tuple
,
List
...
...
@@ -17,7 +18,7 @@ from .functional import InterpolationMode, _interpolation_modes_from_int
__all__
=
[
"Compose"
,
"ToTensor"
,
"Normalize"
,
"Resize"
,
"CenterCrop"
,
"RandomResizedCrop"
"RandomResizedCrop"
,
"RandomHorizontalFlip"
]
...
...
@@ -370,3 +371,28 @@ def _setup_size(size, error_msg):
raise
ValueError
(
error_msg
)
return
size
class
RandomHorizontalFlip
(
paddle
.
nn
.
Layer
):
"""Horizontally flip the given image randomly with a given probability.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading
dimensions
Args:
p (float): probability of the image being flipped. Default value is 0.5
"""
def
__init__
(
self
,
p
=
0.5
):
super
().
__init__
()
self
.
p
=
p
def
forward
(
self
,
img
):
"""
Args:
img (PIL Image or Tensor): Image to be flipped.
Returns:
PIL Image or Tensor: Randomly flipped image.
"""
if
random
.
random
()
<
self
.
p
:
return
F
.
hflip
(
img
)
return
img
tutorials/mobilenetv3_prod/Step6/presets.py
浏览文件 @
69e91d85
...
...
@@ -10,8 +10,8 @@ class ClassificationPresetTrain:
auto_augment_policy
=
None
,
random_erase_prob
=
0.0
):
trans
=
[
transforms
.
RandomResizedCrop
(
crop_size
)]
#
if hflip_prob > 0:
#
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
if
hflip_prob
>
0
:
trans
.
append
(
transforms
.
RandomHorizontalFlip
(
hflip_prob
))
#if auto_augment_policy is not None:
# aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
# trans.append(autoaugment.AutoAugment(policy=aa_policy))
...
...
tutorials/mobilenetv3_prod/Step6/shell/train.sh
0 → 100644
浏览文件 @
69e91d85
export
CUDA_VISIBLE_DEVICES
=
0
python3.7 train.py
\
--data-path
/paddle/data/ILSVRC2012/
\
--model
mobilenet_v3_small
\
--lr
0.1
\
--batch-size
=
256
\
--output-dir
"./output/"
\
--epochs
120
\
--workers
=
6
tutorials/mobilenetv3_prod/Step6/shell/train_dist.sh
浏览文件 @
69e91d85
...
...
@@ -5,7 +5,7 @@ python3.7 -m paddle.distributed.launch \
train.py
\
--data-path
/paddle/data/ILSVRC2012/
\
--model
mobilenet_v3_small
\
--lr
0.
4
\
--lr
0.
8
\
--batch-size
=
256
\
--output-dir
"./output/"
\
--epochs
120
\
...
...
tutorials/mobilenetv3_prod/Step6/train.py
浏览文件 @
69e91d85
...
...
@@ -221,7 +221,6 @@ def main(args):
lr_scheduler
.
step
()
if
paddle
.
distributed
.
get_rank
()
==
0
:
top1
=
evaluate
(
model
,
criterion
,
data_loader_test
,
device
=
device
)
best_top1
=
max
(
best_top1
,
top1
)
if
args
.
output_dir
:
paddle
.
save
(
model
.
state_dict
(),
os
.
path
.
join
(
args
.
output_dir
,
...
...
@@ -233,6 +232,12 @@ def main(args):
os
.
path
.
join
(
args
.
output_dir
,
'latest.pdparams'
))
paddle
.
save
(
optimizer
.
state_dict
(),
os
.
path
.
join
(
args
.
output_dir
,
'latest.pdopt'
))
if
top1
>
best_top1
:
best_top1
=
top1
paddle
.
save
(
model
.
state_dict
(),
os
.
path
.
join
(
args
.
output_dir
,
'best.pdparams'
))
paddle
.
save
(
optimizer
.
state_dict
(),
os
.
path
.
join
(
args
.
output_dir
,
'best.pdopt'
))
total_time
=
time
.
time
()
-
start_time
total_time_str
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
total_time
)))
...
...
@@ -286,7 +291,7 @@ def get_args_parser(add_help=True):
type
=
float
,
help
=
'decrease lr by a factor of lr-gamma'
)
parser
.
add_argument
(
'--print-freq'
,
default
=
1
,
type
=
int
,
help
=
'print frequency'
)
'--print-freq'
,
default
=
1
0
,
type
=
int
,
help
=
'print frequency'
)
parser
.
add_argument
(
'--output-dir'
,
default
=
'.'
,
help
=
'path where to save'
)
parser
.
add_argument
(
'--resume'
,
default
=
''
,
help
=
'resume from checkpoint'
)
parser
.
add_argument
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录