提交 7e7145a9 编写于 作者: H hypox64

Spring Festival Evening

上级 6c130aab
...@@ -134,6 +134,7 @@ __pycache__/ ...@@ -134,6 +134,7 @@ __pycache__/
tmp/ tmp/
checkpoints/ checkpoints/
mask/ mask/
mask_old/
origin_image/ origin_image/
datasets/ datasets/
dataset/ dataset/
...@@ -149,6 +150,8 @@ result/ ...@@ -149,6 +150,8 @@ result/
/result /result
/reference /reference
/python_test.py /python_test.py
/pretrained_models_old
/deepmosaic_window
#./make_datasets #./make_datasets
/make_datasets/video /make_datasets/video
/make_datasets/tmp /make_datasets/tmp
......
...@@ -6,17 +6,15 @@ from models import runmodel,loadmodel ...@@ -6,17 +6,15 @@ from models import runmodel,loadmodel
from util import mosaic,util,ffmpeg,filt,data from util import mosaic,util,ffmpeg,filt,data
from util import image_processing as impro from util import image_processing as impro
def addmosaic_img(opt): def addmosaic_img(opt,netS):
net = loadmodel.unet(opt)
path = opt.media_path path = opt.media_path
print('Add Mosaic:',path) print('Add Mosaic:',path)
img = impro.imread(path) img = impro.imread(path)
mask = runmodel.get_ROI_position(img,net,opt)[0] mask = runmodel.get_ROI_position(img,netS,opt)[0]
img = mosaic.addmosaic(img,mask,opt) img = mosaic.addmosaic(img,mask,opt)
cv2.imwrite(os.path.join(opt.result_dir,os.path.splitext(os.path.basename(path))[0]+'_add.jpg'),img) cv2.imwrite(os.path.join(opt.result_dir,os.path.splitext(os.path.basename(path))[0]+'_add.jpg'),img)
def addmosaic_video(opt): def addmosaic_video(opt,netS):
net = loadmodel.unet(opt)
path = opt.media_path path = opt.media_path
util.clean_tempfiles() util.clean_tempfiles()
fps = ffmpeg.get_video_infos(path)[0] fps = ffmpeg.get_video_infos(path)[0]
...@@ -30,7 +28,7 @@ def addmosaic_video(opt): ...@@ -30,7 +28,7 @@ def addmosaic_video(opt):
for imagepath in imagepaths: for imagepath in imagepaths:
print('Find ROI location:',imagepath) print('Find ROI location:',imagepath)
img = impro.imread(os.path.join('./tmp/video2image',imagepath)) img = impro.imread(os.path.join('./tmp/video2image',imagepath))
mask,x,y,area = runmodel.get_ROI_position(img,net,opt) mask,x,y,area = runmodel.get_ROI_position(img,netS,opt)
positions.append([x,y,area]) positions.append([x,y,area])
cv2.imwrite(os.path.join('./tmp/ROI_mask',imagepath),mask) cv2.imwrite(os.path.join('./tmp/ROI_mask',imagepath),mask)
print('Optimize ROI locations...') print('Optimize ROI locations...')
...@@ -49,13 +47,13 @@ def addmosaic_video(opt): ...@@ -49,13 +47,13 @@ def addmosaic_video(opt):
'./tmp/voice_tmp.mp3', './tmp/voice_tmp.mp3',
os.path.join(opt.result_dir,os.path.splitext(os.path.basename(path))[0]+'_add.mp4')) os.path.join(opt.result_dir,os.path.splitext(os.path.basename(path))[0]+'_add.mp4'))
def cleanmosaic_img(opt): def cleanmosaic_img(opt,netG,netM):
netG = loadmodel.pix2pix(opt)
net_mosaic_pos = loadmodel.unet_clean(opt)
path = opt.media_path path = opt.media_path
print('Clean Mosaic:',path) print('Clean Mosaic:',path)
img_origin = impro.imread(path) img_origin = impro.imread(path)
x,y,size = runmodel.get_mosaic_position(img_origin,net_mosaic_pos,opt)[:3] x,y,size,mask = runmodel.get_mosaic_position(img_origin,netM,opt)
cv2.imwrite('./mask/'+os.path.basename(path), mask)
img_result = img_origin.copy() img_result = img_origin.copy()
if size != 0 : if size != 0 :
img_mosaic = img_origin[y-size:y+size,x-size:x+size] img_mosaic = img_origin[y-size:y+size,x-size:x+size]
...@@ -65,9 +63,7 @@ def cleanmosaic_img(opt): ...@@ -65,9 +63,7 @@ def cleanmosaic_img(opt):
print('Do not find mosaic') print('Do not find mosaic')
cv2.imwrite(os.path.join(opt.result_dir,os.path.splitext(os.path.basename(path))[0]+'_clean.jpg'),img_result) cv2.imwrite(os.path.join(opt.result_dir,os.path.splitext(os.path.basename(path))[0]+'_clean.jpg'),img_result)
def cleanmosaic_video_byframe(opt): def cleanmosaic_video_byframe(opt,netG,netM):
netG = loadmodel.pix2pix(opt)
net_mosaic_pos = loadmodel.unet_clean(opt)
path = opt.media_path path = opt.media_path
util.clean_tempfiles() util.clean_tempfiles()
fps = ffmpeg.get_video_infos(path)[0] fps = ffmpeg.get_video_infos(path)[0]
...@@ -80,7 +76,7 @@ def cleanmosaic_video_byframe(opt): ...@@ -80,7 +76,7 @@ def cleanmosaic_video_byframe(opt):
# get position # get position
for imagepath in imagepaths: for imagepath in imagepaths:
img_origin = impro.imread(os.path.join('./tmp/video2image',imagepath)) img_origin = impro.imread(os.path.join('./tmp/video2image',imagepath))
x,y,size = runmodel.get_mosaic_position(img_origin,net_mosaic_pos,opt)[:3] x,y,size = runmodel.get_mosaic_position(img_origin,netM,opt)[:3]
positions.append([x,y,size]) positions.append([x,y,size])
print('Find mosaic location:',imagepath) print('Find mosaic location:',imagepath)
print('Optimize mosaic locations...') print('Optimize mosaic locations...')
...@@ -103,9 +99,7 @@ def cleanmosaic_video_byframe(opt): ...@@ -103,9 +99,7 @@ def cleanmosaic_video_byframe(opt):
'./tmp/voice_tmp.mp3', './tmp/voice_tmp.mp3',
os.path.join(opt.result_dir,os.path.splitext(os.path.basename(path))[0]+'_clean.mp4')) os.path.join(opt.result_dir,os.path.splitext(os.path.basename(path))[0]+'_clean.mp4'))
def cleanmosaic_video_fusion(opt): def cleanmosaic_video_fusion(opt,netG,netM):
net = loadmodel.video(opt)
net_mosaic_pos = loadmodel.unet_clean(opt)
path = opt.media_path path = opt.media_path
N = 25 N = 25
INPUT_SIZE = 128 INPUT_SIZE = 128
...@@ -122,7 +116,7 @@ def cleanmosaic_video_fusion(opt): ...@@ -122,7 +116,7 @@ def cleanmosaic_video_fusion(opt):
for imagepath in imagepaths: for imagepath in imagepaths:
img_origin = impro.imread(os.path.join('./tmp/video2image',imagepath)) img_origin = impro.imread(os.path.join('./tmp/video2image',imagepath))
# x,y,size = runmodel.get_mosaic_position(img_origin,net_mosaic_pos,opt)[:3] # x,y,size = runmodel.get_mosaic_position(img_origin,net_mosaic_pos,opt)[:3]
x,y,size,mask = runmodel.get_mosaic_position(img_origin,net_mosaic_pos,opt) x,y,size,mask = runmodel.get_mosaic_position(img_origin,netM,opt)
cv2.imwrite(os.path.join('./tmp/mosaic_mask',imagepath), mask) cv2.imwrite(os.path.join('./tmp/mosaic_mask',imagepath), mask)
positions.append([x,y,size]) positions.append([x,y,size])
print('Find mosaic location:',imagepath) print('Find mosaic location:',imagepath)
...@@ -151,11 +145,12 @@ def cleanmosaic_video_fusion(opt): ...@@ -151,11 +145,12 @@ def cleanmosaic_video_fusion(opt):
mask = mask[y-size:y+size,x-size:x+size] mask = mask[y-size:y+size,x-size:x+size]
mask = impro.resize(mask, INPUT_SIZE) mask = impro.resize(mask, INPUT_SIZE)
mosaic_input[:,:,-1] = mask mosaic_input[:,:,-1] = mask
mosaic_input = data.im2tensor(mosaic_input,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False) mosaic_input = data.im2tensor(mosaic_input,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False,is0_1 = False)
unmosaic_pred = net(mosaic_input) unmosaic_pred = netG(mosaic_input)
unmosaic_pred = (unmosaic_pred.cpu().detach().numpy()*255)[0] #unmosaic_pred = (unmosaic_pred.cpu().detach().numpy()*255)[0]
img_fake = unmosaic_pred.transpose((1, 2, 0)) #img_fake = unmosaic_pred.transpose((1, 2, 0))
img_fake = data.tensor2im(unmosaic_pred,rgb2bgr = False ,is0_1 = False)
img_result = impro.replace_mosaic(img_origin,img_fake,x,y,size,opt.no_feather) img_result = impro.replace_mosaic(img_origin,img_fake,x,y,size,opt.no_feather)
cv2.imwrite(os.path.join('./tmp/replace_mosaic',imagepath),img_result) cv2.imwrite(os.path.join('./tmp/replace_mosaic',imagepath),img_result)
......
...@@ -10,7 +10,7 @@ class Options(): ...@@ -10,7 +10,7 @@ class Options():
def initialize(self): def initialize(self):
#base #base
self.parser.add_argument('--use_gpu',type=bool,default=True, help='if True, use gpu') self.parser.add_argument('--use_gpu',type=int,default=1, help='if 0, do not use gpu')
# self.parser.add_argument('--use_gpu', action='store_true', help='if input it, use gpu') # self.parser.add_argument('--use_gpu', action='store_true', help='if input it, use gpu')
self.parser.add_argument('--media_path', type=str, default='./hands_test.mp4',help='your videos or images path') self.parser.add_argument('--media_path', type=str, default='./hands_test.mp4',help='your videos or images path')
self.parser.add_argument('--mode', type=str, default='auto',help='add or clean mosaic into your media auto | add | clean') self.parser.add_argument('--mode', type=str, default='auto',help='add or clean mosaic into your media auto | add | clean')
......
import os import os
from cores import Options,core from cores import Options,core
from util import util from util import util
from models import loadmodel
opt = Options().getparse() opt = Options().getparse()
util.file_init(opt) util.file_init(opt)
def main(): def main():
if os.path.isdir(opt.media_path):
files = util.Traversal(opt.media_path)
else:
files = [opt.media_path]
if opt.mode == 'add': if opt.mode == 'add':
if util.is_img(opt.media_path): netS = loadmodel.unet(opt)
core.addmosaic_img(opt) for file in files:
elif util.is_video(opt.media_path): opt.media_path = file
core.addmosaic_video(opt) if util.is_img(file):
else: core.addmosaic_img(opt,netS)
print('This type of file is not supported') elif util.is_video(file):
core.addmosaic_video(opt,netS)
util.clean_tempfiles(tmp_init = False)
else:
print('This type of file is not supported')
elif opt.mode == 'clean': elif opt.mode == 'clean':
if util.is_img(opt.media_path): netM = loadmodel.unet_clean(opt)
core.cleanmosaic_img(opt) if opt.netG == 'video':
elif util.is_video(opt.media_path): netG = loadmodel.video(opt)
if opt.netG == 'video':
core.cleanmosaic_video_fusion(opt)
else:
core.cleanmosaic_video_byframe(opt)
else: else:
print('This type of file is not supported') netG = loadmodel.pix2pix(opt)
util.clean_tempfiles(tmp_init = False) for file in files:
opt.media_path = file
if util.is_img(file):
core.cleanmosaic_img(opt,netG,netM)
elif util.is_video(file):
if opt.netG == 'video':
core.cleanmosaic_video_fusion(opt,netG,netM)
else:
core.cleanmosaic_video_byframe(opt,netG,netM)
util.clean_tempfiles(tmp_init = False)
else:
print('This type of file is not supported')
if __name__ == '__main__':
try: main()
main() # if __name__ == '__main__':
except Exception as e: # try:
print('Error:',e) # main()
input('Please press any key to exit.\n') # except Exception as e:
util.clean_tempfiles(tmp_init = False) # print('Error:',e)
exit(0) # input('Please press any key to exit.\n')
# util.clean_tempfiles(tmp_init = False)
# exit(0)
...@@ -11,14 +11,15 @@ from util import util,mosaic ...@@ -11,14 +11,15 @@ from util import util,mosaic
import datetime import datetime
import shutil import shutil
mask_path = '/media/hypo/Porject/Datasets/unet/av/mask' mask_dir = '/media/hypo/Project/MyProject/DeepMosaics/DeepMosaics/train/add/datasets/av/mask'
img_path ='/media/hypo/Porject/Datasets/unet/av/origin_image' img_dir ='/media/hypo/Project/MyProject/DeepMosaics/DeepMosaics/train/add/datasets/av/origin_image'
output_dir = './datasets_img' output_dir = './datasets_img'
util.makedirs(output_dir) util.makedirs(output_dir)
HD = True # if false make dataset for pix2pix, if Ture for pix2pix_HD HD = True # if false make dataset for pix2pix, if Ture for pix2pix_HD
MASK = False # if True, output mask,too MASK = True # if True, output mask,too
OUT_SIZE = 256 OUT_SIZE = 256
FOLD_NUM = 5 FOLD_NUM = 2
Bounding = True
if HD: if HD:
train_A_path = os.path.join(output_dir,'train_A') train_A_path = os.path.join(output_dir,'train_A')
...@@ -32,8 +33,8 @@ if MASK: ...@@ -32,8 +33,8 @@ if MASK:
mask_path = os.path.join(output_dir,'mask') mask_path = os.path.join(output_dir,'mask')
util.makedirs(mask_path) util.makedirs(mask_path)
mask_names = os.listdir(mask_path) mask_names = os.listdir(mask_dir)
img_names = os.listdir(img_path) img_names = os.listdir(img_dir)
mask_names.sort() mask_names.sort()
img_names.sort() img_names.sort()
print('Find images:',len(img_names)) print('Find images:',len(img_names))
...@@ -42,13 +43,14 @@ cnt = 0 ...@@ -42,13 +43,14 @@ cnt = 0
for fold in range(FOLD_NUM): for fold in range(FOLD_NUM):
for img_name,mask_name in zip(img_names,mask_names): for img_name,mask_name in zip(img_names,mask_names):
try: try:
img = impro.imread(os.path.join(img_path,img_name)) img = impro.imread(os.path.join(img_dir,img_name))
mask = impro.imread(os.path.join(mask_path,mask_name),'gray') mask = impro.imread(os.path.join(mask_dir,mask_name),'gray')
mask = impro.resize_like(mask, img) mask = impro.resize_like(mask, img)
x,y,size,area = impro.boundingSquare(mask, 1.5) x,y,size,area = impro.boundingSquare(mask, 1.5)
if area > 100: if area > 100:
img = impro.resize(img[y-size:y+size,x-size:x+size],OUT_SIZE) if Bounding
mask = impro.resize(mask[y-size:y+size,x-size:x+size],OUT_SIZE) img = impro.resize(img[y-size:y+size,x-size:x+size],OUT_SIZE)
mask = impro.resize(mask[y-size:y+size,x-size:x+size],OUT_SIZE)
img_mosaic = mosaic.addmosaic_random(img, mask) img_mosaic = mosaic.addmosaic_random(img, mask)
if HD: if HD:
......
...@@ -11,11 +11,11 @@ from util import util,mosaic ...@@ -11,11 +11,11 @@ from util import util,mosaic
import datetime import datetime
ir_mask_path = './Irregular_Holes_mask' ir_mask_path = './Irregular_Holes_mask'
img_dir ='/home/hypo/MyProject/Haystack/CV/output/all/face' img_dir ='/media/hypo/Hypoyun/Datasets/other/face512'
MOD = 'HD' #HD | pix2pix | mosaic MOD = 'mosaic' #HD | pix2pix | mosaic
MASK = False # if True, output mask,too MASK = False # if True, output mask,too
BOUNDING = True # if true the mosaic size will be more big BOUNDING = False # if true the mosaic size will be more big
suffix = '' suffix = '_1'
output_dir = os.path.join('./datasets_img',MOD) output_dir = os.path.join('./datasets_img',MOD)
util.makedirs(output_dir) util.makedirs(output_dir)
...@@ -27,6 +27,13 @@ if MOD == 'HD': ...@@ -27,6 +27,13 @@ if MOD == 'HD':
elif MOD == 'pix2pix': elif MOD == 'pix2pix':
train_path = os.path.join(output_dir,'train') train_path = os.path.join(output_dir,'train')
util.makedirs(train_path) util.makedirs(train_path)
elif MOD == 'mosaic':
ori_path = os.path.join(output_dir,'ori')
mosaic_path = os.path.join(output_dir,'mosaic')
mask_path = os.path.join(output_dir,'mask')
util.makedirs(ori_path)
util.makedirs(mosaic_path)
util.makedirs(mask_path)
if MASK: if MASK:
mask_path = os.path.join(output_dir,'mask') mask_path = os.path.join(output_dir,'mask')
util.makedirs(mask_path) util.makedirs(mask_path)
...@@ -43,12 +50,13 @@ transform_img = transforms.Compose([ ...@@ -43,12 +50,13 @@ transform_img = transforms.Compose([
]) ])
mask_names = os.listdir(ir_mask_path) mask_names = os.listdir(ir_mask_path)
img_names = os.listdir(img_dir) img_paths = util.Traversal(img_dir)
print('Find images:',len(img_names)) img_paths = util.is_imgs(img_paths)
print('Find images:',len(img_paths))
for i,img_name in enumerate(img_names,1): for i,img_path in enumerate(img_paths,1):
try: try:
img = Image.open(os.path.join(img_dir,img_name)) img = Image.open(img_path)
img = transform_img(img) img = transform_img(img)
img = np.array(img) img = np.array(img)
img = img[...,::-1] img = img[...,::-1]
...@@ -70,11 +78,16 @@ for i,img_name in enumerate(img_names,1): ...@@ -70,11 +78,16 @@ for i,img_name in enumerate(img_names,1):
if MOD == 'HD':#[128:384,128:384,:] --->256 if MOD == 'HD':#[128:384,128:384,:] --->256
cv2.imwrite(os.path.join(train_A_path,'%05d' % i+suffix+'.jpg'), mosaic_img) cv2.imwrite(os.path.join(train_A_path,'%05d' % i+suffix+'.jpg'), mosaic_img)
cv2.imwrite(os.path.join(train_B_path,'%05d' % i+suffix+'.jpg'), img) cv2.imwrite(os.path.join(train_B_path,'%05d' % i+suffix+'.jpg'), img)
else: if MASK:
cv2.imwrite(os.path.join(mask_path,'%05d' % i+suffix+'.png'), mask)
elif MOD == 'pix2pix':
merge_img = impro.makedataset(mosaic_img, img) merge_img = impro.makedataset(mosaic_img, img)
cv2.imwrite(os.path.join(train_path,'%05d' % i+suffix+'.jpg'), merge_img) cv2.imwrite(os.path.join(train_path,'%05d' % i+suffix+'.jpg'), merge_img)
if MASK: elif MOD == 'mosaic':
cv2.imwrite(os.path.join(mosaic_path,'%05d' % i+suffix+'.jpg'), mosaic_img)
cv2.imwrite(os.path.join(ori_path,'%05d' % i+suffix+'.jpg'), img)
cv2.imwrite(os.path.join(mask_path,'%05d' % i+suffix+'.png'), mask) cv2.imwrite(os.path.join(mask_path,'%05d' % i+suffix+'.png'), mask)
print('\r','Proc/all:'+str(i)+'/'+str(len(img_names)),util.get_bar(100*i/len(img_names),num=40),end='')
print('\r','Proc/all:'+str(i)+'/'+str(len(img_paths)),util.get_bar(100*i/len(img_paths),num=40),end='')
except Exception as e: except Exception as e:
print(img_name,e) print(img_path,e)
...@@ -5,7 +5,7 @@ from util import mosaic ...@@ -5,7 +5,7 @@ from util import mosaic
from util import data from util import data
import torch import torch
def run_unet(img,net,size = 128,use_gpu = True): def run_unet(img,net,size = 224,use_gpu = True):
img=impro.image2folat(img,3) img=impro.image2folat(img,3)
img=img.reshape(1,3,size,size) img=img.reshape(1,3,size,size)
img = torch.from_numpy(img) img = torch.from_numpy(img)
...@@ -16,12 +16,12 @@ def run_unet(img,net,size = 128,use_gpu = True): ...@@ -16,12 +16,12 @@ def run_unet(img,net,size = 128,use_gpu = True):
pred = pred.reshape(size,size).astype('uint8') pred = pred.reshape(size,size).astype('uint8')
return pred return pred
def run_unet_rectim(img,net,size = 128,use_gpu = True): def run_unet_rectim(img,net,size = 224,use_gpu = True):
img = impro.resize(img,size) img = impro.resize(img,size)
img1,img2 = impro.spiltimage(img) img1,img2 = impro.spiltimage(img,size)
mask1 = run_unet(img1,net,size = 128,use_gpu = use_gpu) mask1 = run_unet(img1,net,size,use_gpu = use_gpu)
mask2 = run_unet(img2,net,size = 128,use_gpu = use_gpu) mask2 = run_unet(img2,net,size,use_gpu = use_gpu)
mask = impro.mergeimage(mask1,mask2,img) mask = impro.mergeimage(mask1,mask2,img,size)
return mask return mask
def run_pix2pix(img,net,opt): def run_pix2pix(img,net,opt):
...@@ -42,8 +42,9 @@ def get_ROI_position(img,net,opt): ...@@ -42,8 +42,9 @@ def get_ROI_position(img,net,opt):
def get_mosaic_position(img_origin,net_mosaic_pos,opt,threshold = 128 ): def get_mosaic_position(img_origin,net_mosaic_pos,opt,threshold = 128 ):
mask = run_unet_rectim(img_origin,net_mosaic_pos,use_gpu = opt.use_gpu) mask = run_unet_rectim(img_origin,net_mosaic_pos,use_gpu = opt.use_gpu)
mask = impro.mask_threshold(mask,10,threshold) mask_1 = mask.copy()
mask = impro.mask_threshold(mask,20,threshold)
x,y,size,area = impro.boundingSquare(mask,Ex_mul=1.5) x,y,size,area = impro.boundingSquare(mask,Ex_mul=1.5)
rat = min(img_origin.shape[:2])/128.0 rat = min(img_origin.shape[:2])/224.0
x,y,size = int(rat*x),int(rat*y),int(rat*size) x,y,size = int(rat*x),int(rat*y),int(rat*size)
return x,y,size,mask return x,y,size,mask_1
\ No newline at end of file \ No newline at end of file
...@@ -31,4 +31,4 @@ class UNet(nn.Module): ...@@ -31,4 +31,4 @@ class UNet(nn.Module):
x = self.up3(x, x2) x = self.up3(x, x2)
x = self.up4(x, x1) x = self.up4(x, x1)
x = self.outc(x) x = self.outc(x)
return torch.Tanh(x) return x
\ No newline at end of file \ No newline at end of file
...@@ -90,7 +90,12 @@ class up(nn.Module): ...@@ -90,7 +90,12 @@ class up(nn.Module):
class outconv(nn.Module): class outconv(nn.Module):
def __init__(self, in_ch, out_ch): def __init__(self, in_ch, out_ch):
super(outconv, self).__init__() super(outconv, self).__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 1) self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 1),
nn.Sigmoid()
)
def forward(self, x): def forward(self, x):
x = self.conv(x) x = self.conv(x)
......
...@@ -22,17 +22,19 @@ import torch.backends.cudnn as cudnn ...@@ -22,17 +22,19 @@ import torch.backends.cudnn as cudnn
LR = 0.0002 LR = 0.0002
EPOCHS = 100 EPOCHS = 100
BATCHSIZE = 16 BATCHSIZE = 8
LOADSIZE = 256 LOADSIZE = 256
FINESIZE = 224 FINESIZE = 224
CONTINUE = False CONTINUE = True
use_gpu = True use_gpu = True
SAVE_FRE = 5 SAVE_FRE = 1
cudnn.benchmark = False MAX_LOAD = 35000
#cudnn.benchmark = True
dir_img = './datasets/av/origin_image/'
dir_mask = './datasets/av/mask/' dir_img = './datasets/mosaic/mosaic/'
dir_checkpoint = 'checkpoints/' dir_mask = './datasets/mosaic/mask/'
dir_checkpoint = 'checkpoints/mosaic/'
def Totensor(img,use_gpu=True): def Totensor(img,use_gpu=True):
...@@ -43,15 +45,15 @@ def Totensor(img,use_gpu=True): ...@@ -43,15 +45,15 @@ def Totensor(img,use_gpu=True):
return img return img
def Toinputshape(imgs,masks,finesize): def Toinputshape(imgs,masks,finesize,test_flag = False):
batchsize = len(imgs) batchsize = len(imgs)
result_imgs=[];result_masks=[] result_imgs=[];result_masks=[]
for i in range(batchsize): for i in range(batchsize):
# print(imgs[i].shape,masks[i].shape) # print(imgs[i].shape,masks[i].shape)
img,mask = data.random_transform_image(imgs[i], masks[i], finesize) img,mask = data.random_transform_image(imgs[i], masks[i], finesize, test_flag)
# print(img.shape,mask.shape) # print(img.shape,mask.shape)
mask = (mask.reshape(1,finesize,finesize)/255.0-0.5)/0.5 mask = (mask.reshape(1,finesize,finesize)/255.0)
img = (img.transpose((2, 0, 1))/255.0-0.5)/0.5 img = (img.transpose((2, 0, 1))/255.0)
result_imgs.append(img) result_imgs.append(img)
result_masks.append(mask) result_masks.append(mask)
result_imgs = np.array(result_imgs) result_imgs = np.array(result_imgs)
...@@ -74,9 +76,10 @@ def batch_generator(images,masks,batchsize): ...@@ -74,9 +76,10 @@ def batch_generator(images,masks,batchsize):
def loadimage(dir_img,dir_mask,loadsize,eval_p): def loadimage(dir_img,dir_mask,loadsize,eval_p):
t1 = datetime.datetime.now() t1 = datetime.datetime.now()
imgnames = os.listdir(dir_img) imgnames = os.listdir(dir_img)
# imgnames = imgnames[:100] # imgnames = imgnames[:100]
print('images num:',len(imgnames))
random.shuffle(imgnames) random.shuffle(imgnames)
imgnames = imgnames[:MAX_LOAD]
print('load images:',len(imgnames))
imgnames = (f[:-4] for f in imgnames) imgnames = (f[:-4] for f in imgnames)
images = [] images = []
masks = [] masks = []
...@@ -94,7 +97,7 @@ def loadimage(dir_img,dir_mask,loadsize,eval_p): ...@@ -94,7 +97,7 @@ def loadimage(dir_img,dir_mask,loadsize,eval_p):
return train_images,train_masks,eval_images,eval_masks return train_images,train_masks,eval_images,eval_masks
util.makedirs(dir_checkpoint)
print('loading data......') print('loading data......')
train_images,train_masks,eval_images,eval_masks = loadimage(dir_img,dir_mask,LOADSIZE,0.2) train_images,train_masks,eval_images,eval_masks = loadimage(dir_img,dir_mask,LOADSIZE,0.2)
dataset_eval_images,dataset_eval_masks = batch_generator(eval_images,eval_masks,BATCHSIZE) dataset_eval_images,dataset_eval_masks = batch_generator(eval_images,eval_masks,BATCHSIZE)
...@@ -104,6 +107,10 @@ dataset_train_images,dataset_train_masks = batch_generator(train_images,train_ma ...@@ -104,6 +107,10 @@ dataset_train_images,dataset_train_masks = batch_generator(train_images,train_ma
net = unet_model.UNet(n_channels = 3, n_classes = 1) net = unet_model.UNet(n_channels = 3, n_classes = 1)
if CONTINUE:
if not os.path.isfile(os.path.join(dir_checkpoint,'last.pth')):
CONTINUE = False
print('can not load last.pth, training on init weight.')
if CONTINUE: if CONTINUE:
net.load_state_dict(torch.load(dir_checkpoint+'last.pth')) net.load_state_dict(torch.load(dir_checkpoint+'last.pth'))
if use_gpu: if use_gpu:
...@@ -117,6 +124,7 @@ criterion = nn.BCELoss() ...@@ -117,6 +124,7 @@ criterion = nn.BCELoss()
print('begin training......') print('begin training......')
for epoch in range(EPOCHS): for epoch in range(EPOCHS):
random_save = random.randint(0, len(dataset_train_images))
starttime = datetime.datetime.now() starttime = datetime.datetime.now()
print('Epoch {}/{}.'.format(epoch + 1, EPOCHS)) print('Epoch {}/{}.'.format(epoch + 1, EPOCHS))
...@@ -139,15 +147,18 @@ for epoch in range(EPOCHS): ...@@ -139,15 +147,18 @@ for epoch in range(EPOCHS):
optimizer.step() optimizer.step()
if i%100 == 0: if i%100 == 0:
data.showresult(img,mask,mask_pred,os.path.join(dir_checkpoint,'result.png')) data.showresult(img,mask,mask_pred,os.path.join(dir_checkpoint,'result.png'),True)
if i == random_save:
data.showresult(img,mask,mask_pred,os.path.join(dir_checkpoint,'epoch_'+str(epoch+1)+'.png'),True)
# torch.cuda.empty_cache() # torch.cuda.empty_cache()
# # net.eval() # # net.eval()
epoch_loss_eval = 0 epoch_loss_eval = 0
with torch.no_grad(): with torch.no_grad():
#net.eval()
for i,(img,mask) in enumerate(zip(dataset_eval_images,dataset_eval_masks)): for i,(img,mask) in enumerate(zip(dataset_eval_images,dataset_eval_masks)):
# print(epoch,i,img.shape,mask.shape) # print(epoch,i,img.shape,mask.shape)
img,mask = Toinputshape(img, mask, FINESIZE) img,mask = Toinputshape(img, mask, FINESIZE,test_flag=True)
img = Totensor(img,use_gpu) img = Totensor(img,use_gpu)
mask = Totensor(mask,use_gpu) mask = Totensor(mask,use_gpu)
mask_pred = net(img) mask_pred = net(img)
...@@ -164,5 +175,5 @@ for epoch in range(EPOCHS): ...@@ -164,5 +175,5 @@ for epoch in range(EPOCHS):
if (epoch+1)%SAVE_FRE == 0: if (epoch+1)%SAVE_FRE == 0:
torch.save(net.cpu().state_dict(),dir_checkpoint+'epoch'+str(epoch+1)+'.pth') torch.save(net.cpu().state_dict(),dir_checkpoint+'epoch'+str(epoch+1)+'.pth')
data.showresult(img,mask,mask_pred,os.path.join(dir_checkpoint,'epoch_'+str(epoch+1)+'.png'))
print('network saved.') print('network saved.')
...@@ -21,19 +21,19 @@ ITER = 10000000 ...@@ -21,19 +21,19 @@ ITER = 10000000
LR = 0.0002 LR = 0.0002
beta1 = 0.5 beta1 = 0.5
use_gpu = True use_gpu = True
use_gan = False use_gan = True
use_L2 = True use_L2 = False
CONTINUE = True CONTINUE = True
lambda_L1 = 100.0 lambda_L1 = 100.0
lambda_gan = 1 lambda_gan = 1
SAVE_FRE = 10000 SAVE_FRE = 10000
start_iter = 0 start_iter = 0
finesize = 128 finesize = 256
loadsize = int(finesize*1.1) loadsize = int(finesize*1.2)
batchsize = 8 batchsize = 1
perload_num = 16 perload_num = 16
savename = 'MosaicNet_batch' savename = 'MosaicNet_instance_gan_256_D5'
dir_checkpoint = 'checkpoints/'+savename dir_checkpoint = 'checkpoints/'+savename
util.makedirs(dir_checkpoint) util.makedirs(dir_checkpoint)
...@@ -57,10 +57,14 @@ loadmodel.show_paramsnumber(netG,'netG') ...@@ -57,10 +57,14 @@ loadmodel.show_paramsnumber(netG,'netG')
# netG = unet_model.UNet(3*N+1, 3) # netG = unet_model.UNet(3*N+1, 3)
if use_gan: if use_gan:
#netD = pix2pix_model.define_D(3*2+1, 64, 'pixel', norm='instance') #netD = pix2pix_model.define_D(3*2+1, 64, 'pixel', norm='instance')
netD = pix2pix_model.define_D(3*2+1, 64, 'basic', norm='instance') #netD = pix2pix_model.define_D(3*2+1, 64, 'basic', norm='instance')
#netD = pix2pix_model.define_D(3*2+1, 64, 'n_layers', n_layers_D=5, norm='instance', init_type='normal', init_gain=0.02, gpu_ids=[]) netD = pix2pix_model.define_D(3*2+1, 64, 'n_layers', n_layers_D=5, norm='instance')
if CONTINUE: if CONTINUE:
if not os.path.isfile(os.path.join(dir_checkpoint,'last_G.pth')):
CONTINUE = False
print('can not load last_G, training on init weight.')
if CONTINUE:
netG.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last_G.pth'))) netG.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last_G.pth')))
if use_gan: if use_gan:
netD.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last_D.pth'))) netD.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last_D.pth')))
......
...@@ -37,7 +37,7 @@ def im2tensor(image_numpy, imtype=np.uint8, gray=False,bgr2rgb = True, reshape = ...@@ -37,7 +37,7 @@ def im2tensor(image_numpy, imtype=np.uint8, gray=False,bgr2rgb = True, reshape =
image_numpy = (image_numpy/255.0-0.5)/0.5 image_numpy = (image_numpy/255.0-0.5)/0.5
image_tensor = torch.from_numpy(image_numpy).float() image_tensor = torch.from_numpy(image_numpy).float()
if reshape: if reshape:
image_tensor=image_tensor.reshape(1,1,h,w) image_tensor = image_tensor.reshape(1,1,h,w)
else: else:
h, w ,ch = image_numpy.shape h, w ,ch = image_numpy.shape
if bgr2rgb: if bgr2rgb:
...@@ -52,7 +52,7 @@ def im2tensor(image_numpy, imtype=np.uint8, gray=False,bgr2rgb = True, reshape = ...@@ -52,7 +52,7 @@ def im2tensor(image_numpy, imtype=np.uint8, gray=False,bgr2rgb = True, reshape =
image_numpy = image_numpy.transpose((2, 0, 1)) image_numpy = image_numpy.transpose((2, 0, 1))
image_tensor = torch.from_numpy(image_numpy).float() image_tensor = torch.from_numpy(image_numpy).float()
if reshape: if reshape:
image_tensor=image_tensor.reshape(1,ch,h,w) image_tensor = image_tensor.reshape(1,ch,h,w)
if use_gpu: if use_gpu:
image_tensor = image_tensor.cuda() image_tensor = image_tensor.cuda()
return image_tensor return image_tensor
...@@ -91,7 +91,7 @@ def random_transform_video(src,target,finesize,N): ...@@ -91,7 +91,7 @@ def random_transform_video(src,target,finesize,N):
return src,target return src,target
def random_transform_image(img,mask,finesize): def random_transform_image(img,mask,finesize,test_flag = False):
# randomsize = int(finesize*(1.2+0.2*random.random())+2) # randomsize = int(finesize*(1.2+0.2*random.random())+2)
...@@ -118,6 +118,9 @@ def random_transform_image(img,mask,finesize): ...@@ -118,6 +118,9 @@ def random_transform_image(img,mask,finesize):
# print(h,w,h_move,w_move) # print(h,w,h_move,w_move)
img_crop = img[h_move:h_move+finesize,w_move:w_move+finesize] img_crop = img[h_move:h_move+finesize,w_move:w_move+finesize]
mask_crop = mask[h_move:h_move+finesize,w_move:w_move+finesize] mask_crop = mask[h_move:h_move+finesize,w_move:w_move+finesize]
if test_flag:
return img_crop,mask_crop
#random rotation #random rotation
if random.random()<0.2: if random.random()<0.2:
...@@ -143,12 +146,19 @@ def random_transform_image(img,mask,finesize): ...@@ -143,12 +146,19 @@ def random_transform_image(img,mask,finesize):
else: else:
img = img[::-1,:,:] img = img[::-1,:,:]
mask = mask[::-1,:] mask = mask[::-1,:]
#random blur
if random.random()>0.5:
size_ran = random.uniform(0.5,1.5)
img = cv2.resize(img, (int(finesize*size_ran),int(finesize*size_ran)))
img = cv2.resize(img, (finesize,finesize))
#img = cv2.blur(img, (random.randint(1,3), random.randint(1,3)))
return img,mask return img,mask
def showresult(img1,img2,img3,name): def showresult(img1,img2,img3,name,is0_1 = False):
size = img1.shape[3] size = img1.shape[3]
showimg=np.zeros((size,size*3,3)) showimg=np.zeros((size,size*3,3))
showimg[0:size,0:size] = tensor2im(img1,rgb2bgr = False, is0_1 = False) showimg[0:size,0:size] = tensor2im(img1,rgb2bgr = False, is0_1 = is0_1)
showimg[0:size,size:size*2] = tensor2im(img2,rgb2bgr = False, is0_1 = False) showimg[0:size,size:size*2] = tensor2im(img2,rgb2bgr = False, is0_1 = is0_1)
showimg[0:size,size*2:size*3] = tensor2im(img3,rgb2bgr = False, is0_1 = False) showimg[0:size,size*2:size*3] = tensor2im(img3,rgb2bgr = False, is0_1 = is0_1)
cv2.imwrite(name, showimg) cv2.imwrite(name, showimg)
...@@ -84,9 +84,9 @@ def image2folat(img,ch): ...@@ -84,9 +84,9 @@ def image2folat(img,ch):
img = (img.transpose((2, 0, 1))/255.0).astype(np.float32) img = (img.transpose((2, 0, 1))/255.0).astype(np.float32)
return img return img
def spiltimage(img): def spiltimage(img,size = 128):
h, w = img.shape[:2] h, w = img.shape[:2]
size = min(h,w) # size = min(h,w)
if w >= h: if w >= h:
img1 = img[:,0:size] img1 = img[:,0:size]
img2 = img[:,w-size:w] img2 = img[:,w-size:w]
...@@ -96,12 +96,12 @@ def spiltimage(img): ...@@ -96,12 +96,12 @@ def spiltimage(img):
return img1,img2 return img1,img2
def mergeimage(img1,img2,orgin_image): def mergeimage(img1,img2,orgin_image,size = 128):
h, w = orgin_image.shape[:2] h, w = orgin_image.shape[:2]
new_img1 = np.zeros((h,w), dtype = "uint8") new_img1 = np.zeros((h,w), dtype = "uint8")
new_img2 = np.zeros((h,w), dtype = "uint8") new_img2 = np.zeros((h,w), dtype = "uint8")
size = min(h,w) # size = min(h,w)
if w >= h: if w >= h:
new_img1[:,0:size]=img1 new_img1[:,0:size]=img1
new_img2[:,w-size:w]=img2 new_img2[:,w-size:w]=img2
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册