video_model.py 7.1 KB
Newer Older
H
hypox64 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
import torch
import torch.nn as nn
import torch.nn.functional as F
from .unet_parts import *
from .pix2pix_model import *

class encoder_2d(nn.Module):
    """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.

    We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
    """

    def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
        """Construct a Resnet-based generator

        Parameters:
            input_nc (int)      -- the number of channels in input images
            output_nc (int)     -- the number of channels in output images
            ngf (int)           -- the number of filters in the last conv layer
            norm_layer          -- normalization layer
            use_dropout (bool)  -- if use dropout layers
            n_blocks (int)      -- the number of ResNet blocks
            padding_type (str)  -- the name of padding layer in conv layers: reflect | replicate | zero
        """
        assert(n_blocks >= 0)
        super(encoder_2d, self).__init__()
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
                 norm_layer(ngf),
                 nn.ReLU(True)]

        n_downsampling = 2
        for i in range(n_downsampling):  # add downsampling layers
            mult = 2 ** i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
                      norm_layer(ngf * mult * 2),
                      nn.ReLU(True)]
        #torch.Size([1, 256, 32, 32])

        self.model = nn.Sequential(*model)

    def forward(self, input):
        """Standard forward"""
        return self.model(input)


class decoder_2d(nn.Module):
    """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.

    We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
    """

    def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
        """Construct a Resnet-based generator

        Parameters:
            input_nc (int)      -- the number of channels in input images
            output_nc (int)     -- the number of channels in output images
            ngf (int)           -- the number of filters in the last conv layer
            norm_layer          -- normalization layer
            use_dropout (bool)  -- if use dropout layers
            n_blocks (int)      -- the number of ResNet blocks
            padding_type (str)  -- the name of padding layer in conv layers: reflect | replicate | zero
        """
        super(decoder_2d, self).__init__()
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        model = []

        n_downsampling = 2
        mult = 2 ** n_downsampling
        for i in range(n_blocks):       # add ResNet blocks
            model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
        #torch.Size([1, 256, 32, 32])

        for i in range(n_downsampling):  # add upsampling layers
            mult = 2 ** (n_downsampling - i)
            # model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
            #                              kernel_size=3, stride=2,
            #                              padding=1, output_padding=1,
            #                              bias=use_bias),
            #           norm_layer(int(ngf * mult / 2)),
            #           nn.ReLU(True)]
            #https://distill.pub/2016/deconv-checkerboard/
            #https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/190

            model += [  nn.Upsample(scale_factor = 2, mode='nearest'),
                        nn.ReflectionPad2d(1),
                        nn.Conv2d(ngf * mult, int(ngf * mult / 2),kernel_size=3, stride=1, padding=0),
                        norm_layer(int(ngf * mult / 2)),
                        nn.ReLU(True)]
        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
        # model += [nn.Tanh()]
H
hypox64 已提交
103
        # model += [nn.Sigmoid()]
H
hypox64 已提交
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148

        self.model = nn.Sequential(*model)

    def forward(self, input):
        """Standard forward"""
        return self.model(input)



class conv_3d(nn.Module):
    def __init__(self,inchannel,outchannel,kernel_size=3,stride=2,padding=1):
        super(conv_3d, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(inchannel, outchannel, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.BatchNorm3d(outchannel),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class encoder_3d(nn.Module):
    def __init__(self,in_channel):
        super(encoder_3d, self).__init__()
        self.down1 = conv_3d(1, 64, 3, 2, 1)
        self.down2 = conv_3d(64, 128, 3, 2, 1)
        self.down3 = conv_3d(128, 256, 3, 1, 1)
        self.conver2d = nn.Sequential(
            nn.Conv2d(int(in_channel/4), 1, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(1),
            nn.ReLU(inplace=True),
        )


    def forward(self, x):
        x = x.view(x.size(0),1,x.size(1),x.size(2),x.size(3))
        x = self.down1(x)
        x = self.down2(x)
        x = self.down3(x)

        x = x.view(x.size(1),x.size(2),x.size(3),x.size(4))
        x = self.conver2d(x)
        x = x.view(x.size(1),x.size(0),x.size(2),x.size(3))
H
hypox64 已提交
149

H
hypox64 已提交
150 151 152 153
        return x



H
HypoX64 已提交
154
class MosaicNet(nn.Module):
H
hypox64 已提交
155
    def __init__(self, in_channel, out_channel):
H
HypoX64 已提交
156
        super(MosaicNet, self).__init__()
H
hypox64 已提交
157 158 159 160

        self.encoder_2d = encoder_2d(4,-1,64,n_blocks=9)
        self.encoder_3d = encoder_3d(in_channel)
        self.decoder_2d = decoder_2d(4,3,64,n_blocks=9)
H
hypox64 已提交
161 162 163
        self.merge1 = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(512, 256, 3, 1, 0, bias=False),
H
hypox64 已提交
164 165 166
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
        )
H
hypox64 已提交
167 168 169 170 171
        self.merge2 = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(6, out_channel, kernel_size=7, padding=0),
            nn.Sigmoid()
        )
H
hypox64 已提交
172 173 174 175 176

    def forward(self, x):

        N = int((x.size()[1])/3)
        x_2d = torch.cat((x[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], x[:,N-1:N,:,:]), 1)
H
hypox64 已提交
177 178
        shortcat_2d = x[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:]

H
hypox64 已提交
179 180
        x_2d = self.encoder_2d(x_2d)
        x_3d = self.encoder_3d(x)
H
hypox64 已提交
181 182
        x = torch.cat((x_2d,x_3d),1)
        x = self.merge1(x)
H
hypox64 已提交
183
        x = self.decoder_2d(x)
H
hypox64 已提交
184 185
        x = torch.cat((x,shortcat_2d),1)
        x = self.merge2(x)
H
hypox64 已提交
186 187 188

        return x