未验证 提交 fd42b2f6 编写于 作者: J Jason 提交者: GitHub

Merge pull request #99 from FlyingQianMM/develop_qh

add hrnet for classify, detect, segment
...@@ -35,7 +35,7 @@ train(self, num_epochs, train_dataset, train_batch_size=64, eval_dataset=None, s ...@@ -35,7 +35,7 @@ train(self, num_epochs, train_dataset, train_batch_size=64, eval_dataset=None, s
> > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。 > > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。
> > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。 > > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
> > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。 > > - **early_stop** (bool): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。 > > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
...@@ -186,3 +186,7 @@ paddlex.cls.DenseNet161(num_classes=1000) ...@@ -186,3 +186,7 @@ paddlex.cls.DenseNet161(num_classes=1000)
paddlex.cls.DenseNet201(num_classes=1000) paddlex.cls.DenseNet201(num_classes=1000)
``` ```
### HRNet_W18
```python
paddlex.cls.HRNet_W18(num_classes=1000)
```
...@@ -53,7 +53,7 @@ train(self, num_epochs, train_dataset, train_batch_size=8, eval_dataset=None, sa ...@@ -53,7 +53,7 @@ train(self, num_epochs, train_dataset, train_batch_size=8, eval_dataset=None, sa
> > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。 > > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。
> > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在PascalVOC数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。 > > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在PascalVOC数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
> > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。 > > - **early_stop** (bool): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。 > > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
...@@ -107,7 +107,7 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec ...@@ -107,7 +107,7 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec
> **参数** > **参数**
> > - **num_classes** (int): 包含了背景类的类别数。默认为81。 > > - **num_classes** (int): 包含了背景类的类别数。默认为81。
> > - **backbone** (str): FasterRCNN的backbone网络,取值范围为['ResNet18', 'ResNet50', 'ResNet50_vd', 'ResNet101', 'ResNet101_vd']。默认为'ResNet50'。 > > - **backbone** (str): FasterRCNN的backbone网络,取值范围为['ResNet18', 'ResNet50', 'ResNet50_vd', 'ResNet101', 'ResNet101_vd', 'HRNet_W18']。默认为'ResNet50'。
> > - **with_fpn** (bool): 是否使用FPN结构。默认为True。 > > - **with_fpn** (bool): 是否使用FPN结构。默认为True。
> > - **aspect_ratios** (list): 生成anchor高宽比的可选值。默认为[0.5, 1.0, 2.0]。 > > - **aspect_ratios** (list): 生成anchor高宽比的可选值。默认为[0.5, 1.0, 2.0]。
> > - **anchor_sizes** (list): 生成anchor大小的可选值。默认为[32, 64, 128, 256, 512]。 > > - **anchor_sizes** (list): 生成anchor大小的可选值。默认为[32, 64, 128, 256, 512]。
......
...@@ -12,7 +12,7 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_ ...@@ -12,7 +12,7 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_
> **参数** > **参数**
> > - **num_classes** (int): 包含了背景类的类别数。默认为81。 > > - **num_classes** (int): 包含了背景类的类别数。默认为81。
> > - **backbone** (str): MaskRCNN的backbone网络,取值范围为['ResNet18', 'ResNet50', 'ResNet50_vd', 'ResNet101', 'ResNet101_vd']。默认为'ResNet50'。 > > - **backbone** (str): MaskRCNN的backbone网络,取值范围为['ResNet18', 'ResNet50', 'ResNet50_vd', 'ResNet101', 'ResNet101_vd', 'HRNet_W18']。默认为'ResNet50'。
> > - **with_fpn** (bool): 是否使用FPN结构。默认为True。 > > - **with_fpn** (bool): 是否使用FPN结构。默认为True。
> > - **aspect_ratios** (list): 生成anchor高宽比的可选值。默认为[0.5, 1.0, 2.0]。 > > - **aspect_ratios** (list): 生成anchor高宽比的可选值。默认为[0.5, 1.0, 2.0]。
> > - **anchor_sizes** (list): 生成anchor大小的可选值。默认为[32, 64, 128, 256, 512]。 > > - **anchor_sizes** (list): 生成anchor大小的可选值。默认为[32, 64, 128, 256, 512]。
......
...@@ -47,7 +47,7 @@ train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, ev ...@@ -47,7 +47,7 @@ train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, ev
> > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认False。 > > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认False。
> > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。 > > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
> > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。 > > - **early_stop** (bool): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。 > > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
...@@ -124,7 +124,7 @@ train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, ev ...@@ -124,7 +124,7 @@ train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, ev
> > - **save_interval_epochs** (int): 模型保存间隔(单位:迭代轮数)。默认为1。 > > - **save_interval_epochs** (int): 模型保存间隔(单位:迭代轮数)。默认为1。
> > - **log_interval_steps** (int): 训练日志输出间隔(单位:迭代次数)。默认为2。 > > - **log_interval_steps** (int): 训练日志输出间隔(单位:迭代次数)。默认为2。
> > - **save_dir** (str): 模型保存路径。默认'output' > > - **save_dir** (str): 模型保存路径。默认'output'
> > - **pretrain_weights** (str): 若指定为路径时,则加载路径下预训练模型;若为字符串'IMAGENET',则自动下载在COCO图片数据上预训练的模型权重;若为None,则不使用预训练模型。默认'COCO'。 > > - **pretrain_weights** (str): 若指定为路径时,则加载路径下预训练模型;若为字符串'COCO',则自动下载在COCO图片数据上预训练的模型权重;若为None,则不使用预训练模型。默认'COCO'。
> > - **optimizer** (paddle.fluid.optimizer): 优化器。当该参数为None时,使用默认的优化器:使用fluid.optimizer.Momentum优化方法,polynomial的学习率衰减策略。 > > - **optimizer** (paddle.fluid.optimizer): 优化器。当该参数为None时,使用默认的优化器:使用fluid.optimizer.Momentum优化方法,polynomial的学习率衰减策略。
> > - **learning_rate** (float): 默认优化器的初始学习率。默认0.01。 > > - **learning_rate** (float): 默认优化器的初始学习率。默认0.01。
> > - **lr_decay_power** (float): 默认优化器学习率衰减指数。默认0.9。 > > - **lr_decay_power** (float): 默认优化器学习率衰减指数。默认0.9。
...@@ -173,3 +173,88 @@ predict(self, im_file, transforms=None): ...@@ -173,3 +173,88 @@ predict(self, im_file, transforms=None):
> **返回值** > **返回值**
> > > >
> > - **dict**: 包含关键字'label_map'和'score_map', 'label_map'存储预测结果灰度图,像素值表示对应的类别,'score_map'存储各类别的概率,shape=(h, w, num_classes)。 > > - **dict**: 包含关键字'label_map'和'score_map', 'label_map'存储预测结果灰度图,像素值表示对应的类别,'score_map'存储各类别的概率,shape=(h, w, num_classes)。
## HRNet类
```python
paddlex.seg.HRNet(num_classes=2, width=18, use_bce_loss=False, use_dice_loss=False, class_weight=None, ignore_index=255)
```
> 构建HRNet分割器。
> **参数**
> > - **num_classes** (int): 类别数。
> > - **width** (int): 高分辨率分支中特征层的通道数量。默认值为18。可选择取值为[18, 30, 32, 40, 44, 48, 60, 64]。
> > - **use_bce_loss** (bool): 是否使用bce loss作为网络的损失函数,只能用于两类分割。可与dice loss同时使用。默认False。
> > - **use_dice_loss** (bool): 是否使用dice loss作为网络的损失函数,只能用于两类分割,可与bce loss同时使用。当use_bce_loss和use_dice_loss都为False时,使用交叉熵损失函数。默认False。
> > - **class_weight** (list/str): 交叉熵损失函数各类损失的权重。当`class_weight`为list的时候,长度应为`num_classes`。当`class_weight`为str时, weight.lower()应为'dynamic',这时会根据每一轮各类像素的比重自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,即平时使用的交叉熵损失函数。
> > - **ignore_index** (int): label上忽略的值,label为`ignore_index`的像素不参与损失函数的计算。默认255。
### train 训练接口
```python
train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5, resume_checkpoint=None):
```
> HRNet模型训练接口。
> **参数**
> >
> > - **num_epochs** (int): 训练迭代轮数。
> > - **train_dataset** (paddlex.datasets): 训练数据读取器。
> > - **train_batch_size** (int): 训练数据batch大小。同时作为验证数据batch大小。默认2。
> > - **eval_dataset** (paddlex.datasets): 评估数据读取器。
> > - **save_interval_epochs** (int): 模型保存间隔(单位:迭代轮数)。默认为1。
> > - **log_interval_steps** (int): 训练日志输出间隔(单位:迭代次数)。默认为2。
> > - **save_dir** (str): 模型保存路径。默认'output'
> > - **pretrain_weights** (str): 若指定为路径时,则加载路径下预训练模型;若为字符串'IMAGENET',则自动下载在ImageNet数据集上预训练的模型权重;若为None,则不使用预训练模型。默认'IMAGENET'。
> > - **optimizer** (paddle.fluid.optimizer): 优化器。当该参数为None时,使用默认的优化器:使用fluid.optimizer.Momentum优化方法,polynomial的学习率衰减策略。
> > - **learning_rate** (float): 默认优化器的初始学习率。默认0.01。
> > - **lr_decay_power** (float): 默认优化器学习率衰减指数。默认0.9。
> > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认False。
> > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
> > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
#### evaluate 评估接口
```
evaluate(self, eval_dataset, batch_size=1, epoch_id=None, return_details=False):
```
> HRNet模型评估接口。
> **参数**
> >
> > - **eval_dataset** (paddlex.datasets): 评估数据读取器。
> > - **batch_size** (int): 评估时的batch大小。默认1。
> > - **epoch_id** (int): 当前评估模型所在的训练轮数。
> > - **return_details** (bool): 是否返回详细信息。默认False。
> **返回值**
> >
> > - **dict**: 当return_details为False时,返回dict。包含关键字:'miou'、'category_iou'、'macc'、
> > 'category_acc'和'kappa',分别表示平均iou、各类别iou、平均准确率、各类别准确率和kappa系数。
> > - **tuple** (metrics, eval_details):当return_details为True时,增加返回dict (eval_details),
> > 包含关键字:'confusion_matrix',表示评估的混淆矩阵。
#### predict 预测接口
```
predict(self, im_file, transforms=None):
```
> HRNet模型预测接口。需要注意的是,只有在训练过程中定义了eval_dataset,模型在保存时才会将预测时的图像处理流程保存在`UNet.test_transforms`和`UNet.eval_transforms`中。如未在训练时定义eval_dataset,那在调用预测`predict`接口时,用户需要再重新定义test_transforms传入给`predict`接口。
> **参数**
> >
> > - **img_file** (str): 预测图像路径。
> > - **transforms** (paddlex.seg.transforms): 数据预处理操作。
> **返回值**
> >
> > - **dict**: 包含关键字'label_map'和'score_map', 'label_map'存储预测结果灰度图,像素值表示对应的类别,'score_map'存储各类别的概率,shape=(h, w, num_classes)。
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
| DenseNet161|116.3MB | 8.863 | 78.6 | 94.1 | | DenseNet161|116.3MB | 8.863 | 78.6 | 94.1 |
| DenseNet201| 84.6MB | 8.173 | 77.6 | 93.7 | | DenseNet201| 84.6MB | 8.173 | 77.6 | 93.7 |
| ShuffleNetV2 | 9.0MB | 10.941 | 68.8 | 88.5 | | ShuffleNetV2 | 9.0MB | 10.941 | 68.8 | 88.5 |
| HRNet_W18 | 21.29MB | 7.368 (V100 GPU) | 76.9 | 93.4 |
## 目标检测模型 ## 目标检测模型
...@@ -41,6 +42,7 @@ ...@@ -41,6 +42,7 @@
|FasterRCNN-ResNet50_vd-FPN|168.7MB | 45.773 | 38.9 | |FasterRCNN-ResNet50_vd-FPN|168.7MB | 45.773 | 38.9 |
|FasterRCNN-ResNet101-FPN| 251.7MB | 55.782 | 38.7 | |FasterRCNN-ResNet101-FPN| 251.7MB | 55.782 | 38.7 |
|FasterRCNN-ResNet101_vd-FPN |252MB | 58.785 | 40.5 | |FasterRCNN-ResNet101_vd-FPN |252MB | 58.785 | 40.5 |
|FasterRCNN-HRNet_W18-FPN |115.5MB | 57.11 | 36 |
|YOLOv3-DarkNet53|252.4MB | 21.944 | 38.9 | |YOLOv3-DarkNet53|252.4MB | 21.944 | 38.9 |
|YOLOv3-MobileNetv1 |101.2MB | 12.771 | 29.3 | |YOLOv3-MobileNetv1 |101.2MB | 12.771 | 29.3 |
|YOLOv3-MobileNetv3|94.6MB | - | 31.6 | |YOLOv3-MobileNetv3|94.6MB | - | 31.6 |
...@@ -49,4 +51,3 @@ ...@@ -49,4 +51,3 @@
## 实例分割模型 ## 实例分割模型
> 表中模型相关指标均为在MSCOCO数据集上测试得到。 > 表中模型相关指标均为在MSCOCO数据集上测试得到。
...@@ -36,5 +36,6 @@ DenseNet121 = cv.models.DenseNet121 ...@@ -36,5 +36,6 @@ DenseNet121 = cv.models.DenseNet121
DenseNet161 = cv.models.DenseNet161 DenseNet161 = cv.models.DenseNet161
DenseNet201 = cv.models.DenseNet201 DenseNet201 = cv.models.DenseNet201
ShuffleNetV2 = cv.models.ShuffleNetV2 ShuffleNetV2 = cv.models.ShuffleNetV2
HRNet_W18 = cv.models.HRNet_W18
transforms = cv.transforms.cls_transforms transforms = cv.transforms.cls_transforms
...@@ -34,11 +34,13 @@ from .classifier import DenseNet121 ...@@ -34,11 +34,13 @@ from .classifier import DenseNet121
from .classifier import DenseNet161 from .classifier import DenseNet161
from .classifier import DenseNet201 from .classifier import DenseNet201
from .classifier import ShuffleNetV2 from .classifier import ShuffleNetV2
from .classifier import HRNet_W18
from .base import BaseAPI from .base import BaseAPI
from .yolo_v3 import YOLOv3 from .yolo_v3 import YOLOv3
from .faster_rcnn import FasterRCNN from .faster_rcnn import FasterRCNN
from .mask_rcnn import MaskRCNN from .mask_rcnn import MaskRCNN
from .unet import UNet from .unet import UNet
from .deeplabv3p import DeepLabv3p from .deeplabv3p import DeepLabv3p
from .hrnet import HRNet
from .load_model import load_model from .load_model import load_model
from .slim import prune from .slim import prune
...@@ -79,9 +79,9 @@ class BaseAPI: ...@@ -79,9 +79,9 @@ class BaseAPI:
return int(batch_size // len(self.places)) return int(batch_size // len(self.places))
else: else:
raise Exception("Please support correct batch_size, \ raise Exception("Please support correct batch_size, \
which can be divided by available cards({}) in {}". which can be divided by available cards({}) in {}"
format(paddlex.env_info['num'], .format(paddlex.env_info['num'], paddlex.env_info[
paddlex.env_info['place'])) 'place']))
def build_program(self): def build_program(self):
# 构建训练网络 # 构建训练网络
...@@ -198,6 +198,8 @@ class BaseAPI: ...@@ -198,6 +198,8 @@ class BaseAPI:
backbone = self.backbone backbone = self.backbone
else: else:
backbone = self.__class__.__name__ backbone = self.__class__.__name__
if backbone == "HRNet":
backbone = backbone + "_W{}".format(self.width)
pretrain_weights = get_pretrain_weights( pretrain_weights = get_pretrain_weights(
pretrain_weights, self.model_type, backbone, pretrain_dir) pretrain_weights, self.model_type, backbone, pretrain_dir)
if startup_prog is None: if startup_prog is None:
...@@ -210,8 +212,8 @@ class BaseAPI: ...@@ -210,8 +212,8 @@ class BaseAPI:
paddlex.utils.utils.load_pretrain_weights( paddlex.utils.utils.load_pretrain_weights(
self.exe, self.train_prog, resume_checkpoint, resume=True) self.exe, self.train_prog, resume_checkpoint, resume=True)
if not osp.exists(osp.join(resume_checkpoint, "model.yml")): if not osp.exists(osp.join(resume_checkpoint, "model.yml")):
raise Exception( raise Exception("There's not model.yml in {}".format(
"There's not model.yml in {}".format(resume_checkpoint)) resume_checkpoint))
with open(osp.join(resume_checkpoint, "model.yml")) as f: with open(osp.join(resume_checkpoint, "model.yml")) as f:
info = yaml.load(f.read(), Loader=yaml.Loader) info = yaml.load(f.read(), Loader=yaml.Loader)
self.completed_epochs = info['completed_epochs'] self.completed_epochs = info['completed_epochs']
...@@ -269,13 +271,13 @@ class BaseAPI: ...@@ -269,13 +271,13 @@ class BaseAPI:
except: except:
pass pass
if hasattr(self, 'test_transforms'):
if hasattr(self.test_transforms, 'to_rgb'): if hasattr(self.test_transforms, 'to_rgb'):
if self.test_transforms.to_rgb: if self.test_transforms.to_rgb:
info['TransformsMode'] = 'RGB' info['TransformsMode'] = 'RGB'
else: else:
info['TransformsMode'] = 'BGR' info['TransformsMode'] = 'BGR'
if hasattr(self, 'test_transforms'):
if self.test_transforms is not None: if self.test_transforms is not None:
info['Transforms'] = list() info['Transforms'] = list()
for op in self.test_transforms.transforms: for op in self.test_transforms.transforms:
...@@ -362,8 +364,8 @@ class BaseAPI: ...@@ -362,8 +364,8 @@ class BaseAPI:
# 模型保存成功的标志 # 模型保存成功的标志
open(osp.join(save_dir, '.success'), 'w').close() open(osp.join(save_dir, '.success'), 'w').close()
logging.info( logging.info("Model for inference deploy saved in {}.".format(
"Model for inference deploy saved in {}.".format(save_dir)) save_dir))
def train_loop(self, def train_loop(self,
num_epochs, num_epochs,
...@@ -377,7 +379,8 @@ class BaseAPI: ...@@ -377,7 +379,8 @@ class BaseAPI:
early_stop=False, early_stop=False,
early_stop_patience=5): early_stop_patience=5):
if train_dataset.num_samples < train_batch_size: if train_dataset.num_samples < train_batch_size:
raise Exception('The amount of training datset must be larger than batch size.') raise Exception(
'The amount of training datset must be larger than batch size.')
if not osp.isdir(save_dir): if not osp.isdir(save_dir):
if osp.exists(save_dir): if osp.exists(save_dir):
os.remove(save_dir) os.remove(save_dir)
...@@ -415,8 +418,8 @@ class BaseAPI: ...@@ -415,8 +418,8 @@ class BaseAPI:
build_strategy=build_strategy, build_strategy=build_strategy,
exec_strategy=exec_strategy) exec_strategy=exec_strategy)
total_num_steps = math.floor( total_num_steps = math.floor(train_dataset.num_samples /
train_dataset.num_samples / train_batch_size) train_batch_size)
num_steps = 0 num_steps = 0
time_stat = list() time_stat = list()
time_train_one_epoch = None time_train_one_epoch = None
...@@ -430,8 +433,8 @@ class BaseAPI: ...@@ -430,8 +433,8 @@ class BaseAPI:
if self.model_type == 'detector': if self.model_type == 'detector':
eval_batch_size = self._get_single_card_bs(train_batch_size) eval_batch_size = self._get_single_card_bs(train_batch_size)
if eval_dataset is not None: if eval_dataset is not None:
total_num_steps_eval = math.ceil( total_num_steps_eval = math.ceil(eval_dataset.num_samples /
eval_dataset.num_samples / eval_batch_size) eval_batch_size)
if use_vdl: if use_vdl:
# VisualDL component # VisualDL component
...@@ -473,7 +476,9 @@ class BaseAPI: ...@@ -473,7 +476,9 @@ class BaseAPI:
if use_vdl: if use_vdl:
for k, v in step_metrics.items(): for k, v in step_metrics.items():
log_writer.add_scalar('Metrics/Training(Step): {}'.format(k), v, num_steps) log_writer.add_scalar(
'Metrics/Training(Step): {}'.format(k), v,
num_steps)
# 估算剩余时间 # 估算剩余时间
avg_step_time = np.mean(time_stat) avg_step_time = np.mean(time_stat)
...@@ -481,11 +486,12 @@ class BaseAPI: ...@@ -481,11 +486,12 @@ class BaseAPI:
eta = (num_epochs - i - 1) * time_train_one_epoch + ( eta = (num_epochs - i - 1) * time_train_one_epoch + (
total_num_steps - step - 1) * avg_step_time total_num_steps - step - 1) * avg_step_time
else: else:
eta = ((num_epochs - i) * total_num_steps - step - eta = ((num_epochs - i) * total_num_steps - step - 1
1) * avg_step_time ) * avg_step_time
if time_eval_one_epoch is not None: if time_eval_one_epoch is not None:
eval_eta = (total_eval_times - i // eval_eta = (
save_interval_epochs) * time_eval_one_epoch total_eval_times - i // save_interval_epochs
) * time_eval_one_epoch
else: else:
eval_eta = ( eval_eta = (
total_eval_times - i // save_interval_epochs total_eval_times - i // save_interval_epochs
...@@ -495,10 +501,11 @@ class BaseAPI: ...@@ -495,10 +501,11 @@ class BaseAPI:
logging.info( logging.info(
"[TRAIN] Epoch={}/{}, Step={}/{}, {}, time_each_step={}s, eta={}" "[TRAIN] Epoch={}/{}, Step={}/{}, {}, time_each_step={}s, eta={}"
.format(i + 1, num_epochs, step + 1, total_num_steps, .format(i + 1, num_epochs, step + 1, total_num_steps,
dict2str(step_metrics), round( dict2str(step_metrics),
avg_step_time, 2), eta_str)) round(avg_step_time, 2), eta_str))
train_metrics = OrderedDict( train_metrics = OrderedDict(
zip(list(self.train_outputs.keys()), np.mean(records, axis=0))) zip(list(self.train_outputs.keys()), np.mean(
records, axis=0)))
logging.info('[TRAIN] Epoch {} finished, {} .'.format( logging.info('[TRAIN] Epoch {} finished, {} .'.format(
i + 1, dict2str(train_metrics))) i + 1, dict2str(train_metrics)))
time_train_one_epoch = time.time() - epoch_start_time time_train_one_epoch = time.time() - epoch_start_time
...@@ -534,7 +541,8 @@ class BaseAPI: ...@@ -534,7 +541,8 @@ class BaseAPI:
if isinstance(v, np.ndarray): if isinstance(v, np.ndarray):
if v.size > 1: if v.size > 1:
continue continue
log_writer.add_scalar("Metrics/Eval(Epoch): {}".format(k), v, i+1) log_writer.add_scalar(
"Metrics/Eval(Epoch): {}".format(k), v, i + 1)
self.save_model(save_dir=current_save_dir) self.save_model(save_dir=current_save_dir)
time_eval_one_epoch = time.time() - eval_epoch_start_time time_eval_one_epoch = time.time() - eval_epoch_start_time
eval_epoch_start_time = time.time() eval_epoch_start_time = time.time()
......
...@@ -40,8 +40,8 @@ class BaseClassifier(BaseAPI): ...@@ -40,8 +40,8 @@ class BaseClassifier(BaseAPI):
self.init_params = locals() self.init_params = locals()
super(BaseClassifier, self).__init__('classifier') super(BaseClassifier, self).__init__('classifier')
if not hasattr(paddlex.cv.nets, str.lower(model_name)): if not hasattr(paddlex.cv.nets, str.lower(model_name)):
raise Exception( raise Exception("ERROR: There's no model named {}.".format(
"ERROR: There's no model named {}.".format(model_name)) model_name))
self.model_name = model_name self.model_name = model_name
self.labels = None self.labels = None
self.num_classes = num_classes self.num_classes = num_classes
...@@ -218,15 +218,14 @@ class BaseClassifier(BaseAPI): ...@@ -218,15 +218,14 @@ class BaseClassifier(BaseAPI):
num_pad_samples = batch_size - num_samples num_pad_samples = batch_size - num_samples
pad_images = np.tile(images[0:1], (num_pad_samples, 1, 1, 1)) pad_images = np.tile(images[0:1], (num_pad_samples, 1, 1, 1))
images = np.concatenate([images, pad_images]) images = np.concatenate([images, pad_images])
outputs = self.exe.run( outputs = self.exe.run(self.parallel_test_prog,
self.parallel_test_prog,
feed={'image': images}, feed={'image': images},
fetch_list=list(self.test_outputs.values())) fetch_list=list(self.test_outputs.values()))
outputs = [outputs[0][:num_samples]] outputs = [outputs[0][:num_samples]]
true_labels.extend(labels) true_labels.extend(labels)
pred_scores.extend(outputs[0].tolist()) pred_scores.extend(outputs[0].tolist())
logging.debug("[EVAL] Epoch={}, Step={}/{}".format( logging.debug("[EVAL] Epoch={}, Step={}/{}".format(epoch_id, step +
epoch_id, step + 1, total_steps)) 1, total_steps))
pred_top1_label = np.argsort(pred_scores)[:, -1] pred_top1_label = np.argsort(pred_scores)[:, -1]
pred_topk_label = np.argsort(pred_scores)[:, -k:] pred_topk_label = np.argsort(pred_scores)[:, -k:]
...@@ -263,8 +262,7 @@ class BaseClassifier(BaseAPI): ...@@ -263,8 +262,7 @@ class BaseClassifier(BaseAPI):
self.arrange_transforms( self.arrange_transforms(
transforms=self.test_transforms, mode='test') transforms=self.test_transforms, mode='test')
im = self.test_transforms(img_file) im = self.test_transforms(img_file)
result = self.exe.run( result = self.exe.run(self.test_prog,
self.test_prog,
feed={'image': im}, feed={'image': im},
fetch_list=list(self.test_outputs.values())) fetch_list=list(self.test_outputs.values()))
pred_label = np.argsort(result[0][0])[::-1][:true_topk] pred_label = np.argsort(result[0][0])[::-1][:true_topk]
...@@ -400,3 +398,9 @@ class ShuffleNetV2(BaseClassifier): ...@@ -400,3 +398,9 @@ class ShuffleNetV2(BaseClassifier):
def __init__(self, num_classes=1000): def __init__(self, num_classes=1000):
super(ShuffleNetV2, self).__init__( super(ShuffleNetV2, self).__init__(
model_name='ShuffleNetV2', num_classes=num_classes) model_name='ShuffleNetV2', num_classes=num_classes)
class HRNet_W18(BaseClassifier):
def __init__(self, num_classes=1000):
super(HRNet_W18, self).__init__(
model_name='HRNet_W18', num_classes=num_classes)
...@@ -32,7 +32,7 @@ class FasterRCNN(BaseAPI): ...@@ -32,7 +32,7 @@ class FasterRCNN(BaseAPI):
Args: Args:
num_classes (int): 包含了背景类的类别数。默认为81。 num_classes (int): 包含了背景类的类别数。默认为81。
backbone (str): FasterRCNN的backbone网络,取值范围为['ResNet18', 'ResNet50', backbone (str): FasterRCNN的backbone网络,取值范围为['ResNet18', 'ResNet50',
'ResNet50_vd', 'ResNet101', 'ResNet101_vd']。默认为'ResNet50'。 'ResNet50_vd', 'ResNet101', 'ResNet101_vd', 'HRNet_W18']。默认为'ResNet50'。
with_fpn (bool): 是否使用FPN结构。默认为True。 with_fpn (bool): 是否使用FPN结构。默认为True。
aspect_ratios (list): 生成anchor高宽比的可选值。默认为[0.5, 1.0, 2.0]。 aspect_ratios (list): 生成anchor高宽比的可选值。默认为[0.5, 1.0, 2.0]。
anchor_sizes (list): 生成anchor大小的可选值。默认为[32, 64, 128, 256, 512]。 anchor_sizes (list): 生成anchor大小的可选值。默认为[32, 64, 128, 256, 512]。
...@@ -47,7 +47,8 @@ class FasterRCNN(BaseAPI): ...@@ -47,7 +47,8 @@ class FasterRCNN(BaseAPI):
self.init_params = locals() self.init_params = locals()
super(FasterRCNN, self).__init__('detector') super(FasterRCNN, self).__init__('detector')
backbones = [ backbones = [
'ResNet18', 'ResNet50', 'ResNet50_vd', 'ResNet101', 'ResNet101_vd' 'ResNet18', 'ResNet50', 'ResNet50_vd', 'ResNet101', 'ResNet101_vd',
'HRNet_W18'
] ]
assert backbone in backbones, "backbone should be one of {}".format( assert backbone in backbones, "backbone should be one of {}".format(
backbones) backbones)
...@@ -79,6 +80,12 @@ class FasterRCNN(BaseAPI): ...@@ -79,6 +80,12 @@ class FasterRCNN(BaseAPI):
layers = 101 layers = 101
variant = 'd' variant = 'd'
norm_type = 'affine_channel' norm_type = 'affine_channel'
elif backbone_name == 'HRNet_W18':
backbone = paddlex.cv.nets.hrnet.HRNet(
width=18, freeze_norm=True, norm_decay=0., freeze_at=0)
if self.with_fpn is False:
self.with_fpn = True
return backbone
if self.with_fpn: if self.with_fpn:
backbone = paddlex.cv.nets.resnet.ResNet( backbone = paddlex.cv.nets.resnet.ResNet(
norm_type='bn' if norm_type is None else norm_type, norm_type='bn' if norm_type is None else norm_type,
...@@ -227,7 +234,9 @@ class FasterRCNN(BaseAPI): ...@@ -227,7 +234,9 @@ class FasterRCNN(BaseAPI):
# 构建训练、验证、测试网络 # 构建训练、验证、测试网络
self.build_program() self.build_program()
fuse_bn = True fuse_bn = True
if self.with_fpn and self.backbone in ['ResNet18', 'ResNet50']: if self.with_fpn and self.backbone in [
'ResNet18', 'ResNet50', 'HRNet_W18'
]:
fuse_bn = False fuse_bn = False
self.net_initialize( self.net_initialize(
startup_prog=fluid.default_startup_program(), startup_prog=fluid.default_startup_program(),
......
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
import paddle.fluid as fluid
import paddlex
from collections import OrderedDict
from .deeplabv3p import DeepLabv3p
class HRNet(DeepLabv3p):
"""实现HRNet网络的构建并进行训练、评估、预测和模型导出。
Args:
num_classes (int): 类别数。
width (int): 高分辨率分支中特征层的通道数量。默认值为18。可选择取值为[18, 30, 32, 40, 44, 48, 60, 64]。
use_bce_loss (bool): 是否使用bce loss作为网络的损失函数,只能用于两类分割。可与dice loss同时使用。默认False。
use_dice_loss (bool): 是否使用dice loss作为网络的损失函数,只能用于两类分割,可与bce loss同时使用。
当use_bce_loss和use_dice_loss都为False时,使用交叉熵损失函数。默认False。
class_weight (list/str): 交叉熵损失函数各类损失的权重。当class_weight为list的时候,长度应为
num_classes。当class_weight为str时, weight.lower()应为'dynamic',这时会根据每一轮各类像素的比重
自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,
即平时使用的交叉熵损失函数。
ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。默认255。
Raises:
ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。
ValueError: class_weight为list, 但长度不等于num_class。
class_weight为str, 但class_weight.low()不等于dynamic。
TypeError: class_weight不为None时,其类型不是list或str。
"""
def __init__(self,
num_classes=2,
width=18,
use_bce_loss=False,
use_dice_loss=False,
class_weight=None,
ignore_index=255):
self.init_params = locals()
super(DeepLabv3p, self).__init__('segmenter')
# dice_loss或bce_loss只适用两类分割中
if num_classes > 2 and (use_bce_loss or use_dice_loss):
raise ValueError(
"dice loss and bce loss is only applicable to binary classfication"
)
if class_weight is not None:
if isinstance(class_weight, list):
if len(class_weight) != num_classes:
raise ValueError(
"Length of class_weight should be equal to number of classes"
)
elif isinstance(class_weight, str):
if class_weight.lower() != 'dynamic':
raise ValueError(
"if class_weight is string, must be dynamic!")
else:
raise TypeError(
'Expect class_weight is a list or string but receive {}'.
format(type(class_weight)))
self.num_classes = num_classes
self.width = width
self.use_bce_loss = use_bce_loss
self.use_dice_loss = use_dice_loss
self.class_weight = class_weight
self.ignore_index = ignore_index
self.labels = None
def build_net(self, mode='train'):
model = paddlex.cv.nets.segmentation.HRNet(
self.num_classes,
width=self.width,
mode=mode,
use_bce_loss=self.use_bce_loss,
use_dice_loss=self.use_dice_loss,
class_weight=self.class_weight,
ignore_index=self.ignore_index)
inputs = model.generate_inputs()
model_out = model.build_net(inputs)
outputs = OrderedDict()
if mode == 'train':
self.optimizer.minimize(model_out)
outputs['loss'] = model_out
elif mode == 'eval':
outputs['loss'] = model_out[0]
outputs['pred'] = model_out[1]
outputs['label'] = model_out[2]
outputs['mask'] = model_out[3]
else:
outputs['pred'] = model_out[0]
outputs['logit'] = model_out[1]
return inputs, outputs
def default_optimizer(self,
learning_rate,
num_epochs,
num_steps_each_epoch,
lr_decay_power=0.9):
decay_step = num_epochs * num_steps_each_epoch
lr_decay = fluid.layers.polynomial_decay(
learning_rate,
decay_step,
end_learning_rate=0,
power=lr_decay_power)
optimizer = fluid.optimizer.Momentum(
lr_decay,
momentum=0.9,
regularization=fluid.regularizer.L2Decay(
regularization_coeff=5e-04))
return optimizer
def train(self,
num_epochs,
train_dataset,
train_batch_size=2,
eval_dataset=None,
save_interval_epochs=1,
log_interval_steps=2,
save_dir='output',
pretrain_weights='IMAGENET',
optimizer=None,
learning_rate=0.01,
lr_decay_power=0.9,
use_vdl=False,
sensitivities_file=None,
eval_metric_loss=0.05,
early_stop=False,
early_stop_patience=5,
resume_checkpoint=None):
"""训练。
Args:
num_epochs (int): 训练迭代轮数。
train_dataset (paddlex.datasets): 训练数据读取器。
train_batch_size (int): 训练数据batch大小。同时作为验证数据batch大小。默认2。
eval_dataset (paddlex.datasets): 评估数据读取器。
save_interval_epochs (int): 模型保存间隔(单位:迭代轮数)。默认为1。
log_interval_steps (int): 训练日志输出间隔(单位:迭代次数)。默认为2。
save_dir (str): 模型保存路径。默认'output'。
pretrain_weights (str): 若指定为路径时,则加载路径下预训练模型;若为字符串'IMAGENET',
则自动下载在IMAGENET图片数据上预训练的模型权重;若为None,则不使用预训练模型。默认为'IMAGENET'。
optimizer (paddle.fluid.optimizer): 优化器。当改参数为None时,使用默认的优化器:使用
fluid.optimizer.Momentum优化方法,polynomial的学习率衰减策略。
learning_rate (float): 默认优化器的初始学习率。默认0.01。
lr_decay_power (float): 默认优化器学习率多项式衰减系数。默认0.9。
use_vdl (bool): 是否使用VisualDL进行可视化。默认False。
sensitivities_file (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',
则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
eval_metric_loss (float): 可容忍的精度损失。默认为0.05。
early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。
resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
Raises:
ValueError: 模型从inference model进行加载。
"""
return super(HRNet, self).train(
num_epochs, train_dataset, train_batch_size, eval_dataset,
save_interval_epochs, log_interval_steps, save_dir,
pretrain_weights, optimizer, learning_rate, lr_decay_power,
use_vdl, sensitivities_file, eval_metric_loss, early_stop,
early_stop_patience, resume_checkpoint)
...@@ -32,7 +32,7 @@ class MaskRCNN(FasterRCNN): ...@@ -32,7 +32,7 @@ class MaskRCNN(FasterRCNN):
Args: Args:
num_classes (int): 包含了背景类的类别数。默认为81。 num_classes (int): 包含了背景类的类别数。默认为81。
backbone (str): MaskRCNN的backbone网络,取值范围为['ResNet18', 'ResNet50', backbone (str): MaskRCNN的backbone网络,取值范围为['ResNet18', 'ResNet50',
'ResNet50_vd', 'ResNet101', 'ResNet101_vd']。默认为'ResNet50'。 'ResNet50_vd', 'ResNet101', 'ResNet101_vd', 'HRNet_W18']。默认为'ResNet50'。
with_fpn (bool): 是否使用FPN结构。默认为True。 with_fpn (bool): 是否使用FPN结构。默认为True。
aspect_ratios (list): 生成anchor高宽比的可选值。默认为[0.5, 1.0, 2.0]。 aspect_ratios (list): 生成anchor高宽比的可选值。默认为[0.5, 1.0, 2.0]。
anchor_sizes (list): 生成anchor大小的可选值。默认为[32, 64, 128, 256, 512]。 anchor_sizes (list): 生成anchor大小的可选值。默认为[32, 64, 128, 256, 512]。
...@@ -46,7 +46,8 @@ class MaskRCNN(FasterRCNN): ...@@ -46,7 +46,8 @@ class MaskRCNN(FasterRCNN):
anchor_sizes=[32, 64, 128, 256, 512]): anchor_sizes=[32, 64, 128, 256, 512]):
self.init_params = locals() self.init_params = locals()
backbones = [ backbones = [
'ResNet18', 'ResNet50', 'ResNet50_vd', 'ResNet101', 'ResNet101_vd' 'ResNet18', 'ResNet50', 'ResNet50_vd', 'ResNet101', 'ResNet101_vd',
'HRNet_W18'
] ]
assert backbone in backbones, "backbone should be one of {}".format( assert backbone in backbones, "backbone should be one of {}".format(
backbones) backbones)
...@@ -194,7 +195,9 @@ class MaskRCNN(FasterRCNN): ...@@ -194,7 +195,9 @@ class MaskRCNN(FasterRCNN):
# 构建训练、验证、测试网络 # 构建训练、验证、测试网络
self.build_program() self.build_program()
fuse_bn = True fuse_bn = True
if self.with_fpn and self.backbone in ['ResNet18', 'ResNet50']: if self.with_fpn and self.backbone in [
'ResNet18', 'ResNet50', 'HRNet_W18'
]:
fuse_bn = False fuse_bn = False
self.net_initialize( self.net_initialize(
startup_prog=fluid.default_startup_program(), startup_prog=fluid.default_startup_program(),
......
...@@ -56,6 +56,20 @@ image_pretrain = { ...@@ -56,6 +56,20 @@ image_pretrain = {
'https://paddle-imagenet-models-name.bj.bcebos.com/Xception65_deeplab_pretrained.tar', 'https://paddle-imagenet-models-name.bj.bcebos.com/Xception65_deeplab_pretrained.tar',
'ShuffleNetV2': 'ShuffleNetV2':
'https://paddle-imagenet-models-name.bj.bcebos.com/ShuffleNetV2_pretrained.tar', 'https://paddle-imagenet-models-name.bj.bcebos.com/ShuffleNetV2_pretrained.tar',
'HRNet_W18':
'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W18_C_pretrained.tar',
'HRNet_W30':
'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W30_C_pretrained.tar',
'HRNet_W32':
'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W32_C_pretrained.tar',
'HRNet_W40':
'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W40_C_pretrained.tar',
'HRNet_W48':
'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W48_C_pretrained.tar',
'HRNet_W60':
'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W60_C_pretrained.tar',
'HRNet_W64':
'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W64_C_pretrained.tar',
} }
coco_pretrain = { coco_pretrain = {
......
...@@ -23,6 +23,7 @@ from .segmentation import DeepLabv3p ...@@ -23,6 +23,7 @@ from .segmentation import DeepLabv3p
from .xception import Xception from .xception import Xception
from .densenet import DenseNet from .densenet import DenseNet
from .shufflenet_v2 import ShuffleNetV2 from .shufflenet_v2 import ShuffleNetV2
from .hrnet import HRNet
def resnet18(input, num_classes=1000): def resnet18(input, num_classes=1000):
...@@ -51,14 +52,20 @@ def resnet50_vd(input, num_classes=1000): ...@@ -51,14 +52,20 @@ def resnet50_vd(input, num_classes=1000):
def resnet50_vd_ssld(input, num_classes=1000): def resnet50_vd_ssld(input, num_classes=1000):
model = ResNet(layers=50, num_classes=num_classes, model = ResNet(
variant='d', lr_mult_list=[1.0, 0.1, 0.2, 0.2, 0.3]) layers=50,
num_classes=num_classes,
variant='d',
lr_mult_list=[1.0, 0.1, 0.2, 0.2, 0.3])
return model(input) return model(input)
def resnet101_vd_ssld(input, num_classes=1000): def resnet101_vd_ssld(input, num_classes=1000):
model = ResNet(layers=101, num_classes=num_classes, model = ResNet(
variant='d', lr_mult_list=[1.0, 0.1, 0.2, 0.2, 0.3]) layers=101,
num_classes=num_classes,
variant='d',
lr_mult_list=[1.0, 0.1, 0.2, 0.2, 0.3])
return model(input) return model(input)
...@@ -93,13 +100,17 @@ def mobilenetv3_large(input, num_classes=1000): ...@@ -93,13 +100,17 @@ def mobilenetv3_large(input, num_classes=1000):
def mobilenetv3_small_ssld(input, num_classes=1000): def mobilenetv3_small_ssld(input, num_classes=1000):
model = MobileNetV3(num_classes=num_classes, model_name='small', model = MobileNetV3(
num_classes=num_classes,
model_name='small',
lr_mult_list=[0.25, 0.25, 0.5, 0.5, 0.75]) lr_mult_list=[0.25, 0.25, 0.5, 0.5, 0.75])
return model(input) return model(input)
def mobilenetv3_large_ssld(input, num_classes=1000): def mobilenetv3_large_ssld(input, num_classes=1000):
model = MobileNetV3(num_classes=num_classes, model_name='large', model = MobileNetV3(
num_classes=num_classes,
model_name='large',
lr_mult_list=[0.25, 0.25, 0.5, 0.5, 0.75]) lr_mult_list=[0.25, 0.25, 0.5, 0.5, 0.75])
return model(input) return model(input)
...@@ -133,6 +144,12 @@ def densenet201(input, num_classes=1000): ...@@ -133,6 +144,12 @@ def densenet201(input, num_classes=1000):
model = DenseNet(layers=201, num_classes=num_classes) model = DenseNet(layers=201, num_classes=num_classes)
return model(input) return model(input)
def shufflenetv2(input, num_classes=1000): def shufflenetv2(input, num_classes=1000):
model = ShuffleNetV2(num_classes=num_classes) model = ShuffleNetV2(num_classes=num_classes)
return model(input) return model(input)
def hrnet_w18(input, num_classes=1000):
model = HRNet(width=18, num_classes=num_classes)
return model(input)
...@@ -68,13 +68,14 @@ class DarkNet(object): ...@@ -68,13 +68,14 @@ class DarkNet(object):
bias_attr=False) bias_attr=False)
bn_name = name + ".bn" bn_name = name + ".bn"
if self.num_classes:
regularizer = None
else:
regularizer = L2Decay(float(self.norm_decay))
bn_param_attr = ParamAttr( bn_param_attr = ParamAttr(
regularizer=L2Decay(float(self.norm_decay)), regularizer=regularizer, name=bn_name + '.scale')
name=bn_name + '.scale')
bn_bias_attr = ParamAttr( bn_bias_attr = ParamAttr(
regularizer=L2Decay(float(self.norm_decay)), regularizer=regularizer, name=bn_name + '.offset')
name=bn_name + '.offset')
out = fluid.layers.batch_norm( out = fluid.layers.batch_norm(
input=conv, input=conv,
......
...@@ -21,7 +21,7 @@ import copy ...@@ -21,7 +21,7 @@ import copy
from paddle import fluid from paddle import fluid
from .fpn import FPN from .fpn import (FPN, HRFPN)
from .rpn_head import (RPNHead, FPNRPNHead) from .rpn_head import (RPNHead, FPNRPNHead)
from .roi_extractor import (RoIAlign, FPNRoIAlign) from .roi_extractor import (RoIAlign, FPNRoIAlign)
from .bbox_head import (BBoxHead, TwoFCHead) from .bbox_head import (BBoxHead, TwoFCHead)
...@@ -82,6 +82,11 @@ class FasterRCNN(object): ...@@ -82,6 +82,11 @@ class FasterRCNN(object):
self.backbone = backbone self.backbone = backbone
self.mode = mode self.mode = mode
if with_fpn and fpn is None: if with_fpn and fpn is None:
if self.backbone.__class__.__name__.startswith('HRNet'):
fpn = HRFPN()
fpn.min_level = 2
fpn.max_level = 6
else:
fpn = FPN() fpn = FPN()
self.fpn = fpn self.fpn = fpn
self.num_classes = num_classes self.num_classes = num_classes
......
...@@ -23,7 +23,7 @@ from paddle.fluid.param_attr import ParamAttr ...@@ -23,7 +23,7 @@ from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Xavier from paddle.fluid.initializer import Xavier
from paddle.fluid.regularizer import L2Decay from paddle.fluid.regularizer import L2Decay
__all__ = ['FPN'] __all__ = ['FPN', 'HRFPN']
def ConvNorm(input, def ConvNorm(input,
...@@ -219,8 +219,8 @@ class FPN(object): ...@@ -219,8 +219,8 @@ class FPN(object):
body_name = body_name_list[i] body_name = body_name_list[i]
body_input = body_dict[body_name] body_input = body_dict[body_name]
top_output = self.fpn_inner_output[i - 1] top_output = self.fpn_inner_output[i - 1]
fpn_inner_single = self._add_topdown_lateral( fpn_inner_single = self._add_topdown_lateral(body_name, body_input,
body_name, body_input, top_output) top_output)
self.fpn_inner_output[i] = fpn_inner_single self.fpn_inner_output[i] = fpn_inner_single
fpn_dict = {} fpn_dict = {}
fpn_name_list = [] fpn_name_list = []
...@@ -293,3 +293,107 @@ class FPN(object): ...@@ -293,3 +293,107 @@ class FPN(object):
spatial_scale.insert(0, spatial_scale[0] * 0.5) spatial_scale.insert(0, spatial_scale[0] * 0.5)
res_dict = OrderedDict([(k, fpn_dict[k]) for k in fpn_name_list]) res_dict = OrderedDict([(k, fpn_dict[k]) for k in fpn_name_list])
return res_dict, spatial_scale return res_dict, spatial_scale
class HRFPN(object):
"""
HRNet, see https://arxiv.org/abs/1908.07919
Args:
num_chan (int): number of feature channels
pooling_type (str): pooling type of downsampling
share_conv (bool): whethet to share conv for different layers' reduction
spatial_scale (list): feature map scaling factor
"""
def __init__(
self,
num_chan=256,
pooling_type="avg",
share_conv=False,
spatial_scale=[1. / 64, 1. / 32, 1. / 16, 1. / 8, 1. / 4], ):
self.num_chan = num_chan
self.pooling_type = pooling_type
self.share_conv = share_conv
self.spatial_scale = spatial_scale
def get_output(self, body_dict):
num_out = len(self.spatial_scale)
body_name_list = list(body_dict.keys())
num_backbone_stages = len(body_name_list)
outs = []
outs.append(body_dict[body_name_list[0]])
# resize
for i in range(1, len(body_dict)):
resized = self.resize_input_tensor(body_dict[body_name_list[i]],
outs[0], 2**i)
outs.append(resized)
# concat
out = fluid.layers.concat(outs, axis=1)
# reduction
out = fluid.layers.conv2d(
input=out,
num_filters=self.num_chan,
filter_size=1,
stride=1,
padding=0,
param_attr=ParamAttr(name='hrfpn_reduction_weights'),
bias_attr=False)
# conv
outs = [out]
for i in range(1, num_out):
outs.append(
self.pooling(
out,
size=2**i,
stride=2**i,
pooling_type=self.pooling_type))
outputs = []
for i in range(num_out):
conv_name = "shared_fpn_conv" if self.share_conv else "shared_fpn_conv_" + str(
i)
conv = fluid.layers.conv2d(
input=outs[i],
num_filters=self.num_chan,
filter_size=3,
stride=1,
padding=1,
param_attr=ParamAttr(name=conv_name + "_weights"),
bias_attr=False)
outputs.append(conv)
for idx in range(0, num_out - len(body_name_list)):
body_name_list.append("fpn_res5_sum_subsampled_{}x".format(2**(
idx + 1)))
outputs = outputs[::-1]
body_name_list = body_name_list[::-1]
res_dict = OrderedDict([(body_name_list[k], outputs[k])
for k in range(len(body_name_list))])
return res_dict, self.spatial_scale
def resize_input_tensor(self, body_input, ref_output, scale):
shape = fluid.layers.shape(ref_output)
shape_hw = fluid.layers.slice(shape, axes=[0], starts=[2], ends=[4])
out_shape_ = shape_hw
out_shape = fluid.layers.cast(out_shape_, dtype='int32')
out_shape.stop_gradient = True
body_output = fluid.layers.resize_bilinear(
body_input, scale=scale, out_shape=out_shape)
return body_output
def pooling(self, input, size, stride, pooling_type):
pool = fluid.layers.pool2d(
input=input,
pool_size=size,
pool_stride=stride,
pool_type=pooling_type)
return pool
...@@ -21,7 +21,7 @@ import copy ...@@ -21,7 +21,7 @@ import copy
import paddle.fluid as fluid import paddle.fluid as fluid
from .fpn import FPN from .fpn import (FPN, HRFPN)
from .rpn_head import (RPNHead, FPNRPNHead) from .rpn_head import (RPNHead, FPNRPNHead)
from .roi_extractor import (RoIAlign, FPNRoIAlign) from .roi_extractor import (RoIAlign, FPNRoIAlign)
from .bbox_head import (BBoxHead, TwoFCHead) from .bbox_head import (BBoxHead, TwoFCHead)
...@@ -92,8 +92,12 @@ class MaskRCNN(object): ...@@ -92,8 +92,12 @@ class MaskRCNN(object):
self.backbone = backbone self.backbone = backbone
self.mode = mode self.mode = mode
if with_fpn and fpn is None: if with_fpn and fpn is None:
fpn = FPN( if self.backbone.__class__.__name__.startswith('HRNet'):
num_chan=num_chan, fpn = HRFPN()
fpn.min_level = 2
fpn.max_level = 6
else:
fpn = FPN(num_chan=num_chan,
min_level=min_level, min_level=min_level,
max_level=max_level, max_level=max_level,
spatial_scale=spatial_scale) spatial_scale=spatial_scale)
...@@ -200,7 +204,7 @@ class MaskRCNN(object): ...@@ -200,7 +204,7 @@ class MaskRCNN(object):
bg_thresh_hi=self.bg_thresh_hi, bg_thresh_hi=self.bg_thresh_hi,
bg_thresh_lo=self.bg_thresh_lo, bg_thresh_lo=self.bg_thresh_lo,
bbox_reg_weights=self.bbox_reg_weights, bbox_reg_weights=self.bbox_reg_weights,
class_nums=self.num_classes, calass_nums=self.num_classes,
use_random=self.rpn_head.use_random) use_random=self.rpn_head.use_random)
rois = outputs[0] rois = outputs[0]
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.framework import Variable
from paddle.fluid.regularizer import L2Decay
from numbers import Integral
from paddle.fluid.initializer import MSRA
import math
__all__ = ['HRNet']
class HRNet(object):
def __init__(self,
width=40,
has_se=False,
freeze_at=0,
norm_type='bn',
freeze_norm=False,
norm_decay=0.,
feature_maps=[2, 3, 4, 5],
num_classes=None):
super(HRNet, self).__init__()
if isinstance(feature_maps, Integral):
feature_maps = [feature_maps]
assert 0 <= freeze_at <= 4, "freeze_at should be 0, 1, 2, 3 or 4"
assert len(feature_maps) > 0, "need one or more feature maps"
assert norm_type in ['bn', 'sync_bn']
self.width = width
self.has_se = has_se
self.channels = {
18: [[18, 36], [18, 36, 72], [18, 36, 72, 144]],
30: [[30, 60], [30, 60, 120], [30, 60, 120, 240]],
32: [[32, 64], [32, 64, 128], [32, 64, 128, 256]],
40: [[40, 80], [40, 80, 160], [40, 80, 160, 320]],
44: [[44, 88], [44, 88, 176], [44, 88, 176, 352]],
48: [[48, 96], [48, 96, 192], [48, 96, 192, 384]],
60: [[60, 120], [60, 120, 240], [60, 120, 240, 480]],
64: [[64, 128], [64, 128, 256], [64, 128, 256, 512]],
}
self.freeze_at = freeze_at
self.norm_type = norm_type
self.norm_decay = norm_decay
self.freeze_norm = freeze_norm
self.feature_maps = feature_maps
self.num_classes = num_classes
self.end_points = []
return
def net(self, input, class_dim=1000):
width = self.width
channels_2, channels_3, channels_4 = self.channels[width]
num_modules_2, num_modules_3, num_modules_4 = 1, 4, 3
x = self.conv_bn_layer(
input=input,
filter_size=3,
num_filters=64,
stride=2,
if_act=True,
name='layer1_1')
x = self.conv_bn_layer(
input=x,
filter_size=3,
num_filters=64,
stride=2,
if_act=True,
name='layer1_2')
la1 = self.layer1(x, name='layer2')
tr1 = self.transition_layer([la1], [256], channels_2, name='tr1')
st2 = self.stage(tr1, num_modules_2, channels_2, name='st2')
tr2 = self.transition_layer(st2, channels_2, channels_3, name='tr2')
st3 = self.stage(tr2, num_modules_3, channels_3, name='st3')
tr3 = self.transition_layer(st3, channels_3, channels_4, name='tr3')
st4 = self.stage(tr3, num_modules_4, channels_4, name='st4')
# classification
if self.num_classes:
last_cls = self.last_cls_out(x=st4, name='cls_head')
y = last_cls[0]
last_num_filters = [256, 512, 1024]
for i in range(3):
y = fluid.layers.elementwise_add(
last_cls[i + 1],
self.conv_bn_layer(
input=y,
filter_size=3,
num_filters=last_num_filters[i],
stride=2,
name='cls_head_add' + str(i + 1)))
y = self.conv_bn_layer(
input=y,
filter_size=1,
num_filters=2048,
stride=1,
name='cls_head_last_conv')
pool = fluid.layers.pool2d(
input=y, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(
input=pool,
size=class_dim,
param_attr=ParamAttr(
name='fc_weights',
initializer=fluid.initializer.Uniform(-stdv, stdv)),
bias_attr=ParamAttr(name='fc_offset'))
return out
# segmentation
if self.feature_maps == "stage4":
return st4
self.end_points = st4
return st4[-1]
def layer1(self, input, name=None):
conv = input
for i in range(4):
conv = self.bottleneck_block(
conv,
num_filters=64,
downsample=True if i == 0 else False,
name=name + '_' + str(i + 1))
return conv
def transition_layer(self, x, in_channels, out_channels, name=None):
num_in = len(in_channels)
num_out = len(out_channels)
out = []
for i in range(num_out):
if i < num_in:
if in_channels[i] != out_channels[i]:
residual = self.conv_bn_layer(
x[i],
filter_size=3,
num_filters=out_channels[i],
name=name + '_layer_' + str(i + 1))
out.append(residual)
else:
out.append(x[i])
else:
residual = self.conv_bn_layer(
x[-1],
filter_size=3,
num_filters=out_channels[i],
stride=2,
name=name + '_layer_' + str(i + 1))
out.append(residual)
return out
def branches(self, x, block_num, channels, name=None):
out = []
for i in range(len(channels)):
residual = x[i]
for j in range(block_num):
residual = self.basic_block(
residual,
channels[i],
name=name + '_branch_layer_' + str(i + 1) + '_' +
str(j + 1))
out.append(residual)
return out
def fuse_layers(self, x, channels, multi_scale_output=True, name=None):
out = []
for i in range(len(channels) if multi_scale_output else 1):
residual = x[i]
if self.feature_maps == "stage4":
shape = fluid.layers.shape(residual)
width = shape[-1]
height = shape[-2]
for j in range(len(channels)):
if j > i:
y = self.conv_bn_layer(
x[j],
filter_size=1,
num_filters=channels[i],
if_act=False,
name=name + '_layer_' + str(i + 1) + '_' + str(j + 1))
if self.feature_maps == "stage4":
y = fluid.layers.resize_bilinear(
input=y, out_shape=[height, width])
else:
y = fluid.layers.resize_nearest(
input=y, scale=2**(j - i))
residual = fluid.layers.elementwise_add(
x=residual, y=y, act=None)
elif j < i:
y = x[j]
for k in range(i - j):
if k == i - j - 1:
y = self.conv_bn_layer(
y,
filter_size=3,
num_filters=channels[i],
stride=2,
if_act=False,
name=name + '_layer_' + str(i + 1) + '_' +
str(j + 1) + '_' + str(k + 1))
else:
y = self.conv_bn_layer(
y,
filter_size=3,
num_filters=channels[j],
stride=2,
name=name + '_layer_' + str(i + 1) + '_' +
str(j + 1) + '_' + str(k + 1))
residual = fluid.layers.elementwise_add(
x=residual, y=y, act=None)
residual = fluid.layers.relu(residual)
out.append(residual)
return out
def high_resolution_module(self,
x,
channels,
multi_scale_output=True,
name=None):
residual = self.branches(x, 4, channels, name=name)
out = self.fuse_layers(
residual,
channels,
multi_scale_output=multi_scale_output,
name=name)
return out
def stage(self,
x,
num_modules,
channels,
multi_scale_output=True,
name=None):
out = x
for i in range(num_modules):
if i == num_modules - 1 and multi_scale_output == False:
out = self.high_resolution_module(
out,
channels,
multi_scale_output=False,
name=name + '_' + str(i + 1))
else:
out = self.high_resolution_module(
out, channels, name=name + '_' + str(i + 1))
return out
def last_cls_out(self, x, name=None):
out = []
num_filters_list = [32, 64, 128, 256]
for i in range(len(x)):
out.append(
self.bottleneck_block(
input=x[i],
num_filters=num_filters_list[i],
name=name + 'conv_' + str(i + 1),
downsample=True))
return out
def basic_block(self,
input,
num_filters,
stride=1,
downsample=False,
name=None):
residual = input
conv = self.conv_bn_layer(
input=input,
filter_size=3,
num_filters=num_filters,
stride=stride,
name=name + '_conv1')
conv = self.conv_bn_layer(
input=conv,
filter_size=3,
num_filters=num_filters,
if_act=False,
name=name + '_conv2')
if downsample:
residual = self.conv_bn_layer(
input=input,
filter_size=1,
num_filters=num_filters,
if_act=False,
name=name + '_downsample')
if self.has_se:
conv = self.squeeze_excitation(
input=conv,
num_channels=num_filters,
reduction_ratio=16,
name=name + '_fc')
return fluid.layers.elementwise_add(x=residual, y=conv, act='relu')
def bottleneck_block(self,
input,
num_filters,
stride=1,
downsample=False,
name=None):
residual = input
conv = self.conv_bn_layer(
input=input,
filter_size=1,
num_filters=num_filters,
name=name + '_conv1')
conv = self.conv_bn_layer(
input=conv,
filter_size=3,
num_filters=num_filters,
stride=stride,
name=name + '_conv2')
conv = self.conv_bn_layer(
input=conv,
filter_size=1,
num_filters=num_filters * 4,
if_act=False,
name=name + '_conv3')
if downsample:
residual = self.conv_bn_layer(
input=input,
filter_size=1,
num_filters=num_filters * 4,
if_act=False,
name=name + '_downsample')
if self.has_se:
conv = self.squeeze_excitation(
input=conv,
num_channels=num_filters * 4,
reduction_ratio=16,
name=name + '_fc')
return fluid.layers.elementwise_add(x=residual, y=conv, act='relu')
def squeeze_excitation(self,
input,
num_channels,
reduction_ratio,
name=None):
pool = fluid.layers.pool2d(
input=input, pool_size=0, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
squeeze = fluid.layers.fc(
input=pool,
size=num_channels / reduction_ratio,
act='relu',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv),
name=name + '_sqz_weights'),
bias_attr=ParamAttr(name=name + '_sqz_offset'))
stdv = 1.0 / math.sqrt(squeeze.shape[1] * 1.0)
excitation = fluid.layers.fc(
input=squeeze,
size=num_channels,
act='sigmoid',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv),
name=name + '_exc_weights'),
bias_attr=ParamAttr(name=name + '_exc_offset'))
scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0)
return scale
def conv_bn_layer(self,
input,
filter_size,
num_filters,
stride=1,
padding=1,
num_groups=1,
if_act=True,
name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=num_groups,
act=None,
param_attr=ParamAttr(
initializer=MSRA(), name=name + '_weights'),
bias_attr=False)
bn_name = name + '_bn'
bn = self._bn(input=conv, bn_name=bn_name)
if if_act:
bn = fluid.layers.relu(bn)
return bn
def _bn(self, input, act=None, bn_name=None):
norm_lr = 0. if self.freeze_norm else 1.
norm_decay = self.norm_decay
if self.num_classes or self.feature_maps == "stage4":
regularizer = None
pattr_initializer = fluid.initializer.Constant(1.0)
battr_initializer = fluid.initializer.Constant(0.0)
else:
regularizer = L2Decay(norm_decay)
pattr_initializer = None
battr_initializer = None
pattr = ParamAttr(
name=bn_name + '_scale',
learning_rate=norm_lr,
regularizer=regularizer,
initializer=pattr_initializer)
battr = ParamAttr(
name=bn_name + '_offset',
learning_rate=norm_lr,
regularizer=regularizer,
initializer=battr_initializer)
global_stats = True if self.freeze_norm else False
out = fluid.layers.batch_norm(
input=input,
act=act,
name=bn_name + '.output.1',
param_attr=pattr,
bias_attr=battr,
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance',
use_global_stats=global_stats)
scale = fluid.framework._get_var(pattr.name)
bias = fluid.framework._get_var(battr.name)
if self.freeze_norm:
scale.stop_gradient = True
bias.stop_gradient = True
return out
def __call__(self, input):
assert isinstance(input, Variable)
if isinstance(self.feature_maps, (list, tuple)):
assert not (set(self.feature_maps) - set([2, 3, 4, 5])), \
"feature maps {} not in [2, 3, 4, 5]".format(self.feature_maps)
res_endpoints = []
res = input
feature_maps = self.feature_maps
out = self.net(input)
if self.num_classes or self.feature_maps == "stage4":
return out
for i in feature_maps:
res = self.end_points[i - 2]
if i in self.feature_maps:
res_endpoints.append(res)
if self.freeze_at >= i:
res.stop_gradient = True
return OrderedDict([('res{}_sum'.format(self.feature_maps[idx]), feat)
for idx, feat in enumerate(res_endpoints)])
...@@ -79,10 +79,14 @@ class MobileNetV1(object): ...@@ -79,10 +79,14 @@ class MobileNetV1(object):
bn_name = name + "_bn" bn_name = name + "_bn"
norm_decay = self.norm_decay norm_decay = self.norm_decay
if self.num_classes:
regularizer = None
else:
regularizer = L2Decay(norm_decay)
bn_param_attr = ParamAttr( bn_param_attr = ParamAttr(
regularizer=L2Decay(norm_decay), name=bn_name + '_scale') regularizer=regularizer, name=bn_name + '_scale')
bn_bias_attr = ParamAttr( bn_bias_attr = ParamAttr(
regularizer=L2Decay(norm_decay), name=bn_name + '_offset') regularizer=regularizer, name=bn_name + '_offset')
return fluid.layers.batch_norm( return fluid.layers.batch_norm(
input=conv, input=conv,
act=act, act=act,
...@@ -189,11 +193,11 @@ class MobileNetV1(object): ...@@ -189,11 +193,11 @@ class MobileNetV1(object):
if self.num_classes: if self.num_classes:
out = fluid.layers.pool2d( out = fluid.layers.pool2d(
input=out, pool_type='avg', global_pooling=True) input=out, pool_type='avg', global_pooling=True)
output = fluid.layers.fc( output = fluid.layers.fc(input=out,
input=out,
size=self.num_classes, size=self.num_classes,
param_attr=ParamAttr( param_attr=ParamAttr(
initializer=fluid.initializer.MSRA(), name="fc7_weights"), initializer=fluid.initializer.MSRA(),
name="fc7_weights"),
bias_attr=ParamAttr(name="fc7_offset")) bias_attr=ParamAttr(name="fc7_offset"))
return output return output
......
...@@ -31,6 +31,7 @@ class MobileNetV3(): ...@@ -31,6 +31,7 @@ class MobileNetV3():
with_extra_blocks (bool): if extra blocks should be added. with_extra_blocks (bool): if extra blocks should be added.
extra_block_filters (list): number of filter for each extra block. extra_block_filters (list): number of filter for each extra block.
""" """
def __init__(self, def __init__(self,
scale=1.0, scale=1.0,
model_name='small', model_name='small',
...@@ -113,10 +114,16 @@ class MobileNetV3(): ...@@ -113,10 +114,16 @@ class MobileNetV3():
lr_idx = self.curr_stage // self.lr_interval lr_idx = self.curr_stage // self.lr_interval
lr_idx = min(lr_idx, len(self.lr_mult_list) - 1) lr_idx = min(lr_idx, len(self.lr_mult_list) - 1)
lr_mult = self.lr_mult_list[lr_idx] lr_mult = self.lr_mult_list[lr_idx]
conv_param_attr = ParamAttr(name=name + '_weights', if self.num_classes:
regularizer = None
else:
regularizer = L2Decay(self.conv_decay)
conv_param_attr = ParamAttr(
name=name + '_weights',
learning_rate=lr_mult, learning_rate=lr_mult,
regularizer=L2Decay(self.conv_decay)) regularizer=regularizer)
conv = fluid.layers.conv2d(input=input, conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters, num_filters=num_filters,
filter_size=filter_size, filter_size=filter_size,
stride=stride, stride=stride,
...@@ -127,11 +134,12 @@ class MobileNetV3(): ...@@ -127,11 +134,12 @@ class MobileNetV3():
param_attr=conv_param_attr, param_attr=conv_param_attr,
bias_attr=False) bias_attr=False)
bn_name = name + '_bn' bn_name = name + '_bn'
bn_param_attr = ParamAttr(name=bn_name + "_scale", bn_param_attr = ParamAttr(
regularizer=L2Decay(self.norm_decay)) name=bn_name + "_scale", regularizer=L2Decay(self.norm_decay))
bn_bias_attr = ParamAttr(name=bn_name + "_offset", bn_bias_attr = ParamAttr(
regularizer=L2Decay(self.norm_decay)) name=bn_name + "_offset", regularizer=L2Decay(self.norm_decay))
bn = fluid.layers.batch_norm(input=conv, bn = fluid.layers.batch_norm(
input=conv,
param_attr=bn_param_attr, param_attr=bn_param_attr,
bias_attr=bn_bias_attr, bias_attr=bn_bias_attr,
moving_mean_name=bn_name + '_mean', moving_mean_name=bn_name + '_mean',
...@@ -154,10 +162,8 @@ class MobileNetV3(): ...@@ -154,10 +162,8 @@ class MobileNetV3():
lr_mult = self.lr_mult_list[lr_idx] lr_mult = self.lr_mult_list[lr_idx]
num_mid_filter = int(num_out_filter // ratio) num_mid_filter = int(num_out_filter // ratio)
pool = fluid.layers.pool2d(input=input, pool = fluid.layers.pool2d(
pool_type='avg', input=input, pool_type='avg', global_pooling=True, use_cudnn=False)
global_pooling=True,
use_cudnn=False)
conv1 = fluid.layers.conv2d( conv1 = fluid.layers.conv2d(
input=pool, input=pool,
filter_size=1, filter_size=1,
...@@ -191,7 +197,8 @@ class MobileNetV3(): ...@@ -191,7 +197,8 @@ class MobileNetV3():
use_se=False, use_se=False,
name=None): name=None):
input_data = input input_data = input
conv0 = self._conv_bn_layer(input=input, conv0 = self._conv_bn_layer(
input=input,
filter_size=1, filter_size=1,
num_filters=num_mid_filter, num_filters=num_mid_filter,
stride=1, stride=1,
...@@ -201,7 +208,8 @@ class MobileNetV3(): ...@@ -201,7 +208,8 @@ class MobileNetV3():
name=name + '_expand') name=name + '_expand')
if self.block_stride == 16 and stride == 2: if self.block_stride == 16 and stride == 2:
self.end_points.append(conv0) self.end_points.append(conv0)
conv1 = self._conv_bn_layer(input=conv0, conv1 = self._conv_bn_layer(
input=conv0,
filter_size=filter_size, filter_size=filter_size,
num_filters=num_mid_filter, num_filters=num_mid_filter,
stride=stride, stride=stride,
...@@ -213,11 +221,11 @@ class MobileNetV3(): ...@@ -213,11 +221,11 @@ class MobileNetV3():
name=name + '_depthwise') name=name + '_depthwise')
if use_se: if use_se:
conv1 = self._se_block(input=conv1, conv1 = self._se_block(
num_out_filter=num_mid_filter, input=conv1, num_out_filter=num_mid_filter, name=name + '_se')
name=name + '_se')
conv2 = self._conv_bn_layer(input=conv1, conv2 = self._conv_bn_layer(
input=conv1,
filter_size=1, filter_size=1,
num_filters=num_out_filter, num_filters=num_out_filter,
stride=1, stride=1,
...@@ -227,7 +235,8 @@ class MobileNetV3(): ...@@ -227,7 +235,8 @@ class MobileNetV3():
if num_in_filter != num_out_filter or stride != 1: if num_in_filter != num_out_filter or stride != 1:
return conv2 return conv2
else: else:
return fluid.layers.elementwise_add(x=input_data, y=conv2, act=None) return fluid.layers.elementwise_add(
x=input_data, y=conv2, act=None)
def _extra_block_dw(self, def _extra_block_dw(self,
input, input,
...@@ -235,14 +244,16 @@ class MobileNetV3(): ...@@ -235,14 +244,16 @@ class MobileNetV3():
num_filters2, num_filters2,
stride, stride,
name=None): name=None):
pointwise_conv = self._conv_bn_layer(input=input, pointwise_conv = self._conv_bn_layer(
input=input,
filter_size=1, filter_size=1,
num_filters=int(num_filters1), num_filters=int(num_filters1),
stride=1, stride=1,
padding="SAME", padding="SAME",
act='relu6', act='relu6',
name=name + "_extra1") name=name + "_extra1")
depthwise_conv = self._conv_bn_layer(input=pointwise_conv, depthwise_conv = self._conv_bn_layer(
input=pointwise_conv,
filter_size=3, filter_size=3,
num_filters=int(num_filters2), num_filters=int(num_filters2),
stride=stride, stride=stride,
...@@ -251,7 +262,8 @@ class MobileNetV3(): ...@@ -251,7 +262,8 @@ class MobileNetV3():
act='relu6', act='relu6',
use_cudnn=False, use_cudnn=False,
name=name + "_extra2_dw") name=name + "_extra2_dw")
normal_conv = self._conv_bn_layer(input=depthwise_conv, normal_conv = self._conv_bn_layer(
input=depthwise_conv,
filter_size=1, filter_size=1,
num_filters=int(num_filters2), num_filters=int(num_filters2),
stride=1, stride=1,
...@@ -282,7 +294,8 @@ class MobileNetV3(): ...@@ -282,7 +294,8 @@ class MobileNetV3():
self.block_stride *= layer_cfg[5] self.block_stride *= layer_cfg[5]
if layer_cfg[5] == 2: if layer_cfg[5] == 2:
blocks.append(conv) blocks.append(conv)
conv = self._residual_unit(input=conv, conv = self._residual_unit(
input=conv,
num_in_filter=inplanes, num_in_filter=inplanes,
num_mid_filter=int(scale * layer_cfg[1]), num_mid_filter=int(scale * layer_cfg[1]),
num_out_filter=int(scale * layer_cfg[2]), num_out_filter=int(scale * layer_cfg[2]),
...@@ -298,7 +311,8 @@ class MobileNetV3(): ...@@ -298,7 +311,8 @@ class MobileNetV3():
blocks.append(conv) blocks.append(conv)
if self.num_classes: if self.num_classes:
conv = self._conv_bn_layer(input=conv, conv = self._conv_bn_layer(
input=conv,
filter_size=1, filter_size=1,
num_filters=int(scale * self.cls_ch_squeeze), num_filters=int(scale * self.cls_ch_squeeze),
stride=1, stride=1,
...@@ -308,7 +322,8 @@ class MobileNetV3(): ...@@ -308,7 +322,8 @@ class MobileNetV3():
act='hard_swish', act='hard_swish',
name='conv_last') name='conv_last')
conv = fluid.layers.pool2d(input=conv, conv = fluid.layers.pool2d(
input=conv,
pool_type='avg', pool_type='avg',
global_pooling=True, global_pooling=True,
use_cudnn=False) use_cudnn=False)
...@@ -333,7 +348,8 @@ class MobileNetV3(): ...@@ -333,7 +348,8 @@ class MobileNetV3():
return blocks return blocks
# extra block # extra block
conv_extra = self._conv_bn_layer(conv, conv_extra = self._conv_bn_layer(
conv,
filter_size=1, filter_size=1,
num_filters=int(scale * cfg[-1][1]), num_filters=int(scale * cfg[-1][1]),
stride=1, stride=1,
......
...@@ -135,8 +135,10 @@ class ResNet(object): ...@@ -135,8 +135,10 @@ class ResNet(object):
filter_size=filter_size, filter_size=filter_size,
stride=stride, stride=stride,
padding=padding, padding=padding,
param_attr=ParamAttr(initializer=Constant(0.0), name=name + ".w_0"), param_attr=ParamAttr(
bias_attr=ParamAttr(initializer=Constant(0.0), name=name + ".b_0"), initializer=Constant(0.0), name=name + ".w_0"),
bias_attr=ParamAttr(
initializer=Constant(0.0), name=name + ".b_0"),
act=act, act=act,
name=name) name=name)
return out return out
...@@ -151,7 +153,8 @@ class ResNet(object): ...@@ -151,7 +153,8 @@ class ResNet(object):
name=None, name=None,
dcn_v2=False, dcn_v2=False,
use_lr_mult_list=False): use_lr_mult_list=False):
lr_mult = self.lr_mult_list[self.curr_stage] if use_lr_mult_list else 1.0 lr_mult = self.lr_mult_list[
self.curr_stage] if use_lr_mult_list else 1.0
_name = self.prefix_name + name if self.prefix_name != '' else name _name = self.prefix_name + name if self.prefix_name != '' else name
if not dcn_v2: if not dcn_v2:
conv = fluid.layers.conv2d( conv = fluid.layers.conv2d(
...@@ -162,8 +165,8 @@ class ResNet(object): ...@@ -162,8 +165,8 @@ class ResNet(object):
padding=(filter_size - 1) // 2, padding=(filter_size - 1) // 2,
groups=groups, groups=groups,
act=None, act=None,
param_attr=ParamAttr(name=_name + "_weights", param_attr=ParamAttr(
learning_rate=lr_mult), name=_name + "_weights", learning_rate=lr_mult),
bias_attr=False, bias_attr=False,
name=_name + '.conv2d.output.1') name=_name + '.conv2d.output.1')
else: else:
...@@ -202,14 +205,18 @@ class ResNet(object): ...@@ -202,14 +205,18 @@ class ResNet(object):
norm_lr = 0. if self.freeze_norm else lr_mult norm_lr = 0. if self.freeze_norm else lr_mult
norm_decay = self.norm_decay norm_decay = self.norm_decay
if self.num_classes:
regularizer = None
else:
regularizer = L2Decay(norm_decay)
pattr = ParamAttr( pattr = ParamAttr(
name=bn_name + '_scale', name=bn_name + '_scale',
learning_rate=norm_lr, learning_rate=norm_lr,
regularizer=L2Decay(norm_decay)) regularizer=regularizer)
battr = ParamAttr( battr = ParamAttr(
name=bn_name + '_offset', name=bn_name + '_offset',
learning_rate=norm_lr, learning_rate=norm_lr,
regularizer=L2Decay(norm_decay)) regularizer=regularizer)
if self.norm_type in ['bn', 'sync_bn']: if self.norm_type in ['bn', 'sync_bn']:
global_stats = True if self.freeze_norm else False global_stats = True if self.freeze_norm else False
...@@ -262,8 +269,8 @@ class ResNet(object): ...@@ -262,8 +269,8 @@ class ResNet(object):
pool_padding=0, pool_padding=0,
ceil_mode=True, ceil_mode=True,
pool_type='avg') pool_type='avg')
return self._conv_norm(input, ch_out, 1, 1, name=name, return self._conv_norm(
use_lr_mult_list=True) input, ch_out, 1, 1, name=name, use_lr_mult_list=True)
return self._conv_norm(input, ch_out, 1, stride, name=name) return self._conv_norm(input, ch_out, 1, stride, name=name)
else: else:
return input return input
......
...@@ -14,5 +14,6 @@ ...@@ -14,5 +14,6 @@
from .unet import UNet from .unet import UNet
from .deeplabv3p import DeepLabv3p from .deeplabv3p import DeepLabv3p
from .hrnet import HRNet
from .model_utils import libs from .model_utils import libs
from .model_utils import loss from .model_utils import loss
# coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
from .model_utils.libs import sigmoid_to_softmax
from .model_utils.loss import softmax_with_loss
from .model_utils.loss import dice_loss
from .model_utils.loss import bce_loss
import paddlex
import paddlex.utils.logging as logging
class HRNet(object):
def __init__(self,
num_classes,
mode='train',
width=18,
use_bce_loss=False,
use_dice_loss=False,
class_weight=None,
ignore_index=255):
# dice_loss或bce_loss只适用两类分割中
if num_classes > 2 and (use_bce_loss or use_dice_loss):
raise ValueError(
"dice loss and bce loss is only applicable to binary classfication"
)
if class_weight is not None:
if isinstance(class_weight, list):
if len(class_weight) != num_classes:
raise ValueError(
"Length of class_weight should be equal to number of classes"
)
elif isinstance(class_weight, str):
if class_weight.lower() != 'dynamic':
raise ValueError(
"if class_weight is string, must be dynamic!")
else:
raise TypeError(
'Expect class_weight is a list or string but receive {}'.
format(type(class_weight)))
self.num_classes = num_classes
self.mode = mode
self.use_bce_loss = use_bce_loss
self.use_dice_loss = use_dice_loss
self.class_weight = class_weight
self.ignore_index = ignore_index
self.backbone = paddlex.cv.nets.hrnet.HRNet(
width=width, feature_maps="stage4")
def build_net(self, inputs):
if self.use_dice_loss or self.use_bce_loss:
self.num_classes = 1
image = inputs['image']
st4 = self.backbone(image)
# upsample
shape = fluid.layers.shape(st4[0])[-2:]
st4[1] = fluid.layers.resize_bilinear(st4[1], out_shape=shape)
st4[2] = fluid.layers.resize_bilinear(st4[2], out_shape=shape)
st4[3] = fluid.layers.resize_bilinear(st4[3], out_shape=shape)
out = fluid.layers.concat(st4, axis=1)
last_channels = sum(self.backbone.channels[self.backbone.width][-1])
out = self._conv_bn_layer(
input=out,
filter_size=1,
num_filters=last_channels,
stride=1,
if_act=True,
name='conv-2')
out = fluid.layers.conv2d(
input=out,
num_filters=self.num_classes,
filter_size=1,
stride=1,
padding=0,
act=None,
param_attr=ParamAttr(
initializer=MSRA(), name='conv-1_weights'),
bias_attr=False)
input_shape = fluid.layers.shape(image)[-2:]
logit = fluid.layers.resize_bilinear(out, input_shape)
if self.num_classes == 1:
out = sigmoid_to_softmax(logit)
out = fluid.layers.transpose(out, [0, 2, 3, 1])
else:
out = fluid.layers.transpose(logit, [0, 2, 3, 1])
pred = fluid.layers.argmax(out, axis=3)
pred = fluid.layers.unsqueeze(pred, axes=[3])
if self.mode == 'train':
label = inputs['label']
mask = label != self.ignore_index
return self._get_loss(logit, label, mask)
elif self.mode == 'eval':
label = inputs['label']
mask = label != self.ignore_index
loss = self._get_loss(logit, label, mask)
return loss, pred, label, mask
else:
if self.num_classes == 1:
logit = sigmoid_to_softmax(logit)
else:
logit = fluid.layers.softmax(logit, axis=1)
return pred, logit
def generate_inputs(self):
inputs = OrderedDict()
inputs['image'] = fluid.data(
dtype='float32', shape=[None, 3, None, None], name='image')
if self.mode == 'train':
inputs['label'] = fluid.data(
dtype='int32', shape=[None, 1, None, None], name='label')
elif self.mode == 'eval':
inputs['label'] = fluid.data(
dtype='int32', shape=[None, 1, None, None], name='label')
return inputs
def _get_loss(self, logit, label, mask):
avg_loss = 0
if not (self.use_dice_loss or self.use_bce_loss):
avg_loss += softmax_with_loss(
logit,
label,
mask,
num_classes=self.num_classes,
weight=self.class_weight,
ignore_index=self.ignore_index)
else:
if self.use_dice_loss:
avg_loss += dice_loss(logit, label, mask)
if self.use_bce_loss:
avg_loss += bce_loss(
logit, label, mask, ignore_index=self.ignore_index)
return avg_loss
def _conv_bn_layer(self,
input,
filter_size,
num_filters,
stride=1,
padding=1,
num_groups=1,
if_act=True,
name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=num_groups,
act=None,
param_attr=ParamAttr(
initializer=MSRA(), name=name + '_weights'),
bias_attr=False)
bn_name = name + '_bn'
bn = fluid.layers.batch_norm(
input=conv,
param_attr=ParamAttr(
name=bn_name + "_scale",
initializer=fluid.initializer.Constant(1.0)),
bias_attr=ParamAttr(
name=bn_name + "_offset",
initializer=fluid.initializer.Constant(0.0)),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
if if_act:
bn = fluid.layers.relu(bn)
return bn
...@@ -17,5 +17,6 @@ from . import cv ...@@ -17,5 +17,6 @@ from . import cv
UNet = cv.models.UNet UNet = cv.models.UNet
DeepLabv3p = cv.models.DeepLabv3p DeepLabv3p = cv.models.DeepLabv3p
HRNet = cv.models.HRNet
transforms = cv.transforms.seg_transforms transforms = cv.transforms.seg_transforms
visualize = cv.models.utils.visualize.visualize_segmentation visualize = cv.models.utils.visualize.visualize_segmentation
import os
# 选择使用0号卡
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import paddlex as pdx
from paddlex.seg import transforms
# 下载和解压视盘分割数据集
optic_dataset = 'https://bj.bcebos.com/paddlex/datasets/optic_disc_seg.tar.gz'
pdx.utils.download_and_decompress(optic_dataset, path='./')
# 定义训练和验证时的transforms
train_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(), transforms.ResizeRangeScaling(),
transforms.RandomPaddingCrop(crop_size=512), transforms.Normalize()
])
eval_transforms = transforms.Compose([
transforms.ResizeByLong(long_size=512),
transforms.Padding(target_size=512), transforms.Normalize()
])
# 定义训练和验证所用的数据集
train_dataset = pdx.datasets.SegDataset(
data_dir='optic_disc_seg',
file_list='optic_disc_seg/train_list.txt',
label_list='optic_disc_seg/labels.txt',
transforms=train_transforms,
shuffle=True)
eval_dataset = pdx.datasets.SegDataset(
data_dir='optic_disc_seg',
file_list='optic_disc_seg/val_list.txt',
label_list='optic_disc_seg/labels.txt',
transforms=eval_transforms)
# 初始化模型,并进行训练
# 可使用VisualDL查看训练指标
# VisualDL启动方式: visualdl --logdir output/unet/vdl_log --port 8001
# 浏览器打开 https://0.0.0.0:8001即可
# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
num_classes = len(train_dataset.labels)
model = pdx.seg.HRNet(num_classes=num_classes)
model.train(
num_epochs=20,
train_dataset=train_dataset,
train_batch_size=4,
eval_dataset=eval_dataset,
learning_rate=0.01,
save_dir='output/hrnet',
use_vdl=True)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册