occlusion_aware.py 8.8 KB
Newer Older
F
FNRE 已提交
1
# code was heavily based on https://github.com/AliaksandrSiarohin/first-order-model
L
lzzyzlbb 已提交
2 3
# Users should be careful about adopting these functions in any commercial matters.
# https://github.com/AliaksandrSiarohin/first-order-model/blob/master/LICENSE.md
F
FNRE 已提交
4

5 6 7
import paddle
from paddle import nn
import paddle.nn.functional as F
8
from ...modules.first_order import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d, make_coordinate_grid
L
lzzyzlbb 已提交
9
from ...modules.first_order import MobileResBlock2d, MobileUpBlock2d, MobileDownBlock2d
10
from ...modules.dense_motion import DenseMotionNetwork
11 12
import numpy as np
import cv2
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28


class OcclusionAwareGenerator(nn.Layer):
    """
    Generator that given source image and and keypoints try to transform image according to movement trajectories
    induced by keypoints. Generator follows Johnson architecture.
    """
    def __init__(self,
                 num_channels,
                 num_kp,
                 block_expansion,
                 max_features,
                 num_down_blocks,
                 num_bottleneck_blocks,
                 estimate_occlusion_map=False,
                 dense_motion_params=None,
29
                 estimate_jacobian=False,
L
lzzyzlbb 已提交
30 31
                 inference=False,
                 mobile_net=False):
32 33 34 35 36 37 38
        super(OcclusionAwareGenerator, self).__init__()

        if dense_motion_params is not None:
            self.dense_motion_network = DenseMotionNetwork(
                num_kp=num_kp,
                num_channels=num_channels,
                estimate_occlusion_map=estimate_occlusion_map,
39 40
                **dense_motion_params,
                mobile_net=mobile_net)
41 42 43 44 45 46
        else:
            self.dense_motion_network = None

        self.first = SameBlock2d(num_channels,
                                 block_expansion,
                                 kernel_size=(7, 7),
L
lzzyzlbb 已提交
47 48
                                 padding=(3, 3),
                                 mobile_net=mobile_net)
49 50

        down_blocks = []
L
lzzyzlbb 已提交
51 52 53 54 55
        if mobile_net:
            for i in range(num_down_blocks):
                in_features = min(max_features, block_expansion * (2**i))
                out_features = min(max_features, block_expansion * (2**(i + 1)))
                down_blocks.append(
56 57 58 59
                    MobileDownBlock2d(in_features,
                                      out_features,
                                      kernel_size=(3, 3),
                                      padding=(1, 1)))
L
lzzyzlbb 已提交
60 61 62 63 64 65 66 67 68
        else:
            for i in range(num_down_blocks):
                in_features = min(max_features, block_expansion * (2**i))
                out_features = min(max_features, block_expansion * (2**(i + 1)))
                down_blocks.append(
                    DownBlock2d(in_features,
                                out_features,
                                kernel_size=(3, 3),
                                padding=(1, 1)))
69 70 71
        self.down_blocks = nn.LayerList(down_blocks)

        up_blocks = []
L
lzzyzlbb 已提交
72 73 74
        if mobile_net:
            for i in range(num_down_blocks):
                in_features = min(max_features,
75 76 77 78
                                  block_expansion * (2**(num_down_blocks - i)))
                out_features = min(
                    max_features,
                    block_expansion * (2**(num_down_blocks - i - 1)))
L
lzzyzlbb 已提交
79 80
                up_blocks.append(
                    MobileUpBlock2d(in_features,
81 82 83
                                    out_features,
                                    kernel_size=(3, 3),
                                    padding=(1, 1)))
L
lzzyzlbb 已提交
84 85 86 87
        else:
            for i in range(num_down_blocks):
                in_features = min(max_features,
                                  block_expansion * (2**(num_down_blocks - i)))
88 89 90
                out_features = min(
                    max_features,
                    block_expansion * (2**(num_down_blocks - i - 1)))
L
lzzyzlbb 已提交
91 92 93 94 95
                up_blocks.append(
                    UpBlock2d(in_features,
                              out_features,
                              kernel_size=(3, 3),
                              padding=(1, 1)))
96 97 98 99
        self.up_blocks = nn.LayerList(up_blocks)

        self.bottleneck = paddle.nn.Sequential()
        in_features = min(max_features, block_expansion * (2**num_down_blocks))
L
lzzyzlbb 已提交
100
        if mobile_net:
101
            for i in range(num_bottleneck_blocks):
