loadmodel.py 742 字节
Newer Older
HypoX64's avatar
preview  
HypoX64 已提交
1 2 3 4
import torch
from .pix2pix_model import *
from .unet_model import UNet

H
hypox64 已提交
5 6 7 8 9
def pix2pix(opt):
    print(opt.model_path,opt.netG)
    netG = define_G(3, 3, 64, opt.netG, norm='batch',use_dropout=True, init_type='normal', gpu_ids=[])

    netG.load_state_dict(torch.load(opt.model_path))
HypoX64's avatar
preview  
HypoX64 已提交
10
    netG.eval()
H
hypox64 已提交
11
    if opt.use_gpu:
HypoX64's avatar
preview  
HypoX64 已提交
12 13 14
        netG.cuda()
    return netG

H
hypox64 已提交
15 16 17 18 19 20 21
def unet_clean(opt):
    net = UNet(n_channels = 3, n_classes = 1)
    net.load_state_dict(torch.load(opt.mosaic_position_model_path))
    net.eval()
    if opt.use_gpu:
        net.cuda()
    return net
H
hypox64 已提交
22

H
hypox64 已提交
23
def unet(opt):
HypoX64's avatar
preview  
HypoX64 已提交
24
    net = UNet(n_channels = 3, n_classes = 1)
H
hypox64 已提交
25
    net.load_state_dict(torch.load(opt.model_path))
HypoX64's avatar
preview  
HypoX64 已提交
26
    net.eval()
H
hypox64 已提交
27
    if opt.use_gpu:
HypoX64's avatar
preview  
HypoX64 已提交
28 29
        net.cuda()
    return net