diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a8752db2bd808e5b6f1c68dd013d809111c2ea89 --- /dev/null +++ b/README.md @@ -0,0 +1,44 @@ +# facial_landmark +人脸 98 个关键点检测 + +## 项目介绍 +注意:该项目不包括手部检测部分,手部检测项目地址:https://codechina.csdn.net/EricLee/yolo_v3 +* 图片示例: +![image](https://codechina.csdn.net/EricLee/facial_landmark/-/raw/master/samples/6.jpg) +* 视频示例: +![video](https://codechina.csdn.net/EricLee/facial_landmark/-/raw/master/samples/sample.gif) + +## 项目配置 +* 作者开发环境: +* Python 3.7 +* PyTorch >= 1.5.1 + +## 数据集 +* 数据集官方地址: + https://wywu.github.io/projects/LAB/WFLW.html +* [该项目制作的训练集的数据集下载地址(百度网盘 Password: qruc )](https://pan.baidu.com/s/1DyFDviOEtmk0gb4N0cYHEw) + +``` + @inproceedings{wayne2018lab, + author = {Wu, Wayne and Qian, Chen and Yang, Shuo and Wang, Quan and Cai, Yici and Zhou, Qiang}, + title = {Look at Boundary: A Boundary-Aware Face Alignment Algorithm}, + booktitle = {CVPR}, + month = June, + year = {2018} + } +``` + + 该数据集对于人脸的关键点定义如下图(图片如侵权请联系删除): + ![image](https://codechina.csdn.net/EricLee/facial_landmark/-/raw/master/WFLW_annotation.png) + + +## 预训练模型 +* [预训练模型下载地址(百度网盘 Password: 5twg )](https://pan.baidu.com/s/1Psz-xsb3S07A1hnz0wQ4fw) + +## 项目使用方法 + +### 模型训练 +* 根目录下运行命令: python train.py (注意脚本内相关参数配置 ) + +### 模型推理 +* 根目录下运行命令: python inference.py (注意脚本内相关参数配置 ) diff --git a/WFLW_annotation.png b/WFLW_annotation.png new file mode 100644 index 0000000000000000000000000000000000000000..766c09b4a30cf4f5086d10e329edd2cc51fc61d1 Binary files /dev/null and b/WFLW_annotation.png differ diff --git a/data_iter/data_agu.py b/data_iter/data_agu.py new file mode 100644 index 0000000000000000000000000000000000000000..d850a87e71452de707144796aefb2648ab9ca736 --- /dev/null +++ b/data_iter/data_agu.py @@ -0,0 +1,141 @@ +#-*-coding:utf-8-*- +# date:2019-05-20 +# Author: Eric.Lee +# function: face rot img aug + +import cv2 +import numpy as np +import random + +# flip 的 landmarks 查表之前 landmarks 序号。 +flip_landmarks_dict = { + 0:32,1:31,2:30,3:29,4:28,5:27,6:26,7:25,8:24,9:23,10:22,11:21,12:20,13:19,14:18,15:17, + 16:16,17:15,18:14,19:13,20:12,21:11,22:10,23:9,24:8,25:7,26:6,27:5,28:4,29:3,30:2,31:1,32:0, + 33:46,34:45,35:44,36:43,37:42,38:50,39:49,40:48,41:47, + 46:33,45:34,44:35,43:36,42:37,50:38,49:39,48:40,47:41, + 60:72,61:71,62:70,63:69,64:68,65:75,66:74,67:73, + 72:60,71:61,70:62,69:63,68:64,75:65,74:66,73:67, + 96:97,97:96, + 51:51,52:52,53:53,54:54, + 55:59,56:58,57:57,58:56,59:55, + 76:82,77:81,78:80,79:79,80:78,81:77,82:76, + 87:83,86:84,85:85,84:86,83:87, + 88:92,89:91,90:90,91:89,92:88, + 95:93,94:94,93:95 + } +# 非形变处理 +def letterbox(img_,img_size=256,mean_rgb = (128,128,128)): + + shape_ = img_.shape[:2] # shape = [height, width] + ratio = float(img_size) / max(shape_) # ratio = old / new + new_shape_ = (round(shape_[1] * ratio), round(shape_[0] * ratio)) + dw_ = (img_size - new_shape_[0]) / 2 # width padding + dh_ = (img_size - new_shape_[1]) / 2 # height padding + top_, bottom_ = round(dh_ - 0.1), round(dh_ + 0.1) + left_, right_ = round(dw_ - 0.1), round(dw_ + 0.1) + # resize img + img_a = cv2.resize(img_, new_shape_, interpolation=cv2.INTER_LINEAR) + + img_a = cv2.copyMakeBorder(img_a, top_, bottom_, left_, right_, cv2.BORDER_CONSTANT, value=mean_rgb) # padded square + + return img_a + +def img_agu_channel_same(img_): + img_a = np.zeros(img_.shape, dtype = np.uint8) + gray = cv2.cvtColor(img_,cv2.COLOR_RGB2GRAY) + img_a[:,:,0] =gray + img_a[:,:,1] =gray + img_a[:,:,2] =gray + + return img_a + +# 图像旋转 +def face_random_rotate(image , pts,angle ,Eye_Left,Eye_Right,fix_res= True,img_size=(256,256),vis = False): + cx,cy = (Eye_Left[0] + Eye_Right[0]) / 2,(Eye_Left[1] + Eye_Right[1]) / 2 + (h , w) = image.shape[:2] + h = h + w = w + # (cx , cy) = (int(0.5 * w) , int(0.5 * h)) + M = cv2.getRotationMatrix2D((cx , cy) , angle , 1.0) + cos = np.abs(M[0 , 0]) + sin = np.abs(M[0 , 1]) + + # 计算新图像的bounding + nW = int((h * sin) + (w * cos)) + nH = int((h * cos) + (w * sin)) + + M[0 , 2] += int(0.5 * nW) - cx + M[1 , 2] += int(0.5 * nH) - cy + + resize_model = [cv2.INTER_LINEAR,cv2.INTER_CUBIC,cv2.INTER_NEAREST,cv2.INTER_AREA,cv2.INTER_LANCZOS4] + + img_rot = cv2.warpAffine(image , M , (nW , nH),flags=resize_model[random.randint(0,4)]) + #flags : INTER_LINEAR INTER_CUBIC INTER_NEAREST + #borderMode : BORDER_REFLECT BORDER_TRANSPARENT BORDER_REPLICATE CV_BORDER_WRAP BORDER_CONSTANT + + pts_r = [] + for pt in pts: + x = float(pt[0]) + y = float(pt[1]) + + x_r = (x*M[0][0] + y*M[0][1] + M[0][2]) + y_r = (x*M[1][0] + y*M[1][1] + M[1][2]) + + pts_r.append([x_r,y_r]) + + x = [pt[0] for pt in pts_r] + y = [pt[1] for pt in pts_r] + + x1,y1,x2,y2 = np.min(x),np.min(y),np.max(x),np.max(y) + + translation_pixels = 60 + + scaling = 0.3 + x1 += random.randint(-int(max((x2-x1)*scaling,translation_pixels)),int((x2-x1)*0.25)) + y1 += random.randint(-int(max((y2-y1)*scaling,translation_pixels)),int((y2-y1)*0.25)) + x2 += random.randint(-int((x2-x1)*0.15),int(max((x2-x1)*scaling,translation_pixels))) + y2 += random.randint(-int((y2-y1)*0.15),int(max((y2-y1)*scaling,translation_pixels))) + + x1,y1,x2,y2 = int(x1),int(y1),int(x2),int(y2) + x1 = int(max(0,x1)) + y1 = int(max(0,y1)) + x2 = int(min(x2,img_rot.shape[1]-1)) + y2 = int(min(y2,img_rot.shape[0]-1)) + + + crop_rot = img_rot[y1:y2,x1:x2,:] + + crop_pts = [] + width_crop = float(x2-x1) + height_crop = float(y2-y1) + for pt in pts_r: + x = pt[0] + y = pt[1] + crop_pts.append([float(x-x1)/width_crop,float(y-y1)/height_crop]) # 归一化 + + # 随机镜像 + if random.random() >= 0.5: + # print('--------->>> flip') + crop_rot = cv2.flip(crop_rot,1) + crop_pts_flip = [] + for i in range(len(crop_pts)): + # print( crop_rot.shape[1],crop_pts[flip_landmarks_dict[i]][0]) + x = 1. - crop_pts[flip_landmarks_dict[i]][0] + y = crop_pts[flip_landmarks_dict[i]][1] + # print(i,x,y) + crop_pts_flip.append([x,y]) + crop_pts = crop_pts_flip + + if vis: + for pt in crop_pts: + x = int(pt[0]*width_crop) + y = int(pt[1]*height_crop) + + cv2.circle(crop_rot, (int(x),int(y)), 2, (255,0,255),-1) + + if fix_res: + crop_rot = letterbox(crop_rot,img_size=img_size[0],mean_rgb = (128,128,128)) + else: + crop_rot = cv2.resize(crop_rot, img_size, interpolation = resize_model[random.randint(0,4)]) + + return crop_rot,crop_pts diff --git a/data_iter/datasets.py b/data_iter/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..974f6bf4398a37b836dcc7d5d8c864b6fdd60822 --- /dev/null +++ b/data_iter/datasets.py @@ -0,0 +1,117 @@ +#-*-coding:utf-8-*- +# date:2019-05-20 +# Author: Eric.Lee +# function: data iter + +import glob +import math +import os +import random +import shutil +from pathlib import Path +from PIL import Image +# import matplotlib.pyplot as plt +from tqdm import tqdm +import cv2 +import numpy as np +import torch +from torch.utils.data import Dataset +from torch.utils.data import DataLoader +from data_iter.data_agu import * +import shutil +import json + +# 图像白化 +def prewhiten(x): + mean = np.mean(x) + std = np.std(x) + std_adj = np.maximum(std, 1.0 / np.sqrt(x.size)) + y = np.multiply(np.subtract(x, mean), 1 / std_adj) + return y + +# 图像亮度、对比度增强 +def contrast_img(img, c, b): # 亮度就是每个像素所有通道都加上b + rows, cols, channels = img.shape + # 新建全零(黑色)图片数组:np.zeros(img1.shape, dtype=uint8) + blank = np.zeros([rows, cols, channels], img.dtype) + dst = cv2.addWeighted(img, c, blank, 1-c, b) + return dst + +class LoadImagesAndLabels(Dataset): + def __init__(self, ops, img_size=(224,224), flag_agu = False,fix_res = True,vis = False): + print('img_size (height,width) : ',img_size[0],img_size[1]) + r_ = open(ops.train_list,'r') + lines = r_.readlines() + idx = 0 + file_list = [] + landmarks_list = [] + for line in lines: + # print(line) + msg = line.strip().split(' ') + idx += 1 + print('idx-',idx,' : ',len(msg)) + landmarks = msg[0:196] + bbox = msg[196:200] + attributes = msg[200:206] + img_file = msg[206] + print(img_file) + pts = [] + global_dict_landmarks = {} # 全局坐标系坐标 + for i in range(int(len(landmarks)/2)): + x = float(landmarks[i*2+0]) + y = float(landmarks[i*2+1]) + pts.append([x,y]) + + + landmarks_list.append(pts) + file_list.append(ops.images_path+img_file) + + self.files = file_list + self.landmarks = landmarks_list + self.img_size = img_size + self.flag_agu = flag_agu + self.fix_res = fix_res + self.vis = vis + + def __len__(self): + return len(self.files) + + def __getitem__(self, index): + img_path = self.files[index] + pts = self.landmarks[index] + img = cv2.imread(img_path) # BGR + if self.flag_agu == True: + left_eye = np.average(pts[60:68], axis=0) + right_eye = np.average(pts[68:76], axis=0) + + angle_random = random.randint(-36,36) + # 返回 crop 图 和 归一化 landmarks + img_, landmarks_ = face_random_rotate(img, pts, angle_random, left_eye, right_eye, + fix_res = self.fix_res,img_size = self.img_size,vis = False) + if self.flag_agu == True: + if random.random() > 0.5: + c = float(random.randint(50,150))/100. + b = random.randint(-20,20) + img_ = contrast_img(img_, c, b) + if self.flag_agu == True: + if random.random() > 0.7: + # print('agu hue ') + img_hsv=cv2.cvtColor(img_,cv2.COLOR_BGR2HSV) + hue_x = random.randint(-10,10) + # print(cc) + img_hsv[:,:,0]=(img_hsv[:,:,0]+hue_x) + img_hsv[:,:,0] =np.maximum(img_hsv[:,:,0],0) + img_hsv[:,:,0] =np.minimum(img_hsv[:,:,0],180)#范围 0 ~180 + img_=cv2.cvtColor(img_hsv,cv2.COLOR_HSV2BGR) + if self.flag_agu == True: + if random.random() > 0.8: + img_ = img_agu_channel_same(img_) + if self.vis == True: + cv2.namedWindow('crop',0) + cv2.imshow('crop',img_) + cv2.waitKey(1) + img_ = img_.astype(np.float32) + img_ = (img_-128.)/256. + img_ = img_.transpose(2, 0, 1) + landmarks_ = np.array(landmarks_).ravel() + return img_,landmarks_ diff --git a/image/1.jpg b/image/1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3ae60f9dc48993db28b41793273e2be289695272 Binary files /dev/null and b/image/1.jpg differ diff --git a/image/10.jpg b/image/10.jpg new file mode 100644 index 0000000000000000000000000000000000000000..123ef639a778cb985b438846481d0a25af59064d Binary files /dev/null and b/image/10.jpg differ diff --git a/image/11.jpg b/image/11.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f4f3df43e6fd93c812ec4c26c41b0b021976d1a8 Binary files /dev/null and b/image/11.jpg differ diff --git a/image/12.jpg b/image/12.jpg new file mode 100644 index 0000000000000000000000000000000000000000..93c2318e2c39b671ed714e8fabf04ed08dbe108c Binary files /dev/null and b/image/12.jpg differ diff --git a/image/2.jpg b/image/2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fbd541766a767d92802e17368ca9f9d37ab5b8aa Binary files /dev/null and b/image/2.jpg differ diff --git a/image/3.jpg b/image/3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..35705991168fecd02f34d826749f66dec2de09ee Binary files /dev/null and b/image/3.jpg differ diff --git a/image/4.jpg b/image/4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e1c97c476411f078a9d2d4af469a8b0b182fbd08 Binary files /dev/null and b/image/4.jpg differ diff --git a/image/5.jpg b/image/5.jpg new file mode 100644 index 0000000000000000000000000000000000000000..96012f4c7117bee2c2a3f9fcec8558c7ff3fd81f Binary files /dev/null and b/image/5.jpg differ diff --git a/image/6.jpg b/image/6.jpg new file mode 100644 index 0000000000000000000000000000000000000000..188ffc964cf8b3941f4b0c68a1fb961144715081 Binary files /dev/null and b/image/6.jpg differ diff --git a/image/7.jpg b/image/7.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e907d75af543df77043a64966d5544e5f59c19bd Binary files /dev/null and b/image/7.jpg differ diff --git a/image/8.jpg b/image/8.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5a8d00b1b760698aadd8c5739f83d481c9a4e68a Binary files /dev/null and b/image/8.jpg differ diff --git a/image/9.jpg b/image/9.jpg new file mode 100644 index 0000000000000000000000000000000000000000..487acba63e2bc9734e9a7d731c3186b89c146324 Binary files /dev/null and b/image/9.jpg differ diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..52cc2fc2f865dbaaef040ed1f1caaf887069e887 --- /dev/null +++ b/inference.py @@ -0,0 +1,125 @@ +#-*-coding:utf-8-*- +# date:2020-04-25 +# Author: Eric.Lee +# function: inference + +import os +import argparse +import torch +import torch.nn as nn +from data_iter.datasets import letterbox +import numpy as np + +import math +import cv2 +import torch.nn.functional as F + +from models.resnet import resnet50, resnet34 +from utils.common_utils import * + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description=' Project Landmarks Test') + + parser.add_argument('--test_model', type=str, default = './model_exp/2021-02-21_17-51-30/resnet_50-epoch-724.pth', + help = 'test_model') # 模型路径 + parser.add_argument('--model', type=str, default = 'resnet_50', + help = 'model : resnet_x') # 模型类型 + parser.add_argument('--num_classes', type=int , default = 196, + help = 'num_classes') # 分类类别个数 + parser.add_argument('--GPUS', type=str, default = '0', + help = 'GPUS') # GPU选择 + parser.add_argument('--test_path', type=str, default = './image/', + help = 'test_path') # 测试集路径 + parser.add_argument('--img_size', type=tuple , default = (256,256), + help = 'img_size') # 输入模型图片尺寸 + parser.add_argument('--fix_res', type=bool , default = False, + help = 'fix_resolution') # 输入模型样本图片是否保证图像分辨率的长宽比 + parser.add_argument('--vis', type=bool , default = True, + help = 'vis') # 是否可视化图片 + + print('\n/******************* {} ******************/\n'.format(parser.description)) + #-------------------------------------------------------------------------- + ops = parser.parse_args()# 解析添加参数 + #-------------------------------------------------------------------------- + print('----------------------------------') + + unparsed = vars(ops) # parse_args()方法的返回值为namespace,用vars()内建函数化为字典 + for key in unparsed.keys(): + print('{} : {}'.format(key,unparsed[key])) + + #--------------------------------------------------------------------------- + os.environ['CUDA_VISIBLE_DEVICES'] = ops.GPUS + + test_path = ops.test_path # 测试图片文件夹路径 + + #---------------------------------------------------------------- 构建模型 + print('use model : %s'%(ops.model)) + + if ops.model == 'resnet_50': + model_ = resnet50(num_classes = ops.num_classes,img_size=ops.img_size[0]) + elif ops.model == 'resnet_34': + model_ = resnet34(num_classes = ops.num_classes,img_size=ops.img_size[0]) + + + use_cuda = torch.cuda.is_available() + + device = torch.device("cuda:0" if use_cuda else "cpu") + model_ = model_.to(device) + model_.eval() # 设置为前向推断模式 + + # print(model_)# 打印模型结构 + + # 加载测试模型 + if os.access(ops.test_model,os.F_OK):# checkpoint + chkpt = torch.load(ops.test_model, map_location=device) + model_.load_state_dict(chkpt) + print('load test model : {}'.format(ops.test_model)) + + #---------------------------------------------------------------- 预测图片 + font = cv2.FONT_HERSHEY_SIMPLEX + with torch.no_grad(): + idx = 0 + for file in os.listdir(ops.test_path): + if '.jpg' not in file: + continue + idx += 1 + print('{}) image : {}'.format(idx,file)) + img = cv2.imread(ops.test_path + file) + img_width = img.shape[1] + img_height = img.shape[0] + # 输入图片预处理 + if ops.fix_res: + img_ = letterbox(img,size_=ops.img_size[0],mean_rgb = (128,128,128)) + else: + img_ = cv2.resize(img, (ops.img_size[1],ops.img_size[0]), interpolation = cv2.INTER_CUBIC) + + img_ = img_.astype(np.float32) + img_ = (img_-128.)/256. + + img_ = img_.transpose(2, 0, 1) + img_ = torch.from_numpy(img_) + img_ = img_.unsqueeze_(0) + + if use_cuda: + img_ = img_.cuda() # (bs, 3, h, w) + + pre_ = model_(img_.float()) + # print(pre_.size()) + output = pre_.cpu().detach().numpy() + output = np.squeeze(output) + # print(output.shape) + dict_landmarks = draw_landmarks(img,output,draw_circle = False) + + draw_contour(img,dict_landmarks) + + if ops.vis: + cv2.namedWindow('image',0) + cv2.imshow('image',img) + cv2.imwrite("./samples/"+file,img) + if cv2.waitKey(1000) == 27 : + break + + cv2.destroyAllWindows() + + print('well done ') diff --git a/loss/loss.py b/loss/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..69a506723512b72228cf02a4cf46095e193e50ff --- /dev/null +++ b/loss/loss.py @@ -0,0 +1,36 @@ +#-*-coding:utf-8-*- +# date:2019-05-20 +import torch +import torch.nn as nn +import torch.optim as optim +import os +import math + + +def wing_loss(landmarks, labels, w=0.06, epsilon=0.01): + """ + Arguments: + landmarks, labels: float tensors with shape [batch_size, landmarks]. landmarks means x1,x2,x3,x4...y1,y2,y3,y4 1-D + w, epsilon: a float numbers. + Returns: + a float tensor with shape []. + """ + + x = landmarks - labels + c = w * (1.0 - math.log(1.0 + w / epsilon)) + absolute_x = torch.abs(x) + + losses = torch.where(\ + (w>absolute_x),\ + w * torch.log(1.0 + absolute_x / epsilon),\ + absolute_x - c) + + + # loss = tf.reduce_mean(tf.reduce_mean(losses, axis=[1]), axis=0) + losses = torch.mean(losses,dim=1,keepdim=True) + loss = torch.mean(losses) + return loss +def got_total_wing_loss(output,crop_landmarks): + loss = wing_loss(output, crop_landmarks) + + return loss diff --git a/models/resnet.py b/models/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..433b22f51f92dfe5e3e0f4d503c1d64c54c3dadd --- /dev/null +++ b/models/resnet.py @@ -0,0 +1,262 @@ +import torch +import torch.nn as nn +import math +import torch.utils.model_zoo as model_zoo + +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152'] + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000, img_size=224,dropout_factor = 1.): + self.inplanes = 64 + self.dropout_factor = dropout_factor + super(ResNet, self).__init__() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + # see this issue: https://github.com/xxradon/PytorchToCaffe/issues/16 + # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + + assert img_size % 32 == 0 + pool_kernel = int(img_size / 32) + self.avgpool = nn.AvgPool2d(pool_kernel, stride=1, ceil_mode=True) + + self.dropout = nn.Dropout(self.dropout_factor) + + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + + x = self.dropout(x) + + x = self.fc(x) + + return x + + +def load_model(model, pretrained_state_dict): + model_dict = model.state_dict() + pretrained_dict = {k: v for k, v in pretrained_state_dict.items() if + k in model_dict and model_dict[k].size() == pretrained_state_dict[k].size()} + model.load_state_dict(pretrained_dict, strict=False) + if len(pretrained_dict) == 0: + print("[INFO] No params were loaded ...") + else: + for k, v in pretrained_state_dict.items(): + if k in pretrained_dict: + print("==>> Load {} {}".format(k, v.size())) + else: + print("[INFO] Skip {} {}".format(k, v.size())) + return model + + +def resnet18(pretrained=False, **kwargs): + """Constructs a ResNet-18 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + if pretrained: + # model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) + print("Load pretrained model from {}".format(model_urls['resnet18'])) + pretrained_state_dict = model_zoo.load_url(model_urls['resnet18']) + model = load_model(model, pretrained_state_dict) + return model + + +def resnet34(pretrained=False, **kwargs): + """Constructs a ResNet-34 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) + if pretrained: + # model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) + print("Load pretrained model from {}".format(model_urls['resnet34'])) + pretrained_state_dict = model_zoo.load_url(model_urls['resnet34']) + model = load_model(model, pretrained_state_dict) + return model + + +def resnet50(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + if pretrained: + # model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) + print("Load pretrained model from {}".format(model_urls['resnet50'])) + pretrained_state_dict = model_zoo.load_url(model_urls['resnet50']) + model = load_model(model, pretrained_state_dict) + return model + + +def resnet101(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) + if pretrained: + # model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) + print("Load pretrained model from {}".format(model_urls['resnet101'])) + pretrained_state_dict = model_zoo.load_url(model_urls['resnet101']) + model = load_model(model, pretrained_state_dict) + return model + + +def resnet152(pretrained=False, **kwargs): + """Constructs a ResNet-152 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) + if pretrained: + # model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) + print("Load pretrained model from {}".format(model_urls['resnet152'])) + pretrained_state_dict = model_zoo.load_url(model_urls['resnet152']) + model = load_model(model, pretrained_state_dict) + return model + +if __name__ == "__main__": + input = torch.randn([32, 3, 256,256]) + model = resnet34(False, num_classes=2, img_size=256) + output = model(input) + print(output.size()) diff --git a/samples/1.jpg b/samples/1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..65a1368f884bd82ff8cf5220abfd6c07b4691deb Binary files /dev/null and b/samples/1.jpg differ diff --git a/samples/10.jpg b/samples/10.jpg new file mode 100644 index 0000000000000000000000000000000000000000..87094256e051b72d0327f9840bb1bab1e0b5a3c8 Binary files /dev/null and b/samples/10.jpg differ diff --git a/samples/11.jpg b/samples/11.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b4a389845f93bd70f11ff7eba167b8839f4826db Binary files /dev/null and b/samples/11.jpg differ diff --git a/samples/12.jpg b/samples/12.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9f026d6d24b52e17bb418f1255043a781e94328c Binary files /dev/null and b/samples/12.jpg differ diff --git a/samples/2.jpg b/samples/2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b1a8bebca8fada93efd2e1bdf3a2c138d60e6dcc Binary files /dev/null and b/samples/2.jpg differ diff --git a/samples/3.jpg b/samples/3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..413fbb42c0d12e52d43e9872c879e07dfe6b1f75 Binary files /dev/null and b/samples/3.jpg differ diff --git a/samples/4.jpg b/samples/4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4b7beaf2b475b6e8cc5280be24a19f8c4d0488e1 Binary files /dev/null and b/samples/4.jpg differ diff --git a/samples/5.jpg b/samples/5.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5de8d0e8a1c3261c8742e6ef32b9dc771b891ddb Binary files /dev/null and b/samples/5.jpg differ diff --git a/samples/6.jpg b/samples/6.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c19d284d5154278c3ebb7ca7757f4603a0c54860 Binary files /dev/null and b/samples/6.jpg differ diff --git a/samples/7.jpg b/samples/7.jpg new file mode 100644 index 0000000000000000000000000000000000000000..693653ddac2aa66d27c8722dd92bc81ac72dbe05 Binary files /dev/null and b/samples/7.jpg differ diff --git a/samples/8.jpg b/samples/8.jpg new file mode 100644 index 0000000000000000000000000000000000000000..570a94072c6afeabb1bb3eaff9c0960424068df4 Binary files /dev/null and b/samples/8.jpg differ diff --git a/samples/9.jpg b/samples/9.jpg new file mode 100644 index 0000000000000000000000000000000000000000..122871e5894bf19d002c941949a10752f34cb0c3 Binary files /dev/null and b/samples/9.jpg differ diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..68a4ab4046a69cebd7798c72189a219f2d981b51 --- /dev/null +++ b/train.py @@ -0,0 +1,223 @@ +#-*-coding:utf-8-*- +# date:2020-04-24 +# Author: Eric.Lee +## function: train + +import os +import argparse +import torch +import torch.nn as nn +import torch.optim as optim +import sys + +from utils.model_utils import * +from utils.common_utils import * +from data_iter.datasets import * + +from models.resnet import resnet50 +from loss.loss import * +import cv2 +import time +from datetime import datetime + +def trainer(ops,f_log): + try: + os.environ['CUDA_VISIBLE_DEVICES'] = ops.GPUS + + if ops.log_flag: + sys.stdout = f_log + + set_seed(ops.seed) + #---------------------------------------------------------------- 构建模型 + print('use model : %s'%(ops.model)) + + if ops.model == 'resnet_50': + model_ = resnet50(pretrained = ops.pretrained,num_classes = ops.num_classes,img_size = ops.img_size[0],dropout_factor=ops.dropout) + elif ops.model == 'resnet_34': + model_ = resnet34(pretrained = ops.pretrained,num_classes = ops.num_classes,img_size = ops.img_size[0],dropout_factor=ops.dropout) + + use_cuda = torch.cuda.is_available() + + device = torch.device("cuda:0" if use_cuda else "cpu") + model_ = model_.to(device) + + # print(model_)# 打印模型结构 + # Dataset + dataset = LoadImagesAndLabels(ops= ops,img_size=ops.img_size,flag_agu=ops.flag_agu,fix_res = ops.fix_res,vis = False) + print("wflw done") + + print('len train datasets : %s'%(dataset.__len__())) + # Dataloader + dataloader = DataLoader(dataset, + batch_size=ops.batch_size, + num_workers=ops.num_workers, + shuffle=True, + pin_memory=False, + drop_last = True) + # 优化器设计 + optimizer_SGD = optim.SGD(model_.parameters(), lr=ops.init_lr, momentum=ops.momentum, weight_decay=ops.weight_decay)# 优化器初始化 + optimizer = optimizer_SGD + # 加载 finetune 模型 + if os.access(ops.fintune_model,os.F_OK):# checkpoint + chkpt = torch.load(ops.fintune_model, map_location=device) + model_.load_state_dict(chkpt) + print('load fintune model : {}'.format(ops.fintune_model)) + + print('/**********************************************/') + # 损失函数 + if ops.loss_define != 'wing_loss': + criterion = nn.MSELoss(reduce=True, reduction='mean') + + step = 0 + idx = 0 + + # 变量初始化 + best_loss = np.inf + loss_mean = 0. # 损失均值 + loss_idx = 0. # 损失计算计数器 + flag_change_lr_cnt = 0 # 学习率更新计数器 + init_lr = ops.init_lr # 学习率 + + epochs_loss_dict = {} + + for epoch in range(0, ops.epochs): + if ops.log_flag: + sys.stdout = f_log + print('\nepoch %d ------>>>'%epoch) + model_.train() + # 学习率更新策略 + if loss_mean!=0.: + if best_loss > (loss_mean/loss_idx): + flag_change_lr_cnt = 0 + best_loss = (loss_mean/loss_idx) + else: + flag_change_lr_cnt += 1 + + if flag_change_lr_cnt > 20: + init_lr = init_lr*ops.lr_decay + set_learning_rate(optimizer, init_lr) + flag_change_lr_cnt = 0 + + loss_mean = 0. # 损失均值 + loss_idx = 0. # 损失计算计数器 + + for i, (imgs_, pts_) in enumerate(dataloader): + # print('imgs_, pts_',imgs_.size(), pts_.size()) + if use_cuda: + imgs_ = imgs_.cuda() # pytorch 的 数据输入格式 : (batch, channel, height, width) + pts_ = pts_.cuda() + + output = model_(imgs_.float()) + if ops.loss_define == 'wing_loss': + # 可以针对人脸部分的效果调损失权重,注意 关键点 id 映射关系 *2 ,因为涉及坐标(x,y) + loss = got_total_wing_loss(output, pts_.float()) + loss_eye_center = got_total_wing_loss(output[:,192:196], pts_[:,192:196].float())*0.06 + loss_eye = got_total_wing_loss(output[:,120:152], pts_[:,120:152].float())*0.15 + loss_nose = got_total_wing_loss(output[:,102:120], pts_[:,102:120].float())*0.08 + loss += loss_eye_center + loss_eye + loss_nose + else: + loss = criterion(output, pts_.float()) + loss_mean += loss.item() + loss_idx += 1. + if i%10 == 0: + loc_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + print(' %s - %s - epoch [%s/%s] (%s/%s):'%(loc_time,ops.model,epoch,ops.epochs,i,int(dataset.__len__()/ops.batch_size)),\ + 'Mean Loss : %.6f - Loss: %.6f'%(loss_mean/loss_idx,loss.item()),\ + ' lr : %.7f'%init_lr,' bs :',ops.batch_size,\ + ' img_size: %s x %s'%(ops.img_size[0],ops.img_size[1]),' best_loss: %.6f'%best_loss) + # 计算梯度 + loss.backward() + # 优化器对模型参数更新 + optimizer.step() + # 优化器梯度清零 + optimizer.zero_grad() + step += 1 + + torch.save(model_.state_dict(), ops.model_exp + '{}-epoch-{}.pth'.format(ops.model,epoch)) + + except Exception as e: + print('Exception : ',e) # 打印异常 + print('Exception file : ', e.__traceback__.tb_frame.f_globals['__file__'])# 发生异常所在的文件 + print('Exception line : ', e.__traceback__.tb_lineno)# 发生异常所在的行数 + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description=' Project Classification Train') + parser.add_argument('--seed', type=int, default = 32, + help = 'seed') # 设置随机种子 + parser.add_argument('--model_exp', type=str, default = './model_exp', + help = 'model_exp') # 模型输出文件夹 + parser.add_argument('--model', type=str, default = 'resnet_50', + help = 'model : resnet_50,resnet_34') # 模型类型 + parser.add_argument('--num_classes', type=int , default = 196, + help = 'num_classes') # landmarks 个数*2 + parser.add_argument('--GPUS', type=str, default = '0', + help = 'GPUS') # GPU选择 + + parser.add_argument('--images_path', type=str, default = './datasets/WFLW_images/', + help = 'images_path') # 图片路径 + + parser.add_argument('--train_list', type=str, + default = './datasets/WFLW_annotations/list_98pt_rect_attr_train_test/list_98pt_rect_attr_train.txt', + help = 'annotations_train_list')# 训练集标注信息 + + parser.add_argument('--pretrained', type=bool, default = True, + help = 'imageNet_Pretrain') # 初始化学习率 + parser.add_argument('--fintune_model', type=str, default = './model_exp/2021-02-21_17-51-10/resnet_50-epoch-103.pth', + help = 'fintune_model') # fintune model + parser.add_argument('--loss_define', type=str, default = 'wing_loss', + help = 'define_loss') # 损失函数定义 + parser.add_argument('--init_lr', type=float, default = 1e-5, + help = 'init_learningRate') # 初始化学习率 + parser.add_argument('--lr_decay', type=float, default = 0.1, + help = 'learningRate_decay') # 学习率权重衰减率 + parser.add_argument('--weight_decay', type=float, default = 5e-6, + help = 'weight_decay') # 优化器正则损失权重 + parser.add_argument('--momentum', type=float, default = 0.9, + help = 'momentum') # 优化器动量 + parser.add_argument('--batch_size', type=int, default = 16, + help = 'batch_size') # 训练每批次图像数量 + parser.add_argument('--dropout', type=float, default = 0.5, + help = 'dropout') # dropout + parser.add_argument('--epochs', type=int, default = 1000, + help = 'epochs') # 训练周期 + parser.add_argument('--num_workers', type=int, default = 8, + help = 'num_workers') # 训练数据生成器线程数 + parser.add_argument('--img_size', type=tuple , default = (256,256), + help = 'img_size') # 输入模型图片尺寸 + parser.add_argument('--flag_agu', type=bool , default = True, + help = 'data_augmentation') # 训练数据生成器是否进行数据扩增 + parser.add_argument('--fix_res', type=bool , default = False, + help = 'fix_resolution') # 输入模型样本图片是否保证图像分辨率的长宽比 + parser.add_argument('--clear_model_exp', type=bool, default = False, + help = 'clear_model_exp') # 模型输出文件夹是否进行清除 + parser.add_argument('--log_flag', type=bool, default = False, + help = 'log flag') # 是否保存训练 log + + #-------------------------------------------------------------------------- + args = parser.parse_args()# 解析添加参数 + #-------------------------------------------------------------------------- + mkdir_(args.model_exp, flag_rm=args.clear_model_exp) + loc_time = time.localtime() + args.model_exp = args.model_exp + '/' + time.strftime("%Y-%m-%d_%H-%M-%S", loc_time)+'/' + mkdir_(args.model_exp, flag_rm=args.clear_model_exp) + + f_log = None + if args.log_flag: + f_log = open(args.model_exp+'/train_{}.log'.format(time.strftime("%Y-%m-%d_%H-%M-%S",loc_time)), 'a+') + sys.stdout = f_log + + print('---------------------------------- log : {}'.format(time.strftime("%Y-%m-%d %H:%M:%S", loc_time))) + print('\n/******************* {} ******************/\n'.format(parser.description)) + + unparsed = vars(args) # parse_args()方法的返回值为namespace,用vars()内建函数化为字典 + for key in unparsed.keys(): + print('{} : {}'.format(key,unparsed[key])) + + unparsed['time'] = time.strftime("%Y-%m-%d %H:%M:%S", loc_time) + + trainer(ops = args,f_log = f_log)# 模型训练 + + if args.log_flag: + sys.stdout = f_log + print('well done : {}'.format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))) diff --git a/utils/common_utils.py b/utils/common_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1e5b172d5efb740952deacaa38d34aa2e3abe980 --- /dev/null +++ b/utils/common_utils.py @@ -0,0 +1,132 @@ +#-*-coding:utf-8-*- +# date:2020-04-11 +# Author: Eric.Lee +# function: common utils + +import os +import shutil +import cv2 +import numpy as np +import json + +def mkdir_(path, flag_rm=False): + if os.path.exists(path): + if flag_rm == True: + shutil.rmtree(path) + os.mkdir(path) + print('remove {} done ~ '.format(path)) + else: + os.mkdir(path) + +def plot_box(bbox, img, color=None, label=None, line_thickness=None): + tl = line_thickness or round(0.002 * max(img.shape[0:2])) + 1 + color = color or [random.randint(0, 255) for _ in range(3)] + c1, c2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])) + cv2.rectangle(img, c1, c2, color, thickness=tl)# 目标的bbox + if label: + tf = max(tl - 2, 1) + t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0] # label size + c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3 # 字体的bbox + cv2.rectangle(img, c1, c2, color, -1) # label 矩形填充 + # 文本绘制 + cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 4, [225, 255, 255],thickness=tf, lineType=cv2.LINE_AA) + +class JSON_Encoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + else: + return super(JSON_Encoder, self).default(obj) + +def draw_landmarks(img,output,draw_circle): + img_width = img.shape[1] + img_height = img.shape[0] + dict_landmarks = {} + for i in range(int(output.shape[0]/2)): + x = output[i*2+0]*float(img_width) + y = output[i*2+1]*float(img_height) + if 41>= i >=33: + if 'left_eyebrow' not in dict_landmarks.keys(): + dict_landmarks['left_eyebrow'] = [] + dict_landmarks['left_eyebrow'].append([int(x),int(y),(0,255,0)]) + if draw_circle: + cv2.circle(img, (int(x),int(y)), 2, (0,255,0),-1) + elif 50>= i >=42: + if 'right_eyebrow' not in dict_landmarks.keys(): + dict_landmarks['right_eyebrow'] = [] + dict_landmarks['right_eyebrow'].append([int(x),int(y),(0,255,0)]) + if draw_circle: + cv2.circle(img, (int(x),int(y)), 2, (0,255,0),-1) + elif 67>= i >=60: + if 'left_eye' not in dict_landmarks.keys(): + dict_landmarks['left_eye'] = [] + dict_landmarks['left_eye'].append([int(x),int(y),(255,0,255)]) + if draw_circle: + cv2.circle(img, (int(x),int(y)), 2, (255,0,255),-1) + elif 75>= i >=68: + if 'right_eye' not in dict_landmarks.keys(): + dict_landmarks['right_eye'] = [] + dict_landmarks['right_eye'].append([int(x),int(y),(255,0,255)]) + if draw_circle: + cv2.circle(img, (int(x),int(y)), 2, (255,0,255),-1) + elif 97>= i >=96: + cv2.circle(img, (int(x),int(y)), 2, (0,0,255),-1) + elif 54>= i >=51: + if 'bridge_nose' not in dict_landmarks.keys(): + dict_landmarks['bridge_nose'] = [] + dict_landmarks['bridge_nose'].append([int(x),int(y),(0,170,255)]) + if draw_circle: + cv2.circle(img, (int(x),int(y)), 2, (0,170,255),-1) + elif 32>= i >=0: + if 'basin' not in dict_landmarks.keys(): + dict_landmarks['basin'] = [] + dict_landmarks['basin'].append([int(x),int(y),(255,30,30)]) + if draw_circle: + cv2.circle(img, (int(x),int(y)), 2, (255,30,30),-1) + elif 59>= i >=55: + if 'wing_nose' not in dict_landmarks.keys(): + dict_landmarks['wing_nose'] = [] + dict_landmarks['wing_nose'].append([int(x),int(y),(0,255,255)]) + if draw_circle: + cv2.circle(img, (int(x),int(y)), 2, (0,255,255),-1) + elif 87>= i >=76: + if 'out_lip' not in dict_landmarks.keys(): + dict_landmarks['out_lip'] = [] + dict_landmarks['out_lip'].append([int(x),int(y),(255,255,0)]) + if draw_circle: + cv2.circle(img, (int(x),int(y)), 2, (255,255,0),-1) + elif 95>= i >=88: + if 'in_lip' not in dict_landmarks.keys(): + dict_landmarks['in_lip'] = [] + dict_landmarks['in_lip'].append([int(x),int(y),(50,220,255)]) + if draw_circle: + cv2.circle(img, (int(x),int(y)), 2, (50,220,255),-1) + else: + if draw_circle: + cv2.circle(img, (int(x),int(y)), 2, (255,0,255),-1) + + return dict_landmarks + +def draw_contour(image,dict): + for key in dict.keys(): + # print(key) + _,_,color = dict[key][0] + + if 'basin' == key or 'wing_nose' == key: + pts = np.array([[dict[key][i][0],dict[key][i][1]] for i in range(len(dict[key]))],np.int32) + # print(pts) + cv2.polylines(image,[pts],False,color) + + else: + points_array = np.zeros((1,len(dict[key]),2),dtype = np.int32) + for i in range(len(dict[key])): + x,y,_ = dict[key][i] + points_array[0,i,0] = x + points_array[0,i,1] = y + + # cv2.fillPoly(image, points_array, color) + cv2.drawContours(image,points_array,-1,color,thickness=1) diff --git a/utils/model_utils.py b/utils/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..48cc3d33db8dafe7ed54f344825f6905e1adac26 --- /dev/null +++ b/utils/model_utils.py @@ -0,0 +1,61 @@ +#-*-coding:utf-8-*- +# date:2020-04-11 +# Author: Eric.Lee +# function: model utils + +import os +import numpy as np +import torch +import torch.backends.cudnn as cudnn +import random + +def get_acc(output, label): + total = output.shape[0] + _, pred_label = output.max(1) + num_correct = (pred_label == label).sum().item() + return num_correct / float(total) + +def set_learning_rate(optimizer, lr): + for param_group in optimizer.param_groups: + param_group['lr'] = lr + +def set_seed(seed = 666): + np.random.seed(seed) + random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + cudnn.deterministic = True + +def split_trainval_datasets(ops): + print(' --------------->>> split_trainval_datasets ') + train_split_datasets = [] + train_split_datasets_label = [] + + val_split_datasets = [] + val_split_datasets_label = [] + for idx,doc in enumerate(sorted(os.listdir(ops.train_path), key=lambda x:int(x.split('.')[0]), reverse=False)): + # print(' %s label is %s \n'%(doc,idx)) + + data_list = os.listdir(ops.train_path+doc) + random.shuffle(data_list) + + cal_split_num = int(len(data_list)*ops.val_factor) + + for i,file in enumerate(data_list): + if '.jpg' in file: + if i < cal_split_num: + val_split_datasets.append(ops.train_path+doc + '/' + file) + val_split_datasets_label.append(idx) + else: + train_split_datasets.append(ops.train_path+doc + '/' + file) + train_split_datasets_label.append(idx) + + print(ops.train_path+doc + '/' + file,idx) + + print('\n') + print('train_split_datasets len {}'.format(len(train_split_datasets))) + print('val_split_datasets len {}'.format(len(val_split_datasets))) + + return train_split_datasets,train_split_datasets_label,val_split_datasets,val_split_datasets_label