提交 bad7914f 编写于 作者: 氢键H-H's avatar 氢键H-H 🇨🇳

init

上级
.vscode
.pytest_cache
__pycache__
resources/
\ No newline at end of file
# FCN-DenseNet
[![License](https://img.shields.io/badge/license-Apache%202-blue.svg)](./LICENSE)
[![Release](https://img.shields.io/badge/release-v1.0-blue)](https://github.com/JoveH-H/FCN-DenseNet/releases/tag/v1.0)
[![Author](https://img.shields.io/badge/Author-Jove-%2300a8ff)](https://github.com/JoveH-H)
详情参考:[《语义分割 FCN-DenseNet 应用入门》](https://joveh-h.blog.csdn.net/article/details/125654652)
import os
import numpy as np
import PIL.Image as Image
import torch
from torch.utils import data
import random
import scipy.stats
import cv2
class MySynData(data.Dataset):
"""
synthesis data
"""
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
def __init__(self, obj_root, bg_root, transform=True, hflip=False, vflip=False, crop=False):
super(MySynData, self).__init__()
self.obj_root, self.bg_root = obj_root, bg_root
self.is_transform = transform
self.is_hflip = hflip
self.is_vflip = vflip
self.is_crop = crop
obj_names = os.listdir(obj_root)
bg_names = os.listdir(bg_root)
self.name_combs = [(obj_name, bg_name) for obj_name in obj_names for bg_name in bg_names]
def __len__(self):
return len(self.name_combs)
def __getitem__(self, index):
obj_name, bg_name = self.name_combs[index]
obj = Image.open('%s/%s' % (self.obj_root, obj_name))
bg = Image.open('%s/%s' % (self.bg_root, bg_name))
sbc, sbr = bg.size
ratio = 400.0 / max(sbr, sbc)
bg = bg.resize((int(sbc * ratio), int(sbr * ratio)))
bg = np.array(bg, dtype=np.uint8)
r, c, _ = bg.shape
r_location = scipy.stats.weibull_min.rvs(1.56, 0, 0.22, size=1)[0] * r
r_location = int(r_location)
r_location = min(r_location, r-1)
c_location = scipy.stats.weibull_min.rvs(1.72, 0, 0.27, size=1)[0] * c
c_location = int(c_location)
c_location = min(c_location, c-1)
length = scipy.stats.norm.rvs(0.61, 0.07, size=1)[0] * max(r, c)
length = max(length, 10)
sbc, sbr = obj.size
ratio = length / max(sbr, sbc)
obj = obj.resize((int(sbc * ratio), int(sbr * ratio)))
sbc, sbr = obj.size
r_location_end = min(r_location + sbr, r)
c_location_end = min(c_location + sbc, c)
obj_r_end = min(r_location_end - r_location, sbr)
obj_c_end = min(c_location_end - c_location, sbc)
obj = np.array(obj, dtype=np.uint8)
m_obj = obj[:, :, 3]
m_obj[m_obj != 0] = 1
m_obj = np.expand_dims(m_obj, 2)
obj = obj[:, :, :3]
mask = np.zeros((bg.shape[0], bg.shape[1], 1))
bg[r_location:r_location_end, c_location:c_location_end] = \
bg[r_location:r_location_end, c_location:c_location_end] * (1 - m_obj[:obj_r_end, :obj_c_end]) \
+ obj[:obj_r_end, :obj_c_end] * m_obj[:obj_r_end, :obj_c_end]
mask[r_location:r_location_end, c_location:c_location_end] = \
m_obj[:obj_r_end, :obj_c_end]
bg = bg.astype(np.uint8)
mask = mask.astype(np.uint8)
mask[mask != 0] = 1
if self.is_crop:
H = int(0.9 * bg.shape[0])
W = int(0.9 * bg.shape[1])
H_offset = random.choice(range(bg.shape[0] - H))
W_offset = random.choice(range(bg.shape[1] - W))
H_slice = slice(H_offset, H_offset + H)
W_slice = slice(W_offset, W_offset + W)
bg = bg[H_slice, W_slice, :]
mask = mask[H_slice, W_slice]
if self.is_hflip and random.randint(0, 1):
bg = bg[:, ::-1, :]
mask = mask[:, ::-1]
if self.is_vflip and random.randint(0, 1):
bg = bg[::-1, :, :]
mask = mask[::-1, :]
bg = cv2.resize(bg, dsize=(256, 256), interpolation=cv2.INTER_NEAREST)
mask = cv2.resize(mask, dsize=(256, 256), interpolation=cv2.INTER_NEAREST)
if self.is_transform:
bg, mask = self.transform(bg, mask)
return bg, mask
else:
return bg, mask
def transform(self, img, gt):
img = img.astype(np.float64) / 255
img -= self.mean
img /= self.std
img = img.transpose(2, 0, 1)
img = torch.from_numpy(img).float()
gt = torch.from_numpy(gt)
return img, gt
class MyData(data.Dataset):
"""
load images for testing
root: director/to/images/
structure:
- root
- images (images here)
- masks (ground truth)
"""
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
def __init__(self, root, transform=True, hflip=False, vflip=False, crop=False):
super(MyData, self).__init__()
self.root = root
self.is_transform = transform
self.is_hflip = hflip
self.is_vflip = vflip
self.is_crop = crop
img_root = os.path.join(self.root, 'img')
gt_root = os.path.join(self.root, 'mask')
file_names = os.listdir(gt_root)
self.img_names = []
self.map_names = []
self.gt_names = []
self.names = []
for i, name in enumerate(file_names):
if not name.endswith('.png'):
continue
self.img_names.append(img_root + '/' + name[:-4] + '.png')
self.gt_names.append(gt_root + '/' + name[:-4] + '.png')
self.names.append(name[:-4])
def __len__(self):
return len(self.gt_names)
def __getitem__(self, index):
# load image
img_file = self.img_names[index]
img = Image.open(img_file)
img = np.array(img, dtype=np.uint8)
if len(img.shape) < 3:
img = np.stack((img, img, img), 2)
if img.shape[2] > 3:
img = img[:, :, :3]
gt_file = self.gt_names[index]
gt = Image.open(gt_file)
gt = np.array(gt, dtype=np.int32)
gt[gt != 0] = 1
if self.is_crop:
H = int(0.9 * img.shape[0])
W = int(0.9 * img.shape[1])
H_offset = random.choice(range(img.shape[0] - H))
W_offset = random.choice(range(img.shape[1] - W))
H_slice = slice(H_offset, H_offset + H)
W_slice = slice(W_offset, W_offset + W)
img = img[H_slice, W_slice, :]
gt = gt[H_slice, W_slice]
if self.is_hflip and random.randint(0, 1):
img = img[:, ::-1, :]
gt = gt[:, ::-1]
if self.is_vflip and random.randint(0, 1):
img = img[::-1, :, :]
gt = gt[::-1, :]
img = cv2.resize(img, dsize=(256, 256), interpolation=cv2.INTER_NEAREST)
gt = cv2.resize(gt, dsize=(256, 256), interpolation=cv2.INTER_NEAREST)
if self.is_transform:
img, gt = self.transform(img, gt)
return img, gt
else:
return img, gt
def transform(self, img, gt):
img = img.astype(np.float64) / 255
img -= self.mean
img /= self.std
img = img.transpose(2, 0, 1)
img = torch.from_numpy(img).float()
gt = torch.from_numpy(gt)
return img, gt
class MyTestData(data.Dataset):
"""
load images for testing
root: director/to/images/
structure:
- root
- images (images here)
- masks (ground truth)
"""
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
def __init__(self, root, transform=True):
super(MyTestData, self).__init__()
self.root = root
self._transform = transform
img_root = os.path.join(self.root, 'img')
file_names = os.listdir(img_root)
self.img_names = []
self.names = []
for i, name in enumerate(file_names):
if not name.endswith('.png'):
continue
self.img_names.append(img_root + '/' + name[:-4] + '.png')
self.names.append(name[:-4])
def __len__(self):
return len(self.img_names)
def __getitem__(self, index):
# load image
img_file = self.img_names[index]
img = Image.open(img_file)
img_size = img.size
img = img.resize((256, 256))
img = np.array(img, dtype=np.uint8)
if self._transform:
img = self.transform(img)
return img, self.names[index], img_size
else:
return img, self.names[index], img_size
def transform(self, img):
img = img.astype(np.float64) / 255
img -= self.mean
img /= self.std
img = img.transpose(2, 0, 1)
img = torch.from_numpy(img).float()
return img
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from collections import OrderedDict
import re
__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
model_urls = {
'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
}
def densenet89(pretrained=False, **kwargs):
r"""Densenet-89 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 8, 16, 12),
**kwargs)
model.classifier = None
features = model.features
features.block0 = nn.Sequential(features.conv0, features.norm0, features.relu0, features.pool0)
features.denseblock1 = nn.Sequential(*list(features.denseblock1))
features.transition1 = nn.Sequential(*list(features.transition1)[:-1])
features.denseblock2 = nn.Sequential(*list(features.denseblock2))
features.transition2 = nn.Sequential(*list(features.transition2)[:-1])
features.denseblock3 = nn.Sequential(*list(features.denseblock3))
features.transition3 = nn.Sequential(*list(features.transition3)[:-1])
features.denseblock4 = nn.Sequential(*(list(features.denseblock4)+[features.norm5]))
model.features = features
return model
def densenet57(pretrained=False, **kwargs):
r"""Densenet-121 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(4, 6, 8, 8),
**kwargs)
model.classifier = None
features = model.features
features.block0 = nn.Sequential(features.conv0, features.norm0, features.relu0, features.pool0)
features.denseblock1 = nn.Sequential(*list(features.denseblock1))
features.transition1 = nn.Sequential(*list(features.transition1)[:-1])
features.denseblock2 = nn.Sequential(*list(features.denseblock2))
features.transition2 = nn.Sequential(*list(features.transition2)[:-1])
features.denseblock3 = nn.Sequential(*list(features.denseblock3))
features.transition3 = nn.Sequential(*list(features.transition3)[:-1])
features.denseblock4 = nn.Sequential(*(list(features.denseblock4)+[features.norm5]))
model.features = features
return model
def densenet121(pretrained=False, **kwargs):
r"""Densenet-121 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16),
**kwargs)
if pretrained:
# '.'s are no longer allowed in module names, but pervious _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
# state_dict = model_zoo.load_url(model_urls['densenet121'], model_dir='./pretrained')
state_dict = 'densenet121-a639ec97.pth'
state_dict = torch.load(state_dict)
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
model.load_state_dict(state_dict)
model.classifier = None
features = model.features
features.block0 = nn.Sequential(features.conv0, features.norm0, features.relu0, features.pool0)
features.denseblock1 = nn.Sequential(*list(features.denseblock1))
features.transition1 = nn.Sequential(*list(features.transition1)[:-1])
features.denseblock2 = nn.Sequential(*list(features.denseblock2))
features.transition2 = nn.Sequential(*list(features.transition2)[:-1])
features.denseblock3 = nn.Sequential(*list(features.denseblock3))
features.transition3 = nn.Sequential(*list(features.transition3)[:-1])
features.denseblock4 = nn.Sequential(*(list(features.denseblock4)+[features.norm5]))
model.features = features
return model
def densenet169(pretrained=False, **kwargs):
r"""Densenet-169 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32),
**kwargs)
if pretrained:
# '.'s are no longer allowed in module names, but pervious _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
state_dict = model_zoo.load_url(model_urls['densenet169'])
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
model.load_state_dict(state_dict)
model.classifier = None
features = model.features
features.block0 = nn.Sequential(features.conv0, features.norm0, features.relu0, features.pool0)
features.denseblock1 = nn.Sequential(*list(features.denseblock1))
features.transition1 = nn.Sequential(*list(features.transition1)[:-1])
features.denseblock2 = nn.Sequential(*list(features.denseblock2))
features.transition2 = nn.Sequential(*list(features.transition2)[:-1])
features.denseblock3 = nn.Sequential(*list(features.denseblock3))
features.transition3 = nn.Sequential(*list(features.transition3)[:-1])
features.denseblock4 = nn.Sequential(*(list(features.denseblock4) + [features.norm5]))
model.features = features
return model
def densenet201(pretrained=False, **kwargs):
r"""Densenet-201 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32),
**kwargs)
if pretrained:
# '.'s are no longer allowed in module names, but pervious _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
state_dict = model_zoo.load_url(model_urls['densenet201'])
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
model.load_state_dict(state_dict)
model.classifier = None
features = model.features
features.block0 = nn.Sequential(features.conv0, features.norm0, features.relu0, features.pool0)
features.denseblock1 = nn.Sequential(*list(features.denseblock1))
features.transition1 = nn.Sequential(*list(features.transition1)[:-1])
features.denseblock2 = nn.Sequential(*list(features.denseblock2))
features.transition2 = nn.Sequential(*list(features.transition2)[:-1])
features.denseblock3 = nn.Sequential(*list(features.denseblock3))
features.transition3 = nn.Sequential(*list(features.transition3)[:-1])
features.denseblock4 = nn.Sequential(*(list(features.denseblock4) + [features.norm5]))
model.features = features
return model
def densenet161(pretrained=False, **kwargs):
r"""Densenet-161 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),
**kwargs)
if pretrained:
# '.'s are no longer allowed in module names, but pervious _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
state_dict = model_zoo.load_url(model_urls['densenet161'])
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
model.load_state_dict(state_dict)
model.classifier = None
features = model.features
features.block0 = nn.Sequential(features.conv0, features.norm0, features.relu0, features.pool0)
features.denseblock1 = nn.Sequential(*list(features.denseblock1))
features.transition1 = nn.Sequential(*list(features.transition1)[:-1])
features.denseblock2 = nn.Sequential(*list(features.denseblock2))
features.transition2 = nn.Sequential(*list(features.transition2)[:-1])
features.denseblock3 = nn.Sequential(*list(features.denseblock3))
features.transition3 = nn.Sequential(*list(features.transition3)[:-1])
features.denseblock4 = nn.Sequential(*(list(features.denseblock4) + [features.norm5]))
model.features = features
return model
class _DenseLayer(nn.Sequential):
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
super(_DenseLayer, self).__init__()
self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
self.add_module('relu1', nn.ReLU(inplace=True)),
self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
growth_rate, kernel_size=1, stride=1, bias=False)),
self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
self.add_module('relu2', nn.ReLU(inplace=True)),
self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
kernel_size=3, stride=1, padding=1, bias=False)),
self.drop_rate = drop_rate
def forward(self, x):
new_features = super(_DenseLayer, self).forward(x)
if self.drop_rate > 0:
new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
return torch.cat([x, new_features], 1)
class _DenseBlock(nn.Sequential):
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
super(_DenseBlock, self).__init__()
for i in range(num_layers):
layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate)
self.add_module('denselayer%d' % (i + 1), layer)
class _Transition(nn.Sequential):
def __init__(self, num_input_features, num_output_features):
super(_Transition, self).__init__()
self.add_module('norm', nn.BatchNorm2d(num_input_features))
self.add_module('relu', nn.ReLU(inplace=True))
self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
kernel_size=1, stride=1, bias=False))
self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
class DenseNet(nn.Module):
r"""Densenet-BC model class, based on
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args:
growth_rate (int) - how many filters to add each layer (`k` in paper)
block_config (list of 4 ints) - how many layers in each pooling block
num_init_features (int) - the number of filters to learn in the first convolution layer
bn_size (int) - multiplicative factor for number of bottle neck layers
(i.e. bn_size * k features in the bottleneck layer)
drop_rate (float) - dropout rate after each dense layer
num_classes (int) - number of classification classes
"""
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000):
super(DenseNet, self).__init__()
# First convolution
self.features = nn.Sequential(OrderedDict([
('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
('norm0', nn.BatchNorm2d(num_init_features)),
('relu0', nn.ReLU(inplace=True)),
('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
]))
# Each denseblock
num_features = num_init_features
for i, num_layers in enumerate(block_config):
block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,
bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate)
self.features.add_module('denseblock%d' % (i + 1), block)
num_features = num_features + num_layers * growth_rate
if i != len(block_config) - 1:
trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)
self.features.add_module('transition%d' % (i + 1), trans)
num_features = num_features // 2
# Final batch norm
self.features.add_module('norm5', nn.BatchNorm2d(num_features))
# Linear layer
self.classifier = nn.Linear(num_features, num_classes)
# Official init from torch repo.
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight.data)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.bias.data.zero_()
def forward(self, x):
outputs = []
x = self.features.block0(x) # 1/4
x = self.features.denseblock1(x)
x = self.features.transition1(x)
x = F.avg_pool2d(x, kernel_size=2, stride=2) # 1/8
x = self.features.denseblock2(x)
x = self.features.transition2(x)
outputs.append(x)
x = F.avg_pool2d(x, kernel_size=2, stride=2) # 1/16
x = self.features.denseblock3(x)
x = self.features.transition3(x)
outputs.append(x)
x = F.avg_pool2d(x, kernel_size=2, stride=2) # 1/32
x = self.features.denseblock4(x)
outputs.append(x)
return outputs
import torch.nn as nn
def nothing(x):
return x
dim_dict = {
'resnet101': [512, 1024, 2048],
'resnet152': [512, 1024, 2048],
'resnet50': [512, 1024, 2048],
'resnet34': [128, 256, 512],
'resnet18': [128, 256, 512],
'densenet57': [144, 200, 456],
'densenet89': [192, 352, 736],
'densenet121': [256, 512, 1024],
'densenet161': [384, 1056, 2208],
'densenet169': [256, 640, 1664],
'densenet201': [256, 896, 1920]
}
class Deconv(nn.Module):
def __init__(self, base='vgg'):
super(Deconv, self).__init__()
if base == 'vgg':
self.pred5 = nn.Sequential(
nn.Conv2d(512, 1, kernel_size=1),
nn.ReLU()
)
self.reduce_channels = [nothing, nothing, nothing]
else:
self.pred5 = nn.Sequential(
nn.Conv2d(512, 1, kernel_size=1),
nn.ReLU(),
nn.UpsamplingBilinear2d(scale_factor=2)
)
self.reduce_channels = nn.ModuleList([
nn.Conv2d(in_dim, out_dim, kernel_size=1) for in_dim, out_dim in zip(dim_dict[base], [256, 512, 512])
])
self.pred4 = nn.Sequential(
nn.Conv2d(512, 1, kernel_size=1),
nn.ReLU(),
nn.UpsamplingBilinear2d(scale_factor=2)
)
self.pred3 = nn.Sequential(
nn.Conv2d(256, 1, kernel_size=1),
nn.UpsamplingBilinear2d(scale_factor=8)
)
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0, 0.01)
m.bias.data.fill_(0)
def forward(self, x):
x = [r(_x) for r, _x in zip(self.reduce_channels, x)]
pred5 = self.pred5(x[2])
pred4 = self.pred4(pred5 + x[1])
pred3 = self.pred3(pred4 + x[0])
return pred3
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from dataset import MyTestData
from model import Deconv
import densenet
import numpy as np
import os
import sys
import argparse
import time
from PIL import Image
home = os.path.expanduser("~")
parser = argparse.ArgumentParser()
parser.add_argument('--input_dir', default='./resources/images/data/test/') # test dataset
parser.add_argument('--output_dir', default='./resources/images/data/test/') # test dataset
parser.add_argument('--para_dir', default='./parameters_densenet121/') # parameters
parser.add_argument('--b', type=int, default=1) # batch size
parser.add_argument('--q', default='densenet121') # save checkpoint parameters
opt = parser.parse_args()
print(opt)
def main():
if not os.path.exists(opt.output_dir):
os.mkdir(opt.output_dir)
bsize = opt.b
feature = getattr(densenet, opt.q)(pretrained=False)
feature.cuda()
feature.eval()
sb = torch.load('%s/feature_model.pth' % opt.para_dir)
feature.load_state_dict(sb)
deconv = Deconv(opt.q)
deconv.cuda()
deconv.eval()
sb = torch.load('%s/deconv_model.pth' % opt.para_dir)
deconv.load_state_dict(sb)
test_loader = torch.utils.data.DataLoader(MyTestData(opt.input_dir), batch_size=bsize, shuffle=False, num_workers=1, pin_memory=True)
step_len = len(test_loader)
for id, (data, img_name, img_size) in enumerate(test_loader):
inputs = Variable(data).cuda()
start_time = time.time()
feats = feature(inputs)
outputs = deconv(feats)
outputs = F.sigmoid(outputs)
outputs = outputs.data.cpu().squeeze(1).numpy()
end_time = time.time()
for i, msk in enumerate(outputs):
msk = (msk * 255).astype(np.uint8)
msk = Image.fromarray(msk)
msk = msk.resize((img_size[0][i], img_size[1][i]))
msk.save('%s/%s_pred.png' % (opt.output_dir, img_name[i]), 'PNG')
# 显示进度
step_now = id + 1
step_schedule_num = int(40 * step_now / step_len)
print("\r", end="")
print("step: {}/{} [{}{}] - time: {:.2f}ms".format(step_now, step_len,
">" * step_schedule_num,
"-" * (40 - step_schedule_num),
(end_time - start_time) * 1000), end="")
sys.stdout.flush()
print("\r")
if __name__ == "__main__":
main()
import os
from shutil import copyfile, move
import random
absolute_path = os.path.abspath(os.path.dirname(__file__)).replace('\\', '/')
dataset_dir_path = absolute_path + '/resources/images/dataset'
data_dir_path = absolute_path + '/resources/images/data'
# config
data_type = 1
movement_type = 0
img_id_num = 6
test_ratio = 0.1
movement_type_list = ['copy', 'move']
movement_type = movement_type_list[movement_type]
input("任意键开始")
FileNameList = os.listdir(dataset_dir_path)
FileNameList.sort() # 排序
random.shuffle(FileNameList)
len_img = len(FileNameList)
now_id = 0
for i in range(len_img):
# 判断当前文件是否为_img文件
if os.path.splitext(FileNameList[i])[0][-4:] == "_img":
now_id += 1
if now_id > len_img / 2 * test_ratio:
data_type_path = data_dir_path + '/train'
else:
data_type_path = data_dir_path + '/test'
if movement_type == 'copy':
copyfile(dataset_dir_path + '/' + FileNameList[i], data_type_path + '/img/' + FileNameList[i][0:img_id_num] + '.png')
copyfile(dataset_dir_path + '/' + os.path.splitext(FileNameList[i])[0][:-4] + '_label.png', data_type_path + '/mask/' + FileNameList[i][0:img_id_num] + '.png')
else:
move(dataset_dir_path + '/' + FileNameList[i], data_type_path + '/img/' + FileNameList[i][0:img_id_num] + '.png')
move(dataset_dir_path + '/' + os.path.splitext(FileNameList[i])[0][:-4] + '_label.png', data_type_path + '/mask/' + FileNameList[i][0:img_id_num] + '.png')
input("任意键退出")
import os
from shutil import copyfile
import random
absolute_path = os.path.abspath(os.path.dirname(__file__)).replace('\\', '/')
img_dir_path = absolute_path + '/resources/images/images'
lab_dir_path = absolute_path + '/resources/images/annotations/trimaps'
data_dir_path = absolute_path + '/resources/images/data'
# config
data_type = 1
movement_type = 0
img_id_num = 6
test_ratio = 0.03
input("任意键开始")
ImgFileNameList = os.listdir(img_dir_path)
LabFileNameList = os.listdir(lab_dir_path)
ImgFileNameList.sort() # 排序
random.shuffle(ImgFileNameList)
len_img = len(ImgFileNameList)
for i in range(len_img):
if i + 1 > len_img * test_ratio:
copyfile(img_dir_path + '/' + ImgFileNameList[i], data_dir_path + '/train/img/' + ImgFileNameList[i])
copyfile(lab_dir_path + '/' + ImgFileNameList[i], data_dir_path + '/train/mask/' + ImgFileNameList[i])
else:
copyfile(img_dir_path + '/' + ImgFileNameList[i], data_dir_path + '/test/img/' + ImgFileNameList[i])
copyfile(lab_dir_path + '/' + ImgFileNameList[i], data_dir_path + '/test/mask/' + ImgFileNameList[i])
input("任意键退出")
import os
import numpy as np
import imgviz
from PIL import Image
absolute_path = os.path.abspath(os.path.dirname(__file__)).replace('\\', '/')
data_dir_path = absolute_path + '/resources/images/annotations/trimaps'
ImgFileNameList = os.listdir(data_dir_path)
def get_gray_cls(van_lbl, array_lbl):
cls = [2, 3] # 用来存储灰度图像中每种类别所对应的像素,默认背景色为0
for x in range(van_lbl.size[0]):
for y in range(van_lbl.size[1]):
if array_lbl[y, x] not in cls:
cls.append(array_lbl[x, y])
return cls
def get_P_cls(cls_gray):
cls_P = [] # 将灰度图像中的每类像素用0~N表示
for i in range(len(cls_gray)):
cls_P.append(i)
return cls_P
def array_gray_to_P(cls_gray, cls_P, array):
for i in range(len(cls_gray)):
array[array == cls_gray[i]] = cls_P[i]
return array
if __name__ == '__main__':
van_file = data_dir_path + '/Abyssinian_1.png' # 必须是一张包含所有类别的图像,称之为先锋图像
van_lbl = Image.open(van_file).convert('L') # 将先锋图像转换为灰度图像
array_lbl = np.array(van_lbl) # 获得灰度图像的numpy矩阵
cls_gray = get_gray_cls(van_lbl, array_lbl) # 获取灰度图像中每种类别所对应的像素值
cls_P = get_P_cls(cls_gray) # 将灰度图像中的每种类别所对应的像素值映射为0~N
# 遍历每一张原始图像
len_img = len(ImgFileNameList)
for i in range(len_img):
orig_lbl = Image.open(data_dir_path + '/' + ImgFileNameList[i]).convert('L') # 将图像转换为灰度图像
array_gray = np.array(orig_lbl) # 获得灰度图像的numpy矩阵
array_P = array_gray_to_P(cls_gray, cls_P, array_gray) # 将灰度图像的numpy矩阵值映射为0~N
label = Image.fromarray(array_P.astype(np.uint8), mode='P') # 转换为PIL的P模式
# 转换成VOC格式的P模式图像
colormap = imgviz.label_colormap()
label.putpalette(colormap.flatten())
label.save(data_dir_path + '/' + ImgFileNameList[i])
import base64
import json
import os
import os.path as osp
import imgviz
import PIL.Image
from labelme import utils
absolute_path = os.path.abspath(os.path.dirname(__file__)).replace('\\', '/')
images_dir_path = absolute_path + '/resources/images/data_json'
dataset_dir_path = absolute_path + '/resources/images/dataset'
def main(json_file_dir=images_dir_path, out_dir=dataset_dir_path):
input("任意键开始")
if not osp.exists(out_dir):
os.mkdir(out_dir)
FileNameList = os.listdir(json_file_dir)
for i in range(len(FileNameList)):
# 判断当前文件是否为json文件
if os.path.splitext(FileNameList[i])[1] == ".json":
json_file = json_file_dir + '/' + FileNameList[i]
data = json.load(open(json_file))
imageData = data.get("imageData")
if not imageData:
imagePath = os.path.join(os.path.dirname(json_file), data["imagePath"])
with open(imagePath, "rb") as f:
imageData = f.read()
imageData = base64.b64encode(imageData).decode("utf-8")
img = utils.img_b64_to_arr(imageData)
label_name_to_value = {"_background_": 0}
for shape in sorted(data["shapes"], key=lambda x: x["label"]):
label_name = shape["label"]
if label_name in label_name_to_value:
label_value = label_name_to_value[label_name]
else:
label_value = len(label_name_to_value)
label_name_to_value[label_name] = label_value
lbl = utils.shapes_to_label(img.shape, data['shapes'], label_name_to_value)
label_names = [None] * (max(label_name_to_value.values()) + 1)
for name, value in label_name_to_value.items():
label_names[value] = name
lbl_viz = imgviz.label2rgb(
label=lbl, img=imgviz.asgray(img), label_names=label_names, loc="rb"
)
PIL.Image.fromarray(img).save(osp.join(out_dir, "{}_img.png".format(os.path.splitext(FileNameList[i])[0])))
utils.lblsave(osp.join(out_dir, "{}_label.png".format(os.path.splitext(FileNameList[i])[0])), lbl)
print("Finish to: {}".format(json_file))
input("任意键退出")
if __name__ == "__main__":
main()
import torch
import os
import densenet
from model import Deconv
import argparse
# 文件绝对地址
Absolute_File_Path = os.path.dirname(__file__).replace('\\', '/')
parser = argparse.ArgumentParser()
parser.add_argument('--input_dir', default='./resources/images/data/test/') # training dataset
parser.add_argument('--output_dir', default='./resources/images/data/test/') # training dataset
parser.add_argument('--para_dir', default='./parameters_densenet89/') # training dataset
parser.add_argument('--b', type=int, default=1) # batch size
parser.add_argument('--q', default='densenet89') # save checkpoint parameters
opt = parser.parse_args()
print(opt)
Net_Class_Set = 1
Net_List = ['feature', 'deconv']
Net_Input_List = [torch.rand(1, 3, 320, 240), torch.rand(1, 3, 192, 32, 32)]
# 需要训练的模型类型地址,具体请参考说明文档
Net_Class = Net_List[Net_Class_Set]
Net_Input = Net_Input_List[Net_Class_Set]
# 实例化一个网络对象
if Net_Class_Set == 0:
model = getattr(densenet, opt.q)(pretrained=True).cpu()
else:
model = Deconv(opt.q).cpu()
Model_File_Path = Absolute_File_Path + "/model/{}_model.pth".format(Net_Class)
Onnx_File_Path = Absolute_File_Path + "/model/{}_model.onnx".format(Net_Class)
model.load_state_dict(torch.load(Model_File_Path, map_location='cpu'))
model.eval()
def torch2onnx(model, save_path):
"""
:param model:pkl
:param save_path:onnx
:return:onnx
"""
model.eval()
data = Net_Input
input_names = ["{}_input".format(Net_Class)]
output_names = ["{}_out".format(Net_Class)]
torch.onnx._export(model, data, save_path, export_params=True, opset_version=11, input_names=input_names, output_names=output_names)
input("torch2onnx finish. 任意键退出...")
if __name__ == '__main__':
torch2onnx(model, Onnx_File_Path)
import gc
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from dataset import MyData
from model import Deconv
import densenet
import os
import sys
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--train_dir', default='./resources/images/data/train/') # training dataset
parser.add_argument('--val_dir', default='./resources/images/data/test/') # test dataset
parser.add_argument('--check_dir', default='./parameters') # save checkpoint parameters
parser.add_argument('--q', default='densenet121') # save checkpoint parameters
parser.add_argument('--b', type=int, default=4) # batch size
parser.add_argument('--e', type=int, default=100) # epoches
parser.add_argument('--svae_interval', type=int, default=5) # svae interval
opt = parser.parse_args()
def validation(feature, net, loader):
feature.eval()
net.eval()
total_loss = 0
for ib, (data, lbl) in enumerate(loader):
inputs = Variable(data).cuda()
lbl = Variable(lbl.float().unsqueeze(1)).cuda()
feats = feature(inputs)
msk = net(feats)
loss = F.binary_cross_entropy_with_logits(msk, lbl)
total_loss += loss.item()
feature.train()
net.train()
return total_loss / len(loader)
def main():
train_dir = opt.train_dir
val_dir = opt.val_dir
check_dir = opt.check_dir + '_' + opt.q
bsize = opt.b
epoch_sum = opt.e
svae_interval = opt.svae_interval
if not os.path.exists(check_dir):
os.mkdir(check_dir)
feature = getattr(densenet, opt.q)(pretrained=True)
feature.cuda()
deconv = Deconv(opt.q)
deconv.cuda()
train_loader = torch.utils.data.DataLoader(MyData(train_dir, transform=True, crop=False, hflip=False, vflip=False),
batch_size=bsize, shuffle=True, num_workers=1, pin_memory=True)
val_loader = torch.utils.data.DataLoader(MyData(val_dir, transform=True, crop=False, hflip=False, vflip=False),
batch_size=bsize, shuffle=False, num_workers=1, pin_memory=True)
optimizer = torch.optim.AdamW([
{'params': feature.parameters(), 'lr': 1e-3},
{'params': deconv.parameters(), 'lr': 1e-3},
])
min_loss = 10000.0
for it in range(epoch_sum):
step_len = len(train_loader)
for ib, (data, lbl) in enumerate(train_loader):
inputs = Variable(data).cuda()
lbl = Variable(lbl.float().unsqueeze(1)).cuda()
feats = feature(inputs)
msk = deconv(feats)
loss = F.binary_cross_entropy_with_logits(msk, lbl)
deconv.zero_grad()
feature.zero_grad()
loss.backward()
optimizer.step()
# 显示进度
step_now = ib + 1
step_schedule_num = int(40 * step_now / step_len)
epoch_now = it + 1
print("\r", end="")
print("epoch: {}/{} step: {}/{} [{}{}] - loss: {:.5f}".format(epoch_now, epoch_sum,
step_now, step_len,
">" * step_schedule_num,
"-" * (40 - step_schedule_num),
loss.item()), end="")
sys.stdout.flush()
# 清除变量和内存
del inputs, msk, lbl, loss, feats
gc.collect()
print("\r")
if epoch_now % svae_interval == 0:
val_loss = validation(feature, deconv, val_loader)
if val_loss < min_loss:
filename = ('{}/deconv_model.pth'.format(check_dir))
torch.save(deconv.state_dict(), filename)
filename = ('{}/feature_model.pth'.format(check_dir))
torch.save(feature.state_dict(), filename)
print('epoch: {} val loss: {:.5f} save model'.format(epoch_now, val_loss))
min_loss = val_loss
else:
print('epoch: {} val loss: {:.5f} pass'.format(epoch_now, val_loss))
if __name__ == "__main__":
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册