“a99199034db5e13238c60515647e9f62ac71d2d4”上不存在“python/paddle/fluid/tests/unittests/eager_op_test.py”
未验证 提交 11518a43 编写于 作者: F Feiyu Chan 提交者: GitHub

Add FFT related operators and APIs (#35665)

* 1. add interface for fft;
2. add data type predicate;
3. fix paddle.roll.

* add fft c2c cufft kernel

* implement argument checking & op calling parts for fft_c2c and fftn_c2c

* add operator and opmaker definitions

* only register float and double for cpu.

* add common code for implementing FFT, add pocketfft as a dependency

* add fft c2c cufft kernel function

* fix bugs in python interface

* add support for c2r, r2c operators, op makers, kernels and kernel functors.

* test and fix bugs

* 1. fft_c2c function: add support for onesided=False;
2. add complex<float>, complex<double> support for concat and flip.

* 1. fft: fix python api bugs;
2. shape_op: add support for complex data types.

* fft c2c cufft kernel done with complie and link

* fix shape_op, add mkl placeholder

* remove mkl

* complete fft c2c in gpu

* 1. implement mkl-based fft, FFTC2CFunctor and common function exec_fft;
2. change the design, add input and output typename as template parameter for all FFTFunctors, update pocketfft-based implementation.

* complete fft c2c on gpu in ND

* complete fft c2c on gpu in ND

* complete fft c2c backward in ND

* fix MKL-based implementation

* Add frame op and CPU/GPU kernels.

* Add frame op forward unittest.

* Add frame op forward unittest.

* Remove axis parameter in FrameFunctor.

* Add frame op grad CPU/GPU kernels and unittest.

* Add frame op grad CPU/GPU kernels and unittest.

* Update doc string.

* Update after review and remove librosa requirement in unittest.

* Update grad kernel.

* add fft_c2r op

* Remove data allocation in TransCompute function.

* add fft r2c onesided with cpu(pocketfft/mkl) and gpu

* last fft c2r functor

* fix C2R and R2C for cufft, becase the direction is not an option in these cases.

* add fft r2c onesided with cpu(pocketfft/mkl) and gpu

* fix bugs in python APIs

* fix fft_c2r grad kernal

* fix bugs in python APIs

* add cuda fft c2r grad kernal functor

* clean code

* fix fft_c2r python API

* fill fft r2c result with conjugate symmetry (#19)

fill fft r2c result with conjugate symmetry

* add placeholder for unittests (#24)

* simple parameterize test function by auto generate test case from parm list (#25)

* miscellaneous fixes for python APIs (#26)

* add placeholder for unittests

* resize fft inputs before computation is n or s is provided.

* add complex kernels for pad and pad_grad

* simplify argument checking.

* add type promotion

* add int to float or complex promotion

* fix output data type for static mode

* fix fft's input dtype dispatch, import fft to paddle

* fix typos in axes checking (#27)

* fix typos in axes checking

* fix argument checking (#28)

* fix argument checking

* Add C2R Python layer normal and abnormal use cases (#29)

* documents and single case

* test c2r case

* New C2R Python layer normal and exception use cases

* complete rfft,rfft2,rfftn,ihfft,ihfft2,ihfftn unittest and doc string (#30)

* Documentation of the common interfaces of c2r and c2c (#31)

* Documentation of the common interfaces of c2r and c2c

* clean c++ code  (#32)

* clean code

* Add numpy-based implementation of spectral ops (#33)

* add numpy reference implementation of spectral ops

* Add fft_c2r numpy based implementation for unittest. (#34)

* add fft_c2r numpy implementation

* Add deframe op and stft/istft api. (#23)

* Add frame api

* Add deframe op and kernels.

* Add stft and istft apis.

* Add deframe api. Update stft and istft apis.

* Fix bug in frame_from_librosa function when input dims >= 3

* Rename deframe to overlap_add.

* Update istft.

* Update after code review.

* Add overlap_add op and stft/istft api unittest (#35)

* Add overlap_add op unittest.

* Register complex kernels of squeeze/unsquuze op.

* Add stft/istft api unittest.

* Add unittest for fft helper functions (#36)

* add unittests for fft helper functions. add complex kernel for roll op.

* complete static graph unittest for all public api (#37)

* Unittest of op with FFT C2C, C2R and r2c added (#38)

* documents and single case

* test c2r case

* New C2R Python layer normal and exception use cases

* Documentation of the common interfaces of c2r and c2c

* Unittest of op with FFT C2C, C2R and r2c added
Co-authored-by: lijiaqi0612's avatarlijiaqi <lijiaqi0612@163.com>

* add fft related options to CMakeLists.txt

* fix typos and clean code (#39)

* fix invisible character in mkl branch and fix error in error message

* clean code: remove docstring from unittest for signal.py.

* always convert numpy array to paddle.Tensor to avoid comparing numpy dtype with paddle dtype. (#40)

* always convert numpy array to paddle.Tensor to avoid comparing numpy dtype with paddle dtype.

* fix CI Errors: numpy dtype comparison, thrust when cuda is not available (#41)

1. always convert numpy array to paddle.Tensor to avoid comparing numpy dtype with paddle dtype.
2. promote floating point tensor to complex tensor ior fft_c2c and fft_c2r;
3. fix unittest to catch UnImplementedError and RuntimeError;
4. fix compile error by avoid using thrust when cuda is not available.
5.  fix sample code, use paddle.fft instead of paddle.tensor.fft

* remove inclusion of thrust, add __all__ list for fft (#42)

* Add api doc and update unittest. (#43)

* Add doc strings.
* Update overlap_add op unittest

* fix MKL-based FFT implementation (#44)

* fix MKL-based FFT implementation, MKL CDFT's FORWARD DOMAIN is always REAL for R2C and C2R

* remove code for debug (#45)

* use dynload for cufft (#46)

* use std::ptrdiff_t as datatype of stride (instead of int64_t) to avoid argument mismatch on some platforms.

* add complex support for fill_zeros_like

* use dynload for cufft

* Update doc and unittest. (#47)

* Add doc of frame op and overlap_add op.

* Update unittest.

* use dynload for cufft (#48)

1. use dynload for cufft
2. fix unittest;
3. temporarily disable Rocm.

* fix conflicts and merge upstream (#49)

fix conflicts and merge upstream

* fix compile error: only link dyload_cuda when cuda is available (#50)

* fix compile error: only link dyload_cuda when cuda is available

* fix dynload for cufft on windows (#51)

1. fix dynload for cufft on windows;
2. fix unittests.

* add NOMINMAX to compile on windows (#52)

 add NOMINMAX to compile on windows

* explicitly specify capture mode for lambdas (#55)

 explicitly specify capture mode for lambdas

* fix fft sample (#53)

* fix fft sample

* update scipy and numpy version for unittests of fft (#56)

update scipy and numpy version for unittests of fft

* Add static graph unittests of frame and overlap_add api. (#57)

* Remove cache of cuFFT & Disable ONEMKL (#59)

1. replace numpy.fft with scipy.fft as numpy<1.20 not support ortho norm
2. remove cache of cufft plans;
3. enhance error checking.
4. default WITH_ONEMKL to OFF
Co-authored-by: Njeff41404 <jeff41404@gmail.com>
Co-authored-by: Nroot <root@bjyz-sys-gpu-kongming9.bjyz.baidu.com>
Co-authored-by: NKP <109694228@qq.com>
Co-authored-by: lijiaqi0612's avatarlijiaqi <lijiaqi0612@163.com>
Co-authored-by: NXiaoxu Chen <chenxx_id@163.com>
Co-authored-by: Nlijiaqi0612 <33169170+lijiaqi0612@users.noreply.github.com>
上级 01063218
...@@ -38,6 +38,8 @@ project(paddle CXX C) ...@@ -38,6 +38,8 @@ project(paddle CXX C)
# enable language CUDA # enable language CUDA
# TODO(Shibo Tao): remove find_package(CUDA) completely. # TODO(Shibo Tao): remove find_package(CUDA) completely.
find_package(CUDA QUIET) find_package(CUDA QUIET)
find_package(MKL CONFIG QUIET)
option(WITH_ONEMKL "Compile PaddlePaddle with oneMKL" OFF)
option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND}) option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND})
option(WITH_TENSORRT "Compile PaddlePaddle with NVIDIA TensorRT" OFF) option(WITH_TENSORRT "Compile PaddlePaddle with NVIDIA TensorRT" OFF)
option(WITH_XPU "Compile PaddlePaddle with BAIDU KUNLUN XPU" OFF) option(WITH_XPU "Compile PaddlePaddle with BAIDU KUNLUN XPU" OFF)
...@@ -225,6 +227,7 @@ option(WITH_STRIP "Strip so files of Whl packages" OFF) ...@@ -225,6 +227,7 @@ option(WITH_STRIP "Strip so files of Whl packages" OFF)
option(NEW_RELEASE_CUBIN "PaddlePaddle next-level release strategy for pypi cubin package" OFF) option(NEW_RELEASE_CUBIN "PaddlePaddle next-level release strategy for pypi cubin package" OFF)
option(NEW_RELEASE_JIT "PaddlePaddle next-level release strategy for backup jit package" OFF) option(NEW_RELEASE_JIT "PaddlePaddle next-level release strategy for backup jit package" OFF)
option(WITH_ASCEND_INT64 "Compile with int64 kernel for ascend NPU" OFF) option(WITH_ASCEND_INT64 "Compile with int64 kernel for ascend NPU" OFF)
option(WITH_POCKETFFT "Compile with pocketfft support" ON)
# PY_VERSION # PY_VERSION
if(NOT PY_VERSION) if(NOT PY_VERSION)
...@@ -373,6 +376,10 @@ if (WITH_MIPS) ...@@ -373,6 +376,10 @@ if (WITH_MIPS)
add_definitions(-DPADDLE_WITH_MIPS) add_definitions(-DPADDLE_WITH_MIPS)
endif() endif()
if (WITH_ONEMKL)
add_definitions(-DPADDLE_WITH_ONEMKL)
endif()
if (WITH_HETERPS) if (WITH_HETERPS)
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -faligned-new") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -faligned-new")
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
find_library(GPERFTOOLS_TCMALLOC find_library(GPERFTOOLS_TCMALLOC
NAMES tcmalloc NAMES tcmalloc
HINTS ${Gperftools_ROOT_DIR}/lib) HINTS ${Gperftools_ROOT_DIR}/lib)
find_library(GPERFTOOLS_PROFILER find_library(GPERFTOOLS_PROFILER
NAMES profiler NAMES profiler
HINTS ${Gperftools_ROOT_DIR}/lib) HINTS ${Gperftools_ROOT_DIR}/lib)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
include(ExternalProject)
set(POCKETFFT_PATH "${THIRD_PARTY_PATH}/pocketfft" CACHE STRING "A path setting for external_pocketfft path.")
set(POCKETFFT_PREFIX_DIR ${POCKETFFT_PATH})
set(POCKETFFT_REPOSITORY https://gitlab.mpcdf.mpg.de/mtr/pocketfft.git)
set(POCKETFFT_TAG release_for_eigen)
SET(POCKETFFT_INCLUDE_DIR ${POCKETFFT_PREFIX_DIR}/src)
message("POCKETFFT_INCLUDE_DIR is ${POCKETFFT_INCLUDE_DIR}")
include_directories(${POCKETFFT_INCLUDE_DIR})
ExternalProject_Add(
extern_pocketfft
${EXTERNAL_PROJECT_LOG_ARGS}
${SHALLOW_CLONE}
GIT_REPOSITORY ${POCKETFFT_REPOSITORY}
GIT_TAG ${POCKETFFT_TAG}
PREFIX ${POCKETFFT_PREFIX_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
TEST_COMMAND ""
)
add_library(pocketfft INTERFACE)
add_dependencies(pocketfft extern_pocketfft)
...@@ -361,4 +361,10 @@ if (WITH_CRYPTO) ...@@ -361,4 +361,10 @@ if (WITH_CRYPTO)
add_definitions(-DPADDLE_WITH_CRYPTO) add_definitions(-DPADDLE_WITH_CRYPTO)
endif (WITH_CRYPTO) endif (WITH_CRYPTO)
if (WITH_POCKETFFT)
include(external/pocketfft)
list(APPEND third_party_deps extern_pocketfft)
add_definitions(-DPADDLE_WITH_POCKETFFT)
endif (WITH_POCKETFFT)
add_custom_target(third_party ALL DEPENDS ${third_party_deps}) add_custom_target(third_party ALL DEPENDS ${third_party_deps})
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <iostream>
#include <string> #include <string>
#include <typeindex> #include <typeindex>
...@@ -170,11 +171,26 @@ extern inline proto::VarType::Type ToComplexType(proto::VarType::Type t) { ...@@ -170,11 +171,26 @@ extern inline proto::VarType::Type ToComplexType(proto::VarType::Type t) {
return proto::VarType::COMPLEX128; return proto::VarType::COMPLEX128;
default: default:
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unknown complex value data type (%s), now only support float32 and " "Unknown real value data type (%s), now only support float32 and "
"float64.", "float64.",
DataTypeToString(t))); DataTypeToString(t)));
} }
} }
extern inline proto::VarType::Type ToRealType(proto::VarType::Type t) {
switch (t) {
case proto::VarType::COMPLEX64:
return proto::VarType::FP32;
case proto::VarType::COMPLEX128:
return proto::VarType::FP64;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unknown complex value data type (%s), now only support complex64 "
"and "
"complex128.",
DataTypeToString(t)));
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -59,6 +59,10 @@ if (WITH_GPU) ...@@ -59,6 +59,10 @@ if (WITH_GPU)
endif() endif()
endif() endif()
if (WITH_POCKETFFT)
SET(OP_HEADER_DEPS ${OP_HEADER_DEPS} pocketfft)
endif()
SET(OP_MKL_DEPS "") SET(OP_MKL_DEPS "")
if (NOT WITH_MKL OR NOT WITH_AVX) if (NOT WITH_MKL OR NOT WITH_AVX)
...@@ -75,7 +79,7 @@ if(WITH_UNITY_BUILD) ...@@ -75,7 +79,7 @@ if(WITH_UNITY_BUILD)
endif() endif()
register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op lstm_op run_program_op eye_op recurrent_op register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op lstm_op run_program_op eye_op recurrent_op
sync_batch_norm_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS}) sync_batch_norm_op spectral_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})
op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc DEPS executor_cache ${OP_HEADER_DEPS}) op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc DEPS executor_cache ${OP_HEADER_DEPS})
...@@ -94,6 +98,12 @@ else() ...@@ -94,6 +98,12 @@ else()
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
endif() endif()
if (WITH_GPU AND (NOT WITH_ROCM))
op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda ${OP_HEADER_DEPS})
else()
op_library(spectral_op SRCS spectral_op.cc DEPS ${OP_HEADER_DEPS})
endif()
op_library(lstm_op DEPS ${OP_HEADER_DEPS} lstm_compute) op_library(lstm_op DEPS ${OP_HEADER_DEPS} lstm_compute)
op_library(eye_op DEPS ${OP_HEADER_DEPS}) op_library(eye_op DEPS ${OP_HEADER_DEPS})
op_library(recurrent_op DEPS ${OP_HEADER_DEPS}) op_library(recurrent_op DEPS ${OP_HEADER_DEPS})
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/concat_op.h" #include "paddle/fluid/operators/concat_op.h"
#include <paddle/fluid/platform/complex.h>
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -237,7 +238,11 @@ REGISTER_OP_CPU_KERNEL( ...@@ -237,7 +238,11 @@ REGISTER_OP_CPU_KERNEL(
ops::ConcatKernel<paddle::platform::CPUDeviceContext, ops::ConcatKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>, paddle::platform::float16>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, int>, ops::ConcatKernel<paddle::platform::CPUDeviceContext, int>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, uint8_t>); ops::ConcatKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
concat_grad, concat_grad,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, double>, ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, double>,
...@@ -247,4 +252,8 @@ REGISTER_OP_CPU_KERNEL( ...@@ -247,4 +252,8 @@ REGISTER_OP_CPU_KERNEL(
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, ops::ConcatGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>, paddle::platform::float16>,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, int>, ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, uint8_t>); ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/concat_op.h" #include "paddle/fluid/operators/concat_op.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
...@@ -24,7 +25,11 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -24,7 +25,11 @@ REGISTER_OP_CUDA_KERNEL(
ops::ConcatKernel<paddle::platform::CUDADeviceContext, plat::float16>, ops::ConcatKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::ConcatKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, int>, ops::ConcatKernel<paddle::platform::CUDADeviceContext, int>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, uint8_t>); ops::ConcatKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext,
plat::complex<float>>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext,
plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
concat_grad, concat_grad,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, double>, ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, double>,
...@@ -33,4 +38,8 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -33,4 +38,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, plat::float16>, ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, int>, ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, uint8_t>); ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext,
plat::complex<float>>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext,
plat::complex<double>>);
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/eigen/eigen_function.h" #include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -42,6 +43,8 @@ template struct EigenScale<Eigen::DefaultDevice, int8_t>; ...@@ -42,6 +43,8 @@ template struct EigenScale<Eigen::DefaultDevice, int8_t>;
template struct EigenScale<Eigen::DefaultDevice, int16_t>; template struct EigenScale<Eigen::DefaultDevice, int16_t>;
template struct EigenScale<Eigen::DefaultDevice, int>; template struct EigenScale<Eigen::DefaultDevice, int>;
template struct EigenScale<Eigen::DefaultDevice, int64_t>; template struct EigenScale<Eigen::DefaultDevice, int64_t>;
template struct EigenScale<Eigen::DefaultDevice, platform::complex<float>>;
template struct EigenScale<Eigen::DefaultDevice, platform::complex<double>>;
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/eigen/eigen_function.h" #include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
...@@ -41,6 +42,8 @@ template struct EigenScale<Eigen::GpuDevice, int16_t>; ...@@ -41,6 +42,8 @@ template struct EigenScale<Eigen::GpuDevice, int16_t>;
template struct EigenScale<Eigen::GpuDevice, int>; template struct EigenScale<Eigen::GpuDevice, int>;
template struct EigenScale<Eigen::GpuDevice, int64_t>; template struct EigenScale<Eigen::GpuDevice, int64_t>;
template struct EigenScale<Eigen::GpuDevice, platform::float16>; template struct EigenScale<Eigen::GpuDevice, platform::float16>;
template struct EigenScale<Eigen::GpuDevice, platform::complex<float>>;
template struct EigenScale<Eigen::GpuDevice, platform::complex<double>>;
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/fill_zeros_like_op.h" #include "paddle/fluid/operators/fill_zeros_like_op.h"
#include "paddle/fluid/platform/complex.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -93,7 +94,11 @@ REGISTER_OP_CPU_KERNEL( ...@@ -93,7 +94,11 @@ REGISTER_OP_CPU_KERNEL(
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, float>, ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, float>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, double>, ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, double>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, bool>); ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, bool>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
fill_zeros_like2, fill_zeros_like2,
...@@ -101,4 +106,8 @@ REGISTER_OP_CPU_KERNEL( ...@@ -101,4 +106,8 @@ REGISTER_OP_CPU_KERNEL(
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, float>, ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, float>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, double>, ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, double>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, bool>); ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, bool>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/fill_zeros_like_op.h" #include "paddle/fluid/operators/fill_zeros_like_op.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
...@@ -25,7 +26,11 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -25,7 +26,11 @@ REGISTER_OP_CUDA_KERNEL(
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, double>, ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, double>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>, paddle::platform::float16>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, bool>); ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, bool>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
fill_zeros_like2, fill_zeros_like2,
...@@ -35,4 +40,8 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -35,4 +40,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, double>, ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, double>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>, paddle::platform::float16>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, bool>); ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, bool>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/complex.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -145,6 +146,7 @@ class FlipOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -145,6 +146,7 @@ class FlipOpGradMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(flip, ops::FlipOp, ops::FlipOpMaker, ops::FlipOpInferVarType, REGISTER_OPERATOR(flip, ops::FlipOp, ops::FlipOpMaker, ops::FlipOpInferVarType,
ops::FlipOpGradMaker<paddle::framework::OpDesc>, ops::FlipOpGradMaker<paddle::framework::OpDesc>,
ops::FlipOpGradMaker<paddle::imperative::OpBase>); ops::FlipOpGradMaker<paddle::imperative::OpBase>);
...@@ -153,7 +155,9 @@ REGISTER_OP_CPU_KERNEL( ...@@ -153,7 +155,9 @@ REGISTER_OP_CPU_KERNEL(
ops::FlipKernel<paddle::platform::CPUDeviceContext, double>, ops::FlipKernel<paddle::platform::CPUDeviceContext, double>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, int32_t>, ops::FlipKernel<paddle::platform::CPUDeviceContext, int32_t>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::FlipKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, bool>); ops::FlipKernel<paddle::platform::CPUDeviceContext, bool>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, plat::complex<float>>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, plat::complex<double>>);
/* ========================== register checkpoint ===========================*/ /* ========================== register checkpoint ===========================*/
REGISTER_OP_VERSION(flip) REGISTER_OP_VERSION(flip)
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/complex.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -163,4 +164,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -163,4 +164,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::FlipKernel<paddle::platform::CUDADeviceContext, plat::float16>, ops::FlipKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::FlipKernel<paddle::platform::CUDADeviceContext, int>, ops::FlipKernel<paddle::platform::CUDADeviceContext, int>,
ops::FlipKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::FlipKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::FlipKernel<paddle::platform::CUDADeviceContext, bool>); ops::FlipKernel<paddle::platform::CUDADeviceContext, bool>,
ops::FlipKernel<paddle::platform::CUDADeviceContext, plat::complex<float>>,
ops::FlipKernel<paddle::platform::CUDADeviceContext,
plat::complex<double>>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/frame_op.h"
namespace paddle {
namespace operators {
class FrameOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "frame");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "frame");
const int frame_length = ctx->Attrs().Get<int>("frame_length");
const int hop_length = ctx->Attrs().Get<int>("hop_length");
const int axis = ctx->Attrs().Get<int>("axis");
const auto x_dims = ctx->GetInputDim("X");
const int x_rank = x_dims.size();
PADDLE_ENFORCE_GE(
x_rank, 1, platform::errors::InvalidArgument(
"Input(X) of FrameOp should be a tensor which contains "
"at least 1 dimension, but got rank %s.",
x_rank));
PADDLE_ENFORCE_GT(hop_length, 0,
platform::errors::InvalidArgument(
"Attribute(hop_length) of FrameOp should be greater "
"than 0, but got %s.",
hop_length));
PADDLE_ENFORCE_EQ(
(axis == 0 || axis == -1), true,
platform::errors::InvalidArgument(
"Attribute(axis) of FrameOp should 0 or -1, but got %s.", axis));
std::vector<int64_t> output_shape;
int seq_length;
int n_frames;
int start_axis;
int end_axis;
if (axis == 0) {
seq_length = x_dims[0];
start_axis = 1;
end_axis = x_rank - 1;
} else {
seq_length = x_dims[x_rank - 1];
start_axis = 0;
end_axis = x_rank - 2;
}
PADDLE_ENFORCE_LE(frame_length, seq_length,
platform::errors::InvalidArgument(
"Attribute(frame_length) of FrameOp should be less "
"equal than sequence length, but got (%s) > (%s).",
frame_length, seq_length));
// It won't go into for loop when x_rank == 1U.
for (int i = start_axis; i <= end_axis; i++) {
output_shape.push_back(x_dims[i]);
}
n_frames = 1 + (seq_length - frame_length) / hop_length;
if (axis == 0) {
// (n_frames, frame_length, ...)
output_shape.insert(output_shape.begin(), frame_length);
output_shape.insert(output_shape.begin(), n_frames);
} else {
// (..., frame_length, n_frames)
output_shape.push_back(frame_length);
output_shape.push_back(n_frames);
}
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(in_dtype, ctx.GetPlace());
}
};
class FrameOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of frame op.");
AddOutput("Out", "(Tensor), The output tensor of frame op.");
AddAttr<int>(
"frame_length",
"Length of the frame and `0 < frame_length <= x.shape[axis]`.");
AddAttr<int>("hop_length",
"Number of steps to advance between adjacent frames and "
"`0 < hop_length`.");
AddAttr<int>("axis",
"Specify the axis to operate on the input Tensors. Its value "
"should be 0(the first dimension) or -1(the last dimension).")
.SetDefault(-1);
AddComment(R"DOC(
Slice the N-dimensional (where N >= 1) input into (overlapping) frames.
)DOC");
}
};
class FrameOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "frame_grad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "frame_grad");
const auto x_dims = ctx->GetInputDim("X");
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(in_dtype, ctx.GetPlace());
}
};
template <typename T>
class FrameOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("frame_grad");
retv->SetInput("X", this->Input("X"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(frame, ops::FrameOp, ops::FrameOpMaker,
ops::FrameOpGradMaker<paddle::framework::OpDesc>,
ops::FrameOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(frame_grad, ops::FrameOpGrad);
REGISTER_OP_CPU_KERNEL(
frame, ops::FrameKernel<paddle::platform::CPUDeviceContext, int>,
ops::FrameKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::FrameKernel<paddle::platform::CPUDeviceContext, float>,
ops::FrameKernel<paddle::platform::CPUDeviceContext, double>,
ops::FrameKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::FrameKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
frame_grad, ops::FrameGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::FrameGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::FrameGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::FrameGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::FrameGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::FrameGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/frame_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
frame, ops::FrameKernel<paddle::platform::CUDADeviceContext, int>,
ops::FrameKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::FrameKernel<paddle::platform::CUDADeviceContext, float>,
ops::FrameKernel<paddle::platform::CUDADeviceContext, double>,
ops::FrameKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::FrameKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::FrameKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
frame_grad, ops::FrameGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::FrameGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::FrameGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::FrameGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::FrameGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::FrameGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::FrameGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/seq2col.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
struct FrameFunctor {
void operator()(const DeviceContext& dev_ctx, const Tensor* input,
Tensor* output, size_t seq_length, size_t frame_length,
size_t n_frames, size_t hop_length,
bool is_grad = false) const {
auto numel = output->numel();
const auto* input_data = input->data<T>();
auto* output_data = output->data<T>();
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
if (!is_grad) {
math::Seq2ColFunctor<T> functor(input_data, output_data, seq_length,
frame_length, n_frames, hop_length);
for_range(functor);
} else {
math::Col2SeqFunctor<T> functor(input_data, output_data, seq_length,
frame_length, n_frames, hop_length);
for_range(functor);
}
}
};
template <typename DeviceContext, typename T>
class FrameKernel : public framework::OpKernel<T> {
public:
/*
Frame kernel slices frames from input sequences. The main steps as follows:
- Case 1 - input dims == 1:
- axis is -1: Call a FrameFunctor to compute directly.
- axis is 0: Transpose output firstly, and then it falls into
case axis is -1. Finally, it restores the dims of
output tensor.
- Case 2 - input dims == 2:
- axis is -1: Call a FrameFunctor to compute directly.
- axis is 0: Transpose both input and output firstly, and then it falls
into case axis is -1. Finally, it restores the dims of
output tensor.
- Case 3 - input dims > 2:
Flatten the input and output to 2D and 3D respectively so that it
falls into Case 2. Finally, it restores the dims of output tensor.
*/
void Compute(const framework::ExecutionContext& ctx) const override {
const Tensor* x = ctx.Input<Tensor>("X");
Tensor* out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
const size_t x_rank = x->dims().size();
const size_t out_rank = out->dims().size();
const int frame_length = ctx.Attr<int>("frame_length");
const int hop_length = ctx.Attr<int>("hop_length");
const int axis = ctx.Attr<int>("axis");
const int n_frames =
(axis == 0) ? out->dims()[0] : out->dims()[out_rank - 1];
const int seq_length = (axis == 0) ? x->dims()[0] : x->dims()[x_rank - 1];
auto& dev_ctx = ctx.device_context<DeviceContext>();
// When the number of input dims is larger than 2, it needs to copy
// from x to resize input into 2d and output into 3d. Morevoer, output
// dims will be restored at the last step.
Tensor x_(x->type());
x_ = *x;
framework::DDim preserved_dims;
if (x_rank > 2) {
// Save dims used to flatten both input and output tensors and restore
// output tensor.
framework::DDim x_resized_dims;
framework::DDim out_resized_dims;
if (axis == 0) {
preserved_dims = framework::slice_ddim(x_.dims(), 1, x_rank);
x_resized_dims = {seq_length, framework::product(preserved_dims)};
out_resized_dims = {n_frames, frame_length,
framework::product(preserved_dims)};
} else {
preserved_dims = framework::slice_ddim(x_.dims(), 0, x_rank - 1);
x_resized_dims = {framework::product(preserved_dims), seq_length};
out_resized_dims = {framework::product(preserved_dims), frame_length,
n_frames};
}
x_.Resize(x_resized_dims);
out->Resize(out_resized_dims);
}
Tensor trans_x(x_.type());
Tensor trans_out(out->type());
// Transpose input and output in case that axis is 0.
if (axis == 0) {
if (x_rank == 1U) {
trans_x = x_;
std::vector<int> perm_out{1, 0};
auto out_dims_vec = framework::vectorize(out->dims());
for (int i = 0; i < out->dims().size(); ++i) {
out_dims_vec[i] = out->dims()[perm_out[i]];
}
trans_out.Resize(framework::make_ddim(out_dims_vec));
trans_out.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(perm_out.size(), dev_ctx, *out,
&trans_out, perm_out);
} else {
std::vector<int> perm_x{1, 0};
auto x_dims_vec = framework::vectorize(x_.dims());
for (int i = 0; i < x_.dims().size(); ++i) {
x_dims_vec[i] = x_.dims()[perm_x[i]];
}
trans_x.Resize(framework::make_ddim(x_dims_vec));
trans_x.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(perm_x.size(), dev_ctx, x_, &trans_x,
perm_x);
std::vector<int> perm_out{2, 1, 0};
auto out_dims_vec = framework::vectorize(out->dims());
for (int i = 0; i < out->dims().size(); ++i) {
out_dims_vec[i] = out->dims()[perm_out[i]];
}
trans_out.Resize(framework::make_ddim(out_dims_vec));
trans_out.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(perm_out.size(), dev_ctx, *out,
&trans_out, perm_out);
}
} else {
trans_x = x_;
trans_out = *out;
}
FrameFunctor<DeviceContext, T>()(dev_ctx, &trans_x, &trans_out, seq_length,
frame_length, n_frames, hop_length,
/*is_grad*/ false);
// Transpose output in case axis is 0.
if (axis == 0) {
if (x_rank == 1U) {
std::vector<int> perm_out{1, 0};
TransCompute<DeviceContext, T>(perm_out.size(), dev_ctx, trans_out, out,
perm_out);
} else {
std::vector<int> perm_out{2, 1, 0};
TransCompute<DeviceContext, T>(perm_out.size(), dev_ctx, trans_out, out,
perm_out);
}
}
// Restore output dims when the number of dims is larger than 2.
if (x_rank > 2) {
std::vector<int64_t> restored_out_shape;
for (int i = 0; i < preserved_dims.size(); i++) {
restored_out_shape.push_back(preserved_dims[i]);
}
if (axis == 0) {
// (n_frames, frame_length, ...)
restored_out_shape.insert(restored_out_shape.begin(), frame_length);
restored_out_shape.insert(restored_out_shape.begin(), n_frames);
} else {
// (..., frame_length, n_frames)
restored_out_shape.push_back(frame_length);
restored_out_shape.push_back(n_frames);
}
out->Resize(framework::make_ddim(restored_out_shape));
}
}
};
template <typename DeviceContext, typename T>
class FrameGradKernel : public framework::OpKernel<T> {
public:
/*
Frame gradient kernel accumulate gradient `d_x` from `d_out`. The
main steps as follows:
- Case 1 - d_x dims == 1:
- axis is -1: Call a FrameFunctor to compute directly. Notes that
`is_grad` is set to true to select gradient data functor.
- axis is 0: Transpose `d_out` firstly, and then it falls into
case axis is -1.
- Case 2 - d_x dims == 2:
- axis is -1: Call a FrameFunctor to compute directly.
- axis is 0: Transpose both `d_x` and `d_out` firstly, and then it
falls into case axis is -1. Finally, it restores the
dims of `d_x`.
- Case 3 - d_x dims > 2:
Flatten the `d_x` and `d_out` to 2D and 3D respectively so that it
falls into Case 2. Finally, it restores the dims of `d_x` tensor.
*/
void Compute(const framework::ExecutionContext& ctx) const {
const Tensor* d_out = ctx.Input<Tensor>(framework::GradVarName("Out"));
Tensor* d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
d_x->mutable_data<T>(ctx.GetPlace());
const size_t d_out_rank = d_out->dims().size();
const size_t d_x_rank = d_x->dims().size();
const int frame_length = ctx.Attr<int>("frame_length");
const int hop_length = ctx.Attr<int>("hop_length");
const int axis = ctx.Attr<int>("axis");
const int n_frames =
(axis == 0) ? d_out->dims()[0] : d_out->dims()[d_out_rank - 1];
const int seq_length =
(axis == 0) ? d_x->dims()[0] : d_x->dims()[d_x_rank - 1];
auto& dev_ctx = ctx.device_context<DeviceContext>();
Tensor d_out_(d_out->type());
d_out_ = *d_out;
framework::DDim preserved_dims;
if (d_x_rank > 2) {
// Save dims used to flatten both input and output tensors and restore
// output tensor.
framework::DDim d_x_resized_dims;
framework::DDim d_out_resized_dims;
if (axis == 0) {
preserved_dims = framework::slice_ddim(d_x->dims(), 1, d_x_rank);
d_x_resized_dims = {seq_length, framework::product(preserved_dims)};
d_out_resized_dims = {n_frames, frame_length,
framework::product(preserved_dims)};
} else {
preserved_dims = framework::slice_ddim(d_x->dims(), 0, d_x_rank - 1);
d_x_resized_dims = {framework::product(preserved_dims), seq_length};
d_out_resized_dims = {framework::product(preserved_dims), frame_length,
n_frames};
}
d_x->Resize(d_x_resized_dims);
d_out_.Resize(d_out_resized_dims);
}
Tensor trans_d_x(d_x->type());
Tensor trans_d_out(d_out_.type());
// Transpose input and output in case that axis is 0.
if (axis == 0) {
if (d_x_rank == 1U) {
trans_d_x = *d_x;
std::vector<int> perm_d_out{1, 0};
auto d_out_dims_vec = framework::vectorize(d_out_.dims());
for (int i = 0; i < d_out_.dims().size(); ++i) {
d_out_dims_vec[i] = d_out_.dims()[perm_d_out[i]];
}
trans_d_out.Resize(framework::make_ddim(d_out_dims_vec));
trans_d_out.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(perm_d_out.size(), dev_ctx, d_out_,
&trans_d_out, perm_d_out);
} else {
std::vector<int> perm_d_x{1, 0};
auto d_x_dims_vec = framework::vectorize(d_x->dims());
for (int i = 0; i < d_x->dims().size(); ++i) {
d_x_dims_vec[i] = d_x->dims()[perm_d_x[i]];
}
trans_d_x.Resize(framework::make_ddim(d_x_dims_vec));
trans_d_x.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(perm_d_x.size(), dev_ctx, *d_x,
&trans_d_x, perm_d_x);
std::vector<int> perm_d_out{2, 1, 0};
auto d_out_dims_vec = framework::vectorize(d_out_.dims());
for (int i = 0; i < d_out_.dims().size(); ++i) {
d_out_dims_vec[i] = d_out_.dims()[perm_d_out[i]];
}
trans_d_out.Resize(framework::make_ddim(d_out_dims_vec));
trans_d_out.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(perm_d_out.size(), dev_ctx, d_out_,
&trans_d_out, perm_d_out);
}
} else {
trans_d_x = *d_x;
trans_d_out = d_out_;
}
FrameFunctor<DeviceContext, T>()(dev_ctx, &trans_d_out, &trans_d_x,
seq_length, frame_length, n_frames,
hop_length,
/*is_grad*/ true);
// Transpose output in case axis is 0.
if (axis == 0 && d_x_rank > 1U) {
std::vector<int> perm_d_x{1, 0};
TransCompute<DeviceContext, T>(perm_d_x.size(), dev_ctx, trans_d_x, d_x,
perm_d_x);
}
// Restore output dims when the number of dims is larger than 2.
if (d_x_rank > 2) {
std::vector<int64_t> restored_d_x_shape;
for (int i = 0; i < preserved_dims.size(); i++) {
restored_d_x_shape.push_back(preserved_dims[i]);
}
if (axis == 0) {
// (seq_length, ...)
restored_d_x_shape.insert(restored_d_x_shape.begin(), seq_length);
} else {
// (..., seq_length)
restored_d_x_shape.push_back(seq_length);
}
d_x->Resize(framework::make_ddim(restored_d_x_shape));
}
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace paddle {
namespace operators {
namespace math {
template <typename T>
struct Seq2ColFunctor {
Seq2ColFunctor(const T* seq, T* col, size_t seq_length, size_t frame_length,
size_t n_frames, size_t hop_length)
: seq_(seq),
col_(col),
seq_length_(seq_length),
frame_length_(frame_length),
n_frames_(n_frames),
hop_length_(hop_length) {}
/*
Convert sequences to frames.
1. Dimension infomation:
Sequences Frames
(N, seq_length) -> (N, frame_length, n_frames)
2. Mapping from `i` to `src_idx` and `trg_idx` can be derived from:
a. Notion
- `i` stands for the flattened index of a bunch of frames.
- `src_idx` and `trg_idx` are the 1D indices of seqs and frames
respectivly.
b. Sample idx
```cpp
sample_idx = i / (n_frames_ * frame_length_);
```
c. Maps `i` to `f` and `n`.
```cpp
f = i % (n_frames_ * frame_length_) / n_frames_;
n = i % (n_frames_ * frame_length_) % n_frames_;
```
d. Replace `sample_idx`, `f` and `n` in the following eqations:
```cpp
src_idx = sample_idx * seq_length_ + n * hop_length_ + f;
trg_idx = sample_idx * n_frames_ * frame_length_ + f * n_frames_ + n;
col_[trg_idx] = seq_[src_idx];
```
e. Result can be deduced shown in the function body below.
*/
HOSTDEVICE void operator()(size_t i) const {
size_t src_idx;
size_t trg_idx;
src_idx = i / (n_frames_ * frame_length_) * seq_length_ +
i % (n_frames_ * frame_length_) % n_frames_ * hop_length_ +
i % (n_frames_ * frame_length_) / n_frames_;
trg_idx = i / (n_frames_ * frame_length_) * n_frames_ * frame_length_ +
i % (n_frames_ * frame_length_) / n_frames_ * n_frames_ +
i % (n_frames_ * frame_length_) % n_frames_;
col_[trg_idx] = seq_[src_idx];
}
const T* seq_;
T* col_;
size_t seq_length_;
size_t frame_length_;
size_t n_frames_;
size_t hop_length_;
};
template <typename T>
struct Col2SeqFunctor {
Col2SeqFunctor(const T* col, T* seq, size_t seq_length, size_t frame_length,
size_t n_frames, size_t hop_length)
: col_(col),
seq_(seq),
seq_length_(seq_length),
frame_length_(frame_length),
n_frames_(n_frames),
hop_length_(hop_length) {}
/*
Accumulate output gradient d_out to d_x.
1. Dimension infomation:
d_out d_x
(N, frame_length, n_frames) -> (N, seq_length)
2. Using a sliding window to find source indices from `d_out` according to
`i`:
a. Notion
- `i` stands for the flattened index of `d_x`.
- `seq_i` stands for a relative index of a `d_x` sample.
- `left`: Starting index of a frame window.
- `right`: Ending index of a frame window.
b. Sample idx
```cpp
sample_idx = i / seq_length_;
```
c. Slides a window with length of `frame_length` to find `f` and `n`.
- `n`: The idx of num_frames_, increases in each hop.
- `f`: The idx of frame_lengths_, relative idx from left of a sliding
window.
d. Accumulate all grads from d_out.
```cpp
seq_[i] +=
col_[sample_idx * frame_length_ * n_frames_ + f * n_frames_ + n];
```
*/
HOSTDEVICE void operator()(size_t i) const {
size_t sample_idx = i / seq_length_;
size_t seq_i = i % seq_length_;
// Sliding window
seq_[i] = 0; // Init seq_[i] to 0, and sums up all
// grads from col_ in the while loop.
size_t n = get_start_frame_idx(seq_i);
size_t f;
size_t left = n * hop_length_;
size_t right = left + frame_length_ - 1;
while (left <= seq_i && right < seq_length_) {
f = seq_i - left;
seq_[i] +=
col_[sample_idx * frame_length_ * n_frames_ + f * n_frames_ + n];
// Next frame.
left += hop_length_;
right += hop_length_;
n += 1;
}
}
/*
Calculate minimum value of frame index `n` to satisfy the inequality:
seq_i <= right
==> seq_i <= left + frame_length - 1
==> seq_i <= hop_length_ * n + frame_length_ - 1
*/
HOSTDEVICE size_t get_start_frame_idx(size_t seq_i) const {
int64_t tmp = seq_i + 1 - frame_length_;
if (tmp > 0) {
size_t n = tmp / hop_length_;
if (tmp % hop_length_ == 0) {
return n;
} else {
return n + 1;
}
} else {
return 0;
}
}
const T* col_;
T* seq_;
size_t seq_length_;
size_t frame_length_;
size_t n_frames_;
size_t hop_length_;
};
} // namespace math
} // namespace operators
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/overlap_add_op.h"
namespace paddle {
namespace operators {
class OverlapAddOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "overlap_add");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "overlap_add");
const int hop_length = ctx->Attrs().Get<int>("hop_length");
const int axis = ctx->Attrs().Get<int>("axis");
const auto x_dims = ctx->GetInputDim("X");
const int x_rank = x_dims.size();
PADDLE_ENFORCE_GE(
x_rank, 2,
platform::errors::InvalidArgument(
"Input(X) of OverlapAddOp should be a tensor which contains "
"at least 2 dimensions, but got rank %s.",
x_rank));
PADDLE_ENFORCE_GT(
hop_length, 0,
platform::errors::InvalidArgument(
"Attribute(hop_length) of OverlapAddOp should be greater "
"than 0, but got %s.",
hop_length));
PADDLE_ENFORCE_EQ(
(axis == 0 || axis == -1), true,
platform::errors::InvalidArgument(
"Attribute(axis) of OverlapAddOp should 0 or -1, but got %s.",
axis));
std::vector<int64_t> output_shape;
int n_frames;
int frame_length;
int start_axis;
int end_axis;
if (axis == 0) {
n_frames = x_dims[0];
frame_length = x_dims[1];
start_axis = 2;
end_axis = x_rank - 1;
} else {
n_frames = x_dims[x_rank - 1];
frame_length = x_dims[x_rank - 2];
start_axis = 0;
end_axis = x_rank - 3;
}
PADDLE_ENFORCE_LE(
hop_length, frame_length,
platform::errors::InvalidArgument(
"Attribute(hop_length) of OverlapAddOp should be less or equal "
"than frame_length, but got hop_length(%s) > frame_length(%s).",
hop_length, frame_length));
const int seq_length = (n_frames - 1) * hop_length + frame_length;
// It won't go into for loop when x_rank == 2U.
for (int i = start_axis; i <= end_axis; i++) {
output_shape.push_back(x_dims[i]);
}
if (axis == 0) {
// (seq_length, ...)
output_shape.insert(output_shape.begin(), seq_length);
} else {
// (..., seq_length)
output_shape.push_back(seq_length);
}
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(in_dtype, ctx.GetPlace());
}
};
class OverlapAddOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of overlap_add op.");
AddOutput("Out", "(Tensor), The output tensor of overlap_add op.");
AddAttr<int>("hop_length",
"Number of steps to advance between adjacent frames and "
"`0 < hop_length <= frame_length`.");
AddAttr<int>("axis",
"Specify the axis to operate on the input Tensors. Its value "
"should be 0(the first dimension) or -1(the last dimension).")
.SetDefault(-1);
AddComment(R"DOC(
Reconstructs a tensor consisted of overlap added sequences from input frames.
)DOC");
}
};
class OverlapAddOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "overlap_add_grad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "overlap_add_grad");
const auto x_dims = ctx->GetInputDim("X");
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(in_dtype, ctx.GetPlace());
}
};
template <typename T>
class OverlapAddOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("overlap_add_grad");
retv->SetInput("X", this->Input("X"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(overlap_add, ops::OverlapAddOp, ops::OverlapAddOpMaker,
ops::OverlapAddOpGradMaker<paddle::framework::OpDesc>,
ops::OverlapAddOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(overlap_add_grad, ops::OverlapAddOpGrad);
REGISTER_OP_CPU_KERNEL(
overlap_add, ops::OverlapAddKernel<paddle::platform::CPUDeviceContext, int>,
ops::OverlapAddKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::OverlapAddKernel<paddle::platform::CPUDeviceContext, float>,
ops::OverlapAddKernel<paddle::platform::CPUDeviceContext, double>,
ops::OverlapAddKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::OverlapAddKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
overlap_add_grad,
ops::OverlapAddGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::OverlapAddGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::OverlapAddGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::OverlapAddGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::OverlapAddGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::OverlapAddGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/overlap_add_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
overlap_add,
ops::OverlapAddKernel<paddle::platform::CUDADeviceContext, int>,
ops::OverlapAddKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::OverlapAddKernel<paddle::platform::CUDADeviceContext, float>,
ops::OverlapAddKernel<paddle::platform::CUDADeviceContext, double>,
ops::OverlapAddKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::OverlapAddKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::OverlapAddKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
overlap_add_grad,
ops::OverlapAddGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::OverlapAddGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::OverlapAddGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::OverlapAddGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::OverlapAddGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::OverlapAddGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::OverlapAddGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/seq2col.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
struct OverlapAddFunctor {
void operator()(const DeviceContext& dev_ctx, const Tensor* input,
Tensor* output, size_t seq_length, size_t frame_length,
size_t n_frames, size_t hop_length,
bool is_grad = false) const {
auto numel = output->numel();
const auto* input_data = input->data<T>();
auto* output_data = output->data<T>();
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
if (!is_grad) {
math::Col2SeqFunctor<T> functor(input_data, output_data, seq_length,
frame_length, n_frames, hop_length);
for_range(functor);
} else {
math::Seq2ColFunctor<T> functor(input_data, output_data, seq_length,
frame_length, n_frames, hop_length);
for_range(functor);
}
}
};
template <typename DeviceContext, typename T>
class OverlapAddKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
const Tensor* x = ctx.Input<Tensor>("X");
Tensor* out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
const size_t x_rank = x->dims().size();
const size_t out_rank = out->dims().size();
const int hop_length = ctx.Attr<int>("hop_length");
const int axis = ctx.Attr<int>("axis");
const int n_frames = (axis == 0) ? x->dims()[0] : x->dims()[x_rank - 1];
const int frame_length = (axis == 0) ? x->dims()[1] : x->dims()[x_rank - 2];
const int seq_length =
(axis == 0) ? out->dims()[0] : out->dims()[out_rank - 1];
auto& dev_ctx = ctx.device_context<DeviceContext>();
Tensor x_(x->type());
x_ = *x;
framework::DDim preserved_dims;
if (out_rank > 2) {
// Save dims used to flatten both input and output tensors and restore
// output tensor.
framework::DDim x_resized_dims;
framework::DDim out_resized_dims;
if (axis == 0) {
preserved_dims = framework::slice_ddim(out->dims(), 1, out_rank);
x_resized_dims = {n_frames, frame_length,
framework::product(preserved_dims)};
out_resized_dims = {seq_length, framework::product(preserved_dims)};
} else {
preserved_dims = framework::slice_ddim(out->dims(), 0, out_rank - 1);
x_resized_dims = {framework::product(preserved_dims), frame_length,
n_frames};
out_resized_dims = {framework::product(preserved_dims), seq_length};
}
x_.Resize(x_resized_dims);
out->Resize(out_resized_dims);
}
Tensor trans_x(x_.type());
Tensor trans_out(out->type());
// Transpose input and output in case that axis is 0.
if (axis == 0) {
if (out_rank == 1U) {
trans_out = *out;
std::vector<int> perm_x{1, 0};
auto x_dims_vec = framework::vectorize(x_.dims());
for (int i = 0; i < x_.dims().size(); ++i) {
x_dims_vec[i] = x_.dims()[perm_x[i]];
}
trans_x.Resize(framework::make_ddim(x_dims_vec));
trans_x.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(perm_x.size(), dev_ctx, x_, &trans_x,
perm_x);
} else {
std::vector<int> perm_out{1, 0};
auto out_dims_vec = framework::vectorize(out->dims());
for (int i = 0; i < out->dims().size(); ++i) {
out_dims_vec[i] = out->dims()[perm_out[i]];
}
trans_out.Resize(framework::make_ddim(out_dims_vec));
trans_out.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(perm_out.size(), dev_ctx, *out,
&trans_out, perm_out);
std::vector<int> perm_x{2, 1, 0};
auto x_dims_vec = framework::vectorize(x_.dims());
for (int i = 0; i < x_.dims().size(); ++i) {
x_dims_vec[i] = x_.dims()[perm_x[i]];
}
trans_x.Resize(framework::make_ddim(x_dims_vec));
trans_x.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(perm_x.size(), dev_ctx, x_, &trans_x,
perm_x);
}
} else {
trans_x = x_;
trans_out = *out;
}
OverlapAddFunctor<DeviceContext, T>()(dev_ctx, &trans_x, &trans_out,
seq_length, frame_length, n_frames,
hop_length, /*is_grad*/ false);
// Transpose output in case axis is 0.
if (axis == 0 && out_rank > 1U) {
std::vector<int> perm_out{1, 0};
TransCompute<DeviceContext, T>(perm_out.size(), dev_ctx, trans_out, out,
perm_out);
}
// Restore output dims when the number of dims is larger than 2.
if (out_rank > 2) {
std::vector<int64_t> restored_out_shape;
for (int i = 0; i < preserved_dims.size(); i++) {
restored_out_shape.push_back(preserved_dims[i]);
}
if (axis == 0) {
// (seq_length, ...)
restored_out_shape.insert(restored_out_shape.begin(), seq_length);
} else {
// (..., seq_length)
restored_out_shape.push_back(seq_length);
}
out->Resize(framework::make_ddim(restored_out_shape));
}
}
};
template <typename DeviceContext, typename T>
class OverlapAddGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const Tensor* d_out = ctx.Input<Tensor>(framework::GradVarName("Out"));
Tensor* d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
d_x->mutable_data<T>(ctx.GetPlace());
const size_t d_out_rank = d_out->dims().size();
const size_t d_x_rank = d_x->dims().size();
const int hop_length = ctx.Attr<int>("hop_length");
const int axis = ctx.Attr<int>("axis");
const int n_frames =
(axis == 0) ? d_x->dims()[0] : d_x->dims()[d_x_rank - 1];
const int frame_length =
(axis == 0) ? d_x->dims()[1] : d_x->dims()[d_x_rank - 2];
const int seq_length =
(axis == 0) ? d_out->dims()[0] : d_out->dims()[d_out_rank - 1];
auto& dev_ctx = ctx.device_context<DeviceContext>();
// When the number of input dims is larger than 2, it needs to copy
// from x to resize input into 2d and output into 3d. Morevoer, output
// dims will be restored at the last step.
Tensor d_out_(d_out->type());
d_out_ = *d_out;
framework::DDim preserved_dims;
if (d_out_rank > 2) {
// Save dims used to flatten both input and output tensors and restore
// output tensor.
framework::DDim d_x_resized_dims;
framework::DDim d_out_resized_dims;
if (axis == 0) {
preserved_dims = framework::slice_ddim(d_out_.dims(), 1, d_out_rank);
d_x_resized_dims = {n_frames, frame_length,
framework::product(preserved_dims)};
d_out_resized_dims = {seq_length, framework::product(preserved_dims)};
} else {
preserved_dims =
framework::slice_ddim(d_out_.dims(), 0, d_out_rank - 1);
d_x_resized_dims = {framework::product(preserved_dims), frame_length,
n_frames};
d_out_resized_dims = {framework::product(preserved_dims), seq_length};
}
d_x->Resize(d_x_resized_dims);
d_out_.Resize(d_out_resized_dims);
}
Tensor trans_d_x(d_x->type());
Tensor trans_d_out(d_out_.type());
// Transpose input and output in case that axis is 0.
if (axis == 0) {
if (d_out_rank == 1U) {
trans_d_out = d_out_;
std::vector<int> perm_d_x{1, 0};
auto d_x_dims_vec = framework::vectorize(d_x->dims());
for (int i = 0; i < d_x->dims().size(); ++i) {
d_x_dims_vec[i] = d_x->dims()[perm_d_x[i]];
}
trans_d_x.Resize(framework::make_ddim(d_x_dims_vec));
trans_d_x.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(perm_d_x.size(), dev_ctx, *d_x,
&trans_d_x, perm_d_x);
} else {
std::vector<int> perm_d_out{1, 0};
auto d_out_dims_vec = framework::vectorize(d_out_.dims());
for (int i = 0; i < d_out_.dims().size(); ++i) {
d_out_dims_vec[i] = d_out_.dims()[perm_d_out[i]];
}
trans_d_out.Resize(framework::make_ddim(d_out_dims_vec));
trans_d_out.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(perm_d_out.size(), dev_ctx, d_out_,
&trans_d_out, perm_d_out);
std::vector<int> perm_d_x{2, 1, 0};
auto d_x_dims_vec = framework::vectorize(d_x->dims());
for (int i = 0; i < d_x->dims().size(); ++i) {
d_x_dims_vec[i] = d_x->dims()[perm_d_x[i]];
}
trans_d_x.Resize(framework::make_ddim(d_x_dims_vec));
trans_d_x.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(perm_d_x.size(), dev_ctx, *d_x,
&trans_d_x, perm_d_x);
}
} else {
trans_d_x = *d_x;
trans_d_out = d_out_;
}
OverlapAddFunctor<DeviceContext, T>()(dev_ctx, &trans_d_out, &trans_d_x,
seq_length, frame_length, n_frames,
hop_length,
/*is_grad*/ true);
// Transpose output in case axis is 0.
if (axis == 0) {
if (d_out_rank == 1U) {
std::vector<int> perm_d_x{1, 0};
TransCompute<DeviceContext, T>(perm_d_x.size(), dev_ctx, trans_d_x, d_x,
perm_d_x);
} else {
std::vector<int> perm_d_x{2, 1, 0};
TransCompute<DeviceContext, T>(perm_d_x.size(), dev_ctx, trans_d_x, d_x,
perm_d_x);
}
}
// Restore output dims when the number of dims is larger than 2.
if (d_out_rank > 2) {
std::vector<int64_t> restored_d_x_shape;
for (int i = 0; i < preserved_dims.size(); i++) {
restored_d_x_shape.push_back(preserved_dims[i]);
}
if (axis == 0) {
// (n_frames, frame_length, ...)
restored_d_x_shape.insert(restored_d_x_shape.begin(), frame_length);
restored_d_x_shape.insert(restored_d_x_shape.begin(), n_frames);
} else {
// (..., frame_length, n_frames)
restored_d_x_shape.push_back(frame_length);
restored_d_x_shape.push_back(n_frames);
}
d_x->Resize(framework::make_ddim(restored_d_x_shape));
}
}
};
} // namespace operators
} // namespace paddle
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/pad_op.h" #include "paddle/fluid/operators/pad_op.h"
#include <memory> #include <memory>
#include "paddle/fluid/platform/complex.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -170,10 +171,18 @@ REGISTER_OP_CPU_KERNEL( ...@@ -170,10 +171,18 @@ REGISTER_OP_CPU_KERNEL(
pad, ops::PadKernel<paddle::platform::CPUDeviceContext, float>, pad, ops::PadKernel<paddle::platform::CPUDeviceContext, float>,
ops::PadKernel<paddle::platform::CPUDeviceContext, double>, ops::PadKernel<paddle::platform::CPUDeviceContext, double>,
ops::PadKernel<paddle::platform::CPUDeviceContext, int>, ops::PadKernel<paddle::platform::CPUDeviceContext, int>,
ops::PadKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::PadKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::PadKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::PadKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
pad_grad, ops::PadGradKernel<paddle::platform::CPUDeviceContext, float>, pad_grad, ops::PadGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::PadGradKernel<paddle::platform::CPUDeviceContext, double>); ops::PadGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::PadGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::PadGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
pad, ops::PadKernel<paddle::platform::CUDADeviceContext, double>, pad, ops::PadKernel<paddle::platform::CUDADeviceContext, double>,
...@@ -181,9 +190,17 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -181,9 +190,17 @@ REGISTER_OP_CUDA_KERNEL(
ops::PadKernel<paddle::platform::CUDADeviceContext, int>, ops::PadKernel<paddle::platform::CUDADeviceContext, int>,
ops::PadKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::PadKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::PadKernel<paddle::platform::CUDADeviceContext, ops::PadKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>); paddle::platform::float16>,
ops::PadKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::PadKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
pad_grad, ops::PadGradKernel<paddle::platform::CUDADeviceContext, double>, pad_grad, ops::PadGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::PadGradKernel<paddle::platform::CUDADeviceContext, float>, ops::PadGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::PadGradKernel<paddle::platform::CUDADeviceContext, ops::PadGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>); paddle::platform::float16>,
ops::PadGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::PadGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/complex.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -148,12 +149,20 @@ REGISTER_OP_CPU_KERNEL( ...@@ -148,12 +149,20 @@ REGISTER_OP_CPU_KERNEL(
roll, ops::RollKernel<paddle::platform::CPUDeviceContext, float>, roll, ops::RollKernel<paddle::platform::CPUDeviceContext, float>,
ops::RollKernel<paddle::platform::CPUDeviceContext, double>, ops::RollKernel<paddle::platform::CPUDeviceContext, double>,
ops::RollKernel<paddle::platform::CPUDeviceContext, int>, ops::RollKernel<paddle::platform::CPUDeviceContext, int>,
ops::RollKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::RollKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::RollKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::RollKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
roll_grad, ops::RollGradKernel<paddle::platform::CPUDeviceContext, float>, roll_grad, ops::RollGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::RollGradKernel<paddle::platform::CPUDeviceContext, double>, ops::RollGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::RollGradKernel<paddle::platform::CPUDeviceContext, int>, ops::RollGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::RollGradKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::RollGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::RollGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::RollGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_VERSION(roll) REGISTER_OP_VERSION(roll)
.AddCheckpoint( .AddCheckpoint(
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "paddle/fluid/framework/array.h" #include "paddle/fluid/framework/array.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/roll_op.h" #include "paddle/fluid/operators/roll_op.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle { namespace paddle {
...@@ -188,9 +189,17 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -188,9 +189,17 @@ REGISTER_OP_CUDA_KERNEL(
roll, ops::RollKernel<paddle::platform::CUDADeviceContext, float>, roll, ops::RollKernel<paddle::platform::CUDADeviceContext, float>,
ops::RollKernel<paddle::platform::CUDADeviceContext, double>, ops::RollKernel<paddle::platform::CUDADeviceContext, double>,
ops::RollKernel<paddle::platform::CUDADeviceContext, int>, ops::RollKernel<paddle::platform::CUDADeviceContext, int>,
ops::RollKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::RollKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::RollKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::RollKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
roll_grad, ops::RollGradKernel<paddle::platform::CUDADeviceContext, float>, roll_grad, ops::RollGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::RollGradKernel<paddle::platform::CUDADeviceContext, double>, ops::RollGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::RollGradKernel<paddle::platform::CUDADeviceContext, int>, ops::RollGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::RollGradKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::RollGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::RollGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::RollGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/shape_op.h" #include "paddle/fluid/operators/shape_op.h"
#include <string> #include <string>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/complex.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -64,6 +65,7 @@ Return the shape of the input. ...@@ -64,6 +65,7 @@ Return the shape of the input.
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR( REGISTER_OPERATOR(
shape, ops::ShapeOp, ops::ShapeOpMaker, shape, ops::ShapeOp, ops::ShapeOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
...@@ -71,4 +73,6 @@ REGISTER_OPERATOR( ...@@ -71,4 +73,6 @@ REGISTER_OPERATOR(
REGISTER_OP_CPU_KERNEL(shape, ops::ShapeKernel<bool>, ops::ShapeKernel<int>, REGISTER_OP_CPU_KERNEL(shape, ops::ShapeKernel<bool>, ops::ShapeKernel<int>,
ops::ShapeKernel<int8_t>, ops::ShapeKernel<uint8_t>, ops::ShapeKernel<int8_t>, ops::ShapeKernel<uint8_t>,
ops::ShapeKernel<int64_t>, ops::ShapeKernel<float>, ops::ShapeKernel<int64_t>, ops::ShapeKernel<float>,
ops::ShapeKernel<double>); ops::ShapeKernel<double>,
ops::ShapeKernel<plat::complex<float>>,
ops::ShapeKernel<plat::complex<double>>);
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/shape_op.h" #include "paddle/fluid/operators/shape_op.h"
#include "paddle/fluid/platform/complex.h"
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
shape, paddle::operators::ShapeKernel<bool>, shape, paddle::operators::ShapeKernel<bool>,
...@@ -21,4 +22,6 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -21,4 +22,6 @@ REGISTER_OP_CUDA_KERNEL(
paddle::operators::ShapeKernel<int64_t>, paddle::operators::ShapeKernel<int64_t>,
paddle::operators::ShapeKernel<float>, paddle::operators::ShapeKernel<float>,
paddle::operators::ShapeKernel<double>, paddle::operators::ShapeKernel<double>,
paddle::operators::ShapeKernel<paddle::platform::float16>); paddle::operators::ShapeKernel<paddle::platform::float16>,
paddle::operators::ShapeKernel<paddle::platform::complex<float>>,
paddle::operators::ShapeKernel<paddle::platform::complex<double>>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/spectral_op.h"
#include <algorithm>
#include <functional>
#include <memory>
#include <numeric>
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/complex.h"
#if defined(PADDLE_WITH_ONEMKL)
#include <mkl_dfti.h>
#elif defined(PADDLE_WITH_POCKETFFT)
#include "extern_pocketfft/pocketfft_hdronly.h"
#endif
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
// FFTC2C
class FFTC2COpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), the input tensor of fft_c2c op.");
AddOutput("Out", "(Tensor), the output tensor of fft_c2c op.");
AddAttr<std::vector<int64_t>>("axes",
"std::vector<int64_t>, the fft axes.");
AddAttr<std::string>("normalization",
"fft_norm_type, the fft normalization type.");
AddAttr<bool>("forward", "bool, the fft direction.");
AddComment(R"DOC(
Compute complex to complex FFT.
)DOC");
}
};
class FFTC2COp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fft_c2c");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "fft_c2c");
const auto axes = ctx->Attrs().Get<std::vector<int64_t>>("axes");
const auto x_dim = ctx->GetInputDim("X");
for (size_t i = 0; i < axes.size(); i++) {
PADDLE_ENFORCE_GT(x_dim[axes[i]], 0,
platform::errors::InvalidArgument(
"Invalid fft n-point (%d).", x_dim[axes[i]]));
}
ctx->ShareDim("X", /*->*/ "Out"); // only for c2c
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
const auto kernel_dtype = framework::ToRealType(in_dtype);
return framework::OpKernelType(kernel_dtype, ctx.GetPlace());
}
};
template <typename T>
class FFTC2CGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("fft_c2c_grad");
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
}
};
class FFTC2CGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
const auto out_grad_name = framework::GradVarName("Out");
OP_INOUT_CHECK(ctx->HasInput(out_grad_name), "Input", out_grad_name,
"fft_c2c_grad");
const auto x_grad_name = framework::GradVarName("X");
OP_INOUT_CHECK(ctx->HasOutput(x_grad_name), "Output", x_grad_name,
"fft_c2c_grad");
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim(out_grad_name));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const auto in_dtype = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
const auto kernel_dtype = framework::ToRealType(in_dtype);
return framework::OpKernelType(kernel_dtype, ctx.GetPlace());
}
};
// FFTR2C
class FFTR2COpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), the input tensor of fft_r2c op.");
AddOutput("Out", "(Tensor), the output tensor of fft_r2c op.");
AddAttr<std::vector<int64_t>>("axes",
"std::vector<int64_t>, the fft axes.");
AddAttr<std::string>("normalization",
"fft_norm_type, the fft normalization type.");
AddAttr<bool>("forward", "bool, the fft direction.");
AddAttr<bool>("onesided", "bool, perform onesided fft.");
AddComment(R"DOC(
Compute real to complex FFT.
)DOC");
}
};
class FFTR2COp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fft_r2c");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "fft_r2c");
const auto axes = ctx->Attrs().Get<std::vector<int64_t>>("axes");
const auto x_dim = ctx->GetInputDim("X");
for (size_t i = 0; i < axes.size() - 1L; i++) {
PADDLE_ENFORCE_GT(x_dim[axes[i]], 0,
platform::errors::InvalidArgument(
"Invalid fft n-point (%d).", x_dim[axes[i]]));
}
const bool onesided = ctx->Attrs().Get<bool>("onesided");
if (!onesided) {
ctx->ShareDim("X", /*->*/ "Out");
} else {
framework::DDim out_dim(ctx->GetInputDim("X"));
const int64_t last_fft_axis = axes.back();
const int64_t last_fft_dim_size = out_dim.at(last_fft_axis);
out_dim.at(last_fft_axis) = last_fft_dim_size / 2 + 1;
ctx->SetOutputDim("Out", out_dim);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(in_dtype, ctx.GetPlace());
}
};
template <typename T>
class FFTR2CGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("fft_r2c_grad");
grad_op->SetInput("X", this->Input("X"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
}
};
class FFTR2CGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
const auto out_grad_name = framework::GradVarName("Out");
OP_INOUT_CHECK(ctx->HasInput(out_grad_name), "Input", out_grad_name,
"fft_r2c_grad");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fft_r2c_grad");
const auto x_grad_name = framework::GradVarName("X");
OP_INOUT_CHECK(ctx->HasOutput(x_grad_name), "Output", x_grad_name,
"fft_r2c_grad");
ctx->ShareDim("X", /*->*/ x_grad_name);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const auto in_dtype = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
const auto kernel_dtype = framework::ToRealType(in_dtype);
return framework::OpKernelType(kernel_dtype, ctx.GetPlace());
}
};
// FFTC2R
class FFTC2ROpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), the input tensor of fft_c2r op.");
AddOutput("Out", "(Tensor), the output tensor of fft_c2r op.");
AddAttr<std::vector<int64_t>>("axes",
"std::vector<int64_t>, the fft axes.");
AddAttr<std::string>("normalization",
"fft_norm_type, the fft normalization type.");
AddAttr<bool>("forward", "bool, the fft direction.");
AddAttr<int64_t>(
"last_dim_size", "int",
"Length of the transformed "
"axis of the output. For n output points, last_dim_size//2 + 1 input"
" points are necessary. If the input is longer than this,"
" it is cropped. If it is shorter than this, it is padded"
" with zeros. If last_dim_size is not given, it is taken to be 2*(m-1)"
" where m is the length of the input along the axis "
"specified by axis.")
.SetDefault(0L);
AddComment(R"DOC(
Compute complex to complex FFT.
)DOC");
}
};
class FFTC2ROp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fft_c2r");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "fft_c2r");
const auto axes = ctx->Attrs().Get<std::vector<int64_t>>("axes");
const auto x_dim = ctx->GetInputDim("X");
for (size_t i = 0; i < axes.size() - 1L; i++) {
PADDLE_ENFORCE_GT(x_dim[axes[i]], 0,
platform::errors::InvalidArgument(
"Invalid fft n-point (%d).", x_dim[axes[i]]));
}
const int64_t last_dim_size = ctx->Attrs().Get<int64_t>("last_dim_size");
framework::DDim out_dim(ctx->GetInputDim("X"));
const int64_t last_fft_axis = axes.back();
if (last_dim_size == 0) {
const int64_t last_fft_dim_size = out_dim.at(last_fft_axis);
const int64_t fft_n_point = (last_fft_dim_size - 1) * 2;
PADDLE_ENFORCE_GT(fft_n_point, 0,
platform::errors::InvalidArgument(
"Invalid fft n-point (%d).", fft_n_point));
out_dim.at(last_fft_axis) = fft_n_point;
} else {
PADDLE_ENFORCE_GT(last_dim_size, 0,
platform::errors::InvalidArgument(
"Invalid fft n-point (%d).", last_dim_size));
out_dim.at(last_fft_axis) = last_dim_size;
}
ctx->SetOutputDim("Out", out_dim);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
const auto kernel_dtype = framework::ToRealType(in_dtype);
return framework::OpKernelType(kernel_dtype, ctx.GetPlace());
}
};
template <typename T>
class FFTC2RGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("fft_c2r_grad");
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
}
};
class FFTC2RGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
const auto out_grad_name = framework::GradVarName("Out");
OP_INOUT_CHECK(ctx->HasInput(out_grad_name), "Input", out_grad_name,
"fft_c2r_grad");
const auto x_grad_name = framework::GradVarName("X");
OP_INOUT_CHECK(ctx->HasOutput(x_grad_name), "Output", x_grad_name,
"fft_c2r_grad");
const auto axes = ctx->Attrs().Get<std::vector<int64_t>>("axes");
const auto out_grad_dim = ctx->GetInputDim(out_grad_name);
framework::DDim x_grad_dim(out_grad_dim);
const int64_t last_fft_axis = axes.back();
const int64_t last_fft_dim_size = x_grad_dim.at(last_fft_axis);
x_grad_dim.at(last_fft_axis) = last_fft_dim_size / 2 + 1;
ctx->SetOutputDim(x_grad_name, x_grad_dim);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const auto in_dtype = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
return framework::OpKernelType(in_dtype, ctx.GetPlace());
}
};
// common functions
FFTNormMode get_norm_from_string(const std::string& norm, bool forward) {
if (norm.empty() || norm == "backward") {
return forward ? FFTNormMode::none : FFTNormMode::by_n;
}
if (norm == "forward") {
return forward ? FFTNormMode::by_n : FFTNormMode::none;
}
if (norm == "ortho") {
return FFTNormMode::by_sqrt_n;
}
PADDLE_THROW(platform::errors::InvalidArgument(
"FFT norm string must be 'forward' or 'backward' or 'ortho', "
"received %s",
norm));
}
// FFT Functors
#if defined(PADDLE_WITH_ONEMKL)
namespace {
static inline void MKL_DFTI_CHECK(MKL_INT status) {
if (status && !DftiErrorClass(status, DFTI_NO_ERROR)) {
PADDLE_THROW(platform::errors::External(DftiErrorMessage(status)));
}
}
struct DftiDescriptorDeleter {
void operator()(DFTI_DESCRIPTOR_HANDLE handle) {
if (handle != nullptr) {
MKL_DFTI_CHECK(DftiFreeDescriptor(&handle));
}
}
};
class DftiDescriptor {
public:
void init(DFTI_CONFIG_VALUE precision, DFTI_CONFIG_VALUE signal_type,
MKL_LONG signal_ndim, MKL_LONG* sizes) {
if (desc_ != nullptr) {
PADDLE_THROW(platform::errors::AlreadyExists(
"DFT DESCRIPTOR can only be initialized once."));
}
DFTI_DESCRIPTOR* raw_desc;
if (signal_ndim == 1) {
MKL_DFTI_CHECK(
DftiCreateDescriptor(&raw_desc, precision, signal_type, 1, sizes[0]));
} else {
MKL_DFTI_CHECK(DftiCreateDescriptor(&raw_desc, precision, signal_type,
signal_ndim, sizes));
}
desc_.reset(raw_desc);
}
DFTI_DESCRIPTOR* get() const {
if (desc_ == nullptr) {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"DFTI DESCRIPTOR has not been initialized."));
}
return desc_.get();
}
private:
std::unique_ptr<DFTI_DESCRIPTOR, DftiDescriptorDeleter> desc_;
};
DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
const framework::proto::VarType::Type& out_dtype,
const framework::DDim& in_strides,
const framework::DDim& out_strides,
const std::vector<int>& signal_sizes,
FFTNormMode normalization, bool forward) {
const DFTI_CONFIG_VALUE precision = [&] {
switch (in_dtype) {
case framework::proto::VarType::FP32:
return DFTI_SINGLE;
case framework::proto::VarType::COMPLEX64:
return DFTI_SINGLE;
case framework::proto::VarType::FP64:
return DFTI_DOUBLE;
case framework::proto::VarType::COMPLEX128:
return DFTI_DOUBLE;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Input data type should be FP32, FP64, COMPLEX64 or COMPLEX128."));
}
}();
// C2C, R2C, C2R
const FFTTransformType fft_type = GetFFTTransformType(in_dtype, out_dtype);
const DFTI_CONFIG_VALUE domain =
(fft_type == FFTTransformType::C2C) ? DFTI_COMPLEX : DFTI_REAL;
// const bool complex_input = framework::IsComplexType(in_dtype);
// const bool complex_output = framework::IsComplexType(out_dtype);
// const DFTI_CONFIG_VALUE domain = [&] {
// if (forward) {
// return complex_input ? DFTI_COMPLEX : DFTI_REAL;
// } else {
// return complex_output ? DFTI_COMPLEX : DFTI_REAL;
// }
// }();
DftiDescriptor descriptor;
std::vector<MKL_LONG> fft_sizes(signal_sizes.cbegin(), signal_sizes.cend());
const MKL_LONG signal_ndim = fft_sizes.size() - 1;
descriptor.init(precision, domain, signal_ndim, fft_sizes.data() + 1);
// placement inplace or not inplace
MKL_DFTI_CHECK(
DftiSetValue(descriptor.get(), DFTI_PLACEMENT, DFTI_NOT_INPLACE));
// number of transformations
const MKL_LONG batch_size = fft_sizes[0];
MKL_DFTI_CHECK(
DftiSetValue(descriptor.get(), DFTI_NUMBER_OF_TRANSFORMS, batch_size));
// input & output distance
const MKL_LONG idist = in_strides[0];
const MKL_LONG odist = out_strides[0];
MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_INPUT_DISTANCE, idist));
MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_OUTPUT_DISTANCE, odist));
// input & output stride
std::vector<MKL_LONG> mkl_in_stride(1 + signal_ndim, 0);
std::vector<MKL_LONG> mkl_out_stride(1 + signal_ndim, 0);
for (MKL_LONG i = 1; i <= signal_ndim; i++) {
mkl_in_stride[i] = in_strides[i];
mkl_out_stride[i] = out_strides[i];
}
MKL_DFTI_CHECK(
DftiSetValue(descriptor.get(), DFTI_INPUT_STRIDES, mkl_in_stride.data()));
MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_OUTPUT_STRIDES,
mkl_out_stride.data()));
// conjugate even storage
if (!(fft_type == FFTTransformType::C2C)) {
MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_CONJUGATE_EVEN_STORAGE,
DFTI_COMPLEX_COMPLEX));
}
MKL_LONG signal_numel =
std::accumulate(fft_sizes.cbegin() + 1, fft_sizes.cend(), 1UL,
std::multiplies<MKL_LONG>());
if (normalization != FFTNormMode::none) {
const double scale =
((normalization == FFTNormMode::by_sqrt_n)
? 1.0 / std::sqrt(static_cast<double>(signal_numel))
: 1.0 / static_cast<double>(signal_numel));
const auto scale_direction = [&]() {
if (fft_type == FFTTransformType::R2C ||
(fft_type == FFTTransformType::C2C && forward)) {
return DFTI_FORWARD_SCALE;
} else {
// (fft_type == FFTTransformType::C2R ||
// (fft_type == FFTTransformType::C2C && !forward))
return DFTI_BACKWARD_SCALE;
}
}();
MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), scale_direction, scale));
}
// commit the descriptor
MKL_DFTI_CHECK(DftiCommitDescriptor(descriptor.get()));
return descriptor;
}
// Execute a general fft operation (can be c2c, onesided r2c or onesided c2r)
template <typename DeviceContext, typename Ti, typename To>
void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out,
const std::vector<int64_t>& axes, FFTNormMode normalization,
bool forward) {
const framework::DDim& in_sizes = x->dims();
const int ndim = in_sizes.size();
const int signal_ndim = axes.size();
const int batch_ndim = ndim - signal_ndim;
const framework::DDim& out_sizes = out->dims();
// make a dim permutation
std::vector<int> dim_permute(ndim);
std::iota(dim_permute.begin(), dim_permute.end(), 0);
std::vector<bool> is_transformed_dim(ndim, false);
for (const auto& d : axes) {
is_transformed_dim[d] = true;
}
const auto batch_end =
std::partition(dim_permute.begin(), dim_permute.end(),
[&](size_t axis) { return !is_transformed_dim[axis]; });
std::copy(axes.cbegin(), axes.cend(), batch_end);
// transpose input according to that permutation
framework::DDim transposed_input_shape = in_sizes.transpose(dim_permute);
std::vector<int64_t> transposed_input_shape_ =
framework::vectorize(transposed_input_shape);
framework::Tensor transposed_input;
transposed_input.Resize(transposed_input_shape);
const auto place = ctx.GetPlace();
transposed_input.mutable_data<Ti>(place);
TransCompute<platform::CPUDeviceContext, Ti>(ndim, ctx, *x, &transposed_input,
dim_permute);
// make an collapsed input: collapse batch axes for input
const int batch_size = std::accumulate(
transposed_input_shape.Get(), transposed_input_shape.Get() + batch_ndim,
1L, std::multiplies<int64_t>());
std::vector<int> collapsed_input_shape_(1 + signal_ndim);
collapsed_input_shape_[0] = batch_size;
std::copy(transposed_input_shape_.begin() + batch_ndim,
transposed_input_shape_.end(), collapsed_input_shape_.begin() + 1);
const framework::DDim collapsed_input_shape =
framework::make_ddim(collapsed_input_shape_);
transposed_input.Resize(collapsed_input_shape);
framework::Tensor& collapsed_input = transposed_input;
// make a collapsed output
std::vector<int> collapsed_output_shape_(1 + signal_ndim);
collapsed_output_shape_[0] = batch_size;
for (int i = 0; i < signal_ndim; i++) {
collapsed_output_shape_[1 + i] = out_sizes[axes[i]];
}
const framework::DDim collapsed_output_shape =
framework::make_ddim(collapsed_output_shape_);
framework::Tensor collapsed_output;
collapsed_output.Resize(collapsed_output_shape);
collapsed_output.mutable_data(place, out->type());
// signal sizes
std::vector<int> signal_sizes(1 + signal_ndim);
signal_sizes[0] = batch_size;
for (int i = 0; i < signal_ndim; i++) {
signal_sizes[1 + i] =
std::max(collapsed_input_shape[1 + i], collapsed_output_shape[1 + i]);
}
// input & output stride
const framework::DDim input_stride = framework::stride(collapsed_input_shape);
const framework::DDim output_stride =
framework::stride(collapsed_output_shape);
// make a DFTI_DESCRIPTOR
DftiDescriptor desc =
_plan_mkl_fft(x->type(), out->type(), input_stride, output_stride,
signal_sizes, normalization, forward);
const FFTTransformType fft_type = GetFFTTransformType(x->type(), out->type());
if (fft_type == FFTTransformType::C2R && forward) {
framework::Tensor collapsed_input_conj(collapsed_input.type());
collapsed_input_conj.mutable_data<Ti>(collapsed_input.dims(),
ctx.GetPlace());
// conjugate the input
platform::ForRange<DeviceContext> for_range(ctx, collapsed_input.numel());
math::ConjFunctor<Ti> functor(collapsed_input.data<Ti>(),
collapsed_input.numel(),
collapsed_input_conj.data<Ti>());
for_range(functor);
MKL_DFTI_CHECK(DftiComputeBackward(desc.get(),
collapsed_input_conj.data<void>(),
collapsed_output.data<void>()));
} else if (fft_type == FFTTransformType::R2C && !forward) {
framework::Tensor collapsed_output_conj(collapsed_output.type());
collapsed_output_conj.mutable_data<To>(collapsed_output.dims(),
ctx.GetPlace());
MKL_DFTI_CHECK(DftiComputeForward(desc.get(), collapsed_input.data<void>(),
collapsed_output_conj.data<void>()));
// conjugate the output
platform::ForRange<DeviceContext> for_range(ctx, collapsed_output.numel());
math::ConjFunctor<To> functor(collapsed_output_conj.data<To>(),
collapsed_output.numel(),
collapsed_output.data<To>());
for_range(functor);
} else {
if (forward) {
MKL_DFTI_CHECK(DftiComputeForward(desc.get(),
collapsed_input.data<void>(),
collapsed_output.data<void>()));
} else {
MKL_DFTI_CHECK(DftiComputeBackward(desc.get(),
collapsed_input.data<void>(),
collapsed_output.data<void>()));
}
}
// resize for the collapsed output
framework::DDim transposed_output_shape = out_sizes.transpose(dim_permute);
collapsed_output.Resize(transposed_output_shape);
framework::Tensor& transposed_output = collapsed_output;
// reverse the transposition
std::vector<int> reverse_dim_permute(ndim);
for (int i = 0; i < ndim; i++) {
reverse_dim_permute[dim_permute[i]] = i;
}
TransCompute<platform::CPUDeviceContext, To>(ndim, ctx, transposed_output,
out, reverse_dim_permute);
}
} // anonymous namespace
template <typename Ti, typename To>
struct FFTC2CFunctor<platform::CPUDeviceContext, Ti, To> {
void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x,
Tensor* out, const std::vector<int64_t>& axes,
FFTNormMode normalization, bool forward) {
exec_fft<platform::CPUDeviceContext, Ti, To>(ctx, x, out, axes,
normalization, forward);
}
};
template <typename Ti, typename To>
struct FFTR2CFunctor<platform::CPUDeviceContext, Ti, To> {
void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x,
Tensor* out, const std::vector<int64_t>& axes,
FFTNormMode normalization, bool forward) {
exec_fft<platform::CPUDeviceContext, Ti, To>(ctx, x, out, axes,
normalization, forward);
}
};
template <typename Ti, typename To>
struct FFTC2RFunctor<platform::CPUDeviceContext, Ti, To> {
void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x,
Tensor* out, const std::vector<int64_t>& axes,
FFTNormMode normalization, bool forward) {
if (axes.size() > 1) {
const std::vector<int64_t> c2c_dims(axes.begin(), axes.end() - 1);
Tensor temp;
temp.mutable_data<Ti>(x->dims(), ctx.GetPlace());
FFTC2CFunctor<platform::CPUDeviceContext, Ti, Ti> c2c_functor;
c2c_functor(ctx, x, &temp, c2c_dims, normalization, forward);
const std::vector<int64_t> new_axes{axes.back()};
exec_fft<platform::CPUDeviceContext, Ti, To>(ctx, &temp, out, new_axes,
normalization, forward);
} else {
exec_fft<platform::CPUDeviceContext, Ti, To>(ctx, x, out, axes,
normalization, forward);
}
}
};
#elif defined(PADDLE_WITH_POCKETFFT)
namespace {
template <typename T>
T compute_factor(int64_t size, FFTNormMode normalization) {
constexpr auto one = static_cast<T>(1);
switch (normalization) {
case FFTNormMode::none:
return one;
case FFTNormMode::by_n:
return one / static_cast<T>(size);
case FFTNormMode::by_sqrt_n:
return one / std::sqrt(static_cast<T>(size));
}
PADDLE_THROW(
platform::errors::InvalidArgument("Unsupported normalization type"));
}
} // anonymous namespace
template <typename Ti, typename To>
struct FFTC2CFunctor<platform::CPUDeviceContext, Ti, To> {
void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x,
Tensor* out, const std::vector<int64_t>& axes,
FFTNormMode normalization, bool forward) {
using R = typename Ti::value_type;
using C = std::complex<R>;
const auto& input_dim = x->dims();
const std::vector<size_t> in_sizes =
framework::vectorize<size_t>(input_dim);
std::vector<std::ptrdiff_t> in_strides =
framework::vectorize<std::ptrdiff_t>(framework::stride(input_dim));
const int64_t data_size = sizeof(C);
std::transform(in_strides.begin(), in_strides.end(), in_strides.begin(),
[&](std::ptrdiff_t s) { return s * data_size; });
const auto* in_data = reinterpret_cast<const C*>(x->data<Ti>());
auto* out_data = reinterpret_cast<C*>(out->data<To>());
// pocketfft requires std::vector<size_t>
std::vector<size_t> axes_(axes.size());
std::copy(axes.begin(), axes.end(), axes_.begin());
// compuet factor
int64_t signal_numel = 1;
for (auto i : axes) {
signal_numel *= in_sizes[i];
}
R factor = compute_factor<R>(signal_numel, normalization);
pocketfft::c2c(in_sizes, in_strides, in_strides, axes_, forward, in_data,
out_data, factor);
}
};
template <typename Ti, typename To>
struct FFTR2CFunctor<platform::CPUDeviceContext, Ti, To> {
void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x,
Tensor* out, const std::vector<int64_t>& axes,
FFTNormMode normalization, bool forward) {
using R = Ti;
using C = std::complex<R>;
const auto& input_dim = x->dims();
const std::vector<size_t> in_sizes =
framework::vectorize<size_t>(input_dim);
std::vector<std::ptrdiff_t> in_strides =
framework::vectorize<std::ptrdiff_t>(framework::stride(input_dim));
{
const int64_t data_size = sizeof(R);
std::transform(in_strides.begin(), in_strides.end(), in_strides.begin(),
[&](std::ptrdiff_t s) { return s * data_size; });
}
const auto& output_dim = out->dims();
const std::vector<size_t> out_sizes =
framework::vectorize<size_t>(output_dim);
std::vector<std::ptrdiff_t> out_strides =
framework::vectorize<std::ptrdiff_t>(framework::stride(output_dim));
{
const int64_t data_size = sizeof(C);
std::transform(out_strides.begin(), out_strides.end(),
out_strides.begin(),
[&](std::ptrdiff_t s) { return s * data_size; });
}
const auto* in_data = x->data<R>();
auto* out_data = reinterpret_cast<C*>(out->data<To>());
// pocketfft requires std::vector<size_t>
std::vector<size_t> axes_(axes.size());
std::copy(axes.begin(), axes.end(), axes_.begin());
// compuet normalization factor
int64_t signal_numel = 1;
for (auto i : axes) {
signal_numel *= in_sizes[i];
}
R factor = compute_factor<R>(signal_numel, normalization);
pocketfft::r2c(in_sizes, in_strides, out_strides, axes_, forward, in_data,
out_data, factor);
}
};
template <typename Ti, typename To>
struct FFTC2RFunctor<platform::CPUDeviceContext, Ti, To> {
void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x,
Tensor* out, const std::vector<int64_t>& axes,
FFTNormMode normalization, bool forward) {
using R = To;
using C = std::complex<R>;
const auto& input_dim = x->dims();
const std::vector<size_t> in_sizes =
framework::vectorize<size_t>(input_dim);
std::vector<std::ptrdiff_t> in_strides =
framework::vectorize<std::ptrdiff_t>(framework::stride(input_dim));
{
const int64_t data_size = sizeof(C);
std::transform(in_strides.begin(), in_strides.end(), in_strides.begin(),
[&](std::ptrdiff_t s) { return s * data_size; });
}
const auto& output_dim = out->dims();
const std::vector<size_t> out_sizes =
framework::vectorize<size_t>(output_dim);
std::vector<std::ptrdiff_t> out_strides =
framework::vectorize<std::ptrdiff_t>(framework::stride(output_dim));
{
const int64_t data_size = sizeof(R);
std::transform(out_strides.begin(), out_strides.end(),
out_strides.begin(),
[&](std::ptrdiff_t s) { return s * data_size; });
}
const auto* in_data = reinterpret_cast<const C*>(x->data<Ti>());
auto* out_data = out->data<R>();
// pocketfft requires std::vector<size_t>
std::vector<size_t> axes_(axes.size());
std::copy(axes.begin(), axes.end(), axes_.begin());
// compuet normalization factor
int64_t signal_numel = 1;
for (auto i : axes) {
signal_numel *= out_sizes[i];
}
R factor = compute_factor<R>(signal_numel, normalization);
pocketfft::c2r(out_sizes, in_strides, out_strides, axes_, forward, in_data,
out_data, factor);
}
};
#endif
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(fft_c2c, ops::FFTC2COp, ops::FFTC2COpMaker,
ops::FFTC2CGradOpMaker<paddle::framework::OpDesc>,
ops::FFTC2CGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
fft_c2c, ops::FFTC2CKernel<paddle::platform::CPUDeviceContext, float>,
ops::FFTC2CKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OPERATOR(fft_c2c_grad, ops::FFTC2CGradOp);
REGISTER_OP_CPU_KERNEL(
fft_c2c_grad,
ops::FFTC2CGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::FFTC2CGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OPERATOR(fft_r2c, ops::FFTR2COp, ops::FFTR2COpMaker,
ops::FFTR2CGradOpMaker<paddle::framework::OpDesc>,
ops::FFTR2CGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
fft_r2c, ops::FFTR2CKernel<paddle::platform::CPUDeviceContext, float>,
ops::FFTR2CKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OPERATOR(fft_r2c_grad, ops::FFTR2CGradOp);
REGISTER_OP_CPU_KERNEL(
fft_r2c_grad,
ops::FFTR2CGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::FFTR2CGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OPERATOR(fft_c2r, ops::FFTC2ROp, ops::FFTC2ROpMaker,
ops::FFTC2RGradOpMaker<paddle::framework::OpDesc>,
ops::FFTC2RGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
fft_c2r, ops::FFTC2RKernel<paddle::platform::CPUDeviceContext, float>,
ops::FFTC2RKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OPERATOR(fft_c2r_grad, ops::FFTC2RGradOp);
REGISTER_OP_CPU_KERNEL(
fft_c2r_grad,
ops::FFTC2RGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::FFTC2RGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cufft.h>
#include <cufftXt.h>
#include <functional>
#include <list>
#include <memory>
#include <mutex>
#include <numeric>
#include <sstream>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/operators/conj_op.h"
#include "paddle/fluid/operators/spectral_op.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/dynload/cufft.h"
namespace paddle {
namespace operators {
namespace {
using ScalarType = framework::proto::VarType::Type;
const int64_t kMaxCUFFTNdim = 3;
const int64_t kMaxDataNdim = kMaxCUFFTNdim + 1;
static inline std::string get_cufft_error_info(cufftResult error) {
switch (error) {
case CUFFT_SUCCESS:
return "CUFFT_SUCCESS";
case CUFFT_INVALID_PLAN:
return "CUFFT_INVALID_PLAN";
case CUFFT_ALLOC_FAILED:
return "CUFFT_ALLOC_FAILED";
case CUFFT_INVALID_TYPE:
return "CUFFT_INVALID_TYPE";
case CUFFT_INVALID_VALUE:
return "CUFFT_INVALID_VALUE";
case CUFFT_INTERNAL_ERROR:
return "CUFFT_INTERNAL_ERROR";
case CUFFT_EXEC_FAILED:
return "CUFFT_EXEC_FAILED";
case CUFFT_SETUP_FAILED:
return "CUFFT_SETUP_FAILED";
case CUFFT_INVALID_SIZE:
return "CUFFT_INVALID_SIZE";
case CUFFT_UNALIGNED_DATA:
return "CUFFT_UNALIGNED_DATA";
case CUFFT_INCOMPLETE_PARAMETER_LIST:
return "CUFFT_INCOMPLETE_PARAMETER_LIST";
case CUFFT_INVALID_DEVICE:
return "CUFFT_INVALID_DEVICE";
case CUFFT_PARSE_ERROR:
return "CUFFT_PARSE_ERROR";
case CUFFT_NO_WORKSPACE:
return "CUFFT_NO_WORKSPACE";
case CUFFT_NOT_IMPLEMENTED:
return "CUFFT_NOT_IMPLEMENTED";
#ifndef __HIPCC__
case CUFFT_LICENSE_ERROR:
return "CUFFT_LICENSE_ERROR";
#endif
case CUFFT_NOT_SUPPORTED:
return "CUFFT_NOT_SUPPORTED";
default:
std::ostringstream ss;
ss << "unknown error " << error;
return ss.str();
}
}
static inline void CUFFT_CHECK(cufftResult error) {
if (error != CUFFT_SUCCESS) {
PADDLE_THROW(platform::errors::External(get_cufft_error_info(error)));
}
}
// This struct is used to easily compute hashes of the
// parameters. It will be the **key** to the plan cache.
struct PlanKey {
// between 1 and kMaxCUFFTNdim, i.e., 1 <= signal_ndim <= 3
int64_t signal_ndim_;
// These include additional batch dimension as well.
int64_t sizes_[kMaxDataNdim];
int64_t input_shape_[kMaxDataNdim];
int64_t output_shape_[kMaxDataNdim];
FFTTransformType fft_type_;
ScalarType value_type_;
PlanKey() = default;
PlanKey(const std::vector<int64_t>& in_shape,
const std::vector<int64_t>& out_shape,
const std::vector<int64_t>& signal_size, FFTTransformType fft_type,
ScalarType value_type) {
// Padding bits must be zeroed for hashing
memset(this, 0, sizeof(*this));
signal_ndim_ = signal_size.size() - 1;
fft_type_ = fft_type;
value_type_ = value_type;
std::copy(signal_size.cbegin(), signal_size.cend(), sizes_);
std::copy(in_shape.cbegin(), in_shape.cend(), input_shape_);
std::copy(out_shape.cbegin(), out_shape.cend(), output_shape_);
}
};
// An RAII encapsulation of cuFFTHandle
class CuFFTHandle {
::cufftHandle handle_;
public:
CuFFTHandle() { CUFFT_CHECK(platform::dynload::cufftCreate(&handle_)); }
::cufftHandle& get() { return handle_; }
const ::cufftHandle& get() const { return handle_; }
~CuFFTHandle() {
// Not using fftDestroy() for rocFFT to work around double freeing of handles
#ifndef __HIPCC__
CUFFT_CHECK(platform::dynload::cufftDestroy(handle_));
#endif
}
};
#ifdef __HIPCC__
using plan_size_type = int;
#else
using plan_size_type = long long int; // NOLINT
#endif
// This class contains all the information needed to execute a cuFFT plan:
// 1. the plan
// 2. the workspace size needed
class CuFFTConfig {
public:
// Only move semantics is enought for this class. Although we already use
// unique_ptr for the plan, still remove copy constructor and assignment op so
// we don't accidentally copy and take perf hit.
CuFFTConfig(const CuFFTConfig&) = delete;
CuFFTConfig& operator=(CuFFTConfig const&) = delete;
explicit CuFFTConfig(const PlanKey& plan_key)
: CuFFTConfig(
std::vector<int64_t>(plan_key.sizes_,
plan_key.sizes_ + plan_key.signal_ndim_ + 1),
plan_key.signal_ndim_, plan_key.fft_type_, plan_key.value_type_) {}
// sizes are full signal, including batch size and always two-sided
CuFFTConfig(const std::vector<int64_t>& sizes, const int64_t signal_ndim,
FFTTransformType fft_type, ScalarType dtype)
: fft_type_(fft_type), value_type_(dtype) {
// signal sizes (excluding batch dim)
std::vector<plan_size_type> signal_sizes(sizes.begin() + 1, sizes.end());
// input batch size
const auto batch = static_cast<plan_size_type>(sizes[0]);
// const int64_t signal_ndim = sizes.size() - 1;
PADDLE_ENFORCE_EQ(signal_ndim, sizes.size() - 1,
platform::errors::InvalidArgument(
"The signal_ndim must be equal to sizes.size() - 1,"
"But signal_ndim is: [%d], sizes.size() - 1 is: [%d]",
signal_ndim, sizes.size() - 1));
#ifdef __HIPCC__
hipfftType exec_type = [&] {
if (dtype == framework::proto::VarType::FP32) {
switch (fft_type) {
case FFTTransformType::C2C:
return HIPFFT_C2C;
case FFTTransformType::R2C:
return HIPFFT_R2C;
case FFTTransformType::C2R:
return HIPFFT_C2R;
}
} else if (dtype == framework::proto::VarType::FP64) {
switch (fft_type) {
case FFTTransformType::C2C:
return HIPFFT_Z2Z;
case FFTTransformType::R2C:
return HIPFFT_D2Z;
case FFTTransformType::C2R:
return HIPFFT_Z2D;
}
}
PADDLE_THROW(platform::errors::InvalidArgument(
"hipFFT only support transforms of type float32 and float64"));
}();
#else
cudaDataType itype, otype, exec_type;
const auto complex_input = has_complex_input(fft_type);
const auto complex_output = has_complex_output(fft_type);
if (dtype == framework::proto::VarType::FP32) {
itype = complex_input ? CUDA_C_32F : CUDA_R_32F;
otype = complex_output ? CUDA_C_32F : CUDA_R_32F;
exec_type = CUDA_C_32F;
} else if (dtype == framework::proto::VarType::FP64) {
itype = complex_input ? CUDA_C_64F : CUDA_R_64F;
otype = complex_output ? CUDA_C_64F : CUDA_R_64F;
exec_type = CUDA_C_64F;
} else if (dtype == framework::proto::VarType::FP16) {
itype = complex_input ? CUDA_C_16F : CUDA_R_16F;
otype = complex_output ? CUDA_C_16F : CUDA_R_16F;
exec_type = CUDA_C_16F;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"cuFFT only support transforms of type float16, float32 and "
"float64"));
}
#endif
// disable auto allocation of workspace to use allocator from the framework
CUFFT_CHECK(platform::dynload::cufftSetAutoAllocation(
plan(), /* autoAllocate */ 0));
size_t ws_size_t;
// make plan
#ifdef __HIPCC__
CUFFT_CHECK(hipfftMakePlanMany(
plan(), signal_ndim, signal_sizes.data(),
/* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1,
/* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, exec_type,
batch, &ws_size_t));
#else
CUFFT_CHECK(platform::dynload::cufftXtMakePlanMany(
plan(), signal_ndim, signal_sizes.data(),
/* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1, itype,
/* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, otype,
batch, &ws_size_t, exec_type));
#endif
ws_size = ws_size_t;
}
const cufftHandle& plan() const { return plan_ptr.get(); }
FFTTransformType transform_type() const { return fft_type_; }
ScalarType data_type() const { return value_type_; }
size_t workspace_size() const { return ws_size; }
private:
CuFFTHandle plan_ptr;
size_t ws_size;
FFTTransformType fft_type_;
ScalarType value_type_;
};
// Execute a pre-planned transform
static void exec_cufft_plan(const CuFFTConfig& config, void* in_data,
void* out_data, bool forward) {
auto& plan = config.plan();
#ifdef __HIPCC__
auto value_type = config.data_type();
if (value_type == framework::proto::VarType::FP32) {
switch (config.transform_type()) {
case FFTTransformType::C2C: {
CUFFT_CHECK(hipfftExecC2C(plan, static_cast<hipfftComplex*>(in_data),
static_cast<hipfftComplex*>(out_data),
forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD));
return;
}
case FFTTransformType::R2C: {
CUFFT_CHECK(hipfftExecR2C(plan, static_cast<hipfftReal*>(in_data),
static_cast<hipfftComplex*>(out_data)));
return;
}
case FFTTransformType::C2R: {
CUFFT_CHECK(hipfftExecC2R(plan, static_cast<hipfftComplex*>(in_data),
static_cast<hipfftReal*>(out_data)));
return;
}
}
} else if (value_type == framework::proto::VarType::FP64) {
switch (config.transform_type()) {
case FFTTransformType::C2C: {
CUFFT_CHECK(hipfftExecZ2Z(plan,
static_cast<hipfftDoubleComplex*>(in_data),
static_cast<hipfftDoubleComplex*>(out_data),
forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD));
return;
}
case FFTTransformType::R2C: {
CUFFT_CHECK(hipfftExecD2Z(plan, static_cast<hipfftDoubleReal*>(in_data),
static_cast<hipfftDoubleComplex*>(out_data)));
return;
}
case FFTTransformType::C2R: {
CUFFT_CHECK(hipfftExecZ2D(plan,
static_cast<hipfftDoubleComplex*>(in_data),
static_cast<hipfftDoubleReal*>(out_data)));
return;
}
}
}
PADDLE_THROW(platform::errors::InvalidArgument(
"hipFFT only support transforms of type float32 and float64"));
#else
CUFFT_CHECK(platform::dynload::cufftXtExec(
plan, in_data, out_data, forward ? CUFFT_FORWARD : CUFFT_INVERSE));
#endif
}
// Execute a general unnormalized fft operation (can be c2c, onesided r2c or
// onesided c2r)
template <typename DeviceContext, typename Ti, typename To>
void exec_fft(const DeviceContext& ctx, const Tensor* X, Tensor* out,
const std::vector<int64_t>& dim, bool forward) {
const auto x_dims = framework::vectorize(X->dims());
const auto out_dims = framework::vectorize(out->dims());
const int64_t ndim = static_cast<int64_t>(X->dims().size());
const int64_t signal_ndim = static_cast<int64_t>(dim.size());
const int64_t batch_dims = ndim - signal_ndim;
auto tensor_place = ctx.GetPlace();
// Transpose batch dimensions first, then with transforming dims
std::vector<int> dim_permute(ndim);
std::vector<int> reverse_dim_permute(ndim);
std::vector<int64_t> trans_dims(ndim);
std::iota(dim_permute.begin(), dim_permute.end(), int{0});
std::vector<bool> is_transformed_dim(ndim);
for (const auto& d : dim) {
is_transformed_dim[d] = true;
}
auto batch_end =
std::partition(dim_permute.begin(), dim_permute.end(),
[&](int64_t d) { return !is_transformed_dim[d]; });
std::sort(dim_permute.begin(), batch_end);
std::copy(dim.cbegin(), dim.cend(), batch_end);
for (size_t i = 0; i < ndim; i++) {
trans_dims[i] = x_dims[dim_permute[i]]; // shape of input transpose
reverse_dim_permute[dim_permute[i]] =
static_cast<int>(i); // reverse of dim permute
}
framework::Tensor input;
input.Resize(framework::make_ddim(trans_dims));
input.mutable_data<Ti>(tensor_place);
/*
auto in_ret = TransposeSimple<Ti>::run(ctx, *X, dim_permute, input);
if (!in_ret) {
TransCompute<DeviceContext, Ti>(ndim, ctx, *X, input, dim_permute);
}
*/
TransCompute<DeviceContext, Ti>(ndim, ctx, *X, &input, dim_permute);
// Reshape batch dimensions into a single dimension
std::vector<int64_t> batched_sizes(signal_ndim + 1);
auto batch_size =
std::accumulate(trans_dims.begin(), trans_dims.begin() + batch_dims,
static_cast<int>(1), std::multiplies<int>());
batched_sizes[0] = batch_size;
std::copy(trans_dims.begin() + batch_dims, trans_dims.end(),
batched_sizes.begin() + 1);
input.Resize(framework::make_ddim(batched_sizes));
// Check the shape of transforming dims with input and output
std::vector<int64_t> signal_size(signal_ndim + 1);
signal_size[0] = batch_size;
for (int64_t i = 0; i < signal_ndim; ++i) {
auto in_size = input.dims()[i + 1];
auto out_size = out_dims[dim[i]];
signal_size[i + 1] = std::max(in_size, out_size);
PADDLE_ENFORCE_EQ(
(in_size == signal_size[i + 1] ||
in_size == (signal_size[i + 1] / 2) + 1),
true,
platform::errors::InvalidArgument(
"The dimension[%d] of Input size: [%d] must be equal or half to "
"The dimension[%d] of Output size: [%d]",
dim[i], in_size, dim[i], out_size));
PADDLE_ENFORCE_EQ(
(out_size == signal_size[i + 1] ||
out_size == (signal_size[i + 1] / 2) + 1),
true,
platform::errors::InvalidArgument(
"The dimension[%d] of Output size: [%d] must be equal or half to "
"The dimension[%d] of Input size: [%d]",
dim[i], out_size, dim[i], in_size));
}
std::vector<int64_t> reshape_out_sizes(ndim);
for (size_t i = 0; i < ndim; ++i) {
reshape_out_sizes[i] = out_dims[dim_permute[i]];
}
std::vector<int64_t> batched_out_sizes(batched_sizes.begin(),
batched_sizes.end());
for (size_t i = 0; i < dim.size(); ++i) {
batched_out_sizes[i + 1] = out_dims[dim[i]];
}
// output
framework::Tensor output;
output.Resize(framework::make_ddim(batched_out_sizes));
output.mutable_data<To>(tensor_place);
// Create the transform plan (either from cache or locally)
const auto value_type = framework::IsComplexType(input.type())
? framework::ToRealType(input.type())
: input.type();
auto fft_type = GetFFTTransformType(input.type(), output.type());
PlanKey Key(framework::vectorize(input.dims()),
framework::vectorize(output.dims()), signal_size, fft_type,
value_type);
CuFFTConfig uncached_plan(Key);
CuFFTConfig* config = &uncached_plan;
auto& plan = config->plan();
// prepare cufft for execution
CUFFT_CHECK(platform::dynload::cufftSetStream(plan, ctx.stream()));
framework::Tensor workspace_tensor;
workspace_tensor.mutable_data<To>(tensor_place, config->workspace_size());
CUFFT_CHECK(
platform::dynload::cufftSetWorkArea(plan, workspace_tensor.data<To>()));
// execute transform plan
if (fft_type == FFTTransformType::C2R && forward) {
forward = false;
framework::Tensor input_conj(input.type());
input_conj.mutable_data<Ti>(input.dims(), ctx.GetPlace());
platform::ForRange<DeviceContext> for_range(ctx, input.numel());
math::ConjFunctor<Ti> functor(input.data<Ti>(), input.numel(),
input_conj.data<Ti>());
for_range(functor);
exec_cufft_plan(*config, input_conj.data<void>(), output.data<void>(),
forward);
} else if (fft_type == FFTTransformType::R2C && !forward) {
forward = true;
framework::Tensor out_conj(output.type());
out_conj.mutable_data<To>(output.dims(), ctx.GetPlace());
exec_cufft_plan(*config, input.data<void>(), out_conj.data<void>(),
forward);
platform::ForRange<DeviceContext> for_range(ctx, output.numel());
math::ConjFunctor<To> functor(out_conj.data<To>(), output.numel(),
output.data<To>());
for_range(functor);
} else {
exec_cufft_plan(*config, input.data<void>(), output.data<void>(), forward);
}
// Inverting output by reshape and transpose to original batch and dimension
output.Resize(framework::make_ddim(reshape_out_sizes));
out->Resize(framework::make_ddim(out_dims));
TransCompute<DeviceContext, To>(ndim, ctx, output, out, reverse_dim_permute);
}
// Calculates the normalization constant
double fft_normalization_scale(FFTNormMode normalization,
const std::vector<int64_t>& sizes,
const std::vector<int64_t>& dims) {
// auto norm = static_cast<fft_norm_mode>(normalization);
if (normalization == FFTNormMode::none) {
return static_cast<double>(1.0);
}
int64_t signal_numel = 1;
for (auto dim : dims) {
signal_numel *= sizes[dim];
}
const double scale_denom = (normalization == FFTNormMode::by_sqrt_n)
? std::sqrt(signal_numel)
: static_cast<double>(signal_numel);
return static_cast<double>(1.0 / scale_denom);
}
template <typename DeviceContext, typename T>
void exec_normalization(const DeviceContext& ctx, const Tensor* in, Tensor* out,
FFTNormMode normalization,
const std::vector<int64_t>& sizes,
const std::vector<int64_t>& axes) {
double scale = fft_normalization_scale(normalization, sizes, axes);
if (scale != 1.0) {
auto eigen_out = framework::EigenVector<T>::Flatten(*out);
auto eigen_in = framework::EigenVector<T>::Flatten(*in);
auto dev = ctx.eigen_device();
EigenScale<Eigen::GpuDevice, T>::Eval(*dev, eigen_out, eigen_in,
static_cast<T>(scale),
static_cast<T>(0), false);
} else {
framework::TensorCopy(*in, ctx.GetPlace(), out);
}
}
} // anonymous namespace
// Use the optimized path to perform single R2C or C2R if transformation dim is
// supported by cuFFT
bool use_optimized_cufft_path(const std::vector<int64_t>& axes) {
// For performance reason, when axes starts with (0, 1), do not use the
// optimized path.
if (axes.size() > kMaxCUFFTNdim ||
(axes.size() >= 2 && axes[0] == 0 && axes[1] == 1)) {
return false;
} else {
return true;
}
}
template <typename Ti, typename To>
struct FFTC2CFunctor<platform::CUDADeviceContext, Ti, To> {
void operator()(const platform::CUDADeviceContext& ctx, const Tensor* X,
Tensor* out, const std::vector<int64_t>& axes,
FFTNormMode normalization, bool forward) {
if (axes.empty()) {
framework::TensorCopy(*X, ctx.GetPlace(), out);
return;
}
framework::Tensor* p_out = out;
std::vector<int64_t> out_dims = framework::vectorize(X->dims());
std::vector<int64_t> working_axes(axes.begin(), axes.end());
std::vector<int64_t> first_dims;
size_t max_dims;
framework::Tensor working_tensor;
working_tensor.mutable_data<Ti>(X->dims(), ctx.GetPlace());
framework::Tensor* p_working_tensor = &working_tensor;
framework::TensorCopy(*X, ctx.GetPlace(), &working_tensor);
while (true) {
max_dims =
std::min(static_cast<size_t>(kMaxCUFFTNdim), working_axes.size());
first_dims.assign(working_axes.end() - max_dims, working_axes.end());
exec_fft<platform::CUDADeviceContext, Ti, To>(ctx, p_working_tensor,
p_out, first_dims, forward);
working_axes.resize(working_axes.size() - max_dims);
first_dims.clear();
if (working_axes.empty()) {
break;
}
std::swap(p_out, p_working_tensor);
}
exec_normalization<platform::CUDADeviceContext, To>(
ctx, p_out, out, normalization, out_dims, axes);
}
};
template <typename Ti, typename To>
struct FFTC2RFunctor<platform::CUDADeviceContext, Ti, To> {
void operator()(const platform::CUDADeviceContext& ctx, const Tensor* X,
Tensor* out, const std::vector<int64_t>& axes,
FFTNormMode normalization, bool forward) {
std::vector<int64_t> in_dims = framework::vectorize(X->dims());
std::vector<int64_t> out_dims = framework::vectorize(out->dims());
if (use_optimized_cufft_path(axes)) {
framework::Tensor x_copy(X->type());
x_copy.mutable_data<Ti>(X->dims(), ctx.GetPlace());
framework::TensorCopy(*X, ctx.GetPlace(), &x_copy);
exec_fft<platform::CUDADeviceContext, Ti, To>(ctx, &x_copy, out, axes,
forward);
} else {
framework::Tensor temp_tensor;
temp_tensor.mutable_data<Ti>(X->dims(), ctx.GetPlace());
const std::vector<int64_t> dims(axes.begin(), axes.end() - 1);
FFTC2CFunctor<platform::CUDADeviceContext, Ti, Ti> c2c_functor;
c2c_functor(ctx, X, &temp_tensor, dims, FFTNormMode::none, forward);
exec_fft<platform::CUDADeviceContext, Ti, To>(ctx, &temp_tensor, out,
{axes.back()}, forward);
}
exec_normalization<platform::CUDADeviceContext, To>(
ctx, out, out, normalization, out_dims, axes);
}
};
// n dimension real to complex FFT use cufft lib
template <typename Ti, typename To>
struct FFTR2CFunctor<platform::CUDADeviceContext, Ti, To> {
void operator()(const platform::CUDADeviceContext& ctx, const Tensor* X,
Tensor* out, const std::vector<int64_t>& axes,
FFTNormMode normalization, bool forward) {
// Step1: R2C transform on the last dimension
framework::Tensor* r2c_out = out;
const std::vector<int64_t> last_dim{axes.back()};
std::vector<int64_t> out_dims = framework::vectorize(out->dims());
exec_fft<platform::CUDADeviceContext, Ti, To>(ctx, X, r2c_out, last_dim,
forward);
// Step2: C2C transform on the remaining dimension
framework::Tensor c2c_out;
if (axes.size() > 1) {
c2c_out.mutable_data<To>(out->dims(), ctx.GetPlace());
std::vector<int64_t> remain_dim(axes.begin(), axes.end() - 1);
FFTC2CFunctor<platform::CUDADeviceContext, To, To> fft_c2c_func;
fft_c2c_func(ctx, r2c_out, &c2c_out, remain_dim, FFTNormMode::none,
forward);
}
const auto in_sizes = framework::vectorize(X->dims());
framework::Tensor* norm_tensor = axes.size() > 1 ? &c2c_out : r2c_out;
exec_normalization<platform::CUDADeviceContext, To>(
ctx, norm_tensor, out, normalization, in_sizes, axes);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
fft_c2c, ops::FFTC2CKernel<paddle::platform::CUDADeviceContext, float>,
ops::FFTC2CKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
fft_c2c_grad,
ops::FFTC2CGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::FFTC2CGradKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
fft_c2r, ops::FFTC2RKernel<paddle::platform::CUDADeviceContext, float>,
ops::FFTC2RKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
fft_c2r_grad,
ops::FFTC2RGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::FFTC2RGradKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
fft_r2c, ops::FFTR2CKernel<paddle::platform::CUDADeviceContext, float>,
ops::FFTR2CKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
fft_r2c_grad,
ops::FFTR2CGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::FFTR2CGradKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#define NOMINMAX // to use std::min std::max correctly on windows
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/conj_op.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/operators/math/padding.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/for_range.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "thrust/device_vector.h"
#endif
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
enum class FFTNormMode : int64_t {
none, // No normalization
by_sqrt_n, // Divide by sqrt(signal_size)
by_n, // Divide by signal_size
};
FFTNormMode get_norm_from_string(const std::string& norm, bool forward);
// Enum representing the FFT type
enum class FFTTransformType : int64_t {
C2C = 0, // Complex-to-complex
R2C, // Real-to-complex
C2R, // Complex-to-real
};
// Create transform type enum from bools representing if input and output are
// complex
inline FFTTransformType GetFFTTransformType(
framework::proto::VarType::Type input_dtype,
framework::proto::VarType::Type output_dtype) {
auto complex_input = framework::IsComplexType(input_dtype);
auto complex_output = framework::IsComplexType(output_dtype);
if (complex_input && complex_output) {
return FFTTransformType::C2C;
} else if (complex_input && !complex_output) {
return FFTTransformType::C2R;
} else if (!complex_input && complex_output) {
return FFTTransformType::R2C;
}
PADDLE_THROW(
platform::errors::InvalidArgument("Real to real FFTs are not supported"));
}
// Returns true if the transform type has complex input
inline bool has_complex_input(FFTTransformType type) {
switch (type) {
case FFTTransformType::C2C:
case FFTTransformType::C2R:
return true;
case FFTTransformType::R2C:
return false;
}
PADDLE_THROW(platform::errors::InvalidArgument("Unknown FFTTransformType"));
}
// Returns true if the transform type has complex output
inline bool has_complex_output(FFTTransformType type) {
switch (type) {
case FFTTransformType::C2C:
case FFTTransformType::R2C:
return true;
case FFTTransformType::C2R:
return false;
}
PADDLE_THROW(platform::errors::InvalidArgument("Unknown FFTTransformType"));
}
template <typename T>
struct FFTFillConjGradFunctor {
T* input_;
const size_t axis_;
const int64_t* strides_;
const size_t double_length_;
FFTFillConjGradFunctor(T* input, size_t axis, const int64_t* strides,
size_t double_length)
: input_(input),
axis_(axis),
strides_(strides),
double_length_(double_length) {}
HOSTDEVICE void operator()(size_t index) {
size_t offtset = index; // back
size_t index_i;
for (size_t i = 0; i <= axis_; i++) {
index_i = offtset / strides_[i];
offtset %= strides_[i];
}
if ((0 < index_i) && (index_i < double_length_ + 1)) {
input_[index] *= static_cast<T>(2);
}
}
};
template <typename DeviceContext, typename Ti, typename To>
struct FFTC2CFunctor {
void operator()(const DeviceContext& ctx, const Tensor* X, Tensor* out,
const std::vector<int64_t>& axes, FFTNormMode normalization,
bool forward);
};
template <typename DeviceContext, typename Ti, typename To>
struct FFTR2CFunctor {
void operator()(const DeviceContext& ctx, const Tensor* X, Tensor* out,
const std::vector<int64_t>& axes, FFTNormMode normalization,
bool forward);
};
template <typename DeviceContext, typename Ti, typename To>
struct FFTC2RFunctor {
void operator()(const DeviceContext& ctx, const Tensor* X, Tensor* out,
const std::vector<int64_t>& axes, FFTNormMode normalization,
bool forward);
};
// Giving a linear destination index and strides of tensor, get_idx return the
// corresponding linear position of source tensor.
// The linear index is the position of flatten tensor.
// Giving a linear destination index and strides of tensor, get_idx return the
// corresponding linear position of source tensor.
// The linear index is the position of flatten tensor.
HOSTDEVICE inline int64_t get_src_idx(const int64_t dst_idx,
const int64_t* dst_strides,
const int64_t* dst_shape,
const int64_t* src_strides,
const bool* is_fft_axis, const bool conj,
const int64_t rank) {
int64_t src_idx = 0;
int64_t quotient = dst_idx;
int64_t remainder = 0;
for (int64_t i = 0; i < rank; i++) {
remainder = quotient % dst_strides[i];
quotient = quotient / dst_strides[i];
if (conj && is_fft_axis[i]) {
src_idx += ((dst_shape[i] - quotient) % dst_shape[i]) * src_strides[i];
} else {
src_idx += src_strides[i] * quotient;
}
quotient = remainder;
}
return src_idx;
}
HOSTDEVICE inline bool is_conj_part(const int64_t dst_idx,
const int64_t* dst_strides,
const int64_t last_axis,
const int64_t last_axis_size) {
int64_t quotient = dst_idx;
int64_t remainder = 0;
for (int64_t i = 0; i < last_axis + 1; i++) {
remainder = quotient % dst_strides[i];
quotient = quotient / dst_strides[i];
if ((i == last_axis) && (quotient > last_axis_size - 1)) {
return true;
}
quotient = remainder;
}
return false;
}
// FFTFillConjFunctor fill the destination tensor with source tensor and
// conjugate symmetry element of source tensor .
// Use framework::ForRange to iterate destination element with
// supporting different device
template <typename C>
struct FFTFillConjFunctor {
FFTFillConjFunctor(const C* src_data, C* dst_data, const int64_t* src_strides,
const int64_t* dst_strides, const int64_t* dst_shape,
const bool* is_fft_axis, const int64_t last_axis,
const int64_t last_axis_size, const int64_t rank)
: src_data_(src_data),
dst_data_(dst_data),
src_strides_(src_strides),
dst_strides_(dst_strides),
dst_shape_(dst_shape),
is_fft_axis_(is_fft_axis),
last_axis_(last_axis),
last_axis_size_(last_axis_size),
rank_(rank) {}
HOSTDEVICE void operator()(int64_t dst_idx) {
if (is_conj_part(dst_idx, dst_strides_, last_axis_, last_axis_size_)) {
const auto conj_idx =
get_src_idx(dst_idx, dst_strides_, dst_shape_, src_strides_,
is_fft_axis_, true, rank_);
auto src_value = src_data_[conj_idx];
auto conj_value = C(src_value.real, -src_value.imag);
dst_data_[dst_idx] = conj_value;
} else {
const auto copy_idx =
get_src_idx(dst_idx, dst_strides_, dst_shape_, src_strides_,
is_fft_axis_, false, rank_);
dst_data_[dst_idx] = src_data_[copy_idx];
}
}
const C* src_data_;
C* dst_data_;
const int64_t* src_strides_;
const int64_t* dst_strides_;
const int64_t* dst_shape_;
const bool* is_fft_axis_;
const int64_t last_axis_;
const int64_t last_axis_size_;
const int64_t rank_;
};
template <typename DeviceContext, typename C>
void fill_conj(const DeviceContext& ctx, const Tensor* src, Tensor* dst,
const std::vector<int64_t>& axes) {
std::vector<int64_t> src_strides_v =
framework::vectorize<int64_t>(framework::stride(src->dims()));
std::vector<int64_t> dst_strides_v =
framework::vectorize<int64_t>(framework::stride(dst->dims()));
std::vector<int64_t> dst_shape_v = framework::vectorize<int64_t>(dst->dims());
const auto src_data = src->data<C>();
auto dst_data = dst->data<C>();
const auto last_axis = axes.back();
const auto last_axis_size = dst->dims().at(last_axis) / 2 + 1;
const int64_t rank = dst->dims().size();
auto _is_fft_axis = std::make_unique<bool[]>(rank);
for (const auto i : axes) {
_is_fft_axis[i] = true;
}
#if defined(__NVCC__) || defined(__HIPCC__)
const thrust::device_vector<int64_t> src_strides_g(src_strides_v);
const auto src_strides = thrust::raw_pointer_cast(src_strides_g.data());
const thrust::device_vector<int64_t> dst_strides_g(dst_strides_v);
const auto dst_strides = thrust::raw_pointer_cast(dst_strides_g.data());
const thrust::device_vector<int64_t> dst_shape_g(dst_shape_v);
const auto dst_shape = thrust::raw_pointer_cast(dst_shape_g.data());
const thrust::device_vector<bool> is_fft_axis_g(_is_fft_axis.get(),
_is_fft_axis.get() + rank);
const auto p_is_fft_axis = thrust::raw_pointer_cast(is_fft_axis_g.data());
#else
const auto src_strides = src_strides_v.data();
const auto dst_strides = dst_strides_v.data();
const auto dst_shape = dst_shape_v.data();
const auto p_is_fft_axis = _is_fft_axis.get();
#endif
platform::ForRange<DeviceContext> for_range(ctx, dst->numel());
FFTFillConjFunctor<C> fill_conj_functor(src_data, dst_data, src_strides,
dst_strides, dst_shape, p_is_fft_axis,
last_axis, last_axis_size, rank);
for_range(fill_conj_functor);
}
template <typename DeviceContext, typename T>
class FFTC2CKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using C = paddle::platform::complex<T>;
auto& dev_ctx = ctx.device_context<DeviceContext>();
auto axes = ctx.Attr<std::vector<int64_t>>("axes");
const std::string& norm_str = ctx.Attr<std::string>("normalization");
const bool forward = ctx.Attr<bool>("forward");
const auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Output<Tensor>("Out");
y->mutable_data<C>(ctx.GetPlace());
auto normalization = get_norm_from_string(norm_str, forward);
FFTC2CFunctor<DeviceContext, C, C> fft_c2c_func;
fft_c2c_func(dev_ctx, x, y, axes, normalization, forward);
}
};
template <typename DeviceContext, typename T>
class FFTC2CGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using C = paddle::platform::complex<T>;
auto& dev_ctx = ctx.device_context<DeviceContext>();
auto axes = ctx.Attr<std::vector<int64_t>>("axes");
const std::string& norm_str = ctx.Attr<std::string>("normalization");
const bool forward = ctx.Attr<bool>("forward");
const auto* dy = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
dx->mutable_data<C>(ctx.GetPlace());
auto normalization = get_norm_from_string(norm_str, forward);
FFTC2CFunctor<DeviceContext, C, C> fft_c2c_func;
fft_c2c_func(dev_ctx, dy, dx, axes, normalization, !forward);
}
};
template <typename DeviceContext, typename T>
class FFTR2CKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using C = paddle::platform::complex<T>;
auto& dev_ctx = ctx.device_context<DeviceContext>();
auto axes = ctx.Attr<std::vector<int64_t>>("axes");
const std::string& norm_str = ctx.Attr<std::string>("normalization");
const bool forward = ctx.Attr<bool>("forward");
const bool onesided = ctx.Attr<bool>("onesided");
const auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Output<Tensor>("Out");
y->mutable_data<C>(ctx.GetPlace());
auto normalization = get_norm_from_string(norm_str, forward);
FFTR2CFunctor<DeviceContext, T, C> fft_r2c_func;
if (onesided) {
fft_r2c_func(dev_ctx, x, y, axes, normalization, forward);
} else {
framework::DDim onesided_dims(y->dims());
const int64_t onesided_last_axis_size = y->dims().at(axes.back()) / 2 + 1;
onesided_dims.at(axes.back()) = onesided_last_axis_size;
framework::Tensor onesided_out;
onesided_out.mutable_data<C>(onesided_dims, ctx.GetPlace());
fft_r2c_func(dev_ctx, x, &onesided_out, axes, normalization, forward);
fill_conj<DeviceContext, C>(dev_ctx, &onesided_out, y, axes);
}
}
};
template <typename DeviceContext, typename T>
class FFTR2CGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using C = paddle::platform::complex<T>;
auto& dev_ctx = ctx.device_context<DeviceContext>();
const auto axes = ctx.Attr<std::vector<int64_t>>("axes");
const std::string& norm_str = ctx.Attr<std::string>("normalization");
const bool forward = ctx.Attr<bool>("forward");
const bool onesided = ctx.Attr<bool>("onesided");
const auto* dy = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
dx->mutable_data<T>(ctx.GetPlace());
framework::Tensor complex_dx;
complex_dx.mutable_data<C>(dx->dims(), ctx.GetPlace());
auto normalization = get_norm_from_string(norm_str, forward);
FFTC2CFunctor<DeviceContext, C, C> fft_c2c_func;
if (!onesided) {
fft_c2c_func(dev_ctx, dy, &complex_dx, axes, normalization, !forward);
} else {
framework::Tensor full_dy;
full_dy.mutable_data<C>(dx->dims(), ctx.GetPlace());
auto zero_length = static_cast<int>(full_dy.dims().at(axes.back()) -
dy->dims().at(axes.back()));
auto rank = dy->dims().size();
std::vector<int> pads(rank * 2, 0);
pads[axes.back() * 2 + 1] = zero_length;
paddle::operators::math::PaddingFunctor<DeviceContext, C>(
rank, ctx, pads, static_cast<C>(0), *dy, &full_dy);
fft_c2c_func(dev_ctx, &full_dy, &complex_dx, axes, normalization,
!forward);
}
framework::TransComplexToReal(dx->type(), complex_dx.type(), complex_dx,
dx);
}
};
template <typename DeviceContext, typename T>
class FFTC2RKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using C = paddle::platform::complex<T>;
auto& dev_ctx = ctx.device_context<DeviceContext>();
auto axes = ctx.Attr<std::vector<int64_t>>("axes");
const std::string& norm_str = ctx.Attr<std::string>("normalization");
const bool forward = ctx.Attr<bool>("forward");
const auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Output<Tensor>("Out");
y->mutable_data<T>(ctx.GetPlace());
auto normalization = get_norm_from_string(norm_str, forward);
FFTC2RFunctor<DeviceContext, C, T> fft_c2r_func;
fft_c2r_func(dev_ctx, x, y, axes, normalization, forward);
}
};
template <typename DeviceContext, typename T>
class FFTC2RGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using C = paddle::platform::complex<T>;
auto& dev_ctx = ctx.device_context<DeviceContext>();
auto axes = ctx.Attr<std::vector<int64_t>>("axes");
const std::string& norm_str = ctx.Attr<std::string>("normalization");
const bool forward = ctx.Attr<bool>("forward");
const auto* dy = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
C* pdx = dx->mutable_data<C>(ctx.GetPlace());
auto normalization = get_norm_from_string(norm_str, forward);
FFTR2CFunctor<DeviceContext, T, C> fft_r2c_func;
fft_r2c_func(dev_ctx, dy, dx, axes, normalization, !forward);
const int64_t double_length =
dy->dims()[axes.back()] - dx->dims()[axes.back()];
const framework::DDim strides = framework::stride(dx->dims());
#if defined(__NVCC__) || defined(__HIPCC__)
const thrust::device_vector<int64_t> strides_g(
framework::vectorize(strides));
const int64_t* pstrides = thrust::raw_pointer_cast(strides_g.data());
#else
const int64_t* pstrides = strides.Get();
#endif
FFTFillConjGradFunctor<C> func(pdx, axes.back(), pstrides, double_length);
size_t limit = dx->numel();
platform::ForRange<DeviceContext> for_range(dev_ctx, limit);
for_range(func);
}
};
} // namespace operators
} // namespace paddle
...@@ -389,7 +389,11 @@ REGISTER_OP_CPU_KERNEL( ...@@ -389,7 +389,11 @@ REGISTER_OP_CPU_KERNEL(
ops::SqueezeKernel<paddle::platform::CPUDeviceContext, int>, ops::SqueezeKernel<paddle::platform::CPUDeviceContext, int>,
ops::SqueezeKernel<paddle::platform::CPUDeviceContext, uint8_t>, ops::SqueezeKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::SqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>, ops::SqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::SqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::SqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::SqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::SqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
squeeze_grad, squeeze_grad,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, float>, ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, float>,
...@@ -398,7 +402,12 @@ REGISTER_OP_CPU_KERNEL( ...@@ -398,7 +402,12 @@ REGISTER_OP_CPU_KERNEL(
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, int>, ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, uint8_t>, ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, int8_t>, ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
squeeze2, ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, float>, squeeze2, ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, float>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, double>, ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, double>,
...@@ -406,7 +415,12 @@ REGISTER_OP_CPU_KERNEL( ...@@ -406,7 +415,12 @@ REGISTER_OP_CPU_KERNEL(
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, int>, ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, int>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, uint8_t>, ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, int8_t>, ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, int64_t>); ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
squeeze2_grad, squeeze2_grad,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, float>, ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, float>,
...@@ -415,4 +429,8 @@ REGISTER_OP_CPU_KERNEL( ...@@ -415,4 +429,8 @@ REGISTER_OP_CPU_KERNEL(
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, int>, ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, int>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, uint8_t>, ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, int8_t>, ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
...@@ -25,7 +25,11 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -25,7 +25,11 @@ REGISTER_OP_CUDA_KERNEL(
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, int>, ops::SqueezeKernel<paddle::platform::CUDADeviceContext, int>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, uint8_t>, ops::SqueezeKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>, ops::SqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::SqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
squeeze_grad, squeeze_grad,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, float>, ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, float>,
...@@ -35,7 +39,11 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -35,7 +39,11 @@ REGISTER_OP_CUDA_KERNEL(
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, int>, ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, uint8_t>, ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, int8_t>, ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
squeeze2, ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, float>, squeeze2, ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, float>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, double>, ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, double>,
...@@ -44,7 +52,11 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -44,7 +52,11 @@ REGISTER_OP_CUDA_KERNEL(
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int>, ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int8_t>, ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, uint8_t>, ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int64_t>); ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
squeeze2_grad, squeeze2_grad,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, float>, ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, float>,
...@@ -54,4 +66,8 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -54,4 +66,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int>, ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int8_t>, ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, uint8_t>, ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
...@@ -362,7 +362,11 @@ REGISTER_OP_CPU_KERNEL( ...@@ -362,7 +362,11 @@ REGISTER_OP_CPU_KERNEL(
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int>, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, uint8_t>, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
unsqueeze_grad, unsqueeze_grad,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, float>, ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, float>,
...@@ -371,7 +375,11 @@ REGISTER_OP_CPU_KERNEL( ...@@ -371,7 +375,11 @@ REGISTER_OP_CPU_KERNEL(
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int>, ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, uint8_t>, ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int8_t>, ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
unsqueeze2, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, float>, unsqueeze2, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, float>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, double>, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, double>,
...@@ -379,7 +387,11 @@ REGISTER_OP_CPU_KERNEL( ...@@ -379,7 +387,11 @@ REGISTER_OP_CPU_KERNEL(
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int>, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, uint8_t>, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
unsqueeze2_grad, unsqueeze2_grad,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, float>, ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, float>,
...@@ -388,4 +400,8 @@ REGISTER_OP_CPU_KERNEL( ...@@ -388,4 +400,8 @@ REGISTER_OP_CPU_KERNEL(
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int>, ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, uint8_t>, ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int8_t>, ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
...@@ -25,7 +25,11 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -25,7 +25,11 @@ REGISTER_OP_CUDA_KERNEL(
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, uint8_t>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
unsqueeze_grad, unsqueeze_grad,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, float>, ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, float>,
...@@ -36,7 +40,11 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -36,7 +40,11 @@ REGISTER_OP_CUDA_KERNEL(
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int>, ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int8_t>, ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, uint8_t>, ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
unsqueeze2, unsqueeze2,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, float>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, float>,
...@@ -46,7 +54,11 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -46,7 +54,11 @@ REGISTER_OP_CUDA_KERNEL(
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, uint8_t>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
unsqueeze2_grad, unsqueeze2_grad,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, float>, ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, float>,
...@@ -57,4 +69,8 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -57,4 +69,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int>, ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, uint8_t>, ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int8_t>, ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
...@@ -60,6 +60,8 @@ struct PADDLE_ALIGN(sizeof(T) * 2) complex { ...@@ -60,6 +60,8 @@ struct PADDLE_ALIGN(sizeof(T) * 2) complex {
T real; T real;
T imag; T imag;
using value_type = T;
complex() = default; complex() = default;
complex(const complex<T>& o) = default; complex(const complex<T>& o) = default;
complex& operator=(const complex<T>& o) = default; complex& operator=(const complex<T>& o) = default;
......
cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags enforce) cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags enforce)
list(APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc cusolver.cc cusparse.cc nvtx.cc) list(APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc cusolver.cc cusparse.cc nvtx.cc cufft.cc)
if (NOT WITH_NV_JETSON) if (NOT WITH_NV_JETSON)
list(APPEND CUDA_SRCS nvjpeg.cc) list(APPEND CUDA_SRCS nvjpeg.cc)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/dynload/cufft.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace platform {
namespace dynload {
std::once_flag cufft_dso_flag;
void* cufft_dso_handle = nullptr;
#define DEFINE_WRAP(__name) DynLoad__##__name __name
CUFFT_FFT_ROUTINE_EACH(DEFINE_WRAP);
bool HasCUFFT() {
std::call_once(cufft_dso_flag,
[]() { cufft_dso_handle = GetCUFFTDsoHandle(); });
return cufft_dso_handle != nullptr;
}
void EnforceCUFFTLoaded(const char* fn_name) {
PADDLE_ENFORCE_NOT_NULL(
cufft_dso_handle,
platform::errors::PreconditionNotMet(
"Cannot load cufft shared library. Cannot invoke method %s.",
fn_name));
}
} // namespace dynload
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_CUDA
#include <cufft.h>
#include <cufftXt.h>
#include <glog/logging.h>
#include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/platform/port.h"
namespace paddle {
namespace platform {
namespace dynload {
extern std::once_flag cufft_dso_flag;
extern void* cufft_dso_handle;
extern bool HasCUFFT();
extern void EnforceCUFFTLoaded(const char* fn_name);
#define DECLARE_DYNAMIC_LOAD_CUFFT_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
using cufft_func = decltype(&::__name); \
std::call_once(cufft_dso_flag, []() { \
cufft_dso_handle = paddle::platform::dynload::GetCUFFTDsoHandle(); \
}); \
EnforceCUFFTLoaded(#__name); \
static void* p_##__name = dlsym(cufft_dso_handle, #__name); \
return reinterpret_cast<cufft_func>(p_##__name)(args...); \
} \
}; \
extern struct DynLoad__##__name __name
/**
* include all needed cufft functions in HPPL
* different cufft version has different interfaces
**/
#define CUFFT_FFT_ROUTINE_EACH(__macro) \
__macro(cufftPlan1d); \
__macro(cufftPlan2d); \
__macro(cufftPlan3d); \
__macro(cufftPlanMany); \
__macro(cufftMakePlan1d); \
__macro(cufftMakePlan2d); \
__macro(cufftMakePlan3d); \
__macro(cufftMakePlanMany); \
__macro(cufftMakePlanMany64); \
__macro(cufftGetSizeMany64); \
__macro(cufftEstimate1d); \
__macro(cufftEstimate2d); \
__macro(cufftEstimate3d); \
__macro(cufftEstimateMany); \
__macro(cufftCreate); \
__macro(cufftGetSize1d); \
__macro(cufftGetSize2d); \
__macro(cufftGetSize3d); \
__macro(cufftGetSizeMany); \
__macro(cufftGetSize); \
__macro(cufftSetWorkArea); \
__macro(cufftSetAutoAllocation); \
__macro(cufftExecC2C); \
__macro(cufftExecR2C); \
__macro(cufftExecC2R); \
__macro(cufftExecZ2Z); \
__macro(cufftExecD2Z); \
__macro(cufftExecZ2D); \
__macro(cufftSetStream); \
__macro(cufftDestroy); \
__macro(cufftGetVersion); \
__macro(cufftGetProperty); \
__macro(cufftXtSetGPUs); \
__macro(cufftXtMalloc); \
__macro(cufftXtMemcpy); \
__macro(cufftXtFree); \
__macro(cufftXtSetWorkArea); \
__macro(cufftXtExecDescriptorC2C); \
__macro(cufftXtExecDescriptorR2C); \
__macro(cufftXtExecDescriptorC2R); \
__macro(cufftXtExecDescriptorZ2Z); \
__macro(cufftXtExecDescriptorD2Z); \
__macro(cufftXtExecDescriptorZ2D); \
__macro(cufftXtQueryPlan); \
__macro(cufftXtSetCallback); \
__macro(cufftXtClearCallback); \
__macro(cufftXtSetCallbackSharedSize); \
__macro(cufftXtMakePlanMany); \
__macro(cufftXtGetSizeMany); \
__macro(cufftXtExec); \
__macro(cufftXtExecDescriptor); \
__macro(cufftXtSetWorkAreaPolicy);
CUFFT_FFT_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUFFT_WRAP)
} // namespace dynload
} // namespace platform
} // namespace paddle
#endif
...@@ -109,6 +109,9 @@ static constexpr char* win_cusolver_lib = ...@@ -109,6 +109,9 @@ static constexpr char* win_cusolver_lib =
static constexpr char* win_cusparse_lib = static constexpr char* win_cusparse_lib =
"cusparse64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR "cusparse64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR
".dll;cusparse64_" CUDA_VERSION_MAJOR ".dll;cusparse64_10.dll"; ".dll;cusparse64_" CUDA_VERSION_MAJOR ".dll;cusparse64_10.dll";
static constexpr char* win_cufft_lib =
"cufft64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR
".dll;cufft64_" CUDA_VERSION_MAJOR ".dll;cufft64_10.dll";
#else #else
static constexpr char* win_curand_lib = static constexpr char* win_curand_lib =
"curand64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR "curand64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR
...@@ -122,6 +125,9 @@ static constexpr char* win_cusolver_lib = ...@@ -122,6 +125,9 @@ static constexpr char* win_cusolver_lib =
static constexpr char* win_cusparse_lib = static constexpr char* win_cusparse_lib =
"cusparse64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR "cusparse64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR
".dll;cusparse64_" CUDA_VERSION_MAJOR ".dll"; ".dll;cusparse64_" CUDA_VERSION_MAJOR ".dll";
static constexpr char* win_cufft_lib =
"cufft64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR
".dll;cufft64_" CUDA_VERSION_MAJOR ".dll";
#endif // CUDA_VERSION #endif // CUDA_VERSION
#endif #endif
...@@ -489,6 +495,17 @@ void* GetNvtxDsoHandle() { ...@@ -489,6 +495,17 @@ void* GetNvtxDsoHandle() {
#endif #endif
} }
void* GetCUFFTDsoHandle() {
#if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcufft.dylib");
#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA)
return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, win_cufft_lib, true,
{cuda_lib_path});
#else
return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcufft.so");
#endif
}
} // namespace dynload } // namespace dynload
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -41,6 +41,7 @@ void* GetTensorRtDsoHandle(); ...@@ -41,6 +41,7 @@ void* GetTensorRtDsoHandle();
void* GetMKLMLDsoHandle(); void* GetMKLMLDsoHandle();
void* GetOpDsoHandle(const std::string& dso_name); void* GetOpDsoHandle(const std::string& dso_name);
void* GetNvtxDsoHandle(); void* GetNvtxDsoHandle();
void* GetCUFFTDsoHandle();
void SetPaddleLibPath(const std::string&); void SetPaddleLibPath(const std::string&);
} // namespace dynload } // namespace dynload
......
...@@ -64,6 +64,7 @@ import paddle.reader # noqa: F401 ...@@ -64,6 +64,7 @@ import paddle.reader # noqa: F401
import paddle.static # noqa: F401 import paddle.static # noqa: F401
import paddle.vision # noqa: F401 import paddle.vision # noqa: F401
from .tensor import fft
from .tensor.random import bernoulli # noqa: F401 from .tensor.random import bernoulli # noqa: F401
from .tensor.attribute import rank # noqa: F401 from .tensor.attribute import rank # noqa: F401
......
...@@ -6727,8 +6727,10 @@ def pad(x, paddings, pad_value=0., name=None): ...@@ -6727,8 +6727,10 @@ def pad(x, paddings, pad_value=0., name=None):
x = fluid.data(name='data', shape=[300, 300], dtype='float32') x = fluid.data(name='data', shape=[300, 300], dtype='float32')
out = fluid.layers.pad(x=x, paddings=[0, 1, 1, 2], pad_value=0.) out = fluid.layers.pad(x=x, paddings=[0, 1, 1, 2], pad_value=0.)
""" """
check_variable_and_dtype( check_variable_and_dtype(x, 'x', [
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], "pad") 'float16', 'float32', 'float64', 'int32', 'int64', 'complex64',
'complex128'
], "pad")
helper = LayerHelper('pad', **locals()) helper = LayerHelper('pad', **locals())
dtype = helper.input_dtype(input_param_name='x') dtype = helper.input_dtype(input_param_name='x')
......
...@@ -702,6 +702,7 @@ endif() ...@@ -702,6 +702,7 @@ endif()
add_subdirectory(sequence) add_subdirectory(sequence)
add_subdirectory(dygraph_to_static) add_subdirectory(dygraph_to_static)
add_subdirectory(rnn) add_subdirectory(rnn)
add_subdirectory(fft)
if (WITH_XPU) if (WITH_XPU)
add_subdirectory(xpu) add_subdirectory(xpu)
......
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach(TEST_OP)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from functools import partial
from numpy import asarray
from numpy.fft._pocketfft import _raw_fft, _raw_fftnd, _get_forward_norm, _get_backward_norm, _cook_nd_args
def _fftc2c(a, n=None, axis=-1, norm=None, forward=None):
a = asarray(a)
if n is None:
n = a.shape[axis]
if forward:
inv_norm = _get_forward_norm(n, norm)
else:
inv_norm = _get_backward_norm(n, norm)
output = _raw_fft(a, n, axis, False, forward, inv_norm)
return output
def _fftr2c(a, n=None, axis=-1, norm=None, forward=None):
a = asarray(a)
if n is None:
n = a.shape[axis]
if forward:
inv_norm = _get_forward_norm(n, norm)
else:
inv_norm = _get_backward_norm(n, norm)
output = _raw_fft(a, n, axis, True, True, inv_norm)
if not forward:
output = output.conj()
return output
def _fftc2r(a, n=None, axis=-1, norm=None, forward=None):
a = asarray(a)
if n is None:
n = (a.shape[axis] - 1) * 2
if forward:
inv_norm = _get_forward_norm(n, norm)
else:
inv_norm = _get_backward_norm(n, norm)
output = _raw_fft(a.conj()
if forward else a, n, axis, True, False, inv_norm)
return output
def fft_c2c(x, axes, normalization, forward):
f = partial(_fftc2c, forward=forward)
y = _raw_fftnd(x, s=None, axes=axes, function=f, norm=normalization)
return y
def fft_c2c_backward(dy, axes, normalization, forward):
f = partial(_fftc2c, forward=forward)
dx = _raw_fftnd(dy, s=None, axes=axes, function=f, norm=normalization)
return dx
def fft_r2c(x, axes, normalization, forward, onesided):
a = asarray(x)
s, axes = _cook_nd_args(a, axes=axes)
if onesided:
a = _fftr2c(a, s[-1], axes[-1], normalization, forward)
for ii in range(len(axes) - 1):
a = _fftc2c(a, s[ii], axes[ii], normalization, forward)
else:
a = fft_c2c(x, axes, normalization, forward)
return a
def fft_r2c_backward(dy, x, axes, normalization, forward, onesided):
a = dy
if not onesided:
a = fft_c2c_backward(a, axes, normalization, forward).real
else:
pad_widths = [(0, 0)] * a.ndim
last_axis = axes[-1]
if last_axis < 0:
last_axis += a.ndim
last_dim_size = a.shape[last_axis]
pad_widths[last_axis] = (0, x.shape[last_axis] - last_dim_size)
a = np.pad(a, pad_width=pad_widths)
a = fft_c2c_backward(a, axes, normalization, forward).real
return a
def fft_c2r(x, axes, normalization, forward, last_dim_size):
a = asarray(x)
s, axes = _cook_nd_args(a, axes=axes, invreal=1)
if last_dim_size is not None:
s[-1] = last_dim_size
for ii in range(len(axes) - 1):
a = _fftc2c(a, s[ii], axes[ii], normalization, forward)
a = _fftc2r(a, s[-1], axes[-1], normalization, forward)
return a
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import re
import sys
import unittest
import numpy as np
import paddle
import scipy.fft
DEVICES = [paddle.CPUPlace()]
if paddle.is_compiled_with_cuda():
DEVICES.append(paddle.CUDAPlace(0))
TEST_CASE_NAME = 'suffix'
# All test case will use float64 for compare percision, refs:
# https://github.com/PaddlePaddle/Paddle/wiki/Upgrade-OP-Precision-to-Float64
RTOL = {
'float32': 1e-03,
'complex64': 1e-3,
'float64': 1e-7,
'complex128': 1e-7
}
ATOL = {'float32': 0.0, 'complex64': 0, 'float64': 0.0, 'complex128': 0}
def rand_x(dims=1,
dtype='float64',
min_dim_len=1,
max_dim_len=10,
complex=False):
shape = [np.random.randint(min_dim_len, max_dim_len) for i in range(dims)]
if complex:
return np.random.randn(*shape).astype(dtype) + 1.j * np.random.randn(
*shape).astype(dtype)
else:
return np.random.randn(*shape).astype(dtype)
def place(devices, key='place'):
def decorate(cls):
module = sys.modules[cls.__module__].__dict__
raw_classes = {
k: v
for k, v in module.items() if k.startswith(cls.__name__)
}
for raw_name, raw_cls in raw_classes.items():
for d in devices:
test_cls = dict(raw_cls.__dict__)
test_cls.update({key: d})
new_name = raw_name + '.' + d.__class__.__name__
module[new_name] = type(new_name, (raw_cls, ), test_cls)
del module[raw_name]
return cls
return decorate
def parameterize(fields, values=None):
fields = [fields] if isinstance(fields, str) else fields
params = [dict(zip(fields, vals)) for vals in values]
def decorate(cls):
test_cls_module = sys.modules[cls.__module__].__dict__
for k, v in enumerate(params):
test_cls = dict(cls.__dict__)
test_cls.update(v)
name = cls.__name__ + str(k)
name = name + '.' + v.get('suffix') if v.get('suffix') else name
test_cls_module[name] = type(name, (cls, ), test_cls)
for m in list(cls.__dict__):
if m.startswith("test"):
delattr(cls, m)
return cls
return decorate
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'),
[('test_x_float64', rand_x(5, np.float64), None, -1, 'backward'),
('test_x_complex', rand_x(
5, complex=True), None, -1,
'backward'), ('test_n_grater_input_length', rand_x(
5, max_dim_len=5), 11, -1,
'backward'), ('test_n_smaller_than_input_length', rand_x(
5, min_dim_len=5, complex=True), 3, -1, 'backward'),
('test_axis_not_last', rand_x(5), None, 3, 'backward'),
('test_norm_forward', rand_x(5), None, 3, 'forward'),
('test_norm_ortho', rand_x(5), None, 3, 'ortho')])
class TestFft(unittest.TestCase):
def test_fft(self):
with paddle.fluid.dygraph.guard(self.place):
self.assertTrue(
np.allclose(
scipy.fft.fft(self.x, self.n, self.axis, self.norm),
paddle.fft.fft(
paddle.to_tensor(self.x), self.n, self.axis, self.norm),
rtol=RTOL.get(str(self.x.dtype)),
atol=ATOL.get(str(self.x.dtype))))
@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [
('test_n_nagative', rand_x(2), -1, -1, 'backward', ValueError),
('test_n_zero', rand_x(2), 0, -1, 'backward', ValueError),
('test_axis_out_of_range', rand_x(1), None, 10, 'backward', ValueError),
('test_axis_with_array', rand_x(1), None, (0, 1), 'backward', ValueError),
('test_norm_not_in_enum_value', rand_x(2), None, -1, 'random', ValueError)
])
class TestFftException(unittest.TestCase):
def test_Fft(self):
with self.assertRaises(self.expect_exception):
paddle.fft.fft(
paddle.to_tensor(self.x), self.n, self.axis, self.norm)
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [
('test_x_float64', rand_x(5), None, (0, 1), 'backward'),
('test_x_complex128', rand_x(
5, complex=True), None, (0, 1), 'backward'),
('test_n_grater_input_length', rand_x(
5, max_dim_len=5), (6, 6), (0, 1), 'backward'),
('test_n_smaller_than_input_length', rand_x(
5, min_dim_len=5, complex=True), (4, 4), (0, 1), 'backward'),
('test_axis_random', rand_x(5), None, (1, 2), 'backward'),
('test_axis_none', rand_x(5), None, None, 'backward'),
('test_norm_forward', rand_x(5), None, (0, 1), 'forward'),
('test_norm_ortho', rand_x(5), None, (0, 1), 'ortho'),
])
class TestFft2(unittest.TestCase):
def test_Fft2(self):
with paddle.fluid.dygraph.guard(self.place):
self.assertTrue(
np.allclose(
scipy.fft.fft2(self.x, self.n, self.axis, self.norm),
paddle.fft.fft2(
paddle.to_tensor(self.x), self.n, self.axis, self.norm),
rtol=RTOL.get(str(self.x.dtype)),
atol=ATOL.get(str(self.x.dtype))))
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'),
[('test_x_complex_input', rand_x(
2, complex=True), None, (0, 1), None,
ValueError), ('test_x_1dim_tensor', rand_x(1), None, (0, 1), None,
ValueError), ('test_n_nagative', rand_x(2), -1, (0, 1),
'backward', ValueError),
('test_n_len_not_equal_axis', rand_x(
5, max_dim_len=5), 11, (0, 1), 'backward',
ValueError), ('test_n_zero', rand_x(2), (0, 0), (0, 1), 'backward',
ValueError), ('test_axis_out_of_range', rand_x(2), None,
(0, 1, 2), 'backward', ValueError),
('test_axis_with_array', rand_x(1), None, (0, 1), 'backward', ValueError),
('test_axis_not_sequence', rand_x(5), None, -10, 'backward', ValueError),
('test_norm_not_enum', rand_x(2), None, -1, 'random', ValueError)])
class TestFft2Exception(unittest.TestCase):
def test_fft2(self):
with paddle.fluid.dygraph.guard(self.place):
with self.assertRaises(self.expect_exception):
paddle.fft.fft2(
paddle.to_tensor(self.x), self.n, self.axis, self.norm)
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'),
[('test_x_float64', rand_x(5, np.float64), None, None, 'backward'),
('test_x_complex128', rand_x(
5, complex=True), None, None,
'backward'), ('test_n_grater_input_length', rand_x(
5, max_dim_len=5), (6, 6), (1, 2), 'backward'), (
'test_n_smaller_input_length', rand_x(
5, min_dim_len=5, complex=True), (3, 3), (1, 2), 'backward'),
('test_axis_not_default', rand_x(5), None, (1, 2),
'backward'), ('test_norm_forward', rand_x(5), None, None, 'forward'),
('test_norm_ortho', rand_x(5), None, None, 'ortho')])
class TestFftn(unittest.TestCase):
def test_Fftn(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
scipy.fft.fftn(self.x, self.n, self.axis, self.norm),
paddle.fft.fftn(
paddle.to_tensor(self.x), self.n, self.axis, self.norm),
rtol=RTOL.get(str(self.x.dtype)),
atol=ATOL.get(str(self.x.dtype)))
@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [
('test_x_complex128',
(np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4)
).astype(np.complex128), None, -1, "backward"),
('test_n_grater_than_input_length',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), 4, -1,
"backward"),
('test_n_smaller_than_input_length',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), 2, -1,
"backward"),
('test_axis_not_last',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, 1,
"backward"),
('test_norm_forward',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, 1,
"forward"),
('test_norm_ortho',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, -1,
"ortho"),
])
class TestHfft(unittest.TestCase):
"""Test hfft with norm condition
"""
def test_hfft(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
scipy.fft.hfft(self.x, self.n, self.axis, self.norm),
paddle.fft.hfft(
paddle.to_tensor(self.x), self.n, self.axis, self.norm),
rtol=1e-5,
atol=0)
@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [
('test_x_complex128',
(np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4)
).astype(np.complex128), None, -1, "backward"),
('test_n_grater_than_input_length',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), 4, -1,
"backward"),
('test_n_smaller_than_input_length',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), 2, -1,
"backward"),
('test_axis_not_last',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, -1,
"backward"),
('test_norm_forward',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, -1,
"forward"),
('test_norm_ortho',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, -1,
"ortho"),
])
class TestIrfft(unittest.TestCase):
"""Test irfft with norm condition
"""
def test_irfft(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
scipy.fft.irfft(self.x, self.n, self.axis, self.norm),
paddle.fft.irfft(
paddle.to_tensor(self.x), self.n, self.axis, self.norm),
rtol=1e-5,
atol=0)
@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [
('test_x_complex128',
(np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4)
).astype(np.complex128), None, None, "backward"),
('test_n_grater_than_input_length',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), [4], None,
"backward"),
('test_n_smaller_than_input_length',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), [2], None,
"backward"),
('test_axis_not_last',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, None,
"backward"),
('test_norm_forward',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, None,
"forward"),
('test_norm_ortho',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, None,
"ortho"),
])
class Testirfftn(unittest.TestCase):
"""Test irfftn with norm condition
"""
def test_irfftn(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
scipy.fft.irfftn(self.x, self.n, self.axis, self.norm),
paddle.fft.irfftn(
paddle.to_tensor(self.x), self.n, self.axis, self.norm),
rtol=1e-5,
atol=0)
@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [
('test_x_complex128',
(np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4)
).astype(np.complex128), None, None, "backward"),
('test_n_grater_than_input_length',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), [4], None,
"backward"),
('test_n_smaller_than_input_length',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), [2], None,
"backward"),
('test_axis_not_last',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, None,
"backward"),
('test_norm_forward',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, None,
"forward"),
('test_norm_ortho',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, None,
"ortho"),
])
class Testhfftn(unittest.TestCase):
"""Test hfftn with norm condition
"""
def test_hfftn(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
scipy.fft.hfftn(self.x, self.n, self.axis, self.norm),
paddle.fft.hfftn(
paddle.to_tensor(self.x), self.n, self.axis, self.norm),
rtol=1e-5,
atol=0)
@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'x', 's', 'axis', 'norm'), [
('test_x_complex128',
(np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4)
).astype(np.complex128), None, (-2, -1), "backward"),
('test_with_s', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
[2, 2], (-2, -1), "backward", ValueError),
('test_axis_not_last',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1),
"backward"),
('test_norm_forward',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1),
"forward"),
('test_norm_ortho',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1),
"ortho"),
])
class Testhfft2(unittest.TestCase):
"""Test hfft2 with norm condition
"""
def test_hfft2(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
scipy.fft.hfft2(self.x, self.s, self.axis, self.norm),
paddle.fft.hfft2(
paddle.to_tensor(self.x), self.s, self.axis, self.norm),
rtol=1e-5,
atol=0)
@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'x', 's', 'axis', 'norm'), [
('test_x_complex128',
(np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4)
).astype(np.complex128), None, (-2, -1), "backward"),
('test_n_equal_input_length',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (4, 6), (-2, -1),
"backward"),
('test_axis_not_last',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1),
"backward"),
('test_norm_forward',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1),
"forward"),
('test_norm_ortho',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1),
"ortho"),
])
class TestIrfft2(unittest.TestCase):
"""Test irfft2 with norm condition
"""
def test_irfft2(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
scipy.fft.irfft2(self.x, self.s, self.axis, self.norm),
paddle.fft.irfft2(
paddle.to_tensor(self.x), self.s, self.axis, self.norm),
rtol=1e-5,
atol=0)
@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [(
'test_bool_input',
(np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4)).astype(np.bool8),
None, -1, 'backward', NotImplementedError), (
'test_n_nagative',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), -1, -1,
'backward', ValueError), (
'test_n_zero', np.random.randn(4, 4) + 1j * np.random.randn(4, 4),
0, -1, 'backward', ValueError), (
'test_n_type',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
(1, 2, 3), -1, 'backward', ValueError), (
'test_axis_out_of_range',
np.random.randn(4) + 1j * np.random.randn(4), None, 10,
'backward', ValueError), (
'test_axis_with_array',
np.random.randn(4) + 1j * np.random.randn(4), None,
(0, 1), 'backward', ValueError), (
'test_norm_not_in_enum_value',
np.random.randn(4, 4) + 1j * np.random.randn(4, 4),
None, -1, 'random', ValueError)])
class TestHfftException(unittest.TestCase):
'''Test hfft with buoudary condition
Test case include:
- n out of range
- axis out of range
- norm out of range
'''
def test_hfft(self):
with paddle.fluid.dygraph.guard(self.place):
with self.assertRaises(self.expect_exception):
paddle.fft.hfft(
paddle.to_tensor(self.x), self.n, self.axis, self.norm)
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'),
[('test_n_nagative',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), -1, -1,
'backward', ValueError),
('test_n_zero', np.random.randn(4, 4) + 1j * np.random.randn(4, 4), 0, -1,
'backward', ValueError),
('test_n_type', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
(1, 2), -1, 'backward', ValueError),
('test_axis_out_of_range', np.random.randn(4) + 1j * np.random.randn(4),
None, 10, 'backward', ValueError),
('test_axis_with_array', np.random.randn(4) + 1j * np.random.randn(4),
None, (0, 1), 'backward',
ValueError), ('test_norm_not_in_enum_value',
np.random.randn(4, 4) + 1j * np.random.randn(4, 4), None,
None, 'random', ValueError)])
class TestIrfftException(unittest.TestCase):
'''Test Irfft with buoudary condition
Test case include:
- n out of range
- axis out of range
- norm out of range
- the dimensions of n and axis are different
'''
def test_irfft(self):
with paddle.fluid.dygraph.guard(self.place):
with self.assertRaises(self.expect_exception):
paddle.fft.irfft(
paddle.to_tensor(self.x), self.n, self.axis, self.norm)
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'),
[('test_bool_input',
(np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4)
).astype(np.bool8), None, (-2, -1), 'backward', NotImplementedError),
('test_n_nagative',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (-1, -2),
(-2, -1), 'backward', ValueError),
('test_n_zero', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
(0, 0), (-2, -1), 'backward', ValueError),
('test_n_type', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
3, None, 'backward', ValueError),
('test_n_axis_dim',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (1, 2), (-1),
'backward', ValueError), ('test_axis_out_of_range',
np.random.randn(4) + 1j * np.random.randn(4),
None, (1, 2), 'backward', ValueError),
('test_axis_type', np.random.randn(4) + 1j * np.random.randn(4), None, -1,
'backward',
ValueError), ('test_norm_not_in_enum_value',
np.random.randn(4, 4) + 1j * np.random.randn(4, 4), None,
None, 'random', ValueError)])
class TestHfft2Exception(unittest.TestCase):
'''Test hfft2 with buoudary condition
Test case include:
- n out of range
- axis out of range
- the dimensions of n and axis are different
- norm out of range
'''
def test_hfft2(self):
with paddle.fluid.dygraph.guard(self.place):
with self.assertRaises(self.expect_exception):
paddle.fft.hfft2(
paddle.to_tensor(self.x), self.n, self.axis, self.norm)
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'),
[('test_n_nagative',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (-1, -2),
(-2, -1), 'backward', ValueError),
('test_zero_point',
np.random.randn(4, 4, 1) + 1j * np.random.randn(4, 4, 1), None, (-2, -1),
"backward", ValueError),
('test_n_zero', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
(0, 0), (-2, -1), 'backward', ValueError),
('test_n_type', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
3, -1, 'backward',
ValueError), ('test_n_axis_dim',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
(1, 2), (-3, -2, -1), 'backward', ValueError),
('test_axis_out_of_range', np.random.randn(4) + 1j * np.random.randn(4),
None, (1, 2), 'backward', ValueError), (
'test_axis_type', np.random.randn(4) + 1j * np.random.randn(4), None,
1, 'backward',
ValueError), ('test_norm_not_in_enum_value',
np.random.randn(4, 4) + 1j * np.random.randn(4, 4),
None, None, 'random', ValueError)])
class TestIrfft2Exception(unittest.TestCase):
'''Test irfft2 with buoudary condition
Test case include:
- n out of range
- axis out of range
- norm out of range
- the dimensions of n and axis are different
'''
def test_irfft2(self):
with paddle.fluid.dygraph.guard(self.place):
with self.assertRaises(self.expect_exception):
paddle.fft.irfft2(
paddle.to_tensor(self.x), self.n, self.axis, self.norm)
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'),
[('test_bool_input',
(np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4)
).astype(np.bool8), None, (-2, -1), 'backward', NotImplementedError),
('test_n_nagative',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (-1, -2),
(-2, -1), 'backward', ValueError),
('test_n_zero', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
(0, 0), (-2, -1), 'backward', ValueError),
('test_n_type', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
3, -1, 'backward', ValueError),
('test_n_axis_dim',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
(1, 2), (-3, -2, -1), 'backward',
ValueError), ('test_axis_out_of_range',
np.random.randn(4) + 1j * np.random.randn(4), None,
(10, 20), 'backward', ValueError),
('test_axis_type', np.random.randn(4) + 1j * np.random.randn(4), None, 1,
'backward',
ValueError), ('test_norm_not_in_enum_value',
np.random.randn(4, 4) + 1j * np.random.randn(4, 4), None,
None, 'random', ValueError)])
class TestHfftnException(unittest.TestCase):
'''Test hfftn with buoudary condition
Test case include:
- n out of range
- axis out of range
- norm out of range
- the dimensions of n and axis are different
'''
def test_hfftn(self):
with paddle.fluid.dygraph.guard(self.place):
with self.assertRaises(self.expect_exception):
paddle.fft.hfftn(
paddle.to_tensor(self.x), self.n, self.axis, self.norm)
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'),
[('test_n_nagative',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (-1, -2),
(-2, -1), 'backward', ValueError),
('test_n_zero', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
(0, 0), (-2, -1), 'backward', ValueError),
('test_n_type', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
3, -1, 'backward',
ValueError), ('test_n_axis_dim',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
(1, 2), (-3, -2, -1), 'backward', ValueError),
('test_axis_out_of_range', np.random.randn(4) + 1j * np.random.randn(4),
None, (10, 20), 'backward', ValueError),
('test_axis_type', np.random.randn(4) + 1j * np.random.randn(4), None, 1,
'backward',
ValueError), ('test_norm_not_in_enum_value',
np.random.randn(4, 4) + 1j * np.random.randn(4, 4), None,
None, 'random', ValueError)])
class TestIrfftnException(unittest.TestCase):
'''Test irfftn with buoudary condition
Test case include:
- n out of range
- axis out of range
- norm out of range
- the dimensions of n and axis are different
'''
def test_irfftn(self):
with paddle.fluid.dygraph.guard(self.place):
with self.assertRaises(self.expect_exception):
paddle.fft.irfftn(
paddle.to_tensor(self.x), self.n, self.axis, self.norm)
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'),
[('test_x_float64', rand_x(5, np.float64), None, -1, 'backward'), (
'test_n_grater_than_input_length', rand_x(
5, max_dim_len=5), 11, -1, 'backward'),
('test_n_smaller_than_input_length', rand_x(
5, min_dim_len=5), 3, -1,
'backward'), ('test_axis_not_last', rand_x(5), None, 3, 'backward'),
('test_norm_forward', rand_x(5), None, 3, 'forward'),
('test_norm_ortho', rand_x(5), None, 3, 'ortho')])
class TestRfft(unittest.TestCase):
def test_rfft(self):
with paddle.fluid.dygraph.guard(self.place):
self.assertTrue(
np.allclose(
scipy.fft.rfft(self.x, self.n, self.axis, self.norm),
paddle.fft.rfft(
paddle.to_tensor(self.x), self.n, self.axis, self.norm),
rtol=RTOL.get(str(self.x.dtype)),
atol=ATOL.get(str(self.x.dtype))))
@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [
('test_n_nagative', rand_x(2), -1, -1, 'backward', ValueError),
('test_n_zero', rand_x(2), 0, -1, 'backward', ValueError),
('test_axis_out_of_range', rand_x(1), None, 10, 'backward', ValueError),
('test_axis_with_array', rand_x(1), None, (0, 1), 'backward', ValueError),
('test_norm_not_in_enum_value', rand_x(2), None, -1, 'random', ValueError)
])
class TestRfftException(unittest.TestCase):
def test_rfft(self):
with self.assertRaises(self.expect_exception):
paddle.fft.rfft(
paddle.to_tensor(self.x), self.n, self.axis, self.norm)
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [
('test_x_float64', rand_x(5), None, (0, 1), 'backward'),
('test_n_grater_input_length', rand_x(
5, max_dim_len=5), (6, 6), (0, 1), 'backward'),
('test_n_smaller_than_input_length', rand_x(
5, min_dim_len=5), (4, 4), (0, 1), 'backward'),
('test_axis_random', rand_x(5), None, (1, 2), 'backward'),
('test_axis_none', rand_x(5), None, None, 'backward'),
('test_norm_forward', rand_x(5), None, (0, 1), 'forward'),
('test_norm_ortho', rand_x(5), None, (0, 1), 'ortho'),
])
class TestRfft2(unittest.TestCase):
def test_rfft2(self):
with paddle.fluid.dygraph.guard(self.place):
self.assertTrue(
np.allclose(
scipy.fft.rfft2(self.x, self.n, self.axis, self.norm),
paddle.fft.rfft2(
paddle.to_tensor(self.x), self.n, self.axis, self.norm),
rtol=RTOL.get(str(self.x.dtype)),
atol=ATOL.get(str(self.x.dtype))))
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [
('test_x_complex_input', rand_x(
2, complex=True), None, (0, 1), 'backward', RuntimeError),
('test_x_1dim_tensor', rand_x(1), None, (0, 1), 'backward', ValueError),
('test_n_nagative', rand_x(2), -1, (0, 1), 'backward', ValueError),
('test_n_zero', rand_x(2), 0, (0, 1), 'backward', ValueError),
('test_axis_out_of_range', rand_x(2), None, (0, 1, 2), 'backward',
ValueError),
('test_axis_with_array', rand_x(1), None, (0, 1), 'backward',
ValueError),
('test_axis_not_sequence', rand_x(5), None, -10, 'backward',
ValueError),
('test_norm_not_enum', rand_x(2), None, -1, 'random', ValueError),
])
class TestRfft2Exception(unittest.TestCase):
def test_rfft(self):
with paddle.fluid.dygraph.guard(self.place):
with self.assertRaises(self.expect_exception):
paddle.fft.rfft2(
paddle.to_tensor(self.x), self.n, self.axis, self.norm)
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [
('test_x_float64', rand_x(5, np.float64), None, None, 'backward'),
('test_n_grater_input_length', rand_x(
5, max_dim_len=5), (6, 6), (1, 2), 'backward'),
('test_n_smaller_input_length', rand_x(
5, min_dim_len=5), (3, 3), (1, 2), 'backward'),
('test_axis_not_default', rand_x(5), None, (1, 2), 'backward'),
('test_norm_forward', rand_x(5), None, None, 'forward'),
('test_norm_ortho', rand_x(5), None, None, 'ortho'),
])
class TestRfftn(unittest.TestCase):
def test_rfftn(self):
with paddle.fluid.dygraph.guard(self.place):
self.assertTrue(
np.allclose(
scipy.fft.rfftn(self.x, self.n, self.axis, self.norm),
paddle.fft.rfftn(
paddle.to_tensor(self.x), self.n, self.axis, self.norm),
rtol=RTOL.get(str(self.x.dtype)),
atol=ATOL.get(str(self.x.dtype))))
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'),
[('test_x_complex', rand_x(
4, complex=True), None, None, 'backward',
RuntimeError), ('test_n_nagative', rand_x(4), (-1, -1), (1, 2),
'backward', ValueError),
('test_n_not_sequence', rand_x(4), -1, None, 'backward', ValueError),
('test_n_zero', rand_x(4), 0, None, 'backward', ValueError), (
'test_axis_out_of_range', rand_x(1), None, [0, 1], 'backward',
ValueError),
('test_norm_not_in_enum', rand_x(2), None, -1, 'random', ValueError)])
class TestRfftnException(unittest.TestCase):
def test_rfft(self):
with paddle.fluid.dygraph.guard(self.place):
with self.assertRaises(self.expect_exception):
paddle.fft.rfftn(
paddle.to_tensor(self.x), self.n, self.axis, self.norm)
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'),
[('test_x_float64', rand_x(5, np.float64), None, -1, 'backward'), (
'test_n_grater_than_input_length', rand_x(
5, max_dim_len=5), 11, -1, 'backward'),
('test_n_smaller_than_input_length', rand_x(
5, min_dim_len=5), 3, -1,
'backward'), ('test_axis_not_last', rand_x(5), None, 3, 'backward'),
('test_norm_forward', rand_x(5), None, 3, 'forward'),
('test_norm_ortho', rand_x(5), None, 3, 'ortho')])
class TestIhfft(unittest.TestCase):
def test_ihfft(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
scipy.fft.ihfft(self.x, self.n, self.axis, self.norm),
paddle.fft.ihfft(
paddle.to_tensor(self.x), self.n, self.axis, self.norm),
rtol=RTOL.get(str(self.x.dtype)),
atol=ATOL.get(str(self.x.dtype)))
@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [
('test_n_nagative', rand_x(2), -1, -1, 'backward', ValueError),
('test_n_zero', rand_x(2), 0, -1, 'backward', ValueError),
('test_axis_out_of_range', rand_x(1), None, 10, 'backward', ValueError),
('test_axis_with_array', rand_x(1), None, (0, 1), 'backward', ValueError),
('test_norm_not_in_enum_value', rand_x(2), None, -1, 'random', ValueError)
])
class TestIhfftException(unittest.TestCase):
def test_ihfft(self):
with paddle.fluid.dygraph.guard(self.place):
with self.assertRaises(self.expect_exception):
paddle.fft.ihfft(
paddle.to_tensor(self.x), self.n, self.axis, self.norm)
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [
('test_x_float64', rand_x(5), None, (0, 1), 'backward'),
('test_n_grater_input_length', rand_x(
5, max_dim_len=5), (11, 11), (0, 1), 'backward'),
('test_n_smaller_than_input_length', rand_x(
5, min_dim_len=5), (1, 1), (0, 1), 'backward'),
('test_axis_random', rand_x(5), None, (1, 2), 'backward'),
('test_axis_none', rand_x(5), None, None, 'backward'),
('test_norm_forward', rand_x(5), None, (0, 1), 'forward'),
('test_norm_ortho', rand_x(5), None, (0, 1), 'ortho'),
])
class TestIhfft2(unittest.TestCase):
def test_ihfft2(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
scipy.fft.ihfft2(self.x, self.n, self.axis, self.norm),
paddle.fft.ihfft2(
paddle.to_tensor(self.x), self.n, self.axis, self.norm),
rtol=RTOL.get(str(self.x.dtype)),
atol=ATOL.get(str(self.x.dtype)))
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'),
[('test_x_complex_input', rand_x(
2, complex=True), None, (0, 1), None, ValueError),
('test_x_1dim_tensor', rand_x(1), None, (0, 1), None,
ValueError), ('test_n_nagative', rand_x(2), -1, (0, 1), 'backward',
ValueError), ('test_n_len_not_equal_axis', rand_x(
5, max_dim_len=5), 11, (0, 1), 'backward', ValueError),
('test_n_zero', rand_x(2), (0, 0), (0, 1), 'backward', ValueError),
('test_axis_out_of_range', rand_x(2), None, (0, 1, 2), 'backward',
ValueError), ('test_axis_with_array', rand_x(1), None, (0, 1), 'backward',
ValueError), ('test_axis_not_sequence', rand_x(5), None,
-10, 'backward', ValueError),
('test_norm_not_enum', rand_x(2), None, -1, 'random', ValueError)])
class TestIhfft2Exception(unittest.TestCase):
def test_rfft(self):
with paddle.fluid.dygraph.guard(self.place):
with self.assertRaises(self.expect_exception):
paddle.fft.ihfft2(
paddle.to_tensor(self.x), self.n, self.axis, self.norm)
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'),
[('test_x_float64', rand_x(5, np.float64), None, None, 'backward'),
('test_n_grater_input_length', rand_x(
5, max_dim_len=5), (11, 11), (0, 1),
'backward'), ('test_n_smaller_input_length', rand_x(
5, min_dim_len=5), (1, 1), (0, 1), 'backward'),
('test_axis_not_default', rand_x(5), None, (1, 2),
'backward'), ('test_norm_forward', rand_x(5), None, None, 'forward'),
('test_norm_ortho', rand_x(5), None, None, 'ortho')])
class TestIhfftn(unittest.TestCase):
def test_rfftn(self):
with paddle.fluid.dygraph.guard(self.place):
self.assertTrue(
np.allclose(
scipy.fft.ihfftn(self.x, self.n, self.axis, self.norm),
paddle.fft.ihfftn(
paddle.to_tensor(self.x), self.n, self.axis, self.norm),
rtol=RTOL.get(str(self.x.dtype)),
atol=ATOL.get(str(self.x.dtype))))
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'),
[('test_x_complex', rand_x(
4, complex=True), None, None, 'backward', RuntimeError),
('test_n_nagative', rand_x(4), -1, None, 'backward', ValueError),
('test_n_zero', rand_x(4), 0, None, 'backward', ValueError), (
'test_axis_out_of_range', rand_x(1), None, [0, 1], 'backward',
ValueError),
('test_norm_not_in_enum', rand_x(2), None, -1, 'random', ValueError)])
class TestIhfftnException(unittest.TestCase):
def test_rfft(self):
with paddle.fluid.dygraph.guard(self.place):
with self.assertRaises(self.expect_exception):
paddle.fft.ihfftn(
paddle.to_tensor(self.x), self.n, self.axis, self.norm)
@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'n', 'd', 'dtype'), [
('test_without_d', 20, 1, 'float32'),
('test_with_d', 20, 0.5, 'float32'),
])
class TestFftFreq(unittest.TestCase):
def test_fftfreq(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
scipy.fft.fftfreq(self.n, self.d).astype(self.dtype),
paddle.fft.fftfreq(self.n, self.d, self.dtype).numpy(),
rtol=RTOL.get(str(self.dtype)),
atol=ATOL.get(str(self.dtype)))
@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'n', 'd', 'dtype'), [
('test_without_d', 20, 1, 'float32'),
('test_with_d', 20, 0.5, 'float32'),
])
class TestRfftFreq(unittest.TestCase):
def test_rfftfreq(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
scipy.fft.rfftfreq(self.n, self.d).astype(self.dtype),
paddle.fft.rfftfreq(self.n, self.d, self.dtype).numpy(),
rtol=RTOL.get(str(self.dtype)),
atol=ATOL.get(str(self.dtype)))
@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'x', 'axes', 'dtype'), [
('test_1d', np.random.randn(10), (0, ), 'float64'),
('test_2d', np.random.randn(10, 10), (0, 1), 'float64'),
])
class TestFftShift(unittest.TestCase):
def test_fftshift(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
scipy.fft.fftshift(self.x, self.axes),
paddle.fft.fftshift(paddle.to_tensor(self.x),
self.axes).numpy(),
rtol=RTOL.get(str(self.x.dtype)),
atol=ATOL.get(str(self.x.dtype)))
@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'x', 'axes'), [
('test_1d', np.random.randn(10), (0, ), 'float64'),
('test_2d', np.random.randn(10, 10), (0, 1), 'float64'),
])
class TestIfftShift(unittest.TestCase):
def test_ifftshift(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
scipy.fft.ifftshift(self.x, self.axes),
paddle.fft.ifftshift(paddle.to_tensor(self.x),
self.axes).numpy(),
rtol=RTOL.get(str(self.x.dtype)),
atol=ATOL.get(str(self.x.dtype)))
if __name__ == '__main__':
unittest.main()
# yapf: enable
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import re
import sys
import unittest
import numpy as np
import paddle
import scipy.fft
from test_fft import (ATOL, DEVICES, RTOL, TEST_CASE_NAME, parameterize, place,
rand_x)
@contextlib.contextmanager
def stgraph(func, place, x, n, axes, norm):
"""static graph exec context"""
paddle.enable_static()
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
input = paddle.static.data('input', x.shape, dtype=x.dtype)
output = func(input, n, axes, norm)
exe = paddle.static.Executor(place)
exe.run(sp)
[output] = exe.run(mp, feed={'input': x}, fetch_list=[output])
yield output
paddle.disable_static()
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'),
[('test_x_float64', rand_x(5, np.float64), None, -1, 'backward'),
('test_x_complex64', rand_x(
5, np.float64, complex=True), None, -1,
'backward'), ('test_n_grater_than_input_length', rand_x(
5, max_dim_len=5), 11, -1,
'backward'), ('test_n_smaller_than_input_length', rand_x(
5, min_dim_len=5), 3, -1, 'backward'),
('test_axis_not_last', rand_x(5), None, 3, 'backward'),
('test_norm_forward', rand_x(5), None, 3, 'forward'),
('test_norm_ortho', rand_x(5), None, 3, 'ortho')])
class TestFft(unittest.TestCase):
def test_static_rfft(self):
with stgraph(paddle.fft.fft, self.place, self.x, self.n, self.axis,
self.norm) as y:
np.testing.assert_allclose(
scipy.fft.fft(self.x, self.n, self.axis, self.norm),
y,
rtol=RTOL.get(str(self.x.dtype)),
atol=ATOL.get(str(self.x.dtype)))
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'),
[('test_n_nagative', rand_x(2), -1, -1, 'backward', ValueError),
('test_n_zero', rand_x(2), 0, -1, 'backward', ValueError),
('test_axis_out_of_range', rand_x(1), None, 10, 'backward',
ValueError), ('test_axis_with_array', rand_x(1), None, (0, 1), 'backward',
ValueError), ('test_norm_not_in_enum_value', rand_x(2),
None, -1, 'random', ValueError)])
class TestFftException(unittest.TestCase):
def test_fft(self):
with self.assertRaises(self.expect_exception):
with stgraph(paddle.fft.fft, self.place, self.x, self.n, self.axis,
self.norm) as y:
pass
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [
('test_x_float64', rand_x(5), None, (0, 1), 'backward'),
('test_x_complex128', rand_x(
5, complex=True), None, (0, 1), 'backward'),
('test_n_grater_input_length', rand_x(
5, max_dim_len=5), (6, 6), (0, 1), 'backward'),
('test_n_smaller_than_input_length', rand_x(
5, min_dim_len=5), (4, 4), (0, 1), 'backward'),
('test_axis_random', rand_x(5), None, (1, 2), 'backward'),
('test_axis_none', rand_x(5), None, None, 'backward'),
('test_norm_forward', rand_x(5), None, (0, 1), 'forward'),
('test_norm_ortho', rand_x(5), None, (0, 1), 'ortho'),
])
class TestFft2(unittest.TestCase):
def test_static_fft2(self):
with stgraph(paddle.fft.fft2, self.place, self.x, self.n, self.axis,
self.norm) as y:
np.testing.assert_allclose(
scipy.fft.fft2(self.x, self.n, self.axis, self.norm),
y,
rtol=RTOL.get(str(self.x.dtype)),
atol=ATOL.get(str(self.x.dtype)))
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'),
[
# ('test_x_not_tensor', [0, 1], None, (0, 1), 'backward', ValueError),
('test_x_1dim_tensor', rand_x(1), None, (0, 1), 'backward', ValueError),
('test_n_nagative', rand_x(2), -1, (0, 1), 'backward', ValueError),
('test_n_zero', rand_x(2), 0, (0, 1), 'backward', ValueError),
('test_axis_out_of_range', rand_x(2), None, (0, 1, 2), 'backward',
ValueError),
('test_axis_with_array', rand_x(1), None, (0, 1), 'backward',
ValueError),
('test_axis_not_sequence', rand_x(5), None, -10, 'backward',
ValueError),
('test_norm_not_enum', rand_x(2), None, -1, 'random', ValueError)
])
class TestFft2Exception(unittest.TestCase):
def test_static_fft2(self):
with self.assertRaises(self.expect_exception):
with stgraph(paddle.fft.fft2, self.place, self.x, self.n, self.axis,
self.norm) as y:
pass
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'),
[('test_x_float64', rand_x(5, np.float64), None, None, 'backward'),
('test_x_complex128', rand_x(
5, np.float64, complex=True), None, None,
'backward'), ('test_n_grater_input_length', rand_x(
5, max_dim_len=5), (6, 6), (1, 2),
'backward'), ('test_n_smaller_input_length', rand_x(
5, min_dim_len=5), (3, 3), (1, 2), 'backward'),
('test_axis_not_default', rand_x(5), None, (1, 2),
'backward'), ('test_norm_forward', rand_x(5), None, None, 'forward'),
('test_norm_ortho', rand_x(5), None, None, 'ortho')])
class TestFftn(unittest.TestCase):
def test_static_fftn(self):
with stgraph(paddle.fft.fftn, self.place, self.x, self.n, self.axis,
self.norm) as y:
np.testing.assert_allclose(
scipy.fft.fftn(self.x, self.n, self.axis, self.norm),
y,
rtol=RTOL.get(str(self.x.dtype)),
atol=ATOL.get(str(self.x.dtype)))
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'),
[('test_x_complex', rand_x(
4, complex=True), None, None, 'backward',
TypeError), ('test_n_nagative', rand_x(4), (-1, -1), (1, 2), 'backward',
ValueError), ('test_n_not_sequence', rand_x(4), -1, None,
'backward', ValueError),
('test_n_zero', rand_x(4), 0, None, 'backward', ValueError),
('test_axis_out_of_range', rand_x(1), None, [0, 1], 'backward',
ValueError), ('test_norm_not_in_enum', rand_x(2), None, -1, 'random',
ValueError)])
class TestRfftnException(unittest.TestCase):
def test_static_rfftn(self):
with self.assertRaises(self.expect_exception):
with stgraph(paddle.fft.rfftn, self.place, self.x, self.n,
self.axis, self.norm) as y:
pass
@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [
('test_x_complex128',
(np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4)
).astype(np.complex128), None, -1, "backward"),
('test_n_grater_than_input_length',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), 4, -1,
"backward"),
('test_n_smaller_than_input_length',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), 2, -1,
"backward"),
('test_axis_not_last',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, 1,
"backward"),
('test_norm_forward',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, 1,
"forward"),
('test_norm_ortho',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, -1,
"ortho"),
])
class TestHfft(unittest.TestCase):
"""Test hfft with norm condition
"""
def test_hfft(self):
with stgraph(paddle.fft.hfft, self.place, self.x, self.n, self.axis,
self.norm) as y:
np.testing.assert_allclose(
scipy.fft.hfft(self.x, self.n, self.axis, self.norm),
y,
rtol=1e-5,
atol=0)
@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [
('test_x_complex128',
(np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4)
).astype(np.complex128), None, -1, "backward"),
('test_n_grater_than_input_length',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), 4, -1,
"backward"),
('test_n_smaller_than_input_length',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), 2, -1,
"backward"),
('test_axis_not_last',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, -1,
"backward"),
('test_norm_forward',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, -1,
"forward"),
('test_norm_ortho',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, -1,
"ortho"),
])
class TestIrfft(unittest.TestCase):
"""Test irfft with norm condition
"""
def test_irfft(self):
with stgraph(paddle.fft.irfft, self.place, self.x, self.n, self.axis,
self.norm) as y:
np.testing.assert_allclose(
scipy.fft.irfft(self.x, self.n, self.axis, self.norm),
y,
rtol=1e-5,
atol=0)
@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [
('test_x_complex128',
(np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4)
).astype(np.complex128), None, None, "backward"),
('test_n_grater_than_input_length',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), [4], None,
"backward"),
('test_n_smaller_than_input_length',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), [2], None,
"backward"),
('test_axis_not_last',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, None,
"backward"),
('test_norm_forward',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, None,
"forward"),
('test_norm_ortho',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, None,
"ortho"),
])
class Testirfftn(unittest.TestCase):
"""Test irfftn with norm condition
"""
def test_static_irfftn(self):
with stgraph(paddle.fft.irfftn, self.place, self.x, self.n, self.axis,
self.norm) as y:
np.testing.assert_allclose(
scipy.fft.irfftn(self.x, self.n, self.axis, self.norm),
y,
rtol=1e-5,
atol=0)
@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [
('test_x_complex128',
(np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4)
).astype(np.complex128), None, None, "backward"),
('test_n_grater_than_input_length',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), [4], None,
"backward"),
('test_n_smaller_than_input_length',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), [2], None,
"backward"),
('test_axis_not_last',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, None,
"backward"),
('test_norm_forward',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, None,
"forward"),
('test_norm_ortho',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, None,
"ortho"),
])
class Testhfftn(unittest.TestCase):
"""Test hfftn with norm condition
"""
def test_static_hfftn(self):
with stgraph(paddle.fft.hfftn, self.place, self.x, self.n, self.axis,
self.norm) as y:
np.testing.assert_allclose(
scipy.fft.hfftn(self.x, self.n, self.axis, self.norm),
y,
rtol=1e-5,
atol=0)
@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'x', 's', 'axis', 'norm'), [
('test_x_complex128',
(np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4)
).astype(np.complex128), None, (-2, -1), "backward"),
('test_n_grater_input_length',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), [4, 8], (-2, -1),
"backward"),
('test_n_smaller_input_length',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), [2, 4], (-2, -1),
"backward"),
('test_axis_not_last',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1),
"backward"),
('test_norm_forward',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1),
"forward"),
('test_norm_ortho',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1),
"ortho"),
])
class Testhfft2(unittest.TestCase):
"""Test hfft2 with norm condition
"""
def test_static_hfft2(self):
with stgraph(paddle.fft.hfft2, self.place, self.x, self.s, self.axis,
self.norm) as y:
np.testing.assert_allclose(
scipy.fft.hfft2(self.x, self.s, self.axis, self.norm),
y,
rtol=1e-5,
atol=0)
@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'x', 's', 'axis', 'norm'), [
('test_x_complex128',
(np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4)
).astype(np.complex128), None, (-2, -1), "backward"),
('test_n_equal_input_length',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (2, 4), (-2, -1),
"backward"),
('test_axis_not_last',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1),
"backward"),
('test_norm_forward',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1),
"forward"),
('test_norm_ortho',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1),
"ortho"),
])
class TestIrfft2(unittest.TestCase):
"""Test irfft2 with norm condition
"""
def test_static_irfft2(self):
with stgraph(paddle.fft.irfft2, self.place, self.x, self.s, self.axis,
self.norm) as y:
np.testing.assert_allclose(
scipy.fft.irfft2(self.x, self.s, self.axis, self.norm),
y,
rtol=1e-5,
atol=0)
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'),
[('test_input_dtype', np.random.randn(4, 4, 4), None, -1, 'backward',
TypeError), ('test_bool_input',
(np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4)
).astype(np.bool8), None, -1, 'backward', TypeError),
('test_n_nagative',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), -1, -1,
'backward', ValueError),
('test_n_zero', np.random.randn(4, 4) + 1j * np.random.randn(4, 4), 0, -1,
'backward', ValueError),
('test_n_type', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
(1, 2, 3), -1, 'backward', ValueError),
('test_axis_out_of_range', np.random.randn(4) + 1j * np.random.randn(4),
None, 10, 'backward', ValueError), (
'test_axis_with_array', np.random.randn(4) + 1j * np.random.randn(4),
None, (0, 1), 'backward',
ValueError), ('test_norm_not_in_enum_value',
np.random.randn(4, 4) + 1j * np.random.randn(4, 4),
None, -1, 'random', ValueError)])
class TestHfftException(unittest.TestCase):
'''Test hfft with buoudary condition
Test case include:
- non complex input
- n out of range
- axis out of range
- norm out of range
'''
def test_static_hfft(self):
with self.assertRaises(self.expect_exception):
with stgraph(paddle.fft.hfft, self.place, self.x, self.n, self.axis,
self.norm) as y:
pass
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'),
[('test_input_dtype', np.random.randn(4, 4, 4), None, -1, 'backward',
TypeError), ('test_bool_input',
(np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4)
).astype(np.bool8), None, -1, 'backward', TypeError),
('test_n_nagative',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), -1, -1,
'backward', ValueError),
('test_n_zero', np.random.randn(4, 4) + 1j * np.random.randn(4, 4), 0, -1,
'backward', ValueError),
('test_n_type', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
(1, 2), -1, 'backward', ValueError),
('test_axis_out_of_range', np.random.randn(4) + 1j * np.random.randn(4),
None, 10, 'backward', ValueError), (
'test_axis_with_array', np.random.randn(4) + 1j * np.random.randn(4),
None, (0, 1), 'backward',
ValueError), ('test_norm_not_in_enum_value',
np.random.randn(4, 4) + 1j * np.random.randn(4, 4),
None, None, 'random', ValueError)])
class TestIrfftException(unittest.TestCase):
'''Test Irfft with buoudary condition
Test case include:
- non complex input
- n out of range
- axis out of range
- norm out of range
- the dimensions of n and axis are different
'''
def test_static_irfft(self):
with self.assertRaises(self.expect_exception):
with stgraph(paddle.fft.irfft, self.place, self.x, self.n,
self.axis, self.norm) as y:
pass
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'),
[('test_input_dtype', np.random.randn(4, 4, 4), None, None, 'backward',
TypeError), ('test_bool_input',
(np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4)
).astype(np.bool8), None, (-2, -1), 'backward', TypeError),
('test_n_nagative',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (-1, -2),
(-2, -1), 'backward', ValueError),
('test_n_zero', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
(0, 0), (-2, -1), 'backward', ValueError),
('test_n_type', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
3, None, 'backward',
ValueError), ('test_n_axis_dim',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
(1, 2), (-1), 'backward', ValueError),
('test_axis_out_of_range', np.random.randn(4) + 1j * np.random.randn(4),
None, (1, 2), 'backward', ValueError), (
'test_axis_type', np.random.randn(4) + 1j * np.random.randn(4), None,
-1, 'backward',
ValueError), ('test_norm_not_in_enum_value',
np.random.randn(4, 4) + 1j * np.random.randn(4, 4),
None, None, 'random', ValueError)])
class TestHfft2Exception(unittest.TestCase):
'''Test hfft2 with buoudary condition
Test case include:
- non complex input
- n out of range
- axis out of range
- the dimensions of n and axis are different
- norm out of range
'''
def test_static_hfft2(self):
with self.assertRaises(self.expect_exception):
with stgraph(paddle.fft.hfft2, self.place, self.x, self.n,
self.axis, self.norm) as y:
pass
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'),
[('test_input_dtype', np.random.randn(4, 4, 4), None, None, 'backward',
TypeError), ('test_bool_input',
(np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4)
).astype(np.bool8), None, (-2, -1), 'backward', TypeError),
('test_n_nagative',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (-1, -2),
(-2, -1), 'backward', ValueError),
('test_n_zero', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
(0, 0), (-2, -1), 'backward', ValueError),
('test_n_type', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
3, -1, 'backward',
ValueError), ('test_n_axis_dim',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
(1, 2), (-3, -2, -1), 'backward', ValueError),
('test_axis_out_of_range', np.random.randn(4) + 1j * np.random.randn(4),
None, (1, 2), 'backward', ValueError), (
'test_axis_type', np.random.randn(4) + 1j * np.random.randn(4), None,
1, 'backward',
ValueError), ('test_norm_not_in_enum_value',
np.random.randn(4, 4) + 1j * np.random.randn(4, 4),
None, None, 'random', ValueError)])
class TestIrfft2Exception(unittest.TestCase):
'''Test irfft2 with buoudary condition
Test case include:
- non complex input
- n out of range
- axis out of range
- norm out of range
- the dimensions of n and axis are different
'''
def test_static_irfft2(self):
with self.assertRaises(self.expect_exception):
with stgraph(paddle.fft.irfft2, self.place, self.x, self.n,
self.axis, self.norm) as y:
pass
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'),
[('test_input_dtype', np.random.randn(4, 4, 4), None, None, 'backward',
TypeError), ('test_bool_input',
(np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4)
).astype(np.bool8), None, (-2, -1), 'backward', TypeError),
('test_n_nagative',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (-1, -2),
(-2, -1), 'backward', ValueError),
('test_n_zero', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
(0, 0), (-2, -1), 'backward', ValueError),
('test_n_type', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
3, -1, 'backward',
ValueError), ('test_n_axis_dim',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
(1, 2), (-3, -2, -1), 'backward', ValueError),
('test_axis_out_of_range', np.random.randn(4) + 1j * np.random.randn(4),
None, (10, 20), 'backward', ValueError), (
'test_axis_type', np.random.randn(4) + 1j * np.random.randn(4), None,
1, 'backward',
ValueError), ('test_norm_not_in_enum_value',
np.random.randn(4, 4) + 1j * np.random.randn(4, 4),
None, None, 'random', ValueError)])
class TestHfftnException(unittest.TestCase):
'''Test hfftn with buoudary condition
Test case include:
- non complex input
- n out of range
- axis out of range
- norm out of range
- the dimensions of n and axis are different
'''
def test_static_hfftn(self):
with self.assertRaises(self.expect_exception):
with stgraph(paddle.fft.hfftn, self.place, self.x, self.n,
self.axis, self.norm) as y:
pass
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'),
[
('test_input_dtype', np.random.randn(4, 4, 4), None, None, 'backward',
TypeError),
# ('test_bool_input',
# (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4)
# ).astype(np.bool8), None, (-2, -1), 'backward', ValueError),
('test_n_nagative',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (-1, -2),
(-2, -1), 'backward', ValueError),
('test_n_zero',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (0, 0),
(-2, -1), 'backward', ValueError),
('test_n_type',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), 3, -1,
'backward', ValueError),
('test_n_axis_dim',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (1, 2),
(-3, -2, -1), 'backward', ValueError),
('test_axis_out_of_range', np.random.randn(4) + 1j * np.random.randn(4),
None, (10, 20), 'backward', ValueError),
('test_axis_type', np.random.randn(4) + 1j * np.random.randn(4), None,
1, 'backward', ValueError),
('test_norm_not_in_enum_value',
np.random.randn(4, 4) + 1j * np.random.randn(4, 4), None, None,
'random', ValueError)
])
class TestIrfftnException(unittest.TestCase):
'''Test irfftn with buoudary condition
Test case include:
- non complex input
- n out of range
- axis out of range
- norm out of range
- the dimensions of n and axis are different
'''
def test_static_irfftn(self):
with self.assertRaises(self.expect_exception):
with stgraph(paddle.fft.irfftn, self.place, self.x, self.n,
self.axis, self.norm) as y:
pass
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'),
[('test_x_float64', rand_x(5, np.float64), None, -1, 'backward'), (
'test_n_grater_than_input_length', rand_x(
5, max_dim_len=5), 11, -1, 'backward'),
('test_n_smaller_than_input_length', rand_x(
5, min_dim_len=5), 3, -1,
'backward'), ('test_axis_not_last', rand_x(5), None, 3, 'backward'),
('test_norm_forward', rand_x(5), None, 3, 'forward'),
('test_norm_ortho', rand_x(5), None, 3, 'ortho')])
class TestRfft(unittest.TestCase):
def test_static_rfft(self):
with stgraph(paddle.fft.rfft, self.place, self.x, self.n, self.axis,
self.norm) as y:
np.testing.assert_allclose(
scipy.fft.rfft(self.x, self.n, self.axis, self.norm),
y,
rtol=RTOL.get(str(self.x.dtype)),
atol=ATOL.get(str(self.x.dtype)))
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'),
[('test_n_nagative', rand_x(2), -1, -1, 'backward', ValueError),
('test_n_zero', rand_x(2), 0, -1, 'backward', ValueError),
('test_axis_out_of_range', rand_x(1), None, 10, 'backward',
ValueError), ('test_axis_with_array', rand_x(1), None, (0, 1), 'backward',
ValueError), ('test_norm_not_in_enum_value', rand_x(2),
None, -1, 'random', ValueError)])
class TestRfftException(unittest.TestCase):
def test_rfft(self):
with self.assertRaises(self.expect_exception):
with stgraph(paddle.fft.rfft, self.place, self.x, self.n, self.axis,
self.norm) as y:
pass
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [
('test_x_float64', rand_x(5), None, (0, 1), 'backward'),
('test_n_grater_input_length', rand_x(
5, max_dim_len=5), (6, 6), (0, 1), 'backward'),
('test_n_smaller_than_input_length', rand_x(
5, min_dim_len=5), (4, 4), (0, 1), 'backward'),
('test_axis_random', rand_x(5), None, (1, 2), 'backward'),
('test_axis_none', rand_x(5), None, None, 'backward'),
('test_norm_forward', rand_x(5), None, (0, 1), 'forward'),
('test_norm_ortho', rand_x(5), None, (0, 1), 'ortho'),
])
class TestRfft2(unittest.TestCase):
def test_static_rfft2(self):
with stgraph(paddle.fft.rfft2, self.place, self.x, self.n, self.axis,
self.norm) as y:
np.testing.assert_allclose(
scipy.fft.rfft2(self.x, self.n, self.axis, self.norm),
y,
rtol=RTOL.get(str(self.x.dtype)),
atol=ATOL.get(str(self.x.dtype)))
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'),
[
('test_x_complex_input', rand_x(
2, complex=True), None, (0, 1), 'backward', TypeError),
# ('test_x_not_tensor', [0, 1], None, (0, 1), 'backward', ValueError),
('test_x_1dim_tensor', rand_x(1), None, (0, 1), 'backward', ValueError),
('test_n_nagative', rand_x(2), -1, (0, 1), 'backward', ValueError),
('test_n_zero', rand_x(2), 0, (0, 1), 'backward', ValueError),
('test_axis_out_of_range', rand_x(2), None, (0, 1, 2), 'backward',
ValueError),
('test_axis_with_array', rand_x(1), None, (0, 1), 'backward',
ValueError),
('test_axis_not_sequence', rand_x(5), None, -10, 'backward',
ValueError),
('test_norm_not_enum', rand_x(2), None, -1, 'random', ValueError)
])
class TestRfft2Exception(unittest.TestCase):
def test_static_rfft(self):
with self.assertRaises(self.expect_exception):
with stgraph(paddle.fft.rfft2, self.place, self.x, self.n,
self.axis, self.norm) as y:
pass
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'),
[('test_x_float64', rand_x(5, np.float64), None, None, 'backward'),
('test_n_grater_input_length', rand_x(
5, max_dim_len=5), (6, 6), (1, 2),
'backward'), ('test_n_smaller_input_length', rand_x(
5, min_dim_len=5), (3, 3), (1, 2), 'backward'),
('test_axis_not_default', rand_x(5), None, (1, 2),
'backward'), ('test_norm_forward', rand_x(5), None, None, 'forward'),
('test_norm_ortho', rand_x(5), None, None, 'ortho')])
class TestRfftn(unittest.TestCase):
def test_static_rfft(self):
with stgraph(paddle.fft.rfftn, self.place, self.x, self.n, self.axis,
self.norm) as y:
np.testing.assert_allclose(
scipy.fft.rfftn(self.x, self.n, self.axis, self.norm),
y,
rtol=RTOL.get(str(self.x.dtype)),
atol=ATOL.get(str(self.x.dtype)))
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'),
[('test_x_complex', rand_x(
4, complex=True), None, None, 'backward',
TypeError), ('test_n_nagative', rand_x(4), (-1, -1), (1, 2), 'backward',
ValueError), ('test_n_not_sequence', rand_x(4), -1, None,
'backward', ValueError),
('test_n_zero', rand_x(4), 0, None, 'backward', ValueError),
('test_axis_out_of_range', rand_x(1), None, [0, 1], 'backward',
ValueError), ('test_norm_not_in_enum', rand_x(2), None, -1, 'random',
ValueError)])
class TestRfftnException(unittest.TestCase):
def test_static_rfftn(self):
with self.assertRaises(self.expect_exception):
with stgraph(paddle.fft.rfftn, self.place, self.x, self.n,
self.axis, self.norm) as y:
pass
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'),
[('test_x_float64', rand_x(5, np.float64), None, -1, 'backward'), (
'test_n_grater_than_input_length', rand_x(
5, max_dim_len=5), 11, -1, 'backward'),
('test_n_smaller_than_input_length', rand_x(
5, min_dim_len=5), 3, -1,
'backward'), ('test_axis_not_last', rand_x(5), None, 3, 'backward'),
('test_norm_forward', rand_x(5), None, 3, 'forward'),
('test_norm_ortho', rand_x(5), None, 3, 'ortho')])
class TestIhfft(unittest.TestCase):
def test_static_ihfft(self):
with stgraph(paddle.fft.ihfft, self.place, self.x, self.n, self.axis,
self.norm) as y:
np.testing.assert_allclose(
scipy.fft.ihfft(self.x, self.n, self.axis, self.norm),
y,
rtol=RTOL.get(str(self.x.dtype)),
atol=ATOL.get(str(self.x.dtype)))
@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [
('test_n_nagative', rand_x(2), -1, -1, 'backward', ValueError),
('test_n_zero', rand_x(2), 0, -1, 'backward', ValueError),
('test_axis_out_of_range', rand_x(1), None, 10, 'backward', ValueError),
('test_axis_with_array', rand_x(1), None, (0, 1), 'backward', ValueError),
('test_norm_not_in_enum_value', rand_x(2), None, -1, 'random', ValueError)
])
class TestIhfftException(unittest.TestCase):
def test_static_ihfft(self):
with self.assertRaises(self.expect_exception):
with stgraph(paddle.fft.ihfft, self.place, self.x, self.n,
self.axis, self.norm) as y:
pass
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [
('test_x_float64', rand_x(5), None, (0, 1), 'backward'),
('test_n_grater_input_length', rand_x(
5, max_dim_len=5), (11, 11), (0, 1), 'backward'),
('test_n_smaller_than_input_length', rand_x(
5, min_dim_len=5), (1, 1), (0, 1), 'backward'),
('test_axis_random', rand_x(5), None, (1, 2), 'backward'),
('test_axis_none', rand_x(5), None, None, 'backward'),
('test_norm_forward', rand_x(5), None, (0, 1), 'forward'),
('test_norm_ortho', rand_x(5), None, (0, 1), 'ortho'),
])
class TestIhfft2(unittest.TestCase):
def test_static_ihfft2(self):
with stgraph(paddle.fft.ihfft2, self.place, self.x, self.n, self.axis,
self.norm) as y:
np.testing.assert_allclose(
scipy.fft.ihfft2(self.x, self.n, self.axis, self.norm),
y,
rtol=RTOL.get(str(self.x.dtype)),
atol=ATOL.get(str(self.x.dtype)))
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'),
[
('test_x_complex_input', rand_x(
2, complex=True), None, (0, 1), None, ValueError),
# ('test_x_not_tensor', [0, 1], None, (0, 1), None, ValueError),
('test_x_1dim_tensor', rand_x(1), None, (0, 1), None, ValueError),
('test_n_nagative', rand_x(2), -1, (0, 1), 'backward', ValueError),
('test_n_len_not_equal_axis', rand_x(
5, max_dim_len=5), 11, (0, 1), 'backward', ValueError),
('test_n_zero', rand_x(2), (0, 0), (0, 1), 'backward', ValueError),
('test_axis_out_of_range', rand_x(2), None, (0, 1, 2), 'backward',
ValueError),
('test_axis_with_array', rand_x(1), None, (0, 1), 'backward',
ValueError),
('test_axis_not_sequence', rand_x(5), None, -10, 'backward',
ValueError),
('test_norm_not_enum', rand_x(2), None, -1, 'random', ValueError)
])
class TestIhfft2Exception(unittest.TestCase):
def test_static_ihfft2(self):
with self.assertRaises(self.expect_exception):
with stgraph(paddle.fft.ihfft2, self.place, self.x, self.n,
self.axis, self.norm) as y:
pass
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'),
[('test_x_float64', rand_x(5, np.float64), None, None, 'backward'),
('test_n_grater_input_length', rand_x(
5, max_dim_len=5), (11, 11), (0, 1),
'backward'), ('test_n_smaller_input_length', rand_x(
5, min_dim_len=5), (1, 1), (0, 1), 'backward'),
('test_axis_not_default', rand_x(5), None, (1, 2),
'backward'), ('test_norm_forward', rand_x(5), None, None, 'forward'),
('test_norm_ortho', rand_x(5), None, None, 'ortho')])
class TestIhfftn(unittest.TestCase):
def test_static_ihfftn(self):
with stgraph(paddle.fft.ihfftn, self.place, self.x, self.n, self.axis,
self.norm) as y:
np.testing.assert_allclose(
scipy.fft.ihfftn(self.x, self.n, self.axis, self.norm),
y,
rtol=RTOL.get(str(self.x.dtype)),
atol=ATOL.get(str(self.x.dtype)))
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'),
[('test_x_complex', rand_x(
4, complex=True), None, None, 'backward', TypeError),
('test_n_nagative', rand_x(4), -1, None, 'backward',
ValueError), ('test_n_zero', rand_x(4), 0, None, 'backward', ValueError),
('test_axis_out_of_range', rand_x(1), None, [0, 1], 'backward',
ValueError), ('test_norm_not_in_enum', rand_x(2), None, -1, 'random',
ValueError)])
class TestIhfftnException(unittest.TestCase):
def test_static_ihfftn(self):
with self.assertRaises(self.expect_exception):
with stgraph(paddle.fft.ihfftn, self.place, self.x, self.n,
self.axis, self.norm) as y:
pass
if __name__ == '__main__':
unittest.main()
# yapf: enable
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import re
import sys
from spectral_op_np import fft_c2c, fft_r2c, fft_c2r
import paddle.fluid.core as core
import paddle.fluid.dygraph as dg
import paddle.static as static
from numpy.random import random as rand
from paddle.fluid import Program, program_guard
sys.path.append("../")
from op_test import OpTest
paddle.enable_static()
TEST_CASE_NAME = 'test_case'
def parameterize(attrs, input_values=None):
if isinstance(attrs, str):
attrs = [attrs]
input_dicts = (attrs if input_values is None else
[dict(zip(attrs, vals)) for vals in input_values])
def decorator(base_class):
test_class_module = sys.modules[base_class.__module__].__dict__
for idx, input_dict in enumerate(input_dicts):
test_class_dict = dict(base_class.__dict__)
test_class_dict.update(input_dict)
name = class_name(base_class, idx, input_dict)
test_class_module[name] = type(name, (base_class, ),
test_class_dict)
for method_name in list(base_class.__dict__):
if method_name.startswith("test"):
delattr(base_class, method_name)
return base_class
return decorator
def to_safe_name(s):
return str(re.sub("[^a-zA-Z0-9_]+", "_", s))
def class_name(cls, num, params_dict):
suffix = to_safe_name(
next((v for v in params_dict.values() if isinstance(v, str)), ""))
if TEST_CASE_NAME in params_dict:
suffix = to_safe_name(params_dict["test_case"])
return "{}_{}{}".format(cls.__name__, num, suffix and "_" + suffix)
@parameterize((TEST_CASE_NAME, 'x', 'axes', 'norm', 'forward'), [
('test_axes_is_sqe_type', (np.random.random(
(12, 14)) + 1j * np.random.random((12, 14))).astype(np.complex128),
[0, 1], 'forward', True), ('test_axis_not_last', (np.random.random(
(4, 4, 4)) + 1j * np.random.random((4, 4, 4))).astype(np.complex128),
(0, 1), "backward", False),
('test_norm_forward', (np.random.random((12, 14)) + 1j * np.random.random(
(12, 14))).astype(np.complex128), (0, ), "forward",
False), ('test_norm_backward', (np.random.random(
(12, 14)) + 1j * np.random.random((12, 14))).astype(np.complex128),
(0, ), "backward", True), ('test_norm_ortho', (np.random.random(
(12, 14)) + 1j * np.random.random(
(12, 14))).astype(np.complex128), (1, ), "ortho", True)
])
class TestFFTC2COp(OpTest):
# Because framwork not support complex numerial gradient, we skip gradient check.
no_need_check_grad = True
def setUp(self):
self.op_type = "fft_c2c"
out = fft_c2c(self.x, self.axes, self.norm, self.forward)
self.inputs = {'X': self.x}
self.attrs = {
'axes': self.axes,
'normalization': self.norm,
"forward": self.forward
}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
@parameterize(
(TEST_CASE_NAME, 'x', 'axes', 'norm', 'forward', 'last_dim_size'),
[('test_axes_is_sqe_type', (np.random.random(
(12, 14)) + 1j * np.random.random((12, 14))).astype(np.complex128),
[0, 1], 'forward', True, 26), ('test_axis_not_last', (np.random.random(
(4, 4, 4)) + 1j * np.random.random((4, 4, 4))).astype(np.complex128),
(0, 1), "backward", False, None),
('test_norm_forward', (np.random.random((12, 14)) + 1j * np.random.random(
(12, 14))).astype(np.complex128), (0, ), "forward", False, 22),
('test_norm_backward', (np.random.random((12, 14)) + 1j * np.random.random(
(12, 14))).astype(np.complex128), (0, ), "backward", True,
22), ('test_norm_ortho', (np.random.random(
(12, 14)) + 1j * np.random.random((12, 14))).astype(np.complex128),
(1, ), "ortho", True, 26)])
class TestFFTC2ROp(OpTest):
# Because framwork not support complex numerial gradient, we skip gradient check.
no_need_check_grad = True
def setUp(self):
self.op_type = "fft_c2r"
out = fft_c2r(self.x, self.axes, self.norm, self.forward,
self.last_dim_size)
self.inputs = {'X': self.x}
self.attrs = {
"axes": self.axes,
"normalization": self.norm,
"forward": self.forward,
"last_dim_size": self.last_dim_size
}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
@parameterize(
(TEST_CASE_NAME, 'x', 'axes', 'norm', 'forward', 'onesided'),
[('test_axes_is_sqe_type', np.random.randn(12, 14).astype(np.float64),
(0, 1), 'forward', True,
True), ('test_axis_not_last', np.random.randn(4, 4, 4).astype(np.float64),
(0, 1), "backward", False, True),
('test_norm_forward', np.random.randn(12, 14).astype(np.float64), (0, 1),
"forward", False, False),
('test_norm_backward', np.random.randn(12, 14).astype(np.float64), (0, ),
"backward", True, False), ('test_norm_ortho',
np.random.randn(12, 14).astype(np.float64),
(1, ), "ortho", True, False)])
class TestFFTR2COp(OpTest):
# Because framwork not support complex numerial gradient, we skip gradient check.
no_need_check_grad = True
def setUp(self):
self.op_type = "fft_r2c"
out = fft_r2c(self.x, self.axes, self.norm, self.forward, self.onesided)
self.inputs = {'X': self.x}
self.attrs = {
'axes': self.axes,
'normalization': self.norm,
"forward": self.forward,
'onesided': self.onesided
}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from numpy.lib.stride_tricks import as_strided
import paddle
import unittest
from op_test import OpTest
def frame_from_librosa(x, frame_length, hop_length, axis=-1):
if axis == -1 and not x.flags["C_CONTIGUOUS"]:
x = np.ascontiguousarray(x)
elif axis == 0 and not x.flags["F_CONTIGUOUS"]:
x = np.asfortranarray(x)
n_frames = 1 + (x.shape[axis] - frame_length) // hop_length
strides = np.asarray(x.strides)
if axis == -1:
shape = list(x.shape)[:-1] + [frame_length, n_frames]
strides = list(strides) + [hop_length * x.itemsize]
elif axis == 0:
shape = [n_frames, frame_length] + list(x.shape)[1:]
strides = [hop_length * x.itemsize] + list(strides)
else:
raise ValueError("Frame axis={} must be either 0 or -1".format(axis))
return as_strided(x, shape=shape, strides=strides)
class TestFrameOp(OpTest):
def setUp(self):
self.op_type = "frame"
self.shape, self.type, self.attrs = self.initTestCase()
self.inputs = {
'X': np.random.random(size=self.shape).astype(self.type),
}
self.outputs = {
'Out': frame_from_librosa(
x=self.inputs['X'], **self.attrs)
}
def initTestCase(self):
input_shape = (150, )
input_type = 'float64'
attrs = {
'frame_length': 50,
'hop_length': 15,
'axis': -1,
}
return input_shape, input_type, attrs
def test_check_output(self):
paddle.enable_static()
self.check_output()
paddle.disable_static()
def test_check_grad_normal(self):
paddle.enable_static()
self.check_grad(['X'], 'Out')
paddle.disable_static()
class TestCase1(TestFrameOp):
def initTestCase(self):
input_shape = (150, )
input_type = 'float64'
attrs = {
'frame_length': 50,
'hop_length': 15,
'axis': 0,
}
return input_shape, input_type, attrs
class TestCase2(TestFrameOp):
def initTestCase(self):
input_shape = (8, 150)
input_type = 'float64'
attrs = {
'frame_length': 50,
'hop_length': 15,
'axis': -1,
}
return input_shape, input_type, attrs
class TestCase3(TestFrameOp):
def initTestCase(self):
input_shape = (150, 8)
input_type = 'float64'
attrs = {
'frame_length': 50,
'hop_length': 15,
'axis': 0,
}
return input_shape, input_type, attrs
class TestCase4(TestFrameOp):
def initTestCase(self):
input_shape = (4, 2, 150)
input_type = 'float64'
attrs = {
'frame_length': 50,
'hop_length': 15,
'axis': -1,
}
return input_shape, input_type, attrs
class TestCase5(TestFrameOp):
def initTestCase(self):
input_shape = (150, 4, 2)
input_type = 'float64'
attrs = {
'frame_length': 50,
'hop_length': 15,
'axis': 0,
}
return input_shape, input_type, attrs
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import unittest
from op_test import OpTest
def overlap_add(x, hop_length, axis=-1):
assert axis in [0, -1], 'axis should be 0/-1.'
assert len(x.shape) >= 2, 'Input dims shoulb be >= 2.'
squeeze_output = False
if len(x.shape) == 2:
squeeze_output = True
dim = 0 if axis == -1 else -1
x = np.expand_dims(x, dim) # batch
n_frames = x.shape[axis]
frame_length = x.shape[1] if axis == 0 else x.shape[-2]
# Assure no gaps between frames.
assert 0 < hop_length <= frame_length, \
f'hop_length should be in (0, frame_length({frame_length})], but got {hop_length}.'
seq_length = (n_frames - 1) * hop_length + frame_length
reshape_output = False
if len(x.shape) > 3:
reshape_output = True
if axis == 0:
target_shape = [seq_length] + list(x.shape[2:])
x = x.reshape(n_frames, frame_length, np.product(x.shape[2:]))
else:
target_shape = list(x.shape[:-2]) + [seq_length]
x = x.reshape(np.product(x.shape[:-2]), frame_length, n_frames)
if axis == 0:
x = x.transpose((2, 1, 0))
y = np.zeros(shape=[np.product(x.shape[:-2]), seq_length], dtype=x.dtype)
for i in range(x.shape[0]):
for frame in range(x.shape[-1]):
sample = frame * hop_length
y[i, sample:sample + frame_length] += x[i, :, frame]
if axis == 0:
y = y.transpose((1, 0))
if reshape_output:
y = y.reshape(target_shape)
if squeeze_output:
y = y.squeeze(-1) if axis == 0 else y.squeeze(0)
return y
class TestOverlapAddOp(OpTest):
def setUp(self):
self.op_type = "overlap_add"
self.shape, self.type, self.attrs = self.initTestCase()
self.inputs = {
'X': np.random.random(size=self.shape).astype(self.type),
}
self.outputs = {'Out': overlap_add(x=self.inputs['X'], **self.attrs)}
def initTestCase(self):
input_shape = (50, 3)
input_type = 'float64'
attrs = {
'hop_length': 4,
'axis': -1,
}
return input_shape, input_type, attrs
def test_check_output(self):
paddle.enable_static()
self.check_output()
paddle.disable_static()
def test_check_grad_normal(self):
paddle.enable_static()
self.check_grad(['X'], 'Out')
paddle.disable_static()
class TestCase1(TestOverlapAddOp):
def initTestCase(self):
input_shape = (3, 50)
input_type = 'float64'
attrs = {
'hop_length': 4,
'axis': 0,
}
return input_shape, input_type, attrs
class TestCase2(TestOverlapAddOp):
def initTestCase(self):
input_shape = (2, 40, 5)
input_type = 'float64'
attrs = {
'hop_length': 10,
'axis': -1,
}
return input_shape, input_type, attrs
class TestCase3(TestOverlapAddOp):
def initTestCase(self):
input_shape = (5, 40, 2)
input_type = 'float64'
attrs = {
'hop_length': 10,
'axis': 0,
}
return input_shape, input_type, attrs
class TestCase4(TestOverlapAddOp):
def initTestCase(self):
input_shape = (3, 5, 12, 8)
input_type = 'float64'
attrs = {
'hop_length': 5,
'axis': -1,
}
return input_shape, input_type, attrs
class TestCase5(TestOverlapAddOp):
def initTestCase(self):
input_shape = (8, 12, 5, 3)
input_type = 'float64'
attrs = {
'hop_length': 5,
'axis': 0,
}
return input_shape, input_type, attrs
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import sys
import unittest
import numpy as np
from numpy import fft
from numpy.lib.stride_tricks import as_strided
import paddle
import scipy.signal
paddle.set_default_dtype('float64')
DEVICES = [paddle.CPUPlace()]
if paddle.is_compiled_with_cuda():
DEVICES.append(paddle.CUDAPlace(0))
TEST_CASE_NAME = 'test_case'
# Constrain STFT block sizes to 256 KB
MAX_MEM_BLOCK = 2**8 * 2**10
def fix_length(data, size, axis=-1, **kwargs):
kwargs.setdefault("mode", "constant")
n = data.shape[axis]
if n > size:
slices = [slice(None)] * data.ndim
slices[axis] = slice(0, size)
return data[tuple(slices)]
elif n < size:
lengths = [(0, 0)] * data.ndim
lengths[axis] = (0, size - n)
return np.pad(data, lengths, **kwargs)
return data
def tiny(x):
# Make sure we have an array view
x = np.asarray(x)
# Only floating types generate a tiny
if np.issubdtype(x.dtype, np.floating) or np.issubdtype(x.dtype,
np.complexfloating):
dtype = x.dtype
else:
dtype = np.float32
return np.finfo(dtype).tiny
def normalize(S, norm=np.inf, axis=0, threshold=None, fill=None):
# Avoid div-by-zero
if threshold is None:
threshold = tiny(S)
elif threshold <= 0:
raise Exception("threshold={} must be strictly "
"positive".format(threshold))
if fill not in [None, False, True]:
raise Exception("fill={} must be None or boolean".format(fill))
if not np.all(np.isfinite(S)):
raise Exception("Input must be finite")
# All norms only depend on magnitude, let's do that first
mag = np.abs(S).astype(np.float)
# For max/min norms, filling with 1 works
fill_norm = 1
if norm == np.inf:
length = np.max(mag, axis=axis, keepdims=True)
elif norm == -np.inf:
length = np.min(mag, axis=axis, keepdims=True)
elif norm == 0:
if fill is True:
raise Exception("Cannot normalize with norm=0 and fill=True")
length = np.sum(mag > 0, axis=axis, keepdims=True, dtype=mag.dtype)
elif np.issubdtype(type(norm), np.number) and norm > 0:
length = np.sum(mag**norm, axis=axis, keepdims=True)**(1.0 / norm)
if axis is None:
fill_norm = mag.size**(-1.0 / norm)
else:
fill_norm = mag.shape[axis]**(-1.0 / norm)
elif norm is None:
return S
else:
raise Exception("Unsupported norm: {}".format(repr(norm)))
# indices where norm is below the threshold
small_idx = length < threshold
Snorm = np.empty_like(S)
if fill is None:
# Leave small indices un-normalized
length[small_idx] = 1.0
Snorm[:] = S / length
elif fill:
# If we have a non-zero fill value, we locate those entries by
# doing a nan-divide.
# If S was finite, then length is finite (except for small positions)
length[small_idx] = np.nan
Snorm[:] = S / length
Snorm[np.isnan(Snorm)] = fill_norm
else:
# Set small values to zero by doing an inf-divide.
# This is safe (by IEEE-754) as long as S is finite.
length[small_idx] = np.inf
Snorm[:] = S / length
return Snorm
def __window_ss_fill(x, win_sq, n_frames, hop_length): # pragma: no cover
"""Helper function for window sum-square calculation."""
n = len(x)
n_fft = len(win_sq)
for i in range(n_frames):
sample = i * hop_length
x[sample:min(n, sample + n_fft)] += win_sq[:max(0,
min(n_fft, n - sample))]
def window_sumsquare(
window,
n_frames,
hop_length=512,
win_length=None,
n_fft=2048,
dtype=np.float32,
norm=None, ):
if win_length is None:
win_length = n_fft
n = n_fft + hop_length * (n_frames - 1)
x = np.zeros(n, dtype=dtype)
# Compute the squared window at the desired length
win_sq = get_window(window, win_length)
win_sq = normalize(win_sq, norm=norm)**2
win_sq = pad_center(win_sq, n_fft)
# Fill the envelope
__window_ss_fill(x, win_sq, n_frames, hop_length)
return x
def dtype_c2r(d, default=np.float32):
mapping = {
np.dtype(np.complex64): np.float32,
np.dtype(np.complex128): np.float64,
}
# If we're given a real type already, return it
dt = np.dtype(d)
if dt.kind == "f":
return dt
# Otherwise, try to map the dtype.
# If no match is found, return the default.
return np.dtype(mapping.get(np.dtype(d), default))
def dtype_r2c(d, default=np.complex64):
mapping = {
np.dtype(np.float32): np.complex64,
np.dtype(np.float64): np.complex128,
}
# If we're given a complex type already, return it
dt = np.dtype(d)
if dt.kind == "c":
return dt
# Otherwise, try to map the dtype.
# If no match is found, return the default.
return np.dtype(mapping.get(dt, default))
def frame(x, frame_length, hop_length, axis=-1):
if not isinstance(x, np.ndarray):
raise Exception("Input must be of type numpy.ndarray, "
"given type(x)={}".format(type(x)))
if x.shape[axis] < frame_length:
raise Exception("Input is too short (n={:d})"
" for frame_length={:d}".format(x.shape[axis],
frame_length))
if hop_length < 1:
raise Exception("Invalid hop_length: {:d}".format(hop_length))
if axis == -1 and not x.flags["F_CONTIGUOUS"]:
print("librosa.util.frame called with axis={} "
"on a non-contiguous input. This will result in a copy.".format(
axis))
x = np.asfortranarray(x)
elif axis == 0 and not x.flags["C_CONTIGUOUS"]:
print("librosa.util.frame called with axis={} "
"on a non-contiguous input. This will result in a copy.".format(
axis))
x = np.ascontiguousarray(x)
n_frames = 1 + (x.shape[axis] - frame_length) // hop_length
strides = np.asarray(x.strides)
new_stride = np.prod(strides[strides > 0] // x.itemsize) * x.itemsize
if axis == -1:
shape = list(x.shape)[:-1] + [frame_length, n_frames]
strides = list(strides) + [hop_length * new_stride]
elif axis == 0:
shape = [n_frames, frame_length] + list(x.shape)[1:]
strides = [hop_length * new_stride] + list(strides)
else:
raise Exception("Frame axis={} must be either 0 or -1".format(axis))
return as_strided(x, shape=shape, strides=strides)
def pad_center(data, size, axis=-1, **kwargs):
kwargs.setdefault("mode", "constant")
n = data.shape[axis]
lpad = int((size - n) // 2)
lengths = [(0, 0)] * data.ndim
lengths[axis] = (lpad, int(size - n - lpad))
if lpad < 0:
raise Exception(("Target size ({:d}) must be "
"at least input size ({:d})").format(size, n))
return np.pad(data, lengths, **kwargs)
def get_window(window, Nx, fftbins=True):
if callable(window):
return window(Nx)
elif isinstance(window, (str, tuple)) or np.isscalar(window):
# TODO: if we add custom window functions in librosa, call them here
return scipy.signal.get_window(window, Nx, fftbins=fftbins)
elif isinstance(window, (np.ndarray, list)):
if len(window) == Nx:
return np.asarray(window)
raise Exception("Window size mismatch: "
"{:d} != {:d}".format(len(window), Nx))
else:
raise Exception("Invalid window specification: {}".format(window))
def __overlap_add(y, ytmp, hop_length):
# numba-accelerated overlap add for inverse stft
# y is the pre-allocated output buffer
# ytmp is the windowed inverse-stft frames
# hop_length is the hop-length of the STFT analysis
n_fft = ytmp.shape[0]
for frame in range(ytmp.shape[1]):
sample = frame * hop_length
y[sample:(sample + n_fft)] += ytmp[:, frame]
def stft(x,
n_fft=2048,
hop_length=None,
win_length=None,
window="hann",
center=True,
pad_mode="reflect"):
y = x
input_rank = len(y.shape)
if input_rank == 2:
assert y.shape[0] == 1 # Only 1d input supported in librosa
y = y.squeeze(0)
dtype = None
# By default, use the entire frame
if win_length is None:
win_length = n_fft
# Set the default hop, if it's not already specified
if hop_length is None:
hop_length = int(win_length // 4)
fft_window = get_window(window, win_length, fftbins=True)
# Pad the window out to n_fft size
fft_window = pad_center(fft_window, n_fft)
# Reshape so that the window can be broadcast
fft_window = fft_window.reshape((-1, 1))
# Pad the time series so that frames are centered
if center:
if n_fft > y.shape[-1]:
print("n_fft={} is too small for input signal of length={}".format(
n_fft, y.shape[-1]))
y = np.pad(y, int(n_fft // 2), mode=pad_mode)
elif n_fft > y.shape[-1]:
raise Exception("n_fft={} is too large for input signal of length={}".
format(n_fft, y.shape[-1]))
# Window the time series.
y_frames = frame(y, frame_length=n_fft, hop_length=hop_length)
if dtype is None:
dtype = dtype_r2c(y.dtype)
# Pre-allocate the STFT matrix
stft_matrix = np.empty(
(int(1 + n_fft // 2), y_frames.shape[1]), dtype=dtype, order="F")
# how many columns can we fit within MAX_MEM_BLOCK?
n_columns = MAX_MEM_BLOCK // (stft_matrix.shape[0] * stft_matrix.itemsize)
n_columns = max(n_columns, 1)
for bl_s in range(0, stft_matrix.shape[1], n_columns):
bl_t = min(bl_s + n_columns, stft_matrix.shape[1])
stft_matrix[:, bl_s:bl_t] = fft.rfft(
fft_window * y_frames[:, bl_s:bl_t], axis=0)
if input_rank == 2:
stft_matrix = np.expand_dims(stft_matrix, 0)
return stft_matrix
def istft(
x,
hop_length=None,
win_length=None,
window="hann",
center=True,
length=None, ):
stft_matrix = x
input_rank = len(stft_matrix.shape)
if input_rank == 3:
assert stft_matrix.shape[0] == 1 # Only 2d input supported in librosa
stft_matrix = stft_matrix.squeeze(0)
dtype = None
n_fft = 2 * (stft_matrix.shape[0] - 1)
# By default, use the entire frame
if win_length is None:
win_length = n_fft
# Set the default hop, if it's not already specified
if hop_length is None:
hop_length = int(win_length // 4)
ifft_window = get_window(window, win_length, fftbins=True)
# Pad out to match n_fft, and add a broadcasting axis
ifft_window = pad_center(ifft_window, n_fft)[:, np.newaxis]
# For efficiency, trim STFT frames according to signal length if available
if length:
if center:
padded_length = length + int(n_fft)
else:
padded_length = length
n_frames = min(stft_matrix.shape[1],
int(np.ceil(padded_length / hop_length)))
else:
n_frames = stft_matrix.shape[1]
expected_signal_len = n_fft + hop_length * (n_frames - 1)
if dtype is None:
dtype = dtype_c2r(stft_matrix.dtype)
y = np.zeros(expected_signal_len, dtype=dtype)
n_columns = MAX_MEM_BLOCK // (stft_matrix.shape[0] * stft_matrix.itemsize)
n_columns = min(n_columns, 1)
frame = 0
for bl_s in range(0, n_frames, n_columns):
bl_t = min(bl_s + n_columns, n_frames)
# invert the block and apply the window function
ytmp = ifft_window * fft.irfft(stft_matrix[:, bl_s:bl_t], axis=0)
# Overlap-add the istft block starting at the i'th frame
__overlap_add(y[frame * hop_length:], ytmp, hop_length)
frame += bl_t - bl_s
# Normalize by sum of squared window
ifft_window_sum = window_sumsquare(
window,
n_frames,
win_length=win_length,
n_fft=n_fft,
hop_length=hop_length,
dtype=dtype, )
approx_nonzero_indices = ifft_window_sum > tiny(ifft_window_sum)
y[approx_nonzero_indices] /= ifft_window_sum[approx_nonzero_indices]
if length is None:
# If we don't need to control length, just do the usual center trimming
# to eliminate padded data
if center:
y = y[int(n_fft // 2):-int(n_fft // 2)]
else:
if center:
# If we're centering, crop off the first n_fft//2 samples
# and then trim/pad to the target length.
# We don't trim the end here, so that if the signal is zero-padded
# to a longer duration, the decay is smooth by windowing
start = int(n_fft // 2)
else:
# If we're not centering, start at 0 and trim/pad as necessary
start = 0
y = fix_length(y[start:], length)
if input_rank == 3:
y = np.expand_dims(y, 0)
return y
def frame_for_api_test(x, frame_length, hop_length, axis=-1):
if axis == -1 and not x.flags["C_CONTIGUOUS"]:
x = np.ascontiguousarray(x)
elif axis == 0 and not x.flags["F_CONTIGUOUS"]:
x = np.asfortranarray(x)
n_frames = 1 + (x.shape[axis] - frame_length) // hop_length
strides = np.asarray(x.strides)
if axis == -1:
shape = list(x.shape)[:-1] + [frame_length, n_frames]
strides = list(strides) + [hop_length * x.itemsize]
elif axis == 0:
shape = [n_frames, frame_length] + list(x.shape)[1:]
strides = [hop_length * x.itemsize] + list(strides)
else:
raise ValueError("Frame axis={} must be either 0 or -1".format(axis))
return as_strided(x, shape=shape, strides=strides)
def overlap_add_for_api_test(x, hop_length, axis=-1):
assert axis in [0, -1], 'axis should be 0/-1.'
assert len(x.shape) >= 2, 'Input dims shoulb be >= 2.'
squeeze_output = False
if len(x.shape) == 2:
squeeze_output = True
dim = 0 if axis == -1 else -1
x = np.expand_dims(x, dim) # batch
n_frames = x.shape[axis]
frame_length = x.shape[1] if axis == 0 else x.shape[-2]
# Assure no gaps between frames.
assert 0 < hop_length <= frame_length, \
f'hop_length should be in (0, frame_length({frame_length})], but got {hop_length}.'
seq_length = (n_frames - 1) * hop_length + frame_length
reshape_output = False
if len(x.shape) > 3:
reshape_output = True
if axis == 0:
target_shape = [seq_length] + list(x.shape[2:])
x = x.reshape(n_frames, frame_length, np.product(x.shape[2:]))
else:
target_shape = list(x.shape[:-2]) + [seq_length]
x = x.reshape(np.product(x.shape[:-2]), frame_length, n_frames)
if axis == 0:
x = x.transpose((2, 1, 0))
y = np.zeros(shape=[np.product(x.shape[:-2]), seq_length], dtype=x.dtype)
for i in range(x.shape[0]):
for frame in range(x.shape[-1]):
sample = frame * hop_length
y[i, sample:sample + frame_length] += x[i, :, frame]
if axis == 0:
y = y.transpose((1, 0))
if reshape_output:
y = y.reshape(target_shape)
if squeeze_output:
y = y.squeeze(-1) if axis == 0 else y.squeeze(0)
return y
def place(devices, key='place'):
def decorate(cls):
module = sys.modules[cls.__module__].__dict__
raw_classes = {
k: v
for k, v in module.items() if k.startswith(cls.__name__)
}
for raw_name, raw_cls in raw_classes.items():
for d in devices:
test_cls = dict(raw_cls.__dict__)
test_cls.update({key: d})
new_name = raw_name + '.' + d.__class__.__name__
module[new_name] = type(new_name, (raw_cls, ), test_cls)
del module[raw_name]
return cls
return decorate
def setUpModule():
global rtol
global atol
# All test case will use float64 for compare percision, refs:
# https://github.com/PaddlePaddle/Paddle/wiki/Upgrade-OP-Precision-to-Float64
rtol = {
'float32': 1e-06,
'float64': 1e-7,
'complex64': 1e-06,
'complex128': 1e-7,
}
atol = {
'float32': 0.0,
'float64': 0.0,
'complex64': 0.0,
'complex128': 0.0,
}
def tearDownModule():
pass
def rand_x(dims=1,
dtype='float64',
min_dim_len=1,
max_dim_len=10,
shape=None,
complex=False):
if shape is None:
shape = [
np.random.randint(min_dim_len, max_dim_len) for i in range(dims)
]
if complex:
return np.random.randn(*shape).astype(dtype) + 1.j * np.random.randn(
*shape).astype(dtype)
else:
return np.random.randn(*shape).astype(dtype)
def parameterize(attrs, input_values=None):
if isinstance(attrs, str):
attrs = [attrs]
input_dicts = (attrs if input_values is None else
[dict(zip(attrs, vals)) for vals in input_values])
def decorator(base_class):
test_class_module = sys.modules[base_class.__module__].__dict__
for idx, input_dict in enumerate(input_dicts):
test_class_dict = dict(base_class.__dict__)
test_class_dict.update(input_dict)
name = class_name(base_class, idx, input_dict)
test_class_module[name] = type(name, (base_class, ),
test_class_dict)
for method_name in list(base_class.__dict__):
if method_name.startswith("test"):
delattr(base_class, method_name)
return base_class
return decorator
def class_name(cls, num, params_dict):
suffix = to_safe_name(
next((v for v in params_dict.values() if isinstance(v, str)), ""))
if TEST_CASE_NAME in params_dict:
suffix = to_safe_name(params_dict["test_case"])
return "{}_{}{}".format(cls.__name__, num, suffix and "_" + suffix)
def to_safe_name(s):
return str(re.sub("[^a-zA-Z0-9_]+", "_", s))
# yapf: disable
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'frame_length', 'hop_length', 'axis'),
[
('test_1d_input1', rand_x(1, np.float64, shape=[150]), 50, 15, 0),
('test_1d_input2', rand_x(1, np.float64, shape=[150]), 50, 15, -1),
('test_2d_input1', rand_x(2, np.float64, shape=[150, 8]), 50, 15, 0),
('test_2d_input2', rand_x(2, np.float64, shape=[8, 150]), 50, 15, -1),
('test_3d_input1', rand_x(3, np.float64, shape=[150, 4, 2]), 50, 15, 0),
('test_3d_input2', rand_x(3, np.float64, shape=[4, 2, 150]), 50, 15, -1),
])
class TestFrame(unittest.TestCase):
def test_frame(self):
self.assertTrue(
np.allclose(
frame_for_api_test(self.x, self.frame_length, self.hop_length, self.axis),
paddle.tensor.signal.frame(
paddle.to_tensor(self.x),
self.frame_length,
self.hop_length,
self.axis),
rtol=rtol.get(str(self.x.dtype)),
atol=atol.get(str(self.x.dtype))))
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'frame_length', 'hop_length', 'axis'),
[
('test_1d_input1', rand_x(1, np.float64, shape=[150]), 50, 15, 0),
('test_1d_input2', rand_x(1, np.float64, shape=[150]), 50, 15, -1),
('test_2d_input1', rand_x(2, np.float64, shape=[150, 8]), 50, 15, 0),
('test_2d_input2', rand_x(2, np.float64, shape=[8, 150]), 50, 15, -1),
('test_3d_input1', rand_x(3, np.float64, shape=[150, 4, 2]), 50, 15, 0),
('test_3d_input2', rand_x(3, np.float64, shape=[4, 2, 150]), 50, 15, -1),
])
class TestFrameStatic(unittest.TestCase):
def test_frame_static(self):
paddle.enable_static()
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
input = paddle.static.data('input', self.x.shape, dtype=self.x.dtype)
output = paddle.tensor.signal.frame(
input,
self.frame_length,
self.hop_length,
self.axis),
exe = paddle.static.Executor(self.place)
exe.run(sp)
[output] = exe.run(mp, feed={'input': self.x}, fetch_list=[output])
paddle.disable_static()
self.assertTrue(
np.allclose(
frame_for_api_test(self.x, self.frame_length, self.hop_length, self.axis),
output,
rtol=rtol.get(str(self.x.dtype)),
atol=atol.get(str(self.x.dtype))))
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'frame_length', 'hop_length', 'axis', 'expect_exception'),
[
('test_axis', rand_x(1, np.float64, shape=[150]), 50, 15, 2, ValueError),
('test_hop_length', rand_x(1, np.float64, shape=[150]), 50, 0, -1, ValueError),
('test_frame_length1', rand_x(2, np.float64, shape=[150, 8]), 0, 15, 0, ValueError),
('test_frame_length2', rand_x(2, np.float64, shape=[150, 8]), 151, 15, 0, ValueError),
])
class TestFrameException(unittest.TestCase):
def test_frame(self):
with self.assertRaises(self.expect_exception):
paddle.tensor.signal.frame(
paddle.to_tensor(self.x),
self.frame_length,
self.hop_length,
self.axis)
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'hop_length', 'axis'),
[
('test_2d_input1', rand_x(2, np.float64, shape=[3, 50]), 4, 0),
('test_2d_input2', rand_x(2, np.float64, shape=[50, 3]), 4, -1),
('test_3d_input1', rand_x(3, np.float64, shape=[5, 40, 2]), 10, 0),
('test_3d_input2', rand_x(3, np.float64, shape=[2, 40, 5]), 10, -1),
('test_4d_input1', rand_x(4, np.float64, shape=[8, 12, 5, 3]), 5, 0),
('test_4d_input2', rand_x(4, np.float64, shape=[3, 5, 12, 8]), 5, -1),
])
class TestOverlapAdd(unittest.TestCase):
def test_overlap_add(self):
self.assertTrue(
np.allclose(
overlap_add_for_api_test(self.x, self.hop_length, self.axis),
paddle.tensor.signal.overlap_add(
paddle.to_tensor(self.x),
self.hop_length,
self.axis),
rtol=rtol.get(str(self.x.dtype)),
atol=atol.get(str(self.x.dtype))))
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'hop_length', 'axis'),
[
('test_2d_input1', rand_x(2, np.float64, shape=[3, 50]), 4, 0),
('test_2d_input2', rand_x(2, np.float64, shape=[50, 3]), 4, -1),
('test_3d_input1', rand_x(3, np.float64, shape=[5, 40, 2]), 10, 0),
('test_3d_input2', rand_x(3, np.float64, shape=[2, 40, 5]), 10, -1),
('test_4d_input1', rand_x(4, np.float64, shape=[8, 12, 5, 3]), 5, 0),
('test_4d_input2', rand_x(4, np.float64, shape=[3, 5, 12, 8]), 5, -1),
])
class TestOverlapAddStatic(unittest.TestCase):
def test_overlap_add_static(self):
paddle.enable_static()
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
input = paddle.static.data('input', self.x.shape, dtype=self.x.dtype)
output = paddle.tensor.signal.overlap_add(
input,
self.hop_length,
self.axis),
exe = paddle.static.Executor(self.place)
exe.run(sp)
[output] = exe.run(mp, feed={'input': self.x}, fetch_list=[output])
paddle.disable_static()
self.assertTrue(
np.allclose(
overlap_add_for_api_test(self.x, self.hop_length, self.axis),
output,
rtol=rtol.get(str(self.x.dtype)),
atol=atol.get(str(self.x.dtype))))
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'hop_length', 'axis', 'expect_exception'),
[
('test_axis', rand_x(2, np.float64, shape=[3, 50]), 4, 2, ValueError),
('test_hop_length', rand_x(2, np.float64, shape=[50, 3]), -1, -1, ValueError),
])
class TestOverlapAddException(unittest.TestCase):
def test_overlap_add(self):
with self.assertRaises(self.expect_exception):
paddle.tensor.signal.overlap_add(
paddle.to_tensor(self.x),
self.hop_length,
self.axis)
# ================= STFT
# common args
# x
# n_fft,
# hop_length=None,
# win_length=None,
# window=None,
# center=True,
# pad_mode='reflect',
# paddle only
# normalized=False,
# onesided=True,
# ================= ISTFT
# common args
# x,
# hop_length=None,
# win_length=None,
# window=None,
# center=True,
# length=None,
# paddle only
# n_fft,
# normalized=False,
# onesided=True,
# return_complex=False,
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n_fft', 'hop_length', 'win_length', 'window', 'center', 'pad_mode', 'normalized', 'onesided'),
[
('test_1d_input', rand_x(1, np.float64, shape=[160000]),
512, None, None, get_window('hann', 512), True, 'reflect', False, True),
('test_2d_input', rand_x(2, np.float64, shape=[1, 160000]),
512, None, None, get_window('hann', 512), True, 'reflect', False, True),
('test_hop_length', rand_x(2, np.float64, shape=[1, 160000]),
512, 255, None, get_window('hann', 512), True, 'reflect', False, True),
('test_win_length', rand_x(2, np.float64, shape=[1, 160000]),
512, 255, 499, get_window('hann', 499), True, 'reflect', False, True),
('test_window', rand_x(2, np.float64, shape=[1, 160000]),
512, None, None, None, True, 'reflect', False, True),
('test_center', rand_x(2, np.float64, shape=[1, 160000]),
512, None, None, None, False, 'reflect', False, True),
])
class TestStft(unittest.TestCase):
def test_stft(self):
if self.window is None:
win_p = None
win_l = 'boxcar' # rectangular window
else:
win_p = paddle.to_tensor(self.window)
win_l = self.window
self.assertTrue(
np.allclose(
stft(self.x, self.n_fft, self.hop_length, self.win_length, win_l, self.center, self.pad_mode),
paddle.tensor.signal.stft(
paddle.to_tensor(self.x),
self.n_fft,
self.hop_length,
self.win_length,
win_p,
self.center,
self.pad_mode,
self.normalized,
self.onesided),
rtol=rtol.get(str(self.x.dtype)),
atol=atol.get(str(self.x.dtype))))
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n_fft', 'hop_length', 'win_length', 'window', 'center', 'pad_mode', 'normalized', 'onesided', 'expect_exception'),
[
('test_dims', rand_x(1, np.float64, shape=[1, 2, 3]),
512, None, None, None, True, 'reflect', False, True, AssertionError),
('test_hop_length', rand_x(1, np.float64, shape=[16000]),
512, 0, None, None, True, 'reflect', False, True, AssertionError),
('test_nfft1', rand_x(1, np.float64, shape=[16000]),
0, None, None, None, True, 'reflect', False, True, AssertionError),
('test_nfft2', rand_x(1, np.float64, shape=[16000]),
16001, None, None, None, True, 'reflect', False, True, AssertionError),
('test_win_length', rand_x(1, np.float64, shape=[16000]),
512, None, 0, None, True, 'reflect', False, True, AssertionError),
('test_win_length', rand_x(1, np.float64, shape=[16000]),
512, None, 513, None, True, 'reflect', False, True, AssertionError),
('test_pad_mode', rand_x(1, np.float64, shape=[16000]),
512, None, None, None, True, 'nonsense', False, True, AssertionError),
('test_complex_onesided', rand_x(1, np.float64, shape=[16000], complex=True),
512, None, None, None, False, 'reflect', False, True, AssertionError),
])
class TestStftException(unittest.TestCase):
def test_stft(self):
if self.window is None:
win_p = None
else:
win_p = paddle.to_tensor(self.window)
with self.assertRaises(self.expect_exception):
paddle.tensor.signal.stft(
paddle.to_tensor(self.x),
self.n_fft,
self.hop_length,
self.win_length,
win_p,
self.center,
self.pad_mode,
self.normalized,
self.onesided),
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n_fft', 'hop_length', 'win_length', 'window', 'center', 'normalized', 'onesided', 'length', 'return_complex'),
[
('test_2d_input', rand_x(2, np.float64, shape=[257, 471], complex=True),
512, None, None, get_window('hann', 512), True, False, True, None, False),
('test_3d_input', rand_x(3, np.float64, shape=[1, 257, 471], complex=True),
512, None, None, get_window('hann', 512), True, False, True, None, False),
('test_hop_length', rand_x(3, np.float64, shape=[1, 257, 471], complex=True),
512, 99, None, get_window('hann', 512), True, False, True, None, False),
('test_win_length', rand_x(3, np.float64, shape=[1, 257, 471], complex=True),
512, 99, 299, get_window('hann', 299), True, False, True, None, False),
('test_window', rand_x(3, np.float64, shape=[1, 257, 471], complex=True),
512, None, None, None, True, False, True, None, False),
('test_center', rand_x(3, np.float64, shape=[1, 257, 471], complex=True),
512, None, None, None, False, False, True, None, False),
('test_length', rand_x(3, np.float64, shape=[1, 257, 471], complex=True),
512, None, None, None, False, False, True, 1888, False),
])
class TestIstft(unittest.TestCase):
def test_istft(self):
if self.window is None:
win_p = None
win_l = 'boxcar' # rectangular window
else:
win_p = paddle.to_tensor(self.window)
win_l = self.window
self.assertTrue(
np.allclose(
istft(self.x, self.hop_length, self.win_length, win_l, self.center, self.length),
paddle.tensor.signal.istft(
paddle.to_tensor(self.x),
self.n_fft,
self.hop_length,
self.win_length,
win_p,
self.center,
self.normalized,
self.onesided,
self.length,
self.return_complex),
rtol=rtol.get(str(self.x.dtype)),
atol=atol.get(str(self.x.dtype))))
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n_fft', 'hop_length', 'win_length', 'window', 'center', 'normalized', 'onesided', 'length', 'return_complex', 'expect_exception'),
[
('test_dims', rand_x(4, np.float64, shape=[1, 2, 3, 4], complex=True),
512, None, None, get_window('hann', 512), True, False, True, None, False, AssertionError),
('test_n_fft', rand_x(3, np.float64, shape=[1, 257, 471], complex=True),
257, None, None, get_window('hann', 512), True, False, True, None, False, AssertionError),
('test_hop_length1', rand_x(3, np.float64, shape=[1, 257, 471], complex=True),
512, 0, None, get_window('hann', 512), True, False, True, None, False, AssertionError),
('test_hop_length2', rand_x(3, np.float64, shape=[1, 257, 471], complex=True),
512, 513, None, get_window('hann', 512), True, False, True, None, False, AssertionError),
('test_win_length1', rand_x(3, np.float64, shape=[1, 257, 471], complex=True),
512, None, 0, get_window('hann', 512), True, False, True, None, False, AssertionError),
('test_win_length2', rand_x(3, np.float64, shape=[1, 257, 471], complex=True),
512, None, 513, get_window('hann', 512), True, False, True, None, False, AssertionError),
('test_onesided1', rand_x(3, np.float64, shape=[1, 257, 471], complex=True),
20, None, None, get_window('hann', 512), True, False, True, None, False, AssertionError),
('test_onesided2', rand_x(3, np.float64, shape=[1, 257, 471], complex=True),
256, None, None, None, True, False, False, None, False, AssertionError),
('test_window', rand_x(3, np.float64, shape=[1, 512, 471], complex=True),
512, None, 511, get_window('hann', 512), True, False, False, None, False, AssertionError),
('test_return_complex1', rand_x(3, np.float64, shape=[1, 257, 471], complex=True),
512, None, None, get_window('hann', 512), True, False, True, None, True, AssertionError),
('test_return_complex2', rand_x(3, np.float64, shape=[1, 257, 471], complex=True),
512, None, None, rand_x(1, np.float64, shape=[512], complex=True), True, False, True, None, False, AssertionError),
('test_NOLA', rand_x(3, np.float64, shape=[1, 257, 471], complex=True),
512, 512, None, get_window('hann', 512), True, False, True, None, False, ValueError),
])
class TestIstftException(unittest.TestCase):
def test_istft(self):
if self.window is None:
win_p = None
else:
win_p = paddle.to_tensor(self.window)
with self.assertRaises(self.expect_exception):
paddle.tensor.signal.istft(
paddle.to_tensor(self.x),
self.n_fft,
self.hop_length,
self.win_length,
win_p,
self.center,
self.normalized,
self.onesided,
self.length,
self.return_complex),
# yapf: enable
if __name__ == '__main__':
unittest.main()
...@@ -216,6 +216,8 @@ from .array import array_write # noqa: F401 ...@@ -216,6 +216,8 @@ from .array import array_write # noqa: F401
from .array import create_array # noqa: F401 from .array import create_array # noqa: F401
from .einsum import einsum # noqa: F401 from .einsum import einsum # noqa: F401
from . import fft
from . import signal
#this list used in math_op_patch.py for _binary_creator_ #this list used in math_op_patch.py for _binary_creator_
tensor_method_func = [ #noqa tensor_method_func = [ #noqa
......
...@@ -35,6 +35,41 @@ def _complex_to_real_dtype(dtype): ...@@ -35,6 +35,41 @@ def _complex_to_real_dtype(dtype):
return dtype return dtype
def _real_to_complex_dtype(dtype):
if dtype == core.VarDesc.VarType.FP32:
return core.VarDesc.VarType.COMPLEX64
elif dtype == core.VarDesc.VarType.FP64:
return core.VarDesc.VarType.COMPLEX128
else:
return dtype
def is_complex(x):
dtype = x.dtype
is_complex_dtype = (dtype == core.VarDesc.VarType.COMPLEX64 or
dtype == core.VarDesc.VarType.COMPLEX128)
return is_complex_dtype
def is_floating_point(x):
dtype = x.dtype
is_fp_dtype = (dtype == core.VarDesc.VarType.FP32 or
dtype == core.VarDesc.VarType.FP64 or
dtype == core.VarDesc.VarType.FP16 or
dtype == core.VarDesc.VarType.BF16)
return is_fp_dtype
def is_interger(x):
dtype = x.dtype
is_int_dtype = (dtype == core.VarDesc.VarType.UINT8 or
dtype == core.VarDesc.VarType.INT8 or
dtype == core.VarDesc.VarType.INT16 or
dtype == core.VarDesc.VarType.INT32 or
dtype == core.VarDesc.VarType.INT64)
return is_int_dtype
def real(x, name=None): def real(x, name=None):
""" """
Returns a new tensor containing real values of the input tensor. Returns a new tensor containing real values of the input tensor.
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Sequence
import numpy as np
import paddle
from .attribute import is_complex, is_floating_point, is_interger, _real_to_complex_dtype, _complex_to_real_dtype
from ..fluid.framework import in_dygraph_mode
from .. import _C_ops
from ..fluid.data_feeder import check_variable_and_dtype
from ..fluid.layer_helper import LayerHelper
__all__ = [
'fft',
'fft2',
'fftn',
'ifft',
'ifft2',
'ifftn',
'rfft',
'rfft2',
'rfftn',
'irfft',
'irfft2',
'irfftn',
'hfft',
'hfft2',
'hfftn',
'ihfft',
'ihfft2',
'ihfftn',
'fftfreq',
'rfftfreq',
'fftshift',
'ifftshift',
]
def _check_normalization(norm):
if norm not in ['forward', 'backward', 'ortho']:
raise ValueError(
"Unexpected norm: {}. Norm should be forward, backward or ortho".
format(norm))
def _check_fft_n(n):
if not isinstance(n, int):
raise ValueError(
"Invalid FFT argument n({}), it shoule be an integer.".format(n))
if n <= 0:
raise ValueError(
"Invalid FFT argument n({}), it should be positive.".format(n))
def _check_fft_shape(x, s):
ndim = x.ndim
if not isinstance(s, Sequence):
raise ValueError(
"Invaid FFT argument s({}), it should be a sequence of integers.")
if len(s) > ndim:
raise ValueError(
"Length of FFT argument s should not be larger than the rank of input. "
"Received s: {}, rank of x: {}".format(s, ndim))
for size in s:
if not isinstance(size, int) or size <= 0:
raise ValueError("FFT sizes {} contains invalid value ({})".format(
s, size))
def _check_fft_axis(x, axis):
ndim = x.ndim
if not isinstance(axis, int):
raise ValueError(
"Invalid FFT axis ({}), it shoule be an integer.".format(axis))
if axis < -ndim or axis >= ndim:
raise ValueError(
"Invalid FFT axis ({}), it should be in range [-{}, {})".format(
axis, ndim, ndim))
def _check_fft_axes(x, axes):
ndim = x.ndim
if not isinstance(axes, Sequence):
raise ValueError(
"Invalid FFT axes ({}), it should be a sequence of integers.".
format(axes))
if len(axes) > ndim:
raise ValueError(
"Length of fft axes should not be larger than the rank of input. "
"Received, len of axes: {}, rank of x: {}".format(len(axes), ndim))
for axis in axes:
if not isinstance(axis, int) or axis < -ndim or axis >= ndim:
raise ValueError(
"FFT axes {} contains invalid value ({}), it should be in range [-{}, {})".
format(axes, axis, ndim, ndim))
def _resize_fft_input(x, s, axes):
if len(s) != len(axes):
raise ValueError("length of `s` should equals length of `axes`.")
shape = x.shape
ndim = x.ndim
axes_to_pad = []
paddings = []
axes_to_slice = []
slices = []
for i, axis in enumerate(axes):
if shape[axis] < s[i]:
axes_to_pad.append(axis)
paddings.append(s[i] - shape[axis])
elif shape[axis] > s[i]:
axes_to_slice.append(axis)
slices.append((0, s[i]))
if axes_to_slice:
x = paddle.slice(
x,
axes_to_slice,
starts=[item[0] for item in slices],
ends=[item[1] for item in slices])
if axes_to_pad:
padding_widths = [0] * (2 * ndim)
for axis, pad in zip(axes_to_pad, paddings):
padding_widths[2 * axis + 1] = pad
x = paddle.nn.functional.pad(x, padding_widths)
return x
def _normalize_axes(x, axes):
ndim = x.ndim
return [item if item >= 0 else (item + ndim) for item in axes]
def _check_at_least_ndim(x, rank):
if x.ndim < rank:
raise ValueError("The rank of the input ({}) should >= {}".format(
x.ndim, rank))
# public APIs 1d
def fft(x, n=None, axis=-1, norm="backward", name=None):
"""
Calculate one-dimensional discrete Fourier transform.
This function uses the efficient fast Fourier transform (FFT) algorithm [1] to
calculate the 1-D * n * point discrete Fourier transform (DFT).
Args:
x (Tensor): The input data. It's a Tensor type. It's a complex.
n (int, optional): The length of the output transform axis. If `n` is less than
the length input, the input will be cropped. If larger, the input is filled
with zeros. If `n` is not given, the input length along the axis specified
by `axis` is used.
axis (int, optional): Axis used to calculate FFT. If not specified, the last axis
is used by default.
norm (str): Indicates which direction to scale the `forward` or `backward` transform
pair and what normalization factor to use. The parameter value must be one
of "forward" or "backward" or "ortho". Default is "backward", meaning no normalization on
the forward transforms and scaling by ``1/n`` on the `ifft`. "forward" instead applies
the ``1/n`` factor on the forward tranform. For ``norm="ortho"``, both directions are
scaled by ``1/sqrt(n)``.
name (str, optional): The default value is None. Normally there is no need for user to set
this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
complex tensor. The truncated or zero-padded input, transformed along the axis indicated
by `axis`, or the last one if `axis` is not specified.
Examples:
.. code-block:: python
import numpy as np
import paddle
x = np.exp(3j * np.pi * np.arange(7) / 7)
xp = paddle.to_tensor(x)
fft_xp = paddle.fft.fft(xp).numpy()
print(fft_xp)
# [1.+1.25396034e+00j 1.+4.38128627e+00j 1.-4.38128627e+00j
# 1.-1.25396034e+00j 1.-4.81574619e-01j 1.+8.88178420e-16j
# 1.+4.81574619e-01j]
"""
if is_interger(x) or is_floating_point(x):
return fft_r2c(
x, n, axis, norm, forward=True, onesided=False, name=name)
else:
return fft_c2c(x, n, axis, norm, forward=True, name=name)
def ifft(x, n=None, axis=-1, norm="backward", name=None):
"""
Compute the 1-D inverse discrete Fourier Transform.
This function computes the inverse of the 1-D *n*-point discrete Fourier transform
computed by `fft`. In other words, ``ifft(fft(x)) == x`` to within numerical accuracy.
The input should be ordered in the same way as is returned by `fft`,
i.e.,
* ``x[0]`` should contain the zero frequency term,
* ``x[1:n//2]`` should contain the positive-frequency terms,
* ``x[n//2 + 1:]`` should contain the negative-frequency terms, in
increasing order starting from the most negative frequency.
For an even number of input points, ``x[n//2]`` represents the sum of
the values at the positive and negative Nyquist frequencies, as the two
are aliased together.
Args:
x (Tensor): The input data. It's a Tensor type. It's a complex.
n (int, optional): The length of the output transform axis. If `n` is less than
the length input, the input will be cropped. If larger, the input is filled
with zeros. If `n` is not given, the input length along the axis specified
by `axis` is used.
axis (int, optional): Axis used to calculate FFT. If not specified, the last axis
is used by default.
norm (str): Indicates which direction to scale the `forward` or `backward` transform
pair and what normalization factor to use. The parameter value must be one
of "forward" or "backward" or "ortho". Default is "backward", meaning no normalization on
the forward transforms and scaling by ``1/n`` on the `ifft`. "forward" instead applies
the ``1/n`` factor on the forward tranform. For ``norm="ortho"``, both directions are
scaled by ``1/sqrt(n)``.
name (str, optional): The default value is None. Normally there is no need for user to set
this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
complex tensor. The truncated or zero-padded input, transformed along the axis indicated
by `axis`, or the last one if `axis` is not specified.
Examples:
.. code-block:: python
import numpy as np
import paddle
x = np.exp(3j * np.pi * np.arange(7) / 7)
xp = paddle.to_tensor(x)
ifft_xp = paddle.fft.ifft(xp).numpy()
print(ifft_xp)
# [0.14285714+1.79137191e-01j 0.14285714+6.87963741e-02j
# 0.14285714+1.26882631e-16j 0.14285714-6.87963741e-02j
# 0.14285714-1.79137191e-01j 0.14285714-6.25898038e-01j
# 0.14285714+6.25898038e-01j]
"""
if is_interger(x) or is_floating_point(x):
return fft_r2c(
x, n, axis, norm, forward=False, onesided=False, name=name)
else:
return fft_c2c(x, n, axis, norm, forward=False, name=name)
def rfft(x, n=None, axis=-1, norm="backward", name=None):
"""
The one dimensional FFT for real input.
This function computes the one dimensional *n*-point discrete Fourier
Transform (DFT) of a real-valued tensor by means of an efficient algorithm
called the Fast Fourier Transform (FFT).
When the DFT is computed for purely real input, the output is
Hermitian-symmetric. This function does not compute the negative frequency
terms, and the length of the transformed axis of the output is therefore
``n//2 + 1``.
Args:
x(Tensor) : Real-valued input tensor
n(int, optional): Number of points along transformation axis in the
input to use. If `n` is smaller than the length of the input, the
input is cropped. If it is larger, the input is padded with zeros.
If `n` is not given, the length of the input along the axis
specified by `axis` is used.
axis(int, optional): Axis over which to compute the FFT. Default value
is last axis.
norm(str, optional) : Normalization mode, indicates which direction of
the forward/backward pair of transforms is scaled and with what
normalization factor. Include {"backward", "ortho", "forward"},
default value is "backward".
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name` .
Returns:
out(Tensor) : complex tensor
Raises:
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([0.0, 1.0, 0.0, 0.0])
print(paddle.fft.rfft(x))
# Tensor(shape=[3], dtype=complex64, place=CUDAPlace(0), stop_gradient=True,
# [ (1+0j), -1j , (-1+0j)])
"""
return fft_r2c(x, n, axis, norm, forward=True, onesided=True, name=name)
def irfft(x, n=None, axis=-1, norm="backward", name=None):
"""
Computes the inverse of `rfft`.
This function calculates the inverse of the one-dimensional *n* point discrete
Fourier transform of the actual input calculated by "rfft". In other words,
``irfft(rfft(a),len(a)) == a`` is within the numerical accuracy range.
The input shall be in the form of "rfft", i.e. the actual zero frequency term,
followed by the complex positive frequency term, in the order of increasing frequency.
Because the discrete Fourier transform of the actual input is Hermite symmetric,
the negative frequency term is regarded as the complex conjugate term of the corresponding
positive frequency term.
Args:
x (Tensor): The input data. It's a Tensor type. It's a complex.
n (int, optional): The length of the output transform axis. For `n` output
points, ``n//2 + 1``input points are necessary. If the length of the input tensor is greater
than `n`, it will be cropped, if it is shorter than this, fill in zero. If `n` is not given,
it is considered to be ``2 * (k-1)``, where ``k`` is the length of the input axis specified
along the ` axis'.
axis (int, optional): Axis used to calculate FFT. If not specified, the last axis
is used by default.
norm (str): Indicates which direction to scale the `forward` or `backward` transform
pair and what normalization factor to use. The parameter value must be one
of "forward" or "backward" or "ortho". Default is "backward".
name (str, optional): The default value is None. Normally there is no need for user to set
this property. For more information, please refer to :ref:`api_guide_Name` .
Returns:
Real tensor. Truncated or zero fill input for the transformation along the axis indicated by
`axis`, or the last input if `axis` is not specified. The length of the conversion axis
is `n`, or ``2 * k-2``, if `k` is None, where `k` is the length of the input conversion axis.
If the output is an odd number, you need to specify the value of 'n', such as ``2 * k-1``
in some cases.
Examples:
.. code-block:: python
import numpy as np
import paddle
x = np.array([1, -1j, -1])
xp = paddle.to_tensor(x)
irfft_xp = paddle.fft.irfft(xp).numpy()
print(irfft_xp)
# [0. 0. 0. 4.]
"""
return fft_c2r(x, n, axis, norm, forward=False, name=name)
def hfft(x, n=None, axis=-1, norm="backward", name=None):
"""
Compute the FFT of a signal that has Hermitian symmetry, a real
spectrum.
Args:
x (Tensor): The input data. It's a Tensor type. It's a complex.
n (int, optional): The length of the output transform axis. For `n` output
points, ``n//2 + 1`` input points are necessary. If the length of the input tensor is greater
than `n`, it will be cropped, if it is shorter than this, fill in zero. If `n` is not given,
it is considered to be ``2 * (k-1)``, where ``k`` is the length of the input axis specified
along the ` axis'.
axis (int,optional): Axis used to calculate FFT. If not specified, the last axis
is used by default.
norm (str): Indicates which direction to scale the `forward` or `backward` transform
pair and what normalization factor to use. The parameter value must be one
of "forward" or "backward" or "ortho". Default is "backward".
name (str, optional): The default value is None. Normally there is no need for user to set
this property. For more information, please refer to :ref:`api_guide_Name` .
Returns:
Real tensor. Truncated or zero fill input for the transformation along the axis indicated by
`axis`, or the last input if `axis` is not specified. The length of the conversion axis
is `n`, or ``2 * k-2``, if `k` is None, where `k` is the length of the input conversion axis.
If the output is an odd number, you need to specify the value of 'n', such as ``2 * k-1`` in
some cases.
Examples:
.. code-block:: python
import numpy as np
import paddle
x = np.array([1, -1j, -1])
xp = paddle.to_tensor(x)
hfft_xp = paddle.fft.hfft(xp).numpy()
print(hfft_xp)
# [0. 0. 0. 4.]
"""
return fft_c2r(x, n, axis, norm, forward=True, name=name)
def ihfft(x, n=None, axis=-1, norm="backward", name=None):
"""
The inverse FFT of a signal that has Hermitian symmetry.
This function computes the one dimensional *n*-point inverse FFT of a signal
that has Hermitian symmetry by means of an efficient algorithm called
the Fast Fourier Transform (FFT).
When the DFT is computed for purely real input, the output is
Hermitian-symmetric. This function does not compute the negative frequency
terms, and the length of the transformed axis of the output is therefore
``n//2 + 1``.
Args:
x(Tensor): Input tensor.
n(int, optional): The number of points along transformation axis in the
input to use. If `n` is smaller than the length of the input, the
input is cropped. If it is larger, the input is padded with zeros.
If `n` is not given, the length of the input along the axis
specified by `axis` is used.
axis(int, optional) : Axis over which to compute the inverse FFT. If not
given, the last axis is used.
norm(str, optional) : Normalization mode, indicates which direction of
the forward/backward pair of transforms is scaled and with what
normalization factor. Include {"backward", "ortho", "forward"},
default value is "backward".
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name` .
Returns:
out(Tensor) : complex tensor.
Examples:
.. code-block:: python
import paddle
spectrum = paddle.to_tensor([10.0, -5.0, 0.0, -1.0, 0.0, -5.0])
print(paddle.fft.ifft(spectrum))
# Tensor(shape=[6], dtype=complex64, place=CUDAPlace(0), stop_gradient=True,
# [(-0.1666666716337204+0j), (1-1.9868215517249155e-08j), (2.3333334922790527-1.9868215517249155e-08j), (3.5+0j), (2.3333334922790527+1.9868215517249155e-08j), (1+1.9868215517249155e-08j)])
print(paddle.fft.ihfft(spectrum))
# Tensor(shape = [4], dtype = complex64, place = CUDAPlace(0), stop_gradient = True,
# [(-0.1666666716337204+0j), (1-1.9868215517249155e-08j), (2.3333334922790527-1.9868215517249155e-08j), (3.5+0j)])
"""
return fft_r2c(x, n, axis, norm, forward=False, onesided=True, name=name)
# public APIs nd
def fftn(x, s=None, axes=None, norm="backward", name=None):
"""
Compute the N-D discrete Fourier Transform.
This function calculates the n-D discrete Fourier transform on any number of axes
in the M-D array by fast Fourier transform (FFT).
Args:
x (Tensor): The input data. It's a Tensor type. It's a complex.
s (sequence of ints, optional): Shape (length of each transformed axis) of the output
(``s[0]`` refers to axis 0, ``s[1]`` to axis 1, etc.).
This corresponds to ``n`` for ``fft(x, n)``.
Along any axis, if the given shape is smaller than that of the input,
the input is cropped. If it is larger, the input is padded with zeros.
if `s` is not given, the shape of the input along the axes specified
by `axes` is used.
axes (sequence of ints, optional): Axes used to calculate FFT. If not given, the last ``len(s)``
axes are used, or all axes if `s` is also not specified.
norm (str): Indicates which direction to scale the `forward` or `backward` transform
pair and what normalization factor to use. The parameter value must be one
of "forward" or "backward" or "ortho". Default is "backward", meaning no normalization on
the forward transforms and scaling by ``1/n`` on the `ifft`. "forward" instead applies
the ``1/n`` factor on the forward tranform. For ``norm="ortho"``, both directions are
scaled by ``1/sqrt(n)``.
name (str, optional): The default value is None. Normally there is no need for user to set
this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
complex tensor. The truncated or zero-padded input, transformed along the axes indicated by
`axes`, or by a combination of `s` and `x`, as explained in the parameters section above.
Examples:
.. code-block:: python
import numpy as np
import paddle
x = x = np.mgrid[:4, :4, :4][1]
xp = paddle.to_tensor(x)
fftn_xp = paddle.fft.fftn(xp, axes=(1, 2)).numpy()
print(fftn_xp)
# [[[24.+0.j 0.+0.j 0.+0.j 0.-0.j]
# [-8.+8.j 0.+0.j 0.+0.j 0.-0.j]
# [-8.+0.j 0.+0.j 0.+0.j 0.-0.j]
# [-8.-8.j 0.+0.j 0.+0.j 0.-0.j]]
# [[24.+0.j 0.+0.j 0.+0.j 0.-0.j]
# [-8.+8.j 0.+0.j 0.+0.j 0.-0.j]
# [-8.+0.j 0.+0.j 0.+0.j 0.-0.j]
# [-8.-8.j 0.+0.j 0.+0.j 0.-0.j]]
# [[24.+0.j 0.+0.j 0.+0.j 0.-0.j]
# [-8.+8.j 0.+0.j 0.+0.j 0.-0.j]
# [-8.+0.j 0.+0.j 0.+0.j 0.-0.j]
# [-8.-8.j 0.+0.j 0.+0.j 0.-0.j]]
# [[24.+0.j 0.+0.j 0.+0.j 0.-0.j]
# [-8.+8.j 0.+0.j 0.+0.j 0.-0.j]
# [-8.+0.j 0.+0.j 0.+0.j 0.-0.j]
# [-8.-8.j 0.+0.j 0.+0.j 0.-0.j]]]
"""
if is_interger(x) or is_floating_point(x):
return fftn_r2c(
x, s, axes, norm, forward=True, onesided=False, name=name)
else:
return fftn_c2c(x, s, axes, norm, forward=True, name=name)
def ifftn(x, s=None, axes=None, norm="backward", name=None):
"""
Compute the N-D inverse discrete Fourier Transform.
This function computes the inverse of the N-D discrete
Fourier Transform over any number of axes in an M-D array by
means of the Fast Fourier Transform (FFT). In other words,
``ifftn(fftn(x)) == x`` to within numerical accuracy.
The input, analogously to `ifft`, should be ordered in the same way as is
returned by `fftn`, i.e., it should have the term for zero frequency
in all axes in the low-order corner, the positive frequency terms in the
first half of all axes, the term for the Nyquist frequency in the middle
of all axes and the negative frequency terms in the second half of all
axes, in order of decreasingly negative frequency.
Args:
x (Tensor): The input data. It's a Tensor type. It's a complex.
s (sequence of ints, optional): Shape (length of each transformed axis) of the output
(``s[0]`` refers to axis 0, ``s[1]`` to axis 1, etc.).
This corresponds to ``n`` for ``fft(x, n)``.
Along any axis, if the given shape is smaller than that of the input,
the input is cropped. If it is larger, the input is padded with zeros.
if `s` is not given, the shape of the input along the axes specified
by `axes` is used.
axes (sequence of ints, optional): Axes used to calculate FFT. If not given, the last ``len(s)``
axes are used, or all axes if `s` is also not specified.
norm (str): Indicates which direction to scale the `forward` or `backward` transform
pair and what normalization factor to use. The parameter value must be one
of "forward" or "backward" or "ortho". Default is "backward", meaning no normalization on
the forward transforms and scaling by ``1/n`` on the `ifft`. "forward" instead applies
the ``1/n`` factor on the forward tranform. For ``norm="ortho"``, both directions are
scaled by ``1/sqrt(n)``.
name (str, optional): The default value is None. Normally there is no need for user to set
this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
complex tensor. The truncated or zero-padded input, transformed along the axes indicated by
`axes`, or by a combination of `s` and `x`, as explained in the parameters section above.
Examples:
.. code-block:: python
import numpy as np
import paddle
x = np.eye(3)
xp = paddle.to_tensor(x)
ifftn_xp = paddle.fft.ifftn(xp, axes=(1,)).numpy()
print(ifftn_xp)
# [[ 0.33333333+0.j 0.33333333+0.j 0.33333333-0.j ]
# [ 0.33333333+0.j -0.16666667+0.28867513j -0.16666667-0.28867513j]
# [ 0.33333333+0.j -0.16666667-0.28867513j -0.16666667+0.28867513j]]
"""
if is_interger(x) or is_floating_point(x):
return fftn_r2c(
x, s, axes, norm, forward=False, onesided=False, name=name)
else:
return fftn_c2c(x, s, axes, norm, forward=False, name=name)
def rfftn(x, s=None, axes=None, norm="backward", name=None):
"""
The N dimensional FFT for real input.
This function computes the N-dimensional discrete Fourier Transform over
any number of axes in an M-dimensional real array by means of the Fast
Fourier Transform (FFT). By default, all axes are transformed, with the
real transform performed over the last axis, while the remaining
transforms are complex.
The transform for real input is performed over the last transformation
axis, as by `rfft`, then the transform over the remaining axes is
performed as by `fftn`. The order of the output is as for `rfft` for the
final transformation axis, and as for `fftn` for the remaining
transformation axes.
Args:
x(Tensor) : Input tensor, taken to be real.
s(Sequence[int]) : Shape to use from the exec fft. The final element of
`s` corresponds to `n` for ``rfft(x, n)``, while for the remaining
axes, it corresponds to `n` for ``fft(x, n)``. Along any axis, if
the given shape is smaller than that of the input, the input is
cropped. If it is larger, the input is padded with zeros. if `s` is
not given, the shape of the input along the axes specified by `axes`
is used.
axes(Sequence[int]) : Axes over which to compute the FFT. If not given,
the last ``len(s)`` axes are used, or all axes if `s` is also not
specified.
norm(str, optional) : Normalization mode, indicates which direction of
the forward/backward pair of transforms is scaled and with what
normalization factor. Include {"backward", "ortho", "forward"},
default value is "backward".
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name` .
Returns:
out(Tensor): complex tensor
Raises:
ValueError: If `s` and `axes` have different length.
Examples:
.. code-block:: python
import paddle
# default, all axis will be used to exec fft
x = paddle.ones((2, 3, 4))
print(paddle.fft.rfftn(x))
# Tensor(shape=[2, 3, 3], dtype=complex64, place=CUDAPlace(0), stop_gradient=True,
# [[[(24+0j), 0j , 0j ],
# [0j , 0j , 0j ],
# [0j , 0j , 0j ]],
#
# [[0j , 0j , 0j ],
# [0j , 0j , 0j ],
# [0j , 0j , 0j ]]])
# use axes(2, 0)
print(paddle.fft.rfftn(x, axes=(2, 0)))
# Tensor(shape=[2, 3, 3], dtype=complex64, place=CUDAPlace(0), stop_gradient=True,
# [[[(24+0j), 0j , 0j ],
# [0j , 0j , 0j ],
# [0j , 0j , 0j ]],
#
# [[0j , 0j , 0j ],
# [0j , 0j , 0j ],
# [0j , 0j , 0j ]]])
"""
return fftn_r2c(x, s, axes, norm, forward=True, onesided=True, name=name)
def irfftn(x, s=None, axes=None, norm="backward", name=None):
"""
Computes the inverse of `rfftn`.
This function computes the inverse of the N-D discrete
Fourier Transform for real input over any number of axes in an
M-D array by means of the Fast Fourier Transform (FFT). In
other words, ``irfftn(rfftn(x), x.shape) == x`` to within numerical
accuracy. (The ``a.shape`` is necessary like ``len(a)`` is for `irfft`,
and for the same reason.)
The input should be ordered in the same way as is returned by `rfftn`,
i.e., as for `irfft` for the final transformation axis, and as for `ifftn`
along all the other axes.
Args:
x (Tensor): The input data. It's a Tensor type.
s (sequence of ints, optional): The length of the output transform axis.
(``s[0]`` refers to axis 0, ``s[1]`` to axis 1, etc.). `s` is also the
number of input points used along this axis, except for the last axis,
where ``s[-1]//2+1`` points of the input are used. Along any axis, if
the shape indicated by `s` is smaller than that of the input, the input
is cropped. If it is larger, the input is padded with zeros.
If `s` is not given, the shape of the input along the axes specified by axes
is used. Except for the last axis which is taken to be ``2*(k-1)`` where
``k`` is the length of the input along that axis.
axes (sequence of ints, optional): Axes over which to compute the inverse FFT. If not given, the last
`len(s)` axes are used, or all axes if `s` is also not specified.
norm (str): Indicates which direction to scale the `forward` or `backward` transform
pair and what normalization factor to use. The parameter value must be one
of "forward" or "backward" or "ortho". Default is "backward".
name (str, optional): The default value is None. Normally there is no need for user to set
this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Real tensor. The truncated or zero-padded input, transformed along the axes indicated by `axes`,
or by a combination of `s` or `x`, as explained in the parameters section above. The length of
each transformed axis is as given by the corresponding element of `s`, or the length of the input
in every axis except for the last one if `s` is not given. In the final transformed axis the length
of the output when `s` is not given is ``2*(m-1)``, where ``m`` is the length of the final
transformed axis of the input. To get an odd number of output points in the final axis,
`s` must be specified.
Examples:
.. code-block:: python
import numpy as np
import paddle
x = (np.array([2, 2, 3]) + 1j * np.array([2, 2, 3])).astype(np.complex128)
xp = paddle.to_tensor(x)
irfftn_xp = paddle.fft.irfftn(xp).numpy()
print(irfftn_xp)
# [ 2.25 -1.25 0.25 0.75]
"""
return fftn_c2r(x, s, axes, norm, forward=False, name=name)
def hfftn(x, s=None, axes=None, norm="backward", name=None):
"""
Compute the N-D FFT of Hermitian symmetric complex input, i.e., a
signal with a real spectrum.
This function calculates the n-D discrete Fourier transform of Hermite symmetric
complex input on any axis in M-D array by fast Fourier transform (FFT).
In other words, ``ihfftn(hfftn(x, s)) == x is within the numerical accuracy range.
(``s`` here are ``x.shape`` and ``s[-1] = x.shape[- 1] * 2 - 1``. This is necessary
for the same reason that ``irfft` requires ``x.shape``.)
Args:
x (Tensor): The input data. It's a Tensor type.
s (sequence of ints, optional): The length of the output transform axis.
(``s[0]`` refers to axis 0, ``s[1]`` to axis 1, etc.). `s` is also the
number of input points used along this axis, except for the last axis,
where ``s[-1]//2+1`` points of the input are used. Along any axis, if
the shape indicated by `s` is smaller than that of the input, the input
is cropped. If it is larger, the input is padded with zeros.
If `s` is not given, the shape of the input along the axes specified by axes
is used. Except for the last axis which is taken to be ``2*(k-1)`` where
``k`` is the length of the input along that axis.
axes (sequence of ints, optional): Axes over which to compute the inverse FFT. If not given, the last
`len(s)` axes are used, or all axes if `s` is also not specified.
norm (str): Indicates which direction to scale the `forward` or `backward` transform
pair and what normalization factor to use. The parameter value must be one
of "forward" or "backward" or "ortho". Default is "backward".
name (str, optional): The default value is None. Normally there is no need for user to set
this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Real tensor. Truncate or zero fill input, transforming along the axis indicated by axis or
a combination of `s` or `X`.
Examples:
.. code-block:: python
import numpy as np
import paddle
x = (np.array([2, 2, 3]) + 1j * np.array([2, 2, 3])).astype(np.complex128)
xp = paddle.to_tensor(x)
hfftn_xp = paddle.fft.hfftn(xp).numpy()
print(hfftn_xp)
# [ 9. 3. 1. -5.]
"""
return fftn_c2r(x, s, axes, norm, forward=True, name=name)
def ihfftn(x, s=None, axes=None, norm="backward", name=None):
"""
The n dimensional inverse FFT of a signal that has Hermitian symmetry.
This function computes the n dimensional inverse FFT over any number of axes
in an M-dimensional of a signal that has Hermitian symmetry by means of an
efficient algorithm called the Fast Fourier Transform (FFT).
Args:
x(Tensor): Input tensor.
s(Sequence[int], optional) : Shape (length along each transformed axis)
to use from the input. (``s[0]`` refers to axis 0, ``s[1]`` to axis
1, etc.). Along any axis, if the given shape is smaller than that
of the input, the input is cropped. If it is larger, the input is
padded with zeros. if `s` is not given, the shape of the input
along the axes specified by `axes` is used.
axis(Sequence[int], optional) : Axis over which to compute the inverse FFT. If not
given, the last axis is used.
norm(str, optional) : Normalization mode, indicates which direction of
the forward/backward pair of transforms is scaled and with what
normalization factor. Include {"backward", "ortho", "forward"},
default value is "backward".
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name` .
Returns:
out(Tensor) : complex tensor.
Examples:
.. code-block:: python
import paddle
spectrum = paddle.to_tensor([10.0, -5.0, 0.0, -1.0, 0.0, -5.0])
print(paddle.fft.ifft(spectrum))
# Tensor(shape=[6], dtype=complex64, place=CUDAPlace(0), stop_gradient=True,
# [(-0.1666666716337204+0j), (1-1.9868215517249155e-08j), (2.3333334922790527-1.9868215517249155e-08j), (3.5+0j), (2.3333334922790527+1.9868215517249155e-08j), (1+1.9868215517249155e-08j)])
print(paddle.fft.ihfft(spectrum))
# Tensor(shape = [4], dtype = complex64, place = CUDAPlace(0), stop_gradient = True,
# [(-0.1666666716337204+0j), (1-1.9868215517249155e-08j), (2.3333334922790527-1.9868215517249155e-08j), (3.5+0j)])
"""
return fftn_r2c(x, s, axes, norm, forward=False, onesided=True, name=name)
# public APIs 2d
def fft2(x, s=None, axes=(-2, -1), norm="backward", name=None):
"""
Compute the 2-D discrete Fourier Transform
This function computes the N-D discrete Fourier Transform
over any axes in an M-D array by means of the
Fast Fourier Transform (FFT). By default, the transform is computed over
the last two axes of the input array, i.e., a 2-dimensional FFT.
Args:
x (Tensor): The input data. It's a Tensor type.
s (sequence of ints, optional): Shape (length of each transformed axis) of the output.
It should be a sequence of 2 integers. This corresponds to ``n`` for ``fft(x, n)``.
Along each axis, if the given shape is smaller than that of the input,
the input is cropped. If it is larger, the input is padded with zeros.
if `s` is not given, the shape of the input along the axes specified
by `axes` is used. Default is None.
axes (sequence of ints, optional): Axes over which to compute the FFT. It should be a
sequence of 2 integers. If not specified, the last two axes are used by default.
norm (str): Indicates which direction to scale the `forward` or `backward` transform
pair and what normalization factor to use. The parameter value must be one
of "forward" or "backward" or "ortho". Default is "backward".
name (str, optional): The default value is None. Normally there is no need for user to set
this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Complex tensor. The truncated or zero-padded input, transformed along the axes indicated by `axes`,
or the last two axes if `axes` is not given.
Raises:
ValueError: if `s` not be a sequence of 2 integers or None.
ValueError: if `axes` not be a sequence of 2 integers or None.
ValueError: If the input dimension is smaller than 2.
Examples:
.. code-block:: python
import numpy as np
import paddle
x = np.mgrid[:2, :2][1]
xp = paddle.to_tensor(x)
fft2_xp = paddle.fft.fft2(xp).numpy()
print(fft2_xp)
# [[ 2.+0.j -2.+0.j]
# [ 0.+0.j 0.+0.j]]
"""
_check_at_least_ndim(x, 2)
if s is not None:
if not isinstance(s, Sequence) or len(s) != 2:
raise ValueError(
"Invalid FFT argument s ({}), it should be a sequence of 2 integers.".
format(s))
if axes is not None:
if not isinstance(axes, Sequence) or len(axes) != 2:
raise ValueError(
"Invalid FFT argument axes ({}), it should be a sequence of 2 integers.".
format(axes))
return fftn(x, s, axes, norm, name)
def ifft2(x, s=None, axes=(-2, -1), norm="backward", name=None):
"""
Compute the 2-D inverse discrete Fourier Transform.
This function computes the inverse of the 2-D discrete Fourier
Transform over any number of axes in an M-D array by means of
the Fast Fourier Transform (FFT). In other words, ``ifft2(fft2(x)) == x``
to within numerical accuracy. By default, the inverse transform is
computed over the last two axes of the input array.
The input, analogously to `ifft`, should be ordered in the same way as is
returned by `fft2`, i.e., it should have the term for zero frequency
in the low-order corner of the two axes, the positive frequency terms in
the first half of these axes, the term for the Nyquist frequency in the
middle of the axes and the negative frequency terms in the second half of
both axes, in order of decreasingly negative frequency.
Args:
x (Tensor): The input data. It's a Tensor type.
s (sequence of ints, optional): Shape (length of each transformed axis) of the output.
It should be a sequence of 2 integers. This corresponds to ``n`` for ``fft(x, n)``.
Along each axis, if the given shape is smaller than that of the input,
the input is cropped. If it is larger, the input is padded with zeros.
if `s` is not given, the shape of the input along the axes specified
by `axes` is used. Default is None.
axes (sequence of ints, optional): Axes over which to compute the FFT. It should be a
sequence of 2 integers. If not specified, the last two axes are used by default.
norm (str): Indicates which direction to scale the `forward` or `backward` transform
pair and what normalization factor to use. The parameter value must be one
of "forward" or "backward" or "ortho". Default is "backward".
name (str, optional): The default value is None. Normally there is no need for user to set
this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Complex tensor. The truncated or zero-padded input, transformed along the axes indicated by `axes`,
or the last two axes if `axes` is not given.
Raises:
ValueError: if `s` not be a sequence of 2 integers or None.
ValueError: if `axes` not be a sequence of 2 integers or None.
ValueError: If the input dimension is smaller than 2.
Examples:
.. code-block:: python
import numpy as np
import paddle
x = np.mgrid[:2, :2][1]
xp = paddle.to_tensor(x)
ifft2_xp = paddle.fft.ifft2(xp).numpy()
print(ifft2_xp)
# [[ 0.5+0.j -0.5+0.j]
# [ 0. +0.j 0. +0.j]]
"""
_check_at_least_ndim(x, 2)
if s is not None:
if not isinstance(s, Sequence) or len(s) != 2:
raise ValueError(
"Invalid FFT argument s ({}), it should be a sequence of 2 integers.".
format(s))
if axes is not None:
if not isinstance(axes, Sequence) or len(axes) != 2:
raise ValueError(
"Invalid FFT argument axes ({}), it should be a sequence of 2 integers.".
format(axes))
return ifftn(x, s, axes, norm, name)
def rfft2(x, s=None, axes=(-2, -1), norm="backward", name=None):
"""
The two dimensional FFT with real tensor input.
This is really just `rfftn` with different default behavior.
For more details see `rfftn`.
Args:
x(Tensor): Input tensor, taken to be real.
s(Sequence[int]) : Shape of the FFT.
axes(Sequence[int], optional): Axes over which to compute the FFT.
norm(str, optional) : {"backward", "ortho", "forward"},
default is "backward". Indicates which direction of the
forward/backward pair of transforms is scaled and with what
normalization factor.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name` .
Returns:
out(Tensor): The result of the real 2-D FFT.
Raises:
Examples:
.. code-block:: python
import paddle
import numpy as np
x = paddle.to_tensor(np.mgrid[:5, :5][0].astype(np.float32))
print(paddle.fft.rfft2(x))
# Tensor(shape=[5, 3], dtype=complex64, place=CUDAPlace(0), stop_gradient=True,
# [[ (50+0j) , (1.1920928955078125e-07+0j) , 0j ],
# [(-12.5+17.204774856567383j) , (-9.644234211236835e-08+7.006946134424652e-08j) , 0j ],
# [(-12.500000953674316+4.061495304107666j) , (3.6837697336977726e-08-1.1337477445749755e-07j), 0j ],
# [(-12.500000953674316-4.061495304107666j) , (3.6837697336977726e-08+1.1337477445749755e-07j), 0j ],
# [(-12.5-17.204774856567383j) , (-9.644234211236835e-08-7.006946134424652e-08j) , 0j ]])
"""
_check_at_least_ndim(x, 2)
if s is not None:
if not isinstance(s, Sequence) or len(s) != 2:
raise ValueError(
"Invalid FFT argument s ({}), it should be a sequence of 2 integers.".
format(s))
if axes is not None:
if not isinstance(axes, Sequence) or len(axes) != 2:
raise ValueError(
"Invalid FFT argument axes ({}), it should be a sequence of 2 integers.".
format(axes))
return rfftn(x, s, axes, norm, name)
def irfft2(x, s=None, axes=(-2, -1), norm="backward", name=None):
"""
Computes the inverse of `rfft2`.
Args:
x (Tensor): The input data. It's a Tensor type.
s (sequence of ints, optional): Shape of the real output to the inverse FFT. Default is None.
axes (sequence of ints, optional): The axes over which to compute the inverse FFT. Axes
must be two-dimensional. If not specified, the last two axes are used by default.
norm (str): Indicates which direction to scale the `forward` or `backward` transform
pair and what normalization factor to use. The parameter value must be one
of "forward" or "backward" or "ortho". Default is "backward".
name (str, optional): The default value is None. Normally there is no need for user to set
this property. For more information, please refer to :ref:`api_guide_Name` .
Returns:
Real tensor. The result of the inverse real 2-D FFT.
Raises:
ValueError: if `s` not be a sequence of 2 integers or None.
ValueError: if `axes` not be a sequence of 2 integers or None.
ValueError: If the input dimension is smaller than 2.
Examples:
.. code-block:: python
import numpy as np
import paddle
x = (np.array([[3,2,3],[2, 2, 3]]) + 1j * np.array([[3,2,3],[2, 2, 3]])).astype(np.complex128)
xp = paddle.to_tensor(x)
irfft2_xp = paddle.fft.irfft2(xp).numpy()
print(irfft2_xp)
# [[ 2.375 -1.125 0.375 0.875]
# [ 0.125 0.125 0.125 0.125]]
"""
_check_at_least_ndim(x, 2)
if s is not None:
if not isinstance(s, Sequence) or len(s) != 2:
raise ValueError(
"Invalid FFT argument s ({}), it should be a sequence of 2 integers.".
format(s))
if axes is not None:
if not isinstance(axes, Sequence) or len(axes) != 2:
raise ValueError(
"Invalid FFT argument axes ({}), it should be a sequence of 2 integers.".
format(axes))
return irfftn(x, s, axes, norm, name)
def hfft2(x, s=None, axes=(-2, -1), norm="backward", name=None):
"""
Compute the 2-D FFT of a Hermitian complex array.
Args:
x (Tensor): The input data. It's a Tensor type.
s (sequence of ints, optional): Shape of the real output. Default is None.
axes (sequence of ints, optional): Axes over which to compute the FFT. Axes must be
two-dimensional. If not specified, the last two axes are used by default.
norm (str): Indicates which direction to scale the `forward` or `backward` transform
pair and what normalization factor to use. The parameter value must be one
of "forward" or "backward" or "ortho". Default is "backward".
name (str, optional): The default value is None. Normally there is no need for user to set
this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Real tensor. The real result of the 2-D Hermitian complex real FFT.
Raises:
ValueError: if `s` not be a sequence of 2 integers or None.
ValueError: if `axes` not be a sequence of 2 integers or None.
ValueError: If the input dimension is smaller than 2.
Examples:
.. code-block:: python
import numpy as np
import paddle
x = (np.array([[3,2,3],[2, 2, 3]]) + 1j * np.array([[3,2,3],[2, 2, 3]])).astype(np.complex128)
xp = paddle.to_tensor(x)
hfft2_xp = paddle.fft.hfft2(xp).numpy()
print(hfft2_xp)
# [[19. 7. 3. -9.]
# [ 1. 1. 1. 1.]]
"""
_check_at_least_ndim(x, 2)
if s is not None:
if not isinstance(s, Sequence) or len(s) != 2:
raise ValueError(
"Invalid FFT argument s ({}), it should be a sequence of 2 integers.".
format(s))
if axes is not None:
if not isinstance(axes, Sequence) or len(axes) != 2:
raise ValueError(
"Invalid FFT argument axes ({}), it should be a sequence of 2 integers.".
format(axes))
return hfftn(x, s, axes, norm, name)
def ihfft2(x, s=None, axes=(-2, -1), norm="backward", name=None):
"""
Compute the two dimensional inverse FFT of a real spectrum.
This is really `ihfftn` with different defaults.
For more details see `ihfftn`.
Args:
x(Tensor): Input tensor
s(Sequence[int], optional): Shape of the real input to the inverse FFT.
axes(Sequance[int], optional): The axes over which to compute the
inverse fft. Default is the last two axes.
norm(str, optional): {"backward", "ortho", "forward"}. Default is
"backward".
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name` .
Returns:
out(Tensor) : The result of the inverse real 2-D FFT.
"""
_check_at_least_ndim(x, 2)
if s is not None:
if not isinstance(s, Sequence) or len(s) != 2:
raise ValueError(
"Invalid FFT argument s ({}), it should be a sequence of 2 integers.".
format(s))
if axes is not None:
if not isinstance(axes, Sequence) or len(axes) != 2:
raise ValueError(
"Invalid FFT argument axes ({}), it should be a sequence of 2 integers.".
format(axes))
return ihfftn(x, s, axes, norm, name)
# public APIs utilities
def fftfreq(n, d=1.0, dtype=None, name=None):
"""
Return the Discrete Fourier Transform sample frequencies.
The returned float array `f` contains the frequency bin centers in cycles
per unit of the sample spacing (with zero at the start). For instance, if
the sample spacing is in seconds, then the frequency unit is cycles/second.
Given input length `n` and a sample spacing `d`::
f = [0, 1, ..., n/2-1, -n/2, ..., -1] / (d*n) if n is even
f = [0, 1, ..., (n-1)/2, -(n-1)/2, ..., -1] / (d*n) if n is odd
Args:
n (int): Dimension inputed.
d (scalar, optional): Sample spacing (inverse of the sampling rate). Defaults is 1.
name (str, optional): The default value is None. Normally there is no need for user to set
this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor. A tensor of length 'n' containing the sampling frequency.
Examples:
.. code-block:: python
import numpy as np
import paddle
x = np.array([3, 1, 2, 2, 3], dtype=float)
scalar_temp = 0.5
n = x.size
fftfreq_xp = paddle.fft.fftfreq(n, d=scalar_temp)
print(fftfreq_xp)
# Tensor(shape=[5], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [ 0. , 0.40000001, 0.80000001, -0.80000001, -0.40000001])
"""
dtype = paddle.framework.get_default_dtype()
val = 1.0 / (n * d)
pos_max = (n + 1) // 2
neg_max = n // 2
indices = paddle.arange(-neg_max, pos_max, dtype=dtype, name=name)
indices = paddle.roll(indices, -neg_max, name=name)
return indices * val
def rfftfreq(n, d=1.0, dtype=None, name=None):
"""
Return the Discrete Fourier Transform sample frequencies.
The returned floating-point array "F" contains the center of the frequency unit,
and the unit is the number of cycles of the sampling interval (the starting point is zero).
Given input length `n` and a sample spacing `d`::
f = [0, 1, ..., n/2-1, n/2] / (d*n) if n is even
f = [0, 1, ..., (n-1)/2-1, (n-1)/2] / (d*n) if n is odd
the Nyquist frequency component is considered to be positive.
Args:
n (int): Dimension inputed.
d (scalar, optional): Sample spacing (inverse of the sampling rate). Defaults is 1.
name (str, optional): The default value is None. Normally there is no need for user to set
this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor. A tensor of length ``n//2 + 1`` containing the sample frequencies.
Examples:
.. code-block:: python
import numpy as np
import paddle
x = np.array([3, 1, 2, 2, 3], dtype=float)
scalar_temp = 0.3
n = x.size
rfftfreq_xp = paddle.fft.rfftfreq(n, d=scalar_temp)
print(rfftfreq_xp)
# Tensor(shape=[3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [0. , 0.66666669, 1.33333337])
"""
dtype = paddle.framework.get_default_dtype()
val = 1.0 / (n * d)
pos_max = 1 + n // 2
indices = paddle.arange(0, pos_max, dtype=dtype, name=name)
return indices * val
def fftshift(x, axes=None, name=None):
"""
Shift the zero-frequency component to the center of the spectrum.
This function swaps half spaces for all the axes listed (all by default).
Note that ``y[0]`` is the Nyquist component only if ``len(x)`` is even.
Args:
n (int): Dimension inputed.
axes (int|tuple, optional): The axis on which to move. The default is none, which moves all axes.
Default is None.
name (str, optional): The default value is None. Normally there is no need for user to set
this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor. The shifted tensor.
Examples:
.. code-block:: python
import numpy as np
import paddle
x = np.array([3, 1, 2, 2, 3], dtype=float)
scalar_temp = 0.3
n = x.size
fftfreq_xp = paddle.fft.fftfreq(n, d=scalar_temp)
res = paddle.fft.fftshift(fftfreq_xp).numpy()
print(res)
# [-1.3333334 -0.6666667 0. 0.6666667 1.3333334]
"""
shape = paddle.shape(x)
if axes is None:
# shift all axes
rank = paddle.rank(x).reshape([1])
axes = axes or paddle.arange(0, rank)
shifts = [size // 2 for size in shape]
elif isinstance(axes, int):
shifts = shape[axes] // 2
else:
shifts = [shape[ax] // 2 for ax in axes]
return paddle.roll(x, shifts, axes, name=name)
def ifftshift(x, axes=None, name=None):
"""
The inverse of `fftshift`. Although the even length 'x' is the same, the function of the
odd length 'x' is different. An example.
Args:
n (int): Dimension inputed.
axes (int|tuple, optional): The axis on which to move. The default is none, which moves all axes.
Default is None.
name (str, optional): The default value is None. Normally there is no need for user to set
this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor. The shifted tensor.
Examples:
.. code-block:: python
import numpy as np
import paddle
x = np.array([3, 1, 2, 2, 3], dtype=float)
scalar_temp = 0.3
n = x.size
fftfreq_xp = paddle.fft.fftfreq(n, d=scalar_temp)
res = paddle.fft.ifftshift(fftfreq_xp).numpy()
print(res)
# [ 1.3333334 -1.3333334 -0.6666667 0. 0.6666667]
"""
shape = paddle.shape(x)
if axes is None:
# shift all axes
rank = paddle.rank(x).reshape([1])
axes = axes or paddle.arange(0, rank)
shifts = [-size // 2 for size in shape]
elif isinstance(axes, int):
shifts = -shape[axes] // 2
else:
shifts = [-shape[ax] // 2 for ax in axes]
return paddle.roll(x, shifts, axes, name=name)
# internal functions
def fft_c2c(x, n, axis, norm, forward, name):
if is_interger(x):
x = paddle.cast(x, _real_to_complex_dtype(paddle.get_default_dtype()))
elif is_floating_point(x):
x = paddle.cast(x, _real_to_complex_dtype(x.dtype))
_check_normalization(norm)
axis = axis or -1
_check_fft_axis(x, axis)
axes = [axis]
axes = _normalize_axes(x, axes)
if n is not None:
_check_fft_n(n)
s = [n]
x = _resize_fft_input(x, s, axes)
op_type = 'fft_c2c'
check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type)
if in_dygraph_mode():
attrs = ('axes', axes, 'normalization', norm, 'forward', forward)
out = getattr(_C_ops, op_type)(x, *attrs)
else:
inputs = {'X': [x], }
attrs = {'axes': axes, 'normalization': norm, 'forward': forward}
helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype(input_param_name='x')
out = helper.create_variable_for_type_inference(dtype)
outputs = {"Out": [out]}
helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=attrs)
return out
def fft_r2c(x, n, axis, norm, forward, onesided, name):
if is_interger(x):
x = paddle.cast(x, paddle.get_default_dtype())
_check_normalization(norm)
axis = axis or -1
_check_fft_axis(x, axis)
axes = [axis]
axes = _normalize_axes(x, axes)
if n is not None:
_check_fft_n(n)
s = [n]
x = _resize_fft_input(x, s, axes)
op_type = 'fft_r2c'
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], op_type)
if in_dygraph_mode():
attrs = ('axes', axes, 'normalization', norm, 'forward', forward,
'onesided', onesided)
out = getattr(_C_ops, op_type)(x, *attrs)
else:
inputs = {'X': [x], }
attrs = {
'axes': axes,
'normalization': norm,
'forward': forward,
'onesided': onesided,
}
helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype(input_param_name='x')
out = helper.create_variable_for_type_inference(
_real_to_complex_dtype(dtype))
outputs = {"Out": [out]}
helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=attrs)
return out
def fft_c2r(x, n, axis, norm, forward, name):
if is_interger(x):
x = paddle.cast(x, _real_to_complex_dtype(paddle.get_default_dtype()))
elif is_floating_point(x):
x = paddle.cast(x, _real_to_complex_dtype(x.dtype))
_check_normalization(norm)
axis = axis or -1
_check_fft_axis(x, axis)
axes = [axis]
axes = _normalize_axes(x, axes)
if n is not None:
_check_fft_n(n)
s = [n // 2 + 1]
x = _resize_fft_input(x, s, axes)
op_type = 'fft_c2r'
check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type)
if in_dygraph_mode():
if n is not None:
attrs = ('axes', axes, 'normalization', norm, 'forward', forward,
'last_dim_size', n)
else:
attrs = ('axes', axes, 'normalization', norm, 'forward', forward)
out = getattr(_C_ops, op_type)(x, *attrs)
else:
inputs = {'X': [x], }
attrs = {'axes': axes, 'normalization': norm, 'forward': forward}
if n is not None:
attrs['last_dim_size'] = n
helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype(input_param_name='x')
out = helper.create_variable_for_type_inference(
_complex_to_real_dtype(dtype))
outputs = {"Out": [out]}
helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=attrs)
return out
def fftn_c2c(x, s, axes, norm, forward, name):
if is_interger(x):
x = paddle.cast(x, _real_to_complex_dtype(paddle.get_default_dtype()))
elif is_floating_point(x):
x = paddle.cast(x, _real_to_complex_dtype(x.dtype))
_check_normalization(norm)
if s is not None:
_check_fft_shape(x, s)
rank = x.ndim
if axes is None:
if s is None:
axes = list(range(rank))
else:
fft_ndims = len(s)
axes = list(range(rank - fft_ndims, rank))
else:
_check_fft_axes(x, axes)
axes = _normalize_axes(x, axes)
axes_argsoft = np.argsort(axes).tolist()
axes = [axes[i] for i in axes_argsoft]
if s is not None:
if len(s) != len(axes):
raise ValueError(
"Length of s ({}) and length of axes ({}) does not match.".
format(len(s), len(axes)))
s = [s[i] for i in axes_argsoft]
if s is not None:
x = _resize_fft_input(x, s, axes)
op_type = 'fft_c2c'
check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type)
if in_dygraph_mode():
attrs = ('axes', axes, 'normalization', norm, 'forward', forward)
out = getattr(_C_ops, op_type)(x, *attrs)
else:
inputs = {'X': [x], }
attrs = {'axes': axes, 'normalization': norm, 'forward': forward}
helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype(input_param_name='x')
out = helper.create_variable_for_type_inference(dtype)
outputs = {"Out": [out]}
helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=attrs)
return out
def fftn_r2c(x, s, axes, norm, forward, onesided, name):
if is_interger(x):
x = paddle.cast(x, paddle.get_default_dtype())
_check_normalization(norm)
if s is not None:
_check_fft_shape(x, s)
rank = x.ndim
if axes is None:
if s is None:
axes = list(range(rank))
else:
fft_ndims = len(s)
axes = list(range(rank - fft_ndims, rank))
else:
_check_fft_axes(x, axes)
axes = _normalize_axes(x, axes)
axes_argsoft = np.argsort(axes[:-1]).tolist()
axes = [axes[i] for i in axes_argsoft] + [axes[-1]]
if s is not None:
if len(s) != len(axes):
raise ValueError(
"Length of s ({}) and length of axes ({}) does not match.".
format(len(s), len(axes)))
s = [s[i] for i in axes_argsoft] + [s[-1]]
if s is not None:
x = _resize_fft_input(x, s, axes)
op_type = 'fft_r2c'
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], op_type)
if in_dygraph_mode():
attrs = ('axes', axes, 'normalization', norm, 'forward', forward,
'onesided', onesided)
out = getattr(_C_ops, op_type)(x, *attrs)
else:
inputs = {'X': [x], }
attrs = {
'axes': axes,
'normalization': norm,
'forward': forward,
'onesided': onesided,
}
helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype(input_param_name='x')
out = helper.create_variable_for_type_inference(
_real_to_complex_dtype(dtype))
outputs = {"Out": [out]}
helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=attrs)
return out
def fftn_c2r(x, s, axes, norm, forward, name):
if is_interger(x):
x = paddle.cast(x, _real_to_complex_dtype(paddle.get_default_dtype()))
elif is_floating_point(x):
x = paddle.cast(x, _real_to_complex_dtype(x.dtype))
_check_normalization(norm)
if s is not None:
_check_fft_shape(x, s)
rank = x.ndim
if axes is None:
if s is None:
axes = list(range(rank))
else:
fft_ndims = len(s)
axes = list(range(rank - fft_ndims, rank))
else:
_check_fft_axes(x, axes)
axes = _normalize_axes(x, axes)
axes_argsoft = np.argsort(axes[:-1]).tolist()
axes = [axes[i] for i in axes_argsoft] + [axes[-1]]
if s is not None:
if len(s) != len(axes):
raise ValueError(
"Length of s ({}) and length of axes ({}) does not match.".
format(len(s), len(axes)))
s = [s[i] for i in axes_argsoft] + [s[-1]]
if s is not None:
fft_input_shape = list(s)
fft_input_shape[-1] = fft_input_shape[-1] // 2 + 1
x = _resize_fft_input(x, fft_input_shape, axes)
op_type = 'fft_c2r'
check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type)
if in_dygraph_mode():
if s:
attrs = ('axes', axes, 'normalization', norm, 'forward', forward,
'last_dim_size', s[-1])
else:
attrs = ('axes', axes, 'normalization', norm, 'forward', forward)
out = getattr(_C_ops, op_type)(x, *attrs)
else:
inputs = {'X': [x], }
attrs = {'axes': axes, 'normalization': norm, 'forward': forward}
if s:
attrs["last_dim_size"] = s[-1]
helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype(input_param_name='x')
out = helper.create_variable_for_type_inference(
_complex_to_real_dtype(dtype))
outputs = {"Out": [out]}
helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=attrs)
return out
...@@ -682,7 +682,7 @@ def roll(x, shifts, axis=None, name=None): ...@@ -682,7 +682,7 @@ def roll(x, shifts, axis=None, name=None):
axis = [axis] axis = [axis]
len_origin_shape = len(origin_shape) len_origin_shape = len(origin_shape)
if axis: if axis is not None:
for i in range(len(axis)): for i in range(len(axis)):
if axis[i] >= len_origin_shape or axis[i] < -len_origin_shape: if axis[i] >= len_origin_shape or axis[i] < -len_origin_shape:
raise ValueError( raise ValueError(
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import paddle
from .attribute import is_complex, is_floating_point
from .fft import fft_r2c, fft_c2r, fft_c2c
from ..fluid.data_feeder import check_variable_and_dtype
from ..fluid.framework import in_dygraph_mode
from ..fluid.layer_helper import LayerHelper
from .. import _C_ops
__all__ = [
'frame',
'overlap_add',
'stft',
'istft',
]
def frame(x, frame_length, hop_length, axis=-1, name=None):
"""
Slice the N-dimensional (where N >= 1) input into (overlapping) frames.
Args:
x (Tensor): The input data which is a N-dimensional (where N >= 1) Tensor
with shape `[..., seq_length]` or `[seq_length, ...]`.
frame_length (int): Length of the frame and `0 < frame_length <= x.shape[axis]`.
hop_length (int): Number of steps to advance between adjacent frames
and `0 < hop_length`.
axis (int, optional): Specify the axis to operate on the input Tensors. Its
value should be 0(the first dimension) or -1(the last dimension). If not
specified, the last axis is used by default.
Returns:
The output frames tensor with shape `[..., frame_length, num_frames]` if `axis==-1`,
otherwise `[num_frames, frame_length, ...]` where
`num_framse = 1 + (x.shape[axis] - frame_length) // hop_length`
Examples:
.. code-block:: python
import paddle
from paddle.tensor.signal import frame
# 1D
x = paddle.arange(8)
y0 = frame(x, frame_length=4, hop_length=2, axis=-1) # [4, 3]
# [[0, 2, 4],
# [1, 3, 5],
# [2, 4, 6],
# [3, 5, 7]]
y1 = frame(x, frame_length=4, hop_length=2, axis=0) # [3, 4]
# [[0, 1, 2, 3],
# [2, 3, 4, 5],
# [4, 5, 6, 7]]
# 2D
x0 = paddle.arange(16).reshape([2, 8])
y0 = frame(x0, frame_length=4, hop_length=2, axis=-1) # [2, 4, 3]
# [[[0, 2, 4],
# [1, 3, 5],
# [2, 4, 6],
# [3, 5, 7]],
#
# [[8 , 10, 12],
# [9 , 11, 13],
# [10, 12, 14],
# [11, 13, 15]]]
x1 = paddle.arange(16).reshape([8, 2])
y1 = frame(x1, frame_length=4, hop_length=2, axis=0) # [3, 4, 2]
# [[[0 , 1 ],
# [2 , 3 ],
# [4 , 5 ],
# [6 , 7 ]],
#
# [4 , 5 ],
# [6 , 7 ],
# [8 , 9 ],
# [10, 11]],
#
# [8 , 9 ],
# [10, 11],
# [12, 13],
# [14, 15]]]
# > 2D
x0 = paddle.arange(32).reshape([2, 2, 8])
y0 = frame(x0, frame_length=4, hop_length=2, axis=-1) # [2, 2, 4, 3]
x1 = paddle.arange(32).reshape([8, 2, 2])
y1 = frame(x1, frame_length=4, hop_length=2, axis=0) # [3, 4, 2, 2]
"""
if axis not in [0, -1]:
raise ValueError(f'Unexpected axis: {axis}. It should be 0 or -1.')
if not isinstance(frame_length, int) or frame_length <= 0:
raise ValueError(
f'Unexpected frame_length: {frame_length}. It should be an positive integer.'
)
if not isinstance(hop_length, int) or hop_length <= 0:
raise ValueError(
f'Unexpected hop_length: {hop_length}. It should be an positive integer.'
)
if frame_length > x.shape[axis]:
raise ValueError(
f'Attribute frame_length should be less equal than sequence length, '
f'but got ({frame_length}) > ({x.shape[axis]}).')
op_type = 'frame'
if in_dygraph_mode():
attrs = ('frame_length', frame_length, 'hop_length', hop_length, 'axis',
axis)
op = getattr(_C_ops, op_type)
out = op(x, *attrs)
else:
check_variable_and_dtype(
x, 'x', ['int32', 'int64', 'float16', 'float32',
'float64'], op_type)
helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype(input_param_name='x')
out = helper.create_variable_for_type_inference(dtype=dtype)
helper.append_op(
type=op_type,
inputs={'X': x},
attrs={
'frame_length': frame_length,
'hop_length': hop_length,
'axis': axis
},
outputs={'Out': out})
return out
def overlap_add(x, hop_length, axis=-1, name=None):
"""
Reconstructs a tensor consisted of overlap added sequences from input frames.
Args:
x (Tensor): The input data which is a N-dimensional (where N >= 2) Tensor
with shape `[..., frame_length, num_frames]` or
`[num_frames, frame_length ...]`.
hop_length (int): Number of steps to advance between adjacent frames and
`0 < hop_length <= frame_length`.
axis (int, optional): Specify the axis to operate on the input Tensors. Its
value should be 0(the first dimension) or -1(the last dimension). If not
specified, the last axis is used by default.
Returns:
The output frames tensor with shape `[..., seq_length]` if `axis==-1`,
otherwise `[seq_length, ...]` where
`seq_length = (n_frames - 1) * hop_length + frame_length`
Examples:
.. code-block:: python
import paddle
from paddle.tensor.signal import overlap_add
# 2D
x0 = paddle.arange(16).reshape([8, 2])
# [[0 , 1 ],
# [2 , 3 ],
# [4 , 5 ],
# [6 , 7 ],
# [8 , 9 ],
# [10, 11],
# [12, 13],
# [14, 15]]
y0 = overlap_add(x0, hop_length=2, axis=-1) # [10]
# [0 , 2 , 5 , 9 , 13, 17, 21, 25, 13, 15]
x1 = paddle.arange(16).reshape([2, 8])
# [[0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ],
# [8 , 9 , 10, 11, 12, 13, 14, 15]]
y1 = overlap_add(x1, hop_length=2, axis=0) # [10]
# [0 , 1 , 10, 12, 14, 16, 18, 20, 14, 15]
# > 2D
x0 = paddle.arange(32).reshape([2, 1, 8, 2])
y0 = overlap_add(x0, hop_length=2, axis=-1) # [2, 1, 10]
x1 = paddle.arange(32).reshape([2, 8, 1, 2])
y1 = overlap_add(x1, hop_length=2, axis=0) # [10, 1, 2]
"""
if axis not in [0, -1]:
raise ValueError(f'Unexpected axis: {axis}. It should be 0 or -1.')
if not isinstance(hop_length, int) or hop_length <= 0:
raise ValueError(
f'Unexpected hop_length: {hop_length}. It should be an positive integer.'
)
op_type = 'overlap_add'
if in_dygraph_mode():
attrs = ('hop_length', hop_length, 'axis', axis)
op = getattr(_C_ops, op_type)
out = op(x, *attrs)
else:
check_variable_and_dtype(
x, 'x', ['int32', 'int64', 'float16', 'float32',
'float64'], op_type)
helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype(input_param_name='x')
out = helper.create_variable_for_type_inference(dtype=dtype)
helper.append_op(
type=op_type,
inputs={'X': x},
attrs={'hop_length': hop_length,
'axis': axis},
outputs={'Out': out})
return out
def stft(x,
n_fft,
hop_length=None,
win_length=None,
window=None,
center=True,
pad_mode='reflect',
normalized=False,
onesided=True,
name=None):
"""
Short-time Fourier transform (STFT).
The STFT computes the discrete Fourier transforms (DFT) of short overlapping
windows of the input using this formula:
.. math::
X_t[\omega] = \sum_{n = 0}^{N-1}%
\text{window}[n]\ x[t \times H + n]\ %
e^{-{2 \pi j \omega n}/{N}}
Where:
- :math:`t`: The :math:`t`-th input window.
- :math:`\omega`: Frequency :math:`0 \leq \omega < \text{n\_fft}` for `onesided=False`,
or :math:`0 \leq \omega < \lfloor \text{n\_fft} / 2 \rfloor + 1` for `onesided=True`.
- :math:`N`: Value of `n_fft`.
- :math:`H`: Value of `hop_length`.
Args:
x (Tensor): The input data which is a 1-dimensional or 2-dimensional Tensor with
shape `[..., seq_length]`. It can be a real-valued or a complex Tensor.
n_fft (int): The number of input samples to perform Fourier transform.
hop_length (int, optional): Number of steps to advance between adjacent windows
and `0 < hop_length`. Default: `None`(treated as equal to `n_fft//4`)
win_length (int, optional): The size of window. Default: `None`(treated as equal
to `n_fft`)
window (Tensor, optional): A 1-dimensional tensor of size `win_length`. It will
be center padded to length `n_fft` if `win_length < n_fft`. Default: `None`(
treated as a rectangle window with value equal to 1 of size `win_length`).
center (bool, optional): Whether to pad `x` to make that the
:math:`t \times hop\_length` at the center of :math:`t`-th frame. Default: `True`.
pad_mode (str, optional): Choose padding pattern when `center` is `True`. See
`paddle.nn.functional.pad` for all padding options. Default: `"reflect"`
normalized (bool, optional): Control whether to scale the output by `1/sqrt(n_fft)`.
Default: `False`
onesided (bool, optional): Control whether to return half of the Fourier transform
output that satisfies the conjugate symmetry condition when input is a real-valued
tensor. It can not be `True` if input is a complex tensor. Default: `True`
name (str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
The complex STFT output tensor with shape `[..., n_fft//2 + 1, num_frames]`(
real-valued input and `onesided` is `True`) or `[..., n_fft, num_frames]`(
`onesided` is `False`)
Exampels:
.. code-block:: python
import paddle
from paddle.tensor.signal import stft
# real-valued input
x = paddle.randn([8, 48000], dtype=paddle.float64)
y1 = stft(x, n_fft=512) # [8, 257, 376]
y2 = stft(x, n_fft=512, onesided=False) # [8, 512, 376]
# complex input
x = paddle.randn([8, 48000], dtype=paddle.float64) + \
paddle.randn([8, 48000], dtype=paddle.float64)*1j # [8, 48000] complex128
y1 = stft(x, n_fft=512, center=False, onesided=False) # [8, 512, 372]
"""
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'complex64', 'complex128'],
'stft')
x_rank = len(x.shape)
assert x_rank in [1, 2], \
f'x should be a 1D or 2D real tensor, but got rank of x is {x_rank}'
if x_rank == 1: # (batch, seq_length)
x = x.unsqueeze(0)
if hop_length is None:
hop_length = int(n_fft // 4)
assert hop_length > 0, \
f'hop_length should be > 0, but got {hop_length}.'
if win_length is None:
win_length = n_fft
assert 0 < n_fft <= x.shape[-1], \
f'n_fft should be in (0, seq_length({x.shape[-1]})], but got {n_fft}.'
assert 0 < win_length <= n_fft, \
f'win_length should be in (0, n_fft({n_fft})], but got {win_length}.'
if window is not None:
assert len(window.shape) == 1 and len(window) == win_length, \
f'expected a 1D window tensor of size equal to win_length({win_length}), but got window with shape {window.shape}.'
else:
window = paddle.ones(shape=(win_length, ), dtype=x.dtype)
if win_length < n_fft:
pad_left = (n_fft - win_length) // 2
pad_right = n_fft - win_length - pad_left
window = paddle.nn.functional.pad(window,
pad=[pad_left, pad_right],
mode='constant')
if center:
assert pad_mode in ['constant', 'reflect'], \
'pad_mode should be "reflect" or "constant", but got "{}".'.format(pad_mode)
pad_length = n_fft // 2
# FIXME: Input `x` can be a complex tensor but pad does not supprt complex input.
x = paddle.nn.functional.pad(x.unsqueeze(-1),
pad=[pad_length, pad_length],
mode=pad_mode,
data_format="NLC").squeeze(-1)
x_frames = frame(x=x, frame_length=n_fft, hop_length=hop_length, axis=-1)
x_frames = x_frames.transpose(
perm=[0, 2,
1]) # switch n_fft to last dim, egs: (batch, num_frames, n_fft)
x_frames = x_frames * window
norm = 'ortho' if normalized else 'backward'
if is_complex(x_frames):
assert not onesided, \
'onesided should be False when input or window is a complex Tensor.'
if not is_complex(x):
out = fft_r2c(
x=x_frames,
n=None,
axis=-1,
norm=norm,
forward=True,
onesided=onesided,
name=name)
else:
out = fft_c2c(
x=x_frames, n=None, axis=-1, norm=norm, forward=True, name=name)
out = out.transpose(perm=[0, 2, 1]) # (batch, n_fft, num_frames)
if x_rank == 1:
out.squeeze_(0)
return out
def istft(x,
n_fft,
hop_length=None,
win_length=None,
window=None,
center=True,
normalized=False,
onesided=True,
length=None,
return_complex=False,
name=None):
"""
Inverse short-time Fourier transform (ISTFT).
Reconstruct time-domain signal from the giving complex input and window tensor when
nonzero overlap-add (NOLA) condition is met:
.. math::
\sum_{t = -\infty}^{\infty}%
\text{window}^2[n - t \times H]\ \neq \ 0, \ \text{for } all \ n
Where:
- :math:`t`: The :math:`t`-th input window.
- :math:`N`: Value of `n_fft`.
- :math:`H`: Value of `hop_length`.
Result of `istft` expected to be the inverse of `paddle.tensor.signal.stft`, but it is
not guaranteed to reconstruct a exactly realizible time-domain signal from a STFT
complex tensor which has been modified (via masking or otherwise). Therefore, `istft`
gives the [Griffin-Lim optimal estimate](https://ieeexplore.ieee.org/document/1164317)
(optimal in a least-squares sense) for the corresponding signal.
Args:
x (Tensor): The input data which is a 2-dimensional or 3-dimensional **complesx**
Tensor with shape `[..., n_fft, num_frames]`.
n_fft (int): The size of Fourier transform.
hop_length (int, optional): Number of steps to advance between adjacent windows
from time-domain signal and `0 < hop_length < win_length`. Default: `None`(
treated as equal to `n_fft//4`)
win_length (int, optional): The size of window. Default: `None`(treated as equal
to `n_fft`)
window (Tensor, optional): A 1-dimensional tensor of size `win_length`. It will
be center padded to length `n_fft` if `win_length < n_fft`. It should be a
real-valued tensor if `return_complex` is False. Default: `None`(treated as
a rectangle window with value equal to 1 of size `win_length`).
center (bool, optional): It means that whether the time-domain signal has been
center padded. Default: `True`.
normalized (bool, optional): Control whether to scale the output by `1/sqrt(n_fft)`.
Default: `False`
onesided (bool, optional): It means that whether the input STFT tensor is a half
of the conjugate symmetry STFT tensor transformed from a real-valued signal
and `istft` will return a real-valued tensor when it is set to `True`.
Default: `True`.
length (int, optional): Specify the length of time-domain signal. Default: `None`(
treated as the whole length of signal).
return_complex (bool, optional): It means that whether the time-domain signal is
real-valued. If `return_complex` is set to `True`, `onesided` should be set to
`False` cause the output is complex.
name (str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
A tensor of least squares estimation of the reconstructed signal(s) with shape
`[..., seq_length]`
Exampels:
.. code-block:: python
import numpy as np
import paddle
from paddle.tensor.signal import stft, istft
paddle.seed(0)
# STFT
x = paddle.randn([8, 48000], dtype=paddle.float64)
y = stft(x, n_fft=512) # [8, 257, 376]
# ISTFT
x_ = istft(y, n_fft=512) # [8, 48000]
np.allclose(x, x_) # True
"""
check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], 'istft')
x_rank = len(x.shape)
assert x_rank in [2, 3], \
'x should be a 2D or 3D complex tensor, but got rank of x is {}'.format(x_rank)
if x_rank == 2: # (batch, n_fft, n_frames)
x = x.unsqueeze(0)
if hop_length is None:
hop_length = int(n_fft // 4)
if win_length is None:
win_length = n_fft
# Assure no gaps between frames.
assert 0 < hop_length <= win_length, \
'hop_length should be in (0, win_length({})], but got {}.'.format(win_length, hop_length)
assert 0 < win_length <= n_fft, \
'win_length should be in (0, n_fft({})], but got {}.'.format(n_fft, win_length)
n_frames = x.shape[-1]
fft_size = x.shape[-2]
if onesided:
assert (fft_size == n_fft // 2 + 1), \
'fft_size should be equal to n_fft // 2 + 1({}) when onesided is True, but got {}.'.format(n_fft // 2 + 1, fft_size)
else:
assert (fft_size == n_fft), \
'fft_size should be equal to n_fft({}) when onesided is False, but got {}.'.format(n_fft, fft_size)
if window is not None:
assert len(window.shape) == 1 and len(window) == win_length, \
'expected a 1D window tensor of size equal to win_length({}), but got window with shape {}.'.format(win_length, window.shape)
else:
window = paddle.ones(shape=(win_length, ))
if win_length < n_fft:
pad_left = (n_fft - win_length) // 2
pad_right = n_fft - win_length - pad_left
# FIXME: Input `window` can be a complex tensor but pad does not supprt complex input.
window = paddle.nn.functional.pad(window,
pad=[pad_left, pad_right],
mode='constant')
x = x.transpose(
perm=[0, 2,
1]) # switch n_fft to last dim, egs: (batch, num_frames, n_fft)
norm = 'ortho' if normalized else 'backward'
if return_complex:
assert not onesided, \
'onesided should be False when input(output of istft) or window is a complex Tensor.'
out = fft_c2c(x=x, n=None, axis=-1, norm=norm, forward=False, name=None)
else:
assert not is_complex(window), \
'Data type of window should not be complex when return_complex is False.'
if onesided is False:
x = x[:, :, :n_fft // 2 + 1]
out = fft_c2r(x=x, n=None, axis=-1, norm=norm, forward=False, name=None)
out = overlap_add(
x=(out * window).transpose(
perm=[0, 2, 1]), # (batch, n_fft, num_frames)
hop_length=hop_length,
axis=-1) # (batch, seq_length)
window_envelop = overlap_add(
x=paddle.tile(
x=window * window, repeat_times=[n_frames, 1]).transpose(
perm=[1, 0]), # (n_fft, num_frames)
hop_length=hop_length,
axis=-1) # (seq_length, )
if length is None:
if center:
out = out[:, (n_fft // 2):-(n_fft // 2)]
window_envelop = window_envelop[(n_fft // 2):-(n_fft // 2)]
else:
if center:
start = n_fft // 2
else:
start = 0
out = out[:, start:start + length]
window_envelop = window_envelop[start:start + length]
# Check whether the Nonzero Overlap Add (NOLA) constraint is met.
if window_envelop.abs().min().item() < 1e-11:
raise ValueError(
'Abort istft because Nonzero Overlap Add (NOLA) condition failed. For more information about NOLA constraint please see `scipy.signal.check_NOLA`(https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.check_NOLA.html).'
)
out = out / window_envelop
if x_rank == 2:
out.squeeze_(0)
return out
...@@ -6,6 +6,7 @@ gym ...@@ -6,6 +6,7 @@ gym
opencv-python<=4.2.0.32 opencv-python<=4.2.0.32
visualdl visualdl
paddle2onnx>=0.4 paddle2onnx>=0.4
scipy scipy>=1.6
prettytable prettytable
distro distro
numpy>=1.20
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册