spectral_op_np.py 4.9 KB
Newer Older
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

F
Feiyu Chan 已提交
15
import enum
16 17 18
import numpy as np
from functools import partial
from numpy import asarray
19
from numpy.fft._pocketfft import _cook_nd_args, _raw_fft, _raw_fftnd
20 21


F
Feiyu Chan 已提交
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
class NormMode(enum.Enum):
    none = 1
    by_sqrt_n = 2
    by_n = 3


def _get_norm_mode(norm, forward):
    if norm == "ortho":
        return NormMode.by_sqrt_n
    if norm is None or norm == "backward":
        return NormMode.none if forward else NormMode.by_n
    return NormMode.by_n if forward else NormMode.none


def _get_inv_norm(n, norm_mode):
    assert isinstance(norm_mode,
                      NormMode), "invalid norm_type {}".format(norm_mode)
    if norm_mode == NormMode.none:
        return 1.0
    if norm_mode == NormMode.by_sqrt_n:
        return np.sqrt(n)
    return n


# 1d transforms
47 48 49 50
def _fftc2c(a, n=None, axis=-1, norm=None, forward=None):
    a = asarray(a)
    if n is None:
        n = a.shape[axis]
F
Feiyu Chan 已提交
51
    inv_norm = _get_inv_norm(n, norm)
52 53 54 55 56 57 58 59
    output = _raw_fft(a, n, axis, False, forward, inv_norm)
    return output


def _fftr2c(a, n=None, axis=-1, norm=None, forward=None):
    a = asarray(a)
    if n is None:
        n = a.shape[axis]
F
Feiyu Chan 已提交
60
    inv_norm = _get_inv_norm(n, norm)
61 62 63 64 65 66 67 68 69 70
    output = _raw_fft(a, n, axis, True, True, inv_norm)
    if not forward:
        output = output.conj()
    return output


def _fftc2r(a, n=None, axis=-1, norm=None, forward=None):
    a = asarray(a)
    if n is None:
        n = (a.shape[axis] - 1) * 2
F
Feiyu Chan 已提交
71
    inv_norm = _get_inv_norm(n, norm)
72 73
    output = _raw_fft(a.conj() if forward else a, n, axis, True, False,
                      inv_norm)
74 75 76
    return output


F
Feiyu Chan 已提交
77 78
# general fft functors
def _fft_c2c_nd(x, axes, norm_mode, forward):
79
    f = partial(_fftc2c, forward=forward)
F
Feiyu Chan 已提交
80
    y = _raw_fftnd(x, s=None, axes=axes, function=f, norm=norm_mode)
81 82 83
    return y


F
Feiyu Chan 已提交
84
def _fft_r2c_nd(x, axes, norm_mode, forward, onesided):
85 86 87
    a = asarray(x)
    s, axes = _cook_nd_args(a, axes=axes)
    if onesided:
F
Feiyu Chan 已提交
88 89
        a = _fftr2c(a, s[-1], axes[-1], norm_mode, forward)
        a = _fft_c2c_nd(a, axes[:-1], norm_mode, forward)
90
    else:
F
Feiyu Chan 已提交
91 92 93 94 95 96 97 98 99 100 101
        a = _fft_c2c_nd(x, axes, norm_mode, forward)
    return a


def _fft_c2r_nd(x, axes, norm_mode, forward, last_dim_size):
    a = asarray(x)
    s, axes = _cook_nd_args(a, axes=axes, invreal=1)
    if last_dim_size is not None:
        s[-1] = last_dim_size
    a = _fft_c2c_nd(a, axes[:-1], norm_mode, forward)
    a = _fftc2r(a, s[-1], axes[-1], norm_mode, forward)
102 103 104
    return a


F
Feiyu Chan 已提交
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
# kernels
def fft_c2c(x, axes, normalization, forward):
    norm_mode = _get_norm_mode(normalization, forward)
    return _fft_c2c_nd(x, axes, norm_mode, forward)


def fft_c2r(x, axes, normalization, forward, last_dim_size):
    norm_mode = _get_norm_mode(normalization, forward)
    return _fft_c2r_nd(x, axes, norm_mode, forward, last_dim_size)


def fft_r2c(x, axes, normalization, forward, onesided):
    norm_mode = _get_norm_mode(normalization, forward)
    return _fft_r2c_nd(x, axes, norm_mode, forward, onesided)


# backward kernel
def fft_c2c_backward(dy, axes, normalization, forward):
    norm_mode = _get_norm_mode(normalization, forward)
    dx = _fft_c2c_nd(dy, axes, norm_mode, not forward)
    return dx


def fft_r2c_backward(x, dy, axes, normalization, forward, onesided):
129 130
    a = dy
    if not onesided:
F
Feiyu Chan 已提交
131
        a = fft_c2c_backward(a, axes, normalization, forward)
132 133 134 135 136 137 138 139
    else:
        pad_widths = [(0, 0)] * a.ndim
        last_axis = axes[-1]
        if last_axis < 0:
            last_axis += a.ndim
        last_dim_size = a.shape[last_axis]
        pad_widths[last_axis] = (0, x.shape[last_axis] - last_dim_size)
        a = np.pad(a, pad_width=pad_widths)
F
Feiyu Chan 已提交
140 141
        a = fft_c2c_backward(a, axes, normalization, forward)
    return a.real
142 143


F
Feiyu Chan 已提交
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
def _fft_fill_conj_grad(x, axes, length_to_double):
    last_fft_axis = axes[-1]
    shape = x.shape
    for multi_index in np.ndindex(*shape):
        if 0 < multi_index[last_fft_axis] and multi_index[
                last_fft_axis] <= length_to_double:
            x[multi_index] *= 2
    return x


def fft_c2r_backward(x, dy, axes, normalization, forward, last_dim_size):
    norm_mode = _get_norm_mode(normalization, forward)
    a = dy
    a = _fft_r2c_nd(dy, axes, norm_mode, not forward, True)
    last_fft_axis = axes[-1]
    length_to_double = dy.shape[last_fft_axis] - x.shape[last_fft_axis]
    a = _fft_fill_conj_grad(a, axes, length_to_double)
161
    return a