loadmodel.py 3.3 KB
Newer Older
HypoX64's avatar
preview  
HypoX64 已提交
1
import torch
H
hypox64 已提交
2 3
from .pix2pix_model import define_G
from .pix2pixHD_model import define_G as define_G_HD
HypoX64's avatar
preview  
HypoX64 已提交
4
from .unet_model import UNet
H
HypoX64 已提交
5
from .video_model import MosaicNet
H
HypoX64 已提交
6
from .videoHD_model import MosaicNet as MosaicNet_HD
H
HypoX64 已提交
7 8 9 10 11 12

def show_paramsnumber(net,netname='net'):
    parameters = sum(param.numel() for param in net.parameters())
    parameters = round(parameters/1e6,2)
    print(netname+' parameters: '+str(parameters)+'M')

H
hypox64 已提交
13 14 15 16 17 18 19 20 21 22 23 24 25
def __patch_instance_norm_state_dict(state_dict, module, keys, i=0):
    """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
    key = keys[i]
    if i + 1 == len(keys):  # at the end, pointing to a parameter/buffer
        if module.__class__.__name__.startswith('InstanceNorm') and \
                (key == 'running_mean' or key == 'running_var'):
            if getattr(module, key) is None:
                state_dict.pop('.'.join(keys))
        if module.__class__.__name__.startswith('InstanceNorm') and \
           (key == 'num_batches_tracked'):
            state_dict.pop('.'.join(keys))
    else:
        __patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
HypoX64's avatar
preview  
HypoX64 已提交
26

H
hypox64 已提交
27
def pix2pix(opt):
H
hypox64 已提交
28
    # print(opt.model_path,opt.netG)
H
hypox64 已提交
29 30 31 32
    if opt.netG == 'HD':
        netG = define_G_HD(3, 3, 64, 'global' ,4)
    else:
        netG = define_G(3, 3, 64, opt.netG, norm='batch',use_dropout=True, init_type='normal', gpu_ids=[])
H
HypoX64 已提交
33
    show_paramsnumber(netG,'netG')
H
hypox64 已提交
34
    netG.load_state_dict(torch.load(opt.model_path))
HypoX64's avatar
preview  
HypoX64 已提交
35
    netG.eval()
H
hypox64 已提交
36
    if opt.use_gpu:
HypoX64's avatar
preview  
HypoX64 已提交
37 38 39
        netG.cuda()
    return netG

H
hypox64 已提交
40

H
hypox64 已提交
41 42
def style(opt):
    if opt.edges:
43
        netG = define_G(1, 3, 64, 'resnet_9blocks', norm='instance',use_dropout=True, init_type='normal', gpu_ids=[])
H
hypox64 已提交
44 45 46
    else:
        netG = define_G(3, 3, 64, 'resnet_9blocks', norm='instance',use_dropout=False, init_type='normal', gpu_ids=[])

H
hypox64 已提交
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
    #in other to load old pretrain model
    #https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/models/base_model.py
    if isinstance(netG, torch.nn.DataParallel):
        netG = netG.module
    # if you are using PyTorch newer than 0.4 (e.g., built from
    # GitHub source), you can remove str() on self.device
    state_dict = torch.load(opt.model_path, map_location='cpu')
    if hasattr(state_dict, '_metadata'):
        del state_dict._metadata

    # patch InstanceNorm checkpoints prior to 0.4
    for key in list(state_dict.keys()):  # need to copy keys here because we mutate in loop
        __patch_instance_norm_state_dict(state_dict, netG, key.split('.'))
    netG.load_state_dict(state_dict)

    if opt.use_gpu:
        netG.cuda()
    return netG

H
hypox64 已提交
66
def video(opt):
H
HypoX64 已提交
67 68 69 70
    if 'HD' in opt.model_path:
        netG = MosaicNet_HD(3*25+1, 3, norm='instance')
    else:
        netG = MosaicNet(3*25+1, 3,norm = 'batch')
H
HypoX64 已提交
71
    show_paramsnumber(netG,'netG')
H
hypox64 已提交
72 73 74 75 76 77 78
    netG.load_state_dict(torch.load(opt.model_path))
    netG.eval()
    if opt.use_gpu:
        netG.cuda()
    return netG


H
hypox64 已提交
79 80
def unet_clean(opt):
    net = UNet(n_channels = 3, n_classes = 1)
H
HypoX64 已提交
81
    show_paramsnumber(net,'segment')
H
hypox64 已提交
82 83 84 85 86
    net.load_state_dict(torch.load(opt.mosaic_position_model_path))
    net.eval()
    if opt.use_gpu:
        net.cuda()
    return net
H
hypox64 已提交
87

H
hypox64 已提交
88
def unet(opt):
HypoX64's avatar
preview  
HypoX64 已提交
89
    net = UNet(n_channels = 3, n_classes = 1)
H
HypoX64 已提交
90
    show_paramsnumber(net,'segment')
H
hypox64 已提交
91
    net.load_state_dict(torch.load(opt.model_path))
HypoX64's avatar
preview  
HypoX64 已提交
92
    net.eval()
H
hypox64 已提交
93
    if opt.use_gpu:
HypoX64's avatar
preview  
HypoX64 已提交
94 95
        net.cuda()
    return net