upfirdn2d.py 4.1 KB
Newer Older
H
Hecong Wu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
L
LielinJiang 已提交
18 19 20 21


def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
                     pad_y0, pad_y1):
H
Hecong Wu 已提交
22 23
    _, channel, in_h, in_w = input.shape
    input = input.reshape((-1, in_h, in_w, 1))
L
LielinJiang 已提交
24

H
Hecong Wu 已提交
25 26
    _, in_h, in_w, minor = input.shape
    kernel_h, kernel_w = kernel.shape
L
LielinJiang 已提交
27

H
Hecong Wu 已提交
28
    out = input.reshape((-1, in_h, 1, in_w, 1, minor))
L
LielinJiang 已提交
29 30
    out = out.transpose((0, 1, 3, 5, 2, 4))
    out = out.reshape((-1, 1, 1, 1))
H
Hecong Wu 已提交
31 32
    out = F.pad(out, [0, up_x - 1, 0, up_y - 1])
    out = out.reshape((-1, in_h, in_w, minor, up_y, up_x))
L
LielinJiang 已提交
33
    out = out.transpose((0, 3, 1, 4, 2, 5))
H
Hecong Wu 已提交
34
    out = out.reshape((-1, minor, in_h * up_y, in_w * up_x))
L
LielinJiang 已提交
35

H
Hecong Wu 已提交
36
    out = F.pad(
L
LielinJiang 已提交
37 38 39 40 41 42 43 44 45 46
        out, [max(pad_x0, 0),
              max(pad_x1, 0),
              max(pad_y0, 0),
              max(pad_y1, 0)])
    out = out[:, :,
              max(-pad_y0, 0):out.shape[2] - max(-pad_y1, 0),
              max(-pad_x0, 0):out.shape[3] - max(-pad_x1, 0), ]

    out = out.reshape(
        ([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]))
H
Hecong Wu 已提交
47 48 49 50 51 52 53 54 55 56
    w = paddle.flip(kernel, [0, 1]).reshape((1, 1, kernel_h, kernel_w))
    out = F.conv2d(out, w)
    out = out.reshape((
        -1,
        minor,
        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
    ))
    out = out.transpose((0, 2, 3, 1))
    out = out[:, ::down_y, ::down_x, :]
L
LielinJiang 已提交
57

H
Hecong Wu 已提交
58 59
    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
L
LielinJiang 已提交
60

H
Hecong Wu 已提交
61 62
    return out.reshape((-1, channel, out_h, out_w))

L
LielinJiang 已提交
63

H
Hecong Wu 已提交
64
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
L
LielinJiang 已提交
65 66 67
    out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1],
                           pad[0], pad[1])

H
Hecong Wu 已提交
68 69 70 71 72
    return out


def make_kernel(k):
    k = paddle.to_tensor(k, dtype='float32')
L
LielinJiang 已提交
73

H
Hecong Wu 已提交
74 75
    if k.ndim == 1:
        k = k.unsqueeze(0) * k.unsqueeze(1)
L
LielinJiang 已提交
76

H
Hecong Wu 已提交
77
    k /= k.sum()
L
LielinJiang 已提交
78

H
Hecong Wu 已提交
79
    return k
L
LielinJiang 已提交
80 81


H
Hecong Wu 已提交
82 83 84
class Upfirdn2dUpsample(nn.Layer):
    def __init__(self, kernel, factor=2):
        super().__init__()
L
LielinJiang 已提交
85

H
Hecong Wu 已提交
86
        self.factor = factor
L
LielinJiang 已提交
87
        kernel = make_kernel(kernel) * (factor * factor)
H
Hecong Wu 已提交
88
        self.register_buffer("kernel", kernel)
L
LielinJiang 已提交
89

H
Hecong Wu 已提交
90
        p = kernel.shape[0] - factor
L
LielinJiang 已提交
91

H
Hecong Wu 已提交
92 93
        pad0 = (p + 1) // 2 + factor - 1
        pad1 = p // 2
L
LielinJiang 已提交
94

H
Hecong Wu 已提交
95
        self.pad = (pad0, pad1)
L
LielinJiang 已提交
96

H
Hecong Wu 已提交
97
    def forward(self, input):
L
LielinJiang 已提交
98 99 100 101 102 103
        out = upfirdn2d(input,
                        self.kernel,
                        up=self.factor,
                        down=1,
                        pad=self.pad)

H
Hecong Wu 已提交
104
        return out
L
LielinJiang 已提交
105 106


H
Hecong Wu 已提交
107 108 109
class Upfirdn2dDownsample(nn.Layer):
    def __init__(self, kernel, factor=2):
        super().__init__()
L
LielinJiang 已提交
110

H
Hecong Wu 已提交
111 112 113
        self.factor = factor
        kernel = make_kernel(kernel)
        self.register_buffer("kernel", kernel)
L
LielinJiang 已提交
114

H
Hecong Wu 已提交
115
        p = kernel.shape[0] - factor
L
LielinJiang 已提交
116

H
Hecong Wu 已提交
117 118
        pad0 = (p + 1) // 2
        pad1 = p // 2
L
LielinJiang 已提交
119

H
Hecong Wu 已提交
120
        self.pad = (pad0, pad1)
L
LielinJiang 已提交
121

H
Hecong Wu 已提交
122
    def forward(self, input):
L
LielinJiang 已提交
123 124 125 126 127 128
        out = upfirdn2d(input,
                        self.kernel,
                        up=1,
                        down=self.factor,
                        pad=self.pad)

H
Hecong Wu 已提交
129
        return out
L
LielinJiang 已提交
130 131


H
Hecong Wu 已提交
132 133 134
class Upfirdn2dBlur(nn.Layer):
    def __init__(self, kernel, pad, upsample_factor=1):
        super().__init__()
L
LielinJiang 已提交
135

H
Hecong Wu 已提交
136
        kernel = make_kernel(kernel)
L
LielinJiang 已提交
137

H
Hecong Wu 已提交
138
        if upsample_factor > 1:
L
LielinJiang 已提交
139 140 141 142
            kernel = kernel * (upsample_factor * upsample_factor)

        self.register_buffer("kernel", kernel, persistable=False)

H
Hecong Wu 已提交
143
        self.pad = pad
L
LielinJiang 已提交
144

H
Hecong Wu 已提交
145 146
    def forward(self, input):
        out = upfirdn2d(input, self.kernel, pad=self.pad)
L
LielinJiang 已提交
147

H
Hecong Wu 已提交
148
        return out