loadmodel.py 1.4 KB
Newer Older
HypoX64's avatar
preview  
HypoX64 已提交
1
import torch
H
hypox64 已提交
2 3
from .pix2pix_model import define_G
from .pix2pixHD_model import define_G as define_G_HD
HypoX64's avatar
preview  
HypoX64 已提交
4
from .unet_model import UNet
H
HypoX64 已提交
5 6 7 8 9 10 11
from .video_model import MosaicNet

def show_paramsnumber(net,netname='net'):
    parameters = sum(param.numel() for param in net.parameters())
    parameters = round(parameters/1e6,2)
    print(netname+' parameters: '+str(parameters)+'M')

HypoX64's avatar
preview  
HypoX64 已提交
12

H
hypox64 已提交
13
def pix2pix(opt):
H
hypox64 已提交
14
    # print(opt.model_path,opt.netG)
H
hypox64 已提交
15 16 17 18
    if opt.netG == 'HD':
        netG = define_G_HD(3, 3, 64, 'global' ,4)
    else:
        netG = define_G(3, 3, 64, opt.netG, norm='batch',use_dropout=True, init_type='normal', gpu_ids=[])
H
HypoX64 已提交
19
    show_paramsnumber(netG,'netG')
H
hypox64 已提交
20
    netG.load_state_dict(torch.load(opt.model_path))
HypoX64's avatar
preview  
HypoX64 已提交
21
    netG.eval()
H
hypox64 已提交
22
    if opt.use_gpu:
HypoX64's avatar
preview  
HypoX64 已提交
23 24 25
        netG.cuda()
    return netG

H
hypox64 已提交
26
def video(opt):
H
HypoX64 已提交
27 28
    netG = MosaicNet(3*25+1, 3)
    show_paramsnumber(netG,'netG')
H
hypox64 已提交
29 30 31 32 33 34 35
    netG.load_state_dict(torch.load(opt.model_path))
    netG.eval()
    if opt.use_gpu:
        netG.cuda()
    return netG


H
hypox64 已提交
36 37
def unet_clean(opt):
    net = UNet(n_channels = 3, n_classes = 1)
H
HypoX64 已提交
38
    show_paramsnumber(net,'segment')
H
hypox64 已提交
39 40 41 42 43
    net.load_state_dict(torch.load(opt.mosaic_position_model_path))
    net.eval()
    if opt.use_gpu:
        net.cuda()
    return net
H
hypox64 已提交
44

H
hypox64 已提交
45
def unet(opt):
HypoX64's avatar
preview  
HypoX64 已提交
46
    net = UNet(n_channels = 3, n_classes = 1)
H
HypoX64 已提交
47
    show_paramsnumber(net,'segment')
H
hypox64 已提交
48
    net.load_state_dict(torch.load(opt.model_path))
HypoX64's avatar
preview  
HypoX64 已提交
49
    net.eval()
H
hypox64 已提交
50
    if opt.use_gpu:
HypoX64's avatar
preview  
HypoX64 已提交
51 52
        net.cuda()
    return net