L
lzzyzlbb 已提交
102 103
                self.bottleneck.add_sublayer(
                    'r' + str(i),
104 105 106
                    MobileResBlock2d(in_features,
                                     kernel_size=(3, 3),
                                     padding=(1, 1)))
L
lzzyzlbb 已提交
107 108 109 110 111
        else:
            for i in range(num_bottleneck_blocks):
                self.bottleneck.add_sublayer(
                    'r' + str(i),
                    ResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1)))
112

L
LielinJiang 已提交
113
        self.final = nn.Conv2D(block_expansion,
114 115 116 117 118
                               num_channels,
                               kernel_size=(7, 7),
                               padding=(3, 3))
        self.estimate_occlusion_map = estimate_occlusion_map
        self.num_channels = num_channels
119 120
        self.inference = inference
        self.pad = 5
121
        self.mobile_net = mobile_net
122 123 124 125 126 127 128 129

    def deform_input(self, inp, deformation):
        _, h_old, w_old, _ = deformation.shape
        _, _, h, w = inp.shape
        if h_old != h or w_old != w:
            deformation = deformation.transpose([0, 3, 1, 2])
            deformation = F.interpolate(deformation,
                                        size=(h, w),
L
LielinJiang 已提交
130 131
                                        mode='bilinear',
                                        align_corners=False)
132
            deformation = deformation.transpose([0, 2, 3, 1])
133
        if self.inference:
134
            identity_grid = make_coordinate_grid((h, w), type=inp.dtype)
135
            identity_grid = identity_grid.reshape([1, h, w, 2])
136 137 138 139 140 141 142
            visualization_matrix = np.zeros((h, w)).astype("float32")
            visualization_matrix[self.pad:h - self.pad,
                                 self.pad:w - self.pad] = 1.0
            gauss_kernel = paddle.to_tensor(
                cv2.GaussianBlur(visualization_matrix, (9, 9),
                                 0.0,
                                 borderType=cv2.BORDER_ISOLATED))
143
            gauss_kernel = gauss_kernel.unsqueeze(0).unsqueeze(-1)
144 145
            deformation = gauss_kernel * deformation + (
                1 - gauss_kernel) * identity_grid
146

F
FNRE 已提交
147 148 149 150 151
        return F.grid_sample(inp,
                             deformation,
                             mode='bilinear',
                             padding_mode='zeros',
                             align_corners=True)
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180

    def forward(self, source_image, kp_driving, kp_source):
        # Encoding (downsampling) part
        out = self.first(source_image)
        for i in range(len(self.down_blocks)):
            out = self.down_blocks[i](out)

        # Transforming feature representation according to deformation and occlusion
        output_dict = {}
        if self.dense_motion_network is not None:
            dense_motion = self.dense_motion_network(source_image=source_image,
                                                     kp_driving=kp_driving,
                                                     kp_source=kp_source)
            output_dict['mask'] = dense_motion['mask']
            output_dict['sparse_deformed'] = dense_motion['sparse_deformed']

            if 'occlusion_map' in dense_motion:
                occlusion_map = dense_motion['occlusion_map']
                output_dict['occlusion_map'] = occlusion_map
            else:
                occlusion_map = None
            deformation = dense_motion['deformation']
            out = self.deform_input(out, deformation)

            if occlusion_map is not None:
                if out.shape[2] != occlusion_map.shape[2] or out.shape[
                        3] != occlusion_map.shape[3]:
                    occlusion_map = F.interpolate(occlusion_map,
                                                  size=out.shape[2:],
L
LielinJiang 已提交
181 182
                                                  mode='bilinear',
                                                  align_corners=False)
183 184 185 186 187 188
                if self.inference and not self.mobile_net:
                    h, w = occlusion_map.shape[2:]
                    occlusion_map[:, :, 0:self.pad, :] = 1.0
                    occlusion_map[:, :, :, 0:self.pad] = 1.0
                    occlusion_map[:, :, h - self.pad:h, :] = 1.0
                    occlusion_map[:, :, :, w - self.pad:w] = 1.0
189 190 191 192 193 194 195 196 197 198 199 200 201 202 203
                out = out * occlusion_map

            output_dict["deformed"] = self.deform_input(source_image,
                                                        deformation)

        # Decoding part
        out = self.bottleneck(out)
        for i in range(len(self.up_blocks)):
            out = self.up_blocks[i](out)
        out = self.final(out)
        out = F.sigmoid(out)

        output_dict["prediction"] = out

        return output_dict