未验证 提交 11b9f5f9 编写于 作者: X Xiaoxu Chen 提交者: GitHub

[Cherry-pick]FFT function enhancements and bugfixes (#36537)

* update fft api path (#36219)

* update fft api path
* add sample code for ihfft2
Co-authored-by: Nchenfeiyu <chenfeiyu@baidu.com>

* fix fft axis (#36321)

fix: `-1` is used when fft's axis is `0`

* use unified external error message for cufft api (#36114)

* fft: modify sample code result (#36325)

* dynamic load mkl as a fft backend when it is avaialble and requested (#36414)

* add rocm support for fft api (#36415)

* move signal apis

* move fft and signal API path (#2)

* move signal apis

* move fft.py and signal.py to paddle/, fix typos

* fix relative imports from fft.py and signal.py

* fix typos in signal.py (#3)

* move signal apis

* move fft.py and signal.py to paddle/, fix typos

* fix relative imports from fft.py and signal.py

* fix typos

* disable Cache when CUFFT_VERSION >= 10200 (#4)

* move signal apis

* move fft.py and signal.py to paddle/, fix typos

* fix relative imports from fft.py and signal.py

* fix typos

* Add LRUCache for fft plans

* add LRUCache for cuff and hipfft (#5)

* move signal apis

* move fft.py and signal.py to paddle/, fix typos

* fix relative imports from fft.py and signal.py

* fix typos

* WIP: add cache

* delete move constructor and operator= for CuFFTHandle and FFTConfig

* remove log from CuFFTHandle and FFTConfig

* add lrucache for fft rocm backend

* disable LRUCache when CUFFT_VERSION >= 10200

* disbale copy and move for hipFFTHandle; format code
Co-authored-by: NXiaoxu Chen <chenxx_id@163.com>

* remove debug message of cufftHandler

* roll_op: support Tensor as input for shifts (#36727)

* fix fftshift/ifftshift on static mode

* update roll_op version

* add more test cases for fftshift/ifftshift
Co-authored-by: Nzhiboniu <31800336+zhiboniu@users.noreply.github.com>
Co-authored-by: Nchenfeiyu <chenfeiyu@baidu.com>
Co-authored-by: LJQ️ <33169170+lijiaqi0612@users.noreply.github.com>
上级 96edcea4
...@@ -255,8 +255,8 @@ if(WITH_GPU) ...@@ -255,8 +255,8 @@ if(WITH_GPU)
include(external/cub) # download cub include(external/cub) # download cub
list(APPEND third_party_deps extern_cub) list(APPEND third_party_deps extern_cub)
endif() endif()
set(URL "https://paddlepaddledeps.bj.bcebos.com/externalErrorMsg.tar.gz" CACHE STRING "" FORCE) set(URL "https://paddlepaddledeps.bj.bcebos.com/externalErrorMsg_20210928.tar.gz" CACHE STRING "" FORCE)
file_download_and_uncompress(${URL} "externalError" MD5 061f3b7895aadcbe2c3ed592590f8b10) # download file externalErrorMsg.tar.gz file_download_and_uncompress(${URL} "externalError" MD5 a712a49384e77ca216ad866712f7cafa) # download file externalErrorMsg.tar.gz
if(WITH_TESTING) if(WITH_TESTING)
# copy externalErrorMsg.pb, just for unittest can get error message correctly. # copy externalErrorMsg.pb, just for unittest can get error message correctly.
set(SRC_DIR ${THIRD_PARTY_PATH}/externalError/data) set(SRC_DIR ${THIRD_PARTY_PATH}/externalError/data)
......
...@@ -105,10 +105,20 @@ else() ...@@ -105,10 +105,20 @@ else()
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
endif() endif()
if (WITH_GPU AND (NOT WITH_ROCM)) if (WITH_GPU OR WITH_ROCM)
if (MKL_FOUND AND WITH_ONEMKL)
op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda dynload_mklrt ${OP_HEADER_DEPS})
target_include_directories(spectral_op PRIVATE ${MKL_INCLUDE})
else()
op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda ${OP_HEADER_DEPS}) op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda ${OP_HEADER_DEPS})
endif()
else() else()
if (MKL_FOUND AND WITH_ONEMKL)
op_library(spectral_op SRCS spectral_op.cc DEPS dynload_mklrt ${OP_HEADER_DEPS})
target_include_directories(spectral_op PRIVATE ${MKL_INCLUDE})
else()
op_library(spectral_op SRCS spectral_op.cc DEPS ${OP_HEADER_DEPS}) op_library(spectral_op SRCS spectral_op.cc DEPS ${OP_HEADER_DEPS})
endif()
endif() endif()
op_library(lstm_op DEPS ${OP_HEADER_DEPS} lstm_compute) op_library(lstm_op DEPS ${OP_HEADER_DEPS} lstm_compute)
......
...@@ -40,6 +40,7 @@ class RollOp : public framework::OperatorWithKernel { ...@@ -40,6 +40,7 @@ class RollOp : public framework::OperatorWithKernel {
auto dims = ctx->Attrs().Get<std::vector<int64_t>>("axis"); auto dims = ctx->Attrs().Get<std::vector<int64_t>>("axis");
auto shifts = ctx->Attrs().Get<std::vector<int64_t>>("shifts"); auto shifts = ctx->Attrs().Get<std::vector<int64_t>>("shifts");
if (!ctx->HasInput("ShiftsTensor")) {
if (dims.size() != 0) { if (dims.size() != 0) {
PADDLE_ENFORCE_EQ(dims.size(), shifts.size(), PADDLE_ENFORCE_EQ(dims.size(), shifts.size(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -56,6 +57,7 @@ class RollOp : public framework::OperatorWithKernel { ...@@ -56,6 +57,7 @@ class RollOp : public framework::OperatorWithKernel {
"shifts.size() = %d", "shifts.size() = %d",
shifts.size())); shifts.size()));
} }
}
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
auto type = ctx->GetInputsVarType("X")[0]; auto type = ctx->GetInputsVarType("X")[0];
...@@ -105,6 +107,10 @@ class RollOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -105,6 +107,10 @@ class RollOpMaker : public framework::OpProtoAndCheckerMaker {
"The number of places by which the elements " "The number of places by which the elements "
"of the tensor are shifted.") "of the tensor are shifted.")
.SetDefault({}); .SetDefault({});
AddInput("ShiftsTensor",
"The number of places by which the elements of the tensor "
"are shifted.")
.AsDispensable();
AddAttr<std::vector<int64_t>>( AddAttr<std::vector<int64_t>>(
"axis", "axis",
"Axis along which to roll. It must have the same size " "Axis along which to roll. It must have the same size "
...@@ -129,6 +135,9 @@ class RollGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -129,6 +135,9 @@ class RollGradMaker : public framework::SingleGradOpMaker<T> {
void Apply(GradOpPtr<T> op) const override { void Apply(GradOpPtr<T> op) const override {
op->SetType("roll_grad"); op->SetType("roll_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
if (this->HasInput("ShiftsTensor")) {
op->SetInput("ShiftsTensor", this->Input("ShiftsTensor"));
}
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
...@@ -174,7 +183,12 @@ REGISTER_OP_VERSION(roll) ...@@ -174,7 +183,12 @@ REGISTER_OP_VERSION(roll)
"(std::vector<int64_t>) Axis along which to roll. " "(std::vector<int64_t>) Axis along which to roll. "
"It must have the same size with shifts, or size = 0.", "It must have the same size with shifts, or size = 0.",
std::vector<int64_t>()) std::vector<int64_t>())
.DeleteAttr( .DeleteAttr("dims",
"dims",
"(std::vector<int64_t>) Dims along which to roll. " "(std::vector<int64_t>) Dims along which to roll. "
"It must have the same size with shifts, or size = 0.")); "It must have the same size with shifts, or size = 0."))
.AddCheckpoint(
R"ROC(Upgrade roll add a dispensable input "ShiftsTensor".)ROC",
paddle::framework::compatible::OpVersionDesc().NewInput(
"ShiftsTensor",
"The number of places by which the elements of"
"the tensor are shifted."));
...@@ -59,6 +59,16 @@ class RollKernel<platform::CUDADeviceContext, T> ...@@ -59,6 +59,16 @@ class RollKernel<platform::CUDADeviceContext, T>
auto* in = context.Input<LoDTensor>("X"); auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out"); auto* out = context.Output<LoDTensor>("Out");
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts"); std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
if (context.HasInput("ShiftsTensor")) {
const auto* shifts_tensor =
context.Input<framework::Tensor>("ShiftsTensor");
PADDLE_ENFORCE_EQ(
shifts_tensor->dims().size(), 1,
platform::errors::InvalidArgument(
"The rank of ShiftsTensor is expected to be 1, got %s",
shifts_tensor->dims().size()));
shifts = GetDataFromTensor<int64_t>(shifts_tensor);
}
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis"); std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");
auto* in_data = in->data<T>(); auto* in_data = in->data<T>();
...@@ -134,6 +144,16 @@ class RollGradKernel<platform::CUDADeviceContext, T> ...@@ -134,6 +144,16 @@ class RollGradKernel<platform::CUDADeviceContext, T>
auto* in = context.Input<LoDTensor>(framework::GradVarName("Out")); auto* in = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* out = context.Output<LoDTensor>(framework::GradVarName("X")); auto* out = context.Output<LoDTensor>(framework::GradVarName("X"));
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts"); std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
if (context.HasInput("ShiftsTensor")) {
const auto* shifts_tensor =
context.Input<framework::Tensor>("ShiftsTensor");
PADDLE_ENFORCE_EQ(
shifts_tensor->dims().size(), 1,
platform::errors::InvalidArgument(
"The rank of ShiftsTensor is expected to be 1, got %s",
shifts_tensor->dims().size()));
shifts = GetDataFromTensor<int64_t>(shifts_tensor);
}
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis"); std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");
auto* in_data = in->data<T>(); auto* in_data = in->data<T>();
......
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -85,6 +87,16 @@ class RollKernel : public framework::OpKernel<T> { ...@@ -85,6 +87,16 @@ class RollKernel : public framework::OpKernel<T> {
auto& input = input_var->Get<LoDTensor>(); auto& input = input_var->Get<LoDTensor>();
auto* output = output_var->GetMutable<LoDTensor>(); auto* output = output_var->GetMutable<LoDTensor>();
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts"); std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
if (context.HasInput("ShiftsTensor")) {
const auto* shifts_tensor =
context.Input<framework::Tensor>("ShiftsTensor");
PADDLE_ENFORCE_EQ(
shifts_tensor->dims().size(), 1,
platform::errors::InvalidArgument(
"The rank of ShiftsTensor is expected to be 1, got %s",
shifts_tensor->dims().size()));
shifts = GetDataFromTensor<int64_t>(shifts_tensor);
}
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis"); std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");
std::vector<T> out_vec; std::vector<T> out_vec;
...@@ -123,6 +135,11 @@ class RollGradKernel : public framework::OpKernel<T> { ...@@ -123,6 +135,11 @@ class RollGradKernel : public framework::OpKernel<T> {
auto& input = input_var->Get<LoDTensor>(); auto& input = input_var->Get<LoDTensor>();
auto* output = output_var->GetMutable<LoDTensor>(); auto* output = output_var->GetMutable<LoDTensor>();
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts"); std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
if (context.HasInput("ShiftsTensor")) {
const auto* shifts_tensor =
context.Input<framework::Tensor>("ShiftsTensor");
shifts = GetDataFromTensor<int64_t>(shifts_tensor);
}
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis"); std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");
std::vector<T> out_vec; std::vector<T> out_vec;
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/operators/spectral_op.h"
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/dynload/hipfft.h"
#endif
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/dynload/cufft.h"
#endif
namespace paddle {
namespace operators {
using ScalarType = framework::proto::VarType::Type;
const int64_t kMaxFFTNdim = 3;
const int64_t kMaxDataNdim = kMaxFFTNdim + 1;
// This struct is used to easily compute hashes of the
// parameters. It will be the **key** to the plan cache.
struct FFTConfigKey {
// between 1 and kMaxFFTNdim, i.e., 1 <= signal_ndim <= 3
int64_t signal_ndim_;
// These include additional batch dimension as well.
int64_t sizes_[kMaxDataNdim];
int64_t input_shape_[kMaxDataNdim];
int64_t output_shape_[kMaxDataNdim];
FFTTransformType fft_type_;
ScalarType value_type_;
FFTConfigKey() = default;
FFTConfigKey(const std::vector<int64_t>& in_shape,
const std::vector<int64_t>& out_shape,
const std::vector<int64_t>& signal_size,
FFTTransformType fft_type, ScalarType value_type) {
// Padding bits must be zeroed for hashing
memset(this, 0, sizeof(*this));
signal_ndim_ = signal_size.size() - 1;
fft_type_ = fft_type;
value_type_ = value_type;
std::copy(signal_size.cbegin(), signal_size.cend(), sizes_);
std::copy(in_shape.cbegin(), in_shape.cend(), input_shape_);
std::copy(out_shape.cbegin(), out_shape.cend(), output_shape_);
}
};
#if defined(PADDLE_WITH_CUDA)
// An RAII encapsulation of cuFFTHandle
class CuFFTHandle {
::cufftHandle handle_;
public:
CuFFTHandle() {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cufftCreate(&handle_));
}
CuFFTHandle(const CuFFTHandle& other) = delete;
CuFFTHandle& operator=(const CuFFTHandle& other) = delete;
CuFFTHandle(CuFFTHandle&& other) = delete;
CuFFTHandle& operator=(CuFFTHandle&& other) = delete;
::cufftHandle& get() { return handle_; }
const ::cufftHandle& get() const { return handle_; }
~CuFFTHandle() {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cufftDestroy(handle_));
}
};
using plan_size_type = long long int; // NOLINT
// This class contains all the information needed to execute a cuFFT plan:
// 1. the plan
// 2. the workspace size needed
class FFTConfig {
public:
// Only move semantics is enought for this class. Although we already use
// unique_ptr for the plan, still remove copy constructor and assignment op so
// we don't accidentally copy and take perf hit.
explicit FFTConfig(const FFTConfigKey& plan_key)
: FFTConfig(
std::vector<int64_t>(plan_key.sizes_,
plan_key.sizes_ + plan_key.signal_ndim_ + 1),
plan_key.signal_ndim_, plan_key.fft_type_, plan_key.value_type_) {}
// sizes are full signal, including batch size and always two-sided
FFTConfig(const std::vector<int64_t>& sizes, const int64_t signal_ndim,
FFTTransformType fft_type, ScalarType dtype)
: fft_type_(fft_type), value_type_(dtype) {
// signal sizes (excluding batch dim)
std::vector<plan_size_type> signal_sizes(sizes.begin() + 1, sizes.end());
// input batch size
const auto batch = static_cast<plan_size_type>(sizes[0]);
// const int64_t signal_ndim = sizes.size() - 1;
PADDLE_ENFORCE_EQ(signal_ndim, sizes.size() - 1,
platform::errors::InvalidArgument(
"The signal_ndim must be equal to sizes.size() - 1,"
"But signal_ndim is: [%d], sizes.size() - 1 is: [%d]",
signal_ndim, sizes.size() - 1));
cudaDataType itype, otype, exec_type;
const auto complex_input = has_complex_input(fft_type);
const auto complex_output = has_complex_output(fft_type);
if (dtype == framework::proto::VarType::FP32) {
itype = complex_input ? CUDA_C_32F : CUDA_R_32F;
otype = complex_output ? CUDA_C_32F : CUDA_R_32F;
exec_type = CUDA_C_32F;
} else if (dtype == framework::proto::VarType::FP64) {
itype = complex_input ? CUDA_C_64F : CUDA_R_64F;
otype = complex_output ? CUDA_C_64F : CUDA_R_64F;
exec_type = CUDA_C_64F;
} else if (dtype == framework::proto::VarType::FP16) {
itype = complex_input ? CUDA_C_16F : CUDA_R_16F;
otype = complex_output ? CUDA_C_16F : CUDA_R_16F;
exec_type = CUDA_C_16F;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"cuFFT only support transforms of type float16, float32 and "
"float64"));
}
// disable auto allocation of workspace to use allocator from the framework
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cufftSetAutoAllocation(
plan(), /* autoAllocate */ 0));
size_t ws_size_t;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cufftXtMakePlanMany(
plan(), signal_ndim, signal_sizes.data(),
/* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1, itype,
/* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, otype,
batch, &ws_size_t, exec_type));
ws_size = ws_size_t;
}
FFTConfig(const FFTConfig& other) = delete;
FFTConfig& operator=(const FFTConfig& other) = delete;
FFTConfig(FFTConfig&& other) = delete;
FFTConfig& operator=(FFTConfig&& other) = delete;
const cufftHandle& plan() const { return plan_ptr.get(); }
FFTTransformType transform_type() const { return fft_type_; }
ScalarType data_type() const { return value_type_; }
size_t workspace_size() const { return ws_size; }
private:
CuFFTHandle plan_ptr;
size_t ws_size;
FFTTransformType fft_type_;
ScalarType value_type_;
};
#elif defined(PADDLE_WITH_HIP)
// An RAII encapsulation of cuFFTHandle
class HIPFFTHandle {
::hipfftHandle handle_;
public:
HIPFFTHandle() {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftCreate(&handle_));
}
HIPFFTHandle(const HIPFFTHandle& other) = delete;
HIPFFTHandle& operator=(const HIPFFTHandle& other) = delete;
HIPFFTHandle(HIPFFTHandle&& other) = delete;
HIPFFTHandle& operator=(HIPFFTHandle&& other) = delete;
::hipfftHandle& get() { return handle_; }
const ::hipfftHandle& get() const { return handle_; }
~HIPFFTHandle() {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftDestroy(handle_));
}
};
using plan_size_type = int;
// This class contains all the information needed to execute a cuFFT plan:
// 1. the plan
// 2. the workspace size needed
class FFTConfig {
public:
// Only move semantics is enought for this class. Although we already use
// unique_ptr for the plan, still remove copy constructor and assignment op so
// we don't accidentally copy and take perf hit.
explicit FFTConfig(const FFTConfigKey& plan_key)
: FFTConfig(
std::vector<int64_t>(plan_key.sizes_,
plan_key.sizes_ + plan_key.signal_ndim_ + 1),
plan_key.signal_ndim_, plan_key.fft_type_, plan_key.value_type_) {}
// sizes are full signal, including batch size and always two-sided
FFTConfig(const std::vector<int64_t>& sizes, const int64_t signal_ndim,
FFTTransformType fft_type, ScalarType dtype)
: fft_type_(fft_type), value_type_(dtype) {
// signal sizes (excluding batch dim)
std::vector<plan_size_type> signal_sizes(sizes.begin() + 1, sizes.end());
// input batch size
const auto batch = static_cast<plan_size_type>(sizes[0]);
// const int64_t signal_ndim = sizes.size() - 1;
PADDLE_ENFORCE_EQ(signal_ndim, sizes.size() - 1,
platform::errors::InvalidArgument(
"The signal_ndim must be equal to sizes.size() - 1,"
"But signal_ndim is: [%d], sizes.size() - 1 is: [%d]",
signal_ndim, sizes.size() - 1));
hipfftType exec_type = [&] {
if (dtype == framework::proto::VarType::FP32) {
switch (fft_type) {
case FFTTransformType::C2C:
return HIPFFT_C2C;
case FFTTransformType::R2C:
return HIPFFT_R2C;
case FFTTransformType::C2R:
return HIPFFT_C2R;
}
} else if (dtype == framework::proto::VarType::FP64) {
switch (fft_type) {
case FFTTransformType::C2C:
return HIPFFT_Z2Z;
case FFTTransformType::R2C:
return HIPFFT_D2Z;
case FFTTransformType::C2R:
return HIPFFT_Z2D;
}
}
PADDLE_THROW(platform::errors::InvalidArgument(
"hipFFT only support transforms of type float32 and float64"));
}();
// disable auto allocation of workspace to use allocator from the framework
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftSetAutoAllocation(
plan(), /* autoAllocate */ 0));
size_t ws_size_t;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftMakePlanMany(
plan(), signal_ndim, signal_sizes.data(),
/* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1,
/* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, exec_type,
batch, &ws_size_t));
ws_size = ws_size_t;
}
const hipfftHandle& plan() const { return plan_ptr.get(); }
FFTTransformType transform_type() const { return fft_type_; }
ScalarType data_type() const { return value_type_; }
size_t workspace_size() const { return ws_size; }
private:
HIPFFTHandle plan_ptr;
size_t ws_size;
FFTTransformType fft_type_;
ScalarType value_type_;
};
#endif
// Hashing machinery for Key
// Fowler–Noll–Vo hash function
// see
// https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function
template <typename Key>
struct KeyHash {
// Key must be a POD because we read out its memory
// contenst as char* when hashing
static_assert(std::is_pod<Key>::value, "Key must be plain old data type");
size_t operator()(const Key& params) const {
auto ptr = reinterpret_cast<const uint8_t*>(&params);
uint32_t value = 0x811C9DC5;
for (int i = 0; i < static_cast<int>(sizeof(Key)); ++i) {
value ^= ptr[i];
value *= 0x01000193;
}
return static_cast<size_t>(value);
}
};
template <typename Key>
struct KeyEqual {
// Key must be a POD because we read out its memory
// contenst as char* when comparing
static_assert(std::is_pod<Key>::value, "Key must be plain old data type");
bool operator()(const Key& a, const Key& b) const {
auto ptr1 = reinterpret_cast<const uint8_t*>(&a);
auto ptr2 = reinterpret_cast<const uint8_t*>(&b);
return memcmp(ptr1, ptr2, sizeof(Key)) == 0;
}
};
#if CUDA_VERSION < 10000
// Note that the max plan number for CUDA version < 10 has to be 1023
// due to a bug that fails on the 1024th plan
constexpr size_t CUFFT_MAX_PLAN_NUM = 1023;
constexpr size_t CUFFT_DEFAULT_CACHE_SIZE = CUFFT_MAX_PLAN_NUM;
#else
constexpr size_t CUFFT_MAX_PLAN_NUM = std::numeric_limits<size_t>::max();
// The default max cache size chosen for CUDA version > 10 is arbitrary.
// This number puts a limit on how big of a plan cache should we maintain by
// default. Users can always configure it via cufft_set_plan_cache_max_size.
constexpr size_t CUFFT_DEFAULT_CACHE_SIZE = 4096;
#endif
static_assert(CUFFT_MAX_PLAN_NUM >= 0 &&
CUFFT_MAX_PLAN_NUM <= std::numeric_limits<size_t>::max(),
"CUFFT_MAX_PLAN_NUM not in size_t range");
static_assert(CUFFT_DEFAULT_CACHE_SIZE >= 0 &&
CUFFT_DEFAULT_CACHE_SIZE <= CUFFT_MAX_PLAN_NUM,
"CUFFT_DEFAULT_CACHE_SIZE not in [0, CUFFT_MAX_PLAN_NUM] range");
// This cache assumes that the mapping from key to value never changes.
// This is **NOT** thread-safe. Please use a mutex when using it **AND** the
// value returned from try_emplace_value.
// The contract of using this cache is that try_emplace_value should only be
// used when the max_size is positive.
class FFTConfigCache {
public:
using kv_t = typename std::pair<FFTConfigKey, FFTConfig>;
using map_t = typename std::unordered_map<
std::reference_wrapper<FFTConfigKey>, typename std::list<kv_t>::iterator,
KeyHash<FFTConfigKey>, KeyEqual<FFTConfigKey>>;
using map_kkv_iter_t = typename map_t::iterator;
FFTConfigCache() : FFTConfigCache(CUFFT_DEFAULT_CACHE_SIZE) {}
explicit FFTConfigCache(int64_t max_size) { _set_max_size(max_size); }
FFTConfigCache(const FFTConfigCache& other) = delete;
FFTConfigCache& operator=(const FFTConfigCache& other) = delete;
FFTConfigCache(FFTConfigCache&& other) noexcept
: _usage_list(std::move(other._usage_list)),
_cache_map(std::move(other._cache_map)),
_max_size(other._max_size) {}
FFTConfigCache& operator=(FFTConfigCache&& other) noexcept {
_usage_list = std::move(other._usage_list);
_cache_map = std::move(other._cache_map);
_max_size = other._max_size;
return *this;
}
// If key is in this cache, return the cached config. Otherwise, emplace the
// config in this cache and return it.
FFTConfig& lookup(FFTConfigKey params) {
PADDLE_ENFORCE_GT(_max_size, 0,
platform::errors::InvalidArgument(
"The max size of FFTConfigCache must be great than 0,"
"But received is [%d]",
_max_size));
map_kkv_iter_t map_it = _cache_map.find(params);
// Hit, put to list front
if (map_it != _cache_map.end()) {
_usage_list.splice(_usage_list.begin(), _usage_list, map_it->second);
return map_it->second->second;
}
// Miss
// remove if needed
if (_usage_list.size() >= _max_size) {
auto last = _usage_list.end();
last--;
_cache_map.erase(last->first);
_usage_list.pop_back();
}
// construct new plan at list front, then insert into _cache_map
_usage_list.emplace_front(std::piecewise_construct,
std::forward_as_tuple(params),
std::forward_as_tuple(params));
auto kv_it = _usage_list.begin();
_cache_map.emplace(std::piecewise_construct,
std::forward_as_tuple(kv_it->first),
std::forward_as_tuple(kv_it));
return kv_it->second;
}
void clear() {
_cache_map.clear();
_usage_list.clear();
}
void resize(int64_t new_size) {
_set_max_size(new_size);
auto cur_size = _usage_list.size();
if (cur_size > _max_size) {
auto delete_it = _usage_list.end();
for (size_t i = 0; i < cur_size - _max_size; i++) {
delete_it--;
_cache_map.erase(delete_it->first);
}
_usage_list.erase(delete_it, _usage_list.end());
}
}
size_t size() const { return _cache_map.size(); }
size_t max_size() const noexcept { return _max_size; }
std::mutex mutex;
private:
// Only sets size and does value check. Does not resize the data structures.
void _set_max_size(int64_t new_size) {
// We check that 0 <= new_size <= CUFFT_MAX_PLAN_NUM here. Since
// CUFFT_MAX_PLAN_NUM is of type size_t, we need to do non-negativity check
// first.
PADDLE_ENFORCE_GE(
new_size, 0,
platform::errors::InvalidArgument(
"cuFFT plan cache size must be non-negative, But received is [%d]",
new_size));
PADDLE_ENFORCE_LE(new_size, CUFFT_MAX_PLAN_NUM,
platform::errors::InvalidArgument(
"cuFFT plan cache size can not be larger than [%d], "
"But received is [%d]",
CUFFT_MAX_PLAN_NUM, new_size));
_max_size = static_cast<size_t>(new_size);
}
std::list<kv_t> _usage_list;
map_t _cache_map;
size_t _max_size;
};
static std::vector<std::unique_ptr<FFTConfigCache>> plan_caches;
static std::mutex plan_caches_mutex;
static inline FFTConfigCache& get_fft_plan_cache(int64_t device_index) {
std::lock_guard<std::mutex> guard(plan_caches_mutex);
if (device_index >= plan_caches.size()) {
plan_caches.resize(device_index + 1);
}
if (!plan_caches[device_index]) {
plan_caches[device_index] = std::make_unique<FFTConfigCache>();
}
return *plan_caches[device_index];
}
} // namespace operators
} // namespace paddle
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/complex.h"
#if defined(PADDLE_WITH_ONEMKL) #if defined(PADDLE_WITH_ONEMKL)
#include <mkl_dfti.h> #include "paddle/fluid/platform/dynload/mklrt.h"
#elif defined(PADDLE_WITH_POCKETFFT) #elif defined(PADDLE_WITH_POCKETFFT)
#include "extern_pocketfft/pocketfft_hdronly.h" #include "extern_pocketfft/pocketfft_hdronly.h"
#endif #endif
...@@ -357,46 +357,45 @@ FFTNormMode get_norm_from_string(const std::string& norm, bool forward) { ...@@ -357,46 +357,45 @@ FFTNormMode get_norm_from_string(const std::string& norm, bool forward) {
// FFT Functors // FFT Functors
#if defined(PADDLE_WITH_ONEMKL) #if defined(PADDLE_WITH_ONEMKL)
#define MKL_DFTI_CHECK(expr) \
do { \
MKL_LONG status = (expr); \
if (!platform::dynload::DftiErrorClass(status, DFTI_NO_ERROR)) \
PADDLE_THROW(platform::errors::External( \
platform::dynload::DftiErrorMessage(status))); \
} while (0);
namespace { namespace {
static inline void MKL_DFTI_CHECK(MKL_INT status) {
if (status && !DftiErrorClass(status, DFTI_NO_ERROR)) {
PADDLE_THROW(platform::errors::External(DftiErrorMessage(status)));
}
}
struct DftiDescriptorDeleter { struct DftiDescriptorDeleter {
void operator()(DFTI_DESCRIPTOR_HANDLE handle) { void operator()(DFTI_DESCRIPTOR_HANDLE handle) {
if (handle != nullptr) { if (handle != nullptr) {
MKL_DFTI_CHECK(DftiFreeDescriptor(&handle)); MKL_DFTI_CHECK(platform::dynload::DftiFreeDescriptor(&handle));
} }
} }
}; };
// A RAII wrapper for MKL_DESCRIPTOR*
class DftiDescriptor { class DftiDescriptor {
public: public:
void init(DFTI_CONFIG_VALUE precision, DFTI_CONFIG_VALUE signal_type, void init(DFTI_CONFIG_VALUE precision, DFTI_CONFIG_VALUE signal_type,
MKL_LONG signal_ndim, MKL_LONG* sizes) { MKL_LONG signal_ndim, MKL_LONG* sizes) {
if (desc_ != nullptr) { PADDLE_ENFORCE_EQ(desc_.get(), nullptr,
PADDLE_THROW(platform::errors::AlreadyExists( platform::errors::AlreadyExists(
"DFT DESCRIPTOR can only be initialized once.")); "DftiDescriptor has already been initialized."));
}
DFTI_DESCRIPTOR* raw_desc; DFTI_DESCRIPTOR* raw_desc;
if (signal_ndim == 1) { MKL_DFTI_CHECK(platform::dynload::DftiCreateDescriptorX(
MKL_DFTI_CHECK( &raw_desc, precision, signal_type, signal_ndim, sizes));
DftiCreateDescriptor(&raw_desc, precision, signal_type, 1, sizes[0]));
} else {
MKL_DFTI_CHECK(DftiCreateDescriptor(&raw_desc, precision, signal_type,
signal_ndim, sizes));
}
desc_.reset(raw_desc); desc_.reset(raw_desc);
} }
DFTI_DESCRIPTOR* get() const { DFTI_DESCRIPTOR* get() const {
if (desc_ == nullptr) { DFTI_DESCRIPTOR* raw_desc = desc_.get();
PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_ENFORCE_NOT_NULL(raw_desc,
platform::errors::PreconditionNotMet(
"DFTI DESCRIPTOR has not been initialized.")); "DFTI DESCRIPTOR has not been initialized."));
} return raw_desc;
return desc_.get();
} }
private: private:
...@@ -421,7 +420,9 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype, ...@@ -421,7 +420,9 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
return DFTI_DOUBLE; return DFTI_DOUBLE;
default: default:
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Input data type should be FP32, FP64, COMPLEX64 or COMPLEX128.")); "Invalid input datatype (%s), input data type should be FP32, "
"FP64, COMPLEX64 or COMPLEX128.",
framework::DataTypeToString(in_dtype)));
} }
}(); }();
...@@ -430,35 +431,27 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype, ...@@ -430,35 +431,27 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
const DFTI_CONFIG_VALUE domain = const DFTI_CONFIG_VALUE domain =
(fft_type == FFTTransformType::C2C) ? DFTI_COMPLEX : DFTI_REAL; (fft_type == FFTTransformType::C2C) ? DFTI_COMPLEX : DFTI_REAL;
// const bool complex_input = framework::IsComplexType(in_dtype);
// const bool complex_output = framework::IsComplexType(out_dtype);
// const DFTI_CONFIG_VALUE domain = [&] {
// if (forward) {
// return complex_input ? DFTI_COMPLEX : DFTI_REAL;
// } else {
// return complex_output ? DFTI_COMPLEX : DFTI_REAL;
// }
// }();
DftiDescriptor descriptor; DftiDescriptor descriptor;
std::vector<MKL_LONG> fft_sizes(signal_sizes.cbegin(), signal_sizes.cend()); std::vector<MKL_LONG> fft_sizes(signal_sizes.cbegin(), signal_sizes.cend());
const MKL_LONG signal_ndim = fft_sizes.size() - 1; const MKL_LONG signal_ndim = fft_sizes.size() - 1;
descriptor.init(precision, domain, signal_ndim, fft_sizes.data() + 1); descriptor.init(precision, domain, signal_ndim, fft_sizes.data() + 1);
// placement inplace or not inplace // placement inplace or not inplace
MKL_DFTI_CHECK( MKL_DFTI_CHECK(platform::dynload::DftiSetValue(
DftiSetValue(descriptor.get(), DFTI_PLACEMENT, DFTI_NOT_INPLACE)); descriptor.get(), DFTI_PLACEMENT, DFTI_NOT_INPLACE));
// number of transformations // number of transformations
const MKL_LONG batch_size = fft_sizes[0]; const MKL_LONG batch_size = fft_sizes[0];
MKL_DFTI_CHECK( MKL_DFTI_CHECK(platform::dynload::DftiSetValue(
DftiSetValue(descriptor.get(), DFTI_NUMBER_OF_TRANSFORMS, batch_size)); descriptor.get(), DFTI_NUMBER_OF_TRANSFORMS, batch_size));
// input & output distance // input & output distance
const MKL_LONG idist = in_strides[0]; const MKL_LONG idist = in_strides[0];
const MKL_LONG odist = out_strides[0]; const MKL_LONG odist = out_strides[0];
MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_INPUT_DISTANCE, idist)); MKL_DFTI_CHECK(platform::dynload::DftiSetValue(descriptor.get(),
MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_OUTPUT_DISTANCE, odist)); DFTI_INPUT_DISTANCE, idist));
MKL_DFTI_CHECK(platform::dynload::DftiSetValue(descriptor.get(),
DFTI_OUTPUT_DISTANCE, odist));
// input & output stride // input & output stride
std::vector<MKL_LONG> mkl_in_stride(1 + signal_ndim, 0); std::vector<MKL_LONG> mkl_in_stride(1 + signal_ndim, 0);
...@@ -467,15 +460,15 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype, ...@@ -467,15 +460,15 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
mkl_in_stride[i] = in_strides[i]; mkl_in_stride[i] = in_strides[i];
mkl_out_stride[i] = out_strides[i]; mkl_out_stride[i] = out_strides[i];
} }
MKL_DFTI_CHECK( MKL_DFTI_CHECK(platform::dynload::DftiSetValue(
DftiSetValue(descriptor.get(), DFTI_INPUT_STRIDES, mkl_in_stride.data())); descriptor.get(), DFTI_INPUT_STRIDES, mkl_in_stride.data()));
MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_OUTPUT_STRIDES, MKL_DFTI_CHECK(platform::dynload::DftiSetValue(
mkl_out_stride.data())); descriptor.get(), DFTI_OUTPUT_STRIDES, mkl_out_stride.data()));
// conjugate even storage // conjugate even storage
if (!(fft_type == FFTTransformType::C2C)) { if (!(fft_type == FFTTransformType::C2C)) {
MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_CONJUGATE_EVEN_STORAGE, MKL_DFTI_CHECK(platform::dynload::DftiSetValue(
DFTI_COMPLEX_COMPLEX)); descriptor.get(), DFTI_CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX));
} }
MKL_LONG signal_numel = MKL_LONG signal_numel =
...@@ -496,11 +489,12 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype, ...@@ -496,11 +489,12 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
return DFTI_BACKWARD_SCALE; return DFTI_BACKWARD_SCALE;
} }
}(); }();
MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), scale_direction, scale)); MKL_DFTI_CHECK(platform::dynload::DftiSetValue(descriptor.get(),
scale_direction, scale));
} }
// commit the descriptor // commit the descriptor
MKL_DFTI_CHECK(DftiCommitDescriptor(descriptor.get())); MKL_DFTI_CHECK(platform::dynload::DftiCommitDescriptor(descriptor.get()));
return descriptor; return descriptor;
} }
...@@ -592,14 +586,15 @@ void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out, ...@@ -592,14 +586,15 @@ void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out,
collapsed_input.numel(), collapsed_input.numel(),
collapsed_input_conj.data<Ti>()); collapsed_input_conj.data<Ti>());
for_range(functor); for_range(functor);
MKL_DFTI_CHECK(DftiComputeBackward(desc.get(), MKL_DFTI_CHECK(platform::dynload::DftiComputeBackward(
collapsed_input_conj.data<void>(), desc.get(), collapsed_input_conj.data<void>(),
collapsed_output.data<void>())); collapsed_output.data<void>()));
} else if (fft_type == FFTTransformType::R2C && !forward) { } else if (fft_type == FFTTransformType::R2C && !forward) {
framework::Tensor collapsed_output_conj(collapsed_output.type()); framework::Tensor collapsed_output_conj(collapsed_output.type());
collapsed_output_conj.mutable_data<To>(collapsed_output.dims(), collapsed_output_conj.mutable_data<To>(collapsed_output.dims(),
ctx.GetPlace()); ctx.GetPlace());
MKL_DFTI_CHECK(DftiComputeForward(desc.get(), collapsed_input.data<void>(), MKL_DFTI_CHECK(platform::dynload::DftiComputeForward(
desc.get(), collapsed_input.data<void>(),
collapsed_output_conj.data<void>())); collapsed_output_conj.data<void>()));
// conjugate the output // conjugate the output
platform::ForRange<DeviceContext> for_range(ctx, collapsed_output.numel()); platform::ForRange<DeviceContext> for_range(ctx, collapsed_output.numel());
...@@ -609,12 +604,12 @@ void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out, ...@@ -609,12 +604,12 @@ void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out,
for_range(functor); for_range(functor);
} else { } else {
if (forward) { if (forward) {
MKL_DFTI_CHECK(DftiComputeForward(desc.get(), MKL_DFTI_CHECK(platform::dynload::DftiComputeForward(
collapsed_input.data<void>(), desc.get(), collapsed_input.data<void>(),
collapsed_output.data<void>())); collapsed_output.data<void>()));
} else { } else {
MKL_DFTI_CHECK(DftiComputeBackward(desc.get(), MKL_DFTI_CHECK(platform::dynload::DftiComputeBackward(
collapsed_input.data<void>(), desc.get(), collapsed_input.data<void>(),
collapsed_output.data<void>())); collapsed_output.data<void>()));
} }
} }
......
...@@ -8,10 +8,6 @@ ...@@ -8,10 +8,6 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <cufft.h>
#include <cufftXt.h>
#include <functional> #include <functional>
#include <list> #include <list>
#include <memory> #include <memory>
...@@ -24,263 +20,168 @@ ...@@ -24,263 +20,168 @@
#include <vector> #include <vector>
#include "paddle/fluid/operators/conj_op.h" #include "paddle/fluid/operators/conj_op.h"
#include "paddle/fluid/operators/spectral_helper.h"
#include "paddle/fluid/operators/spectral_op.h" #include "paddle/fluid/operators/spectral_op.h"
#include "paddle/fluid/operators/transpose_op.h" #include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/dynload/cufft.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace { namespace {
using ScalarType = framework::proto::VarType::Type; // Calculates the normalization constant
const int64_t kMaxCUFFTNdim = 3; double fft_normalization_scale(FFTNormMode normalization,
const int64_t kMaxDataNdim = kMaxCUFFTNdim + 1; const std::vector<int64_t>& sizes,
const std::vector<int64_t>& dims) {
static inline std::string get_cufft_error_info(cufftResult error) { // auto norm = static_cast<fft_norm_mode>(normalization);
switch (error) { if (normalization == FFTNormMode::none) {
case CUFFT_SUCCESS: return static_cast<double>(1.0);
return "CUFFT_SUCCESS";
case CUFFT_INVALID_PLAN:
return "CUFFT_INVALID_PLAN";
case CUFFT_ALLOC_FAILED:
return "CUFFT_ALLOC_FAILED";
case CUFFT_INVALID_TYPE:
return "CUFFT_INVALID_TYPE";
case CUFFT_INVALID_VALUE:
return "CUFFT_INVALID_VALUE";
case CUFFT_INTERNAL_ERROR:
return "CUFFT_INTERNAL_ERROR";
case CUFFT_EXEC_FAILED:
return "CUFFT_EXEC_FAILED";
case CUFFT_SETUP_FAILED:
return "CUFFT_SETUP_FAILED";
case CUFFT_INVALID_SIZE:
return "CUFFT_INVALID_SIZE";
case CUFFT_UNALIGNED_DATA:
return "CUFFT_UNALIGNED_DATA";
case CUFFT_INCOMPLETE_PARAMETER_LIST:
return "CUFFT_INCOMPLETE_PARAMETER_LIST";
case CUFFT_INVALID_DEVICE:
return "CUFFT_INVALID_DEVICE";
case CUFFT_PARSE_ERROR:
return "CUFFT_PARSE_ERROR";
case CUFFT_NO_WORKSPACE:
return "CUFFT_NO_WORKSPACE";
case CUFFT_NOT_IMPLEMENTED:
return "CUFFT_NOT_IMPLEMENTED";
#ifndef __HIPCC__
case CUFFT_LICENSE_ERROR:
return "CUFFT_LICENSE_ERROR";
#endif
case CUFFT_NOT_SUPPORTED:
return "CUFFT_NOT_SUPPORTED";
default:
std::ostringstream ss;
ss << "unknown error " << error;
return ss.str();
} }
}
static inline void CUFFT_CHECK(cufftResult error) { int64_t signal_numel = 1;
if (error != CUFFT_SUCCESS) { for (auto dim : dims) {
PADDLE_THROW(platform::errors::External(get_cufft_error_info(error))); signal_numel *= sizes[dim];
} }
const double scale_denom = (normalization == FFTNormMode::by_sqrt_n)
? std::sqrt(signal_numel)
: static_cast<double>(signal_numel);
return static_cast<double>(1.0 / scale_denom);
} }
// This struct is used to easily compute hashes of the template <typename DeviceContext, typename T>
// parameters. It will be the **key** to the plan cache. void exec_normalization(const DeviceContext& ctx, const Tensor* in, Tensor* out,
struct PlanKey { FFTNormMode normalization,
// between 1 and kMaxCUFFTNdim, i.e., 1 <= signal_ndim <= 3 const std::vector<int64_t>& sizes,
int64_t signal_ndim_; const std::vector<int64_t>& axes) {
// These include additional batch dimension as well. double scale = fft_normalization_scale(normalization, sizes, axes);
int64_t sizes_[kMaxDataNdim]; if (scale != 1.0) {
int64_t input_shape_[kMaxDataNdim]; auto eigen_out = framework::EigenVector<T>::Flatten(*out);
int64_t output_shape_[kMaxDataNdim]; auto eigen_in = framework::EigenVector<T>::Flatten(*in);
FFTTransformType fft_type_; auto dev = ctx.eigen_device();
ScalarType value_type_; EigenScale<Eigen::GpuDevice, T>::Eval(*dev, eigen_out, eigen_in,
static_cast<T>(scale),
PlanKey() = default; static_cast<T>(0), false);
} else {
PlanKey(const std::vector<int64_t>& in_shape, framework::TensorCopy(*in, ctx.GetPlace(), out);
const std::vector<int64_t>& out_shape,
const std::vector<int64_t>& signal_size, FFTTransformType fft_type,
ScalarType value_type) {
// Padding bits must be zeroed for hashing
memset(this, 0, sizeof(*this));
signal_ndim_ = signal_size.size() - 1;
fft_type_ = fft_type;
value_type_ = value_type;
std::copy(signal_size.cbegin(), signal_size.cend(), sizes_);
std::copy(in_shape.cbegin(), in_shape.cend(), input_shape_);
std::copy(out_shape.cbegin(), out_shape.cend(), output_shape_);
} }
}; }
// An RAII encapsulation of cuFFTHandle
class CuFFTHandle {
::cufftHandle handle_;
public:
CuFFTHandle() { CUFFT_CHECK(platform::dynload::cufftCreate(&handle_)); }
::cufftHandle& get() { return handle_; } #if defined(PADDLE_WITH_CUDA)
const ::cufftHandle& get() const { return handle_; } FFTConfigKey create_fft_configkey(const framework::Tensor& input,
const framework::Tensor& output,
int signal_ndim) {
// Create the transform plan (either from cache or locally)
const auto value_type = framework::IsComplexType(input.type())
? framework::ToRealType(input.type())
: input.type();
auto fft_type = GetFFTTransformType(input.type(), output.type());
// signal sizes
std::vector<int64_t> signal_size(signal_ndim + 1);
~CuFFTHandle() { signal_size[0] = input.dims()[0];
// Not using fftDestroy() for rocFFT to work around double freeing of handles for (int64_t i = 1; i <= signal_ndim; ++i) {
#ifndef __HIPCC__ auto in_size = input.dims()[i];
CUFFT_CHECK(platform::dynload::cufftDestroy(handle_)); auto out_size = output.dims()[i];
#endif signal_size[i] = std::max(in_size, out_size);
} }
}; FFTConfigKey key(framework::vectorize(input.dims()),
framework::vectorize(output.dims()), signal_size, fft_type,
value_type);
return key;
}
#ifdef __HIPCC__ // Execute a pre-planned transform
using plan_size_type = int; static void exec_cufft_plan_raw(const FFTConfig& config, void* in_data,
#else void* out_data, bool forward) {
using plan_size_type = long long int; // NOLINT auto& plan = config.plan();
#endif
// This class contains all the information needed to execute a cuFFT plan: PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cufftXtExec(
// 1. the plan plan, in_data, out_data, forward ? CUFFT_FORWARD : CUFFT_INVERSE));
// 2. the workspace size needed }
class CuFFTConfig {
public:
// Only move semantics is enought for this class. Although we already use
// unique_ptr for the plan, still remove copy constructor and assignment op so
// we don't accidentally copy and take perf hit.
CuFFTConfig(const CuFFTConfig&) = delete;
CuFFTConfig& operator=(CuFFTConfig const&) = delete;
explicit CuFFTConfig(const PlanKey& plan_key)
: CuFFTConfig(
std::vector<int64_t>(plan_key.sizes_,
plan_key.sizes_ + plan_key.signal_ndim_ + 1),
plan_key.signal_ndim_, plan_key.fft_type_, plan_key.value_type_) {}
// sizes are full signal, including batch size and always two-sided
CuFFTConfig(const std::vector<int64_t>& sizes, const int64_t signal_ndim,
FFTTransformType fft_type, ScalarType dtype)
: fft_type_(fft_type), value_type_(dtype) {
// signal sizes (excluding batch dim)
std::vector<plan_size_type> signal_sizes(sizes.begin() + 1, sizes.end());
// input batch size
const auto batch = static_cast<plan_size_type>(sizes[0]);
// const int64_t signal_ndim = sizes.size() - 1;
PADDLE_ENFORCE_EQ(signal_ndim, sizes.size() - 1,
platform::errors::InvalidArgument(
"The signal_ndim must be equal to sizes.size() - 1,"
"But signal_ndim is: [%d], sizes.size() - 1 is: [%d]",
signal_ndim, sizes.size() - 1));
#ifdef __HIPCC__
hipfftType exec_type = [&] {
if (dtype == framework::proto::VarType::FP32) {
switch (fft_type) {
case FFTTransformType::C2C:
return HIPFFT_C2C;
case FFTTransformType::R2C:
return HIPFFT_R2C;
case FFTTransformType::C2R:
return HIPFFT_C2R;
}
} else if (dtype == framework::proto::VarType::FP64) {
switch (fft_type) {
case FFTTransformType::C2C:
return HIPFFT_Z2Z;
case FFTTransformType::R2C:
return HIPFFT_D2Z;
case FFTTransformType::C2R:
return HIPFFT_Z2D;
}
}
PADDLE_THROW(platform::errors::InvalidArgument(
"hipFFT only support transforms of type float32 and float64"));
}();
#else
cudaDataType itype, otype, exec_type;
const auto complex_input = has_complex_input(fft_type);
const auto complex_output = has_complex_output(fft_type);
if (dtype == framework::proto::VarType::FP32) {
itype = complex_input ? CUDA_C_32F : CUDA_R_32F;
otype = complex_output ? CUDA_C_32F : CUDA_R_32F;
exec_type = CUDA_C_32F;
} else if (dtype == framework::proto::VarType::FP64) {
itype = complex_input ? CUDA_C_64F : CUDA_R_64F;
otype = complex_output ? CUDA_C_64F : CUDA_R_64F;
exec_type = CUDA_C_64F;
} else if (dtype == framework::proto::VarType::FP16) {
itype = complex_input ? CUDA_C_16F : CUDA_R_16F;
otype = complex_output ? CUDA_C_16F : CUDA_R_16F;
exec_type = CUDA_C_16F;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"cuFFT only support transforms of type float16, float32 and "
"float64"));
}
#endif
// disable auto allocation of workspace to use allocator from the framework template <typename DeviceContext, typename Ti, typename To>
CUFFT_CHECK(platform::dynload::cufftSetAutoAllocation( void exec_cufft_plan(const DeviceContext& ctx, const FFTConfig& config,
plan(), /* autoAllocate */ 0)); framework::Tensor* input, framework::Tensor* output,
bool forward) {
size_t ws_size_t; // execute transform plan
auto fft_type = config.transform_type();
// make plan if (fft_type == FFTTransformType::C2R && forward) {
#ifdef __HIPCC__ forward = false;
CUFFT_CHECK(hipfftMakePlanMany( framework::Tensor input_conj(input->type());
plan(), signal_ndim, signal_sizes.data(), input_conj.mutable_data<Ti>(input->dims(), ctx.GetPlace());
/* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1, platform::ForRange<DeviceContext> for_range(ctx, input->numel());
/* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, exec_type, math::ConjFunctor<Ti> functor(input->data<Ti>(), input->numel(),
batch, &ws_size_t)); input_conj.data<Ti>());
#else for_range(functor);
exec_cufft_plan_raw(config, input_conj.data<void>(), output->data<void>(),
CUFFT_CHECK(platform::dynload::cufftXtMakePlanMany( forward);
plan(), signal_ndim, signal_sizes.data(), } else if (fft_type == FFTTransformType::R2C && !forward) {
/* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1, itype, forward = true;
/* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, otype, framework::Tensor out_conj(output->type());
batch, &ws_size_t, exec_type)); out_conj.mutable_data<To>(output->dims(), ctx.GetPlace());
#endif exec_cufft_plan_raw(config, input->data<void>(), out_conj.data<void>(),
forward);
ws_size = ws_size_t; platform::ForRange<DeviceContext> for_range(ctx, output->numel());
math::ConjFunctor<To> functor(out_conj.data<To>(), output->numel(),
output->data<To>());
for_range(functor);
} else {
exec_cufft_plan_raw(config, input->data<void>(), output->data<void>(),
forward);
} }
}
const cufftHandle& plan() const { return plan_ptr.get(); } #elif defined(PADDLE_WITH_HIP)
FFTTransformType transform_type() const { return fft_type_; } FFTConfigKey create_fft_configkey(const framework::Tensor& input,
ScalarType data_type() const { return value_type_; } const framework::Tensor& output,
size_t workspace_size() const { return ws_size; } int signal_ndim) {
// Create the transform plan (either from cache or locally)
const auto value_type = framework::IsComplexType(input.type())
? framework::ToRealType(input.type())
: input.type();
auto fft_type = GetFFTTransformType(input.type(), output.type());
// signal sizes
std::vector<int64_t> signal_size(signal_ndim + 1);
private: signal_size[0] = input.dims()[0];
CuFFTHandle plan_ptr; for (int64_t i = 1; i <= signal_ndim; ++i) {
size_t ws_size; auto in_size = input.dims()[i];
FFTTransformType fft_type_; auto out_size = output.dims()[i];
ScalarType value_type_; signal_size[i] = std::max(in_size, out_size);
}; }
FFTConfigKey key(framework::vectorize(input.dims()),
framework::vectorize(output.dims()), signal_size, fft_type,
value_type);
return key;
}
// Execute a pre-planned transform // Execute a pre-planned transform
static void exec_cufft_plan(const CuFFTConfig& config, void* in_data, static void exec_hipfft_plan_raw(const FFTConfig& config, void* in_data,
void* out_data, bool forward) { void* out_data, bool forward) {
auto& plan = config.plan(); auto& plan = config.plan();
#ifdef __HIPCC__
auto value_type = config.data_type(); auto value_type = config.data_type();
if (value_type == framework::proto::VarType::FP32) { if (value_type == framework::proto::VarType::FP32) {
switch (config.transform_type()) { switch (config.transform_type()) {
case FFTTransformType::C2C: { case FFTTransformType::C2C: {
CUFFT_CHECK(hipfftExecC2C(plan, static_cast<hipfftComplex*>(in_data), PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftExecC2C(
plan, static_cast<hipfftComplex*>(in_data),
static_cast<hipfftComplex*>(out_data), static_cast<hipfftComplex*>(out_data),
forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD)); forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD));
return; return;
} }
case FFTTransformType::R2C: { case FFTTransformType::R2C: {
CUFFT_CHECK(hipfftExecR2C(plan, static_cast<hipfftReal*>(in_data), PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftExecR2C(
plan, static_cast<hipfftReal*>(in_data),
static_cast<hipfftComplex*>(out_data))); static_cast<hipfftComplex*>(out_data)));
return; return;
} }
case FFTTransformType::C2R: { case FFTTransformType::C2R: {
CUFFT_CHECK(hipfftExecC2R(plan, static_cast<hipfftComplex*>(in_data), PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftExecC2R(
plan, static_cast<hipfftComplex*>(in_data),
static_cast<hipfftReal*>(out_data))); static_cast<hipfftReal*>(out_data)));
return; return;
} }
...@@ -288,20 +189,21 @@ static void exec_cufft_plan(const CuFFTConfig& config, void* in_data, ...@@ -288,20 +189,21 @@ static void exec_cufft_plan(const CuFFTConfig& config, void* in_data,
} else if (value_type == framework::proto::VarType::FP64) { } else if (value_type == framework::proto::VarType::FP64) {
switch (config.transform_type()) { switch (config.transform_type()) {
case FFTTransformType::C2C: { case FFTTransformType::C2C: {
CUFFT_CHECK(hipfftExecZ2Z(plan, PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftExecZ2Z(
static_cast<hipfftDoubleComplex*>(in_data), plan, static_cast<hipfftDoubleComplex*>(in_data),
static_cast<hipfftDoubleComplex*>(out_data), static_cast<hipfftDoubleComplex*>(out_data),
forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD)); forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD));
return; return;
} }
case FFTTransformType::R2C: { case FFTTransformType::R2C: {
CUFFT_CHECK(hipfftExecD2Z(plan, static_cast<hipfftDoubleReal*>(in_data), PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftExecD2Z(
plan, static_cast<hipfftDoubleReal*>(in_data),
static_cast<hipfftDoubleComplex*>(out_data))); static_cast<hipfftDoubleComplex*>(out_data)));
return; return;
} }
case FFTTransformType::C2R: { case FFTTransformType::C2R: {
CUFFT_CHECK(hipfftExecZ2D(plan, PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftExecZ2D(
static_cast<hipfftDoubleComplex*>(in_data), plan, static_cast<hipfftDoubleComplex*>(in_data),
static_cast<hipfftDoubleReal*>(out_data))); static_cast<hipfftDoubleReal*>(out_data)));
return; return;
} }
...@@ -309,28 +211,53 @@ static void exec_cufft_plan(const CuFFTConfig& config, void* in_data, ...@@ -309,28 +211,53 @@ static void exec_cufft_plan(const CuFFTConfig& config, void* in_data,
} }
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"hipFFT only support transforms of type float32 and float64")); "hipFFT only support transforms of type float32 and float64"));
#else
CUFFT_CHECK(platform::dynload::cufftXtExec(
plan, in_data, out_data, forward ? CUFFT_FORWARD : CUFFT_INVERSE));
#endif
} }
template <typename DeviceContext, typename Ti, typename To>
void exec_hipfft_plan(const DeviceContext& ctx, const FFTConfig& config,
framework::Tensor* input, framework::Tensor* output,
bool forward) {
auto fft_type = config.transform_type();
if (fft_type == FFTTransformType::C2R && forward) {
forward = false;
framework::Tensor input_conj(input->type());
input_conj.mutable_data<Ti>(input->dims(), ctx.GetPlace());
platform::ForRange<DeviceContext> for_range(ctx, input->numel());
math::ConjFunctor<Ti> functor(input->data<Ti>(), input->numel(),
input_conj.data<Ti>());
for_range(functor);
exec_hipfft_plan_raw(config, input_conj.data<void>(), output->data<void>(),
forward);
} else if (fft_type == FFTTransformType::R2C && !forward) {
forward = true;
framework::Tensor out_conj(output->type());
out_conj.mutable_data<To>(output->dims(), ctx.GetPlace());
exec_hipfft_plan_raw(config, input->data<void>(), out_conj.data<void>(),
forward);
platform::ForRange<DeviceContext> for_range(ctx, output->numel());
math::ConjFunctor<To> functor(out_conj.data<To>(), output->numel(),
output->data<To>());
for_range(functor);
} else {
exec_hipfft_plan_raw(config, input->data<void>(), output->data<void>(),
forward);
}
}
#endif
// Execute a general unnormalized fft operation (can be c2c, onesided r2c or // Execute a general unnormalized fft operation (can be c2c, onesided r2c or
// onesided c2r) // onesided c2r)
template <typename DeviceContext, typename Ti, typename To> template <typename DeviceContext, typename Ti, typename To>
void exec_fft(const DeviceContext& ctx, const Tensor* X, Tensor* out, void exec_fft(const DeviceContext& ctx, const Tensor* X, Tensor* out,
const std::vector<int64_t>& dim, bool forward) { const std::vector<int64_t>& dim, bool forward) {
const auto x_dims = framework::vectorize(X->dims()); const auto x_dims = framework::vectorize(X->dims());
const auto out_dims = framework::vectorize(out->dims());
const int64_t ndim = static_cast<int64_t>(X->dims().size()); const int64_t ndim = static_cast<int64_t>(X->dims().size());
const int64_t signal_ndim = static_cast<int64_t>(dim.size());
const int64_t batch_dims = ndim - signal_ndim;
auto tensor_place = ctx.GetPlace(); auto tensor_place = ctx.GetPlace();
// Transpose batch dimensions first, then with transforming dims // make a dim permutation
std::vector<int> dim_permute(ndim); std::vector<int> dim_permute(ndim);
std::vector<int> reverse_dim_permute(ndim);
std::vector<int64_t> trans_dims(ndim);
std::iota(dim_permute.begin(), dim_permute.end(), int{0}); std::iota(dim_permute.begin(), dim_permute.end(), int{0});
std::vector<bool> is_transformed_dim(ndim); std::vector<bool> is_transformed_dim(ndim);
for (const auto& d : dim) { for (const auto& d : dim) {
...@@ -342,167 +269,120 @@ void exec_fft(const DeviceContext& ctx, const Tensor* X, Tensor* out, ...@@ -342,167 +269,120 @@ void exec_fft(const DeviceContext& ctx, const Tensor* X, Tensor* out,
std::sort(dim_permute.begin(), batch_end); std::sort(dim_permute.begin(), batch_end);
std::copy(dim.cbegin(), dim.cend(), batch_end); std::copy(dim.cbegin(), dim.cend(), batch_end);
for (size_t i = 0; i < ndim; i++) { // transpose input according to dim permutation
trans_dims[i] = x_dims[dim_permute[i]]; // shape of input transpose auto transposed_input_shape = X->dims().transpose(dim_permute);
reverse_dim_permute[dim_permute[i]] = framework::Tensor transposed_input;
static_cast<int>(i); // reverse of dim permute transposed_input.Resize(transposed_input_shape);
} transposed_input.mutable_data<Ti>(tensor_place);
framework::Tensor input; TransCompute<DeviceContext, Ti>(ndim, ctx, *X, &transposed_input,
input.Resize(framework::make_ddim(trans_dims)); dim_permute);
input.mutable_data<Ti>(tensor_place);
/*
auto in_ret = TransposeSimple<Ti>::run(ctx, *X, dim_permute, input);
if (!in_ret) {
TransCompute<DeviceContext, Ti>(ndim, ctx, *X, input, dim_permute);
}
*/
TransCompute<DeviceContext, Ti>(ndim, ctx, *X, &input, dim_permute);
// Reshape batch dimensions into a single dimension // Reshape batch dimensions into a single dimension
std::vector<int64_t> batched_sizes(signal_ndim + 1); const int64_t signal_ndim = static_cast<int64_t>(dim.size());
std::vector<int64_t> collapsed_input_shape(signal_ndim + 1);
auto transposed_input_shape_ = framework::vectorize(transposed_input_shape);
const int64_t batch_dims = ndim - signal_ndim;
auto batch_size = auto batch_size =
std::accumulate(trans_dims.begin(), trans_dims.begin() + batch_dims, std::accumulate(transposed_input_shape_.begin(),
transposed_input_shape_.begin() + batch_dims,
static_cast<int>(1), std::multiplies<int>()); static_cast<int>(1), std::multiplies<int>());
batched_sizes[0] = batch_size; collapsed_input_shape[0] = batch_size;
std::copy(trans_dims.begin() + batch_dims, trans_dims.end(),
batched_sizes.begin() + 1);
input.Resize(framework::make_ddim(batched_sizes));
// Check the shape of transforming dims with input and output std::copy(transposed_input_shape_.begin() + batch_dims,
std::vector<int64_t> signal_size(signal_ndim + 1); transposed_input_shape_.end(), collapsed_input_shape.begin() + 1);
signal_size[0] = batch_size;
for (int64_t i = 0; i < signal_ndim; ++i) {
auto in_size = input.dims()[i + 1];
auto out_size = out_dims[dim[i]];
signal_size[i + 1] = std::max(in_size, out_size);
PADDLE_ENFORCE_EQ(
(in_size == signal_size[i + 1] ||
in_size == (signal_size[i + 1] / 2) + 1),
true,
platform::errors::InvalidArgument(
"The dimension[%d] of Input size: [%d] must be equal or half to "
"The dimension[%d] of Output size: [%d]",
dim[i], in_size, dim[i], out_size));
PADDLE_ENFORCE_EQ(
(out_size == signal_size[i + 1] ||
out_size == (signal_size[i + 1] / 2) + 1),
true,
platform::errors::InvalidArgument(
"The dimension[%d] of Output size: [%d] must be equal or half to "
"The dimension[%d] of Input size: [%d]",
dim[i], out_size, dim[i], in_size));
}
std::vector<int64_t> reshape_out_sizes(ndim); framework::Tensor& collapsed_input = transposed_input;
for (size_t i = 0; i < ndim; ++i) { collapsed_input.Resize(framework::make_ddim(collapsed_input_shape));
reshape_out_sizes[i] = out_dims[dim_permute[i]];
} // make a collpased output
std::vector<int64_t> batched_out_sizes(batched_sizes.begin(), const auto out_dims = framework::vectorize(out->dims());
batched_sizes.end()); std::vector<int64_t> collapsed_output_shape(1 + signal_ndim);
collapsed_output_shape[0] = batch_size;
for (size_t i = 0; i < dim.size(); ++i) { for (size_t i = 0; i < dim.size(); ++i) {
batched_out_sizes[i + 1] = out_dims[dim[i]]; collapsed_output_shape[i + 1] = out_dims[dim[i]];
}
framework::Tensor collapsed_output;
collapsed_output.Resize(framework::make_ddim(collapsed_output_shape));
collapsed_output.mutable_data<To>(tensor_place);
FFTConfig* config = nullptr;
#if defined(PADDLE_WITH_CUDA)
std::unique_ptr<FFTConfig> config_ = nullptr;
// create plan
FFTConfigKey key =
create_fft_configkey(collapsed_input, collapsed_output, signal_ndim);
if (CUFFT_VERSION < 10200) {
const int64_t device_id = static_cast<int64_t>(
reinterpret_cast<const platform::CUDAPlace*>(&collapsed_input.place())
->GetDeviceId());
FFTConfigCache& plan_cache = get_fft_plan_cache(device_id);
std::unique_lock<std::mutex> guard(plan_cache.mutex, std::defer_lock);
guard.lock();
config = &(plan_cache.lookup(key));
} else {
config_ = std::make_unique<FFTConfig>(key);
config = config_.get();
} }
// output
framework::Tensor output;
output.Resize(framework::make_ddim(batched_out_sizes));
output.mutable_data<To>(tensor_place);
// Create the transform plan (either from cache or locally)
const auto value_type = framework::IsComplexType(input.type())
? framework::ToRealType(input.type())
: input.type();
auto fft_type = GetFFTTransformType(input.type(), output.type());
PlanKey Key(framework::vectorize(input.dims()),
framework::vectorize(output.dims()), signal_size, fft_type,
value_type);
CuFFTConfig uncached_plan(Key);
CuFFTConfig* config = &uncached_plan;
auto& plan = config->plan();
// prepare cufft for execution // prepare cufft for execution
CUFFT_CHECK(platform::dynload::cufftSetStream(plan, ctx.stream())); PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cufftSetStream(config->plan(), ctx.stream()));
framework::Tensor workspace_tensor; framework::Tensor workspace_tensor;
workspace_tensor.mutable_data<To>(tensor_place, config->workspace_size()); workspace_tensor.mutable_data<To>(tensor_place, config->workspace_size());
CUFFT_CHECK( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cufftSetWorkArea(
platform::dynload::cufftSetWorkArea(plan, workspace_tensor.data<To>())); config->plan(), workspace_tensor.data<To>()));
// execute transform plan // execute transform plan
if (fft_type == FFTTransformType::C2R && forward) { exec_cufft_plan<DeviceContext, Ti, To>(ctx, *config, &collapsed_input,
forward = false; &collapsed_output, forward);
framework::Tensor input_conj(input.type());
input_conj.mutable_data<Ti>(input.dims(), ctx.GetPlace()); #elif defined(PADDLE_WITH_HIP)
platform::ForRange<DeviceContext> for_range(ctx, input.numel()); // create plan
math::ConjFunctor<Ti> functor(input.data<Ti>(), input.numel(), FFTConfigKey key =
input_conj.data<Ti>()); create_fft_configkey(collapsed_input, collapsed_output, signal_ndim);
for_range(functor); const int64_t device_id = static_cast<int64_t>(
exec_cufft_plan(*config, input_conj.data<void>(), output.data<void>(), reinterpret_cast<const platform::CUDAPlace*>(&collapsed_input.place())
forward); ->GetDeviceId());
} else if (fft_type == FFTTransformType::R2C && !forward) { FFTConfigCache& plan_cache = get_fft_plan_cache(device_id);
forward = true; std::unique_lock<std::mutex> guard(plan_cache.mutex, std::defer_lock);
framework::Tensor out_conj(output.type()); guard.lock();
out_conj.mutable_data<To>(output.dims(), ctx.GetPlace()); config = &(plan_cache.lookup(key));
exec_cufft_plan(*config, input.data<void>(), out_conj.data<void>(),
forward);
platform::ForRange<DeviceContext> for_range(ctx, output.numel()); // prepare cufft for execution
math::ConjFunctor<To> functor(out_conj.data<To>(), output.numel(), PADDLE_ENFORCE_CUDA_SUCCESS(
output.data<To>()); platform::dynload::hipfftSetStream(config->plan(), ctx.stream()));
for_range(functor); framework::Tensor workspace_tensor;
} else { workspace_tensor.mutable_data<To>(tensor_place, config->workspace_size());
exec_cufft_plan(*config, input.data<void>(), output.data<void>(), forward); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftSetWorkArea(
} config->plan(), workspace_tensor.data<To>()));
// execute transform plan
exec_hipfft_plan<DeviceContext, Ti, To>(ctx, *config, &collapsed_input,
&collapsed_output, forward);
#endif
// Inverting output by reshape and transpose to original batch and dimension // Inverting output by reshape and transpose to original batch and dimension
output.Resize(framework::make_ddim(reshape_out_sizes)); auto transposed_out_shape = out->dims().transpose(dim_permute);
out->Resize(framework::make_ddim(out_dims));
TransCompute<DeviceContext, To>(ndim, ctx, output, out, reverse_dim_permute);
}
// Calculates the normalization constant collapsed_output.Resize(transposed_out_shape);
double fft_normalization_scale(FFTNormMode normalization, auto& transposed_output = collapsed_output;
const std::vector<int64_t>& sizes,
const std::vector<int64_t>& dims) {
// auto norm = static_cast<fft_norm_mode>(normalization);
if (normalization == FFTNormMode::none) {
return static_cast<double>(1.0);
}
int64_t signal_numel = 1; std::vector<int> reverse_dim_permute(ndim);
for (auto dim : dims) { for (size_t i = 0; i < ndim; i++) {
signal_numel *= sizes[dim]; reverse_dim_permute[dim_permute[i]] = i;
} }
const double scale_denom = (normalization == FFTNormMode::by_sqrt_n)
? std::sqrt(signal_numel)
: static_cast<double>(signal_numel);
return static_cast<double>(1.0 / scale_denom);
}
template <typename DeviceContext, typename T> TransCompute<DeviceContext, To>(ndim, ctx, transposed_output, out,
void exec_normalization(const DeviceContext& ctx, const Tensor* in, Tensor* out, reverse_dim_permute);
FFTNormMode normalization,
const std::vector<int64_t>& sizes,
const std::vector<int64_t>& axes) {
double scale = fft_normalization_scale(normalization, sizes, axes);
if (scale != 1.0) {
auto eigen_out = framework::EigenVector<T>::Flatten(*out);
auto eigen_in = framework::EigenVector<T>::Flatten(*in);
auto dev = ctx.eigen_device();
EigenScale<Eigen::GpuDevice, T>::Eval(*dev, eigen_out, eigen_in,
static_cast<T>(scale),
static_cast<T>(0), false);
} else {
framework::TensorCopy(*in, ctx.GetPlace(), out);
}
} }
} // anonymous namespace } // anonymous namespace
// Use the optimized path to perform single R2C or C2R if transformation dim is // Use the optimized path to perform single R2C or C2R if transformation dim is
// supported by cuFFT // supported by cuFFT
bool use_optimized_cufft_path(const std::vector<int64_t>& axes) { bool use_optimized_fft_path(const std::vector<int64_t>& axes) {
// For performance reason, when axes starts with (0, 1), do not use the // For performance reason, when axes starts with (0, 1), do not use the
// optimized path. // optimized path.
if (axes.size() > kMaxCUFFTNdim || if (axes.size() > kMaxFFTNdim ||
(axes.size() >= 2 && axes[0] == 0 && axes[1] == 1)) { (axes.size() >= 2 && axes[0] == 0 && axes[1] == 1)) {
return false; return false;
} else { } else {
...@@ -532,7 +412,7 @@ struct FFTC2CFunctor<platform::CUDADeviceContext, Ti, To> { ...@@ -532,7 +412,7 @@ struct FFTC2CFunctor<platform::CUDADeviceContext, Ti, To> {
while (true) { while (true) {
max_dims = max_dims =
std::min(static_cast<size_t>(kMaxCUFFTNdim), working_axes.size()); std::min(static_cast<size_t>(kMaxFFTNdim), working_axes.size());
first_dims.assign(working_axes.end() - max_dims, working_axes.end()); first_dims.assign(working_axes.end() - max_dims, working_axes.end());
exec_fft<platform::CUDADeviceContext, Ti, To>(ctx, p_working_tensor, exec_fft<platform::CUDADeviceContext, Ti, To>(ctx, p_working_tensor,
...@@ -559,7 +439,7 @@ struct FFTC2RFunctor<platform::CUDADeviceContext, Ti, To> { ...@@ -559,7 +439,7 @@ struct FFTC2RFunctor<platform::CUDADeviceContext, Ti, To> {
std::vector<int64_t> in_dims = framework::vectorize(X->dims()); std::vector<int64_t> in_dims = framework::vectorize(X->dims());
std::vector<int64_t> out_dims = framework::vectorize(out->dims()); std::vector<int64_t> out_dims = framework::vectorize(out->dims());
if (use_optimized_cufft_path(axes)) { if (use_optimized_fft_path(axes)) {
framework::Tensor x_copy(X->type()); framework::Tensor x_copy(X->type());
x_copy.mutable_data<Ti>(X->dims(), ctx.GetPlace()); x_copy.mutable_data<Ti>(X->dims(), ctx.GetPlace());
framework::TensorCopy(*X, ctx.GetPlace(), &x_copy); framework::TensorCopy(*X, ctx.GetPlace(), &x_copy);
......
...@@ -7,7 +7,7 @@ if (NOT WITH_NV_JETSON) ...@@ -7,7 +7,7 @@ if (NOT WITH_NV_JETSON)
endif() endif()
if (WITH_ROCM) if (WITH_ROCM)
list(APPEND HIP_SRCS rocblas.cc miopen.cc hiprand.cc) list(APPEND HIP_SRCS rocblas.cc miopen.cc hiprand.cc hipfft.cc)
endif() endif()
# There is no macOS version of NCCL. # There is no macOS version of NCCL.
...@@ -49,3 +49,9 @@ endif() ...@@ -49,3 +49,9 @@ endif()
cc_library(dynload_lapack SRCS lapack.cc DEPS dynamic_loader) cc_library(dynload_lapack SRCS lapack.cc DEPS dynamic_loader)
add_dependencies(dynload_lapack extern_lapack) add_dependencies(dynload_lapack extern_lapack)
# TODO(TJ): add iomp, mkldnn? # TODO(TJ): add iomp, mkldnn?
if (MKL_FOUND AND WITH_ONEMKL)
message("ONEMKL INCLUDE directory is ${MKL_INCLUDE}")
cc_library(dynload_mklrt SRCS mklrt.cc DEPS dynamic_loader)
target_include_directories(dynload_mklrt PRIVATE ${MKL_INCLUDE})
endif()
...@@ -53,6 +53,12 @@ DEFINE_string(mklml_dir, "", "Specify path for loading libmklml_intel.so."); ...@@ -53,6 +53,12 @@ DEFINE_string(mklml_dir, "", "Specify path for loading libmklml_intel.so.");
DEFINE_string(lapack_dir, "", "Specify path for loading liblapack.so."); DEFINE_string(lapack_dir, "", "Specify path for loading liblapack.so.");
DEFINE_string(mkl_dir, "",
"Specify path for loading libmkl_rt.so. "
"For insrance, /opt/intel/oneapi/mkl/latest/lib/intel64/."
"If default, "
"dlopen will search mkl from LD_LIBRARY_PATH");
DEFINE_string(op_dir, "", "Specify path for loading user-defined op library."); DEFINE_string(op_dir, "", "Specify path for loading user-defined op library.");
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
...@@ -350,6 +356,16 @@ void* GetCurandDsoHandle() { ...@@ -350,6 +356,16 @@ void* GetCurandDsoHandle() {
#endif #endif
} }
#ifdef PADDLE_WITH_HIP
void* GetROCFFTDsoHandle() {
#if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "librocfft.dylib");
#else
return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "librocfft.so");
#endif
}
#endif
void* GetNvjpegDsoHandle() { void* GetNvjpegDsoHandle() {
#if defined(__APPLE__) || defined(__OSX__) #if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libnvjpeg.dylib"); return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libnvjpeg.dylib");
...@@ -518,6 +534,16 @@ void* GetCUFFTDsoHandle() { ...@@ -518,6 +534,16 @@ void* GetCUFFTDsoHandle() {
#endif #endif
} }
void* GetMKLRTDsoHandle() {
#if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(FLAGS_mkl_dir, "libmkl_rt.dylib");
#elif defined(_WIN32)
return GetDsoHandleFromSearchPath(FLAGS_mkl_dir, "mkl_rt.dll");
#else
return GetDsoHandleFromSearchPath(FLAGS_mkl_dir, "libmkl_rt.so");
#endif
}
} // namespace dynload } // namespace dynload
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -43,6 +43,8 @@ void* GetLAPACKDsoHandle(); ...@@ -43,6 +43,8 @@ void* GetLAPACKDsoHandle();
void* GetOpDsoHandle(const std::string& dso_name); void* GetOpDsoHandle(const std::string& dso_name);
void* GetNvtxDsoHandle(); void* GetNvtxDsoHandle();
void* GetCUFFTDsoHandle(); void* GetCUFFTDsoHandle();
void* GetMKLRTDsoHandle();
void* GetROCFFTDsoHandle();
void SetPaddleLibPath(const std::string&); void SetPaddleLibPath(const std::string&);
} // namespace dynload } // namespace dynload
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/dynload/hipfft.h"
namespace paddle {
namespace platform {
namespace dynload {
std::once_flag hipfft_dso_flag;
void *hipfft_dso_handle;
#define DEFINE_WRAP(__name) DynLoad__##__name __name
HIPFFT_FFT_ROUTINE_EACH(DEFINE_WRAP);
} // namespace dynload
} // namespace platform
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_HIP
#include <hipfft.h>
#include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/platform/port.h"
namespace paddle {
namespace platform {
namespace dynload {
extern std::once_flag hipfft_dso_flag;
extern void *hipfft_dso_handle;
#define DECLARE_DYNAMIC_LOAD_HIPFFT_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
using hipfftFunc = decltype(&::__name); \
std::call_once(hipfft_dso_flag, []() { \
hipfft_dso_handle = paddle::platform::dynload::GetROCFFTDsoHandle(); \
}); \
static void *p_##__name = dlsym(hipfft_dso_handle, #__name); \
return reinterpret_cast<hipfftFunc>(p_##__name)(args...); \
} \
}; \
extern DynLoad__##__name __name
#define HIPFFT_FFT_ROUTINE_EACH(__macro) \
__macro(hipfftPlan1d); \
__macro(hipfftPlan2d); \
__macro(hipfftPlan3d); \
__macro(hipfftPlanMany); \
__macro(hipfftMakePlan1d); \
__macro(hipfftMakePlanMany); \
__macro(hipfftMakePlanMany64); \
__macro(hipfftGetSizeMany64); \
__macro(hipfftEstimate1d); \
__macro(hipfftEstimate2d); \
__macro(hipfftEstimate3d); \
__macro(hipfftEstimateMany); \
__macro(hipfftCreate); \
__macro(hipfftGetSize1d); \
__macro(hipfftGetSizeMany); \
__macro(hipfftGetSize); \
__macro(hipfftSetWorkArea); \
__macro(hipfftSetAutoAllocation); \
__macro(hipfftExecC2C); \
__macro(hipfftExecR2C); \
__macro(hipfftExecC2R); \
__macro(hipfftExecZ2Z); \
__macro(hipfftExecD2Z); \
__macro(hipfftExecZ2D); \
__macro(hipfftSetStream); \
__macro(hipfftDestroy); \
__macro(hipfftGetVersion); \
__macro(hipfftGetProperty);
HIPFFT_FFT_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_HIPFFT_WRAP);
inline const char *hipfftGetErrorString(hipfftResult_t status) {
switch (status) {
case HIPFFT_SUCCESS:
return "'HIPFFT_SUCCESS'. The hipFFT operation was successful.";
case HIPFFT_INVALID_PLAN:
return "'HIPFFT_INVALID_PLAN'. hipFFT was passed an invalid plan handle.";
case HIPFFT_ALLOC_FAILED:
return "'HIPFFT_ALLOC_FAILED'. hipFFT failed to allocate GPU or CPU "
"memory.";
case HIPFFT_INVALID_TYPE:
return "'HIPFFT_INVALID_TYPE'. No longer used.";
case HIPFFT_INVALID_VALUE:
return "'HIPFFT_INVALID_VALUE'. User specified an invalid pointer or "
"parameter.";
case HIPFFT_INTERNAL_ERROR:
return "'HIPFFT_INTERNAL_ERROR'. Driver or internal hipFFT library "
"error.";
case HIPFFT_EXEC_FAILED:
return "'HIPFFT_EXEC_FAILED'. Failed to execute an FFT on the GPU.";
case HIPFFT_SETUP_FAILED:
return "'HIPFFT_SETUP_FAILED'. The hipFFT library failed to initialize.";
case HIPFFT_INVALID_SIZE:
return "'HIPFFT_INVALID_SIZE'. User specified an invalid transform size.";
case HIPFFT_UNALIGNED_DATA:
return "'HIPFFT_UNALIGNED_DATA'. No longer used.";
case HIPFFT_INCOMPLETE_PARAMETER_LIST:
return "'HIPFFT_INCOMPLETE_PARAMETER_LIST'. Missing parameters in call.";
case HIPFFT_INVALID_DEVICE:
return "'HIPFFT_INVALID_DEVICE'. Execution of a plan was on different "
"GPU than plan creation.";
case HIPFFT_PARSE_ERROR:
return "'HIPFFT_PARSE_ERROR'. Internal plan database error.";
case HIPFFT_NO_WORKSPACE:
return "'HIPFFT_NO_WORKSPACE'. No workspace has been provided prior to "
"plan execution.";
case HIPFFT_NOT_IMPLEMENTED:
return "'HIPFFT_NOT_IMPLEMENTED'. Function does not implement "
"functionality for parameters given.";
case HIPFFT_NOT_SUPPORTED:
return "'HIPFFT_NOT_SUPPORTED'. Operation is not supported for "
"parameters given.";
default:
return "HIPFFT_STATUS_UNKNOWN_ERROR";
}
}
} // namespace dynload
} // namespace platform
} // namespace paddle
#endif
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/dynload/mklrt.h"
namespace paddle {
namespace platform {
namespace dynload {
std::once_flag mklrt_dso_flag;
void* mklrt_dso_handle = nullptr;
#define DEFINE_WRAP(__name) DynLoad__##__name __name
MKLDFTI_ROUTINE_EACH(DEFINE_WRAP);
DFTI_EXTERN MKL_LONG DftiCreateDescriptorX(DFTI_DESCRIPTOR_HANDLE* desc,
enum DFTI_CONFIG_VALUE prec,
enum DFTI_CONFIG_VALUE domain,
MKL_LONG dim, MKL_LONG* sizes) {
if (prec == DFTI_SINGLE) {
if (dim == 1) {
return DftiCreateDescriptor_s_1d(desc, domain, sizes[0]);
} else {
return DftiCreateDescriptor_s_md(desc, domain, dim, sizes);
}
} else if (prec == DFTI_DOUBLE) {
if (dim == 1) {
return DftiCreateDescriptor_d_1d(desc, domain, sizes[0]);
} else {
return DftiCreateDescriptor_d_md(desc, domain, dim, sizes);
}
} else {
return DftiCreateDescriptor(desc, prec, domain, dim, sizes);
}
}
} // namespace dynload
} // namespace platform
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <mkl_dfti.h>
#include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/platform/port.h"
namespace paddle {
namespace platform {
namespace dynload {
extern std::once_flag mklrt_dso_flag;
extern void* mklrt_dso_handle;
/**
* The following macro definition can generate structs
* (for each function) to dynamic load mkldfti routine
* via operator overloading.
*/
#define DYNAMIC_LOAD_MKLRT_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
using mklrtFunc = decltype(&::__name); \
std::call_once(mklrt_dso_flag, []() { \
mklrt_dso_handle = paddle::platform::dynload::GetMKLRTDsoHandle(); \
}); \
static void* p_##__name = dlsym(mklrt_dso_handle, #__name); \
return reinterpret_cast<mklrtFunc>(p_##__name)(args...); \
} \
}; \
extern DynLoad__##__name __name
// mkl_dfti.h has a macro that shadows the function with the same name
// un-defeine this macro so as to export that function
#undef DftiCreateDescriptor
#define MKLDFTI_ROUTINE_EACH(__macro) \
__macro(DftiCreateDescriptor); \
__macro(DftiCreateDescriptor_s_1d); \
__macro(DftiCreateDescriptor_d_1d); \
__macro(DftiCreateDescriptor_s_md); \
__macro(DftiCreateDescriptor_d_md); \
__macro(DftiSetValue); \
__macro(DftiGetValue); \
__macro(DftiCommitDescriptor); \
__macro(DftiComputeForward); \
__macro(DftiComputeBackward); \
__macro(DftiFreeDescriptor); \
__macro(DftiErrorClass); \
__macro(DftiErrorMessage);
MKLDFTI_ROUTINE_EACH(DYNAMIC_LOAD_MKLRT_WRAP)
#undef DYNAMIC_LOAD_MKLRT_WRAP
// define another function to avoid naming conflict
DFTI_EXTERN MKL_LONG DftiCreateDescriptorX(DFTI_DESCRIPTOR_HANDLE* desc,
enum DFTI_CONFIG_VALUE prec,
enum DFTI_CONFIG_VALUE domain,
MKL_LONG dim, MKL_LONG* sizes);
} // namespace dynload
} // namespace platform
} // namespace paddle
...@@ -31,6 +31,7 @@ limitations under the License. */ ...@@ -31,6 +31,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include <cublas_v2.h> #include <cublas_v2.h>
#include <cudnn.h> #include <cudnn.h>
#include <cufft.h>
#include <curand.h> #include <curand.h>
#include <thrust/system/cuda/error.h> #include <thrust/system/cuda/error.h>
#include <thrust/system_error.h> #include <thrust/system_error.h>
...@@ -85,6 +86,7 @@ limitations under the License. */ ...@@ -85,6 +86,7 @@ limitations under the License. */
#endif // PADDLE_WITH_CUDA #endif // PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/dynload/hipfft.h"
#include "paddle/fluid/platform/dynload/hiprand.h" #include "paddle/fluid/platform/dynload/hiprand.h"
#include "paddle/fluid/platform/dynload/miopen.h" #include "paddle/fluid/platform/dynload/miopen.h"
#include "paddle/fluid/platform/dynload/rocblas.h" #include "paddle/fluid/platform/dynload/rocblas.h"
...@@ -714,6 +716,7 @@ DEFINE_EXTERNAL_API_TYPE(curandStatus_t, CURAND_STATUS_SUCCESS, CURAND); ...@@ -714,6 +716,7 @@ DEFINE_EXTERNAL_API_TYPE(curandStatus_t, CURAND_STATUS_SUCCESS, CURAND);
DEFINE_EXTERNAL_API_TYPE(cudnnStatus_t, CUDNN_STATUS_SUCCESS, CUDNN); DEFINE_EXTERNAL_API_TYPE(cudnnStatus_t, CUDNN_STATUS_SUCCESS, CUDNN);
DEFINE_EXTERNAL_API_TYPE(cublasStatus_t, CUBLAS_STATUS_SUCCESS, CUBLAS); DEFINE_EXTERNAL_API_TYPE(cublasStatus_t, CUBLAS_STATUS_SUCCESS, CUBLAS);
DEFINE_EXTERNAL_API_TYPE(cusolverStatus_t, CUSOLVER_STATUS_SUCCESS, CUSOLVER); DEFINE_EXTERNAL_API_TYPE(cusolverStatus_t, CUSOLVER_STATUS_SUCCESS, CUSOLVER);
DEFINE_EXTERNAL_API_TYPE(cufftResult_t, CUFFT_SUCCESS, CUFFT);
#if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL) #if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
DEFINE_EXTERNAL_API_TYPE(ncclResult_t, ncclSuccess, NCCL); DEFINE_EXTERNAL_API_TYPE(ncclResult_t, ncclSuccess, NCCL);
...@@ -751,6 +754,8 @@ inline const char* GetErrorMsgUrl(T status) { ...@@ -751,6 +754,8 @@ inline const char* GetErrorMsgUrl(T status) {
return "https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/" return "https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/"
"types.html#ncclresult-t"; "types.html#ncclresult-t";
break; break;
case platform::proto::ApiType::CUFFT:
return "https://docs.nvidia.com/cuda/cufft/index.html#cufftresult";
default: default:
return "Unknown type of External API, can't get error message URL!"; return "Unknown type of External API, can't get error message URL!";
break; break;
...@@ -839,6 +844,7 @@ template std::string GetExternalErrorMsg<curandStatus_t>(curandStatus_t); ...@@ -839,6 +844,7 @@ template std::string GetExternalErrorMsg<curandStatus_t>(curandStatus_t);
template std::string GetExternalErrorMsg<cudnnStatus_t>(cudnnStatus_t); template std::string GetExternalErrorMsg<cudnnStatus_t>(cudnnStatus_t);
template std::string GetExternalErrorMsg<cublasStatus_t>(cublasStatus_t); template std::string GetExternalErrorMsg<cublasStatus_t>(cublasStatus_t);
template std::string GetExternalErrorMsg<cusolverStatus_t>(cusolverStatus_t); template std::string GetExternalErrorMsg<cusolverStatus_t>(cusolverStatus_t);
template std::string GetExternalErrorMsg<cufftResult_t>(cufftResult_t);
#if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL) #if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
template std::string GetExternalErrorMsg<ncclResult_t>(ncclResult_t); template std::string GetExternalErrorMsg<ncclResult_t>(ncclResult_t);
#endif #endif
...@@ -899,6 +905,15 @@ inline std::string build_nvidia_error_msg(cusolverStatus_t stat) { ...@@ -899,6 +905,15 @@ inline std::string build_nvidia_error_msg(cusolverStatus_t stat) {
return sout.str(); return sout.str();
} }
/*************** CUFFT ERROR ***************/
inline bool is_error(cufftResult_t stat) { return stat != CUFFT_SUCCESS; }
inline std::string build_nvidia_error_msg(cufftResult_t stat) {
std::ostringstream sout;
sout << "CUFFT error(" << stat << "). " << GetExternalErrorMsg(stat);
return sout.str();
}
/**************** NCCL ERROR ****************/ /**************** NCCL ERROR ****************/
#if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL) #if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
inline bool is_error(ncclResult_t nccl_result) { inline bool is_error(ncclResult_t nccl_result) {
...@@ -1099,6 +1114,14 @@ inline std::string build_rocm_error_msg(ncclResult_t nccl_result) { ...@@ -1099,6 +1114,14 @@ inline std::string build_rocm_error_msg(ncclResult_t nccl_result) {
} }
#endif // not(__APPLE__) and PADDLE_WITH_NCCL #endif // not(__APPLE__) and PADDLE_WITH_NCCL
/***** HIPFFT ERROR *****/
inline bool is_error(hipfftResult_t stat) { return stat != HIPFFT_SUCCESS; }
inline std::string build_rocm_error_msg(hipfftResult_t stat) {
std::string msg(" HIPFFT error, ");
return msg + platform::dynload::hipfftGetErrorString(stat) + " ";
}
namespace details { namespace details {
template <typename T> template <typename T>
...@@ -1115,6 +1138,7 @@ DEFINE_EXTERNAL_API_TYPE(hipError_t, hipSuccess); ...@@ -1115,6 +1138,7 @@ DEFINE_EXTERNAL_API_TYPE(hipError_t, hipSuccess);
DEFINE_EXTERNAL_API_TYPE(hiprandStatus_t, HIPRAND_STATUS_SUCCESS); DEFINE_EXTERNAL_API_TYPE(hiprandStatus_t, HIPRAND_STATUS_SUCCESS);
DEFINE_EXTERNAL_API_TYPE(miopenStatus_t, miopenStatusSuccess); DEFINE_EXTERNAL_API_TYPE(miopenStatus_t, miopenStatusSuccess);
DEFINE_EXTERNAL_API_TYPE(rocblas_status, rocblas_status_success); DEFINE_EXTERNAL_API_TYPE(rocblas_status, rocblas_status_success);
DEFINE_EXTERNAL_API_TYPE(hipfftResult_t, HIPFFT_SUCCESS);
#if !defined(__APPLE__) && defined(PADDLE_WITH_RCCL) #if !defined(__APPLE__) && defined(PADDLE_WITH_RCCL)
DEFINE_EXTERNAL_API_TYPE(ncclResult_t, ncclSuccess); DEFINE_EXTERNAL_API_TYPE(ncclResult_t, ncclSuccess);
......
...@@ -9,10 +9,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -9,10 +9,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/platform/enforce.h"
#include <list> #include <list>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/platform/enforce.h"
TEST(ENFORCE, OK) { TEST(ENFORCE, OK) {
PADDLE_ENFORCE(true, paddle::platform::errors::Unavailable( PADDLE_ENFORCE(true, paddle::platform::errors::Unavailable(
...@@ -330,6 +331,10 @@ TEST(enforce, hip_success) { ...@@ -330,6 +331,10 @@ TEST(enforce, hip_success) {
CheckCudaStatusFailure(rocblas_status_invalid_handle, "Rocblas error")); CheckCudaStatusFailure(rocblas_status_invalid_handle, "Rocblas error"));
EXPECT_TRUE( EXPECT_TRUE(
CheckCudaStatusFailure(rocblas_status_invalid_value, "Rocblas error")); CheckCudaStatusFailure(rocblas_status_invalid_value, "Rocblas error"));
EXPECT_TRUE(CheckCudaStatusSuccess(HIPFFT_SUCCESS));
EXPECT_TRUE(CheckCudaStatusFailure(HIPFFT_INVALID_PLAN, "HIPFFT error"));
EXPECT_TRUE(CheckCudaStatusFailure(HIPFFT_ALLOC_FAILED, "HIPFFT error"));
#if !defined(__APPLE__) && defined(PADDLE_WITH_RCCL) #if !defined(__APPLE__) && defined(PADDLE_WITH_RCCL)
EXPECT_TRUE(CheckCudaStatusSuccess(ncclSuccess)); EXPECT_TRUE(CheckCudaStatusSuccess(ncclSuccess));
EXPECT_TRUE(CheckCudaStatusFailure(ncclUnhandledCudaError, "Rccl error")); EXPECT_TRUE(CheckCudaStatusFailure(ncclUnhandledCudaError, "Rccl error"));
...@@ -418,6 +423,25 @@ TEST(enforce, cuda_success) { ...@@ -418,6 +423,25 @@ TEST(enforce, cuda_success) {
"negative vector size, for example).To correct: ensure that all the " "negative vector size, for example).To correct: ensure that all the "
"parameters being passed have valid values")); "parameters being passed have valid values"));
EXPECT_TRUE(CheckCudaStatusSuccess(CUFFT_SUCCESS));
EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_INVALID_PLAN, "CUFFT error"));
EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_ALLOC_FAILED, "CUFFT error"));
EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_INVALID_TYPE, "CUFFT error"));
EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_INVALID_VALUE, "CUFFT error"));
EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_INTERNAL_ERROR, "CUFFT error"));
EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_EXEC_FAILED, "CUFFT error"));
EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_SETUP_FAILED, "CUFFT error"));
EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_INVALID_SIZE, "CUFFT error"));
EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_UNALIGNED_DATA, "CUFFT error"));
EXPECT_TRUE(
CheckCudaStatusFailure(CUFFT_INCOMPLETE_PARAMETER_LIST, "CUFFT error"));
EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_INVALID_DEVICE, "CUFFT error"));
EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_PARSE_ERROR, "CUFFT error"));
EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_NO_WORKSPACE, "CUFFT error"));
EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_NOT_IMPLEMENTED, "CUFFT error"));
EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_LICENSE_ERROR, "CUFFT error"));
EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_NOT_SUPPORTED, "CUFFT error"));
#if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL) #if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
EXPECT_TRUE(CheckCudaStatusSuccess(ncclSuccess)); EXPECT_TRUE(CheckCudaStatusSuccess(ncclSuccess));
EXPECT_TRUE(CheckCudaStatusFailure(ncclUnhandledCudaError, "NCCL error")); EXPECT_TRUE(CheckCudaStatusFailure(ncclUnhandledCudaError, "NCCL error"));
......
...@@ -24,6 +24,7 @@ enum ApiType { ...@@ -24,6 +24,7 @@ enum ApiType {
CUBLAS = 3; CUBLAS = 3;
CUSOLVER = 4; CUSOLVER = 4;
NCCL = 5; NCCL = 5;
CUFFT = 6;
} }
message MessageDesc { message MessageDesc {
......
...@@ -64,7 +64,6 @@ import paddle.reader # noqa: F401 ...@@ -64,7 +64,6 @@ import paddle.reader # noqa: F401
import paddle.static # noqa: F401 import paddle.static # noqa: F401
import paddle.vision # noqa: F401 import paddle.vision # noqa: F401
from .tensor import fft
from .tensor.random import bernoulli # noqa: F401 from .tensor.random import bernoulli # noqa: F401
from .tensor.attribute import rank # noqa: F401 from .tensor.attribute import rank # noqa: F401
...@@ -297,6 +296,8 @@ from .hapi import summary # noqa: F401 ...@@ -297,6 +296,8 @@ from .hapi import summary # noqa: F401
from .hapi import flops # noqa: F401 from .hapi import flops # noqa: F401
from . import hub # noqa: F401 from . import hub # noqa: F401
from . import linalg # noqa: F401 from . import linalg # noqa: F401
from . import fft # noqa: F401
from . import signal # noqa: F401
import paddle.text # noqa: F401 import paddle.text # noqa: F401
import paddle.vision # noqa: F401 import paddle.vision # noqa: F401
......
...@@ -15,30 +15,30 @@ ...@@ -15,30 +15,30 @@
from typing import Sequence from typing import Sequence
import numpy as np import numpy as np
import paddle import paddle
from .attribute import is_complex, is_floating_point, is_interger, _real_to_complex_dtype, _complex_to_real_dtype from .tensor.attribute import is_complex, is_floating_point, is_interger, _real_to_complex_dtype, _complex_to_real_dtype
from ..fluid.framework import in_dygraph_mode from .fluid.framework import in_dygraph_mode
from .. import _C_ops from . import _C_ops
from ..fluid.data_feeder import check_variable_and_dtype from .fluid.data_feeder import check_variable_and_dtype
from ..fluid.layer_helper import LayerHelper from .fluid.layer_helper import LayerHelper
__all__ = [ __all__ = [
'fft', 'fft',
'fft2',
'fftn',
'ifft', 'ifft',
'ifft2',
'ifftn',
'rfft', 'rfft',
'rfft2',
'rfftn',
'irfft', 'irfft',
'irfft2',
'irfftn',
'hfft', 'hfft',
'hfft2',
'hfftn',
'ihfft', 'ihfft',
'fft2',
'ifft2',
'rfft2',
'irfft2',
'hfft2',
'ihfft2', 'ihfft2',
'fftn',
'ifftn',
'rfftn',
'irfftn',
'hfftn',
'ihfftn', 'ihfftn',
'fftfreq', 'fftfreq',
'rfftfreq', 'rfftfreq',
...@@ -362,7 +362,7 @@ def irfft(x, n=None, axis=-1, norm="backward", name=None): ...@@ -362,7 +362,7 @@ def irfft(x, n=None, axis=-1, norm="backward", name=None):
xp = paddle.to_tensor(x) xp = paddle.to_tensor(x)
irfft_xp = paddle.fft.irfft(xp).numpy() irfft_xp = paddle.fft.irfft(xp).numpy()
print(irfft_xp) print(irfft_xp)
# [0. 0. 0. 4.] # [0. 1. 0. 0.]
""" """
return fft_c2r(x, n, axis, norm, forward=False, name=name) return fft_c2r(x, n, axis, norm, forward=False, name=name)
...@@ -500,7 +500,7 @@ def fftn(x, s=None, axes=None, norm="backward", name=None): ...@@ -500,7 +500,7 @@ def fftn(x, s=None, axes=None, norm="backward", name=None):
import numpy as np import numpy as np
import paddle import paddle
x = x = np.mgrid[:4, :4, :4][1] x = np.mgrid[:4, :4, :4][1]
xp = paddle.to_tensor(x) xp = paddle.to_tensor(x)
fftn_xp = paddle.fft.fftn(xp, axes=(1, 2)).numpy() fftn_xp = paddle.fft.fftn(xp, axes=(1, 2)).numpy()
print(fftn_xp) print(fftn_xp)
...@@ -654,9 +654,9 @@ def rfftn(x, s=None, axes=None, norm="backward", name=None): ...@@ -654,9 +654,9 @@ def rfftn(x, s=None, axes=None, norm="backward", name=None):
# use axes(2, 0) # use axes(2, 0)
print(paddle.fft.rfftn(x, axes=(2, 0))) print(paddle.fft.rfftn(x, axes=(2, 0)))
# Tensor(shape=[2, 3, 3], dtype=complex64, place=CUDAPlace(0), stop_gradient=True, # Tensor(shape=[2, 3, 3], dtype=complex64, place=CUDAPlace(0), stop_gradient=True,
# [[[(24+0j), 0j , 0j ], # [[[(8+0j), 0j , 0j ],
# [0j , 0j , 0j ], # [(8+0j), 0j , 0j ],
# [0j , 0j , 0j ]], # [(8+0j), 0j , 0j ]],
# #
# [[0j , 0j , 0j ], # [[0j , 0j , 0j ],
# [0j , 0j , 0j ], # [0j , 0j , 0j ],
...@@ -1135,7 +1135,24 @@ def ihfft2(x, s=None, axes=(-2, -1), norm="backward", name=None): ...@@ -1135,7 +1135,24 @@ def ihfft2(x, s=None, axes=(-2, -1), norm="backward", name=None):
refer to :ref:`api_guide_Name` . refer to :ref:`api_guide_Name` .
Returns: Returns:
out(Tensor) : The result of the inverse real 2-D FFT. out(Tensor) : The result of the inverse hermitian 2-D FFT.
Examples:
.. code-block:: python
import numpy as np
import paddle
x = np.mgrid[:5, :5][0].astype(np.float64)
xp = paddle.to_tensor(x)
ihfft2_xp = paddle.fft.ihfft2(xp).numpy()
print(ihfft2_xp)
# [[ 2. +0.j 0. +0.j 0. +0.j ]
# [-0.5-0.68819096j 0. +0.j 0. +0.j ]
# [-0.5-0.16245985j 0. +0.j 0. +0.j ]
# [-0.5+0.16245985j 0. +0.j 0. +0.j ]
# [-0.5+0.68819096j 0. +0.j 0. +0.j ]]
""" """
_check_at_least_ndim(x, 2) _check_at_least_ndim(x, 2)
if s is not None: if s is not None:
...@@ -1273,9 +1290,8 @@ def fftshift(x, axes=None, name=None): ...@@ -1273,9 +1290,8 @@ def fftshift(x, axes=None, name=None):
import paddle import paddle
x = np.array([3, 1, 2, 2, 3], dtype=float) x = np.array([3, 1, 2, 2, 3], dtype=float)
scalar_temp = 0.3
n = x.size n = x.size
fftfreq_xp = paddle.fft.fftfreq(n, d=scalar_temp) fftfreq_xp = paddle.fft.fftfreq(n, d=0.3)
res = paddle.fft.fftshift(fftfreq_xp).numpy() res = paddle.fft.fftshift(fftfreq_xp).numpy()
print(res) print(res)
# [-1.3333334 -0.6666667 0. 0.6666667 1.3333334] # [-1.3333334 -0.6666667 0. 0.6666667 1.3333334]
...@@ -1284,13 +1300,13 @@ def fftshift(x, axes=None, name=None): ...@@ -1284,13 +1300,13 @@ def fftshift(x, axes=None, name=None):
shape = paddle.shape(x) shape = paddle.shape(x)
if axes is None: if axes is None:
# shift all axes # shift all axes
rank = paddle.rank(x).reshape([1]) rank = len(x.shape)
axes = axes or paddle.arange(0, rank) axes = list(range(0, rank))
shifts = [size // 2 for size in shape] shifts = shape // 2
elif isinstance(axes, int): elif isinstance(axes, int):
shifts = shape[axes] // 2 shifts = shape[axes] // 2
else: else:
shifts = [shape[ax] // 2 for ax in axes] shifts = paddle.concat([shape[ax] // 2 for ax in axes])
return paddle.roll(x, shifts, axes, name=name) return paddle.roll(x, shifts, axes, name=name)
...@@ -1317,9 +1333,8 @@ def ifftshift(x, axes=None, name=None): ...@@ -1317,9 +1333,8 @@ def ifftshift(x, axes=None, name=None):
import paddle import paddle
x = np.array([3, 1, 2, 2, 3], dtype=float) x = np.array([3, 1, 2, 2, 3], dtype=float)
scalar_temp = 0.3
n = x.size n = x.size
fftfreq_xp = paddle.fft.fftfreq(n, d=scalar_temp) fftfreq_xp = paddle.fft.fftfreq(n, d=0.3)
res = paddle.fft.ifftshift(fftfreq_xp).numpy() res = paddle.fft.ifftshift(fftfreq_xp).numpy()
print(res) print(res)
# [ 1.3333334 -1.3333334 -0.6666667 0. 0.6666667] # [ 1.3333334 -1.3333334 -0.6666667 0. 0.6666667]
...@@ -1328,13 +1343,13 @@ def ifftshift(x, axes=None, name=None): ...@@ -1328,13 +1343,13 @@ def ifftshift(x, axes=None, name=None):
shape = paddle.shape(x) shape = paddle.shape(x)
if axes is None: if axes is None:
# shift all axes # shift all axes
rank = paddle.rank(x).reshape([1]) rank = len(x.shape)
axes = axes or paddle.arange(0, rank) axes = list(range(0, rank))
shifts = [-size // 2 for size in shape] shifts = shape // 2
elif isinstance(axes, int): elif isinstance(axes, int):
shifts = -shape[axes] // 2 shifts = -shape[axes] // 2
else: else:
shifts = [-shape[ax] // 2 for ax in axes] shifts = paddle.concat([-shape[ax] // 2 for ax in axes])
return paddle.roll(x, shifts, axes, name=name) return paddle.roll(x, shifts, axes, name=name)
...@@ -1346,7 +1361,7 @@ def fft_c2c(x, n, axis, norm, forward, name): ...@@ -1346,7 +1361,7 @@ def fft_c2c(x, n, axis, norm, forward, name):
x = paddle.cast(x, _real_to_complex_dtype(x.dtype)) x = paddle.cast(x, _real_to_complex_dtype(x.dtype))
_check_normalization(norm) _check_normalization(norm)
axis = axis or -1 axis = axis if axis is not None else -1
_check_fft_axis(x, axis) _check_fft_axis(x, axis)
axes = [axis] axes = [axis]
axes = _normalize_axes(x, axes) axes = _normalize_axes(x, axes)
...@@ -1376,7 +1391,7 @@ def fft_r2c(x, n, axis, norm, forward, onesided, name): ...@@ -1376,7 +1391,7 @@ def fft_r2c(x, n, axis, norm, forward, onesided, name):
if is_interger(x): if is_interger(x):
x = paddle.cast(x, paddle.get_default_dtype()) x = paddle.cast(x, paddle.get_default_dtype())
_check_normalization(norm) _check_normalization(norm)
axis = axis or -1 axis = axis if axis is not None else -1
_check_fft_axis(x, axis) _check_fft_axis(x, axis)
axes = [axis] axes = [axis]
axes = _normalize_axes(x, axes) axes = _normalize_axes(x, axes)
...@@ -1415,7 +1430,7 @@ def fft_c2r(x, n, axis, norm, forward, name): ...@@ -1415,7 +1430,7 @@ def fft_c2r(x, n, axis, norm, forward, name):
elif is_floating_point(x): elif is_floating_point(x):
x = paddle.cast(x, _real_to_complex_dtype(x.dtype)) x = paddle.cast(x, _real_to_complex_dtype(x.dtype))
_check_normalization(norm) _check_normalization(norm)
axis = axis or -1 axis = axis if axis is not None else -1
_check_fft_axis(x, axis) _check_fft_axis(x, axis)
axes = [axis] axes = [axis]
axes = _normalize_axes(x, axes) axes = _normalize_axes(x, axes)
......
...@@ -1009,10 +1009,11 @@ class TestRfftFreq(unittest.TestCase): ...@@ -1009,10 +1009,11 @@ class TestRfftFreq(unittest.TestCase):
@place(DEVICES) @place(DEVICES)
@parameterize((TEST_CASE_NAME, 'x', 'axes', 'dtype'), [ @parameterize(
('test_1d', np.random.randn(10), (0, ), 'float64'), (TEST_CASE_NAME, 'x', 'axes', 'dtype'),
[('test_1d', np.random.randn(10), (0, ), 'float64'),
('test_2d', np.random.randn(10, 10), (0, 1), 'float64'), ('test_2d', np.random.randn(10, 10), (0, 1), 'float64'),
]) ('test_2d_with_all_axes', np.random.randn(10, 10), None, 'float64')])
class TestFftShift(unittest.TestCase): class TestFftShift(unittest.TestCase):
def test_fftshift(self): def test_fftshift(self):
"""Test fftshift with norm condition """Test fftshift with norm condition
...@@ -1030,6 +1031,7 @@ class TestFftShift(unittest.TestCase): ...@@ -1030,6 +1031,7 @@ class TestFftShift(unittest.TestCase):
@parameterize((TEST_CASE_NAME, 'x', 'axes'), [ @parameterize((TEST_CASE_NAME, 'x', 'axes'), [
('test_1d', np.random.randn(10), (0, ), 'float64'), ('test_1d', np.random.randn(10), (0, ), 'float64'),
('test_2d', np.random.randn(10, 10), (0, 1), 'float64'), ('test_2d', np.random.randn(10, 10), (0, 1), 'float64'),
('test_2d_with_all_axes', np.random.randn(10, 10), None, 'float64'),
]) ])
class TestIfftShift(unittest.TestCase): class TestIfftShift(unittest.TestCase):
def test_ifftshift(self): def test_ifftshift(self):
......
...@@ -122,6 +122,34 @@ class TestRollAPI(unittest.TestCase): ...@@ -122,6 +122,34 @@ class TestRollAPI(unittest.TestCase):
self.assertRaises(ValueError, test_axis_out_range) self.assertRaises(ValueError, test_axis_out_range)
def test_shifts_as_tensor_dygraph(self):
with fluid.dygraph.guard():
x = paddle.arange(9).reshape([3, 3])
shape = paddle.shape(x)
shifts = shape // 2
axes = [0, 1]
out = paddle.roll(x, shifts=shifts, axis=axes).numpy()
expected_out = np.array([[8, 6, 7], [2, 0, 1], [5, 3, 4]])
self.assertTrue(np.allclose(out, expected_out))
def test_shifts_as_tensor_static(self):
with program_guard(Program(), Program()):
x = paddle.arange(9).reshape([3, 3]).astype('float32')
shape = paddle.shape(x)
shifts = shape // 2
axes = [0, 1]
out = paddle.roll(x, shifts=shifts, axis=axes)
expected_out = np.array([[8, 6, 7], [2, 0, 1], [5, 3, 4]])
exe = fluid.Executor(fluid.CPUPlace())
[out_np] = exe.run(fetch_list=[out])
self.assertTrue(np.allclose(out_np, expected_out))
if paddle.is_compiled_with_cuda():
exe = fluid.Executor(fluid.CPUPlace())
[out_np] = exe.run(fetch_list=[out])
self.assertTrue(np.allclose(out_np, expected_out))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -652,7 +652,7 @@ class TestFrame(unittest.TestCase): ...@@ -652,7 +652,7 @@ class TestFrame(unittest.TestCase):
self.assertTrue( self.assertTrue(
np.allclose( np.allclose(
frame_for_api_test(self.x, self.frame_length, self.hop_length, self.axis), frame_for_api_test(self.x, self.frame_length, self.hop_length, self.axis),
paddle.tensor.signal.frame( paddle.signal.frame(
paddle.to_tensor(self.x), paddle.to_tensor(self.x),
self.frame_length, self.frame_length,
self.hop_length, self.hop_length,
...@@ -678,7 +678,7 @@ class TestFrameStatic(unittest.TestCase): ...@@ -678,7 +678,7 @@ class TestFrameStatic(unittest.TestCase):
mp, sp = paddle.static.Program(), paddle.static.Program() mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp): with paddle.static.program_guard(mp, sp):
input = paddle.static.data('input', self.x.shape, dtype=self.x.dtype) input = paddle.static.data('input', self.x.shape, dtype=self.x.dtype)
output = paddle.tensor.signal.frame( output = paddle.signal.frame(
input, input,
self.frame_length, self.frame_length,
self.hop_length, self.hop_length,
...@@ -708,7 +708,7 @@ class TestFrameStatic(unittest.TestCase): ...@@ -708,7 +708,7 @@ class TestFrameStatic(unittest.TestCase):
class TestFrameException(unittest.TestCase): class TestFrameException(unittest.TestCase):
def test_frame(self): def test_frame(self):
with self.assertRaises(self.expect_exception): with self.assertRaises(self.expect_exception):
paddle.tensor.signal.frame( paddle.signal.frame(
paddle.to_tensor(self.x), paddle.to_tensor(self.x),
self.frame_length, self.frame_length,
self.hop_length, self.hop_length,
...@@ -731,7 +731,7 @@ class TestOverlapAdd(unittest.TestCase): ...@@ -731,7 +731,7 @@ class TestOverlapAdd(unittest.TestCase):
self.assertTrue( self.assertTrue(
np.allclose( np.allclose(
overlap_add_for_api_test(self.x, self.hop_length, self.axis), overlap_add_for_api_test(self.x, self.hop_length, self.axis),
paddle.tensor.signal.overlap_add( paddle.signal.overlap_add(
paddle.to_tensor(self.x), paddle.to_tensor(self.x),
self.hop_length, self.hop_length,
self.axis), self.axis),
...@@ -756,7 +756,7 @@ class TestOverlapAddStatic(unittest.TestCase): ...@@ -756,7 +756,7 @@ class TestOverlapAddStatic(unittest.TestCase):
mp, sp = paddle.static.Program(), paddle.static.Program() mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp): with paddle.static.program_guard(mp, sp):
input = paddle.static.data('input', self.x.shape, dtype=self.x.dtype) input = paddle.static.data('input', self.x.shape, dtype=self.x.dtype)
output = paddle.tensor.signal.overlap_add( output = paddle.signal.overlap_add(
input, input,
self.hop_length, self.hop_length,
self.axis), self.axis),
...@@ -783,7 +783,7 @@ class TestOverlapAddStatic(unittest.TestCase): ...@@ -783,7 +783,7 @@ class TestOverlapAddStatic(unittest.TestCase):
class TestOverlapAddException(unittest.TestCase): class TestOverlapAddException(unittest.TestCase):
def test_overlap_add(self): def test_overlap_add(self):
with self.assertRaises(self.expect_exception): with self.assertRaises(self.expect_exception):
paddle.tensor.signal.overlap_add( paddle.signal.overlap_add(
paddle.to_tensor(self.x), paddle.to_tensor(self.x),
self.hop_length, self.hop_length,
self.axis) self.axis)
...@@ -848,7 +848,7 @@ class TestStft(unittest.TestCase): ...@@ -848,7 +848,7 @@ class TestStft(unittest.TestCase):
self.assertTrue( self.assertTrue(
np.allclose( np.allclose(
stft(self.x, self.n_fft, self.hop_length, self.win_length, win_l, self.center, self.pad_mode), stft(self.x, self.n_fft, self.hop_length, self.win_length, win_l, self.center, self.pad_mode),
paddle.tensor.signal.stft( paddle.signal.stft(
paddle.to_tensor(self.x), paddle.to_tensor(self.x),
self.n_fft, self.n_fft,
self.hop_length, self.hop_length,
...@@ -891,7 +891,7 @@ class TestStftException(unittest.TestCase): ...@@ -891,7 +891,7 @@ class TestStftException(unittest.TestCase):
win_p = paddle.to_tensor(self.window) win_p = paddle.to_tensor(self.window)
with self.assertRaises(self.expect_exception): with self.assertRaises(self.expect_exception):
paddle.tensor.signal.stft( paddle.signal.stft(
paddle.to_tensor(self.x), paddle.to_tensor(self.x),
self.n_fft, self.n_fft,
self.hop_length, self.hop_length,
...@@ -934,7 +934,7 @@ class TestIstft(unittest.TestCase): ...@@ -934,7 +934,7 @@ class TestIstft(unittest.TestCase):
self.assertTrue( self.assertTrue(
np.allclose( np.allclose(
istft(self.x, self.hop_length, self.win_length, win_l, self.center, self.length), istft(self.x, self.hop_length, self.win_length, win_l, self.center, self.length),
paddle.tensor.signal.istft( paddle.signal.istft(
paddle.to_tensor(self.x), paddle.to_tensor(self.x),
self.n_fft, self.n_fft,
self.hop_length, self.hop_length,
...@@ -986,7 +986,7 @@ class TestIstftException(unittest.TestCase): ...@@ -986,7 +986,7 @@ class TestIstftException(unittest.TestCase):
win_p = paddle.to_tensor(self.window) win_p = paddle.to_tensor(self.window)
with self.assertRaises(self.expect_exception): with self.assertRaises(self.expect_exception):
paddle.tensor.signal.istft( paddle.signal.istft(
paddle.to_tensor(self.x), paddle.to_tensor(self.x),
self.n_fft, self.n_fft,
self.hop_length, self.hop_length,
......
...@@ -16,16 +16,14 @@ from typing import Optional ...@@ -16,16 +16,14 @@ from typing import Optional
import paddle import paddle
from .attribute import is_complex, is_floating_point from .tensor.attribute import is_complex, is_floating_point
from .fft import fft_r2c, fft_c2r, fft_c2c from .fft import fft_r2c, fft_c2r, fft_c2c
from ..fluid.data_feeder import check_variable_and_dtype from .fluid.data_feeder import check_variable_and_dtype
from ..fluid.framework import in_dygraph_mode from .fluid.framework import in_dygraph_mode
from ..fluid.layer_helper import LayerHelper from .fluid.layer_helper import LayerHelper
from .. import _C_ops from . import _C_ops
__all__ = [ __all__ = [
'frame',
'overlap_add',
'stft', 'stft',
'istft', 'istft',
] ]
...@@ -56,7 +54,7 @@ def frame(x, frame_length, hop_length, axis=-1, name=None): ...@@ -56,7 +54,7 @@ def frame(x, frame_length, hop_length, axis=-1, name=None):
.. code-block:: python .. code-block:: python
import paddle import paddle
from paddle.tensor.signal import frame from paddle.signal import frame
# 1D # 1D
x = paddle.arange(8) x = paddle.arange(8)
...@@ -177,7 +175,7 @@ def overlap_add(x, hop_length, axis=-1, name=None): ...@@ -177,7 +175,7 @@ def overlap_add(x, hop_length, axis=-1, name=None):
.. code-block:: python .. code-block:: python
import paddle import paddle
from paddle.tensor.signal import overlap_add from paddle.signal import overlap_add
# 2D # 2D
x0 = paddle.arange(16).reshape([8, 2]) x0 = paddle.arange(16).reshape([8, 2])
...@@ -291,11 +289,11 @@ def stft(x, ...@@ -291,11 +289,11 @@ def stft(x,
real-valued input and `onesided` is `True`) or `[..., n_fft, num_frames]`( real-valued input and `onesided` is `True`) or `[..., n_fft, num_frames]`(
`onesided` is `False`) `onesided` is `False`)
Exampels: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
from paddle.tensor.signal import stft from paddle.signal import stft
# real-valued input # real-valued input
x = paddle.randn([8, 48000], dtype=paddle.float64) x = paddle.randn([8, 48000], dtype=paddle.float64)
...@@ -415,7 +413,7 @@ def istft(x, ...@@ -415,7 +413,7 @@ def istft(x,
- :math:`N`: Value of `n_fft`. - :math:`N`: Value of `n_fft`.
- :math:`H`: Value of `hop_length`. - :math:`H`: Value of `hop_length`.
Result of `istft` expected to be the inverse of `paddle.tensor.signal.stft`, but it is Result of `istft` expected to be the inverse of `paddle.signal.stft`, but it is
not guaranteed to reconstruct a exactly realizible time-domain signal from a STFT not guaranteed to reconstruct a exactly realizible time-domain signal from a STFT
complex tensor which has been modified (via masking or otherwise). Therefore, `istft` complex tensor which has been modified (via masking or otherwise). Therefore, `istft`
gives the [Griffin-Lim optimal estimate](https://ieeexplore.ieee.org/document/1164317) gives the [Griffin-Lim optimal estimate](https://ieeexplore.ieee.org/document/1164317)
...@@ -454,12 +452,12 @@ def istft(x, ...@@ -454,12 +452,12 @@ def istft(x,
A tensor of least squares estimation of the reconstructed signal(s) with shape A tensor of least squares estimation of the reconstructed signal(s) with shape
`[..., seq_length]` `[..., seq_length]`
Exampels: Examples:
.. code-block:: python .. code-block:: python
import numpy as np import numpy as np
import paddle import paddle
from paddle.tensor.signal import stft, istft from paddle.signal import stft, istft
paddle.seed(0) paddle.seed(0)
......
...@@ -222,8 +222,6 @@ from .array import array_write # noqa: F401 ...@@ -222,8 +222,6 @@ from .array import array_write # noqa: F401
from .array import create_array # noqa: F401 from .array import create_array # noqa: F401
from .einsum import einsum # noqa: F401 from .einsum import einsum # noqa: F401
from . import fft
from . import signal
#this list used in math_op_patch.py for _binary_creator_ #this list used in math_op_patch.py for _binary_creator_
tensor_method_func = [ #noqa tensor_method_func = [ #noqa
......
...@@ -696,9 +696,18 @@ def roll(x, shifts, axis=None, name=None): ...@@ -696,9 +696,18 @@ def roll(x, shifts, axis=None, name=None):
helper = LayerHelper("roll", **locals()) helper = LayerHelper("roll", **locals())
check_type(axis, 'axis', (list, tuple), 'roll') check_type(axis, 'axis', (list, tuple), 'roll')
check_type(shifts, 'shifts', (list, tuple), 'roll')
out = helper.create_variable_for_type_inference(x.dtype) out = helper.create_variable_for_type_inference(x.dtype)
if isinstance(shifts, Variable):
helper.append_op(
type='roll',
inputs={'X': x,
"ShiftsTensor": shifts},
outputs={'Out': out},
attrs={'axis': axis})
else:
check_type(shifts, 'shifts', (list, tuple), 'roll')
helper.append_op( helper.append_op(
type='roll', type='roll',
inputs={'X': x}, inputs={'X': x},
......
Usage: #### **Introduction for crawling new error message:**
Please run:
```
bash start.sh
```
If you want to update all external error message, you need to run command `bash start.sh` in current directory,
and upload the generated file `externalErrorMsg.tar.gz` to https://paddlepaddledeps.bj.bcebos.com/externalErrorMsg.tar.gz 1. add new spider code in spider.py for crawling error message from website.
2. run `bash start.sh` in current directory to generate new externalErrorMsg_${date}.tar.gz file, for example `externalErrorMsg_20210928.tar.gz`.
3. upload above tar file into bos https://paddlepaddledeps.bj.bcebos.com **paddlepaddledeps** bucket, and copy download link `${download_url}`. ***\*Be careful not to delete original tar file\****.
4. compute md5 value of above tar file `${md5}`, and modify cmake/third_party.cmake file
```
set(URL "${download_url}" CACHE STRING "" FORCE)
file_download_and_uncompress(${URL} "externalError" MD5 ${md5})
```
for example:
```
set(URL "https://paddlepaddledeps.bj.bcebos.com/externalErrorMsg_20210928.tar.gz" CACHE STRING "" FORCE)
file_download_and_uncompress(${URL} "externalError" MD5 a712a49384e77ca216ad866712f7cafa)
```
5. commit your changes, and create pull request.
...@@ -17,8 +17,10 @@ import re ...@@ -17,8 +17,10 @@ import re
import urllib.request import urllib.request
import json import json
import collections import collections
import sys, getopt import sys
import getopt
import external_error_pb2 import external_error_pb2
from html.parser import HTMLParser
def parsing(externalErrorDesc): def parsing(externalErrorDesc):
...@@ -335,6 +337,31 @@ def parsing(externalErrorDesc): ...@@ -335,6 +337,31 @@ def parsing(externalErrorDesc):
_Messages.message = "'%s'. %s" % (error[0], m_message) _Messages.message = "'%s'. %s" % (error[0], m_message)
print("End crawling errorMessage for nvidia NCCL API!\n") print("End crawling errorMessage for nvidia NCCL API!\n")
#*************************************************************************************************#
#*********************************** CUFFT Error Message **************************************#
print("start crawling errorMessage for nvidia CUFFT API--->")
url = 'https://docs.nvidia.com/cuda/cufft/index.html#cufftresult'
allMessageDesc = externalErrorDesc.errors.add()
allMessageDesc.type = external_error_pb2.CUFFT
html = urllib.request.urlopen(url).read().decode('utf-8')
class CUFFTHTMLParser(HTMLParser):
'''CUFFTHTML Parser
'''
def handle_data(self, data):
if 'typedef enum cufftResult_t' in data:
for line in data.strip().splitlines()[1:-1]:
status, code, desc = re.split('=|//', line.strip())
_Messages = allMessageDesc.messages.add()
_Messages.code = int(code.strip(' ,'))
_Messages.message = "'%s'. %s" % (status.strip(),
desc.strip())
CUFFTHTMLParser().feed(html)
def main(argv): def main(argv):
try: try:
......
...@@ -32,4 +32,4 @@ fi ...@@ -32,4 +32,4 @@ fi
protobuf/bin/protoc -I../../paddle/fluid/platform/ --python_out . ../../paddle/fluid/platform/external_error.proto protobuf/bin/protoc -I../../paddle/fluid/platform/ --python_out . ../../paddle/fluid/platform/external_error.proto
python3.7 spider.py python3.7 spider.py
tar czvf externalErrorMsg.tar.gz externalErrorMsg.pb tar czvf externalErrorMsg_$(date +'%Y%m%d').tar.gz externalErrorMsg.pb
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册