// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/utils/data_type.h" namespace phi { namespace funcs { enum class FFTNormMode : int8_t { none, // No normalization by_sqrt_n, // Divide by sqrt(signal_size) by_n, // Divide by signal_size }; inline FFTNormMode get_norm_from_string(const std::string& norm, bool forward) { if (norm.empty() || norm == "backward") { return forward ? FFTNormMode::none : FFTNormMode::by_n; } if (norm == "forward") { return forward ? FFTNormMode::by_n : FFTNormMode::none; } if (norm == "ortho") { return FFTNormMode::by_sqrt_n; } PADDLE_THROW(phi::errors::InvalidArgument( "FFT norm string must be 'forward' or 'backward' or 'ortho', " "received %s", norm)); } enum class FFTTransformType : int8_t { C2C = 0, // Complex-to-complex R2C, // Real-to-complex C2R, // Complex-to-real }; // Create transform type enum from bools representing if input and output are // complex inline FFTTransformType GetFFTTransformType(DataType input_dtype, DataType output_dtype) { auto complex_input = IsComplexType(input_dtype); auto complex_output = IsComplexType(output_dtype); if (complex_input && complex_output) { return FFTTransformType::C2C; } else if (complex_input && !complex_output) { return FFTTransformType::C2R; } else if (!complex_input && complex_output) { return FFTTransformType::R2C; } PADDLE_THROW( phi::errors::InvalidArgument("Real to real FFTs are not supported")); } template struct FFTC2CFunctor { void operator()(const DeviceContext& ctx, const DenseTensor& X, DenseTensor* out, const std::vector& axes, FFTNormMode normalization, bool forward); }; template struct FFTR2CFunctor { void operator()(const DeviceContext& ctx, const DenseTensor& X, DenseTensor* out, const std::vector& axes, FFTNormMode normalization, bool forward); }; template struct FFTC2RFunctor { void operator()(const DeviceContext& ctx, const DenseTensor& X, DenseTensor* out, const std::vector& axes, FFTNormMode normalization, bool forward); }; } // namespace funcs } // namespace phi