tps.py 11.3 KB
Newer Older
W
WenmuZhou 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

W
WenmuZhou 已提交
19
import math
W
WenmuZhou 已提交
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
import paddle
from paddle import nn, ParamAttr
from paddle.nn import functional as F
import numpy as np


class ConvBNLayer(nn.Layer):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 groups=1,
                 act=None,
                 name=None):
        super(ConvBNLayer, self).__init__()
        self.conv = nn.Conv2D(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=(kernel_size - 1) // 2,
            groups=groups,
            weight_attr=ParamAttr(name=name + "_weights"),
            bias_attr=False)
        bn_name = "bn_" + name
        self.bn = nn.BatchNorm(
            out_channels,
            act=act,
            param_attr=ParamAttr(name=bn_name + '_scale'),
            bias_attr=ParamAttr(bn_name + '_offset'),
            moving_mean_name=bn_name + '_mean',
            moving_variance_name=bn_name + '_variance')

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x


class LocalizationNetwork(nn.Layer):
    def __init__(self, in_channels, num_fiducial, loc_lr, model_name):
        super(LocalizationNetwork, self).__init__()
        self.F = num_fiducial
        F = num_fiducial
        if model_name == "large":
            num_filters_list = [64, 128, 256, 512]
            fc_dim = 256
        else:
            num_filters_list = [16, 32, 64, 128]
            fc_dim = 64

        self.block_list = []
        for fno in range(0, len(num_filters_list)):
            num_filters = num_filters_list[fno]
            name = "loc_conv%d" % fno
            conv = self.add_sublayer(
                name,
                ConvBNLayer(
                    in_channels=in_channels,
                    out_channels=num_filters,
                    kernel_size=3,
                    act='relu',
                    name=name))
            self.block_list.append(conv)
            if fno == len(num_filters_list) - 1:
                pool = nn.AdaptiveAvgPool2D(1)
            else:
                pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
            in_channels = num_filters
            self.block_list.append(pool)
        name = "loc_fc1"
W
WenmuZhou 已提交
92
        stdv = 1.0 / math.sqrt(num_filters_list[-1] * 1.0)
W
WenmuZhou 已提交
93 94 95 96
        self.fc1 = nn.Linear(
            in_channels,
            fc_dim,
            weight_attr=ParamAttr(
W
WenmuZhou 已提交
97 98 99
                learning_rate=loc_lr,
                name=name + "_w",
                initializer=nn.initializer.Uniform(-stdv, stdv)),
W
WenmuZhou 已提交
100 101 102 103 104 105 106 107 108
            bias_attr=ParamAttr(name=name + '.b_0'),
            name=name)

        # Init fc2 in LocalizationNetwork
        initial_bias = self.get_initial_fiducials()
        initial_bias = initial_bias.reshape(-1)
        name = "loc_fc2"
        param_attr = ParamAttr(
            learning_rate=loc_lr,
W
WenmuZhou 已提交
109
            initializer=nn.initializer.Assign(np.zeros([fc_dim, F * 2])),
W
WenmuZhou 已提交
110 111 112
            name=name + "_w")
        bias_attr = ParamAttr(
            learning_rate=loc_lr,
W
WenmuZhou 已提交
113
            initializer=nn.initializer.Assign(initial_bias),
W
WenmuZhou 已提交
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
            name=name + "_b")
        self.fc2 = nn.Linear(
            fc_dim,
            F * 2,
            weight_attr=param_attr,
            bias_attr=bias_attr,
            name=name)
        self.out_channels = F * 2

    def forward(self, x):
        """
           Estimating parameters of geometric transformation
           Args:
               image: input
           Return:
               batch_C_prime: the matrix of the geometric transformation
        """
        B = x.shape[0]
        i = 0
        for block in self.block_list:
            x = block(x)
W
WenmuZhou 已提交
135
        x = x.squeeze(axis=2).squeeze(axis=2)
W
WenmuZhou 已提交
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
        x = self.fc1(x)

        x = F.relu(x)
        x = self.fc2(x)
        x = x.reshape(shape=[-1, self.F, 2])
        return x

    def get_initial_fiducials(self):
        """ see RARE paper Fig. 6 (a) """
        F = self.F
        ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
        ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2))
        ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2))
        ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
        ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
        initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
        return initial_bias


