From 687902fcfd89c7648ab68d6e57041cfb7ae20fc8 Mon Sep 17 00:00:00 2001 From: Feiyu Chan Date: Fri, 25 Feb 2022 22:45:32 +0800 Subject: [PATCH] [phi] update code for mkl based fft (#39889) --- paddle/fluid/operators/spectral_op.cc | 75 ++++++++++++++------------- paddle/fluid/platform/dynload/mklrt.h | 3 +- 2 files changed, 40 insertions(+), 38 deletions(-) diff --git a/paddle/fluid/operators/spectral_op.cc b/paddle/fluid/operators/spectral_op.cc index fe76448a185..db3dc214bfe 100644 --- a/paddle/fluid/operators/spectral_op.cc +++ b/paddle/fluid/operators/spectral_op.cc @@ -25,9 +25,10 @@ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/operators/transpose_op.h" #include "paddle/fluid/platform/complex.h" +#include "paddle/phi/kernels/funcs/complex_functors.h" #if defined(PADDLE_WITH_ONEMKL) -#include "paddle/fluid/platform/dynload/mklrt.h" +#include "paddle/phi/backends/dynload/mklrt.h" #elif defined(PADDLE_WITH_POCKETFFT) #include "extern_pocketfft/pocketfft_hdronly.h" #endif @@ -357,12 +358,12 @@ FFTNormMode get_norm_from_string(const std::string& norm, bool forward) { // FFT Functors #if defined(PADDLE_WITH_ONEMKL) -#define MKL_DFTI_CHECK(expr) \ - do { \ - MKL_LONG status = (expr); \ - if (!platform::dynload::DftiErrorClass(status, DFTI_NO_ERROR)) \ - PADDLE_THROW(platform::errors::External( \ - platform::dynload::DftiErrorMessage(status))); \ +#define MKL_DFTI_CHECK(expr) \ + do { \ + MKL_LONG status = (expr); \ + if (!phi::dynload::DftiErrorClass(status, DFTI_NO_ERROR)) \ + PADDLE_THROW( \ + platform::errors::External(phi::dynload::DftiErrorMessage(status))); \ } while (0); namespace { @@ -370,7 +371,7 @@ namespace { struct DftiDescriptorDeleter { void operator()(DFTI_DESCRIPTOR_HANDLE handle) { if (handle != nullptr) { - MKL_DFTI_CHECK(platform::dynload::DftiFreeDescriptor(&handle)); + MKL_DFTI_CHECK(phi::dynload::DftiFreeDescriptor(&handle)); } } }; @@ -385,7 +386,7 @@ class DftiDescriptor { "DftiDescriptor has already been initialized.")); DFTI_DESCRIPTOR* raw_desc; - MKL_DFTI_CHECK(platform::dynload::DftiCreateDescriptorX( + MKL_DFTI_CHECK(phi::dynload::DftiCreateDescriptorX( &raw_desc, precision, signal_type, signal_ndim, sizes)); desc_.reset(raw_desc); } @@ -437,21 +438,21 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype, descriptor.init(precision, domain, signal_ndim, fft_sizes.data() + 1); // placement inplace or not inplace - MKL_DFTI_CHECK(platform::dynload::DftiSetValue( - descriptor.get(), DFTI_PLACEMENT, DFTI_NOT_INPLACE)); + MKL_DFTI_CHECK(phi::dynload::DftiSetValue(descriptor.get(), DFTI_PLACEMENT, + DFTI_NOT_INPLACE)); // number of transformations const MKL_LONG batch_size = fft_sizes[0]; - MKL_DFTI_CHECK(platform::dynload::DftiSetValue( + MKL_DFTI_CHECK(phi::dynload::DftiSetValue( descriptor.get(), DFTI_NUMBER_OF_TRANSFORMS, batch_size)); // input & output distance const MKL_LONG idist = in_strides[0]; const MKL_LONG odist = out_strides[0]; - MKL_DFTI_CHECK(platform::dynload::DftiSetValue(descriptor.get(), - DFTI_INPUT_DISTANCE, idist)); - MKL_DFTI_CHECK(platform::dynload::DftiSetValue(descriptor.get(), - DFTI_OUTPUT_DISTANCE, odist)); + MKL_DFTI_CHECK( + phi::dynload::DftiSetValue(descriptor.get(), DFTI_INPUT_DISTANCE, idist)); + MKL_DFTI_CHECK(phi::dynload::DftiSetValue(descriptor.get(), + DFTI_OUTPUT_DISTANCE, odist)); // input & output stride std::vector mkl_in_stride(1 + signal_ndim, 0); @@ -460,14 +461,14 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype, mkl_in_stride[i] = in_strides[i]; mkl_out_stride[i] = out_strides[i]; } - MKL_DFTI_CHECK(platform::dynload::DftiSetValue( + MKL_DFTI_CHECK(phi::dynload::DftiSetValue( descriptor.get(), DFTI_INPUT_STRIDES, mkl_in_stride.data())); - MKL_DFTI_CHECK(platform::dynload::DftiSetValue( + MKL_DFTI_CHECK(phi::dynload::DftiSetValue( descriptor.get(), DFTI_OUTPUT_STRIDES, mkl_out_stride.data())); // conjugate even storage if (!(fft_type == FFTTransformType::C2C)) { - MKL_DFTI_CHECK(platform::dynload::DftiSetValue( + MKL_DFTI_CHECK(phi::dynload::DftiSetValue( descriptor.get(), DFTI_CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX)); } @@ -489,12 +490,12 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype, return DFTI_BACKWARD_SCALE; } }(); - MKL_DFTI_CHECK(platform::dynload::DftiSetValue(descriptor.get(), - scale_direction, scale)); + MKL_DFTI_CHECK( + phi::dynload::DftiSetValue(descriptor.get(), scale_direction, scale)); } // commit the descriptor - MKL_DFTI_CHECK(platform::dynload::DftiCommitDescriptor(descriptor.get())); + MKL_DFTI_CHECK(phi::dynload::DftiCommitDescriptor(descriptor.get())); return descriptor; } @@ -575,39 +576,39 @@ void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out, framework::TransToProtoVarType(out->dtype()), input_stride, output_stride, signal_sizes, normalization, forward); - const FFTTransformType fft_type = GetFFTTransformType(x->type(), out->type()); + const FFTTransformType fft_type = + GetFFTTransformType(framework::TransToProtoVarType(x->dtype()), + framework::TransToProtoVarType(out->type())); if (fft_type == FFTTransformType::C2R && forward) { - framework::Tensor collapsed_input_conj( - framework::TransToProtoVarType(collapsed_input.dtype())); + framework::Tensor collapsed_input_conj(collapsed_input.dtype()); collapsed_input_conj.mutable_data(collapsed_input.dims(), ctx.GetPlace()); // conjugate the input platform::ForRange for_range(ctx, collapsed_input.numel()); - math::ConjFunctor functor(collapsed_input.data(), - collapsed_input.numel(), - collapsed_input_conj.data()); + phi::funcs::ConjFunctor functor(collapsed_input.data(), + collapsed_input.numel(), + collapsed_input_conj.data()); for_range(functor); - MKL_DFTI_CHECK(platform::dynload::DftiComputeBackward( + MKL_DFTI_CHECK(phi::dynload::DftiComputeBackward( desc.get(), collapsed_input_conj.data(), collapsed_output.data())); } else if (fft_type == FFTTransformType::R2C && !forward) { - framework::Tensor collapsed_output_conj( - framework::TransToProtoVarType(collapsed_output.dtype())); + framework::Tensor collapsed_output_conj(collapsed_output.dtype()); collapsed_output_conj.mutable_data(collapsed_output.dims(), ctx.GetPlace()); - MKL_DFTI_CHECK(platform::dynload::DftiComputeForward( + MKL_DFTI_CHECK(phi::dynload::DftiComputeForward( desc.get(), collapsed_input.data(), collapsed_output_conj.data())); // conjugate the output platform::ForRange for_range(ctx, collapsed_output.numel()); - math::ConjFunctor functor(collapsed_output_conj.data(), - collapsed_output.numel(), - collapsed_output.data()); + phi::funcs::ConjFunctor functor(collapsed_output_conj.data(), + collapsed_output.numel(), + collapsed_output.data()); for_range(functor); } else { if (forward) { - MKL_DFTI_CHECK(platform::dynload::DftiComputeForward( + MKL_DFTI_CHECK(phi::dynload::DftiComputeForward( desc.get(), collapsed_input.data(), collapsed_output.data())); } else { - MKL_DFTI_CHECK(platform::dynload::DftiComputeBackward( + MKL_DFTI_CHECK(phi::dynload::DftiComputeBackward( desc.get(), collapsed_input.data(), collapsed_output.data())); } } diff --git a/paddle/fluid/platform/dynload/mklrt.h b/paddle/fluid/platform/dynload/mklrt.h index 3b7d23277e0..334b98a1c3d 100644 --- a/paddle/fluid/platform/dynload/mklrt.h +++ b/paddle/fluid/platform/dynload/mklrt.h @@ -17,7 +17,8 @@ limitations under the License. */ #include #include // NOLINT -#include "paddle/fluid/platform/dynload/dynamic_loader.h" +#include "paddle/phi/backends/dynload/dynamic_loader.h" +#include "paddle/phi/backends/dynload/mklrt.h" #include "paddle/phi/backends/dynload/port.h" namespace paddle { -- GitLab