提交 355832ee 编写于 作者: F FlyingQianMM

add hrnet_w18_small_v1 for segmentation

上级 8391440b
import os
# 选择使用0号卡
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import paddlex as pdx
from paddlex.seg import transforms
# 下载和解压人像分割数据集
human_seg_data = 'https://paddlex.bj.bcebos.com/humanseg/data/human_seg_data.zip'
pdx.utils.download_and_decompress(human_seg_data, path='./')
# 下载和解压人像分割预训练模型
pretrain_weights = 'https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_mobile_ckpt.zip'
pdx.utils.download_and_decompress(
pretrain_weights, path='./output/human_seg/pretrain')
# 定义训练和验证时的transforms
train_transforms = transforms.Compose([
transforms.Resize([192, 192]), transforms.RandomHorizontalFlip(),
transforms.Normalize()
])
eval_transforms = transforms.Compose(
[transforms.Resize([192, 192]), transforms.Normalize()])
# 定义训练和验证所用的数据集
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/datasets/semantic_segmentation.html#segdataset
train_dataset = pdx.datasets.SegDataset(
data_dir='human_seg_data',
file_list='human_seg_data/train_list.txt',
label_list='human_seg_data/labels.txt',
transforms=train_transforms,
shuffle=True)
eval_dataset = pdx.datasets.SegDataset(
data_dir='human_seg_data',
file_list='human_seg_data/val_list.txt',
label_list='human_seg_data/labels.txt',
transforms=eval_transforms)
# 初始化模型,并进行训练
# 可使用VisualDL查看训练指标
# VisualDL启动方式: visualdl --logdir output/unet/vdl_log --port 8001
# 浏览器打开 https://0.0.0.0:8001即可
# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
# https://paddlex.readthedocs.io/zh_CN/latest/apis/models/semantic_segmentation.html#hrnet
num_classes = len(train_dataset.labels)
model = pdx.seg.HRNet(num_classes=num_classes, width='18_small_v1')
model.train(
num_epochs=10,
train_dataset=train_dataset,
train_batch_size=8,
eval_dataset=eval_dataset,
learning_rate=0.001,
pretrain_weights='./output/human_seg/pretrain/humanseg_mobile_ckpt',
save_dir='output/human_seg',
use_vdl=True)
......@@ -186,10 +186,10 @@ paddlex.seg.HRNet(num_classes=2, width=18, use_bce_loss=False, use_dice_loss=Fal
> **参数**
> > - **num_classes** (int): 类别数。
> > - **width** (int): 高分辨率分支中特征层的通道数量。默认值为18。可选择取值为[18, 30, 32, 40, 44, 48, 60, 64]
> > - **width** (int|str): 高分辨率分支中特征层的通道数量。默认值为18。可选择取值为[18, 30, 32, 40, 44, 48, 60, 64, '18_small_v1']。'18_small_v1'是18的轻量级版本
> > - **use_bce_loss** (bool): 是否使用bce loss作为网络的损失函数,只能用于两类分割。可与dice loss同时使用。默认False。
> > - **use_dice_loss** (bool): 是否使用dice loss作为网络的损失函数,只能用于两类分割,可与bce loss同时使用。当use_bce_loss和use_dice_loss都为False时,使用交叉熵损失函数。默认False。
> > - **class_weight** (list/str): 交叉熵损失函数各类损失的权重。当`class_weight`为list的时候,长度应为`num_classes`。当`class_weight`为str时, weight.lower()应为'dynamic',这时会根据每一轮各类像素的比重自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,即平时使用的交叉熵损失函数。
> > - **class_weight** (list|str): 交叉熵损失函数各类损失的权重。当`class_weight`为list的时候,长度应为`num_classes`。当`class_weight`为str时, weight.lower()应为'dynamic',这时会根据每一轮各类像素的比重自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,即平时使用的交叉熵损失函数。
> > - **ignore_index** (int): label上忽略的值,label为`ignore_index`的像素不参与损失函数的计算。默认255。
### train 训练接口
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -24,11 +24,12 @@ class HRNet(DeepLabv3p):
Args:
num_classes (int): 类别数。
width (int): 高分辨率分支中特征层的通道数量。默认值为18。可选择取值为[18, 30, 32, 40, 44, 48, 60, 64]。
width (int|str): 高分辨率分支中特征层的通道数量。默认值为18。可选择取值为[18, 30, 32, 40, 44, 48, 60, 64, '18_small_v1']。
'18_small_v1'是18的轻量级版本。
use_bce_loss (bool): 是否使用bce loss作为网络的损失函数,只能用于两类分割。可与dice loss同时使用。默认False。
use_dice_loss (bool): 是否使用dice loss作为网络的损失函数,只能用于两类分割,可与bce loss同时使用。
当use_bce_loss和use_dice_loss都为False时,使用交叉熵损失函数。默认False。
class_weight (list/str): 交叉熵损失函数各类损失的权重。当class_weight为list的时候,长度应为
class_weight (list|str): 交叉熵损失函数各类损失的权重。当class_weight为list的时候,长度应为
num_classes。当class_weight为str时, weight.lower()应为'dynamic',这时会根据每一轮各类像素的比重
自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,
即平时使用的交叉熵损失函数。
......@@ -173,6 +174,6 @@ class HRNet(DeepLabv3p):
return super(HRNet, self).train(
num_epochs, train_dataset, train_batch_size, eval_dataset,
save_interval_epochs, log_interval_steps, save_dir,
pretrain_weights, optimizer, learning_rate, lr_decay_power, use_vdl,
sensitivities_file, eval_metric_loss, early_stop,
pretrain_weights, optimizer, learning_rate, lr_decay_power,
use_vdl, sensitivities_file, eval_metric_loss, early_stop,
early_stop_patience, resume_checkpoint)
......@@ -51,15 +51,38 @@ class HRNet(object):
self.width = width
self.has_se = has_se
self.num_modules = {
'18_small_v1': [1, 1, 1, 1],
'18': [1, 1, 4, 3],
'30': [1, 1, 4, 3],
'32': [1, 1, 4, 3],
'40': [1, 1, 4, 3],
'44': [1, 1, 4, 3],
'48': [1, 1, 4, 3],
'60': [1, 1, 4, 3],
'64': [1, 1, 4, 3]
}
self.num_blocks = {
'18_small_v1': [[1], [2, 2], [2, 2, 2], [2, 2, 2, 2]],
'18': [[4], [4, 4], [4, 4, 4], [4, 4, 4, 4]],
'30': [[4], [4, 4], [4, 4, 4], [4, 4, 4, 4]],
'32': [[4], [4, 4], [4, 4, 4], [4, 4, 4, 4]],
'40': [[4], [4, 4], [4, 4, 4], [4, 4, 4, 4]],
'44': [[4], [4, 4], [4, 4, 4], [4, 4, 4, 4]],
'48': [[4], [4, 4], [4, 4, 4], [4, 4, 4, 4]],
'60': [[4], [4, 4], [4, 4, 4], [4, 4, 4, 4]],
'64': [[4], [4, 4], [4, 4, 4], [4, 4, 4, 4]]
}
self.channels = {
18: [[18, 36], [18, 36, 72], [18, 36, 72, 144]],
30: [[30, 60], [30, 60, 120], [30, 60, 120, 240]],
32: [[32, 64], [32, 64, 128], [32, 64, 128, 256]],
40: [[40, 80], [40, 80, 160], [40, 80, 160, 320]],
44: [[44, 88], [44, 88, 176], [44, 88, 176, 352]],
48: [[48, 96], [48, 96, 192], [48, 96, 192, 384]],
60: [[60, 120], [60, 120, 240], [60, 120, 240, 480]],
64: [[64, 128], [64, 128, 256], [64, 128, 256, 512]],
'18_small_v1': [[32], [16, 32], [16, 32, 64], [16, 32, 64, 128]],
'18': [[64], [18, 36], [18, 36, 72], [18, 36, 72, 144]],
'30': [[64], [30, 60], [30, 60, 120], [30, 60, 120, 240]],
'32': [[64], [32, 64], [32, 64, 128], [32, 64, 128, 256]],
'40': [[64], [40, 80], [40, 80, 160], [40, 80, 160, 320]],
'44': [[64], [44, 88], [44, 88, 176], [44, 88, 176, 352]],
'48': [[64], [48, 96], [48, 96, 192], [48, 96, 192, 384]],
'60': [[64], [60, 120], [60, 120, 240], [60, 120, 240, 480]],
'64': [[64], [64, 128], [64, 128, 256], [64, 128, 256, 512]],
}
self.freeze_at = freeze_at
......@@ -73,31 +96,38 @@ class HRNet(object):
def net(self, input):
width = self.width
channels_2, channels_3, channels_4 = self.channels[width]
num_modules_2, num_modules_3, num_modules_4 = 1, 4, 3
channels_1, channels_2, channels_3, channels_4 = self.channels[str(
width)]
num_modules_1, num_modules_2, num_modules_3, num_modules_4 = self.num_modules[
str(width)]
num_blocks_1, num_blocks_2, num_blocks_3, num_blocks_4 = self.num_blocks[
str(width)]
x = self.conv_bn_layer(
input=input,
filter_size=3,
num_filters=64,
num_filters=channels_1[0],
stride=2,
if_act=True,
name='layer1_1')
x = self.conv_bn_layer(
input=x,
filter_size=3,
num_filters=64,
num_filters=channels_1[0],
stride=2,
if_act=True,
name='layer1_2')
la1 = self.layer1(x, name='layer2')
la1 = self.layer1(x, num_blocks_1, channels_1, name='layer2')
tr1 = self.transition_layer([la1], [256], channels_2, name='tr1')
st2 = self.stage(tr1, num_modules_2, channels_2, name='st2')
st2 = self.stage(
tr1, num_modules_2, num_blocks_2, channels_2, name='st2')
tr2 = self.transition_layer(st2, channels_2, channels_3, name='tr2')
st3 = self.stage(tr2, num_modules_3, channels_3, name='st3')
st3 = self.stage(
tr2, num_modules_3, num_blocks_3, channels_3, name='st3')
tr3 = self.transition_layer(st3, channels_3, channels_4, name='tr3')
st4 = self.stage(tr3, num_modules_4, channels_4, name='st4')
st4 = self.stage(
tr3, num_modules_4, num_blocks_4, channels_4, name='st4')
# classification
if self.num_classes:
......@@ -139,12 +169,12 @@ class HRNet(object):
self.end_points = st4
return st4[-1]
def layer1(self, input, name=None):
def layer1(self, input, num_blocks, channels, name=None):
conv = input
for i in range(4):
for i in range(num_blocks[0]):
conv = self.bottleneck_block(
conv,
num_filters=64,
num_filters=channels[0],
downsample=True if i == 0 else False,
name=name + '_' + str(i + 1))
return conv
......@@ -178,7 +208,7 @@ class HRNet(object):
out = []
for i in range(len(channels)):
residual = x[i]
for j in range(block_num):
for j in range(block_num[i]):
residual = self.basic_block(
residual,
channels[i],
......@@ -240,10 +270,11 @@ class HRNet(object):
def high_resolution_module(self,
x,
num_blocks,
channels,
multi_scale_output=True,
name=None):
residual = self.branches(x, 4, channels, name=name)
residual = self.branches(x, num_blocks, channels, name=name)
out = self.fuse_layers(
residual,
channels,
......@@ -254,6 +285,7 @@ class HRNet(object):
def stage(self,
x,
num_modules,
num_blocks,
channels,
multi_scale_output=True,
name=None):
......@@ -262,12 +294,13 @@ class HRNet(object):
if i == num_modules - 1 and multi_scale_output == False:
out = self.high_resolution_module(
out,
num_blocks,
channels,
multi_scale_output=False,
name=name + '_' + str(i + 1))
else:
out = self.high_resolution_module(
out, channels, name=name + '_' + str(i + 1))
out, num_blocks, channels, name=name + '_' + str(i + 1))
return out
......
......@@ -83,7 +83,8 @@ class HRNet(object):
st4[3] = fluid.layers.resize_bilinear(st4[3], out_shape=shape)
out = fluid.layers.concat(st4, axis=1)
last_channels = sum(self.backbone.channels[self.backbone.width][-1])
last_channels = sum(self.backbone.channels[str(self.backbone.width)][
-1])
out = self._conv_bn_layer(
input=out,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册