未验证 提交 319f95d0 编写于 作者: K KP 提交者: GitHub

Add complex type compatibility for stft api and stft op. (#40113)

* Add stft_op.

* Add stft_grad_op.

* Add stft_op unittest.

* [DLTP-45176] Add complex compatibility in static mode for stft api.

* [DLTP-45176] Add complex compatibility in static mode for stft api.

* Add doc.

* Update unitests of stft op.

* Update spectral helper.

* fix coding style.
上级 3d0be938
...@@ -64,18 +64,26 @@ class FrameOp : public framework::OperatorWithKernel { ...@@ -64,18 +64,26 @@ class FrameOp : public framework::OperatorWithKernel {
end_axis = x_rank - 2; end_axis = x_rank - 2;
} }
PADDLE_ENFORCE_LE(frame_length, seq_length, bool contain_unknown_dim = phi::contain_unknown_dim(x_dims);
platform::errors::InvalidArgument( bool check = ctx->IsRuntime() || !contain_unknown_dim;
"Attribute(frame_length) of FrameOp should be less " if (check) {
"equal than sequence length, but got (%s) > (%s).", PADDLE_ENFORCE_LE(frame_length, seq_length,
frame_length, seq_length)); platform::errors::InvalidArgument(
"Attribute(frame_length) of FrameOp should be less "
"equal than sequence length, but got (%s) > (%s).",
frame_length, seq_length));
}
// It won't go into for loop when x_rank == 1U. // It won't go into for loop when x_rank == 1U.
for (int i = start_axis; i <= end_axis; i++) { for (int i = start_axis; i <= end_axis; i++) {
output_shape.push_back(x_dims[i]); output_shape.push_back(x_dims[i]);
} }
n_frames = 1 + (seq_length - frame_length) / hop_length; if (seq_length == -1) {
n_frames = -1;
} else {
n_frames = 1 + (seq_length - frame_length) / hop_length;
}
if (axis == 0) { if (axis == 0) {
// (n_frames, frame_length, ...) // (n_frames, frame_length, ...)
......
...@@ -98,9 +98,17 @@ REGISTER_OP_CPU_KERNEL( ...@@ -98,9 +98,17 @@ REGISTER_OP_CPU_KERNEL(
mean, ops::MeanKernel<paddle::platform::CPUDeviceContext, float>, mean, ops::MeanKernel<paddle::platform::CPUDeviceContext, float>,
ops::MeanKernel<paddle::platform::CPUDeviceContext, double>, ops::MeanKernel<paddle::platform::CPUDeviceContext, double>,
ops::MeanKernel<paddle::platform::CPUDeviceContext, ops::MeanKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>); paddle::platform::bfloat16>,
ops::MeanKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::MeanKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
mean_grad, ops::MeanGradKernel<paddle::platform::CPUDeviceContext, float>, mean_grad, ops::MeanGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::MeanGradKernel<paddle::platform::CPUDeviceContext, double>, ops::MeanGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::MeanGradKernel<paddle::platform::CPUDeviceContext, ops::MeanGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>); paddle::platform::bfloat16>,
ops::MeanGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::MeanGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
...@@ -102,10 +102,17 @@ namespace plat = paddle::platform; ...@@ -102,10 +102,17 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
mean, ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext, float>, mean, ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext, double>, ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext, plat::float16>); ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
mean_grad, mean_grad,
ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext, float>, ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext, double>, ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext, ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext,
plat::float16>); paddle::platform::complex<float>>,
ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
...@@ -54,6 +54,7 @@ class OverlapAddOp : public framework::OperatorWithKernel { ...@@ -54,6 +54,7 @@ class OverlapAddOp : public framework::OperatorWithKernel {
std::vector<int64_t> output_shape; std::vector<int64_t> output_shape;
int n_frames; int n_frames;
int frame_length; int frame_length;
int seq_length;
int start_axis; int start_axis;
int end_axis; int end_axis;
...@@ -69,14 +70,22 @@ class OverlapAddOp : public framework::OperatorWithKernel { ...@@ -69,14 +70,22 @@ class OverlapAddOp : public framework::OperatorWithKernel {
end_axis = x_rank - 3; end_axis = x_rank - 3;
} }
PADDLE_ENFORCE_LE( bool contain_unknown_dim = phi::contain_unknown_dim(x_dims);
hop_length, frame_length, bool check = ctx->IsRuntime() || !contain_unknown_dim;
platform::errors::InvalidArgument( if (check) {
"Attribute(hop_length) of OverlapAddOp should be less or equal " PADDLE_ENFORCE_LE(
"than frame_length, but got hop_length(%s) > frame_length(%s).", hop_length, frame_length,
hop_length, frame_length)); platform::errors::InvalidArgument(
"Attribute(hop_length) of OverlapAddOp should be less or equal "
"than frame_length, but got hop_length(%s) > frame_length(%s).",
hop_length, frame_length));
}
const int seq_length = (n_frames - 1) * hop_length + frame_length; if (n_frames == -1) {
seq_length = -1;
} else {
seq_length = (n_frames - 1) * hop_length + frame_length;
}
// It won't go into for loop when x_rank == 2U. // It won't go into for loop when x_rank == 2U.
for (int i = start_axis; i <= end_axis; i++) { for (int i = start_axis; i <= end_axis; i++) {
......
...@@ -16,451 +16,469 @@ ...@@ -16,451 +16,469 @@
#include "paddle/fluid/operators/spectral_op.h" #include "paddle/fluid/operators/spectral_op.h"
#ifdef PADDLE_WITH_HIP #if defined(PADDLE_WITH_ONEMKL)
#include "paddle/fluid/platform/dynload/hipfft.h" #include "paddle/phi/backends/dynload/mklrt.h"
#endif #elif defined(PADDLE_WITH_POCKETFFT)
#include "extern_pocketfft/pocketfft_hdronly.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/dynload/cufft.h"
#endif #endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using ScalarType = framework::proto::VarType::Type;
const int64_t kMaxFFTNdim = 3;
const int64_t kMaxDataNdim = kMaxFFTNdim + 1;
// This struct is used to easily compute hashes of the
// parameters. It will be the **key** to the plan cache.
struct FFTConfigKey {
// between 1 and kMaxFFTNdim, i.e., 1 <= signal_ndim <= 3
int64_t signal_ndim_;
// These include additional batch dimension as well.
int64_t sizes_[kMaxDataNdim];
int64_t input_shape_[kMaxDataNdim];
int64_t output_shape_[kMaxDataNdim];
FFTTransformType fft_type_;
ScalarType value_type_;
FFTConfigKey() = default;
FFTConfigKey(const std::vector<int64_t>& in_shape,
const std::vector<int64_t>& out_shape,
const std::vector<int64_t>& signal_size,
FFTTransformType fft_type, ScalarType value_type) {
// Padding bits must be zeroed for hashing
memset(this, 0, sizeof(*this));
signal_ndim_ = signal_size.size() - 1;
fft_type_ = fft_type;
value_type_ = value_type;
std::copy(signal_size.cbegin(), signal_size.cend(), sizes_);
std::copy(in_shape.cbegin(), in_shape.cend(), input_shape_);
std::copy(out_shape.cbegin(), out_shape.cend(), output_shape_);
}
};
#if defined(PADDLE_WITH_CUDA)
// An RAII encapsulation of cuFFTHandle
class CuFFTHandle {
::cufftHandle handle_;
public: using Tensor = framework::Tensor;
CuFFTHandle() {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cufftCreate(&handle_));
}
CuFFTHandle(const CuFFTHandle& other) = delete;
CuFFTHandle& operator=(const CuFFTHandle& other) = delete;
CuFFTHandle(CuFFTHandle&& other) = delete; // FFT Functors
CuFFTHandle& operator=(CuFFTHandle&& other) = delete; #if defined(PADDLE_WITH_ONEMKL)
::cufftHandle& get() { return handle_; } #define MKL_DFTI_CHECK(expr) \
const ::cufftHandle& get() const { return handle_; } do { \
MKL_LONG status = (expr); \
if (!phi::dynload::DftiErrorClass(status, DFTI_NO_ERROR)) \
PADDLE_THROW( \
platform::errors::External(phi::dynload::DftiErrorMessage(status))); \
} while (0);
~CuFFTHandle() { struct DftiDescriptorDeleter {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cufftDestroy(handle_)); void operator()(DFTI_DESCRIPTOR_HANDLE handle) {
if (handle != nullptr) {
MKL_DFTI_CHECK(phi::dynload::DftiFreeDescriptor(&handle));
}
} }
}; };
using plan_size_type = long long int; // NOLINT // A RAII wrapper for MKL_DESCRIPTOR*
// This class contains all the information needed to execute a cuFFT plan: class DftiDescriptor {
// 1. the plan
// 2. the workspace size needed
class FFTConfig {
public: public:
// Only move semantics is enought for this class. Although we already use void init(DFTI_CONFIG_VALUE precision, DFTI_CONFIG_VALUE signal_type,
// unique_ptr for the plan, still remove copy constructor and assignment op so MKL_LONG signal_ndim, MKL_LONG* sizes) {
// we don't accidentally copy and take perf hit. PADDLE_ENFORCE_EQ(desc_.get(), nullptr,
explicit FFTConfig(const FFTConfigKey& plan_key) platform::errors::AlreadyExists(
: FFTConfig( "DftiDescriptor has already been initialized."));
std::vector<int64_t>(plan_key.sizes_,
plan_key.sizes_ + plan_key.signal_ndim_ + 1), DFTI_DESCRIPTOR* raw_desc;
plan_key.signal_ndim_, plan_key.fft_type_, plan_key.value_type_) {} MKL_DFTI_CHECK(phi::dynload::DftiCreateDescriptorX(
&raw_desc, precision, signal_type, signal_ndim, sizes));
// sizes are full signal, including batch size and always two-sided desc_.reset(raw_desc);
FFTConfig(const std::vector<int64_t>& sizes, const int64_t signal_ndim,
FFTTransformType fft_type, ScalarType dtype)
: fft_type_(fft_type), value_type_(dtype) {
// signal sizes (excluding batch dim)
std::vector<plan_size_type> signal_sizes(sizes.begin() + 1, sizes.end());
// input batch size
const auto batch = static_cast<plan_size_type>(sizes[0]);
// const int64_t signal_ndim = sizes.size() - 1;
PADDLE_ENFORCE_EQ(signal_ndim, sizes.size() - 1,
platform::errors::InvalidArgument(
"The signal_ndim must be equal to sizes.size() - 1,"
"But signal_ndim is: [%d], sizes.size() - 1 is: [%d]",
signal_ndim, sizes.size() - 1));
cudaDataType itype, otype, exec_type;
const auto complex_input = has_complex_input(fft_type);
const auto complex_output = has_complex_output(fft_type);
if (dtype == framework::proto::VarType::FP32) {
itype = complex_input ? CUDA_C_32F : CUDA_R_32F;
otype = complex_output ? CUDA_C_32F : CUDA_R_32F;
exec_type = CUDA_C_32F;
} else if (dtype == framework::proto::VarType::FP64) {
itype = complex_input ? CUDA_C_64F : CUDA_R_64F;
otype = complex_output ? CUDA_C_64F : CUDA_R_64F;
exec_type = CUDA_C_64F;
} else if (dtype == framework::proto::VarType::FP16) {
itype = complex_input ? CUDA_C_16F : CUDA_R_16F;
otype = complex_output ? CUDA_C_16F : CUDA_R_16F;
exec_type = CUDA_C_16F;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"cuFFT only support transforms of type float16, float32 and "
"float64"));
}
// disable auto allocation of workspace to use allocator from the framework
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cufftSetAutoAllocation(
plan(), /* autoAllocate */ 0));
size_t ws_size_t;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cufftXtMakePlanMany(
plan(), signal_ndim, signal_sizes.data(),
/* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1, itype,
/* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, otype,
batch, &ws_size_t, exec_type));
ws_size = ws_size_t;
} }
FFTConfig(const FFTConfig& other) = delete; DFTI_DESCRIPTOR* get() const {
FFTConfig& operator=(const FFTConfig& other) = delete; DFTI_DESCRIPTOR* raw_desc = desc_.get();
PADDLE_ENFORCE_NOT_NULL(raw_desc,
FFTConfig(FFTConfig&& other) = delete; platform::errors::PreconditionNotMet(
FFTConfig& operator=(FFTConfig&& other) = delete; "DFTI DESCRIPTOR has not been initialized."));
return raw_desc;
const cufftHandle& plan() const { return plan_ptr.get(); } }
FFTTransformType transform_type() const { return fft_type_; }
ScalarType data_type() const { return value_type_; }
size_t workspace_size() const { return ws_size; }
private: private:
CuFFTHandle plan_ptr; std::unique_ptr<DFTI_DESCRIPTOR, DftiDescriptorDeleter> desc_;
size_t ws_size;
FFTTransformType fft_type_;
ScalarType value_type_;
}; };
#elif defined(PADDLE_WITH_HIP) static DftiDescriptor _plan_mkl_fft(
// An RAII encapsulation of cuFFTHandle const framework::proto::VarType::Type& in_dtype,
class HIPFFTHandle { const framework::proto::VarType::Type& out_dtype,
::hipfftHandle handle_; const framework::DDim& in_strides, const framework::DDim& out_strides,
const std::vector<int>& signal_sizes, FFTNormMode normalization,
public: bool forward) {
HIPFFTHandle() { const DFTI_CONFIG_VALUE precision = [&] {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftCreate(&handle_)); switch (in_dtype) {
case framework::proto::VarType::FP32:
return DFTI_SINGLE;
case framework::proto::VarType::COMPLEX64:
return DFTI_SINGLE;
case framework::proto::VarType::FP64:
return DFTI_DOUBLE;
case framework::proto::VarType::COMPLEX128:
return DFTI_DOUBLE;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Invalid input datatype (%s), input data type should be FP32, "
"FP64, COMPLEX64 or COMPLEX128.",
framework::DataTypeToString(in_dtype)));
}
}();
// C2C, R2C, C2R
const FFTTransformType fft_type = GetFFTTransformType(in_dtype, out_dtype);
const DFTI_CONFIG_VALUE domain =
(fft_type == FFTTransformType::C2C) ? DFTI_COMPLEX : DFTI_REAL;
DftiDescriptor descriptor;
std::vector<MKL_LONG> fft_sizes(signal_sizes.cbegin(), signal_sizes.cend());
const MKL_LONG signal_ndim = fft_sizes.size() - 1;
descriptor.init(precision, domain, signal_ndim, fft_sizes.data() + 1);
// placement inplace or not inplace
MKL_DFTI_CHECK(phi::dynload::DftiSetValue(descriptor.get(), DFTI_PLACEMENT,
DFTI_NOT_INPLACE));
// number of transformations
const MKL_LONG batch_size = fft_sizes[0];
MKL_DFTI_CHECK(phi::dynload::DftiSetValue(
descriptor.get(), DFTI_NUMBER_OF_TRANSFORMS, batch_size));
// input & output distance
const MKL_LONG idist = in_strides[0];
const MKL_LONG odist = out_strides[0];
MKL_DFTI_CHECK(
phi::dynload::DftiSetValue(descriptor.get(), DFTI_INPUT_DISTANCE, idist));
MKL_DFTI_CHECK(phi::dynload::DftiSetValue(descriptor.get(),
DFTI_OUTPUT_DISTANCE, odist));
// input & output stride
std::vector<MKL_LONG> mkl_in_stride(1 + signal_ndim, 0);
std::vector<MKL_LONG> mkl_out_stride(1 + signal_ndim, 0);
for (MKL_LONG i = 1; i <= signal_ndim; i++) {
mkl_in_stride[i] = in_strides[i];
mkl_out_stride[i] = out_strides[i];
} }
MKL_DFTI_CHECK(phi::dynload::DftiSetValue(
HIPFFTHandle(const HIPFFTHandle& other) = delete; descriptor.get(), DFTI_INPUT_STRIDES, mkl_in_stride.data()));
HIPFFTHandle& operator=(const HIPFFTHandle& other) = delete; MKL_DFTI_CHECK(phi::dynload::DftiSetValue(
descriptor.get(), DFTI_OUTPUT_STRIDES, mkl_out_stride.data()));
HIPFFTHandle(HIPFFTHandle&& other) = delete;
HIPFFTHandle& operator=(HIPFFTHandle&& other) = delete; // conjugate even storage
if (!(fft_type == FFTTransformType::C2C)) {
::hipfftHandle& get() { return handle_; } MKL_DFTI_CHECK(phi::dynload::DftiSetValue(
const ::hipfftHandle& get() const { return handle_; } descriptor.get(), DFTI_CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX));
~HIPFFTHandle() {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftDestroy(handle_));
} }
};
using plan_size_type = int; MKL_LONG signal_numel =
// This class contains all the information needed to execute a cuFFT plan: std::accumulate(fft_sizes.cbegin() + 1, fft_sizes.cend(), 1UL,
// 1. the plan std::multiplies<MKL_LONG>());
// 2. the workspace size needed if (normalization != FFTNormMode::none) {
class FFTConfig { const double scale =
public: ((normalization == FFTNormMode::by_sqrt_n)
// Only move semantics is enought for this class. Although we already use ? 1.0 / std::sqrt(static_cast<double>(signal_numel))
// unique_ptr for the plan, still remove copy constructor and assignment op so : 1.0 / static_cast<double>(signal_numel));
// we don't accidentally copy and take perf hit. const auto scale_direction = [&]() {
explicit FFTConfig(const FFTConfigKey& plan_key) if (fft_type == FFTTransformType::R2C ||
: FFTConfig( (fft_type == FFTTransformType::C2C && forward)) {
std::vector<int64_t>(plan_key.sizes_, return DFTI_FORWARD_SCALE;
plan_key.sizes_ + plan_key.signal_ndim_ + 1), } else {
plan_key.signal_ndim_, plan_key.fft_type_, plan_key.value_type_) {} // (fft_type == FFTTransformType::C2R ||
// (fft_type == FFTTransformType::C2C && !forward))
// sizes are full signal, including batch size and always two-sided return DFTI_BACKWARD_SCALE;
FFTConfig(const std::vector<int64_t>& sizes, const int64_t signal_ndim,
FFTTransformType fft_type, ScalarType dtype)
: fft_type_(fft_type), value_type_(dtype) {
// signal sizes (excluding batch dim)
std::vector<plan_size_type> signal_sizes(sizes.begin() + 1, sizes.end());
// input batch size
const auto batch = static_cast<plan_size_type>(sizes[0]);
// const int64_t signal_ndim = sizes.size() - 1;
PADDLE_ENFORCE_EQ(signal_ndim, sizes.size() - 1,
platform::errors::InvalidArgument(
"The signal_ndim must be equal to sizes.size() - 1,"
"But signal_ndim is: [%d], sizes.size() - 1 is: [%d]",
signal_ndim, sizes.size() - 1));
hipfftType exec_type = [&] {
if (dtype == framework::proto::VarType::FP32) {
switch (fft_type) {
case FFTTransformType::C2C:
return HIPFFT_C2C;
case FFTTransformType::R2C:
return HIPFFT_R2C;
case FFTTransformType::C2R:
return HIPFFT_C2R;
}
} else if (dtype == framework::proto::VarType::FP64) {
switch (fft_type) {
case FFTTransformType::C2C:
return HIPFFT_Z2Z;
case FFTTransformType::R2C:
return HIPFFT_D2Z;
case FFTTransformType::C2R:
return HIPFFT_Z2D;
}
} }
PADDLE_THROW(platform::errors::InvalidArgument(
"hipFFT only support transforms of type float32 and float64"));
}(); }();
MKL_DFTI_CHECK(
// disable auto allocation of workspace to use allocator from the framework phi::dynload::DftiSetValue(descriptor.get(), scale_direction, scale));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftSetAutoAllocation(
plan(), /* autoAllocate */ 0));
size_t ws_size_t;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftMakePlanMany(
plan(), signal_ndim, signal_sizes.data(),
/* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1,
/* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, exec_type,
batch, &ws_size_t));
ws_size = ws_size_t;
} }
const hipfftHandle& plan() const { return plan_ptr.get(); } // commit the descriptor
MKL_DFTI_CHECK(phi::dynload::DftiCommitDescriptor(descriptor.get()));
FFTTransformType transform_type() const { return fft_type_; } return descriptor;
ScalarType data_type() const { return value_type_; } }
size_t workspace_size() const { return ws_size; }
private: // Execute a general fft operation (can be c2c, onesided r2c or onesided c2r)
HIPFFTHandle plan_ptr; template <typename DeviceContext, typename Ti, typename To>
size_t ws_size; void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out,
FFTTransformType fft_type_; const std::vector<int64_t>& axes, FFTNormMode normalization,
ScalarType value_type_; bool forward) {
}; const framework::DDim& in_sizes = x->dims();
#endif const int ndim = in_sizes.size();
const int signal_ndim = axes.size();
const int batch_ndim = ndim - signal_ndim;
const framework::DDim& out_sizes = out->dims();
// make a dim permutation
std::vector<int> dim_permute(ndim);
std::iota(dim_permute.begin(), dim_permute.end(), 0);
std::vector<bool> is_transformed_dim(ndim, false);
for (const auto& d : axes) {
is_transformed_dim[d] = true;
}
const auto batch_end =
std::partition(dim_permute.begin(), dim_permute.end(),
[&](size_t axis) { return !is_transformed_dim[axis]; });
std::copy(axes.cbegin(), axes.cend(), batch_end);
// transpose input according to that permutation
framework::DDim transposed_input_shape = in_sizes.transpose(dim_permute);
std::vector<int64_t> transposed_input_shape_ =
phi::vectorize(transposed_input_shape);
framework::Tensor transposed_input;
transposed_input.Resize(transposed_input_shape);
const auto place = ctx.GetPlace();
transposed_input.mutable_data<Ti>(place);
TransCompute<platform::CPUDeviceContext, Ti>(ndim, ctx, *x, &transposed_input,
dim_permute);
// make an collapsed input: collapse batch axes for input
const int batch_size = std::accumulate(
transposed_input_shape.Get(), transposed_input_shape.Get() + batch_ndim,
1L, std::multiplies<int64_t>());
std::vector<int> collapsed_input_shape_(1 + signal_ndim);
collapsed_input_shape_[0] = batch_size;
std::copy(transposed_input_shape_.begin() + batch_ndim,
transposed_input_shape_.end(), collapsed_input_shape_.begin() + 1);
const framework::DDim collapsed_input_shape =
phi::make_ddim(collapsed_input_shape_);
transposed_input.Resize(collapsed_input_shape);
framework::Tensor& collapsed_input = transposed_input;
// make a collapsed output
std::vector<int> collapsed_output_shape_(1 + signal_ndim);
collapsed_output_shape_[0] = batch_size;
for (int i = 0; i < signal_ndim; i++) {
collapsed_output_shape_[1 + i] = out_sizes[axes[i]];
}
const framework::DDim collapsed_output_shape =
phi::make_ddim(collapsed_output_shape_);
framework::Tensor collapsed_output;
collapsed_output.Resize(collapsed_output_shape);
collapsed_output.mutable_data(place, out->type());
// signal sizes
std::vector<int> signal_sizes(1 + signal_ndim);
signal_sizes[0] = batch_size;
for (int i = 0; i < signal_ndim; i++) {
signal_sizes[1 + i] =
std::max(collapsed_input_shape[1 + i], collapsed_output_shape[1 + i]);
}
// Hashing machinery for Key // input & output stride
// Fowler–Noll–Vo hash function const framework::DDim input_stride = phi::stride(collapsed_input_shape);
// see const framework::DDim output_stride = phi::stride(collapsed_output_shape);
// https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function
template <typename Key> // make a DFTI_DESCRIPTOR
struct KeyHash { DftiDescriptor desc =
// Key must be a POD because we read out its memory _plan_mkl_fft(framework::TransToProtoVarType(x->dtype()),
// contenst as char* when hashing framework::TransToProtoVarType(out->dtype()), input_stride,
static_assert(std::is_pod<Key>::value, "Key must be plain old data type"); output_stride, signal_sizes, normalization, forward);
size_t operator()(const Key& params) const { const FFTTransformType fft_type =
auto ptr = reinterpret_cast<const uint8_t*>(&params); GetFFTTransformType(framework::TransToProtoVarType(x->dtype()),
uint32_t value = 0x811C9DC5; framework::TransToProtoVarType(out->type()));
for (int i = 0; i < static_cast<int>(sizeof(Key)); ++i) { if (fft_type == FFTTransformType::C2R && forward) {
value ^= ptr[i]; framework::Tensor collapsed_input_conj(collapsed_input.dtype());
value *= 0x01000193; collapsed_input_conj.mutable_data<Ti>(collapsed_input.dims(),
ctx.GetPlace());
// conjugate the input
platform::ForRange<DeviceContext> for_range(ctx, collapsed_input.numel());
phi::funcs::ConjFunctor<Ti> functor(collapsed_input.data<Ti>(),
collapsed_input.numel(),
collapsed_input_conj.data<Ti>());
for_range(functor);
MKL_DFTI_CHECK(phi::dynload::DftiComputeBackward(
desc.get(), collapsed_input_conj.data(), collapsed_output.data()));
} else if (fft_type == FFTTransformType::R2C && !forward) {
framework::Tensor collapsed_output_conj(collapsed_output.dtype());
collapsed_output_conj.mutable_data<To>(collapsed_output.dims(),
ctx.GetPlace());
MKL_DFTI_CHECK(phi::dynload::DftiComputeForward(
desc.get(), collapsed_input.data(), collapsed_output_conj.data()));
// conjugate the output
platform::ForRange<DeviceContext> for_range(ctx, collapsed_output.numel());
phi::funcs::ConjFunctor<To> functor(collapsed_output_conj.data<To>(),
collapsed_output.numel(),
collapsed_output.data<To>());
for_range(functor);
} else {
if (forward) {
MKL_DFTI_CHECK(phi::dynload::DftiComputeForward(
desc.get(), collapsed_input.data(), collapsed_output.data()));
} else {
MKL_DFTI_CHECK(phi::dynload::DftiComputeBackward(
desc.get(), collapsed_input.data(), collapsed_output.data()));
} }
return static_cast<size_t>(value);
} }
};
template <typename Key> // resize for the collapsed output
struct KeyEqual { framework::DDim transposed_output_shape = out_sizes.transpose(dim_permute);
// Key must be a POD because we read out its memory collapsed_output.Resize(transposed_output_shape);
// contenst as char* when comparing framework::Tensor& transposed_output = collapsed_output;
static_assert(std::is_pod<Key>::value, "Key must be plain old data type");
bool operator()(const Key& a, const Key& b) const { // reverse the transposition
auto ptr1 = reinterpret_cast<const uint8_t*>(&a); std::vector<int> reverse_dim_permute(ndim);
auto ptr2 = reinterpret_cast<const uint8_t*>(&b); for (int i = 0; i < ndim; i++) {
return memcmp(ptr1, ptr2, sizeof(Key)) == 0; reverse_dim_permute[dim_permute[i]] = i;
} }
}; TransCompute<platform::CPUDeviceContext, To>(ndim, ctx, transposed_output,
out, reverse_dim_permute);
#if CUDA_VERSION < 10000 }
// Note that the max plan number for CUDA version < 10 has to be 1023
// due to a bug that fails on the 1024th plan
constexpr size_t CUFFT_MAX_PLAN_NUM = 1023;
constexpr size_t CUFFT_DEFAULT_CACHE_SIZE = CUFFT_MAX_PLAN_NUM;
#else
constexpr size_t CUFFT_MAX_PLAN_NUM = std::numeric_limits<size_t>::max();
// The default max cache size chosen for CUDA version > 10 is arbitrary.
// This number puts a limit on how big of a plan cache should we maintain by
// default. Users can always configure it via cufft_set_plan_cache_max_size.
constexpr size_t CUFFT_DEFAULT_CACHE_SIZE = 4096;
#endif
static_assert(CUFFT_MAX_PLAN_NUM >= 0 &&
CUFFT_MAX_PLAN_NUM <= std::numeric_limits<size_t>::max(),
"CUFFT_MAX_PLAN_NUM not in size_t range");
static_assert(CUFFT_DEFAULT_CACHE_SIZE >= 0 &&
CUFFT_DEFAULT_CACHE_SIZE <= CUFFT_MAX_PLAN_NUM,
"CUFFT_DEFAULT_CACHE_SIZE not in [0, CUFFT_MAX_PLAN_NUM] range");
// This cache assumes that the mapping from key to value never changes.
// This is **NOT** thread-safe. Please use a mutex when using it **AND** the
// value returned from try_emplace_value.
// The contract of using this cache is that try_emplace_value should only be
// used when the max_size is positive.
class FFTConfigCache {
public:
using kv_t = typename std::pair<FFTConfigKey, FFTConfig>;
using map_t = typename std::unordered_map<
std::reference_wrapper<FFTConfigKey>, typename std::list<kv_t>::iterator,
KeyHash<FFTConfigKey>, KeyEqual<FFTConfigKey>>;
using map_kkv_iter_t = typename map_t::iterator;
FFTConfigCache() : FFTConfigCache(CUFFT_DEFAULT_CACHE_SIZE) {}
explicit FFTConfigCache(int64_t max_size) { _set_max_size(max_size); }
FFTConfigCache(const FFTConfigCache& other) = delete;
FFTConfigCache& operator=(const FFTConfigCache& other) = delete;
FFTConfigCache(FFTConfigCache&& other) noexcept
: _usage_list(std::move(other._usage_list)),
_cache_map(std::move(other._cache_map)),
_max_size(other._max_size) {}
FFTConfigCache& operator=(FFTConfigCache&& other) noexcept { template <typename Ti, typename To>
_usage_list = std::move(other._usage_list); struct FFTC2CFunctor<platform::CPUDeviceContext, Ti, To> {
_cache_map = std::move(other._cache_map); void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x,
_max_size = other._max_size; Tensor* out, const std::vector<int64_t>& axes,
return *this; FFTNormMode normalization, bool forward) {
exec_fft<platform::CPUDeviceContext, Ti, To>(ctx, x, out, axes,
normalization, forward);
} }
};
// If key is in this cache, return the cached config. Otherwise, emplace the template <typename Ti, typename To>
// config in this cache and return it. struct FFTR2CFunctor<platform::CPUDeviceContext, Ti, To> {
FFTConfig& lookup(FFTConfigKey params) { void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x,
PADDLE_ENFORCE_GT(_max_size, 0, Tensor* out, const std::vector<int64_t>& axes,
platform::errors::InvalidArgument( FFTNormMode normalization, bool forward) {
"The max size of FFTConfigCache must be great than 0," exec_fft<platform::CPUDeviceContext, Ti, To>(ctx, x, out, axes,
"But received is [%d]", normalization, forward);
_max_size)); }
};
map_kkv_iter_t map_it = _cache_map.find(params);
// Hit, put to list front
if (map_it != _cache_map.end()) {
_usage_list.splice(_usage_list.begin(), _usage_list, map_it->second);
return map_it->second->second;
}
// Miss template <typename Ti, typename To>
// remove if needed struct FFTC2RFunctor<platform::CPUDeviceContext, Ti, To> {
if (_usage_list.size() >= _max_size) { void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x,
auto last = _usage_list.end(); Tensor* out, const std::vector<int64_t>& axes,
last--; FFTNormMode normalization, bool forward) {
_cache_map.erase(last->first); if (axes.size() > 1) {
_usage_list.pop_back(); const std::vector<int64_t> c2c_dims(axes.begin(), axes.end() - 1);
Tensor temp;
temp.mutable_data<Ti>(x->dims(), ctx.GetPlace());
FFTC2CFunctor<platform::CPUDeviceContext, Ti, Ti> c2c_functor;
c2c_functor(ctx, x, &temp, c2c_dims, normalization, forward);
const std::vector<int64_t> new_axes{axes.back()};
exec_fft<platform::CPUDeviceContext, Ti, To>(ctx, &temp, out, new_axes,
normalization, forward);
} else {
exec_fft<platform::CPUDeviceContext, Ti, To>(ctx, x, out, axes,
normalization, forward);
} }
// construct new plan at list front, then insert into _cache_map
_usage_list.emplace_front(std::piecewise_construct,
std::forward_as_tuple(params),
std::forward_as_tuple(params));
auto kv_it = _usage_list.begin();
_cache_map.emplace(std::piecewise_construct,
std::forward_as_tuple(kv_it->first),
std::forward_as_tuple(kv_it));
return kv_it->second;
} }
};
void clear() { #elif defined(PADDLE_WITH_POCKETFFT)
_cache_map.clear();
_usage_list.clear(); template <typename T>
T compute_factor(int64_t size, FFTNormMode normalization) {
constexpr auto one = static_cast<T>(1);
switch (normalization) {
case FFTNormMode::none:
return one;
case FFTNormMode::by_n:
return one / static_cast<T>(size);
case FFTNormMode::by_sqrt_n:
return one / std::sqrt(static_cast<T>(size));
} }
PADDLE_THROW(
platform::errors::InvalidArgument("Unsupported normalization type"));
}
void resize(int64_t new_size) { template <typename Ti, typename To>
_set_max_size(new_size); struct FFTC2CFunctor<platform::CPUDeviceContext, Ti, To> {
auto cur_size = _usage_list.size(); void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x,
if (cur_size > _max_size) { Tensor* out, const std::vector<int64_t>& axes,
auto delete_it = _usage_list.end(); FFTNormMode normalization, bool forward) {
for (size_t i = 0; i < cur_size - _max_size; i++) { using R = typename Ti::value_type;
delete_it--; using C = std::complex<R>;
_cache_map.erase(delete_it->first);
} const auto& input_dim = x->dims();
_usage_list.erase(delete_it, _usage_list.end()); const std::vector<size_t> in_sizes = phi::vectorize<size_t>(input_dim);
std::vector<std::ptrdiff_t> in_strides =
phi::vectorize<std::ptrdiff_t>(phi::stride(input_dim));
const int64_t data_size = sizeof(C);
std::transform(in_strides.begin(), in_strides.end(), in_strides.begin(),
[&](std::ptrdiff_t s) { return s * data_size; });
const auto* in_data = reinterpret_cast<const C*>(x->data<Ti>());
auto* out_data = reinterpret_cast<C*>(out->data<To>());
// pocketfft requires std::vector<size_t>
std::vector<size_t> axes_(axes.size());
std::copy(axes.begin(), axes.end(), axes_.begin());
// compuet factor
int64_t signal_numel = 1;
for (auto i : axes) {
signal_numel *= in_sizes[i];
} }
R factor = compute_factor<R>(signal_numel, normalization);
pocketfft::c2c(in_sizes, in_strides, in_strides, axes_, forward, in_data,
out_data, factor);
} }
};
size_t size() const { return _cache_map.size(); } template <typename Ti, typename To>
struct FFTR2CFunctor<platform::CPUDeviceContext, Ti, To> {
size_t max_size() const noexcept { return _max_size; } void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x,
Tensor* out, const std::vector<int64_t>& axes,
FFTNormMode normalization, bool forward) {
using R = Ti;
using C = std::complex<R>;
const auto& input_dim = x->dims();
const std::vector<size_t> in_sizes = phi::vectorize<size_t>(input_dim);
std::vector<std::ptrdiff_t> in_strides =
phi::vectorize<std::ptrdiff_t>(phi::stride(input_dim));
{
const int64_t data_size = sizeof(R);
std::transform(in_strides.begin(), in_strides.end(), in_strides.begin(),
[&](std::ptrdiff_t s) { return s * data_size; });
}
std::mutex mutex; const auto& output_dim = out->dims();
const std::vector<size_t> out_sizes = phi::vectorize<size_t>(output_dim);
std::vector<std::ptrdiff_t> out_strides =
phi::vectorize<std::ptrdiff_t>(phi::stride(output_dim));
{
const int64_t data_size = sizeof(C);
std::transform(out_strides.begin(), out_strides.end(),
out_strides.begin(),
[&](std::ptrdiff_t s) { return s * data_size; });
}
private: const auto* in_data = x->data<R>();
// Only sets size and does value check. Does not resize the data structures. auto* out_data = reinterpret_cast<C*>(out->data<To>());
void _set_max_size(int64_t new_size) { // pocketfft requires std::vector<size_t>
// We check that 0 <= new_size <= CUFFT_MAX_PLAN_NUM here. Since std::vector<size_t> axes_(axes.size());
// CUFFT_MAX_PLAN_NUM is of type size_t, we need to do non-negativity check std::copy(axes.begin(), axes.end(), axes_.begin());
// first. // compuet normalization factor
PADDLE_ENFORCE_GE( int64_t signal_numel = 1;
new_size, 0, for (auto i : axes) {
platform::errors::InvalidArgument( signal_numel *= in_sizes[i];
"cuFFT plan cache size must be non-negative, But received is [%d]", }
new_size)); R factor = compute_factor<R>(signal_numel, normalization);
PADDLE_ENFORCE_LE(new_size, CUFFT_MAX_PLAN_NUM, pocketfft::r2c(in_sizes, in_strides, out_strides, axes_, forward, in_data,
platform::errors::InvalidArgument( out_data, factor);
"cuFFT plan cache size can not be larger than [%d], "
"But received is [%d]",
CUFFT_MAX_PLAN_NUM, new_size));
_max_size = static_cast<size_t>(new_size);
} }
std::list<kv_t> _usage_list;
map_t _cache_map;
size_t _max_size;
}; };
static std::vector<std::unique_ptr<FFTConfigCache>> plan_caches; template <typename Ti, typename To>
static std::mutex plan_caches_mutex; struct FFTC2RFunctor<platform::CPUDeviceContext, Ti, To> {
void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x,
static inline FFTConfigCache& get_fft_plan_cache(int64_t device_index) { Tensor* out, const std::vector<int64_t>& axes,
std::lock_guard<std::mutex> guard(plan_caches_mutex); FFTNormMode normalization, bool forward) {
using R = To;
using C = std::complex<R>;
const auto& input_dim = x->dims();
const std::vector<size_t> in_sizes = phi::vectorize<size_t>(input_dim);
std::vector<std::ptrdiff_t> in_strides =
phi::vectorize<std::ptrdiff_t>(phi::stride(input_dim));
{
const int64_t data_size = sizeof(C);
std::transform(in_strides.begin(), in_strides.end(), in_strides.begin(),
[&](std::ptrdiff_t s) { return s * data_size; });
}
if (device_index >= plan_caches.size()) { const auto& output_dim = out->dims();
plan_caches.resize(device_index + 1); const std::vector<size_t> out_sizes = phi::vectorize<size_t>(output_dim);
} std::vector<std::ptrdiff_t> out_strides =
phi::vectorize<std::ptrdiff_t>(phi::stride(output_dim));
{
const int64_t data_size = sizeof(R);
std::transform(out_strides.begin(), out_strides.end(),
out_strides.begin(),
[&](std::ptrdiff_t s) { return s * data_size; });
}
if (!plan_caches[device_index]) { const auto* in_data = reinterpret_cast<const C*>(x->data<Ti>());
plan_caches[device_index] = std::make_unique<FFTConfigCache>(); auto* out_data = out->data<R>();
// pocketfft requires std::vector<size_t>
std::vector<size_t> axes_(axes.size());
std::copy(axes.begin(), axes.end(), axes_.begin());
// compuet normalization factor
int64_t signal_numel = 1;
for (auto i : axes) {
signal_numel *= out_sizes[i];
}
R factor = compute_factor<R>(signal_numel, normalization);
pocketfft::c2r(out_sizes, in_strides, out_strides, axes_, forward, in_data,
out_data, factor);
} }
};
return *plan_caches[device_index]; #endif
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -13,28 +13,7 @@ ...@@ -13,28 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/spectral_op.h" #include "paddle/fluid/operators/spectral_op.h"
#include "paddle/fluid/operators/spectral_helper.h"
#include <algorithm>
#include <functional>
#include <memory>
#include <numeric>
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#if defined(PADDLE_WITH_ONEMKL)
#include "paddle/phi/backends/dynload/mklrt.h"
#elif defined(PADDLE_WITH_POCKETFFT)
#include "extern_pocketfft/pocketfft_hdronly.h"
#endif
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -355,465 +334,6 @@ FFTNormMode get_norm_from_string(const std::string& norm, bool forward) { ...@@ -355,465 +334,6 @@ FFTNormMode get_norm_from_string(const std::string& norm, bool forward) {
norm)); norm));
} }
// FFT Functors
#if defined(PADDLE_WITH_ONEMKL)
#define MKL_DFTI_CHECK(expr) \
do { \
MKL_LONG status = (expr); \
if (!phi::dynload::DftiErrorClass(status, DFTI_NO_ERROR)) \
PADDLE_THROW( \
platform::errors::External(phi::dynload::DftiErrorMessage(status))); \
} while (0);
namespace {
struct DftiDescriptorDeleter {
void operator()(DFTI_DESCRIPTOR_HANDLE handle) {
if (handle != nullptr) {
MKL_DFTI_CHECK(phi::dynload::DftiFreeDescriptor(&handle));
}
}
};
// A RAII wrapper for MKL_DESCRIPTOR*
class DftiDescriptor {
public:
void init(DFTI_CONFIG_VALUE precision, DFTI_CONFIG_VALUE signal_type,
MKL_LONG signal_ndim, MKL_LONG* sizes) {
PADDLE_ENFORCE_EQ(desc_.get(), nullptr,
platform::errors::AlreadyExists(
"DftiDescriptor has already been initialized."));
DFTI_DESCRIPTOR* raw_desc;
MKL_DFTI_CHECK(phi::dynload::DftiCreateDescriptorX(
&raw_desc, precision, signal_type, signal_ndim, sizes));
desc_.reset(raw_desc);
}
DFTI_DESCRIPTOR* get() const {
DFTI_DESCRIPTOR* raw_desc = desc_.get();
PADDLE_ENFORCE_NOT_NULL(raw_desc,
platform::errors::PreconditionNotMet(
"DFTI DESCRIPTOR has not been initialized."));
return raw_desc;
}
private:
std::unique_ptr<DFTI_DESCRIPTOR, DftiDescriptorDeleter> desc_;
};
DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
const framework::proto::VarType::Type& out_dtype,
const framework::DDim& in_strides,
const framework::DDim& out_strides,
const std::vector<int>& signal_sizes,
FFTNormMode normalization, bool forward) {
const DFTI_CONFIG_VALUE precision = [&] {
switch (in_dtype) {
case framework::proto::VarType::FP32:
return DFTI_SINGLE;
case framework::proto::VarType::COMPLEX64:
return DFTI_SINGLE;
case framework::proto::VarType::FP64:
return DFTI_DOUBLE;
case framework::proto::VarType::COMPLEX128:
return DFTI_DOUBLE;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Invalid input datatype (%s), input data type should be FP32, "
"FP64, COMPLEX64 or COMPLEX128.",
framework::DataTypeToString(in_dtype)));
}
}();
// C2C, R2C, C2R
const FFTTransformType fft_type = GetFFTTransformType(in_dtype, out_dtype);
const DFTI_CONFIG_VALUE domain =
(fft_type == FFTTransformType::C2C) ? DFTI_COMPLEX : DFTI_REAL;
DftiDescriptor descriptor;
std::vector<MKL_LONG> fft_sizes(signal_sizes.cbegin(), signal_sizes.cend());
const MKL_LONG signal_ndim = fft_sizes.size() - 1;
descriptor.init(precision, domain, signal_ndim, fft_sizes.data() + 1);
// placement inplace or not inplace
MKL_DFTI_CHECK(phi::dynload::DftiSetValue(descriptor.get(), DFTI_PLACEMENT,
DFTI_NOT_INPLACE));
// number of transformations
const MKL_LONG batch_size = fft_sizes[0];
MKL_DFTI_CHECK(phi::dynload::DftiSetValue(
descriptor.get(), DFTI_NUMBER_OF_TRANSFORMS, batch_size));
// input & output distance
const MKL_LONG idist = in_strides[0];
const MKL_LONG odist = out_strides[0];
MKL_DFTI_CHECK(
phi::dynload::DftiSetValue(descriptor.get(), DFTI_INPUT_DISTANCE, idist));
MKL_DFTI_CHECK(phi::dynload::DftiSetValue(descriptor.get(),
DFTI_OUTPUT_DISTANCE, odist));
// input & output stride
std::vector<MKL_LONG> mkl_in_stride(1 + signal_ndim, 0);
std::vector<MKL_LONG> mkl_out_stride(1 + signal_ndim, 0);
for (MKL_LONG i = 1; i <= signal_ndim; i++) {
mkl_in_stride[i] = in_strides[i];
mkl_out_stride[i] = out_strides[i];
}
MKL_DFTI_CHECK(phi::dynload::DftiSetValue(
descriptor.get(), DFTI_INPUT_STRIDES, mkl_in_stride.data()));
MKL_DFTI_CHECK(phi::dynload::DftiSetValue(
descriptor.get(), DFTI_OUTPUT_STRIDES, mkl_out_stride.data()));
// conjugate even storage
if (!(fft_type == FFTTransformType::C2C)) {
MKL_DFTI_CHECK(phi::dynload::DftiSetValue(
descriptor.get(), DFTI_CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX));
}
MKL_LONG signal_numel =
std::accumulate(fft_sizes.cbegin() + 1, fft_sizes.cend(), 1UL,
std::multiplies<MKL_LONG>());
if (normalization != FFTNormMode::none) {
const double scale =
((normalization == FFTNormMode::by_sqrt_n)
? 1.0 / std::sqrt(static_cast<double>(signal_numel))
: 1.0 / static_cast<double>(signal_numel));
const auto scale_direction = [&]() {
if (fft_type == FFTTransformType::R2C ||
(fft_type == FFTTransformType::C2C && forward)) {
return DFTI_FORWARD_SCALE;
} else {
// (fft_type == FFTTransformType::C2R ||
// (fft_type == FFTTransformType::C2C && !forward))
return DFTI_BACKWARD_SCALE;
}
}();
MKL_DFTI_CHECK(
phi::dynload::DftiSetValue(descriptor.get(), scale_direction, scale));
}
// commit the descriptor
MKL_DFTI_CHECK(phi::dynload::DftiCommitDescriptor(descriptor.get()));
return descriptor;
}
// Execute a general fft operation (can be c2c, onesided r2c or onesided c2r)
template <typename DeviceContext, typename Ti, typename To>
void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out,
const std::vector<int64_t>& axes, FFTNormMode normalization,
bool forward) {
const framework::DDim& in_sizes = x->dims();
const int ndim = in_sizes.size();
const int signal_ndim = axes.size();
const int batch_ndim = ndim - signal_ndim;
const framework::DDim& out_sizes = out->dims();
// make a dim permutation
std::vector<int> dim_permute(ndim);
std::iota(dim_permute.begin(), dim_permute.end(), 0);
std::vector<bool> is_transformed_dim(ndim, false);
for (const auto& d : axes) {
is_transformed_dim[d] = true;
}
const auto batch_end =
std::partition(dim_permute.begin(), dim_permute.end(),
[&](size_t axis) { return !is_transformed_dim[axis]; });
std::copy(axes.cbegin(), axes.cend(), batch_end);
// transpose input according to that permutation
framework::DDim transposed_input_shape = in_sizes.transpose(dim_permute);
std::vector<int64_t> transposed_input_shape_ =
phi::vectorize(transposed_input_shape);
framework::Tensor transposed_input;
transposed_input.Resize(transposed_input_shape);
const auto place = ctx.GetPlace();
transposed_input.mutable_data<Ti>(place);
TransCompute<platform::CPUDeviceContext, Ti>(ndim, ctx, *x, &transposed_input,
dim_permute);
// make an collapsed input: collapse batch axes for input
const int batch_size = std::accumulate(
transposed_input_shape.Get(), transposed_input_shape.Get() + batch_ndim,
1L, std::multiplies<int64_t>());
std::vector<int> collapsed_input_shape_(1 + signal_ndim);
collapsed_input_shape_[0] = batch_size;
std::copy(transposed_input_shape_.begin() + batch_ndim,
transposed_input_shape_.end(), collapsed_input_shape_.begin() + 1);
const framework::DDim collapsed_input_shape =
phi::make_ddim(collapsed_input_shape_);
transposed_input.Resize(collapsed_input_shape);
framework::Tensor& collapsed_input = transposed_input;
// make a collapsed output
std::vector<int> collapsed_output_shape_(1 + signal_ndim);
collapsed_output_shape_[0] = batch_size;
for (int i = 0; i < signal_ndim; i++) {
collapsed_output_shape_[1 + i] = out_sizes[axes[i]];
}
const framework::DDim collapsed_output_shape =
phi::make_ddim(collapsed_output_shape_);
framework::Tensor collapsed_output;
collapsed_output.Resize(collapsed_output_shape);
collapsed_output.mutable_data(place, out->type());
// signal sizes
std::vector<int> signal_sizes(1 + signal_ndim);
signal_sizes[0] = batch_size;
for (int i = 0; i < signal_ndim; i++) {
signal_sizes[1 + i] =
std::max(collapsed_input_shape[1 + i], collapsed_output_shape[1 + i]);
}
// input & output stride
const framework::DDim input_stride = phi::stride(collapsed_input_shape);
const framework::DDim output_stride = phi::stride(collapsed_output_shape);
// make a DFTI_DESCRIPTOR
DftiDescriptor desc =
_plan_mkl_fft(framework::TransToProtoVarType(x->dtype()),
framework::TransToProtoVarType(out->dtype()), input_stride,
output_stride, signal_sizes, normalization, forward);
const FFTTransformType fft_type =
GetFFTTransformType(framework::TransToProtoVarType(x->dtype()),
framework::TransToProtoVarType(out->type()));
if (fft_type == FFTTransformType::C2R && forward) {
framework::Tensor collapsed_input_conj(collapsed_input.dtype());
collapsed_input_conj.mutable_data<Ti>(collapsed_input.dims(),
ctx.GetPlace());
// conjugate the input
platform::ForRange<DeviceContext> for_range(ctx, collapsed_input.numel());
phi::funcs::ConjFunctor<Ti> functor(collapsed_input.data<Ti>(),
collapsed_input.numel(),
collapsed_input_conj.data<Ti>());
for_range(functor);
MKL_DFTI_CHECK(phi::dynload::DftiComputeBackward(
desc.get(), collapsed_input_conj.data(), collapsed_output.data()));
} else if (fft_type == FFTTransformType::R2C && !forward) {
framework::Tensor collapsed_output_conj(collapsed_output.dtype());
collapsed_output_conj.mutable_data<To>(collapsed_output.dims(),
ctx.GetPlace());
MKL_DFTI_CHECK(phi::dynload::DftiComputeForward(
desc.get(), collapsed_input.data(), collapsed_output_conj.data()));
// conjugate the output
platform::ForRange<DeviceContext> for_range(ctx, collapsed_output.numel());
phi::funcs::ConjFunctor<To> functor(collapsed_output_conj.data<To>(),
collapsed_output.numel(),
collapsed_output.data<To>());
for_range(functor);
} else {
if (forward) {
MKL_DFTI_CHECK(phi::dynload::DftiComputeForward(
desc.get(), collapsed_input.data(), collapsed_output.data()));
} else {
MKL_DFTI_CHECK(phi::dynload::DftiComputeBackward(
desc.get(), collapsed_input.data(), collapsed_output.data()));
}
}
// resize for the collapsed output
framework::DDim transposed_output_shape = out_sizes.transpose(dim_permute);
collapsed_output.Resize(transposed_output_shape);
framework::Tensor& transposed_output = collapsed_output;
// reverse the transposition
std::vector<int> reverse_dim_permute(ndim);
for (int i = 0; i < ndim; i++) {
reverse_dim_permute[dim_permute[i]] = i;
}
TransCompute<platform::CPUDeviceContext, To>(ndim, ctx, transposed_output,
out, reverse_dim_permute);
}
} // anonymous namespace
template <typename Ti, typename To>
struct FFTC2CFunctor<platform::CPUDeviceContext, Ti, To> {
void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x,
Tensor* out, const std::vector<int64_t>& axes,
FFTNormMode normalization, bool forward) {
exec_fft<platform::CPUDeviceContext, Ti, To>(ctx, x, out, axes,
normalization, forward);
}
};
template <typename Ti, typename To>
struct FFTR2CFunctor<platform::CPUDeviceContext, Ti, To> {
void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x,
Tensor* out, const std::vector<int64_t>& axes,
FFTNormMode normalization, bool forward) {
exec_fft<platform::CPUDeviceContext, Ti, To>(ctx, x, out, axes,
normalization, forward);
}
};
template <typename Ti, typename To>
struct FFTC2RFunctor<platform::CPUDeviceContext, Ti, To> {
void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x,
Tensor* out, const std::vector<int64_t>& axes,
FFTNormMode normalization, bool forward) {
if (axes.size() > 1) {
const std::vector<int64_t> c2c_dims(axes.begin(), axes.end() - 1);
Tensor temp;
temp.mutable_data<Ti>(x->dims(), ctx.GetPlace());
FFTC2CFunctor<platform::CPUDeviceContext, Ti, Ti> c2c_functor;
c2c_functor(ctx, x, &temp, c2c_dims, normalization, forward);
const std::vector<int64_t> new_axes{axes.back()};
exec_fft<platform::CPUDeviceContext, Ti, To>(ctx, &temp, out, new_axes,
normalization, forward);
} else {
exec_fft<platform::CPUDeviceContext, Ti, To>(ctx, x, out, axes,
normalization, forward);
}
}
};
#elif defined(PADDLE_WITH_POCKETFFT)
namespace {
template <typename T>
T compute_factor(int64_t size, FFTNormMode normalization) {
constexpr auto one = static_cast<T>(1);
switch (normalization) {
case FFTNormMode::none:
return one;
case FFTNormMode::by_n:
return one / static_cast<T>(size);
case FFTNormMode::by_sqrt_n:
return one / std::sqrt(static_cast<T>(size));
}
PADDLE_THROW(
platform::errors::InvalidArgument("Unsupported normalization type"));
}
} // anonymous namespace
template <typename Ti, typename To>
struct FFTC2CFunctor<platform::CPUDeviceContext, Ti, To> {
void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x,
Tensor* out, const std::vector<int64_t>& axes,
FFTNormMode normalization, bool forward) {
using R = typename Ti::value_type;
using C = std::complex<R>;
const auto& input_dim = x->dims();
const std::vector<size_t> in_sizes = phi::vectorize<size_t>(input_dim);
std::vector<std::ptrdiff_t> in_strides =
phi::vectorize<std::ptrdiff_t>(phi::stride(input_dim));
const int64_t data_size = sizeof(C);
std::transform(in_strides.begin(), in_strides.end(), in_strides.begin(),
[&](std::ptrdiff_t s) { return s * data_size; });
const auto* in_data = reinterpret_cast<const C*>(x->data<Ti>());
auto* out_data = reinterpret_cast<C*>(out->data<To>());
// pocketfft requires std::vector<size_t>
std::vector<size_t> axes_(axes.size());
std::copy(axes.begin(), axes.end(), axes_.begin());
// compuet factor
int64_t signal_numel = 1;
for (auto i : axes) {
signal_numel *= in_sizes[i];
}
R factor = compute_factor<R>(signal_numel, normalization);
pocketfft::c2c(in_sizes, in_strides, in_strides, axes_, forward, in_data,
out_data, factor);
}
};
template <typename Ti, typename To>
struct FFTR2CFunctor<platform::CPUDeviceContext, Ti, To> {
void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x,
Tensor* out, const std::vector<int64_t>& axes,
FFTNormMode normalization, bool forward) {
using R = Ti;
using C = std::complex<R>;
const auto& input_dim = x->dims();
const std::vector<size_t> in_sizes = phi::vectorize<size_t>(input_dim);
std::vector<std::ptrdiff_t> in_strides =
phi::vectorize<std::ptrdiff_t>(phi::stride(input_dim));
{
const int64_t data_size = sizeof(R);
std::transform(in_strides.begin(), in_strides.end(), in_strides.begin(),
[&](std::ptrdiff_t s) { return s * data_size; });
}
const auto& output_dim = out->dims();
const std::vector<size_t> out_sizes = phi::vectorize<size_t>(output_dim);
std::vector<std::ptrdiff_t> out_strides =
phi::vectorize<std::ptrdiff_t>(phi::stride(output_dim));
{
const int64_t data_size = sizeof(C);
std::transform(out_strides.begin(), out_strides.end(),
out_strides.begin(),
[&](std::ptrdiff_t s) { return s * data_size; });
}
const auto* in_data = x->data<R>();
auto* out_data = reinterpret_cast<C*>(out->data<To>());
// pocketfft requires std::vector<size_t>
std::vector<size_t> axes_(axes.size());
std::copy(axes.begin(), axes.end(), axes_.begin());
// compuet normalization factor
int64_t signal_numel = 1;
for (auto i : axes) {
signal_numel *= in_sizes[i];
}
R factor = compute_factor<R>(signal_numel, normalization);
pocketfft::r2c(in_sizes, in_strides, out_strides, axes_, forward, in_data,
out_data, factor);
}
};
template <typename Ti, typename To>
struct FFTC2RFunctor<platform::CPUDeviceContext, Ti, To> {
void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x,
Tensor* out, const std::vector<int64_t>& axes,
FFTNormMode normalization, bool forward) {
using R = To;
using C = std::complex<R>;
const auto& input_dim = x->dims();
const std::vector<size_t> in_sizes = phi::vectorize<size_t>(input_dim);
std::vector<std::ptrdiff_t> in_strides =
phi::vectorize<std::ptrdiff_t>(phi::stride(input_dim));
{
const int64_t data_size = sizeof(C);
std::transform(in_strides.begin(), in_strides.end(), in_strides.begin(),
[&](std::ptrdiff_t s) { return s * data_size; });
}
const auto& output_dim = out->dims();
const std::vector<size_t> out_sizes = phi::vectorize<size_t>(output_dim);
std::vector<std::ptrdiff_t> out_strides =
phi::vectorize<std::ptrdiff_t>(phi::stride(output_dim));
{
const int64_t data_size = sizeof(R);
std::transform(out_strides.begin(), out_strides.end(),
out_strides.begin(),
[&](std::ptrdiff_t s) { return s * data_size; });
}
const auto* in_data = reinterpret_cast<const C*>(x->data<Ti>());
auto* out_data = out->data<R>();
// pocketfft requires std::vector<size_t>
std::vector<size_t> axes_(axes.size());
std::copy(axes.begin(), axes.end(), axes_.begin());
// compuet normalization factor
int64_t signal_numel = 1;
for (auto i : axes) {
signal_numel *= out_sizes[i];
}
R factor = compute_factor<R>(signal_numel, normalization);
pocketfft::c2r(out_sizes, in_strides, out_strides, axes_, forward, in_data,
out_data, factor);
}
};
#endif
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -8,496 +8,9 @@ ...@@ -8,496 +8,9 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <functional>
#include <list>
#include <memory>
#include <mutex>
#include <numeric>
#include <sstream>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/operators/conj_op.h" #include "paddle/fluid/operators/spectral_op.cu.h"
#include "paddle/fluid/operators/spectral_helper.h"
#include "paddle/fluid/operators/spectral_op.h" #include "paddle/fluid/operators/spectral_op.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
namespace paddle {
namespace operators {
namespace {
// Calculates the normalization constant
double fft_normalization_scale(FFTNormMode normalization,
const std::vector<int64_t>& sizes,
const std::vector<int64_t>& dims) {
// auto norm = static_cast<fft_norm_mode>(normalization);
if (normalization == FFTNormMode::none) {
return static_cast<double>(1.0);
}
int64_t signal_numel = 1;
for (auto dim : dims) {
signal_numel *= sizes[dim];
}
const double scale_denom = (normalization == FFTNormMode::by_sqrt_n)
? std::sqrt(signal_numel)
: static_cast<double>(signal_numel);
return static_cast<double>(1.0 / scale_denom);
}
template <typename DeviceContext, typename T>
void exec_normalization(const DeviceContext& ctx, const Tensor* in, Tensor* out,
FFTNormMode normalization,
const std::vector<int64_t>& sizes,
const std::vector<int64_t>& axes) {
double scale = fft_normalization_scale(normalization, sizes, axes);
if (scale != 1.0) {
auto eigen_out = framework::EigenVector<T>::Flatten(*out);
auto eigen_in = framework::EigenVector<T>::Flatten(*in);
auto dev = ctx.eigen_device();
EigenScale<Eigen::GpuDevice, T>::Eval(*dev, eigen_out, eigen_in,
static_cast<T>(scale),
static_cast<T>(0), false);
} else {
framework::TensorCopy(*in, ctx.GetPlace(), out);
}
}
#if defined(PADDLE_WITH_CUDA)
FFTConfigKey create_fft_configkey(const framework::Tensor& input,
const framework::Tensor& output,
int signal_ndim) {
// Create the transform plan (either from cache or locally)
const auto value_type =
framework::IsComplexType(framework::TransToProtoVarType(input.dtype()))
? framework::ToRealType(framework::TransToProtoVarType(input.dtype()))
: framework::TransToProtoVarType(input.dtype());
auto fft_type =
GetFFTTransformType(framework::TransToProtoVarType(input.dtype()),
framework::TransToProtoVarType(output.dtype()));
// signal sizes
std::vector<int64_t> signal_size(signal_ndim + 1);
signal_size[0] = input.dims()[0];
for (int64_t i = 1; i <= signal_ndim; ++i) {
auto in_size = input.dims()[i];
auto out_size = output.dims()[i];
signal_size[i] = std::max(in_size, out_size);
}
FFTConfigKey key(phi::vectorize(input.dims()), phi::vectorize(output.dims()),
signal_size, fft_type, value_type);
return key;
}
// Execute a pre-planned transform
static void exec_cufft_plan_raw(const FFTConfig& config, void* in_data,
void* out_data, bool forward) {
auto& plan = config.plan();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cufftXtExec(
plan, in_data, out_data, forward ? CUFFT_FORWARD : CUFFT_INVERSE));
}
template <typename DeviceContext, typename Ti, typename To>
void exec_cufft_plan(const DeviceContext& ctx, const FFTConfig& config,
framework::Tensor* input, framework::Tensor* output,
bool forward) {
// execute transform plan
auto fft_type = config.transform_type();
if (fft_type == FFTTransformType::C2R && forward) {
forward = false;
framework::Tensor input_conj(input->type());
input_conj.mutable_data<Ti>(input->dims(), ctx.GetPlace());
platform::ForRange<DeviceContext> for_range(ctx, input->numel());
phi::funcs::ConjFunctor<Ti> functor(input->data<Ti>(), input->numel(),
input_conj.data<Ti>());
for_range(functor);
exec_cufft_plan_raw(config, input_conj.data(), output->data(), forward);
} else if (fft_type == FFTTransformType::R2C && !forward) {
forward = true;
framework::Tensor out_conj(output->type());
out_conj.mutable_data<To>(output->dims(), ctx.GetPlace());
exec_cufft_plan_raw(config, input->data(), out_conj.data(), forward);
platform::ForRange<DeviceContext> for_range(ctx, output->numel());
phi::funcs::ConjFunctor<To> functor(out_conj.data<To>(), output->numel(),
output->data<To>());
for_range(functor);
} else {
exec_cufft_plan_raw(config, input->data(), output->data(), forward);
}
}
#elif defined(PADDLE_WITH_HIP)
FFTConfigKey create_fft_configkey(const framework::Tensor& input,
const framework::Tensor& output,
int signal_ndim) {
// Create the transform plan (either from cache or locally)
const auto value_type =
framework::IsComplexType(framework::TransToProtoVarType(input.dtype()))
? framework::ToRealType(framework::TransToProtoVarType(input.dtype()))
: framework::TransToProtoVarType(input.dtype());
auto fft_type =
GetFFTTransformType(framework::TransToProtoVarType(input.dtype()),
framework::TransToProtoVarType(output.type()));
// signal sizes
std::vector<int64_t> signal_size(signal_ndim + 1);
signal_size[0] = input.dims()[0];
for (int64_t i = 1; i <= signal_ndim; ++i) {
auto in_size = input.dims()[i];
auto out_size = output.dims()[i];
signal_size[i] = std::max(in_size, out_size);
}
FFTConfigKey key(phi::vectorize(input.dims()), phi::vectorize(output.dims()),
signal_size, fft_type, value_type);
return key;
}
// Execute a pre-planned transform
static void exec_hipfft_plan_raw(const FFTConfig& config, void* in_data,
void* out_data, bool forward) {
auto& plan = config.plan();
auto value_type = config.data_type();
if (value_type == framework::proto::VarType::FP32) {
switch (config.transform_type()) {
case FFTTransformType::C2C: {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftExecC2C(
plan, static_cast<hipfftComplex*>(in_data),
static_cast<hipfftComplex*>(out_data),
forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD));
return;
}
case FFTTransformType::R2C: {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftExecR2C(
plan, static_cast<hipfftReal*>(in_data),
static_cast<hipfftComplex*>(out_data)));
return;
}
case FFTTransformType::C2R: {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftExecC2R(
plan, static_cast<hipfftComplex*>(in_data),
static_cast<hipfftReal*>(out_data)));
return;
}
}
} else if (value_type == framework::proto::VarType::FP64) {
switch (config.transform_type()) {
case FFTTransformType::C2C: {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftExecZ2Z(
plan, static_cast<hipfftDoubleComplex*>(in_data),
static_cast<hipfftDoubleComplex*>(out_data),
forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD));
return;
}
case FFTTransformType::R2C: {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftExecD2Z(
plan, static_cast<hipfftDoubleReal*>(in_data),
static_cast<hipfftDoubleComplex*>(out_data)));
return;
}
case FFTTransformType::C2R: {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftExecZ2D(
plan, static_cast<hipfftDoubleComplex*>(in_data),
static_cast<hipfftDoubleReal*>(out_data)));
return;
}
}
}
PADDLE_THROW(platform::errors::InvalidArgument(
"hipFFT only support transforms of type float32 and float64"));
}
template <typename DeviceContext, typename Ti, typename To>
void exec_hipfft_plan(const DeviceContext& ctx, const FFTConfig& config,
framework::Tensor* input, framework::Tensor* output,
bool forward) {
auto fft_type = config.transform_type();
if (fft_type == FFTTransformType::C2R && forward) {
forward = false;
framework::Tensor input_conj(input->type());
input_conj.mutable_data<Ti>(input->dims(), ctx.GetPlace());
platform::ForRange<DeviceContext> for_range(ctx, input->numel());
phi::funcs::ConjFunctor<Ti> functor(input->data<Ti>(), input->numel(),
input_conj.data<Ti>());
for_range(functor);
exec_hipfft_plan_raw(config, input_conj.data(), output->data(), forward);
} else if (fft_type == FFTTransformType::R2C && !forward) {
forward = true;
framework::Tensor out_conj(output->type());
out_conj.mutable_data<To>(output->dims(), ctx.GetPlace());
exec_hipfft_plan_raw(config, input->data(), out_conj.data(), forward);
platform::ForRange<DeviceContext> for_range(ctx, output->numel());
phi::funcs::ConjFunctor<To> functor(out_conj.data<To>(), output->numel(),
output->data<To>());
for_range(functor);
} else {
exec_hipfft_plan_raw(config, input->data(), output->data(), forward);
}
}
#endif
// Execute a general unnormalized fft operation (can be c2c, onesided r2c or
// onesided c2r)
template <typename DeviceContext, typename Ti, typename To>
void exec_fft(const DeviceContext& ctx, const Tensor* X, Tensor* out,
const std::vector<int64_t>& dim, bool forward) {
const auto x_dims = phi::vectorize(X->dims());
const int64_t ndim = static_cast<int64_t>(X->dims().size());
auto tensor_place = ctx.GetPlace();
// make a dim permutation
std::vector<int> dim_permute(ndim);
std::iota(dim_permute.begin(), dim_permute.end(), int{0});
std::vector<bool> is_transformed_dim(ndim);
for (const auto& d : dim) {
is_transformed_dim[d] = true;
}
auto batch_end =
std::partition(dim_permute.begin(), dim_permute.end(),
[&](int64_t d) { return !is_transformed_dim[d]; });
std::sort(dim_permute.begin(), batch_end);
std::copy(dim.cbegin(), dim.cend(), batch_end);
// transpose input according to dim permutation
auto transposed_input_shape = X->dims().transpose(dim_permute);
framework::Tensor transposed_input;
transposed_input.Resize(transposed_input_shape);
transposed_input.mutable_data<Ti>(tensor_place);
TransCompute<DeviceContext, Ti>(ndim, ctx, *X, &transposed_input,
dim_permute);
// Reshape batch dimensions into a single dimension
const int64_t signal_ndim = static_cast<int64_t>(dim.size());
std::vector<int64_t> collapsed_input_shape(signal_ndim + 1);
auto transposed_input_shape_ = phi::vectorize(transposed_input_shape);
const int64_t batch_dims = ndim - signal_ndim;
auto batch_size =
std::accumulate(transposed_input_shape_.begin(),
transposed_input_shape_.begin() + batch_dims,
static_cast<int>(1), std::multiplies<int>());
collapsed_input_shape[0] = batch_size;
std::copy(transposed_input_shape_.begin() + batch_dims,
transposed_input_shape_.end(), collapsed_input_shape.begin() + 1);
framework::Tensor& collapsed_input = transposed_input;
collapsed_input.Resize(phi::make_ddim(collapsed_input_shape));
// make a collpased output
const auto out_dims = phi::vectorize(out->dims());
std::vector<int64_t> collapsed_output_shape(1 + signal_ndim);
collapsed_output_shape[0] = batch_size;
for (size_t i = 0; i < dim.size(); ++i) {
collapsed_output_shape[i + 1] = out_dims[dim[i]];
}
framework::Tensor collapsed_output;
collapsed_output.Resize(phi::make_ddim(collapsed_output_shape));
collapsed_output.mutable_data<To>(tensor_place);
FFTConfig* config = nullptr;
#if defined(PADDLE_WITH_CUDA)
std::unique_ptr<FFTConfig> config_ = nullptr;
// create plan
FFTConfigKey key =
create_fft_configkey(collapsed_input, collapsed_output, signal_ndim);
bool using_cache = false;
#if !defined(CUFFT_VERSION) || (CUFFT_VERSION < 10200)
using_cache = true;
#endif
if (using_cache) {
const int64_t device_id = static_cast<int64_t>(
reinterpret_cast<const platform::CUDAPlace*>(&collapsed_input.place())
->GetDeviceId());
FFTConfigCache& plan_cache = get_fft_plan_cache(device_id);
std::unique_lock<std::mutex> guard(plan_cache.mutex, std::defer_lock);
guard.lock();
config = &(plan_cache.lookup(key));
} else {
config_ = std::make_unique<FFTConfig>(key);
config = config_.get();
}
// prepare cufft for execution
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cufftSetStream(config->plan(), ctx.stream()));
framework::Tensor workspace_tensor;
workspace_tensor.mutable_data<To>(tensor_place, config->workspace_size());
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cufftSetWorkArea(
config->plan(), workspace_tensor.data<To>()));
// execute transform plan
exec_cufft_plan<DeviceContext, Ti, To>(ctx, *config, &collapsed_input,
&collapsed_output, forward);
#elif defined(PADDLE_WITH_HIP)
// create plan
FFTConfigKey key =
create_fft_configkey(collapsed_input, collapsed_output, signal_ndim);
const int64_t device_id = static_cast<int64_t>(
reinterpret_cast<const platform::CUDAPlace*>(&collapsed_input.place())
->GetDeviceId());
FFTConfigCache& plan_cache = get_fft_plan_cache(device_id);
std::unique_lock<std::mutex> guard(plan_cache.mutex, std::defer_lock);
guard.lock();
config = &(plan_cache.lookup(key));
// prepare cufft for execution
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::hipfftSetStream(config->plan(), ctx.stream()));
framework::Tensor workspace_tensor;
workspace_tensor.mutable_data<To>(tensor_place, config->workspace_size());
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftSetWorkArea(
config->plan(), workspace_tensor.data<To>()));
// execute transform plan
exec_hipfft_plan<DeviceContext, Ti, To>(ctx, *config, &collapsed_input,
&collapsed_output, forward);
#endif
// Inverting output by reshape and transpose to original batch and dimension
auto transposed_out_shape = out->dims().transpose(dim_permute);
collapsed_output.Resize(transposed_out_shape);
auto& transposed_output = collapsed_output;
std::vector<int> reverse_dim_permute(ndim);
for (size_t i = 0; i < ndim; i++) {
reverse_dim_permute[dim_permute[i]] = i;
}
TransCompute<DeviceContext, To>(ndim, ctx, transposed_output, out,
reverse_dim_permute);
}
} // anonymous namespace
// Use the optimized path to perform single R2C or C2R if transformation dim is
// supported by cuFFT
bool use_optimized_fft_path(const std::vector<int64_t>& axes) {
// For performance reason, when axes starts with (0, 1), do not use the
// optimized path.
if (axes.size() > kMaxFFTNdim ||
(axes.size() >= 2 && axes[0] == 0 && axes[1] == 1)) {
return false;
} else {
return true;
}
}
template <typename Ti, typename To>
struct FFTC2CFunctor<platform::CUDADeviceContext, Ti, To> {
void operator()(const platform::CUDADeviceContext& ctx, const Tensor* X,
Tensor* out, const std::vector<int64_t>& axes,
FFTNormMode normalization, bool forward) {
if (axes.empty()) {
framework::TensorCopy(*X, ctx.GetPlace(), out);
return;
}
framework::Tensor* p_out = out;
std::vector<int64_t> out_dims = phi::vectorize(X->dims());
std::vector<int64_t> working_axes(axes.begin(), axes.end());
std::vector<int64_t> first_dims;
size_t max_dims;
framework::Tensor working_tensor;
working_tensor.mutable_data<Ti>(X->dims(), ctx.GetPlace());
framework::Tensor* p_working_tensor = &working_tensor;
framework::TensorCopy(*X, ctx.GetPlace(), &working_tensor);
while (true) {
max_dims =
std::min(static_cast<size_t>(kMaxFFTNdim), working_axes.size());
first_dims.assign(working_axes.end() - max_dims, working_axes.end());
exec_fft<platform::CUDADeviceContext, Ti, To>(ctx, p_working_tensor,
p_out, first_dims, forward);
working_axes.resize(working_axes.size() - max_dims);
first_dims.clear();
if (working_axes.empty()) {
break;
}
std::swap(p_out, p_working_tensor);
}
exec_normalization<platform::CUDADeviceContext, To>(
ctx, p_out, out, normalization, out_dims, axes);
}
};
template <typename Ti, typename To>
struct FFTC2RFunctor<platform::CUDADeviceContext, Ti, To> {
void operator()(const platform::CUDADeviceContext& ctx, const Tensor* X,
Tensor* out, const std::vector<int64_t>& axes,
FFTNormMode normalization, bool forward) {
std::vector<int64_t> in_dims = phi::vectorize(X->dims());
std::vector<int64_t> out_dims = phi::vectorize(out->dims());
if (use_optimized_fft_path(axes)) {
framework::Tensor x_copy(X->type());
x_copy.mutable_data<Ti>(X->dims(), ctx.GetPlace());
framework::TensorCopy(*X, ctx.GetPlace(), &x_copy);
exec_fft<platform::CUDADeviceContext, Ti, To>(ctx, &x_copy, out, axes,
forward);
} else {
framework::Tensor temp_tensor;
temp_tensor.mutable_data<Ti>(X->dims(), ctx.GetPlace());
const std::vector<int64_t> dims(axes.begin(), axes.end() - 1);
FFTC2CFunctor<platform::CUDADeviceContext, Ti, Ti> c2c_functor;
c2c_functor(ctx, X, &temp_tensor, dims, FFTNormMode::none, forward);
exec_fft<platform::CUDADeviceContext, Ti, To>(ctx, &temp_tensor, out,
{axes.back()}, forward);
}
exec_normalization<platform::CUDADeviceContext, To>(
ctx, out, out, normalization, out_dims, axes);
}
};
// n dimension real to complex FFT use cufft lib
template <typename Ti, typename To>
struct FFTR2CFunctor<platform::CUDADeviceContext, Ti, To> {
void operator()(const platform::CUDADeviceContext& ctx, const Tensor* X,
Tensor* out, const std::vector<int64_t>& axes,
FFTNormMode normalization, bool forward) {
// Step1: R2C transform on the last dimension
framework::Tensor* r2c_out = out;
const std::vector<int64_t> last_dim{axes.back()};
std::vector<int64_t> out_dims = phi::vectorize(out->dims());
exec_fft<platform::CUDADeviceContext, Ti, To>(ctx, X, r2c_out, last_dim,
forward);
// Step2: C2C transform on the remaining dimension
framework::Tensor c2c_out;
if (axes.size() > 1) {
c2c_out.mutable_data<To>(out->dims(), ctx.GetPlace());
std::vector<int64_t> remain_dim(axes.begin(), axes.end() - 1);
FFTC2CFunctor<platform::CUDADeviceContext, To, To> fft_c2c_func;
fft_c2c_func(ctx, r2c_out, &c2c_out, remain_dim, FFTNormMode::none,
forward);
}
const auto in_sizes = phi::vectorize(X->dims());
framework::Tensor* norm_tensor = axes.size() > 1 ? &c2c_out : r2c_out;
exec_normalization<platform::CUDADeviceContext, To>(
ctx, norm_tensor, out, normalization, in_sizes, axes);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <list>
#include <memory>
#include <mutex>
#include <numeric>
#include <sstream>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/operators/conj_op.h"
#include "paddle/fluid/operators/spectral_op.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/dynload/hipfft.h"
#endif
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/dynload/cufft.h"
#endif
namespace paddle {
namespace operators {
using ScalarType = framework::proto::VarType::Type;
const int64_t kMaxFFTNdim = 3;
const int64_t kMaxDataNdim = kMaxFFTNdim + 1;
// This struct is used to easily compute hashes of the
// parameters. It will be the **key** to the plan cache.
struct FFTConfigKey {
// between 1 and kMaxFFTNdim, i.e., 1 <= signal_ndim <= 3
int64_t signal_ndim_;
// These include additional batch dimension as well.
int64_t sizes_[kMaxDataNdim];
int64_t input_shape_[kMaxDataNdim];
int64_t output_shape_[kMaxDataNdim];
FFTTransformType fft_type_;
ScalarType value_type_;
FFTConfigKey() = default;
FFTConfigKey(const std::vector<int64_t>& in_shape,
const std::vector<int64_t>& out_shape,
const std::vector<int64_t>& signal_size,
FFTTransformType fft_type, ScalarType value_type) {
// Padding bits must be zeroed for hashing
memset(this, 0, sizeof(*this));
signal_ndim_ = signal_size.size() - 1;
fft_type_ = fft_type;
value_type_ = value_type;
std::copy(signal_size.cbegin(), signal_size.cend(), sizes_);
std::copy(in_shape.cbegin(), in_shape.cend(), input_shape_);
std::copy(out_shape.cbegin(), out_shape.cend(), output_shape_);
}
};
#if defined(PADDLE_WITH_CUDA)
// An RAII encapsulation of cuFFTHandle
class CuFFTHandle {
::cufftHandle handle_;
public:
CuFFTHandle() {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cufftCreate(&handle_));
}
CuFFTHandle(const CuFFTHandle& other) = delete;
CuFFTHandle& operator=(const CuFFTHandle& other) = delete;
CuFFTHandle(CuFFTHandle&& other) = delete;
CuFFTHandle& operator=(CuFFTHandle&& other) = delete;
::cufftHandle& get() { return handle_; }
const ::cufftHandle& get() const { return handle_; }
~CuFFTHandle() {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cufftDestroy(handle_));
}
};
using plan_size_type = long long int; // NOLINT
// This class contains all the information needed to execute a cuFFT plan:
// 1. the plan
// 2. the workspace size needed
class FFTConfig {
public:
// Only move semantics is enought for this class. Although we already use
// unique_ptr for the plan, still remove copy constructor and assignment op so
// we don't accidentally copy and take perf hit.
explicit FFTConfig(const FFTConfigKey& plan_key)
: FFTConfig(
std::vector<int64_t>(plan_key.sizes_,
plan_key.sizes_ + plan_key.signal_ndim_ + 1),
plan_key.signal_ndim_, plan_key.fft_type_, plan_key.value_type_) {}
// sizes are full signal, including batch size and always two-sided
FFTConfig(const std::vector<int64_t>& sizes, const int64_t signal_ndim,
FFTTransformType fft_type, ScalarType dtype)
: fft_type_(fft_type), value_type_(dtype) {
// signal sizes (excluding batch dim)
std::vector<plan_size_type> signal_sizes(sizes.begin() + 1, sizes.end());
// input batch size
const auto batch = static_cast<plan_size_type>(sizes[0]);
// const int64_t signal_ndim = sizes.size() - 1;
PADDLE_ENFORCE_EQ(signal_ndim, sizes.size() - 1,
platform::errors::InvalidArgument(
"The signal_ndim must be equal to sizes.size() - 1,"
"But signal_ndim is: [%d], sizes.size() - 1 is: [%d]",
signal_ndim, sizes.size() - 1));
cudaDataType itype, otype, exec_type;
const auto complex_input = has_complex_input(fft_type);
const auto complex_output = has_complex_output(fft_type);
if (dtype == framework::proto::VarType::FP32) {
itype = complex_input ? CUDA_C_32F : CUDA_R_32F;
otype = complex_output ? CUDA_C_32F : CUDA_R_32F;
exec_type = CUDA_C_32F;
} else if (dtype == framework::proto::VarType::FP64) {
itype = complex_input ? CUDA_C_64F : CUDA_R_64F;
otype = complex_output ? CUDA_C_64F : CUDA_R_64F;
exec_type = CUDA_C_64F;
} else if (dtype == framework::proto::VarType::FP16) {
itype = complex_input ? CUDA_C_16F : CUDA_R_16F;
otype = complex_output ? CUDA_C_16F : CUDA_R_16F;
exec_type = CUDA_C_16F;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"cuFFT only support transforms of type float16, float32 and "
"float64"));
}
// disable auto allocation of workspace to use allocator from the framework
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cufftSetAutoAllocation(
plan(), /* autoAllocate */ 0));
size_t ws_size_t;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cufftXtMakePlanMany(
plan(), signal_ndim, signal_sizes.data(),
/* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1, itype,
/* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, otype,
batch, &ws_size_t, exec_type));
ws_size = ws_size_t;
}
FFTConfig(const FFTConfig& other) = delete;
FFTConfig& operator=(const FFTConfig& other) = delete;
FFTConfig(FFTConfig&& other) = delete;
FFTConfig& operator=(FFTConfig&& other) = delete;
const cufftHandle& plan() const { return plan_ptr.get(); }
FFTTransformType transform_type() const { return fft_type_; }
ScalarType data_type() const { return value_type_; }
size_t workspace_size() const { return ws_size; }
private:
CuFFTHandle plan_ptr;
size_t ws_size;
FFTTransformType fft_type_;
ScalarType value_type_;
};
#elif defined(PADDLE_WITH_HIP)
// An RAII encapsulation of cuFFTHandle
class HIPFFTHandle {
::hipfftHandle handle_;
public:
HIPFFTHandle() {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftCreate(&handle_));
}
HIPFFTHandle(const HIPFFTHandle& other) = delete;
HIPFFTHandle& operator=(const HIPFFTHandle& other) = delete;
HIPFFTHandle(HIPFFTHandle&& other) = delete;
HIPFFTHandle& operator=(HIPFFTHandle&& other) = delete;
::hipfftHandle& get() { return handle_; }
const ::hipfftHandle& get() const { return handle_; }
~HIPFFTHandle() {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftDestroy(handle_));
}
};
using plan_size_type = int;
// This class contains all the information needed to execute a cuFFT plan:
// 1. the plan
// 2. the workspace size needed
class FFTConfig {
public:
// Only move semantics is enought for this class. Although we already use
// unique_ptr for the plan, still remove copy constructor and assignment op so
// we don't accidentally copy and take perf hit.
explicit FFTConfig(const FFTConfigKey& plan_key)
: FFTConfig(
std::vector<int64_t>(plan_key.sizes_,
plan_key.sizes_ + plan_key.signal_ndim_ + 1),
plan_key.signal_ndim_, plan_key.fft_type_, plan_key.value_type_) {}
// sizes are full signal, including batch size and always two-sided
FFTConfig(const std::vector<int64_t>& sizes, const int64_t signal_ndim,
FFTTransformType fft_type, ScalarType dtype)
: fft_type_(fft_type), value_type_(dtype) {
// signal sizes (excluding batch dim)
std::vector<plan_size_type> signal_sizes(sizes.begin() + 1, sizes.end());
// input batch size
const auto batch = static_cast<plan_size_type>(sizes[0]);
// const int64_t signal_ndim = sizes.size() - 1;
PADDLE_ENFORCE_EQ(signal_ndim, sizes.size() - 1,
platform::errors::InvalidArgument(
"The signal_ndim must be equal to sizes.size() - 1,"
"But signal_ndim is: [%d], sizes.size() - 1 is: [%d]",
signal_ndim, sizes.size() - 1));
hipfftType exec_type = [&] {
if (dtype == framework::proto::VarType::FP32) {
switch (fft_type) {
case FFTTransformType::C2C:
return HIPFFT_C2C;
case FFTTransformType::R2C:
return HIPFFT_R2C;
case FFTTransformType::C2R:
return HIPFFT_C2R;
}
} else if (dtype == framework::proto::VarType::FP64) {
switch (fft_type) {
case FFTTransformType::C2C:
return HIPFFT_Z2Z;
case FFTTransformType::R2C:
return HIPFFT_D2Z;
case FFTTransformType::C2R:
return HIPFFT_Z2D;
}
}
PADDLE_THROW(platform::errors::InvalidArgument(
"hipFFT only support transforms of type float32 and float64"));
}();
// disable auto allocation of workspace to use allocator from the framework
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftSetAutoAllocation(
plan(), /* autoAllocate */ 0));
size_t ws_size_t;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftMakePlanMany(
plan(), signal_ndim, signal_sizes.data(),
/* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1,
/* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, exec_type,
batch, &ws_size_t));
ws_size = ws_size_t;
}
const hipfftHandle& plan() const { return plan_ptr.get(); }
FFTTransformType transform_type() const { return fft_type_; }
ScalarType data_type() const { return value_type_; }
size_t workspace_size() const { return ws_size; }
private:
HIPFFTHandle plan_ptr;
size_t ws_size;
FFTTransformType fft_type_;
ScalarType value_type_;
};
#endif
// Hashing machinery for Key
// Fowler–Noll–Vo hash function
// see
// https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function
template <typename Key>
struct KeyHash {
// Key must be a POD because we read out its memory
// contenst as char* when hashing
static_assert(std::is_pod<Key>::value, "Key must be plain old data type");
size_t operator()(const Key& params) const {
auto ptr = reinterpret_cast<const uint8_t*>(&params);
uint32_t value = 0x811C9DC5;
for (int i = 0; i < static_cast<int>(sizeof(Key)); ++i) {
value ^= ptr[i];
value *= 0x01000193;
}
return static_cast<size_t>(value);
}
};
template <typename Key>
struct KeyEqual {
// Key must be a POD because we read out its memory
// contenst as char* when comparing
static_assert(std::is_pod<Key>::value, "Key must be plain old data type");
bool operator()(const Key& a, const Key& b) const {
auto ptr1 = reinterpret_cast<const uint8_t*>(&a);
auto ptr2 = reinterpret_cast<const uint8_t*>(&b);
return memcmp(ptr1, ptr2, sizeof(Key)) == 0;
}
};
#if CUDA_VERSION < 10000
// Note that the max plan number for CUDA version < 10 has to be 1023
// due to a bug that fails on the 1024th plan
constexpr size_t CUFFT_MAX_PLAN_NUM = 1023;
constexpr size_t CUFFT_DEFAULT_CACHE_SIZE = CUFFT_MAX_PLAN_NUM;
#else
constexpr size_t CUFFT_MAX_PLAN_NUM = std::numeric_limits<size_t>::max();
// The default max cache size chosen for CUDA version > 10 is arbitrary.
// This number puts a limit on how big of a plan cache should we maintain by
// default. Users can always configure it via cufft_set_plan_cache_max_size.
constexpr size_t CUFFT_DEFAULT_CACHE_SIZE = 4096;
#endif
static_assert(CUFFT_MAX_PLAN_NUM >= 0 &&
CUFFT_MAX_PLAN_NUM <= std::numeric_limits<size_t>::max(),
"CUFFT_MAX_PLAN_NUM not in size_t range");
static_assert(CUFFT_DEFAULT_CACHE_SIZE >= 0 &&
CUFFT_DEFAULT_CACHE_SIZE <= CUFFT_MAX_PLAN_NUM,
"CUFFT_DEFAULT_CACHE_SIZE not in [0, CUFFT_MAX_PLAN_NUM] range");
// This cache assumes that the mapping from key to value never changes.
// This is **NOT** thread-safe. Please use a mutex when using it **AND** the
// value returned from try_emplace_value.
// The contract of using this cache is that try_emplace_value should only be
// used when the max_size is positive.
class FFTConfigCache {
public:
using kv_t = typename std::pair<FFTConfigKey, FFTConfig>;
using map_t = typename std::unordered_map<
std::reference_wrapper<FFTConfigKey>, typename std::list<kv_t>::iterator,
KeyHash<FFTConfigKey>, KeyEqual<FFTConfigKey>>;
using map_kkv_iter_t = typename map_t::iterator;
FFTConfigCache() : FFTConfigCache(CUFFT_DEFAULT_CACHE_SIZE) {}
explicit FFTConfigCache(int64_t max_size) { _set_max_size(max_size); }
FFTConfigCache(const FFTConfigCache& other) = delete;
FFTConfigCache& operator=(const FFTConfigCache& other) = delete;
FFTConfigCache(FFTConfigCache&& other) noexcept
: _usage_list(std::move(other._usage_list)),
_cache_map(std::move(other._cache_map)),
_max_size(other._max_size) {}
FFTConfigCache& operator=(FFTConfigCache&& other) noexcept {
_usage_list = std::move(other._usage_list);
_cache_map = std::move(other._cache_map);
_max_size = other._max_size;
return *this;
}
// If key is in this cache, return the cached config. Otherwise, emplace the
// config in this cache and return it.
FFTConfig& lookup(FFTConfigKey params) {
PADDLE_ENFORCE_GT(_max_size, 0,
platform::errors::InvalidArgument(
"The max size of FFTConfigCache must be great than 0,"
"But received is [%d]",
_max_size));
map_kkv_iter_t map_it = _cache_map.find(params);
// Hit, put to list front
if (map_it != _cache_map.end()) {
_usage_list.splice(_usage_list.begin(), _usage_list, map_it->second);
return map_it->second->second;
}
// Miss
// remove if needed
if (_usage_list.size() >= _max_size) {
auto last = _usage_list.end();
last--;
_cache_map.erase(last->first);
_usage_list.pop_back();
}
// construct new plan at list front, then insert into _cache_map
_usage_list.emplace_front(std::piecewise_construct,
std::forward_as_tuple(params),
std::forward_as_tuple(params));
auto kv_it = _usage_list.begin();
_cache_map.emplace(std::piecewise_construct,
std::forward_as_tuple(kv_it->first),
std::forward_as_tuple(kv_it));
return kv_it->second;
}
void clear() {
_cache_map.clear();
_usage_list.clear();
}
void resize(int64_t new_size) {
_set_max_size(new_size);
auto cur_size = _usage_list.size();
if (cur_size > _max_size) {
auto delete_it = _usage_list.end();
for (size_t i = 0; i < cur_size - _max_size; i++) {
delete_it--;
_cache_map.erase(delete_it->first);
}
_usage_list.erase(delete_it, _usage_list.end());
}
}
size_t size() const { return _cache_map.size(); }
size_t max_size() const noexcept { return _max_size; }
std::mutex mutex;
private:
// Only sets size and does value check. Does not resize the data structures.
void _set_max_size(int64_t new_size) {
// We check that 0 <= new_size <= CUFFT_MAX_PLAN_NUM here. Since
// CUFFT_MAX_PLAN_NUM is of type size_t, we need to do non-negativity check
// first.
PADDLE_ENFORCE_GE(
new_size, 0,
platform::errors::InvalidArgument(
"cuFFT plan cache size must be non-negative, But received is [%d]",
new_size));
PADDLE_ENFORCE_LE(new_size, CUFFT_MAX_PLAN_NUM,
platform::errors::InvalidArgument(
"cuFFT plan cache size can not be larger than [%d], "
"But received is [%d]",
CUFFT_MAX_PLAN_NUM, new_size));
_max_size = static_cast<size_t>(new_size);
}
std::list<kv_t> _usage_list;
map_t _cache_map;
size_t _max_size;
};
static std::vector<std::unique_ptr<FFTConfigCache>> plan_caches;
static std::mutex plan_caches_mutex;
static inline FFTConfigCache& get_fft_plan_cache(int64_t device_index) {
std::lock_guard<std::mutex> guard(plan_caches_mutex);
if (device_index >= plan_caches.size()) {
plan_caches.resize(device_index + 1);
}
if (!plan_caches[device_index]) {
plan_caches[device_index] = std::make_unique<FFTConfigCache>();
}
return *plan_caches[device_index];
}
// Calculates the normalization constant
static double fft_normalization_scale(FFTNormMode normalization,
const std::vector<int64_t>& sizes,
const std::vector<int64_t>& dims) {
// auto norm = static_cast<fft_norm_mode>(normalization);
if (normalization == FFTNormMode::none) {
return static_cast<double>(1.0);
}
int64_t signal_numel = 1;
for (auto dim : dims) {
signal_numel *= sizes[dim];
}
const double scale_denom = (normalization == FFTNormMode::by_sqrt_n)
? std::sqrt(signal_numel)
: static_cast<double>(signal_numel);
return static_cast<double>(1.0 / scale_denom);
}
template <typename DeviceContext, typename T>
void exec_normalization(const DeviceContext& ctx, const Tensor* in, Tensor* out,
FFTNormMode normalization,
const std::vector<int64_t>& sizes,
const std::vector<int64_t>& axes) {
double scale = fft_normalization_scale(normalization, sizes, axes);
if (scale != 1.0) {
auto eigen_out = framework::EigenVector<T>::Flatten(*out);
auto eigen_in = framework::EigenVector<T>::Flatten(*in);
auto dev = ctx.eigen_device();
EigenScale<Eigen::GpuDevice, T>::Eval(*dev, eigen_out, eigen_in,
static_cast<T>(scale),
static_cast<T>(0), false);
} else {
framework::TensorCopy(*in, ctx.GetPlace(), out);
}
}
#if defined(PADDLE_WITH_CUDA)
static FFTConfigKey create_fft_configkey(const framework::Tensor& input,
const framework::Tensor& output,
int signal_ndim) {
// Create the transform plan (either from cache or locally)
const auto value_type =
framework::IsComplexType(framework::TransToProtoVarType(input.dtype()))
? framework::ToRealType(framework::TransToProtoVarType(input.dtype()))
: framework::TransToProtoVarType(input.dtype());
auto fft_type =
GetFFTTransformType(framework::TransToProtoVarType(input.dtype()),
framework::TransToProtoVarType(output.dtype()));
// signal sizes
std::vector<int64_t> signal_size(signal_ndim + 1);
signal_size[0] = input.dims()[0];
for (int64_t i = 1; i <= signal_ndim; ++i) {
auto in_size = input.dims()[i];
auto out_size = output.dims()[i];
signal_size[i] = std::max(in_size, out_size);
}
FFTConfigKey key(phi::vectorize(input.dims()), phi::vectorize(output.dims()),
signal_size, fft_type, value_type);
return key;
}
// Execute a pre-planned transform
static void exec_cufft_plan_raw(const FFTConfig& config, void* in_data,
void* out_data, bool forward) {
auto& plan = config.plan();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cufftXtExec(
plan, in_data, out_data, forward ? CUFFT_FORWARD : CUFFT_INVERSE));
}
template <typename DeviceContext, typename Ti, typename To>
void exec_cufft_plan(const DeviceContext& ctx, const FFTConfig& config,
framework::Tensor* input, framework::Tensor* output,
bool forward) {
// execute transform plan
auto fft_type = config.transform_type();
if (fft_type == FFTTransformType::C2R && forward) {
forward = false;
framework::Tensor input_conj(input->type());
input_conj.mutable_data<Ti>(input->dims(), ctx.GetPlace());
platform::ForRange<DeviceContext> for_range(ctx, input->numel());
phi::funcs::ConjFunctor<Ti> functor(input->data<Ti>(), input->numel(),
input_conj.data<Ti>());
for_range(functor);
exec_cufft_plan_raw(config, input_conj.data(), output->data(), forward);
} else if (fft_type == FFTTransformType::R2C && !forward) {
forward = true;
framework::Tensor out_conj(output->type());
out_conj.mutable_data<To>(output->dims(), ctx.GetPlace());
exec_cufft_plan_raw(config, input->data(), out_conj.data(), forward);
platform::ForRange<DeviceContext> for_range(ctx, output->numel());
phi::funcs::ConjFunctor<To> functor(out_conj.data<To>(), output->numel(),
output->data<To>());
for_range(functor);
} else {
exec_cufft_plan_raw(config, input->data(), output->data(), forward);
}
}
#elif defined(PADDLE_WITH_HIP)
static FFTConfigKey create_fft_configkey(const framework::Tensor& input,
const framework::Tensor& output,
int signal_ndim) {
// Create the transform plan (either from cache or locally)
const auto value_type =
framework::IsComplexType(framework::TransToProtoVarType(input.dtype()))
? framework::ToRealType(framework::TransToProtoVarType(input.dtype()))
: framework::TransToProtoVarType(input.dtype());
auto fft_type =
GetFFTTransformType(framework::TransToProtoVarType(input.dtype()),
framework::TransToProtoVarType(output.type()));
// signal sizes
std::vector<int64_t> signal_size(signal_ndim + 1);
signal_size[0] = input.dims()[0];
for (int64_t i = 1; i <= signal_ndim; ++i) {
auto in_size = input.dims()[i];
auto out_size = output.dims()[i];
signal_size[i] = std::max(in_size, out_size);
}
FFTConfigKey key(phi::vectorize(input.dims()), phi::vectorize(output.dims()),
signal_size, fft_type, value_type);
return key;
}
// Execute a pre-planned transform
static void exec_hipfft_plan_raw(const FFTConfig& config, void* in_data,
void* out_data, bool forward) {
auto& plan = config.plan();
auto value_type = config.data_type();
if (value_type == framework::proto::VarType::FP32) {
switch (config.transform_type()) {
case FFTTransformType::C2C: {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftExecC2C(
plan, static_cast<hipfftComplex*>(in_data),
static_cast<hipfftComplex*>(out_data),
forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD));
return;
}
case FFTTransformType::R2C: {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftExecR2C(
plan, static_cast<hipfftReal*>(in_data),
static_cast<hipfftComplex*>(out_data)));
return;
}
case FFTTransformType::C2R: {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftExecC2R(
plan, static_cast<hipfftComplex*>(in_data),
static_cast<hipfftReal*>(out_data)));
return;
}
}
} else if (value_type == framework::proto::VarType::FP64) {
switch (config.transform_type()) {
case FFTTransformType::C2C: {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftExecZ2Z(
plan, static_cast<hipfftDoubleComplex*>(in_data),
static_cast<hipfftDoubleComplex*>(out_data),
forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD));
return;
}
case FFTTransformType::R2C: {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftExecD2Z(
plan, static_cast<hipfftDoubleReal*>(in_data),
static_cast<hipfftDoubleComplex*>(out_data)));
return;
}
case FFTTransformType::C2R: {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftExecZ2D(
plan, static_cast<hipfftDoubleComplex*>(in_data),
static_cast<hipfftDoubleReal*>(out_data)));
return;
}
}
}
PADDLE_THROW(platform::errors::InvalidArgument(
"hipFFT only support transforms of type float32 and float64"));
}
template <typename DeviceContext, typename Ti, typename To>
void exec_hipfft_plan(const DeviceContext& ctx, const FFTConfig& config,
framework::Tensor* input, framework::Tensor* output,
bool forward) {
auto fft_type = config.transform_type();
if (fft_type == FFTTransformType::C2R && forward) {
forward = false;
framework::Tensor input_conj(input->type());
input_conj.mutable_data<Ti>(input->dims(), ctx.GetPlace());
platform::ForRange<DeviceContext> for_range(ctx, input->numel());
phi::funcs::ConjFunctor<Ti> functor(input->data<Ti>(), input->numel(),
input_conj.data<Ti>());
for_range(functor);
exec_hipfft_plan_raw(config, input_conj.data(), output->data(), forward);
} else if (fft_type == FFTTransformType::R2C && !forward) {
forward = true;
framework::Tensor out_conj(output->type());
out_conj.mutable_data<To>(output->dims(), ctx.GetPlace());
exec_hipfft_plan_raw(config, input->data(), out_conj.data(), forward);
platform::ForRange<DeviceContext> for_range(ctx, output->numel());
phi::funcs::ConjFunctor<To> functor(out_conj.data<To>(), output->numel(),
output->data<To>());
for_range(functor);
} else {
exec_hipfft_plan_raw(config, input->data(), output->data(), forward);
}
}
#endif
// Execute a general unnormalized fft operation (can be c2c, onesided r2c or
// onesided c2r)
template <typename DeviceContext, typename Ti, typename To>
void exec_fft(const DeviceContext& ctx, const Tensor* X, Tensor* out,
const std::vector<int64_t>& dim, bool forward) {
const auto x_dims = phi::vectorize(X->dims());
const int64_t ndim = static_cast<int64_t>(X->dims().size());
auto tensor_place = ctx.GetPlace();
// make a dim permutation
std::vector<int> dim_permute(ndim);
std::iota(dim_permute.begin(), dim_permute.end(), int{0});
std::vector<bool> is_transformed_dim(ndim);
for (const auto& d : dim) {
is_transformed_dim[d] = true;
}
auto batch_end =
std::partition(dim_permute.begin(), dim_permute.end(),
[&](int64_t d) { return !is_transformed_dim[d]; });
std::sort(dim_permute.begin(), batch_end);
std::copy(dim.cbegin(), dim.cend(), batch_end);
// transpose input according to dim permutation
auto transposed_input_shape = X->dims().transpose(dim_permute);
framework::Tensor transposed_input;
transposed_input.Resize(transposed_input_shape);
transposed_input.mutable_data<Ti>(tensor_place);
TransCompute<DeviceContext, Ti>(ndim, ctx, *X, &transposed_input,
dim_permute);
// Reshape batch dimensions into a single dimension
const int64_t signal_ndim = static_cast<int64_t>(dim.size());
std::vector<int64_t> collapsed_input_shape(signal_ndim + 1);
auto transposed_input_shape_ = phi::vectorize(transposed_input_shape);
const int64_t batch_dims = ndim - signal_ndim;
auto batch_size =
std::accumulate(transposed_input_shape_.begin(),
transposed_input_shape_.begin() + batch_dims,
static_cast<int>(1), std::multiplies<int>());
collapsed_input_shape[0] = batch_size;
std::copy(transposed_input_shape_.begin() + batch_dims,
transposed_input_shape_.end(), collapsed_input_shape.begin() + 1);
framework::Tensor& collapsed_input = transposed_input;
collapsed_input.Resize(phi::make_ddim(collapsed_input_shape));
// make a collpased output
const auto out_dims = phi::vectorize(out->dims());
std::vector<int64_t> collapsed_output_shape(1 + signal_ndim);
collapsed_output_shape[0] = batch_size;
for (size_t i = 0; i < dim.size(); ++i) {
collapsed_output_shape[i + 1] = out_dims[dim[i]];
}
framework::Tensor collapsed_output;
collapsed_output.Resize(phi::make_ddim(collapsed_output_shape));
collapsed_output.mutable_data<To>(tensor_place);
FFTConfig* config = nullptr;
#if defined(PADDLE_WITH_CUDA)
std::unique_ptr<FFTConfig> config_ = nullptr;
// create plan
FFTConfigKey key =
create_fft_configkey(collapsed_input, collapsed_output, signal_ndim);
bool using_cache = false;
#if !defined(CUFFT_VERSION) || (CUFFT_VERSION < 10200)
using_cache = true;
#endif
if (using_cache) {
const int64_t device_id = static_cast<int64_t>(
reinterpret_cast<const platform::CUDAPlace*>(&collapsed_input.place())
->GetDeviceId());
FFTConfigCache& plan_cache = get_fft_plan_cache(device_id);
std::unique_lock<std::mutex> guard(plan_cache.mutex, std::defer_lock);
guard.lock();
config = &(plan_cache.lookup(key));
} else {
config_ = std::make_unique<FFTConfig>(key);
config = config_.get();
}
// prepare cufft for execution
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cufftSetStream(config->plan(), ctx.stream()));
framework::Tensor workspace_tensor;
workspace_tensor.mutable_data<To>(tensor_place, config->workspace_size());
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cufftSetWorkArea(
config->plan(), workspace_tensor.data<To>()));
// execute transform plan
exec_cufft_plan<DeviceContext, Ti, To>(ctx, *config, &collapsed_input,
&collapsed_output, forward);
#elif defined(PADDLE_WITH_HIP)
// create plan
FFTConfigKey key =
create_fft_configkey(collapsed_input, collapsed_output, signal_ndim);
const int64_t device_id = static_cast<int64_t>(
reinterpret_cast<const platform::CUDAPlace*>(&collapsed_input.place())
->GetDeviceId());
FFTConfigCache& plan_cache = get_fft_plan_cache(device_id);
std::unique_lock<std::mutex> guard(plan_cache.mutex, std::defer_lock);
guard.lock();
config = &(plan_cache.lookup(key));
// prepare cufft for execution
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::hipfftSetStream(config->plan(), ctx.stream()));
framework::Tensor workspace_tensor;
workspace_tensor.mutable_data<To>(tensor_place, config->workspace_size());
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftSetWorkArea(
config->plan(), workspace_tensor.data<To>()));
// execute transform plan
exec_hipfft_plan<DeviceContext, Ti, To>(ctx, *config, &collapsed_input,
&collapsed_output, forward);
#endif
// Inverting output by reshape and transpose to original batch and dimension
auto transposed_out_shape = out->dims().transpose(dim_permute);
collapsed_output.Resize(transposed_out_shape);
auto& transposed_output = collapsed_output;
std::vector<int> reverse_dim_permute(ndim);
for (size_t i = 0; i < ndim; i++) {
reverse_dim_permute[dim_permute[i]] = i;
}
TransCompute<DeviceContext, To>(ndim, ctx, transposed_output, out,
reverse_dim_permute);
}
// Use the optimized path to perform single R2C or C2R if transformation dim is
// supported by cuFFT
static bool use_optimized_fft_path(const std::vector<int64_t>& axes) {
// For performance reason, when axes starts with (0, 1), do not use the
// optimized path.
if (axes.size() > kMaxFFTNdim ||
(axes.size() >= 2 && axes[0] == 0 && axes[1] == 1)) {
return false;
} else {
return true;
}
}
template <typename Ti, typename To>
struct FFTC2CFunctor<platform::CUDADeviceContext, Ti, To> {
void operator()(const platform::CUDADeviceContext& ctx, const Tensor* X,
Tensor* out, const std::vector<int64_t>& axes,
FFTNormMode normalization, bool forward) {
if (axes.empty()) {
framework::TensorCopy(*X, ctx.GetPlace(), out);
return;
}
framework::Tensor* p_out = out;
std::vector<int64_t> out_dims = phi::vectorize(X->dims());
std::vector<int64_t> working_axes(axes.begin(), axes.end());
std::vector<int64_t> first_dims;
size_t max_dims;
framework::Tensor working_tensor;
working_tensor.mutable_data<Ti>(X->dims(), ctx.GetPlace());
framework::Tensor* p_working_tensor = &working_tensor;
framework::TensorCopy(*X, ctx.GetPlace(), &working_tensor);
while (true) {
max_dims =
std::min(static_cast<size_t>(kMaxFFTNdim), working_axes.size());
first_dims.assign(working_axes.end() - max_dims, working_axes.end());
exec_fft<platform::CUDADeviceContext, Ti, To>(ctx, p_working_tensor,
p_out, first_dims, forward);
working_axes.resize(working_axes.size() - max_dims);
first_dims.clear();
if (working_axes.empty()) {
break;
}
std::swap(p_out, p_working_tensor);
}
exec_normalization<platform::CUDADeviceContext, To>(
ctx, p_out, out, normalization, out_dims, axes);
}
};
template <typename Ti, typename To>
struct FFTC2RFunctor<platform::CUDADeviceContext, Ti, To> {
void operator()(const platform::CUDADeviceContext& ctx, const Tensor* X,
Tensor* out, const std::vector<int64_t>& axes,
FFTNormMode normalization, bool forward) {
std::vector<int64_t> in_dims = phi::vectorize(X->dims());
std::vector<int64_t> out_dims = phi::vectorize(out->dims());
if (use_optimized_fft_path(axes)) {
framework::Tensor x_copy(X->type());
x_copy.mutable_data<Ti>(X->dims(), ctx.GetPlace());
framework::TensorCopy(*X, ctx.GetPlace(), &x_copy);
exec_fft<platform::CUDADeviceContext, Ti, To>(ctx, &x_copy, out, axes,
forward);
} else {
framework::Tensor temp_tensor;
temp_tensor.mutable_data<Ti>(X->dims(), ctx.GetPlace());
const std::vector<int64_t> dims(axes.begin(), axes.end() - 1);
FFTC2CFunctor<platform::CUDADeviceContext, Ti, Ti> c2c_functor;
c2c_functor(ctx, X, &temp_tensor, dims, FFTNormMode::none, forward);
exec_fft<platform::CUDADeviceContext, Ti, To>(ctx, &temp_tensor, out,
{axes.back()}, forward);
}
exec_normalization<platform::CUDADeviceContext, To>(
ctx, out, out, normalization, out_dims, axes);
}
};
// n dimension real to complex FFT use cufft lib
template <typename Ti, typename To>
struct FFTR2CFunctor<platform::CUDADeviceContext, Ti, To> {
void operator()(const platform::CUDADeviceContext& ctx, const Tensor* X,
Tensor* out, const std::vector<int64_t>& axes,
FFTNormMode normalization, bool forward) {
// Step1: R2C transform on the last dimension
framework::Tensor* r2c_out = out;
const std::vector<int64_t> last_dim{axes.back()};
std::vector<int64_t> out_dims = phi::vectorize(out->dims());
exec_fft<platform::CUDADeviceContext, Ti, To>(ctx, X, r2c_out, last_dim,
forward);
// Step2: C2C transform on the remaining dimension
framework::Tensor c2c_out;
if (axes.size() > 1) {
c2c_out.mutable_data<To>(out->dims(), ctx.GetPlace());
std::vector<int64_t> remain_dim(axes.begin(), axes.end() - 1);
FFTC2CFunctor<platform::CUDADeviceContext, To, To> fft_c2c_func;
fft_c2c_func(ctx, r2c_out, &c2c_out, remain_dim, FFTNormMode::none,
forward);
}
const auto in_sizes = phi::vectorize(X->dims());
framework::Tensor* norm_tensor = axes.size() > 1 ? &c2c_out : r2c_out;
exec_normalization<platform::CUDADeviceContext, To>(
ctx, norm_tensor, out, normalization, in_sizes, axes);
}
};
} // namespace operators
} // namespace paddle
...@@ -11,8 +11,11 @@ ...@@ -11,8 +11,11 @@
#pragma once #pragma once
#define NOMINMAX // to use std::min std::max correctly on windows #define NOMINMAX // to use std::min std::max correctly on windows
#include <algorithm>
#include <functional>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <numeric>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
...@@ -23,8 +26,10 @@ ...@@ -23,8 +26,10 @@
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/conj_op.h" #include "paddle/fluid/operators/conj_op.h"
#include "paddle/fluid/operators/eigen/eigen_function.h" #include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/padding.h" #include "paddle/phi/kernels/funcs/padding.h"
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__)
#include "thrust/device_vector.h" #include "thrust/device_vector.h"
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/stft_op.h"
#include "paddle/fluid/operators/spectral_helper.h"
namespace paddle {
namespace operators {
class StftOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "frame");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "frame");
const int n_fft = ctx->Attrs().Get<int>("n_fft");
const int hop_length = ctx->Attrs().Get<int>("hop_length");
const auto x_dims = ctx->GetInputDim("X");
const int x_rank = x_dims.size();
const bool onesided = ctx->Attrs().Get<bool>("onesided");
PADDLE_ENFORCE_EQ(
x_rank, 2,
platform::errors::InvalidArgument(
"Input(X) of StftOp should be a tensor with shape [N, T], "
"but got rank %s.",
x_rank));
PADDLE_ENFORCE_GT(
hop_length, 0,
platform::errors::InvalidArgument(
"Attribute(hop_length) should be greater than 0, but got %s.",
hop_length));
int seq_length = x_dims[x_rank - 1];
int n_frames = 1 + (seq_length - n_fft) / hop_length;
PADDLE_ENFORCE_LE(n_fft, seq_length,
platform::errors::InvalidArgument(
"Attribute(frame_length) should be less equal than "
"sequence length, but got (%s) > (%s).",
n_fft, seq_length));
std::vector<int64_t> output_shape;
output_shape.push_back(x_dims[0]);
if (onesided) {
output_shape.push_back(n_fft / 2 + 1);
} else {
output_shape.push_back(n_fft);
}
output_shape.push_back(n_frames);
ctx->SetOutputDim("Out", phi::make_ddim(output_shape));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(in_dtype, ctx.GetPlace());
}
};
class StftOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input waveforms with shape (N, T)");
AddOutput("Out",
"The complex STFT output tensor with shape (N, n_fft, "
"num_frames) or (N, n_fft/2 + 1, num_frames)");
AddAttr<int>("n_fft", "The number of input samples to perform FFT");
AddAttr<int>("hop_length", "Number of samples between adjacent frames");
AddAttr<bool>("normalized",
"Control whether to scale the output by 1/sqrt(n_fft)");
AddAttr<bool>("onesided",
"Control whether to return half of the FFT output");
AddComment(R"DOC(
Short-time Fourier transform (STFT).
)DOC");
}
};
template <typename T>
class StftGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("stft_grad");
grad_op->SetInput("X", this->Input("X"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
}
};
class StftGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
const auto out_grad_name = framework::GradVarName("Out");
OP_INOUT_CHECK(ctx->HasInput(out_grad_name), "Input", out_grad_name,
"stft_grad");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "stft_grad");
const auto x_grad_name = framework::GradVarName("X");
OP_INOUT_CHECK(ctx->HasOutput(x_grad_name), "Output", x_grad_name,
"stft_grad");
ctx->ShareDim("X", /*->*/ x_grad_name);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const auto in_dtype = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
const auto kernel_dtype = framework::ToRealType(in_dtype);
return framework::OpKernelType(kernel_dtype, ctx.GetPlace());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(stft, ops::StftOp, ops::StftOpMaker,
ops::StftGradOpMaker<paddle::framework::OpDesc>,
ops::StftGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(stft_grad, ops::StftGradOp);
REGISTER_OP_CPU_KERNEL(
stft, ops::StftKernel<paddle::platform::CPUDeviceContext, float>,
ops::StftKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
stft_grad, ops::StftGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::StftGradKernel<paddle::platform::CPUDeviceContext, double>);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/spectral_op.cu.h"
#include "paddle/fluid/operators/stft_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
stft, ops::StftKernel<paddle::platform::CUDADeviceContext, float>,
ops::StftKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
stft_grad, ops::StftGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::StftGradKernel<paddle::platform::CUDADeviceContext, double>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/frame_op.h"
#include "paddle/fluid/operators/spectral_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class StftKernel : public framework::OpKernel<T> {
public:
/*
Batch Signals (N, T) -> Frames (N, n_fft, num_frames) -> FFTR2C -> (N,
n_fft/2 + 1, num_frames) or (N, n_fft, num_frames)
*/
void Compute(const framework::ExecutionContext& ctx) const override {
using C = paddle::platform::complex<T>;
const Tensor* x = ctx.Input<Tensor>("X");
Tensor* out = ctx.Output<Tensor>("Out");
out->mutable_data<C>(ctx.GetPlace());
const size_t x_rank = x->dims().size();
const size_t out_rank = out->dims().size();
const int n_fft = ctx.Attr<int>("n_fft");
const int hop_length = ctx.Attr<int>("hop_length");
const bool normalized = ctx.Attr<bool>("normalized");
const bool onesided = ctx.Attr<bool>("onesided");
const int n_frames = out->dims()[out_rank - 1];
const int seq_length = x->dims()[x_rank - 1];
auto& dev_ctx = ctx.device_context<DeviceContext>();
std::vector<int64_t> axes = {1};
// Frame
Tensor frames;
framework::DDim frames_dims(out->dims());
frames_dims.at(axes.back()) = n_fft;
frames.mutable_data<T>(frames_dims, ctx.GetPlace());
FrameFunctor<DeviceContext, T>()(dev_ctx, x, &frames, seq_length, n_fft,
n_frames, hop_length, /*is_grad*/ false);
// FFTR2C
FFTNormMode normalization;
if (normalized) {
normalization = get_norm_from_string("ortho", true);
} else {
normalization = get_norm_from_string("backward", true);
}
FFTR2CFunctor<DeviceContext, T, C> fft_r2c_func;
if (onesided) {
fft_r2c_func(dev_ctx, &frames, out, axes, normalization, true);
} else {
framework::DDim onesided_dims(out->dims());
const int64_t onesided_axis_size = out->dims().at(axes.back()) / 2 + 1;
onesided_dims.at(axes.back()) = onesided_axis_size;
Tensor onesided_out;
onesided_out.mutable_data<C>(onesided_dims, ctx.GetPlace());
fft_r2c_func(dev_ctx, &frames, &onesided_out, axes, normalization, true);
fill_conj<DeviceContext, C>(dev_ctx, &onesided_out, out, axes);
}
}
};
template <typename DeviceContext, typename T>
class StftGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using C = paddle::platform::complex<T>;
auto& dev_ctx = ctx.device_context<DeviceContext>();
const auto* dy = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
dx->mutable_data<T>(ctx.GetPlace());
const size_t dy_rank = dy->dims().size();
const size_t dx_rank = dx->dims().size();
const int n_fft = ctx.Attr<int>("n_fft");
const int hop_length = ctx.Attr<int>("hop_length");
const bool normalized = ctx.Attr<bool>("normalized");
const bool onesided = ctx.Attr<bool>("onesided");
const int n_frames = dy->dims()[dy_rank - 1];
const int seq_length = dx->dims()[dx_rank - 1];
std::vector<int64_t> axes = {1};
Tensor d_frames;
framework::DDim d_frames_dims(dy->dims());
d_frames_dims.at(axes.back()) = n_fft;
d_frames.mutable_data<T>(d_frames_dims, ctx.GetPlace());
Tensor complex_d_frames;
complex_d_frames.mutable_data<C>(d_frames_dims, ctx.GetPlace());
// dy -> d_frames
FFTNormMode normalization;
if (normalized) {
normalization = get_norm_from_string("ortho", true);
} else {
normalization = get_norm_from_string("backward", true);
}
FFTC2CFunctor<DeviceContext, C, C> fft_c2c_func;
if (!onesided) {
fft_c2c_func(dev_ctx, dy, &complex_d_frames, axes, normalization, false);
} else {
Tensor full_dy;
full_dy.mutable_data<C>(d_frames_dims, ctx.GetPlace());
auto zero_length = static_cast<int>(full_dy.dims().at(axes.back()) -
dy->dims().at(axes.back()));
auto rank = dy->dims().size();
std::vector<int> pads(rank * 2, 0);
pads[axes.back() * 2 + 1] = zero_length;
phi::funcs::PaddingFunctor<DeviceContext, C>(
rank, ctx.template device_context<DeviceContext>(), pads,
static_cast<C>(0), *dy, &full_dy);
fft_c2c_func(dev_ctx, &full_dy, &complex_d_frames, axes, normalization,
false);
}
framework::TransComplexToReal(
framework::TransToProtoVarType(d_frames.dtype()),
framework::TransToProtoVarType(complex_d_frames.dtype()),
complex_d_frames, &d_frames);
// d_frames -> dx
FrameFunctor<DeviceContext, T>()(dev_ctx, &d_frames, dx, seq_length, n_fft,
n_frames, hop_length, /*is_grad*/ true);
}
};
} // namespace operators
} // namespace paddle
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/phi/kernels/pad3d_kernel.h" #include "paddle/phi/kernels/pad3d_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
namespace phi { namespace phi {
...@@ -574,5 +575,13 @@ void Pad3dKernel(const Context& dev_ctx, ...@@ -574,5 +575,13 @@ void Pad3dKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(pad3d,
pad3d, CPU, ALL_LAYOUT, phi::Pad3dKernel, float, double, int, int64_t) {} CPU,
ALL_LAYOUT,
phi::Pad3dKernel,
float,
double,
int,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
namespace phi { namespace phi {
...@@ -585,4 +586,6 @@ PD_REGISTER_KERNEL(pad3d, ...@@ -585,4 +586,6 @@ PD_REGISTER_KERNEL(pad3d,
float, float,
double, double,
int, int,
int64_t) {} int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
...@@ -266,9 +266,10 @@ def generate_activation_fn(op_type): ...@@ -266,9 +266,10 @@ def generate_activation_fn(op_type):
op_type) op_type)
else: else:
# abs exp square ops support dtype(int32, int64, float16, float32, float64) # abs exp square ops support dtype(int32, int64, float16, float32, float64)
check_variable_and_dtype( check_variable_and_dtype(x, 'x', [
x, 'x', ['int32', 'int64', 'float16', 'float32', 'float64'], 'int32', 'int64', 'float16', 'float32', 'float64', 'complex64',
op_type) 'complex128'
], op_type)
helper = LayerHelper(op_type, **locals()) helper = LayerHelper(op_type, **locals())
......
...@@ -5616,9 +5616,10 @@ def transpose(x, perm, name=None): ...@@ -5616,9 +5616,10 @@ def transpose(x, perm, name=None):
out, _ = _C_ops.transpose2(x, 'axis', perm) out, _ = _C_ops.transpose2(x, 'axis', perm)
return out return out
check_variable_and_dtype( check_variable_and_dtype(x, 'x', [
x, 'x', ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], 'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'complex64',
'transpose') 'complex128'
], 'transpose')
check_type(perm, 'perm', (list, tuple), 'transpose') check_type(perm, 'perm', (list, tuple), 'transpose')
if isinstance(perm, tuple): if isinstance(perm, tuple):
perm = list(perm) perm = list(perm)
...@@ -6410,10 +6411,10 @@ def squeeze(input, axes, name=None): ...@@ -6410,10 +6411,10 @@ def squeeze(input, axes, name=None):
return out return out
helper = LayerHelper("squeeze", **locals()) helper = LayerHelper("squeeze", **locals())
check_variable_and_dtype( check_variable_and_dtype(input, 'input', [
input, 'input', 'float16', 'float32', 'float64', 'bool', 'int8', 'int32', 'int64',
['float16', 'float32', 'float64', 'bool', 'int8', 'int32', 'int64'], 'complex64', 'complex128'
'squeeze') ], 'squeeze')
check_type(axes, 'axis/axes', (list, tuple), 'squeeze') check_type(axes, 'axis/axes', (list, tuple), 'squeeze')
out = helper.create_variable_for_type_inference(dtype=input.dtype) out = helper.create_variable_for_type_inference(dtype=input.dtype)
x_shape = helper.create_variable_for_type_inference(dtype=input.dtype) x_shape = helper.create_variable_for_type_inference(dtype=input.dtype)
...@@ -6471,8 +6472,16 @@ def unsqueeze(input, axes, name=None): ...@@ -6471,8 +6472,16 @@ def unsqueeze(input, axes, name=None):
check_type(axes, 'axis/axes', (int, list, tuple, Variable), 'unsqueeze') check_type(axes, 'axis/axes', (int, list, tuple, Variable), 'unsqueeze')
check_variable_and_dtype(input, 'input', [ check_variable_and_dtype(input, 'input', [
'float16', 'float32', 'float64', 'bool', 'int8', 'int16', 'int32', 'float16',
'int64' 'float32',
'float64',
'bool',
'int8',
'int16',
'int32',
'int64',
'complex64',
'complex128',
], 'unsqueeze') ], 'unsqueeze')
helper = LayerHelper("unsqueeze2", **locals()) helper = LayerHelper("unsqueeze2", **locals())
inputs = {"X": input} inputs = {"X": input}
......
...@@ -756,7 +756,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): ...@@ -756,7 +756,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
check_shape(shape) check_shape(shape)
check_dtype(dtype, 'dtype', [ check_dtype(dtype, 'dtype', [
'bool', 'float16', 'float32', 'float64', 'uint8', 'int16', 'int32', 'bool', 'float16', 'float32', 'float64', 'uint8', 'int16', 'int32',
'int64' 'int64', 'complex64', 'complex128'
], 'fill_constant') ], 'fill_constant')
check_type(shape, 'shape', (Variable, list, tuple), 'fill_constant') check_type(shape, 'shape', (Variable, list, tuple), 'fill_constant')
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from numpy.lib.stride_tricks import as_strided
import paddle
import unittest
from op_test import OpTest
def frame_from_librosa(x, frame_length, hop_length, axis=-1):
if axis == -1 and not x.flags["C_CONTIGUOUS"]:
x = np.ascontiguousarray(x)
elif axis == 0 and not x.flags["F_CONTIGUOUS"]:
x = np.asfortranarray(x)
n_frames = 1 + (x.shape[axis] - frame_length) // hop_length
strides = np.asarray(x.strides)
if axis == -1:
shape = list(x.shape)[:-1] + [frame_length, n_frames]
strides = list(strides) + [hop_length * x.itemsize]
elif axis == 0:
shape = [n_frames, frame_length] + list(x.shape)[1:]
strides = [hop_length * x.itemsize] + list(strides)
else:
raise ValueError("Frame axis={} must be either 0 or -1".format(axis))
return as_strided(x, shape=shape, strides=strides)
def stft_np(x, n_fft, hop_length, **kwargs):
frames = frame_from_librosa(x, n_fft, hop_length)
res = np.fft.rfft(frames, axis=1)
return res
class TestStftOp(OpTest):
def setUp(self):
self.op_type = "stft"
self.shape, self.type, self.attrs = self.initTestCase()
self.inputs = {
'X': np.random.random(size=self.shape).astype(self.type),
}
self.outputs = {'Out': stft_np(x=self.inputs['X'], **self.attrs)}
def initTestCase(self):
input_shape = (2, 100)
input_type = 'float64'
attrs = {
'n_fft': 50,
'hop_length': 15,
'normalized': False,
'onesided': True,
}
return input_shape, input_type, attrs
def test_check_output(self):
paddle.enable_static()
self.check_output()
paddle.disable_static()
def test_check_grad_normal(self):
paddle.enable_static()
self.check_grad(['X'], 'Out')
paddle.disable_static()
if __name__ == '__main__':
unittest.main()
...@@ -119,10 +119,11 @@ def frame(x, frame_length, hop_length, axis=-1, name=None): ...@@ -119,10 +119,11 @@ def frame(x, frame_length, hop_length, axis=-1, name=None):
f'Unexpected hop_length: {hop_length}. It should be an positive integer.' f'Unexpected hop_length: {hop_length}. It should be an positive integer.'
) )
if frame_length > x.shape[axis]: if in_dygraph_mode():
raise ValueError( if frame_length > x.shape[axis]:
f'Attribute frame_length should be less equal than sequence length, ' raise ValueError(
f'but got ({frame_length}) > ({x.shape[axis]}).') f'Attribute frame_length should be less equal than sequence length, '
f'but got ({frame_length}) > ({x.shape[axis]}).')
op_type = 'frame' op_type = 'frame'
...@@ -306,8 +307,7 @@ def stft(x, ...@@ -306,8 +307,7 @@ def stft(x,
y1 = stft(x, n_fft=512, center=False, onesided=False) # [8, 512, 372] y1 = stft(x, n_fft=512, center=False, onesided=False) # [8, 512, 372]
""" """
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'complex64', 'complex128'], x, 'x', ['float32', 'float64', 'complex64', 'complex128'], 'stft')
'stft')
x_rank = len(x.shape) x_rank = len(x.shape)
assert x_rank in [1, 2], \ assert x_rank in [1, 2], \
...@@ -325,8 +325,9 @@ def stft(x, ...@@ -325,8 +325,9 @@ def stft(x,
if win_length is None: if win_length is None:
win_length = n_fft win_length = n_fft
assert 0 < n_fft <= x.shape[-1], \ if in_dygraph_mode():
f'n_fft should be in (0, seq_length({x.shape[-1]})], but got {n_fft}.' assert 0 < n_fft <= x.shape[-1], \
f'n_fft should be in (0, seq_length({x.shape[-1]})], but got {n_fft}.'
assert 0 < win_length <= n_fft, \ assert 0 < win_length <= n_fft, \
f'win_length should be in (0, n_fft({n_fft})], but got {win_length}.' f'win_length should be in (0, n_fft({n_fft})], but got {win_length}.'
...@@ -359,7 +360,7 @@ def stft(x, ...@@ -359,7 +360,7 @@ def stft(x,
x_frames = x_frames.transpose( x_frames = x_frames.transpose(
perm=[0, 2, perm=[0, 2,
1]) # switch n_fft to last dim, egs: (batch, num_frames, n_fft) 1]) # switch n_fft to last dim, egs: (batch, num_frames, n_fft)
x_frames = x_frames * window x_frames = paddle.multiply(x_frames, window)
norm = 'ortho' if normalized else 'backward' norm = 'ortho' if normalized else 'backward'
if is_complex(x_frames): if is_complex(x_frames):
...@@ -495,18 +496,22 @@ def istft(x, ...@@ -495,18 +496,22 @@ def istft(x,
n_frames = x.shape[-1] n_frames = x.shape[-1]
fft_size = x.shape[-2] fft_size = x.shape[-2]
if onesided: if in_dygraph_mode():
assert (fft_size == n_fft // 2 + 1), \ if onesided:
'fft_size should be equal to n_fft // 2 + 1({}) when onesided is True, but got {}.'.format(n_fft // 2 + 1, fft_size) assert (fft_size == n_fft // 2 + 1), \
else: 'fft_size should be equal to n_fft // 2 + 1({}) when onesided is True, but got {}.'.format(n_fft // 2 + 1, fft_size)
assert (fft_size == n_fft), \ else:
'fft_size should be equal to n_fft({}) when onesided is False, but got {}.'.format(n_fft, fft_size) assert (fft_size == n_fft), \
'fft_size should be equal to n_fft({}) when onesided is False, but got {}.'.format(n_fft, fft_size)
if window is not None: if window is not None:
assert len(window.shape) == 1 and len(window) == win_length, \ assert len(window.shape) == 1 and len(window) == win_length, \
'expected a 1D window tensor of size equal to win_length({}), but got window with shape {}.'.format(win_length, window.shape) 'expected a 1D window tensor of size equal to win_length({}), but got window with shape {}.'.format(win_length, window.shape)
else: else:
window = paddle.ones(shape=(win_length, )) window_dtype = paddle.float32 if x.dtype in [
paddle.float32, paddle.complex64
] else paddle.float64
window = paddle.ones(shape=(win_length, ), dtype=window_dtype)
if win_length < n_fft: if win_length < n_fft:
pad_left = (n_fft - win_length) // 2 pad_left = (n_fft - win_length) // 2
...@@ -534,15 +539,15 @@ def istft(x, ...@@ -534,15 +539,15 @@ def istft(x,
x = x[:, :, :n_fft // 2 + 1] x = x[:, :, :n_fft // 2 + 1]
out = fft_c2r(x=x, n=None, axis=-1, norm=norm, forward=False, name=None) out = fft_c2r(x=x, n=None, axis=-1, norm=norm, forward=False, name=None)
out = paddle.multiply(out, window).transpose(
perm=[0, 2, 1]) # (batch, n_fft, num_frames)
out = overlap_add( out = overlap_add(
x=(out * window).transpose( x=out, hop_length=hop_length, axis=-1) # (batch, seq_length)
perm=[0, 2, 1]), # (batch, n_fft, num_frames)
hop_length=hop_length,
axis=-1) # (batch, seq_length)
window_envelop = overlap_add( window_envelop = overlap_add(
x=paddle.tile( x=paddle.tile(
x=window * window, repeat_times=[n_frames, 1]).transpose( x=paddle.multiply(window, window).unsqueeze(0),
repeat_times=[n_frames, 1]).transpose(
perm=[1, 0]), # (n_fft, num_frames) perm=[1, 0]), # (n_fft, num_frames)
hop_length=hop_length, hop_length=hop_length,
axis=-1) # (seq_length, ) axis=-1) # (seq_length, )
...@@ -561,7 +566,7 @@ def istft(x, ...@@ -561,7 +566,7 @@ def istft(x,
window_envelop = window_envelop[start:start + length] window_envelop = window_envelop[start:start + length]
# Check whether the Nonzero Overlap Add (NOLA) constraint is met. # Check whether the Nonzero Overlap Add (NOLA) constraint is met.
if window_envelop.abs().min().item() < 1e-11: if in_dygraph_mode() and window_envelop.abs().min().item() < 1e-11:
raise ValueError( raise ValueError(
'Abort istft because Nonzero Overlap Add (NOLA) condition failed. For more information about NOLA constraint please see `scipy.signal.check_NOLA`(https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.check_NOLA.html).' 'Abort istft because Nonzero Overlap Add (NOLA) condition failed. For more information about NOLA constraint please see `scipy.signal.check_NOLA`(https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.check_NOLA.html).'
) )
......
...@@ -147,7 +147,9 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None): ...@@ -147,7 +147,9 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None):
var_names = {'x': x, 'y': y} var_names = {'x': x, 'y': y}
for name, val in var_names.items(): for name, val in var_names.items():
check_variable_and_dtype( check_variable_and_dtype(
val, name, ['float16', 'float32', 'float64'], 'matmul') val, name,
['float16', 'float32', 'float64', 'complex64', 'complex128'],
'matmul')
__check_input(x, y) __check_input(x, y)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册