From 5d31ba1cbfe9ecb820fdb975799cd7eefc4dfa65 Mon Sep 17 00:00:00 2001 From: hypox64 Date: Sat, 5 Oct 2019 23:53:57 +0800 Subject: [PATCH] add train part --- make_datasets/get_video_dataset.py | 79 ++++++++ make_datasets/use_irregular_holes_mask.py | 69 +++++++ train/train.py | 209 ++++++++++++++++++++++ 3 files changed, 357 insertions(+) create mode 100644 make_datasets/get_video_dataset.py create mode 100644 make_datasets/use_irregular_holes_mask.py create mode 100644 train/train.py diff --git a/make_datasets/get_video_dataset.py b/make_datasets/get_video_dataset.py new file mode 100644 index 0000000..1e715da --- /dev/null +++ b/make_datasets/get_video_dataset.py @@ -0,0 +1,79 @@ +import os +import numpy as np +import cv2 +import random + +import sys +sys.path.append("..") +from models import runmodel,loadmodel +from util import mosaic,util,ffmpeg,filt +from util import image_processing as impro +from options import Options + +opt = Options().getparse() +util.file_init(opt) + +videos = os.listdir('./video') +videos.sort() +opt.model_path = '../pretrained_models/add_youknow_128.pth' +opt.use_gpu = True + +net = loadmodel.unet(opt) +for path in videos: + + path = os.path.join('./video',path) + util.clean_tempfiles() + ffmpeg.video2voice(path,'./tmp/voice_tmp.mp3') + ffmpeg.video2image(path,'./tmp/video2image/output_%05d.'+opt.tempimage_type) + imagepaths=os.listdir('./tmp/video2image') + imagepaths.sort() + + # get position + positions = [] + img_ori_example = impro.imread(os.path.join('./tmp/video2image',imagepaths[0])) + mask_avg = np.zeros((impro.resize(img_ori_example, 128)).shape[:2]) + for imagepath in imagepaths: + imagepath = os.path.join('./tmp/video2image',imagepath) + print('Find ROI location:',imagepath) + img = impro.imread(imagepath) + x,y,size,mask = runmodel.get_mosaic_position(img,net,opt,threshold = 64) + cv2.imwrite(os.path.join('./tmp/ROI_mask', + os.path.basename(imagepath)),mask) + positions.append([x,y,size]) + mask_avg = mask_avg + mask + print('Optimize ROI locations...') + mask_index = filt.position_medfilt(np.array(positions), 13) + + mask = np.clip(mask_avg/len(imagepaths),0,255).astype('uint8') + mask = impro.mask_threshold(mask,20,32) + x,y,size,area = impro.boundingSquare(mask,Ex_mul=1.5) + rat = min(img_ori_example.shape[:2])/128.0 + x,y,size = int(rat*x),int(rat*y),int(rat*size) + cv2.imwrite(os.path.join('./tmp/ROI_mask_check', + 'test_show.png'),mask) + if size !=0 : + mask_path = './dataset/'+os.path.splitext(os.path.basename(path))[0]+'/mask' + ori_path = './dataset/'+os.path.splitext(os.path.basename(path))[0]+'/ori' + mosaic_path = './dataset/'+os.path.splitext(os.path.basename(path))[0]+'/mosaic' + os.makedirs('./dataset/'+os.path.splitext(os.path.basename(path))[0]+'') + os.makedirs(mask_path) + os.makedirs(ori_path) + os.makedirs(mosaic_path) + print('Add mosaic to images...') + mosaic_size = mosaic.get_autosize(img_ori_example,mask,area_type = 'bounding')*random.uniform(1,2) + models = ['squa_avg','rect_avg','squa_mid'] + mosaic_type = random.randint(0,len(models)-1) + rect_rat = random.uniform(1.2,1.6) + for i in range(len(imagepaths)): + mask = impro.imread(os.path.join('./tmp/ROI_mask',imagepaths[mask_index[i]])) + img_ori = impro.imread(os.path.join('./tmp/video2image',imagepaths[i])) + img_mosaic = mosaic.addmosaic_normal(img_ori,mask,mosaic_size,model = models[mosaic_type],rect_rat=rect_rat) + mask = impro.resize(mask, min(img_ori.shape[:2])) + + img_ori_crop = impro.resize(img_ori[y-size:y+size,x-size:x+size],256) + img_mosaic_crop = impro.resize(img_mosaic[y-size:y+size,x-size:x+size],256) + mask_crop = impro.resize(mask[y-size:y+size,x-size:x+size],256) + + cv2.imwrite(os.path.join(ori_path,os.path.basename(imagepaths[i])),img_ori_crop) + cv2.imwrite(os.path.join(mosaic_path,os.path.basename(imagepaths[i])),img_mosaic_crop) + cv2.imwrite(os.path.join(mask_path,os.path.basename(imagepaths[i])),mask_crop) \ No newline at end of file diff --git a/make_datasets/use_irregular_holes_mask.py b/make_datasets/use_irregular_holes_mask.py new file mode 100644 index 0000000..f00606d --- /dev/null +++ b/make_datasets/use_irregular_holes_mask.py @@ -0,0 +1,69 @@ +import numpy as np +import cv2 +import os +from torchvision import transforms +from PIL import Image +import random +import sys +sys.path.append("..") +import util.image_processing as impro +from util import util,mosaic +import datetime + +ir_mask_path = './Irregular_Holes_mask' +# img_path = 'D:/MyProject_new/face_512' +img_path ='/media/hypo/Hypoyun/Hypoyun/手机摄影/20190219' +output_dir = './datasets' +util.makedirs(output_dir) +HD = True #if false make dataset for pix2pix, if Ture for pix2pix_HD +MASK = True +if HD: + train_A_path = os.path.join(output_dir,'train_A') + train_B_path = os.path.join(output_dir,'train_B') + util.makedirs(train_A_path) + util.makedirs(train_B_path) +else: + train_path = os.path.join(output_dir,'train') + util.makedirs(train_path) +if MASK: + mask_path = os.path.join(output_dir,'mask') + util.makedirs(mask_path) + +transform_mask = transforms.Compose([ + transforms.RandomResizedCrop(size=512, scale=(0.5,1)), + transforms.RandomHorizontalFlip(), + ]) + +transform_img = transforms.Compose([ + + transforms.Resize(512), + transforms.RandomCrop(512) + ]) + +mask_names = os.listdir(ir_mask_path) +img_names = os.listdir(img_path) +print('Find images:',len(img_names)) + +for i,img_name in enumerate(img_names,1): + try: + img = Image.open(os.path.join(img_path,img_name)) + img = transform_img(img) + img = np.array(img) + img = img[...,::-1] + + mask = Image.open(os.path.join(ir_mask_path,random.choices(mask_names)[0])) + mask = transform_mask(mask) + mask = np.array(mask) + + mosaic_img = mosaic.addmosaic_random(img, mask) + if HD: + cv2.imwrite(os.path.join(train_A_path,'%05d' % i+'.jpg'), mosaic_img) + cv2.imwrite(os.path.join(train_B_path,'%05d' % i+'.jpg'), img) + else: + merge_img = impro.makedataset(mosaic_img, img) + cv2.imwrite(os.path.join(train_path,'%05d' % i+'.jpg'), merge_img) + if MASK: + cv2.imwrite(os.path.join(mask_path,'%05d' % i+'.png'), mask) + print("Processing:",img_name," ","Remain:",len(img_names)-i) + except Exception as e: + print(img_name,e) diff --git a/train/train.py b/train/train.py new file mode 100644 index 0000000..a88c5ae --- /dev/null +++ b/train/train.py @@ -0,0 +1,209 @@ +import os +import numpy as np +import cv2 +import random +import torch +import torch.nn as nn +import time + +import sys +sys.path.append("..") +from models import runmodel,loadmodel +from util import mosaic,util,ffmpeg,filt,data +from util import image_processing as impro +from cores import Options +from models import pix2pix_model +from matplotlib import pyplot as plt +import torch.backends.cudnn as cudnn + +N = 25 +ITER = 1000000 +LR = 0.0002 +use_gpu = True +CONTINUE = True +# BATCHSIZE = 4 +dir_checkpoint = 'checkpoints/' +SAVE_FRE = 5000 +start_iter = 0 +SIZE = 256 +lambda_L1 = 100.0 +opt = Options().getparse() +opt.use_gpu=True +videos = os.listdir('./dataset') +videos.sort() +lengths = [] +for video in videos: + video_images = os.listdir('./dataset/'+video+'/ori') + lengths.append(len(video_images)) + + +netG = pix2pix_model.define_G(3*N+1, 3, 128, 'resnet_9blocks', norm='instance',use_dropout=True, init_type='normal', gpu_ids=[]) +netD = pix2pix_model.define_D(3*2, 64, 'basic', n_layers_D=3, norm='instance', init_type='normal', init_gain=0.02, gpu_ids=[]) + +if CONTINUE: + netG.load_state_dict(torch.load(dir_checkpoint+'last_G.pth')) + netD.load_state_dict(torch.load(dir_checkpoint+'last_D.pth')) + f = open('./iter','r') + start_iter = int(f.read()) + f.close() +if use_gpu: + netG.cuda() + netD.cuda() + cudnn.benchmark = True +optimizer_G = torch.optim.Adam(netG.parameters(), lr=LR) +optimizer_D = torch.optim.Adam(netG.parameters(), lr=LR) +criterion_L1 = nn.L1Loss() +criterion_L2 = nn.MSELoss() +criterionGAN = pix2pix_model.GANLoss('lsgan').cuda() + +def showresult(img1,img2,img3,name): + img1 = (img1.cpu().detach().numpy()*255) + img2 = (img2.cpu().detach().numpy()*255) + img3 = (img3.cpu().detach().numpy()*255) + batchsize = img1.shape[0] + size = img1.shape[3] + ran =int(batchsize*random.random()) + showimg=np.zeros((size,size*3,3)) + showimg[0:size,0:size] =img1[ran].transpose((1, 2, 0)) + showimg[0:size,size:size*2] = img2[ran].transpose((1, 2, 0)) + showimg[0:size,size*2:size*3] = img3[ran].transpose((1, 2, 0)) + cv2.imwrite(name, showimg) + + +def loaddata(): + video_index = random.randint(0,len(videos)-1) + video = videos[video_index] + img_index = random.randint(N,lengths[video_index]- N) + input_img = np.zeros((SIZE,SIZE,3*N+1), dtype='uint8') + for i in range(0,N): + # print('./dataset/'+video+'/mosaic/output_'+'%05d'%(img_index+i-int(N/2))+'.png') + img = cv2.imread('./dataset/'+video+'/mosaic/output_'+'%05d'%(img_index+i-int(N/2))+'.png') + img = impro.resize(img,SIZE) + input_img[:,:,i*3:(i+1)*3] = img + mask = cv2.imread('./dataset/'+video+'/mask/output_'+'%05d'%(img_index)+'.png',0) + mask = impro.resize(mask,256) + mask = impro.mask_threshold(mask,15,128) + input_img[:,:,-1] = mask + input_img = data.im2tensor(input_img,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False) + + ground_true = cv2.imread('./dataset/'+video+'/ori/output_'+'%05d'%(img_index)+'.png') + ground_true = impro.resize(ground_true,SIZE) + # ground_true = im2tensor(ground_true,use_gpu) + ground_true = data.im2tensor(ground_true,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False) + return input_img,ground_true + +input_imgs=[] +ground_trues=[] +def preload(): + while 1: + input_img,ground_true = loaddata() + input_imgs.append(input_img) + ground_trues.append(ground_true) + if len(input_imgs)>10: + del(input_imgs[0]) + del(ground_trues[0]) +import threading +t=threading.Thread(target=preload,args=()) #t为新创建的线程 +t.start() +time.sleep(3) #wait frist load + + +netG.train() +loss_sum = [0.,0.] +loss_plot = [[],[]] +item_plot = [] +time_start=time.time() +print("Begin training...") +for iter in range(start_iter+1,ITER): + + # input_img,ground_true = loaddata() + ran = random.randint(0, 9) + input_img = input_imgs[ran] + ground_true = ground_trues[ran] + + pred = netG(input_img) + + fake_AB = torch.cat((input_img[:,int((N+1)/2)*3:(int((N+1)/2)+1)*3,:,:], pred), 1) + pred_fake = netD(fake_AB.detach()) + loss_D_fake = criterionGAN(pred_fake, False) + + real_AB = torch.cat((input_img[:,int((N+1)/2)*3:(int((N+1)/2)+1)*3,:,:], ground_true), 1) + pred_real = netD(real_AB) + loss_D_real = criterionGAN(pred_real, True) + loss_D = (loss_D_fake + loss_D_real) * 0.5 + + optimizer_D.zero_grad() + loss_D.backward() + optimizer_D.step() + + fake_AB = torch.cat((input_img[:,int((N+1)/2)*3:(int((N+1)/2)+1)*3,:,:], pred), 1) + pred_fake = netD(fake_AB) + loss_G_GAN = criterionGAN(pred_fake, True) + # Second, G(A) = B + loss_G_L1 = criterion_L1(pred, ground_true) * lambda_L1 + # combine loss and calculate gradients + loss_G = loss_G_GAN + loss_G_L1 + loss_sum[0] += loss_G_L1.item() + loss_sum[1] += loss_G.item() + + optimizer_G.zero_grad() + loss_G.backward() + optimizer_G.step() + + + + # a = netD(ground_true) + # print(a.size()) + # loss = criterion_L1(pred, ground_true)+criterion_L2(pred, ground_true) + # # loss = criterion_L2(pred, ground_true) + # loss_sum += loss.item() + + # optimizer_G.zero_grad() + # loss.backward() + # optimizer_G.step() + + if (iter+1)%100 == 0: + showresult(input_img[:,int((N+1)/2)*3:(int((N+1)/2)+1)*3,:,:], ground_true, pred,'./result_train.png') + if (iter+1)%100 == 0: + time_end=time.time() + print('iter:',iter+1,' L1_loss:', round(loss_sum[0]/100,4),'G_loss:', round(loss_sum[1]/100,4),'time:',round((time_end-time_start)/100,4)) + if (iter+1)/100 >= 10: + loss_plot[0].append(loss_sum[0]/100) + loss_plot[1].append(loss_sum[1]/100) + item_plot.append(iter+1) + plt.plot(item_plot,loss_plot[0]) + plt.plot(item_plot,loss_plot[1]) + plt.savefig('./loss.png') + plt.close() + loss_sum = [0.,0.] + + #show test result + # netG.eval() + # input_img = np.zeros((SIZE,SIZE,3*N), dtype='uint8') + # imgs = os.listdir('./test') + # for i in range(0,N): + # # print('./dataset/'+video+'/mosaic/output_'+'%05d'%(img_index+i-int(N/2))+'.png') + # img = cv2.imread('./test/'+imgs[i]) + # img = impro.resize(img,SIZE) + # input_img[:,:,i*3:(i+1)*3] = img + # input_img = im2tensor(input_img,use_gpu) + # ground_true = cv2.imread('./test/output_'+'%05d'%13+'.png') + # ground_true = impro.resize(ground_true,SIZE) + # ground_true = im2tensor(ground_true,use_gpu) + # pred = netG(input_img) + # showresult(input_img[:,int((N+1)/2)*3:(int((N+1)/2)+1)*3,:,:],pred,pred,'./result_test.png') + + netG.train() + time_start=time.time() + + if (iter+1)%SAVE_FRE == 0: + torch.save(netG.cpu().state_dict(),dir_checkpoint+'last_G.pth') + torch.save(netD.cpu().state_dict(),dir_checkpoint+'last_D.pth') + if use_gpu: + netG.cuda() + netD.cuda() + f = open('./iter','w+') + f.write(str(iter+1)) + f.close() + # torch.save(netG.cpu().state_dict(),dir_checkpoint+'iter'+str(iter+1)+'.pth') + print('network saved.') -- GitLab