未验证 提交 687902fc 编写于 作者: F Feiyu Chan 提交者: GitHub

[phi] update code for mkl based fft (#39889)

上级 584844ec
...@@ -25,9 +25,10 @@ ...@@ -25,9 +25,10 @@
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/transpose_op.h" #include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#if defined(PADDLE_WITH_ONEMKL) #if defined(PADDLE_WITH_ONEMKL)
#include "paddle/fluid/platform/dynload/mklrt.h" #include "paddle/phi/backends/dynload/mklrt.h"
#elif defined(PADDLE_WITH_POCKETFFT) #elif defined(PADDLE_WITH_POCKETFFT)
#include "extern_pocketfft/pocketfft_hdronly.h" #include "extern_pocketfft/pocketfft_hdronly.h"
#endif #endif
...@@ -357,12 +358,12 @@ FFTNormMode get_norm_from_string(const std::string& norm, bool forward) { ...@@ -357,12 +358,12 @@ FFTNormMode get_norm_from_string(const std::string& norm, bool forward) {
// FFT Functors // FFT Functors
#if defined(PADDLE_WITH_ONEMKL) #if defined(PADDLE_WITH_ONEMKL)
#define MKL_DFTI_CHECK(expr) \ #define MKL_DFTI_CHECK(expr) \
do { \ do { \
MKL_LONG status = (expr); \ MKL_LONG status = (expr); \
if (!platform::dynload::DftiErrorClass(status, DFTI_NO_ERROR)) \ if (!phi::dynload::DftiErrorClass(status, DFTI_NO_ERROR)) \
PADDLE_THROW(platform::errors::External( \ PADDLE_THROW( \
platform::dynload::DftiErrorMessage(status))); \ platform::errors::External(phi::dynload::DftiErrorMessage(status))); \
} while (0); } while (0);
namespace { namespace {
...@@ -370,7 +371,7 @@ namespace { ...@@ -370,7 +371,7 @@ namespace {
struct DftiDescriptorDeleter { struct DftiDescriptorDeleter {
void operator()(DFTI_DESCRIPTOR_HANDLE handle) { void operator()(DFTI_DESCRIPTOR_HANDLE handle) {
if (handle != nullptr) { if (handle != nullptr) {
MKL_DFTI_CHECK(platform::dynload::DftiFreeDescriptor(&handle)); MKL_DFTI_CHECK(phi::dynload::DftiFreeDescriptor(&handle));
} }
} }
}; };
...@@ -385,7 +386,7 @@ class DftiDescriptor { ...@@ -385,7 +386,7 @@ class DftiDescriptor {
"DftiDescriptor has already been initialized.")); "DftiDescriptor has already been initialized."));
DFTI_DESCRIPTOR* raw_desc; DFTI_DESCRIPTOR* raw_desc;
MKL_DFTI_CHECK(platform::dynload::DftiCreateDescriptorX( MKL_DFTI_CHECK(phi::dynload::DftiCreateDescriptorX(
&raw_desc, precision, signal_type, signal_ndim, sizes)); &raw_desc, precision, signal_type, signal_ndim, sizes));
desc_.reset(raw_desc); desc_.reset(raw_desc);
} }
...@@ -437,21 +438,21 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype, ...@@ -437,21 +438,21 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
descriptor.init(precision, domain, signal_ndim, fft_sizes.data() + 1); descriptor.init(precision, domain, signal_ndim, fft_sizes.data() + 1);
// placement inplace or not inplace // placement inplace or not inplace
MKL_DFTI_CHECK(platform::dynload::DftiSetValue( MKL_DFTI_CHECK(phi::dynload::DftiSetValue(descriptor.get(), DFTI_PLACEMENT,
descriptor.get(), DFTI_PLACEMENT, DFTI_NOT_INPLACE)); DFTI_NOT_INPLACE));
// number of transformations // number of transformations
const MKL_LONG batch_size = fft_sizes[0]; const MKL_LONG batch_size = fft_sizes[0];
MKL_DFTI_CHECK(platform::dynload::DftiSetValue( MKL_DFTI_CHECK(phi::dynload::DftiSetValue(
descriptor.get(), DFTI_NUMBER_OF_TRANSFORMS, batch_size)); descriptor.get(), DFTI_NUMBER_OF_TRANSFORMS, batch_size));
// input & output distance // input & output distance
const MKL_LONG idist = in_strides[0]; const MKL_LONG idist = in_strides[0];
const MKL_LONG odist = out_strides[0]; const MKL_LONG odist = out_strides[0];
MKL_DFTI_CHECK(platform::dynload::DftiSetValue(descriptor.get(), MKL_DFTI_CHECK(
DFTI_INPUT_DISTANCE, idist)); phi::dynload::DftiSetValue(descriptor.get(), DFTI_INPUT_DISTANCE, idist));
MKL_DFTI_CHECK(platform::dynload::DftiSetValue(descriptor.get(), MKL_DFTI_CHECK(phi::dynload::DftiSetValue(descriptor.get(),
DFTI_OUTPUT_DISTANCE, odist)); DFTI_OUTPUT_DISTANCE, odist));
// input & output stride // input & output stride
std::vector<MKL_LONG> mkl_in_stride(1 + signal_ndim, 0); std::vector<MKL_LONG> mkl_in_stride(1 + signal_ndim, 0);
...@@ -460,14 +461,14 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype, ...@@ -460,14 +461,14 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
mkl_in_stride[i] = in_strides[i]; mkl_in_stride[i] = in_strides[i];
mkl_out_stride[i] = out_strides[i]; mkl_out_stride[i] = out_strides[i];
} }
MKL_DFTI_CHECK(platform::dynload::DftiSetValue( MKL_DFTI_CHECK(phi::dynload::DftiSetValue(
descriptor.get(), DFTI_INPUT_STRIDES, mkl_in_stride.data())); descriptor.get(), DFTI_INPUT_STRIDES, mkl_in_stride.data()));
MKL_DFTI_CHECK(platform::dynload::DftiSetValue( MKL_DFTI_CHECK(phi::dynload::DftiSetValue(
descriptor.get(), DFTI_OUTPUT_STRIDES, mkl_out_stride.data())); descriptor.get(), DFTI_OUTPUT_STRIDES, mkl_out_stride.data()));
// conjugate even storage // conjugate even storage
if (!(fft_type == FFTTransformType::C2C)) { if (!(fft_type == FFTTransformType::C2C)) {
MKL_DFTI_CHECK(platform::dynload::DftiSetValue( MKL_DFTI_CHECK(phi::dynload::DftiSetValue(
descriptor.get(), DFTI_CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX)); descriptor.get(), DFTI_CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX));
} }
...@@ -489,12 +490,12 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype, ...@@ -489,12 +490,12 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
return DFTI_BACKWARD_SCALE; return DFTI_BACKWARD_SCALE;
} }
}(); }();
MKL_DFTI_CHECK(platform::dynload::DftiSetValue(descriptor.get(), MKL_DFTI_CHECK(
scale_direction, scale)); phi::dynload::DftiSetValue(descriptor.get(), scale_direction, scale));
} }
// commit the descriptor // commit the descriptor
MKL_DFTI_CHECK(platform::dynload::DftiCommitDescriptor(descriptor.get())); MKL_DFTI_CHECK(phi::dynload::DftiCommitDescriptor(descriptor.get()));
return descriptor; return descriptor;
} }
...@@ -575,39 +576,39 @@ void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out, ...@@ -575,39 +576,39 @@ void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out,
framework::TransToProtoVarType(out->dtype()), input_stride, framework::TransToProtoVarType(out->dtype()), input_stride,
output_stride, signal_sizes, normalization, forward); output_stride, signal_sizes, normalization, forward);
const FFTTransformType fft_type = GetFFTTransformType(x->type(), out->type()); const FFTTransformType fft_type =
GetFFTTransformType(framework::TransToProtoVarType(x->dtype()),
framework::TransToProtoVarType(out->type()));
if (fft_type == FFTTransformType::C2R && forward) { if (fft_type == FFTTransformType::C2R && forward) {
framework::Tensor collapsed_input_conj( framework::Tensor collapsed_input_conj(collapsed_input.dtype());
framework::TransToProtoVarType(collapsed_input.dtype()));
collapsed_input_conj.mutable_data<Ti>(collapsed_input.dims(), collapsed_input_conj.mutable_data<Ti>(collapsed_input.dims(),
ctx.GetPlace()); ctx.GetPlace());
// conjugate the input // conjugate the input
platform::ForRange<DeviceContext> for_range(ctx, collapsed_input.numel()); platform::ForRange<DeviceContext> for_range(ctx, collapsed_input.numel());
math::ConjFunctor<Ti> functor(collapsed_input.data<Ti>(), phi::funcs::ConjFunctor<Ti> functor(collapsed_input.data<Ti>(),
collapsed_input.numel(), collapsed_input.numel(),
collapsed_input_conj.data<Ti>()); collapsed_input_conj.data<Ti>());
for_range(functor); for_range(functor);
MKL_DFTI_CHECK(platform::dynload::DftiComputeBackward( MKL_DFTI_CHECK(phi::dynload::DftiComputeBackward(
desc.get(), collapsed_input_conj.data(), collapsed_output.data())); desc.get(), collapsed_input_conj.data(), collapsed_output.data()));
} else if (fft_type == FFTTransformType::R2C && !forward) { } else if (fft_type == FFTTransformType::R2C && !forward) {
framework::Tensor collapsed_output_conj( framework::Tensor collapsed_output_conj(collapsed_output.dtype());
framework::TransToProtoVarType(collapsed_output.dtype()));
collapsed_output_conj.mutable_data<To>(collapsed_output.dims(), collapsed_output_conj.mutable_data<To>(collapsed_output.dims(),
ctx.GetPlace()); ctx.GetPlace());
MKL_DFTI_CHECK(platform::dynload::DftiComputeForward( MKL_DFTI_CHECK(phi::dynload::DftiComputeForward(
desc.get(), collapsed_input.data(), collapsed_output_conj.data())); desc.get(), collapsed_input.data(), collapsed_output_conj.data()));
// conjugate the output // conjugate the output
platform::ForRange<DeviceContext> for_range(ctx, collapsed_output.numel()); platform::ForRange<DeviceContext> for_range(ctx, collapsed_output.numel());
math::ConjFunctor<To> functor(collapsed_output_conj.data<To>(), phi::funcs::ConjFunctor<To> functor(collapsed_output_conj.data<To>(),
collapsed_output.numel(), collapsed_output.numel(),
collapsed_output.data<To>()); collapsed_output.data<To>());
for_range(functor); for_range(functor);
} else { } else {
if (forward) { if (forward) {
MKL_DFTI_CHECK(platform::dynload::DftiComputeForward( MKL_DFTI_CHECK(phi::dynload::DftiComputeForward(
desc.get(), collapsed_input.data(), collapsed_output.data())); desc.get(), collapsed_input.data(), collapsed_output.data()));
} else { } else {
MKL_DFTI_CHECK(platform::dynload::DftiComputeBackward( MKL_DFTI_CHECK(phi::dynload::DftiComputeBackward(
desc.get(), collapsed_input.data(), collapsed_output.data())); desc.get(), collapsed_input.data(), collapsed_output.data()));
} }
} }
......
...@@ -17,7 +17,8 @@ limitations under the License. */ ...@@ -17,7 +17,8 @@ limitations under the License. */
#include <mkl_dfti.h> #include <mkl_dfti.h>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/dynamic_loader.h" #include "paddle/phi/backends/dynload/dynamic_loader.h"
#include "paddle/phi/backends/dynload/mklrt.h"
#include "paddle/phi/backends/dynload/port.h" #include "paddle/phi/backends/dynload/port.h"
namespace paddle { namespace paddle {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册