提交 bbd277bd 编写于 作者: H hypox64

Allow fps limit(--fps) and traditional method to clean mosaic(--traditional)

上级 d3cf56de
...@@ -153,6 +153,7 @@ result/ ...@@ -153,6 +153,7 @@ result/
/python_test.py /python_test.py
/pretrained_models_old /pretrained_models_old
/deepmosaic_window /deepmosaic_window
/sftp-config.json
#./make_datasets #./make_datasets
/make_datasets/video /make_datasets/video
/make_datasets/tmp /make_datasets/tmp
...@@ -160,6 +161,9 @@ result/ ...@@ -160,6 +161,9 @@ result/
/make_datasets/datasets /make_datasets/datasets
/make_datasets/dataset /make_datasets/dataset
/make_datasets/datasets_img /make_datasets/datasets_img
/make_datasets/videos
#./models
/models/videoHD_model.py
#./train #./train
/train/clean/dataset /train/clean/dataset
#mediafile #mediafile
...@@ -177,4 +181,5 @@ result/ ...@@ -177,4 +181,5 @@ result/
*.rmvb *.rmvb
*.JPG *.JPG
*.MP4 *.MP4
*.JPEG *.JPEG
\ No newline at end of file *.exe
\ No newline at end of file
...@@ -48,17 +48,17 @@ git clone https://github.com/HypoX64/DeepMosaics ...@@ -48,17 +48,17 @@ git clone https://github.com/HypoX64/DeepMosaics
cd DeepMosaics cd DeepMosaics
``` ```
#### Get pre_trained models and test video #### Get pre_trained models and test video
You can download pre_trained models and test video and replace the files in the project.<br> You can download pre_trained models and put them into './pretrained_models'.<br>
[[Google Drive]](https://drive.google.com/open?id=1LTERcN33McoiztYEwBxMuRjjgxh4DEPs) [[百度云,提取码1x0a]](https://pan.baidu.com/s/10rN3U3zd5TmfGpO_PEShqQ) [[Google Drive]](https://drive.google.com/open?id=1LTERcN33McoiztYEwBxMuRjjgxh4DEPs) [[百度云,提取码1x0a]](https://pan.baidu.com/s/10rN3U3zd5TmfGpO_PEShqQ)
#### Simple example #### Simple example
* Add Mosaic (output video will save in './result') * Add Mosaic (output video will save in './result')
```bash ```bash
python3 deepmosaic.py python3 deepmosaic.py --media_path ./imgs/ruoruo.jpg --model_path ./pretrained_models/mosaic/add_face.pth --use_gpu -1
``` ```
* Clean Mosaic (output video will save in './result') * Clean Mosaic (output video will save in './result')
```bash ```bash
python3 deepmosaic.py --mode clean --model_path ./pretrained_models/clean_hands_unet_128.pth --media_path ./result/hands_test_AddMosaic.mp4 python3 deepmosaic.py --media_path ./result/ruoruo_add.jpg --model_path ./pretrained_models/mosaic/clean_face_HD.pth --use_gpu -1
``` ```
#### More parameters #### More parameters
If you want to test other image or video, please refer to this file. If you want to test other image or video, please refer to this file.
......
...@@ -49,18 +49,18 @@ ...@@ -49,18 +49,18 @@
git clone https://github.com/HypoX64/DeepMosaics git clone https://github.com/HypoX64/DeepMosaics
cd DeepMosaics cd DeepMosaics
``` ```
#### 下载测试视频以及预训练模型 #### 下载预训练模型
可以通过以下两种方法下载测试视频以及预训练模型,并将他们置于项目文件夹中.<br> 可以通过以下两种方法下载预训练模型,并将他们置于'./pretrained_models'文件夹中.<br>
[[Google Drive]](https://drive.google.com/open?id=1LTERcN33McoiztYEwBxMuRjjgxh4DEPs) [[百度云,提取码1x0a]](https://pan.baidu.com/s/10rN3U3zd5TmfGpO_PEShqQ) <br> [[Google Drive]](https://drive.google.com/open?id=1LTERcN33McoiztYEwBxMuRjjgxh4DEPs) [[百度云,提取码1x0a]](https://pan.baidu.com/s/10rN3U3zd5TmfGpO_PEShqQ) <br>
#### 简单的例子 #### 简单的例子
* 为视频添加马赛克,例子中认为是需要打码的区域 ,可以通过切换预训练模型切换自动打码区域(输出结果将储存到 './result') * 为视频添加马赛克,例子中认为是需要打码的区域 ,可以通过切换预训练模型切换自动打码区域(输出结果将储存到 './result')
```bash ```bash
python3 deepmosaic.py python3 deepmosaic.py --media_path ./imgs/ruoruo.jpg --model_path ./pretrained_models/mosaic/add_face.pth --use_gpu -1
``` ```
* 将视频中的马赛克移除,对于不同的打码物体需要使用对应的预训练模型进行马赛克消除(输出结果将储存到 './result') * 将视频中的马赛克移除,对于不同的打码物体需要使用对应的预训练模型进行马赛克消除(输出结果将储存到 './result')
```bash ```bash
python3 deepmosaic.py --mode clean --model_path ./pretrained_models/clean_hands_unet_128.pth --media_path ./result/hands_test_AddMosaic.mp4 python3 deepmosaic.py --media_path ./result/ruoruo_add.jpg --model_path ./pretrained_models/mosaic/clean_face_HD.pth --use_gpu -1
``` ```
#### 更多的参数 #### 更多的参数
如果想要测试其他的图片或视频,请参照以下文件输入参数. 如果想要测试其他的图片或视频,请参照以下文件输入参数.
......
...@@ -6,15 +6,23 @@ from models import runmodel,loadmodel ...@@ -6,15 +6,23 @@ from models import runmodel,loadmodel
from util import mosaic,util,ffmpeg,filt,data from util import mosaic,util,ffmpeg,filt,data
from util import image_processing as impro from util import image_processing as impro
'''
---------------------Video Init---------------------
'''
def video_init(opt,path): def video_init(opt,path):
util.clean_tempfiles() util.clean_tempfiles()
fps = ffmpeg.get_video_infos(path)[0] fps,endtime,height,width = ffmpeg.get_video_infos(path)
if opt.fps !=0:
fps = opt.fps
ffmpeg.video2voice(path,'./tmp/voice_tmp.mp3') ffmpeg.video2voice(path,'./tmp/voice_tmp.mp3')
ffmpeg.video2image(path,'./tmp/video2image/output_%05d.'+opt.tempimage_type) ffmpeg.video2image(path,'./tmp/video2image/output_%05d.'+opt.tempimage_type,fps)
imagepaths=os.listdir('./tmp/video2image') imagepaths=os.listdir('./tmp/video2image')
imagepaths.sort() imagepaths.sort()
return fps,imagepaths return fps,imagepaths,height,width
'''
---------------------Add Mosaic---------------------
'''
def addmosaic_img(opt,netS): def addmosaic_img(opt,netS):
path = opt.media_path path = opt.media_path
print('Add Mosaic:',path) print('Add Mosaic:',path)
...@@ -25,7 +33,7 @@ def addmosaic_img(opt,netS): ...@@ -25,7 +33,7 @@ def addmosaic_img(opt,netS):
def addmosaic_video(opt,netS): def addmosaic_video(opt,netS):
path = opt.media_path path = opt.media_path
fps,imagepaths = video_init(opt,path) fps,imagepaths = video_init(opt,path)[:2]
# get position # get position
positions = [] positions = []
for i,imagepath in enumerate(imagepaths,1): for i,imagepath in enumerate(imagepaths,1):
...@@ -33,7 +41,7 @@ def addmosaic_video(opt,netS): ...@@ -33,7 +41,7 @@ def addmosaic_video(opt,netS):
mask,x,y,area = runmodel.get_ROI_position(img,netS,opt) mask,x,y,area = runmodel.get_ROI_position(img,netS,opt)
positions.append([x,y,area]) positions.append([x,y,area])
cv2.imwrite(os.path.join('./tmp/ROI_mask',imagepath),mask) cv2.imwrite(os.path.join('./tmp/ROI_mask',imagepath),mask)
print('\r','Find ROI location:'+str(i)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=40),end='') print('\r','Find ROI location:'+str(i)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=35),end='')
print('\nOptimize ROI locations...') print('\nOptimize ROI locations...')
mask_index = filt.position_medfilt(np.array(positions), 7) mask_index = filt.position_medfilt(np.array(positions), 7)
...@@ -44,13 +52,16 @@ def addmosaic_video(opt,netS): ...@@ -44,13 +52,16 @@ def addmosaic_video(opt,netS):
if impro.mask_area(mask)>100: if impro.mask_area(mask)>100:
img = mosaic.addmosaic(img, mask, opt) img = mosaic.addmosaic(img, mask, opt)
cv2.imwrite(os.path.join('./tmp/addmosaic_image',imagepaths[i]),img) cv2.imwrite(os.path.join('./tmp/addmosaic_image',imagepaths[i]),img)
print('\r','Add Mosaic:'+str(i+1)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=40),end='') print('\r','Add Mosaic:'+str(i+1)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=35),end='')
print() print()
ffmpeg.image2video( fps, ffmpeg.image2video( fps,
'./tmp/addmosaic_image/output_%05d.'+opt.tempimage_type, './tmp/addmosaic_image/output_%05d.'+opt.tempimage_type,
'./tmp/voice_tmp.mp3', './tmp/voice_tmp.mp3',
os.path.join(opt.result_dir,os.path.splitext(os.path.basename(path))[0]+'_add.mp4')) os.path.join(opt.result_dir,os.path.splitext(os.path.basename(path))[0]+'_add.mp4'))
'''
---------------------Style Transfer---------------------
'''
def styletransfer_img(opt,netG): def styletransfer_img(opt,netG):
print('Style Transfer_img:',opt.media_path) print('Style Transfer_img:',opt.media_path)
img = impro.imread(opt.media_path) img = impro.imread(opt.media_path)
...@@ -61,13 +72,13 @@ def styletransfer_img(opt,netG): ...@@ -61,13 +72,13 @@ def styletransfer_img(opt,netG):
def styletransfer_video(opt,netG): def styletransfer_video(opt,netG):
path = opt.media_path path = opt.media_path
positions = [] positions = []
fps,imagepaths = video_init(opt,path) fps,imagepaths = video_init(opt,path)[:2]
for i,imagepath in enumerate(imagepaths,1): for i,imagepath in enumerate(imagepaths,1):
img = impro.imread(os.path.join('./tmp/video2image',imagepath)) img = impro.imread(os.path.join('./tmp/video2image',imagepath))
img = runmodel.run_styletransfer(opt, netG, img) img = runmodel.run_styletransfer(opt, netG, img)
cv2.imwrite(os.path.join('./tmp/style_transfer',imagepath),img) cv2.imwrite(os.path.join('./tmp/style_transfer',imagepath),img)
print('\r','Transfer:'+str(i)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=40),end='') print('\r','Transfer:'+str(i)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=35),end='')
print() print()
suffix = os.path.basename(opt.model_path).replace('.pth','').replace('style_','') suffix = os.path.basename(opt.model_path).replace('.pth','').replace('style_','')
ffmpeg.image2video( fps, ffmpeg.image2video( fps,
...@@ -75,6 +86,24 @@ def styletransfer_video(opt,netG): ...@@ -75,6 +86,24 @@ def styletransfer_video(opt,netG):
'./tmp/voice_tmp.mp3', './tmp/voice_tmp.mp3',
os.path.join(opt.result_dir,os.path.splitext(os.path.basename(path))[0]+'_'+suffix+'.mp4')) os.path.join(opt.result_dir,os.path.splitext(os.path.basename(path))[0]+'_'+suffix+'.mp4'))
'''
---------------------Clean Mosaic---------------------
'''
def get_mosaic_positions(opt,netM,imagepaths,savemask=True):
# get mosaic position
positions = []
for i,imagepath in enumerate(imagepaths,1):
img_origin = impro.imread(os.path.join('./tmp/video2image',imagepath))
x,y,size,mask = runmodel.get_mosaic_position(img_origin,netM,opt)
if savemask:
cv2.imwrite(os.path.join('./tmp/mosaic_mask',imagepath), mask)
positions.append([x,y,size])
print('\r','Find mosaic location:'+str(i)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=35),end='')
print('\nOptimize mosaic locations...')
positions =np.array(positions)
for i in range(3):positions[:,i] = filt.medfilt(positions[:,i],opt.medfilt_num)
return positions
def cleanmosaic_img(opt,netG,netM): def cleanmosaic_img(opt,netG,netM):
path = opt.media_path path = opt.media_path
...@@ -85,7 +114,10 @@ def cleanmosaic_img(opt,netG,netM): ...@@ -85,7 +114,10 @@ def cleanmosaic_img(opt,netG,netM):
img_result = img_origin.copy() img_result = img_origin.copy()
if size != 0 : if size != 0 :
img_mosaic = img_origin[y-size:y+size,x-size:x+size] img_mosaic = img_origin[y-size:y+size,x-size:x+size]
img_fake = runmodel.run_pix2pix(img_mosaic,netG,opt) if opt.traditional:
img_fake = runmodel.traditional_cleaner(img_mosaic,opt)
else:
img_fake = runmodel.run_pix2pix(img_mosaic,netG,opt)
img_result = impro.replace_mosaic(img_origin,img_fake,x,y,size,opt.no_feather) img_result = impro.replace_mosaic(img_origin,img_fake,x,y,size,opt.no_feather)
else: else:
print('Do not find mosaic') print('Do not find mosaic')
...@@ -93,19 +125,8 @@ def cleanmosaic_img(opt,netG,netM): ...@@ -93,19 +125,8 @@ def cleanmosaic_img(opt,netG,netM):
def cleanmosaic_video_byframe(opt,netG,netM): def cleanmosaic_video_byframe(opt,netG,netM):
path = opt.media_path path = opt.media_path
fps,imagepaths = video_init(opt,path) fps,imagepaths = video_init(opt,path)[:2]
positions = [] positions = get_mosaic_positions(opt,netM,imagepaths,savemask=False)
# get position
for i,imagepath in enumerate(imagepaths,1):
img_origin = impro.imread(os.path.join('./tmp/video2image',imagepath))
x,y,size = runmodel.get_mosaic_position(img_origin,netM,opt)[:3]
positions.append([x,y,size])
print('\r','Find mosaic location:'+str(i)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=40),end='')
print('\nOptimize mosaic locations...')
positions =np.array(positions)
for i in range(3):positions[:,i] = filt.medfilt(positions[:,i],opt.medfilt_num)
# clean mosaic # clean mosaic
for i,imagepath in enumerate(imagepaths,0): for i,imagepath in enumerate(imagepaths,0):
x,y,size = positions[i][0],positions[i][1],positions[i][2] x,y,size = positions[i][0],positions[i][1],positions[i][2]
...@@ -113,10 +134,13 @@ def cleanmosaic_video_byframe(opt,netG,netM): ...@@ -113,10 +134,13 @@ def cleanmosaic_video_byframe(opt,netG,netM):
img_result = img_origin.copy() img_result = img_origin.copy()
if size != 0: if size != 0:
img_mosaic = img_origin[y-size:y+size,x-size:x+size] img_mosaic = img_origin[y-size:y+size,x-size:x+size]
img_fake = runmodel.run_pix2pix(img_mosaic,netG,opt) if opt.traditional:
img_result = impro.replace_mosaic(img_origin,img_fake,x,y,size,opt.no_feather) img_fake = runmodel.traditional_cleaner(img_mosaic,opt)
else:
img_fake = runmodel.run_pix2pix(img_mosaic,netG,opt)
img_result = impro.replace_mosaic(img_origin,img_fake,x,y,size,opt.no_feather)
cv2.imwrite(os.path.join('./tmp/replace_mosaic',imagepath),img_result) cv2.imwrite(os.path.join('./tmp/replace_mosaic',imagepath),img_result)
print('\r','Clean Mosaic:'+str(i+1)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=40),end='') print('\r','Clean Mosaic:'+str(i+1)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=35),end='')
print() print()
ffmpeg.image2video( fps, ffmpeg.image2video( fps,
'./tmp/replace_mosaic/output_%05d.'+opt.tempimage_type, './tmp/replace_mosaic/output_%05d.'+opt.tempimage_type,
...@@ -127,48 +151,39 @@ def cleanmosaic_video_fusion(opt,netG,netM): ...@@ -127,48 +151,39 @@ def cleanmosaic_video_fusion(opt,netG,netM):
path = opt.media_path path = opt.media_path
N = 25 N = 25
INPUT_SIZE = 128 INPUT_SIZE = 128
fps,imagepaths = video_init(opt,path) fps,imagepaths,height,width = video_init(opt,path)
positions = [] positions = get_mosaic_positions(opt,netM,imagepaths,savemask=True)
# get position
for i,imagepath in enumerate(imagepaths,1):
img_origin = impro.imread(os.path.join('./tmp/video2image',imagepath))
# x,y,size = runmodel.get_mosaic_position(img_origin,net_mosaic_pos,opt)[:3]
x,y,size,mask = runmodel.get_mosaic_position(img_origin,netM,opt)
cv2.imwrite(os.path.join('./tmp/mosaic_mask',imagepath), mask)
positions.append([x,y,size])
print('\r','Find mosaic location:'+str(i)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=40),end='')
print('\nOptimize mosaic locations...')
positions =np.array(positions)
for i in range(3):positions[:,i] = filt.medfilt(positions[:,i],opt.medfilt_num)
# clean mosaic # clean mosaic
img_pool = np.zeros((height,width,3*N), dtype='uint8')
for i,imagepath in enumerate(imagepaths,0): for i,imagepath in enumerate(imagepaths,0):
x,y,size = positions[i][0],positions[i][1],positions[i][2] x,y,size = positions[i][0],positions[i][1],positions[i][2]
img_origin = impro.imread(os.path.join('./tmp/video2image',imagepath))
# image read stream
mask = cv2.imread(os.path.join('./tmp/mosaic_mask',imagepath),0) mask = cv2.imread(os.path.join('./tmp/mosaic_mask',imagepath),0)
if i==0 :
for j in range(0,N):
img_pool[:,:,j*3:(j+1)*3] = impro.imread(os.path.join('./tmp/video2image',imagepaths[np.clip(i+j-12,0,len(imagepaths)-1)]))
else:
img_pool[:,:,0:(N-1)*3] = img_pool[:,:,3:N*3]
img_pool[:,:,(N-1)*3:] = impro.imread(os.path.join('./tmp/video2image',imagepaths[np.clip(i+12,0,len(imagepaths)-1)]))
img_origin = img_pool[:,:,int((N-1)/2)*3:(int((N-1)/2)+1)*3]
if size==0: if size==0: # can not find mosaic,
cv2.imwrite(os.path.join('./tmp/replace_mosaic',imagepath),img_origin) cv2.imwrite(os.path.join('./tmp/replace_mosaic',imagepath),img_origin)
else: else:
mosaic_input = np.zeros((INPUT_SIZE,INPUT_SIZE,3*N+1), dtype='uint8') mosaic_input = np.zeros((INPUT_SIZE,INPUT_SIZE,3*N+1), dtype='uint8')
for j in range(0,N): mosaic_input[:,:,0:N*3] = impro.resize(img_pool[y-size:y+size,x-size:x+size,:], INPUT_SIZE)
img = impro.imread(os.path.join('./tmp/video2image',imagepaths[np.clip(i+j-12,0,len(imagepaths)-1)])) mask = impro.resize(mask,np.min(img_origin.shape[:2]))[y-size:y+size,x-size:x+size]
img = img[y-size:y+size,x-size:x+size] mosaic_input[:,:,-1] = impro.resize(mask, INPUT_SIZE)
img = impro.resize(img,INPUT_SIZE)
mosaic_input[:,:,j*3:(j+1)*3] = img
mask = impro.resize(mask,np.min(img_origin.shape[:2]))
mask = mask[y-size:y+size,x-size:x+size]
mask = impro.resize(mask, INPUT_SIZE)
mosaic_input[:,:,-1] = mask
mosaic_input = data.im2tensor(mosaic_input,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False,is0_1 = False) mosaic_input = data.im2tensor(mosaic_input,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False,is0_1 = False)
unmosaic_pred = netG(mosaic_input) unmosaic_pred = netG(mosaic_input)
#unmosaic_pred = (unmosaic_pred.cpu().detach().numpy()*255)[0]
#img_fake = unmosaic_pred.transpose((1, 2, 0))
img_fake = data.tensor2im(unmosaic_pred,rgb2bgr = False ,is0_1 = False) img_fake = data.tensor2im(unmosaic_pred,rgb2bgr = False ,is0_1 = False)
img_result = impro.replace_mosaic(img_origin,img_fake,x,y,size,opt.no_feather) img_result = impro.replace_mosaic(img_origin,img_fake,x,y,size,opt.no_feather)
cv2.imwrite(os.path.join('./tmp/replace_mosaic',imagepath),img_result) cv2.imwrite(os.path.join('./tmp/replace_mosaic',imagepath),img_result)
print('\r','Clean Mosaic:'+str(i+1)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=40),end='') print('\r','Clean Mosaic:'+str(i+1)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=35),end='')
print() print()
ffmpeg.image2video( fps, ffmpeg.image2video( fps,
'./tmp/replace_mosaic/output_%05d.'+opt.tempimage_type, './tmp/replace_mosaic/output_%05d.'+opt.tempimage_type,
......
...@@ -10,15 +10,16 @@ class Options(): ...@@ -10,15 +10,16 @@ class Options():
def initialize(self): def initialize(self):
#base #base
self.parser.add_argument('--use_gpu',type=int,default=1, help='if 0 or -1, do not use gpu') self.parser.add_argument('--use_gpu',type=int,default=0, help='if -1, do not use gpu')
# self.parser.add_argument('--use_gpu', action='store_true', help='if input it, use gpu') # self.parser.add_argument('--use_gpu', action='store_true', help='if input it, use gpu')
self.parser.add_argument('--media_path', type=str, default='./hands_test.mp4',help='your videos or images path') self.parser.add_argument('--media_path', type=str, default='./imgs/ruoruo.jpg',help='your videos or images path')
self.parser.add_argument('--mode', type=str, default='auto',help='auto | add | clean | style') self.parser.add_argument('--mode', type=str, default='auto',help='auto | add | clean | style')
self.parser.add_argument('--model_path', type=str, default='./pretrained_models/add_hands_128.pth',help='pretrained model path') self.parser.add_argument('--model_path', type=str, default='./pretrained_models/mosaic/add_face.pth',help='pretrained model path')
self.parser.add_argument('--result_dir', type=str, default='./result',help='output media will be saved here') self.parser.add_argument('--result_dir', type=str, default='./result',help='output media will be saved here')
self.parser.add_argument('--tempimage_type', type=str, default='png',help='type of temp image, png | jpg, png is better but occupy more storage space') self.parser.add_argument('--tempimage_type', type=str, default='png',help='type of temp image, png | jpg, png is better but occupy more storage space')
self.parser.add_argument('--netG', type=str, default='auto', self.parser.add_argument('--netG', type=str, default='auto',
help='select model to use for netG(Clean mosaic and Transfer style) -> auto | unet_128 | unet_256 | resnet_9blocks | HD | video') help='select model to use for netG(Clean mosaic and Transfer style) -> auto | unet_128 | unet_256 | resnet_9blocks | HD | video')
self.parser.add_argument('--fps', type=int, default=0,help='read and output fps, if 0-> origin')
self.parser.add_argument('--output_size', type=int, default=0,help='size of output file,if 0 -> origin') self.parser.add_argument('--output_size', type=int, default=0,help='size of output file,if 0 -> origin')
#AddMosaic #AddMosaic
...@@ -29,8 +30,11 @@ class Options(): ...@@ -29,8 +30,11 @@ class Options():
#CleanMosaic #CleanMosaic
self.parser.add_argument('--mosaic_position_model_path', type=str, default='auto',help='name of model use to find mosaic position') self.parser.add_argument('--mosaic_position_model_path', type=str, default='auto',help='name of model use to find mosaic position')
self.parser.add_argument('--traditional', action='store_true', help='if true, use traditional image processing methods to clean mosaic')
self.parser.add_argument('--tr_blur', type=int, default=10, help='ksize of blur when using traditional method, it will affect final quality')
self.parser.add_argument('--tr_down', type=int, default=10, help='downsample when using traditional method,it will affect final quality')
self.parser.add_argument('--no_feather', action='store_true', help='if true, no edge feather and color correction, but run faster') self.parser.add_argument('--no_feather', action='store_true', help='if true, no edge feather and color correction, but run faster')
self.parser.add_argument('--no_large_area', action='store_true', help='if true, do not find the largest mosaic area') self.parser.add_argument('--all_mosaic_area', action='store_true', help='if true, find all mosaic area, else only find the largest area')
self.parser.add_argument('--medfilt_num', type=int, default=11,help='medfilt window of mosaic movement in the video') self.parser.add_argument('--medfilt_num', type=int, default=11,help='medfilt window of mosaic movement in the video')
self.parser.add_argument('--ex_mult', type=str, default='auto',help='mosaic area expansion') self.parser.add_argument('--ex_mult', type=str, default='auto',help='mosaic area expansion')
...@@ -50,17 +54,16 @@ class Options(): ...@@ -50,17 +54,16 @@ class Options():
model_name = os.path.basename(self.opt.model_path) model_name = os.path.basename(self.opt.model_path)
if torch.cuda.is_available() and self.opt.use_gpu > 0: if torch.cuda.is_available() and self.opt.use_gpu > -1:
self.opt.use_gpu = True self.opt.use_gpu = True
else: else:
self.opt.use_gpu = False self.opt.use_gpu = False
if self.opt.mode == 'auto': if self.opt.mode == 'auto':
if 'add' in model_name: if 'clean' in model_name or self.opt.traditional:
self.opt.mode = 'add'
elif 'clean' in model_name:
self.opt.mode = 'clean' self.opt.mode = 'clean'
elif 'add' in model_name:
self.opt.mode = 'add'
elif 'style' in model_name or 'edges' in model_name: elif 'style' in model_name or 'edges' in model_name:
self.opt.mode = 'style' self.opt.mode = 'style'
else: else:
......
...@@ -25,7 +25,9 @@ def main(): ...@@ -25,7 +25,9 @@ def main():
elif opt.mode == 'clean': elif opt.mode == 'clean':
netM = loadmodel.unet_clean(opt) netM = loadmodel.unet_clean(opt)
if opt.netG == 'video': if opt.traditional:
netG = None
elif opt.netG == 'video':
netG = loadmodel.video(opt) netG = loadmodel.video(opt)
else: else:
netG = loadmodel.pix2pix(opt) netG = loadmodel.pix2pix(opt)
...@@ -35,7 +37,7 @@ def main(): ...@@ -35,7 +37,7 @@ def main():
if util.is_img(file): if util.is_img(file):
core.cleanmosaic_img(opt,netG,netM) core.cleanmosaic_img(opt,netG,netM)
elif util.is_video(file): elif util.is_video(file):
if opt.netG == 'video': if opt.netG == 'video' and not opt.traditional:
core.cleanmosaic_video_fusion(opt,netG,netM) core.cleanmosaic_video_fusion(opt,netG,netM)
else: else:
core.cleanmosaic_video_byframe(opt,netG,netM) core.cleanmosaic_video_byframe(opt,netG,netM)
...@@ -56,12 +58,12 @@ def main(): ...@@ -56,12 +58,12 @@ def main():
util.clean_tempfiles(tmp_init = False) util.clean_tempfiles(tmp_init = False)
# main() main()
if __name__ == '__main__': # if __name__ == '__main__':
try: # try:
main() # main()
except Exception as e: # except Exception as e:
print('Error:',e) # print('Error:',e)
input('Please press any key to exit.\n') # input('Please press any key to exit.\n')
util.clean_tempfiles(tmp_init = False) # util.clean_tempfiles(tmp_init = False)
exit(0) # exit(0)
...@@ -8,22 +8,10 @@ import torch ...@@ -8,22 +8,10 @@ import torch
import numpy as np import numpy as np
def run_unet(img,net,size = 224,use_gpu = True): def run_unet(img,net,size = 224,use_gpu = True):
img=impro.image2folat(img,3)
img=img.reshape(1,3,size,size)
img = torch.from_numpy(img)
if use_gpu:
img=img.cuda()
pred = net(img)
pred = (pred.cpu().detach().numpy()*255)
pred = pred.reshape(size,size).astype('uint8')
return pred
def run_unet_rectim(img,net,size = 224,use_gpu = True):
img = impro.resize(img,size) img = impro.resize(img,size)
img1,img2 = impro.spiltimage(img,size) img = data.im2tensor(img,use_gpu = use_gpu, bgr2rgb = False,use_transform = False , is0_1 = True)
mask1 = run_unet(img1,net,size,use_gpu = use_gpu) mask = net(img)
mask2 = run_unet(img2,net,size,use_gpu = use_gpu) mask = data.tensor2im(mask, gray=True,rgb2bgr = False, is0_1 = True)
mask = impro.mergeimage(mask1,mask2,img,size)
return mask return mask
def run_pix2pix(img,net,opt): def run_pix2pix(img,net,opt):
...@@ -36,6 +24,13 @@ def run_pix2pix(img,net,opt): ...@@ -36,6 +24,13 @@ def run_pix2pix(img,net,opt):
img_fake = data.tensor2im(img_fake) img_fake = data.tensor2im(img_fake)
return img_fake return img_fake
def traditional_cleaner(img,opt):
h,w = img.shape[:2]
img = cv2.blur(img, (opt.tr_blur,opt.tr_blur))
img = img[::opt.tr_down,::opt.tr_down,:]
img = cv2.resize(img, (w,h),interpolation=cv2.INTER_LANCZOS4)
return img
def run_styletransfer(opt, net, img): def run_styletransfer(opt, net, img):
if opt.output_size != 0: if opt.output_size != 0:
...@@ -60,23 +55,22 @@ def run_styletransfer(opt, net, img): ...@@ -60,23 +55,22 @@ def run_styletransfer(opt, net, img):
return img return img
img = data.im2tensor(img,use_gpu=opt.use_gpu,gray=True,use_transform = False,is0_1 = False) img = data.im2tensor(img,use_gpu=opt.use_gpu,gray=True,use_transform = False,is0_1 = False)
else: else:
img = data.im2tensor(img,use_gpu=opt.use_gpu) img = data.im2tensor(img,use_gpu=opt.use_gpu,gray=False,use_transform = True)
img = net(img) img = net(img)
img = data.tensor2im(img) img = data.tensor2im(img)
return img return img
def get_ROI_position(img,net,opt): def get_ROI_position(img,net,opt):
mask = run_unet_rectim(img,net,use_gpu = opt.use_gpu) mask = run_unet(img,net,size=224,use_gpu = opt.use_gpu)
mask = impro.mask_threshold(mask,opt.mask_extend,opt.mask_threshold) mask = impro.mask_threshold(mask,opt.mask_extend,opt.mask_threshold)
x,y,halfsize,area = impro.boundingSquare(mask, 1) x,y,halfsize,area = impro.boundingSquare(mask, 1)
return mask,x,y,area return mask,x,y,area
def get_mosaic_position(img_origin,net_mosaic_pos,opt,threshold = 128 ): def get_mosaic_position(img_origin,net_mosaic_pos,opt,threshold = 128 ):
mask = run_unet_rectim(img_origin,net_mosaic_pos,use_gpu = opt.use_gpu) mask = run_unet(img_origin,net_mosaic_pos,size=224,use_gpu = opt.use_gpu)
#mask_1 = mask.copy()
mask = impro.mask_threshold(mask,30,threshold) mask = impro.mask_threshold(mask,30,threshold)
if not opt.no_large_area: if not opt.all_mosaic_area:
mask = impro.find_best_ROI(mask) mask = impro.find_mostlikely_ROI(mask)
x,y,size,area = impro.boundingSquare(mask,Ex_mul=opt.ex_mult) x,y,size,area = impro.boundingSquare(mask,Ex_mul=opt.ex_mult)
rat = min(img_origin.shape[:2])/224.0 rat = min(img_origin.shape[:2])/224.0
x,y,size = int(rat*x),int(rat*y),int(rat*size) x,y,size = int(rat*x),int(rat*y),int(rat*size)
......
import torch
import torch.nn as nn
import torch.nn.functional as F
from .unet_parts import *
class conv_3d(nn.Module):
def __init__(self,inchannel,outchannel,kernel_size=3,stride=2,padding=1):
super(conv_3d, self).__init__()
self.conv = nn.Sequential(
nn.Conv3d(inchannel, outchannel, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
nn.BatchNorm3d(outchannel),
nn.ReLU(inplace=True),
)
def forward(self, x):
x = self.conv(x)
return x
class encoder_3d(nn.Module):
def __init__(self,in_channel):
super(encoder_3d, self).__init__()
self.down1 = conv_3d(1, 64, 3, 2, 1)
self.down2 = conv_3d(64, 128, 3, 2, 1)
self.down3 = conv_3d(128, 256, 3, 2, 1)
self.down4 = conv_3d(256, 512, 3, 2, 1)
self.conver2d = nn.Sequential(
nn.Conv2d(int(in_channel/16)+1, 1, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(1),
nn.ReLU(inplace=True),
)
def forward(self, x):
x = x.view(x.size(0),1,x.size(1),x.size(2),x.size(3))
x = self.down1(x)
x = self.down2(x)
x = self.down3(x)
x = self.down4(x)
x = x.view(x.size(1),x.size(2),x.size(3),x.size(4))
x = self.conver2d(x)
x = x.view(x.size(1),x.size(0),x.size(2),x.size(3))
# print(x.size())
# x = self.avgpool(x)
return x
class encoder_2d(nn.Module):
def __init__(self, in_channel):
super(encoder_2d, self).__init__()
self.inc = inconv(in_channel, 64)
self.down1 = down(64, 128)
self.down2 = down(128, 256)
self.down3 = down(256, 512)
self.down4 = down(512, 512)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
return x1,x2,x3,x4,x5
class decoder_2d(nn.Module):
def __init__(self, out_channel):
super(decoder_2d, self).__init__()
self.up1 = up(1024, 256,bilinear=False)
self.up2 = up(512, 128,bilinear=False)
self.up3 = up(256, 64,bilinear=False)
self.up4 = up(128, 64,bilinear=False)
self.outc = outconv(64, out_channel)
def forward(self,x5,x4,x3,x2,x1):
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
x = self.outc(x)
return x
class HypoNet(nn.Module):
def __init__(self, in_channel, out_channel):
super(HypoNet, self).__init__()
self.encoder_2d = encoder_2d(4)
self.encoder_3d = encoder_3d(in_channel)
self.decoder_2d = decoder_2d(out_channel)
def forward(self, x):
N = int((x.size()[1])/3)
x_2d = torch.cat((x[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], x[:,N-1:N,:,:]), 1)
# print(x_2d.size())
x_3d = self.encoder_3d(x)
x1,x2,x3,x4,x5 = self.encoder_2d(x_2d)
x5 = x5 + x_3d
x_2d = self.decoder_2d(x5,x4,x3,x2,x1)
return x_2d
...@@ -12,37 +12,41 @@ sys.path.append("../..") ...@@ -12,37 +12,41 @@ sys.path.append("../..")
from util import mosaic,util,ffmpeg,filt,data from util import mosaic,util,ffmpeg,filt,data
from util import image_processing as impro from util import image_processing as impro
from cores import Options from cores import Options
from models import pix2pix_model,pix2pixHD_model,video_model,unet_model,loadmodel from models import pix2pix_model,pix2pixHD_model,video_model,unet_model,loadmodel,videoHD_model
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
N = 25 opt = Options()
ITER = 10000000 opt.parser.add_argument('--N',type=int,default=25, help='')
LR = 0.0002 opt.parser.add_argument('--lr',type=float,default=0.0002, help='')
beta1 = 0.5 opt.parser.add_argument('--beta1',type=float,default=0.5, help='')
use_gpu = True opt.parser.add_argument('--gan', action='store_true', help='if input it, use gan')
use_gan = False opt.parser.add_argument('--l2', action='store_true', help='if input it, use L2 loss')
use_L2 = False opt.parser.add_argument('--lambda_L1',type=float,default=100, help='')
CONTINUE = True opt.parser.add_argument('--lambda_gan',type=float,default=1, help='')
lambda_L1 = 100.0 opt.parser.add_argument('--finesize',type=int,default=256, help='')
lambda_gan = 0.5 opt.parser.add_argument('--loadsize',type=int,default=286, help='')
opt.parser.add_argument('--batchsize',type=int,default=1, help='')
opt.parser.add_argument('--perload_num',type=int,default=16, help='')
opt.parser.add_argument('--norm',type=str,default='instance', help='')
SAVE_FRE = 10000 opt.parser.add_argument('--maxiter',type=int,default=10000000, help='')
start_iter = 0 opt.parser.add_argument('--savefreq',type=int,default=10000, help='')
finesize = 256 opt.parser.add_argument('--startiter',type=int,default=0, help='')
loadsize = int(finesize*1.2) opt.parser.add_argument('--continuetrain', action='store_true', help='')
batchsize = 6 opt.parser.add_argument('--savename',type=str,default='MosaicNet', help='')
perload_num = 16
# savename = 'MosaicNet_instance_gan_256_hdD' opt = opt.getparse()
savename = 'MosaicNet_instance_test' dir_checkpoint = os.path.join('checkpoints/',opt.savename)
dir_checkpoint = 'checkpoints/'+savename
util.makedirs(dir_checkpoint) util.makedirs(dir_checkpoint)
util.writelog(os.path.join(dir_checkpoint,'loss.txt'),
str(time.asctime(time.localtime(time.time())))+'\n'+util.opt2str(opt))
N = opt.N
loss_sum = [0.,0.,0.,0.] loss_sum = [0.,0.,0.,0.]
loss_plot = [[],[]] loss_plot = [[],[]]
item_plot = [] item_plot = []
opt = Options().getparse()
videos = os.listdir('./dataset') videos = os.listdir('./dataset')
videos.sort() videos.sort()
lengths = [] lengths = []
...@@ -53,39 +57,39 @@ for video in videos: ...@@ -53,39 +57,39 @@ for video in videos:
#unet_128 #unet_128
#resnet_9blocks #resnet_9blocks
#netG = pix2pix_model.define_G(3*N+1, 3, 128, 'resnet_6blocks', norm='instance',use_dropout=True, init_type='normal', gpu_ids=[]) #netG = pix2pix_model.define_G(3*N+1, 3, 128, 'resnet_6blocks', norm='instance',use_dropout=True, init_type='normal', gpu_ids=[])
netG = video_model.MosaicNet(3*N+1, 3, norm='instance') netG = videoHD_model.MosaicNet(3*N+1, 3, norm=opt.norm)
loadmodel.show_paramsnumber(netG,'netG') loadmodel.show_paramsnumber(netG,'netG')
# netG = unet_model.UNet(3*N+1, 3) # netG = unet_model.UNet(3*N+1, 3)
if use_gan: if opt.gan:
netD = pix2pixHD_model.define_D(6, 64, 3, norm='instance', use_sigmoid=False, num_D=2) netD = pix2pixHD_model.define_D(6, 64, 3, norm=opt.norm, use_sigmoid=False, num_D=2)
#netD = pix2pix_model.define_D(3*2+1, 64, 'pixel', norm='instance') #netD = pix2pix_model.define_D(3*2+1, 64, 'pixel', norm='instance')
#netD = pix2pix_model.define_D(3*2, 64, 'basic', norm='instance') #netD = pix2pix_model.define_D(3*2, 64, 'basic', norm='instance')
#netD = pix2pix_model.define_D(3*2+1, 64, 'n_layers', n_layers_D=5, norm='instance') #netD = pix2pix_model.define_D(3*2+1, 64, 'n_layers', n_layers_D=5, norm='instance')
if CONTINUE: if opt.continuetrain:
if not os.path.isfile(os.path.join(dir_checkpoint,'last_G.pth')): if not os.path.isfile(os.path.join(dir_checkpoint,'last_G.pth')):
CONTINUE = False opt.continuetrain = False
print('can not load last_G, training on init weight.') print('can not load last_G, training on init weight.')
if CONTINUE: if opt.continuetrain:
netG.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last_G.pth'))) netG.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last_G.pth')))
if use_gan: if opt.gan:
netD.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last_D.pth'))) netD.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last_D.pth')))
f = open(os.path.join(dir_checkpoint,'iter'),'r') f = open(os.path.join(dir_checkpoint,'iter'),'r')
start_iter = int(f.read()) opt.startiter = int(f.read())
f.close() f.close()
optimizer_G = torch.optim.Adam(netG.parameters(), lr=LR,betas=(beta1, 0.999)) optimizer_G = torch.optim.Adam(netG.parameters(), lr=opt.lr,betas=(opt.beta1, 0.999))
criterion_L1 = nn.L1Loss() criterion_L1 = nn.L1Loss()
criterion_L2 = nn.MSELoss() criterion_L2 = nn.MSELoss()
if use_gan: if opt.gan:
optimizer_D = torch.optim.Adam(netG.parameters(), lr=LR,betas=(beta1, 0.999)) optimizer_D = torch.optim.Adam(netG.parameters(), lr=opt.lr,betas=(opt.beta1, 0.999))
# criterionGAN = pix2pix_model.GANLoss(gan_mode='lsgan').cuda() # criterionGAN = pix2pix_model.GANLoss(gan_mode='lsgan').cuda()
criterionGAN = pix2pixHD_model.GANLoss(tensor=torch.cuda.FloatTensor) criterionGAN = pix2pixHD_model.GANLoss(tensor=torch.cuda.FloatTensor)
netD.train() netD.train()
if use_gpu: if opt.use_gpu:
netG.cuda() netG.cuda()
if use_gan: if opt.gan:
netD.cuda() netD.cuda()
criterionGAN.cuda() criterionGAN.cuda()
cudnn.benchmark = True cudnn.benchmark = True
...@@ -93,22 +97,22 @@ if use_gpu: ...@@ -93,22 +97,22 @@ if use_gpu:
def loaddata(): def loaddata():
video_index = random.randint(0,len(videos)-1) video_index = random.randint(0,len(videos)-1)
video = videos[video_index] video = videos[video_index]
img_index = random.randint(N,lengths[video_index]- N) img_index = random.randint(int(N/2)+1,lengths[video_index]- int(N/2)-1)
input_img = np.zeros((loadsize,loadsize,3*N+1), dtype='uint8') input_img = np.zeros((opt.loadsize,opt.loadsize,3*N+1), dtype='uint8')
for i in range(0,N): for i in range(0,N):
img = cv2.imread('./dataset/'+video+'/mosaic/output_'+'%05d'%(img_index+i-int(N/2))+'.png') img = cv2.imread('./dataset/'+video+'/mosaic/output_'+'%05d'%(img_index+i-int(N/2))+'.png')
img = impro.resize(img,loadsize) img = impro.resize(img,opt.loadsize)
input_img[:,:,i*3:(i+1)*3] = img input_img[:,:,i*3:(i+1)*3] = img
mask = cv2.imread('./dataset/'+video+'/mask/output_'+'%05d'%(img_index)+'.png',0) mask = cv2.imread('./dataset/'+video+'/mask/output_'+'%05d'%(img_index)+'.png',0)
mask = impro.resize(mask,loadsize) mask = impro.resize(mask,opt.loadsize)
mask = impro.mask_threshold(mask,15,128) mask = impro.mask_threshold(mask,15,128)
input_img[:,:,-1] = mask input_img[:,:,-1] = mask
ground_true = cv2.imread('./dataset/'+video+'/ori/output_'+'%05d'%(img_index)+'.png') ground_true = cv2.imread('./dataset/'+video+'/ori/output_'+'%05d'%(img_index)+'.png')
ground_true = impro.resize(ground_true,loadsize) ground_true = impro.resize(ground_true,opt.loadsize)
input_img,ground_true = data.random_transform_video(input_img,ground_true,finesize,N) input_img,ground_true = data.random_transform_video(input_img,ground_true,opt.finesize,N)
input_img = data.im2tensor(input_img,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False,is0_1=False) input_img = data.im2tensor(input_img,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False,is0_1=False)
ground_true = data.im2tensor(ground_true,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False,is0_1=False) ground_true = data.im2tensor(ground_true,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False,is0_1=False)
...@@ -116,17 +120,17 @@ def loaddata(): ...@@ -116,17 +120,17 @@ def loaddata():
print('preloading data, please wait 5s...') print('preloading data, please wait 5s...')
if perload_num <= batchsize: if opt.perload_num <= opt.batchsize:
perload_num = batchsize*2 opt.perload_num = opt.batchsize*2
input_imgs = torch.rand(perload_num,N*3+1,finesize,finesize).cuda() input_imgs = torch.rand(opt.perload_num,N*3+1,opt.finesize,opt.finesize).cuda()
ground_trues = torch.rand(perload_num,3,finesize,finesize).cuda() ground_trues = torch.rand(opt.perload_num,3,opt.finesize,opt.finesize).cuda()
load_cnt = 0 load_cnt = 0
def preload(): def preload():
global load_cnt global load_cnt
while 1: while 1:
try: try:
ran = random.randint(0, perload_num-1) ran = random.randint(0, opt.perload_num-1)
input_imgs[ran],ground_trues[ran] = loaddata() input_imgs[ran],ground_trues[ran] = loaddata()
load_cnt += 1 load_cnt += 1
# time.sleep(0.1) # time.sleep(0.1)
...@@ -139,24 +143,24 @@ t.daemon = True ...@@ -139,24 +143,24 @@ t.daemon = True
t.start() t.start()
time_start=time.time() time_start=time.time()
while load_cnt < perload_num: while load_cnt < opt.perload_num:
time.sleep(0.1) time.sleep(0.1)
time_end=time.time() time_end=time.time()
print('load speed:',round((time_end-time_start)/perload_num,3),'s/it') print('load speed:',round((time_end-time_start)/opt.perload_num,3),'s/it')
util.copyfile('./train.py', os.path.join(dir_checkpoint,'train.py')) util.copyfile('./train.py', os.path.join(dir_checkpoint,'train.py'))
util.copyfile('../../models/video_model.py', os.path.join(dir_checkpoint,'model.py')) util.copyfile('../../models/videoHD_model.py', os.path.join(dir_checkpoint,'model.py'))
netG.train() netG.train()
time_start=time.time() time_start=time.time()
print("Begin training...") print("Begin training...")
for iter in range(start_iter+1,ITER): for iter in range(opt.startiter+1,opt.maxiter):
ran = random.randint(0, perload_num-batchsize-1) ran = random.randint(0, opt.perload_num-opt.batchsize-1)
inputdata = input_imgs[ran:ran+batchsize].clone() inputdata = input_imgs[ran:ran+opt.batchsize].clone()
target = ground_trues[ran:ran+batchsize].clone() target = ground_trues[ran:ran+opt.batchsize].clone()
if use_gan: if opt.gan:
# compute fake images: G(A) # compute fake images: G(A)
pred = netG(inputdata) pred = netG(inputdata)
# update D # update D
...@@ -186,12 +190,12 @@ for iter in range(start_iter+1,ITER): ...@@ -186,12 +190,12 @@ for iter in range(start_iter+1,ITER):
real_A = inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:] real_A = inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:]
fake_AB = torch.cat((real_A, pred), 1) fake_AB = torch.cat((real_A, pred), 1)
pred_fake = netD(fake_AB) pred_fake = netD(fake_AB)
loss_G_GAN = criterionGAN(pred_fake, True)*lambda_gan loss_G_GAN = criterionGAN(pred_fake, True)*opt.lambda_gan
# Second, G(A) = B # Second, G(A) = B
if use_L2: if opt.l2:
loss_G_L1 = (criterion_L1(pred, target)+criterion_L2(pred, target)) * lambda_L1 loss_G_L1 = (criterion_L1(pred, target)+criterion_L2(pred, target)) * opt.lambda_L1
else: else:
loss_G_L1 = criterion_L1(pred, target) * lambda_L1 loss_G_L1 = criterion_L1(pred, target) * opt.lambda_L1
# combine loss and calculate gradients # combine loss and calculate gradients
loss_G = loss_G_GAN + loss_G_L1 loss_G = loss_G_GAN + loss_G_L1
loss_sum[0] += loss_G_L1.item() loss_sum[0] += loss_G_L1.item()
...@@ -202,10 +206,10 @@ for iter in range(start_iter+1,ITER): ...@@ -202,10 +206,10 @@ for iter in range(start_iter+1,ITER):
else: else:
pred = netG(inputdata) pred = netG(inputdata)
if use_L2: if opt.l2:
loss_G_L1 = (criterion_L1(pred, target)+criterion_L2(pred, target)) * lambda_L1 loss_G_L1 = (criterion_L1(pred, target)+criterion_L2(pred, target)) * opt.lambda_L1
else: else:
loss_G_L1 = criterion_L1(pred, target) * lambda_L1 loss_G_L1 = criterion_L1(pred, target) * opt.lambda_L1
loss_sum[0] += loss_G_L1.item() loss_sum[0] += loss_G_L1.item()
optimizer_G.zero_grad() optimizer_G.zero_grad()
...@@ -215,15 +219,16 @@ for iter in range(start_iter+1,ITER): ...@@ -215,15 +219,16 @@ for iter in range(start_iter+1,ITER):
if (iter+1)%100 == 0: if (iter+1)%100 == 0:
try: try:
data.showresult(inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], data.showresult(inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:],
target, pred,os.path.join(dir_checkpoint,'result_train.png')) target, pred,os.path.join(dir_checkpoint,'result_train.jpg'))
except Exception as e: except Exception as e:
print(e) print(e)
if (iter+1)%1000 == 0: if (iter+1)%1000 == 0:
time_end = time.time() time_end = time.time()
if use_gan: if opt.gan:
print('iter:',iter+1,' L1_loss:', round(loss_sum[0]/1000,4),' G_loss:', round(loss_sum[1]/1000,4), savestr ='iter:{0:d} L1_loss:{1:.4f} G_loss:{2:.4f} D_f:{3:.4f} D_r:{4:.4f} time:{5:.2f}'.format(
' D_f:',round(loss_sum[2]/1000,4),' D_r:',round(loss_sum[3]/1000,4),' time:',round((time_end-time_start)/1000,2)) iter+1,loss_sum[0]/1000,loss_sum[1]/1000,loss_sum[2]/1000,loss_sum[3]/1000,(time_end-time_start)/1000)
util.writelog(os.path.join(dir_checkpoint,'loss.txt'), savestr,True)
if (iter+1)/1000 >= 10: if (iter+1)/1000 >= 10:
loss_plot[0].append(loss_sum[0]/1000) loss_plot[0].append(loss_sum[0]/1000)
loss_plot[1].append(loss_sum[1]/1000) loss_plot[1].append(loss_sum[1]/1000)
...@@ -231,18 +236,19 @@ for iter in range(start_iter+1,ITER): ...@@ -231,18 +236,19 @@ for iter in range(start_iter+1,ITER):
try: try:
plt.plot(item_plot,loss_plot[0]) plt.plot(item_plot,loss_plot[0])
plt.plot(item_plot,loss_plot[1]) plt.plot(item_plot,loss_plot[1])
plt.savefig(os.path.join(dir_checkpoint,'loss.png')) plt.savefig(os.path.join(dir_checkpoint,'loss.jpg'))
plt.close() plt.close()
except Exception as e: except Exception as e:
print("error:",e) print("error:",e)
else: else:
print('iter:',iter+1,' L1_loss:',round(loss_sum[0]/1000,4),' time:',round((time_end-time_start)/1000,2)) savestr ='iter:{0:d} L1_loss:{1:.4f} time:{2:.2f}'.format(iter+1,loss_sum[0]/1000,(time_end-time_start)/1000)
util.writelog(os.path.join(dir_checkpoint,'loss.txt'), savestr,True)
if (iter+1)/1000 >= 10: if (iter+1)/1000 >= 10:
loss_plot[0].append(loss_sum[0]/1000) loss_plot[0].append(loss_sum[0]/1000)
item_plot.append(iter+1) item_plot.append(iter+1)
try: try:
plt.plot(item_plot,loss_plot[0]) plt.plot(item_plot,loss_plot[0])
plt.savefig(os.path.join(dir_checkpoint,'loss.png')) plt.savefig(os.path.join(dir_checkpoint,'loss.jpg'))
plt.close() plt.close()
except Exception as e: except Exception as e:
print("error:",e) print("error:",e)
...@@ -250,17 +256,17 @@ for iter in range(start_iter+1,ITER): ...@@ -250,17 +256,17 @@ for iter in range(start_iter+1,ITER):
time_start=time.time() time_start=time.time()
if (iter+1)%SAVE_FRE == 0: if (iter+1)%opt.savefreq == 0:
if iter+1 != SAVE_FRE: if iter+1 != opt.savefreq:
os.rename(os.path.join(dir_checkpoint,'last_G.pth'),os.path.join(dir_checkpoint,str(iter+1-SAVE_FRE)+'G.pth')) os.rename(os.path.join(dir_checkpoint,'last_G.pth'),os.path.join(dir_checkpoint,str(iter+1-opt.savefreq)+'G.pth'))
torch.save(netG.cpu().state_dict(),os.path.join(dir_checkpoint,'last_G.pth')) torch.save(netG.cpu().state_dict(),os.path.join(dir_checkpoint,'last_G.pth'))
if use_gan: if opt.gan:
if iter+1 != SAVE_FRE: if iter+1 != opt.savefreq:
os.rename(os.path.join(dir_checkpoint,'last_D.pth'),os.path.join(dir_checkpoint,str(iter+1-SAVE_FRE)+'D.pth')) os.rename(os.path.join(dir_checkpoint,'last_D.pth'),os.path.join(dir_checkpoint,str(iter+1-opt.savefreq)+'D.pth'))
torch.save(netD.cpu().state_dict(),os.path.join(dir_checkpoint,'last_D.pth')) torch.save(netD.cpu().state_dict(),os.path.join(dir_checkpoint,'last_D.pth'))
if use_gpu: if opt.use_gpu:
netG.cuda() netG.cuda()
if use_gan: if opt.gan:
netD.cuda() netD.cuda()
f = open(os.path.join(dir_checkpoint,'iter'),'w+') f = open(os.path.join(dir_checkpoint,'iter'),'w+')
f.write(str(iter+1)) f.write(str(iter+1))
...@@ -272,27 +278,27 @@ for iter in range(start_iter+1,ITER): ...@@ -272,27 +278,27 @@ for iter in range(start_iter+1,ITER):
test_names = os.listdir('./test') test_names = os.listdir('./test')
test_names.sort() test_names.sort()
result = np.zeros((finesize*2,finesize*len(test_names),3), dtype='uint8') result = np.zeros((opt.finesize*2,opt.finesize*len(test_names),3), dtype='uint8')
for cnt,test_name in enumerate(test_names,0): for cnt,test_name in enumerate(test_names,0):
img_names = os.listdir(os.path.join('./test',test_name,'image')) img_names = os.listdir(os.path.join('./test',test_name,'image'))
img_names.sort() img_names.sort()
inputdata = np.zeros((finesize,finesize,3*N+1), dtype='uint8') inputdata = np.zeros((opt.finesize,opt.finesize,3*N+1), dtype='uint8')
for i in range(0,N): for i in range(0,N):
img = impro.imread(os.path.join('./test',test_name,'image',img_names[i])) img = impro.imread(os.path.join('./test',test_name,'image',img_names[i]))
img = impro.resize(img,finesize) img = impro.resize(img,opt.finesize)
inputdata[:,:,i*3:(i+1)*3] = img inputdata[:,:,i*3:(i+1)*3] = img
mask = impro.imread(os.path.join('./test',test_name,'mask.png'),'gray') mask = impro.imread(os.path.join('./test',test_name,'mask.png'),'gray')
mask = impro.resize(mask,finesize) mask = impro.resize(mask,opt.finesize)
mask = impro.mask_threshold(mask,15,128) mask = impro.mask_threshold(mask,15,128)
inputdata[:,:,-1] = mask inputdata[:,:,-1] = mask
result[0:finesize,finesize*cnt:finesize*(cnt+1),:] = inputdata[:,:,int((N-1)/2)*3:(int((N-1)/2)+1)*3] result[0:opt.finesize,opt.finesize*cnt:opt.finesize*(cnt+1),:] = inputdata[:,:,int((N-1)/2)*3:(int((N-1)/2)+1)*3]
inputdata = data.im2tensor(inputdata,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False,is0_1 = False) inputdata = data.im2tensor(inputdata,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False,is0_1 = False)
pred = netG(inputdata) pred = netG(inputdata)
pred = data.tensor2im(pred,rgb2bgr = False, is0_1 = False) pred = data.tensor2im(pred,rgb2bgr = False, is0_1 = False)
result[finesize:finesize*2,finesize*cnt:finesize*(cnt+1),:] = pred result[opt.finesize:opt.finesize*2,opt.finesize*cnt:opt.finesize*(cnt+1),:] = pred
cv2.imwrite(os.path.join(dir_checkpoint,str(iter+1)+'_test.png'), result) cv2.imwrite(os.path.join(dir_checkpoint,str(iter+1)+'_test.jpg'), result)
netG.train() netG.train()
...@@ -35,16 +35,17 @@ def is_video(path): ...@@ -35,16 +35,17 @@ def is_video(path):
else: else:
return False return False
file_list,dir_list = Traversal('./') def cleanall():
for file in file_list: file_list,dir_list = Traversal('./')
if ('tmp' in file) | ('pth' in file)|('pycache' in file) | is_video(file) | is_img(file): for file in file_list:
if os.path.exists(file): if ('tmp' in file) | ('pth' in file)|('pycache' in file) | is_video(file) | is_img(file):
if 'imgs' not in file: if os.path.exists(file):
os.remove(file) if 'imgs' not in file:
print('remove file:',file) os.remove(file)
print('remove file:',file)
for dir in dir_list: for dir in dir_list:
if ('tmp'in dir)|('pycache'in dir): if ('tmp'in dir)|('pycache'in dir):
if os.path.exists(dir): if os.path.exists(dir):
shutil.rmtree(dir) shutil.rmtree(dir)
print('remove dir:',dir) print('remove dir:',dir)
\ No newline at end of file \ No newline at end of file
...@@ -14,17 +14,21 @@ transform = transforms.Compose([ ...@@ -14,17 +14,21 @@ transform = transforms.Compose([
def tensor2im(image_tensor, imtype=np.uint8, gray=False, rgb2bgr = True ,is0_1 = False): def tensor2im(image_tensor, imtype=np.uint8, gray=False, rgb2bgr = True ,is0_1 = False):
image_tensor =image_tensor.data image_tensor =image_tensor.data
image_numpy = image_tensor[0].cpu().float().numpy() image_numpy = image_tensor[0].cpu().float().numpy()
# if gray:
# image_numpy = (image_numpy+1.0)/2.0 * 255.0
# else:
if image_numpy.shape[0] == 1:
image_numpy = np.tile(image_numpy, (3, 1, 1))
image_numpy = image_numpy.transpose((1, 2, 0))
if not is0_1: if not is0_1:
image_numpy = (image_numpy + 1)/2.0 image_numpy = (image_numpy + 1)/2.0
image_numpy = np.clip(image_numpy * 255.0,0,255) image_numpy = np.clip(image_numpy * 255.0,0,255)
# gray -> output 1ch
if gray:
h, w = image_numpy.shape[1:]
image_numpy = image_numpy.reshape(h,w)
return image_numpy.astype(imtype)
# output 3ch
if image_numpy.shape[0] == 1:
image_numpy = np.tile(image_numpy, (3, 1, 1))
image_numpy = image_numpy.transpose((1, 2, 0))
if rgb2bgr and not gray: if rgb2bgr and not gray:
image_numpy = image_numpy[...,::-1]-np.zeros_like(image_numpy) image_numpy = image_numpy[...,::-1]-np.zeros_like(image_numpy)
return image_numpy.astype(imtype) return image_numpy.astype(imtype)
......
...@@ -2,20 +2,28 @@ import os,json ...@@ -2,20 +2,28 @@ import os,json
# ffmpeg 3.4.6 # ffmpeg 3.4.6
def video2image(videopath,imagepath): def video2image(videopath,imagepath,fps=0):
os.system('ffmpeg -i "'+videopath+'" -f image2 '+imagepath) if fps == 0:
os.system('ffmpeg -i "'+videopath+'" -f image2 '+imagepath)
else:
os.system('ffmpeg -i "'+videopath+'" -r '+str(fps)+' -f image2 '+imagepath)
def video2voice(videopath,voicepath): def video2voice(videopath,voicepath):
os.system('ffmpeg -i "'+videopath+'" -f mp3 '+voicepath) os.system('ffmpeg -i "'+videopath+'" -f mp3 '+voicepath)
def image2video(fps,imagepath,voicepath,videopath): def image2video(fps,imagepath,voicepath,videopath):
os.system('ffmpeg -y -r '+str(fps)+' -i '+imagepath+' -vcodec libx264 -b 12M '+'./tmp/video_tmp.mp4') os.system('ffmpeg -y -r '+str(fps)+' -i '+imagepath+' -vcodec libx264 '+'./tmp/video_tmp.mp4')
#os.system('ffmpeg -f image2 -i '+imagepath+' -vcodec libx264 -r '+str(fps)+' ./tmp/video_tmp.mp4') #os.system('ffmpeg -f image2 -i '+imagepath+' -vcodec libx264 -r '+str(fps)+' ./tmp/video_tmp.mp4')
os.system('ffmpeg -i ./tmp/video_tmp.mp4 -i "'+voicepath+'" -vcodec copy -acodec copy '+videopath) os.system('ffmpeg -i ./tmp/video_tmp.mp4 -i "'+voicepath+'" -vcodec copy -acodec copy '+videopath)
def get_video_infos(videopath): def get_video_infos(videopath):
cmd_str = 'ffprobe -v quiet -print_format json -show_format -show_streams -i "' + videopath + '"' cmd_str = 'ffprobe -v quiet -print_format json -show_format -show_streams -i "' + videopath + '"'
out_string = os.popen(cmd_str).read() #out_string = os.popen(cmd_str).read()
#For chinese path in Windows
#https://blog.csdn.net/weixin_43903378/article/details/91979025
stream = os.popen(cmd_str)._stream
out_string = stream.buffer.read().decode(encoding='utf-8')
infos = json.loads(out_string) infos = json.loads(out_string)
try: try:
fps = eval(infos['streams'][0]['avg_frame_rate']) fps = eval(infos['streams'][0]['avg_frame_rate'])
...@@ -28,7 +36,7 @@ def get_video_infos(videopath): ...@@ -28,7 +36,7 @@ def get_video_infos(videopath):
width = int(infos['streams'][1]['width']) width = int(infos['streams'][1]['width'])
height = int(infos['streams'][1]['height']) height = int(infos['streams'][1]['height'])
return fps,endtime,width,height return fps,endtime,height,width
def cut_video(in_path,start_time,last_time,out_path,vcodec='h265'): def cut_video(in_path,start_time,last_time,out_path,vcodec='h265'):
if vcodec == 'copy': if vcodec == 'copy':
......
...@@ -19,7 +19,7 @@ def imread(file_path,mod = 'normal'): ...@@ -19,7 +19,7 @@ def imread(file_path,mod = 'normal'):
elif mod == 'all': elif mod == 'all':
img = cv2.imread(file_path,-1) img = cv2.imread(file_path,-1)
#For chinese path, use cv2.imdecode in windows. #In windows, for chinese path, use cv2.imdecode insteaded.
#It will loss EXIF, I can't fix it #It will loss EXIF, I can't fix it
else: else:
if mod == 'gray': if mod == 'gray':
...@@ -133,7 +133,7 @@ def mergeimage(img1,img2,orgin_image,size = 128): ...@@ -133,7 +133,7 @@ def mergeimage(img1,img2,orgin_image,size = 128):
result_img = cv2.add(new_img1,new_img2) result_img = cv2.add(new_img1,new_img2)
return result_img return result_img
def find_best_ROI(mask): def find_mostlikely_ROI(mask):
contours,hierarchy=cv2.findContours(mask, cv2.RETR_LIST,cv2.CHAIN_APPROX_SIMPLE) contours,hierarchy=cv2.findContours(mask, cv2.RETR_LIST,cv2.CHAIN_APPROX_SIMPLE)
if len(contours)>0: if len(contours)>0:
areas = [] areas = []
...@@ -182,9 +182,9 @@ def boundingSquare(mask,Ex_mul): ...@@ -182,9 +182,9 @@ def boundingSquare(mask,Ex_mul):
center = ((point0+point1)/2).astype('int') center = ((point0+point1)/2).astype('int')
return center[0],center[1],halfsize,area return center[0],center[1],halfsize,area
def mask_threshold(mask,blur,threshold): def mask_threshold(mask,ex_mun,threshold):
mask = cv2.threshold(mask,threshold,255,cv2.THRESH_BINARY)[1] mask = cv2.threshold(mask,threshold,255,cv2.THRESH_BINARY)[1]
mask = cv2.blur(mask, (blur, blur)) mask = cv2.blur(mask, (ex_mun, ex_mun))
mask = cv2.threshold(mask,threshold/5,255,cv2.THRESH_BINARY)[1] mask = cv2.threshold(mask,threshold/5,255,cv2.THRESH_BINARY)[1]
return mask return mask
...@@ -200,7 +200,7 @@ def mask_area(mask): ...@@ -200,7 +200,7 @@ def mask_area(mask):
def replace_mosaic(img_origin,img_fake,x,y,size,no_father): def replace_mosaic(img_origin,img_fake,x,y,size,no_father):
img_fake = resize(img_fake,size*2) img_fake = resize(img_fake,size*2,interpolation=cv2.INTER_LANCZOS4)
if no_father: if no_father:
img_origin[y-size:y+size,x-size:x+size]=img_fake img_origin[y-size:y+size,x-size:x+size]=img_fake
img_result = img_origin img_result = img_origin
......
...@@ -40,10 +40,12 @@ def is_videos(paths): ...@@ -40,10 +40,12 @@ def is_videos(paths):
tmp.append(path) tmp.append(path)
return tmp return tmp
def writelog(path,log): def writelog(path,log,isprint=False):
f = open(path,'a+') f = open(path,'a+')
f.write(log+'\n') f.write(log+'\n')
f.close() f.close()
if isprint:
print(log)
def makedirs(path): def makedirs(path):
if os.path.isdir(path): if os.path.isdir(path):
...@@ -87,3 +89,11 @@ def copyfile(src,dst): ...@@ -87,3 +89,11 @@ def copyfile(src,dst):
shutil.copyfile(src, dst) shutil.copyfile(src, dst)
except Exception as e: except Exception as e:
print(e) print(e)
def opt2str(opt):
message = ''
message += '---------------------- Options --------------------\n'
for k, v in sorted(vars(opt).items()):
message += '{:>25}: {:<35}\n'.format(str(k), str(v))
message += '----------------- End -------------------'
return message
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册