loadmodel.py 2.6 KB
Newer Older
HypoX64's avatar
preview  
HypoX64 已提交
1
import torch
H
hypox64 已提交
2
from . import model_util
3 4 5 6
from .pix2pix_model import define_G as pix2pix_G
from .pix2pixHD_model import define_G as pix2pixHD_G
# from .video_model import MosaicNet
# from .videoHD_model import MosaicNet as MosaicNet_HD
H
hypox64 已提交
7
from .BiSeNet_model import BiSeNet
8
from .BVDNet import define_G as video_G
H
HypoX64 已提交
9 10 11 12 13 14

def show_paramsnumber(net,netname='net'):
    parameters = sum(param.numel() for param in net.parameters())
    parameters = round(parameters/1e6,2)
    print(netname+' parameters: '+str(parameters)+'M')

H
hypox64 已提交
15
def pix2pix(opt):
H
hypox64 已提交
16
    # print(opt.model_path,opt.netG)
H
hypox64 已提交
17
    if opt.netG == 'HD':
18
        netG = pix2pixHD_G(3, 3, 64, 'global' ,4)
H
hypox64 已提交
19
    else:
20
        netG = pix2pix_G(3, 3, 64, opt.netG, norm='batch',use_dropout=True, init_type='normal', gpu_ids=[])
H
HypoX64 已提交
21
    show_paramsnumber(netG,'netG')
H
hypox64 已提交
22
    netG.load_state_dict(torch.load(opt.model_path))
H
hypox64 已提交
23
    netG = model_util.todevice(netG,opt.gpu_id)
HypoX64's avatar
preview  
HypoX64 已提交
24 25 26
    netG.eval()
    return netG

H
hypox64 已提交
27

H
hypox64 已提交
28 29
def style(opt):
    if opt.edges:
30
        netG = pix2pix_G(1, 3, 64, 'resnet_9blocks', norm='instance',use_dropout=True, init_type='normal', gpu_ids=[])
H
hypox64 已提交
31
    else:
32
        netG = pix2pix_G(3, 3, 64, 'resnet_9blocks', norm='instance',use_dropout=False, init_type='normal', gpu_ids=[])
H
hypox64 已提交
33

H
hypox64 已提交
34 35 36 37 38 39 40 41 42 43 44 45
    #in other to load old pretrain model
    #https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/models/base_model.py
    if isinstance(netG, torch.nn.DataParallel):
        netG = netG.module
    # if you are using PyTorch newer than 0.4 (e.g., built from
    # GitHub source), you can remove str() on self.device
    state_dict = torch.load(opt.model_path, map_location='cpu')
    if hasattr(state_dict, '_metadata'):
        del state_dict._metadata

    # patch InstanceNorm checkpoints prior to 0.4
    for key in list(state_dict.keys()):  # need to copy keys here because we mutate in loop
H
hypox64 已提交
46
        model_util.patch_instance_norm_state_dict(state_dict, netG, key.split('.'))
H
hypox64 已提交
47 48
    netG.load_state_dict(state_dict)

H
hypox64 已提交
49 50
    netG = model_util.todevice(netG,opt.gpu_id)
    netG.eval()
H
hypox64 已提交
51 52
    return netG

H
hypox64 已提交
53
def video(opt):
54
    netG = video_G(N=2,n_blocks=1,gpu_id=opt.gpu_id)
H
HypoX64 已提交
55
    show_paramsnumber(netG,'netG')
H
hypox64 已提交
56
    netG.load_state_dict(torch.load(opt.model_path))
H
hypox64 已提交
57
    netG = model_util.todevice(netG,opt.gpu_id)
H
hypox64 已提交
58 59 60
    netG.eval()
    return netG

H
hypox64 已提交
61 62 63 64 65
def bisenet(opt,type='roi'):
    '''
    type: roi or mosaic
    '''
    net = BiSeNet(num_classes=1, context_path='resnet18',train_flag=False)
H
HypoX64 已提交
66
    show_paramsnumber(net,'segment')
H
hypox64 已提交
67 68 69 70
    if type == 'roi':
        net.load_state_dict(torch.load(opt.model_path))
    elif type == 'mosaic':
        net.load_state_dict(torch.load(opt.mosaic_position_model_path))
H
hypox64 已提交
71
    net = model_util.todevice(net,opt.gpu_id)
H
hypox64 已提交
72 73
    net.eval()
    return net