提交 2c94ac53 编写于 作者: S syyxsxx

Merge remote-tracking branch 'paddle/develop' into develop

......@@ -104,8 +104,8 @@ pip install paddlex -i https://mirror.baidu.com/pypi/simple
## 交流与反馈
- 项目官网:https://www.paddlepaddle.org.cn/paddle/paddlex
- PaddleX用户交流群:1045148026 (手机QQ扫描如下二维码快速加入)
![](./docs/gui/images/QR.jpg)
- PaddleX用户交流群:957286141 (手机QQ扫描如下二维码快速加入)
![](./docs/gui/images/QR2.jpg)
......
......@@ -23,6 +23,7 @@ import org.opencv.core.Scalar;
import org.opencv.core.Size;
import org.opencv.imgproc.Imgproc;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
......@@ -101,6 +102,15 @@ public class Transforms {
if (info.containsKey("coarsest_stride")) {
padding.coarsest_stride = (int) info.get("coarsest_stride");
}
if (info.containsKey("im_padding_value")) {
List<Double> im_padding_value = (List<Double>) info.get("im_padding_value");
if (im_padding_value.size()!=3){
Log.e(TAG, "len of im_padding_value in padding must == 3.");
}
for (int k =0; i<im_padding_value.size(); i++){
padding.paddding_value[k] = im_padding_value.get(k);
}
}
if (info.containsKey("target_size")) {
if (info.get("target_size") instanceof Integer) {
padding.width = (int) info.get("target_size");
......@@ -124,7 +134,7 @@ public class Transforms {
if(transformsMode.equalsIgnoreCase("RGB")){
Imgproc.cvtColor(inputMat, inputMat, Imgproc.COLOR_BGR2RGB);
}else if(!transformsMode.equalsIgnoreCase("BGR")){
Log.e(TAG, "transformsMode only support RGB or BGR");
Log.e(TAG, "transformsMode only support RGB or BGR.");
}
inputMat.convertTo(inputMat, CvType.CV_32FC(3));
......@@ -136,16 +146,15 @@ public class Transforms {
int h = inputMat.height();
int c = inputMat.channels();
imageBlob.setImageData(new float[w * h * c]);
int[] channelStride = new int[]{w * h, w * h * 2};
for (int y = 0; y < h; y++) {
for (int x = 0;
x < w; x++) {
double[] color = inputMat.get(y, x);
imageBlob.getImageData()[y * w + x] = (float) (color[0]);
imageBlob.getImageData()[y * w + x + channelStride[0]] = (float) (color[1]);
imageBlob.getImageData()[y * w + x + channelStride[1]] = (float) (color[2]);
}
Mat singleChannelMat = new Mat(h, w, CvType.CV_32FC(1));
float[] singleChannelImageData = new float[w * h];
for (int i = 0; i < c; i++) {
Core.extractChannel(inputMat, singleChannelMat, i);
singleChannelMat.get(0, 0, singleChannelImageData);
System.arraycopy(singleChannelImageData ,0, imageBlob.getImageData(),i*w*h, w*h);
}
return imageBlob;
}
......@@ -248,6 +257,7 @@ public class Transforms {
private double width;
private double height;
private double coarsest_stride;
private double[] paddding_value = {0.0, 0.0, 0.0};
public Mat run(Mat inputMat, ImageBlob imageBlob) {
int origin_w = inputMat.width();
......@@ -264,7 +274,7 @@ public class Transforms {
}
imageBlob.setNewImageSize(inputMat.height(),2);
imageBlob.setNewImageSize(inputMat.width(),3);
Core.copyMakeBorder(inputMat, inputMat, 0, (int)padding_h, 0, (int)padding_w, Core.BORDER_CONSTANT, new Scalar(0));
Core.copyMakeBorder(inputMat, inputMat, 0, (int)padding_h, 0, (int)padding_w, Core.BORDER_CONSTANT, new Scalar(paddding_value));
return inputMat;
}
}
......
......@@ -31,8 +31,11 @@ import org.opencv.core.Scalar;
import org.opencv.core.Size;
import org.opencv.imgproc.Imgproc;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
......@@ -120,13 +123,11 @@ public class Visualize {
int new_w = (int)imageBlob.getNewImageSize()[3];
Mat mask = new Mat(new_h, new_w, CvType.CV_32FC(1));
float[] scoreData = new float[new_h*new_w];
for (int h = 0; h < new_h; h++) {
for (int w = 0; w < new_w; w++){
scoreData[new_h * h + w] = (1-result.getMask().getScoreData()[cutoutClass + h * new_h + w]) * 255;
}
}
System.arraycopy(result.getMask().getScoreData() ,cutoutClass*new_h*new_w, scoreData ,0, new_h*new_w);
mask.put(0,0, scoreData);
Core.multiply(mask, new Scalar(255), mask);
mask.convertTo(mask,CvType.CV_8UC(1));
ListIterator<Map.Entry<String, int[]>> reverseReshapeInfo = new ArrayList<Map.Entry<String, int[]>>(imageBlob.getReshapeInfo().entrySet()).listIterator(imageBlob.getReshapeInfo().size());
while (reverseReshapeInfo.hasPrevious()) {
Map.Entry<String, int[]> entry = reverseReshapeInfo.previous();
......
# PaddleX文档
PaddleX的使用文档均在本目录结构下。文档采用Read the Docs方式组织,您可以直接访问[在线文档](https://paddlex.readthedocs.io/zh_CN/latest/index.html)进行查阅。
PaddleX的使用文档均在本目录结构下。文档采用Read the Docs方式组织,您可以直接访问[在线文档](https://paddlex.readthedocs.io/zh_CN/develop/index.html)进行查阅。
## 编译文档
在本目录下按如下步骤进行文档编译
......
......@@ -7,7 +7,7 @@
图像分类、目标检测、实例分割、语义分割统一的预测器,实现高性能预测。
```
paddlex.deploy.Predictor(model_dir, use_gpu=False, gpu_id=0, use_mkl=False, use_trt=False, use_glog=False, memory_optimize=True)
paddlex.deploy.Predictor(model_dir, use_gpu=False, gpu_id=0, use_mkl=False, mkl_thread_num=4, use_trt=False, use_glog=False, memory_optimize=True)
```
**参数**
......@@ -16,6 +16,7 @@ paddlex.deploy.Predictor(model_dir, use_gpu=False, gpu_id=0, use_mkl=False, use_
> * **use_gpu** (bool): 是否使用GPU进行预测。
> * **gpu_id** (int): 使用的GPU序列号。
> * **use_mkl** (bool): 是否使用mkldnn加速库。
> * **mkl_thread_num** (int): 使用mkldnn加速库时的线程数,默认为4
> * **use_trt** (boll): 是否使用TensorRT预测引擎。
> * **use_glog** (bool): 是否打印中间日志。
> * **memory_optimize** (bool): 是否优化内存使用。
......
......@@ -3,7 +3,7 @@
## paddlex.seg.DeepLabv3p
```python
paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride=16, aspp_with_sep_conv=True, decoder_use_sep_conv=True, encoder_with_aspp=True, enable_decoder=True, use_bce_loss=False, use_dice_loss=False, class_weight=None, ignore_index=255)
paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride=16, aspp_with_sep_conv=True, decoder_use_sep_conv=True, encoder_with_aspp=True, enable_decoder=True, use_bce_loss=False, use_dice_loss=False, class_weight=None, ignore_index=255, pooling_crop_size=None)
```
......@@ -12,7 +12,7 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride
> **参数**
> > - **num_classes** (int): 类别数。
> > - **backbone** (str): DeepLabv3+的backbone网络,实现特征图的计算,取值范围为['Xception65', 'Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5', 'MobileNetV2_x1.0', 'MobileNetV2_x1.5', 'MobileNetV2_x2.0'],默认值为'MobileNetV2_x1.0'。
> > - **backbone** (str): DeepLabv3+的backbone网络,实现特征图的计算,取值范围为['Xception65', 'Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5', 'MobileNetV2_x1.0', 'MobileNetV2_x1.5', 'MobileNetV2_x2.0', 'MobileNetV3_large_x1_0_ssld'],默认值为'MobileNetV2_x1.0'。
> > - **output_stride** (int): backbone 输出特征图相对于输入的下采样倍数,一般取值为8或16。默认16。
> > - **aspp_with_sep_conv** (bool): decoder模块是否采用separable convolutions。默认True。
> > - **decoder_use_sep_conv** (bool): decoder模块是否采用separable convolutions。默认True。
......@@ -22,6 +22,7 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride
> > - **use_dice_loss** (bool): 是否使用dice loss作为网络的损失函数,只能用于两类分割,可与bce loss同时使用,当`use_bce_loss`和`use_dice_loss`都为False时,使用交叉熵损失函数。默认False。
> > - **class_weight** (list/str): 交叉熵损失函数各类损失的权重。当`class_weight`为list的时候,长度应为`num_classes`。当`class_weight`为str时, weight.lower()应为'dynamic',这时会根据每一轮各类像素的比重自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,即平时使用的交叉熵损失函数。
> > - **ignore_index** (int): label上忽略的值,label为`ignore_index`的像素不参与损失函数的计算。默认255。
> > - **pooling_crop_size** (int):当backbone为`MobileNetV3_large_x1_0_ssld`时,需设置为训练过程中模型输入大小,格式为[W, H]。例如模型输入大小为[512, 512], 则`pooling_crop_size`应该设置为[512, 512]。在encoder模块中获取图像平均值时被用到,若为None,则直接求平均值;若为模型输入大小,则使用`avg_pool`算子得到平均值。默认值None。
### train
......
......@@ -81,6 +81,7 @@
| 模型 | 模型大小 | 预测时间(毫秒) | mIoU(%) |
|:-------|:-----------|:-------------|:----------|
| [DeepLabv3_MobileNetV3_large_x1_0_ssld](https://paddleseg.bj.bcebos.com/models/deeplabv3p_mobilenetv3_large_cityscapes.tar.gz) | 9.3MB | - | 73.28 |
| [DeepLabv3_MobileNetv2_x1.0](https://paddleseg.bj.bcebos.com/models/mobilenet_cityscapes.tgz) | 14.7MB | - | 69.8 |
| [DeepLabv3_Xception65](https://paddleseg.bj.bcebos.com/models/xception65_bn_cityscapes.tgz) | 329.3MB | - | 79.3 |
| [HRNet_W18](https://paddleseg.bj.bcebos.com/models/hrnet_w18_bn_cityscapes.tgz) | 77.3MB | | 79.36 |
......
......@@ -27,7 +27,7 @@ import paddlex as pdx
predictor = pdx.deploy.Predictor('./inference_model')
image_list = ['xiaoduxiong_test_image/JPEGImages/WeChatIMG110.jpeg',
'xiaoduxiong_test_image/JPEGImages/WeChatIMG111.jpeg']
result = predictor.predict(image_list=image_list)
result = predictor.batch_predict(image_list=image_list)
```
* 视频流预测
......
......@@ -80,6 +80,7 @@ PaddleX目前提供了实例分割MaskRCNN模型,支持5种不同的backbone
| 模型 | 模型特点 | 存储体积 | GPU预测速度 | CPU(x86)预测速度(毫秒) | 骁龙855(ARM)预测速度 (毫秒)| mIoU |
| :---- | :------- | :---------- | :---------- | :----- | :----- |:--- |
| DeepLabv3p-MobileNetV2_x1.0 | 轻量级模型,适用于移动端场景| - | - | - | 69.8% |
| DeepLabv3-MobileNetV3_large_x1_0_ssld | 轻量级模型,适用于移动端场景| - | - | - | 73.28% |
| HRNet_W18_Small_v1 | 轻量高速,适用于移动端场景 | - | - | - | - |
| FastSCNN | 轻量高速,适用于追求高速预测的移动端或服务器端场景 | - | - | - | 69.64 |
| HRNet_W18 | 高精度模型,适用于对图像分辨率较为敏感、对目标细节预测要求更高的服务器端场景| - | - | - | 79.36 |
......
......@@ -12,6 +12,7 @@ PaddleX目前提供了DeepLabv3p、UNet、HRNet和FastSCNN四种语义分割结
| :---------------- | :------- | :------- | :--------- | :--------- | :----- |
| [DeepLabv3p-MobileNetV2-x0.25](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/semantic_segmentation/deeplabv3p_mobilenetv2_x0.25.py) | - | 2.9MB | - | - | 模型小,预测速度快,适用于低性能或移动端设备 |
| [DeepLabv3p-MobileNetV2-x1.0](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/semantic_segmentation/deeplabv3p_mobilenetv2.py) | 69.8% | 11MB | - | - | 模型小,预测速度快,适用于低性能或移动端设备 |
| [DeepLabv3_MobileNetV3_large_x1_0_ssld](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/semantic_segmentation/deeplabv3p_mobilenetv3_large_ssld.py) | 73.28% | 9.3MB | - | - | 模型小,预测速度快,精度较高,适用于低性能或移动端设备 |
| [DeepLabv3p-Xception65](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/semantic_segmentation/deeplabv3p_xception65.py) | 79.3% | 158MB | - | - | 模型大,精度高,适用于服务端 |
| [UNet](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/semantic_segmentation/unet.py) | - | 52MB | - | - | 模型较大,精度高,适用于服务端 |
| [HRNet](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/semantic_segmentation/hrnet.py) | 79.4% | 37MB | - | - | 模型较小,模型精度高,适用于服务端部署 |
......
......@@ -51,6 +51,12 @@ def arg_parser():
action="store_true",
default=False,
help="export onnx model for deployment")
parser.add_argument(
"--onnx_opset",
"-oo",
type=int,
default=10,
help="when use paddle2onnx, set onnx opset version to export")
parser.add_argument(
"--data_conversion",
"-dc",
......@@ -162,7 +168,7 @@ def main():
logging.error(
"paddlex --export_inference --model_dir model_path --save_dir infer_model"
)
pdx.convertor.export_onnx_model(model, args.save_dir)
pdx.convertor.export_onnx_model(model, args.save_dir, args.onnx_opset)
if args.data_conversion:
assert args.source is not None, "--source should be defined while converting dataset"
......
......@@ -29,10 +29,12 @@ def export_onnx(model_dir, save_dir, fixed_input_shape):
export_onnx_model(model, save_dir)
def export_onnx_model(model, save_dir):
if model.model_type == "detector" or model.__class__.__name__ == "FastSCNN":
def export_onnx_model(model, save_dir, opset_version=10):
if model.__class__.__name__ == "FastSCNN" or (
model.model_type == "detector" and
model.__class__.__name__ != "YOLOv3"):
logging.error(
"Only image classifier models and semantic segmentation models(except FastSCNN) are supported to export to ONNX"
"Only image classifier models, detection models(YOLOv3) and semantic segmentation models(except FastSCNN) are supported to export to ONNX"
)
try:
import x2paddle
......@@ -41,6 +43,406 @@ def export_onnx_model(model, save_dir):
except:
logging.error(
"You need to install x2paddle first, pip install x2paddle>=0.7.4")
from x2paddle.op_mapper.paddle_op_mapper import PaddleOpMapper
if opset_version == 10 and model.__class__.__name__ == "YOLOv3":
logging.warning(
"Export for openVINO by default, the output of multiclass_nms exported to onnx will contains background. If you need onnx completely consistent with paddle, please use X2Paddle to export"
)
x2paddle.op_mapper.paddle2onnx.opset10.paddle_custom_layer.multiclass_nms.multiclass_nms = multiclass_nms_for_openvino
from x2paddle.op_mapper.paddle2onnx.paddle_op_mapper import PaddleOpMapper
mapper = PaddleOpMapper()
mapper.convert(model.test_prog, save_dir)
mapper.convert(
model.test_prog,
save_dir,
scope=model.scope,
opset_version=opset_version)
def multiclass_nms_for_openvino(op, block):
"""
Convert the paddle multiclass_nms to onnx op.
This op is get the select boxes from origin boxes.
This op is for OpenVINO, which donn't support dynamic shape).
"""
import math
import sys
import numpy as np
import paddle.fluid.core as core
import paddle.fluid as fluid
import onnx
import warnings
from onnx import helper, onnx_pb
inputs = dict()
outputs = dict()
attrs = dict()
for name in op.input_names:
inputs[name] = op.input(name)
for name in op.output_names:
outputs[name] = op.output(name)
for name in op.attr_names:
attrs[name] = op.attr(name)
result_name = outputs['Out'][0]
background = attrs['background_label']
normalized = attrs['normalized']
if normalized == False:
warnings.warn(
'The parameter normalized of multiclass_nms OP of Paddle is False, which has diff with ONNX. \
Please set normalized=True in multiclass_nms of Paddle'
)
#convert the paddle attribute to onnx tensor
name_score_threshold = [outputs['Out'][0] + "@score_threshold"]
name_iou_threshold = [outputs['Out'][0] + "@iou_threshold"]
name_keep_top_k = [outputs['Out'][0] + '@keep_top_k']
name_keep_top_k_2D = [outputs['Out'][0] + '@keep_top_k_1D']
node_score_threshold = onnx.helper.make_node(
'Constant',
inputs=[],
outputs=name_score_threshold,
value=onnx.helper.make_tensor(
name=name_score_threshold[0] + "@const",
data_type=onnx.TensorProto.FLOAT,
dims=(),
vals=[float(attrs['score_threshold'])]))
node_iou_threshold = onnx.helper.make_node(
'Constant',
inputs=[],
outputs=name_iou_threshold,
value=onnx.helper.make_tensor(
name=name_iou_threshold[0] + "@const",
data_type=onnx.TensorProto.FLOAT,
dims=(),
vals=[float(attrs['nms_threshold'])]))
node_keep_top_k = onnx.helper.make_node(
'Constant',
inputs=[],
outputs=name_keep_top_k,
value=onnx.helper.make_tensor(
name=name_keep_top_k[0] + "@const",
data_type=onnx.TensorProto.INT64,
dims=(),
vals=[np.int64(attrs['keep_top_k'])]))
node_keep_top_k_2D = onnx.helper.make_node(
'Constant',
inputs=[],
outputs=name_keep_top_k_2D,
value=onnx.helper.make_tensor(
name=name_keep_top_k_2D[0] + "@const",
data_type=onnx.TensorProto.INT64,
dims=[1, 1],
vals=[np.int64(attrs['keep_top_k'])]))
# the paddle data format is x1,y1,x2,y2
kwargs = {'center_point_box': 0}
name_select_nms = [outputs['Out'][0] + "@select_index"]
node_select_nms= onnx.helper.make_node(
'NonMaxSuppression',
inputs=inputs['BBoxes'] + inputs['Scores'] + name_keep_top_k +\
name_iou_threshold + name_score_threshold,
outputs=name_select_nms)
# step 1 nodes select the nms class
node_list = [
node_score_threshold, node_iou_threshold, node_keep_top_k,
node_keep_top_k_2D, node_select_nms
]
# create some const value to use
name_const_value = [result_name+"@const_0",
result_name+"@const_1",\
result_name+"@const_2",\
result_name+"@const_-1"]
value_const_value = [0, 1, 2, -1]
for name, value in zip(name_const_value, value_const_value):
node = onnx.helper.make_node(
'Constant',
inputs=[],
outputs=[name],
value=onnx.helper.make_tensor(
name=name + "@const",
data_type=onnx.TensorProto.INT64,
dims=[1],
vals=[value]))
node_list.append(node)
# In this code block, we will deocde the raw score data, reshape N * C * M to 1 * N*C*M
# and the same time, decode the select indices to 1 * D, gather the select_indices
outputs_gather_1_ = [result_name + "@gather_1_"]
node_gather_1_ = onnx.helper.make_node(
'Gather',
inputs=name_select_nms + [result_name + "@const_1"],
outputs=outputs_gather_1_,
axis=1)
node_list.append(node_gather_1_)
outputs_gather_1 = [result_name + "@gather_1"]
node_gather_1 = onnx.helper.make_node(
'Unsqueeze',
inputs=outputs_gather_1_,
outputs=outputs_gather_1,
axes=[0])
node_list.append(node_gather_1)
outputs_gather_2_ = [result_name + "@gather_2_"]
node_gather_2_ = onnx.helper.make_node(
'Gather',
inputs=name_select_nms + [result_name + "@const_2"],
outputs=outputs_gather_2_,
axis=1)
node_list.append(node_gather_2_)
outputs_gather_2 = [result_name + "@gather_2"]
node_gather_2 = onnx.helper.make_node(
'Unsqueeze',
inputs=outputs_gather_2_,
outputs=outputs_gather_2,
axes=[0])
node_list.append(node_gather_2)
# reshape scores N * C * M to (N*C*M) * 1
outputs_reshape_scores_rank1 = [result_name + "@reshape_scores_rank1"]
node_reshape_scores_rank1 = onnx.helper.make_node(
"Reshape",
inputs=inputs['Scores'] + [result_name + "@const_-1"],
outputs=outputs_reshape_scores_rank1)
node_list.append(node_reshape_scores_rank1)
# get the shape of scores
outputs_shape_scores = [result_name + "@shape_scores"]
node_shape_scores = onnx.helper.make_node(
'Shape', inputs=inputs['Scores'], outputs=outputs_shape_scores)
node_list.append(node_shape_scores)
# gather the index: 2 shape of scores
outputs_gather_scores_dim1 = [result_name + "@gather_scores_dim1"]
node_gather_scores_dim1 = onnx.helper.make_node(
'Gather',
inputs=outputs_shape_scores + [result_name + "@const_2"],
outputs=outputs_gather_scores_dim1,
axis=0)
node_list.append(node_gather_scores_dim1)
# mul class * M
outputs_mul_classnum_boxnum = [result_name + "@mul_classnum_boxnum"]
node_mul_classnum_boxnum = onnx.helper.make_node(
'Mul',
inputs=outputs_gather_1 + outputs_gather_scores_dim1,
outputs=outputs_mul_classnum_boxnum)
node_list.append(node_mul_classnum_boxnum)
# add class * M * index
outputs_add_class_M_index = [result_name + "@add_class_M_index"]
node_add_class_M_index = onnx.helper.make_node(
'Add',
inputs=outputs_mul_classnum_boxnum + outputs_gather_2,
outputs=outputs_add_class_M_index)
node_list.append(node_add_class_M_index)
# Squeeze the indices to 1 dim
outputs_squeeze_select_index = [result_name + "@squeeze_select_index"]
node_squeeze_select_index = onnx.helper.make_node(
'Squeeze',
inputs=outputs_add_class_M_index,
outputs=outputs_squeeze_select_index,
axes=[0, 2])
node_list.append(node_squeeze_select_index)
# gather the data from flatten scores
outputs_gather_select_scores = [result_name + "@gather_select_scores"]
node_gather_select_scores = onnx.helper.make_node('Gather',
inputs=outputs_reshape_scores_rank1 + \
outputs_squeeze_select_index,
outputs=outputs_gather_select_scores,
axis=0)
node_list.append(node_gather_select_scores)
# get nums to input TopK
outputs_shape_select_num = [result_name + "@shape_select_num"]
node_shape_select_num = onnx.helper.make_node(
'Shape',
inputs=outputs_gather_select_scores,
outputs=outputs_shape_select_num)
node_list.append(node_shape_select_num)
outputs_gather_select_num = [result_name + "@gather_select_num"]
node_gather_select_num = onnx.helper.make_node(
'Gather',
inputs=outputs_shape_select_num + [result_name + "@const_0"],
outputs=outputs_gather_select_num,
axis=0)
node_list.append(node_gather_select_num)
outputs_unsqueeze_select_num = [result_name + "@unsqueeze_select_num"]
node_unsqueeze_select_num = onnx.helper.make_node(
'Unsqueeze',
inputs=outputs_gather_select_num,
outputs=outputs_unsqueeze_select_num,
axes=[0])
node_list.append(node_unsqueeze_select_num)
outputs_concat_topK_select_num = [result_name + "@conat_topK_select_num"]
node_conat_topK_select_num = onnx.helper.make_node(
'Concat',
inputs=outputs_unsqueeze_select_num + name_keep_top_k_2D,
outputs=outputs_concat_topK_select_num,
axis=0)
node_list.append(node_conat_topK_select_num)
outputs_cast_concat_topK_select_num = [
result_name + "@concat_topK_select_num"
]
node_outputs_cast_concat_topK_select_num = onnx.helper.make_node(
'Cast',
inputs=outputs_concat_topK_select_num,
outputs=outputs_cast_concat_topK_select_num,
to=6)
node_list.append(node_outputs_cast_concat_topK_select_num)
# get min(topK, num_select)
outputs_compare_topk_num_select = [
result_name + "@compare_topk_num_select"
]
node_compare_topk_num_select = onnx.helper.make_node(
'ReduceMin',
inputs=outputs_cast_concat_topK_select_num,
outputs=outputs_compare_topk_num_select,
keepdims=0)
node_list.append(node_compare_topk_num_select)
# unsqueeze the indices to 1D tensor
outputs_unsqueeze_topk_select_indices = [
result_name + "@unsqueeze_topk_select_indices"
]
node_unsqueeze_topk_select_indices = onnx.helper.make_node(
'Unsqueeze',
inputs=outputs_compare_topk_num_select,
outputs=outputs_unsqueeze_topk_select_indices,
axes=[0])
node_list.append(node_unsqueeze_topk_select_indices)
# cast the indices to INT64
outputs_cast_topk_indices = [result_name + "@cast_topk_indices"]
node_cast_topk_indices = onnx.helper.make_node(
'Cast',
inputs=outputs_unsqueeze_topk_select_indices,
outputs=outputs_cast_topk_indices,
to=7)
node_list.append(node_cast_topk_indices)
# select topk scores indices
outputs_topk_select_topk_indices = [result_name + "@topk_select_topk_values",\
result_name + "@topk_select_topk_indices"]
node_topk_select_topk_indices = onnx.helper.make_node(
'TopK',
inputs=outputs_gather_select_scores + outputs_cast_topk_indices,
outputs=outputs_topk_select_topk_indices)
node_list.append(node_topk_select_topk_indices)
# gather topk label, scores, boxes
outputs_gather_topk_scores = [result_name + "@gather_topk_scores"]
node_gather_topk_scores = onnx.helper.make_node(
'Gather',
inputs=outputs_gather_select_scores +
[outputs_topk_select_topk_indices[1]],
outputs=outputs_gather_topk_scores,
axis=0)
node_list.append(node_gather_topk_scores)
outputs_gather_topk_class = [result_name + "@gather_topk_class"]
node_gather_topk_class = onnx.helper.make_node(
'Gather',
inputs=outputs_gather_1 + [outputs_topk_select_topk_indices[1]],
outputs=outputs_gather_topk_class,
axis=1)
node_list.append(node_gather_topk_class)
# gather the boxes need to gather the boxes id, then get boxes
outputs_gather_topk_boxes_id = [result_name + "@gather_topk_boxes_id"]
node_gather_topk_boxes_id = onnx.helper.make_node(
'Gather',
inputs=outputs_gather_2 + [outputs_topk_select_topk_indices[1]],
outputs=outputs_gather_topk_boxes_id,
axis=1)
node_list.append(node_gather_topk_boxes_id)
# squeeze the gather_topk_boxes_id to 1 dim
outputs_squeeze_topk_boxes_id = [result_name + "@squeeze_topk_boxes_id"]
node_squeeze_topk_boxes_id = onnx.helper.make_node(
'Squeeze',
inputs=outputs_gather_topk_boxes_id,
outputs=outputs_squeeze_topk_boxes_id,
axes=[0, 2])
node_list.append(node_squeeze_topk_boxes_id)
outputs_gather_select_boxes = [result_name + "@gather_select_boxes"]
node_gather_select_boxes = onnx.helper.make_node(
'Gather',
inputs=inputs['BBoxes'] + outputs_squeeze_topk_boxes_id,
outputs=outputs_gather_select_boxes,
axis=1)
node_list.append(node_gather_select_boxes)
# concat the final result
# before concat need to cast the class to float
outputs_cast_topk_class = [result_name + "@cast_topk_class"]
node_cast_topk_class = onnx.helper.make_node(
'Cast',
inputs=outputs_gather_topk_class,
outputs=outputs_cast_topk_class,
to=1)
node_list.append(node_cast_topk_class)
outputs_unsqueeze_topk_scores = [result_name + "@unsqueeze_topk_scores"]
node_unsqueeze_topk_scores = onnx.helper.make_node(
'Unsqueeze',
inputs=outputs_gather_topk_scores,
outputs=outputs_unsqueeze_topk_scores,
axes=[0, 2])
node_list.append(node_unsqueeze_topk_scores)
inputs_concat_final_results = outputs_cast_topk_class + outputs_unsqueeze_topk_scores +\
outputs_gather_select_boxes
outputs_sort_by_socre_results = [result_name + "@concat_topk_scores"]
node_sort_by_socre_results = onnx.helper.make_node(
'Concat',
inputs=inputs_concat_final_results,
outputs=outputs_sort_by_socre_results,
axis=2)
node_list.append(node_sort_by_socre_results)
# select topk classes indices
outputs_squeeze_cast_topk_class = [
result_name + "@squeeze_cast_topk_class"
]
node_squeeze_cast_topk_class = onnx.helper.make_node(
'Squeeze',
inputs=outputs_cast_topk_class,
outputs=outputs_squeeze_cast_topk_class,
axes=[0, 2])
node_list.append(node_squeeze_cast_topk_class)
outputs_neg_squeeze_cast_topk_class = [
result_name + "@neg_squeeze_cast_topk_class"
]
node_neg_squeeze_cast_topk_class = onnx.helper.make_node(
'Neg',
inputs=outputs_squeeze_cast_topk_class,
outputs=outputs_neg_squeeze_cast_topk_class)
node_list.append(node_neg_squeeze_cast_topk_class)
outputs_topk_select_classes_indices = [result_name + "@topk_select_topk_classes_scores",\
result_name + "@topk_select_topk_classes_indices"]
node_topk_select_topk_indices = onnx.helper.make_node(
'TopK',
inputs=outputs_neg_squeeze_cast_topk_class + outputs_cast_topk_indices,
outputs=outputs_topk_select_classes_indices)
node_list.append(node_topk_select_topk_indices)
outputs_concat_final_results = outputs['Out']
node_concat_final_results = onnx.helper.make_node(
'Gather',
inputs=outputs_sort_by_socre_results +
[outputs_topk_select_classes_indices[1]],
outputs=outputs_concat_final_results,
axis=1)
node_list.append(node_concat_final_results)
return node_list
......@@ -37,7 +37,7 @@ class DeepLabv3p(BaseAPI):
num_classes (int): 类别数。
backbone (str): DeepLabv3+的backbone网络,实现特征图的计算,取值范围为['Xception65', 'Xception41',
'MobileNetV2_x0.25', 'MobileNetV2_x0.5', 'MobileNetV2_x1.0', 'MobileNetV2_x1.5',
'MobileNetV2_x2.0']。默认'MobileNetV2_x1.0'。
'MobileNetV2_x2.0', 'MobileNetV3_large_x1_0_ssld']。默认'MobileNetV2_x1.0'。
output_stride (int): backbone 输出特征图相对于输入的下采样倍数,一般取值为8或16。默认16。
aspp_with_sep_conv (bool): 在asspp模块是否采用separable convolutions。默认True。
decoder_use_sep_conv (bool): decoder模块是否采用separable convolutions。默认True。
......@@ -51,10 +51,13 @@ class DeepLabv3p(BaseAPI):
自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None时,各类的权重1,
即平时使用的交叉熵损失函数。
ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。默认255。
pooling_crop_size (list): 当backbone为MobileNetV3_large_x1_0_ssld时,需设置为训练过程中模型输入大小, 格式为[W, H]。
在encoder模块中获取图像平均值时被用到,若为None,则直接求平均值;若为模型输入大小,则使用'pool'算子得到平均值。
默认值为None。
Raises:
ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。
ValueError: backbone取值不在['Xception65', 'Xception41', 'MobileNetV2_x0.25',
'MobileNetV2_x0.5', 'MobileNetV2_x1.0', 'MobileNetV2_x1.5', 'MobileNetV2_x2.0']之内。
'MobileNetV2_x0.5', 'MobileNetV2_x1.0', 'MobileNetV2_x1.5', 'MobileNetV2_x2.0', 'MobileNetV3_large_x1_0_ssld']之内。
ValueError: class_weight为list, 但长度不等于num_class。
class_weight为str, 但class_weight.low()不等于dynamic。
TypeError: class_weight不为None时,其类型不是list或str。
......@@ -71,7 +74,8 @@ class DeepLabv3p(BaseAPI):
use_bce_loss=False,
use_dice_loss=False,
class_weight=None,
ignore_index=255):
ignore_index=255,
pooling_crop_size=None):
self.init_params = locals()
super(DeepLabv3p, self).__init__('segmenter')
# dice_loss或bce_loss只适用两类分割中
......@@ -85,12 +89,12 @@ class DeepLabv3p(BaseAPI):
if backbone not in [
'Xception65', 'Xception41', 'MobileNetV2_x0.25',
'MobileNetV2_x0.5', 'MobileNetV2_x1.0', 'MobileNetV2_x1.5',
'MobileNetV2_x2.0'
'MobileNetV2_x2.0', 'MobileNetV3_large_x1_0_ssld'
]:
raise ValueError(
"backbone: {} is set wrong. it should be one of "
"('Xception65', 'Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5',"
" 'MobileNetV2_x1.0', 'MobileNetV2_x1.5', 'MobileNetV2_x2.0')".
" 'MobileNetV2_x1.0', 'MobileNetV2_x1.5', 'MobileNetV2_x2.0', 'MobileNetV3_large_x1_0_ssld')".
format(backbone))
if class_weight is not None:
......@@ -121,6 +125,30 @@ class DeepLabv3p(BaseAPI):
self.labels = None
self.sync_bn = True
self.fixed_input_shape = None
self.pooling_stride = [1, 1]
self.pooling_crop_size = pooling_crop_size
self.aspp_with_se = False
self.se_use_qsigmoid = False
self.aspp_convs_filters = 256
self.aspp_with_concat_projection = True
self.add_image_level_feature = True
self.use_sum_merge = False
self.conv_filters = 256
self.output_is_logits = False
self.backbone_lr_mult_list = None
if 'MobileNetV3' in backbone:
self.output_stride = 32
self.pooling_stride = (4, 5)
self.aspp_with_se = True
self.se_use_qsigmoid = True
self.aspp_convs_filters = 128
self.aspp_with_concat_projection = False
self.add_image_level_feature = False
self.use_sum_merge = True
self.output_is_logits = True
if self.output_is_logits:
self.conv_filters = self.num_classes
self.backbone_lr_mult_list = [0.15, 0.35, 0.65, 0.85, 1]
def _get_backbone(self, backbone):
def mobilenetv2(backbone):
......@@ -167,10 +195,22 @@ class DeepLabv3p(BaseAPI):
end_points=end_points,
decode_points=decode_points)
def mobilenetv3(backbone):
scale = 1.0
lr_mult_list = self.backbone_lr_mult_list
return paddlex.cv.nets.MobileNetV3(
scale=scale,
model_name='large',
output_stride=self.output_stride,
lr_mult_list=lr_mult_list,
for_seg=True)
if 'Xception' in backbone:
return xception(backbone)
elif 'MobileNetV2' in backbone:
return mobilenetv2(backbone)
elif 'MobileNetV3' in backbone:
return mobilenetv3(backbone)
def build_net(self, mode='train'):
model = paddlex.cv.nets.segmentation.DeepLabv3p(
......@@ -186,7 +226,17 @@ class DeepLabv3p(BaseAPI):
use_dice_loss=self.use_dice_loss,
class_weight=self.class_weight,
ignore_index=self.ignore_index,
fixed_input_shape=self.fixed_input_shape)
fixed_input_shape=self.fixed_input_shape,
pooling_stride=self.pooling_stride,
pooling_crop_size=self.pooling_crop_size,
aspp_with_se=self.aspp_with_se,
se_use_qsigmoid=self.se_use_qsigmoid,
aspp_convs_filters=self.aspp_convs_filters,
aspp_with_concat_projection=self.aspp_with_concat_projection,
add_image_level_feature=self.add_image_level_feature,
use_sum_merge=self.use_sum_merge,
conv_filters=self.conv_filters,
output_is_logits=self.output_is_logits)
inputs = model.generate_inputs()
model_out = model.build_net(inputs)
outputs = OrderedDict()
......@@ -370,9 +420,6 @@ class DeepLabv3p(BaseAPI):
elif info[0] == 'padding':
w, h = info[1][1], info[1][0]
one_pred = one_pred[0:h, 0:w]
else:
raise Exception(
"Unexpected info '{}' in im_info".format(info[0]))
one_pred = one_pred.astype('int64')
one_pred = one_pred[np.newaxis, :, :, np.newaxis]
one_label = one_label[np.newaxis, np.newaxis, :, :]
......@@ -430,9 +477,6 @@ class DeepLabv3p(BaseAPI):
w, h = info[1][1], info[1][0]
pred = pred[0:h, 0:w]
logit = logit[0:h, 0:w, :]
else:
raise Exception("Unexpected info '{}' in im_info".format(
info[0]))
pred_list.append(pred)
logit_list.append(logit)
......
......@@ -243,6 +243,32 @@ def get_prune_params(model):
for i in params_not_prune:
if i in prune_names:
prune_names.remove(i)
elif model_type.startswith('HRNet'):
for param in program.global_block().all_parameters():
if 'weight' not in param.name:
continue
prune_names.append(param.name)
params_not_prune = [
'conv-1_weights'
]
for i in params_not_prune:
if i in prune_names:
prune_names.remove(i)
elif model_type.startswith('FastSCNN'):
for param in program.global_block().all_parameters():
if 'weight' not in param.name:
continue
if 'dwise' in param.name or 'depthwise' in param.name or 'logit' in param.name:
continue
prune_names.append(param.name)
params_not_prune = [
'classifier/weights'
]
for i in params_not_prune:
if i in prune_names:
prune_names.remove(i)
elif model_type.startswith('DeepLabv3p'):
for param in program.global_block().all_parameters():
......
......@@ -122,6 +122,8 @@ coco_pretrain = {
}
cityscapes_pretrain = {
'DeepLabv3p_MobileNetV3_large_x1_0_ssld_CITYSCAPES':
'https://paddleseg.bj.bcebos.com/models/deeplabv3p_mobilenetv3_large_cityscapes.tar.gz',
'DeepLabv3p_MobileNetV2_x1.0_CITYSCAPES':
'https://paddleseg.bj.bcebos.com/models/mobilenet_cityscapes.tgz',
'DeepLabv3p_Xception65_CITYSCAPES':
......@@ -144,7 +146,8 @@ def get_pretrain_weights(flag, class_name, backbone, save_dir):
if flag == 'COCO':
if class_name == 'DeepLabv3p' and backbone in [
'Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5',
'MobileNetV2_x1.5', 'MobileNetV2_x2.0'
'MobileNetV2_x1.5', 'MobileNetV2_x2.0',
'MobileNetV3_large_x1_0_ssld'
]:
model_name = '{}_{}'.format(class_name, backbone)
logging.warning(warning_info.format(model_name, flag, 'IMAGENET'))
......
......@@ -311,7 +311,7 @@ class YOLOv3:
def _upsample(self, input, scale=2, name=None):
out = fluid.layers.resize_nearest(
input=input, scale=float(scale), name=name)
input=input, scale=float(scale), name=name, align_corners=False)
return out
def _detection_block(self,
......
......@@ -235,10 +235,13 @@ class HRNet(object):
name=name + '_layer_' + str(i + 1) + '_' + str(j + 1))
if self.feature_maps == "stage4":
y = fluid.layers.resize_bilinear(
input=y, out_shape=[height, width])
input=y,
out_shape=[height, width],
align_corners=False,
align_mode=1)
else:
y = fluid.layers.resize_nearest(
input=y, scale=2**(j - i))
input=y, scale=2**(j - i), align_corners=False)
residual = fluid.layers.elementwise_add(
x=residual, y=y, act=None)
elif j < i:
......
......@@ -42,7 +42,9 @@ class MobileNetV3():
extra_block_filters=[[256, 512], [128, 256], [128, 256],
[64, 128]],
num_classes=None,
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0]):
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
for_seg=False,
output_stride=None):
assert len(lr_mult_list) == 5, \
"lr_mult_list length in MobileNetV3 must be 5 but got {}!!".format(
len(lr_mult_list))
......@@ -57,48 +59,112 @@ class MobileNetV3():
self.num_classes = num_classes
self.lr_mult_list = lr_mult_list
self.curr_stage = 0
if model_name == "large":
self.cfg = [
# kernel_size, expand, channel, se_block, act_mode, stride
[3, 16, 16, False, 'relu', 1],
[3, 64, 24, False, 'relu', 2],
[3, 72, 24, False, 'relu', 1],
[5, 72, 40, True, 'relu', 2],
[5, 120, 40, True, 'relu', 1],
[5, 120, 40, True, 'relu', 1],
[3, 240, 80, False, 'hard_swish', 2],
[3, 200, 80, False, 'hard_swish', 1],
[3, 184, 80, False, 'hard_swish', 1],
[3, 184, 80, False, 'hard_swish', 1],
[3, 480, 112, True, 'hard_swish', 1],
[3, 672, 112, True, 'hard_swish', 1],
[5, 672, 160, True, 'hard_swish', 2],
[5, 960, 160, True, 'hard_swish', 1],
[5, 960, 160, True, 'hard_swish', 1],
]
self.cls_ch_squeeze = 960
self.cls_ch_expand = 1280
self.lr_interval = 3
elif model_name == "small":
self.cfg = [
# kernel_size, expand, channel, se_block, act_mode, stride
[3, 16, 16, True, 'relu', 2],
[3, 72, 24, False, 'relu', 2],
[3, 88, 24, False, 'relu', 1],
[5, 96, 40, True, 'hard_swish', 2],
[5, 240, 40, True, 'hard_swish', 1],
[5, 240, 40, True, 'hard_swish', 1],
[5, 120, 48, True, 'hard_swish', 1],
[5, 144, 48, True, 'hard_swish', 1],
[5, 288, 96, True, 'hard_swish', 2],
[5, 576, 96, True, 'hard_swish', 1],
[5, 576, 96, True, 'hard_swish', 1],
]
self.cls_ch_squeeze = 576
self.cls_ch_expand = 1280
self.lr_interval = 2
self.for_seg = for_seg
self.decode_point = None
if self.for_seg:
if model_name == "large":
self.cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, False, 'relu', 1],
[3, 64, 24, False, 'relu', 2],
[3, 72, 24, False, 'relu', 1],
[5, 72, 40, True, 'relu', 2],
[5, 120, 40, True, 'relu', 1],
[5, 120, 40, True, 'relu', 1],
[3, 240, 80, False, 'hard_swish', 2],
[3, 200, 80, False, 'hard_swish', 1],
[3, 184, 80, False, 'hard_swish', 1],
[3, 184, 80, False, 'hard_swish', 1],
[3, 480, 112, True, 'hard_swish', 1],
[3, 672, 112, True, 'hard_swish', 1],
# The number of channels in the last 4 stages is reduced by a
# factor of 2 compared to the standard implementation.
[5, 336, 80, True, 'hard_swish', 2],
[5, 480, 80, True, 'hard_swish', 1],
[5, 480, 80, True, 'hard_swish', 1],
]
self.cls_ch_squeeze = 480
self.cls_ch_expand = 1280
self.lr_interval = 3
elif model_name == "small":
self.cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, True, 'relu', 2],
[3, 72, 24, False, 'relu', 2],
[3, 88, 24, False, 'relu', 1],
[5, 96, 40, True, 'hard_swish', 2],
[5, 240, 40, True, 'hard_swish', 1],
[5, 240, 40, True, 'hard_swish', 1],
[5, 120, 48, True, 'hard_swish', 1],
[5, 144, 48, True, 'hard_swish', 1],
# The number of channels in the last 4 stages is reduced by a
# factor of 2 compared to the standard implementation.
[5, 144, 48, True, 'hard_swish', 2],
[5, 288, 48, True, 'hard_swish', 1],
[5, 288, 48, True, 'hard_swish', 1],
]
else:
raise NotImplementedError
else:
raise NotImplementedError
if model_name == "large":
self.cfg = [
# kernel_size, expand, channel, se_block, act_mode, stride
[3, 16, 16, False, 'relu', 1],
[3, 64, 24, False, 'relu', 2],
[3, 72, 24, False, 'relu', 1],
[5, 72, 40, True, 'relu', 2],
[5, 120, 40, True, 'relu', 1],
[5, 120, 40, True, 'relu', 1],
[3, 240, 80, False, 'hard_swish', 2],
[3, 200, 80, False, 'hard_swish', 1],
[3, 184, 80, False, 'hard_swish', 1],
[3, 184, 80, False, 'hard_swish', 1],
[3, 480, 112, True, 'hard_swish', 1],
[3, 672, 112, True, 'hard_swish', 1],
[5, 672, 160, True, 'hard_swish', 2],
[5, 960, 160, True, 'hard_swish', 1],
[5, 960, 160, True, 'hard_swish', 1],
]
self.cls_ch_squeeze = 960
self.cls_ch_expand = 1280
self.lr_interval = 3
elif model_name == "small":
self.cfg = [
# kernel_size, expand, channel, se_block, act_mode, stride
[3, 16, 16, True, 'relu', 2],
[3, 72, 24, False, 'relu', 2],
[3, 88, 24, False, 'relu', 1],
[5, 96, 40, True, 'hard_swish', 2],
[5, 240, 40, True, 'hard_swish', 1],
[5, 240, 40, True, 'hard_swish', 1],
[5, 120, 48, True, 'hard_swish', 1],
[5, 144, 48, True, 'hard_swish', 1],
[5, 288, 96, True, 'hard_swish', 2],
[5, 576, 96, True, 'hard_swish', 1],
[5, 576, 96, True, 'hard_swish', 1],
]
self.cls_ch_squeeze = 576
self.cls_ch_expand = 1280
self.lr_interval = 2
else:
raise NotImplementedError
if self.for_seg:
self.modify_bottle_params(output_stride)
def modify_bottle_params(self, output_stride=None):
if output_stride is not None and output_stride % 2 != 0:
raise Exception("output stride must to be even number")
if output_stride is None:
return
else:
stride = 2
for i, _cfg in enumerate(self.cfg):
stride = stride * _cfg[-1]
if stride > output_stride:
s = 1
self.cfg[i][-1] = s
def _conv_bn_layer(self,
input,
......@@ -153,6 +219,14 @@ class MobileNetV3():
bn = fluid.layers.relu6(bn)
return bn
def make_divisible(self, v, divisor=8, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
def _hard_swish(self, x):
return x * fluid.layers.relu6(x + 3) / 6.
......@@ -220,6 +294,9 @@ class MobileNetV3():
use_cudnn=False,
name=name + '_depthwise')
if self.curr_stage == 5:
self.decode_point = conv1
if use_se:
conv1 = self._se_block(
input=conv1, num_out_filter=num_mid_filter, name=name + '_se')
......@@ -282,7 +359,7 @@ class MobileNetV3():
conv = self._conv_bn_layer(
input,
filter_size=3,
num_filters=inplanes if scale <= 1.0 else int(inplanes * scale),
num_filters=self.make_divisible(inplanes * scale),
stride=2,
padding=1,
num_groups=1,
......@@ -290,6 +367,7 @@ class MobileNetV3():
act='hard_swish',
name='conv1')
i = 0
inplanes = self.make_divisible(inplanes * scale)
for layer_cfg in cfg:
self.block_stride *= layer_cfg[5]
if layer_cfg[5] == 2:
......@@ -297,19 +375,32 @@ class MobileNetV3():
conv = self._residual_unit(
input=conv,
num_in_filter=inplanes,
num_mid_filter=int(scale * layer_cfg[1]),
num_out_filter=int(scale * layer_cfg[2]),
num_mid_filter=self.make_divisible(scale * layer_cfg[1]),
num_out_filter=self.make_divisible(scale * layer_cfg[2]),
act=layer_cfg[4],
stride=layer_cfg[5],
filter_size=layer_cfg[0],
use_se=layer_cfg[3],
name='conv' + str(i + 2))
inplanes = int(scale * layer_cfg[2])
inplanes = self.make_divisible(scale * layer_cfg[2])
i += 1
self.curr_stage = i
blocks.append(conv)
if self.for_seg:
conv = self._conv_bn_layer(
input=conv,
filter_size=1,
num_filters=self.make_divisible(scale * self.cls_ch_squeeze),
stride=1,
padding=0,
num_groups=1,
if_act=True,
act='hard_swish',
name='conv_last')
return conv, self.decode_point
if self.num_classes:
conv = self._conv_bn_layer(
input=conv,
......
......@@ -21,7 +21,7 @@ from collections import OrderedDict
import paddle.fluid as fluid
from .model_utils.libs import scope, name_scope
from .model_utils.libs import bn, bn_relu, relu
from .model_utils.libs import bn, bn_relu, relu, qsigmoid
from .model_utils.libs import conv, max_pool, deconv
from .model_utils.libs import separate_conv
from .model_utils.libs import sigmoid_to_softmax
......@@ -82,7 +82,17 @@ class DeepLabv3p(object):
use_dice_loss=False,
class_weight=None,
ignore_index=255,
fixed_input_shape=None):
fixed_input_shape=None,
pooling_stride=[1, 1],
pooling_crop_size=None,
aspp_with_se=False,
se_use_qsigmoid=False,
aspp_convs_filters=256,
aspp_with_concat_projection=True,
add_image_level_feature=True,
use_sum_merge=False,
conv_filters=256,
output_is_logits=False):
# dice_loss或bce_loss只适用两类分割中
if num_classes > 2 and (use_bce_loss or use_dice_loss):
raise ValueError(
......@@ -117,6 +127,17 @@ class DeepLabv3p(object):
self.encoder_with_aspp = encoder_with_aspp
self.enable_decoder = enable_decoder
self.fixed_input_shape = fixed_input_shape
self.output_is_logits = output_is_logits
self.aspp_convs_filters = aspp_convs_filters
self.output_stride = output_stride
self.pooling_crop_size = pooling_crop_size
self.pooling_stride = pooling_stride
self.se_use_qsigmoid = se_use_qsigmoid
self.aspp_with_concat_projection = aspp_with_concat_projection
self.add_image_level_feature = add_image_level_feature
self.aspp_with_se = aspp_with_se
self.use_sum_merge = use_sum_merge
self.conv_filters = conv_filters
def _encoder(self, input):
# 编码器配置,采用ASPP架构,pooling + 1x1_conv + 三个不同尺度的空洞卷积并行, concat后1x1conv
......@@ -129,19 +150,36 @@ class DeepLabv3p(object):
elif self.output_stride == 8:
aspp_ratios = [12, 24, 36]
else:
raise Exception("DeepLabv3p only support stride 8 or 16")
aspp_ratios = []
param_attr = fluid.ParamAttr(
name=name_scope + 'weights',
regularizer=None,
initializer=fluid.initializer.TruncatedNormal(
loc=0.0, scale=0.06))
concat_logits = []
with scope('encoder'):
channel = 256
channel = self.aspp_convs_filters
with scope("image_pool"):
image_avg = fluid.layers.reduce_mean(
input, [2, 3], keep_dim=True)
image_avg = bn_relu(
if self.pooling_crop_size is None:
image_avg = fluid.layers.reduce_mean(
input, [2, 3], keep_dim=True)
else:
pool_w = int((self.pooling_crop_size[0] - 1.0) /
self.output_stride + 1.0)
pool_h = int((self.pooling_crop_size[1] - 1.0) /
self.output_stride + 1.0)
image_avg = fluid.layers.pool2d(
input,
pool_size=(pool_h, pool_w),
pool_stride=self.pooling_stride,
pool_type='avg',
pool_padding='VALID')
act = qsigmoid if self.se_use_qsigmoid else bn_relu
image_avg = act(
conv(
image_avg,
channel,
......@@ -153,6 +191,8 @@ class DeepLabv3p(object):
input_shape = fluid.layers.shape(input)
image_avg = fluid.layers.resize_bilinear(image_avg,
input_shape[2:])
if self.add_image_level_feature:
concat_logits.append(image_avg)
with scope("aspp0"):
aspp0 = bn_relu(
......@@ -164,77 +204,160 @@ class DeepLabv3p(object):
groups=1,
padding=0,
param_attr=param_attr))
with scope("aspp1"):
if self.aspp_with_sep_conv:
aspp1 = separate_conv(
input,
channel,
1,
3,
dilation=aspp_ratios[0],
act=relu)
else:
aspp1 = bn_relu(
conv(
concat_logits.append(aspp0)
if aspp_ratios:
with scope("aspp1"):
if self.aspp_with_sep_conv:
aspp1 = separate_conv(
input,
channel,
stride=1,
filter_size=3,
1,
3,
dilation=aspp_ratios[0],
padding=aspp_ratios[0],
param_attr=param_attr))
with scope("aspp2"):
if self.aspp_with_sep_conv:
aspp2 = separate_conv(
input,
channel,
1,
3,
dilation=aspp_ratios[1],
act=relu)
else:
aspp2 = bn_relu(
conv(
act=relu)
else:
aspp1 = bn_relu(
conv(
input,
channel,
stride=1,
filter_size=3,
dilation=aspp_ratios[0],
padding=aspp_ratios[0],
param_attr=param_attr))
concat_logits.append(aspp1)
with scope("aspp2"):
if self.aspp_with_sep_conv:
aspp2 = separate_conv(
input,
channel,
stride=1,
filter_size=3,
1,
3,
dilation=aspp_ratios[1],
padding=aspp_ratios[1],
param_attr=param_attr))
with scope("aspp3"):
if self.aspp_with_sep_conv:
aspp3 = separate_conv(
input,
channel,
1,
3,
dilation=aspp_ratios[2],
act=relu)
else:
aspp3 = bn_relu(
conv(
act=relu)
else:
aspp2 = bn_relu(
conv(
input,
channel,
stride=1,
filter_size=3,
dilation=aspp_ratios[1],
padding=aspp_ratios[1],
param_attr=param_attr))
concat_logits.append(aspp2)
with scope("aspp3"):
if self.aspp_with_sep_conv:
aspp3 = separate_conv(
input,
channel,
stride=1,
filter_size=3,
1,
3,
dilation=aspp_ratios[2],
padding=aspp_ratios[2],
param_attr=param_attr))
act=relu)
else:
aspp3 = bn_relu(
conv(
input,
channel,
stride=1,
filter_size=3,
dilation=aspp_ratios[2],
padding=aspp_ratios[2],
param_attr=param_attr))
concat_logits.append(aspp3)
with scope("concat"):
data = fluid.layers.concat(
[image_avg, aspp0, aspp1, aspp2, aspp3], axis=1)
data = bn_relu(
data = fluid.layers.concat(concat_logits, axis=1)
if self.aspp_with_concat_projection:
data = bn_relu(
conv(
data,
channel,
1,
1,
groups=1,
padding=0,
param_attr=param_attr))
data = fluid.layers.dropout(data, 0.9)
if self.aspp_with_se:
data = data * image_avg
return data
def _decoder_with_sum_merge(self, encode_data, decode_shortcut,
param_attr):
decode_shortcut_shape = fluid.layers.shape(decode_shortcut)
encode_data = fluid.layers.resize_bilinear(encode_data,
decode_shortcut_shape[2:])
encode_data = conv(
encode_data,
self.conv_filters,
1,
1,
groups=1,
padding=0,
param_attr=param_attr)
with scope('merge'):
decode_shortcut = conv(
decode_shortcut,
self.conv_filters,
1,
1,
groups=1,
padding=0,
param_attr=param_attr)
return encode_data + decode_shortcut
def _decoder_with_concat(self, encode_data, decode_shortcut, param_attr):
with scope('concat'):
decode_shortcut = bn_relu(
conv(
decode_shortcut,
48,
1,
1,
groups=1,
padding=0,
param_attr=param_attr))
decode_shortcut_shape = fluid.layers.shape(decode_shortcut)
encode_data = fluid.layers.resize_bilinear(
encode_data, decode_shortcut_shape[2:])
encode_data = fluid.layers.concat(
[encode_data, decode_shortcut], axis=1)
if self.decoder_use_sep_conv:
with scope("separable_conv1"):
encode_data = separate_conv(
encode_data, self.conv_filters, 1, 3, dilation=1, act=relu)
with scope("separable_conv2"):
encode_data = separate_conv(
encode_data, self.conv_filters, 1, 3, dilation=1, act=relu)
else:
with scope("decoder_conv1"):
encode_data = bn_relu(
conv(
data,
channel,
1,
1,
groups=1,
padding=0,
encode_data,
self.conv_filters,
stride=1,
filter_size=3,
dilation=1,
padding=1,
param_attr=param_attr))
data = fluid.layers.dropout(data, 0.9)
return data
with scope("decoder_conv2"):
encode_data = bn_relu(
conv(
encode_data,
self.conv_filters,
stride=1,
filter_size=3,
dilation=1,
padding=1,
param_attr=param_attr))
return encode_data
def _decoder(self, encode_data, decode_shortcut):
# 解码器配置
......@@ -246,52 +369,14 @@ class DeepLabv3p(object):
regularizer=None,
initializer=fluid.initializer.TruncatedNormal(
loc=0.0, scale=0.06))
with scope('decoder'):
with scope('concat'):
decode_shortcut = bn_relu(
conv(
decode_shortcut,
48,
1,
1,
groups=1,
padding=0,
param_attr=param_attr))
if self.use_sum_merge:
return self._decoder_with_sum_merge(
encode_data, decode_shortcut, param_attr)
decode_shortcut_shape = fluid.layers.shape(decode_shortcut)
encode_data = fluid.layers.resize_bilinear(
encode_data, decode_shortcut_shape[2:])
encode_data = fluid.layers.concat(
[encode_data, decode_shortcut], axis=1)
if self.decoder_use_sep_conv:
with scope("separable_conv1"):
encode_data = separate_conv(
encode_data, 256, 1, 3, dilation=1, act=relu)
with scope("separable_conv2"):
encode_data = separate_conv(
encode_data, 256, 1, 3, dilation=1, act=relu)
else:
with scope("decoder_conv1"):
encode_data = bn_relu(
conv(
encode_data,
256,
stride=1,
filter_size=3,
dilation=1,
padding=1,
param_attr=param_attr))
with scope("decoder_conv2"):
encode_data = bn_relu(
conv(
encode_data,
256,
stride=1,
filter_size=3,
dilation=1,
padding=1,
param_attr=param_attr))
return encode_data
return self._decoder_with_concat(encode_data, decode_shortcut,
param_attr)
def _get_loss(self, logit, label, mask):
avg_loss = 0
......@@ -335,8 +420,11 @@ class DeepLabv3p(object):
self.num_classes = 1
image = inputs['image']
data, decode_shortcuts = self.backbone(image)
decode_shortcut = decode_shortcuts[self.backbone.decode_points]
if 'MobileNetV3' in self.backbone.__class__.__name__:
data, decode_shortcut = self.backbone(image)
else:
data, decode_shortcuts = self.backbone(image)
decode_shortcut = decode_shortcuts[self.backbone.decode_points]
# 编码器解码器设置
if self.encoder_with_aspp:
......@@ -351,18 +439,22 @@ class DeepLabv3p(object):
regularization_coeff=0.0),
initializer=fluid.initializer.TruncatedNormal(
loc=0.0, scale=0.01))
with scope('logit'):
with fluid.name_scope('last_conv'):
logit = conv(
data,
self.num_classes,
1,
stride=1,
padding=0,
bias_attr=True,
param_attr=param_attr)
image_shape = fluid.layers.shape(image)
logit = fluid.layers.resize_bilinear(logit, image_shape[2:])
if not self.output_is_logits:
with scope('logit'):
with fluid.name_scope('last_conv'):
logit = conv(
data,
self.num_classes,
1,
stride=1,
padding=0,
bias_attr=True,
param_attr=param_attr)
else:
logit = data
image_shape = fluid.layers.shape(image)
logit = fluid.layers.resize_bilinear(logit, image_shape[2:])
if self.num_classes == 1:
out = sigmoid_to_softmax(logit)
......
......@@ -112,6 +112,10 @@ def bn_relu(data, norm_type='bn', eps=1e-5):
return fluid.layers.relu(bn(data, norm_type=norm_type, eps=eps))
def qsigmoid(data):
return fluid.layers.relu6(data + 3) * 0.16667
def relu(data):
return fluid.layers.relu(data)
......
......@@ -73,8 +73,6 @@ class Compose(SegTransform):
tuple: 根据网络所需字段所组成的tuple;字段由transforms中的最后一个数据预处理操作决定。
"""
if im_info is None:
im_info = list()
if isinstance(im, np.ndarray):
if len(im.shape) != 3:
raise Exception(
......@@ -86,6 +84,8 @@ class Compose(SegTransform):
except:
raise ValueError('Can\'t read The image file {}!'.format(im))
im = im.astype('float32')
if im_info is None:
im_info = [('origin_shape', im.shape[0:2])]
if self.to_rgb:
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
if label is not None:
......
......@@ -191,9 +191,9 @@ def det_compose(im,
bboxes[:, 3] = bboxes[:, 3] * h_scale
else:
bboxes = outputs[2]['gt_bbox']
if not isinstance(
op,
pdx.cv.transforms.det_transforms.RandomHorizontalFlip):
if not isinstance(op, (
pdx.cv.transforms.det_transforms.RandomHorizontalFlip,
pdx.cv.transforms.det_transforms.Padding)):
for i in range(bboxes.shape[0]):
bbox = bboxes[i]
cname = labels[outputs[2]['gt_class'][i][0] - 1]
......
......@@ -19,7 +19,9 @@ import yaml
import paddlex
import paddle.fluid as fluid
from paddlex.cv.transforms import build_transforms
from paddlex.cv.models import BaseClassifier, YOLOv3, FasterRCNN, MaskRCNN, DeepLabv3p
from paddlex.cv.models import BaseClassifier
from paddlex.cv.models import PPYOLO, FasterRCNN, MaskRCNN
from paddlex.cv.models import DeepLabv3p
class Predictor:
......@@ -129,8 +131,8 @@ class Predictor:
thread_num=thread_num)
res['image'] = im
elif self.model_type == "detector":
if self.model_name == "YOLOv3":
im, im_size = YOLOv3._preprocess(
if self.model_name in ["PPYOLO", "YOLOv3"]:
im, im_size = PPYOLO._preprocess(
image,
self.transforms,
self.model_type,
......@@ -190,8 +192,8 @@ class Predictor:
res = {'bbox': (results[0][0], offset_to_lengths(results[0][1])), }
res['im_id'] = (np.array(
[[i] for i in range(batch_size)]).astype('int32'), [[]])
if self.model_name == "YOLOv3":
preds = YOLOv3._postprocess(res, batch_size, self.num_classes,
if self.model_name in ["PPYOLO", "YOLOv3"]:
preds = PPYOLO._postprocess(res, batch_size, self.num_classes,
self.labels)
elif self.model_name == "FasterRCNN":
preds = FasterRCNN._postprocess(res, batch_size,
......
......@@ -31,7 +31,7 @@ setuptools.setup(
install_requires=[
"pycocotools;platform_system!='Windows'", 'pyyaml', 'colorama', 'tqdm',
'paddleslim==1.0.1', 'visualdl>=2.0.0b', 'paddlehub>=1.6.2',
'shapely>=1.7.0'
'shapely>=1.7.0', "opencv-python"
],
classifiers=[
"Programming Language :: Python :: 3",
......
......@@ -20,6 +20,7 @@
|instance_segmentation/mask_rcnn_r18_fpn.py | 实例分割MaskRCNN | 小度熊分拣 |
|instance_segmentation/mask_rcnn_f50_fpn.py | 实例分割MaskRCNN | 小度熊分拣 |
|semantic_segmentation/deeplabv3p_mobilenetv2.py | 语义分割DeepLabV3 | 视盘分割 |
|semantic_segmentation/deeplabv3p_mobilenetv2.py | 语义分割DeepLabV3 | 视盘分割 |
|semantic_segmentation/deeplabv3p_mobilenetv2_x0.25.py | 语义分割DeepLabV3 | 视盘分割 |
|semantic_segmentation/deeplabv3p_xception65.py | 语义分割DeepLabV3 | 视盘分割 |
|semantic_segmentation/fast_scnn.py | 语义分割FastSCNN | 视盘分割 |
......
# 环境变量配置,用于控制是否使用GPU
# 说明文档:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html#gpu
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import paddlex as pdx
from paddlex.seg import transforms
# 下载和解压视盘分割数据集
optic_dataset = 'https://bj.bcebos.com/paddlex/datasets/optic_disc_seg.tar.gz'
pdx.utils.download_and_decompress(optic_dataset, path='./')
# 定义训练和验证时的transforms
# API说明 https://paddlex.readthedocs.io/zh_CN/develop/apis/transforms/seg_transforms.html
train_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(), transforms.ResizeRangeScaling(),
transforms.RandomPaddingCrop(crop_size=512), transforms.Normalize()
])
eval_transforms = transforms.Compose([
transforms.ResizeByLong(long_size=512),
transforms.Padding(target_size=512), transforms.Normalize()
])
# 定义训练和验证所用的数据集
# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/datasets.html#paddlex-datasets-segdataset
train_dataset = pdx.datasets.SegDataset(
data_dir='optic_disc_seg',
file_list='optic_disc_seg/train_list.txt',
label_list='optic_disc_seg/labels.txt',
transforms=train_transforms,
shuffle=True)
eval_dataset = pdx.datasets.SegDataset(
data_dir='optic_disc_seg',
file_list='optic_disc_seg/val_list.txt',
label_list='optic_disc_seg/labels.txt',
transforms=eval_transforms)
# 初始化模型,并进行训练
# 可使用VisualDL查看训练指标,参考https://paddlex.readthedocs.io/zh_CN/develop/train/visualdl.html
num_classes = len(train_dataset.labels)
# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/models/semantic_segmentation.html#paddlex-seg-deeplabv3p
model = pdx.seg.DeepLabv3p(
num_classes=num_classes,
backbone='MobileNetV3_large_x1_0_ssld',
pooling_crop_size=(512, 512))
# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/models/semantic_segmentation.html#train
# 各参数介绍与调整说明:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html
model.train(
num_epochs=40,
train_dataset=train_dataset,
train_batch_size=4,
eval_dataset=eval_dataset,
learning_rate=0.01,
save_dir='output/deeplabv3p_mobilenetv3_large_ssld',
use_vdl=True)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册