class GridGenerator(nn.Layer):
    def __init__(self, in_channels, num_fiducial):
        super(GridGenerator, self).__init__()
        self.eps = 1e-6
        self.F = num_fiducial

        name = "ex_fc"
        initializer = nn.initializer.Constant(value=0.0)
        param_attr = ParamAttr(
            learning_rate=0.0, initializer=initializer, name=name + "_w")
        bias_attr = ParamAttr(
            learning_rate=0.0, initializer=initializer, name=name + "_b")
        self.fc = nn.Linear(
            in_channels,
            6,
            weight_attr=param_attr,
            bias_attr=bias_attr,
            name=name)

    def forward(self, batch_C_prime, I_r_size):
        """
        Generate the grid for the grid_sampler.
        Args:
            batch_C_prime: the matrix of the geometric transformation
            I_r_size: the shape of the input image
        Return:
            batch_P_prime: the grid for the grid_sampler
        """
W
WenmuZhou 已提交
183 184 185 186 187 188
        C = self.build_C_paddle()
        P = self.build_P_paddle(I_r_size)

        inv_delta_C_tensor = self.build_inv_delta_C_paddle(C).astype('float32')
        P_hat_tensor = self.build_P_hat_paddle(
            C, paddle.to_tensor(P)).astype('float32')
W
WenmuZhou 已提交
189 190 191 192 193 194 195 196 197 198 199 200 201 202

        inv_delta_C_tensor.stop_gradient = True
        P_hat_tensor.stop_gradient = True

        batch_C_ex_part_tensor = self.get_expand_tensor(batch_C_prime)

        batch_C_ex_part_tensor.stop_gradient = True

        batch_C_prime_with_zeros = paddle.concat(
            [batch_C_prime, batch_C_ex_part_tensor], axis=1)
        batch_T = paddle.matmul(inv_delta_C_tensor, batch_C_prime_with_zeros)
        batch_P_prime = paddle.matmul(P_hat_tensor, batch_T)
        return batch_P_prime

W
WenmuZhou 已提交
203
    def build_C_paddle(self):
W
WenmuZhou 已提交
204 205
        """ Return coordinates of fiducial points in I_r; C """
        F = self.F
W
WenmuZhou 已提交
206 207 208
        ctrl_pts_x = paddle.linspace(-1.0, 1.0, int(F / 2), dtype='float64')
        ctrl_pts_y_top = -1 * paddle.ones([int(F / 2)], dtype='float64')
        ctrl_pts_y_bottom = paddle.ones([int(F / 2)], dtype='float64')
W
WenmuZhou 已提交
209 210 211
        ctrl_pts_top = paddle.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
        ctrl_pts_bottom = paddle.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
        C = paddle.concat([ctrl_pts_top, ctrl_pts_bottom], axis=0)
W
WenmuZhou 已提交
212 213
        return C  # F x 2

W
WenmuZhou 已提交
214 215
    def build_P_paddle(self, I_r_size):
        I_r_height, I_r_width = I_r_size
W
WenmuZhou 已提交
216 217 218 219
        I_r_grid_x = paddle.divide(
            paddle.arange(
                -I_r_width, I_r_width, 2, dtype='float64') + 1.0,
            paddle.to_tensor(
W
WenmuZhou 已提交
220
                I_r_width, dtype='float64'))
W
WenmuZhou 已提交
221 222 223 224 225
        I_r_grid_y = paddle.divide(
            paddle.arange(
                -I_r_height, I_r_height, 2, dtype='float64') + 1.0,
            paddle.to_tensor(
                I_r_height, dtype='float64'))  # self.I_r_height
W
WenmuZhou 已提交
226
        # P: self.I_r_width x self.I_r_height x 2
W
WenmuZhou 已提交
227 228
        P = paddle.stack(paddle.meshgrid(I_r_grid_x, I_r_grid_y), axis=2)
        P = paddle.transpose(P, perm=[1, 0, 2])
W
WenmuZhou 已提交
229 230 231
        # n (= self.I_r_width x self.I_r_height) x 2
        return P.reshape([-1, 2])

W
WenmuZhou 已提交
232
    def build_inv_delta_C_paddle(self, C):
W
WenmuZhou 已提交
233 234
        """ Return inv_delta_C which is needed to calculate T """
        F = self.F
W
WenmuZhou 已提交
235
        hat_C = paddle.zeros((F, F), dtype='float64')  # F x F
W
WenmuZhou 已提交
236 237
        for i in range(0, F):
            for j in range(i, F):
