diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..f21db8032425d5ff77d89415478b27e5fa012a87 --- /dev/null +++ b/LICENSE @@ -0,0 +1,22 @@ + +MIT License + +Copyright (c) 2021 Eric.Lee2021 + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 960c43c2815b094dcef10ee5a8aaf2e939b3129c..ab1622d7df3f36ef0c880048df64524616b5bca8 100644 --- a/README.md +++ b/README.md @@ -1 +1,41 @@ -# Face Parsing +# face parsing +人脸区域分割 + +## 项目介绍 +注意:该项目不包括人脸检测部分,人脸检测项目地址:https://codechina.csdn.net/EricLee/yolo_v3 + +* 图片示例: +![image](https://codechina.csdn.net/EricLee/faceparsing/-/raw/master/samples/t.jpg) + +* 视频示例: +![video](https://codechina.csdn.net/EricLee/faceparsing/-/raw/master/samples/sample.gif) + +## 项目配置 +* 作者开发环境: +* Python 3.7 +* PyTorch >= 1.5.1 + +## 数据集 +* CelebAMask-HQ dataset,数据下载地址: + https://github.com/switchablenorms/CelebAMask-HQ + +``` +• The CelebAMask-HQ dataset is available for non-commercial research purposes only. +• You agree not to reproduce, duplicate, copy, sell, trade, resell or exploit for any commercial purposes, any portion of the images and any portion of derived data. +• You agree not to further copy, publish or distribute any portion of the CelebAMask-HQ dataset. Except, for internal use at a single site within the same organization it is allowed to make copies of the dataset. + +``` + +* 数据集制作 + 下载数据集并解压,然后运行脚本 prepropess_data.py,生成训练用的mask,注意脚本内相关参数配置。 + +## 预训练模型 +* [预训练模型下载地址(百度网盘 Password: )]() + +## 项目使用方法 + +### 模型训练 +* 根目录下运行命令: python train.py (注意脚本内相关参数配置 ) + +### 模型推理 +* 根目录下运行命令: python inference.py (注意脚本内相关参数配置 ) diff --git a/face_dataset.py b/face_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..49b5865cc251de936628c50e679aa4f30736266f --- /dev/null +++ b/face_dataset.py @@ -0,0 +1,60 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + +import torch +from torch.utils.data import Dataset +import torchvision.transforms as transforms + +import os.path as osp +import os +from PIL import Image +import numpy as np +import json +import cv2 + +from transform import * + + + +class FaceMask(Dataset): + def __init__(self, rootpth,img_size, cropsize=(640, 480), mode='train', *args, **kwargs): + super(FaceMask, self).__init__(*args, **kwargs) + assert mode in ('train', 'val', 'test') + self.mode = mode + self.ignore_lb = 255 + self.rootpth = rootpth + self.img_size = img_size + + self.imgs = os.listdir(os.path.join(self.rootpth, 'CelebA-HQ-img')) + + # pre-processing + self.to_tensor = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ]) + self.trans_train = Compose([ + ColorJitter( + brightness=0.5, + contrast=0.5, + saturation=0.5), + HorizontalFlip(), + RandomScale((0.75, 1.0, 1.25, 1.5, 1.75, 2.0)), + RandomCrop(cropsize) + ]) + + def __getitem__(self, idx): + impth = self.imgs[idx] + img = Image.open(osp.join(self.rootpth, 'CelebA-HQ-img', impth)) + img = img.resize((self.img_size,self.img_size), Image.BILINEAR) + label = Image.open(osp.join(self.rootpth, 'mask_{}'.format(self.img_size), impth[:-3]+'png')).convert('P') + # print(np.unique(np.array(label))) + if self.mode == 'train': + im_lb = dict(im=img, lb=label) + im_lb = self.trans_train(im_lb) + img, label = im_lb['im'], im_lb['lb'] + img = self.to_tensor(img) + label = np.array(label).astype(np.int64)[np.newaxis, :] + return img, label + + def __len__(self): + return len(self.imgs) diff --git a/images/2020-09-06_17-01-01_835030.jpg b/images/2020-09-06_17-01-01_835030.jpg new file mode 100644 index 0000000000000000000000000000000000000000..96f41c0f1958916adfb97fe4af748a0972be830d Binary files /dev/null and b/images/2020-09-06_17-01-01_835030.jpg differ diff --git a/images/2020-09-06_17-09-01_763315.jpg b/images/2020-09-06_17-09-01_763315.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d2b8ec04d0b8ff5cdcf510a06c78d893d1a74fe1 Binary files /dev/null and b/images/2020-09-06_17-09-01_763315.jpg differ diff --git a/images/2020-09-06_20-31-54_377040.jpg b/images/2020-09-06_20-31-54_377040.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e4921c11a1524d1f3c04098b5a7dcad9a96c665f Binary files /dev/null and b/images/2020-09-06_20-31-54_377040.jpg differ diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..302f35bac6fd53e32c83f852198258262d86af22 --- /dev/null +++ b/inference.py @@ -0,0 +1,113 @@ +# -*- encoding: utf-8 -*- + +from model import BiSeNet +import torch + +import os +import os.path as osp +import numpy as np +from PIL import Image +import torchvision.transforms as transforms +import cv2 + +# Compute gaussian kernel +def CenterGaussianHeatMap(img_height, img_width, c_x, c_y, variance): + gaussian_map = np.zeros((img_height, img_width)) + for x_p in range(img_width): + for y_p in range(img_height): + dist_sq = (x_p - c_x) * (x_p - c_x) + \ + (y_p - c_y) * (y_p - c_y) + exponent = dist_sq / 2.0 / variance / variance + gaussian_map[y_p, x_p] = np.exp(-exponent) + return gaussian_map + +def vis_parsing_maps(im, parsing_anno,x,y, stride): + # Colors for all 20 parts + part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], + [255, 0, 85], [255, 0, 170], + [0, 255, 0], [85, 255, 0], [170, 255, 0], + [0, 255, 85], [0, 255, 170], + [0, 0, 255], [85, 0, 255], [170, 0, 255], + [0, 85, 255], [0, 170, 255], + [255, 255, 0], [255, 255, 85], [255, 255, 170], + [255, 0, 255], [255, 85, 255], [255, 170, 255], + [0, 255, 255], [85, 255, 255], [170, 255, 255]] + + im = np.array(im) + vis_im = im.copy().astype(np.uint8) + vis_parsing_anno = parsing_anno.copy() + vis_parsing_anno_color = np.zeros((im.shape[0], im.shape[1], 3)) + 0 + + face_mask = np.zeros((im.shape[0], im.shape[1])) + + num_of_class = np.max(vis_parsing_anno) + + for pi in range(1, num_of_class + 1): + index = np.where(vis_parsing_anno == pi)# 获得对应分类的的像素坐标 + + idx_y = (index[0]+y).astype(np.int) + idx_x = (index[1]+x).astype(np.int) + + # continue + vis_parsing_anno_color[idx_y,idx_x, :] = part_colors[pi]# 给对应的类别的掩码赋值 + + face_mask[idx_y,idx_x] = 0.45 + # if pi in[1,2,3,4,5,6,7,8,10,11,12,13,14,17]: + # face_mask[idx_y,idx_x] = 0.35 + + vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) + + face_mask = np.expand_dims(face_mask, 2) + vis_im = vis_parsing_anno_color*face_mask + (1.-face_mask)*vis_im + vis_im = vis_im.astype(np.uint8) + + return vis_im + + +def inference( img_size, image_path, model_path): + n_classes = 19 + net = BiSeNet(n_classes=n_classes) + net.cuda() + + print('model : {}'.format(model_path)) + net.load_state_dict(torch.load(model_path)) + net.eval() + + to_tensor = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ]) + with torch.no_grad(): + idx = 0 + for f_ in os.listdir(image_path): + img_ = cv2.imread(image_path + f_) + img = Image.fromarray(cv2.cvtColor(img_,cv2.COLOR_BGR2RGB)) + + image = img.resize((img_size, img_size), Image.BILINEAR) + img = to_tensor(image) + img = torch.unsqueeze(img, 0) + img = img.cuda() + out = net(img)[0] + parsing_ = out.squeeze(0).cpu().numpy().argmax(0) + idx += 1 + print('<{}> image : '.format(idx),np.unique(parsing_)) + + parsing_ = cv2.resize(parsing_,(img_.shape[1],img_.shape[0]),interpolation=cv2.INTER_NEAREST) + parsing_ = parsing_.astype(np.uint8) + vis_im = vis_parsing_maps(img_, parsing_, 0,0,stride=1) + + # 保存输出结果 + test_result = './result/' + if not osp.exists(test_result): + os.makedirs(test_result) + cv2.imwrite(test_result+"p_{}-".format(img_size)+f_,vis_im) + + cv2.namedWindow("vis_im",0) + cv2.imshow("vis_im",vis_im) + if cv2.waitKey(500) == 27: + break +if __name__ == "__main__": + img_size = 512 + model_path = "./model_exp/2021-02-23_22-03-22/fp_latest.pth" + image_path = "./images/" + inference(img_size = img_size, image_path=image_path, model_path=model_path) diff --git a/loss.py b/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..f8f65aa05566853cb87678d97926bd03b911e166 --- /dev/null +++ b/loss.py @@ -0,0 +1,75 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np + + +class OhemCELoss(nn.Module): + def __init__(self, thresh, n_min, ignore_lb=255, *args, **kwargs): + super(OhemCELoss, self).__init__() + self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).cuda() + self.n_min = n_min + self.ignore_lb = ignore_lb + self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none') + + def forward(self, logits, labels): + N, C, H, W = logits.size() + loss = self.criteria(logits, labels).view(-1) + loss, _ = torch.sort(loss, descending=True) + if loss[self.n_min] > self.thresh: + loss = loss[loss>self.thresh] + else: + loss = loss[:self.n_min] + return torch.mean(loss) + + +class SoftmaxFocalLoss(nn.Module): + def __init__(self, gamma, ignore_lb=255, *args, **kwargs): + super(SoftmaxFocalLoss, self).__init__() + self.gamma = gamma + self.nll = nn.NLLLoss(ignore_index=ignore_lb) + + def forward(self, logits, labels): + scores = F.softmax(logits, dim=1) + factor = torch.pow(1.-scores, self.gamma) + log_score = F.log_softmax(logits, dim=1) + log_score = factor * log_score + loss = self.nll(log_score, labels) + return loss + + +if __name__ == '__main__': + torch.manual_seed(15) + criteria1 = OhemCELoss(thresh=0.7, n_min=16*20*20//16).cuda() + criteria2 = OhemCELoss(thresh=0.7, n_min=16*20*20//16).cuda() + net1 = nn.Sequential( + nn.Conv2d(3, 19, kernel_size=3, stride=2, padding=1), + ) + net1.cuda() + net1.train() + net2 = nn.Sequential( + nn.Conv2d(3, 19, kernel_size=3, stride=2, padding=1), + ) + net2.cuda() + net2.train() + + with torch.no_grad(): + inten = torch.randn(16, 3, 20, 20).cuda() + lbs = torch.randint(0, 19, [16, 20, 20]).cuda() + lbs[1, :, :] = 255 + + logits1 = net1(inten) + logits1 = F.interpolate(logits1, inten.size()[2:], mode='bilinear') + logits2 = net2(inten) + logits2 = F.interpolate(logits2, inten.size()[2:], mode='bilinear') + + loss1 = criteria1(logits1, lbs) + loss2 = criteria2(logits2, lbs) + loss = loss1 + loss2 + print(loss.detach().cpu()) + loss.backward() diff --git a/model.py b/model.py new file mode 100644 index 0000000000000000000000000000000000000000..040f41ffe57c3a2278e4c4db68749716ef45c304 --- /dev/null +++ b/model.py @@ -0,0 +1,283 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision + +from resnet import Resnet18 +# from modules.bn import InPlaceABNSync as BatchNorm2d + + +class ConvBNReLU(nn.Module): + def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs): + super(ConvBNReLU, self).__init__() + self.conv = nn.Conv2d(in_chan, + out_chan, + kernel_size = ks, + stride = stride, + padding = padding, + bias = False) + self.bn = nn.BatchNorm2d(out_chan) + self.init_weight() + + def forward(self, x): + x = self.conv(x) + x = F.relu(self.bn(x)) + return x + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + +class BiSeNetOutput(nn.Module): + def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs): + super(BiSeNetOutput, self).__init__() + self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) + self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False) + self.init_weight() + + def forward(self, x): + x = self.conv(x) + x = self.conv_out(x) + return x + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +class AttentionRefinementModule(nn.Module): + def __init__(self, in_chan, out_chan, *args, **kwargs): + super(AttentionRefinementModule, self).__init__() + self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) + self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False) + self.bn_atten = nn.BatchNorm2d(out_chan) + self.sigmoid_atten = nn.Sigmoid() + self.init_weight() + + def forward(self, x): + feat = self.conv(x) + atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = self.conv_atten(atten) + atten = self.bn_atten(atten) + atten = self.sigmoid_atten(atten) + out = torch.mul(feat, atten) + return out + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + +class ContextPath(nn.Module): + def __init__(self, *args, **kwargs): + super(ContextPath, self).__init__() + self.resnet = Resnet18() + self.arm16 = AttentionRefinementModule(256, 128) + self.arm32 = AttentionRefinementModule(512, 128) + self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) + self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) + self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) + + self.init_weight() + + def forward(self, x): + H0, W0 = x.size()[2:] + feat8, feat16, feat32 = self.resnet(x) + H8, W8 = feat8.size()[2:] + H16, W16 = feat16.size()[2:] + H32, W32 = feat32.size()[2:] + + avg = F.avg_pool2d(feat32, feat32.size()[2:]) + avg = self.conv_avg(avg) + avg_up = F.interpolate(avg, (H32, W32), mode='nearest') + + feat32_arm = self.arm32(feat32) + feat32_sum = feat32_arm + avg_up + feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest') + feat32_up = self.conv_head32(feat32_up) + + feat16_arm = self.arm16(feat16) + feat16_sum = feat16_arm + feat32_up + feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest') + feat16_up = self.conv_head16(feat16_up) + + return feat8, feat16_up, feat32_up # x8, x8, x16 + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, (nn.Linear, nn.Conv2d)): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +### This is not used, since I replace this with the resnet feature with the same size +class SpatialPath(nn.Module): + def __init__(self, *args, **kwargs): + super(SpatialPath, self).__init__() + self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3) + self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) + self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) + self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0) + self.init_weight() + + def forward(self, x): + feat = self.conv1(x) + feat = self.conv2(feat) + feat = self.conv3(feat) + feat = self.conv_out(feat) + return feat + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +class FeatureFusionModule(nn.Module): + def __init__(self, in_chan, out_chan, *args, **kwargs): + super(FeatureFusionModule, self).__init__() + self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) + self.conv1 = nn.Conv2d(out_chan, + out_chan//4, + kernel_size = 1, + stride = 1, + padding = 0, + bias = False) + self.conv2 = nn.Conv2d(out_chan//4, + out_chan, + kernel_size = 1, + stride = 1, + padding = 0, + bias = False) + self.relu = nn.ReLU(inplace=True) + self.sigmoid = nn.Sigmoid() + self.init_weight() + + def forward(self, fsp, fcp): + fcat = torch.cat([fsp, fcp], dim=1) + feat = self.convblk(fcat) + atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = self.conv1(atten) + atten = self.relu(atten) + atten = self.conv2(atten) + atten = self.sigmoid(atten) + feat_atten = torch.mul(feat, atten) + feat_out = feat_atten + feat + return feat_out + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +class BiSeNet(nn.Module): + def __init__(self, n_classes, *args, **kwargs): + super(BiSeNet, self).__init__() + self.cp = ContextPath() + ## here self.sp is deleted + self.ffm = FeatureFusionModule(256, 256) + self.conv_out = BiSeNetOutput(256, 256, n_classes) + self.conv_out16 = BiSeNetOutput(128, 64, n_classes) + self.conv_out32 = BiSeNetOutput(128, 64, n_classes) + self.init_weight() + + def forward(self, x): + H, W = x.size()[2:] + feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature + feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature + feat_fuse = self.ffm(feat_sp, feat_cp8) + + feat_out = self.conv_out(feat_fuse) + feat_out16 = self.conv_out16(feat_cp8) + feat_out32 = self.conv_out32(feat_cp16) + + feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True) + feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True) + feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True) + return feat_out, feat_out16, feat_out32 + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], [] + for name, child in self.named_children(): + child_wd_params, child_nowd_params = child.get_params() + if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput): + lr_mul_wd_params += child_wd_params + lr_mul_nowd_params += child_nowd_params + else: + wd_params += child_wd_params + nowd_params += child_nowd_params + return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params + + +if __name__ == "__main__": + net = BiSeNet(19) + net.cuda() + net.eval() + in_ten = torch.randn(16, 3, 640, 480).cuda() + out, out16, out32 = net(in_ten) + print(out.shape) + + net.get_params() diff --git a/prepropess_data.py b/prepropess_data.py new file mode 100644 index 0000000000000000000000000000000000000000..54de918f43f16a583a8177e8589b00873e554fe5 --- /dev/null +++ b/prepropess_data.py @@ -0,0 +1,47 @@ +# -*- encoding: utf-8 -*- +#function : 训练样本预处理 + +import os +import os.path as osp +import cv2 +from transform import * +from PIL import Image + +if __name__ == "__main__": + + image_size = 256# 样本分辨率 + + face_data = './CelebAMask-HQ/CelebA-HQ-img' + face_sep_mask = './CelebAMask-HQ/CelebAMask-HQ-mask-anno' + mask_path = './CelebAMask-HQ/mask_{}'.format(image_size) + + if not os.path.exists(mask_path): + os.mkdir(mask_path) + + counter = 0 + total = 0 + for i in range(15): + + atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r', + 'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat'] + + for j in range(i * 2000, (i + 1) * 2000): + + mask = np.zeros((512, 512)) + + for l, att in enumerate(atts, 1): + total += 1 + file_name = ''.join([str(j).rjust(5, '0'), '_', att, '.png']) + path = osp.join(face_sep_mask, str(i), file_name) + + if os.path.exists(path): + counter += 1 + sep_mask = np.array(Image.open(path).convert('P')) + + mask[sep_mask == 225] = l + if image_size != 512: + mask = cv2.resize(mask,(image_size,image_size),interpolation=cv2.INTER_NEAREST) + cv2.imwrite('{}/{}.png'.format(mask_path, j), mask) + print(j) + + print(counter, total) diff --git a/resnet.py b/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..aa2bf95130e9815ba378cb6f73207068b81a04b9 --- /dev/null +++ b/resnet.py @@ -0,0 +1,109 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.model_zoo as modelzoo + +# from modules.bn import InPlaceABNSync as BatchNorm2d + +resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + def __init__(self, in_chan, out_chan, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(in_chan, out_chan, stride) + self.bn1 = nn.BatchNorm2d(out_chan) + self.conv2 = conv3x3(out_chan, out_chan) + self.bn2 = nn.BatchNorm2d(out_chan) + self.relu = nn.ReLU(inplace=True) + self.downsample = None + if in_chan != out_chan or stride != 1: + self.downsample = nn.Sequential( + nn.Conv2d(in_chan, out_chan, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(out_chan), + ) + + def forward(self, x): + residual = self.conv1(x) + residual = F.relu(self.bn1(residual)) + residual = self.conv2(residual) + residual = self.bn2(residual) + + shortcut = x + if self.downsample is not None: + shortcut = self.downsample(x) + + out = shortcut + residual + out = self.relu(out) + return out + + +def create_layer_basic(in_chan, out_chan, bnum, stride=1): + layers = [BasicBlock(in_chan, out_chan, stride=stride)] + for i in range(bnum-1): + layers.append(BasicBlock(out_chan, out_chan, stride=1)) + return nn.Sequential(*layers) + + +class Resnet18(nn.Module): + def __init__(self): + super(Resnet18, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) + self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) + self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) + self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) + self.init_weight() + + def forward(self, x): + x = self.conv1(x) + x = F.relu(self.bn1(x)) + x = self.maxpool(x) + + x = self.layer1(x) + feat8 = self.layer2(x) # 1/8 + feat16 = self.layer3(feat8) # 1/16 + feat32 = self.layer4(feat16) # 1/32 + return feat8, feat16, feat32 + + def init_weight(self): + state_dict = modelzoo.load_url(resnet18_url) + self_state_dict = self.state_dict() + for k, v in state_dict.items(): + if 'fc' in k: continue + self_state_dict.update({k: v}) + self.load_state_dict(self_state_dict) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, (nn.Linear, nn.Conv2d)): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +if __name__ == "__main__": + net = Resnet18() + x = torch.randn(16, 3, 224, 224) + out = net(x) + print(out[0].size()) + print(out[1].size()) + print(out[2].size()) + net.get_params() diff --git a/samples/sample.gif b/samples/sample.gif new file mode 100644 index 0000000000000000000000000000000000000000..ce27992758a29c438ceba0b768f64c5af5d67530 Binary files /dev/null and b/samples/sample.gif differ diff --git a/samples/t.jpg b/samples/t.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4d970f352b53768dd167a4b745d59c2180dc2981 Binary files /dev/null and b/samples/t.jpg differ diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..244e26cd7e82744fb5508f232e511feaad5322eb --- /dev/null +++ b/train.py @@ -0,0 +1,158 @@ +# -*- encoding: utf-8 -*- + +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "0" +from model import BiSeNet +from face_dataset import FaceMask +from loss import OhemCELoss +import torch.optim as Optimizer +import cv2 +import numpy as np + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +import torch.nn.functional as F + +import os.path as osp +import time +import datetime +import argparse + +def set_learning_rate(optimizer, lr): + for param_group in optimizer.param_groups: + param_group['lr'] = lr + +def train(fintune_model,image_size,lr0,path_data,model_exp): + + # dataset + n_classes = 19 + n_img_per_gpu = 16 + n_workers = 8 + cropsize = [int(image_size*0.85),int(image_size*0.85)] + + ds = FaceMask(path_data,img_size = image_size, cropsize=cropsize, mode='train') + # sampler = torch.utils.data.distributed.DistributedSampler(ds) + dl = DataLoader(ds, + batch_size = n_img_per_gpu, + shuffle = True, + num_workers = n_workers, + pin_memory = True, + drop_last = True) + + # model + ignore_idx = -100 + + use_cuda = torch.cuda.is_available() + device = torch.device("cuda:0" if use_cuda else "cpu") + net = BiSeNet(n_classes=n_classes) + net = net.to(device) + + if os.access(fintune_model,os.F_OK) and (fintune_model is not None):# checkpoint + chkpt = torch.load(fintune_model, map_location=device) + net.load_state_dict(chkpt) + print('load fintune model : {}'.format(fintune_model)) + else: + print('no fintune model') + + score_thres = 0.7 + n_min = n_img_per_gpu * cropsize[0] * cropsize[1]//16 + LossP = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx) + Loss2 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx) + Loss3 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx) + + ## optimizer + momentum = 0.9 + weight_decay = 5e-4 + lr_start = lr0 + max_epoch = 1000 + + optim = Optimizer.SGD( + net.parameters(), + lr = lr_start, + momentum = momentum, + weight_decay = weight_decay) + + ## train loop + msg_iter = 50 + loss_avg = [] + st = glob_st = time.time() + # diter = iter(dl) + epoch = 0 + flag_change_lr_cnt = 0 # 学习率更新计数器 + init_lr = lr_start # 学习率 + + best_loss = np.inf + loss_mean = 0. # 损失均值 + loss_idx = 0. # 损失计算计数器 + + print('start training ~') + it = 0 + for epoch in range(max_epoch): + net.train() + # 学习率更新策略 + if loss_mean!=0.: + if best_loss > (loss_mean/loss_idx): + flag_change_lr_cnt = 0 + best_loss = (loss_mean/loss_idx) + else: + flag_change_lr_cnt += 1 + + if flag_change_lr_cnt > 30: + init_lr = init_lr*0.1 + set_learning_rate(optimizer, init_lr) + flag_change_lr_cnt = 0 + + loss_mean = 0. # 损失均值 + loss_idx = 0. # 损失计算计数器 + + for i, (im, lb) in enumerate(dl): + + im = im.cuda() + lb = lb.cuda() + H, W = im.size()[2:] + lb = torch.squeeze(lb, 1) + + optim.zero_grad() + out, out16, out32 = net(im) + lossp = LossP(out, lb) + loss2 = Loss2(out16, lb) + loss3 = Loss3(out32, lb) + loss = lossp + loss2 + loss3 + + loss_mean += loss.item() + loss_idx += 1. + + loss.backward() + optim.step() + + + if it % msg_iter == 0: + + print('epoch <{}/{}> -->> <{}/{}> -> iter {} : loss {:.5f}, loss_mean :{:.5f}, best_loss :{:.5f},lr :{:.6f},batch_size : {}'.\ + format(epoch,max_epoch,i,int(ds.__len__()/n_img_per_gpu),it,loss.item(),loss_mean/loss_idx,best_loss,init_lr,n_img_per_gpu)) + # print(msg) + + if (it) % 500 == 0: + state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict() + torch.save(state, model_exp+'fp_{}_latest.pth'.format(image_size)) + it += 1 + torch.save(state, model_exp+'fp_{}_epoch-{}.pth'.format(image_size,epoch)) + +if __name__ == "__main__": + image_size = 512 + lr0 = 1e-4 + model_exp = './model_exp/' + path_data = './CelebAMask-HQ/' + if not osp.exists(model_exp): + os.makedirs(model_exp) + + + loc_time = time.localtime() + model_exp += time.strftime("%Y-%m-%d_%H-%M-%S", loc_time)+'/' + if not osp.exists(model_exp): + os.makedirs(model_exp) + + fintune_model = './weights/fp0.pth' + + train(fintune_model,image_size,lr0,path_data,model_exp) diff --git a/transform.py b/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..1be64768ba072be9897a961a956a11450ca8f158 --- /dev/null +++ b/transform.py @@ -0,0 +1,128 @@ +# -*- encoding: utf-8 -*- + + +from PIL import Image +import PIL.ImageEnhance as ImageEnhance +import random +import numpy as np + +class RandomCrop(object): + def __init__(self, size, *args, **kwargs): + self.size = size + + def __call__(self, im_lb): + im = im_lb['im'] + lb = im_lb['lb'] + assert im.size == lb.size + W, H = self.size + w, h = im.size + + if (W, H) == (w, h): return dict(im=im, lb=lb) + if w < W or h < H: + scale = float(W) / w if w < h else float(H) / h + w, h = int(scale * w + 1), int(scale * h + 1) + im = im.resize((w, h), Image.BILINEAR) + lb = lb.resize((w, h), Image.NEAREST) + sw, sh = random.random() * (w - W), random.random() * (h - H) + crop = int(sw), int(sh), int(sw) + W, int(sh) + H + return dict( + im = im.crop(crop), + lb = lb.crop(crop) + ) + + +class HorizontalFlip(object): + def __init__(self, p=0.5, *args, **kwargs): + self.p = p + + def __call__(self, im_lb): + if random.random() > self.p: + return im_lb + else: + im = im_lb['im'] + lb = im_lb['lb'] + + # atts = [1 'skin', 2 'l_brow', 3 'r_brow', 4 'l_eye', 5 'r_eye', 6 'eye_g', 7 'l_ear', 8 'r_ear', 9 'ear_r', + # 10 'nose', 11 'mouth', 12 'u_lip', 13 'l_lip', 14 'neck', 15 'neck_l', 16 'cloth', 17 'hair', 18 'hat'] + + flip_lb = np.array(lb) + flip_lb[lb == 2] = 3 + flip_lb[lb == 3] = 2 + flip_lb[lb == 4] = 5 + flip_lb[lb == 5] = 4 + flip_lb[lb == 7] = 8 + flip_lb[lb == 8] = 7 + flip_lb = Image.fromarray(flip_lb) + return dict(im = im.transpose(Image.FLIP_LEFT_RIGHT), + lb = flip_lb.transpose(Image.FLIP_LEFT_RIGHT), + ) + + +class RandomScale(object): + def __init__(self, scales=(1, ), *args, **kwargs): + self.scales = scales + + def __call__(self, im_lb): + im = im_lb['im'] + lb = im_lb['lb'] + W, H = im.size + scale = random.choice(self.scales) + w, h = int(W * scale), int(H * scale) + return dict(im = im.resize((w, h), Image.BILINEAR), + lb = lb.resize((w, h), Image.NEAREST), + ) + + +class ColorJitter(object): + def __init__(self, brightness=None, contrast=None, saturation=None, *args, **kwargs): + if not brightness is None and brightness>0: + self.brightness = [max(1-brightness, 0), 1+brightness] + if not contrast is None and contrast>0: + self.contrast = [max(1-contrast, 0), 1+contrast] + if not saturation is None and saturation>0: + self.saturation = [max(1-saturation, 0), 1+saturation] + + def __call__(self, im_lb): + im = im_lb['im'] + lb = im_lb['lb'] + r_brightness = random.uniform(self.brightness[0], self.brightness[1]) + r_contrast = random.uniform(self.contrast[0], self.contrast[1]) + r_saturation = random.uniform(self.saturation[0], self.saturation[1]) + im = ImageEnhance.Brightness(im).enhance(r_brightness) + im = ImageEnhance.Contrast(im).enhance(r_contrast) + im = ImageEnhance.Color(im).enhance(r_saturation) + return dict(im = im, + lb = lb, + ) + + +class MultiScale(object): + def __init__(self, scales): + self.scales = scales + + def __call__(self, img): + W, H = img.size + sizes = [(int(W*ratio), int(H*ratio)) for ratio in self.scales] + imgs = [] + [imgs.append(img.resize(size, Image.BILINEAR)) for size in sizes] + return imgs + + +class Compose(object): + def __init__(self, do_list): + self.do_list = do_list + + def __call__(self, im_lb): + for comp in self.do_list: + im_lb = comp(im_lb) + return im_lb + + + + +if __name__ == '__main__': + flip = HorizontalFlip(p = 1) + crop = RandomCrop((321, 321)) + rscales = RandomScale((0.75, 1.0, 1.5, 1.75, 2.0)) + img = Image.open('data/img.jpg') + lb = Image.open('data/label.png')