From 11518a436183474458e46be1239084fba6775d99 Mon Sep 17 00:00:00 2001 From: Feiyu Chan Date: Sat, 18 Sep 2021 13:49:41 +0800 Subject: [PATCH] Add FFT related operators and APIs (#35665) * 1. add interface for fft; 2. add data type predicate; 3. fix paddle.roll. * add fft c2c cufft kernel * implement argument checking & op calling parts for fft_c2c and fftn_c2c * add operator and opmaker definitions * only register float and double for cpu. * add common code for implementing FFT, add pocketfft as a dependency * add fft c2c cufft kernel function * fix bugs in python interface * add support for c2r, r2c operators, op makers, kernels and kernel functors. * test and fix bugs * 1. fft_c2c function: add support for onesided=False; 2. add complex, complex support for concat and flip. * 1. fft: fix python api bugs; 2. shape_op: add support for complex data types. * fft c2c cufft kernel done with complie and link * fix shape_op, add mkl placeholder * remove mkl * complete fft c2c in gpu * 1. implement mkl-based fft, FFTC2CFunctor and common function exec_fft; 2. change the design, add input and output typename as template parameter for all FFTFunctors, update pocketfft-based implementation. * complete fft c2c on gpu in ND * complete fft c2c on gpu in ND * complete fft c2c backward in ND * fix MKL-based implementation * Add frame op and CPU/GPU kernels. * Add frame op forward unittest. * Add frame op forward unittest. * Remove axis parameter in FrameFunctor. * Add frame op grad CPU/GPU kernels and unittest. * Add frame op grad CPU/GPU kernels and unittest. * Update doc string. * Update after review and remove librosa requirement in unittest. * Update grad kernel. * add fft_c2r op * Remove data allocation in TransCompute function. * add fft r2c onesided with cpu(pocketfft/mkl) and gpu * last fft c2r functor * fix C2R and R2C for cufft, becase the direction is not an option in these cases. * add fft r2c onesided with cpu(pocketfft/mkl) and gpu * fix bugs in python APIs * fix fft_c2r grad kernal * fix bugs in python APIs * add cuda fft c2r grad kernal functor * clean code * fix fft_c2r python API * fill fft r2c result with conjugate symmetry (#19) fill fft r2c result with conjugate symmetry * add placeholder for unittests (#24) * simple parameterize test function by auto generate test case from parm list (#25) * miscellaneous fixes for python APIs (#26) * add placeholder for unittests * resize fft inputs before computation is n or s is provided. * add complex kernels for pad and pad_grad * simplify argument checking. * add type promotion * add int to float or complex promotion * fix output data type for static mode * fix fft's input dtype dispatch, import fft to paddle * fix typos in axes checking (#27) * fix typos in axes checking * fix argument checking (#28) * fix argument checking * Add C2R Python layer normal and abnormal use cases (#29) * documents and single case * test c2r case * New C2R Python layer normal and exception use cases * complete rfft,rfft2,rfftn,ihfft,ihfft2,ihfftn unittest and doc string (#30) * Documentation of the common interfaces of c2r and c2c (#31) * Documentation of the common interfaces of c2r and c2c * clean c++ code (#32) * clean code * Add numpy-based implementation of spectral ops (#33) * add numpy reference implementation of spectral ops * Add fft_c2r numpy based implementation for unittest. (#34) * add fft_c2r numpy implementation * Add deframe op and stft/istft api. (#23) * Add frame api * Add deframe op and kernels. * Add stft and istft apis. * Add deframe api. Update stft and istft apis. * Fix bug in frame_from_librosa function when input dims >= 3 * Rename deframe to overlap_add. * Update istft. * Update after code review. * Add overlap_add op and stft/istft api unittest (#35) * Add overlap_add op unittest. * Register complex kernels of squeeze/unsquuze op. * Add stft/istft api unittest. * Add unittest for fft helper functions (#36) * add unittests for fft helper functions. add complex kernel for roll op. * complete static graph unittest for all public api (#37) * Unittest of op with FFT C2C, C2R and r2c added (#38) * documents and single case * test c2r case * New C2R Python layer normal and exception use cases * Documentation of the common interfaces of c2r and c2c * Unittest of op with FFT C2C, C2R and r2c added Co-authored-by: lijiaqi * add fft related options to CMakeLists.txt * fix typos and clean code (#39) * fix invisible character in mkl branch and fix error in error message * clean code: remove docstring from unittest for signal.py. * always convert numpy array to paddle.Tensor to avoid comparing numpy dtype with paddle dtype. (#40) * always convert numpy array to paddle.Tensor to avoid comparing numpy dtype with paddle dtype. * fix CI Errors: numpy dtype comparison, thrust when cuda is not available (#41) 1. always convert numpy array to paddle.Tensor to avoid comparing numpy dtype with paddle dtype. 2. promote floating point tensor to complex tensor ior fft_c2c and fft_c2r; 3. fix unittest to catch UnImplementedError and RuntimeError; 4. fix compile error by avoid using thrust when cuda is not available. 5. fix sample code, use paddle.fft instead of paddle.tensor.fft * remove inclusion of thrust, add __all__ list for fft (#42) * Add api doc and update unittest. (#43) * Add doc strings. * Update overlap_add op unittest * fix MKL-based FFT implementation (#44) * fix MKL-based FFT implementation, MKL CDFT's FORWARD DOMAIN is always REAL for R2C and C2R * remove code for debug (#45) * use dynload for cufft (#46) * use std::ptrdiff_t as datatype of stride (instead of int64_t) to avoid argument mismatch on some platforms. * add complex support for fill_zeros_like * use dynload for cufft * Update doc and unittest. (#47) * Add doc of frame op and overlap_add op. * Update unittest. * use dynload for cufft (#48) 1. use dynload for cufft 2. fix unittest; 3. temporarily disable Rocm. * fix conflicts and merge upstream (#49) fix conflicts and merge upstream * fix compile error: only link dyload_cuda when cuda is available (#50) * fix compile error: only link dyload_cuda when cuda is available * fix dynload for cufft on windows (#51) 1. fix dynload for cufft on windows; 2. fix unittests. * add NOMINMAX to compile on windows (#52) add NOMINMAX to compile on windows * explicitly specify capture mode for lambdas (#55) explicitly specify capture mode for lambdas * fix fft sample (#53) * fix fft sample * update scipy and numpy version for unittests of fft (#56) update scipy and numpy version for unittests of fft * Add static graph unittests of frame and overlap_add api. (#57) * Remove cache of cuFFT & Disable ONEMKL (#59) 1. replace numpy.fft with scipy.fft as numpy<1.20 not support ortho norm 2. remove cache of cufft plans; 3. enhance error checking. 4. default WITH_ONEMKL to OFF Co-authored-by: jeff41404 Co-authored-by: root Co-authored-by: KP <109694228@qq.com> Co-authored-by: lijiaqi Co-authored-by: Xiaoxu Chen Co-authored-by: lijiaqi0612 <33169170+lijiaqi0612@users.noreply.github.com> --- CMakeLists.txt | 7 + cmake/FindGperftools.cmake | 2 +- cmake/external/pocketfft.cmake | 44 + cmake/third_party.cmake | 6 + paddle/fluid/framework/data_type.h | 18 +- paddle/fluid/operators/CMakeLists.txt | 12 +- paddle/fluid/operators/concat_op.cc | 13 +- paddle/fluid/operators/concat_op.cu.cc | 13 +- paddle/fluid/operators/eigen/scale.cc | 3 + paddle/fluid/operators/eigen/scale.cu | 3 + paddle/fluid/operators/fill_zeros_like_op.cc | 13 +- .../fluid/operators/fill_zeros_like_op.cu.cc | 13 +- paddle/fluid/operators/flip_op.cc | 6 +- paddle/fluid/operators/flip_op.cu | 6 +- paddle/fluid/operators/frame_op.cc | 186 ++ paddle/fluid/operators/frame_op.cu | 41 + paddle/fluid/operators/frame_op.h | 341 ++++ paddle/fluid/operators/math/seq2col.h | 186 ++ paddle/fluid/operators/overlap_add_op.cc | 188 ++ paddle/fluid/operators/overlap_add_op.cu | 43 + paddle/fluid/operators/overlap_add_op.h | 304 ++++ paddle/fluid/operators/pad_op.cc | 25 +- paddle/fluid/operators/roll_op.cc | 13 +- paddle/fluid/operators/roll_op.cu | 13 +- paddle/fluid/operators/shape_op.cc | 6 +- paddle/fluid/operators/shape_op.cu | 5 +- paddle/fluid/operators/spectral_op.cc | 870 +++++++++ paddle/fluid/operators/spectral_op.cu | 643 +++++++ paddle/fluid/operators/spectral_op.h | 461 +++++ paddle/fluid/operators/squeeze_op.cc | 26 +- paddle/fluid/operators/squeeze_op.cu.cc | 24 +- paddle/fluid/operators/unsqueeze_op.cc | 24 +- paddle/fluid/operators/unsqueeze_op.cu.cc | 24 +- paddle/fluid/platform/complex.h | 2 + paddle/fluid/platform/dynload/CMakeLists.txt | 2 +- paddle/fluid/platform/dynload/cufft.cc | 44 + paddle/fluid/platform/dynload/cufft.h | 113 ++ .../fluid/platform/dynload/dynamic_loader.cc | 17 + .../fluid/platform/dynload/dynamic_loader.h | 1 + python/paddle/__init__.py | 1 + python/paddle/fluid/layers/nn.py | 6 +- .../fluid/tests/unittests/CMakeLists.txt | 1 + .../fluid/tests/unittests/fft/CMakeLists.txt | 6 + .../fluid/tests/unittests/fft/__init__.py | 13 + .../tests/unittests/fft/spectral_op_np.py | 108 ++ .../fluid/tests/unittests/fft/test_fft.py | 960 ++++++++++ .../fft/test_fft_with_static_graph.py | 894 +++++++++ .../tests/unittests/fft/test_spectral_op.py | 178 ++ .../fluid/tests/unittests/test_frame_op.py | 140 ++ .../tests/unittests/test_overlap_add_op.py | 157 ++ .../fluid/tests/unittests/test_signal.py | 1005 ++++++++++ python/paddle/tensor/__init__.py | 2 + python/paddle/tensor/attribute.py | 35 + python/paddle/tensor/fft.py | 1609 +++++++++++++++++ python/paddle/tensor/manipulation.py | 2 +- python/paddle/tensor/signal.py | 576 ++++++ python/unittest_py/requirements.txt | 3 +- 57 files changed, 9413 insertions(+), 44 deletions(-) create mode 100644 cmake/external/pocketfft.cmake create mode 100644 paddle/fluid/operators/frame_op.cc create mode 100644 paddle/fluid/operators/frame_op.cu create mode 100644 paddle/fluid/operators/frame_op.h create mode 100644 paddle/fluid/operators/math/seq2col.h create mode 100644 paddle/fluid/operators/overlap_add_op.cc create mode 100644 paddle/fluid/operators/overlap_add_op.cu create mode 100644 paddle/fluid/operators/overlap_add_op.h create mode 100644 paddle/fluid/operators/spectral_op.cc create mode 100644 paddle/fluid/operators/spectral_op.cu create mode 100644 paddle/fluid/operators/spectral_op.h mode change 100755 => 100644 paddle/fluid/operators/squeeze_op.cu.cc mode change 100755 => 100644 paddle/fluid/operators/unsqueeze_op.cc mode change 100755 => 100644 paddle/fluid/operators/unsqueeze_op.cu.cc create mode 100644 paddle/fluid/platform/dynload/cufft.cc create mode 100644 paddle/fluid/platform/dynload/cufft.h create mode 100644 python/paddle/fluid/tests/unittests/fft/CMakeLists.txt create mode 100644 python/paddle/fluid/tests/unittests/fft/__init__.py create mode 100644 python/paddle/fluid/tests/unittests/fft/spectral_op_np.py create mode 100644 python/paddle/fluid/tests/unittests/fft/test_fft.py create mode 100644 python/paddle/fluid/tests/unittests/fft/test_fft_with_static_graph.py create mode 100644 python/paddle/fluid/tests/unittests/fft/test_spectral_op.py create mode 100644 python/paddle/fluid/tests/unittests/test_frame_op.py create mode 100644 python/paddle/fluid/tests/unittests/test_overlap_add_op.py create mode 100644 python/paddle/fluid/tests/unittests/test_signal.py create mode 100644 python/paddle/tensor/fft.py create mode 100644 python/paddle/tensor/signal.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 219f6fe20ba..98772e96781 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,6 +38,8 @@ project(paddle CXX C) # enable language CUDA # TODO(Shibo Tao): remove find_package(CUDA) completely. find_package(CUDA QUIET) +find_package(MKL CONFIG QUIET) +option(WITH_ONEMKL "Compile PaddlePaddle with oneMKL" OFF) option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND}) option(WITH_TENSORRT "Compile PaddlePaddle with NVIDIA TensorRT" OFF) option(WITH_XPU "Compile PaddlePaddle with BAIDU KUNLUN XPU" OFF) @@ -225,6 +227,7 @@ option(WITH_STRIP "Strip so files of Whl packages" OFF) option(NEW_RELEASE_CUBIN "PaddlePaddle next-level release strategy for pypi cubin package" OFF) option(NEW_RELEASE_JIT "PaddlePaddle next-level release strategy for backup jit package" OFF) option(WITH_ASCEND_INT64 "Compile with int64 kernel for ascend NPU" OFF) +option(WITH_POCKETFFT "Compile with pocketfft support" ON) # PY_VERSION if(NOT PY_VERSION) @@ -373,6 +376,10 @@ if (WITH_MIPS) add_definitions(-DPADDLE_WITH_MIPS) endif() +if (WITH_ONEMKL) + add_definitions(-DPADDLE_WITH_ONEMKL) +endif() + if (WITH_HETERPS) if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -faligned-new") diff --git a/cmake/FindGperftools.cmake b/cmake/FindGperftools.cmake index 928f573a4fb..318f9f5fd3b 100644 --- a/cmake/FindGperftools.cmake +++ b/cmake/FindGperftools.cmake @@ -20,7 +20,7 @@ find_library(GPERFTOOLS_TCMALLOC NAMES tcmalloc HINTS ${Gperftools_ROOT_DIR}/lib) - + find_library(GPERFTOOLS_PROFILER NAMES profiler HINTS ${Gperftools_ROOT_DIR}/lib) diff --git a/cmake/external/pocketfft.cmake b/cmake/external/pocketfft.cmake new file mode 100644 index 00000000000..7323f67d115 --- /dev/null +++ b/cmake/external/pocketfft.cmake @@ -0,0 +1,44 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +include(ExternalProject) + + +set(POCKETFFT_PATH "${THIRD_PARTY_PATH}/pocketfft" CACHE STRING "A path setting for external_pocketfft path.") +set(POCKETFFT_PREFIX_DIR ${POCKETFFT_PATH}) + +set(POCKETFFT_REPOSITORY https://gitlab.mpcdf.mpg.de/mtr/pocketfft.git) +set(POCKETFFT_TAG release_for_eigen) + +SET(POCKETFFT_INCLUDE_DIR ${POCKETFFT_PREFIX_DIR}/src) +message("POCKETFFT_INCLUDE_DIR is ${POCKETFFT_INCLUDE_DIR}") +include_directories(${POCKETFFT_INCLUDE_DIR}) + +ExternalProject_Add( + extern_pocketfft + ${EXTERNAL_PROJECT_LOG_ARGS} + ${SHALLOW_CLONE} + GIT_REPOSITORY ${POCKETFFT_REPOSITORY} + GIT_TAG ${POCKETFFT_TAG} + PREFIX ${POCKETFFT_PREFIX_DIR} + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" +) + +add_library(pocketfft INTERFACE) + +add_dependencies(pocketfft extern_pocketfft) diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake index aa31745c213..6487b5062c4 100644 --- a/cmake/third_party.cmake +++ b/cmake/third_party.cmake @@ -361,4 +361,10 @@ if (WITH_CRYPTO) add_definitions(-DPADDLE_WITH_CRYPTO) endif (WITH_CRYPTO) +if (WITH_POCKETFFT) + include(external/pocketfft) + list(APPEND third_party_deps extern_pocketfft) + add_definitions(-DPADDLE_WITH_POCKETFFT) +endif (WITH_POCKETFFT) + add_custom_target(third_party ALL DEPENDS ${third_party_deps}) diff --git a/paddle/fluid/framework/data_type.h b/paddle/fluid/framework/data_type.h index 72ee126e13c..08749b6b751 100644 --- a/paddle/fluid/framework/data_type.h +++ b/paddle/fluid/framework/data_type.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include #include @@ -170,11 +171,26 @@ extern inline proto::VarType::Type ToComplexType(proto::VarType::Type t) { return proto::VarType::COMPLEX128; default: PADDLE_THROW(platform::errors::Unimplemented( - "Unknown complex value data type (%s), now only support float32 and " + "Unknown real value data type (%s), now only support float32 and " "float64.", DataTypeToString(t))); } } +extern inline proto::VarType::Type ToRealType(proto::VarType::Type t) { + switch (t) { + case proto::VarType::COMPLEX64: + return proto::VarType::FP32; + case proto::VarType::COMPLEX128: + return proto::VarType::FP64; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Unknown complex value data type (%s), now only support complex64 " + "and " + "complex128.", + DataTypeToString(t))); + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 4a82f558ff4..b9025560f6d 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -59,6 +59,10 @@ if (WITH_GPU) endif() endif() +if (WITH_POCKETFFT) + SET(OP_HEADER_DEPS ${OP_HEADER_DEPS} pocketfft) +endif() + SET(OP_MKL_DEPS "") if (NOT WITH_MKL OR NOT WITH_AVX) @@ -75,7 +79,7 @@ if(WITH_UNITY_BUILD) endif() register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op lstm_op run_program_op eye_op recurrent_op - sync_batch_norm_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS}) + sync_batch_norm_op spectral_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS}) op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc DEPS executor_cache ${OP_HEADER_DEPS}) @@ -94,6 +98,12 @@ else() op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) endif() +if (WITH_GPU AND (NOT WITH_ROCM)) + op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda ${OP_HEADER_DEPS}) +else() + op_library(spectral_op SRCS spectral_op.cc DEPS ${OP_HEADER_DEPS}) +endif() + op_library(lstm_op DEPS ${OP_HEADER_DEPS} lstm_compute) op_library(eye_op DEPS ${OP_HEADER_DEPS}) op_library(recurrent_op DEPS ${OP_HEADER_DEPS}) diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index 4783aa3a86f..a400d27b798 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/concat_op.h" +#include #include #include #include @@ -237,7 +238,11 @@ REGISTER_OP_CPU_KERNEL( ops::ConcatKernel, ops::ConcatKernel, - ops::ConcatKernel); + ops::ConcatKernel, + ops::ConcatKernel>, + ops::ConcatKernel>); REGISTER_OP_CPU_KERNEL( concat_grad, ops::ConcatGradKernel, @@ -247,4 +252,8 @@ REGISTER_OP_CPU_KERNEL( ops::ConcatGradKernel, ops::ConcatGradKernel, - ops::ConcatGradKernel); + ops::ConcatGradKernel, + ops::ConcatGradKernel>, + ops::ConcatGradKernel>); diff --git a/paddle/fluid/operators/concat_op.cu.cc b/paddle/fluid/operators/concat_op.cu.cc index 63025c3bd03..2be76329857 100644 --- a/paddle/fluid/operators/concat_op.cu.cc +++ b/paddle/fluid/operators/concat_op.cu.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/concat_op.h" +#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; @@ -24,7 +25,11 @@ REGISTER_OP_CUDA_KERNEL( ops::ConcatKernel, ops::ConcatKernel, ops::ConcatKernel, - ops::ConcatKernel); + ops::ConcatKernel, + ops::ConcatKernel>, + ops::ConcatKernel>); REGISTER_OP_CUDA_KERNEL( concat_grad, ops::ConcatGradKernel, @@ -33,4 +38,8 @@ REGISTER_OP_CUDA_KERNEL( ops::ConcatGradKernel, ops::ConcatGradKernel, ops::ConcatGradKernel, - ops::ConcatGradKernel); + ops::ConcatGradKernel, + ops::ConcatGradKernel>, + ops::ConcatGradKernel>); diff --git a/paddle/fluid/operators/eigen/scale.cc b/paddle/fluid/operators/eigen/scale.cc index e85878f20aa..d9fbb878e35 100644 --- a/paddle/fluid/operators/eigen/scale.cc +++ b/paddle/fluid/operators/eigen/scale.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/eigen/eigen_function.h" #include "paddle/fluid/platform/bfloat16.h" +#include "paddle/fluid/platform/complex.h" namespace paddle { namespace operators { @@ -42,6 +43,8 @@ template struct EigenScale; template struct EigenScale; template struct EigenScale; template struct EigenScale; +template struct EigenScale>; +template struct EigenScale>; } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/eigen/scale.cu b/paddle/fluid/operators/eigen/scale.cu index 6a77f72f620..5e485799af5 100644 --- a/paddle/fluid/operators/eigen/scale.cu +++ b/paddle/fluid/operators/eigen/scale.cu @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/eigen/eigen_function.h" +#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/float16.h" namespace paddle { @@ -41,6 +42,8 @@ template struct EigenScale; template struct EigenScale; template struct EigenScale; template struct EigenScale; +template struct EigenScale>; +template struct EigenScale>; } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/fill_zeros_like_op.cc b/paddle/fluid/operators/fill_zeros_like_op.cc index c727c657ed7..2d340829332 100644 --- a/paddle/fluid/operators/fill_zeros_like_op.cc +++ b/paddle/fluid/operators/fill_zeros_like_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/fill_zeros_like_op.h" +#include "paddle/fluid/platform/complex.h" namespace paddle { namespace operators { @@ -93,7 +94,11 @@ REGISTER_OP_CPU_KERNEL( ops::FillZerosLikeKernel, ops::FillZerosLikeKernel, ops::FillZerosLikeKernel, - ops::FillZerosLikeKernel); + ops::FillZerosLikeKernel, + ops::FillZerosLikeKernel>, + ops::FillZerosLikeKernel>); REGISTER_OP_CPU_KERNEL( fill_zeros_like2, @@ -101,4 +106,8 @@ REGISTER_OP_CPU_KERNEL( ops::FillZerosLikeKernel, ops::FillZerosLikeKernel, ops::FillZerosLikeKernel, - ops::FillZerosLikeKernel); + ops::FillZerosLikeKernel, + ops::FillZerosLikeKernel>, + ops::FillZerosLikeKernel>); diff --git a/paddle/fluid/operators/fill_zeros_like_op.cu.cc b/paddle/fluid/operators/fill_zeros_like_op.cu.cc index 1831635def7..4cb0887c1f3 100644 --- a/paddle/fluid/operators/fill_zeros_like_op.cu.cc +++ b/paddle/fluid/operators/fill_zeros_like_op.cu.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/fill_zeros_like_op.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; @@ -25,7 +26,11 @@ REGISTER_OP_CUDA_KERNEL( ops::FillZerosLikeKernel, ops::FillZerosLikeKernel, - ops::FillZerosLikeKernel); + ops::FillZerosLikeKernel, + ops::FillZerosLikeKernel>, + ops::FillZerosLikeKernel>); REGISTER_OP_CUDA_KERNEL( fill_zeros_like2, @@ -35,4 +40,8 @@ REGISTER_OP_CUDA_KERNEL( ops::FillZerosLikeKernel, ops::FillZerosLikeKernel, - ops::FillZerosLikeKernel); + ops::FillZerosLikeKernel, + ops::FillZerosLikeKernel>, + ops::FillZerosLikeKernel>); diff --git a/paddle/fluid/operators/flip_op.cc b/paddle/fluid/operators/flip_op.cc index d062243acf3..5e6d263f190 100644 --- a/paddle/fluid/operators/flip_op.cc +++ b/paddle/fluid/operators/flip_op.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/complex.h" namespace paddle { namespace operators { @@ -145,6 +146,7 @@ class FlipOpGradMaker : public framework::SingleGradOpMaker { } // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OPERATOR(flip, ops::FlipOp, ops::FlipOpMaker, ops::FlipOpInferVarType, ops::FlipOpGradMaker, ops::FlipOpGradMaker); @@ -153,7 +155,9 @@ REGISTER_OP_CPU_KERNEL( ops::FlipKernel, ops::FlipKernel, ops::FlipKernel, - ops::FlipKernel); + ops::FlipKernel, + ops::FlipKernel>, + ops::FlipKernel>); /* ========================== register checkpoint ===========================*/ REGISTER_OP_VERSION(flip) diff --git a/paddle/fluid/operators/flip_op.cu b/paddle/fluid/operators/flip_op.cu index 581a994ba84..26b3d11bc6c 100644 --- a/paddle/fluid/operators/flip_op.cu +++ b/paddle/fluid/operators/flip_op.cu @@ -16,6 +16,7 @@ limitations under the License. */ #include #include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/platform/complex.h" namespace paddle { namespace operators { @@ -163,4 +164,7 @@ REGISTER_OP_CUDA_KERNEL( ops::FlipKernel, ops::FlipKernel, ops::FlipKernel, - ops::FlipKernel); + ops::FlipKernel, + ops::FlipKernel>, + ops::FlipKernel>); diff --git a/paddle/fluid/operators/frame_op.cc b/paddle/fluid/operators/frame_op.cc new file mode 100644 index 00000000000..7568941e980 --- /dev/null +++ b/paddle/fluid/operators/frame_op.cc @@ -0,0 +1,186 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/frame_op.h" + +namespace paddle { +namespace operators { + +class FrameOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "frame"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "frame"); + + const int frame_length = ctx->Attrs().Get("frame_length"); + const int hop_length = ctx->Attrs().Get("hop_length"); + const int axis = ctx->Attrs().Get("axis"); + + const auto x_dims = ctx->GetInputDim("X"); + const int x_rank = x_dims.size(); + + PADDLE_ENFORCE_GE( + x_rank, 1, platform::errors::InvalidArgument( + "Input(X) of FrameOp should be a tensor which contains " + "at least 1 dimension, but got rank %s.", + x_rank)); + PADDLE_ENFORCE_GT(hop_length, 0, + platform::errors::InvalidArgument( + "Attribute(hop_length) of FrameOp should be greater " + "than 0, but got %s.", + hop_length)); + PADDLE_ENFORCE_EQ( + (axis == 0 || axis == -1), true, + platform::errors::InvalidArgument( + "Attribute(axis) of FrameOp should 0 or -1, but got %s.", axis)); + + std::vector output_shape; + int seq_length; + int n_frames; + + int start_axis; + int end_axis; + + if (axis == 0) { + seq_length = x_dims[0]; + start_axis = 1; + end_axis = x_rank - 1; + } else { + seq_length = x_dims[x_rank - 1]; + start_axis = 0; + end_axis = x_rank - 2; + } + + PADDLE_ENFORCE_LE(frame_length, seq_length, + platform::errors::InvalidArgument( + "Attribute(frame_length) of FrameOp should be less " + "equal than sequence length, but got (%s) > (%s).", + frame_length, seq_length)); + + // It won't go into for loop when x_rank == 1U. + for (int i = start_axis; i <= end_axis; i++) { + output_shape.push_back(x_dims[i]); + } + + n_frames = 1 + (seq_length - frame_length) / hop_length; + + if (axis == 0) { + // (n_frames, frame_length, ...) + output_shape.insert(output_shape.begin(), frame_length); + output_shape.insert(output_shape.begin(), n_frames); + } else { + // (..., frame_length, n_frames) + output_shape.push_back(frame_length); + output_shape.push_back(n_frames); + } + + ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(in_dtype, ctx.GetPlace()); + } +}; + +class FrameOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of frame op."); + AddOutput("Out", "(Tensor), The output tensor of frame op."); + AddAttr( + "frame_length", + "Length of the frame and `0 < frame_length <= x.shape[axis]`."); + AddAttr("hop_length", + "Number of steps to advance between adjacent frames and " + "`0 < hop_length`."); + AddAttr("axis", + "Specify the axis to operate on the input Tensors. Its value " + "should be 0(the first dimension) or -1(the last dimension).") + .SetDefault(-1); + AddComment(R"DOC( + Slice the N-dimensional (where N >= 1) input into (overlapping) frames. + )DOC"); + } +}; + +class FrameOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "frame_grad"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "Out@GRAD", "frame_grad"); + const auto x_dims = ctx->GetInputDim("X"); + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(in_dtype, ctx.GetPlace()); + } +}; + +template +class FrameOpGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + void Apply(GradOpPtr retv) const override { + retv->SetType("frame_grad"); + retv->SetInput("X", this->Input("X")); + retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + retv->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(frame, ops::FrameOp, ops::FrameOpMaker, + ops::FrameOpGradMaker, + ops::FrameOpGradMaker); + +REGISTER_OPERATOR(frame_grad, ops::FrameOpGrad); + +REGISTER_OP_CPU_KERNEL( + frame, ops::FrameKernel, + ops::FrameKernel, + ops::FrameKernel, + ops::FrameKernel, + ops::FrameKernel>, + ops::FrameKernel>); + +REGISTER_OP_CPU_KERNEL( + frame_grad, ops::FrameGradKernel, + ops::FrameGradKernel, + ops::FrameGradKernel, + ops::FrameGradKernel, + ops::FrameGradKernel>, + ops::FrameGradKernel>); diff --git a/paddle/fluid/operators/frame_op.cu b/paddle/fluid/operators/frame_op.cu new file mode 100644 index 00000000000..797e0aa0111 --- /dev/null +++ b/paddle/fluid/operators/frame_op.cu @@ -0,0 +1,41 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/frame_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + frame, ops::FrameKernel, + ops::FrameKernel, + ops::FrameKernel, + ops::FrameKernel, + ops::FrameKernel, + ops::FrameKernel>, + ops::FrameKernel>); + +REGISTER_OP_CUDA_KERNEL( + frame_grad, ops::FrameGradKernel, + ops::FrameGradKernel, + ops::FrameGradKernel, + ops::FrameGradKernel, + ops::FrameGradKernel, + ops::FrameGradKernel>, + ops::FrameGradKernel>); diff --git a/paddle/fluid/operators/frame_op.h b/paddle/fluid/operators/frame_op.h new file mode 100644 index 00000000000..482c6411812 --- /dev/null +++ b/paddle/fluid/operators/frame_op.h @@ -0,0 +1,341 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/seq2col.h" +#include "paddle/fluid/operators/transpose_op.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; + +template +struct FrameFunctor { + void operator()(const DeviceContext& dev_ctx, const Tensor* input, + Tensor* output, size_t seq_length, size_t frame_length, + size_t n_frames, size_t hop_length, + bool is_grad = false) const { + auto numel = output->numel(); + const auto* input_data = input->data(); + auto* output_data = output->data(); + + platform::ForRange for_range(dev_ctx, numel); + if (!is_grad) { + math::Seq2ColFunctor functor(input_data, output_data, seq_length, + frame_length, n_frames, hop_length); + for_range(functor); + } else { + math::Col2SeqFunctor functor(input_data, output_data, seq_length, + frame_length, n_frames, hop_length); + for_range(functor); + } + } +}; + +template +class FrameKernel : public framework::OpKernel { + public: + /* + Frame kernel slices frames from input sequences. The main steps as follows: + + - Case 1 - input dims == 1: + - axis is -1: Call a FrameFunctor to compute directly. + - axis is 0: Transpose output firstly, and then it falls into + case axis is -1. Finally, it restores the dims of + output tensor. + + - Case 2 - input dims == 2: + - axis is -1: Call a FrameFunctor to compute directly. + - axis is 0: Transpose both input and output firstly, and then it falls + into case axis is -1. Finally, it restores the dims of + output tensor. + + - Case 3 - input dims > 2: + Flatten the input and output to 2D and 3D respectively so that it + falls into Case 2. Finally, it restores the dims of output tensor. + */ + void Compute(const framework::ExecutionContext& ctx) const override { + const Tensor* x = ctx.Input("X"); + Tensor* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + const size_t x_rank = x->dims().size(); + const size_t out_rank = out->dims().size(); + + const int frame_length = ctx.Attr("frame_length"); + const int hop_length = ctx.Attr("hop_length"); + const int axis = ctx.Attr("axis"); + const int n_frames = + (axis == 0) ? out->dims()[0] : out->dims()[out_rank - 1]; + const int seq_length = (axis == 0) ? x->dims()[0] : x->dims()[x_rank - 1]; + + auto& dev_ctx = ctx.device_context(); + + // When the number of input dims is larger than 2, it needs to copy + // from x to resize input into 2d and output into 3d. Morevoer, output + // dims will be restored at the last step. + Tensor x_(x->type()); + x_ = *x; + + framework::DDim preserved_dims; + if (x_rank > 2) { + // Save dims used to flatten both input and output tensors and restore + // output tensor. + framework::DDim x_resized_dims; + framework::DDim out_resized_dims; + if (axis == 0) { + preserved_dims = framework::slice_ddim(x_.dims(), 1, x_rank); + x_resized_dims = {seq_length, framework::product(preserved_dims)}; + out_resized_dims = {n_frames, frame_length, + framework::product(preserved_dims)}; + } else { + preserved_dims = framework::slice_ddim(x_.dims(), 0, x_rank - 1); + x_resized_dims = {framework::product(preserved_dims), seq_length}; + out_resized_dims = {framework::product(preserved_dims), frame_length, + n_frames}; + } + x_.Resize(x_resized_dims); + out->Resize(out_resized_dims); + } + + Tensor trans_x(x_.type()); + Tensor trans_out(out->type()); + + // Transpose input and output in case that axis is 0. + if (axis == 0) { + if (x_rank == 1U) { + trans_x = x_; + + std::vector perm_out{1, 0}; + auto out_dims_vec = framework::vectorize(out->dims()); + for (int i = 0; i < out->dims().size(); ++i) { + out_dims_vec[i] = out->dims()[perm_out[i]]; + } + trans_out.Resize(framework::make_ddim(out_dims_vec)); + trans_out.mutable_data(ctx.GetPlace()); + TransCompute(perm_out.size(), dev_ctx, *out, + &trans_out, perm_out); + } else { + std::vector perm_x{1, 0}; + auto x_dims_vec = framework::vectorize(x_.dims()); + for (int i = 0; i < x_.dims().size(); ++i) { + x_dims_vec[i] = x_.dims()[perm_x[i]]; + } + trans_x.Resize(framework::make_ddim(x_dims_vec)); + trans_x.mutable_data(ctx.GetPlace()); + TransCompute(perm_x.size(), dev_ctx, x_, &trans_x, + perm_x); + + std::vector perm_out{2, 1, 0}; + auto out_dims_vec = framework::vectorize(out->dims()); + for (int i = 0; i < out->dims().size(); ++i) { + out_dims_vec[i] = out->dims()[perm_out[i]]; + } + trans_out.Resize(framework::make_ddim(out_dims_vec)); + trans_out.mutable_data(ctx.GetPlace()); + TransCompute(perm_out.size(), dev_ctx, *out, + &trans_out, perm_out); + } + } else { + trans_x = x_; + trans_out = *out; + } + + FrameFunctor()(dev_ctx, &trans_x, &trans_out, seq_length, + frame_length, n_frames, hop_length, + /*is_grad*/ false); + + // Transpose output in case axis is 0. + if (axis == 0) { + if (x_rank == 1U) { + std::vector perm_out{1, 0}; + TransCompute(perm_out.size(), dev_ctx, trans_out, out, + perm_out); + } else { + std::vector perm_out{2, 1, 0}; + TransCompute(perm_out.size(), dev_ctx, trans_out, out, + perm_out); + } + } + + // Restore output dims when the number of dims is larger than 2. + if (x_rank > 2) { + std::vector restored_out_shape; + for (int i = 0; i < preserved_dims.size(); i++) { + restored_out_shape.push_back(preserved_dims[i]); + } + + if (axis == 0) { + // (n_frames, frame_length, ...) + restored_out_shape.insert(restored_out_shape.begin(), frame_length); + restored_out_shape.insert(restored_out_shape.begin(), n_frames); + } else { + // (..., frame_length, n_frames) + restored_out_shape.push_back(frame_length); + restored_out_shape.push_back(n_frames); + } + + out->Resize(framework::make_ddim(restored_out_shape)); + } + } +}; + +template +class FrameGradKernel : public framework::OpKernel { + public: + /* + Frame gradient kernel accumulate gradient `d_x` from `d_out`. The + main steps as follows: + + - Case 1 - d_x dims == 1: + - axis is -1: Call a FrameFunctor to compute directly. Notes that + `is_grad` is set to true to select gradient data functor. + - axis is 0: Transpose `d_out` firstly, and then it falls into + case axis is -1. + + - Case 2 - d_x dims == 2: + - axis is -1: Call a FrameFunctor to compute directly. + - axis is 0: Transpose both `d_x` and `d_out` firstly, and then it + falls into case axis is -1. Finally, it restores the + dims of `d_x`. + + - Case 3 - d_x dims > 2: + Flatten the `d_x` and `d_out` to 2D and 3D respectively so that it + falls into Case 2. Finally, it restores the dims of `d_x` tensor. + */ + void Compute(const framework::ExecutionContext& ctx) const { + const Tensor* d_out = ctx.Input(framework::GradVarName("Out")); + Tensor* d_x = ctx.Output(framework::GradVarName("X")); + d_x->mutable_data(ctx.GetPlace()); + const size_t d_out_rank = d_out->dims().size(); + const size_t d_x_rank = d_x->dims().size(); + + const int frame_length = ctx.Attr("frame_length"); + const int hop_length = ctx.Attr("hop_length"); + const int axis = ctx.Attr("axis"); + const int n_frames = + (axis == 0) ? d_out->dims()[0] : d_out->dims()[d_out_rank - 1]; + const int seq_length = + (axis == 0) ? d_x->dims()[0] : d_x->dims()[d_x_rank - 1]; + + auto& dev_ctx = ctx.device_context(); + + Tensor d_out_(d_out->type()); + d_out_ = *d_out; + + framework::DDim preserved_dims; + if (d_x_rank > 2) { + // Save dims used to flatten both input and output tensors and restore + // output tensor. + framework::DDim d_x_resized_dims; + framework::DDim d_out_resized_dims; + if (axis == 0) { + preserved_dims = framework::slice_ddim(d_x->dims(), 1, d_x_rank); + d_x_resized_dims = {seq_length, framework::product(preserved_dims)}; + d_out_resized_dims = {n_frames, frame_length, + framework::product(preserved_dims)}; + } else { + preserved_dims = framework::slice_ddim(d_x->dims(), 0, d_x_rank - 1); + d_x_resized_dims = {framework::product(preserved_dims), seq_length}; + d_out_resized_dims = {framework::product(preserved_dims), frame_length, + n_frames}; + } + d_x->Resize(d_x_resized_dims); + d_out_.Resize(d_out_resized_dims); + } + + Tensor trans_d_x(d_x->type()); + Tensor trans_d_out(d_out_.type()); + + // Transpose input and output in case that axis is 0. + if (axis == 0) { + if (d_x_rank == 1U) { + trans_d_x = *d_x; + + std::vector perm_d_out{1, 0}; + auto d_out_dims_vec = framework::vectorize(d_out_.dims()); + for (int i = 0; i < d_out_.dims().size(); ++i) { + d_out_dims_vec[i] = d_out_.dims()[perm_d_out[i]]; + } + trans_d_out.Resize(framework::make_ddim(d_out_dims_vec)); + trans_d_out.mutable_data(ctx.GetPlace()); + TransCompute(perm_d_out.size(), dev_ctx, d_out_, + &trans_d_out, perm_d_out); + } else { + std::vector perm_d_x{1, 0}; + auto d_x_dims_vec = framework::vectorize(d_x->dims()); + for (int i = 0; i < d_x->dims().size(); ++i) { + d_x_dims_vec[i] = d_x->dims()[perm_d_x[i]]; + } + trans_d_x.Resize(framework::make_ddim(d_x_dims_vec)); + trans_d_x.mutable_data(ctx.GetPlace()); + TransCompute(perm_d_x.size(), dev_ctx, *d_x, + &trans_d_x, perm_d_x); + + std::vector perm_d_out{2, 1, 0}; + auto d_out_dims_vec = framework::vectorize(d_out_.dims()); + for (int i = 0; i < d_out_.dims().size(); ++i) { + d_out_dims_vec[i] = d_out_.dims()[perm_d_out[i]]; + } + trans_d_out.Resize(framework::make_ddim(d_out_dims_vec)); + trans_d_out.mutable_data(ctx.GetPlace()); + TransCompute(perm_d_out.size(), dev_ctx, d_out_, + &trans_d_out, perm_d_out); + } + } else { + trans_d_x = *d_x; + trans_d_out = d_out_; + } + + FrameFunctor()(dev_ctx, &trans_d_out, &trans_d_x, + seq_length, frame_length, n_frames, + hop_length, + /*is_grad*/ true); + + // Transpose output in case axis is 0. + if (axis == 0 && d_x_rank > 1U) { + std::vector perm_d_x{1, 0}; + TransCompute(perm_d_x.size(), dev_ctx, trans_d_x, d_x, + perm_d_x); + } + + // Restore output dims when the number of dims is larger than 2. + if (d_x_rank > 2) { + std::vector restored_d_x_shape; + for (int i = 0; i < preserved_dims.size(); i++) { + restored_d_x_shape.push_back(preserved_dims[i]); + } + + if (axis == 0) { + // (seq_length, ...) + restored_d_x_shape.insert(restored_d_x_shape.begin(), seq_length); + } else { + // (..., seq_length) + restored_d_x_shape.push_back(seq_length); + } + + d_x->Resize(framework::make_ddim(restored_d_x_shape)); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/seq2col.h b/paddle/fluid/operators/math/seq2col.h new file mode 100644 index 00000000000..56134b6f0ea --- /dev/null +++ b/paddle/fluid/operators/math/seq2col.h @@ -0,0 +1,186 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace paddle { +namespace operators { +namespace math { + +template +struct Seq2ColFunctor { + Seq2ColFunctor(const T* seq, T* col, size_t seq_length, size_t frame_length, + size_t n_frames, size_t hop_length) + : seq_(seq), + col_(col), + seq_length_(seq_length), + frame_length_(frame_length), + n_frames_(n_frames), + hop_length_(hop_length) {} + + /* + Convert sequences to frames. + + 1. Dimension infomation: + + Sequences Frames + (N, seq_length) -> (N, frame_length, n_frames) + + 2. Mapping from `i` to `src_idx` and `trg_idx` can be derived from: + + a. Notion + - `i` stands for the flattened index of a bunch of frames. + - `src_idx` and `trg_idx` are the 1D indices of seqs and frames + respectivly. + + b. Sample idx + ```cpp + sample_idx = i / (n_frames_ * frame_length_); + ``` + + c. Maps `i` to `f` and `n`. + ```cpp + f = i % (n_frames_ * frame_length_) / n_frames_; + n = i % (n_frames_ * frame_length_) % n_frames_; + ``` + + d. Replace `sample_idx`, `f` and `n` in the following eqations: + ```cpp + src_idx = sample_idx * seq_length_ + n * hop_length_ + f; + trg_idx = sample_idx * n_frames_ * frame_length_ + f * n_frames_ + n; + col_[trg_idx] = seq_[src_idx]; + ``` + + e. Result can be deduced shown in the function body below. + */ + HOSTDEVICE void operator()(size_t i) const { + size_t src_idx; + size_t trg_idx; + src_idx = i / (n_frames_ * frame_length_) * seq_length_ + + i % (n_frames_ * frame_length_) % n_frames_ * hop_length_ + + i % (n_frames_ * frame_length_) / n_frames_; + trg_idx = i / (n_frames_ * frame_length_) * n_frames_ * frame_length_ + + i % (n_frames_ * frame_length_) / n_frames_ * n_frames_ + + i % (n_frames_ * frame_length_) % n_frames_; + col_[trg_idx] = seq_[src_idx]; + } + + const T* seq_; + T* col_; + size_t seq_length_; + size_t frame_length_; + size_t n_frames_; + size_t hop_length_; +}; + +template +struct Col2SeqFunctor { + Col2SeqFunctor(const T* col, T* seq, size_t seq_length, size_t frame_length, + size_t n_frames, size_t hop_length) + : col_(col), + seq_(seq), + seq_length_(seq_length), + frame_length_(frame_length), + n_frames_(n_frames), + hop_length_(hop_length) {} + + /* + Accumulate output gradient d_out to d_x. + + 1. Dimension infomation: + + d_out d_x + (N, frame_length, n_frames) -> (N, seq_length) + + 2. Using a sliding window to find source indices from `d_out` according to + `i`: + + a. Notion + - `i` stands for the flattened index of `d_x`. + - `seq_i` stands for a relative index of a `d_x` sample. + - `left`: Starting index of a frame window. + - `right`: Ending index of a frame window. + + b. Sample idx + ```cpp + sample_idx = i / seq_length_; + ``` + + c. Slides a window with length of `frame_length` to find `f` and `n`. + - `n`: The idx of num_frames_, increases in each hop. + - `f`: The idx of frame_lengths_, relative idx from left of a sliding + window. + + d. Accumulate all grads from d_out. + ```cpp + seq_[i] += + col_[sample_idx * frame_length_ * n_frames_ + f * n_frames_ + n]; + ``` + */ + HOSTDEVICE void operator()(size_t i) const { + size_t sample_idx = i / seq_length_; + size_t seq_i = i % seq_length_; + + // Sliding window + seq_[i] = 0; // Init seq_[i] to 0, and sums up all + // grads from col_ in the while loop. + + size_t n = get_start_frame_idx(seq_i); + size_t f; + size_t left = n * hop_length_; + size_t right = left + frame_length_ - 1; + + while (left <= seq_i && right < seq_length_) { + f = seq_i - left; + seq_[i] += + col_[sample_idx * frame_length_ * n_frames_ + f * n_frames_ + n]; + // Next frame. + left += hop_length_; + right += hop_length_; + n += 1; + } + } + + /* + Calculate minimum value of frame index `n` to satisfy the inequality: + + seq_i <= right + ==> seq_i <= left + frame_length - 1 + ==> seq_i <= hop_length_ * n + frame_length_ - 1 + */ + HOSTDEVICE size_t get_start_frame_idx(size_t seq_i) const { + int64_t tmp = seq_i + 1 - frame_length_; + if (tmp > 0) { + size_t n = tmp / hop_length_; + if (tmp % hop_length_ == 0) { + return n; + } else { + return n + 1; + } + } else { + return 0; + } + } + + const T* col_; + T* seq_; + size_t seq_length_; + size_t frame_length_; + size_t n_frames_; + size_t hop_length_; +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/overlap_add_op.cc b/paddle/fluid/operators/overlap_add_op.cc new file mode 100644 index 00000000000..627c613e297 --- /dev/null +++ b/paddle/fluid/operators/overlap_add_op.cc @@ -0,0 +1,188 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/overlap_add_op.h" + +namespace paddle { +namespace operators { + +class OverlapAddOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "overlap_add"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "overlap_add"); + + const int hop_length = ctx->Attrs().Get("hop_length"); + const int axis = ctx->Attrs().Get("axis"); + + const auto x_dims = ctx->GetInputDim("X"); + const int x_rank = x_dims.size(); + + PADDLE_ENFORCE_GE( + x_rank, 2, + platform::errors::InvalidArgument( + "Input(X) of OverlapAddOp should be a tensor which contains " + "at least 2 dimensions, but got rank %s.", + x_rank)); + + PADDLE_ENFORCE_GT( + hop_length, 0, + platform::errors::InvalidArgument( + "Attribute(hop_length) of OverlapAddOp should be greater " + "than 0, but got %s.", + hop_length)); + + PADDLE_ENFORCE_EQ( + (axis == 0 || axis == -1), true, + platform::errors::InvalidArgument( + "Attribute(axis) of OverlapAddOp should 0 or -1, but got %s.", + axis)); + + std::vector output_shape; + int n_frames; + int frame_length; + + int start_axis; + int end_axis; + if (axis == 0) { + n_frames = x_dims[0]; + frame_length = x_dims[1]; + start_axis = 2; + end_axis = x_rank - 1; + } else { + n_frames = x_dims[x_rank - 1]; + frame_length = x_dims[x_rank - 2]; + start_axis = 0; + end_axis = x_rank - 3; + } + + PADDLE_ENFORCE_LE( + hop_length, frame_length, + platform::errors::InvalidArgument( + "Attribute(hop_length) of OverlapAddOp should be less or equal " + "than frame_length, but got hop_length(%s) > frame_length(%s).", + hop_length, frame_length)); + + const int seq_length = (n_frames - 1) * hop_length + frame_length; + + // It won't go into for loop when x_rank == 2U. + for (int i = start_axis; i <= end_axis; i++) { + output_shape.push_back(x_dims[i]); + } + + if (axis == 0) { + // (seq_length, ...) + output_shape.insert(output_shape.begin(), seq_length); + } else { + // (..., seq_length) + output_shape.push_back(seq_length); + } + + ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(in_dtype, ctx.GetPlace()); + } +}; + +class OverlapAddOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of overlap_add op."); + AddOutput("Out", "(Tensor), The output tensor of overlap_add op."); + AddAttr("hop_length", + "Number of steps to advance between adjacent frames and " + "`0 < hop_length <= frame_length`."); + AddAttr("axis", + "Specify the axis to operate on the input Tensors. Its value " + "should be 0(the first dimension) or -1(the last dimension).") + .SetDefault(-1); + AddComment(R"DOC( + Reconstructs a tensor consisted of overlap added sequences from input frames. + )DOC"); + } +}; + +class OverlapAddOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "overlap_add_grad"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "Out@GRAD", "overlap_add_grad"); + const auto x_dims = ctx->GetInputDim("X"); + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(in_dtype, ctx.GetPlace()); + } +}; + +template +class OverlapAddOpGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + void Apply(GradOpPtr retv) const override { + retv->SetType("overlap_add_grad"); + retv->SetInput("X", this->Input("X")); + retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + retv->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(overlap_add, ops::OverlapAddOp, ops::OverlapAddOpMaker, + ops::OverlapAddOpGradMaker, + ops::OverlapAddOpGradMaker); + +REGISTER_OPERATOR(overlap_add_grad, ops::OverlapAddOpGrad); + +REGISTER_OP_CPU_KERNEL( + overlap_add, ops::OverlapAddKernel, + ops::OverlapAddKernel, + ops::OverlapAddKernel, + ops::OverlapAddKernel, + ops::OverlapAddKernel>, + ops::OverlapAddKernel>); + +REGISTER_OP_CPU_KERNEL( + overlap_add_grad, + ops::OverlapAddGradKernel, + ops::OverlapAddGradKernel, + ops::OverlapAddGradKernel, + ops::OverlapAddGradKernel, + ops::OverlapAddGradKernel>, + ops::OverlapAddGradKernel>); diff --git a/paddle/fluid/operators/overlap_add_op.cu b/paddle/fluid/operators/overlap_add_op.cu new file mode 100644 index 00000000000..2b7935e0191 --- /dev/null +++ b/paddle/fluid/operators/overlap_add_op.cu @@ -0,0 +1,43 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/overlap_add_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + overlap_add, + ops::OverlapAddKernel, + ops::OverlapAddKernel, + ops::OverlapAddKernel, + ops::OverlapAddKernel, + ops::OverlapAddKernel, + ops::OverlapAddKernel>, + ops::OverlapAddKernel>); + +REGISTER_OP_CUDA_KERNEL( + overlap_add_grad, + ops::OverlapAddGradKernel, + ops::OverlapAddGradKernel, + ops::OverlapAddGradKernel, + ops::OverlapAddGradKernel, + ops::OverlapAddGradKernel, + ops::OverlapAddGradKernel>, + ops::OverlapAddGradKernel>); diff --git a/paddle/fluid/operators/overlap_add_op.h b/paddle/fluid/operators/overlap_add_op.h new file mode 100644 index 00000000000..865659ee942 --- /dev/null +++ b/paddle/fluid/operators/overlap_add_op.h @@ -0,0 +1,304 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/seq2col.h" +#include "paddle/fluid/operators/transpose_op.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; + +template +struct OverlapAddFunctor { + void operator()(const DeviceContext& dev_ctx, const Tensor* input, + Tensor* output, size_t seq_length, size_t frame_length, + size_t n_frames, size_t hop_length, + bool is_grad = false) const { + auto numel = output->numel(); + const auto* input_data = input->data(); + auto* output_data = output->data(); + + platform::ForRange for_range(dev_ctx, numel); + if (!is_grad) { + math::Col2SeqFunctor functor(input_data, output_data, seq_length, + frame_length, n_frames, hop_length); + for_range(functor); + } else { + math::Seq2ColFunctor functor(input_data, output_data, seq_length, + frame_length, n_frames, hop_length); + for_range(functor); + } + } +}; + +template +class OverlapAddKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + const Tensor* x = ctx.Input("X"); + Tensor* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + const size_t x_rank = x->dims().size(); + const size_t out_rank = out->dims().size(); + + const int hop_length = ctx.Attr("hop_length"); + const int axis = ctx.Attr("axis"); + const int n_frames = (axis == 0) ? x->dims()[0] : x->dims()[x_rank - 1]; + const int frame_length = (axis == 0) ? x->dims()[1] : x->dims()[x_rank - 2]; + const int seq_length = + (axis == 0) ? out->dims()[0] : out->dims()[out_rank - 1]; + + auto& dev_ctx = ctx.device_context(); + + Tensor x_(x->type()); + x_ = *x; + + framework::DDim preserved_dims; + if (out_rank > 2) { + // Save dims used to flatten both input and output tensors and restore + // output tensor. + framework::DDim x_resized_dims; + framework::DDim out_resized_dims; + if (axis == 0) { + preserved_dims = framework::slice_ddim(out->dims(), 1, out_rank); + x_resized_dims = {n_frames, frame_length, + framework::product(preserved_dims)}; + out_resized_dims = {seq_length, framework::product(preserved_dims)}; + } else { + preserved_dims = framework::slice_ddim(out->dims(), 0, out_rank - 1); + x_resized_dims = {framework::product(preserved_dims), frame_length, + n_frames}; + out_resized_dims = {framework::product(preserved_dims), seq_length}; + } + x_.Resize(x_resized_dims); + out->Resize(out_resized_dims); + } + + Tensor trans_x(x_.type()); + Tensor trans_out(out->type()); + + // Transpose input and output in case that axis is 0. + if (axis == 0) { + if (out_rank == 1U) { + trans_out = *out; + + std::vector perm_x{1, 0}; + auto x_dims_vec = framework::vectorize(x_.dims()); + for (int i = 0; i < x_.dims().size(); ++i) { + x_dims_vec[i] = x_.dims()[perm_x[i]]; + } + trans_x.Resize(framework::make_ddim(x_dims_vec)); + trans_x.mutable_data(ctx.GetPlace()); + TransCompute(perm_x.size(), dev_ctx, x_, &trans_x, + perm_x); + } else { + std::vector perm_out{1, 0}; + auto out_dims_vec = framework::vectorize(out->dims()); + for (int i = 0; i < out->dims().size(); ++i) { + out_dims_vec[i] = out->dims()[perm_out[i]]; + } + trans_out.Resize(framework::make_ddim(out_dims_vec)); + trans_out.mutable_data(ctx.GetPlace()); + TransCompute(perm_out.size(), dev_ctx, *out, + &trans_out, perm_out); + + std::vector perm_x{2, 1, 0}; + auto x_dims_vec = framework::vectorize(x_.dims()); + for (int i = 0; i < x_.dims().size(); ++i) { + x_dims_vec[i] = x_.dims()[perm_x[i]]; + } + trans_x.Resize(framework::make_ddim(x_dims_vec)); + trans_x.mutable_data(ctx.GetPlace()); + TransCompute(perm_x.size(), dev_ctx, x_, &trans_x, + perm_x); + } + } else { + trans_x = x_; + trans_out = *out; + } + + OverlapAddFunctor()(dev_ctx, &trans_x, &trans_out, + seq_length, frame_length, n_frames, + hop_length, /*is_grad*/ false); + + // Transpose output in case axis is 0. + if (axis == 0 && out_rank > 1U) { + std::vector perm_out{1, 0}; + TransCompute(perm_out.size(), dev_ctx, trans_out, out, + perm_out); + } + + // Restore output dims when the number of dims is larger than 2. + if (out_rank > 2) { + std::vector restored_out_shape; + for (int i = 0; i < preserved_dims.size(); i++) { + restored_out_shape.push_back(preserved_dims[i]); + } + + if (axis == 0) { + // (seq_length, ...) + restored_out_shape.insert(restored_out_shape.begin(), seq_length); + } else { + // (..., seq_length) + restored_out_shape.push_back(seq_length); + } + + out->Resize(framework::make_ddim(restored_out_shape)); + } + } +}; + +template +class OverlapAddGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const Tensor* d_out = ctx.Input(framework::GradVarName("Out")); + Tensor* d_x = ctx.Output(framework::GradVarName("X")); + d_x->mutable_data(ctx.GetPlace()); + const size_t d_out_rank = d_out->dims().size(); + const size_t d_x_rank = d_x->dims().size(); + + const int hop_length = ctx.Attr("hop_length"); + const int axis = ctx.Attr("axis"); + const int n_frames = + (axis == 0) ? d_x->dims()[0] : d_x->dims()[d_x_rank - 1]; + const int frame_length = + (axis == 0) ? d_x->dims()[1] : d_x->dims()[d_x_rank - 2]; + const int seq_length = + (axis == 0) ? d_out->dims()[0] : d_out->dims()[d_out_rank - 1]; + + auto& dev_ctx = ctx.device_context(); + + // When the number of input dims is larger than 2, it needs to copy + // from x to resize input into 2d and output into 3d. Morevoer, output + // dims will be restored at the last step. + Tensor d_out_(d_out->type()); + d_out_ = *d_out; + + framework::DDim preserved_dims; + if (d_out_rank > 2) { + // Save dims used to flatten both input and output tensors and restore + // output tensor. + framework::DDim d_x_resized_dims; + framework::DDim d_out_resized_dims; + if (axis == 0) { + preserved_dims = framework::slice_ddim(d_out_.dims(), 1, d_out_rank); + d_x_resized_dims = {n_frames, frame_length, + framework::product(preserved_dims)}; + d_out_resized_dims = {seq_length, framework::product(preserved_dims)}; + } else { + preserved_dims = + framework::slice_ddim(d_out_.dims(), 0, d_out_rank - 1); + d_x_resized_dims = {framework::product(preserved_dims), frame_length, + n_frames}; + d_out_resized_dims = {framework::product(preserved_dims), seq_length}; + } + d_x->Resize(d_x_resized_dims); + d_out_.Resize(d_out_resized_dims); + } + + Tensor trans_d_x(d_x->type()); + Tensor trans_d_out(d_out_.type()); + + // Transpose input and output in case that axis is 0. + if (axis == 0) { + if (d_out_rank == 1U) { + trans_d_out = d_out_; + + std::vector perm_d_x{1, 0}; + auto d_x_dims_vec = framework::vectorize(d_x->dims()); + for (int i = 0; i < d_x->dims().size(); ++i) { + d_x_dims_vec[i] = d_x->dims()[perm_d_x[i]]; + } + trans_d_x.Resize(framework::make_ddim(d_x_dims_vec)); + trans_d_x.mutable_data(ctx.GetPlace()); + TransCompute(perm_d_x.size(), dev_ctx, *d_x, + &trans_d_x, perm_d_x); + } else { + std::vector perm_d_out{1, 0}; + auto d_out_dims_vec = framework::vectorize(d_out_.dims()); + for (int i = 0; i < d_out_.dims().size(); ++i) { + d_out_dims_vec[i] = d_out_.dims()[perm_d_out[i]]; + } + trans_d_out.Resize(framework::make_ddim(d_out_dims_vec)); + trans_d_out.mutable_data(ctx.GetPlace()); + TransCompute(perm_d_out.size(), dev_ctx, d_out_, + &trans_d_out, perm_d_out); + + std::vector perm_d_x{2, 1, 0}; + auto d_x_dims_vec = framework::vectorize(d_x->dims()); + for (int i = 0; i < d_x->dims().size(); ++i) { + d_x_dims_vec[i] = d_x->dims()[perm_d_x[i]]; + } + trans_d_x.Resize(framework::make_ddim(d_x_dims_vec)); + trans_d_x.mutable_data(ctx.GetPlace()); + TransCompute(perm_d_x.size(), dev_ctx, *d_x, + &trans_d_x, perm_d_x); + } + } else { + trans_d_x = *d_x; + trans_d_out = d_out_; + } + + OverlapAddFunctor()(dev_ctx, &trans_d_out, &trans_d_x, + seq_length, frame_length, n_frames, + hop_length, + /*is_grad*/ true); + + // Transpose output in case axis is 0. + if (axis == 0) { + if (d_out_rank == 1U) { + std::vector perm_d_x{1, 0}; + TransCompute(perm_d_x.size(), dev_ctx, trans_d_x, d_x, + perm_d_x); + } else { + std::vector perm_d_x{2, 1, 0}; + TransCompute(perm_d_x.size(), dev_ctx, trans_d_x, d_x, + perm_d_x); + } + } + + // Restore output dims when the number of dims is larger than 2. + if (d_out_rank > 2) { + std::vector restored_d_x_shape; + for (int i = 0; i < preserved_dims.size(); i++) { + restored_d_x_shape.push_back(preserved_dims[i]); + } + + if (axis == 0) { + // (n_frames, frame_length, ...) + restored_d_x_shape.insert(restored_d_x_shape.begin(), frame_length); + restored_d_x_shape.insert(restored_d_x_shape.begin(), n_frames); + } else { + // (..., frame_length, n_frames) + restored_d_x_shape.push_back(frame_length); + restored_d_x_shape.push_back(n_frames); + } + + d_x->Resize(framework::make_ddim(restored_d_x_shape)); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/pad_op.cc b/paddle/fluid/operators/pad_op.cc index 3bf66c77bad..1ace706bac6 100644 --- a/paddle/fluid/operators/pad_op.cc +++ b/paddle/fluid/operators/pad_op.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/pad_op.h" #include +#include "paddle/fluid/platform/complex.h" namespace paddle { namespace operators { @@ -170,10 +171,18 @@ REGISTER_OP_CPU_KERNEL( pad, ops::PadKernel, ops::PadKernel, ops::PadKernel, - ops::PadKernel); + ops::PadKernel, + ops::PadKernel>, + ops::PadKernel>); REGISTER_OP_CPU_KERNEL( pad_grad, ops::PadGradKernel, - ops::PadGradKernel); + ops::PadGradKernel, + ops::PadGradKernel>, + ops::PadGradKernel>); REGISTER_OP_CUDA_KERNEL( pad, ops::PadKernel, @@ -181,9 +190,17 @@ REGISTER_OP_CUDA_KERNEL( ops::PadKernel, ops::PadKernel, ops::PadKernel); + paddle::platform::float16>, + ops::PadKernel>, + ops::PadKernel>); REGISTER_OP_CUDA_KERNEL( pad_grad, ops::PadGradKernel, ops::PadGradKernel, ops::PadGradKernel); + paddle::platform::float16>, + ops::PadGradKernel>, + ops::PadGradKernel>); diff --git a/paddle/fluid/operators/roll_op.cc b/paddle/fluid/operators/roll_op.cc index a0c28ae6cba..b6a8111592f 100644 --- a/paddle/fluid/operators/roll_op.cc +++ b/paddle/fluid/operators/roll_op.cc @@ -18,6 +18,7 @@ #include #include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/complex.h" namespace paddle { namespace operators { @@ -148,12 +149,20 @@ REGISTER_OP_CPU_KERNEL( roll, ops::RollKernel, ops::RollKernel, ops::RollKernel, - ops::RollKernel); + ops::RollKernel, + ops::RollKernel>, + ops::RollKernel>); REGISTER_OP_CPU_KERNEL( roll_grad, ops::RollGradKernel, ops::RollGradKernel, ops::RollGradKernel, - ops::RollGradKernel); + ops::RollGradKernel, + ops::RollGradKernel>, + ops::RollGradKernel>); REGISTER_OP_VERSION(roll) .AddCheckpoint( diff --git a/paddle/fluid/operators/roll_op.cu b/paddle/fluid/operators/roll_op.cu index 136c5c0aca8..a170ce2fb11 100644 --- a/paddle/fluid/operators/roll_op.cu +++ b/paddle/fluid/operators/roll_op.cu @@ -16,6 +16,7 @@ #include "paddle/fluid/framework/array.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/roll_op.h" +#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { @@ -188,9 +189,17 @@ REGISTER_OP_CUDA_KERNEL( roll, ops::RollKernel, ops::RollKernel, ops::RollKernel, - ops::RollKernel); + ops::RollKernel, + ops::RollKernel>, + ops::RollKernel>); REGISTER_OP_CUDA_KERNEL( roll_grad, ops::RollGradKernel, ops::RollGradKernel, ops::RollGradKernel, - ops::RollGradKernel); + ops::RollGradKernel, + ops::RollGradKernel>, + ops::RollGradKernel>); diff --git a/paddle/fluid/operators/shape_op.cc b/paddle/fluid/operators/shape_op.cc index d8ec12659f7..dd135b89714 100644 --- a/paddle/fluid/operators/shape_op.cc +++ b/paddle/fluid/operators/shape_op.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/fluid/operators/shape_op.h" #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/complex.h" namespace paddle { namespace operators { @@ -64,6 +65,7 @@ Return the shape of the input. } // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OPERATOR( shape, ops::ShapeOp, ops::ShapeOpMaker, paddle::framework::EmptyGradOpMaker, @@ -71,4 +73,6 @@ REGISTER_OPERATOR( REGISTER_OP_CPU_KERNEL(shape, ops::ShapeKernel, ops::ShapeKernel, ops::ShapeKernel, ops::ShapeKernel, ops::ShapeKernel, ops::ShapeKernel, - ops::ShapeKernel); + ops::ShapeKernel, + ops::ShapeKernel>, + ops::ShapeKernel>); diff --git a/paddle/fluid/operators/shape_op.cu b/paddle/fluid/operators/shape_op.cu index fce723c7841..c6e380a94f8 100644 --- a/paddle/fluid/operators/shape_op.cu +++ b/paddle/fluid/operators/shape_op.cu @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/shape_op.h" +#include "paddle/fluid/platform/complex.h" REGISTER_OP_CUDA_KERNEL( shape, paddle::operators::ShapeKernel, @@ -21,4 +22,6 @@ REGISTER_OP_CUDA_KERNEL( paddle::operators::ShapeKernel, paddle::operators::ShapeKernel, paddle::operators::ShapeKernel, - paddle::operators::ShapeKernel); + paddle::operators::ShapeKernel, + paddle::operators::ShapeKernel>, + paddle::operators::ShapeKernel>); diff --git a/paddle/fluid/operators/spectral_op.cc b/paddle/fluid/operators/spectral_op.cc new file mode 100644 index 00000000000..fb50702233b --- /dev/null +++ b/paddle/fluid/operators/spectral_op.cc @@ -0,0 +1,870 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/spectral_op.h" + +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/operators/transpose_op.h" +#include "paddle/fluid/platform/complex.h" + +#if defined(PADDLE_WITH_ONEMKL) +#include +#elif defined(PADDLE_WITH_POCKETFFT) +#include "extern_pocketfft/pocketfft_hdronly.h" +#endif + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +// FFTC2C +class FFTC2COpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), the input tensor of fft_c2c op."); + AddOutput("Out", "(Tensor), the output tensor of fft_c2c op."); + AddAttr>("axes", + "std::vector, the fft axes."); + AddAttr("normalization", + "fft_norm_type, the fft normalization type."); + AddAttr("forward", "bool, the fft direction."); + AddComment(R"DOC( + Compute complex to complex FFT. + )DOC"); + } +}; + +class FFTC2COp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fft_c2c"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "fft_c2c"); + const auto axes = ctx->Attrs().Get>("axes"); + const auto x_dim = ctx->GetInputDim("X"); + for (size_t i = 0; i < axes.size(); i++) { + PADDLE_ENFORCE_GT(x_dim[axes[i]], 0, + platform::errors::InvalidArgument( + "Invalid fft n-point (%d).", x_dim[axes[i]])); + } + ctx->ShareDim("X", /*->*/ "Out"); // only for c2c + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + const auto kernel_dtype = framework::ToRealType(in_dtype); + return framework::OpKernelType(kernel_dtype, ctx.GetPlace()); + } +}; + +template +class FFTC2CGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr grad_op) const override { + grad_op->SetType("fft_c2c_grad"); + grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + grad_op->SetAttrMap(this->Attrs()); + } +}; + +class FFTC2CGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + const auto out_grad_name = framework::GradVarName("Out"); + OP_INOUT_CHECK(ctx->HasInput(out_grad_name), "Input", out_grad_name, + "fft_c2c_grad"); + const auto x_grad_name = framework::GradVarName("X"); + OP_INOUT_CHECK(ctx->HasOutput(x_grad_name), "Output", x_grad_name, + "fft_c2c_grad"); + + ctx->SetOutputDim(x_grad_name, ctx->GetInputDim(out_grad_name)); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + const auto in_dtype = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); + const auto kernel_dtype = framework::ToRealType(in_dtype); + return framework::OpKernelType(kernel_dtype, ctx.GetPlace()); + } +}; + +// FFTR2C +class FFTR2COpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), the input tensor of fft_r2c op."); + AddOutput("Out", "(Tensor), the output tensor of fft_r2c op."); + AddAttr>("axes", + "std::vector, the fft axes."); + AddAttr("normalization", + "fft_norm_type, the fft normalization type."); + AddAttr("forward", "bool, the fft direction."); + AddAttr("onesided", "bool, perform onesided fft."); + AddComment(R"DOC( + Compute real to complex FFT. + )DOC"); + } +}; + +class FFTR2COp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fft_r2c"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "fft_r2c"); + const auto axes = ctx->Attrs().Get>("axes"); + const auto x_dim = ctx->GetInputDim("X"); + for (size_t i = 0; i < axes.size() - 1L; i++) { + PADDLE_ENFORCE_GT(x_dim[axes[i]], 0, + platform::errors::InvalidArgument( + "Invalid fft n-point (%d).", x_dim[axes[i]])); + } + + const bool onesided = ctx->Attrs().Get("onesided"); + if (!onesided) { + ctx->ShareDim("X", /*->*/ "Out"); + } else { + framework::DDim out_dim(ctx->GetInputDim("X")); + const int64_t last_fft_axis = axes.back(); + const int64_t last_fft_dim_size = out_dim.at(last_fft_axis); + out_dim.at(last_fft_axis) = last_fft_dim_size / 2 + 1; + ctx->SetOutputDim("Out", out_dim); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(in_dtype, ctx.GetPlace()); + } +}; + +template +class FFTR2CGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr grad_op) const override { + grad_op->SetType("fft_r2c_grad"); + grad_op->SetInput("X", this->Input("X")); + grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + grad_op->SetAttrMap(this->Attrs()); + } +}; + +class FFTR2CGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + const auto out_grad_name = framework::GradVarName("Out"); + OP_INOUT_CHECK(ctx->HasInput(out_grad_name), "Input", out_grad_name, + "fft_r2c_grad"); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fft_r2c_grad"); + + const auto x_grad_name = framework::GradVarName("X"); + OP_INOUT_CHECK(ctx->HasOutput(x_grad_name), "Output", x_grad_name, + "fft_r2c_grad"); + + ctx->ShareDim("X", /*->*/ x_grad_name); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + const auto in_dtype = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); + const auto kernel_dtype = framework::ToRealType(in_dtype); + return framework::OpKernelType(kernel_dtype, ctx.GetPlace()); + } +}; + +// FFTC2R +class FFTC2ROpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), the input tensor of fft_c2r op."); + AddOutput("Out", "(Tensor), the output tensor of fft_c2r op."); + AddAttr>("axes", + "std::vector, the fft axes."); + AddAttr("normalization", + "fft_norm_type, the fft normalization type."); + AddAttr("forward", "bool, the fft direction."); + AddAttr( + "last_dim_size", "int", + "Length of the transformed " + "axis of the output. For n output points, last_dim_size//2 + 1 input" + " points are necessary. If the input is longer than this," + " it is cropped. If it is shorter than this, it is padded" + " with zeros. If last_dim_size is not given, it is taken to be 2*(m-1)" + " where m is the length of the input along the axis " + "specified by axis.") + .SetDefault(0L); + AddComment(R"DOC( + Compute complex to complex FFT. + )DOC"); + } +}; + +class FFTC2ROp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fft_c2r"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "fft_c2r"); + + const auto axes = ctx->Attrs().Get>("axes"); + const auto x_dim = ctx->GetInputDim("X"); + for (size_t i = 0; i < axes.size() - 1L; i++) { + PADDLE_ENFORCE_GT(x_dim[axes[i]], 0, + platform::errors::InvalidArgument( + "Invalid fft n-point (%d).", x_dim[axes[i]])); + } + + const int64_t last_dim_size = ctx->Attrs().Get("last_dim_size"); + framework::DDim out_dim(ctx->GetInputDim("X")); + const int64_t last_fft_axis = axes.back(); + if (last_dim_size == 0) { + const int64_t last_fft_dim_size = out_dim.at(last_fft_axis); + const int64_t fft_n_point = (last_fft_dim_size - 1) * 2; + PADDLE_ENFORCE_GT(fft_n_point, 0, + platform::errors::InvalidArgument( + "Invalid fft n-point (%d).", fft_n_point)); + out_dim.at(last_fft_axis) = fft_n_point; + } else { + PADDLE_ENFORCE_GT(last_dim_size, 0, + platform::errors::InvalidArgument( + "Invalid fft n-point (%d).", last_dim_size)); + out_dim.at(last_fft_axis) = last_dim_size; + } + ctx->SetOutputDim("Out", out_dim); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + const auto kernel_dtype = framework::ToRealType(in_dtype); + return framework::OpKernelType(kernel_dtype, ctx.GetPlace()); + } +}; + +template +class FFTC2RGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr grad_op) const override { + grad_op->SetType("fft_c2r_grad"); + grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + grad_op->SetAttrMap(this->Attrs()); + } +}; + +class FFTC2RGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + const auto out_grad_name = framework::GradVarName("Out"); + OP_INOUT_CHECK(ctx->HasInput(out_grad_name), "Input", out_grad_name, + "fft_c2r_grad"); + + const auto x_grad_name = framework::GradVarName("X"); + OP_INOUT_CHECK(ctx->HasOutput(x_grad_name), "Output", x_grad_name, + "fft_c2r_grad"); + + const auto axes = ctx->Attrs().Get>("axes"); + + const auto out_grad_dim = ctx->GetInputDim(out_grad_name); + framework::DDim x_grad_dim(out_grad_dim); + const int64_t last_fft_axis = axes.back(); + const int64_t last_fft_dim_size = x_grad_dim.at(last_fft_axis); + x_grad_dim.at(last_fft_axis) = last_fft_dim_size / 2 + 1; + ctx->SetOutputDim(x_grad_name, x_grad_dim); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + const auto in_dtype = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); + return framework::OpKernelType(in_dtype, ctx.GetPlace()); + } +}; + +// common functions +FFTNormMode get_norm_from_string(const std::string& norm, bool forward) { + if (norm.empty() || norm == "backward") { + return forward ? FFTNormMode::none : FFTNormMode::by_n; + } + + if (norm == "forward") { + return forward ? FFTNormMode::by_n : FFTNormMode::none; + } + + if (norm == "ortho") { + return FFTNormMode::by_sqrt_n; + } + + PADDLE_THROW(platform::errors::InvalidArgument( + "FFT norm string must be 'forward' or 'backward' or 'ortho', " + "received %s", + norm)); +} + +// FFT Functors +#if defined(PADDLE_WITH_ONEMKL) + +namespace { +static inline void MKL_DFTI_CHECK(MKL_INT status) { + if (status && !DftiErrorClass(status, DFTI_NO_ERROR)) { + PADDLE_THROW(platform::errors::External(DftiErrorMessage(status))); + } +} + +struct DftiDescriptorDeleter { + void operator()(DFTI_DESCRIPTOR_HANDLE handle) { + if (handle != nullptr) { + MKL_DFTI_CHECK(DftiFreeDescriptor(&handle)); + } + } +}; + +class DftiDescriptor { + public: + void init(DFTI_CONFIG_VALUE precision, DFTI_CONFIG_VALUE signal_type, + MKL_LONG signal_ndim, MKL_LONG* sizes) { + if (desc_ != nullptr) { + PADDLE_THROW(platform::errors::AlreadyExists( + "DFT DESCRIPTOR can only be initialized once.")); + } + DFTI_DESCRIPTOR* raw_desc; + if (signal_ndim == 1) { + MKL_DFTI_CHECK( + DftiCreateDescriptor(&raw_desc, precision, signal_type, 1, sizes[0])); + } else { + MKL_DFTI_CHECK(DftiCreateDescriptor(&raw_desc, precision, signal_type, + signal_ndim, sizes)); + } + desc_.reset(raw_desc); + } + + DFTI_DESCRIPTOR* get() const { + if (desc_ == nullptr) { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "DFTI DESCRIPTOR has not been initialized.")); + } + return desc_.get(); + } + + private: + std::unique_ptr desc_; +}; + +DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype, + const framework::proto::VarType::Type& out_dtype, + const framework::DDim& in_strides, + const framework::DDim& out_strides, + const std::vector& signal_sizes, + FFTNormMode normalization, bool forward) { + const DFTI_CONFIG_VALUE precision = [&] { + switch (in_dtype) { + case framework::proto::VarType::FP32: + return DFTI_SINGLE; + case framework::proto::VarType::COMPLEX64: + return DFTI_SINGLE; + case framework::proto::VarType::FP64: + return DFTI_DOUBLE; + case framework::proto::VarType::COMPLEX128: + return DFTI_DOUBLE; + default: + PADDLE_THROW(platform::errors::InvalidArgument( + "Input data type should be FP32, FP64, COMPLEX64 or COMPLEX128.")); + } + }(); + + // C2C, R2C, C2R + const FFTTransformType fft_type = GetFFTTransformType(in_dtype, out_dtype); + const DFTI_CONFIG_VALUE domain = + (fft_type == FFTTransformType::C2C) ? DFTI_COMPLEX : DFTI_REAL; + + // const bool complex_input = framework::IsComplexType(in_dtype); + // const bool complex_output = framework::IsComplexType(out_dtype); + // const DFTI_CONFIG_VALUE domain = [&] { + // if (forward) { + // return complex_input ? DFTI_COMPLEX : DFTI_REAL; + // } else { + // return complex_output ? DFTI_COMPLEX : DFTI_REAL; + // } + // }(); + + DftiDescriptor descriptor; + std::vector fft_sizes(signal_sizes.cbegin(), signal_sizes.cend()); + const MKL_LONG signal_ndim = fft_sizes.size() - 1; + descriptor.init(precision, domain, signal_ndim, fft_sizes.data() + 1); + + // placement inplace or not inplace + MKL_DFTI_CHECK( + DftiSetValue(descriptor.get(), DFTI_PLACEMENT, DFTI_NOT_INPLACE)); + + // number of transformations + const MKL_LONG batch_size = fft_sizes[0]; + MKL_DFTI_CHECK( + DftiSetValue(descriptor.get(), DFTI_NUMBER_OF_TRANSFORMS, batch_size)); + + // input & output distance + const MKL_LONG idist = in_strides[0]; + const MKL_LONG odist = out_strides[0]; + MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_INPUT_DISTANCE, idist)); + MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_OUTPUT_DISTANCE, odist)); + + // input & output stride + std::vector mkl_in_stride(1 + signal_ndim, 0); + std::vector mkl_out_stride(1 + signal_ndim, 0); + for (MKL_LONG i = 1; i <= signal_ndim; i++) { + mkl_in_stride[i] = in_strides[i]; + mkl_out_stride[i] = out_strides[i]; + } + MKL_DFTI_CHECK( + DftiSetValue(descriptor.get(), DFTI_INPUT_STRIDES, mkl_in_stride.data())); + MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_OUTPUT_STRIDES, + mkl_out_stride.data())); + + // conjugate even storage + if (!(fft_type == FFTTransformType::C2C)) { + MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_CONJUGATE_EVEN_STORAGE, + DFTI_COMPLEX_COMPLEX)); + } + + MKL_LONG signal_numel = + std::accumulate(fft_sizes.cbegin() + 1, fft_sizes.cend(), 1UL, + std::multiplies()); + if (normalization != FFTNormMode::none) { + const double scale = + ((normalization == FFTNormMode::by_sqrt_n) + ? 1.0 / std::sqrt(static_cast(signal_numel)) + : 1.0 / static_cast(signal_numel)); + const auto scale_direction = [&]() { + if (fft_type == FFTTransformType::R2C || + (fft_type == FFTTransformType::C2C && forward)) { + return DFTI_FORWARD_SCALE; + } else { + // (fft_type == FFTTransformType::C2R || + // (fft_type == FFTTransformType::C2C && !forward)) + return DFTI_BACKWARD_SCALE; + } + }(); + MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), scale_direction, scale)); + } + + // commit the descriptor + MKL_DFTI_CHECK(DftiCommitDescriptor(descriptor.get())); + return descriptor; +} + +// Execute a general fft operation (can be c2c, onesided r2c or onesided c2r) +template +void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out, + const std::vector& axes, FFTNormMode normalization, + bool forward) { + const framework::DDim& in_sizes = x->dims(); + const int ndim = in_sizes.size(); + const int signal_ndim = axes.size(); + const int batch_ndim = ndim - signal_ndim; + const framework::DDim& out_sizes = out->dims(); + + // make a dim permutation + std::vector dim_permute(ndim); + std::iota(dim_permute.begin(), dim_permute.end(), 0); + std::vector is_transformed_dim(ndim, false); + for (const auto& d : axes) { + is_transformed_dim[d] = true; + } + const auto batch_end = + std::partition(dim_permute.begin(), dim_permute.end(), + [&](size_t axis) { return !is_transformed_dim[axis]; }); + std::copy(axes.cbegin(), axes.cend(), batch_end); + + // transpose input according to that permutation + framework::DDim transposed_input_shape = in_sizes.transpose(dim_permute); + std::vector transposed_input_shape_ = + framework::vectorize(transposed_input_shape); + framework::Tensor transposed_input; + transposed_input.Resize(transposed_input_shape); + const auto place = ctx.GetPlace(); + transposed_input.mutable_data(place); + TransCompute(ndim, ctx, *x, &transposed_input, + dim_permute); + + // make an collapsed input: collapse batch axes for input + const int batch_size = std::accumulate( + transposed_input_shape.Get(), transposed_input_shape.Get() + batch_ndim, + 1L, std::multiplies()); + std::vector collapsed_input_shape_(1 + signal_ndim); + collapsed_input_shape_[0] = batch_size; + std::copy(transposed_input_shape_.begin() + batch_ndim, + transposed_input_shape_.end(), collapsed_input_shape_.begin() + 1); + const framework::DDim collapsed_input_shape = + framework::make_ddim(collapsed_input_shape_); + transposed_input.Resize(collapsed_input_shape); + framework::Tensor& collapsed_input = transposed_input; + + // make a collapsed output + std::vector collapsed_output_shape_(1 + signal_ndim); + collapsed_output_shape_[0] = batch_size; + for (int i = 0; i < signal_ndim; i++) { + collapsed_output_shape_[1 + i] = out_sizes[axes[i]]; + } + const framework::DDim collapsed_output_shape = + framework::make_ddim(collapsed_output_shape_); + framework::Tensor collapsed_output; + collapsed_output.Resize(collapsed_output_shape); + collapsed_output.mutable_data(place, out->type()); + + // signal sizes + std::vector signal_sizes(1 + signal_ndim); + signal_sizes[0] = batch_size; + for (int i = 0; i < signal_ndim; i++) { + signal_sizes[1 + i] = + std::max(collapsed_input_shape[1 + i], collapsed_output_shape[1 + i]); + } + + // input & output stride + const framework::DDim input_stride = framework::stride(collapsed_input_shape); + const framework::DDim output_stride = + framework::stride(collapsed_output_shape); + + // make a DFTI_DESCRIPTOR + DftiDescriptor desc = + _plan_mkl_fft(x->type(), out->type(), input_stride, output_stride, + signal_sizes, normalization, forward); + + const FFTTransformType fft_type = GetFFTTransformType(x->type(), out->type()); + if (fft_type == FFTTransformType::C2R && forward) { + framework::Tensor collapsed_input_conj(collapsed_input.type()); + collapsed_input_conj.mutable_data(collapsed_input.dims(), + ctx.GetPlace()); + // conjugate the input + platform::ForRange for_range(ctx, collapsed_input.numel()); + math::ConjFunctor functor(collapsed_input.data(), + collapsed_input.numel(), + collapsed_input_conj.data()); + for_range(functor); + MKL_DFTI_CHECK(DftiComputeBackward(desc.get(), + collapsed_input_conj.data(), + collapsed_output.data())); + } else if (fft_type == FFTTransformType::R2C && !forward) { + framework::Tensor collapsed_output_conj(collapsed_output.type()); + collapsed_output_conj.mutable_data(collapsed_output.dims(), + ctx.GetPlace()); + MKL_DFTI_CHECK(DftiComputeForward(desc.get(), collapsed_input.data(), + collapsed_output_conj.data())); + // conjugate the output + platform::ForRange for_range(ctx, collapsed_output.numel()); + math::ConjFunctor functor(collapsed_output_conj.data(), + collapsed_output.numel(), + collapsed_output.data()); + for_range(functor); + } else { + if (forward) { + MKL_DFTI_CHECK(DftiComputeForward(desc.get(), + collapsed_input.data(), + collapsed_output.data())); + } else { + MKL_DFTI_CHECK(DftiComputeBackward(desc.get(), + collapsed_input.data(), + collapsed_output.data())); + } + } + + // resize for the collapsed output + framework::DDim transposed_output_shape = out_sizes.transpose(dim_permute); + collapsed_output.Resize(transposed_output_shape); + framework::Tensor& transposed_output = collapsed_output; + + // reverse the transposition + std::vector reverse_dim_permute(ndim); + for (int i = 0; i < ndim; i++) { + reverse_dim_permute[dim_permute[i]] = i; + } + TransCompute(ndim, ctx, transposed_output, + out, reverse_dim_permute); +} +} // anonymous namespace + +template +struct FFTC2CFunctor { + void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x, + Tensor* out, const std::vector& axes, + FFTNormMode normalization, bool forward) { + exec_fft(ctx, x, out, axes, + normalization, forward); + } +}; + +template +struct FFTR2CFunctor { + void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x, + Tensor* out, const std::vector& axes, + FFTNormMode normalization, bool forward) { + exec_fft(ctx, x, out, axes, + normalization, forward); + } +}; + +template +struct FFTC2RFunctor { + void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x, + Tensor* out, const std::vector& axes, + FFTNormMode normalization, bool forward) { + if (axes.size() > 1) { + const std::vector c2c_dims(axes.begin(), axes.end() - 1); + Tensor temp; + temp.mutable_data(x->dims(), ctx.GetPlace()); + + FFTC2CFunctor c2c_functor; + c2c_functor(ctx, x, &temp, c2c_dims, normalization, forward); + + const std::vector new_axes{axes.back()}; + exec_fft(ctx, &temp, out, new_axes, + normalization, forward); + } else { + exec_fft(ctx, x, out, axes, + normalization, forward); + } + } +}; + +#elif defined(PADDLE_WITH_POCKETFFT) + +namespace { +template +T compute_factor(int64_t size, FFTNormMode normalization) { + constexpr auto one = static_cast(1); + switch (normalization) { + case FFTNormMode::none: + return one; + case FFTNormMode::by_n: + return one / static_cast(size); + case FFTNormMode::by_sqrt_n: + return one / std::sqrt(static_cast(size)); + } + PADDLE_THROW( + platform::errors::InvalidArgument("Unsupported normalization type")); +} +} // anonymous namespace + +template +struct FFTC2CFunctor { + void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x, + Tensor* out, const std::vector& axes, + FFTNormMode normalization, bool forward) { + using R = typename Ti::value_type; + using C = std::complex; + + const auto& input_dim = x->dims(); + const std::vector in_sizes = + framework::vectorize(input_dim); + std::vector in_strides = + framework::vectorize(framework::stride(input_dim)); + const int64_t data_size = sizeof(C); + std::transform(in_strides.begin(), in_strides.end(), in_strides.begin(), + [&](std::ptrdiff_t s) { return s * data_size; }); + + const auto* in_data = reinterpret_cast(x->data()); + auto* out_data = reinterpret_cast(out->data()); + // pocketfft requires std::vector + std::vector axes_(axes.size()); + std::copy(axes.begin(), axes.end(), axes_.begin()); + // compuet factor + int64_t signal_numel = 1; + for (auto i : axes) { + signal_numel *= in_sizes[i]; + } + R factor = compute_factor(signal_numel, normalization); + pocketfft::c2c(in_sizes, in_strides, in_strides, axes_, forward, in_data, + out_data, factor); + } +}; + +template +struct FFTR2CFunctor { + void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x, + Tensor* out, const std::vector& axes, + FFTNormMode normalization, bool forward) { + using R = Ti; + using C = std::complex; + + const auto& input_dim = x->dims(); + const std::vector in_sizes = + framework::vectorize(input_dim); + std::vector in_strides = + framework::vectorize(framework::stride(input_dim)); + { + const int64_t data_size = sizeof(R); + std::transform(in_strides.begin(), in_strides.end(), in_strides.begin(), + [&](std::ptrdiff_t s) { return s * data_size; }); + } + + const auto& output_dim = out->dims(); + const std::vector out_sizes = + framework::vectorize(output_dim); + std::vector out_strides = + framework::vectorize(framework::stride(output_dim)); + { + const int64_t data_size = sizeof(C); + std::transform(out_strides.begin(), out_strides.end(), + out_strides.begin(), + [&](std::ptrdiff_t s) { return s * data_size; }); + } + + const auto* in_data = x->data(); + auto* out_data = reinterpret_cast(out->data()); + // pocketfft requires std::vector + std::vector axes_(axes.size()); + std::copy(axes.begin(), axes.end(), axes_.begin()); + // compuet normalization factor + int64_t signal_numel = 1; + for (auto i : axes) { + signal_numel *= in_sizes[i]; + } + R factor = compute_factor(signal_numel, normalization); + pocketfft::r2c(in_sizes, in_strides, out_strides, axes_, forward, in_data, + out_data, factor); + } +}; + +template +struct FFTC2RFunctor { + void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x, + Tensor* out, const std::vector& axes, + FFTNormMode normalization, bool forward) { + using R = To; + using C = std::complex; + + const auto& input_dim = x->dims(); + const std::vector in_sizes = + framework::vectorize(input_dim); + std::vector in_strides = + framework::vectorize(framework::stride(input_dim)); + { + const int64_t data_size = sizeof(C); + std::transform(in_strides.begin(), in_strides.end(), in_strides.begin(), + [&](std::ptrdiff_t s) { return s * data_size; }); + } + + const auto& output_dim = out->dims(); + const std::vector out_sizes = + framework::vectorize(output_dim); + std::vector out_strides = + framework::vectorize(framework::stride(output_dim)); + { + const int64_t data_size = sizeof(R); + std::transform(out_strides.begin(), out_strides.end(), + out_strides.begin(), + [&](std::ptrdiff_t s) { return s * data_size; }); + } + + const auto* in_data = reinterpret_cast(x->data()); + auto* out_data = out->data(); + // pocketfft requires std::vector + std::vector axes_(axes.size()); + std::copy(axes.begin(), axes.end(), axes_.begin()); + // compuet normalization factor + int64_t signal_numel = 1; + for (auto i : axes) { + signal_numel *= out_sizes[i]; + } + R factor = compute_factor(signal_numel, normalization); + pocketfft::c2r(out_sizes, in_strides, out_strides, axes_, forward, in_data, + out_data, factor); + } +}; + +#endif + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(fft_c2c, ops::FFTC2COp, ops::FFTC2COpMaker, + ops::FFTC2CGradOpMaker, + ops::FFTC2CGradOpMaker); +REGISTER_OP_CPU_KERNEL( + fft_c2c, ops::FFTC2CKernel, + ops::FFTC2CKernel); + +REGISTER_OPERATOR(fft_c2c_grad, ops::FFTC2CGradOp); +REGISTER_OP_CPU_KERNEL( + fft_c2c_grad, + ops::FFTC2CGradKernel, + ops::FFTC2CGradKernel); + +REGISTER_OPERATOR(fft_r2c, ops::FFTR2COp, ops::FFTR2COpMaker, + ops::FFTR2CGradOpMaker, + ops::FFTR2CGradOpMaker); +REGISTER_OP_CPU_KERNEL( + fft_r2c, ops::FFTR2CKernel, + ops::FFTR2CKernel); + +REGISTER_OPERATOR(fft_r2c_grad, ops::FFTR2CGradOp); +REGISTER_OP_CPU_KERNEL( + fft_r2c_grad, + ops::FFTR2CGradKernel, + ops::FFTR2CGradKernel); + +REGISTER_OPERATOR(fft_c2r, ops::FFTC2ROp, ops::FFTC2ROpMaker, + ops::FFTC2RGradOpMaker, + ops::FFTC2RGradOpMaker); +REGISTER_OP_CPU_KERNEL( + fft_c2r, ops::FFTC2RKernel, + ops::FFTC2RKernel); + +REGISTER_OPERATOR(fft_c2r_grad, ops::FFTC2RGradOp); +REGISTER_OP_CPU_KERNEL( + fft_c2r_grad, + ops::FFTC2RGradKernel, + ops::FFTC2RGradKernel); diff --git a/paddle/fluid/operators/spectral_op.cu b/paddle/fluid/operators/spectral_op.cu new file mode 100644 index 00000000000..9aa5ca39d73 --- /dev/null +++ b/paddle/fluid/operators/spectral_op.cu @@ -0,0 +1,643 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/operators/conj_op.h" +#include "paddle/fluid/operators/spectral_op.h" +#include "paddle/fluid/operators/transpose_op.h" +#include "paddle/fluid/platform/dynload/cufft.h" + +namespace paddle { +namespace operators { + +namespace { + +using ScalarType = framework::proto::VarType::Type; +const int64_t kMaxCUFFTNdim = 3; +const int64_t kMaxDataNdim = kMaxCUFFTNdim + 1; + +static inline std::string get_cufft_error_info(cufftResult error) { + switch (error) { + case CUFFT_SUCCESS: + return "CUFFT_SUCCESS"; + case CUFFT_INVALID_PLAN: + return "CUFFT_INVALID_PLAN"; + case CUFFT_ALLOC_FAILED: + return "CUFFT_ALLOC_FAILED"; + case CUFFT_INVALID_TYPE: + return "CUFFT_INVALID_TYPE"; + case CUFFT_INVALID_VALUE: + return "CUFFT_INVALID_VALUE"; + case CUFFT_INTERNAL_ERROR: + return "CUFFT_INTERNAL_ERROR"; + case CUFFT_EXEC_FAILED: + return "CUFFT_EXEC_FAILED"; + case CUFFT_SETUP_FAILED: + return "CUFFT_SETUP_FAILED"; + case CUFFT_INVALID_SIZE: + return "CUFFT_INVALID_SIZE"; + case CUFFT_UNALIGNED_DATA: + return "CUFFT_UNALIGNED_DATA"; + case CUFFT_INCOMPLETE_PARAMETER_LIST: + return "CUFFT_INCOMPLETE_PARAMETER_LIST"; + case CUFFT_INVALID_DEVICE: + return "CUFFT_INVALID_DEVICE"; + case CUFFT_PARSE_ERROR: + return "CUFFT_PARSE_ERROR"; + case CUFFT_NO_WORKSPACE: + return "CUFFT_NO_WORKSPACE"; + case CUFFT_NOT_IMPLEMENTED: + return "CUFFT_NOT_IMPLEMENTED"; +#ifndef __HIPCC__ + case CUFFT_LICENSE_ERROR: + return "CUFFT_LICENSE_ERROR"; +#endif + case CUFFT_NOT_SUPPORTED: + return "CUFFT_NOT_SUPPORTED"; + default: + std::ostringstream ss; + ss << "unknown error " << error; + return ss.str(); + } +} + +static inline void CUFFT_CHECK(cufftResult error) { + if (error != CUFFT_SUCCESS) { + PADDLE_THROW(platform::errors::External(get_cufft_error_info(error))); + } +} + +// This struct is used to easily compute hashes of the +// parameters. It will be the **key** to the plan cache. +struct PlanKey { + // between 1 and kMaxCUFFTNdim, i.e., 1 <= signal_ndim <= 3 + int64_t signal_ndim_; + // These include additional batch dimension as well. + int64_t sizes_[kMaxDataNdim]; + int64_t input_shape_[kMaxDataNdim]; + int64_t output_shape_[kMaxDataNdim]; + FFTTransformType fft_type_; + ScalarType value_type_; + + PlanKey() = default; + + PlanKey(const std::vector& in_shape, + const std::vector& out_shape, + const std::vector& signal_size, FFTTransformType fft_type, + ScalarType value_type) { + // Padding bits must be zeroed for hashing + memset(this, 0, sizeof(*this)); + signal_ndim_ = signal_size.size() - 1; + fft_type_ = fft_type; + value_type_ = value_type; + + std::copy(signal_size.cbegin(), signal_size.cend(), sizes_); + std::copy(in_shape.cbegin(), in_shape.cend(), input_shape_); + std::copy(out_shape.cbegin(), out_shape.cend(), output_shape_); + } +}; + +// An RAII encapsulation of cuFFTHandle +class CuFFTHandle { + ::cufftHandle handle_; + + public: + CuFFTHandle() { CUFFT_CHECK(platform::dynload::cufftCreate(&handle_)); } + + ::cufftHandle& get() { return handle_; } + const ::cufftHandle& get() const { return handle_; } + + ~CuFFTHandle() { +// Not using fftDestroy() for rocFFT to work around double freeing of handles +#ifndef __HIPCC__ + CUFFT_CHECK(platform::dynload::cufftDestroy(handle_)); +#endif + } +}; + +#ifdef __HIPCC__ +using plan_size_type = int; +#else +using plan_size_type = long long int; // NOLINT +#endif + +// This class contains all the information needed to execute a cuFFT plan: +// 1. the plan +// 2. the workspace size needed +class CuFFTConfig { + public: + // Only move semantics is enought for this class. Although we already use + // unique_ptr for the plan, still remove copy constructor and assignment op so + // we don't accidentally copy and take perf hit. + CuFFTConfig(const CuFFTConfig&) = delete; + CuFFTConfig& operator=(CuFFTConfig const&) = delete; + + explicit CuFFTConfig(const PlanKey& plan_key) + : CuFFTConfig( + std::vector(plan_key.sizes_, + plan_key.sizes_ + plan_key.signal_ndim_ + 1), + plan_key.signal_ndim_, plan_key.fft_type_, plan_key.value_type_) {} + + // sizes are full signal, including batch size and always two-sided + CuFFTConfig(const std::vector& sizes, const int64_t signal_ndim, + FFTTransformType fft_type, ScalarType dtype) + : fft_type_(fft_type), value_type_(dtype) { + // signal sizes (excluding batch dim) + std::vector signal_sizes(sizes.begin() + 1, sizes.end()); + + // input batch size + const auto batch = static_cast(sizes[0]); + // const int64_t signal_ndim = sizes.size() - 1; + PADDLE_ENFORCE_EQ(signal_ndim, sizes.size() - 1, + platform::errors::InvalidArgument( + "The signal_ndim must be equal to sizes.size() - 1," + "But signal_ndim is: [%d], sizes.size() - 1 is: [%d]", + signal_ndim, sizes.size() - 1)); + +#ifdef __HIPCC__ + hipfftType exec_type = [&] { + if (dtype == framework::proto::VarType::FP32) { + switch (fft_type) { + case FFTTransformType::C2C: + return HIPFFT_C2C; + case FFTTransformType::R2C: + return HIPFFT_R2C; + case FFTTransformType::C2R: + return HIPFFT_C2R; + } + } else if (dtype == framework::proto::VarType::FP64) { + switch (fft_type) { + case FFTTransformType::C2C: + return HIPFFT_Z2Z; + case FFTTransformType::R2C: + return HIPFFT_D2Z; + case FFTTransformType::C2R: + return HIPFFT_Z2D; + } + } + PADDLE_THROW(platform::errors::InvalidArgument( + "hipFFT only support transforms of type float32 and float64")); + }(); +#else + cudaDataType itype, otype, exec_type; + const auto complex_input = has_complex_input(fft_type); + const auto complex_output = has_complex_output(fft_type); + if (dtype == framework::proto::VarType::FP32) { + itype = complex_input ? CUDA_C_32F : CUDA_R_32F; + otype = complex_output ? CUDA_C_32F : CUDA_R_32F; + exec_type = CUDA_C_32F; + } else if (dtype == framework::proto::VarType::FP64) { + itype = complex_input ? CUDA_C_64F : CUDA_R_64F; + otype = complex_output ? CUDA_C_64F : CUDA_R_64F; + exec_type = CUDA_C_64F; + } else if (dtype == framework::proto::VarType::FP16) { + itype = complex_input ? CUDA_C_16F : CUDA_R_16F; + otype = complex_output ? CUDA_C_16F : CUDA_R_16F; + exec_type = CUDA_C_16F; + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "cuFFT only support transforms of type float16, float32 and " + "float64")); + } +#endif + + // disable auto allocation of workspace to use allocator from the framework + CUFFT_CHECK(platform::dynload::cufftSetAutoAllocation( + plan(), /* autoAllocate */ 0)); + + size_t ws_size_t; + +// make plan +#ifdef __HIPCC__ + CUFFT_CHECK(hipfftMakePlanMany( + plan(), signal_ndim, signal_sizes.data(), + /* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1, + /* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, exec_type, + batch, &ws_size_t)); +#else + + CUFFT_CHECK(platform::dynload::cufftXtMakePlanMany( + plan(), signal_ndim, signal_sizes.data(), + /* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1, itype, + /* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, otype, + batch, &ws_size_t, exec_type)); +#endif + + ws_size = ws_size_t; + } + + const cufftHandle& plan() const { return plan_ptr.get(); } + + FFTTransformType transform_type() const { return fft_type_; } + ScalarType data_type() const { return value_type_; } + size_t workspace_size() const { return ws_size; } + + private: + CuFFTHandle plan_ptr; + size_t ws_size; + FFTTransformType fft_type_; + ScalarType value_type_; +}; + +// Execute a pre-planned transform +static void exec_cufft_plan(const CuFFTConfig& config, void* in_data, + void* out_data, bool forward) { + auto& plan = config.plan(); +#ifdef __HIPCC__ + auto value_type = config.data_type(); + if (value_type == framework::proto::VarType::FP32) { + switch (config.transform_type()) { + case FFTTransformType::C2C: { + CUFFT_CHECK(hipfftExecC2C(plan, static_cast(in_data), + static_cast(out_data), + forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD)); + return; + } + case FFTTransformType::R2C: { + CUFFT_CHECK(hipfftExecR2C(plan, static_cast(in_data), + static_cast(out_data))); + return; + } + case FFTTransformType::C2R: { + CUFFT_CHECK(hipfftExecC2R(plan, static_cast(in_data), + static_cast(out_data))); + return; + } + } + } else if (value_type == framework::proto::VarType::FP64) { + switch (config.transform_type()) { + case FFTTransformType::C2C: { + CUFFT_CHECK(hipfftExecZ2Z(plan, + static_cast(in_data), + static_cast(out_data), + forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD)); + return; + } + case FFTTransformType::R2C: { + CUFFT_CHECK(hipfftExecD2Z(plan, static_cast(in_data), + static_cast(out_data))); + return; + } + case FFTTransformType::C2R: { + CUFFT_CHECK(hipfftExecZ2D(plan, + static_cast(in_data), + static_cast(out_data))); + return; + } + } + } + PADDLE_THROW(platform::errors::InvalidArgument( + "hipFFT only support transforms of type float32 and float64")); +#else + CUFFT_CHECK(platform::dynload::cufftXtExec( + plan, in_data, out_data, forward ? CUFFT_FORWARD : CUFFT_INVERSE)); +#endif +} + +// Execute a general unnormalized fft operation (can be c2c, onesided r2c or +// onesided c2r) +template +void exec_fft(const DeviceContext& ctx, const Tensor* X, Tensor* out, + const std::vector& dim, bool forward) { + const auto x_dims = framework::vectorize(X->dims()); + const auto out_dims = framework::vectorize(out->dims()); + const int64_t ndim = static_cast(X->dims().size()); + const int64_t signal_ndim = static_cast(dim.size()); + const int64_t batch_dims = ndim - signal_ndim; + auto tensor_place = ctx.GetPlace(); + + // Transpose batch dimensions first, then with transforming dims + std::vector dim_permute(ndim); + std::vector reverse_dim_permute(ndim); + std::vector trans_dims(ndim); + std::iota(dim_permute.begin(), dim_permute.end(), int{0}); + std::vector is_transformed_dim(ndim); + for (const auto& d : dim) { + is_transformed_dim[d] = true; + } + auto batch_end = + std::partition(dim_permute.begin(), dim_permute.end(), + [&](int64_t d) { return !is_transformed_dim[d]; }); + std::sort(dim_permute.begin(), batch_end); + std::copy(dim.cbegin(), dim.cend(), batch_end); + + for (size_t i = 0; i < ndim; i++) { + trans_dims[i] = x_dims[dim_permute[i]]; // shape of input transpose + reverse_dim_permute[dim_permute[i]] = + static_cast(i); // reverse of dim permute + } + framework::Tensor input; + input.Resize(framework::make_ddim(trans_dims)); + input.mutable_data(tensor_place); + /* + auto in_ret = TransposeSimple::run(ctx, *X, dim_permute, input); + if (!in_ret) { + TransCompute(ndim, ctx, *X, input, dim_permute); + } + */ + TransCompute(ndim, ctx, *X, &input, dim_permute); + + // Reshape batch dimensions into a single dimension + std::vector batched_sizes(signal_ndim + 1); + auto batch_size = + std::accumulate(trans_dims.begin(), trans_dims.begin() + batch_dims, + static_cast(1), std::multiplies()); + batched_sizes[0] = batch_size; + std::copy(trans_dims.begin() + batch_dims, trans_dims.end(), + batched_sizes.begin() + 1); + input.Resize(framework::make_ddim(batched_sizes)); + + // Check the shape of transforming dims with input and output + std::vector signal_size(signal_ndim + 1); + signal_size[0] = batch_size; + for (int64_t i = 0; i < signal_ndim; ++i) { + auto in_size = input.dims()[i + 1]; + auto out_size = out_dims[dim[i]]; + signal_size[i + 1] = std::max(in_size, out_size); + PADDLE_ENFORCE_EQ( + (in_size == signal_size[i + 1] || + in_size == (signal_size[i + 1] / 2) + 1), + true, + platform::errors::InvalidArgument( + "The dimension[%d] of Input size: [%d] must be equal or half to " + "The dimension[%d] of Output size: [%d]", + dim[i], in_size, dim[i], out_size)); + PADDLE_ENFORCE_EQ( + (out_size == signal_size[i + 1] || + out_size == (signal_size[i + 1] / 2) + 1), + true, + platform::errors::InvalidArgument( + "The dimension[%d] of Output size: [%d] must be equal or half to " + "The dimension[%d] of Input size: [%d]", + dim[i], out_size, dim[i], in_size)); + } + + std::vector reshape_out_sizes(ndim); + for (size_t i = 0; i < ndim; ++i) { + reshape_out_sizes[i] = out_dims[dim_permute[i]]; + } + std::vector batched_out_sizes(batched_sizes.begin(), + batched_sizes.end()); + for (size_t i = 0; i < dim.size(); ++i) { + batched_out_sizes[i + 1] = out_dims[dim[i]]; + } + + // output + framework::Tensor output; + output.Resize(framework::make_ddim(batched_out_sizes)); + output.mutable_data(tensor_place); + + // Create the transform plan (either from cache or locally) + const auto value_type = framework::IsComplexType(input.type()) + ? framework::ToRealType(input.type()) + : input.type(); + auto fft_type = GetFFTTransformType(input.type(), output.type()); + PlanKey Key(framework::vectorize(input.dims()), + framework::vectorize(output.dims()), signal_size, fft_type, + value_type); + CuFFTConfig uncached_plan(Key); + CuFFTConfig* config = &uncached_plan; + auto& plan = config->plan(); + + // prepare cufft for execution + CUFFT_CHECK(platform::dynload::cufftSetStream(plan, ctx.stream())); + framework::Tensor workspace_tensor; + workspace_tensor.mutable_data(tensor_place, config->workspace_size()); + CUFFT_CHECK( + platform::dynload::cufftSetWorkArea(plan, workspace_tensor.data())); + + // execute transform plan + if (fft_type == FFTTransformType::C2R && forward) { + forward = false; + framework::Tensor input_conj(input.type()); + input_conj.mutable_data(input.dims(), ctx.GetPlace()); + platform::ForRange for_range(ctx, input.numel()); + math::ConjFunctor functor(input.data(), input.numel(), + input_conj.data()); + for_range(functor); + exec_cufft_plan(*config, input_conj.data(), output.data(), + forward); + } else if (fft_type == FFTTransformType::R2C && !forward) { + forward = true; + framework::Tensor out_conj(output.type()); + out_conj.mutable_data(output.dims(), ctx.GetPlace()); + exec_cufft_plan(*config, input.data(), out_conj.data(), + forward); + + platform::ForRange for_range(ctx, output.numel()); + math::ConjFunctor functor(out_conj.data(), output.numel(), + output.data()); + for_range(functor); + } else { + exec_cufft_plan(*config, input.data(), output.data(), forward); + } + + // Inverting output by reshape and transpose to original batch and dimension + output.Resize(framework::make_ddim(reshape_out_sizes)); + out->Resize(framework::make_ddim(out_dims)); + TransCompute(ndim, ctx, output, out, reverse_dim_permute); +} + +// Calculates the normalization constant +double fft_normalization_scale(FFTNormMode normalization, + const std::vector& sizes, + const std::vector& dims) { + // auto norm = static_cast(normalization); + if (normalization == FFTNormMode::none) { + return static_cast(1.0); + } + + int64_t signal_numel = 1; + for (auto dim : dims) { + signal_numel *= sizes[dim]; + } + const double scale_denom = (normalization == FFTNormMode::by_sqrt_n) + ? std::sqrt(signal_numel) + : static_cast(signal_numel); + return static_cast(1.0 / scale_denom); +} + +template +void exec_normalization(const DeviceContext& ctx, const Tensor* in, Tensor* out, + FFTNormMode normalization, + const std::vector& sizes, + const std::vector& axes) { + double scale = fft_normalization_scale(normalization, sizes, axes); + if (scale != 1.0) { + auto eigen_out = framework::EigenVector::Flatten(*out); + auto eigen_in = framework::EigenVector::Flatten(*in); + auto dev = ctx.eigen_device(); + EigenScale::Eval(*dev, eigen_out, eigen_in, + static_cast(scale), + static_cast(0), false); + } else { + framework::TensorCopy(*in, ctx.GetPlace(), out); + } +} +} // anonymous namespace + +// Use the optimized path to perform single R2C or C2R if transformation dim is +// supported by cuFFT +bool use_optimized_cufft_path(const std::vector& axes) { + // For performance reason, when axes starts with (0, 1), do not use the + // optimized path. + if (axes.size() > kMaxCUFFTNdim || + (axes.size() >= 2 && axes[0] == 0 && axes[1] == 1)) { + return false; + } else { + return true; + } +} + +template +struct FFTC2CFunctor { + void operator()(const platform::CUDADeviceContext& ctx, const Tensor* X, + Tensor* out, const std::vector& axes, + FFTNormMode normalization, bool forward) { + if (axes.empty()) { + framework::TensorCopy(*X, ctx.GetPlace(), out); + return; + } + + framework::Tensor* p_out = out; + std::vector out_dims = framework::vectorize(X->dims()); + std::vector working_axes(axes.begin(), axes.end()); + std::vector first_dims; + size_t max_dims; + framework::Tensor working_tensor; + working_tensor.mutable_data(X->dims(), ctx.GetPlace()); + framework::Tensor* p_working_tensor = &working_tensor; + framework::TensorCopy(*X, ctx.GetPlace(), &working_tensor); + + while (true) { + max_dims = + std::min(static_cast(kMaxCUFFTNdim), working_axes.size()); + first_dims.assign(working_axes.end() - max_dims, working_axes.end()); + + exec_fft(ctx, p_working_tensor, + p_out, first_dims, forward); + working_axes.resize(working_axes.size() - max_dims); + first_dims.clear(); + + if (working_axes.empty()) { + break; + } + + std::swap(p_out, p_working_tensor); + } + exec_normalization( + ctx, p_out, out, normalization, out_dims, axes); + } +}; + +template +struct FFTC2RFunctor { + void operator()(const platform::CUDADeviceContext& ctx, const Tensor* X, + Tensor* out, const std::vector& axes, + FFTNormMode normalization, bool forward) { + std::vector in_dims = framework::vectorize(X->dims()); + std::vector out_dims = framework::vectorize(out->dims()); + + if (use_optimized_cufft_path(axes)) { + framework::Tensor x_copy(X->type()); + x_copy.mutable_data(X->dims(), ctx.GetPlace()); + framework::TensorCopy(*X, ctx.GetPlace(), &x_copy); + exec_fft(ctx, &x_copy, out, axes, + forward); + } else { + framework::Tensor temp_tensor; + temp_tensor.mutable_data(X->dims(), ctx.GetPlace()); + const std::vector dims(axes.begin(), axes.end() - 1); + + FFTC2CFunctor c2c_functor; + c2c_functor(ctx, X, &temp_tensor, dims, FFTNormMode::none, forward); + + exec_fft(ctx, &temp_tensor, out, + {axes.back()}, forward); + } + exec_normalization( + ctx, out, out, normalization, out_dims, axes); + } +}; + +// n dimension real to complex FFT use cufft lib +template +struct FFTR2CFunctor { + void operator()(const platform::CUDADeviceContext& ctx, const Tensor* X, + Tensor* out, const std::vector& axes, + FFTNormMode normalization, bool forward) { + // Step1: R2C transform on the last dimension + framework::Tensor* r2c_out = out; + const std::vector last_dim{axes.back()}; + std::vector out_dims = framework::vectorize(out->dims()); + exec_fft(ctx, X, r2c_out, last_dim, + forward); + + // Step2: C2C transform on the remaining dimension + framework::Tensor c2c_out; + if (axes.size() > 1) { + c2c_out.mutable_data(out->dims(), ctx.GetPlace()); + std::vector remain_dim(axes.begin(), axes.end() - 1); + FFTC2CFunctor fft_c2c_func; + fft_c2c_func(ctx, r2c_out, &c2c_out, remain_dim, FFTNormMode::none, + forward); + } + + const auto in_sizes = framework::vectorize(X->dims()); + framework::Tensor* norm_tensor = axes.size() > 1 ? &c2c_out : r2c_out; + exec_normalization( + ctx, norm_tensor, out, normalization, in_sizes, axes); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + fft_c2c, ops::FFTC2CKernel, + ops::FFTC2CKernel); + +REGISTER_OP_CUDA_KERNEL( + fft_c2c_grad, + ops::FFTC2CGradKernel, + ops::FFTC2CGradKernel); + +REGISTER_OP_CUDA_KERNEL( + fft_c2r, ops::FFTC2RKernel, + ops::FFTC2RKernel); + +REGISTER_OP_CUDA_KERNEL( + fft_c2r_grad, + ops::FFTC2RGradKernel, + ops::FFTC2RGradKernel); + +REGISTER_OP_CUDA_KERNEL( + fft_r2c, ops::FFTR2CKernel, + ops::FFTR2CKernel); + +REGISTER_OP_CUDA_KERNEL( + fft_r2c_grad, + ops::FFTR2CGradKernel, + ops::FFTR2CGradKernel); diff --git a/paddle/fluid/operators/spectral_op.h b/paddle/fluid/operators/spectral_op.h new file mode 100644 index 00000000000..e549c4a454b --- /dev/null +++ b/paddle/fluid/operators/spectral_op.h @@ -0,0 +1,461 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#define NOMINMAX // to use std::min std::max correctly on windows +#include +#include +#include +#include +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/data_type_transform.h" +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/conj_op.h" +#include "paddle/fluid/operators/eigen/eigen_function.h" +#include "paddle/fluid/operators/math/padding.h" +#include "paddle/fluid/platform/complex.h" +#include "paddle/fluid/platform/for_range.h" +#if defined(__NVCC__) || defined(__HIPCC__) +#include "thrust/device_vector.h" +#endif + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +enum class FFTNormMode : int64_t { + none, // No normalization + by_sqrt_n, // Divide by sqrt(signal_size) + by_n, // Divide by signal_size +}; + +FFTNormMode get_norm_from_string(const std::string& norm, bool forward); + +// Enum representing the FFT type +enum class FFTTransformType : int64_t { + C2C = 0, // Complex-to-complex + R2C, // Real-to-complex + C2R, // Complex-to-real +}; + +// Create transform type enum from bools representing if input and output are +// complex +inline FFTTransformType GetFFTTransformType( + framework::proto::VarType::Type input_dtype, + framework::proto::VarType::Type output_dtype) { + auto complex_input = framework::IsComplexType(input_dtype); + auto complex_output = framework::IsComplexType(output_dtype); + if (complex_input && complex_output) { + return FFTTransformType::C2C; + } else if (complex_input && !complex_output) { + return FFTTransformType::C2R; + } else if (!complex_input && complex_output) { + return FFTTransformType::R2C; + } + PADDLE_THROW( + platform::errors::InvalidArgument("Real to real FFTs are not supported")); +} + +// Returns true if the transform type has complex input +inline bool has_complex_input(FFTTransformType type) { + switch (type) { + case FFTTransformType::C2C: + case FFTTransformType::C2R: + return true; + + case FFTTransformType::R2C: + return false; + } + PADDLE_THROW(platform::errors::InvalidArgument("Unknown FFTTransformType")); +} + +// Returns true if the transform type has complex output +inline bool has_complex_output(FFTTransformType type) { + switch (type) { + case FFTTransformType::C2C: + case FFTTransformType::R2C: + return true; + + case FFTTransformType::C2R: + return false; + } + PADDLE_THROW(platform::errors::InvalidArgument("Unknown FFTTransformType")); +} + +template +struct FFTFillConjGradFunctor { + T* input_; + const size_t axis_; + const int64_t* strides_; + const size_t double_length_; + + FFTFillConjGradFunctor(T* input, size_t axis, const int64_t* strides, + size_t double_length) + : input_(input), + axis_(axis), + strides_(strides), + double_length_(double_length) {} + + HOSTDEVICE void operator()(size_t index) { + size_t offtset = index; // back + size_t index_i; + for (size_t i = 0; i <= axis_; i++) { + index_i = offtset / strides_[i]; + offtset %= strides_[i]; + } + + if ((0 < index_i) && (index_i < double_length_ + 1)) { + input_[index] *= static_cast(2); + } + } +}; + +template +struct FFTC2CFunctor { + void operator()(const DeviceContext& ctx, const Tensor* X, Tensor* out, + const std::vector& axes, FFTNormMode normalization, + bool forward); +}; + +template +struct FFTR2CFunctor { + void operator()(const DeviceContext& ctx, const Tensor* X, Tensor* out, + const std::vector& axes, FFTNormMode normalization, + bool forward); +}; + +template +struct FFTC2RFunctor { + void operator()(const DeviceContext& ctx, const Tensor* X, Tensor* out, + const std::vector& axes, FFTNormMode normalization, + bool forward); +}; + +// Giving a linear destination index and strides of tensor, get_idx return the +// corresponding linear position of source tensor. +// The linear index is the position of flatten tensor. +// Giving a linear destination index and strides of tensor, get_idx return the +// corresponding linear position of source tensor. +// The linear index is the position of flatten tensor. +HOSTDEVICE inline int64_t get_src_idx(const int64_t dst_idx, + const int64_t* dst_strides, + const int64_t* dst_shape, + const int64_t* src_strides, + const bool* is_fft_axis, const bool conj, + const int64_t rank) { + int64_t src_idx = 0; + int64_t quotient = dst_idx; + int64_t remainder = 0; + + for (int64_t i = 0; i < rank; i++) { + remainder = quotient % dst_strides[i]; + quotient = quotient / dst_strides[i]; + if (conj && is_fft_axis[i]) { + src_idx += ((dst_shape[i] - quotient) % dst_shape[i]) * src_strides[i]; + } else { + src_idx += src_strides[i] * quotient; + } + quotient = remainder; + } + + return src_idx; +} + +HOSTDEVICE inline bool is_conj_part(const int64_t dst_idx, + const int64_t* dst_strides, + const int64_t last_axis, + const int64_t last_axis_size) { + int64_t quotient = dst_idx; + int64_t remainder = 0; + + for (int64_t i = 0; i < last_axis + 1; i++) { + remainder = quotient % dst_strides[i]; + quotient = quotient / dst_strides[i]; + + if ((i == last_axis) && (quotient > last_axis_size - 1)) { + return true; + } + + quotient = remainder; + } + + return false; +} + +// FFTFillConjFunctor fill the destination tensor with source tensor and +// conjugate symmetry element of source tensor . +// Use framework::ForRange to iterate destination element with +// supporting different device +template +struct FFTFillConjFunctor { + FFTFillConjFunctor(const C* src_data, C* dst_data, const int64_t* src_strides, + const int64_t* dst_strides, const int64_t* dst_shape, + const bool* is_fft_axis, const int64_t last_axis, + const int64_t last_axis_size, const int64_t rank) + : src_data_(src_data), + dst_data_(dst_data), + src_strides_(src_strides), + dst_strides_(dst_strides), + dst_shape_(dst_shape), + is_fft_axis_(is_fft_axis), + last_axis_(last_axis), + last_axis_size_(last_axis_size), + rank_(rank) {} + HOSTDEVICE void operator()(int64_t dst_idx) { + if (is_conj_part(dst_idx, dst_strides_, last_axis_, last_axis_size_)) { + const auto conj_idx = + get_src_idx(dst_idx, dst_strides_, dst_shape_, src_strides_, + is_fft_axis_, true, rank_); + auto src_value = src_data_[conj_idx]; + auto conj_value = C(src_value.real, -src_value.imag); + dst_data_[dst_idx] = conj_value; + } else { + const auto copy_idx = + get_src_idx(dst_idx, dst_strides_, dst_shape_, src_strides_, + is_fft_axis_, false, rank_); + dst_data_[dst_idx] = src_data_[copy_idx]; + } + } + + const C* src_data_; + C* dst_data_; + const int64_t* src_strides_; + const int64_t* dst_strides_; + const int64_t* dst_shape_; + const bool* is_fft_axis_; + const int64_t last_axis_; + const int64_t last_axis_size_; + const int64_t rank_; +}; + +template +void fill_conj(const DeviceContext& ctx, const Tensor* src, Tensor* dst, + const std::vector& axes) { + std::vector src_strides_v = + framework::vectorize(framework::stride(src->dims())); + std::vector dst_strides_v = + framework::vectorize(framework::stride(dst->dims())); + std::vector dst_shape_v = framework::vectorize(dst->dims()); + const auto src_data = src->data(); + auto dst_data = dst->data(); + const auto last_axis = axes.back(); + const auto last_axis_size = dst->dims().at(last_axis) / 2 + 1; + const int64_t rank = dst->dims().size(); + auto _is_fft_axis = std::make_unique(rank); + for (const auto i : axes) { + _is_fft_axis[i] = true; + } + +#if defined(__NVCC__) || defined(__HIPCC__) + const thrust::device_vector src_strides_g(src_strides_v); + const auto src_strides = thrust::raw_pointer_cast(src_strides_g.data()); + const thrust::device_vector dst_strides_g(dst_strides_v); + const auto dst_strides = thrust::raw_pointer_cast(dst_strides_g.data()); + const thrust::device_vector dst_shape_g(dst_shape_v); + const auto dst_shape = thrust::raw_pointer_cast(dst_shape_g.data()); + const thrust::device_vector is_fft_axis_g(_is_fft_axis.get(), + _is_fft_axis.get() + rank); + const auto p_is_fft_axis = thrust::raw_pointer_cast(is_fft_axis_g.data()); +#else + const auto src_strides = src_strides_v.data(); + const auto dst_strides = dst_strides_v.data(); + const auto dst_shape = dst_shape_v.data(); + const auto p_is_fft_axis = _is_fft_axis.get(); +#endif + platform::ForRange for_range(ctx, dst->numel()); + FFTFillConjFunctor fill_conj_functor(src_data, dst_data, src_strides, + dst_strides, dst_shape, p_is_fft_axis, + last_axis, last_axis_size, rank); + for_range(fill_conj_functor); +} + +template +class FFTC2CKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + using C = paddle::platform::complex; + auto& dev_ctx = ctx.device_context(); + + auto axes = ctx.Attr>("axes"); + const std::string& norm_str = ctx.Attr("normalization"); + const bool forward = ctx.Attr("forward"); + const auto* x = ctx.Input("X"); + auto* y = ctx.Output("Out"); + + y->mutable_data(ctx.GetPlace()); + auto normalization = get_norm_from_string(norm_str, forward); + + FFTC2CFunctor fft_c2c_func; + fft_c2c_func(dev_ctx, x, y, axes, normalization, forward); + } +}; + +template +class FFTC2CGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + using C = paddle::platform::complex; + auto& dev_ctx = ctx.device_context(); + + auto axes = ctx.Attr>("axes"); + const std::string& norm_str = ctx.Attr("normalization"); + const bool forward = ctx.Attr("forward"); + const auto* dy = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + + dx->mutable_data(ctx.GetPlace()); + auto normalization = get_norm_from_string(norm_str, forward); + + FFTC2CFunctor fft_c2c_func; + fft_c2c_func(dev_ctx, dy, dx, axes, normalization, !forward); + } +}; + +template +class FFTR2CKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + using C = paddle::platform::complex; + auto& dev_ctx = ctx.device_context(); + + auto axes = ctx.Attr>("axes"); + const std::string& norm_str = ctx.Attr("normalization"); + const bool forward = ctx.Attr("forward"); + const bool onesided = ctx.Attr("onesided"); + const auto* x = ctx.Input("X"); + auto* y = ctx.Output("Out"); + + y->mutable_data(ctx.GetPlace()); + auto normalization = get_norm_from_string(norm_str, forward); + + FFTR2CFunctor fft_r2c_func; + + if (onesided) { + fft_r2c_func(dev_ctx, x, y, axes, normalization, forward); + } else { + framework::DDim onesided_dims(y->dims()); + const int64_t onesided_last_axis_size = y->dims().at(axes.back()) / 2 + 1; + onesided_dims.at(axes.back()) = onesided_last_axis_size; + framework::Tensor onesided_out; + onesided_out.mutable_data(onesided_dims, ctx.GetPlace()); + fft_r2c_func(dev_ctx, x, &onesided_out, axes, normalization, forward); + fill_conj(dev_ctx, &onesided_out, y, axes); + } + } +}; + +template +class FFTR2CGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + using C = paddle::platform::complex; + auto& dev_ctx = ctx.device_context(); + + const auto axes = ctx.Attr>("axes"); + const std::string& norm_str = ctx.Attr("normalization"); + const bool forward = ctx.Attr("forward"); + const bool onesided = ctx.Attr("onesided"); + + const auto* dy = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + dx->mutable_data(ctx.GetPlace()); + framework::Tensor complex_dx; + complex_dx.mutable_data(dx->dims(), ctx.GetPlace()); + + auto normalization = get_norm_from_string(norm_str, forward); + FFTC2CFunctor fft_c2c_func; + + if (!onesided) { + fft_c2c_func(dev_ctx, dy, &complex_dx, axes, normalization, !forward); + } else { + framework::Tensor full_dy; + full_dy.mutable_data(dx->dims(), ctx.GetPlace()); + auto zero_length = static_cast(full_dy.dims().at(axes.back()) - + dy->dims().at(axes.back())); + auto rank = dy->dims().size(); + + std::vector pads(rank * 2, 0); + pads[axes.back() * 2 + 1] = zero_length; + + paddle::operators::math::PaddingFunctor( + rank, ctx, pads, static_cast(0), *dy, &full_dy); + fft_c2c_func(dev_ctx, &full_dy, &complex_dx, axes, normalization, + !forward); + } + framework::TransComplexToReal(dx->type(), complex_dx.type(), complex_dx, + dx); + } +}; + +template +class FFTC2RKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + using C = paddle::platform::complex; + auto& dev_ctx = ctx.device_context(); + + auto axes = ctx.Attr>("axes"); + const std::string& norm_str = ctx.Attr("normalization"); + const bool forward = ctx.Attr("forward"); + const auto* x = ctx.Input("X"); + auto* y = ctx.Output("Out"); + + y->mutable_data(ctx.GetPlace()); + auto normalization = get_norm_from_string(norm_str, forward); + + FFTC2RFunctor fft_c2r_func; + fft_c2r_func(dev_ctx, x, y, axes, normalization, forward); + } +}; + +template +class FFTC2RGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + using C = paddle::platform::complex; + auto& dev_ctx = ctx.device_context(); + + auto axes = ctx.Attr>("axes"); + const std::string& norm_str = ctx.Attr("normalization"); + const bool forward = ctx.Attr("forward"); + const auto* dy = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + + C* pdx = dx->mutable_data(ctx.GetPlace()); + auto normalization = get_norm_from_string(norm_str, forward); + + FFTR2CFunctor fft_r2c_func; + fft_r2c_func(dev_ctx, dy, dx, axes, normalization, !forward); + + const int64_t double_length = + dy->dims()[axes.back()] - dx->dims()[axes.back()]; + const framework::DDim strides = framework::stride(dx->dims()); + +#if defined(__NVCC__) || defined(__HIPCC__) + const thrust::device_vector strides_g( + framework::vectorize(strides)); + const int64_t* pstrides = thrust::raw_pointer_cast(strides_g.data()); +#else + const int64_t* pstrides = strides.Get(); +#endif + + FFTFillConjGradFunctor func(pdx, axes.back(), pstrides, double_length); + size_t limit = dx->numel(); + platform::ForRange for_range(dev_ctx, limit); + for_range(func); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/squeeze_op.cc b/paddle/fluid/operators/squeeze_op.cc index 43b9b39bda8..8894ca650de 100644 --- a/paddle/fluid/operators/squeeze_op.cc +++ b/paddle/fluid/operators/squeeze_op.cc @@ -389,7 +389,11 @@ REGISTER_OP_CPU_KERNEL( ops::SqueezeKernel, ops::SqueezeKernel, ops::SqueezeKernel, - ops::SqueezeKernel); + ops::SqueezeKernel, + ops::SqueezeKernel>, + ops::SqueezeKernel>); REGISTER_OP_CPU_KERNEL( squeeze_grad, ops::SqueezeGradKernel, @@ -398,7 +402,12 @@ REGISTER_OP_CPU_KERNEL( ops::SqueezeGradKernel, ops::SqueezeGradKernel, ops::SqueezeGradKernel, - ops::SqueezeGradKernel); + ops::SqueezeGradKernel, + ops::SqueezeGradKernel>, + ops::SqueezeGradKernel>); + REGISTER_OP_CPU_KERNEL( squeeze2, ops::Squeeze2Kernel, ops::Squeeze2Kernel, @@ -406,7 +415,12 @@ REGISTER_OP_CPU_KERNEL( ops::Squeeze2Kernel, ops::Squeeze2Kernel, ops::Squeeze2Kernel, - ops::Squeeze2Kernel); + ops::Squeeze2Kernel, + ops::Squeeze2Kernel>, + ops::Squeeze2Kernel>); + REGISTER_OP_CPU_KERNEL( squeeze2_grad, ops::Squeeze2GradKernel, @@ -415,4 +429,8 @@ REGISTER_OP_CPU_KERNEL( ops::Squeeze2GradKernel, ops::Squeeze2GradKernel, ops::Squeeze2GradKernel, - ops::Squeeze2GradKernel); + ops::Squeeze2GradKernel, + ops::Squeeze2GradKernel>, + ops::Squeeze2GradKernel>); diff --git a/paddle/fluid/operators/squeeze_op.cu.cc b/paddle/fluid/operators/squeeze_op.cu.cc old mode 100755 new mode 100644 index 23431df12b6..9b4000c26ff --- a/paddle/fluid/operators/squeeze_op.cu.cc +++ b/paddle/fluid/operators/squeeze_op.cu.cc @@ -25,7 +25,11 @@ REGISTER_OP_CUDA_KERNEL( ops::SqueezeKernel, ops::SqueezeKernel, ops::SqueezeKernel, - ops::SqueezeKernel); + ops::SqueezeKernel, + ops::SqueezeKernel>, + ops::SqueezeKernel>); REGISTER_OP_CUDA_KERNEL( squeeze_grad, ops::SqueezeGradKernel, @@ -35,7 +39,11 @@ REGISTER_OP_CUDA_KERNEL( ops::SqueezeGradKernel, ops::SqueezeGradKernel, ops::SqueezeGradKernel, - ops::SqueezeGradKernel); + ops::SqueezeGradKernel, + ops::SqueezeGradKernel>, + ops::SqueezeGradKernel>); REGISTER_OP_CUDA_KERNEL( squeeze2, ops::Squeeze2Kernel, ops::Squeeze2Kernel, @@ -44,7 +52,11 @@ REGISTER_OP_CUDA_KERNEL( ops::Squeeze2Kernel, ops::Squeeze2Kernel, ops::Squeeze2Kernel, - ops::Squeeze2Kernel); + ops::Squeeze2Kernel, + ops::Squeeze2Kernel>, + ops::Squeeze2Kernel>); REGISTER_OP_CUDA_KERNEL( squeeze2_grad, ops::Squeeze2GradKernel, @@ -54,4 +66,8 @@ REGISTER_OP_CUDA_KERNEL( ops::Squeeze2GradKernel, ops::Squeeze2GradKernel, ops::Squeeze2GradKernel, - ops::Squeeze2GradKernel); + ops::Squeeze2GradKernel, + ops::Squeeze2GradKernel>, + ops::Squeeze2GradKernel>); diff --git a/paddle/fluid/operators/unsqueeze_op.cc b/paddle/fluid/operators/unsqueeze_op.cc old mode 100755 new mode 100644 index ed7a4f92f09..77b06fb2d4b --- a/paddle/fluid/operators/unsqueeze_op.cc +++ b/paddle/fluid/operators/unsqueeze_op.cc @@ -362,7 +362,11 @@ REGISTER_OP_CPU_KERNEL( ops::UnsqueezeKernel, ops::UnsqueezeKernel, ops::UnsqueezeKernel, - ops::UnsqueezeKernel); + ops::UnsqueezeKernel, + ops::UnsqueezeKernel>, + ops::UnsqueezeKernel>); REGISTER_OP_CPU_KERNEL( unsqueeze_grad, ops::UnsqueezeGradKernel, @@ -371,7 +375,11 @@ REGISTER_OP_CPU_KERNEL( ops::UnsqueezeGradKernel, ops::UnsqueezeGradKernel, ops::UnsqueezeGradKernel, - ops::UnsqueezeGradKernel); + ops::UnsqueezeGradKernel, + ops::UnsqueezeGradKernel>, + ops::UnsqueezeGradKernel>); REGISTER_OP_CPU_KERNEL( unsqueeze2, ops::UnsqueezeKernel, ops::UnsqueezeKernel, @@ -379,7 +387,11 @@ REGISTER_OP_CPU_KERNEL( ops::UnsqueezeKernel, ops::UnsqueezeKernel, ops::UnsqueezeKernel, - ops::UnsqueezeKernel); + ops::UnsqueezeKernel, + ops::UnsqueezeKernel>, + ops::UnsqueezeKernel>); REGISTER_OP_CPU_KERNEL( unsqueeze2_grad, ops::Unsqueeze2GradKernel, @@ -388,4 +400,8 @@ REGISTER_OP_CPU_KERNEL( ops::Unsqueeze2GradKernel, ops::Unsqueeze2GradKernel, ops::Unsqueeze2GradKernel, - ops::Unsqueeze2GradKernel); + ops::Unsqueeze2GradKernel, + ops::Unsqueeze2GradKernel>, + ops::Unsqueeze2GradKernel>); diff --git a/paddle/fluid/operators/unsqueeze_op.cu.cc b/paddle/fluid/operators/unsqueeze_op.cu.cc old mode 100755 new mode 100644 index 2781b3ef8c8..d1fe251ef77 --- a/paddle/fluid/operators/unsqueeze_op.cu.cc +++ b/paddle/fluid/operators/unsqueeze_op.cu.cc @@ -25,7 +25,11 @@ REGISTER_OP_CUDA_KERNEL( ops::UnsqueezeKernel, ops::UnsqueezeKernel, ops::UnsqueezeKernel, - ops::UnsqueezeKernel); + ops::UnsqueezeKernel, + ops::UnsqueezeKernel>, + ops::UnsqueezeKernel>); REGISTER_OP_CUDA_KERNEL( unsqueeze_grad, ops::UnsqueezeGradKernel, @@ -36,7 +40,11 @@ REGISTER_OP_CUDA_KERNEL( ops::UnsqueezeGradKernel, ops::UnsqueezeGradKernel, ops::UnsqueezeGradKernel, - ops::UnsqueezeGradKernel); + ops::UnsqueezeGradKernel, + ops::UnsqueezeGradKernel>, + ops::UnsqueezeGradKernel>); REGISTER_OP_CUDA_KERNEL( unsqueeze2, ops::UnsqueezeKernel, @@ -46,7 +54,11 @@ REGISTER_OP_CUDA_KERNEL( ops::UnsqueezeKernel, ops::UnsqueezeKernel, ops::UnsqueezeKernel, - ops::UnsqueezeKernel); + ops::UnsqueezeKernel, + ops::UnsqueezeKernel>, + ops::UnsqueezeKernel>); REGISTER_OP_CUDA_KERNEL( unsqueeze2_grad, ops::Unsqueeze2GradKernel, @@ -57,4 +69,8 @@ REGISTER_OP_CUDA_KERNEL( ops::Unsqueeze2GradKernel, ops::Unsqueeze2GradKernel, ops::Unsqueeze2GradKernel, - ops::Unsqueeze2GradKernel); + ops::Unsqueeze2GradKernel, + ops::Unsqueeze2GradKernel>, + ops::Unsqueeze2GradKernel>); diff --git a/paddle/fluid/platform/complex.h b/paddle/fluid/platform/complex.h index 2c1b42ea488..065ccd375c9 100644 --- a/paddle/fluid/platform/complex.h +++ b/paddle/fluid/platform/complex.h @@ -60,6 +60,8 @@ struct PADDLE_ALIGN(sizeof(T) * 2) complex { T real; T imag; + using value_type = T; + complex() = default; complex(const complex& o) = default; complex& operator=(const complex& o) = default; diff --git a/paddle/fluid/platform/dynload/CMakeLists.txt b/paddle/fluid/platform/dynload/CMakeLists.txt index ac98ff02035..eed3568f1d8 100644 --- a/paddle/fluid/platform/dynload/CMakeLists.txt +++ b/paddle/fluid/platform/dynload/CMakeLists.txt @@ -1,6 +1,6 @@ cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags enforce) -list(APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc cusolver.cc cusparse.cc nvtx.cc) +list(APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc cusolver.cc cusparse.cc nvtx.cc cufft.cc) if (NOT WITH_NV_JETSON) list(APPEND CUDA_SRCS nvjpeg.cc) diff --git a/paddle/fluid/platform/dynload/cufft.cc b/paddle/fluid/platform/dynload/cufft.cc new file mode 100644 index 00000000000..a125fb72260 --- /dev/null +++ b/paddle/fluid/platform/dynload/cufft.cc @@ -0,0 +1,44 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/dynload/cufft.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace platform { +namespace dynload { +std::once_flag cufft_dso_flag; +void* cufft_dso_handle = nullptr; + +#define DEFINE_WRAP(__name) DynLoad__##__name __name + +CUFFT_FFT_ROUTINE_EACH(DEFINE_WRAP); + +bool HasCUFFT() { + std::call_once(cufft_dso_flag, + []() { cufft_dso_handle = GetCUFFTDsoHandle(); }); + return cufft_dso_handle != nullptr; +} + +void EnforceCUFFTLoaded(const char* fn_name) { + PADDLE_ENFORCE_NOT_NULL( + cufft_dso_handle, + platform::errors::PreconditionNotMet( + "Cannot load cufft shared library. Cannot invoke method %s.", + fn_name)); +} + +} // namespace dynload +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/dynload/cufft.h b/paddle/fluid/platform/dynload/cufft.h new file mode 100644 index 00000000000..ef924d7b5ee --- /dev/null +++ b/paddle/fluid/platform/dynload/cufft.h @@ -0,0 +1,113 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#ifdef PADDLE_WITH_CUDA +#include +#include +#include +#include // NOLINT + +#include "paddle/fluid/platform/dynload/dynamic_loader.h" +#include "paddle/fluid/platform/port.h" + +namespace paddle { +namespace platform { +namespace dynload { + +extern std::once_flag cufft_dso_flag; +extern void* cufft_dso_handle; +extern bool HasCUFFT(); + +extern void EnforceCUFFTLoaded(const char* fn_name); +#define DECLARE_DYNAMIC_LOAD_CUFFT_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ + using cufft_func = decltype(&::__name); \ + std::call_once(cufft_dso_flag, []() { \ + cufft_dso_handle = paddle::platform::dynload::GetCUFFTDsoHandle(); \ + }); \ + EnforceCUFFTLoaded(#__name); \ + static void* p_##__name = dlsym(cufft_dso_handle, #__name); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ + }; \ + extern struct DynLoad__##__name __name + +/** + * include all needed cufft functions in HPPL + * different cufft version has different interfaces + **/ +#define CUFFT_FFT_ROUTINE_EACH(__macro) \ + __macro(cufftPlan1d); \ + __macro(cufftPlan2d); \ + __macro(cufftPlan3d); \ + __macro(cufftPlanMany); \ + __macro(cufftMakePlan1d); \ + __macro(cufftMakePlan2d); \ + __macro(cufftMakePlan3d); \ + __macro(cufftMakePlanMany); \ + __macro(cufftMakePlanMany64); \ + __macro(cufftGetSizeMany64); \ + __macro(cufftEstimate1d); \ + __macro(cufftEstimate2d); \ + __macro(cufftEstimate3d); \ + __macro(cufftEstimateMany); \ + __macro(cufftCreate); \ + __macro(cufftGetSize1d); \ + __macro(cufftGetSize2d); \ + __macro(cufftGetSize3d); \ + __macro(cufftGetSizeMany); \ + __macro(cufftGetSize); \ + __macro(cufftSetWorkArea); \ + __macro(cufftSetAutoAllocation); \ + __macro(cufftExecC2C); \ + __macro(cufftExecR2C); \ + __macro(cufftExecC2R); \ + __macro(cufftExecZ2Z); \ + __macro(cufftExecD2Z); \ + __macro(cufftExecZ2D); \ + __macro(cufftSetStream); \ + __macro(cufftDestroy); \ + __macro(cufftGetVersion); \ + __macro(cufftGetProperty); \ + __macro(cufftXtSetGPUs); \ + __macro(cufftXtMalloc); \ + __macro(cufftXtMemcpy); \ + __macro(cufftXtFree); \ + __macro(cufftXtSetWorkArea); \ + __macro(cufftXtExecDescriptorC2C); \ + __macro(cufftXtExecDescriptorR2C); \ + __macro(cufftXtExecDescriptorC2R); \ + __macro(cufftXtExecDescriptorZ2Z); \ + __macro(cufftXtExecDescriptorD2Z); \ + __macro(cufftXtExecDescriptorZ2D); \ + __macro(cufftXtQueryPlan); \ + __macro(cufftXtSetCallback); \ + __macro(cufftXtClearCallback); \ + __macro(cufftXtSetCallbackSharedSize); \ + __macro(cufftXtMakePlanMany); \ + __macro(cufftXtGetSizeMany); \ + __macro(cufftXtExec); \ + __macro(cufftXtExecDescriptor); \ + __macro(cufftXtSetWorkAreaPolicy); + +CUFFT_FFT_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUFFT_WRAP) + +} // namespace dynload +} // namespace platform +} // namespace paddle + +#endif diff --git a/paddle/fluid/platform/dynload/dynamic_loader.cc b/paddle/fluid/platform/dynload/dynamic_loader.cc index 37932600e7a..bf2dc7aaba1 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.cc +++ b/paddle/fluid/platform/dynload/dynamic_loader.cc @@ -109,6 +109,9 @@ static constexpr char* win_cusolver_lib = static constexpr char* win_cusparse_lib = "cusparse64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR ".dll;cusparse64_" CUDA_VERSION_MAJOR ".dll;cusparse64_10.dll"; +static constexpr char* win_cufft_lib = + "cufft64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR + ".dll;cufft64_" CUDA_VERSION_MAJOR ".dll;cufft64_10.dll"; #else static constexpr char* win_curand_lib = "curand64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR @@ -122,6 +125,9 @@ static constexpr char* win_cusolver_lib = static constexpr char* win_cusparse_lib = "cusparse64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR ".dll;cusparse64_" CUDA_VERSION_MAJOR ".dll"; +static constexpr char* win_cufft_lib = + "cufft64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR + ".dll;cufft64_" CUDA_VERSION_MAJOR ".dll"; #endif // CUDA_VERSION #endif @@ -489,6 +495,17 @@ void* GetNvtxDsoHandle() { #endif } +void* GetCUFFTDsoHandle() { +#if defined(__APPLE__) || defined(__OSX__) + return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcufft.dylib"); +#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA) + return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, win_cufft_lib, true, + {cuda_lib_path}); +#else + return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcufft.so"); +#endif +} + } // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/dynload/dynamic_loader.h b/paddle/fluid/platform/dynload/dynamic_loader.h index e282c033c44..08f0aec8b01 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.h +++ b/paddle/fluid/platform/dynload/dynamic_loader.h @@ -41,6 +41,7 @@ void* GetTensorRtDsoHandle(); void* GetMKLMLDsoHandle(); void* GetOpDsoHandle(const std::string& dso_name); void* GetNvtxDsoHandle(); +void* GetCUFFTDsoHandle(); void SetPaddleLibPath(const std::string&); } // namespace dynload diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 94b049d5d30..6bd58ee558f 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -64,6 +64,7 @@ import paddle.reader # noqa: F401 import paddle.static # noqa: F401 import paddle.vision # noqa: F401 +from .tensor import fft from .tensor.random import bernoulli # noqa: F401 from .tensor.attribute import rank # noqa: F401 diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 06a14295a81..6a1320e65ab 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -6727,8 +6727,10 @@ def pad(x, paddings, pad_value=0., name=None): x = fluid.data(name='data', shape=[300, 300], dtype='float32') out = fluid.layers.pad(x=x, paddings=[0, 1, 1, 2], pad_value=0.) """ - check_variable_and_dtype( - x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], "pad") + check_variable_and_dtype(x, 'x', [ + 'float16', 'float32', 'float64', 'int32', 'int64', 'complex64', + 'complex128' + ], "pad") helper = LayerHelper('pad', **locals()) dtype = helper.input_dtype(input_param_name='x') diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 00067095209..951dae1e61b 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -702,6 +702,7 @@ endif() add_subdirectory(sequence) add_subdirectory(dygraph_to_static) add_subdirectory(rnn) +add_subdirectory(fft) if (WITH_XPU) add_subdirectory(xpu) diff --git a/python/paddle/fluid/tests/unittests/fft/CMakeLists.txt b/python/paddle/fluid/tests/unittests/fft/CMakeLists.txt new file mode 100644 index 00000000000..f71e04c09aa --- /dev/null +++ b/python/paddle/fluid/tests/unittests/fft/CMakeLists.txt @@ -0,0 +1,6 @@ +file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") +string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") + +foreach(TEST_OP ${TEST_OPS}) + py_test_modules(${TEST_OP} MODULES ${TEST_OP}) +endforeach(TEST_OP) diff --git a/python/paddle/fluid/tests/unittests/fft/__init__.py b/python/paddle/fluid/tests/unittests/fft/__init__.py new file mode 100644 index 00000000000..b9a7651e449 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/fft/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/paddle/fluid/tests/unittests/fft/spectral_op_np.py b/python/paddle/fluid/tests/unittests/fft/spectral_op_np.py new file mode 100644 index 00000000000..b00111f6821 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/fft/spectral_op_np.py @@ -0,0 +1,108 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from functools import partial +from numpy import asarray +from numpy.fft._pocketfft import _raw_fft, _raw_fftnd, _get_forward_norm, _get_backward_norm, _cook_nd_args + + +def _fftc2c(a, n=None, axis=-1, norm=None, forward=None): + a = asarray(a) + if n is None: + n = a.shape[axis] + if forward: + inv_norm = _get_forward_norm(n, norm) + else: + inv_norm = _get_backward_norm(n, norm) + output = _raw_fft(a, n, axis, False, forward, inv_norm) + return output + + +def _fftr2c(a, n=None, axis=-1, norm=None, forward=None): + a = asarray(a) + if n is None: + n = a.shape[axis] + if forward: + inv_norm = _get_forward_norm(n, norm) + else: + inv_norm = _get_backward_norm(n, norm) + output = _raw_fft(a, n, axis, True, True, inv_norm) + if not forward: + output = output.conj() + return output + + +def _fftc2r(a, n=None, axis=-1, norm=None, forward=None): + a = asarray(a) + if n is None: + n = (a.shape[axis] - 1) * 2 + if forward: + inv_norm = _get_forward_norm(n, norm) + else: + inv_norm = _get_backward_norm(n, norm) + output = _raw_fft(a.conj() + if forward else a, n, axis, True, False, inv_norm) + return output + + +def fft_c2c(x, axes, normalization, forward): + f = partial(_fftc2c, forward=forward) + y = _raw_fftnd(x, s=None, axes=axes, function=f, norm=normalization) + return y + + +def fft_c2c_backward(dy, axes, normalization, forward): + f = partial(_fftc2c, forward=forward) + dx = _raw_fftnd(dy, s=None, axes=axes, function=f, norm=normalization) + return dx + + +def fft_r2c(x, axes, normalization, forward, onesided): + a = asarray(x) + s, axes = _cook_nd_args(a, axes=axes) + if onesided: + a = _fftr2c(a, s[-1], axes[-1], normalization, forward) + for ii in range(len(axes) - 1): + a = _fftc2c(a, s[ii], axes[ii], normalization, forward) + else: + a = fft_c2c(x, axes, normalization, forward) + return a + + +def fft_r2c_backward(dy, x, axes, normalization, forward, onesided): + a = dy + if not onesided: + a = fft_c2c_backward(a, axes, normalization, forward).real + else: + pad_widths = [(0, 0)] * a.ndim + last_axis = axes[-1] + if last_axis < 0: + last_axis += a.ndim + last_dim_size = a.shape[last_axis] + pad_widths[last_axis] = (0, x.shape[last_axis] - last_dim_size) + a = np.pad(a, pad_width=pad_widths) + a = fft_c2c_backward(a, axes, normalization, forward).real + return a + + +def fft_c2r(x, axes, normalization, forward, last_dim_size): + a = asarray(x) + s, axes = _cook_nd_args(a, axes=axes, invreal=1) + if last_dim_size is not None: + s[-1] = last_dim_size + for ii in range(len(axes) - 1): + a = _fftc2c(a, s[ii], axes[ii], normalization, forward) + a = _fftc2r(a, s[-1], axes[-1], normalization, forward) + return a diff --git a/python/paddle/fluid/tests/unittests/fft/test_fft.py b/python/paddle/fluid/tests/unittests/fft/test_fft.py new file mode 100644 index 00000000000..26355e0411f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/fft/test_fft.py @@ -0,0 +1,960 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import re +import sys +import unittest + +import numpy as np +import paddle +import scipy.fft + +DEVICES = [paddle.CPUPlace()] +if paddle.is_compiled_with_cuda(): + DEVICES.append(paddle.CUDAPlace(0)) + +TEST_CASE_NAME = 'suffix' +# All test case will use float64 for compare percision, refs: +# https://github.com/PaddlePaddle/Paddle/wiki/Upgrade-OP-Precision-to-Float64 +RTOL = { + 'float32': 1e-03, + 'complex64': 1e-3, + 'float64': 1e-7, + 'complex128': 1e-7 +} +ATOL = {'float32': 0.0, 'complex64': 0, 'float64': 0.0, 'complex128': 0} + + +def rand_x(dims=1, + dtype='float64', + min_dim_len=1, + max_dim_len=10, + complex=False): + shape = [np.random.randint(min_dim_len, max_dim_len) for i in range(dims)] + if complex: + return np.random.randn(*shape).astype(dtype) + 1.j * np.random.randn( + *shape).astype(dtype) + else: + return np.random.randn(*shape).astype(dtype) + + +def place(devices, key='place'): + def decorate(cls): + module = sys.modules[cls.__module__].__dict__ + raw_classes = { + k: v + for k, v in module.items() if k.startswith(cls.__name__) + } + + for raw_name, raw_cls in raw_classes.items(): + for d in devices: + test_cls = dict(raw_cls.__dict__) + test_cls.update({key: d}) + new_name = raw_name + '.' + d.__class__.__name__ + module[new_name] = type(new_name, (raw_cls, ), test_cls) + del module[raw_name] + return cls + + return decorate + + +def parameterize(fields, values=None): + + fields = [fields] if isinstance(fields, str) else fields + params = [dict(zip(fields, vals)) for vals in values] + + def decorate(cls): + test_cls_module = sys.modules[cls.__module__].__dict__ + for k, v in enumerate(params): + test_cls = dict(cls.__dict__) + test_cls.update(v) + name = cls.__name__ + str(k) + name = name + '.' + v.get('suffix') if v.get('suffix') else name + + test_cls_module[name] = type(name, (cls, ), test_cls) + + for m in list(cls.__dict__): + if m.startswith("test"): + delattr(cls, m) + return cls + + return decorate + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), + [('test_x_float64', rand_x(5, np.float64), None, -1, 'backward'), + ('test_x_complex', rand_x( + 5, complex=True), None, -1, + 'backward'), ('test_n_grater_input_length', rand_x( + 5, max_dim_len=5), 11, -1, + 'backward'), ('test_n_smaller_than_input_length', rand_x( + 5, min_dim_len=5, complex=True), 3, -1, 'backward'), + ('test_axis_not_last', rand_x(5), None, 3, 'backward'), + ('test_norm_forward', rand_x(5), None, 3, 'forward'), + ('test_norm_ortho', rand_x(5), None, 3, 'ortho')]) +class TestFft(unittest.TestCase): + def test_fft(self): + with paddle.fluid.dygraph.guard(self.place): + self.assertTrue( + np.allclose( + scipy.fft.fft(self.x, self.n, self.axis, self.norm), + paddle.fft.fft( + paddle.to_tensor(self.x), self.n, self.axis, self.norm), + rtol=RTOL.get(str(self.x.dtype)), + atol=ATOL.get(str(self.x.dtype)))) + + +@place(DEVICES) +@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ + ('test_n_nagative', rand_x(2), -1, -1, 'backward', ValueError), + ('test_n_zero', rand_x(2), 0, -1, 'backward', ValueError), + ('test_axis_out_of_range', rand_x(1), None, 10, 'backward', ValueError), + ('test_axis_with_array', rand_x(1), None, (0, 1), 'backward', ValueError), + ('test_norm_not_in_enum_value', rand_x(2), None, -1, 'random', ValueError) +]) +class TestFftException(unittest.TestCase): + def test_Fft(self): + with self.assertRaises(self.expect_exception): + paddle.fft.fft( + paddle.to_tensor(self.x), self.n, self.axis, self.norm) + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ + ('test_x_float64', rand_x(5), None, (0, 1), 'backward'), + ('test_x_complex128', rand_x( + 5, complex=True), None, (0, 1), 'backward'), + ('test_n_grater_input_length', rand_x( + 5, max_dim_len=5), (6, 6), (0, 1), 'backward'), + ('test_n_smaller_than_input_length', rand_x( + 5, min_dim_len=5, complex=True), (4, 4), (0, 1), 'backward'), + ('test_axis_random', rand_x(5), None, (1, 2), 'backward'), + ('test_axis_none', rand_x(5), None, None, 'backward'), + ('test_norm_forward', rand_x(5), None, (0, 1), 'forward'), + ('test_norm_ortho', rand_x(5), None, (0, 1), 'ortho'), + ]) +class TestFft2(unittest.TestCase): + def test_Fft2(self): + with paddle.fluid.dygraph.guard(self.place): + self.assertTrue( + np.allclose( + scipy.fft.fft2(self.x, self.n, self.axis, self.norm), + paddle.fft.fft2( + paddle.to_tensor(self.x), self.n, self.axis, self.norm), + rtol=RTOL.get(str(self.x.dtype)), + atol=ATOL.get(str(self.x.dtype)))) + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), + [('test_x_complex_input', rand_x( + 2, complex=True), None, (0, 1), None, + ValueError), ('test_x_1dim_tensor', rand_x(1), None, (0, 1), None, + ValueError), ('test_n_nagative', rand_x(2), -1, (0, 1), + 'backward', ValueError), + ('test_n_len_not_equal_axis', rand_x( + 5, max_dim_len=5), 11, (0, 1), 'backward', + ValueError), ('test_n_zero', rand_x(2), (0, 0), (0, 1), 'backward', + ValueError), ('test_axis_out_of_range', rand_x(2), None, + (0, 1, 2), 'backward', ValueError), + ('test_axis_with_array', rand_x(1), None, (0, 1), 'backward', ValueError), + ('test_axis_not_sequence', rand_x(5), None, -10, 'backward', ValueError), + ('test_norm_not_enum', rand_x(2), None, -1, 'random', ValueError)]) +class TestFft2Exception(unittest.TestCase): + def test_fft2(self): + with paddle.fluid.dygraph.guard(self.place): + with self.assertRaises(self.expect_exception): + paddle.fft.fft2( + paddle.to_tensor(self.x), self.n, self.axis, self.norm) + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), + [('test_x_float64', rand_x(5, np.float64), None, None, 'backward'), + ('test_x_complex128', rand_x( + 5, complex=True), None, None, + 'backward'), ('test_n_grater_input_length', rand_x( + 5, max_dim_len=5), (6, 6), (1, 2), 'backward'), ( + 'test_n_smaller_input_length', rand_x( + 5, min_dim_len=5, complex=True), (3, 3), (1, 2), 'backward'), + ('test_axis_not_default', rand_x(5), None, (1, 2), + 'backward'), ('test_norm_forward', rand_x(5), None, None, 'forward'), + ('test_norm_ortho', rand_x(5), None, None, 'ortho')]) +class TestFftn(unittest.TestCase): + def test_Fftn(self): + with paddle.fluid.dygraph.guard(self.place): + np.testing.assert_allclose( + scipy.fft.fftn(self.x, self.n, self.axis, self.norm), + paddle.fft.fftn( + paddle.to_tensor(self.x), self.n, self.axis, self.norm), + rtol=RTOL.get(str(self.x.dtype)), + atol=ATOL.get(str(self.x.dtype))) + + +@place(DEVICES) +@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ + ('test_x_complex128', + (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) + ).astype(np.complex128), None, -1, "backward"), + ('test_n_grater_than_input_length', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), 4, -1, + "backward"), + ('test_n_smaller_than_input_length', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), 2, -1, + "backward"), + ('test_axis_not_last', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, 1, + "backward"), + ('test_norm_forward', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, 1, + "forward"), + ('test_norm_ortho', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, -1, + "ortho"), +]) +class TestHfft(unittest.TestCase): + """Test hfft with norm condition + """ + + def test_hfft(self): + with paddle.fluid.dygraph.guard(self.place): + np.testing.assert_allclose( + scipy.fft.hfft(self.x, self.n, self.axis, self.norm), + paddle.fft.hfft( + paddle.to_tensor(self.x), self.n, self.axis, self.norm), + rtol=1e-5, + atol=0) + + +@place(DEVICES) +@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ + ('test_x_complex128', + (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) + ).astype(np.complex128), None, -1, "backward"), + ('test_n_grater_than_input_length', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), 4, -1, + "backward"), + ('test_n_smaller_than_input_length', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), 2, -1, + "backward"), + ('test_axis_not_last', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, -1, + "backward"), + ('test_norm_forward', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, -1, + "forward"), + ('test_norm_ortho', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, -1, + "ortho"), +]) +class TestIrfft(unittest.TestCase): + """Test irfft with norm condition + """ + + def test_irfft(self): + with paddle.fluid.dygraph.guard(self.place): + np.testing.assert_allclose( + scipy.fft.irfft(self.x, self.n, self.axis, self.norm), + paddle.fft.irfft( + paddle.to_tensor(self.x), self.n, self.axis, self.norm), + rtol=1e-5, + atol=0) + + +@place(DEVICES) +@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ + ('test_x_complex128', + (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) + ).astype(np.complex128), None, None, "backward"), + ('test_n_grater_than_input_length', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), [4], None, + "backward"), + ('test_n_smaller_than_input_length', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), [2], None, + "backward"), + ('test_axis_not_last', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, None, + "backward"), + ('test_norm_forward', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, None, + "forward"), + ('test_norm_ortho', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, None, + "ortho"), +]) +class Testirfftn(unittest.TestCase): + """Test irfftn with norm condition + """ + + def test_irfftn(self): + with paddle.fluid.dygraph.guard(self.place): + np.testing.assert_allclose( + scipy.fft.irfftn(self.x, self.n, self.axis, self.norm), + paddle.fft.irfftn( + paddle.to_tensor(self.x), self.n, self.axis, self.norm), + rtol=1e-5, + atol=0) + + +@place(DEVICES) +@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ + ('test_x_complex128', + (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) + ).astype(np.complex128), None, None, "backward"), + ('test_n_grater_than_input_length', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), [4], None, + "backward"), + ('test_n_smaller_than_input_length', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), [2], None, + "backward"), + ('test_axis_not_last', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, None, + "backward"), + ('test_norm_forward', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, None, + "forward"), + ('test_norm_ortho', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, None, + "ortho"), +]) +class Testhfftn(unittest.TestCase): + """Test hfftn with norm condition + """ + + def test_hfftn(self): + with paddle.fluid.dygraph.guard(self.place): + np.testing.assert_allclose( + scipy.fft.hfftn(self.x, self.n, self.axis, self.norm), + paddle.fft.hfftn( + paddle.to_tensor(self.x), self.n, self.axis, self.norm), + rtol=1e-5, + atol=0) + + +@place(DEVICES) +@parameterize((TEST_CASE_NAME, 'x', 's', 'axis', 'norm'), [ + ('test_x_complex128', + (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) + ).astype(np.complex128), None, (-2, -1), "backward"), + ('test_with_s', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + [2, 2], (-2, -1), "backward", ValueError), + ('test_axis_not_last', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1), + "backward"), + ('test_norm_forward', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1), + "forward"), + ('test_norm_ortho', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1), + "ortho"), +]) +class Testhfft2(unittest.TestCase): + """Test hfft2 with norm condition + """ + + def test_hfft2(self): + with paddle.fluid.dygraph.guard(self.place): + np.testing.assert_allclose( + scipy.fft.hfft2(self.x, self.s, self.axis, self.norm), + paddle.fft.hfft2( + paddle.to_tensor(self.x), self.s, self.axis, self.norm), + rtol=1e-5, + atol=0) + + +@place(DEVICES) +@parameterize((TEST_CASE_NAME, 'x', 's', 'axis', 'norm'), [ + ('test_x_complex128', + (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) + ).astype(np.complex128), None, (-2, -1), "backward"), + ('test_n_equal_input_length', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (4, 6), (-2, -1), + "backward"), + ('test_axis_not_last', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1), + "backward"), + ('test_norm_forward', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1), + "forward"), + ('test_norm_ortho', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1), + "ortho"), +]) +class TestIrfft2(unittest.TestCase): + """Test irfft2 with norm condition + """ + + def test_irfft2(self): + with paddle.fluid.dygraph.guard(self.place): + np.testing.assert_allclose( + scipy.fft.irfft2(self.x, self.s, self.axis, self.norm), + paddle.fft.irfft2( + paddle.to_tensor(self.x), self.s, self.axis, self.norm), + rtol=1e-5, + atol=0) + + +@place(DEVICES) +@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [( + 'test_bool_input', + (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4)).astype(np.bool8), + None, -1, 'backward', NotImplementedError), ( + 'test_n_nagative', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), -1, -1, + 'backward', ValueError), ( + 'test_n_zero', np.random.randn(4, 4) + 1j * np.random.randn(4, 4), + 0, -1, 'backward', ValueError), ( + 'test_n_type', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + (1, 2, 3), -1, 'backward', ValueError), ( + 'test_axis_out_of_range', + np.random.randn(4) + 1j * np.random.randn(4), None, 10, + 'backward', ValueError), ( + 'test_axis_with_array', + np.random.randn(4) + 1j * np.random.randn(4), None, + (0, 1), 'backward', ValueError), ( + 'test_norm_not_in_enum_value', + np.random.randn(4, 4) + 1j * np.random.randn(4, 4), + None, -1, 'random', ValueError)]) +class TestHfftException(unittest.TestCase): + '''Test hfft with buoudary condition + Test case include: + - n out of range + - axis out of range + - norm out of range + ''' + + def test_hfft(self): + with paddle.fluid.dygraph.guard(self.place): + with self.assertRaises(self.expect_exception): + paddle.fft.hfft( + paddle.to_tensor(self.x), self.n, self.axis, self.norm) + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), + [('test_n_nagative', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), -1, -1, + 'backward', ValueError), + ('test_n_zero', np.random.randn(4, 4) + 1j * np.random.randn(4, 4), 0, -1, + 'backward', ValueError), + ('test_n_type', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + (1, 2), -1, 'backward', ValueError), + ('test_axis_out_of_range', np.random.randn(4) + 1j * np.random.randn(4), + None, 10, 'backward', ValueError), + ('test_axis_with_array', np.random.randn(4) + 1j * np.random.randn(4), + None, (0, 1), 'backward', + ValueError), ('test_norm_not_in_enum_value', + np.random.randn(4, 4) + 1j * np.random.randn(4, 4), None, + None, 'random', ValueError)]) +class TestIrfftException(unittest.TestCase): + '''Test Irfft with buoudary condition + Test case include: + - n out of range + - axis out of range + - norm out of range + - the dimensions of n and axis are different + ''' + + def test_irfft(self): + with paddle.fluid.dygraph.guard(self.place): + with self.assertRaises(self.expect_exception): + paddle.fft.irfft( + paddle.to_tensor(self.x), self.n, self.axis, self.norm) + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), + [('test_bool_input', + (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) + ).astype(np.bool8), None, (-2, -1), 'backward', NotImplementedError), + ('test_n_nagative', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (-1, -2), + (-2, -1), 'backward', ValueError), + ('test_n_zero', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + (0, 0), (-2, -1), 'backward', ValueError), + ('test_n_type', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + 3, None, 'backward', ValueError), + ('test_n_axis_dim', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (1, 2), (-1), + 'backward', ValueError), ('test_axis_out_of_range', + np.random.randn(4) + 1j * np.random.randn(4), + None, (1, 2), 'backward', ValueError), + ('test_axis_type', np.random.randn(4) + 1j * np.random.randn(4), None, -1, + 'backward', + ValueError), ('test_norm_not_in_enum_value', + np.random.randn(4, 4) + 1j * np.random.randn(4, 4), None, + None, 'random', ValueError)]) +class TestHfft2Exception(unittest.TestCase): + '''Test hfft2 with buoudary condition + Test case include: + - n out of range + - axis out of range + - the dimensions of n and axis are different + - norm out of range + ''' + + def test_hfft2(self): + with paddle.fluid.dygraph.guard(self.place): + with self.assertRaises(self.expect_exception): + paddle.fft.hfft2( + paddle.to_tensor(self.x), self.n, self.axis, self.norm) + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), + [('test_n_nagative', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (-1, -2), + (-2, -1), 'backward', ValueError), + ('test_zero_point', + np.random.randn(4, 4, 1) + 1j * np.random.randn(4, 4, 1), None, (-2, -1), + "backward", ValueError), + ('test_n_zero', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + (0, 0), (-2, -1), 'backward', ValueError), + ('test_n_type', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + 3, -1, 'backward', + ValueError), ('test_n_axis_dim', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + (1, 2), (-3, -2, -1), 'backward', ValueError), + ('test_axis_out_of_range', np.random.randn(4) + 1j * np.random.randn(4), + None, (1, 2), 'backward', ValueError), ( + 'test_axis_type', np.random.randn(4) + 1j * np.random.randn(4), None, + 1, 'backward', + ValueError), ('test_norm_not_in_enum_value', + np.random.randn(4, 4) + 1j * np.random.randn(4, 4), + None, None, 'random', ValueError)]) +class TestIrfft2Exception(unittest.TestCase): + '''Test irfft2 with buoudary condition + Test case include: + - n out of range + - axis out of range + - norm out of range + - the dimensions of n and axis are different + ''' + + def test_irfft2(self): + with paddle.fluid.dygraph.guard(self.place): + with self.assertRaises(self.expect_exception): + paddle.fft.irfft2( + paddle.to_tensor(self.x), self.n, self.axis, self.norm) + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), + [('test_bool_input', + (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) + ).astype(np.bool8), None, (-2, -1), 'backward', NotImplementedError), + ('test_n_nagative', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (-1, -2), + (-2, -1), 'backward', ValueError), + ('test_n_zero', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + (0, 0), (-2, -1), 'backward', ValueError), + ('test_n_type', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + 3, -1, 'backward', ValueError), + ('test_n_axis_dim', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + (1, 2), (-3, -2, -1), 'backward', + ValueError), ('test_axis_out_of_range', + np.random.randn(4) + 1j * np.random.randn(4), None, + (10, 20), 'backward', ValueError), + ('test_axis_type', np.random.randn(4) + 1j * np.random.randn(4), None, 1, + 'backward', + ValueError), ('test_norm_not_in_enum_value', + np.random.randn(4, 4) + 1j * np.random.randn(4, 4), None, + None, 'random', ValueError)]) +class TestHfftnException(unittest.TestCase): + '''Test hfftn with buoudary condition + Test case include: + - n out of range + - axis out of range + - norm out of range + - the dimensions of n and axis are different + ''' + + def test_hfftn(self): + with paddle.fluid.dygraph.guard(self.place): + with self.assertRaises(self.expect_exception): + paddle.fft.hfftn( + paddle.to_tensor(self.x), self.n, self.axis, self.norm) + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), + [('test_n_nagative', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (-1, -2), + (-2, -1), 'backward', ValueError), + ('test_n_zero', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + (0, 0), (-2, -1), 'backward', ValueError), + ('test_n_type', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + 3, -1, 'backward', + ValueError), ('test_n_axis_dim', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + (1, 2), (-3, -2, -1), 'backward', ValueError), + ('test_axis_out_of_range', np.random.randn(4) + 1j * np.random.randn(4), + None, (10, 20), 'backward', ValueError), + ('test_axis_type', np.random.randn(4) + 1j * np.random.randn(4), None, 1, + 'backward', + ValueError), ('test_norm_not_in_enum_value', + np.random.randn(4, 4) + 1j * np.random.randn(4, 4), None, + None, 'random', ValueError)]) +class TestIrfftnException(unittest.TestCase): + '''Test irfftn with buoudary condition + Test case include: + - n out of range + - axis out of range + - norm out of range + - the dimensions of n and axis are different + ''' + + def test_irfftn(self): + with paddle.fluid.dygraph.guard(self.place): + with self.assertRaises(self.expect_exception): + paddle.fft.irfftn( + paddle.to_tensor(self.x), self.n, self.axis, self.norm) + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), + [('test_x_float64', rand_x(5, np.float64), None, -1, 'backward'), ( + 'test_n_grater_than_input_length', rand_x( + 5, max_dim_len=5), 11, -1, 'backward'), + ('test_n_smaller_than_input_length', rand_x( + 5, min_dim_len=5), 3, -1, + 'backward'), ('test_axis_not_last', rand_x(5), None, 3, 'backward'), + ('test_norm_forward', rand_x(5), None, 3, 'forward'), + ('test_norm_ortho', rand_x(5), None, 3, 'ortho')]) +class TestRfft(unittest.TestCase): + def test_rfft(self): + with paddle.fluid.dygraph.guard(self.place): + self.assertTrue( + np.allclose( + scipy.fft.rfft(self.x, self.n, self.axis, self.norm), + paddle.fft.rfft( + paddle.to_tensor(self.x), self.n, self.axis, self.norm), + rtol=RTOL.get(str(self.x.dtype)), + atol=ATOL.get(str(self.x.dtype)))) + + +@place(DEVICES) +@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ + ('test_n_nagative', rand_x(2), -1, -1, 'backward', ValueError), + ('test_n_zero', rand_x(2), 0, -1, 'backward', ValueError), + ('test_axis_out_of_range', rand_x(1), None, 10, 'backward', ValueError), + ('test_axis_with_array', rand_x(1), None, (0, 1), 'backward', ValueError), + ('test_norm_not_in_enum_value', rand_x(2), None, -1, 'random', ValueError) +]) +class TestRfftException(unittest.TestCase): + def test_rfft(self): + with self.assertRaises(self.expect_exception): + paddle.fft.rfft( + paddle.to_tensor(self.x), self.n, self.axis, self.norm) + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ + ('test_x_float64', rand_x(5), None, (0, 1), 'backward'), + ('test_n_grater_input_length', rand_x( + 5, max_dim_len=5), (6, 6), (0, 1), 'backward'), + ('test_n_smaller_than_input_length', rand_x( + 5, min_dim_len=5), (4, 4), (0, 1), 'backward'), + ('test_axis_random', rand_x(5), None, (1, 2), 'backward'), + ('test_axis_none', rand_x(5), None, None, 'backward'), + ('test_norm_forward', rand_x(5), None, (0, 1), 'forward'), + ('test_norm_ortho', rand_x(5), None, (0, 1), 'ortho'), + ]) +class TestRfft2(unittest.TestCase): + def test_rfft2(self): + with paddle.fluid.dygraph.guard(self.place): + self.assertTrue( + np.allclose( + scipy.fft.rfft2(self.x, self.n, self.axis, self.norm), + paddle.fft.rfft2( + paddle.to_tensor(self.x), self.n, self.axis, self.norm), + rtol=RTOL.get(str(self.x.dtype)), + atol=ATOL.get(str(self.x.dtype)))) + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ + ('test_x_complex_input', rand_x( + 2, complex=True), None, (0, 1), 'backward', RuntimeError), + ('test_x_1dim_tensor', rand_x(1), None, (0, 1), 'backward', ValueError), + ('test_n_nagative', rand_x(2), -1, (0, 1), 'backward', ValueError), + ('test_n_zero', rand_x(2), 0, (0, 1), 'backward', ValueError), + ('test_axis_out_of_range', rand_x(2), None, (0, 1, 2), 'backward', + ValueError), + ('test_axis_with_array', rand_x(1), None, (0, 1), 'backward', + ValueError), + ('test_axis_not_sequence', rand_x(5), None, -10, 'backward', + ValueError), + ('test_norm_not_enum', rand_x(2), None, -1, 'random', ValueError), + ]) +class TestRfft2Exception(unittest.TestCase): + def test_rfft(self): + with paddle.fluid.dygraph.guard(self.place): + with self.assertRaises(self.expect_exception): + paddle.fft.rfft2( + paddle.to_tensor(self.x), self.n, self.axis, self.norm) + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ + ('test_x_float64', rand_x(5, np.float64), None, None, 'backward'), + ('test_n_grater_input_length', rand_x( + 5, max_dim_len=5), (6, 6), (1, 2), 'backward'), + ('test_n_smaller_input_length', rand_x( + 5, min_dim_len=5), (3, 3), (1, 2), 'backward'), + ('test_axis_not_default', rand_x(5), None, (1, 2), 'backward'), + ('test_norm_forward', rand_x(5), None, None, 'forward'), + ('test_norm_ortho', rand_x(5), None, None, 'ortho'), + ]) +class TestRfftn(unittest.TestCase): + def test_rfftn(self): + with paddle.fluid.dygraph.guard(self.place): + self.assertTrue( + np.allclose( + scipy.fft.rfftn(self.x, self.n, self.axis, self.norm), + paddle.fft.rfftn( + paddle.to_tensor(self.x), self.n, self.axis, self.norm), + rtol=RTOL.get(str(self.x.dtype)), + atol=ATOL.get(str(self.x.dtype)))) + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), + [('test_x_complex', rand_x( + 4, complex=True), None, None, 'backward', + RuntimeError), ('test_n_nagative', rand_x(4), (-1, -1), (1, 2), + 'backward', ValueError), + ('test_n_not_sequence', rand_x(4), -1, None, 'backward', ValueError), + ('test_n_zero', rand_x(4), 0, None, 'backward', ValueError), ( + 'test_axis_out_of_range', rand_x(1), None, [0, 1], 'backward', + ValueError), + ('test_norm_not_in_enum', rand_x(2), None, -1, 'random', ValueError)]) +class TestRfftnException(unittest.TestCase): + def test_rfft(self): + with paddle.fluid.dygraph.guard(self.place): + with self.assertRaises(self.expect_exception): + paddle.fft.rfftn( + paddle.to_tensor(self.x), self.n, self.axis, self.norm) + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), + [('test_x_float64', rand_x(5, np.float64), None, -1, 'backward'), ( + 'test_n_grater_than_input_length', rand_x( + 5, max_dim_len=5), 11, -1, 'backward'), + ('test_n_smaller_than_input_length', rand_x( + 5, min_dim_len=5), 3, -1, + 'backward'), ('test_axis_not_last', rand_x(5), None, 3, 'backward'), + ('test_norm_forward', rand_x(5), None, 3, 'forward'), + ('test_norm_ortho', rand_x(5), None, 3, 'ortho')]) +class TestIhfft(unittest.TestCase): + def test_ihfft(self): + with paddle.fluid.dygraph.guard(self.place): + np.testing.assert_allclose( + scipy.fft.ihfft(self.x, self.n, self.axis, self.norm), + paddle.fft.ihfft( + paddle.to_tensor(self.x), self.n, self.axis, self.norm), + rtol=RTOL.get(str(self.x.dtype)), + atol=ATOL.get(str(self.x.dtype))) + + +@place(DEVICES) +@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ + ('test_n_nagative', rand_x(2), -1, -1, 'backward', ValueError), + ('test_n_zero', rand_x(2), 0, -1, 'backward', ValueError), + ('test_axis_out_of_range', rand_x(1), None, 10, 'backward', ValueError), + ('test_axis_with_array', rand_x(1), None, (0, 1), 'backward', ValueError), + ('test_norm_not_in_enum_value', rand_x(2), None, -1, 'random', ValueError) +]) +class TestIhfftException(unittest.TestCase): + def test_ihfft(self): + with paddle.fluid.dygraph.guard(self.place): + with self.assertRaises(self.expect_exception): + paddle.fft.ihfft( + paddle.to_tensor(self.x), self.n, self.axis, self.norm) + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ + ('test_x_float64', rand_x(5), None, (0, 1), 'backward'), + ('test_n_grater_input_length', rand_x( + 5, max_dim_len=5), (11, 11), (0, 1), 'backward'), + ('test_n_smaller_than_input_length', rand_x( + 5, min_dim_len=5), (1, 1), (0, 1), 'backward'), + ('test_axis_random', rand_x(5), None, (1, 2), 'backward'), + ('test_axis_none', rand_x(5), None, None, 'backward'), + ('test_norm_forward', rand_x(5), None, (0, 1), 'forward'), + ('test_norm_ortho', rand_x(5), None, (0, 1), 'ortho'), + ]) +class TestIhfft2(unittest.TestCase): + def test_ihfft2(self): + with paddle.fluid.dygraph.guard(self.place): + np.testing.assert_allclose( + scipy.fft.ihfft2(self.x, self.n, self.axis, self.norm), + paddle.fft.ihfft2( + paddle.to_tensor(self.x), self.n, self.axis, self.norm), + rtol=RTOL.get(str(self.x.dtype)), + atol=ATOL.get(str(self.x.dtype))) + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), + [('test_x_complex_input', rand_x( + 2, complex=True), None, (0, 1), None, ValueError), + ('test_x_1dim_tensor', rand_x(1), None, (0, 1), None, + ValueError), ('test_n_nagative', rand_x(2), -1, (0, 1), 'backward', + ValueError), ('test_n_len_not_equal_axis', rand_x( + 5, max_dim_len=5), 11, (0, 1), 'backward', ValueError), + ('test_n_zero', rand_x(2), (0, 0), (0, 1), 'backward', ValueError), + ('test_axis_out_of_range', rand_x(2), None, (0, 1, 2), 'backward', + ValueError), ('test_axis_with_array', rand_x(1), None, (0, 1), 'backward', + ValueError), ('test_axis_not_sequence', rand_x(5), None, + -10, 'backward', ValueError), + ('test_norm_not_enum', rand_x(2), None, -1, 'random', ValueError)]) +class TestIhfft2Exception(unittest.TestCase): + def test_rfft(self): + with paddle.fluid.dygraph.guard(self.place): + with self.assertRaises(self.expect_exception): + paddle.fft.ihfft2( + paddle.to_tensor(self.x), self.n, self.axis, self.norm) + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), + [('test_x_float64', rand_x(5, np.float64), None, None, 'backward'), + ('test_n_grater_input_length', rand_x( + 5, max_dim_len=5), (11, 11), (0, 1), + 'backward'), ('test_n_smaller_input_length', rand_x( + 5, min_dim_len=5), (1, 1), (0, 1), 'backward'), + ('test_axis_not_default', rand_x(5), None, (1, 2), + 'backward'), ('test_norm_forward', rand_x(5), None, None, 'forward'), + ('test_norm_ortho', rand_x(5), None, None, 'ortho')]) +class TestIhfftn(unittest.TestCase): + def test_rfftn(self): + with paddle.fluid.dygraph.guard(self.place): + self.assertTrue( + np.allclose( + scipy.fft.ihfftn(self.x, self.n, self.axis, self.norm), + paddle.fft.ihfftn( + paddle.to_tensor(self.x), self.n, self.axis, self.norm), + rtol=RTOL.get(str(self.x.dtype)), + atol=ATOL.get(str(self.x.dtype)))) + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), + [('test_x_complex', rand_x( + 4, complex=True), None, None, 'backward', RuntimeError), + ('test_n_nagative', rand_x(4), -1, None, 'backward', ValueError), + ('test_n_zero', rand_x(4), 0, None, 'backward', ValueError), ( + 'test_axis_out_of_range', rand_x(1), None, [0, 1], 'backward', + ValueError), + ('test_norm_not_in_enum', rand_x(2), None, -1, 'random', ValueError)]) +class TestIhfftnException(unittest.TestCase): + def test_rfft(self): + with paddle.fluid.dygraph.guard(self.place): + with self.assertRaises(self.expect_exception): + paddle.fft.ihfftn( + paddle.to_tensor(self.x), self.n, self.axis, self.norm) + + +@place(DEVICES) +@parameterize((TEST_CASE_NAME, 'n', 'd', 'dtype'), [ + ('test_without_d', 20, 1, 'float32'), + ('test_with_d', 20, 0.5, 'float32'), +]) +class TestFftFreq(unittest.TestCase): + def test_fftfreq(self): + with paddle.fluid.dygraph.guard(self.place): + np.testing.assert_allclose( + scipy.fft.fftfreq(self.n, self.d).astype(self.dtype), + paddle.fft.fftfreq(self.n, self.d, self.dtype).numpy(), + rtol=RTOL.get(str(self.dtype)), + atol=ATOL.get(str(self.dtype))) + + +@place(DEVICES) +@parameterize((TEST_CASE_NAME, 'n', 'd', 'dtype'), [ + ('test_without_d', 20, 1, 'float32'), + ('test_with_d', 20, 0.5, 'float32'), +]) +class TestRfftFreq(unittest.TestCase): + def test_rfftfreq(self): + with paddle.fluid.dygraph.guard(self.place): + np.testing.assert_allclose( + scipy.fft.rfftfreq(self.n, self.d).astype(self.dtype), + paddle.fft.rfftfreq(self.n, self.d, self.dtype).numpy(), + rtol=RTOL.get(str(self.dtype)), + atol=ATOL.get(str(self.dtype))) + + +@place(DEVICES) +@parameterize((TEST_CASE_NAME, 'x', 'axes', 'dtype'), [ + ('test_1d', np.random.randn(10), (0, ), 'float64'), + ('test_2d', np.random.randn(10, 10), (0, 1), 'float64'), +]) +class TestFftShift(unittest.TestCase): + def test_fftshift(self): + with paddle.fluid.dygraph.guard(self.place): + np.testing.assert_allclose( + scipy.fft.fftshift(self.x, self.axes), + paddle.fft.fftshift(paddle.to_tensor(self.x), + self.axes).numpy(), + rtol=RTOL.get(str(self.x.dtype)), + atol=ATOL.get(str(self.x.dtype))) + + +@place(DEVICES) +@parameterize((TEST_CASE_NAME, 'x', 'axes'), [ + ('test_1d', np.random.randn(10), (0, ), 'float64'), + ('test_2d', np.random.randn(10, 10), (0, 1), 'float64'), +]) +class TestIfftShift(unittest.TestCase): + def test_ifftshift(self): + with paddle.fluid.dygraph.guard(self.place): + np.testing.assert_allclose( + scipy.fft.ifftshift(self.x, self.axes), + paddle.fft.ifftshift(paddle.to_tensor(self.x), + self.axes).numpy(), + rtol=RTOL.get(str(self.x.dtype)), + atol=ATOL.get(str(self.x.dtype))) + + +if __name__ == '__main__': + unittest.main() + +# yapf: enable diff --git a/python/paddle/fluid/tests/unittests/fft/test_fft_with_static_graph.py b/python/paddle/fluid/tests/unittests/fft/test_fft_with_static_graph.py new file mode 100644 index 00000000000..ac9d1557b53 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/fft/test_fft_with_static_graph.py @@ -0,0 +1,894 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import re +import sys +import unittest + +import numpy as np +import paddle +import scipy.fft + +from test_fft import (ATOL, DEVICES, RTOL, TEST_CASE_NAME, parameterize, place, + rand_x) + + +@contextlib.contextmanager +def stgraph(func, place, x, n, axes, norm): + """static graph exec context""" + paddle.enable_static() + mp, sp = paddle.static.Program(), paddle.static.Program() + with paddle.static.program_guard(mp, sp): + input = paddle.static.data('input', x.shape, dtype=x.dtype) + output = func(input, n, axes, norm) + + exe = paddle.static.Executor(place) + exe.run(sp) + [output] = exe.run(mp, feed={'input': x}, fetch_list=[output]) + yield output + paddle.disable_static() + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), + [('test_x_float64', rand_x(5, np.float64), None, -1, 'backward'), + ('test_x_complex64', rand_x( + 5, np.float64, complex=True), None, -1, + 'backward'), ('test_n_grater_than_input_length', rand_x( + 5, max_dim_len=5), 11, -1, + 'backward'), ('test_n_smaller_than_input_length', rand_x( + 5, min_dim_len=5), 3, -1, 'backward'), + ('test_axis_not_last', rand_x(5), None, 3, 'backward'), + ('test_norm_forward', rand_x(5), None, 3, 'forward'), + ('test_norm_ortho', rand_x(5), None, 3, 'ortho')]) +class TestFft(unittest.TestCase): + def test_static_rfft(self): + with stgraph(paddle.fft.fft, self.place, self.x, self.n, self.axis, + self.norm) as y: + np.testing.assert_allclose( + scipy.fft.fft(self.x, self.n, self.axis, self.norm), + y, + rtol=RTOL.get(str(self.x.dtype)), + atol=ATOL.get(str(self.x.dtype))) + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), + [('test_n_nagative', rand_x(2), -1, -1, 'backward', ValueError), + ('test_n_zero', rand_x(2), 0, -1, 'backward', ValueError), + ('test_axis_out_of_range', rand_x(1), None, 10, 'backward', + ValueError), ('test_axis_with_array', rand_x(1), None, (0, 1), 'backward', + ValueError), ('test_norm_not_in_enum_value', rand_x(2), + None, -1, 'random', ValueError)]) +class TestFftException(unittest.TestCase): + def test_fft(self): + with self.assertRaises(self.expect_exception): + with stgraph(paddle.fft.fft, self.place, self.x, self.n, self.axis, + self.norm) as y: + pass + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ + ('test_x_float64', rand_x(5), None, (0, 1), 'backward'), + ('test_x_complex128', rand_x( + 5, complex=True), None, (0, 1), 'backward'), + ('test_n_grater_input_length', rand_x( + 5, max_dim_len=5), (6, 6), (0, 1), 'backward'), + ('test_n_smaller_than_input_length', rand_x( + 5, min_dim_len=5), (4, 4), (0, 1), 'backward'), + ('test_axis_random', rand_x(5), None, (1, 2), 'backward'), + ('test_axis_none', rand_x(5), None, None, 'backward'), + ('test_norm_forward', rand_x(5), None, (0, 1), 'forward'), + ('test_norm_ortho', rand_x(5), None, (0, 1), 'ortho'), + ]) +class TestFft2(unittest.TestCase): + def test_static_fft2(self): + with stgraph(paddle.fft.fft2, self.place, self.x, self.n, self.axis, + self.norm) as y: + np.testing.assert_allclose( + scipy.fft.fft2(self.x, self.n, self.axis, self.norm), + y, + rtol=RTOL.get(str(self.x.dtype)), + atol=ATOL.get(str(self.x.dtype))) + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), + [ + # ('test_x_not_tensor', [0, 1], None, (0, 1), 'backward', ValueError), + ('test_x_1dim_tensor', rand_x(1), None, (0, 1), 'backward', ValueError), + ('test_n_nagative', rand_x(2), -1, (0, 1), 'backward', ValueError), + ('test_n_zero', rand_x(2), 0, (0, 1), 'backward', ValueError), + ('test_axis_out_of_range', rand_x(2), None, (0, 1, 2), 'backward', + ValueError), + ('test_axis_with_array', rand_x(1), None, (0, 1), 'backward', + ValueError), + ('test_axis_not_sequence', rand_x(5), None, -10, 'backward', + ValueError), + ('test_norm_not_enum', rand_x(2), None, -1, 'random', ValueError) + ]) +class TestFft2Exception(unittest.TestCase): + def test_static_fft2(self): + with self.assertRaises(self.expect_exception): + with stgraph(paddle.fft.fft2, self.place, self.x, self.n, self.axis, + self.norm) as y: + pass + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), + [('test_x_float64', rand_x(5, np.float64), None, None, 'backward'), + ('test_x_complex128', rand_x( + 5, np.float64, complex=True), None, None, + 'backward'), ('test_n_grater_input_length', rand_x( + 5, max_dim_len=5), (6, 6), (1, 2), + 'backward'), ('test_n_smaller_input_length', rand_x( + 5, min_dim_len=5), (3, 3), (1, 2), 'backward'), + ('test_axis_not_default', rand_x(5), None, (1, 2), + 'backward'), ('test_norm_forward', rand_x(5), None, None, 'forward'), + ('test_norm_ortho', rand_x(5), None, None, 'ortho')]) +class TestFftn(unittest.TestCase): + def test_static_fftn(self): + with stgraph(paddle.fft.fftn, self.place, self.x, self.n, self.axis, + self.norm) as y: + np.testing.assert_allclose( + scipy.fft.fftn(self.x, self.n, self.axis, self.norm), + y, + rtol=RTOL.get(str(self.x.dtype)), + atol=ATOL.get(str(self.x.dtype))) + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), + [('test_x_complex', rand_x( + 4, complex=True), None, None, 'backward', + TypeError), ('test_n_nagative', rand_x(4), (-1, -1), (1, 2), 'backward', + ValueError), ('test_n_not_sequence', rand_x(4), -1, None, + 'backward', ValueError), + ('test_n_zero', rand_x(4), 0, None, 'backward', ValueError), + ('test_axis_out_of_range', rand_x(1), None, [0, 1], 'backward', + ValueError), ('test_norm_not_in_enum', rand_x(2), None, -1, 'random', + ValueError)]) +class TestRfftnException(unittest.TestCase): + def test_static_rfftn(self): + with self.assertRaises(self.expect_exception): + with stgraph(paddle.fft.rfftn, self.place, self.x, self.n, + self.axis, self.norm) as y: + pass + + +@place(DEVICES) +@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ + ('test_x_complex128', + (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) + ).astype(np.complex128), None, -1, "backward"), + ('test_n_grater_than_input_length', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), 4, -1, + "backward"), + ('test_n_smaller_than_input_length', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), 2, -1, + "backward"), + ('test_axis_not_last', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, 1, + "backward"), + ('test_norm_forward', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, 1, + "forward"), + ('test_norm_ortho', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, -1, + "ortho"), +]) +class TestHfft(unittest.TestCase): + """Test hfft with norm condition + """ + + def test_hfft(self): + with stgraph(paddle.fft.hfft, self.place, self.x, self.n, self.axis, + self.norm) as y: + np.testing.assert_allclose( + scipy.fft.hfft(self.x, self.n, self.axis, self.norm), + y, + rtol=1e-5, + atol=0) + + +@place(DEVICES) +@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ + ('test_x_complex128', + (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) + ).astype(np.complex128), None, -1, "backward"), + ('test_n_grater_than_input_length', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), 4, -1, + "backward"), + ('test_n_smaller_than_input_length', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), 2, -1, + "backward"), + ('test_axis_not_last', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, -1, + "backward"), + ('test_norm_forward', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, -1, + "forward"), + ('test_norm_ortho', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, -1, + "ortho"), +]) +class TestIrfft(unittest.TestCase): + """Test irfft with norm condition + """ + + def test_irfft(self): + with stgraph(paddle.fft.irfft, self.place, self.x, self.n, self.axis, + self.norm) as y: + np.testing.assert_allclose( + scipy.fft.irfft(self.x, self.n, self.axis, self.norm), + y, + rtol=1e-5, + atol=0) + + +@place(DEVICES) +@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ + ('test_x_complex128', + (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) + ).astype(np.complex128), None, None, "backward"), + ('test_n_grater_than_input_length', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), [4], None, + "backward"), + ('test_n_smaller_than_input_length', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), [2], None, + "backward"), + ('test_axis_not_last', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, None, + "backward"), + ('test_norm_forward', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, None, + "forward"), + ('test_norm_ortho', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, None, + "ortho"), +]) +class Testirfftn(unittest.TestCase): + """Test irfftn with norm condition + """ + + def test_static_irfftn(self): + with stgraph(paddle.fft.irfftn, self.place, self.x, self.n, self.axis, + self.norm) as y: + np.testing.assert_allclose( + scipy.fft.irfftn(self.x, self.n, self.axis, self.norm), + y, + rtol=1e-5, + atol=0) + + +@place(DEVICES) +@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ + ('test_x_complex128', + (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) + ).astype(np.complex128), None, None, "backward"), + ('test_n_grater_than_input_length', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), [4], None, + "backward"), + ('test_n_smaller_than_input_length', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), [2], None, + "backward"), + ('test_axis_not_last', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, None, + "backward"), + ('test_norm_forward', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, None, + "forward"), + ('test_norm_ortho', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, None, + "ortho"), +]) +class Testhfftn(unittest.TestCase): + """Test hfftn with norm condition + """ + + def test_static_hfftn(self): + with stgraph(paddle.fft.hfftn, self.place, self.x, self.n, self.axis, + self.norm) as y: + np.testing.assert_allclose( + scipy.fft.hfftn(self.x, self.n, self.axis, self.norm), + y, + rtol=1e-5, + atol=0) + + +@place(DEVICES) +@parameterize((TEST_CASE_NAME, 'x', 's', 'axis', 'norm'), [ + ('test_x_complex128', + (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) + ).astype(np.complex128), None, (-2, -1), "backward"), + ('test_n_grater_input_length', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), [4, 8], (-2, -1), + "backward"), + ('test_n_smaller_input_length', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), [2, 4], (-2, -1), + "backward"), + ('test_axis_not_last', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1), + "backward"), + ('test_norm_forward', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1), + "forward"), + ('test_norm_ortho', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1), + "ortho"), +]) +class Testhfft2(unittest.TestCase): + """Test hfft2 with norm condition + """ + + def test_static_hfft2(self): + with stgraph(paddle.fft.hfft2, self.place, self.x, self.s, self.axis, + self.norm) as y: + np.testing.assert_allclose( + scipy.fft.hfft2(self.x, self.s, self.axis, self.norm), + y, + rtol=1e-5, + atol=0) + + +@place(DEVICES) +@parameterize((TEST_CASE_NAME, 'x', 's', 'axis', 'norm'), [ + ('test_x_complex128', + (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) + ).astype(np.complex128), None, (-2, -1), "backward"), + ('test_n_equal_input_length', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (2, 4), (-2, -1), + "backward"), + ('test_axis_not_last', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1), + "backward"), + ('test_norm_forward', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1), + "forward"), + ('test_norm_ortho', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1), + "ortho"), +]) +class TestIrfft2(unittest.TestCase): + """Test irfft2 with norm condition + """ + + def test_static_irfft2(self): + with stgraph(paddle.fft.irfft2, self.place, self.x, self.s, self.axis, + self.norm) as y: + np.testing.assert_allclose( + scipy.fft.irfft2(self.x, self.s, self.axis, self.norm), + y, + rtol=1e-5, + atol=0) + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), + [('test_input_dtype', np.random.randn(4, 4, 4), None, -1, 'backward', + TypeError), ('test_bool_input', + (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) + ).astype(np.bool8), None, -1, 'backward', TypeError), + ('test_n_nagative', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), -1, -1, + 'backward', ValueError), + ('test_n_zero', np.random.randn(4, 4) + 1j * np.random.randn(4, 4), 0, -1, + 'backward', ValueError), + ('test_n_type', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + (1, 2, 3), -1, 'backward', ValueError), + ('test_axis_out_of_range', np.random.randn(4) + 1j * np.random.randn(4), + None, 10, 'backward', ValueError), ( + 'test_axis_with_array', np.random.randn(4) + 1j * np.random.randn(4), + None, (0, 1), 'backward', + ValueError), ('test_norm_not_in_enum_value', + np.random.randn(4, 4) + 1j * np.random.randn(4, 4), + None, -1, 'random', ValueError)]) +class TestHfftException(unittest.TestCase): + '''Test hfft with buoudary condition + Test case include: + - non complex input + - n out of range + - axis out of range + - norm out of range + ''' + + def test_static_hfft(self): + with self.assertRaises(self.expect_exception): + with stgraph(paddle.fft.hfft, self.place, self.x, self.n, self.axis, + self.norm) as y: + pass + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), + [('test_input_dtype', np.random.randn(4, 4, 4), None, -1, 'backward', + TypeError), ('test_bool_input', + (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) + ).astype(np.bool8), None, -1, 'backward', TypeError), + ('test_n_nagative', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), -1, -1, + 'backward', ValueError), + ('test_n_zero', np.random.randn(4, 4) + 1j * np.random.randn(4, 4), 0, -1, + 'backward', ValueError), + ('test_n_type', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + (1, 2), -1, 'backward', ValueError), + ('test_axis_out_of_range', np.random.randn(4) + 1j * np.random.randn(4), + None, 10, 'backward', ValueError), ( + 'test_axis_with_array', np.random.randn(4) + 1j * np.random.randn(4), + None, (0, 1), 'backward', + ValueError), ('test_norm_not_in_enum_value', + np.random.randn(4, 4) + 1j * np.random.randn(4, 4), + None, None, 'random', ValueError)]) +class TestIrfftException(unittest.TestCase): + '''Test Irfft with buoudary condition + Test case include: + - non complex input + - n out of range + - axis out of range + - norm out of range + - the dimensions of n and axis are different + ''' + + def test_static_irfft(self): + with self.assertRaises(self.expect_exception): + with stgraph(paddle.fft.irfft, self.place, self.x, self.n, + self.axis, self.norm) as y: + pass + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), + [('test_input_dtype', np.random.randn(4, 4, 4), None, None, 'backward', + TypeError), ('test_bool_input', + (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) + ).astype(np.bool8), None, (-2, -1), 'backward', TypeError), + ('test_n_nagative', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (-1, -2), + (-2, -1), 'backward', ValueError), + ('test_n_zero', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + (0, 0), (-2, -1), 'backward', ValueError), + ('test_n_type', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + 3, None, 'backward', + ValueError), ('test_n_axis_dim', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + (1, 2), (-1), 'backward', ValueError), + ('test_axis_out_of_range', np.random.randn(4) + 1j * np.random.randn(4), + None, (1, 2), 'backward', ValueError), ( + 'test_axis_type', np.random.randn(4) + 1j * np.random.randn(4), None, + -1, 'backward', + ValueError), ('test_norm_not_in_enum_value', + np.random.randn(4, 4) + 1j * np.random.randn(4, 4), + None, None, 'random', ValueError)]) +class TestHfft2Exception(unittest.TestCase): + '''Test hfft2 with buoudary condition + Test case include: + - non complex input + - n out of range + - axis out of range + - the dimensions of n and axis are different + - norm out of range + ''' + + def test_static_hfft2(self): + with self.assertRaises(self.expect_exception): + with stgraph(paddle.fft.hfft2, self.place, self.x, self.n, + self.axis, self.norm) as y: + pass + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), + [('test_input_dtype', np.random.randn(4, 4, 4), None, None, 'backward', + TypeError), ('test_bool_input', + (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) + ).astype(np.bool8), None, (-2, -1), 'backward', TypeError), + ('test_n_nagative', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (-1, -2), + (-2, -1), 'backward', ValueError), + ('test_n_zero', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + (0, 0), (-2, -1), 'backward', ValueError), + ('test_n_type', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + 3, -1, 'backward', + ValueError), ('test_n_axis_dim', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + (1, 2), (-3, -2, -1), 'backward', ValueError), + ('test_axis_out_of_range', np.random.randn(4) + 1j * np.random.randn(4), + None, (1, 2), 'backward', ValueError), ( + 'test_axis_type', np.random.randn(4) + 1j * np.random.randn(4), None, + 1, 'backward', + ValueError), ('test_norm_not_in_enum_value', + np.random.randn(4, 4) + 1j * np.random.randn(4, 4), + None, None, 'random', ValueError)]) +class TestIrfft2Exception(unittest.TestCase): + '''Test irfft2 with buoudary condition + Test case include: + - non complex input + - n out of range + - axis out of range + - norm out of range + - the dimensions of n and axis are different + ''' + + def test_static_irfft2(self): + with self.assertRaises(self.expect_exception): + with stgraph(paddle.fft.irfft2, self.place, self.x, self.n, + self.axis, self.norm) as y: + pass + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), + [('test_input_dtype', np.random.randn(4, 4, 4), None, None, 'backward', + TypeError), ('test_bool_input', + (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) + ).astype(np.bool8), None, (-2, -1), 'backward', TypeError), + ('test_n_nagative', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (-1, -2), + (-2, -1), 'backward', ValueError), + ('test_n_zero', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + (0, 0), (-2, -1), 'backward', ValueError), + ('test_n_type', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + 3, -1, 'backward', + ValueError), ('test_n_axis_dim', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + (1, 2), (-3, -2, -1), 'backward', ValueError), + ('test_axis_out_of_range', np.random.randn(4) + 1j * np.random.randn(4), + None, (10, 20), 'backward', ValueError), ( + 'test_axis_type', np.random.randn(4) + 1j * np.random.randn(4), None, + 1, 'backward', + ValueError), ('test_norm_not_in_enum_value', + np.random.randn(4, 4) + 1j * np.random.randn(4, 4), + None, None, 'random', ValueError)]) +class TestHfftnException(unittest.TestCase): + '''Test hfftn with buoudary condition + Test case include: + - non complex input + - n out of range + - axis out of range + - norm out of range + - the dimensions of n and axis are different + ''' + + def test_static_hfftn(self): + with self.assertRaises(self.expect_exception): + with stgraph(paddle.fft.hfftn, self.place, self.x, self.n, + self.axis, self.norm) as y: + pass + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), + [ + ('test_input_dtype', np.random.randn(4, 4, 4), None, None, 'backward', + TypeError), + # ('test_bool_input', + # (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) + # ).astype(np.bool8), None, (-2, -1), 'backward', ValueError), + ('test_n_nagative', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (-1, -2), + (-2, -1), 'backward', ValueError), + ('test_n_zero', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (0, 0), + (-2, -1), 'backward', ValueError), + ('test_n_type', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), 3, -1, + 'backward', ValueError), + ('test_n_axis_dim', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (1, 2), + (-3, -2, -1), 'backward', ValueError), + ('test_axis_out_of_range', np.random.randn(4) + 1j * np.random.randn(4), + None, (10, 20), 'backward', ValueError), + ('test_axis_type', np.random.randn(4) + 1j * np.random.randn(4), None, + 1, 'backward', ValueError), + ('test_norm_not_in_enum_value', + np.random.randn(4, 4) + 1j * np.random.randn(4, 4), None, None, + 'random', ValueError) + ]) +class TestIrfftnException(unittest.TestCase): + '''Test irfftn with buoudary condition + Test case include: + - non complex input + - n out of range + - axis out of range + - norm out of range + - the dimensions of n and axis are different + ''' + + def test_static_irfftn(self): + with self.assertRaises(self.expect_exception): + with stgraph(paddle.fft.irfftn, self.place, self.x, self.n, + self.axis, self.norm) as y: + pass + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), + [('test_x_float64', rand_x(5, np.float64), None, -1, 'backward'), ( + 'test_n_grater_than_input_length', rand_x( + 5, max_dim_len=5), 11, -1, 'backward'), + ('test_n_smaller_than_input_length', rand_x( + 5, min_dim_len=5), 3, -1, + 'backward'), ('test_axis_not_last', rand_x(5), None, 3, 'backward'), + ('test_norm_forward', rand_x(5), None, 3, 'forward'), + ('test_norm_ortho', rand_x(5), None, 3, 'ortho')]) +class TestRfft(unittest.TestCase): + def test_static_rfft(self): + with stgraph(paddle.fft.rfft, self.place, self.x, self.n, self.axis, + self.norm) as y: + np.testing.assert_allclose( + scipy.fft.rfft(self.x, self.n, self.axis, self.norm), + y, + rtol=RTOL.get(str(self.x.dtype)), + atol=ATOL.get(str(self.x.dtype))) + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), + [('test_n_nagative', rand_x(2), -1, -1, 'backward', ValueError), + ('test_n_zero', rand_x(2), 0, -1, 'backward', ValueError), + ('test_axis_out_of_range', rand_x(1), None, 10, 'backward', + ValueError), ('test_axis_with_array', rand_x(1), None, (0, 1), 'backward', + ValueError), ('test_norm_not_in_enum_value', rand_x(2), + None, -1, 'random', ValueError)]) +class TestRfftException(unittest.TestCase): + def test_rfft(self): + with self.assertRaises(self.expect_exception): + with stgraph(paddle.fft.rfft, self.place, self.x, self.n, self.axis, + self.norm) as y: + pass + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ + ('test_x_float64', rand_x(5), None, (0, 1), 'backward'), + ('test_n_grater_input_length', rand_x( + 5, max_dim_len=5), (6, 6), (0, 1), 'backward'), + ('test_n_smaller_than_input_length', rand_x( + 5, min_dim_len=5), (4, 4), (0, 1), 'backward'), + ('test_axis_random', rand_x(5), None, (1, 2), 'backward'), + ('test_axis_none', rand_x(5), None, None, 'backward'), + ('test_norm_forward', rand_x(5), None, (0, 1), 'forward'), + ('test_norm_ortho', rand_x(5), None, (0, 1), 'ortho'), + ]) +class TestRfft2(unittest.TestCase): + def test_static_rfft2(self): + with stgraph(paddle.fft.rfft2, self.place, self.x, self.n, self.axis, + self.norm) as y: + np.testing.assert_allclose( + scipy.fft.rfft2(self.x, self.n, self.axis, self.norm), + y, + rtol=RTOL.get(str(self.x.dtype)), + atol=ATOL.get(str(self.x.dtype))) + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), + [ + ('test_x_complex_input', rand_x( + 2, complex=True), None, (0, 1), 'backward', TypeError), + # ('test_x_not_tensor', [0, 1], None, (0, 1), 'backward', ValueError), + ('test_x_1dim_tensor', rand_x(1), None, (0, 1), 'backward', ValueError), + ('test_n_nagative', rand_x(2), -1, (0, 1), 'backward', ValueError), + ('test_n_zero', rand_x(2), 0, (0, 1), 'backward', ValueError), + ('test_axis_out_of_range', rand_x(2), None, (0, 1, 2), 'backward', + ValueError), + ('test_axis_with_array', rand_x(1), None, (0, 1), 'backward', + ValueError), + ('test_axis_not_sequence', rand_x(5), None, -10, 'backward', + ValueError), + ('test_norm_not_enum', rand_x(2), None, -1, 'random', ValueError) + ]) +class TestRfft2Exception(unittest.TestCase): + def test_static_rfft(self): + with self.assertRaises(self.expect_exception): + with stgraph(paddle.fft.rfft2, self.place, self.x, self.n, + self.axis, self.norm) as y: + pass + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), + [('test_x_float64', rand_x(5, np.float64), None, None, 'backward'), + ('test_n_grater_input_length', rand_x( + 5, max_dim_len=5), (6, 6), (1, 2), + 'backward'), ('test_n_smaller_input_length', rand_x( + 5, min_dim_len=5), (3, 3), (1, 2), 'backward'), + ('test_axis_not_default', rand_x(5), None, (1, 2), + 'backward'), ('test_norm_forward', rand_x(5), None, None, 'forward'), + ('test_norm_ortho', rand_x(5), None, None, 'ortho')]) +class TestRfftn(unittest.TestCase): + def test_static_rfft(self): + with stgraph(paddle.fft.rfftn, self.place, self.x, self.n, self.axis, + self.norm) as y: + np.testing.assert_allclose( + scipy.fft.rfftn(self.x, self.n, self.axis, self.norm), + y, + rtol=RTOL.get(str(self.x.dtype)), + atol=ATOL.get(str(self.x.dtype))) + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), + [('test_x_complex', rand_x( + 4, complex=True), None, None, 'backward', + TypeError), ('test_n_nagative', rand_x(4), (-1, -1), (1, 2), 'backward', + ValueError), ('test_n_not_sequence', rand_x(4), -1, None, + 'backward', ValueError), + ('test_n_zero', rand_x(4), 0, None, 'backward', ValueError), + ('test_axis_out_of_range', rand_x(1), None, [0, 1], 'backward', + ValueError), ('test_norm_not_in_enum', rand_x(2), None, -1, 'random', + ValueError)]) +class TestRfftnException(unittest.TestCase): + def test_static_rfftn(self): + with self.assertRaises(self.expect_exception): + with stgraph(paddle.fft.rfftn, self.place, self.x, self.n, + self.axis, self.norm) as y: + pass + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), + [('test_x_float64', rand_x(5, np.float64), None, -1, 'backward'), ( + 'test_n_grater_than_input_length', rand_x( + 5, max_dim_len=5), 11, -1, 'backward'), + ('test_n_smaller_than_input_length', rand_x( + 5, min_dim_len=5), 3, -1, + 'backward'), ('test_axis_not_last', rand_x(5), None, 3, 'backward'), + ('test_norm_forward', rand_x(5), None, 3, 'forward'), + ('test_norm_ortho', rand_x(5), None, 3, 'ortho')]) +class TestIhfft(unittest.TestCase): + def test_static_ihfft(self): + with stgraph(paddle.fft.ihfft, self.place, self.x, self.n, self.axis, + self.norm) as y: + np.testing.assert_allclose( + scipy.fft.ihfft(self.x, self.n, self.axis, self.norm), + y, + rtol=RTOL.get(str(self.x.dtype)), + atol=ATOL.get(str(self.x.dtype))) + + +@place(DEVICES) +@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ + ('test_n_nagative', rand_x(2), -1, -1, 'backward', ValueError), + ('test_n_zero', rand_x(2), 0, -1, 'backward', ValueError), + ('test_axis_out_of_range', rand_x(1), None, 10, 'backward', ValueError), + ('test_axis_with_array', rand_x(1), None, (0, 1), 'backward', ValueError), + ('test_norm_not_in_enum_value', rand_x(2), None, -1, 'random', ValueError) +]) +class TestIhfftException(unittest.TestCase): + def test_static_ihfft(self): + with self.assertRaises(self.expect_exception): + with stgraph(paddle.fft.ihfft, self.place, self.x, self.n, + self.axis, self.norm) as y: + pass + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ + ('test_x_float64', rand_x(5), None, (0, 1), 'backward'), + ('test_n_grater_input_length', rand_x( + 5, max_dim_len=5), (11, 11), (0, 1), 'backward'), + ('test_n_smaller_than_input_length', rand_x( + 5, min_dim_len=5), (1, 1), (0, 1), 'backward'), + ('test_axis_random', rand_x(5), None, (1, 2), 'backward'), + ('test_axis_none', rand_x(5), None, None, 'backward'), + ('test_norm_forward', rand_x(5), None, (0, 1), 'forward'), + ('test_norm_ortho', rand_x(5), None, (0, 1), 'ortho'), + ]) +class TestIhfft2(unittest.TestCase): + def test_static_ihfft2(self): + with stgraph(paddle.fft.ihfft2, self.place, self.x, self.n, self.axis, + self.norm) as y: + np.testing.assert_allclose( + scipy.fft.ihfft2(self.x, self.n, self.axis, self.norm), + y, + rtol=RTOL.get(str(self.x.dtype)), + atol=ATOL.get(str(self.x.dtype))) + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), + [ + ('test_x_complex_input', rand_x( + 2, complex=True), None, (0, 1), None, ValueError), + # ('test_x_not_tensor', [0, 1], None, (0, 1), None, ValueError), + ('test_x_1dim_tensor', rand_x(1), None, (0, 1), None, ValueError), + ('test_n_nagative', rand_x(2), -1, (0, 1), 'backward', ValueError), + ('test_n_len_not_equal_axis', rand_x( + 5, max_dim_len=5), 11, (0, 1), 'backward', ValueError), + ('test_n_zero', rand_x(2), (0, 0), (0, 1), 'backward', ValueError), + ('test_axis_out_of_range', rand_x(2), None, (0, 1, 2), 'backward', + ValueError), + ('test_axis_with_array', rand_x(1), None, (0, 1), 'backward', + ValueError), + ('test_axis_not_sequence', rand_x(5), None, -10, 'backward', + ValueError), + ('test_norm_not_enum', rand_x(2), None, -1, 'random', ValueError) + ]) +class TestIhfft2Exception(unittest.TestCase): + def test_static_ihfft2(self): + with self.assertRaises(self.expect_exception): + with stgraph(paddle.fft.ihfft2, self.place, self.x, self.n, + self.axis, self.norm) as y: + pass + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), + [('test_x_float64', rand_x(5, np.float64), None, None, 'backward'), + ('test_n_grater_input_length', rand_x( + 5, max_dim_len=5), (11, 11), (0, 1), + 'backward'), ('test_n_smaller_input_length', rand_x( + 5, min_dim_len=5), (1, 1), (0, 1), 'backward'), + ('test_axis_not_default', rand_x(5), None, (1, 2), + 'backward'), ('test_norm_forward', rand_x(5), None, None, 'forward'), + ('test_norm_ortho', rand_x(5), None, None, 'ortho')]) +class TestIhfftn(unittest.TestCase): + def test_static_ihfftn(self): + with stgraph(paddle.fft.ihfftn, self.place, self.x, self.n, self.axis, + self.norm) as y: + np.testing.assert_allclose( + scipy.fft.ihfftn(self.x, self.n, self.axis, self.norm), + y, + rtol=RTOL.get(str(self.x.dtype)), + atol=ATOL.get(str(self.x.dtype))) + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), + [('test_x_complex', rand_x( + 4, complex=True), None, None, 'backward', TypeError), + ('test_n_nagative', rand_x(4), -1, None, 'backward', + ValueError), ('test_n_zero', rand_x(4), 0, None, 'backward', ValueError), + ('test_axis_out_of_range', rand_x(1), None, [0, 1], 'backward', + ValueError), ('test_norm_not_in_enum', rand_x(2), None, -1, 'random', + ValueError)]) +class TestIhfftnException(unittest.TestCase): + def test_static_ihfftn(self): + with self.assertRaises(self.expect_exception): + with stgraph(paddle.fft.ihfftn, self.place, self.x, self.n, + self.axis, self.norm) as y: + pass + + +if __name__ == '__main__': + unittest.main() + +# yapf: enable diff --git a/python/paddle/fluid/tests/unittests/fft/test_spectral_op.py b/python/paddle/fluid/tests/unittests/fft/test_spectral_op.py new file mode 100644 index 00000000000..a84092e36f6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/fft/test_spectral_op.py @@ -0,0 +1,178 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest + +import numpy as np +import paddle + +import re +import sys +from spectral_op_np import fft_c2c, fft_r2c, fft_c2r +import paddle.fluid.core as core +import paddle.fluid.dygraph as dg +import paddle.static as static +from numpy.random import random as rand +from paddle.fluid import Program, program_guard +sys.path.append("../") +from op_test import OpTest + +paddle.enable_static() + +TEST_CASE_NAME = 'test_case' + + +def parameterize(attrs, input_values=None): + + if isinstance(attrs, str): + attrs = [attrs] + input_dicts = (attrs if input_values is None else + [dict(zip(attrs, vals)) for vals in input_values]) + + def decorator(base_class): + test_class_module = sys.modules[base_class.__module__].__dict__ + for idx, input_dict in enumerate(input_dicts): + test_class_dict = dict(base_class.__dict__) + test_class_dict.update(input_dict) + + name = class_name(base_class, idx, input_dict) + + test_class_module[name] = type(name, (base_class, ), + test_class_dict) + + for method_name in list(base_class.__dict__): + if method_name.startswith("test"): + delattr(base_class, method_name) + return base_class + + return decorator + + +def to_safe_name(s): + return str(re.sub("[^a-zA-Z0-9_]+", "_", s)) + + +def class_name(cls, num, params_dict): + suffix = to_safe_name( + next((v for v in params_dict.values() if isinstance(v, str)), "")) + if TEST_CASE_NAME in params_dict: + suffix = to_safe_name(params_dict["test_case"]) + return "{}_{}{}".format(cls.__name__, num, suffix and "_" + suffix) + + +@parameterize((TEST_CASE_NAME, 'x', 'axes', 'norm', 'forward'), [ + ('test_axes_is_sqe_type', (np.random.random( + (12, 14)) + 1j * np.random.random((12, 14))).astype(np.complex128), + [0, 1], 'forward', True), ('test_axis_not_last', (np.random.random( + (4, 4, 4)) + 1j * np.random.random((4, 4, 4))).astype(np.complex128), + (0, 1), "backward", False), + ('test_norm_forward', (np.random.random((12, 14)) + 1j * np.random.random( + (12, 14))).astype(np.complex128), (0, ), "forward", + False), ('test_norm_backward', (np.random.random( + (12, 14)) + 1j * np.random.random((12, 14))).astype(np.complex128), + (0, ), "backward", True), ('test_norm_ortho', (np.random.random( + (12, 14)) + 1j * np.random.random( + (12, 14))).astype(np.complex128), (1, ), "ortho", True) +]) +class TestFFTC2COp(OpTest): + # Because framwork not support complex numerial gradient, we skip gradient check. + no_need_check_grad = True + + def setUp(self): + self.op_type = "fft_c2c" + + out = fft_c2c(self.x, self.axes, self.norm, self.forward) + + self.inputs = {'X': self.x} + self.attrs = { + 'axes': self.axes, + 'normalization': self.norm, + "forward": self.forward + } + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + +@parameterize( + (TEST_CASE_NAME, 'x', 'axes', 'norm', 'forward', 'last_dim_size'), + [('test_axes_is_sqe_type', (np.random.random( + (12, 14)) + 1j * np.random.random((12, 14))).astype(np.complex128), + [0, 1], 'forward', True, 26), ('test_axis_not_last', (np.random.random( + (4, 4, 4)) + 1j * np.random.random((4, 4, 4))).astype(np.complex128), + (0, 1), "backward", False, None), + ('test_norm_forward', (np.random.random((12, 14)) + 1j * np.random.random( + (12, 14))).astype(np.complex128), (0, ), "forward", False, 22), + ('test_norm_backward', (np.random.random((12, 14)) + 1j * np.random.random( + (12, 14))).astype(np.complex128), (0, ), "backward", True, + 22), ('test_norm_ortho', (np.random.random( + (12, 14)) + 1j * np.random.random((12, 14))).astype(np.complex128), + (1, ), "ortho", True, 26)]) +class TestFFTC2ROp(OpTest): + # Because framwork not support complex numerial gradient, we skip gradient check. + no_need_check_grad = True + + def setUp(self): + self.op_type = "fft_c2r" + + out = fft_c2r(self.x, self.axes, self.norm, self.forward, + self.last_dim_size) + + self.inputs = {'X': self.x} + self.attrs = { + "axes": self.axes, + "normalization": self.norm, + "forward": self.forward, + "last_dim_size": self.last_dim_size + } + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + +@parameterize( + (TEST_CASE_NAME, 'x', 'axes', 'norm', 'forward', 'onesided'), + [('test_axes_is_sqe_type', np.random.randn(12, 14).astype(np.float64), + (0, 1), 'forward', True, + True), ('test_axis_not_last', np.random.randn(4, 4, 4).astype(np.float64), + (0, 1), "backward", False, True), + ('test_norm_forward', np.random.randn(12, 14).astype(np.float64), (0, 1), + "forward", False, False), + ('test_norm_backward', np.random.randn(12, 14).astype(np.float64), (0, ), + "backward", True, False), ('test_norm_ortho', + np.random.randn(12, 14).astype(np.float64), + (1, ), "ortho", True, False)]) +class TestFFTR2COp(OpTest): + # Because framwork not support complex numerial gradient, we skip gradient check. + no_need_check_grad = True + + def setUp(self): + self.op_type = "fft_r2c" + + out = fft_r2c(self.x, self.axes, self.norm, self.forward, self.onesided) + + self.inputs = {'X': self.x} + self.attrs = { + 'axes': self.axes, + 'normalization': self.norm, + "forward": self.forward, + 'onesided': self.onesided + } + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() diff --git a/python/paddle/fluid/tests/unittests/test_frame_op.py b/python/paddle/fluid/tests/unittests/test_frame_op.py new file mode 100644 index 00000000000..f26662dcf4f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_frame_op.py @@ -0,0 +1,140 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from numpy.lib.stride_tricks import as_strided +import paddle +import unittest + +from op_test import OpTest + + +def frame_from_librosa(x, frame_length, hop_length, axis=-1): + if axis == -1 and not x.flags["C_CONTIGUOUS"]: + x = np.ascontiguousarray(x) + elif axis == 0 and not x.flags["F_CONTIGUOUS"]: + x = np.asfortranarray(x) + + n_frames = 1 + (x.shape[axis] - frame_length) // hop_length + strides = np.asarray(x.strides) + + if axis == -1: + shape = list(x.shape)[:-1] + [frame_length, n_frames] + strides = list(strides) + [hop_length * x.itemsize] + + elif axis == 0: + shape = [n_frames, frame_length] + list(x.shape)[1:] + strides = [hop_length * x.itemsize] + list(strides) + + else: + raise ValueError("Frame axis={} must be either 0 or -1".format(axis)) + + return as_strided(x, shape=shape, strides=strides) + + +class TestFrameOp(OpTest): + def setUp(self): + self.op_type = "frame" + self.shape, self.type, self.attrs = self.initTestCase() + self.inputs = { + 'X': np.random.random(size=self.shape).astype(self.type), + } + self.outputs = { + 'Out': frame_from_librosa( + x=self.inputs['X'], **self.attrs) + } + + def initTestCase(self): + input_shape = (150, ) + input_type = 'float64' + attrs = { + 'frame_length': 50, + 'hop_length': 15, + 'axis': -1, + } + return input_shape, input_type, attrs + + def test_check_output(self): + paddle.enable_static() + self.check_output() + paddle.disable_static() + + def test_check_grad_normal(self): + paddle.enable_static() + self.check_grad(['X'], 'Out') + paddle.disable_static() + + +class TestCase1(TestFrameOp): + def initTestCase(self): + input_shape = (150, ) + input_type = 'float64' + attrs = { + 'frame_length': 50, + 'hop_length': 15, + 'axis': 0, + } + return input_shape, input_type, attrs + + +class TestCase2(TestFrameOp): + def initTestCase(self): + input_shape = (8, 150) + input_type = 'float64' + attrs = { + 'frame_length': 50, + 'hop_length': 15, + 'axis': -1, + } + return input_shape, input_type, attrs + + +class TestCase3(TestFrameOp): + def initTestCase(self): + input_shape = (150, 8) + input_type = 'float64' + attrs = { + 'frame_length': 50, + 'hop_length': 15, + 'axis': 0, + } + return input_shape, input_type, attrs + + +class TestCase4(TestFrameOp): + def initTestCase(self): + input_shape = (4, 2, 150) + input_type = 'float64' + attrs = { + 'frame_length': 50, + 'hop_length': 15, + 'axis': -1, + } + return input_shape, input_type, attrs + + +class TestCase5(TestFrameOp): + def initTestCase(self): + input_shape = (150, 4, 2) + input_type = 'float64' + attrs = { + 'frame_length': 50, + 'hop_length': 15, + 'axis': 0, + } + return input_shape, input_type, attrs + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_overlap_add_op.py b/python/paddle/fluid/tests/unittests/test_overlap_add_op.py new file mode 100644 index 00000000000..7af67d01b57 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_overlap_add_op.py @@ -0,0 +1,157 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +import unittest + +from op_test import OpTest + + +def overlap_add(x, hop_length, axis=-1): + assert axis in [0, -1], 'axis should be 0/-1.' + assert len(x.shape) >= 2, 'Input dims shoulb be >= 2.' + + squeeze_output = False + if len(x.shape) == 2: + squeeze_output = True + dim = 0 if axis == -1 else -1 + x = np.expand_dims(x, dim) # batch + + n_frames = x.shape[axis] + frame_length = x.shape[1] if axis == 0 else x.shape[-2] + + # Assure no gaps between frames. + assert 0 < hop_length <= frame_length, \ + f'hop_length should be in (0, frame_length({frame_length})], but got {hop_length}.' + + seq_length = (n_frames - 1) * hop_length + frame_length + + reshape_output = False + if len(x.shape) > 3: + reshape_output = True + if axis == 0: + target_shape = [seq_length] + list(x.shape[2:]) + x = x.reshape(n_frames, frame_length, np.product(x.shape[2:])) + else: + target_shape = list(x.shape[:-2]) + [seq_length] + x = x.reshape(np.product(x.shape[:-2]), frame_length, n_frames) + + if axis == 0: + x = x.transpose((2, 1, 0)) + + y = np.zeros(shape=[np.product(x.shape[:-2]), seq_length], dtype=x.dtype) + for i in range(x.shape[0]): + for frame in range(x.shape[-1]): + sample = frame * hop_length + y[i, sample:sample + frame_length] += x[i, :, frame] + + if axis == 0: + y = y.transpose((1, 0)) + + if reshape_output: + y = y.reshape(target_shape) + + if squeeze_output: + y = y.squeeze(-1) if axis == 0 else y.squeeze(0) + + return y + + +class TestOverlapAddOp(OpTest): + def setUp(self): + self.op_type = "overlap_add" + self.shape, self.type, self.attrs = self.initTestCase() + self.inputs = { + 'X': np.random.random(size=self.shape).astype(self.type), + } + self.outputs = {'Out': overlap_add(x=self.inputs['X'], **self.attrs)} + + def initTestCase(self): + input_shape = (50, 3) + input_type = 'float64' + attrs = { + 'hop_length': 4, + 'axis': -1, + } + return input_shape, input_type, attrs + + def test_check_output(self): + paddle.enable_static() + self.check_output() + paddle.disable_static() + + def test_check_grad_normal(self): + paddle.enable_static() + self.check_grad(['X'], 'Out') + paddle.disable_static() + + +class TestCase1(TestOverlapAddOp): + def initTestCase(self): + input_shape = (3, 50) + input_type = 'float64' + attrs = { + 'hop_length': 4, + 'axis': 0, + } + return input_shape, input_type, attrs + + +class TestCase2(TestOverlapAddOp): + def initTestCase(self): + input_shape = (2, 40, 5) + input_type = 'float64' + attrs = { + 'hop_length': 10, + 'axis': -1, + } + return input_shape, input_type, attrs + + +class TestCase3(TestOverlapAddOp): + def initTestCase(self): + input_shape = (5, 40, 2) + input_type = 'float64' + attrs = { + 'hop_length': 10, + 'axis': 0, + } + return input_shape, input_type, attrs + + +class TestCase4(TestOverlapAddOp): + def initTestCase(self): + input_shape = (3, 5, 12, 8) + input_type = 'float64' + attrs = { + 'hop_length': 5, + 'axis': -1, + } + return input_shape, input_type, attrs + + +class TestCase5(TestOverlapAddOp): + def initTestCase(self): + input_shape = (8, 12, 5, 3) + input_type = 'float64' + attrs = { + 'hop_length': 5, + 'axis': 0, + } + return input_shape, input_type, attrs + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_signal.py b/python/paddle/fluid/tests/unittests/test_signal.py new file mode 100644 index 00000000000..a109a5aa5d1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_signal.py @@ -0,0 +1,1005 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import sys +import unittest + +import numpy as np +from numpy import fft +from numpy.lib.stride_tricks import as_strided +import paddle +import scipy.signal + +paddle.set_default_dtype('float64') + +DEVICES = [paddle.CPUPlace()] +if paddle.is_compiled_with_cuda(): + DEVICES.append(paddle.CUDAPlace(0)) +TEST_CASE_NAME = 'test_case' + +# Constrain STFT block sizes to 256 KB +MAX_MEM_BLOCK = 2**8 * 2**10 + + +def fix_length(data, size, axis=-1, **kwargs): + kwargs.setdefault("mode", "constant") + + n = data.shape[axis] + + if n > size: + slices = [slice(None)] * data.ndim + slices[axis] = slice(0, size) + return data[tuple(slices)] + + elif n < size: + lengths = [(0, 0)] * data.ndim + lengths[axis] = (0, size - n) + return np.pad(data, lengths, **kwargs) + + return data + + +def tiny(x): + # Make sure we have an array view + x = np.asarray(x) + + # Only floating types generate a tiny + if np.issubdtype(x.dtype, np.floating) or np.issubdtype(x.dtype, + np.complexfloating): + dtype = x.dtype + else: + dtype = np.float32 + + return np.finfo(dtype).tiny + + +def normalize(S, norm=np.inf, axis=0, threshold=None, fill=None): + # Avoid div-by-zero + if threshold is None: + threshold = tiny(S) + + elif threshold <= 0: + raise Exception("threshold={} must be strictly " + "positive".format(threshold)) + + if fill not in [None, False, True]: + raise Exception("fill={} must be None or boolean".format(fill)) + + if not np.all(np.isfinite(S)): + raise Exception("Input must be finite") + + # All norms only depend on magnitude, let's do that first + mag = np.abs(S).astype(np.float) + + # For max/min norms, filling with 1 works + fill_norm = 1 + + if norm == np.inf: + length = np.max(mag, axis=axis, keepdims=True) + + elif norm == -np.inf: + length = np.min(mag, axis=axis, keepdims=True) + + elif norm == 0: + if fill is True: + raise Exception("Cannot normalize with norm=0 and fill=True") + + length = np.sum(mag > 0, axis=axis, keepdims=True, dtype=mag.dtype) + + elif np.issubdtype(type(norm), np.number) and norm > 0: + length = np.sum(mag**norm, axis=axis, keepdims=True)**(1.0 / norm) + + if axis is None: + fill_norm = mag.size**(-1.0 / norm) + else: + fill_norm = mag.shape[axis]**(-1.0 / norm) + + elif norm is None: + return S + + else: + raise Exception("Unsupported norm: {}".format(repr(norm))) + + # indices where norm is below the threshold + small_idx = length < threshold + + Snorm = np.empty_like(S) + if fill is None: + # Leave small indices un-normalized + length[small_idx] = 1.0 + Snorm[:] = S / length + + elif fill: + # If we have a non-zero fill value, we locate those entries by + # doing a nan-divide. + # If S was finite, then length is finite (except for small positions) + length[small_idx] = np.nan + Snorm[:] = S / length + Snorm[np.isnan(Snorm)] = fill_norm + else: + # Set small values to zero by doing an inf-divide. + # This is safe (by IEEE-754) as long as S is finite. + length[small_idx] = np.inf + Snorm[:] = S / length + + return Snorm + + +def __window_ss_fill(x, win_sq, n_frames, hop_length): # pragma: no cover + """Helper function for window sum-square calculation.""" + + n = len(x) + n_fft = len(win_sq) + for i in range(n_frames): + sample = i * hop_length + x[sample:min(n, sample + n_fft)] += win_sq[:max(0, + min(n_fft, n - sample))] + + +def window_sumsquare( + window, + n_frames, + hop_length=512, + win_length=None, + n_fft=2048, + dtype=np.float32, + norm=None, ): + if win_length is None: + win_length = n_fft + + n = n_fft + hop_length * (n_frames - 1) + x = np.zeros(n, dtype=dtype) + + # Compute the squared window at the desired length + win_sq = get_window(window, win_length) + win_sq = normalize(win_sq, norm=norm)**2 + win_sq = pad_center(win_sq, n_fft) + + # Fill the envelope + __window_ss_fill(x, win_sq, n_frames, hop_length) + + return x + + +def dtype_c2r(d, default=np.float32): + mapping = { + np.dtype(np.complex64): np.float32, + np.dtype(np.complex128): np.float64, + } + + # If we're given a real type already, return it + dt = np.dtype(d) + if dt.kind == "f": + return dt + + # Otherwise, try to map the dtype. + # If no match is found, return the default. + return np.dtype(mapping.get(np.dtype(d), default)) + + +def dtype_r2c(d, default=np.complex64): + mapping = { + np.dtype(np.float32): np.complex64, + np.dtype(np.float64): np.complex128, + } + + # If we're given a complex type already, return it + dt = np.dtype(d) + if dt.kind == "c": + return dt + + # Otherwise, try to map the dtype. + # If no match is found, return the default. + return np.dtype(mapping.get(dt, default)) + + +def frame(x, frame_length, hop_length, axis=-1): + if not isinstance(x, np.ndarray): + raise Exception("Input must be of type numpy.ndarray, " + "given type(x)={}".format(type(x))) + + if x.shape[axis] < frame_length: + raise Exception("Input is too short (n={:d})" + " for frame_length={:d}".format(x.shape[axis], + frame_length)) + + if hop_length < 1: + raise Exception("Invalid hop_length: {:d}".format(hop_length)) + + if axis == -1 and not x.flags["F_CONTIGUOUS"]: + print("librosa.util.frame called with axis={} " + "on a non-contiguous input. This will result in a copy.".format( + axis)) + x = np.asfortranarray(x) + elif axis == 0 and not x.flags["C_CONTIGUOUS"]: + print("librosa.util.frame called with axis={} " + "on a non-contiguous input. This will result in a copy.".format( + axis)) + x = np.ascontiguousarray(x) + + n_frames = 1 + (x.shape[axis] - frame_length) // hop_length + strides = np.asarray(x.strides) + + new_stride = np.prod(strides[strides > 0] // x.itemsize) * x.itemsize + + if axis == -1: + shape = list(x.shape)[:-1] + [frame_length, n_frames] + strides = list(strides) + [hop_length * new_stride] + + elif axis == 0: + shape = [n_frames, frame_length] + list(x.shape)[1:] + strides = [hop_length * new_stride] + list(strides) + + else: + raise Exception("Frame axis={} must be either 0 or -1".format(axis)) + + return as_strided(x, shape=shape, strides=strides) + + +def pad_center(data, size, axis=-1, **kwargs): + kwargs.setdefault("mode", "constant") + + n = data.shape[axis] + + lpad = int((size - n) // 2) + + lengths = [(0, 0)] * data.ndim + lengths[axis] = (lpad, int(size - n - lpad)) + + if lpad < 0: + raise Exception(("Target size ({:d}) must be " + "at least input size ({:d})").format(size, n)) + + return np.pad(data, lengths, **kwargs) + + +def get_window(window, Nx, fftbins=True): + if callable(window): + return window(Nx) + + elif isinstance(window, (str, tuple)) or np.isscalar(window): + # TODO: if we add custom window functions in librosa, call them here + + return scipy.signal.get_window(window, Nx, fftbins=fftbins) + + elif isinstance(window, (np.ndarray, list)): + if len(window) == Nx: + return np.asarray(window) + + raise Exception("Window size mismatch: " + "{:d} != {:d}".format(len(window), Nx)) + else: + raise Exception("Invalid window specification: {}".format(window)) + + +def __overlap_add(y, ytmp, hop_length): + # numba-accelerated overlap add for inverse stft + # y is the pre-allocated output buffer + # ytmp is the windowed inverse-stft frames + # hop_length is the hop-length of the STFT analysis + + n_fft = ytmp.shape[0] + for frame in range(ytmp.shape[1]): + sample = frame * hop_length + y[sample:(sample + n_fft)] += ytmp[:, frame] + + +def stft(x, + n_fft=2048, + hop_length=None, + win_length=None, + window="hann", + center=True, + pad_mode="reflect"): + y = x + input_rank = len(y.shape) + if input_rank == 2: + assert y.shape[0] == 1 # Only 1d input supported in librosa + y = y.squeeze(0) + dtype = None + + # By default, use the entire frame + if win_length is None: + win_length = n_fft + + # Set the default hop, if it's not already specified + if hop_length is None: + hop_length = int(win_length // 4) + + fft_window = get_window(window, win_length, fftbins=True) + + # Pad the window out to n_fft size + fft_window = pad_center(fft_window, n_fft) + + # Reshape so that the window can be broadcast + fft_window = fft_window.reshape((-1, 1)) + + # Pad the time series so that frames are centered + if center: + if n_fft > y.shape[-1]: + print("n_fft={} is too small for input signal of length={}".format( + n_fft, y.shape[-1])) + + y = np.pad(y, int(n_fft // 2), mode=pad_mode) + + elif n_fft > y.shape[-1]: + raise Exception("n_fft={} is too large for input signal of length={}". + format(n_fft, y.shape[-1])) + + # Window the time series. + y_frames = frame(y, frame_length=n_fft, hop_length=hop_length) + + if dtype is None: + dtype = dtype_r2c(y.dtype) + + # Pre-allocate the STFT matrix + stft_matrix = np.empty( + (int(1 + n_fft // 2), y_frames.shape[1]), dtype=dtype, order="F") + + # how many columns can we fit within MAX_MEM_BLOCK? + n_columns = MAX_MEM_BLOCK // (stft_matrix.shape[0] * stft_matrix.itemsize) + n_columns = max(n_columns, 1) + + for bl_s in range(0, stft_matrix.shape[1], n_columns): + bl_t = min(bl_s + n_columns, stft_matrix.shape[1]) + + stft_matrix[:, bl_s:bl_t] = fft.rfft( + fft_window * y_frames[:, bl_s:bl_t], axis=0) + + if input_rank == 2: + stft_matrix = np.expand_dims(stft_matrix, 0) + + return stft_matrix + + +def istft( + x, + hop_length=None, + win_length=None, + window="hann", + center=True, + length=None, ): + + stft_matrix = x + input_rank = len(stft_matrix.shape) + if input_rank == 3: + assert stft_matrix.shape[0] == 1 # Only 2d input supported in librosa + stft_matrix = stft_matrix.squeeze(0) + dtype = None + + n_fft = 2 * (stft_matrix.shape[0] - 1) + + # By default, use the entire frame + if win_length is None: + win_length = n_fft + + # Set the default hop, if it's not already specified + if hop_length is None: + hop_length = int(win_length // 4) + + ifft_window = get_window(window, win_length, fftbins=True) + + # Pad out to match n_fft, and add a broadcasting axis + ifft_window = pad_center(ifft_window, n_fft)[:, np.newaxis] + + # For efficiency, trim STFT frames according to signal length if available + if length: + if center: + padded_length = length + int(n_fft) + else: + padded_length = length + n_frames = min(stft_matrix.shape[1], + int(np.ceil(padded_length / hop_length))) + else: + n_frames = stft_matrix.shape[1] + + expected_signal_len = n_fft + hop_length * (n_frames - 1) + + if dtype is None: + dtype = dtype_c2r(stft_matrix.dtype) + + y = np.zeros(expected_signal_len, dtype=dtype) + + n_columns = MAX_MEM_BLOCK // (stft_matrix.shape[0] * stft_matrix.itemsize) + n_columns = min(n_columns, 1) + + frame = 0 + for bl_s in range(0, n_frames, n_columns): + bl_t = min(bl_s + n_columns, n_frames) + + # invert the block and apply the window function + ytmp = ifft_window * fft.irfft(stft_matrix[:, bl_s:bl_t], axis=0) + + # Overlap-add the istft block starting at the i'th frame + __overlap_add(y[frame * hop_length:], ytmp, hop_length) + + frame += bl_t - bl_s + + # Normalize by sum of squared window + ifft_window_sum = window_sumsquare( + window, + n_frames, + win_length=win_length, + n_fft=n_fft, + hop_length=hop_length, + dtype=dtype, ) + + approx_nonzero_indices = ifft_window_sum > tiny(ifft_window_sum) + y[approx_nonzero_indices] /= ifft_window_sum[approx_nonzero_indices] + + if length is None: + # If we don't need to control length, just do the usual center trimming + # to eliminate padded data + if center: + y = y[int(n_fft // 2):-int(n_fft // 2)] + else: + if center: + # If we're centering, crop off the first n_fft//2 samples + # and then trim/pad to the target length. + # We don't trim the end here, so that if the signal is zero-padded + # to a longer duration, the decay is smooth by windowing + start = int(n_fft // 2) + else: + # If we're not centering, start at 0 and trim/pad as necessary + start = 0 + + y = fix_length(y[start:], length) + + if input_rank == 3: + y = np.expand_dims(y, 0) + + return y + + +def frame_for_api_test(x, frame_length, hop_length, axis=-1): + if axis == -1 and not x.flags["C_CONTIGUOUS"]: + x = np.ascontiguousarray(x) + elif axis == 0 and not x.flags["F_CONTIGUOUS"]: + x = np.asfortranarray(x) + + n_frames = 1 + (x.shape[axis] - frame_length) // hop_length + strides = np.asarray(x.strides) + + if axis == -1: + shape = list(x.shape)[:-1] + [frame_length, n_frames] + strides = list(strides) + [hop_length * x.itemsize] + + elif axis == 0: + shape = [n_frames, frame_length] + list(x.shape)[1:] + strides = [hop_length * x.itemsize] + list(strides) + + else: + raise ValueError("Frame axis={} must be either 0 or -1".format(axis)) + + return as_strided(x, shape=shape, strides=strides) + + +def overlap_add_for_api_test(x, hop_length, axis=-1): + assert axis in [0, -1], 'axis should be 0/-1.' + assert len(x.shape) >= 2, 'Input dims shoulb be >= 2.' + + squeeze_output = False + if len(x.shape) == 2: + squeeze_output = True + dim = 0 if axis == -1 else -1 + x = np.expand_dims(x, dim) # batch + + n_frames = x.shape[axis] + frame_length = x.shape[1] if axis == 0 else x.shape[-2] + + # Assure no gaps between frames. + assert 0 < hop_length <= frame_length, \ + f'hop_length should be in (0, frame_length({frame_length})], but got {hop_length}.' + + seq_length = (n_frames - 1) * hop_length + frame_length + + reshape_output = False + if len(x.shape) > 3: + reshape_output = True + if axis == 0: + target_shape = [seq_length] + list(x.shape[2:]) + x = x.reshape(n_frames, frame_length, np.product(x.shape[2:])) + else: + target_shape = list(x.shape[:-2]) + [seq_length] + x = x.reshape(np.product(x.shape[:-2]), frame_length, n_frames) + + if axis == 0: + x = x.transpose((2, 1, 0)) + + y = np.zeros(shape=[np.product(x.shape[:-2]), seq_length], dtype=x.dtype) + for i in range(x.shape[0]): + for frame in range(x.shape[-1]): + sample = frame * hop_length + y[i, sample:sample + frame_length] += x[i, :, frame] + + if axis == 0: + y = y.transpose((1, 0)) + + if reshape_output: + y = y.reshape(target_shape) + + if squeeze_output: + y = y.squeeze(-1) if axis == 0 else y.squeeze(0) + + return y + + +def place(devices, key='place'): + def decorate(cls): + module = sys.modules[cls.__module__].__dict__ + raw_classes = { + k: v + for k, v in module.items() if k.startswith(cls.__name__) + } + + for raw_name, raw_cls in raw_classes.items(): + for d in devices: + test_cls = dict(raw_cls.__dict__) + test_cls.update({key: d}) + new_name = raw_name + '.' + d.__class__.__name__ + module[new_name] = type(new_name, (raw_cls, ), test_cls) + del module[raw_name] + return cls + + return decorate + + +def setUpModule(): + global rtol + global atol + # All test case will use float64 for compare percision, refs: + # https://github.com/PaddlePaddle/Paddle/wiki/Upgrade-OP-Precision-to-Float64 + rtol = { + 'float32': 1e-06, + 'float64': 1e-7, + 'complex64': 1e-06, + 'complex128': 1e-7, + } + atol = { + 'float32': 0.0, + 'float64': 0.0, + 'complex64': 0.0, + 'complex128': 0.0, + } + + +def tearDownModule(): + pass + + +def rand_x(dims=1, + dtype='float64', + min_dim_len=1, + max_dim_len=10, + shape=None, + complex=False): + + if shape is None: + shape = [ + np.random.randint(min_dim_len, max_dim_len) for i in range(dims) + ] + if complex: + return np.random.randn(*shape).astype(dtype) + 1.j * np.random.randn( + *shape).astype(dtype) + else: + return np.random.randn(*shape).astype(dtype) + + +def parameterize(attrs, input_values=None): + + if isinstance(attrs, str): + attrs = [attrs] + input_dicts = (attrs if input_values is None else + [dict(zip(attrs, vals)) for vals in input_values]) + + def decorator(base_class): + test_class_module = sys.modules[base_class.__module__].__dict__ + for idx, input_dict in enumerate(input_dicts): + test_class_dict = dict(base_class.__dict__) + test_class_dict.update(input_dict) + + name = class_name(base_class, idx, input_dict) + + test_class_module[name] = type(name, (base_class, ), + test_class_dict) + + for method_name in list(base_class.__dict__): + if method_name.startswith("test"): + delattr(base_class, method_name) + return base_class + + return decorator + + +def class_name(cls, num, params_dict): + suffix = to_safe_name( + next((v for v in params_dict.values() if isinstance(v, str)), "")) + if TEST_CASE_NAME in params_dict: + suffix = to_safe_name(params_dict["test_case"]) + return "{}_{}{}".format(cls.__name__, num, suffix and "_" + suffix) + + +def to_safe_name(s): + return str(re.sub("[^a-zA-Z0-9_]+", "_", s)) + + +# yapf: disable +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'frame_length', 'hop_length', 'axis'), + [ + ('test_1d_input1', rand_x(1, np.float64, shape=[150]), 50, 15, 0), + ('test_1d_input2', rand_x(1, np.float64, shape=[150]), 50, 15, -1), + ('test_2d_input1', rand_x(2, np.float64, shape=[150, 8]), 50, 15, 0), + ('test_2d_input2', rand_x(2, np.float64, shape=[8, 150]), 50, 15, -1), + ('test_3d_input1', rand_x(3, np.float64, shape=[150, 4, 2]), 50, 15, 0), + ('test_3d_input2', rand_x(3, np.float64, shape=[4, 2, 150]), 50, 15, -1), + ]) +class TestFrame(unittest.TestCase): + def test_frame(self): + self.assertTrue( + np.allclose( + frame_for_api_test(self.x, self.frame_length, self.hop_length, self.axis), + paddle.tensor.signal.frame( + paddle.to_tensor(self.x), + self.frame_length, + self.hop_length, + self.axis), + rtol=rtol.get(str(self.x.dtype)), + atol=atol.get(str(self.x.dtype)))) + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'frame_length', 'hop_length', 'axis'), + [ + ('test_1d_input1', rand_x(1, np.float64, shape=[150]), 50, 15, 0), + ('test_1d_input2', rand_x(1, np.float64, shape=[150]), 50, 15, -1), + ('test_2d_input1', rand_x(2, np.float64, shape=[150, 8]), 50, 15, 0), + ('test_2d_input2', rand_x(2, np.float64, shape=[8, 150]), 50, 15, -1), + ('test_3d_input1', rand_x(3, np.float64, shape=[150, 4, 2]), 50, 15, 0), + ('test_3d_input2', rand_x(3, np.float64, shape=[4, 2, 150]), 50, 15, -1), + ]) +class TestFrameStatic(unittest.TestCase): + def test_frame_static(self): + paddle.enable_static() + mp, sp = paddle.static.Program(), paddle.static.Program() + with paddle.static.program_guard(mp, sp): + input = paddle.static.data('input', self.x.shape, dtype=self.x.dtype) + output = paddle.tensor.signal.frame( + input, + self.frame_length, + self.hop_length, + self.axis), + exe = paddle.static.Executor(self.place) + exe.run(sp) + [output] = exe.run(mp, feed={'input': self.x}, fetch_list=[output]) + paddle.disable_static() + + self.assertTrue( + np.allclose( + frame_for_api_test(self.x, self.frame_length, self.hop_length, self.axis), + output, + rtol=rtol.get(str(self.x.dtype)), + atol=atol.get(str(self.x.dtype)))) + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'frame_length', 'hop_length', 'axis', 'expect_exception'), + [ + ('test_axis', rand_x(1, np.float64, shape=[150]), 50, 15, 2, ValueError), + ('test_hop_length', rand_x(1, np.float64, shape=[150]), 50, 0, -1, ValueError), + ('test_frame_length1', rand_x(2, np.float64, shape=[150, 8]), 0, 15, 0, ValueError), + ('test_frame_length2', rand_x(2, np.float64, shape=[150, 8]), 151, 15, 0, ValueError), + ]) +class TestFrameException(unittest.TestCase): + def test_frame(self): + with self.assertRaises(self.expect_exception): + paddle.tensor.signal.frame( + paddle.to_tensor(self.x), + self.frame_length, + self.hop_length, + self.axis) + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'hop_length', 'axis'), + [ + ('test_2d_input1', rand_x(2, np.float64, shape=[3, 50]), 4, 0), + ('test_2d_input2', rand_x(2, np.float64, shape=[50, 3]), 4, -1), + ('test_3d_input1', rand_x(3, np.float64, shape=[5, 40, 2]), 10, 0), + ('test_3d_input2', rand_x(3, np.float64, shape=[2, 40, 5]), 10, -1), + ('test_4d_input1', rand_x(4, np.float64, shape=[8, 12, 5, 3]), 5, 0), + ('test_4d_input2', rand_x(4, np.float64, shape=[3, 5, 12, 8]), 5, -1), + ]) +class TestOverlapAdd(unittest.TestCase): + def test_overlap_add(self): + self.assertTrue( + np.allclose( + overlap_add_for_api_test(self.x, self.hop_length, self.axis), + paddle.tensor.signal.overlap_add( + paddle.to_tensor(self.x), + self.hop_length, + self.axis), + rtol=rtol.get(str(self.x.dtype)), + atol=atol.get(str(self.x.dtype)))) + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'hop_length', 'axis'), + [ + ('test_2d_input1', rand_x(2, np.float64, shape=[3, 50]), 4, 0), + ('test_2d_input2', rand_x(2, np.float64, shape=[50, 3]), 4, -1), + ('test_3d_input1', rand_x(3, np.float64, shape=[5, 40, 2]), 10, 0), + ('test_3d_input2', rand_x(3, np.float64, shape=[2, 40, 5]), 10, -1), + ('test_4d_input1', rand_x(4, np.float64, shape=[8, 12, 5, 3]), 5, 0), + ('test_4d_input2', rand_x(4, np.float64, shape=[3, 5, 12, 8]), 5, -1), + ]) +class TestOverlapAddStatic(unittest.TestCase): + def test_overlap_add_static(self): + paddle.enable_static() + mp, sp = paddle.static.Program(), paddle.static.Program() + with paddle.static.program_guard(mp, sp): + input = paddle.static.data('input', self.x.shape, dtype=self.x.dtype) + output = paddle.tensor.signal.overlap_add( + input, + self.hop_length, + self.axis), + exe = paddle.static.Executor(self.place) + exe.run(sp) + [output] = exe.run(mp, feed={'input': self.x}, fetch_list=[output]) + paddle.disable_static() + + self.assertTrue( + np.allclose( + overlap_add_for_api_test(self.x, self.hop_length, self.axis), + output, + rtol=rtol.get(str(self.x.dtype)), + atol=atol.get(str(self.x.dtype)))) + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'hop_length', 'axis', 'expect_exception'), + [ + ('test_axis', rand_x(2, np.float64, shape=[3, 50]), 4, 2, ValueError), + ('test_hop_length', rand_x(2, np.float64, shape=[50, 3]), -1, -1, ValueError), + ]) +class TestOverlapAddException(unittest.TestCase): + def test_overlap_add(self): + with self.assertRaises(self.expect_exception): + paddle.tensor.signal.overlap_add( + paddle.to_tensor(self.x), + self.hop_length, + self.axis) + + +# ================= STFT +# common args +# x +# n_fft, +# hop_length=None, +# win_length=None, +# window=None, +# center=True, +# pad_mode='reflect', + +# paddle only +# normalized=False, +# onesided=True, + +# ================= ISTFT +# common args +# x, +# hop_length=None, +# win_length=None, +# window=None, +# center=True, +# length=None, + +# paddle only +# n_fft, +# normalized=False, +# onesided=True, +# return_complex=False, + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n_fft', 'hop_length', 'win_length', 'window', 'center', 'pad_mode', 'normalized', 'onesided'), + [ + ('test_1d_input', rand_x(1, np.float64, shape=[160000]), + 512, None, None, get_window('hann', 512), True, 'reflect', False, True), + ('test_2d_input', rand_x(2, np.float64, shape=[1, 160000]), + 512, None, None, get_window('hann', 512), True, 'reflect', False, True), + ('test_hop_length', rand_x(2, np.float64, shape=[1, 160000]), + 512, 255, None, get_window('hann', 512), True, 'reflect', False, True), + ('test_win_length', rand_x(2, np.float64, shape=[1, 160000]), + 512, 255, 499, get_window('hann', 499), True, 'reflect', False, True), + ('test_window', rand_x(2, np.float64, shape=[1, 160000]), + 512, None, None, None, True, 'reflect', False, True), + ('test_center', rand_x(2, np.float64, shape=[1, 160000]), + 512, None, None, None, False, 'reflect', False, True), + ]) +class TestStft(unittest.TestCase): + def test_stft(self): + if self.window is None: + win_p = None + win_l = 'boxcar' # rectangular window + else: + win_p = paddle.to_tensor(self.window) + win_l = self.window + + self.assertTrue( + np.allclose( + stft(self.x, self.n_fft, self.hop_length, self.win_length, win_l, self.center, self.pad_mode), + paddle.tensor.signal.stft( + paddle.to_tensor(self.x), + self.n_fft, + self.hop_length, + self.win_length, + win_p, + self.center, + self.pad_mode, + self.normalized, + self.onesided), + rtol=rtol.get(str(self.x.dtype)), + atol=atol.get(str(self.x.dtype)))) + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n_fft', 'hop_length', 'win_length', 'window', 'center', 'pad_mode', 'normalized', 'onesided', 'expect_exception'), + [ + ('test_dims', rand_x(1, np.float64, shape=[1, 2, 3]), + 512, None, None, None, True, 'reflect', False, True, AssertionError), + ('test_hop_length', rand_x(1, np.float64, shape=[16000]), + 512, 0, None, None, True, 'reflect', False, True, AssertionError), + ('test_nfft1', rand_x(1, np.float64, shape=[16000]), + 0, None, None, None, True, 'reflect', False, True, AssertionError), + ('test_nfft2', rand_x(1, np.float64, shape=[16000]), + 16001, None, None, None, True, 'reflect', False, True, AssertionError), + ('test_win_length', rand_x(1, np.float64, shape=[16000]), + 512, None, 0, None, True, 'reflect', False, True, AssertionError), + ('test_win_length', rand_x(1, np.float64, shape=[16000]), + 512, None, 513, None, True, 'reflect', False, True, AssertionError), + ('test_pad_mode', rand_x(1, np.float64, shape=[16000]), + 512, None, None, None, True, 'nonsense', False, True, AssertionError), + ('test_complex_onesided', rand_x(1, np.float64, shape=[16000], complex=True), + 512, None, None, None, False, 'reflect', False, True, AssertionError), + ]) +class TestStftException(unittest.TestCase): + def test_stft(self): + if self.window is None: + win_p = None + else: + win_p = paddle.to_tensor(self.window) + + with self.assertRaises(self.expect_exception): + paddle.tensor.signal.stft( + paddle.to_tensor(self.x), + self.n_fft, + self.hop_length, + self.win_length, + win_p, + self.center, + self.pad_mode, + self.normalized, + self.onesided), + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n_fft', 'hop_length', 'win_length', 'window', 'center', 'normalized', 'onesided', 'length', 'return_complex'), + [ + ('test_2d_input', rand_x(2, np.float64, shape=[257, 471], complex=True), + 512, None, None, get_window('hann', 512), True, False, True, None, False), + ('test_3d_input', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), + 512, None, None, get_window('hann', 512), True, False, True, None, False), + ('test_hop_length', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), + 512, 99, None, get_window('hann', 512), True, False, True, None, False), + ('test_win_length', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), + 512, 99, 299, get_window('hann', 299), True, False, True, None, False), + ('test_window', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), + 512, None, None, None, True, False, True, None, False), + ('test_center', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), + 512, None, None, None, False, False, True, None, False), + ('test_length', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), + 512, None, None, None, False, False, True, 1888, False), + ]) +class TestIstft(unittest.TestCase): + def test_istft(self): + if self.window is None: + win_p = None + win_l = 'boxcar' # rectangular window + else: + win_p = paddle.to_tensor(self.window) + win_l = self.window + + self.assertTrue( + np.allclose( + istft(self.x, self.hop_length, self.win_length, win_l, self.center, self.length), + paddle.tensor.signal.istft( + paddle.to_tensor(self.x), + self.n_fft, + self.hop_length, + self.win_length, + win_p, + self.center, + self.normalized, + self.onesided, + self.length, + self.return_complex), + rtol=rtol.get(str(self.x.dtype)), + atol=atol.get(str(self.x.dtype)))) + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'n_fft', 'hop_length', 'win_length', 'window', 'center', 'normalized', 'onesided', 'length', 'return_complex', 'expect_exception'), + [ + ('test_dims', rand_x(4, np.float64, shape=[1, 2, 3, 4], complex=True), + 512, None, None, get_window('hann', 512), True, False, True, None, False, AssertionError), + ('test_n_fft', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), + 257, None, None, get_window('hann', 512), True, False, True, None, False, AssertionError), + ('test_hop_length1', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), + 512, 0, None, get_window('hann', 512), True, False, True, None, False, AssertionError), + ('test_hop_length2', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), + 512, 513, None, get_window('hann', 512), True, False, True, None, False, AssertionError), + ('test_win_length1', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), + 512, None, 0, get_window('hann', 512), True, False, True, None, False, AssertionError), + ('test_win_length2', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), + 512, None, 513, get_window('hann', 512), True, False, True, None, False, AssertionError), + ('test_onesided1', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), + 20, None, None, get_window('hann', 512), True, False, True, None, False, AssertionError), + ('test_onesided2', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), + 256, None, None, None, True, False, False, None, False, AssertionError), + ('test_window', rand_x(3, np.float64, shape=[1, 512, 471], complex=True), + 512, None, 511, get_window('hann', 512), True, False, False, None, False, AssertionError), + ('test_return_complex1', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), + 512, None, None, get_window('hann', 512), True, False, True, None, True, AssertionError), + ('test_return_complex2', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), + 512, None, None, rand_x(1, np.float64, shape=[512], complex=True), True, False, True, None, False, AssertionError), + ('test_NOLA', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), + 512, 512, None, get_window('hann', 512), True, False, True, None, False, ValueError), + ]) +class TestIstftException(unittest.TestCase): + def test_istft(self): + if self.window is None: + win_p = None + else: + win_p = paddle.to_tensor(self.window) + + with self.assertRaises(self.expect_exception): + paddle.tensor.signal.istft( + paddle.to_tensor(self.x), + self.n_fft, + self.hop_length, + self.win_length, + win_p, + self.center, + self.normalized, + self.onesided, + self.length, + self.return_complex), + + +# yapf: enable + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 052ffb12d47..c0f7d88d3bf 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -216,6 +216,8 @@ from .array import array_write # noqa: F401 from .array import create_array # noqa: F401 from .einsum import einsum # noqa: F401 +from . import fft +from . import signal #this list used in math_op_patch.py for _binary_creator_ tensor_method_func = [ #noqa diff --git a/python/paddle/tensor/attribute.py b/python/paddle/tensor/attribute.py index 3a86b09c5c3..8d8c2a83de1 100644 --- a/python/paddle/tensor/attribute.py +++ b/python/paddle/tensor/attribute.py @@ -35,6 +35,41 @@ def _complex_to_real_dtype(dtype): return dtype +def _real_to_complex_dtype(dtype): + if dtype == core.VarDesc.VarType.FP32: + return core.VarDesc.VarType.COMPLEX64 + elif dtype == core.VarDesc.VarType.FP64: + return core.VarDesc.VarType.COMPLEX128 + else: + return dtype + + +def is_complex(x): + dtype = x.dtype + is_complex_dtype = (dtype == core.VarDesc.VarType.COMPLEX64 or + dtype == core.VarDesc.VarType.COMPLEX128) + return is_complex_dtype + + +def is_floating_point(x): + dtype = x.dtype + is_fp_dtype = (dtype == core.VarDesc.VarType.FP32 or + dtype == core.VarDesc.VarType.FP64 or + dtype == core.VarDesc.VarType.FP16 or + dtype == core.VarDesc.VarType.BF16) + return is_fp_dtype + + +def is_interger(x): + dtype = x.dtype + is_int_dtype = (dtype == core.VarDesc.VarType.UINT8 or + dtype == core.VarDesc.VarType.INT8 or + dtype == core.VarDesc.VarType.INT16 or + dtype == core.VarDesc.VarType.INT32 or + dtype == core.VarDesc.VarType.INT64) + return is_int_dtype + + def real(x, name=None): """ Returns a new tensor containing real values of the input tensor. diff --git a/python/paddle/tensor/fft.py b/python/paddle/tensor/fft.py new file mode 100644 index 00000000000..98ca858c0eb --- /dev/null +++ b/python/paddle/tensor/fft.py @@ -0,0 +1,1609 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Sequence +import numpy as np +import paddle +from .attribute import is_complex, is_floating_point, is_interger, _real_to_complex_dtype, _complex_to_real_dtype +from ..fluid.framework import in_dygraph_mode +from .. import _C_ops +from ..fluid.data_feeder import check_variable_and_dtype +from ..fluid.layer_helper import LayerHelper + +__all__ = [ + 'fft', + 'fft2', + 'fftn', + 'ifft', + 'ifft2', + 'ifftn', + 'rfft', + 'rfft2', + 'rfftn', + 'irfft', + 'irfft2', + 'irfftn', + 'hfft', + 'hfft2', + 'hfftn', + 'ihfft', + 'ihfft2', + 'ihfftn', + 'fftfreq', + 'rfftfreq', + 'fftshift', + 'ifftshift', +] + + +def _check_normalization(norm): + if norm not in ['forward', 'backward', 'ortho']: + raise ValueError( + "Unexpected norm: {}. Norm should be forward, backward or ortho". + format(norm)) + + +def _check_fft_n(n): + if not isinstance(n, int): + raise ValueError( + "Invalid FFT argument n({}), it shoule be an integer.".format(n)) + if n <= 0: + raise ValueError( + "Invalid FFT argument n({}), it should be positive.".format(n)) + + +def _check_fft_shape(x, s): + ndim = x.ndim + if not isinstance(s, Sequence): + raise ValueError( + "Invaid FFT argument s({}), it should be a sequence of integers.") + + if len(s) > ndim: + raise ValueError( + "Length of FFT argument s should not be larger than the rank of input. " + "Received s: {}, rank of x: {}".format(s, ndim)) + for size in s: + if not isinstance(size, int) or size <= 0: + raise ValueError("FFT sizes {} contains invalid value ({})".format( + s, size)) + + +def _check_fft_axis(x, axis): + ndim = x.ndim + if not isinstance(axis, int): + raise ValueError( + "Invalid FFT axis ({}), it shoule be an integer.".format(axis)) + if axis < -ndim or axis >= ndim: + raise ValueError( + "Invalid FFT axis ({}), it should be in range [-{}, {})".format( + axis, ndim, ndim)) + + +def _check_fft_axes(x, axes): + ndim = x.ndim + if not isinstance(axes, Sequence): + raise ValueError( + "Invalid FFT axes ({}), it should be a sequence of integers.". + format(axes)) + if len(axes) > ndim: + raise ValueError( + "Length of fft axes should not be larger than the rank of input. " + "Received, len of axes: {}, rank of x: {}".format(len(axes), ndim)) + for axis in axes: + if not isinstance(axis, int) or axis < -ndim or axis >= ndim: + raise ValueError( + "FFT axes {} contains invalid value ({}), it should be in range [-{}, {})". + format(axes, axis, ndim, ndim)) + + +def _resize_fft_input(x, s, axes): + if len(s) != len(axes): + raise ValueError("length of `s` should equals length of `axes`.") + shape = x.shape + ndim = x.ndim + + axes_to_pad = [] + paddings = [] + axes_to_slice = [] + slices = [] + for i, axis in enumerate(axes): + if shape[axis] < s[i]: + axes_to_pad.append(axis) + paddings.append(s[i] - shape[axis]) + elif shape[axis] > s[i]: + axes_to_slice.append(axis) + slices.append((0, s[i])) + + if axes_to_slice: + x = paddle.slice( + x, + axes_to_slice, + starts=[item[0] for item in slices], + ends=[item[1] for item in slices]) + if axes_to_pad: + padding_widths = [0] * (2 * ndim) + for axis, pad in zip(axes_to_pad, paddings): + padding_widths[2 * axis + 1] = pad + x = paddle.nn.functional.pad(x, padding_widths) + return x + + +def _normalize_axes(x, axes): + ndim = x.ndim + return [item if item >= 0 else (item + ndim) for item in axes] + + +def _check_at_least_ndim(x, rank): + if x.ndim < rank: + raise ValueError("The rank of the input ({}) should >= {}".format( + x.ndim, rank)) + + +# public APIs 1d +def fft(x, n=None, axis=-1, norm="backward", name=None): + """ + Calculate one-dimensional discrete Fourier transform. + + This function uses the efficient fast Fourier transform (FFT) algorithm [1] to + calculate the 1-D * n * point discrete Fourier transform (DFT). + + Args: + x (Tensor): The input data. It's a Tensor type. It's a complex. + n (int, optional): The length of the output transform axis. If `n` is less than + the length input, the input will be cropped. If larger, the input is filled + with zeros. If `n` is not given, the input length along the axis specified + by `axis` is used. + axis (int, optional): Axis used to calculate FFT. If not specified, the last axis + is used by default. + norm (str): Indicates which direction to scale the `forward` or `backward` transform + pair and what normalization factor to use. The parameter value must be one + of "forward" or "backward" or "ortho". Default is "backward", meaning no normalization on + the forward transforms and scaling by ``1/n`` on the `ifft`. "forward" instead applies + the ``1/n`` factor on the forward tranform. For ``norm="ortho"``, both directions are + scaled by ``1/sqrt(n)``. + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + complex tensor. The truncated or zero-padded input, transformed along the axis indicated + by `axis`, or the last one if `axis` is not specified. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = np.exp(3j * np.pi * np.arange(7) / 7) + xp = paddle.to_tensor(x) + fft_xp = paddle.fft.fft(xp).numpy() + print(fft_xp) + # [1.+1.25396034e+00j 1.+4.38128627e+00j 1.-4.38128627e+00j + # 1.-1.25396034e+00j 1.-4.81574619e-01j 1.+8.88178420e-16j + # 1.+4.81574619e-01j] + + + """ + if is_interger(x) or is_floating_point(x): + return fft_r2c( + x, n, axis, norm, forward=True, onesided=False, name=name) + else: + return fft_c2c(x, n, axis, norm, forward=True, name=name) + + +def ifft(x, n=None, axis=-1, norm="backward", name=None): + """ + Compute the 1-D inverse discrete Fourier Transform. + + This function computes the inverse of the 1-D *n*-point discrete Fourier transform + computed by `fft`. In other words, ``ifft(fft(x)) == x`` to within numerical accuracy. + + The input should be ordered in the same way as is returned by `fft`, + i.e., + + * ``x[0]`` should contain the zero frequency term, + * ``x[1:n//2]`` should contain the positive-frequency terms, + * ``x[n//2 + 1:]`` should contain the negative-frequency terms, in + increasing order starting from the most negative frequency. + + For an even number of input points, ``x[n//2]`` represents the sum of + the values at the positive and negative Nyquist frequencies, as the two + are aliased together. + + Args: + x (Tensor): The input data. It's a Tensor type. It's a complex. + n (int, optional): The length of the output transform axis. If `n` is less than + the length input, the input will be cropped. If larger, the input is filled + with zeros. If `n` is not given, the input length along the axis specified + by `axis` is used. + axis (int, optional): Axis used to calculate FFT. If not specified, the last axis + is used by default. + norm (str): Indicates which direction to scale the `forward` or `backward` transform + pair and what normalization factor to use. The parameter value must be one + of "forward" or "backward" or "ortho". Default is "backward", meaning no normalization on + the forward transforms and scaling by ``1/n`` on the `ifft`. "forward" instead applies + the ``1/n`` factor on the forward tranform. For ``norm="ortho"``, both directions are + scaled by ``1/sqrt(n)``. + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + complex tensor. The truncated or zero-padded input, transformed along the axis indicated + by `axis`, or the last one if `axis` is not specified. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = np.exp(3j * np.pi * np.arange(7) / 7) + xp = paddle.to_tensor(x) + ifft_xp = paddle.fft.ifft(xp).numpy() + print(ifft_xp) + # [0.14285714+1.79137191e-01j 0.14285714+6.87963741e-02j + # 0.14285714+1.26882631e-16j 0.14285714-6.87963741e-02j + # 0.14285714-1.79137191e-01j 0.14285714-6.25898038e-01j + # 0.14285714+6.25898038e-01j] + + """ + if is_interger(x) or is_floating_point(x): + return fft_r2c( + x, n, axis, norm, forward=False, onesided=False, name=name) + else: + return fft_c2c(x, n, axis, norm, forward=False, name=name) + + +def rfft(x, n=None, axis=-1, norm="backward", name=None): + """ + The one dimensional FFT for real input. + + This function computes the one dimensional *n*-point discrete Fourier + Transform (DFT) of a real-valued tensor by means of an efficient algorithm + called the Fast Fourier Transform (FFT). + + When the DFT is computed for purely real input, the output is + Hermitian-symmetric. This function does not compute the negative frequency + terms, and the length of the transformed axis of the output is therefore + ``n//2 + 1``. + + Args: + x(Tensor) : Real-valued input tensor + n(int, optional): Number of points along transformation axis in the + input to use. If `n` is smaller than the length of the input, the + input is cropped. If it is larger, the input is padded with zeros. + If `n` is not given, the length of the input along the axis + specified by `axis` is used. + axis(int, optional): Axis over which to compute the FFT. Default value + is last axis. + norm(str, optional) : Normalization mode, indicates which direction of + the forward/backward pair of transforms is scaled and with what + normalization factor. Include {"backward", "ortho", "forward"}, + default value is "backward". + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name` . + + Returns: + out(Tensor) : complex tensor + + Raises: + + + Examples: + .. code-block:: python + import paddle + + x = paddle.to_tensor([0.0, 1.0, 0.0, 0.0]) + print(paddle.fft.rfft(x)) + # Tensor(shape=[3], dtype=complex64, place=CUDAPlace(0), stop_gradient=True, + # [ (1+0j), -1j , (-1+0j)]) + """ + return fft_r2c(x, n, axis, norm, forward=True, onesided=True, name=name) + + +def irfft(x, n=None, axis=-1, norm="backward", name=None): + """ + Computes the inverse of `rfft`. + + This function calculates the inverse of the one-dimensional *n* point discrete + Fourier transform of the actual input calculated by "rfft". In other words, + ``irfft(rfft(a),len(a)) == a`` is within the numerical accuracy range. + + The input shall be in the form of "rfft", i.e. the actual zero frequency term, + followed by the complex positive frequency term, in the order of increasing frequency. + Because the discrete Fourier transform of the actual input is Hermite symmetric, + the negative frequency term is regarded as the complex conjugate term of the corresponding + positive frequency term. + + Args: + x (Tensor): The input data. It's a Tensor type. It's a complex. + n (int, optional): The length of the output transform axis. For `n` output + points, ``n//2 + 1``input points are necessary. If the length of the input tensor is greater + than `n`, it will be cropped, if it is shorter than this, fill in zero. If `n` is not given, + it is considered to be ``2 * (k-1)``, where ``k`` is the length of the input axis specified + along the ` axis'. + axis (int, optional): Axis used to calculate FFT. If not specified, the last axis + is used by default. + norm (str): Indicates which direction to scale the `forward` or `backward` transform + pair and what normalization factor to use. The parameter value must be one + of "forward" or "backward" or "ortho". Default is "backward". + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name` . + + Returns: + Real tensor. Truncated or zero fill input for the transformation along the axis indicated by + `axis`, or the last input if `axis` is not specified. The length of the conversion axis + is `n`, or ``2 * k-2``, if `k` is None, where `k` is the length of the input conversion axis. + If the output is an odd number, you need to specify the value of 'n', such as ``2 * k-1`` + in some cases. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = np.array([1, -1j, -1]) + xp = paddle.to_tensor(x) + irfft_xp = paddle.fft.irfft(xp).numpy() + print(irfft_xp) + # [0. 0. 0. 4.] + + """ + return fft_c2r(x, n, axis, norm, forward=False, name=name) + + +def hfft(x, n=None, axis=-1, norm="backward", name=None): + """ + Compute the FFT of a signal that has Hermitian symmetry, a real + spectrum. + + Args: + x (Tensor): The input data. It's a Tensor type. It's a complex. + n (int, optional): The length of the output transform axis. For `n` output + points, ``n//2 + 1`` input points are necessary. If the length of the input tensor is greater + than `n`, it will be cropped, if it is shorter than this, fill in zero. If `n` is not given, + it is considered to be ``2 * (k-1)``, where ``k`` is the length of the input axis specified + along the ` axis'. + axis (int,optional): Axis used to calculate FFT. If not specified, the last axis + is used by default. + norm (str): Indicates which direction to scale the `forward` or `backward` transform + pair and what normalization factor to use. The parameter value must be one + of "forward" or "backward" or "ortho". Default is "backward". + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name` . + + Returns: + Real tensor. Truncated or zero fill input for the transformation along the axis indicated by + `axis`, or the last input if `axis` is not specified. The length of the conversion axis + is `n`, or ``2 * k-2``, if `k` is None, where `k` is the length of the input conversion axis. + If the output is an odd number, you need to specify the value of 'n', such as ``2 * k-1`` in + some cases. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = np.array([1, -1j, -1]) + xp = paddle.to_tensor(x) + hfft_xp = paddle.fft.hfft(xp).numpy() + print(hfft_xp) + # [0. 0. 0. 4.] + """ + + return fft_c2r(x, n, axis, norm, forward=True, name=name) + + +def ihfft(x, n=None, axis=-1, norm="backward", name=None): + """ + The inverse FFT of a signal that has Hermitian symmetry. + + This function computes the one dimensional *n*-point inverse FFT of a signal + that has Hermitian symmetry by means of an efficient algorithm called + the Fast Fourier Transform (FFT). + + When the DFT is computed for purely real input, the output is + Hermitian-symmetric. This function does not compute the negative frequency + terms, and the length of the transformed axis of the output is therefore + ``n//2 + 1``. + + Args: + x(Tensor): Input tensor. + n(int, optional): The number of points along transformation axis in the + input to use. If `n` is smaller than the length of the input, the + input is cropped. If it is larger, the input is padded with zeros. + If `n` is not given, the length of the input along the axis + specified by `axis` is used. + axis(int, optional) : Axis over which to compute the inverse FFT. If not + given, the last axis is used. + norm(str, optional) : Normalization mode, indicates which direction of + the forward/backward pair of transforms is scaled and with what + normalization factor. Include {"backward", "ortho", "forward"}, + default value is "backward". + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name` . + + Returns: + out(Tensor) : complex tensor. + + Examples: + .. code-block:: python + import paddle + + spectrum = paddle.to_tensor([10.0, -5.0, 0.0, -1.0, 0.0, -5.0]) + print(paddle.fft.ifft(spectrum)) + # Tensor(shape=[6], dtype=complex64, place=CUDAPlace(0), stop_gradient=True, + # [(-0.1666666716337204+0j), (1-1.9868215517249155e-08j), (2.3333334922790527-1.9868215517249155e-08j), (3.5+0j), (2.3333334922790527+1.9868215517249155e-08j), (1+1.9868215517249155e-08j)]) + print(paddle.fft.ihfft(spectrum)) + # Tensor(shape = [4], dtype = complex64, place = CUDAPlace(0), stop_gradient = True, + # [(-0.1666666716337204+0j), (1-1.9868215517249155e-08j), (2.3333334922790527-1.9868215517249155e-08j), (3.5+0j)]) + + """ + return fft_r2c(x, n, axis, norm, forward=False, onesided=True, name=name) + + +# public APIs nd +def fftn(x, s=None, axes=None, norm="backward", name=None): + """ + Compute the N-D discrete Fourier Transform. + + This function calculates the n-D discrete Fourier transform on any number of axes + in the M-D array by fast Fourier transform (FFT). + + Args: + x (Tensor): The input data. It's a Tensor type. It's a complex. + s (sequence of ints, optional): Shape (length of each transformed axis) of the output + (``s[0]`` refers to axis 0, ``s[1]`` to axis 1, etc.). + This corresponds to ``n`` for ``fft(x, n)``. + Along any axis, if the given shape is smaller than that of the input, + the input is cropped. If it is larger, the input is padded with zeros. + if `s` is not given, the shape of the input along the axes specified + by `axes` is used. + axes (sequence of ints, optional): Axes used to calculate FFT. If not given, the last ``len(s)`` + axes are used, or all axes if `s` is also not specified. + norm (str): Indicates which direction to scale the `forward` or `backward` transform + pair and what normalization factor to use. The parameter value must be one + of "forward" or "backward" or "ortho". Default is "backward", meaning no normalization on + the forward transforms and scaling by ``1/n`` on the `ifft`. "forward" instead applies + the ``1/n`` factor on the forward tranform. For ``norm="ortho"``, both directions are + scaled by ``1/sqrt(n)``. + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + complex tensor. The truncated or zero-padded input, transformed along the axes indicated by + `axes`, or by a combination of `s` and `x`, as explained in the parameters section above. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = x = np.mgrid[:4, :4, :4][1] + xp = paddle.to_tensor(x) + fftn_xp = paddle.fft.fftn(xp, axes=(1, 2)).numpy() + print(fftn_xp) + # [[[24.+0.j 0.+0.j 0.+0.j 0.-0.j] + # [-8.+8.j 0.+0.j 0.+0.j 0.-0.j] + # [-8.+0.j 0.+0.j 0.+0.j 0.-0.j] + # [-8.-8.j 0.+0.j 0.+0.j 0.-0.j]] + # [[24.+0.j 0.+0.j 0.+0.j 0.-0.j] + # [-8.+8.j 0.+0.j 0.+0.j 0.-0.j] + # [-8.+0.j 0.+0.j 0.+0.j 0.-0.j] + # [-8.-8.j 0.+0.j 0.+0.j 0.-0.j]] + # [[24.+0.j 0.+0.j 0.+0.j 0.-0.j] + # [-8.+8.j 0.+0.j 0.+0.j 0.-0.j] + # [-8.+0.j 0.+0.j 0.+0.j 0.-0.j] + # [-8.-8.j 0.+0.j 0.+0.j 0.-0.j]] + # [[24.+0.j 0.+0.j 0.+0.j 0.-0.j] + # [-8.+8.j 0.+0.j 0.+0.j 0.-0.j] + # [-8.+0.j 0.+0.j 0.+0.j 0.-0.j] + # [-8.-8.j 0.+0.j 0.+0.j 0.-0.j]]] + """ + if is_interger(x) or is_floating_point(x): + return fftn_r2c( + x, s, axes, norm, forward=True, onesided=False, name=name) + else: + return fftn_c2c(x, s, axes, norm, forward=True, name=name) + + +def ifftn(x, s=None, axes=None, norm="backward", name=None): + """ + Compute the N-D inverse discrete Fourier Transform. + + This function computes the inverse of the N-D discrete + Fourier Transform over any number of axes in an M-D array by + means of the Fast Fourier Transform (FFT). In other words, + ``ifftn(fftn(x)) == x`` to within numerical accuracy. + + The input, analogously to `ifft`, should be ordered in the same way as is + returned by `fftn`, i.e., it should have the term for zero frequency + in all axes in the low-order corner, the positive frequency terms in the + first half of all axes, the term for the Nyquist frequency in the middle + of all axes and the negative frequency terms in the second half of all + axes, in order of decreasingly negative frequency. + + Args: + x (Tensor): The input data. It's a Tensor type. It's a complex. + s (sequence of ints, optional): Shape (length of each transformed axis) of the output + (``s[0]`` refers to axis 0, ``s[1]`` to axis 1, etc.). + This corresponds to ``n`` for ``fft(x, n)``. + Along any axis, if the given shape is smaller than that of the input, + the input is cropped. If it is larger, the input is padded with zeros. + if `s` is not given, the shape of the input along the axes specified + by `axes` is used. + axes (sequence of ints, optional): Axes used to calculate FFT. If not given, the last ``len(s)`` + axes are used, or all axes if `s` is also not specified. + norm (str): Indicates which direction to scale the `forward` or `backward` transform + pair and what normalization factor to use. The parameter value must be one + of "forward" or "backward" or "ortho". Default is "backward", meaning no normalization on + the forward transforms and scaling by ``1/n`` on the `ifft`. "forward" instead applies + the ``1/n`` factor on the forward tranform. For ``norm="ortho"``, both directions are + scaled by ``1/sqrt(n)``. + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + complex tensor. The truncated or zero-padded input, transformed along the axes indicated by + `axes`, or by a combination of `s` and `x`, as explained in the parameters section above. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = np.eye(3) + xp = paddle.to_tensor(x) + ifftn_xp = paddle.fft.ifftn(xp, axes=(1,)).numpy() + print(ifftn_xp) + + # [[ 0.33333333+0.j 0.33333333+0.j 0.33333333-0.j ] + # [ 0.33333333+0.j -0.16666667+0.28867513j -0.16666667-0.28867513j] + # [ 0.33333333+0.j -0.16666667-0.28867513j -0.16666667+0.28867513j]] + + """ + if is_interger(x) or is_floating_point(x): + return fftn_r2c( + x, s, axes, norm, forward=False, onesided=False, name=name) + else: + return fftn_c2c(x, s, axes, norm, forward=False, name=name) + + +def rfftn(x, s=None, axes=None, norm="backward", name=None): + """ + The N dimensional FFT for real input. + + This function computes the N-dimensional discrete Fourier Transform over + any number of axes in an M-dimensional real array by means of the Fast + Fourier Transform (FFT). By default, all axes are transformed, with the + real transform performed over the last axis, while the remaining + transforms are complex. + + The transform for real input is performed over the last transformation + axis, as by `rfft`, then the transform over the remaining axes is + performed as by `fftn`. The order of the output is as for `rfft` for the + final transformation axis, and as for `fftn` for the remaining + transformation axes. + + Args: + x(Tensor) : Input tensor, taken to be real. + s(Sequence[int]) : Shape to use from the exec fft. The final element of + `s` corresponds to `n` for ``rfft(x, n)``, while for the remaining + axes, it corresponds to `n` for ``fft(x, n)``. Along any axis, if + the given shape is smaller than that of the input, the input is + cropped. If it is larger, the input is padded with zeros. if `s` is + not given, the shape of the input along the axes specified by `axes` + is used. + axes(Sequence[int]) : Axes over which to compute the FFT. If not given, + the last ``len(s)`` axes are used, or all axes if `s` is also not + specified. + norm(str, optional) : Normalization mode, indicates which direction of + the forward/backward pair of transforms is scaled and with what + normalization factor. Include {"backward", "ortho", "forward"}, + default value is "backward". + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name` . + + Returns: + out(Tensor): complex tensor + + + Raises: + ValueError: If `s` and `axes` have different length. + + Examples: + .. code-block:: python + import paddle + + # default, all axis will be used to exec fft + x = paddle.ones((2, 3, 4)) + print(paddle.fft.rfftn(x)) + # Tensor(shape=[2, 3, 3], dtype=complex64, place=CUDAPlace(0), stop_gradient=True, + # [[[(24+0j), 0j , 0j ], + # [0j , 0j , 0j ], + # [0j , 0j , 0j ]], + # + # [[0j , 0j , 0j ], + # [0j , 0j , 0j ], + # [0j , 0j , 0j ]]]) + + # use axes(2, 0) + print(paddle.fft.rfftn(x, axes=(2, 0))) + # Tensor(shape=[2, 3, 3], dtype=complex64, place=CUDAPlace(0), stop_gradient=True, + # [[[(24+0j), 0j , 0j ], + # [0j , 0j , 0j ], + # [0j , 0j , 0j ]], + # + # [[0j , 0j , 0j ], + # [0j , 0j , 0j ], + # [0j , 0j , 0j ]]]) + + """ + return fftn_r2c(x, s, axes, norm, forward=True, onesided=True, name=name) + + +def irfftn(x, s=None, axes=None, norm="backward", name=None): + """ + Computes the inverse of `rfftn`. + + This function computes the inverse of the N-D discrete + Fourier Transform for real input over any number of axes in an + M-D array by means of the Fast Fourier Transform (FFT). In + other words, ``irfftn(rfftn(x), x.shape) == x`` to within numerical + accuracy. (The ``a.shape`` is necessary like ``len(a)`` is for `irfft`, + and for the same reason.) + + The input should be ordered in the same way as is returned by `rfftn`, + i.e., as for `irfft` for the final transformation axis, and as for `ifftn` + along all the other axes. + + Args: + x (Tensor): The input data. It's a Tensor type. + s (sequence of ints, optional): The length of the output transform axis. + (``s[0]`` refers to axis 0, ``s[1]`` to axis 1, etc.). `s` is also the + number of input points used along this axis, except for the last axis, + where ``s[-1]//2+1`` points of the input are used. Along any axis, if + the shape indicated by `s` is smaller than that of the input, the input + is cropped. If it is larger, the input is padded with zeros. + If `s` is not given, the shape of the input along the axes specified by axes + is used. Except for the last axis which is taken to be ``2*(k-1)`` where + ``k`` is the length of the input along that axis. + axes (sequence of ints, optional): Axes over which to compute the inverse FFT. If not given, the last + `len(s)` axes are used, or all axes if `s` is also not specified. + norm (str): Indicates which direction to scale the `forward` or `backward` transform + pair and what normalization factor to use. The parameter value must be one + of "forward" or "backward" or "ortho". Default is "backward". + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Real tensor. The truncated or zero-padded input, transformed along the axes indicated by `axes`, + or by a combination of `s` or `x`, as explained in the parameters section above. The length of + each transformed axis is as given by the corresponding element of `s`, or the length of the input + in every axis except for the last one if `s` is not given. In the final transformed axis the length + of the output when `s` is not given is ``2*(m-1)``, where ``m`` is the length of the final + transformed axis of the input. To get an odd number of output points in the final axis, + `s` must be specified. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = (np.array([2, 2, 3]) + 1j * np.array([2, 2, 3])).astype(np.complex128) + xp = paddle.to_tensor(x) + irfftn_xp = paddle.fft.irfftn(xp).numpy() + print(irfftn_xp) + # [ 2.25 -1.25 0.25 0.75] + + """ + return fftn_c2r(x, s, axes, norm, forward=False, name=name) + + +def hfftn(x, s=None, axes=None, norm="backward", name=None): + """ + Compute the N-D FFT of Hermitian symmetric complex input, i.e., a + signal with a real spectrum. + + This function calculates the n-D discrete Fourier transform of Hermite symmetric + complex input on any axis in M-D array by fast Fourier transform (FFT). + In other words, ``ihfftn(hfftn(x, s)) == x is within the numerical accuracy range. + (``s`` here are ``x.shape`` and ``s[-1] = x.shape[- 1] * 2 - 1``. This is necessary + for the same reason that ``irfft` requires ``x.shape``.) + + Args: + x (Tensor): The input data. It's a Tensor type. + s (sequence of ints, optional): The length of the output transform axis. + (``s[0]`` refers to axis 0, ``s[1]`` to axis 1, etc.). `s` is also the + number of input points used along this axis, except for the last axis, + where ``s[-1]//2+1`` points of the input are used. Along any axis, if + the shape indicated by `s` is smaller than that of the input, the input + is cropped. If it is larger, the input is padded with zeros. + If `s` is not given, the shape of the input along the axes specified by axes + is used. Except for the last axis which is taken to be ``2*(k-1)`` where + ``k`` is the length of the input along that axis. + axes (sequence of ints, optional): Axes over which to compute the inverse FFT. If not given, the last + `len(s)` axes are used, or all axes if `s` is also not specified. + norm (str): Indicates which direction to scale the `forward` or `backward` transform + pair and what normalization factor to use. The parameter value must be one + of "forward" or "backward" or "ortho". Default is "backward". + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Real tensor. Truncate or zero fill input, transforming along the axis indicated by axis or + a combination of `s` or `X`. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = (np.array([2, 2, 3]) + 1j * np.array([2, 2, 3])).astype(np.complex128) + xp = paddle.to_tensor(x) + hfftn_xp = paddle.fft.hfftn(xp).numpy() + print(hfftn_xp) + # [ 9. 3. 1. -5.] + + + """ + return fftn_c2r(x, s, axes, norm, forward=True, name=name) + + +def ihfftn(x, s=None, axes=None, norm="backward", name=None): + """ + The n dimensional inverse FFT of a signal that has Hermitian symmetry. + + This function computes the n dimensional inverse FFT over any number of axes + in an M-dimensional of a signal that has Hermitian symmetry by means of an + efficient algorithm called the Fast Fourier Transform (FFT). + + Args: + x(Tensor): Input tensor. + s(Sequence[int], optional) : Shape (length along each transformed axis) + to use from the input. (``s[0]`` refers to axis 0, ``s[1]`` to axis + 1, etc.). Along any axis, if the given shape is smaller than that + of the input, the input is cropped. If it is larger, the input is + padded with zeros. if `s` is not given, the shape of the input + along the axes specified by `axes` is used. + axis(Sequence[int], optional) : Axis over which to compute the inverse FFT. If not + given, the last axis is used. + norm(str, optional) : Normalization mode, indicates which direction of + the forward/backward pair of transforms is scaled and with what + normalization factor. Include {"backward", "ortho", "forward"}, + default value is "backward". + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name` . + + Returns: + out(Tensor) : complex tensor. + + Examples: + .. code-block:: python + import paddle + + spectrum = paddle.to_tensor([10.0, -5.0, 0.0, -1.0, 0.0, -5.0]) + print(paddle.fft.ifft(spectrum)) + # Tensor(shape=[6], dtype=complex64, place=CUDAPlace(0), stop_gradient=True, + # [(-0.1666666716337204+0j), (1-1.9868215517249155e-08j), (2.3333334922790527-1.9868215517249155e-08j), (3.5+0j), (2.3333334922790527+1.9868215517249155e-08j), (1+1.9868215517249155e-08j)]) + print(paddle.fft.ihfft(spectrum)) + # Tensor(shape = [4], dtype = complex64, place = CUDAPlace(0), stop_gradient = True, + # [(-0.1666666716337204+0j), (1-1.9868215517249155e-08j), (2.3333334922790527-1.9868215517249155e-08j), (3.5+0j)]) + + """ + return fftn_r2c(x, s, axes, norm, forward=False, onesided=True, name=name) + + +# public APIs 2d +def fft2(x, s=None, axes=(-2, -1), norm="backward", name=None): + """ + Compute the 2-D discrete Fourier Transform + + This function computes the N-D discrete Fourier Transform + over any axes in an M-D array by means of the + Fast Fourier Transform (FFT). By default, the transform is computed over + the last two axes of the input array, i.e., a 2-dimensional FFT. + + Args: + x (Tensor): The input data. It's a Tensor type. + s (sequence of ints, optional): Shape (length of each transformed axis) of the output. + It should be a sequence of 2 integers. This corresponds to ``n`` for ``fft(x, n)``. + Along each axis, if the given shape is smaller than that of the input, + the input is cropped. If it is larger, the input is padded with zeros. + if `s` is not given, the shape of the input along the axes specified + by `axes` is used. Default is None. + axes (sequence of ints, optional): Axes over which to compute the FFT. It should be a + sequence of 2 integers. If not specified, the last two axes are used by default. + norm (str): Indicates which direction to scale the `forward` or `backward` transform + pair and what normalization factor to use. The parameter value must be one + of "forward" or "backward" or "ortho". Default is "backward". + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Complex tensor. The truncated or zero-padded input, transformed along the axes indicated by `axes`, + or the last two axes if `axes` is not given. + + Raises: + ValueError: if `s` not be a sequence of 2 integers or None. + ValueError: if `axes` not be a sequence of 2 integers or None. + ValueError: If the input dimension is smaller than 2. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = np.mgrid[:2, :2][1] + xp = paddle.to_tensor(x) + fft2_xp = paddle.fft.fft2(xp).numpy() + print(fft2_xp) + # [[ 2.+0.j -2.+0.j] + # [ 0.+0.j 0.+0.j]] + + """ + _check_at_least_ndim(x, 2) + if s is not None: + if not isinstance(s, Sequence) or len(s) != 2: + raise ValueError( + "Invalid FFT argument s ({}), it should be a sequence of 2 integers.". + format(s)) + if axes is not None: + if not isinstance(axes, Sequence) or len(axes) != 2: + raise ValueError( + "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.". + format(axes)) + return fftn(x, s, axes, norm, name) + + +def ifft2(x, s=None, axes=(-2, -1), norm="backward", name=None): + """ + Compute the 2-D inverse discrete Fourier Transform. + + This function computes the inverse of the 2-D discrete Fourier + Transform over any number of axes in an M-D array by means of + the Fast Fourier Transform (FFT). In other words, ``ifft2(fft2(x)) == x`` + to within numerical accuracy. By default, the inverse transform is + computed over the last two axes of the input array. + + The input, analogously to `ifft`, should be ordered in the same way as is + returned by `fft2`, i.e., it should have the term for zero frequency + in the low-order corner of the two axes, the positive frequency terms in + the first half of these axes, the term for the Nyquist frequency in the + middle of the axes and the negative frequency terms in the second half of + both axes, in order of decreasingly negative frequency. + + Args: + x (Tensor): The input data. It's a Tensor type. + s (sequence of ints, optional): Shape (length of each transformed axis) of the output. + It should be a sequence of 2 integers. This corresponds to ``n`` for ``fft(x, n)``. + Along each axis, if the given shape is smaller than that of the input, + the input is cropped. If it is larger, the input is padded with zeros. + if `s` is not given, the shape of the input along the axes specified + by `axes` is used. Default is None. + axes (sequence of ints, optional): Axes over which to compute the FFT. It should be a + sequence of 2 integers. If not specified, the last two axes are used by default. + norm (str): Indicates which direction to scale the `forward` or `backward` transform + pair and what normalization factor to use. The parameter value must be one + of "forward" or "backward" or "ortho". Default is "backward". + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Complex tensor. The truncated or zero-padded input, transformed along the axes indicated by `axes`, + or the last two axes if `axes` is not given. + + Raises: + ValueError: if `s` not be a sequence of 2 integers or None. + ValueError: if `axes` not be a sequence of 2 integers or None. + ValueError: If the input dimension is smaller than 2. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = np.mgrid[:2, :2][1] + xp = paddle.to_tensor(x) + ifft2_xp = paddle.fft.ifft2(xp).numpy() + print(ifft2_xp) + # [[ 0.5+0.j -0.5+0.j] + # [ 0. +0.j 0. +0.j]] + """ + _check_at_least_ndim(x, 2) + if s is not None: + if not isinstance(s, Sequence) or len(s) != 2: + raise ValueError( + "Invalid FFT argument s ({}), it should be a sequence of 2 integers.". + format(s)) + if axes is not None: + if not isinstance(axes, Sequence) or len(axes) != 2: + raise ValueError( + "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.". + format(axes)) + return ifftn(x, s, axes, norm, name) + + +def rfft2(x, s=None, axes=(-2, -1), norm="backward", name=None): + """ + The two dimensional FFT with real tensor input. + + This is really just `rfftn` with different default behavior. + For more details see `rfftn`. + + Args: + x(Tensor): Input tensor, taken to be real. + s(Sequence[int]) : Shape of the FFT. + axes(Sequence[int], optional): Axes over which to compute the FFT. + norm(str, optional) : {"backward", "ortho", "forward"}, + default is "backward". Indicates which direction of the + forward/backward pair of transforms is scaled and with what + normalization factor. + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name` . + + Returns: + out(Tensor): The result of the real 2-D FFT. + + Raises: + + + Examples: + + .. code-block:: python + import paddle + import numpy as np + + x = paddle.to_tensor(np.mgrid[:5, :5][0].astype(np.float32)) + print(paddle.fft.rfft2(x)) + # Tensor(shape=[5, 3], dtype=complex64, place=CUDAPlace(0), stop_gradient=True, + # [[ (50+0j) , (1.1920928955078125e-07+0j) , 0j ], + # [(-12.5+17.204774856567383j) , (-9.644234211236835e-08+7.006946134424652e-08j) , 0j ], + # [(-12.500000953674316+4.061495304107666j) , (3.6837697336977726e-08-1.1337477445749755e-07j), 0j ], + # [(-12.500000953674316-4.061495304107666j) , (3.6837697336977726e-08+1.1337477445749755e-07j), 0j ], + # [(-12.5-17.204774856567383j) , (-9.644234211236835e-08-7.006946134424652e-08j) , 0j ]]) + """ + _check_at_least_ndim(x, 2) + if s is not None: + if not isinstance(s, Sequence) or len(s) != 2: + raise ValueError( + "Invalid FFT argument s ({}), it should be a sequence of 2 integers.". + format(s)) + if axes is not None: + if not isinstance(axes, Sequence) or len(axes) != 2: + raise ValueError( + "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.". + format(axes)) + return rfftn(x, s, axes, norm, name) + + +def irfft2(x, s=None, axes=(-2, -1), norm="backward", name=None): + """ + Computes the inverse of `rfft2`. + + Args: + x (Tensor): The input data. It's a Tensor type. + s (sequence of ints, optional): Shape of the real output to the inverse FFT. Default is None. + axes (sequence of ints, optional): The axes over which to compute the inverse FFT. Axes + must be two-dimensional. If not specified, the last two axes are used by default. + norm (str): Indicates which direction to scale the `forward` or `backward` transform + pair and what normalization factor to use. The parameter value must be one + of "forward" or "backward" or "ortho". Default is "backward". + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name` . + + Returns: + Real tensor. The result of the inverse real 2-D FFT. + + Raises: + ValueError: if `s` not be a sequence of 2 integers or None. + ValueError: if `axes` not be a sequence of 2 integers or None. + ValueError: If the input dimension is smaller than 2. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = (np.array([[3,2,3],[2, 2, 3]]) + 1j * np.array([[3,2,3],[2, 2, 3]])).astype(np.complex128) + xp = paddle.to_tensor(x) + irfft2_xp = paddle.fft.irfft2(xp).numpy() + print(irfft2_xp) + # [[ 2.375 -1.125 0.375 0.875] + # [ 0.125 0.125 0.125 0.125]] + + """ + _check_at_least_ndim(x, 2) + if s is not None: + if not isinstance(s, Sequence) or len(s) != 2: + raise ValueError( + "Invalid FFT argument s ({}), it should be a sequence of 2 integers.". + format(s)) + if axes is not None: + if not isinstance(axes, Sequence) or len(axes) != 2: + raise ValueError( + "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.". + format(axes)) + return irfftn(x, s, axes, norm, name) + + +def hfft2(x, s=None, axes=(-2, -1), norm="backward", name=None): + """ + Compute the 2-D FFT of a Hermitian complex array. + + Args: + x (Tensor): The input data. It's a Tensor type. + s (sequence of ints, optional): Shape of the real output. Default is None. + axes (sequence of ints, optional): Axes over which to compute the FFT. Axes must be + two-dimensional. If not specified, the last two axes are used by default. + norm (str): Indicates which direction to scale the `forward` or `backward` transform + pair and what normalization factor to use. The parameter value must be one + of "forward" or "backward" or "ortho". Default is "backward". + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Real tensor. The real result of the 2-D Hermitian complex real FFT. + + Raises: + ValueError: if `s` not be a sequence of 2 integers or None. + ValueError: if `axes` not be a sequence of 2 integers or None. + ValueError: If the input dimension is smaller than 2. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = (np.array([[3,2,3],[2, 2, 3]]) + 1j * np.array([[3,2,3],[2, 2, 3]])).astype(np.complex128) + xp = paddle.to_tensor(x) + hfft2_xp = paddle.fft.hfft2(xp).numpy() + print(hfft2_xp) + # [[19. 7. 3. -9.] + # [ 1. 1. 1. 1.]] + + + """ + _check_at_least_ndim(x, 2) + if s is not None: + if not isinstance(s, Sequence) or len(s) != 2: + raise ValueError( + "Invalid FFT argument s ({}), it should be a sequence of 2 integers.". + format(s)) + if axes is not None: + if not isinstance(axes, Sequence) or len(axes) != 2: + raise ValueError( + "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.". + format(axes)) + return hfftn(x, s, axes, norm, name) + + +def ihfft2(x, s=None, axes=(-2, -1), norm="backward", name=None): + """ + Compute the two dimensional inverse FFT of a real spectrum. + + This is really `ihfftn` with different defaults. + For more details see `ihfftn`. + + Args: + x(Tensor): Input tensor + s(Sequence[int], optional): Shape of the real input to the inverse FFT. + axes(Sequance[int], optional): The axes over which to compute the + inverse fft. Default is the last two axes. + norm(str, optional): {"backward", "ortho", "forward"}. Default is + "backward". + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name` . + + Returns: + out(Tensor) : The result of the inverse real 2-D FFT. + """ + _check_at_least_ndim(x, 2) + if s is not None: + if not isinstance(s, Sequence) or len(s) != 2: + raise ValueError( + "Invalid FFT argument s ({}), it should be a sequence of 2 integers.". + format(s)) + if axes is not None: + if not isinstance(axes, Sequence) or len(axes) != 2: + raise ValueError( + "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.". + format(axes)) + return ihfftn(x, s, axes, norm, name) + + +# public APIs utilities +def fftfreq(n, d=1.0, dtype=None, name=None): + """ + Return the Discrete Fourier Transform sample frequencies. + + The returned float array `f` contains the frequency bin centers in cycles + per unit of the sample spacing (with zero at the start). For instance, if + the sample spacing is in seconds, then the frequency unit is cycles/second. + + Given input length `n` and a sample spacing `d`:: + + f = [0, 1, ..., n/2-1, -n/2, ..., -1] / (d*n) if n is even + f = [0, 1, ..., (n-1)/2, -(n-1)/2, ..., -1] / (d*n) if n is odd + + Args: + n (int): Dimension inputed. + d (scalar, optional): Sample spacing (inverse of the sampling rate). Defaults is 1. + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor. A tensor of length 'n' containing the sampling frequency. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = np.array([3, 1, 2, 2, 3], dtype=float) + scalar_temp = 0.5 + n = x.size + fftfreq_xp = paddle.fft.fftfreq(n, d=scalar_temp) + print(fftfreq_xp) + + # Tensor(shape=[5], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [ 0. , 0.40000001, 0.80000001, -0.80000001, -0.40000001]) + """ + + dtype = paddle.framework.get_default_dtype() + val = 1.0 / (n * d) + pos_max = (n + 1) // 2 + neg_max = n // 2 + indices = paddle.arange(-neg_max, pos_max, dtype=dtype, name=name) + indices = paddle.roll(indices, -neg_max, name=name) + return indices * val + + +def rfftfreq(n, d=1.0, dtype=None, name=None): + """ + Return the Discrete Fourier Transform sample frequencies. + + The returned floating-point array "F" contains the center of the frequency unit, + and the unit is the number of cycles of the sampling interval (the starting point is zero). + + Given input length `n` and a sample spacing `d`:: + + f = [0, 1, ..., n/2-1, n/2] / (d*n) if n is even + f = [0, 1, ..., (n-1)/2-1, (n-1)/2] / (d*n) if n is odd + + the Nyquist frequency component is considered to be positive. + + Args: + n (int): Dimension inputed. + d (scalar, optional): Sample spacing (inverse of the sampling rate). Defaults is 1. + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor. A tensor of length ``n//2 + 1`` containing the sample frequencies. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = np.array([3, 1, 2, 2, 3], dtype=float) + scalar_temp = 0.3 + n = x.size + rfftfreq_xp = paddle.fft.rfftfreq(n, d=scalar_temp) + print(rfftfreq_xp) + + # Tensor(shape=[3], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [0. , 0.66666669, 1.33333337]) + + """ + + dtype = paddle.framework.get_default_dtype() + val = 1.0 / (n * d) + pos_max = 1 + n // 2 + indices = paddle.arange(0, pos_max, dtype=dtype, name=name) + return indices * val + + +def fftshift(x, axes=None, name=None): + """ + Shift the zero-frequency component to the center of the spectrum. + + This function swaps half spaces for all the axes listed (all by default). + Note that ``y[0]`` is the Nyquist component only if ``len(x)`` is even. + + Args: + n (int): Dimension inputed. + axes (int|tuple, optional): The axis on which to move. The default is none, which moves all axes. + Default is None. + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor. The shifted tensor. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = np.array([3, 1, 2, 2, 3], dtype=float) + scalar_temp = 0.3 + n = x.size + fftfreq_xp = paddle.fft.fftfreq(n, d=scalar_temp) + res = paddle.fft.fftshift(fftfreq_xp).numpy() + print(res) + # [-1.3333334 -0.6666667 0. 0.6666667 1.3333334] + + """ + shape = paddle.shape(x) + if axes is None: + # shift all axes + rank = paddle.rank(x).reshape([1]) + axes = axes or paddle.arange(0, rank) + shifts = [size // 2 for size in shape] + elif isinstance(axes, int): + shifts = shape[axes] // 2 + else: + shifts = [shape[ax] // 2 for ax in axes] + return paddle.roll(x, shifts, axes, name=name) + + +def ifftshift(x, axes=None, name=None): + """ + The inverse of `fftshift`. Although the even length 'x' is the same, the function of the + odd length 'x' is different. An example. + + Args: + n (int): Dimension inputed. + axes (int|tuple, optional): The axis on which to move. The default is none, which moves all axes. + Default is None. + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor. The shifted tensor. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = np.array([3, 1, 2, 2, 3], dtype=float) + scalar_temp = 0.3 + n = x.size + fftfreq_xp = paddle.fft.fftfreq(n, d=scalar_temp) + res = paddle.fft.ifftshift(fftfreq_xp).numpy() + print(res) + # [ 1.3333334 -1.3333334 -0.6666667 0. 0.6666667] + + """ + shape = paddle.shape(x) + if axes is None: + # shift all axes + rank = paddle.rank(x).reshape([1]) + axes = axes or paddle.arange(0, rank) + shifts = [-size // 2 for size in shape] + elif isinstance(axes, int): + shifts = -shape[axes] // 2 + else: + shifts = [-shape[ax] // 2 for ax in axes] + return paddle.roll(x, shifts, axes, name=name) + + +# internal functions +def fft_c2c(x, n, axis, norm, forward, name): + if is_interger(x): + x = paddle.cast(x, _real_to_complex_dtype(paddle.get_default_dtype())) + elif is_floating_point(x): + x = paddle.cast(x, _real_to_complex_dtype(x.dtype)) + _check_normalization(norm) + + axis = axis or -1 + _check_fft_axis(x, axis) + axes = [axis] + axes = _normalize_axes(x, axes) + if n is not None: + _check_fft_n(n) + s = [n] + x = _resize_fft_input(x, s, axes) + op_type = 'fft_c2c' + + check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type) + if in_dygraph_mode(): + attrs = ('axes', axes, 'normalization', norm, 'forward', forward) + out = getattr(_C_ops, op_type)(x, *attrs) + else: + inputs = {'X': [x], } + attrs = {'axes': axes, 'normalization': norm, 'forward': forward} + helper = LayerHelper(op_type, **locals()) + dtype = helper.input_dtype(input_param_name='x') + out = helper.create_variable_for_type_inference(dtype) + outputs = {"Out": [out]} + helper.append_op( + type=op_type, inputs=inputs, outputs=outputs, attrs=attrs) + return out + + +def fft_r2c(x, n, axis, norm, forward, onesided, name): + if is_interger(x): + x = paddle.cast(x, paddle.get_default_dtype()) + _check_normalization(norm) + axis = axis or -1 + _check_fft_axis(x, axis) + axes = [axis] + axes = _normalize_axes(x, axes) + if n is not None: + _check_fft_n(n) + s = [n] + x = _resize_fft_input(x, s, axes) + op_type = 'fft_r2c' + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], op_type) + + if in_dygraph_mode(): + attrs = ('axes', axes, 'normalization', norm, 'forward', forward, + 'onesided', onesided) + out = getattr(_C_ops, op_type)(x, *attrs) + else: + inputs = {'X': [x], } + attrs = { + 'axes': axes, + 'normalization': norm, + 'forward': forward, + 'onesided': onesided, + } + helper = LayerHelper(op_type, **locals()) + dtype = helper.input_dtype(input_param_name='x') + out = helper.create_variable_for_type_inference( + _real_to_complex_dtype(dtype)) + outputs = {"Out": [out]} + helper.append_op( + type=op_type, inputs=inputs, outputs=outputs, attrs=attrs) + return out + + +def fft_c2r(x, n, axis, norm, forward, name): + if is_interger(x): + x = paddle.cast(x, _real_to_complex_dtype(paddle.get_default_dtype())) + elif is_floating_point(x): + x = paddle.cast(x, _real_to_complex_dtype(x.dtype)) + _check_normalization(norm) + axis = axis or -1 + _check_fft_axis(x, axis) + axes = [axis] + axes = _normalize_axes(x, axes) + if n is not None: + _check_fft_n(n) + s = [n // 2 + 1] + x = _resize_fft_input(x, s, axes) + op_type = 'fft_c2r' + check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type) + + if in_dygraph_mode(): + if n is not None: + attrs = ('axes', axes, 'normalization', norm, 'forward', forward, + 'last_dim_size', n) + else: + attrs = ('axes', axes, 'normalization', norm, 'forward', forward) + out = getattr(_C_ops, op_type)(x, *attrs) + else: + inputs = {'X': [x], } + attrs = {'axes': axes, 'normalization': norm, 'forward': forward} + if n is not None: + attrs['last_dim_size'] = n + helper = LayerHelper(op_type, **locals()) + dtype = helper.input_dtype(input_param_name='x') + out = helper.create_variable_for_type_inference( + _complex_to_real_dtype(dtype)) + outputs = {"Out": [out]} + helper.append_op( + type=op_type, inputs=inputs, outputs=outputs, attrs=attrs) + return out + + +def fftn_c2c(x, s, axes, norm, forward, name): + if is_interger(x): + x = paddle.cast(x, _real_to_complex_dtype(paddle.get_default_dtype())) + elif is_floating_point(x): + x = paddle.cast(x, _real_to_complex_dtype(x.dtype)) + _check_normalization(norm) + if s is not None: + _check_fft_shape(x, s) + + rank = x.ndim + if axes is None: + if s is None: + axes = list(range(rank)) + else: + fft_ndims = len(s) + axes = list(range(rank - fft_ndims, rank)) + else: + _check_fft_axes(x, axes) + axes = _normalize_axes(x, axes) + axes_argsoft = np.argsort(axes).tolist() + axes = [axes[i] for i in axes_argsoft] + if s is not None: + if len(s) != len(axes): + raise ValueError( + "Length of s ({}) and length of axes ({}) does not match.". + format(len(s), len(axes))) + s = [s[i] for i in axes_argsoft] + + if s is not None: + x = _resize_fft_input(x, s, axes) + op_type = 'fft_c2c' + check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type) + + if in_dygraph_mode(): + attrs = ('axes', axes, 'normalization', norm, 'forward', forward) + out = getattr(_C_ops, op_type)(x, *attrs) + else: + inputs = {'X': [x], } + attrs = {'axes': axes, 'normalization': norm, 'forward': forward} + helper = LayerHelper(op_type, **locals()) + dtype = helper.input_dtype(input_param_name='x') + out = helper.create_variable_for_type_inference(dtype) + outputs = {"Out": [out]} + helper.append_op( + type=op_type, inputs=inputs, outputs=outputs, attrs=attrs) + return out + + +def fftn_r2c(x, s, axes, norm, forward, onesided, name): + if is_interger(x): + x = paddle.cast(x, paddle.get_default_dtype()) + _check_normalization(norm) + if s is not None: + _check_fft_shape(x, s) + + rank = x.ndim + if axes is None: + if s is None: + axes = list(range(rank)) + else: + fft_ndims = len(s) + axes = list(range(rank - fft_ndims, rank)) + else: + _check_fft_axes(x, axes) + axes = _normalize_axes(x, axes) + axes_argsoft = np.argsort(axes[:-1]).tolist() + axes = [axes[i] for i in axes_argsoft] + [axes[-1]] + if s is not None: + if len(s) != len(axes): + raise ValueError( + "Length of s ({}) and length of axes ({}) does not match.". + format(len(s), len(axes))) + s = [s[i] for i in axes_argsoft] + [s[-1]] + + if s is not None: + x = _resize_fft_input(x, s, axes) + + op_type = 'fft_r2c' + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], op_type) + + if in_dygraph_mode(): + attrs = ('axes', axes, 'normalization', norm, 'forward', forward, + 'onesided', onesided) + out = getattr(_C_ops, op_type)(x, *attrs) + else: + inputs = {'X': [x], } + attrs = { + 'axes': axes, + 'normalization': norm, + 'forward': forward, + 'onesided': onesided, + } + helper = LayerHelper(op_type, **locals()) + dtype = helper.input_dtype(input_param_name='x') + out = helper.create_variable_for_type_inference( + _real_to_complex_dtype(dtype)) + outputs = {"Out": [out]} + helper.append_op( + type=op_type, inputs=inputs, outputs=outputs, attrs=attrs) + + return out + + +def fftn_c2r(x, s, axes, norm, forward, name): + if is_interger(x): + x = paddle.cast(x, _real_to_complex_dtype(paddle.get_default_dtype())) + elif is_floating_point(x): + x = paddle.cast(x, _real_to_complex_dtype(x.dtype)) + _check_normalization(norm) + if s is not None: + _check_fft_shape(x, s) + + rank = x.ndim + if axes is None: + if s is None: + axes = list(range(rank)) + else: + fft_ndims = len(s) + axes = list(range(rank - fft_ndims, rank)) + else: + _check_fft_axes(x, axes) + axes = _normalize_axes(x, axes) + axes_argsoft = np.argsort(axes[:-1]).tolist() + axes = [axes[i] for i in axes_argsoft] + [axes[-1]] + if s is not None: + if len(s) != len(axes): + raise ValueError( + "Length of s ({}) and length of axes ({}) does not match.". + format(len(s), len(axes))) + s = [s[i] for i in axes_argsoft] + [s[-1]] + + if s is not None: + fft_input_shape = list(s) + fft_input_shape[-1] = fft_input_shape[-1] // 2 + 1 + x = _resize_fft_input(x, fft_input_shape, axes) + + op_type = 'fft_c2r' + check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type) + + if in_dygraph_mode(): + if s: + attrs = ('axes', axes, 'normalization', norm, 'forward', forward, + 'last_dim_size', s[-1]) + else: + attrs = ('axes', axes, 'normalization', norm, 'forward', forward) + out = getattr(_C_ops, op_type)(x, *attrs) + else: + inputs = {'X': [x], } + attrs = {'axes': axes, 'normalization': norm, 'forward': forward} + if s: + attrs["last_dim_size"] = s[-1] + helper = LayerHelper(op_type, **locals()) + dtype = helper.input_dtype(input_param_name='x') + out = helper.create_variable_for_type_inference( + _complex_to_real_dtype(dtype)) + outputs = {"Out": [out]} + helper.append_op( + type=op_type, inputs=inputs, outputs=outputs, attrs=attrs) + return out diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 30477d20e75..4129a1060da 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -682,7 +682,7 @@ def roll(x, shifts, axis=None, name=None): axis = [axis] len_origin_shape = len(origin_shape) - if axis: + if axis is not None: for i in range(len(axis)): if axis[i] >= len_origin_shape or axis[i] < -len_origin_shape: raise ValueError( diff --git a/python/paddle/tensor/signal.py b/python/paddle/tensor/signal.py new file mode 100644 index 00000000000..86022a17483 --- /dev/null +++ b/python/paddle/tensor/signal.py @@ -0,0 +1,576 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import paddle + +from .attribute import is_complex, is_floating_point +from .fft import fft_r2c, fft_c2r, fft_c2c +from ..fluid.data_feeder import check_variable_and_dtype +from ..fluid.framework import in_dygraph_mode +from ..fluid.layer_helper import LayerHelper +from .. import _C_ops + +__all__ = [ + 'frame', + 'overlap_add', + 'stft', + 'istft', +] + + +def frame(x, frame_length, hop_length, axis=-1, name=None): + """ + Slice the N-dimensional (where N >= 1) input into (overlapping) frames. + + Args: + x (Tensor): The input data which is a N-dimensional (where N >= 1) Tensor + with shape `[..., seq_length]` or `[seq_length, ...]`. + frame_length (int): Length of the frame and `0 < frame_length <= x.shape[axis]`. + hop_length (int): Number of steps to advance between adjacent frames + and `0 < hop_length`. + axis (int, optional): Specify the axis to operate on the input Tensors. Its + value should be 0(the first dimension) or -1(the last dimension). If not + specified, the last axis is used by default. + + Returns: + The output frames tensor with shape `[..., frame_length, num_frames]` if `axis==-1`, + otherwise `[num_frames, frame_length, ...]` where + + `num_framse = 1 + (x.shape[axis] - frame_length) // hop_length` + + Examples: + + .. code-block:: python + + import paddle + from paddle.tensor.signal import frame + + # 1D + x = paddle.arange(8) + y0 = frame(x, frame_length=4, hop_length=2, axis=-1) # [4, 3] + # [[0, 2, 4], + # [1, 3, 5], + # [2, 4, 6], + # [3, 5, 7]] + + y1 = frame(x, frame_length=4, hop_length=2, axis=0) # [3, 4] + # [[0, 1, 2, 3], + # [2, 3, 4, 5], + # [4, 5, 6, 7]] + + # 2D + x0 = paddle.arange(16).reshape([2, 8]) + y0 = frame(x0, frame_length=4, hop_length=2, axis=-1) # [2, 4, 3] + # [[[0, 2, 4], + # [1, 3, 5], + # [2, 4, 6], + # [3, 5, 7]], + # + # [[8 , 10, 12], + # [9 , 11, 13], + # [10, 12, 14], + # [11, 13, 15]]] + + x1 = paddle.arange(16).reshape([8, 2]) + y1 = frame(x1, frame_length=4, hop_length=2, axis=0) # [3, 4, 2] + # [[[0 , 1 ], + # [2 , 3 ], + # [4 , 5 ], + # [6 , 7 ]], + # + # [4 , 5 ], + # [6 , 7 ], + # [8 , 9 ], + # [10, 11]], + # + # [8 , 9 ], + # [10, 11], + # [12, 13], + # [14, 15]]] + + # > 2D + x0 = paddle.arange(32).reshape([2, 2, 8]) + y0 = frame(x0, frame_length=4, hop_length=2, axis=-1) # [2, 2, 4, 3] + + x1 = paddle.arange(32).reshape([8, 2, 2]) + y1 = frame(x1, frame_length=4, hop_length=2, axis=0) # [3, 4, 2, 2] + """ + if axis not in [0, -1]: + raise ValueError(f'Unexpected axis: {axis}. It should be 0 or -1.') + + if not isinstance(frame_length, int) or frame_length <= 0: + raise ValueError( + f'Unexpected frame_length: {frame_length}. It should be an positive integer.' + ) + + if not isinstance(hop_length, int) or hop_length <= 0: + raise ValueError( + f'Unexpected hop_length: {hop_length}. It should be an positive integer.' + ) + + if frame_length > x.shape[axis]: + raise ValueError( + f'Attribute frame_length should be less equal than sequence length, ' + f'but got ({frame_length}) > ({x.shape[axis]}).') + + op_type = 'frame' + + if in_dygraph_mode(): + attrs = ('frame_length', frame_length, 'hop_length', hop_length, 'axis', + axis) + op = getattr(_C_ops, op_type) + out = op(x, *attrs) + else: + check_variable_and_dtype( + x, 'x', ['int32', 'int64', 'float16', 'float32', + 'float64'], op_type) + helper = LayerHelper(op_type, **locals()) + dtype = helper.input_dtype(input_param_name='x') + out = helper.create_variable_for_type_inference(dtype=dtype) + helper.append_op( + type=op_type, + inputs={'X': x}, + attrs={ + 'frame_length': frame_length, + 'hop_length': hop_length, + 'axis': axis + }, + outputs={'Out': out}) + return out + + +def overlap_add(x, hop_length, axis=-1, name=None): + """ + Reconstructs a tensor consisted of overlap added sequences from input frames. + + Args: + x (Tensor): The input data which is a N-dimensional (where N >= 2) Tensor + with shape `[..., frame_length, num_frames]` or + `[num_frames, frame_length ...]`. + hop_length (int): Number of steps to advance between adjacent frames and + `0 < hop_length <= frame_length`. + axis (int, optional): Specify the axis to operate on the input Tensors. Its + value should be 0(the first dimension) or -1(the last dimension). If not + specified, the last axis is used by default. + + Returns: + The output frames tensor with shape `[..., seq_length]` if `axis==-1`, + otherwise `[seq_length, ...]` where + + `seq_length = (n_frames - 1) * hop_length + frame_length` + + Examples: + + .. code-block:: python + + import paddle + from paddle.tensor.signal import overlap_add + + # 2D + x0 = paddle.arange(16).reshape([8, 2]) + # [[0 , 1 ], + # [2 , 3 ], + # [4 , 5 ], + # [6 , 7 ], + # [8 , 9 ], + # [10, 11], + # [12, 13], + # [14, 15]] + y0 = overlap_add(x0, hop_length=2, axis=-1) # [10] + # [0 , 2 , 5 , 9 , 13, 17, 21, 25, 13, 15] + + x1 = paddle.arange(16).reshape([2, 8]) + # [[0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ], + # [8 , 9 , 10, 11, 12, 13, 14, 15]] + y1 = overlap_add(x1, hop_length=2, axis=0) # [10] + # [0 , 1 , 10, 12, 14, 16, 18, 20, 14, 15] + + # > 2D + x0 = paddle.arange(32).reshape([2, 1, 8, 2]) + y0 = overlap_add(x0, hop_length=2, axis=-1) # [2, 1, 10] + + x1 = paddle.arange(32).reshape([2, 8, 1, 2]) + y1 = overlap_add(x1, hop_length=2, axis=0) # [10, 1, 2] + """ + if axis not in [0, -1]: + raise ValueError(f'Unexpected axis: {axis}. It should be 0 or -1.') + + if not isinstance(hop_length, int) or hop_length <= 0: + raise ValueError( + f'Unexpected hop_length: {hop_length}. It should be an positive integer.' + ) + + op_type = 'overlap_add' + + if in_dygraph_mode(): + attrs = ('hop_length', hop_length, 'axis', axis) + op = getattr(_C_ops, op_type) + out = op(x, *attrs) + else: + check_variable_and_dtype( + x, 'x', ['int32', 'int64', 'float16', 'float32', + 'float64'], op_type) + helper = LayerHelper(op_type, **locals()) + dtype = helper.input_dtype(input_param_name='x') + out = helper.create_variable_for_type_inference(dtype=dtype) + helper.append_op( + type=op_type, + inputs={'X': x}, + attrs={'hop_length': hop_length, + 'axis': axis}, + outputs={'Out': out}) + return out + + +def stft(x, + n_fft, + hop_length=None, + win_length=None, + window=None, + center=True, + pad_mode='reflect', + normalized=False, + onesided=True, + name=None): + """ + Short-time Fourier transform (STFT). + + The STFT computes the discrete Fourier transforms (DFT) of short overlapping + windows of the input using this formula: + + .. math:: + X_t[\omega] = \sum_{n = 0}^{N-1}% + \text{window}[n]\ x[t \times H + n]\ % + e^{-{2 \pi j \omega n}/{N}} + + Where: + - :math:`t`: The :math:`t`-th input window. + - :math:`\omega`: Frequency :math:`0 \leq \omega < \text{n\_fft}` for `onesided=False`, + or :math:`0 \leq \omega < \lfloor \text{n\_fft} / 2 \rfloor + 1` for `onesided=True`. + - :math:`N`: Value of `n_fft`. + - :math:`H`: Value of `hop_length`. + + Args: + x (Tensor): The input data which is a 1-dimensional or 2-dimensional Tensor with + shape `[..., seq_length]`. It can be a real-valued or a complex Tensor. + n_fft (int): The number of input samples to perform Fourier transform. + hop_length (int, optional): Number of steps to advance between adjacent windows + and `0 < hop_length`. Default: `None`(treated as equal to `n_fft//4`) + win_length (int, optional): The size of window. Default: `None`(treated as equal + to `n_fft`) + window (Tensor, optional): A 1-dimensional tensor of size `win_length`. It will + be center padded to length `n_fft` if `win_length < n_fft`. Default: `None`( + treated as a rectangle window with value equal to 1 of size `win_length`). + center (bool, optional): Whether to pad `x` to make that the + :math:`t \times hop\_length` at the center of :math:`t`-th frame. Default: `True`. + pad_mode (str, optional): Choose padding pattern when `center` is `True`. See + `paddle.nn.functional.pad` for all padding options. Default: `"reflect"` + normalized (bool, optional): Control whether to scale the output by `1/sqrt(n_fft)`. + Default: `False` + onesided (bool, optional): Control whether to return half of the Fourier transform + output that satisfies the conjugate symmetry condition when input is a real-valued + tensor. It can not be `True` if input is a complex tensor. Default: `True` + name (str, optional): The default value is None. Normally there is no need for user + to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + The complex STFT output tensor with shape `[..., n_fft//2 + 1, num_frames]`( + real-valued input and `onesided` is `True`) or `[..., n_fft, num_frames]`( + `onesided` is `False`) + + Exampels: + .. code-block:: python + + import paddle + from paddle.tensor.signal import stft + + # real-valued input + x = paddle.randn([8, 48000], dtype=paddle.float64) + y1 = stft(x, n_fft=512) # [8, 257, 376] + y2 = stft(x, n_fft=512, onesided=False) # [8, 512, 376] + + # complex input + x = paddle.randn([8, 48000], dtype=paddle.float64) + \ + paddle.randn([8, 48000], dtype=paddle.float64)*1j # [8, 48000] complex128 + y1 = stft(x, n_fft=512, center=False, onesided=False) # [8, 512, 372] + """ + check_variable_and_dtype( + x, 'x', ['float16', 'float32', 'float64', 'complex64', 'complex128'], + 'stft') + + x_rank = len(x.shape) + assert x_rank in [1, 2], \ + f'x should be a 1D or 2D real tensor, but got rank of x is {x_rank}' + + if x_rank == 1: # (batch, seq_length) + x = x.unsqueeze(0) + + if hop_length is None: + hop_length = int(n_fft // 4) + + assert hop_length > 0, \ + f'hop_length should be > 0, but got {hop_length}.' + + if win_length is None: + win_length = n_fft + + assert 0 < n_fft <= x.shape[-1], \ + f'n_fft should be in (0, seq_length({x.shape[-1]})], but got {n_fft}.' + + assert 0 < win_length <= n_fft, \ + f'win_length should be in (0, n_fft({n_fft})], but got {win_length}.' + + if window is not None: + assert len(window.shape) == 1 and len(window) == win_length, \ + f'expected a 1D window tensor of size equal to win_length({win_length}), but got window with shape {window.shape}.' + else: + window = paddle.ones(shape=(win_length, ), dtype=x.dtype) + + if win_length < n_fft: + pad_left = (n_fft - win_length) // 2 + pad_right = n_fft - win_length - pad_left + window = paddle.nn.functional.pad(window, + pad=[pad_left, pad_right], + mode='constant') + + if center: + assert pad_mode in ['constant', 'reflect'], \ + 'pad_mode should be "reflect" or "constant", but got "{}".'.format(pad_mode) + + pad_length = n_fft // 2 + # FIXME: Input `x` can be a complex tensor but pad does not supprt complex input. + x = paddle.nn.functional.pad(x.unsqueeze(-1), + pad=[pad_length, pad_length], + mode=pad_mode, + data_format="NLC").squeeze(-1) + + x_frames = frame(x=x, frame_length=n_fft, hop_length=hop_length, axis=-1) + x_frames = x_frames.transpose( + perm=[0, 2, + 1]) # switch n_fft to last dim, egs: (batch, num_frames, n_fft) + x_frames = x_frames * window + + norm = 'ortho' if normalized else 'backward' + if is_complex(x_frames): + assert not onesided, \ + 'onesided should be False when input or window is a complex Tensor.' + + if not is_complex(x): + out = fft_r2c( + x=x_frames, + n=None, + axis=-1, + norm=norm, + forward=True, + onesided=onesided, + name=name) + else: + out = fft_c2c( + x=x_frames, n=None, axis=-1, norm=norm, forward=True, name=name) + + out = out.transpose(perm=[0, 2, 1]) # (batch, n_fft, num_frames) + + if x_rank == 1: + out.squeeze_(0) + + return out + + +def istft(x, + n_fft, + hop_length=None, + win_length=None, + window=None, + center=True, + normalized=False, + onesided=True, + length=None, + return_complex=False, + name=None): + """ + Inverse short-time Fourier transform (ISTFT). + + Reconstruct time-domain signal from the giving complex input and window tensor when + nonzero overlap-add (NOLA) condition is met: + + .. math:: + \sum_{t = -\infty}^{\infty}% + \text{window}^2[n - t \times H]\ \neq \ 0, \ \text{for } all \ n + + Where: + - :math:`t`: The :math:`t`-th input window. + - :math:`N`: Value of `n_fft`. + - :math:`H`: Value of `hop_length`. + + Result of `istft` expected to be the inverse of `paddle.tensor.signal.stft`, but it is + not guaranteed to reconstruct a exactly realizible time-domain signal from a STFT + complex tensor which has been modified (via masking or otherwise). Therefore, `istft` + gives the [Griffin-Lim optimal estimate](https://ieeexplore.ieee.org/document/1164317) + (optimal in a least-squares sense) for the corresponding signal. + + Args: + x (Tensor): The input data which is a 2-dimensional or 3-dimensional **complesx** + Tensor with shape `[..., n_fft, num_frames]`. + n_fft (int): The size of Fourier transform. + hop_length (int, optional): Number of steps to advance between adjacent windows + from time-domain signal and `0 < hop_length < win_length`. Default: `None`( + treated as equal to `n_fft//4`) + win_length (int, optional): The size of window. Default: `None`(treated as equal + to `n_fft`) + window (Tensor, optional): A 1-dimensional tensor of size `win_length`. It will + be center padded to length `n_fft` if `win_length < n_fft`. It should be a + real-valued tensor if `return_complex` is False. Default: `None`(treated as + a rectangle window with value equal to 1 of size `win_length`). + center (bool, optional): It means that whether the time-domain signal has been + center padded. Default: `True`. + normalized (bool, optional): Control whether to scale the output by `1/sqrt(n_fft)`. + Default: `False` + onesided (bool, optional): It means that whether the input STFT tensor is a half + of the conjugate symmetry STFT tensor transformed from a real-valued signal + and `istft` will return a real-valued tensor when it is set to `True`. + Default: `True`. + length (int, optional): Specify the length of time-domain signal. Default: `None`( + treated as the whole length of signal). + return_complex (bool, optional): It means that whether the time-domain signal is + real-valued. If `return_complex` is set to `True`, `onesided` should be set to + `False` cause the output is complex. + name (str, optional): The default value is None. Normally there is no need for user + to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A tensor of least squares estimation of the reconstructed signal(s) with shape + `[..., seq_length]` + + Exampels: + .. code-block:: python + + import numpy as np + import paddle + from paddle.tensor.signal import stft, istft + + paddle.seed(0) + + # STFT + x = paddle.randn([8, 48000], dtype=paddle.float64) + y = stft(x, n_fft=512) # [8, 257, 376] + + # ISTFT + x_ = istft(y, n_fft=512) # [8, 48000] + + np.allclose(x, x_) # True + """ + check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], 'istft') + + x_rank = len(x.shape) + assert x_rank in [2, 3], \ + 'x should be a 2D or 3D complex tensor, but got rank of x is {}'.format(x_rank) + + if x_rank == 2: # (batch, n_fft, n_frames) + x = x.unsqueeze(0) + + if hop_length is None: + hop_length = int(n_fft // 4) + + if win_length is None: + win_length = n_fft + + # Assure no gaps between frames. + assert 0 < hop_length <= win_length, \ + 'hop_length should be in (0, win_length({})], but got {}.'.format(win_length, hop_length) + + assert 0 < win_length <= n_fft, \ + 'win_length should be in (0, n_fft({})], but got {}.'.format(n_fft, win_length) + + n_frames = x.shape[-1] + fft_size = x.shape[-2] + + if onesided: + assert (fft_size == n_fft // 2 + 1), \ + 'fft_size should be equal to n_fft // 2 + 1({}) when onesided is True, but got {}.'.format(n_fft // 2 + 1, fft_size) + else: + assert (fft_size == n_fft), \ + 'fft_size should be equal to n_fft({}) when onesided is False, but got {}.'.format(n_fft, fft_size) + + if window is not None: + assert len(window.shape) == 1 and len(window) == win_length, \ + 'expected a 1D window tensor of size equal to win_length({}), but got window with shape {}.'.format(win_length, window.shape) + else: + window = paddle.ones(shape=(win_length, )) + + if win_length < n_fft: + pad_left = (n_fft - win_length) // 2 + pad_right = n_fft - win_length - pad_left + # FIXME: Input `window` can be a complex tensor but pad does not supprt complex input. + window = paddle.nn.functional.pad(window, + pad=[pad_left, pad_right], + mode='constant') + + x = x.transpose( + perm=[0, 2, + 1]) # switch n_fft to last dim, egs: (batch, num_frames, n_fft) + norm = 'ortho' if normalized else 'backward' + + if return_complex: + assert not onesided, \ + 'onesided should be False when input(output of istft) or window is a complex Tensor.' + + out = fft_c2c(x=x, n=None, axis=-1, norm=norm, forward=False, name=None) + else: + assert not is_complex(window), \ + 'Data type of window should not be complex when return_complex is False.' + + if onesided is False: + x = x[:, :, :n_fft // 2 + 1] + out = fft_c2r(x=x, n=None, axis=-1, norm=norm, forward=False, name=None) + + out = overlap_add( + x=(out * window).transpose( + perm=[0, 2, 1]), # (batch, n_fft, num_frames) + hop_length=hop_length, + axis=-1) # (batch, seq_length) + + window_envelop = overlap_add( + x=paddle.tile( + x=window * window, repeat_times=[n_frames, 1]).transpose( + perm=[1, 0]), # (n_fft, num_frames) + hop_length=hop_length, + axis=-1) # (seq_length, ) + + if length is None: + if center: + out = out[:, (n_fft // 2):-(n_fft // 2)] + window_envelop = window_envelop[(n_fft // 2):-(n_fft // 2)] + else: + if center: + start = n_fft // 2 + else: + start = 0 + + out = out[:, start:start + length] + window_envelop = window_envelop[start:start + length] + + # Check whether the Nonzero Overlap Add (NOLA) constraint is met. + if window_envelop.abs().min().item() < 1e-11: + raise ValueError( + 'Abort istft because Nonzero Overlap Add (NOLA) condition failed. For more information about NOLA constraint please see `scipy.signal.check_NOLA`(https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.check_NOLA.html).' + ) + + out = out / window_envelop + + if x_rank == 2: + out.squeeze_(0) + + return out diff --git a/python/unittest_py/requirements.txt b/python/unittest_py/requirements.txt index 8fd1be69a3d..0a793fc64d0 100644 --- a/python/unittest_py/requirements.txt +++ b/python/unittest_py/requirements.txt @@ -6,6 +6,7 @@ gym opencv-python<=4.2.0.32 visualdl paddle2onnx>=0.4 -scipy +scipy>=1.6 prettytable distro +numpy>=1.20 -- GitLab