W
WenmuZhou 已提交
238 239 240 241 242 243 244 245
                if i == j:
                    hat_C[i, j] = 1
                else:
                    r = paddle.norm(C[i] - C[j])
                    hat_C[i, j] = r
                    hat_C[j, i] = r
        hat_C = (hat_C**2) * paddle.log(hat_C)
        delta_C = paddle.concat(  # F+3 x F+3
W
WenmuZhou 已提交
246
            [
W
WenmuZhou 已提交
247
                paddle.concat(
W
WenmuZhou 已提交
248 249
                    [paddle.ones(
                        (F, 1), dtype='float64'), C, hat_C], axis=1),  # F x F+3
W
WenmuZhou 已提交
250
                paddle.concat(
W
WenmuZhou 已提交
251 252 253 254 255
                    [
                        paddle.zeros(
                            (2, 3), dtype='float64'), paddle.transpose(
                                C, perm=[1, 0])
                    ],
W
WenmuZhou 已提交
256 257
                    axis=1),  # 2 x F+3
                paddle.concat(
W
WenmuZhou 已提交
258 259 260 261 262
                    [
                        paddle.zeros(
                            (1, 3), dtype='float64'), paddle.ones(
                                (1, F), dtype='float64')
                    ],
W
WenmuZhou 已提交
263
                    axis=1)  # 1 x F+3
W
WenmuZhou 已提交
264 265
            ],
            axis=0)
W
WenmuZhou 已提交
266
        inv_delta_C = paddle.inverse(delta_C)
W
WenmuZhou 已提交
267 268
        return inv_delta_C  # F+3 x F+3

W
WenmuZhou 已提交
269
    def build_P_hat_paddle(self, C, P):
W
WenmuZhou 已提交
270 271 272 273
        F = self.F
        eps = self.eps
        n = P.shape[0]  # n (= self.I_r_width x self.I_r_height)
        # P_tile: n x 2 -> n x 1 x 2 -> n x F x 2
W
WenmuZhou 已提交
274 275
        P_tile = paddle.tile(paddle.unsqueeze(P, axis=1), (1, F, 1))
        C_tile = paddle.unsqueeze(C, axis=0)  # 1 x F x 2
W
WenmuZhou 已提交
276 277
        P_diff = P_tile - C_tile  # n x F x 2
        # rbf_norm: n x F
W
WenmuZhou 已提交
278 279
        rbf_norm = paddle.norm(P_diff, p=2, axis=2, keepdim=False)

W
WenmuZhou 已提交
280
        # rbf: n x F
W
WenmuZhou 已提交
281 282
        rbf = paddle.multiply(
            paddle.square(rbf_norm), paddle.log(rbf_norm + eps))
W
WenmuZhou 已提交
283 284 285
        P_hat = paddle.concat(
            [paddle.ones(
                (n, 1), dtype='float64'), P, rbf], axis=1)
W
WenmuZhou 已提交
286 287 288
        return P_hat  # n x F+3

    def get_expand_tensor(self, batch_C_prime):
W
WenmuZhou 已提交
289 290
        B, H, C = batch_C_prime.shape
        batch_C_prime = batch_C_prime.reshape([B, H * C])
W
WenmuZhou 已提交
291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307
        batch_C_ex_part_tensor = self.fc(batch_C_prime)
        batch_C_ex_part_tensor = batch_C_ex_part_tensor.reshape([-1, 3, 2])
        return batch_C_ex_part_tensor


class TPS(nn.Layer):
    def __init__(self, in_channels, num_fiducial, loc_lr, model_name):
        super(TPS, self).__init__()
        self.loc_net = LocalizationNetwork(in_channels, num_fiducial, loc_lr,
                                           model_name)
        self.grid_generator = GridGenerator(self.loc_net.out_channels,
                                            num_fiducial)
        self.out_channels = in_channels

    def forward(self, image):
        image.stop_gradient = False
        batch_C_prime = self.loc_net(image)
W
WenmuZhou 已提交
308
        batch_P_prime = self.grid_generator(batch_C_prime, image.shape[2:])
W
WenmuZhou 已提交
309 310 311 312
        batch_P_prime = batch_P_prime.reshape(
            [-1, image.shape[2], image.shape[3], 2])
        batch_I_r = F.grid_sample(x=image, grid=batch_P_prime)
        return batch_I_r