提交 408c23b2 编写于 作者: E Eric.Lee

update

上级 56bd9d9c
#-*-coding:utf-8-*-
# date:2021-12-05
# Author: Eric.Lee
## function: data agu
import numpy as np
import cv2
#-------------------------------------------------------------------------------
# eye_left_n,eye_right_n:为扰动后的参考点坐标
def hand_alignment_aug_fun(imgn,eye_left_n,eye_right_n,\
facial_landmarks_n = None,\
angle = None,desiredLeftEye=(0.34, 0.42),desiredFaceWidth=160, desiredFaceHeight=None,draw_flag = False):
if desiredFaceHeight is None:
desiredFaceHeight = desiredFaceWidth
leftEyeCenter = eye_left_n
rightEyeCenter = eye_right_n
# compute the angle between the eye centroids
dY = rightEyeCenter[1] - leftEyeCenter[1]
dX = rightEyeCenter[0] - leftEyeCenter[0]
if angle == None:
angle = np.degrees(np.arctan2(dY, dX))
else:
# print(' a) disturb angle : ',angle)
angle += np.degrees(np.arctan2(dY, dX))#基于正对角度的扰动
# print(' b) disturb angle : ',angle)
# compute the desired right eye x-coordinate based on the
# desired x-coordinate of the left eye
desiredRightEyeX = 1.0 - desiredLeftEye[0]
# determine the scale of the new resulting image by taking
# the ratio of the distance between eyes in the *current*
# image to the ratio of distance between eyes in the
# *desired* image
dist = np.sqrt((dX ** 2) + (dY ** 2))
desiredDist = (desiredRightEyeX - desiredLeftEye[0])
desiredDist *= desiredFaceWidth
scale = desiredDist / dist
# compute center (x, y)-coordinates (i.e., the median point)
# between the two eyes in the input image
eyesCenter = ((leftEyeCenter[0] + rightEyeCenter[0]) / 2,(leftEyeCenter[1] + rightEyeCenter[1]) / 2)
# grab the rotation matrix for rotating and scaling the face
M = cv2.getRotationMatrix2D(eyesCenter, angle, scale)
# update the translation component of the matrix
tX = desiredFaceWidth * 0.5
tY = desiredFaceHeight * desiredLeftEye[1]
M[0, 2] += (tX - eyesCenter[0])
M[1, 2] += (tY - eyesCenter[1])
M_reg = np.zeros((3,3),dtype = np.float32)
M_reg[0,:] = M[0,:]
M_reg[1,:] = M[1,:]
M_reg[2,:] = (0,0,1.)
# print(M_reg)
M_I = np.linalg.inv(M_reg)#矩阵求逆,从而获得,目标图到原图的关系
# print(M_I)
# apply the affine transformation
(w, h) = (desiredFaceWidth, desiredFaceHeight)
output = cv2.warpAffine(imgn, M, (w, h),flags=cv2.INTER_LINEAR,borderMode=cv2.BORDER_CONSTANT)# INTER_LINEAR INTER_CUBIC INTER_NEAREST
#BORDER_REFLECT BORDER_TRANSPARENT BORDER_REPLICATE CV_BORDER_WRAP BORDER_CONSTANT
pts_landmarks = []
for k in range(len(facial_landmarks_n)):
x = facial_landmarks_n[k][0]
y = facial_landmarks_n[k][1]
x_r = (x*M[0][0] + y*M[0][1] + M[0][2])
y_r = (x*M[1][0] + y*M[1][1] + M[1][2])
pts_landmarks.append([x_r,y_r])
# if draw_flag:
# cv2.circle(output, (int(x_r),int(y_r)), np.int(1),(0,0,255), 1)
#
# cv2.circle(output, (ptx2,pty2), np.int(1),(0,0,255), 1)
# cv2.circle(output, (ptx3,pty3), np.int(1),(0,255,0), 1)
return output,pts_landmarks,M_I
#-*-coding:utf-8-*-
# date:2019-05-20
# Author: Eric.Lee
# function: data iter
import glob
import math
import os
import random
import shutil
from pathlib import Path
from PIL import Image
# import matplotlib.pyplot as plt
from tqdm import tqdm
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from hand_data_iter.data_agu import *
import shutil
import json
def draw_bd_handpose(img_,hand_,x,y):
thick = 2
colors = [(0,215,255),(255,115,55),(5,255,55),(25,15,255),(225,15,55)]
#
cv2.line(img_, (int(hand_['0']['x']+x), int(hand_['0']['y']+y)),(int(hand_['1']['x']+x), int(hand_['1']['y']+y)), colors[0], thick)
cv2.line(img_, (int(hand_['1']['x']+x), int(hand_['1']['y']+y)),(int(hand_['2']['x']+x), int(hand_['2']['y']+y)), colors[0], thick)
cv2.line(img_, (int(hand_['2']['x']+x), int(hand_['2']['y']+y)),(int(hand_['3']['x']+x), int(hand_['3']['y']+y)), colors[0], thick)
cv2.line(img_, (int(hand_['3']['x']+x), int(hand_['3']['y']+y)),(int(hand_['4']['x']+x), int(hand_['4']['y']+y)), colors[0], thick)
cv2.line(img_, (int(hand_['0']['x']+x), int(hand_['0']['y']+y)),(int(hand_['5']['x']+x), int(hand_['5']['y']+y)), colors[1], thick)
cv2.line(img_, (int(hand_['5']['x']+x), int(hand_['5']['y']+y)),(int(hand_['6']['x']+x), int(hand_['6']['y']+y)), colors[1], thick)
cv2.line(img_, (int(hand_['6']['x']+x), int(hand_['6']['y']+y)),(int(hand_['7']['x']+x), int(hand_['7']['y']+y)), colors[1], thick)
cv2.line(img_, (int(hand_['7']['x']+x), int(hand_['7']['y']+y)),(int(hand_['8']['x']+x), int(hand_['8']['y']+y)), colors[1], thick)
cv2.line(img_, (int(hand_['0']['x']+x), int(hand_['0']['y']+y)),(int(hand_['9']['x']+x), int(hand_['9']['y']+y)), colors[2], thick)
cv2.line(img_, (int(hand_['9']['x']+x), int(hand_['9']['y']+y)),(int(hand_['10']['x']+x), int(hand_['10']['y']+y)), colors[2], thick)
cv2.line(img_, (int(hand_['10']['x']+x), int(hand_['10']['y']+y)),(int(hand_['11']['x']+x), int(hand_['11']['y']+y)), colors[2], thick)
cv2.line(img_, (int(hand_['11']['x']+x), int(hand_['11']['y']+y)),(int(hand_['12']['x']+x), int(hand_['12']['y']+y)), colors[2], thick)
cv2.line(img_, (int(hand_['0']['x']+x), int(hand_['0']['y']+y)),(int(hand_['13']['x']+x), int(hand_['13']['y']+y)), colors[3], thick)
cv2.line(img_, (int(hand_['13']['x']+x), int(hand_['13']['y']+y)),(int(hand_['14']['x']+x), int(hand_['14']['y']+y)), colors[3], thick)
cv2.line(img_, (int(hand_['14']['x']+x), int(hand_['14']['y']+y)),(int(hand_['15']['x']+x), int(hand_['15']['y']+y)), colors[3], thick)
cv2.line(img_, (int(hand_['15']['x']+x), int(hand_['15']['y']+y)),(int(hand_['16']['x']+x), int(hand_['16']['y']+y)), colors[3], thick)
cv2.line(img_, (int(hand_['0']['x']+x), int(hand_['0']['y']+y)),(int(hand_['17']['x']+x), int(hand_['17']['y']+y)), colors[4], thick)
cv2.line(img_, (int(hand_['17']['x']+x), int(hand_['17']['y']+y)),(int(hand_['18']['x']+x), int(hand_['18']['y']+y)), colors[4], thick)
cv2.line(img_, (int(hand_['18']['x']+x), int(hand_['18']['y']+y)),(int(hand_['19']['x']+x), int(hand_['19']['y']+y)), colors[4], thick)
cv2.line(img_, (int(hand_['19']['x']+x), int(hand_['19']['y']+y)),(int(hand_['20']['x']+x), int(hand_['20']['y']+y)), colors[4], thick)
def plot_box(bbox, img, color=None, label=None, line_thickness=None):
tl = line_thickness or round(0.002 * max(img.shape[0:2])) + 1
color = color or [random.randint(0, 255) for _ in range(3)]
c1, c2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3]))
cv2.rectangle(img, c1, c2, color, thickness=tl)# 目标的bbox
if label:
tf = max(tl - 2, 1)
t_size = cv2.getTextSize(label, 0, fontScale=tl / 4, thickness=tf)[0] # label size
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3 # 字体的bbox
cv2.rectangle(img, c1, c2, color, -1) # label 矩形填充
# 文本绘制
cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 4, [225, 255, 255],thickness=tf, lineType=cv2.LINE_AA)
def img_agu_channel_same(img_):
img_a = np.zeros(img_.shape, dtype = np.uint8)
gray = cv2.cvtColor(img_,cv2.COLOR_RGB2GRAY)
img_a[:,:,0] =gray
img_a[:,:,1] =gray
img_a[:,:,2] =gray
return img_a
# 图像白化
def prewhiten(x):
mean = np.mean(x)
std = np.std(x)
std_adj = np.maximum(std, 1.0 / np.sqrt(x.size))
y = np.multiply(np.subtract(x, mean), 1 / std_adj)
return y
# 图像亮度、对比度增强
def contrast_img(img, c, b): # 亮度就是每个像素所有通道都加上b
rows, cols, channels = img.shape
# 新建全零(黑色)图片数组:np.zeros(img1.shape, dtype=uint8)
blank = np.zeros([rows, cols, channels], img.dtype)
dst = cv2.addWeighted(img, c, blank, 1-c, b)
return dst
class LoadImagesAndLabels(Dataset):
def __init__(self, ops, img_size=(224,224), flag_agu = False,fix_res = True,vis = False):
# vis = True
print('img_size (height,width) : ',img_size[0],img_size[1])
print("train_path : {}".format(ops.train_path))
path = ops.train_path
file_list = []
hand_anno_list = []
idx = 0
for f_ in os.listdir(path):
if ".jpg" in f_:
img_path = path +f_
label_path = img_path.replace('.jpg','.json')
if not os.path.exists(label_path):
continue
f = open(label_path, encoding='utf-8')#读取 json文件
hand_dict_ = json.load(f)
f.close()
if len(hand_dict_)==0:
continue
hand_dict_ = hand_dict_["info"]
#----------------------------------------------
if vis:
img_ = cv2.imread(img_path)
img_ago = img_.copy()
# cv2.namedWindow("hand_d",0)
# cv2.imshow("hand_d",img_ago)
# cv2.waitKey(1)
#----------------------------------------------
# print("len hand_dict :",len(hand_dict_))
if len(hand_dict_)>0:
for msg in hand_dict_:
bbox = msg["bbox"]
pts = msg["pts"]
file_list.append(img_path)
hand_anno_list.append((bbox,pts))
idx += 1
print(" hands num : {}".format(idx),end = "\r")
#------------------------------------
if vis:
x1,y1,x2,y2 = int(bbox[0]),int(bbox[1]),int(bbox[2]),int(bbox[3])
hand = img_ago[y1:y2,x1:x2,:]
pts_ = []
x_max = -65535
y_max = -65535
x_min = 65535
y_min = 65535
for i in range(21):
x_,y_ = pts[str(i)]["x"],pts[str(i)]["y"]
x_ += x1
y_ += y1
pts_.append([x_,y_])
x_min = x_ if x_min>x_ else x_min
y_min = y_ if y_min>y_ else y_min
x_max = x_ if x_max<x_ else x_max
y_max = y_ if y_max<y_ else y_max
plot_box((x_min,y_min,x_max,y_max), img_, color=(255,100,100), label="hand", line_thickness=2)
offset_x = int((x_max-x_min)/8)
offset_y = int((y_max-y_min)/8)
pt_left = (x_min+random.randint(-offset_x,offset_x),(y_min+y_max)/2+random.randint(-offset_y,offset_y))
pt_right = (x_max+random.randint(-offset_x,offset_x),(y_min+y_max)/2+random.randint(-offset_y,offset_y))
angle_random = random.randint(-180,180)
scale_x = float(random.randint(20,32))/100.
hand_rot,pts_tor_landmarks = hand_alignment_aug_fun(img_ago,pt_left,pt_right,
facial_landmarks_n = pts_,\
angle = angle_random,desiredLeftEye=(scale_x, 0.5),
desiredFaceWidth=img_size[0], desiredFaceHeight=None,draw_flag = True)
pts_hand = {}
for ptk in range(21):
xh,yh = pts_tor_landmarks[ptk][0],pts_tor_landmarks[ptk][1]
pts_hand[str(ptk)] = {}
pts_hand[str(ptk)] = {
"x":xh,
"y":yh,
}
draw_bd_handpose(hand_rot,pts_hand,0,0)
cv2.namedWindow("hand_rotd",0)
cv2.imshow("hand_rotd",hand_rot)
print("hand_rot shape : {}".format(hand_rot.shape))
cv2.waitKey(1)
#
print()
self.files = file_list
self.hand_anno_list = hand_anno_list
self.img_size = img_size
self.flag_agu = flag_agu
# self.fix_res = fix_res
self.vis = vis
def __len__(self):
return len(self.files)
def __getitem__(self, index):
img_path = self.files[index]
bbox,pts = self.hand_anno_list[index]
img = cv2.imread(img_path) # BGR
#-------------------------------------
x1,y1,x2,y2 = int(bbox[0]),int(bbox[1]),int(bbox[2]),int(bbox[3])
pts_ = []
x_max = -65535
y_max = -65535
x_min = 65535
y_min = 65535
for i in range(21):
x_,y_ = pts[str(i)]["x"],pts[str(i)]["y"]
x_ += x1
y_ += y1
pts_.append([x_,y_])
x_min = x_ if x_min>x_ else x_min
y_min = y_ if y_min>y_ else y_min
x_max = x_ if x_max<x_ else x_max
y_max = y_ if y_max<y_ else y_max
if random.random() > 0.55:
offset_x = int((x_max-x_min)/8)
offset_y = int((y_max-y_min)/8)
pt_left = (x_min+random.randint(-offset_x,offset_x),(y_min+y_max)/2+random.randint(-offset_y,offset_y))
pt_right = (x_max+random.randint(-offset_x,offset_x),(y_min+y_max)/2+random.randint(-offset_y,offset_y))
angle_random = random.randint(-180,180)
scale_x = float(random.randint(12,33))/100.
hand_rot,pts_tor_landmarks,_ = hand_alignment_aug_fun(img,pt_left,pt_right,
facial_landmarks_n = pts_,\
angle = angle_random,desiredLeftEye=(scale_x, 0.5),
desiredFaceWidth=self.img_size[0], desiredFaceHeight=None,draw_flag = False)
if self.vis:
pts_hand = {}
for ptk in range(21):
xh,yh = pts_tor_landmarks[ptk][0],pts_tor_landmarks[ptk][1]
pts_hand[str(ptk)] = {}
pts_hand[str(ptk)] = {
"x":xh,
"y":yh,
}
draw_bd_handpose(hand_rot,pts_hand,0,0)
cv2.namedWindow("hand_rotd",0)
cv2.imshow("hand_rotd",hand_rot)
cv2.waitKey(1)
img_ = hand_rot
pts_tor_landmarks_norm = []
for i in range(len(pts_tor_landmarks)):
x_ = float(pts_tor_landmarks[i][0])/float(self.img_size[0])
y_ = float(pts_tor_landmarks[i][1])/float(self.img_size[0])
pts_tor_landmarks_norm.append([x_,y_])
else:
w_ = max(abs(x_max-x_min),abs(y_max-y_min))
w_ = w_*(1.+float(random.randint(5,40))/100.)
x_mid = (x_max+x_min)/2
y_mid = (y_max+y_min)/2
x1,y1,x2,y2 = int(x_mid-w_/2.),int(y_mid-w_/2.),int(x_mid+w_/2.),int(y_mid+w_/2.)
x1 = np.clip(x1,0,img.shape[1]-1)
x2 = np.clip(x2,0,img.shape[1]-1)
y1 = np.clip(y1,0,img.shape[0]-1)
y2 = np.clip(y2,0,img.shape[0]-1)
img_ = img[y1:y2,x1:x2,:]
#-----------------
pts_tor_landmarks = []
pts_hand = {}
for ptk in range(21):
xh,yh = pts[str(ptk)]["x"],pts[str(ptk)]["y"]
xh = xh + bbox[0] -x1
yh = yh + bbox[1] -y1
pts_tor_landmarks.append([xh,yh])
pts_hand[str(ptk)] = {
"x":xh,
"y":yh,
}
#----------------
if random.random() > 0.5: # 左右镜像
img_ = cv2.flip(img_,1)
pts_tor_landmarks = []
pts_hand = {}
for ptk in range(21):
xh,yh = pts[str(ptk)]["x"],pts[str(ptk)]["y"]
xh = xh + bbox[0] -x1
yh = yh + bbox[1] -y1
pts_tor_landmarks.append([img_.shape[1]-1-xh,yh])
pts_hand[str(ptk)] = {
"x":img_.shape[1]-1-xh,
"y":yh,
}
pts_tor_landmarks_norm = []
for i in range(len(pts_tor_landmarks)):
x_ = float(pts_tor_landmarks[i][0])/float(abs(x2-x1))
y_ = float(pts_tor_landmarks[i][1])/float(abs(y2-y1))
pts_tor_landmarks_norm.append([x_,y_])
#-----------------
if self.vis:
draw_bd_handpose(img_,pts_hand,0,0)
img_ = cv2.resize(img_, self.img_size, interpolation = random.randint(0,5))
if self.vis:
cv2.namedWindow("hand_zfx",0)
cv2.imshow("hand_zfx",img_)
cv2.waitKey(1)
#-------------------------------------
if self.flag_agu == True:
if random.random() > 0.5:
c = float(random.randint(80,120))/100.
b = random.randint(-10,10)
img_ = contrast_img(img_, c, b)
if self.flag_agu == True:
if random.random() > 0.9:
# print('agu hue ')
img_hsv=cv2.cvtColor(img_,cv2.COLOR_BGR2HSV)
hue_x = random.randint(-10,10)
# print(cc)
img_hsv[:,:,0]=(img_hsv[:,:,0]+hue_x)
img_hsv[:,:,0] =np.maximum(img_hsv[:,:,0],0)
img_hsv[:,:,0] =np.minimum(img_hsv[:,:,0],180)#范围 0 ~180
img_=cv2.cvtColor(img_hsv,cv2.COLOR_HSV2BGR)
if self.flag_agu == True:
if random.random() > 0.95:
img_ = img_agu_channel_same(img_)
if self.vis == True:
cv2.namedWindow('crop',0)
cv2.imshow('crop',img_)
cv2.waitKey(1)
img_ = img_.astype(np.float32)
img_ = (img_-128.)/256.
img_ = img_.transpose(2, 0, 1)
pts_tor_landmarks_norm = np.array(pts_tor_landmarks_norm).ravel()
return img_,pts_tor_landmarks_norm
#-*-coding:utf-8-*-
# date:2021-12-20
# Author: Eric.Lee
## function: handpose agu
import json
import cv2
import os
import random
from data_agu import hand_alignment_aug_fun
import numpy as np
def plot_box(bbox, img, color=None, label=None, line_thickness=None):
tl = line_thickness or round(0.002 * max(img.shape[0:2])) + 1
color = color or [random.randint(0, 255) for _ in range(3)]
c1, c2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3]))
cv2.rectangle(img, c1, c2, color, thickness=tl)# 目标的bbox
if label:
tf = max(tl - 2, 1)
t_size = cv2.getTextSize(label, 0, fontScale=tl / 4, thickness=tf)[0] # label size
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3 # 字体的bbox
cv2.rectangle(img, c1, c2, color, -1) # label 矩形填充
# 文本绘制
cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 4, [225, 255, 255],thickness=tf, lineType=cv2.LINE_AA)
def draw_bd_handpose(img_,hand_,x,y):
thick = 2
colors = [(0,215,255),(255,115,55),(5,255,55),(25,15,255),(225,15,55)]
#
cv2.line(img_, (int(hand_['0']['x']+x), int(hand_['0']['y']+y)),(int(hand_['1']['x']+x), int(hand_['1']['y']+y)), colors[0], thick)
cv2.line(img_, (int(hand_['1']['x']+x), int(hand_['1']['y']+y)),(int(hand_['2']['x']+x), int(hand_['2']['y']+y)), colors[0], thick)
cv2.line(img_, (int(hand_['2']['x']+x), int(hand_['2']['y']+y)),(int(hand_['3']['x']+x), int(hand_['3']['y']+y)), colors[0], thick)
cv2.line(img_, (int(hand_['3']['x']+x), int(hand_['3']['y']+y)),(int(hand_['4']['x']+x), int(hand_['4']['y']+y)), colors[0], thick)
cv2.line(img_, (int(hand_['0']['x']+x), int(hand_['0']['y']+y)),(int(hand_['5']['x']+x), int(hand_['5']['y']+y)), colors[1], thick)
cv2.line(img_, (int(hand_['5']['x']+x), int(hand_['5']['y']+y)),(int(hand_['6']['x']+x), int(hand_['6']['y']+y)), colors[1], thick)
cv2.line(img_, (int(hand_['6']['x']+x), int(hand_['6']['y']+y)),(int(hand_['7']['x']+x), int(hand_['7']['y']+y)), colors[1], thick)
cv2.line(img_, (int(hand_['7']['x']+x), int(hand_['7']['y']+y)),(int(hand_['8']['x']+x), int(hand_['8']['y']+y)), colors[1], thick)
cv2.line(img_, (int(hand_['0']['x']+x), int(hand_['0']['y']+y)),(int(hand_['9']['x']+x), int(hand_['9']['y']+y)), colors[2], thick)
cv2.line(img_, (int(hand_['9']['x']+x), int(hand_['9']['y']+y)),(int(hand_['10']['x']+x), int(hand_['10']['y']+y)), colors[2], thick)
cv2.line(img_, (int(hand_['10']['x']+x), int(hand_['10']['y']+y)),(int(hand_['11']['x']+x), int(hand_['11']['y']+y)), colors[2], thick)
cv2.line(img_, (int(hand_['11']['x']+x), int(hand_['11']['y']+y)),(int(hand_['12']['x']+x), int(hand_['12']['y']+y)), colors[2], thick)
cv2.line(img_, (int(hand_['0']['x']+x), int(hand_['0']['y']+y)),(int(hand_['13']['x']+x), int(hand_['13']['y']+y)), colors[3], thick)
cv2.line(img_, (int(hand_['13']['x']+x), int(hand_['13']['y']+y)),(int(hand_['14']['x']+x), int(hand_['14']['y']+y)), colors[3], thick)
cv2.line(img_, (int(hand_['14']['x']+x), int(hand_['14']['y']+y)),(int(hand_['15']['x']+x), int(hand_['15']['y']+y)), colors[3], thick)
cv2.line(img_, (int(hand_['15']['x']+x), int(hand_['15']['y']+y)),(int(hand_['16']['x']+x), int(hand_['16']['y']+y)), colors[3], thick)
cv2.line(img_, (int(hand_['0']['x']+x), int(hand_['0']['y']+y)),(int(hand_['17']['x']+x), int(hand_['17']['y']+y)), colors[4], thick)
cv2.line(img_, (int(hand_['17']['x']+x), int(hand_['17']['y']+y)),(int(hand_['18']['x']+x), int(hand_['18']['y']+y)), colors[4], thick)
cv2.line(img_, (int(hand_['18']['x']+x), int(hand_['18']['y']+y)),(int(hand_['19']['x']+x), int(hand_['19']['y']+y)), colors[4], thick)
cv2.line(img_, (int(hand_['19']['x']+x), int(hand_['19']['y']+y)),(int(hand_['20']['x']+x), int(hand_['20']['y']+y)), colors[4], thick)
if __name__ == "__main__":
path = "../../dpcs/handpose_datasets/"
vis = True
hand_idx = 0
for f_ in os.listdir(path):
if ".jpg" in f_:
img_path = path +f_
label_path = img_path.replace('.jpg','.json')
if not os.path.exists(label_path):
continue
img_ = cv2.imread(img_path)
img_ago = img_.copy()
f = open(label_path, encoding='utf-8')#读取 json文件
hand_dict_ = json.load(f)
f.close()
hand_dict_ = hand_dict_["info"]
print("len hand_dict :",len(hand_dict_))
if len(hand_dict_)>0:
for msg in hand_dict_:
bbox = msg["bbox"]
pts = msg["pts"]
print()
print(bbox)
# print(pts)
x1,y1,x2,y2 = int(bbox[0]),int(bbox[1]),int(bbox[2]),int(bbox[3])
hand = img_ago[y1:y2,x1:x2,:]
pts_ = []
x_max = -65535
y_max = -65535
x_min = 65535
y_min = 65535
for i in range(21):
x_,y_ = pts[str(i)]["x"],pts[str(i)]["y"]
x_ += x1
y_ += y1
pts_.append([x_,y_])
x_min = x_ if x_min>x_ else x_min
y_min = y_ if y_min>y_ else y_min
x_max = x_ if x_max<x_ else x_max
y_max = y_ if y_max<y_ else y_max
if vis:
plot_box((x_min,y_min,x_max,y_max), img_, color=(255,100,100), label="hand", line_thickness=2)
#
if True:
angle_random = random.randint(-22,22)
offset_x = int((x_max-x_min)/8)
offset_y = int((y_max-y_min)/8)
pt_left = (x_min+random.randint(-offset_x,offset_x),(y_min+y_max)/2+random.randint(-offset_y,offset_y))
pt_right = (x_max+random.randint(-offset_x,offset_x),(y_min+y_max)/2+random.randint(-offset_y,offset_y))
angle_random = random.randint(-90,90)
scale_x = float(random.randint(20,40))/100.
hand_rot,pts_tor_landmarks,M_I = hand_alignment_aug_fun(img_ago,pt_left,pt_right,
facial_landmarks_n = pts_,\
angle = angle_random,desiredLeftEye=(scale_x, 0.5),
desiredFaceWidth=256, desiredFaceHeight=None,draw_flag = True)
else:
offset_x = 0
offset_y = 0
pt_left = (x_min+random.randint(-offset_x,offset_x),(y_min+y_max)/2+random.randint(-offset_y,offset_y))
pt_right = (x_max+random.randint(-offset_x,offset_x),(y_min+y_max)/2+random.randint(-offset_y,offset_y))
angle_random = 0
scale_x = 0.25
hand_rot,pts_tor_landmarks,M_I = hand_alignment_aug_fun(img_ago,pt_left,pt_right,
facial_landmarks_n = pts_,\
angle = angle_random,desiredLeftEye=(scale_x, 0.5),
desiredFaceWidth=256, desiredFaceHeight=None,draw_flag = True)
#
hand_idx += 1
cv2.imwrite("../test_datasets/{}.jpg".format(hand_idx),hand_rot)
#
pts_hand = {}
pts_hand_global_rot = {}
for ptk in range(21):
xh,yh = pts_tor_landmarks[ptk][0],pts_tor_landmarks[ptk][1]
pts_hand[str(ptk)] = {}
pts_hand[str(ptk)] = {
"x":xh,
"y":yh,
}
#------------
x_r = (xh*M_I[0][0] + yh*M_I[0][1] + M_I[0][2])
y_r = (xh*M_I[1][0] + yh*M_I[1][1] + M_I[1][2])
pts_hand_global_rot[str(ptk)] = {
"x":x_r,
"y":y_r,
}
if vis:
draw_bd_handpose(hand_rot,pts_hand,0,0)
cv2.namedWindow("hand_rot",0)
cv2.imshow("hand_rot",hand_rot)
cv2.namedWindow("hand_origin",0)
cv2.imshow("hand_origin",hand)
RGB = (random.randint(50,255),random.randint(50,255),random.randint(50,255))
plot_box((x1,y1,x2,y2), img_, color=(RGB), label="hand", line_thickness=3)
# draw_bd_handpose(img_,pts,bbox[0],bbox[1])
draw_bd_handpose(img_,pts_hand_global_rot,0,0)
if vis:
cv2.putText(img_, 'len:{}'.format(len(hand_dict_)), (5,40),
cv2.FONT_HERSHEY_COMPLEX, 1.5, (0, 255, 0),4)
cv2.putText(img_, 'len:{}'.format(len(hand_dict_)), (5,40),
cv2.FONT_HERSHEY_COMPLEX, 1.5, (0, 0, 255))
cv2.namedWindow("Gesture_json",0)
cv2.imshow("Gesture_json",img_)
if cv2.waitKey(1) == 27:
break
#-*-coding:utf-8-*-
# date:2019-05-20
# function: wing loss
import torch
import torch.nn as nn
import torch.optim as optim
import os
import math
def wing_loss(landmarks, labels, w=0.06, epsilon=0.01):
"""
Arguments:
landmarks, labels: float tensors with shape [batch_size, landmarks]. landmarks means x1,x2,x3,x4...y1,y2,y3,y4 1-D
w, epsilon: a float numbers.
Returns:
a float tensor with shape [].
"""
x = landmarks - labels
c = w * (1.0 - math.log(1.0 + w / epsilon))
absolute_x = torch.abs(x)
losses = torch.where(\
(w>absolute_x),\
w * torch.log(1.0 + absolute_x / epsilon),\
absolute_x - c)
# loss = tf.reduce_mean(tf.reduce_mean(losses, axis=[1]), axis=0)
losses = torch.mean(losses,dim=1,keepdim=True)
loss = torch.mean(losses)
return loss
def got_total_wing_loss(output,crop_landmarks):
loss = wing_loss(output, crop_landmarks)
return loss
import torch
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152']
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, img_size=224,dropout_factor = 1.):
self.inplanes = 64
self.dropout_factor = dropout_factor
super(ResNet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
# see this issue: https://github.com/xxradon/PytorchToCaffe/issues/16
# self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
assert img_size % 32 == 0
pool_kernel = int(img_size / 32)
self.avgpool = nn.AvgPool2d(pool_kernel, stride=1, ceil_mode=True)
self.dropout = nn.Dropout(self.dropout_factor)
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.dropout(x)
x = self.fc(x)
return x
def load_model(model, pretrained_state_dict):
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_state_dict.items() if
k in model_dict and model_dict[k].size() == pretrained_state_dict[k].size()}
model.load_state_dict(pretrained_dict, strict=False)
if len(pretrained_dict) == 0:
print("[INFO] No params were loaded ...")
else:
for k, v in pretrained_state_dict.items():
if k in pretrained_dict:
print("==>> Load {} {}".format(k, v.size()))
else:
print("[INFO] Skip {} {}".format(k, v.size()))
return model
def resnet18(pretrained=False, **kwargs):
"""Constructs a ResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
if pretrained:
# model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
print("Load pretrained model from {}".format(model_urls['resnet18']))
pretrained_state_dict = model_zoo.load_url(model_urls['resnet18'])
model = load_model(model, pretrained_state_dict)
return model
def resnet34(pretrained=False, **kwargs):
"""Constructs a ResNet-34 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
if pretrained:
# model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
print("Load pretrained model from {}".format(model_urls['resnet34']))
pretrained_state_dict = model_zoo.load_url(model_urls['resnet34'])
model = load_model(model, pretrained_state_dict)
return model
def resnet50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
# model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
print("Load pretrained model from {}".format(model_urls['resnet50']))
pretrained_state_dict = model_zoo.load_url(model_urls['resnet50'])
model = load_model(model, pretrained_state_dict)
return model
def resnet101(pretrained=False, **kwargs):
"""Constructs a ResNet-101 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained:
# model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
print("Load pretrained model from {}".format(model_urls['resnet101']))
pretrained_state_dict = model_zoo.load_url(model_urls['resnet101'])
model = load_model(model, pretrained_state_dict)
return model
def resnet152(pretrained=False, **kwargs):
"""Constructs a ResNet-152 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
if pretrained:
# model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
print("Load pretrained model from {}".format(model_urls['resnet152']))
pretrained_state_dict = model_zoo.load_url(model_urls['resnet152'])
model = load_model(model, pretrained_state_dict)
return model
if __name__ == "__main__":
input = torch.randn([32, 3, 256,256])
model = resnet34(False, num_classes=2, img_size=256)
output = model(input)
print(output.size())
#-*-coding:utf-8-*-
# date:2020-06-24
# Author: Eric.Lee
## function: train
import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import sys
from utils.model_utils import *
from utils.common_utils import *
from hand_data_iter.datasets import *
from models.resnet import resnet50,resnet101
from loss.loss import *
import cv2
import time
import json
from datetime import datetime
def trainer(ops,f_log):
try:
os.environ['CUDA_VISIBLE_DEVICES'] = ops.GPUS
if ops.log_flag:
sys.stdout = f_log
set_seed(ops.seed)
#---------------------------------------------------------------- 构建模型
if ops.model == 'resnet_50':
model_ = resnet50(pretrained = True,num_classes = ops.num_classes,img_size = ops.img_size[0],dropout_factor=ops.dropout)
else:
model_ = resnet101(pretrained = True,num_classes = ops.num_classes,img_size = ops.img_size[0],dropout_factor=ops.dropout)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
model_ = model_.to(device)
# print(model_)# 打印模型结构
# Dataset
dataset = LoadImagesAndLabels(ops= ops,img_size=ops.img_size,flag_agu=ops.flag_agu,fix_res = ops.fix_res,vis = False)
print("handpose done")
print('len train datasets : %s'%(dataset.__len__()))
# Dataloader
dataloader = DataLoader(dataset,
batch_size=ops.batch_size,
num_workers=ops.num_workers,
shuffle=True,
pin_memory=False,
drop_last = True)
# 优化器设计
optimizer_Adam = torch.optim.Adam(model_.parameters(), lr=ops.init_lr, betas=(0.9, 0.99),weight_decay=1e-6)
# optimizer_SGD = optim.SGD(model_.parameters(), lr=ops.init_lr, momentum=ops.momentum, weight_decay=ops.weight_decay)# 优化器初始化
optimizer = optimizer_Adam
# 加载 finetune 模型
if os.access(ops.fintune_model,os.F_OK):# checkpoint
chkpt = torch.load(ops.fintune_model, map_location=device)
model_.load_state_dict(chkpt)
print('load fintune model : {}'.format(ops.fintune_model))
print('/**********************************************/')
# 损失函数
if ops.loss_define != 'wing_loss':
criterion = nn.MSELoss(reduce=True, reduction='mean')
step = 0
idx = 0
# 变量初始化
best_loss = np.inf
loss_mean = 0. # 损失均值
loss_idx = 0. # 损失计算计数器
flag_change_lr_cnt = 0 # 学习率更新计数器
init_lr = ops.init_lr # 学习率
epochs_loss_dict = {}
for epoch in range(0, ops.epochs):
if ops.log_flag:
sys.stdout = f_log
print('\nepoch %d ------>>>'%epoch)
model_.train()
# 学习率更新策略
if loss_mean!=0.:
if best_loss > (loss_mean/loss_idx):
flag_change_lr_cnt = 0
best_loss = (loss_mean/loss_idx)
else:
flag_change_lr_cnt += 1
if flag_change_lr_cnt > 20:
init_lr = init_lr*ops.lr_decay
set_learning_rate(optimizer, init_lr)
flag_change_lr_cnt = 0
loss_mean = 0. # 损失均值
loss_idx = 0. # 损失计算计数器
for i, (imgs_, pts_) in enumerate(dataloader):
# print('imgs_, pts_',imgs_.size(), pts_.size())
if use_cuda:
imgs_ = imgs_.cuda() # pytorch 的 数据输入格式 : (batch, channel, height, width)
pts_ = pts_.cuda()
output = model_(imgs_.float())
if ops.loss_define == 'wing_loss':
loss = got_total_wing_loss(output, pts_.float())
else:
loss = criterion(output, pts_.float())
loss_mean += loss.item()
loss_idx += 1.
if i%10 == 0:
loc_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
print(' %s - %s - epoch [%s/%s] (%s/%s):'%(loc_time,ops.model,epoch,ops.epochs,i,int(dataset.__len__()/ops.batch_size)),\
'Mean Loss : %.6f - Loss: %.6f'%(loss_mean/loss_idx,loss.item()),\
' lr : %.5f'%init_lr,' bs :',ops.batch_size,\
' img_size: %s x %s'%(ops.img_size[0],ops.img_size[1]),' best_loss: %.6f'%best_loss)
# 计算梯度
loss.backward()
# 优化器对模型参数更新
optimizer.step()
# 优化器梯度清零
optimizer.zero_grad()
step += 1
torch.save(model_.state_dict(), ops.model_exp + 'model_epoch-{}.pth'.format(epoch))
except Exception as e:
print('Exception : ',e) # 打印异常
print('Exception file : ', e.__traceback__.tb_frame.f_globals['__file__'])# 发生异常所在的文件
print('Exception line : ', e.__traceback__.tb_lineno)# 发生异常所在的行数
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=' Project Hand Train')
parser.add_argument('--seed', type=int, default = 126673,
help = 'seed') # 设置随机种子
parser.add_argument('--model_exp', type=str, default = './model_exp',
help = 'model_exp') # 模型输出文件夹
parser.add_argument('--model', type=str, default = 'resnet_50',
help = 'model : resnet_50,resnet_101') # 模型类型
parser.add_argument('--num_classes', type=int , default = 42,
help = 'num_classes') # landmarks 个数*2
parser.add_argument('--GPUS', type=str, default = '1',
help = 'GPUS') # GPU选择
parser.add_argument('--train_path', type=str,
default = "./handpose_datasets/",
help = 'datasets')# 训练集标注信息
parser.add_argument('--pretrained', type=bool, default = True,
help = 'imageNet_Pretrain') # 初始化学习率
parser.add_argument('--fintune_model', type=str, default = 'None',
help = 'fintune_model') # fintune model
parser.add_argument('--loss_define', type=str, default = 'wing_loss',
help = 'define_loss') # 损失函数定义
parser.add_argument('--init_lr', type=float, default = 1e-3,
help = 'init learning Rate') # 初始化学习率
parser.add_argument('--lr_decay', type=float, default = 0.1,
help = 'learningRate_decay') # 学习率权重衰减率
parser.add_argument('--weight_decay', type=float, default = 1e-6,
help = 'weight_decay') # 优化器正则损失权重
parser.add_argument('--momentum', type=float, default = 0.9,
help = 'momentum') # 优化器动量
parser.add_argument('--batch_size', type=int, default = 128,
help = 'batch_size') # 训练每批次图像数量
parser.add_argument('--dropout', type=float, default = 0.5,
help = 'dropout') # dropout
parser.add_argument('--epochs', type=int, default = 2000,
help = 'epochs') # 训练周期
parser.add_argument('--num_workers', type=int, default = 10,
help = 'num_workers') # 训练数据生成器线程数
parser.add_argument('--img_size', type=tuple , default = (256,256),
help = 'img_size') # 输入模型图片尺寸
parser.add_argument('--flag_agu', type=bool , default = True,
help = 'data_augmentation') # 训练数据生成器是否进行数据扩增
parser.add_argument('--fix_res', type=bool , default = False,
help = 'fix_resolution') # 输入模型样本图片是否保证图像分辨率的长宽比
parser.add_argument('--clear_model_exp', type=bool, default = False,
help = 'clear_model_exp') # 模型输出文件夹是否进行清除
parser.add_argument('--log_flag', type=bool, default = False,
help = 'log flag') # 是否保存训练 log
#--------------------------------------------------------------------------
args = parser.parse_args()# 解析添加参数
#--------------------------------------------------------------------------
mkdir_(args.model_exp, flag_rm=args.clear_model_exp)
loc_time = time.localtime()
args.model_exp = args.model_exp + '/' + time.strftime("%Y-%m-%d_%H-%M-%S", loc_time)+'/'
mkdir_(args.model_exp, flag_rm=args.clear_model_exp)
f_log = None
if args.log_flag:
f_log = open(args.model_exp+'/train_{}.log'.format(time.strftime("%Y-%m-%d_%H-%M-%S",loc_time)), 'a+')
sys.stdout = f_log
print('---------------------------------- log : {}'.format(time.strftime("%Y-%m-%d %H:%M:%S", loc_time)))
print('\n/******************* {} ******************/\n'.format(parser.description))
unparsed = vars(args) # parse_args()方法的返回值为namespace,用vars()内建函数化为字典
for key in unparsed.keys():
print('{} : {}'.format(key,unparsed[key]))
unparsed['time'] = time.strftime("%Y-%m-%d %H:%M:%S", loc_time)
fs = open(args.model_exp+'train_ops.json',"w",encoding='utf-8')
json.dump(unparsed,fs,ensure_ascii=False,indent = 1)
fs.close()
trainer(ops = args,f_log = f_log)# 模型训练
if args.log_flag:
sys.stdout = f_log
print('well done : {}'.format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())))
#-*-coding:utf-8-*-
# date:2020-04-11
# Author: Eric.Lee
# function: common utils
import os
import shutil
import cv2
import numpy as np
import json
def mkdir_(path, flag_rm=False):
if os.path.exists(path):
if flag_rm == True:
shutil.rmtree(path)
os.mkdir(path)
print('remove {} done ~ '.format(path))
else:
os.mkdir(path)
def plot_box(bbox, img, color=None, label=None, line_thickness=None):
tl = line_thickness or round(0.002 * max(img.shape[0:2])) + 1
color = color or [random.randint(0, 255) for _ in range(3)]
c1, c2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3]))
cv2.rectangle(img, c1, c2, color, thickness=tl)# 目标的bbox
if label:
tf = max(tl - 2, 1)
t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0] # label size
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3 # 字体的bbox
cv2.rectangle(img, c1, c2, color, -1) # label 矩形填充
# 文本绘制
cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 4, [225, 255, 255],thickness=tf, lineType=cv2.LINE_AA)
class JSON_Encoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
else:
return super(JSON_Encoder, self).default(obj)
#-*-coding:utf-8-*-
# date:2020-04-11
# Author: Eric.Lee
# function: model utils
import os
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import random
def get_acc(output, label):
total = output.shape[0]
_, pred_label = output.max(1)
num_correct = (pred_label == label).sum().item()
return num_correct / float(total)
def set_learning_rate(optimizer, lr):
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def set_seed(seed = 666):
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
cudnn.deterministic = True
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册