未验证 提交 153f1138 编写于 作者: F Feiyu Chan 提交者: GitHub

move fft kernels to phi (#44714)

* move fft kernels to phi, done with cufft, pocketfft, mkl_cdft, hipfft
* make stft_op use fft from phi/kernels/funcs, clean code
上级 0cbc870e
......@@ -104,7 +104,7 @@ endif()
set(OP_HEADER_DEPS ${OP_HEADER_DEPS} phi phi_api_utils gather_scatter_kernel backward_infermeta)
register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op load_combine_op lstm_op run_program_op eye_op quantize_linear_op
recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op spectral_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})
recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})
op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc DEPS executor_cache ${OP_HEADER_DEPS})
target_link_libraries(run_program_op cuda_graph_with_memory_pool)
......@@ -129,22 +129,6 @@ else()
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
endif()
if (WITH_GPU OR WITH_ROCM)
if (MKL_FOUND AND WITH_ONEMKL)
op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda dynload_mklrt ${OP_HEADER_DEPS})
target_include_directories(spectral_op PRIVATE ${MKL_INCLUDE})
else()
op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda ${OP_HEADER_DEPS})
endif()
else()
if (MKL_FOUND AND WITH_ONEMKL)
op_library(spectral_op SRCS spectral_op.cc DEPS dynload_mklrt ${OP_HEADER_DEPS})
target_include_directories(spectral_op PRIVATE ${MKL_INCLUDE})
else()
op_library(spectral_op SRCS spectral_op.cc DEPS ${OP_HEADER_DEPS})
endif()
endif()
if (WITH_ASCEND_CL)
op_library(sync_batch_norm_op)
endif()
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/spectral_op.h"
#include "paddle/fluid/operators/spectral_helper.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
// FFTC2C
class FFTC2COpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), the input tensor of fft_c2c op.");
AddOutput("Out", "(Tensor), the output tensor of fft_c2c op.");
AddAttr<std::vector<int64_t>>("axes",
"std::vector<int64_t>, the fft axes.");
AddAttr<std::string>("normalization",
"fft_norm_type, the fft normalization type.");
AddAttr<bool>("forward", "bool, the fft direction.");
AddComment(R"DOC(
Compute complex to complex FFT.
)DOC");
}
};
class FFTC2COp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fft_c2c");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "fft_c2c");
const auto axes = ctx->Attrs().Get<std::vector<int64_t>>("axes");
const auto x_dim = ctx->GetInputDim("X");
for (size_t i = 0; i < axes.size(); i++) {
PADDLE_ENFORCE_GT(x_dim[axes[i]],
0,
platform::errors::InvalidArgument(
"Invalid fft n-point (%d).", x_dim[axes[i]]));
}
ctx->ShareDim("X", /*->*/ "Out"); // only for c2c
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
const auto kernel_dtype = framework::ToRealType(in_dtype);
return framework::OpKernelType(kernel_dtype, ctx.GetPlace());
}
};
template <typename T>
class FFTC2CGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("fft_c2c_grad");
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
}
};
class FFTC2CGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
const auto out_grad_name = framework::GradVarName("Out");
OP_INOUT_CHECK(
ctx->HasInput(out_grad_name), "Input", out_grad_name, "fft_c2c_grad");
const auto x_grad_name = framework::GradVarName("X");
OP_INOUT_CHECK(
ctx->HasOutput(x_grad_name), "Output", x_grad_name, "fft_c2c_grad");
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim(out_grad_name));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const auto in_dtype = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
const auto kernel_dtype = framework::ToRealType(in_dtype);
return framework::OpKernelType(kernel_dtype, ctx.GetPlace());
}
};
// FFTR2C
class FFTR2COpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), the input tensor of fft_r2c op.");
AddOutput("Out", "(Tensor), the output tensor of fft_r2c op.");
AddAttr<std::vector<int64_t>>("axes",
"std::vector<int64_t>, the fft axes.");
AddAttr<std::string>("normalization",
"fft_norm_type, the fft normalization type.");
AddAttr<bool>("forward", "bool, the fft direction.");
AddAttr<bool>("onesided", "bool, perform onesided fft.");
AddComment(R"DOC(
Compute real to complex FFT.
)DOC");
}
};
class FFTR2COp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fft_r2c");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "fft_r2c");
const auto axes = ctx->Attrs().Get<std::vector<int64_t>>("axes");
const auto x_dim = ctx->GetInputDim("X");
for (size_t i = 0; i < axes.size() - 1L; i++) {
PADDLE_ENFORCE_GT(x_dim[axes[i]],
0,
platform::errors::InvalidArgument(
"Invalid fft n-point (%d).", x_dim[axes[i]]));
}
const bool onesided = ctx->Attrs().Get<bool>("onesided");
if (!onesided) {
ctx->ShareDim("X", /*->*/ "Out");
} else {
framework::DDim out_dim(ctx->GetInputDim("X"));
const int64_t last_fft_axis = axes.back();
const int64_t last_fft_dim_size = out_dim.at(last_fft_axis);
out_dim.at(last_fft_axis) = last_fft_dim_size / 2 + 1;
ctx->SetOutputDim("Out", out_dim);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(in_dtype, ctx.GetPlace());
}
};
template <typename T>
class FFTR2CGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("fft_r2c_grad");
grad_op->SetInput("X", this->Input("X"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
}
};
class FFTR2CGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
const auto out_grad_name = framework::GradVarName("Out");
OP_INOUT_CHECK(
ctx->HasInput(out_grad_name), "Input", out_grad_name, "fft_r2c_grad");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fft_r2c_grad");
const auto x_grad_name = framework::GradVarName("X");
OP_INOUT_CHECK(
ctx->HasOutput(x_grad_name), "Output", x_grad_name, "fft_r2c_grad");
ctx->ShareDim("X", /*->*/ x_grad_name);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const auto in_dtype = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
const auto kernel_dtype = framework::ToRealType(in_dtype);
return framework::OpKernelType(kernel_dtype, ctx.GetPlace());
}
};
// FFTC2R
class FFTC2ROpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), the input tensor of fft_c2r op.");
AddOutput("Out", "(Tensor), the output tensor of fft_c2r op.");
AddAttr<std::vector<int64_t>>("axes",
"std::vector<int64_t>, the fft axes.");
AddAttr<std::string>("normalization",
"fft_norm_type, the fft normalization type.");
AddAttr<bool>("forward", "bool, the fft direction.");
AddAttr<int64_t>(
"last_dim_size",
"int",
"Length of the transformed "
"axis of the output. For n output points, last_dim_size//2 + 1 input"
" points are necessary. If the input is longer than this,"
" it is cropped. If it is shorter than this, it is padded"
" with zeros. If last_dim_size is not given, it is taken to be 2*(m-1)"
" where m is the length of the input along the axis "
"specified by axis.")
.SetDefault(0L);
AddComment(R"DOC(
Compute complex to complex FFT.
)DOC");
}
};
class FFTC2ROp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fft_c2r");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "fft_c2r");
const auto axes = ctx->Attrs().Get<std::vector<int64_t>>("axes");
const auto x_dim = ctx->GetInputDim("X");
for (size_t i = 0; i < axes.size() - 1L; i++) {
PADDLE_ENFORCE_GT(x_dim[axes[i]],
0,
platform::errors::InvalidArgument(
"Invalid fft n-point (%d).", x_dim[axes[i]]));
}
const int64_t last_dim_size = ctx->Attrs().Get<int64_t>("last_dim_size");
framework::DDim out_dim(ctx->GetInputDim("X"));
const int64_t last_fft_axis = axes.back();
if (last_dim_size == 0) {
const int64_t last_fft_dim_size = out_dim.at(last_fft_axis);
const int64_t fft_n_point = (last_fft_dim_size - 1) * 2;
PADDLE_ENFORCE_GT(fft_n_point,
0,
platform::errors::InvalidArgument(
"Invalid fft n-point (%d).", fft_n_point));
out_dim.at(last_fft_axis) = fft_n_point;
} else {
PADDLE_ENFORCE_GT(last_dim_size,
0,
platform::errors::InvalidArgument(
"Invalid fft n-point (%d).", last_dim_size));
out_dim.at(last_fft_axis) = last_dim_size;
}
ctx->SetOutputDim("Out", out_dim);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
const auto kernel_dtype = framework::ToRealType(in_dtype);
return framework::OpKernelType(kernel_dtype, ctx.GetPlace());
}
};
template <typename T>
class FFTC2RGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("fft_c2r_grad");
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
}
};
class FFTC2RGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
const auto out_grad_name = framework::GradVarName("Out");
OP_INOUT_CHECK(
ctx->HasInput(out_grad_name), "Input", out_grad_name, "fft_c2r_grad");
const auto x_grad_name = framework::GradVarName("X");
OP_INOUT_CHECK(
ctx->HasOutput(x_grad_name), "Output", x_grad_name, "fft_c2r_grad");
const auto axes = ctx->Attrs().Get<std::vector<int64_t>>("axes");
const auto out_grad_dim = ctx->GetInputDim(out_grad_name);
framework::DDim x_grad_dim(out_grad_dim);
const int64_t last_fft_axis = axes.back();
const int64_t last_fft_dim_size = x_grad_dim.at(last_fft_axis);
x_grad_dim.at(last_fft_axis) = last_fft_dim_size / 2 + 1;
ctx->SetOutputDim(x_grad_name, x_grad_dim);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const auto in_dtype = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
return framework::OpKernelType(in_dtype, ctx.GetPlace());
}
};
// common functions
FFTNormMode get_norm_from_string(const std::string& norm, bool forward) {
if (norm.empty() || norm == "backward") {
return forward ? FFTNormMode::none : FFTNormMode::by_n;
}
if (norm == "forward") {
return forward ? FFTNormMode::by_n : FFTNormMode::none;
}
if (norm == "ortho") {
return FFTNormMode::by_sqrt_n;
}
PADDLE_THROW(platform::errors::InvalidArgument(
"FFT norm string must be 'forward' or 'backward' or 'ortho', "
"received %s",
norm));
}
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(fft_c2c,
ops::FFTC2COp,
ops::FFTC2COpMaker,
ops::FFTC2CGradOpMaker<paddle::framework::OpDesc>,
ops::FFTC2CGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(fft_c2c,
ops::FFTC2CKernel<phi::CPUContext, float>,
ops::FFTC2CKernel<phi::CPUContext, double>);
REGISTER_OPERATOR(fft_c2c_grad, ops::FFTC2CGradOp);
REGISTER_OP_CPU_KERNEL(fft_c2c_grad,
ops::FFTC2CGradKernel<phi::CPUContext, float>,
ops::FFTC2CGradKernel<phi::CPUContext, double>);
REGISTER_OPERATOR(fft_r2c,
ops::FFTR2COp,
ops::FFTR2COpMaker,
ops::FFTR2CGradOpMaker<paddle::framework::OpDesc>,
ops::FFTR2CGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(fft_r2c,
ops::FFTR2CKernel<phi::CPUContext, float>,
ops::FFTR2CKernel<phi::CPUContext, double>);
REGISTER_OPERATOR(fft_r2c_grad, ops::FFTR2CGradOp);
REGISTER_OP_CPU_KERNEL(fft_r2c_grad,
ops::FFTR2CGradKernel<phi::CPUContext, float>,
ops::FFTR2CGradKernel<phi::CPUContext, double>);
REGISTER_OPERATOR(fft_c2r,
ops::FFTC2ROp,
ops::FFTC2ROpMaker,
ops::FFTC2RGradOpMaker<paddle::framework::OpDesc>,
ops::FFTC2RGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(fft_c2r,
ops::FFTC2RKernel<phi::CPUContext, float>,
ops::FFTC2RKernel<phi::CPUContext, double>);
REGISTER_OPERATOR(fft_c2r_grad, ops::FFTC2RGradOp);
REGISTER_OP_CPU_KERNEL(fft_c2r_grad,
ops::FFTC2RGradKernel<phi::CPUContext, float>,
ops::FFTC2RGradKernel<phi::CPUContext, double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/spectral_op.h"
#include "paddle/fluid/operators/spectral_op.cu.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(fft_c2c,
ops::FFTC2CKernel<phi::GPUContext, float>,
ops::FFTC2CKernel<phi::GPUContext, double>);
REGISTER_OP_CUDA_KERNEL(fft_c2c_grad,
ops::FFTC2CGradKernel<phi::GPUContext, float>,
ops::FFTC2CGradKernel<phi::GPUContext, double>);
REGISTER_OP_CUDA_KERNEL(fft_c2r,
ops::FFTC2RKernel<phi::GPUContext, float>,
ops::FFTC2RKernel<phi::GPUContext, double>);
REGISTER_OP_CUDA_KERNEL(fft_c2r_grad,
ops::FFTC2RGradKernel<phi::GPUContext, float>,
ops::FFTC2RGradKernel<phi::GPUContext, double>);
REGISTER_OP_CUDA_KERNEL(fft_r2c,
ops::FFTR2CKernel<phi::GPUContext, float>,
ops::FFTR2CKernel<phi::GPUContext, double>);
REGISTER_OP_CUDA_KERNEL(fft_r2c_grad,
ops::FFTR2CGradKernel<phi::GPUContext, float>,
ops::FFTR2CGradKernel<phi::GPUContext, double>);
此差异已折叠。
......@@ -14,8 +14,6 @@
#include "paddle/fluid/operators/stft_op.h"
#include "paddle/fluid/operators/spectral_helper.h"
namespace paddle {
namespace operators {
class StftOp : public framework::OperatorWithKernel {
......
......@@ -13,7 +13,6 @@
// limitations under the License.
#include "paddle/fluid/operators/stft_op.h"
#include "paddle/fluid/operators/spectral_op.cu.h"
namespace ops = paddle::operators;
......
......@@ -18,8 +18,11 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/spectral_op.h"
#include "paddle/phi/kernels/complex_kernel.h"
#include "paddle/phi/kernels/funcs/fft.h"
#include "paddle/phi/kernels/funcs/fft_fill_conj.h"
#include "paddle/phi/kernels/funcs/frame_functor.h"
#include "paddle/phi/kernels/funcs/padding.h"
namespace paddle {
namespace operators {
......@@ -76,25 +79,25 @@ class StftKernel : public framework::OpKernel<T> {
ctx, &frames, window, axes.back(), MulFunctor<T>(), &frames_w);
// FFTR2C
FFTNormMode normalization;
phi::funcs::FFTNormMode normalization;
if (normalized) {
normalization = get_norm_from_string("ortho", true);
normalization = phi::funcs::get_norm_from_string("ortho", true);
} else {
normalization = get_norm_from_string("backward", true);
normalization = phi::funcs::get_norm_from_string("backward", true);
}
FFTR2CFunctor<DeviceContext, T, C> fft_r2c_func;
phi::funcs::FFTR2CFunctor<DeviceContext, T, C> fft_r2c_func;
if (onesided) {
fft_r2c_func(dev_ctx, &frames_w, out, axes, normalization, true);
fft_r2c_func(dev_ctx, frames_w, out, axes, normalization, true);
} else {
framework::DDim onesided_dims(out->dims());
const int64_t onesided_axis_size = out->dims().at(axes.back()) / 2 + 1;
onesided_dims.at(axes.back()) = onesided_axis_size;
Tensor onesided_out;
onesided_out.mutable_data<C>(onesided_dims, ctx.GetPlace());
fft_r2c_func(
dev_ctx, &frames_w, &onesided_out, axes, normalization, true);
fill_conj<DeviceContext, C>(dev_ctx, &onesided_out, out, axes);
fft_r2c_func(dev_ctx, frames_w, &onesided_out, axes, normalization, true);
phi::funcs::FFTFillConj<DeviceContext, C>(
dev_ctx, &onesided_out, out, axes);
}
}
};
......@@ -131,17 +134,17 @@ class StftGradKernel : public framework::OpKernel<T> {
complex_d_frames_w.mutable_data<C>(d_frames_dims, ctx.GetPlace());
// dy -> d_frames_w
FFTNormMode normalization;
phi::funcs::FFTNormMode normalization;
if (normalized) {
normalization = get_norm_from_string("ortho", true);
normalization = phi::funcs::get_norm_from_string("ortho", true);
} else {
normalization = get_norm_from_string("backward", true);
normalization = phi::funcs::get_norm_from_string("backward", true);
}
FFTC2CFunctor<DeviceContext, C, C> fft_c2c_func;
phi::funcs::FFTC2CFunctor<DeviceContext, C, C> fft_c2c_func;
if (!onesided) {
fft_c2c_func(
dev_ctx, dy, &complex_d_frames_w, axes, normalization, false);
dev_ctx, *dy, &complex_d_frames_w, axes, normalization, false);
} else {
Tensor full_dy;
full_dy.mutable_data<C>(d_frames_dims, ctx.GetPlace());
......@@ -153,20 +156,11 @@ class StftGradKernel : public framework::OpKernel<T> {
pads[axes.back() * 2 + 1] = zero_length;
phi::funcs::PaddingFunctor<DeviceContext, C>(
rank,
ctx.template device_context<DeviceContext>(),
pads,
static_cast<C>(0),
*dy,
&full_dy);
rank, dev_ctx, pads, static_cast<C>(0), *dy, &full_dy);
fft_c2c_func(
dev_ctx, &full_dy, &complex_d_frames_w, axes, normalization, false);
dev_ctx, full_dy, &complex_d_frames_w, axes, normalization, false);
}
framework::TransComplexToReal(
framework::TransToProtoVarType(d_frames_w.dtype()),
framework::TransToProtoVarType(complex_d_frames_w.dtype()),
complex_d_frames_w,
&d_frames_w);
phi::RealKernel<C>(dev_ctx, complex_d_frames_w, &d_frames_w);
// d_frames_w -> d_frames
Tensor d_frames;
......
......@@ -98,6 +98,33 @@
func : erf
backward : erf_grad
- api : fft_c2c
args : (Tensor x, int64_t[] axes, str normalization, bool forward)
output : Tensor
infer_meta :
func : FFTC2CInferMeta
kernel :
func : fft_c2c
backward : fft_c2c_grad
- api : fft_c2r
args : (Tensor x, int64_t[] axes, str normalization, bool forward, int64_t last_dim_size=0L)
output : Tensor
infer_meta :
func : FFTC2RInferMeta
kernel :
func : fft_c2r
backward : fft_c2r_grad
- api : fft_r2c
args : (Tensor x, int64_t[] axes, str normalization, bool forward, bool onesided)
output : Tensor
infer_meta :
func : FFTR2CInferMeta
kernel :
func : fft_r2c
backward : fft_r2c_grad
- api : lgamma
args : (Tensor x)
output : Tensor(out)
......
......@@ -31,6 +31,15 @@
float Scale_in_eltwise = 1.0f, 'float[] Scale_weights = {1.0f}', bool force_fp32_output = false,
int workspace_size_MB = 512, bool exhaustive_search = false]
- api : conv2d
extra :
attrs : [bool use_cudnn = false, bool fuse_relu_before_depthwise_conv = false, bool use_mkldnn = false,
bool use_quantizer = false, str mkldnn_data_type = "float32", bool fuse_relu = false,
str fuse_activation = "", bool fuse_alpha = false, bool fuse_beta = false, bool use_addto = false,
bool fuse_residual_connection = false, float Scale_in = 1.0f, float Scale_out = 1.0f,
float Scale_in_eltwise = 1.0f, 'float[] Scale_weights = {1.0f}', bool force_fp32_output = false,
int workspace_size_MB = 512, bool exhaustive_search = false]
- api : cross
inputs :
{x : X, y : Y}
......@@ -112,3 +121,15 @@
x : X
outputs :
out : Out
- api: fft_c2c
inputs: {x: X}
outputs: {out: Out}
- api: fft_c2r
inputs: {x: X}
outputs: {out: Out}
- api: fft_r2c
inputs: {x: X}
outputs: {out: Out}
......@@ -105,6 +105,38 @@
func : erf_grad
data_type : out_grad
- backward_api : fft_c2c_grad
forward: fft_c2c(Tensor x, int64_t[] axes, str normalization, bool forward) -> Tensor(out)
args : (Tensor out_grad, int64_t[] axes, str normalization, bool forward)
output: Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [out_grad]
kernel :
func : fft_c2c_grad
- backward_api : fft_c2r_grad
forward: fft_c2r(Tensor x, int64_t[] axes, str normalization, bool forward, int64_t last_dim_size) -> Tensor(out)
args : (Tensor out_grad, int64_t[] axes, str normalization, bool forward, int64_t last_dim_size)
output: Tensor(x_grad)
infer_meta :
func : FFTC2RGradInferMeta
kernel :
func : fft_c2r_grad
data_type: out_grad
- backward_api : fft_r2c_grad
forward: fft_r2c(Tensor x, int64_t[] axes, str normalization, bool forward, bool onesided) -> Tensor(out)
args : (Tensor x, Tensor out_grad, int64_t[] axes, str normalization, bool forward, bool onesided)
output: Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : fft_r2c_grad
data_type: out_grad
no_need_buffer: x
- backward_api : lgamma_grad
forward : lgamma(Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
......
......@@ -97,4 +97,18 @@ inline DataType ToComplexType(const DataType& type) {
type));
}
}
inline DataType ToRealType(const DataType& type) {
switch (type) {
case DataType::COMPLEX64:
return DataType::FLOAT32;
case DataType::COMPLEX128:
return DataType::FLOAT64;
default:
PADDLE_THROW(errors::Unimplemented(
"Can not transform data type (%s) to real type, now only support "
"complex64 and complex128 value.",
type));
}
}
} // namespace phi
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
namespace phi {
......@@ -285,6 +286,47 @@ void EigvalshGradInferMeta(const MetaTensor& out_v,
}
}
void FFTC2RGradInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axes,
const std::string& normalization,
bool forward,
int64_t last_dim_size,
MetaTensor* out,
MetaConfig config) {
PADDLE_ENFORCE_NOT_NULL(out,
phi::errors::InvalidArgument(
"Output of fft_c2r _grad should not be null."));
const phi::DDim x_dim = x.dims();
// only ensure that fft axes' size greater than zero at runtime
// they might be -1 to indicate unknown size ar compile time
if (config.is_runtime) {
for (size_t i = 0; i < axes.size(); i++) {
PADDLE_ENFORCE_GT(x_dim[axes[i]],
0,
phi::errors::InvalidArgument(
"Invalid fft n-point (%d).", x_dim[axes[i]]));
}
}
out->set_layout(x.layout());
out->set_dtype(ToComplexType(x.dtype()));
phi::DDim out_dim = x.dims();
const int64_t last_fft_axis = axes.back();
if (last_dim_size > 0) {
out_dim.at(last_fft_axis) = last_dim_size / 2 + 1;
} else if (config.is_runtime) {
const int64_t last_fft_dim_size = x_dim[last_fft_axis];
out_dim.at(last_fft_axis) = last_fft_dim_size / 2 + 1;
} else {
const int64_t last_fft_dim_size = x_dim[last_fft_axis];
out_dim.at(last_fft_axis) =
last_fft_dim_size == -1 ? -1 : last_fft_dim_size / 2 + 1;
}
out->set_dims(out_dim);
}
void FillDiagonalGradInferMeta(const MetaTensor& dout,
float value,
int offset,
......
......@@ -137,6 +137,14 @@ void EigvalshGradInferMeta(const MetaTensor& out_v,
bool is_test,
MetaTensor* x_grad);
void FFTC2RGradInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axes,
const std::string& normalization,
bool forward,
int64_t last_dim_size,
MetaTensor* out,
MetaConfig = MetaConfig());
void FillDiagonalGradInferMeta(
const MetaTensor& dout, float value, int offset, bool wrap, MetaTensor* dx);
......
......@@ -866,6 +866,112 @@ void FillDiagonalInferMeta(
out->set_dtype(x.dtype());
}
void FFTC2CInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axes,
const std::string& normalization,
bool forward,
MetaTensor* out,
MetaConfig config) {
PADDLE_ENFORCE_NOT_NULL(
out,
phi::errors::InvalidArgument("Output of fft_c2c should not be null."));
// only ensure that fft axes' size greater than zero at runtime
// they might be -1 to indicate unknown size ar compile time
if (config.is_runtime) {
const phi::DDim x_dim = x.dims();
for (size_t i = 0; i < axes.size(); i++) {
PADDLE_ENFORCE_GT(x_dim[axes[i]],
0,
phi::errors::InvalidArgument(
"Invalid fft n-point (%d).", x_dim[axes[i]]));
}
}
out->share_meta(x);
}
void FFTC2RInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axes,
const std::string& normalization,
bool forward,
int64_t last_dim_size,
MetaTensor* out,
MetaConfig config) {
PADDLE_ENFORCE_NOT_NULL(
out,
phi::errors::InvalidArgument("Output of fft_c2r should not be null."));
const phi::DDim x_dim = x.dims();
const int64_t last_fft_axis = axes.back();
// only ensure that fft axes' size greater than zero at runtime
// they might be -1 to indicate unknown size ar compile time
if (config.is_runtime) {
size_t signal_dims = axes.size();
for (size_t i = 0; i < signal_dims - 1; i++) {
PADDLE_ENFORCE_GT(x_dim[axes[i]],
0,
phi::errors::InvalidArgument(
"Invalid fft n-point (%d).", x_dim[axes[i]]));
}
}
out->set_layout(x.layout());
out->set_dtype(ToRealType(x.dtype()));
phi::DDim out_dim = x_dim;
if (last_dim_size > 0) {
out_dim.at(last_fft_axis) = last_dim_size;
} else if (config.is_runtime) {
const int64_t input_last_dim_size = x_dim[last_fft_axis];
const int64_t fft_n_point = (input_last_dim_size - 1) * 2;
PADDLE_ENFORCE_GT(
fft_n_point,
0,
phi::errors::InvalidArgument("Invalid fft n-point (%d).", fft_n_point));
out_dim.at(last_fft_axis) = fft_n_point;
} else {
const int64_t input_last_dim_size = x_dim[last_fft_axis];
out_dim.at(last_fft_axis) =
input_last_dim_size == -1 ? -1 : (input_last_dim_size - 1) * 2;
}
out->set_dims(out_dim);
}
void FFTR2CInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axes,
const std::string& normalization,
bool forward,
bool onesided,
MetaTensor* out,
MetaConfig config) {
PADDLE_ENFORCE_NOT_NULL(
out,
phi::errors::InvalidArgument("Output of fft_r2c should not be null."));
const phi::DDim x_dim = x.dims();
// only ensure that fft axes' size greater than zero at runtime
// they might be -1 to indicate unknown size ar compile time
if (config.is_runtime) {
for (size_t i = 0; i < axes.size(); i++) {
PADDLE_ENFORCE_GT(x_dim[axes[i]],
0,
phi::errors::InvalidArgument(
"Invalid fft n-point (%d).", x_dim[axes[i]]));
}
}
out->set_layout(x.layout());
out->set_dtype(ToComplexType(x.dtype()));
if (!onesided) {
out->share_dims(x);
} else {
phi::DDim out_dim = x.dims();
const int64_t last_fft_axis = axes.back();
const int64_t last_fft_dim_size = x_dim[last_fft_axis];
out_dim.at(last_fft_axis) = last_fft_dim_size / 2 + 1;
out->set_dims(out_dim);
}
}
void FlattenInferMeta(const MetaTensor& x,
int start_axis,
int stop_axis,
......
......@@ -135,6 +135,29 @@ void ExpandInferMeta(const MetaTensor& x,
void FillDiagonalInferMeta(
const MetaTensor& x, float value, int offset, bool wrap, MetaTensor* out);
void FFTC2CInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axes,
const std::string& normalization,
bool forward,
MetaTensor* out,
MetaConfig = MetaConfig());
void FFTC2RInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axes,
const std::string& normalization,
bool forward,
int64_t last_dim_size,
MetaTensor* out,
MetaConfig = MetaConfig());
void FFTR2CInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axes,
const std::string& normalization,
bool forward,
bool onesided,
MetaTensor* out,
MetaConfig = MetaConfig());
void FlattenInferMeta(const MetaTensor& x,
int start_axis,
int stop_axis,
......
......@@ -65,7 +65,8 @@ set(COMMON_KERNEL_DEPS
matrix_solve
phi_dynload_warpctc
sequence_padding
sequence_scale)
sequence_scale
fft)
set(COMMON_KERNEL_DEPS
${COMMON_KERNEL_DEPS}
......
......@@ -18,6 +18,7 @@
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/unary.h"
namespace phi {
......@@ -26,6 +27,16 @@ void AssignKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out);
template <typename Context>
DenseTensor Assign(const Context& dev_ctx, const DenseTensor& x) {
DenseTensor out;
MetaTensor meta_out(&out);
MetaTensor meta_x(x);
UnchangedInferMeta(meta_x, &meta_out);
AssignKernel<Context>(dev_ctx, x, &out);
return out;
}
// In order to be compatible with the `AsDispensable` input in the original
// assign op maker, the input parameter here needs to be dispensable, but
// this looks weird
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fft_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/fft_grad_kernel_impl.h"
PD_REGISTER_KERNEL(fft_c2c_grad,
CPU,
ALL_LAYOUT,
phi::FFTC2CGradKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(
fft_c2r_grad, CPU, ALL_LAYOUT, phi::FFTC2RGradKernel, float, double) {}
PD_REGISTER_KERNEL(fft_r2c_grad,
CPU,
ALL_LAYOUT,
phi::FFTR2CGradKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fft_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/fft_kernel_impl.h"
PD_REGISTER_KERNEL(fft_c2c,
CPU,
ALL_LAYOUT,
phi::FFTC2CKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(fft_c2r,
CPU,
ALL_LAYOUT,
phi::FFTC2RKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(fft_r2c, CPU, ALL_LAYOUT, phi::FFTR2CKernel, float, double) {
}
......@@ -62,4 +62,6 @@ PD_REGISTER_KERNEL(scale,
int8_t,
int16_t,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void FFTC2CGradKernel(const Context& ctx,
const DenseTensor& out_grad,
const std::vector<int64_t>& axes,
const std::string& normalization,
bool forward,
DenseTensor* x_grad);
template <typename T, typename Context>
void FFTC2RGradKernel(const Context& ctx,
const DenseTensor& out_grad,
const std::vector<int64_t>& axes,
const std::string& normalization,
bool forward,
int64_t last_dim_size,
DenseTensor* x_grad);
template <typename T, typename Context>
void FFTR2CGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const std::vector<int64_t>& axes,
const std::string& normalization,
bool forward,
bool onesided,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void FFTC2CKernel(const Context& ctx,
const DenseTensor& x,
const std::vector<int64_t>& axes,
const std::string& normalization,
bool forward,
DenseTensor* out);
template <typename T, typename Context>
void FFTC2RKernel(const Context& ctx,
const DenseTensor& x,
const std::vector<int64_t>& axes,
const std::string& normalization,
bool forward,
int64_t last_dim_size,
DenseTensor* out);
template <typename T, typename Context>
void FFTR2CKernel(const Context& ctx,
const DenseTensor& x,
const std::vector<int64_t>& axes,
const std::string& normalization,
bool forward,
bool onesided,
DenseTensor* out);
} // namespace phi
......@@ -16,3 +16,20 @@ math_library(pooling DEPS dense_tensor)
math_library(segment_pooling)
math_library(sequence2batch)
math_library(matrix_solve DEPS dense_tensor eigen3 blas math_function)
if(WITH_GPU OR WITH_ROCM)
if(MKL_FOUND AND WITH_ONEMKL)
math_library(fft spectral_op.cu DEPS dynload_cuda dynload_mklrt
dense_tensor)
target_include_directories(fft PRIVATE ${MKL_INCLUDE})
else()
math_library(fft spectral_op.cu DEPS dynload_cuda dense_tensor pocketfft)
endif()
else()
if(MKL_FOUND AND WITH_ONEMKL)
mathp_library(fft DEPS dynload_mklrt dense_tensor)
target_include_directories(fft PRIVATE ${MKL_INCLUDE})
else()
math_library(fft DEPS dense_tensor pocketfft)
endif()
endif()
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/backends/dynload/cufft.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/funcs/fft.h"
#include "paddle/phi/kernels/funcs/fft_key.h"
namespace phi {
namespace funcs {
namespace detail {
// An RAII encapsulation of cuFFTHandle
class CuFFTHandle {
public:
CuFFTHandle() {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cufftCreate(&handle_));
}
CuFFTHandle(const CuFFTHandle& other) = delete;
CuFFTHandle& operator=(const CuFFTHandle& other) = delete;
CuFFTHandle(CuFFTHandle&& other) = delete;
CuFFTHandle& operator=(CuFFTHandle&& other) = delete;
::cufftHandle& get() { return handle_; }
const ::cufftHandle& get() const { return handle_; }
~CuFFTHandle() { phi::dynload::cufftDestroy(handle_); }
private:
::cufftHandle handle_;
};
// Returns true if the transform type has complex input
inline bool has_complex_input(FFTTransformType type) {
switch (type) {
case FFTTransformType::C2C:
case FFTTransformType::C2R:
return true;
case FFTTransformType::R2C:
return false;
}
PADDLE_THROW(phi::errors::InvalidArgument("Unknown FFTTransformType"));
}
// Returns true if the transform type has complex output
inline bool has_complex_output(FFTTransformType type) {
switch (type) {
case FFTTransformType::C2C:
case FFTTransformType::R2C:
return true;
case FFTTransformType::C2R:
return false;
}
PADDLE_THROW(phi::errors::InvalidArgument("Unknown FFTTransformType"));
}
class FFTConfig {
public:
using plan_size_type = long long int; // NOLINT (be consistent with cufft)
explicit FFTConfig(const FFTConfigKey& key)
: FFTConfig(
std::vector<int64_t>(key.sizes_, key.sizes_ + key.signal_ndim_ + 1),
key.fft_type_,
key.value_type_) {}
// sizes are full signal, including batch size and always two-sided
FFTConfig(const std::vector<int64_t>& sizes,
FFTTransformType fft_type,
DataType precison)
: fft_type_(fft_type), precision_(precison) {
const auto batch_size = static_cast<plan_size_type>(sizes[0]);
std::vector<plan_size_type> signal_sizes(sizes.cbegin() + 1, sizes.cend());
const int signal_ndim = sizes.size() - 1;
cudaDataType itype, otype, exec_type;
const bool complex_input = has_complex_input(fft_type);
const bool complex_output = has_complex_output(fft_type);
if (precison == DataType::FLOAT32) {
itype = complex_input ? CUDA_C_32F : CUDA_R_32F;
otype = complex_output ? CUDA_C_32F : CUDA_R_32F;
exec_type = CUDA_C_32F;
} else if (precison == DataType::FLOAT64) {
itype = complex_input ? CUDA_C_64F : CUDA_R_64F;
otype = complex_output ? CUDA_C_64F : CUDA_R_64F;
exec_type = CUDA_C_64F;
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Only transforms of type float32 and float64 are supported."));
}
// disable auto allocation of workspace to use allocator from the framework
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::cufftSetAutoAllocation(plan(), /* autoAllocate */ 0));
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::cufftXtMakePlanMany(plan(),
signal_ndim,
signal_sizes.data(),
/* inembed */ nullptr,
/* base_istride */ 1L,
/* idist */ 1L,
itype,
/* onembed */ nullptr,
/* base_ostride */ 1L,
/* odist */ 1L,
otype,
batch_size,
&ws_size_,
exec_type));
}
FFTConfig(const FFTConfig& other) = delete;
FFTConfig& operator=(const FFTConfig& other) = delete;
FFTConfig(FFTConfig&& other) = delete;
FFTConfig& operator=(FFTConfig&& other) = delete;
const cufftHandle& plan() const { return plan_.get(); }
FFTTransformType transform_type() const { return fft_type_; }
DataType data_type() const { return precision_; }
size_t workspace_size() const { return ws_size_; }
private:
CuFFTHandle plan_;
size_t ws_size_; // workspace size in bytes
FFTTransformType fft_type_;
DataType precision_;
};
// NOTE: R2C is forward-only, C2R is backward only
static void exec_plan(const FFTConfig& config,
void* in_data,
void* out_data,
bool forward) {
auto& plan = config.plan();
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cufftXtExec(
plan, in_data, out_data, forward ? CUFFT_FORWARD : CUFFT_INVERSE));
}
} // namespace detail
} // namespace funcs
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cmath>
#include "paddle/phi/kernels/funcs/fft.h"
#include "paddle/phi/kernels/funcs/fft_cache.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/assign_kernel.h"
#include "paddle/phi/kernels/complex_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/scale_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
namespace phi {
namespace funcs {
namespace detail {
// Use the optimized path to perform single R2C or C2R if transformation dim is
// supported by cuFFT
static bool use_optimized_fft_path(const std::vector<int64_t>& axes) {
// For performance reason, when axes starts with (0, 1), do not use the
// optimized path.
if (axes.size() > kMaxFFTNdim ||
(axes.size() >= 2 && axes[0] == 0 && axes[1] == 1)) {
return false;
} else {
return true;
}
}
static double fft_normalization_scale(FFTNormMode normalization,
const std::vector<int64_t>& sizes,
const std::vector<int64_t>& dims) {
// auto norm = static_cast<fft_norm_mode>(normalization);
if (normalization == FFTNormMode::none) {
return static_cast<double>(1.0);
}
int64_t signal_numel = 1;
for (auto dim : dims) {
signal_numel *= sizes[dim];
}
const double scale_denom = (normalization == FFTNormMode::by_sqrt_n)
? std::sqrt(signal_numel)
: static_cast<double>(signal_numel);
return static_cast<double>(1.0 / scale_denom);
}
template <typename T>
void exec_normalization(const phi::GPUContext& ctx,
const DenseTensor& in,
DenseTensor* out,
FFTNormMode normalization,
const std::vector<int64_t>& sizes,
const std::vector<int64_t>& axes) {
const double scale = fft_normalization_scale(normalization, sizes, axes);
if (scale != 1.0) {
ScaleKernel<T, phi::GPUContext>(ctx, in, scale, 0, true, out);
} else {
AssignKernel<phi::GPUContext>(ctx, in, out);
}
}
bool has_large_prime_factor(int64_t n) {
constexpr int64_t first_large_prime = 11;
const std::array<int64_t, 4> prime_radices{{2, 3, 5, 7}};
for (auto prime : prime_radices) {
if (n < first_large_prime) {
return false;
}
while (n % prime == 0) {
n /= prime;
}
}
return n != 1;
}
#if defined(PADDLE_WITH_CUDA)
inline bool use_cache(const int64_t* signal_size) {
bool using_cache = true;
int cufft_version;
phi::dynload::cufftGetVersion(&cufft_version);
if (10300 <= cufft_version && cufft_version <= 10400) {
using_cache = std::none_of(
signal_size + 1, signal_size + kMaxDataNdim, [](int64_t dim_size) {
return has_large_prime_factor(dim_size);
});
}
return using_cache;
}
#elif defined(PADDLE_WITH_HIP)
inline bool use_cache(const int64_t* signal_size) { return true; }
#endif
// up to 3d unnormalized fft transform (c2r, r2c, c2c)
template <typename Ti, typename To>
void exec_fft(const phi::GPUContext& ctx,
const DenseTensor& x,
DenseTensor* out,
const std::vector<int64_t>& axes,
bool forward) {
const phi::DDim& in_sizes = x.dims();
const int ndim = in_sizes.size();
const int signal_ndim = axes.size();
const int batch_ndim = ndim - signal_ndim;
const phi::DDim& out_sizes = out->dims();
// make a dim permutation
std::vector<int> dim_permute(ndim);
std::iota(dim_permute.begin(), dim_permute.end(), 0);
std::vector<bool> is_transformed_dim(ndim, false);
for (const auto& d : axes) {
is_transformed_dim[d] = true;
}
const auto batch_end =
std::partition(dim_permute.begin(), dim_permute.end(), [&](size_t axis) {
return !is_transformed_dim[axis];
});
std::copy(axes.cbegin(), axes.cend(), batch_end);
// transpose input according to the permutation
DenseTensor transposed_input =
Transpose<Ti, phi::GPUContext>(ctx, x, dim_permute);
const phi::DDim transposed_input_shape = transposed_input.dims();
// batch size
int64_t batch_size = 1L;
for (int i = 0; i < batch_ndim; i++) {
batch_size *= transposed_input_shape[i];
}
// make an collapsed input: collapse batch axes for input
std::vector<int64_t> collapsed_input_shape_;
collapsed_input_shape_.reserve(1 + signal_ndim);
collapsed_input_shape_.emplace_back(batch_size);
for (int i = 0; i < signal_ndim; i++) {
collapsed_input_shape_.push_back(in_sizes[axes[i]]);
}
phi::DDim collapsed_input_shape = phi::make_ddim(collapsed_input_shape_);
transposed_input.Resize(collapsed_input_shape);
DenseTensor& collapsed_input = transposed_input;
// make a collapsed output
phi::DDim transposed_output_shape = out_sizes.transpose(dim_permute);
std::vector<int64_t> collapsed_output_shape_;
collapsed_output_shape_.reserve(1 + signal_ndim);
collapsed_output_shape_.emplace_back(batch_size);
for (int i = 0; i < signal_ndim; i++) {
collapsed_output_shape_.push_back(out_sizes[axes[i]]);
}
phi::DDim collapsed_output_shape = phi::make_ddim(collapsed_output_shape_);
DenseTensor collapsed_output;
collapsed_output.Resize(collapsed_output_shape);
ctx.Alloc<To>(&collapsed_output);
FFTConfigKey key =
create_fft_configkey(collapsed_input, collapsed_output, signal_ndim);
int64_t device_id = ctx.GetPlace().GetDeviceId();
FFTConfig* config = nullptr;
std::unique_ptr<FFTConfig> config_ = nullptr;
bool using_cache = use_cache(key.sizes_);
if (using_cache) {
FFTConfigCache& plan_cache = get_fft_plan_cache(device_id);
std::unique_lock<std::mutex> guard(plan_cache.mutex, std::defer_lock);
guard.lock();
config = &(plan_cache.lookup(key));
} else {
config_ = std::make_unique<FFTConfig>(key);
config = config_.get();
}
const int64_t workspace_size = static_cast<int64_t>(config->workspace_size());
DenseTensor workspace_tensor = Empty<uint8_t>(ctx, {workspace_size});
// prepare cufft for execution
#if defined(PADDLE_WITH_CUDA)
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::cufftSetStream(config->plan(), ctx.stream()));
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::cufftSetWorkArea(config->plan(), workspace_tensor.data()));
#elif defined(PADDLE_WITH_HIP)
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::hipfftSetStream(config->plan(), ctx.stream()));
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::hipfftSetWorkArea(config->plan(), workspace_tensor.data()));
#endif
// execution of fft plan
const FFTTransformType fft_type = config->transform_type();
if (fft_type == FFTTransformType::C2R && forward) {
ConjKernel<Ti, phi::GPUContext>(ctx, collapsed_input, &collapsed_input);
exec_plan(*config, collapsed_input.data(), collapsed_output.data(), false);
} else if (fft_type == FFTTransformType::R2C && !forward) {
exec_plan(*config, collapsed_input.data(), collapsed_output.data(), true);
ConjKernel<To, phi::GPUContext>(ctx, collapsed_output, &collapsed_output);
} else {
exec_plan(
*config, collapsed_input.data(), collapsed_output.data(), forward);
}
// resize for the collapsed output
collapsed_output.Resize(transposed_output_shape);
phi::DenseTensor& transposed_output = collapsed_output;
// reverse the transposition
std::vector<int> reverse_dim_permute(ndim);
for (int i = 0; i < ndim; i++) {
reverse_dim_permute[dim_permute[i]] = i;
}
TransposeKernel<To, phi::GPUContext>(
ctx, transposed_output, reverse_dim_permute, out);
}
} // namespace detail
template <typename Ti, typename To>
struct FFTC2CFunctor<phi::GPUContext, Ti, To> {
void operator()(const phi::GPUContext& ctx,
const DenseTensor& x,
DenseTensor* out,
const std::vector<int64_t>& axes,
FFTNormMode normalization,
bool forward) {
if (axes.empty()) {
AssignKernel<phi::GPUContext>(ctx, x, out);
return;
}
std::vector<int64_t> working_axes = axes;
std::sort(working_axes.begin(), working_axes.end());
std::vector<int64_t> first_dims;
size_t max_dims;
DenseTensor working_tensor = x; // shallow copy
while (true) {
max_dims = std::min(static_cast<size_t>(detail::kMaxFFTNdim),
working_axes.size());
first_dims.assign(working_axes.end() - max_dims, working_axes.end());
detail::exec_fft<Ti, To>(ctx, working_tensor, out, first_dims, forward);
working_axes.resize(working_axes.size() - max_dims);
first_dims.clear();
if (working_axes.empty()) {
break;
}
if (working_tensor.IsSharedWith(x)) {
working_tensor = std::move(*out);
*out = EmptyLike<Ti>(ctx, x);
} else {
std::swap(*out, working_tensor);
}
}
std::vector<int64_t> out_dims = phi::vectorize(x.dims());
detail::exec_normalization<To>(
ctx, *out, out, normalization, out_dims, axes);
}
};
template <typename Ti, typename To>
struct FFTC2RFunctor<phi::GPUContext, Ti, To> {
void operator()(const phi::GPUContext& ctx,
const DenseTensor& x,
DenseTensor* out,
const std::vector<int64_t>& axes,
FFTNormMode normalization,
bool forward) {
std::vector<int64_t> out_dims = phi::vectorize(out->dims());
if (detail::use_optimized_fft_path(axes)) {
DenseTensor x_copy = Assign(ctx, x);
detail::exec_fft<Ti, To>(ctx, x_copy, out, axes, forward);
} else {
DenseTensor c2c_result = EmptyLike<Ti, phi::GPUContext>(ctx, x);
FFTC2CFunctor<phi::GPUContext, Ti, Ti> c2c_functor;
c2c_functor(ctx,
x,
&c2c_result,
{axes.begin(), axes.end() - 1},
FFTNormMode::none,
forward);
detail::exec_fft<Ti, To>(ctx, c2c_result, out, {axes.back()}, forward);
}
detail::exec_normalization<To>(
ctx, *out, out, normalization, out_dims, axes);
}
};
template <typename Ti, typename To>
struct FFTR2CFunctor<phi::GPUContext, Ti, To> {
void operator()(const phi::GPUContext& ctx,
const DenseTensor& x,
DenseTensor* out,
const std::vector<int64_t>& axes,
FFTNormMode normalization,
bool forward) {
if (detail::use_optimized_fft_path(axes)) {
detail::exec_fft<Ti, To>(ctx, x, out, axes, forward);
} else {
DenseTensor r2c_result = EmptyLike<To, phi::GPUContext>(ctx, *out);
detail::exec_fft<Ti, To>(ctx, x, &r2c_result, {axes.back()}, forward);
FFTC2CFunctor<phi::GPUContext, To, To> fft_c2c_func;
fft_c2c_func(ctx,
r2c_result,
out,
{axes.begin(), axes.end() - 1},
FFTNormMode::none,
forward);
}
const auto in_dims = phi::vectorize(x.dims());
detail::exec_normalization<To>(
ctx, *out, out, normalization, in_dims, axes);
}
};
using complex64_t = phi::dtype::complex<float>;
using complex128_t = phi::dtype::complex<double>;
template struct FFTC2CFunctor<phi::GPUContext, complex64_t, complex64_t>;
template struct FFTC2CFunctor<phi::GPUContext, complex128_t, complex128_t>;
template struct FFTC2RFunctor<phi::GPUContext, complex64_t, float>;
template struct FFTC2RFunctor<phi::GPUContext, complex128_t, double>;
template struct FFTR2CFunctor<phi::GPUContext, float, complex64_t>;
template struct FFTR2CFunctor<phi::GPUContext, double, complex128_t>;
} // namespace funcs
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi {
namespace funcs {
enum class FFTNormMode : int8_t {
none, // No normalization
by_sqrt_n, // Divide by sqrt(signal_size)
by_n, // Divide by signal_size
};
inline FFTNormMode get_norm_from_string(const std::string& norm, bool forward) {
if (norm.empty() || norm == "backward") {
return forward ? FFTNormMode::none : FFTNormMode::by_n;
}
if (norm == "forward") {
return forward ? FFTNormMode::by_n : FFTNormMode::none;
}
if (norm == "ortho") {
return FFTNormMode::by_sqrt_n;
}
PADDLE_THROW(phi::errors::InvalidArgument(
"FFT norm string must be 'forward' or 'backward' or 'ortho', "
"received %s",
norm));
}
enum class FFTTransformType : int8_t {
C2C = 0, // Complex-to-complex
R2C, // Real-to-complex
C2R, // Complex-to-real
};
// Create transform type enum from bools representing if input and output are
// complex
inline FFTTransformType GetFFTTransformType(DataType input_dtype,
DataType output_dtype) {
auto complex_input = IsComplexType(input_dtype);
auto complex_output = IsComplexType(output_dtype);
if (complex_input && complex_output) {
return FFTTransformType::C2C;
} else if (complex_input && !complex_output) {
return FFTTransformType::C2R;
} else if (!complex_input && complex_output) {
return FFTTransformType::R2C;
}
PADDLE_THROW(
phi::errors::InvalidArgument("Real to real FFTs are not supported"));
}
template <typename DeviceContext, typename Ti, typename To>
struct FFTC2CFunctor {
void operator()(const DeviceContext& ctx,
const DenseTensor& X,
DenseTensor* out,
const std::vector<int64_t>& axes,
FFTNormMode normalization,
bool forward);
};
template <typename DeviceContext, typename Ti, typename To>
struct FFTR2CFunctor {
void operator()(const DeviceContext& ctx,
const DenseTensor& X,
DenseTensor* out,
const std::vector<int64_t>& axes,
FFTNormMode normalization,
bool forward);
};
template <typename DeviceContext, typename Ti, typename To>
struct FFTC2RFunctor {
void operator()(const DeviceContext& ctx,
const DenseTensor& X,
DenseTensor* out,
const std::vector<int64_t>& axes,
FFTNormMode normalization,
bool forward);
};
} // namespace funcs
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <limits>
#include <list>
#include <memory>
#include <mutex>
#include <unordered_map>
#include <utility>
#if defined(PADDLE_WITH_CUDA)
#include "paddle/phi/kernels/funcs/cufft_util.h"
#elif defined(PADDLE_WITH_HIP)
#include "paddle/phi/kernels/funcs/hipfft_util.h"
#endif
namespace phi {
namespace funcs {
namespace detail {
#if CUDA_VERSION < 10000
// Note that the max plan number for CUDA version < 10 has to be 1023
// due to a bug that fails on the 1024th plan
constexpr size_t CUFFT_MAX_PLAN_NUM = 1023;
constexpr size_t CUFFT_DEFAULT_CACHE_SIZE = CUFFT_MAX_PLAN_NUM;
#else
constexpr size_t CUFFT_MAX_PLAN_NUM = std::numeric_limits<size_t>::max();
// The default max cache size chosen for CUDA version > 10 is arbitrary.
// This number puts a limit on how big of a plan cache should we maintain by
// default. Users can always configure it via cufft_set_plan_cache_max_size.
constexpr size_t CUFFT_DEFAULT_CACHE_SIZE = 4096;
#endif
static_assert(CUFFT_MAX_PLAN_NUM >= 0 &&
CUFFT_MAX_PLAN_NUM <= std::numeric_limits<size_t>::max(),
"CUFFT_MAX_PLAN_NUM not in size_t range");
static_assert(CUFFT_DEFAULT_CACHE_SIZE >= 0 &&
CUFFT_DEFAULT_CACHE_SIZE <= CUFFT_MAX_PLAN_NUM,
"CUFFT_DEFAULT_CACHE_SIZE not in [0, CUFFT_MAX_PLAN_NUM] range");
class FFTConfigCache {
public:
using kv_t = typename std::pair<FFTConfigKey, FFTConfig>;
using map_t =
typename std::unordered_map<std::reference_wrapper<FFTConfigKey>,
typename std::list<kv_t>::iterator,
KeyHash<FFTConfigKey>,
KeyEqual<FFTConfigKey>>;
using map_kkv_iter_t = typename map_t::iterator;
FFTConfigCache() : FFTConfigCache(CUFFT_DEFAULT_CACHE_SIZE) {}
explicit FFTConfigCache(int64_t max_size) { _set_max_size(max_size); }
FFTConfigCache(const FFTConfigCache& other) = delete;
FFTConfigCache& operator=(const FFTConfigCache& other) = delete;
FFTConfigCache(FFTConfigCache&& other) noexcept
: _usage_list(std::move(other._usage_list)),
_cache_map(std::move(other._cache_map)),
_max_size(other._max_size) {}
FFTConfigCache& operator=(FFTConfigCache&& other) noexcept {
_usage_list = std::move(other._usage_list);
_cache_map = std::move(other._cache_map);
_max_size = other._max_size;
return *this;
}
// If key is in this cache, return the cached config. Otherwise, emplace the
// config in this cache and return it.
FFTConfig& lookup(FFTConfigKey params) {
PADDLE_ENFORCE_GT(_max_size,
0,
phi::errors::InvalidArgument(
"The max size of FFTConfigCache must be great than 0,"
"But received is [%d]",
_max_size));
map_kkv_iter_t map_it = _cache_map.find(params);
// Hit, put to list front
if (map_it != _cache_map.end()) {
_usage_list.splice(_usage_list.begin(), _usage_list, map_it->second);
return map_it->second->second;
}
// Miss
// remove if needed
if (_usage_list.size() >= _max_size) {
auto last = _usage_list.end();
last--;
_cache_map.erase(last->first);
_usage_list.pop_back();
}
// construct new plan at list front, then insert into _cache_map
_usage_list.emplace_front(std::piecewise_construct,
std::forward_as_tuple(params),
std::forward_as_tuple(params));
auto kv_it = _usage_list.begin();
_cache_map.emplace(std::piecewise_construct,
std::forward_as_tuple(kv_it->first),
std::forward_as_tuple(kv_it));
return kv_it->second;
}
void clear() {
_cache_map.clear();
_usage_list.clear();
}
void resize(int64_t new_size) {
_set_max_size(new_size);
auto cur_size = _usage_list.size();
if (cur_size > _max_size) {
auto delete_it = _usage_list.end();
for (size_t i = 0; i < cur_size - _max_size; i++) {
delete_it--;
_cache_map.erase(delete_it->first);
}
_usage_list.erase(delete_it, _usage_list.end());
}
}
size_t size() const { return _cache_map.size(); }
size_t max_size() const noexcept { return _max_size; }
std::mutex mutex;
private:
// Only sets size and does value check. Does not resize the data structures.
void _set_max_size(int64_t new_size) {
// We check that 0 <= new_size <= CUFFT_MAX_PLAN_NUM here. Since
// CUFFT_MAX_PLAN_NUM is of type size_t, we need to do non-negativity check
// first.
PADDLE_ENFORCE_GE(
new_size,
0,
phi::errors::InvalidArgument(
"cuFFT plan cache size must be non-negative, But received is [%d]",
new_size));
PADDLE_ENFORCE_LE(new_size,
CUFFT_MAX_PLAN_NUM,
phi::errors::InvalidArgument(
"cuFFT plan cache size can not be larger than [%d], "
"But received is [%d]",
CUFFT_MAX_PLAN_NUM,
new_size));
_max_size = static_cast<size_t>(new_size);
}
std::list<kv_t> _usage_list;
map_t _cache_map;
size_t _max_size;
};
static std::vector<std::unique_ptr<FFTConfigCache>> plan_caches;
static std::mutex plan_caches_mutex;
static inline FFTConfigCache& get_fft_plan_cache(int64_t device_index) {
std::lock_guard<std::mutex> guard(plan_caches_mutex);
if (device_index >= plan_caches.size()) {
plan_caches.resize(device_index + 1);
}
if (!plan_caches[device_index]) {
plan_caches[device_index] = std::make_unique<FFTConfigCache>();
}
return *plan_caches[device_index];
}
} // namespace detail
} // namespace funcs
} // namespace phi
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#define NOMINMAX // to use std::min std::max correctly on windows
#include <algorithm>
#include <functional>
#include <iostream>
#include <memory>
#include <numeric>
#include <string>
#include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/padding.h"
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "thrust/device_vector.h"
#endif
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
enum class FFTNormMode : int64_t {
none, // No normalization
by_sqrt_n, // Divide by sqrt(signal_size)
by_n, // Divide by signal_size
};
FFTNormMode get_norm_from_string(const std::string& norm, bool forward);
// Enum representing the FFT type
enum class FFTTransformType : int64_t {
C2C = 0, // Complex-to-complex
R2C, // Real-to-complex
C2R, // Complex-to-real
};
// Create transform type enum from bools representing if input and output are
// complex
inline FFTTransformType GetFFTTransformType(
framework::proto::VarType::Type input_dtype,
framework::proto::VarType::Type output_dtype) {
auto complex_input = framework::IsComplexType(input_dtype);
auto complex_output = framework::IsComplexType(output_dtype);
if (complex_input && complex_output) {
return FFTTransformType::C2C;
} else if (complex_input && !complex_output) {
return FFTTransformType::C2R;
} else if (!complex_input && complex_output) {
return FFTTransformType::R2C;
}
PADDLE_THROW(
platform::errors::InvalidArgument("Real to real FFTs are not supported"));
}
// Returns true if the transform type has complex input
inline bool has_complex_input(FFTTransformType type) {
switch (type) {
case FFTTransformType::C2C:
case FFTTransformType::C2R:
return true;
case FFTTransformType::R2C:
return false;
}
PADDLE_THROW(platform::errors::InvalidArgument("Unknown FFTTransformType"));
}
// Returns true if the transform type has complex output
inline bool has_complex_output(FFTTransformType type) {
switch (type) {
case FFTTransformType::C2C:
case FFTTransformType::R2C:
return true;
case FFTTransformType::C2R:
return false;
}
PADDLE_THROW(platform::errors::InvalidArgument("Unknown FFTTransformType"));
}
template <typename T>
struct FFTFillConjGradFunctor {
T* input_;
const size_t axis_;
const int64_t* strides_;
const size_t double_length_;
FFTFillConjGradFunctor(T* input,
size_t axis,
const int64_t* strides,
size_t double_length)
: input_(input),
axis_(axis),
strides_(strides),
double_length_(double_length) {}
HOSTDEVICE void operator()(size_t index) {
size_t offtset = index; // back
size_t index_i;
for (size_t i = 0; i <= axis_; i++) {
index_i = offtset / strides_[i];
offtset %= strides_[i];
}
if ((0 < index_i) && (index_i < double_length_ + 1)) {
input_[index] *= static_cast<T>(2);
}
}
};
template <typename DeviceContext, typename Ti, typename To>
struct FFTC2CFunctor {
void operator()(const DeviceContext& ctx,
const Tensor* X,
Tensor* out,
const std::vector<int64_t>& axes,
FFTNormMode normalization,
bool forward);
};
template <typename DeviceContext, typename Ti, typename To>
struct FFTR2CFunctor {
void operator()(const DeviceContext& ctx,
const Tensor* X,
Tensor* out,
const std::vector<int64_t>& axes,
FFTNormMode normalization,
bool forward);
};
template <typename DeviceContext, typename Ti, typename To>
struct FFTC2RFunctor {
void operator()(const DeviceContext& ctx,
const Tensor* X,
Tensor* out,
const std::vector<int64_t>& axes,
FFTNormMode normalization,
bool forward);
};
namespace phi {
namespace funcs {
// Giving a linear destination index and strides of tensor, get_idx return the
// corresponding linear position of source tensor.
......@@ -271,9 +137,9 @@ struct FFTFillConjFunctor {
};
template <typename DeviceContext, typename C>
void fill_conj(const DeviceContext& ctx,
const Tensor* src,
Tensor* dst,
void FFTFillConj(const DeviceContext& ctx,
const DenseTensor* src,
DenseTensor* dst,
const std::vector<int64_t>& axes) {
std::vector<int64_t> src_strides_v =
phi::vectorize<int64_t>(phi::stride(src->dims()));
......@@ -306,7 +172,7 @@ void fill_conj(const DeviceContext& ctx,
const auto dst_shape = dst_shape_v.data();
const auto p_is_fft_axis = _is_fft_axis.get();
#endif
platform::ForRange<DeviceContext> for_range(ctx, dst->numel());
ForRange<DeviceContext> for_range(ctx, dst->numel());
FFTFillConjFunctor<C> fill_conj_functor(src_data,
dst_data,
src_strides,
......@@ -319,189 +185,35 @@ void fill_conj(const DeviceContext& ctx,
for_range(fill_conj_functor);
}
template <typename DeviceContext, typename T>
class FFTC2CKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using C = paddle::platform::complex<T>;
auto& dev_ctx = ctx.device_context<DeviceContext>();
auto axes = ctx.Attr<std::vector<int64_t>>("axes");
const std::string& norm_str = ctx.Attr<std::string>("normalization");
const bool forward = ctx.Attr<bool>("forward");
const auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Output<Tensor>("Out");
y->mutable_data<C>(ctx.GetPlace());
auto normalization = get_norm_from_string(norm_str, forward);
FFTC2CFunctor<DeviceContext, C, C> fft_c2c_func;
fft_c2c_func(dev_ctx, x, y, axes, normalization, forward);
}
};
template <typename DeviceContext, typename T>
class FFTC2CGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using C = paddle::platform::complex<T>;
auto& dev_ctx = ctx.device_context<DeviceContext>();
auto axes = ctx.Attr<std::vector<int64_t>>("axes");
const std::string& norm_str = ctx.Attr<std::string>("normalization");
const bool forward = ctx.Attr<bool>("forward");
const auto* dy = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
dx->mutable_data<C>(ctx.GetPlace());
auto normalization = get_norm_from_string(norm_str, forward);
FFTC2CFunctor<DeviceContext, C, C> fft_c2c_func;
fft_c2c_func(dev_ctx, dy, dx, axes, normalization, !forward);
}
};
template <typename DeviceContext, typename T>
class FFTR2CKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using C = paddle::platform::complex<T>;
auto& dev_ctx = ctx.device_context<DeviceContext>();
auto axes = ctx.Attr<std::vector<int64_t>>("axes");
const std::string& norm_str = ctx.Attr<std::string>("normalization");
const bool forward = ctx.Attr<bool>("forward");
const bool onesided = ctx.Attr<bool>("onesided");
const auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Output<Tensor>("Out");
y->mutable_data<C>(ctx.GetPlace());
auto normalization = get_norm_from_string(norm_str, forward);
FFTR2CFunctor<DeviceContext, T, C> fft_r2c_func;
if (onesided) {
fft_r2c_func(dev_ctx, x, y, axes, normalization, forward);
} else {
framework::DDim onesided_dims(y->dims());
const int64_t onesided_last_axis_size = y->dims().at(axes.back()) / 2 + 1;
onesided_dims.at(axes.back()) = onesided_last_axis_size;
framework::Tensor onesided_out;
onesided_out.mutable_data<C>(onesided_dims, ctx.GetPlace());
fft_r2c_func(dev_ctx, x, &onesided_out, axes, normalization, forward);
fill_conj<DeviceContext, C>(dev_ctx, &onesided_out, y, axes);
}
}
};
template <typename DeviceContext, typename T>
class FFTR2CGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using C = paddle::platform::complex<T>;
auto& dev_ctx = ctx.device_context<DeviceContext>();
const auto axes = ctx.Attr<std::vector<int64_t>>("axes");
const std::string& norm_str = ctx.Attr<std::string>("normalization");
const bool forward = ctx.Attr<bool>("forward");
const bool onesided = ctx.Attr<bool>("onesided");
const auto* dy = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
dx->mutable_data<T>(ctx.GetPlace());
framework::Tensor complex_dx;
complex_dx.mutable_data<C>(dx->dims(), ctx.GetPlace());
auto normalization = get_norm_from_string(norm_str, forward);
FFTC2CFunctor<DeviceContext, C, C> fft_c2c_func;
if (!onesided) {
fft_c2c_func(dev_ctx, dy, &complex_dx, axes, normalization, !forward);
} else {
framework::Tensor full_dy;
full_dy.mutable_data<C>(dx->dims(), ctx.GetPlace());
auto zero_length = static_cast<int>(full_dy.dims().at(axes.back()) -
dy->dims().at(axes.back()));
auto rank = dy->dims().size();
template <typename T>
struct FFTFillConjGradFunctor {
T* input_;
const size_t axis_;
const int64_t* strides_;
const size_t double_length_;
std::vector<int> pads(rank * 2, 0);
pads[axes.back() * 2 + 1] = zero_length;
FFTFillConjGradFunctor(T* input,
size_t axis,
const int64_t* strides,
size_t double_length)
: input_(input),
axis_(axis),
strides_(strides),
double_length_(double_length) {}
phi::funcs::PaddingFunctor<DeviceContext, C>(
rank,
ctx.template device_context<DeviceContext>(),
pads,
static_cast<C>(0),
*dy,
&full_dy);
fft_c2c_func(
dev_ctx, &full_dy, &complex_dx, axes, normalization, !forward);
}
framework::TransComplexToReal(
framework::TransToProtoVarType(dx->dtype()),
framework::TransToProtoVarType(complex_dx.dtype()),
complex_dx,
dx);
HOSTDEVICE void operator()(size_t index) {
size_t offtset = index; // back
size_t index_i;
for (size_t i = 0; i <= axis_; i++) {
index_i = offtset / strides_[i];
offtset %= strides_[i];
}
};
template <typename DeviceContext, typename T>
class FFTC2RKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using C = paddle::platform::complex<T>;
auto& dev_ctx = ctx.device_context<DeviceContext>();
auto axes = ctx.Attr<std::vector<int64_t>>("axes");
const std::string& norm_str = ctx.Attr<std::string>("normalization");
const bool forward = ctx.Attr<bool>("forward");
const auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Output<Tensor>("Out");
y->mutable_data<T>(ctx.GetPlace());
auto normalization = get_norm_from_string(norm_str, forward);
FFTC2RFunctor<DeviceContext, C, T> fft_c2r_func;
fft_c2r_func(dev_ctx, x, y, axes, normalization, forward);
if ((0 < index_i) && (index_i < double_length_ + 1)) {
input_[index] *= static_cast<T>(2);
}
};
template <typename DeviceContext, typename T>
class FFTC2RGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using C = paddle::platform::complex<T>;
auto& dev_ctx = ctx.device_context<DeviceContext>();
auto axes = ctx.Attr<std::vector<int64_t>>("axes");
const std::string& norm_str = ctx.Attr<std::string>("normalization");
const bool forward = ctx.Attr<bool>("forward");
const auto* dy = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
C* pdx = dx->mutable_data<C>(ctx.GetPlace());
auto normalization = get_norm_from_string(norm_str, forward);
FFTR2CFunctor<DeviceContext, T, C> fft_r2c_func;
fft_r2c_func(dev_ctx, dy, dx, axes, normalization, !forward);
const int64_t double_length =
dy->dims()[axes.back()] - dx->dims()[axes.back()];
const framework::DDim strides = phi::stride(dx->dims());
#if defined(__NVCC__) || defined(__HIPCC__)
const thrust::device_vector<int64_t> strides_g(phi::vectorize(strides));
const int64_t* pstrides = thrust::raw_pointer_cast(strides_g.data());
#else
const int64_t* pstrides = strides.Get();
#endif
FFTFillConjGradFunctor<C> func(pdx, axes.back(), pstrides, double_length);
size_t limit = dx->numel();
platform::ForRange<DeviceContext> for_range(dev_ctx, limit);
for_range(func);
}
};
} // namespace operators
} // namespace paddle
} // namespace funcs
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/fft.h"
namespace phi {
namespace funcs {
namespace detail {
const int64_t kMaxFFTNdim = 3;
const int64_t kMaxDataNdim = kMaxFFTNdim + 1;
struct FFTConfigKey {
int signal_ndim_; // 1 <= signal_ndim <= kMaxFFTNdim
// These include additional batch dimension as well.
int64_t sizes_[kMaxDataNdim];
int64_t input_shape_[kMaxDataNdim];
int64_t output_shape_[kMaxDataNdim];
FFTTransformType fft_type_;
DataType value_type_;
using shape_t = std::vector<int64_t>;
FFTConfigKey() = default;
FFTConfigKey(const shape_t& in_shape,
const shape_t& out_shape,
const shape_t& signal_size,
FFTTransformType fft_type,
DataType value_type) {
// Padding bits must be zeroed for hashing
memset(this, 0, sizeof(*this));
signal_ndim_ = signal_size.size() - 1;
fft_type_ = fft_type;
value_type_ = value_type;
std::copy(signal_size.cbegin(), signal_size.cend(), sizes_);
std::copy(in_shape.cbegin(), in_shape.cend(), input_shape_);
std::copy(out_shape.cbegin(), out_shape.cend(), output_shape_);
}
};
// Hashing machinery for Key
// Fowler–Noll–Vo hash function
// see
// https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function
template <typename Key>
struct KeyHash {
// Key must be a POD because we read out its memory
// contenst as char* when hashing
static_assert(std::is_pod<Key>::value, "Key must be plain old data type");
size_t operator()(const Key& params) const {
auto ptr = reinterpret_cast<const uint8_t*>(&params);
uint32_t value = 0x811C9DC5;
for (int i = 0; i < static_cast<int>(sizeof(Key)); ++i) {
value ^= ptr[i];
value *= 0x01000193;
}
return static_cast<size_t>(value);
}
};
template <typename Key>
struct KeyEqual {
// Key must be a POD because we read out its memory
// contenst as char* when comparing
static_assert(std::is_pod<Key>::value, "Key must be plain old data type");
bool operator()(const Key& a, const Key& b) const {
auto ptr1 = reinterpret_cast<const uint8_t*>(&a);
auto ptr2 = reinterpret_cast<const uint8_t*>(&b);
return memcmp(ptr1, ptr2, sizeof(Key)) == 0;
}
};
static FFTConfigKey create_fft_configkey(const DenseTensor& input,
const DenseTensor& output,
int signal_ndim) {
// Create the transform plan (either from cache or locally)
DataType input_dtype = input.dtype();
const auto value_type =
IsComplexType(input_dtype) ? ToRealType(input_dtype) : input_dtype;
const auto fft_type = GetFFTTransformType(input.dtype(), output.dtype());
// signal sizes
std::vector<int64_t> signal_size(signal_ndim + 1);
signal_size[0] = input.dims()[0];
for (int64_t i = 1; i <= signal_ndim; ++i) {
auto in_size = input.dims()[i];
auto out_size = output.dims()[i];
signal_size[i] = std::max(in_size, out_size);
}
FFTConfigKey key(phi::vectorize(input.dims()),
phi::vectorize(output.dims()),
signal_size,
fft_type,
value_type);
return key;
}
} // namespace detail
} // namespace funcs
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/backends/dynload/hipfft.h"
#include "paddle/phi/kernels/funcs/fft.h"
#include "paddle/phi/kernels/funcs/fft_key.h"
namespace phi {
namespace funcs {
namespace detail {
// An RAII encapsulation of hipFFTHandle
class HIPFFTHandle {
public:
HIPFFTHandle() {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::hipfftCreate(&handle_));
}
HIPFFTHandle(const HIPFFTHandle& other) = delete;
HIPFFTHandle& operator=(const HIPFFTHandle& other) = delete;
HIPFFTHandle(HIPFFTHandle&& other) = delete;
HIPFFTHandle& operator=(HIPFFTHandle&& other) = delete;
::hipfftHandle& get() { return handle_; }
const ::hipfftHandle& get() const { return handle_; }
~HIPFFTHandle() { phi::dynload::hipfftDestroy(handle_); }
private:
::hipfftHandle handle_;
};
class FFTConfig {
public:
using plan_size_type = int;
explicit FFTConfig(const FFTConfigKey& key)
: FFTConfig(
std::vector<int64_t>(key.sizes_, key.sizes_ + key.signal_ndim_ + 1),
key.fft_type_,
key.value_type_) {}
FFTConfig(const std::vector<int64_t>& sizes,
FFTTransformType fft_type,
DataType precision)
: fft_type_(fft_type), precision_(precision) {
std::vector<plan_size_type> signal_sizes(sizes.begin() + 1, sizes.end());
const auto batch_size = static_cast<plan_size_type>(sizes[0]);
const int signal_ndim = sizes.size() - 1;
hipfftType exec_type = [&]() {
if (precision == DataType::FLOAT32) {
switch (fft_type) {
case FFTTransformType::C2C:
return HIPFFT_C2C;
case FFTTransformType::R2C:
return HIPFFT_R2C;
case FFTTransformType::C2R:
return HIPFFT_C2R;
}
} else if (precision == DataType::FLOAT64) {
switch (fft_type) {
case FFTTransformType::C2C:
return HIPFFT_Z2Z;
case FFTTransformType::R2C:
return HIPFFT_D2Z;
case FFTTransformType::C2R:
return HIPFFT_Z2D;
}
}
PADDLE_THROW(phi::errors::InvalidArgument(
"Only transforms of type float32 and float64 are supported."));
}();
// disable auto allocation of workspace to use allocator from the framework
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::hipfftSetAutoAllocation(plan(), /* autoAllocate */ 0));
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::hipfftMakePlanMany(plan(),
signal_ndim,
signal_sizes.data(),
/* inembed */ nullptr,
/* base_istride */ 1,
/* idist */ 1,
/* onembed */ nullptr,
/* base_ostride */ 1,
/* odist */ 1,
exec_type,
batch_size,
&ws_size_));
}
const hipfftHandle& plan() const { return plan_.get(); }
FFTTransformType transform_type() const { return fft_type_; }
DataType data_type() const { return precision_; }
size_t workspace_size() const { return ws_size_; }
private:
HIPFFTHandle plan_;
size_t ws_size_; // workspace size in bytes
FFTTransformType fft_type_;
DataType precision_;
};
// NOTE: R2C is forward-only, C2R is backward only
static void exec_plan(const FFTConfig& config,
void* in_data,
void* out_data,
bool forward) {
const hipfftHandle& plan = config.plan();
DataType value_type = config.data_type();
if (value_type == DataType::FLOAT32) {
switch (config.transform_type()) {
case FFTTransformType::C2C: {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::hipfftExecC2C(
plan,
static_cast<hipfftComplex*>(in_data),
static_cast<hipfftComplex*>(out_data),
forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD));
return;
}
case FFTTransformType::R2C: {
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::hipfftExecR2C(plan,
static_cast<hipfftReal*>(in_data),
static_cast<hipfftComplex*>(out_data)));
return;
}
case FFTTransformType::C2R: {
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::hipfftExecC2R(plan,
static_cast<hipfftComplex*>(in_data),
static_cast<hipfftReal*>(out_data)));
return;
}
}
} else if (value_type == DataType::FLOAT64) {
switch (config.transform_type()) {
case FFTTransformType::C2C: {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::hipfftExecZ2Z(
plan,
static_cast<hipfftDoubleComplex*>(in_data),
static_cast<hipfftDoubleComplex*>(out_data),
forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD));
return;
}
case FFTTransformType::R2C: {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::hipfftExecD2Z(
plan,
static_cast<hipfftDoubleReal*>(in_data),
static_cast<hipfftDoubleComplex*>(out_data)));
return;
}
case FFTTransformType::C2R: {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::hipfftExecZ2D(
plan,
static_cast<hipfftDoubleComplex*>(in_data),
static_cast<hipfftDoubleReal*>(out_data)));
return;
}
}
}
PADDLE_THROW(phi::errors::InvalidArgument(
"hipFFT only support transforms of type float32 and float64"));
}
} // namespace detail
} // namespace funcs
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <numeric>
#include "paddle/phi/backends/dynload/mklrt.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/kernels/funcs/fft.h"
namespace phi {
namespace funcs {
namespace detail {
#define MKL_DFTI_CHECK(expr) \
do { \
MKL_LONG status = (expr); \
if (!phi::dynload::DftiErrorClass(status, DFTI_NO_ERROR)) \
PADDLE_THROW( \
phi::errors::External(phi::dynload::DftiErrorMessage(status))); \
} while (0);
struct DftiDescriptorDeleter {
void operator()(DFTI_DESCRIPTOR_HANDLE handle) {
if (handle != nullptr) {
MKL_DFTI_CHECK(phi::dynload::DftiFreeDescriptor(&handle));
}
}
};
// A RAII wrapper for MKL_DESCRIPTOR*
class DftiDescriptor {
public:
void init(DFTI_CONFIG_VALUE precision,
DFTI_CONFIG_VALUE signal_type,
MKL_LONG signal_ndim,
MKL_LONG* sizes) {
PADDLE_ENFORCE_EQ(desc_.get(),
nullptr,
phi::errors::AlreadyExists(
"DftiDescriptor has already been initialized."));
DFTI_DESCRIPTOR* raw_desc;
MKL_DFTI_CHECK(phi::dynload::DftiCreateDescriptorX(
&raw_desc, precision, signal_type, signal_ndim, sizes));
desc_.reset(raw_desc);
}
DFTI_DESCRIPTOR* get() const {
DFTI_DESCRIPTOR* raw_desc = desc_.get();
PADDLE_ENFORCE_NOT_NULL(raw_desc,
phi::errors::PreconditionNotMet(
"DFTI DESCRIPTOR has not been initialized."));
return raw_desc;
}
private:
std::unique_ptr<DFTI_DESCRIPTOR, DftiDescriptorDeleter> desc_;
};
static DftiDescriptor plan_mkl_fft(const DataType in_dtype,
const DataType out_dtype,
const phi::DDim& in_strides,
const phi::DDim& out_strides,
const std::vector<int64_t>& signal_sizes,
FFTNormMode normalization,
bool forward) {
const DFTI_CONFIG_VALUE precision = [&] {
switch (in_dtype) {
case DataType::FLOAT32:
return DFTI_SINGLE;
case DataType::COMPLEX64:
return DFTI_SINGLE;
case DataType::FLOAT64:
return DFTI_DOUBLE;
case DataType::COMPLEX128:
return DFTI_DOUBLE;
default:
PADDLE_THROW(phi::errors::InvalidArgument(
"Invalid input datatype (%s), input data type should be FP32, "
"FP64, COMPLEX64 or COMPLEX128.",
in_dtype));
}
}();
// C2C, R2C, C2R
const FFTTransformType fft_type = GetFFTTransformType(in_dtype, out_dtype);
const DFTI_CONFIG_VALUE domain =
(fft_type == FFTTransformType::C2C) ? DFTI_COMPLEX : DFTI_REAL;
DftiDescriptor descriptor;
std::vector<MKL_LONG> fft_sizes(signal_sizes.cbegin(), signal_sizes.cend());
const MKL_LONG signal_ndim = fft_sizes.size() - 1;
descriptor.init(precision, domain, signal_ndim, fft_sizes.data() + 1);
// placement inplace or not inplace
MKL_DFTI_CHECK(phi::dynload::DftiSetValue(
descriptor.get(), DFTI_PLACEMENT, DFTI_NOT_INPLACE));
// number of transformations
const MKL_LONG batch_size = fft_sizes[0];
MKL_DFTI_CHECK(phi::dynload::DftiSetValue(
descriptor.get(), DFTI_NUMBER_OF_TRANSFORMS, batch_size));
// input & output distance
const MKL_LONG idist = in_strides[0];
const MKL_LONG odist = out_strides[0];
MKL_DFTI_CHECK(
phi::dynload::DftiSetValue(descriptor.get(), DFTI_INPUT_DISTANCE, idist));
MKL_DFTI_CHECK(phi::dynload::DftiSetValue(
descriptor.get(), DFTI_OUTPUT_DISTANCE, odist));
// input & output stride
std::vector<MKL_LONG> mkl_in_stride(1 + signal_ndim, 0);
std::vector<MKL_LONG> mkl_out_stride(1 + signal_ndim, 0);
for (MKL_LONG i = 1; i <= signal_ndim; i++) {
mkl_in_stride[i] = in_strides[i];
mkl_out_stride[i] = out_strides[i];
}
MKL_DFTI_CHECK(phi::dynload::DftiSetValue(
descriptor.get(), DFTI_INPUT_STRIDES, mkl_in_stride.data()));
MKL_DFTI_CHECK(phi::dynload::DftiSetValue(
descriptor.get(), DFTI_OUTPUT_STRIDES, mkl_out_stride.data()));
// conjugate even storage
if (!(fft_type == FFTTransformType::C2C)) {
MKL_DFTI_CHECK(phi::dynload::DftiSetValue(
descriptor.get(), DFTI_CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX));
}
MKL_LONG signal_numel = std::accumulate(fft_sizes.cbegin() + 1,
fft_sizes.cend(),
1UL,
std::multiplies<MKL_LONG>());
if (normalization != FFTNormMode::none) {
const double scale =
((normalization == FFTNormMode::by_sqrt_n)
? 1.0 / std::sqrt(static_cast<double>(signal_numel))
: 1.0 / static_cast<double>(signal_numel));
const auto scale_direction = [&]() {
if (fft_type == FFTTransformType::R2C ||
(fft_type == FFTTransformType::C2C && forward)) {
return DFTI_FORWARD_SCALE;
} else {
// (fft_type == FFTTransformType::C2R ||
// (fft_type == FFTTransformType::C2C && !forward))
return DFTI_BACKWARD_SCALE;
}
}();
MKL_DFTI_CHECK(
phi::dynload::DftiSetValue(descriptor.get(), scale_direction, scale));
}
// commit the descriptor
MKL_DFTI_CHECK(phi::dynload::DftiCommitDescriptor(descriptor.get()));
return descriptor;
}
} // namespace detail
} // namespace funcs
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fft_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/fft_grad_kernel_impl.h"
PD_REGISTER_KERNEL(fft_c2c_grad,
GPU,
ALL_LAYOUT,
phi::FFTC2CGradKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(
fft_c2r_grad, GPU, ALL_LAYOUT, phi::FFTC2RGradKernel, float, double) {}
PD_REGISTER_KERNEL(fft_r2c_grad,
GPU,
ALL_LAYOUT,
phi::FFTR2CGradKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fft_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/fft_kernel_impl.h"
PD_REGISTER_KERNEL(fft_c2c,
GPU,
ALL_LAYOUT,
phi::FFTC2CKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(fft_c2r,
GPU,
ALL_LAYOUT,
phi::FFTC2RKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(fft_r2c, GPU, ALL_LAYOUT, phi::FFTR2CKernel, float, double) {
}
......@@ -74,4 +74,6 @@ PD_REGISTER_KERNEL(scale,
int8_t,
int16_t,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/fft_grad_kernel.h"
#include <string>
#include <vector>
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/kernels/complex_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/fft.h"
#include "paddle/phi/kernels/funcs/fft_fill_conj.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/pad_kernel.h"
namespace phi {
template <typename T, typename Context>
void FFTC2CGradKernel(const Context& ctx,
const DenseTensor& out_grad,
const std::vector<int64_t>& axes,
const std::string& normalization,
bool forward,
DenseTensor* x_grad) {
ctx.template Alloc<T>(x_grad);
auto norm_type = funcs::get_norm_from_string(normalization, forward);
funcs::FFTC2CFunctor<Context, T, T> fft_c2c_func;
fft_c2c_func(ctx, out_grad, x_grad, axes, norm_type, !forward);
}
template <typename T, typename Context>
void FFTR2CGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const std::vector<int64_t>& axes,
const std::string& normalization,
bool forward,
bool onesided,
DenseTensor* x_grad) {
using R = typename T::value_type;
DenseTensor complex_x_grad = EmptyLike<T>(ctx, x);
ctx.template Alloc<R>(x_grad);
auto norm_type = funcs::get_norm_from_string(normalization, forward);
funcs::FFTC2CFunctor<Context, T, T> fft_c2c_func;
if (!onesided) {
fft_c2c_func(ctx, out_grad, &complex_x_grad, axes, norm_type, !forward);
} else {
DenseTensor full_dy;
DenseTensorMeta full_dy_meta(out_grad.type(), x_grad->dims());
full_dy.set_meta(full_dy_meta);
auto zero_length = static_cast<int>(full_dy.dims().at(axes.back()) -
out_grad.dims().at(axes.back()));
auto rank = out_grad.dims().size();
std::vector<int> pads(rank * 2, 0);
pads[axes.back() * 2 + 1] = zero_length;
PadKernel<T>(ctx, out_grad, pads, static_cast<float>(0.0), &full_dy);
fft_c2c_func(ctx, full_dy, &complex_x_grad, axes, norm_type, !forward);
}
RealKernel<T>(ctx, complex_x_grad, x_grad);
}
template <typename T, typename Context>
void FFTC2RGradKernel(const Context& ctx,
const DenseTensor& out_grad,
const std::vector<int64_t>& axes,
const std::string& normalization,
bool forward,
int64_t last_dim_size,
DenseTensor* x_grad) {
using C = phi::dtype::complex<T>;
ctx.template Alloc<C>(x_grad);
auto norm_type = funcs::get_norm_from_string(normalization, forward);
funcs::FFTR2CFunctor<Context, T, C> fft_r2c_func;
fft_r2c_func(ctx, out_grad, x_grad, axes, norm_type, !forward);
const int64_t double_length =
out_grad.dims()[axes.back()] - x_grad->dims()[axes.back()];
const phi::DDim strides = phi::stride(x_grad->dims());
#if defined(__NVCC__) || defined(__HIPCC__)
const thrust::device_vector<int64_t> strides_g(phi::vectorize(strides));
const int64_t* pstrides = thrust::raw_pointer_cast(strides_g.data());
#else
const int64_t* pstrides = strides.Get();
#endif
funcs::FFTFillConjGradFunctor<C> func(
x_grad->data<C>(), axes.back(), pstrides, double_length);
size_t limit = x_grad->numel();
funcs::ForRange<Context> for_range(ctx, limit);
for_range(func);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/fft_kernel.h"
#include <string>
#include <vector>
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/fft.h"
#include "paddle/phi/kernels/funcs/fft_fill_conj.h"
namespace phi {
template <typename T, typename Context>
void FFTC2CKernel(const Context& ctx,
const DenseTensor& x,
const std::vector<int64_t>& axes,
const std::string& normalization,
bool forward,
DenseTensor* out) {
ctx.template Alloc<T>(out);
const auto norm_type = funcs::get_norm_from_string(normalization, forward);
funcs::FFTC2CFunctor<Context, T, T> fft_c2c_func;
fft_c2c_func(ctx, x, out, axes, norm_type, forward);
}
template <typename T, typename Context>
void FFTC2RKernel(const Context& ctx,
const DenseTensor& x,
const std::vector<int64_t>& axes,
const std::string& normalization,
bool forward,
int64_t last_dim_size,
DenseTensor* out) {
using R = typename T::value_type; // get real type
ctx.template Alloc<R>(out);
const auto norm_type = funcs::get_norm_from_string(normalization, forward);
funcs::FFTC2RFunctor<Context, T, R> fft_c2r_func;
fft_c2r_func(ctx, x, out, axes, norm_type, forward);
}
template <typename T, typename Context>
void FFTR2CKernel(const Context& ctx,
const DenseTensor& x,
const std::vector<int64_t>& axes,
const std::string& normalization,
bool forward,
bool onesided,
DenseTensor* out) {
using C = phi::dtype::complex<T>;
ctx.template Alloc<C>(out);
auto norm_type = funcs::get_norm_from_string(normalization, forward);
funcs::FFTR2CFunctor<Context, T, C> fft_r2c_func;
if (onesided) {
fft_r2c_func(ctx, x, out, axes, norm_type, forward);
} else {
phi::DDim onesided_out_shape = x.dims();
const int64_t last_fft_axis = axes.back();
const int64_t onesided_last_axis_size =
out->dims().at(last_fft_axis) / 2 + 1;
onesided_out_shape[last_fft_axis] = onesided_last_axis_size;
DenseTensor onesided_out =
Empty<C, Context>(ctx, phi::vectorize(onesided_out_shape));
fft_r2c_func(ctx, x, &onesided_out, axes, norm_type, forward);
funcs::FFTFillConj<Context, C>(ctx, &onesided_out, out, axes);
}
}
} // namespace phi
......@@ -17,7 +17,7 @@ import numpy as np
import paddle
from .tensor.attribute import is_complex, is_floating_point, is_integer
from .tensor.creation import _real_to_complex_dtype, _complex_to_real_dtype
from .fluid.framework import _non_static_mode
from .fluid.framework import _in_legacy_dygraph, in_dygraph_mode
from . import _C_ops
from .fluid.data_feeder import check_variable_and_dtype
from .fluid.layer_helper import LayerHelper
......@@ -1392,7 +1392,9 @@ def fft_c2c(x, n, axis, norm, forward, name):
op_type = 'fft_c2c'
check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type)
if _non_static_mode():
if in_dygraph_mode():
out = _C_ops.final_state_fft_c2c(x, axes, norm, forward)
elif _in_legacy_dygraph():
attrs = ('axes', axes, 'normalization', norm, 'forward', forward)
out = getattr(_C_ops, op_type)(x, *attrs)
else:
......@@ -1426,7 +1428,9 @@ def fft_r2c(x, n, axis, norm, forward, onesided, name):
op_type = 'fft_r2c'
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], op_type)
if _non_static_mode():
if in_dygraph_mode():
out = _C_ops.final_state_fft_r2c(x, axes, norm, forward, onesided)
elif _in_legacy_dygraph():
attrs = ('axes', axes, 'normalization', norm, 'forward', forward,
'onesided', onesided)
out = getattr(_C_ops, op_type)(x, *attrs)
......@@ -1469,7 +1473,12 @@ def fft_c2r(x, n, axis, norm, forward, name):
op_type = 'fft_c2r'
check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type)
if _non_static_mode():
if in_dygraph_mode():
if n is not None:
out = _C_ops.final_state_fft_c2r(x, axes, norm, forward, n)
else:
out = _C_ops.final_state_fft_c2r(x, axes, norm, forward, 0)
elif _in_legacy_dygraph():
if n is not None:
attrs = ('axes', axes, 'normalization', norm, 'forward', forward,
'last_dim_size', n)
......@@ -1528,7 +1537,9 @@ def fftn_c2c(x, s, axes, norm, forward, name):
op_type = 'fft_c2c'
check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type)
if _non_static_mode():
if in_dygraph_mode():
out = _C_ops.final_state_fft_c2c(x, axes, norm, forward)
elif _in_legacy_dygraph():
attrs = ('axes', axes, 'normalization', norm, 'forward', forward)
out = getattr(_C_ops, op_type)(x, *attrs)
else:
......@@ -1579,7 +1590,9 @@ def fftn_r2c(x, s, axes, norm, forward, onesided, name):
op_type = 'fft_r2c'
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], op_type)
if _non_static_mode():
if in_dygraph_mode():
out = _C_ops.final_state_fft_r2c(x, axes, norm, forward, onesided)
elif _in_legacy_dygraph():
attrs = ('axes', axes, 'normalization', norm, 'forward', forward,
'onesided', onesided)
out = getattr(_C_ops, op_type)(x, *attrs)
......@@ -1642,7 +1655,12 @@ def fftn_c2r(x, s, axes, norm, forward, name):
op_type = 'fft_c2r'
check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type)
if _non_static_mode():
if in_dygraph_mode():
if s is not None:
out = _C_ops.final_state_fft_c2r(x, axes, norm, forward, s[-1])
else:
out = _C_ops.final_state_fft_c2r(x, axes, norm, forward, 0)
elif _in_legacy_dygraph():
if s:
attrs = ('axes', axes, 'normalization', norm, 'forward', forward,
'last_dim_size', s[-1])
......
......@@ -12,20 +12,43 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import numpy as np
from functools import partial
from numpy import asarray
from numpy.fft._pocketfft import _raw_fft, _raw_fftnd, _get_forward_norm, _get_backward_norm, _cook_nd_args
class NormMode(enum.Enum):
none = 1
by_sqrt_n = 2
by_n = 3
def _get_norm_mode(norm, forward):
if norm == "ortho":
return NormMode.by_sqrt_n
if norm is None or norm == "backward":
return NormMode.none if forward else NormMode.by_n
return NormMode.by_n if forward else NormMode.none
def _get_inv_norm(n, norm_mode):
assert isinstance(norm_mode,
NormMode), "invalid norm_type {}".format(norm_mode)
if norm_mode == NormMode.none:
return 1.0
if norm_mode == NormMode.by_sqrt_n:
return np.sqrt(n)
return n
# 1d transforms
def _fftc2c(a, n=None, axis=-1, norm=None, forward=None):
a = asarray(a)
if n is None:
n = a.shape[axis]
if forward:
inv_norm = _get_forward_norm(n, norm)
else:
inv_norm = _get_backward_norm(n, norm)
inv_norm = _get_inv_norm(n, norm)
output = _raw_fft(a, n, axis, False, forward, inv_norm)
return output
......@@ -34,10 +57,7 @@ def _fftr2c(a, n=None, axis=-1, norm=None, forward=None):
a = asarray(a)
if n is None:
n = a.shape[axis]
if forward:
inv_norm = _get_forward_norm(n, norm)
else:
inv_norm = _get_backward_norm(n, norm)
inv_norm = _get_inv_norm(n, norm)
output = _raw_fft(a, n, axis, True, True, inv_norm)
if not forward:
output = output.conj()
......@@ -48,43 +68,67 @@ def _fftc2r(a, n=None, axis=-1, norm=None, forward=None):
a = asarray(a)
if n is None:
n = (a.shape[axis] - 1) * 2
if forward:
inv_norm = _get_forward_norm(n, norm)
else:
inv_norm = _get_backward_norm(n, norm)
inv_norm = _get_inv_norm(n, norm)
output = _raw_fft(a.conj() if forward else a, n, axis, True, False,
inv_norm)
return output
def fft_c2c(x, axes, normalization, forward):
# general fft functors
def _fft_c2c_nd(x, axes, norm_mode, forward):
f = partial(_fftc2c, forward=forward)
y = _raw_fftnd(x, s=None, axes=axes, function=f, norm=normalization)
y = _raw_fftnd(x, s=None, axes=axes, function=f, norm=norm_mode)
return y
def fft_c2c_backward(dy, axes, normalization, forward):
f = partial(_fftc2c, forward=forward)
dx = _raw_fftnd(dy, s=None, axes=axes, function=f, norm=normalization)
return dx
def fft_r2c(x, axes, normalization, forward, onesided):
def _fft_r2c_nd(x, axes, norm_mode, forward, onesided):
a = asarray(x)
s, axes = _cook_nd_args(a, axes=axes)
if onesided:
a = _fftr2c(a, s[-1], axes[-1], normalization, forward)
for ii in range(len(axes) - 1):
a = _fftc2c(a, s[ii], axes[ii], normalization, forward)
a = _fftr2c(a, s[-1], axes[-1], norm_mode, forward)
a = _fft_c2c_nd(a, axes[:-1], norm_mode, forward)
else:
a = fft_c2c(x, axes, normalization, forward)
a = _fft_c2c_nd(x, axes, norm_mode, forward)
return a
def _fft_c2r_nd(x, axes, norm_mode, forward, last_dim_size):
a = asarray(x)
s, axes = _cook_nd_args(a, axes=axes, invreal=1)
if last_dim_size is not None:
s[-1] = last_dim_size
a = _fft_c2c_nd(a, axes[:-1], norm_mode, forward)
a = _fftc2r(a, s[-1], axes[-1], norm_mode, forward)
return a
def fft_r2c_backward(dy, x, axes, normalization, forward, onesided):
# kernels
def fft_c2c(x, axes, normalization, forward):
norm_mode = _get_norm_mode(normalization, forward)
return _fft_c2c_nd(x, axes, norm_mode, forward)
def fft_c2r(x, axes, normalization, forward, last_dim_size):
norm_mode = _get_norm_mode(normalization, forward)
return _fft_c2r_nd(x, axes, norm_mode, forward, last_dim_size)
def fft_r2c(x, axes, normalization, forward, onesided):
norm_mode = _get_norm_mode(normalization, forward)
return _fft_r2c_nd(x, axes, norm_mode, forward, onesided)
# backward kernel
def fft_c2c_backward(dy, axes, normalization, forward):
norm_mode = _get_norm_mode(normalization, forward)
dx = _fft_c2c_nd(dy, axes, norm_mode, not forward)
return dx
def fft_r2c_backward(x, dy, axes, normalization, forward, onesided):
a = dy
if not onesided:
a = fft_c2c_backward(a, axes, normalization, forward).real
a = fft_c2c_backward(a, axes, normalization, forward)
else:
pad_widths = [(0, 0)] * a.ndim
last_axis = axes[-1]
......@@ -93,16 +137,25 @@ def fft_r2c_backward(dy, x, axes, normalization, forward, onesided):
last_dim_size = a.shape[last_axis]
pad_widths[last_axis] = (0, x.shape[last_axis] - last_dim_size)
a = np.pad(a, pad_width=pad_widths)
a = fft_c2c_backward(a, axes, normalization, forward).real
return a
a = fft_c2c_backward(a, axes, normalization, forward)
return a.real
def fft_c2r(x, axes, normalization, forward, last_dim_size):
a = asarray(x)
s, axes = _cook_nd_args(a, axes=axes, invreal=1)
if last_dim_size is not None:
s[-1] = last_dim_size
for ii in range(len(axes) - 1):
a = _fftc2c(a, s[ii], axes[ii], normalization, forward)
a = _fftc2r(a, s[-1], axes[-1], normalization, forward)
def _fft_fill_conj_grad(x, axes, length_to_double):
last_fft_axis = axes[-1]
shape = x.shape
for multi_index in np.ndindex(*shape):
if 0 < multi_index[last_fft_axis] and multi_index[
last_fft_axis] <= length_to_double:
x[multi_index] *= 2
return x
def fft_c2r_backward(x, dy, axes, normalization, forward, last_dim_size):
norm_mode = _get_norm_mode(normalization, forward)
a = dy
a = _fft_r2c_nd(dy, axes, norm_mode, not forward, True)
last_fft_axis = axes[-1]
length_to_double = dy.shape[last_fft_axis] - x.shape[last_fft_axis]
a = _fft_fill_conj_grad(a, axes, length_to_double)
return a
......@@ -473,7 +473,7 @@ class TestIrfft2(unittest.TestCase):
@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [
('test_bool_input',
(np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4)).astype(
np.bool_), None, -1, 'backward', NotImplementedError),
np.bool_), None, -1, 'backward', RuntimeError),
('test_n_nagative', np.random.randn(4, 4, 4) +
1j * np.random.randn(4, 4, 4), -1, -1, 'backward', ValueError),
('test_n_zero', np.random.randn(4, 4) + 1j * np.random.randn(4, 4), 0, -1,
......@@ -543,7 +543,7 @@ class TestIrfftException(unittest.TestCase):
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'),
[('test_bool_input',
(np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4)).astype(
np.bool_), None, (-2, -1), 'backward', NotImplementedError),
np.bool_), None, (-2, -1), 'backward', RuntimeError),
('test_n_nagative',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (-1, -2),
(-2, -1), 'backward', ValueError),
......@@ -625,7 +625,7 @@ class TestIrfft2Exception(unittest.TestCase):
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'),
[('test_bool_input',
(np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4)).astype(
np.bool_), None, (-2, -1), 'backward', NotImplementedError),
np.bool_), None, (-2, -1), 'backward', RuntimeError),
('test_n_nagative',
np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (-1, -2),
(-2, -1), 'backward', ValueError),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册