loadmodel.py 1.0 KB
Newer Older
HypoX64's avatar
preview  
HypoX64 已提交
1
import torch
H
hypox64 已提交
2 3
from .pix2pix_model import define_G
from .pix2pixHD_model import define_G as define_G_HD
HypoX64's avatar
preview  
HypoX64 已提交
4 5
from .unet_model import UNet

H
hypox64 已提交
6
def pix2pix(opt):
H
hypox64 已提交
7
    # print(opt.model_path,opt.netG)
H
hypox64 已提交
8 9
    if opt.netG == 'HD':
        netG = define_G_HD(3, 3, 64, 'global' ,4)
H
hypox64 已提交
10 11
    elif opt.netG == 'video':
        netG = define_G(3*25+1, 3, 128, 'unet_128', norm='instance',use_dropout=True, init_type='normal', gpu_ids=[])
H
hypox64 已提交
12 13
    else:
        netG = define_G(3, 3, 64, opt.netG, norm='batch',use_dropout=True, init_type='normal', gpu_ids=[])
H
hypox64 已提交
14 15

    netG.load_state_dict(torch.load(opt.model_path))
HypoX64's avatar
preview  
HypoX64 已提交
16
    netG.eval()
H
hypox64 已提交
17
    if opt.use_gpu:
HypoX64's avatar
preview  
HypoX64 已提交
18 19 20
        netG.cuda()
    return netG

H
hypox64 已提交
21 22 23 24 25 26 27
def unet_clean(opt):
    net = UNet(n_channels = 3, n_classes = 1)
    net.load_state_dict(torch.load(opt.mosaic_position_model_path))
    net.eval()
    if opt.use_gpu:
        net.cuda()
    return net
H
hypox64 已提交
28

H
hypox64 已提交
29
def unet(opt):
HypoX64's avatar
preview  
HypoX64 已提交
30
    net = UNet(n_channels = 3, n_classes = 1)
H
hypox64 已提交
31
    net.load_state_dict(torch.load(opt.model_path))
HypoX64's avatar
preview  
HypoX64 已提交
32
    net.eval()
H
hypox64 已提交
33
    if opt.use_gpu:
HypoX64's avatar
preview  
HypoX64 已提交
34 35
        net.cuda()
    return net