未验证 提交 11518a43 编写于 作者: F Feiyu Chan 提交者: GitHub

Add FFT related operators and APIs (#35665)

* 1. add interface for fft;
2. add data type predicate;
3. fix paddle.roll.

* add fft c2c cufft kernel

* implement argument checking & op calling parts for fft_c2c and fftn_c2c

* add operator and opmaker definitions

* only register float and double for cpu.

* add common code for implementing FFT, add pocketfft as a dependency

* add fft c2c cufft kernel function

* fix bugs in python interface

* add support for c2r, r2c operators, op makers, kernels and kernel functors.

* test and fix bugs

* 1. fft_c2c function: add support for onesided=False;
2. add complex<float>, complex<double> support for concat and flip.

* 1. fft: fix python api bugs;
2. shape_op: add support for complex data types.

* fft c2c cufft kernel done with complie and link

* fix shape_op, add mkl placeholder

* remove mkl

* complete fft c2c in gpu

* 1. implement mkl-based fft, FFTC2CFunctor and common function exec_fft;
2. change the design, add input and output typename as template parameter for all FFTFunctors, update pocketfft-based implementation.

* complete fft c2c on gpu in ND

* complete fft c2c on gpu in ND

* complete fft c2c backward in ND

* fix MKL-based implementation

* Add frame op and CPU/GPU kernels.

* Add frame op forward unittest.

* Add frame op forward unittest.

* Remove axis parameter in FrameFunctor.

* Add frame op grad CPU/GPU kernels and unittest.

* Add frame op grad CPU/GPU kernels and unittest.

* Update doc string.

* Update after review and remove librosa requirement in unittest.

* Update grad kernel.

* add fft_c2r op

* Remove data allocation in TransCompute function.

* add fft r2c onesided with cpu(pocketfft/mkl) and gpu

* last fft c2r functor

* fix C2R and R2C for cufft, becase the direction is not an option in these cases.

* add fft r2c onesided with cpu(pocketfft/mkl) and gpu

* fix bugs in python APIs

* fix fft_c2r grad kernal

* fix bugs in python APIs

* add cuda fft c2r grad kernal functor

* clean code

* fix fft_c2r python API

* fill fft r2c result with conjugate symmetry (#19)

fill fft r2c result with conjugate symmetry

* add placeholder for unittests (#24)

* simple parameterize test function by auto generate test case from parm list (#25)

* miscellaneous fixes for python APIs (#26)

* add placeholder for unittests

* resize fft inputs before computation is n or s is provided.

* add complex kernels for pad and pad_grad

* simplify argument checking.

* add type promotion

* add int to float or complex promotion

* fix output data type for static mode

* fix fft's input dtype dispatch, import fft to paddle

* fix typos in axes checking (#27)

* fix typos in axes checking

* fix argument checking (#28)

* fix argument checking

* Add C2R Python layer normal and abnormal use cases (#29)

* documents and single case

* test c2r case

* New C2R Python layer normal and exception use cases

* complete rfft,rfft2,rfftn,ihfft,ihfft2,ihfftn unittest and doc string (#30)

* Documentation of the common interfaces of c2r and c2c (#31)

* Documentation of the common interfaces of c2r and c2c

* clean c++ code  (#32)

* clean code

* Add numpy-based implementation of spectral ops (#33)

* add numpy reference implementation of spectral ops

* Add fft_c2r numpy based implementation for unittest. (#34)

* add fft_c2r numpy implementation

* Add deframe op and stft/istft api. (#23)

* Add frame api

* Add deframe op and kernels.

* Add stft and istft apis.

* Add deframe api. Update stft and istft apis.

* Fix bug in frame_from_librosa function when input dims >= 3

* Rename deframe to overlap_add.

* Update istft.

* Update after code review.

* Add overlap_add op and stft/istft api unittest (#35)

* Add overlap_add op unittest.

* Register complex kernels of squeeze/unsquuze op.

* Add stft/istft api unittest.

* Add unittest for fft helper functions (#36)

* add unittests for fft helper functions. add complex kernel for roll op.

* complete static graph unittest for all public api (#37)

* Unittest of op with FFT C2C, C2R and r2c added (#38)

* documents and single case

* test c2r case

* New C2R Python layer normal and exception use cases

* Documentation of the common interfaces of c2r and c2c

* Unittest of op with FFT C2C, C2R and r2c added
Co-authored-by: lijiaqi0612's avatarlijiaqi <lijiaqi0612@163.com>

* add fft related options to CMakeLists.txt

* fix typos and clean code (#39)

* fix invisible character in mkl branch and fix error in error message

* clean code: remove docstring from unittest for signal.py.

* always convert numpy array to paddle.Tensor to avoid comparing numpy dtype with paddle dtype. (#40)

* always convert numpy array to paddle.Tensor to avoid comparing numpy dtype with paddle dtype.

* fix CI Errors: numpy dtype comparison, thrust when cuda is not available (#41)

1. always convert numpy array to paddle.Tensor to avoid comparing numpy dtype with paddle dtype.
2. promote floating point tensor to complex tensor ior fft_c2c and fft_c2r;
3. fix unittest to catch UnImplementedError and RuntimeError;
4. fix compile error by avoid using thrust when cuda is not available.
5.  fix sample code, use paddle.fft instead of paddle.tensor.fft

* remove inclusion of thrust, add __all__ list for fft (#42)

* Add api doc and update unittest. (#43)

* Add doc strings.
* Update overlap_add op unittest

* fix MKL-based FFT implementation (#44)

* fix MKL-based FFT implementation, MKL CDFT's FORWARD DOMAIN is always REAL for R2C and C2R

* remove code for debug (#45)

* use dynload for cufft (#46)

* use std::ptrdiff_t as datatype of stride (instead of int64_t) to avoid argument mismatch on some platforms.

* add complex support for fill_zeros_like

* use dynload for cufft

* Update doc and unittest. (#47)

* Add doc of frame op and overlap_add op.

* Update unittest.

* use dynload for cufft (#48)

1. use dynload for cufft
2. fix unittest;
3. temporarily disable Rocm.

* fix conflicts and merge upstream (#49)

fix conflicts and merge upstream

* fix compile error: only link dyload_cuda when cuda is available (#50)

* fix compile error: only link dyload_cuda when cuda is available

* fix dynload for cufft on windows (#51)

1. fix dynload for cufft on windows;
2. fix unittests.

* add NOMINMAX to compile on windows (#52)

 add NOMINMAX to compile on windows

* explicitly specify capture mode for lambdas (#55)

 explicitly specify capture mode for lambdas

* fix fft sample (#53)

* fix fft sample

* update scipy and numpy version for unittests of fft (#56)

update scipy and numpy version for unittests of fft

* Add static graph unittests of frame and overlap_add api. (#57)

* Remove cache of cuFFT & Disable ONEMKL (#59)

1. replace numpy.fft with scipy.fft as numpy<1.20 not support ortho norm
2. remove cache of cufft plans;
3. enhance error checking.
4. default WITH_ONEMKL to OFF
Co-authored-by: Njeff41404 <jeff41404@gmail.com>
Co-authored-by: Nroot <root@bjyz-sys-gpu-kongming9.bjyz.baidu.com>
Co-authored-by: NKP <109694228@qq.com>
Co-authored-by: lijiaqi0612's avatarlijiaqi <lijiaqi0612@163.com>
Co-authored-by: NXiaoxu Chen <chenxx_id@163.com>
Co-authored-by: Nlijiaqi0612 <33169170+lijiaqi0612@users.noreply.github.com>
上级 01063218
...@@ -38,6 +38,8 @@ project(paddle CXX C) ...@@ -38,6 +38,8 @@ project(paddle CXX C)
# enable language CUDA # enable language CUDA
# TODO(Shibo Tao): remove find_package(CUDA) completely. # TODO(Shibo Tao): remove find_package(CUDA) completely.
find_package(CUDA QUIET) find_package(CUDA QUIET)
find_package(MKL CONFIG QUIET)
option(WITH_ONEMKL "Compile PaddlePaddle with oneMKL" OFF)
option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND}) option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND})
option(WITH_TENSORRT "Compile PaddlePaddle with NVIDIA TensorRT" OFF) option(WITH_TENSORRT "Compile PaddlePaddle with NVIDIA TensorRT" OFF)
option(WITH_XPU "Compile PaddlePaddle with BAIDU KUNLUN XPU" OFF) option(WITH_XPU "Compile PaddlePaddle with BAIDU KUNLUN XPU" OFF)
...@@ -225,6 +227,7 @@ option(WITH_STRIP "Strip so files of Whl packages" OFF) ...@@ -225,6 +227,7 @@ option(WITH_STRIP "Strip so files of Whl packages" OFF)
option(NEW_RELEASE_CUBIN "PaddlePaddle next-level release strategy for pypi cubin package" OFF) option(NEW_RELEASE_CUBIN "PaddlePaddle next-level release strategy for pypi cubin package" OFF)
option(NEW_RELEASE_JIT "PaddlePaddle next-level release strategy for backup jit package" OFF) option(NEW_RELEASE_JIT "PaddlePaddle next-level release strategy for backup jit package" OFF)
option(WITH_ASCEND_INT64 "Compile with int64 kernel for ascend NPU" OFF) option(WITH_ASCEND_INT64 "Compile with int64 kernel for ascend NPU" OFF)
option(WITH_POCKETFFT "Compile with pocketfft support" ON)
# PY_VERSION # PY_VERSION
if(NOT PY_VERSION) if(NOT PY_VERSION)
...@@ -373,6 +376,10 @@ if (WITH_MIPS) ...@@ -373,6 +376,10 @@ if (WITH_MIPS)
add_definitions(-DPADDLE_WITH_MIPS) add_definitions(-DPADDLE_WITH_MIPS)
endif() endif()
if (WITH_ONEMKL)
add_definitions(-DPADDLE_WITH_ONEMKL)
endif()
if (WITH_HETERPS) if (WITH_HETERPS)
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -faligned-new") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -faligned-new")
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
include(ExternalProject)
set(POCKETFFT_PATH "${THIRD_PARTY_PATH}/pocketfft" CACHE STRING "A path setting for external_pocketfft path.")
set(POCKETFFT_PREFIX_DIR ${POCKETFFT_PATH})
set(POCKETFFT_REPOSITORY https://gitlab.mpcdf.mpg.de/mtr/pocketfft.git)
set(POCKETFFT_TAG release_for_eigen)
SET(POCKETFFT_INCLUDE_DIR ${POCKETFFT_PREFIX_DIR}/src)
message("POCKETFFT_INCLUDE_DIR is ${POCKETFFT_INCLUDE_DIR}")
include_directories(${POCKETFFT_INCLUDE_DIR})
ExternalProject_Add(
extern_pocketfft
${EXTERNAL_PROJECT_LOG_ARGS}
${SHALLOW_CLONE}
GIT_REPOSITORY ${POCKETFFT_REPOSITORY}
GIT_TAG ${POCKETFFT_TAG}
PREFIX ${POCKETFFT_PREFIX_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
TEST_COMMAND ""
)
add_library(pocketfft INTERFACE)
add_dependencies(pocketfft extern_pocketfft)
...@@ -361,4 +361,10 @@ if (WITH_CRYPTO) ...@@ -361,4 +361,10 @@ if (WITH_CRYPTO)
add_definitions(-DPADDLE_WITH_CRYPTO) add_definitions(-DPADDLE_WITH_CRYPTO)
endif (WITH_CRYPTO) endif (WITH_CRYPTO)
if (WITH_POCKETFFT)
include(external/pocketfft)
list(APPEND third_party_deps extern_pocketfft)
add_definitions(-DPADDLE_WITH_POCKETFFT)
endif (WITH_POCKETFFT)
add_custom_target(third_party ALL DEPENDS ${third_party_deps}) add_custom_target(third_party ALL DEPENDS ${third_party_deps})
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <iostream>
#include <string> #include <string>
#include <typeindex> #include <typeindex>
...@@ -170,11 +171,26 @@ extern inline proto::VarType::Type ToComplexType(proto::VarType::Type t) { ...@@ -170,11 +171,26 @@ extern inline proto::VarType::Type ToComplexType(proto::VarType::Type t) {
return proto::VarType::COMPLEX128; return proto::VarType::COMPLEX128;
default: default:
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unknown complex value data type (%s), now only support float32 and " "Unknown real value data type (%s), now only support float32 and "
"float64.", "float64.",
DataTypeToString(t))); DataTypeToString(t)));
} }
} }
extern inline proto::VarType::Type ToRealType(proto::VarType::Type t) {
switch (t) {
case proto::VarType::COMPLEX64:
return proto::VarType::FP32;
case proto::VarType::COMPLEX128:
return proto::VarType::FP64;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unknown complex value data type (%s), now only support complex64 "
"and "
"complex128.",
DataTypeToString(t)));
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -59,6 +59,10 @@ if (WITH_GPU) ...@@ -59,6 +59,10 @@ if (WITH_GPU)
endif() endif()
endif() endif()
if (WITH_POCKETFFT)
SET(OP_HEADER_DEPS ${OP_HEADER_DEPS} pocketfft)
endif()
SET(OP_MKL_DEPS "") SET(OP_MKL_DEPS "")
if (NOT WITH_MKL OR NOT WITH_AVX) if (NOT WITH_MKL OR NOT WITH_AVX)
...@@ -75,7 +79,7 @@ if(WITH_UNITY_BUILD) ...@@ -75,7 +79,7 @@ if(WITH_UNITY_BUILD)
endif() endif()
register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op lstm_op run_program_op eye_op recurrent_op register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op lstm_op run_program_op eye_op recurrent_op
sync_batch_norm_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS}) sync_batch_norm_op spectral_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})
op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc DEPS executor_cache ${OP_HEADER_DEPS}) op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc DEPS executor_cache ${OP_HEADER_DEPS})
...@@ -94,6 +98,12 @@ else() ...@@ -94,6 +98,12 @@ else()
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
endif() endif()
if (WITH_GPU AND (NOT WITH_ROCM))
op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda ${OP_HEADER_DEPS})
else()
op_library(spectral_op SRCS spectral_op.cc DEPS ${OP_HEADER_DEPS})
endif()
op_library(lstm_op DEPS ${OP_HEADER_DEPS} lstm_compute) op_library(lstm_op DEPS ${OP_HEADER_DEPS} lstm_compute)
op_library(eye_op DEPS ${OP_HEADER_DEPS}) op_library(eye_op DEPS ${OP_HEADER_DEPS})
op_library(recurrent_op DEPS ${OP_HEADER_DEPS}) op_library(recurrent_op DEPS ${OP_HEADER_DEPS})
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/concat_op.h" #include "paddle/fluid/operators/concat_op.h"
#include <paddle/fluid/platform/complex.h>
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -237,7 +238,11 @@ REGISTER_OP_CPU_KERNEL( ...@@ -237,7 +238,11 @@ REGISTER_OP_CPU_KERNEL(
ops::ConcatKernel<paddle::platform::CPUDeviceContext, ops::ConcatKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>, paddle::platform::float16>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, int>, ops::ConcatKernel<paddle::platform::CPUDeviceContext, int>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, uint8_t>); ops::ConcatKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
concat_grad, concat_grad,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, double>, ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, double>,
...@@ -247,4 +252,8 @@ REGISTER_OP_CPU_KERNEL( ...@@ -247,4 +252,8 @@ REGISTER_OP_CPU_KERNEL(
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, ops::ConcatGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>, paddle::platform::float16>,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, int>, ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, uint8_t>); ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/concat_op.h" #include "paddle/fluid/operators/concat_op.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
...@@ -24,7 +25,11 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -24,7 +25,11 @@ REGISTER_OP_CUDA_KERNEL(
ops::ConcatKernel<paddle::platform::CUDADeviceContext, plat::float16>, ops::ConcatKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::ConcatKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, int>, ops::ConcatKernel<paddle::platform::CUDADeviceContext, int>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, uint8_t>); ops::ConcatKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext,
plat::complex<float>>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext,
plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
concat_grad, concat_grad,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, double>, ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, double>,
...@@ -33,4 +38,8 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -33,4 +38,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, plat::float16>, ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, int>, ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, uint8_t>); ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext,
plat::complex<float>>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext,
plat::complex<double>>);
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/eigen/eigen_function.h" #include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -42,6 +43,8 @@ template struct EigenScale<Eigen::DefaultDevice, int8_t>; ...@@ -42,6 +43,8 @@ template struct EigenScale<Eigen::DefaultDevice, int8_t>;
template struct EigenScale<Eigen::DefaultDevice, int16_t>; template struct EigenScale<Eigen::DefaultDevice, int16_t>;
template struct EigenScale<Eigen::DefaultDevice, int>; template struct EigenScale<Eigen::DefaultDevice, int>;
template struct EigenScale<Eigen::DefaultDevice, int64_t>; template struct EigenScale<Eigen::DefaultDevice, int64_t>;
template struct EigenScale<Eigen::DefaultDevice, platform::complex<float>>;
template struct EigenScale<Eigen::DefaultDevice, platform::complex<double>>;
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/eigen/eigen_function.h" #include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
...@@ -41,6 +42,8 @@ template struct EigenScale<Eigen::GpuDevice, int16_t>; ...@@ -41,6 +42,8 @@ template struct EigenScale<Eigen::GpuDevice, int16_t>;
template struct EigenScale<Eigen::GpuDevice, int>; template struct EigenScale<Eigen::GpuDevice, int>;
template struct EigenScale<Eigen::GpuDevice, int64_t>; template struct EigenScale<Eigen::GpuDevice, int64_t>;
template struct EigenScale<Eigen::GpuDevice, platform::float16>; template struct EigenScale<Eigen::GpuDevice, platform::float16>;
template struct EigenScale<Eigen::GpuDevice, platform::complex<float>>;
template struct EigenScale<Eigen::GpuDevice, platform::complex<double>>;
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/fill_zeros_like_op.h" #include "paddle/fluid/operators/fill_zeros_like_op.h"
#include "paddle/fluid/platform/complex.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -93,7 +94,11 @@ REGISTER_OP_CPU_KERNEL( ...@@ -93,7 +94,11 @@ REGISTER_OP_CPU_KERNEL(
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, float>, ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, float>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, double>, ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, double>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, bool>); ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, bool>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
fill_zeros_like2, fill_zeros_like2,
...@@ -101,4 +106,8 @@ REGISTER_OP_CPU_KERNEL( ...@@ -101,4 +106,8 @@ REGISTER_OP_CPU_KERNEL(
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, float>, ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, float>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, double>, ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, double>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, bool>); ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, bool>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/fill_zeros_like_op.h" #include "paddle/fluid/operators/fill_zeros_like_op.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
...@@ -25,7 +26,11 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -25,7 +26,11 @@ REGISTER_OP_CUDA_KERNEL(
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, double>, ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, double>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>, paddle::platform::float16>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, bool>); ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, bool>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
fill_zeros_like2, fill_zeros_like2,
...@@ -35,4 +40,8 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -35,4 +40,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, double>, ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, double>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>, paddle::platform::float16>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, bool>); ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, bool>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/complex.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -145,6 +146,7 @@ class FlipOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -145,6 +146,7 @@ class FlipOpGradMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(flip, ops::FlipOp, ops::FlipOpMaker, ops::FlipOpInferVarType, REGISTER_OPERATOR(flip, ops::FlipOp, ops::FlipOpMaker, ops::FlipOpInferVarType,
ops::FlipOpGradMaker<paddle::framework::OpDesc>, ops::FlipOpGradMaker<paddle::framework::OpDesc>,
ops::FlipOpGradMaker<paddle::imperative::OpBase>); ops::FlipOpGradMaker<paddle::imperative::OpBase>);
...@@ -153,7 +155,9 @@ REGISTER_OP_CPU_KERNEL( ...@@ -153,7 +155,9 @@ REGISTER_OP_CPU_KERNEL(
ops::FlipKernel<paddle::platform::CPUDeviceContext, double>, ops::FlipKernel<paddle::platform::CPUDeviceContext, double>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, int32_t>, ops::FlipKernel<paddle::platform::CPUDeviceContext, int32_t>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::FlipKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, bool>); ops::FlipKernel<paddle::platform::CPUDeviceContext, bool>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, plat::complex<float>>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, plat::complex<double>>);
/* ========================== register checkpoint ===========================*/ /* ========================== register checkpoint ===========================*/
REGISTER_OP_VERSION(flip) REGISTER_OP_VERSION(flip)
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/complex.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -163,4 +164,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -163,4 +164,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::FlipKernel<paddle::platform::CUDADeviceContext, plat::float16>, ops::FlipKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::FlipKernel<paddle::platform::CUDADeviceContext, int>, ops::FlipKernel<paddle::platform::CUDADeviceContext, int>,
ops::FlipKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::FlipKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::FlipKernel<paddle::platform::CUDADeviceContext, bool>); ops::FlipKernel<paddle::platform::CUDADeviceContext, bool>,
ops::FlipKernel<paddle::platform::CUDADeviceContext, plat::complex<float>>,
ops::FlipKernel<paddle::platform::CUDADeviceContext,
plat::complex<double>>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/frame_op.h"
namespace paddle {
namespace operators {
class FrameOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "frame");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "frame");
const int frame_length = ctx->Attrs().Get<int>("frame_length");
const int hop_length = ctx->Attrs().Get<int>("hop_length");
const int axis = ctx->Attrs().Get<int>("axis");
const auto x_dims = ctx->GetInputDim("X");
const int x_rank = x_dims.size();
PADDLE_ENFORCE_GE(
x_rank, 1, platform::errors::InvalidArgument(
"Input(X) of FrameOp should be a tensor which contains "
"at least 1 dimension, but got rank %s.",
x_rank));
PADDLE_ENFORCE_GT(hop_length, 0,
platform::errors::InvalidArgument(
"Attribute(hop_length) of FrameOp should be greater "
"than 0, but got %s.",
hop_length));
PADDLE_ENFORCE_EQ(
(axis == 0 || axis == -1), true,
platform::errors::InvalidArgument(
"Attribute(axis) of FrameOp should 0 or -1, but got %s.", axis));
std::vector<int64_t> output_shape;
int seq_length;
int n_frames;
int start_axis;
int end_axis;
if (axis == 0) {
seq_length = x_dims[0];
start_axis = 1;
end_axis = x_rank - 1;
} else {
seq_length = x_dims[x_rank - 1];
start_axis = 0;
end_axis = x_rank - 2;
}
PADDLE_ENFORCE_LE(frame_length, seq_length,
platform::errors::InvalidArgument(
"Attribute(frame_length) of FrameOp should be less "
"equal than sequence length, but got (%s) > (%s).",
frame_length, seq_length));
// It won't go into for loop when x_rank == 1U.
for (int i = start_axis; i <= end_axis; i++) {
output_shape.push_back(x_dims[i]);
}
n_frames = 1 + (seq_length - frame_length) / hop_length;
if (axis == 0) {
// (n_frames, frame_length, ...)
output_shape.insert(output_shape.begin(), frame_length);
output_shape.insert(output_shape.begin(), n_frames);
} else {
// (..., frame_length, n_frames)
output_shape.push_back(frame_length);
output_shape.push_back(n_frames);
}
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(in_dtype, ctx.GetPlace());
}
};
class FrameOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of frame op.");
AddOutput("Out", "(Tensor), The output tensor of frame op.");
AddAttr<int>(
"frame_length",
"Length of the frame and `0 < frame_length <= x.shape[axis]`.");
AddAttr<int>("hop_length",
"Number of steps to advance between adjacent frames and "
"`0 < hop_length`.");
AddAttr<int>("axis",
"Specify the axis to operate on the input Tensors. Its value "
"should be 0(the first dimension) or -1(the last dimension).")
.SetDefault(-1);
AddComment(R"DOC(
Slice the N-dimensional (where N >= 1) input into (overlapping) frames.
)DOC");
}
};
class FrameOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "frame_grad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "frame_grad");
const auto x_dims = ctx->GetInputDim("X");
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(in_dtype, ctx.GetPlace());
}
};
template <typename T>
class FrameOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("frame_grad");
retv->SetInput("X", this->Input("X"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(frame, ops::FrameOp, ops::FrameOpMaker,
ops::FrameOpGradMaker<paddle::framework::OpDesc>,
ops::FrameOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(frame_grad, ops::FrameOpGrad);
REGISTER_OP_CPU_KERNEL(
frame, ops::FrameKernel<paddle::platform::CPUDeviceContext, int>,
ops::FrameKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::FrameKernel<paddle::platform::CPUDeviceContext, float>,
ops::FrameKernel<paddle::platform::CPUDeviceContext, double>,
ops::FrameKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::FrameKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
frame_grad, ops::FrameGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::FrameGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::FrameGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::FrameGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::FrameGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::FrameGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/frame_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
frame, ops::FrameKernel<paddle::platform::CUDADeviceContext, int>,
ops::FrameKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::FrameKernel<paddle::platform::CUDADeviceContext, float>,
ops::FrameKernel<paddle::platform::CUDADeviceContext, double>,
ops::FrameKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::FrameKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::FrameKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
frame_grad, ops::FrameGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::FrameGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::FrameGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::FrameGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::FrameGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::FrameGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::FrameGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/seq2col.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
struct FrameFunctor {
void operator()(const DeviceContext& dev_ctx, const Tensor* input,
Tensor* output, size_t seq_length, size_t frame_length,
size_t n_frames, size_t hop_length,
bool is_grad = false) const {
auto numel = output->numel();
const auto* input_data = input->data<T>();
auto* output_data = output->data<T>();
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
if (!is_grad) {
math::Seq2ColFunctor<T> functor(input_data, output_data, seq_length,
frame_length, n_frames, hop_length);
for_range(functor);
} else {
math::Col2SeqFunctor<T> functor(input_data, output_data, seq_length,
frame_length, n_frames, hop_length);
for_range(functor);
}
}
};
template <typename DeviceContext, typename T>
class FrameKernel : public framework::OpKernel<T> {
public:
/*
Frame kernel slices frames from input sequences. The main steps as follows:
- Case 1 - input dims == 1:
- axis is -1: Call a FrameFunctor to compute directly.
- axis is 0: Transpose output firstly, and then it falls into
case axis is -1. Finally, it restores the dims of
output tensor.
- Case 2 - input dims == 2:
- axis is -1: Call a FrameFunctor to compute directly.
- axis is 0: Transpose both input and output firstly, and then it falls
into case axis is -1. Finally, it restores the dims of
output tensor.
- Case 3 - input dims > 2:
Flatten the input and output to 2D and 3D respectively so that it
falls into Case 2. Finally, it restores the dims of output tensor.
*/
void Compute(const framework::ExecutionContext& ctx) const override {
const Tensor* x = ctx.Input<Tensor>("X");
Tensor* out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
const size_t x_rank = x->dims().size();
const size_t out_rank = out->dims().size();
const int frame_length = ctx.Attr<int>("frame_length");
const int hop_length = ctx.Attr<int>("hop_length");
const int axis = ctx.Attr<int>("axis");
const int n_frames =
(axis == 0) ? out->dims()[0] : out->dims()[out_rank - 1];
const int seq_length = (axis == 0) ? x->dims()[0] : x->dims()[x_rank - 1];
auto& dev_ctx = ctx.device_context<DeviceContext>();
// When the number of input dims is larger than 2, it needs to copy
// from x to resize input into 2d and output into 3d. Morevoer, output
// dims will be restored at the last step.
Tensor x_(x->type());
x_ = *x;
framework::DDim preserved_dims;
if (x_rank > 2) {
// Save dims used to flatten both input and output tensors and restore
// output tensor.
framework::DDim x_resized_dims;
framework::DDim out_resized_dims;
if (axis == 0) {
preserved_dims = framework::slice_ddim(x_.dims(), 1, x_rank);
x_resized_dims = {seq_length, framework::product(preserved_dims)};
out_resized_dims = {n_frames, frame_length,
framework::product(preserved_dims)};
} else {
preserved_dims = framework::slice_ddim(x_.dims(), 0, x_rank - 1);
x_resized_dims = {framework::product(preserved_dims), seq_length};
out_resized_dims = {framework::product(preserved_dims), frame_length,
n_frames};
}
x_.Resize(x_resized_dims);
out->Resize(out_resized_dims);
}
Tensor trans_x(x_.type());
Tensor trans_out(out->type());
// Transpose input and output in case that axis is 0.
if (axis == 0) {
if (x_rank == 1U) {
trans_x = x_;
std::vector<int> perm_out{1, 0};
auto out_dims_vec = framework::vectorize(out->dims());
for (int i = 0; i < out->dims().size(); ++i) {
out_dims_vec[i] = out->dims()[perm_out[i]];
}
trans_out.Resize(framework::make_ddim(out_dims_vec));
trans_out.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(perm_out.size(), dev_ctx, *out,
&trans_out, perm_out);
} else {
std::vector<int> perm_x{1, 0};
auto x_dims_vec = framework::vectorize(x_.dims());
for (int i = 0; i < x_.dims().size(); ++i) {
x_dims_vec[i] = x_.dims()[perm_x[i]];
}
trans_x.Resize(framework::make_ddim(x_dims_vec));
trans_x.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(perm_x.size(), dev_ctx, x_, &trans_x,
perm_x);
std::vector<int> perm_out{2, 1, 0};
auto out_dims_vec = framework::vectorize(out->dims());
for (int i = 0; i < out->dims().size(); ++i) {
out_dims_vec[i] = out->dims()[perm_out[i]];
}
trans_out.Resize(framework::make_ddim(out_dims_vec));
trans_out.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(perm_out.size(), dev_ctx, *out,
&trans_out, perm_out);
}
} else {
trans_x = x_;
trans_out = *out;
}
FrameFunctor<DeviceContext, T>()(dev_ctx, &trans_x, &trans_out, seq_length,
frame_length, n_frames, hop_length,
/*is_grad*/ false);
// Transpose output in case axis is 0.
if (axis == 0) {
if (x_rank == 1U) {
std::vector<int> perm_out{1, 0};
TransCompute<DeviceContext, T>(perm_out.size(), dev_ctx, trans_out, out,
perm_out);
} else {
std::vector<int> perm_out{2, 1, 0};
TransCompute<DeviceContext, T>(perm_out.size(), dev_ctx, trans_out, out,
perm_out);
}
}
// Restore output dims when the number of dims is larger than 2.
if (x_rank > 2) {
std::vector<int64_t> restored_out_shape;
for (int i = 0; i < preserved_dims.size(); i++) {
restored_out_shape.push_back(preserved_dims[i]);
}
if (axis == 0) {
// (n_frames, frame_length, ...)
restored_out_shape.insert(restored_out_shape.begin(), frame_length);
restored_out_shape.insert(restored_out_shape.begin(), n_frames);
} else {
// (..., frame_length, n_frames)
restored_out_shape.push_back(frame_length);
restored_out_shape.push_back(n_frames);
}
out->Resize(framework::make_ddim(restored_out_shape));
}
}
};
template <typename DeviceContext, typename T>
class FrameGradKernel : public framework::OpKernel<T> {
public:
/*
Frame gradient kernel accumulate gradient `d_x` from `d_out`. The
main steps as follows:
- Case 1 - d_x dims == 1:
- axis is -1: Call a FrameFunctor to compute directly. Notes that
`is_grad` is set to true to select gradient data functor.
- axis is 0: Transpose `d_out` firstly, and then it falls into
case axis is -1.
- Case 2 - d_x dims == 2:
- axis is -1: Call a FrameFunctor to compute directly.
- axis is 0: Transpose both `d_x` and `d_out` firstly, and then it
falls into case axis is -1. Finally, it restores the
dims of `d_x`.
- Case 3 - d_x dims > 2:
Flatten the `d_x` and `d_out` to 2D and 3D respectively so that it
falls into Case 2. Finally, it restores the dims of `d_x` tensor.
*/
void Compute(const framework::ExecutionContext& ctx) const {
const Tensor* d_out = ctx.Input<Tensor>(framework::GradVarName("Out"));
Tensor* d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
d_x->mutable_data<T>(ctx.GetPlace());
const size_t d_out_rank = d_out->dims().size();
const size_t d_x_rank = d_x->dims().size();
const int frame_length = ctx.Attr<int>("frame_length");
const int hop_length = ctx.Attr<int>("hop_length");
const int axis = ctx.Attr<int>("axis");
const int n_frames =
(axis == 0) ? d_out->dims()[0] : d_out->dims()[d_out_rank - 1];
const int seq_length =
(axis == 0) ? d_x->dims()[0] : d_x->dims()[d_x_rank - 1];
auto& dev_ctx = ctx.device_context<DeviceContext>();
Tensor d_out_(d_out->type());
d_out_ = *d_out;
framework::DDim preserved_dims;
if (d_x_rank > 2) {
// Save dims used to flatten both input and output tensors and restore
// output tensor.
framework::DDim d_x_resized_dims;
framework::DDim d_out_resized_dims;
if (axis == 0) {
preserved_dims = framework::slice_ddim(d_x->dims(), 1, d_x_rank);
d_x_resized_dims = {seq_length, framework::product(preserved_dims)};
d_out_resized_dims = {n_frames, frame_length,
framework::product(preserved_dims)};
} else {
preserved_dims = framework::slice_ddim(d_x->dims(), 0, d_x_rank - 1);
d_x_resized_dims = {framework::product(preserved_dims), seq_length};
d_out_resized_dims = {framework::product(preserved_dims), frame_length,
n_frames};
}
d_x->Resize(d_x_resized_dims);
d_out_.Resize(d_out_resized_dims);
}
Tensor trans_d_x(d_x->type());
Tensor trans_d_out(d_out_.type());
// Transpose input and output in case that axis is 0.
if (axis == 0) {
if (d_x_rank == 1U) {
trans_d_x = *d_x;
std::vector<int> perm_d_out{1, 0};
auto d_out_dims_vec = framework::vectorize(d_out_.dims());
for (int i = 0; i < d_out_.dims().size(); ++i) {
d_out_dims_vec[i] = d_out_.dims()[perm_d_out[i]];
}
trans_d_out.Resize(framework::make_ddim(d_out_dims_vec));
trans_d_out.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(perm_d_out.size(), dev_ctx, d_out_,
&trans_d_out, perm_d_out);
} else {
std::vector<int> perm_d_x{1, 0};
auto d_x_dims_vec = framework::vectorize(d_x->dims());
for (int i = 0; i < d_x->dims().size(); ++i) {
d_x_dims_vec[i] = d_x->dims()[perm_d_x[i]];
}
trans_d_x.Resize(framework::make_ddim(d_x_dims_vec));
trans_d_x.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(perm_d_x.size(), dev_ctx, *d_x,
&trans_d_x, perm_d_x);
std::vector<int> perm_d_out{2, 1, 0};
auto d_out_dims_vec = framework::vectorize(d_out_.dims());
for (int i = 0; i < d_out_.dims().size(); ++i) {
d_out_dims_vec[i] = d_out_.dims()[perm_d_out[i]];
}
trans_d_out.Resize(framework::make_ddim(d_out_dims_vec));
trans_d_out.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(perm_d_out.size(), dev_ctx, d_out_,
&trans_d_out, perm_d_out);
}
} else {
trans_d_x = *d_x;
trans_d_out = d_out_;
}
FrameFunctor<DeviceContext, T>()(dev_ctx, &trans_d_out, &trans_d_x,
seq_length, frame_length, n_frames,
hop_length,
/*is_grad*/ true);
// Transpose output in case axis is 0.
if (axis == 0 && d_x_rank > 1U) {
std::vector<int> perm_d_x{1, 0};
TransCompute<DeviceContext, T>(perm_d_x.size(), dev_ctx, trans_d_x, d_x,
perm_d_x);
}
// Restore output dims when the number of dims is larger than 2.
if (d_x_rank > 2) {
std::vector<int64_t> restored_d_x_shape;
for (int i = 0; i < preserved_dims.size(); i++) {
restored_d_x_shape.push_back(preserved_dims[i]);
}
if (axis == 0) {
// (seq_length, ...)
restored_d_x_shape.insert(restored_d_x_shape.begin(), seq_length);
} else {
// (..., seq_length)
restored_d_x_shape.push_back(seq_length);
}
d_x->Resize(framework::make_ddim(restored_d_x_shape));
}
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace paddle {
namespace operators {
namespace math {
template <typename T>
struct Seq2ColFunctor {
Seq2ColFunctor(const T* seq, T* col, size_t seq_length, size_t frame_length,
size_t n_frames, size_t hop_length)
: seq_(seq),
col_(col),
seq_length_(seq_length),
frame_length_(frame_length),
n_frames_(n_frames),
hop_length_(hop_length) {}
/*
Convert sequences to frames.
1. Dimension infomation:
Sequences Frames
(N, seq_length) -> (N, frame_length, n_frames)
2. Mapping from `i` to `src_idx` and `trg_idx` can be derived from:
a. Notion
- `i` stands for the flattened index of a bunch of frames.
- `src_idx` and `trg_idx` are the 1D indices of seqs and frames
respectivly.
b. Sample idx
```cpp
sample_idx = i / (n_frames_ * frame_length_);
```
c. Maps `i` to `f` and `n`.
```cpp
f = i % (n_frames_ * frame_length_) / n_frames_;
n = i % (n_frames_ * frame_length_) % n_frames_;
```
d. Replace `sample_idx`, `f` and `n` in the following eqations:
```cpp
src_idx = sample_idx * seq_length_ + n * hop_length_ + f;
trg_idx = sample_idx * n_frames_ * frame_length_ + f * n_frames_ + n;
col_[trg_idx] = seq_[src_idx];
```
e. Result can be deduced shown in the function body below.
*/
HOSTDEVICE void operator()(size_t i) const {
size_t src_idx;
size_t trg_idx;
src_idx = i / (n_frames_ * frame_length_) * seq_length_ +
i % (n_frames_ * frame_length_) % n_frames_ * hop_length_ +
i % (n_frames_ * frame_length_) / n_frames_;
trg_idx = i / (n_frames_ * frame_length_) * n_frames_ * frame_length_ +
i % (n_frames_ * frame_length_) / n_frames_ * n_frames_ +
i % (n_frames_ * frame_length_) % n_frames_;
col_[trg_idx] = seq_[src_idx];
}
const T* seq_;
T* col_;
size_t seq_length_;
size_t frame_length_;
size_t n_frames_;
size_t hop_length_;
};
template <typename T>
struct Col2SeqFunctor {
Col2SeqFunctor(const T* col, T* seq, size_t seq_length, size_t frame_length,
size_t n_frames, size_t hop_length)
: col_(col),
seq_(seq),
seq_length_(seq_length),
frame_length_(frame_length),
n_frames_(n_frames),
hop_length_(hop_length) {}
/*
Accumulate output gradient d_out to d_x.
1. Dimension infomation:
d_out d_x
(N, frame_length, n_frames) -> (N, seq_length)
2. Using a sliding window to find source indices from `d_out` according to
`i`:
a. Notion
- `i` stands for the flattened index of `d_x`.
- `seq_i` stands for a relative index of a `d_x` sample.
- `left`: Starting index of a frame window.
- `right`: Ending index of a frame window.
b. Sample idx
```cpp
sample_idx = i / seq_length_;
```
c. Slides a window with length of `frame_length` to find `f` and `n`.
- `n`: The idx of num_frames_, increases in each hop.
- `f`: The idx of frame_lengths_, relative idx from left of a sliding
window.
d. Accumulate all grads from d_out.
```cpp
seq_[i] +=
col_[sample_idx * frame_length_ * n_frames_ + f * n_frames_ + n];
```
*/
HOSTDEVICE void operator()(size_t i) const {
size_t sample_idx = i / seq_length_;
size_t seq_i = i % seq_length_;
// Sliding window
seq_[i] = 0; // Init seq_[i] to 0, and sums up all
// grads from col_ in the while loop.
size_t n = get_start_frame_idx(seq_i);
size_t f;
size_t left = n * hop_length_;
size_t right = left + frame_length_ - 1;
while (left <= seq_i && right < seq_length_) {
f = seq_i - left;
seq_[i] +=
col_[sample_idx * frame_length_ * n_frames_ + f * n_frames_ + n];
// Next frame.
left += hop_length_;
right += hop_length_;
n += 1;
}
}
/*
Calculate minimum value of frame index `n` to satisfy the inequality:
seq_i <= right
==> seq_i <= left + frame_length - 1
==> seq_i <= hop_length_ * n + frame_length_ - 1
*/
HOSTDEVICE size_t get_start_frame_idx(size_t seq_i) const {
int64_t tmp = seq_i + 1 - frame_length_;
if (tmp > 0) {
size_t n = tmp / hop_length_;
if (tmp % hop_length_ == 0) {
return n;
} else {
return n + 1;
}
} else {
return 0;
}
}
const T* col_;
T* seq_;
size_t seq_length_;
size_t frame_length_;
size_t n_frames_;
size_t hop_length_;
};
} // namespace math
} // namespace operators
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/overlap_add_op.h"
namespace paddle {
namespace operators {
class OverlapAddOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "overlap_add");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "overlap_add");
const int hop_length = ctx->Attrs().Get<int>("hop_length");
const int axis = ctx->Attrs().Get<int>("axis");
const auto x_dims = ctx->GetInputDim("X");
const int x_rank = x_dims.size();
PADDLE_ENFORCE_GE(
x_rank, 2,
platform::errors::InvalidArgument(
"Input(X) of OverlapAddOp should be a tensor which contains "
"at least 2 dimensions, but got rank %s.",
x_rank));
PADDLE_ENFORCE_GT(
hop_length, 0,
platform::errors::InvalidArgument(
"Attribute(hop_length) of OverlapAddOp should be greater "
"than 0, but got %s.",
hop_length));
PADDLE_ENFORCE_EQ(
(axis == 0 || axis == -1), true,
platform::errors::InvalidArgument(
"Attribute(axis) of OverlapAddOp should 0 or -1, but got %s.",
axis));
std::vector<int64_t> output_shape;
int n_frames;
int frame_length;
int start_axis;
int end_axis;
if (axis == 0) {
n_frames = x_dims[0];
frame_length = x_dims[1];
start_axis = 2;
end_axis = x_rank - 1;
} else {
n_frames = x_dims[x_rank - 1];
frame_length = x_dims[x_rank - 2];
start_axis = 0;
end_axis = x_rank - 3;
}
PADDLE_ENFORCE_LE(
hop_length, frame_length,
platform::errors::InvalidArgument(
"Attribute(hop_length) of OverlapAddOp should be less or equal "
"than frame_length, but got hop_length(%s) > frame_length(%s).",
hop_length, frame_length));
const int seq_length = (n_frames - 1) * hop_length + frame_length;
// It won't go into for loop when x_rank == 2U.
for (int i = start_axis; i <= end_axis; i++) {
output_shape.push_back(x_dims[i]);
}
if (axis == 0) {
// (seq_length, ...)
output_shape.insert(output_shape.begin(), seq_length);
} else {
// (..., seq_length)
output_shape.push_back(seq_length);
}
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(in_dtype, ctx.GetPlace());
}
};
class OverlapAddOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of overlap_add op.");
AddOutput("Out", "(Tensor), The output tensor of overlap_add op.");
AddAttr<int>("hop_length",
"Number of steps to advance between adjacent frames and "
"`0 < hop_length <= frame_length`.");
AddAttr<int>("axis",
"Specify the axis to operate on the input Tensors. Its value "
"should be 0(the first dimension) or -1(the last dimension).")
.SetDefault(-1);
AddComment(R"DOC(
Reconstructs a tensor consisted of overlap added sequences from input frames.
)DOC");
}
};
class OverlapAddOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "overlap_add_grad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "overlap_add_grad");
const auto x_dims = ctx->GetInputDim("X");
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(in_dtype, ctx.GetPlace());
}
};
template <typename T>
class OverlapAddOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("overlap_add_grad");
retv->SetInput("X", this->Input("X"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(overlap_add, ops::OverlapAddOp, ops::OverlapAddOpMaker,
ops::OverlapAddOpGradMaker<paddle::framework::OpDesc>,
ops::OverlapAddOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(overlap_add_grad, ops::OverlapAddOpGrad);
REGISTER_OP_CPU_KERNEL(
overlap_add, ops::OverlapAddKernel<paddle::platform::CPUDeviceContext, int>,
ops::OverlapAddKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::OverlapAddKernel<paddle::platform::CPUDeviceContext, float>,
ops::OverlapAddKernel<paddle::platform::CPUDeviceContext, double>,
ops::OverlapAddKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::OverlapAddKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
overlap_add_grad,
ops::OverlapAddGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::OverlapAddGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::OverlapAddGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::OverlapAddGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::OverlapAddGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::OverlapAddGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/overlap_add_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
overlap_add,
ops::OverlapAddKernel<paddle::platform::CUDADeviceContext, int>,
ops::OverlapAddKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::OverlapAddKernel<paddle::platform::CUDADeviceContext, float>,
ops::OverlapAddKernel<paddle::platform::CUDADeviceContext, double>,
ops::OverlapAddKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::OverlapAddKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::OverlapAddKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
overlap_add_grad,
ops::OverlapAddGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::OverlapAddGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::OverlapAddGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::OverlapAddGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::OverlapAddGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::OverlapAddGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::OverlapAddGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/seq2col.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
struct OverlapAddFunctor {
void operator()(const DeviceContext& dev_ctx, const Tensor* input,
Tensor* output, size_t seq_length, size_t frame_length,
size_t n_frames, size_t hop_length,
bool is_grad = false) const {
auto numel = output->numel();
const auto* input_data = input->data<T>();
auto* output_data = output->data<T>();
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
if (!is_grad) {
math::Col2SeqFunctor<T> functor(input_data, output_data, seq_length,
frame_length, n_frames, hop_length);
for_range(functor);
} else {
math::Seq2ColFunctor<T> functor(input_data, output_data, seq_length,
frame_length, n_frames, hop_length);
for_range(functor);
}
}
};
template <typename DeviceContext, typename T>
class OverlapAddKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
const Tensor* x = ctx.Input<Tensor>("X");
Tensor* out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
const size_t x_rank = x->dims().size();
const size_t out_rank = out->dims().size();
const int hop_length = ctx.Attr<int>("hop_length");
const int axis = ctx.Attr<int>("axis");
const int n_frames = (axis == 0) ? x->dims()[0] : x->dims()[x_rank - 1];
const int frame_length = (axis == 0) ? x->dims()[1] : x->dims()[x_rank - 2];
const int seq_length =
(axis == 0) ? out->dims()[0] : out->dims()[out_rank - 1];
auto& dev_ctx = ctx.device_context<DeviceContext>();
Tensor x_(x->type());
x_ = *x;
framework::DDim preserved_dims;
if (out_rank > 2) {
// Save dims used to flatten both input and output tensors and restore
// output tensor.
framework::DDim x_resized_dims;
framework::DDim out_resized_dims;
if (axis == 0) {
preserved_dims = framework::slice_ddim(out->dims(), 1, out_rank);
x_resized_dims = {n_frames, frame_length,
framework::product(preserved_dims)};
out_resized_dims = {seq_length, framework::product(preserved_dims)};
} else {
preserved_dims = framework::slice_ddim(out->dims(), 0, out_rank - 1);
x_resized_dims = {framework::product(preserved_dims), frame_length,
n_frames};
out_resized_dims = {framework::product(preserved_dims), seq_length};
}
x_.Resize(x_resized_dims);
out->Resize(out_resized_dims);
}
Tensor trans_x(x_.type());
Tensor trans_out(out->type());
// Transpose input and output in case that axis is 0.
if (axis == 0) {
if (out_rank == 1U) {
trans_out = *out;
std::vector<int> perm_x{1, 0};
auto x_dims_vec = framework::vectorize(x_.dims());
for (int i = 0; i < x_.dims().size(); ++i) {
x_dims_vec[i] = x_.dims()[perm_x[i]];
}
trans_x.Resize(framework::make_ddim(x_dims_vec));
trans_x.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(perm_x.size(), dev_ctx, x_, &trans_x,
perm_x);
} else {
std::vector<int> perm_out{1, 0};
auto out_dims_vec = framework::vectorize(out->dims());
for (int i = 0; i < out->dims().size(); ++i) {
out_dims_vec[i] = out->dims()[perm_out[i]];
}
trans_out.Resize(framework::make_ddim(out_dims_vec));
trans_out.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(perm_out.size(), dev_ctx, *out,
&trans_out, perm_out);
std::vector<int> perm_x{2, 1, 0};
auto x_dims_vec = framework::vectorize(x_.dims());
for (int i = 0; i < x_.dims().size(); ++i) {
x_dims_vec[i] = x_.dims()[perm_x[i]];
}
trans_x.Resize(framework::make_ddim(x_dims_vec));
trans_x.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(perm_x.size(), dev_ctx, x_, &trans_x,
perm_x);
}
} else {
trans_x = x_;
trans_out = *out;
}
OverlapAddFunctor<DeviceContext, T>()(dev_ctx, &trans_x, &trans_out,
seq_length, frame_length, n_frames,
hop_length, /*is_grad*/ false);
// Transpose output in case axis is 0.
if (axis == 0 && out_rank > 1U) {
std::vector<int> perm_out{1, 0};
TransCompute<DeviceContext, T>(perm_out.size(), dev_ctx, trans_out, out,
perm_out);
}
// Restore output dims when the number of dims is larger than 2.
if (out_rank > 2) {
std::vector<int64_t> restored_out_shape;
for (int i = 0; i < preserved_dims.size(); i++) {
restored_out_shape.push_back(preserved_dims[i]);
}
if (axis == 0) {
// (seq_length, ...)
restored_out_shape.insert(restored_out_shape.begin(), seq_length);
} else {
// (..., seq_length)
restored_out_shape.push_back(seq_length);
}
out->Resize(framework::make_ddim(restored_out_shape));
}
}
};
template <typename DeviceContext, typename T>
class OverlapAddGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const Tensor* d_out = ctx.Input<Tensor>(framework::GradVarName("Out"));
Tensor* d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
d_x->mutable_data<T>(ctx.GetPlace());
const size_t d_out_rank = d_out->dims().size();
const size_t d_x_rank = d_x->dims().size();
const int hop_length = ctx.Attr<int>("hop_length");
const int axis = ctx.Attr<int>("axis");
const int n_frames =
(axis == 0) ? d_x->dims()[0] : d_x->dims()[d_x_rank - 1];
const int frame_length =
(axis == 0) ? d_x->dims()[1] : d_x->dims()[d_x_rank - 2];
const int seq_length =
(axis == 0) ? d_out->dims()[0] : d_out->dims()[d_out_rank - 1];
auto& dev_ctx = ctx.device_context<DeviceContext>();
// When the number of input dims is larger than 2, it needs to copy
// from x to resize input into 2d and output into 3d. Morevoer, output
// dims will be restored at the last step.
Tensor d_out_(d_out->type());
d_out_ = *d_out;
framework::DDim preserved_dims;
if (d_out_rank > 2) {
// Save dims used to flatten both input and output tensors and restore
// output tensor.
framework::DDim d_x_resized_dims;
framework::DDim d_out_resized_dims;
if (axis == 0) {
preserved_dims = framework::slice_ddim(d_out_.dims(), 1, d_out_rank);
d_x_resized_dims = {n_frames, frame_length,
framework::product(preserved_dims)};
d_out_resized_dims = {seq_length, framework::product(preserved_dims)};
} else {
preserved_dims =
framework::slice_ddim(d_out_.dims(), 0, d_out_rank - 1);
d_x_resized_dims = {framework::product(preserved_dims), frame_length,
n_frames};
d_out_resized_dims = {framework::product(preserved_dims), seq_length};
}
d_x->Resize(d_x_resized_dims);
d_out_.Resize(d_out_resized_dims);
}
Tensor trans_d_x(d_x->type());
Tensor trans_d_out(d_out_.type());
// Transpose input and output in case that axis is 0.
if (axis == 0) {
if (d_out_rank == 1U) {
trans_d_out = d_out_;
std::vector<int> perm_d_x{1, 0};
auto d_x_dims_vec = framework::vectorize(d_x->dims());
for (int i = 0; i < d_x->dims().size(); ++i) {
d_x_dims_vec[i] = d_x->dims()[perm_d_x[i]];
}
trans_d_x.Resize(framework::make_ddim(d_x_dims_vec));
trans_d_x.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(perm_d_x.size(), dev_ctx, *d_x,
&trans_d_x, perm_d_x);
} else {
std::vector<int> perm_d_out{1, 0};
auto d_out_dims_vec = framework::vectorize(d_out_.dims());
for (int i = 0; i < d_out_.dims().size(); ++i) {
d_out_dims_vec[i] = d_out_.dims()[perm_d_out[i]];
}
trans_d_out.Resize(framework::make_ddim(d_out_dims_vec));
trans_d_out.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(perm_d_out.size(), dev_ctx, d_out_,
&trans_d_out, perm_d_out);
std::vector<int> perm_d_x{2, 1, 0};
auto d_x_dims_vec = framework::vectorize(d_x->dims());
for (int i = 0; i < d_x->dims().size(); ++i) {
d_x_dims_vec[i] = d_x->dims()[perm_d_x[i]];
}
trans_d_x.Resize(framework::make_ddim(d_x_dims_vec));
trans_d_x.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(perm_d_x.size(), dev_ctx, *d_x,
&trans_d_x, perm_d_x);
}
} else {
trans_d_x = *d_x;
trans_d_out = d_out_;
}
OverlapAddFunctor<DeviceContext, T>()(dev_ctx, &trans_d_out, &trans_d_x,
seq_length, frame_length, n_frames,
hop_length,
/*is_grad*/ true);
// Transpose output in case axis is 0.
if (axis == 0) {
if (d_out_rank == 1U) {
std::vector<int> perm_d_x{1, 0};
TransCompute<DeviceContext, T>(perm_d_x.size(), dev_ctx, trans_d_x, d_x,
perm_d_x);
} else {
std::vector<int> perm_d_x{2, 1, 0};
TransCompute<DeviceContext, T>(perm_d_x.size(), dev_ctx, trans_d_x, d_x,
perm_d_x);
}
}
// Restore output dims when the number of dims is larger than 2.
if (d_out_rank > 2) {
std::vector<int64_t> restored_d_x_shape;
for (int i = 0; i < preserved_dims.size(); i++) {
restored_d_x_shape.push_back(preserved_dims[i]);
}
if (axis == 0) {
// (n_frames, frame_length, ...)
restored_d_x_shape.insert(restored_d_x_shape.begin(), frame_length);
restored_d_x_shape.insert(restored_d_x_shape.begin(), n_frames);
} else {
// (..., frame_length, n_frames)
restored_d_x_shape.push_back(frame_length);
restored_d_x_shape.push_back(n_frames);
}
d_x->Resize(framework::make_ddim(restored_d_x_shape));
}
}
};
} // namespace operators
} // namespace paddle
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/pad_op.h" #include "paddle/fluid/operators/pad_op.h"
#include <memory> #include <memory>
#include "paddle/fluid/platform/complex.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -170,10 +171,18 @@ REGISTER_OP_CPU_KERNEL( ...@@ -170,10 +171,18 @@ REGISTER_OP_CPU_KERNEL(
pad, ops::PadKernel<paddle::platform::CPUDeviceContext, float>, pad, ops::PadKernel<paddle::platform::CPUDeviceContext, float>,
ops::PadKernel<paddle::platform::CPUDeviceContext, double>, ops::PadKernel<paddle::platform::CPUDeviceContext, double>,
ops::PadKernel<paddle::platform::CPUDeviceContext, int>, ops::PadKernel<paddle::platform::CPUDeviceContext, int>,
ops::PadKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::PadKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::PadKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::PadKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
pad_grad, ops::PadGradKernel<paddle::platform::CPUDeviceContext, float>, pad_grad, ops::PadGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::PadGradKernel<paddle::platform::CPUDeviceContext, double>); ops::PadGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::PadGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::PadGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
pad, ops::PadKernel<paddle::platform::CUDADeviceContext, double>, pad, ops::PadKernel<paddle::platform::CUDADeviceContext, double>,
...@@ -181,9 +190,17 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -181,9 +190,17 @@ REGISTER_OP_CUDA_KERNEL(
ops::PadKernel<paddle::platform::CUDADeviceContext, int>, ops::PadKernel<paddle::platform::CUDADeviceContext, int>,
ops::PadKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::PadKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::PadKernel<paddle::platform::CUDADeviceContext, ops::PadKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>); paddle::platform::float16>,
ops::PadKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::PadKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
pad_grad, ops::PadGradKernel<paddle::platform::CUDADeviceContext, double>, pad_grad, ops::PadGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::PadGradKernel<paddle::platform::CUDADeviceContext, float>, ops::PadGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::PadGradKernel<paddle::platform::CUDADeviceContext, ops::PadGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>); paddle::platform::float16>,
ops::PadGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::PadGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/complex.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -148,12 +149,20 @@ REGISTER_OP_CPU_KERNEL( ...@@ -148,12 +149,20 @@ REGISTER_OP_CPU_KERNEL(
roll, ops::RollKernel<paddle::platform::CPUDeviceContext, float>, roll, ops::RollKernel<paddle::platform::CPUDeviceContext, float>,
ops::RollKernel<paddle::platform::CPUDeviceContext, double>, ops::RollKernel<paddle::platform::CPUDeviceContext, double>,
ops::RollKernel<paddle::platform::CPUDeviceContext, int>, ops::RollKernel<paddle::platform::CPUDeviceContext, int>,
ops::RollKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::RollKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::RollKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::RollKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
roll_grad, ops::RollGradKernel<paddle::platform::CPUDeviceContext, float>, roll_grad, ops::RollGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::RollGradKernel<paddle::platform::CPUDeviceContext, double>, ops::RollGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::RollGradKernel<paddle::platform::CPUDeviceContext, int>, ops::RollGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::RollGradKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::RollGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::RollGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::RollGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_VERSION(roll) REGISTER_OP_VERSION(roll)
.AddCheckpoint( .AddCheckpoint(
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "paddle/fluid/framework/array.h" #include "paddle/fluid/framework/array.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/roll_op.h" #include "paddle/fluid/operators/roll_op.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle { namespace paddle {
...@@ -188,9 +189,17 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -188,9 +189,17 @@ REGISTER_OP_CUDA_KERNEL(
roll, ops::RollKernel<paddle::platform::CUDADeviceContext, float>, roll, ops::RollKernel<paddle::platform::CUDADeviceContext, float>,
ops::RollKernel<paddle::platform::CUDADeviceContext, double>, ops::RollKernel<paddle::platform::CUDADeviceContext, double>,
ops::RollKernel<paddle::platform::CUDADeviceContext, int>, ops::RollKernel<paddle::platform::CUDADeviceContext, int>,
ops::RollKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::RollKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::RollKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::RollKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
roll_grad, ops::RollGradKernel<paddle::platform::CUDADeviceContext, float>, roll_grad, ops::RollGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::RollGradKernel<paddle::platform::CUDADeviceContext, double>, ops::RollGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::RollGradKernel<paddle::platform::CUDADeviceContext, int>, ops::RollGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::RollGradKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::RollGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::RollGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::RollGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/shape_op.h" #include "paddle/fluid/operators/shape_op.h"
#include <string> #include <string>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/complex.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -64,6 +65,7 @@ Return the shape of the input. ...@@ -64,6 +65,7 @@ Return the shape of the input.
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR( REGISTER_OPERATOR(
shape, ops::ShapeOp, ops::ShapeOpMaker, shape, ops::ShapeOp, ops::ShapeOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
...@@ -71,4 +73,6 @@ REGISTER_OPERATOR( ...@@ -71,4 +73,6 @@ REGISTER_OPERATOR(
REGISTER_OP_CPU_KERNEL(shape, ops::ShapeKernel<bool>, ops::ShapeKernel<int>, REGISTER_OP_CPU_KERNEL(shape, ops::ShapeKernel<bool>, ops::ShapeKernel<int>,
ops::ShapeKernel<int8_t>, ops::ShapeKernel<uint8_t>, ops::ShapeKernel<int8_t>, ops::ShapeKernel<uint8_t>,
ops::ShapeKernel<int64_t>, ops::ShapeKernel<float>, ops::ShapeKernel<int64_t>, ops::ShapeKernel<float>,
ops::ShapeKernel<double>); ops::ShapeKernel<double>,
ops::ShapeKernel<plat::complex<float>>,
ops::ShapeKernel<plat::complex<double>>);
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/shape_op.h" #include "paddle/fluid/operators/shape_op.h"
#include "paddle/fluid/platform/complex.h"
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
shape, paddle::operators::ShapeKernel<bool>, shape, paddle::operators::ShapeKernel<bool>,
...@@ -21,4 +22,6 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -21,4 +22,6 @@ REGISTER_OP_CUDA_KERNEL(
paddle::operators::ShapeKernel<int64_t>, paddle::operators::ShapeKernel<int64_t>,
paddle::operators::ShapeKernel<float>, paddle::operators::ShapeKernel<float>,
paddle::operators::ShapeKernel<double>, paddle::operators::ShapeKernel<double>,
paddle::operators::ShapeKernel<paddle::platform::float16>); paddle::operators::ShapeKernel<paddle::platform::float16>,
paddle::operators::ShapeKernel<paddle::platform::complex<float>>,
paddle::operators::ShapeKernel<paddle::platform::complex<double>>);
此差异已折叠。
此差异已折叠。
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#define NOMINMAX // to use std::min std::max correctly on windows
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/conj_op.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/operators/math/padding.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/for_range.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "thrust/device_vector.h"
#endif
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
enum class FFTNormMode : int64_t {
none, // No normalization
by_sqrt_n, // Divide by sqrt(signal_size)
by_n, // Divide by signal_size
};
FFTNormMode get_norm_from_string(const std::string& norm, bool forward);
// Enum representing the FFT type
enum class FFTTransformType : int64_t {
C2C = 0, // Complex-to-complex
R2C, // Real-to-complex
C2R, // Complex-to-real
};
// Create transform type enum from bools representing if input and output are
// complex
inline FFTTransformType GetFFTTransformType(
framework::proto::VarType::Type input_dtype,
framework::proto::VarType::Type output_dtype) {
auto complex_input = framework::IsComplexType(input_dtype);
auto complex_output = framework::IsComplexType(output_dtype);
if (complex_input && complex_output) {
return FFTTransformType::C2C;
} else if (complex_input && !complex_output) {
return FFTTransformType::C2R;
} else if (!complex_input && complex_output) {
return FFTTransformType::R2C;
}
PADDLE_THROW(
platform::errors::InvalidArgument("Real to real FFTs are not supported"));
}
// Returns true if the transform type has complex input
inline bool has_complex_input(FFTTransformType type) {
switch (type) {
case FFTTransformType::C2C:
case FFTTransformType::C2R:
return true;
case FFTTransformType::R2C:
return false;
}
PADDLE_THROW(platform::errors::InvalidArgument("Unknown FFTTransformType"));
}
// Returns true if the transform type has complex output
inline bool has_complex_output(FFTTransformType type) {
switch (type) {
case FFTTransformType::C2C:
case FFTTransformType::R2C:
return true;
case FFTTransformType::C2R:
return false;
}
PADDLE_THROW(platform::errors::InvalidArgument("Unknown FFTTransformType"));
}
template <typename T>
struct FFTFillConjGradFunctor {
T* input_;
const size_t axis_;
const int64_t* strides_;
const size_t double_length_;
FFTFillConjGradFunctor(T* input, size_t axis, const int64_t* strides,
size_t double_length)
: input_(input),
axis_(axis),
strides_(strides),
double_length_(double_length) {}
HOSTDEVICE void operator()(size_t index) {
size_t offtset = index; // back
size_t index_i;
for (size_t i = 0; i <= axis_; i++) {
index_i = offtset / strides_[i];
offtset %= strides_[i];
}
if ((0 < index_i) && (index_i < double_length_ + 1)) {
input_[index] *= static_cast<T>(2);
}
}
};
template <typename DeviceContext, typename Ti, typename To>
struct FFTC2CFunctor {
void operator()(const DeviceContext& ctx, const Tensor* X, Tensor* out,
const std::vector<int64_t>& axes, FFTNormMode normalization,
bool forward);
};
template <typename DeviceContext, typename Ti, typename To>
struct FFTR2CFunctor {
void operator()(const DeviceContext& ctx, const Tensor* X, Tensor* out,
const std::vector<int64_t>& axes, FFTNormMode normalization,
bool forward);
};
template <typename DeviceContext, typename Ti, typename To>
struct FFTC2RFunctor {
void operator()(const DeviceContext& ctx, const Tensor* X, Tensor* out,
const std::vector<int64_t>& axes, FFTNormMode normalization,
bool forward);
};
// Giving a linear destination index and strides of tensor, get_idx return the
// corresponding linear position of source tensor.
// The linear index is the position of flatten tensor.
// Giving a linear destination index and strides of tensor, get_idx return the
// corresponding linear position of source tensor.
// The linear index is the position of flatten tensor.
HOSTDEVICE inline int64_t get_src_idx(const int64_t dst_idx,
const int64_t* dst_strides,
const int64_t* dst_shape,
const int64_t* src_strides,
const bool* is_fft_axis, const bool conj,
const int64_t rank) {
int64_t src_idx = 0;
int64_t quotient = dst_idx;
int64_t remainder = 0;
for (int64_t i = 0; i < rank; i++) {
remainder = quotient % dst_strides[i];
quotient = quotient / dst_strides[i];
if (conj && is_fft_axis[i]) {
src_idx += ((dst_shape[i] - quotient) % dst_shape[i]) * src_strides[i];
} else {
src_idx += src_strides[i] * quotient;
}
quotient = remainder;
}
return src_idx;
}
HOSTDEVICE inline bool is_conj_part(const int64_t dst_idx,
const int64_t* dst_strides,
const int64_t last_axis,
const int64_t last_axis_size) {
int64_t quotient = dst_idx;
int64_t remainder = 0;
for (int64_t i = 0; i < last_axis + 1; i++) {
remainder = quotient % dst_strides[i];
quotient = quotient / dst_strides[i];
if ((i == last_axis) && (quotient > last_axis_size - 1)) {
return true;
}
quotient = remainder;
}
return false;
}
// FFTFillConjFunctor fill the destination tensor with source tensor and
// conjugate symmetry element of source tensor .
// Use framework::ForRange to iterate destination element with
// supporting different device
template <typename C>
struct FFTFillConjFunctor {
FFTFillConjFunctor(const C* src_data, C* dst_data, const int64_t* src_strides,
const int64_t* dst_strides, const int64_t* dst_shape,
const bool* is_fft_axis, const int64_t last_axis,
const int64_t last_axis_size, const int64_t rank)
: src_data_(src_data),
dst_data_(dst_data),
src_strides_(src_strides),
dst_strides_(dst_strides),
dst_shape_(dst_shape),
is_fft_axis_(is_fft_axis),
last_axis_(last_axis),
last_axis_size_(last_axis_size),
rank_(rank) {}
HOSTDEVICE void operator()(int64_t dst_idx) {
if (is_conj_part(dst_idx, dst_strides_, last_axis_, last_axis_size_)) {
const auto conj_idx =
get_src_idx(dst_idx, dst_strides_, dst_shape_, src_strides_,
is_fft_axis_, true, rank_);
auto src_value = src_data_[conj_idx];
auto conj_value = C(src_value.real, -src_value.imag);
dst_data_[dst_idx] = conj_value;
} else {
const auto copy_idx =
get_src_idx(dst_idx, dst_strides_, dst_shape_, src_strides_,
is_fft_axis_, false, rank_);
dst_data_[dst_idx] = src_data_[copy_idx];
}
}
const C* src_data_;
C* dst_data_;
const int64_t* src_strides_;
const int64_t* dst_strides_;
const int64_t* dst_shape_;
const bool* is_fft_axis_;
const int64_t last_axis_;
const int64_t last_axis_size_;
const int64_t rank_;
};
template <typename DeviceContext, typename C>
void fill_conj(const DeviceContext& ctx, const Tensor* src, Tensor* dst,
const std::vector<int64_t>& axes) {
std::vector<int64_t> src_strides_v =
framework::vectorize<int64_t>(framework::stride(src->dims()));
std::vector<int64_t> dst_strides_v =
framework::vectorize<int64_t>(framework::stride(dst->dims()));
std::vector<int64_t> dst_shape_v = framework::vectorize<int64_t>(dst->dims());
const auto src_data = src->data<C>();
auto dst_data = dst->data<C>();
const auto last_axis = axes.back();
const auto last_axis_size = dst->dims().at(last_axis) / 2 + 1;
const int64_t rank = dst->dims().size();
auto _is_fft_axis = std::make_unique<bool[]>(rank);
for (const auto i : axes) {
_is_fft_axis[i] = true;
}
#if defined(__NVCC__) || defined(__HIPCC__)
const thrust::device_vector<int64_t> src_strides_g(src_strides_v);
const auto src_strides = thrust::raw_pointer_cast(src_strides_g.data());
const thrust::device_vector<int64_t> dst_strides_g(dst_strides_v);
const auto dst_strides = thrust::raw_pointer_cast(dst_strides_g.data());
const thrust::device_vector<int64_t> dst_shape_g(dst_shape_v);
const auto dst_shape = thrust::raw_pointer_cast(dst_shape_g.data());
const thrust::device_vector<bool> is_fft_axis_g(_is_fft_axis.get(),
_is_fft_axis.get() + rank);
const auto p_is_fft_axis = thrust::raw_pointer_cast(is_fft_axis_g.data());
#else
const auto src_strides = src_strides_v.data();
const auto dst_strides = dst_strides_v.data();
const auto dst_shape = dst_shape_v.data();
const auto p_is_fft_axis = _is_fft_axis.get();
#endif
platform::ForRange<DeviceContext> for_range(ctx, dst->numel());
FFTFillConjFunctor<C> fill_conj_functor(src_data, dst_data, src_strides,
dst_strides, dst_shape, p_is_fft_axis,
last_axis, last_axis_size, rank);
for_range(fill_conj_functor);
}
template <typename DeviceContext, typename T>
class FFTC2CKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using C = paddle::platform::complex<T>;
auto& dev_ctx = ctx.device_context<DeviceContext>();
auto axes = ctx.Attr<std::vector<int64_t>>("axes");
const std::string& norm_str = ctx.Attr<std::string>("normalization");
const bool forward = ctx.Attr<bool>("forward");
const auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Output<Tensor>("Out");
y->mutable_data<C>(ctx.GetPlace());
auto normalization = get_norm_from_string(norm_str, forward);
FFTC2CFunctor<DeviceContext, C, C> fft_c2c_func;
fft_c2c_func(dev_ctx, x, y, axes, normalization, forward);
}
};
template <typename DeviceContext, typename T>
class FFTC2CGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using C = paddle::platform::complex<T>;
auto& dev_ctx = ctx.device_context<DeviceContext>();
auto axes = ctx.Attr<std::vector<int64_t>>("axes");
const std::string& norm_str = ctx.Attr<std::string>("normalization");
const bool forward = ctx.Attr<bool>("forward");
const auto* dy = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
dx->mutable_data<C>(ctx.GetPlace());
auto normalization = get_norm_from_string(norm_str, forward);
FFTC2CFunctor<DeviceContext, C, C> fft_c2c_func;
fft_c2c_func(dev_ctx, dy, dx, axes, normalization, !forward);
}
};
template <typename DeviceContext, typename T>
class FFTR2CKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using C = paddle::platform::complex<T>;
auto& dev_ctx = ctx.device_context<DeviceContext>();
auto axes = ctx.Attr<std::vector<int64_t>>("axes");
const std::string& norm_str = ctx.Attr<std::string>("normalization");
const bool forward = ctx.Attr<bool>("forward");
const bool onesided = ctx.Attr<bool>("onesided");
const auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Output<Tensor>("Out");
y->mutable_data<C>(ctx.GetPlace());
auto normalization = get_norm_from_string(norm_str, forward);
FFTR2CFunctor<DeviceContext, T, C> fft_r2c_func;
if (onesided) {
fft_r2c_func(dev_ctx, x, y, axes, normalization, forward);
} else {
framework::DDim onesided_dims(y->dims());
const int64_t onesided_last_axis_size = y->dims().at(axes.back()) / 2 + 1;
onesided_dims.at(axes.back()) = onesided_last_axis_size;
framework::Tensor onesided_out;
onesided_out.mutable_data<C>(onesided_dims, ctx.GetPlace());
fft_r2c_func(dev_ctx, x, &onesided_out, axes, normalization, forward);
fill_conj<DeviceContext, C>(dev_ctx, &onesided_out, y, axes);
}
}
};
template <typename DeviceContext, typename T>
class FFTR2CGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using C = paddle::platform::complex<T>;
auto& dev_ctx = ctx.device_context<DeviceContext>();
const auto axes = ctx.Attr<std::vector<int64_t>>("axes");
const std::string& norm_str = ctx.Attr<std::string>("normalization");
const bool forward = ctx.Attr<bool>("forward");
const bool onesided = ctx.Attr<bool>("onesided");
const auto* dy = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
dx->mutable_data<T>(ctx.GetPlace());
framework::Tensor complex_dx;
complex_dx.mutable_data<C>(dx->dims(), ctx.GetPlace());
auto normalization = get_norm_from_string(norm_str, forward);
FFTC2CFunctor<DeviceContext, C, C> fft_c2c_func;
if (!onesided) {
fft_c2c_func(dev_ctx, dy, &complex_dx, axes, normalization, !forward);
} else {
framework::Tensor full_dy;
full_dy.mutable_data<C>(dx->dims(), ctx.GetPlace());
auto zero_length = static_cast<int>(full_dy.dims().at(axes.back()) -
dy->dims().at(axes.back()));
auto rank = dy->dims().size();
std::vector<int> pads(rank * 2, 0);
pads[axes.back() * 2 + 1] = zero_length;
paddle::operators::math::PaddingFunctor<DeviceContext, C>(
rank, ctx, pads, static_cast<C>(0), *dy, &full_dy);
fft_c2c_func(dev_ctx, &full_dy, &complex_dx, axes, normalization,
!forward);
}
framework::TransComplexToReal(dx->type(), complex_dx.type(), complex_dx,
dx);
}
};
template <typename DeviceContext, typename T>
class FFTC2RKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using C = paddle::platform::complex<T>;
auto& dev_ctx = ctx.device_context<DeviceContext>();
auto axes = ctx.Attr<std::vector<int64_t>>("axes");
const std::string& norm_str = ctx.Attr<std::string>("normalization");
const bool forward = ctx.Attr<bool>("forward");
const auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Output<Tensor>("Out");
y->mutable_data<T>(ctx.GetPlace());
auto normalization = get_norm_from_string(norm_str, forward);
FFTC2RFunctor<DeviceContext, C, T> fft_c2r_func;
fft_c2r_func(dev_ctx, x, y, axes, normalization, forward);
}
};
template <typename DeviceContext, typename T>
class FFTC2RGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using C = paddle::platform::complex<T>;
auto& dev_ctx = ctx.device_context<DeviceContext>();
auto axes = ctx.Attr<std::vector<int64_t>>("axes");
const std::string& norm_str = ctx.Attr<std::string>("normalization");
const bool forward = ctx.Attr<bool>("forward");
const auto* dy = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
C* pdx = dx->mutable_data<C>(ctx.GetPlace());
auto normalization = get_norm_from_string(norm_str, forward);
FFTR2CFunctor<DeviceContext, T, C> fft_r2c_func;
fft_r2c_func(dev_ctx, dy, dx, axes, normalization, !forward);
const int64_t double_length =
dy->dims()[axes.back()] - dx->dims()[axes.back()];
const framework::DDim strides = framework::stride(dx->dims());
#if defined(__NVCC__) || defined(__HIPCC__)
const thrust::device_vector<int64_t> strides_g(
framework::vectorize(strides));
const int64_t* pstrides = thrust::raw_pointer_cast(strides_g.data());
#else
const int64_t* pstrides = strides.Get();
#endif
FFTFillConjGradFunctor<C> func(pdx, axes.back(), pstrides, double_length);
size_t limit = dx->numel();
platform::ForRange<DeviceContext> for_range(dev_ctx, limit);
for_range(func);
}
};
} // namespace operators
} // namespace paddle
...@@ -389,7 +389,11 @@ REGISTER_OP_CPU_KERNEL( ...@@ -389,7 +389,11 @@ REGISTER_OP_CPU_KERNEL(
ops::SqueezeKernel<paddle::platform::CPUDeviceContext, int>, ops::SqueezeKernel<paddle::platform::CPUDeviceContext, int>,
ops::SqueezeKernel<paddle::platform::CPUDeviceContext, uint8_t>, ops::SqueezeKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::SqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>, ops::SqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::SqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::SqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::SqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::SqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
squeeze_grad, squeeze_grad,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, float>, ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, float>,
...@@ -398,7 +402,12 @@ REGISTER_OP_CPU_KERNEL( ...@@ -398,7 +402,12 @@ REGISTER_OP_CPU_KERNEL(
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, int>, ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, uint8_t>, ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, int8_t>, ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
squeeze2, ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, float>, squeeze2, ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, float>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, double>, ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, double>,
...@@ -406,7 +415,12 @@ REGISTER_OP_CPU_KERNEL( ...@@ -406,7 +415,12 @@ REGISTER_OP_CPU_KERNEL(
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, int>, ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, int>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, uint8_t>, ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, int8_t>, ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, int64_t>); ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
squeeze2_grad, squeeze2_grad,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, float>, ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, float>,
...@@ -415,4 +429,8 @@ REGISTER_OP_CPU_KERNEL( ...@@ -415,4 +429,8 @@ REGISTER_OP_CPU_KERNEL(
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, int>, ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, int>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, uint8_t>, ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, int8_t>, ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
...@@ -25,7 +25,11 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -25,7 +25,11 @@ REGISTER_OP_CUDA_KERNEL(
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, int>, ops::SqueezeKernel<paddle::platform::CUDADeviceContext, int>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, uint8_t>, ops::SqueezeKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>, ops::SqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::SqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
squeeze_grad, squeeze_grad,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, float>, ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, float>,
...@@ -35,7 +39,11 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -35,7 +39,11 @@ REGISTER_OP_CUDA_KERNEL(
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, int>, ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, uint8_t>, ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, int8_t>, ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
squeeze2, ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, float>, squeeze2, ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, float>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, double>, ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, double>,
...@@ -44,7 +52,11 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -44,7 +52,11 @@ REGISTER_OP_CUDA_KERNEL(
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int>, ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int8_t>, ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, uint8_t>, ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int64_t>); ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
squeeze2_grad, squeeze2_grad,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, float>, ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, float>,
...@@ -54,4 +66,8 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -54,4 +66,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int>, ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int8_t>, ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, uint8_t>, ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
...@@ -362,7 +362,11 @@ REGISTER_OP_CPU_KERNEL( ...@@ -362,7 +362,11 @@ REGISTER_OP_CPU_KERNEL(
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int>, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, uint8_t>, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
unsqueeze_grad, unsqueeze_grad,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, float>, ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, float>,
...@@ -371,7 +375,11 @@ REGISTER_OP_CPU_KERNEL( ...@@ -371,7 +375,11 @@ REGISTER_OP_CPU_KERNEL(
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int>, ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, uint8_t>, ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int8_t>, ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
unsqueeze2, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, float>, unsqueeze2, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, float>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, double>, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, double>,
...@@ -379,7 +387,11 @@ REGISTER_OP_CPU_KERNEL( ...@@ -379,7 +387,11 @@ REGISTER_OP_CPU_KERNEL(
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int>, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, uint8_t>, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
unsqueeze2_grad, unsqueeze2_grad,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, float>, ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, float>,
...@@ -388,4 +400,8 @@ REGISTER_OP_CPU_KERNEL( ...@@ -388,4 +400,8 @@ REGISTER_OP_CPU_KERNEL(
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int>, ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, uint8_t>, ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int8_t>, ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
...@@ -25,7 +25,11 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -25,7 +25,11 @@ REGISTER_OP_CUDA_KERNEL(
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, uint8_t>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
unsqueeze_grad, unsqueeze_grad,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, float>, ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, float>,
...@@ -36,7 +40,11 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -36,7 +40,11 @@ REGISTER_OP_CUDA_KERNEL(
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int>, ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int8_t>, ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, uint8_t>, ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
unsqueeze2, unsqueeze2,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, float>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, float>,
...@@ -46,7 +54,11 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -46,7 +54,11 @@ REGISTER_OP_CUDA_KERNEL(
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, uint8_t>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
unsqueeze2_grad, unsqueeze2_grad,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, float>, ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, float>,
...@@ -57,4 +69,8 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -57,4 +69,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int>, ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, uint8_t>, ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int8_t>, ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
...@@ -60,6 +60,8 @@ struct PADDLE_ALIGN(sizeof(T) * 2) complex { ...@@ -60,6 +60,8 @@ struct PADDLE_ALIGN(sizeof(T) * 2) complex {
T real; T real;
T imag; T imag;
using value_type = T;
complex() = default; complex() = default;
complex(const complex<T>& o) = default; complex(const complex<T>& o) = default;
complex& operator=(const complex<T>& o) = default; complex& operator=(const complex<T>& o) = default;
......
cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags enforce) cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags enforce)
list(APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc cusolver.cc cusparse.cc nvtx.cc) list(APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc cusolver.cc cusparse.cc nvtx.cc cufft.cc)
if (NOT WITH_NV_JETSON) if (NOT WITH_NV_JETSON)
list(APPEND CUDA_SRCS nvjpeg.cc) list(APPEND CUDA_SRCS nvjpeg.cc)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/dynload/cufft.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace platform {
namespace dynload {
std::once_flag cufft_dso_flag;
void* cufft_dso_handle = nullptr;
#define DEFINE_WRAP(__name) DynLoad__##__name __name
CUFFT_FFT_ROUTINE_EACH(DEFINE_WRAP);
bool HasCUFFT() {
std::call_once(cufft_dso_flag,
[]() { cufft_dso_handle = GetCUFFTDsoHandle(); });
return cufft_dso_handle != nullptr;
}
void EnforceCUFFTLoaded(const char* fn_name) {
PADDLE_ENFORCE_NOT_NULL(
cufft_dso_handle,
platform::errors::PreconditionNotMet(
"Cannot load cufft shared library. Cannot invoke method %s.",
fn_name));
}
} // namespace dynload
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_CUDA
#include <cufft.h>
#include <cufftXt.h>
#include <glog/logging.h>
#include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/platform/port.h"
namespace paddle {
namespace platform {
namespace dynload {
extern std::once_flag cufft_dso_flag;
extern void* cufft_dso_handle;
extern bool HasCUFFT();
extern void EnforceCUFFTLoaded(const char* fn_name);
#define DECLARE_DYNAMIC_LOAD_CUFFT_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
using cufft_func = decltype(&::__name); \
std::call_once(cufft_dso_flag, []() { \
cufft_dso_handle = paddle::platform::dynload::GetCUFFTDsoHandle(); \
}); \
EnforceCUFFTLoaded(#__name); \
static void* p_##__name = dlsym(cufft_dso_handle, #__name); \
return reinterpret_cast<cufft_func>(p_##__name)(args...); \
} \
}; \
extern struct DynLoad__##__name __name
/**
* include all needed cufft functions in HPPL
* different cufft version has different interfaces
**/
#define CUFFT_FFT_ROUTINE_EACH(__macro) \
__macro(cufftPlan1d); \
__macro(cufftPlan2d); \
__macro(cufftPlan3d); \
__macro(cufftPlanMany); \
__macro(cufftMakePlan1d); \
__macro(cufftMakePlan2d); \
__macro(cufftMakePlan3d); \
__macro(cufftMakePlanMany); \
__macro(cufftMakePlanMany64); \
__macro(cufftGetSizeMany64); \
__macro(cufftEstimate1d); \
__macro(cufftEstimate2d); \
__macro(cufftEstimate3d); \
__macro(cufftEstimateMany); \
__macro(cufftCreate); \
__macro(cufftGetSize1d); \
__macro(cufftGetSize2d); \
__macro(cufftGetSize3d); \
__macro(cufftGetSizeMany); \
__macro(cufftGetSize); \
__macro(cufftSetWorkArea); \
__macro(cufftSetAutoAllocation); \
__macro(cufftExecC2C); \
__macro(cufftExecR2C); \
__macro(cufftExecC2R); \
__macro(cufftExecZ2Z); \
__macro(cufftExecD2Z); \
__macro(cufftExecZ2D); \
__macro(cufftSetStream); \
__macro(cufftDestroy); \
__macro(cufftGetVersion); \
__macro(cufftGetProperty); \
__macro(cufftXtSetGPUs); \
__macro(cufftXtMalloc); \
__macro(cufftXtMemcpy); \
__macro(cufftXtFree); \
__macro(cufftXtSetWorkArea); \
__macro(cufftXtExecDescriptorC2C); \
__macro(cufftXtExecDescriptorR2C); \
__macro(cufftXtExecDescriptorC2R); \
__macro(cufftXtExecDescriptorZ2Z); \
__macro(cufftXtExecDescriptorD2Z); \
__macro(cufftXtExecDescriptorZ2D); \
__macro(cufftXtQueryPlan); \
__macro(cufftXtSetCallback); \
__macro(cufftXtClearCallback); \
__macro(cufftXtSetCallbackSharedSize); \
__macro(cufftXtMakePlanMany); \
__macro(cufftXtGetSizeMany); \
__macro(cufftXtExec); \
__macro(cufftXtExecDescriptor); \
__macro(cufftXtSetWorkAreaPolicy);
CUFFT_FFT_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUFFT_WRAP)
} // namespace dynload
} // namespace platform
} // namespace paddle
#endif
...@@ -109,6 +109,9 @@ static constexpr char* win_cusolver_lib = ...@@ -109,6 +109,9 @@ static constexpr char* win_cusolver_lib =
static constexpr char* win_cusparse_lib = static constexpr char* win_cusparse_lib =
"cusparse64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR "cusparse64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR
".dll;cusparse64_" CUDA_VERSION_MAJOR ".dll;cusparse64_10.dll"; ".dll;cusparse64_" CUDA_VERSION_MAJOR ".dll;cusparse64_10.dll";
static constexpr char* win_cufft_lib =
"cufft64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR
".dll;cufft64_" CUDA_VERSION_MAJOR ".dll;cufft64_10.dll";
#else #else
static constexpr char* win_curand_lib = static constexpr char* win_curand_lib =
"curand64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR "curand64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR
...@@ -122,6 +125,9 @@ static constexpr char* win_cusolver_lib = ...@@ -122,6 +125,9 @@ static constexpr char* win_cusolver_lib =
static constexpr char* win_cusparse_lib = static constexpr char* win_cusparse_lib =
"cusparse64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR "cusparse64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR
".dll;cusparse64_" CUDA_VERSION_MAJOR ".dll"; ".dll;cusparse64_" CUDA_VERSION_MAJOR ".dll";
static constexpr char* win_cufft_lib =
"cufft64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR
".dll;cufft64_" CUDA_VERSION_MAJOR ".dll";
#endif // CUDA_VERSION #endif // CUDA_VERSION
#endif #endif
...@@ -489,6 +495,17 @@ void* GetNvtxDsoHandle() { ...@@ -489,6 +495,17 @@ void* GetNvtxDsoHandle() {
#endif #endif
} }
void* GetCUFFTDsoHandle() {
#if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcufft.dylib");
#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA)
return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, win_cufft_lib, true,
{cuda_lib_path});
#else
return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcufft.so");
#endif
}
} // namespace dynload } // namespace dynload
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -41,6 +41,7 @@ void* GetTensorRtDsoHandle(); ...@@ -41,6 +41,7 @@ void* GetTensorRtDsoHandle();
void* GetMKLMLDsoHandle(); void* GetMKLMLDsoHandle();
void* GetOpDsoHandle(const std::string& dso_name); void* GetOpDsoHandle(const std::string& dso_name);
void* GetNvtxDsoHandle(); void* GetNvtxDsoHandle();
void* GetCUFFTDsoHandle();
void SetPaddleLibPath(const std::string&); void SetPaddleLibPath(const std::string&);
} // namespace dynload } // namespace dynload
......
...@@ -64,6 +64,7 @@ import paddle.reader # noqa: F401 ...@@ -64,6 +64,7 @@ import paddle.reader # noqa: F401
import paddle.static # noqa: F401 import paddle.static # noqa: F401
import paddle.vision # noqa: F401 import paddle.vision # noqa: F401
from .tensor import fft
from .tensor.random import bernoulli # noqa: F401 from .tensor.random import bernoulli # noqa: F401
from .tensor.attribute import rank # noqa: F401 from .tensor.attribute import rank # noqa: F401
......
...@@ -6727,8 +6727,10 @@ def pad(x, paddings, pad_value=0., name=None): ...@@ -6727,8 +6727,10 @@ def pad(x, paddings, pad_value=0., name=None):
x = fluid.data(name='data', shape=[300, 300], dtype='float32') x = fluid.data(name='data', shape=[300, 300], dtype='float32')
out = fluid.layers.pad(x=x, paddings=[0, 1, 1, 2], pad_value=0.) out = fluid.layers.pad(x=x, paddings=[0, 1, 1, 2], pad_value=0.)
""" """
check_variable_and_dtype( check_variable_and_dtype(x, 'x', [
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], "pad") 'float16', 'float32', 'float64', 'int32', 'int64', 'complex64',
'complex128'
], "pad")
helper = LayerHelper('pad', **locals()) helper = LayerHelper('pad', **locals())
dtype = helper.input_dtype(input_param_name='x') dtype = helper.input_dtype(input_param_name='x')
......
...@@ -702,6 +702,7 @@ endif() ...@@ -702,6 +702,7 @@ endif()
add_subdirectory(sequence) add_subdirectory(sequence)
add_subdirectory(dygraph_to_static) add_subdirectory(dygraph_to_static)
add_subdirectory(rnn) add_subdirectory(rnn)
add_subdirectory(fft)
if (WITH_XPU) if (WITH_XPU)
add_subdirectory(xpu) add_subdirectory(xpu)
......
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach(TEST_OP)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from functools import partial
from numpy import asarray
from numpy.fft._pocketfft import _raw_fft, _raw_fftnd, _get_forward_norm, _get_backward_norm, _cook_nd_args
def _fftc2c(a, n=None, axis=-1, norm=None, forward=None):
a = asarray(a)
if n is None:
n = a.shape[axis]
if forward:
inv_norm = _get_forward_norm(n, norm)
else:
inv_norm = _get_backward_norm(n, norm)
output = _raw_fft(a, n, axis, False, forward, inv_norm)
return output
def _fftr2c(a, n=None, axis=-1, norm=None, forward=None):
a = asarray(a)
if n is None:
n = a.shape[axis]
if forward:
inv_norm = _get_forward_norm(n, norm)
else:
inv_norm = _get_backward_norm(n, norm)
output = _raw_fft(a, n, axis, True, True, inv_norm)
if not forward:
output = output.conj()
return output
def _fftc2r(a, n=None, axis=-1, norm=None, forward=None):
a = asarray(a)
if n is None:
n = (a.shape[axis] - 1) * 2
if forward:
inv_norm = _get_forward_norm(n, norm)
else:
inv_norm = _get_backward_norm(n, norm)
output = _raw_fft(a.conj()
if forward else a, n, axis, True, False, inv_norm)
return output
def fft_c2c(x, axes, normalization, forward):
f = partial(_fftc2c, forward=forward)
y = _raw_fftnd(x, s=None, axes=axes, function=f, norm=normalization)
return y
def fft_c2c_backward(dy, axes, normalization, forward):
f = partial(_fftc2c, forward=forward)
dx = _raw_fftnd(dy, s=None, axes=axes, function=f, norm=normalization)
return dx
def fft_r2c(x, axes, normalization, forward, onesided):
a = asarray(x)
s, axes = _cook_nd_args(a, axes=axes)
if onesided:
a = _fftr2c(a, s[-1], axes[-1], normalization, forward)
for ii in range(len(axes) - 1):
a = _fftc2c(a, s[ii], axes[ii], normalization, forward)
else:
a = fft_c2c(x, axes, normalization, forward)
return a
def fft_r2c_backward(dy, x, axes, normalization, forward, onesided):
a = dy
if not onesided:
a = fft_c2c_backward(a, axes, normalization, forward).real
else:
pad_widths = [(0, 0)] * a.ndim
last_axis = axes[-1]
if last_axis < 0:
last_axis += a.ndim
last_dim_size = a.shape[last_axis]
pad_widths[last_axis] = (0, x.shape[last_axis] - last_dim_size)
a = np.pad(a, pad_width=pad_widths)
a = fft_c2c_backward(a, axes, normalization, forward).real
return a
def fft_c2r(x, axes, normalization, forward, last_dim_size):
a = asarray(x)
s, axes = _cook_nd_args(a, axes=axes, invreal=1)
if last_dim_size is not None:
s[-1] = last_dim_size
for ii in range(len(axes) - 1):
a = _fftc2c(a, s[ii], axes[ii], normalization, forward)
a = _fftc2r(a, s[-1], axes[-1], normalization, forward)
return a
此差异已折叠。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import re
import sys
from spectral_op_np import fft_c2c, fft_r2c, fft_c2r
import paddle.fluid.core as core
import paddle.fluid.dygraph as dg
import paddle.static as static
from numpy.random import random as rand
from paddle.fluid import Program, program_guard
sys.path.append("../")
from op_test import OpTest
paddle.enable_static()
TEST_CASE_NAME = 'test_case'
def parameterize(attrs, input_values=None):
if isinstance(attrs, str):
attrs = [attrs]
input_dicts = (attrs if input_values is None else
[dict(zip(attrs, vals)) for vals in input_values])
def decorator(base_class):
test_class_module = sys.modules[base_class.__module__].__dict__
for idx, input_dict in enumerate(input_dicts):
test_class_dict = dict(base_class.__dict__)
test_class_dict.update(input_dict)
name = class_name(base_class, idx, input_dict)
test_class_module[name] = type(name, (base_class, ),
test_class_dict)
for method_name in list(base_class.__dict__):
if method_name.startswith("test"):
delattr(base_class, method_name)
return base_class
return decorator
def to_safe_name(s):
return str(re.sub("[^a-zA-Z0-9_]+", "_", s))
def class_name(cls, num, params_dict):
suffix = to_safe_name(
next((v for v in params_dict.values() if isinstance(v, str)), ""))
if TEST_CASE_NAME in params_dict:
suffix = to_safe_name(params_dict["test_case"])
return "{}_{}{}".format(cls.__name__, num, suffix and "_" + suffix)
@parameterize((TEST_CASE_NAME, 'x', 'axes', 'norm', 'forward'), [
('test_axes_is_sqe_type', (np.random.random(
(12, 14)) + 1j * np.random.random((12, 14))).astype(np.complex128),
[0, 1], 'forward', True), ('test_axis_not_last', (np.random.random(
(4, 4, 4)) + 1j * np.random.random((4, 4, 4))).astype(np.complex128),
(0, 1), "backward", False),
('test_norm_forward', (np.random.random((12, 14)) + 1j * np.random.random(
(12, 14))).astype(np.complex128), (0, ), "forward",
False), ('test_norm_backward', (np.random.random(
(12, 14)) + 1j * np.random.random((12, 14))).astype(np.complex128),
(0, ), "backward", True), ('test_norm_ortho', (np.random.random(
(12, 14)) + 1j * np.random.random(
(12, 14))).astype(np.complex128), (1, ), "ortho", True)
])
class TestFFTC2COp(OpTest):
# Because framwork not support complex numerial gradient, we skip gradient check.
no_need_check_grad = True
def setUp(self):
self.op_type = "fft_c2c"
out = fft_c2c(self.x, self.axes, self.norm, self.forward)
self.inputs = {'X': self.x}
self.attrs = {
'axes': self.axes,
'normalization': self.norm,
"forward": self.forward
}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
@parameterize(
(TEST_CASE_NAME, 'x', 'axes', 'norm', 'forward', 'last_dim_size'),
[('test_axes_is_sqe_type', (np.random.random(
(12, 14)) + 1j * np.random.random((12, 14))).astype(np.complex128),
[0, 1], 'forward', True, 26), ('test_axis_not_last', (np.random.random(
(4, 4, 4)) + 1j * np.random.random((4, 4, 4))).astype(np.complex128),
(0, 1), "backward", False, None),
('test_norm_forward', (np.random.random((12, 14)) + 1j * np.random.random(
(12, 14))).astype(np.complex128), (0, ), "forward", False, 22),
('test_norm_backward', (np.random.random((12, 14)) + 1j * np.random.random(
(12, 14))).astype(np.complex128), (0, ), "backward", True,
22), ('test_norm_ortho', (np.random.random(
(12, 14)) + 1j * np.random.random((12, 14))).astype(np.complex128),
(1, ), "ortho", True, 26)])
class TestFFTC2ROp(OpTest):
# Because framwork not support complex numerial gradient, we skip gradient check.
no_need_check_grad = True
def setUp(self):
self.op_type = "fft_c2r"
out = fft_c2r(self.x, self.axes, self.norm, self.forward,
self.last_dim_size)
self.inputs = {'X': self.x}
self.attrs = {
"axes": self.axes,
"normalization": self.norm,
"forward": self.forward,
"last_dim_size": self.last_dim_size
}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
@parameterize(
(TEST_CASE_NAME, 'x', 'axes', 'norm', 'forward', 'onesided'),
[('test_axes_is_sqe_type', np.random.randn(12, 14).astype(np.float64),
(0, 1), 'forward', True,
True), ('test_axis_not_last', np.random.randn(4, 4, 4).astype(np.float64),
(0, 1), "backward", False, True),
('test_norm_forward', np.random.randn(12, 14).astype(np.float64), (0, 1),
"forward", False, False),
('test_norm_backward', np.random.randn(12, 14).astype(np.float64), (0, ),
"backward", True, False), ('test_norm_ortho',
np.random.randn(12, 14).astype(np.float64),
(1, ), "ortho", True, False)])
class TestFFTR2COp(OpTest):
# Because framwork not support complex numerial gradient, we skip gradient check.
no_need_check_grad = True
def setUp(self):
self.op_type = "fft_r2c"
out = fft_r2c(self.x, self.axes, self.norm, self.forward, self.onesided)
self.inputs = {'X': self.x}
self.attrs = {
'axes': self.axes,
'normalization': self.norm,
"forward": self.forward,
'onesided': self.onesided
}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from numpy.lib.stride_tricks import as_strided
import paddle
import unittest
from op_test import OpTest
def frame_from_librosa(x, frame_length, hop_length, axis=-1):
if axis == -1 and not x.flags["C_CONTIGUOUS"]:
x = np.ascontiguousarray(x)
elif axis == 0 and not x.flags["F_CONTIGUOUS"]:
x = np.asfortranarray(x)
n_frames = 1 + (x.shape[axis] - frame_length) // hop_length
strides = np.asarray(x.strides)
if axis == -1:
shape = list(x.shape)[:-1] + [frame_length, n_frames]
strides = list(strides) + [hop_length * x.itemsize]
elif axis == 0:
shape = [n_frames, frame_length] + list(x.shape)[1:]
strides = [hop_length * x.itemsize] + list(strides)
else:
raise ValueError("Frame axis={} must be either 0 or -1".format(axis))
return as_strided(x, shape=shape, strides=strides)
class TestFrameOp(OpTest):
def setUp(self):
self.op_type = "frame"
self.shape, self.type, self.attrs = self.initTestCase()
self.inputs = {
'X': np.random.random(size=self.shape).astype(self.type),
}
self.outputs = {
'Out': frame_from_librosa(
x=self.inputs['X'], **self.attrs)
}
def initTestCase(self):
input_shape = (150, )
input_type = 'float64'
attrs = {
'frame_length': 50,
'hop_length': 15,
'axis': -1,
}
return input_shape, input_type, attrs
def test_check_output(self):
paddle.enable_static()
self.check_output()
paddle.disable_static()
def test_check_grad_normal(self):
paddle.enable_static()
self.check_grad(['X'], 'Out')
paddle.disable_static()
class TestCase1(TestFrameOp):
def initTestCase(self):
input_shape = (150, )
input_type = 'float64'
attrs = {
'frame_length': 50,
'hop_length': 15,
'axis': 0,
}
return input_shape, input_type, attrs
class TestCase2(TestFrameOp):
def initTestCase(self):
input_shape = (8, 150)
input_type = 'float64'
attrs = {
'frame_length': 50,
'hop_length': 15,
'axis': -1,
}
return input_shape, input_type, attrs
class TestCase3(TestFrameOp):
def initTestCase(self):
input_shape = (150, 8)
input_type = 'float64'
attrs = {
'frame_length': 50,
'hop_length': 15,
'axis': 0,
}
return input_shape, input_type, attrs
class TestCase4(TestFrameOp):
def initTestCase(self):
input_shape = (4, 2, 150)
input_type = 'float64'
attrs = {
'frame_length': 50,
'hop_length': 15,
'axis': -1,
}
return input_shape, input_type, attrs
class TestCase5(TestFrameOp):
def initTestCase(self):
input_shape = (150, 4, 2)
input_type = 'float64'
attrs = {
'frame_length': 50,
'hop_length': 15,
'axis': 0,
}
return input_shape, input_type, attrs
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import unittest
from op_test import OpTest
def overlap_add(x, hop_length, axis=-1):
assert axis in [0, -1], 'axis should be 0/-1.'
assert len(x.shape) >= 2, 'Input dims shoulb be >= 2.'
squeeze_output = False
if len(x.shape) == 2:
squeeze_output = True
dim = 0 if axis == -1 else -1
x = np.expand_dims(x, dim) # batch
n_frames = x.shape[axis]
frame_length = x.shape[1] if axis == 0 else x.shape[-2]
# Assure no gaps between frames.
assert 0 < hop_length <= frame_length, \
f'hop_length should be in (0, frame_length({frame_length})], but got {hop_length}.'
seq_length = (n_frames - 1) * hop_length + frame_length
reshape_output = False
if len(x.shape) > 3:
reshape_output = True
if axis == 0:
target_shape = [seq_length] + list(x.shape[2:])
x = x.reshape(n_frames, frame_length, np.product(x.shape[2:]))
else:
target_shape = list(x.shape[:-2]) + [seq_length]
x = x.reshape(np.product(x.shape[:-2]), frame_length, n_frames)
if axis == 0:
x = x.transpose((2, 1, 0))
y = np.zeros(shape=[np.product(x.shape[:-2]), seq_length], dtype=x.dtype)
for i in range(x.shape[0]):
for frame in range(x.shape[-1]):
sample = frame * hop_length
y[i, sample:sample + frame_length] += x[i, :, frame]
if axis == 0:
y = y.transpose((1, 0))
if reshape_output:
y = y.reshape(target_shape)
if squeeze_output:
y = y.squeeze(-1) if axis == 0 else y.squeeze(0)
return y
class TestOverlapAddOp(OpTest):
def setUp(self):
self.op_type = "overlap_add"
self.shape, self.type, self.attrs = self.initTestCase()
self.inputs = {
'X': np.random.random(size=self.shape).astype(self.type),
}
self.outputs = {'Out': overlap_add(x=self.inputs['X'], **self.attrs)}
def initTestCase(self):
input_shape = (50, 3)
input_type = 'float64'
attrs = {
'hop_length': 4,
'axis': -1,
}
return input_shape, input_type, attrs
def test_check_output(self):
paddle.enable_static()
self.check_output()
paddle.disable_static()
def test_check_grad_normal(self):
paddle.enable_static()
self.check_grad(['X'], 'Out')
paddle.disable_static()
class TestCase1(TestOverlapAddOp):
def initTestCase(self):
input_shape = (3, 50)
input_type = 'float64'
attrs = {
'hop_length': 4,
'axis': 0,
}
return input_shape, input_type, attrs
class TestCase2(TestOverlapAddOp):
def initTestCase(self):
input_shape = (2, 40, 5)
input_type = 'float64'
attrs = {
'hop_length': 10,
'axis': -1,
}
return input_shape, input_type, attrs
class TestCase3(TestOverlapAddOp):
def initTestCase(self):
input_shape = (5, 40, 2)
input_type = 'float64'
attrs = {
'hop_length': 10,
'axis': 0,
}
return input_shape, input_type, attrs
class TestCase4(TestOverlapAddOp):
def initTestCase(self):
input_shape = (3, 5, 12, 8)
input_type = 'float64'
attrs = {
'hop_length': 5,
'axis': -1,
}
return input_shape, input_type, attrs
class TestCase5(TestOverlapAddOp):
def initTestCase(self):
input_shape = (8, 12, 5, 3)
input_type = 'float64'
attrs = {
'hop_length': 5,
'axis': 0,
}
return input_shape, input_type, attrs
if __name__ == '__main__':
unittest.main()
此差异已折叠。
...@@ -216,6 +216,8 @@ from .array import array_write # noqa: F401 ...@@ -216,6 +216,8 @@ from .array import array_write # noqa: F401
from .array import create_array # noqa: F401 from .array import create_array # noqa: F401
from .einsum import einsum # noqa: F401 from .einsum import einsum # noqa: F401
from . import fft
from . import signal
#this list used in math_op_patch.py for _binary_creator_ #this list used in math_op_patch.py for _binary_creator_
tensor_method_func = [ #noqa tensor_method_func = [ #noqa
......
...@@ -35,6 +35,41 @@ def _complex_to_real_dtype(dtype): ...@@ -35,6 +35,41 @@ def _complex_to_real_dtype(dtype):
return dtype return dtype
def _real_to_complex_dtype(dtype):
if dtype == core.VarDesc.VarType.FP32:
return core.VarDesc.VarType.COMPLEX64
elif dtype == core.VarDesc.VarType.FP64:
return core.VarDesc.VarType.COMPLEX128
else:
return dtype
def is_complex(x):
dtype = x.dtype
is_complex_dtype = (dtype == core.VarDesc.VarType.COMPLEX64 or
dtype == core.VarDesc.VarType.COMPLEX128)
return is_complex_dtype
def is_floating_point(x):
dtype = x.dtype
is_fp_dtype = (dtype == core.VarDesc.VarType.FP32 or
dtype == core.VarDesc.VarType.FP64 or
dtype == core.VarDesc.VarType.FP16 or
dtype == core.VarDesc.VarType.BF16)
return is_fp_dtype
def is_interger(x):
dtype = x.dtype
is_int_dtype = (dtype == core.VarDesc.VarType.UINT8 or
dtype == core.VarDesc.VarType.INT8 or
dtype == core.VarDesc.VarType.INT16 or
dtype == core.VarDesc.VarType.INT32 or
dtype == core.VarDesc.VarType.INT64)
return is_int_dtype
def real(x, name=None): def real(x, name=None):
""" """
Returns a new tensor containing real values of the input tensor. Returns a new tensor containing real values of the input tensor.
......
此差异已折叠。
...@@ -682,7 +682,7 @@ def roll(x, shifts, axis=None, name=None): ...@@ -682,7 +682,7 @@ def roll(x, shifts, axis=None, name=None):
axis = [axis] axis = [axis]
len_origin_shape = len(origin_shape) len_origin_shape = len(origin_shape)
if axis: if axis is not None:
for i in range(len(axis)): for i in range(len(axis)):
if axis[i] >= len_origin_shape or axis[i] < -len_origin_shape: if axis[i] >= len_origin_shape or axis[i] < -len_origin_shape:
raise ValueError( raise ValueError(
......
此差异已折叠。
...@@ -6,6 +6,7 @@ gym ...@@ -6,6 +6,7 @@ gym
opencv-python<=4.2.0.32 opencv-python<=4.2.0.32
visualdl visualdl
paddle2onnx>=0.4 paddle2onnx>=0.4
scipy scipy>=1.6
prettytable prettytable
distro distro
numpy>=1.20
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册