test_overlap_add_op.py 4.3 KB
Newer Older
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import paddle
import unittest

from op_test import OpTest


def overlap_add(x, hop_length, axis=-1):
    assert axis in [0, -1], 'axis should be 0/-1.'
    assert len(x.shape) >= 2, 'Input dims shoulb be >= 2.'

    squeeze_output = False
    if len(x.shape) == 2:
        squeeze_output = True
        dim = 0 if axis == -1 else -1
        x = np.expand_dims(x, dim)  # batch

    n_frames = x.shape[axis]
    frame_length = x.shape[1] if axis == 0 else x.shape[-2]

    # Assure no gaps between frames.
    assert 0 < hop_length <= frame_length, \
        f'hop_length should be in (0, frame_length({frame_length})], but got {hop_length}.'

    seq_length = (n_frames - 1) * hop_length + frame_length

    reshape_output = False
    if len(x.shape) > 3:
        reshape_output = True
        if axis == 0:
            target_shape = [seq_length] + list(x.shape[2:])
            x = x.reshape(n_frames, frame_length, np.product(x.shape[2:]))
        else:
            target_shape = list(x.shape[:-2]) + [seq_length]
            x = x.reshape(np.product(x.shape[:-2]), frame_length, n_frames)

    if axis == 0:
        x = x.transpose((2, 1, 0))

    y = np.zeros(shape=[np.product(x.shape[:-2]), seq_length], dtype=x.dtype)
    for i in range(x.shape[0]):
        for frame in range(x.shape[-1]):
            sample = frame * hop_length
            y[i, sample:sample + frame_length] += x[i, :, frame]

    if axis == 0:
        y = y.transpose((1, 0))

    if reshape_output:
        y = y.reshape(target_shape)

    if squeeze_output:
        y = y.squeeze(-1) if axis == 0 else y.squeeze(0)

    return y


class TestOverlapAddOp(OpTest):
73

74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
    def setUp(self):
        self.op_type = "overlap_add"
        self.shape, self.type, self.attrs = self.initTestCase()
        self.inputs = {
            'X': np.random.random(size=self.shape).astype(self.type),
        }
        self.outputs = {'Out': overlap_add(x=self.inputs['X'], **self.attrs)}

    def initTestCase(self):
        input_shape = (50, 3)
        input_type = 'float64'
        attrs = {
            'hop_length': 4,
            'axis': -1,
        }
        return input_shape, input_type, attrs

    def test_check_output(self):
        paddle.enable_static()
        self.check_output()
        paddle.disable_static()

    def test_check_grad_normal(self):
        paddle.enable_static()
        self.check_grad(['X'], 'Out')
        paddle.disable_static()


class TestCase1(TestOverlapAddOp):
103

104 105 106 107 108 109 110 111 112 113 114
    def initTestCase(self):
        input_shape = (3, 50)
        input_type = 'float64'
        attrs = {
            'hop_length': 4,
            'axis': 0,
        }
        return input_shape, input_type, attrs


class TestCase2(TestOverlapAddOp):
115

116 117 118 119 120 121 122 123 124 125 126
    def initTestCase(self):
        input_shape = (2, 40, 5)
        input_type = 'float64'
        attrs = {
            'hop_length': 10,
            'axis': -1,
        }
        return input_shape, input_type, attrs


class TestCase3(TestOverlapAddOp):
127

128 129 130 131 132 133 134 135 136 137 138
    def initTestCase(self):
        input_shape = (5, 40, 2)
        input_type = 'float64'
        attrs = {
            'hop_length': 10,
            'axis': 0,
        }
        return input_shape, input_type, attrs


class TestCase4(TestOverlapAddOp):
139

140 141 142 143 144 145 146 147 148 149 150
    def initTestCase(self):
        input_shape = (3, 5, 12, 8)
        input_type = 'float64'
        attrs = {
            'hop_length': 5,
            'axis': -1,
        }
        return input_shape, input_type, attrs


class TestCase5(TestOverlapAddOp):
151

152 153 154 155 156 157 158 159 160 161 162 163
    def initTestCase(self):
        input_shape = (8, 12, 5, 3)
        input_type = 'float64'
        attrs = {
            'hop_length': 5,
            'axis': 0,
        }
        return input_shape, input_type, attrs


if __name__ == '__main__':
    unittest.main()