From f45e6cf6f476b25b52c194120401b920e8675785 Mon Sep 17 00:00:00 2001 From: Feiyu Chan Date: Fri, 15 Oct 2021 12:46:24 +0800 Subject: [PATCH] dynamic load mkl as a fft backend when it is avaialble and requested (#36414) --- paddle/fluid/operators/CMakeLists.txt | 15 ++- paddle/fluid/operators/spectral_op.cc | 113 +++++++++--------- paddle/fluid/platform/dynload/CMakeLists.txt | 6 + .../fluid/platform/dynload/dynamic_loader.cc | 16 +++ .../fluid/platform/dynload/dynamic_loader.h | 1 + paddle/fluid/platform/dynload/mklrt.cc | 51 ++++++++ paddle/fluid/platform/dynload/mklrt.h | 80 +++++++++++++ 7 files changed, 221 insertions(+), 61 deletions(-) create mode 100644 paddle/fluid/platform/dynload/mklrt.cc create mode 100644 paddle/fluid/platform/dynload/mklrt.h diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index b910b4ec73..bb31fcf854 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -102,10 +102,21 @@ else() op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) endif() + if (WITH_GPU AND (NOT WITH_ROCM)) - op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda ${OP_HEADER_DEPS}) + if (MKL_FOUND AND WITH_ONEMKL) + op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda dynload_mklrt ${OP_HEADER_DEPS}) + target_include_directories(spectral_op PRIVATE ${MKL_INCLUDE}) + else() + op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda ${OP_HEADER_DEPS}) + endif() else() - op_library(spectral_op SRCS spectral_op.cc DEPS ${OP_HEADER_DEPS}) + if (MKL_FOUND AND WITH_ONEMKL) + op_library(spectral_op SRCS spectral_op.cc DEPS dynload_mklrt ${OP_HEADER_DEPS}) + target_include_directories(spectral_op PRIVATE ${MKL_INCLUDE}) + else() + op_library(spectral_op SRCS spectral_op.cc DEPS ${OP_HEADER_DEPS}) + endif() endif() op_library(lstm_op DEPS ${OP_HEADER_DEPS} lstm_compute) diff --git a/paddle/fluid/operators/spectral_op.cc b/paddle/fluid/operators/spectral_op.cc index fb50702233..b5edc1dda5 100644 --- a/paddle/fluid/operators/spectral_op.cc +++ b/paddle/fluid/operators/spectral_op.cc @@ -27,7 +27,7 @@ #include "paddle/fluid/platform/complex.h" #if defined(PADDLE_WITH_ONEMKL) -#include +#include "paddle/fluid/platform/dynload/mklrt.h" #elif defined(PADDLE_WITH_POCKETFFT) #include "extern_pocketfft/pocketfft_hdronly.h" #endif @@ -357,46 +357,45 @@ FFTNormMode get_norm_from_string(const std::string& norm, bool forward) { // FFT Functors #if defined(PADDLE_WITH_ONEMKL) +#define MKL_DFTI_CHECK(expr) \ + do { \ + MKL_LONG status = (expr); \ + if (!platform::dynload::DftiErrorClass(status, DFTI_NO_ERROR)) \ + PADDLE_THROW(platform::errors::External( \ + platform::dynload::DftiErrorMessage(status))); \ + } while (0); + namespace { -static inline void MKL_DFTI_CHECK(MKL_INT status) { - if (status && !DftiErrorClass(status, DFTI_NO_ERROR)) { - PADDLE_THROW(platform::errors::External(DftiErrorMessage(status))); - } -} struct DftiDescriptorDeleter { void operator()(DFTI_DESCRIPTOR_HANDLE handle) { if (handle != nullptr) { - MKL_DFTI_CHECK(DftiFreeDescriptor(&handle)); + MKL_DFTI_CHECK(platform::dynload::DftiFreeDescriptor(&handle)); } } }; +// A RAII wrapper for MKL_DESCRIPTOR* class DftiDescriptor { public: void init(DFTI_CONFIG_VALUE precision, DFTI_CONFIG_VALUE signal_type, MKL_LONG signal_ndim, MKL_LONG* sizes) { - if (desc_ != nullptr) { - PADDLE_THROW(platform::errors::AlreadyExists( - "DFT DESCRIPTOR can only be initialized once.")); - } + PADDLE_ENFORCE_EQ(desc_.get(), nullptr, + platform::errors::AlreadyExists( + "DftiDescriptor has already been initialized.")); + DFTI_DESCRIPTOR* raw_desc; - if (signal_ndim == 1) { - MKL_DFTI_CHECK( - DftiCreateDescriptor(&raw_desc, precision, signal_type, 1, sizes[0])); - } else { - MKL_DFTI_CHECK(DftiCreateDescriptor(&raw_desc, precision, signal_type, - signal_ndim, sizes)); - } + MKL_DFTI_CHECK(platform::dynload::DftiCreateDescriptorX( + &raw_desc, precision, signal_type, signal_ndim, sizes)); desc_.reset(raw_desc); } DFTI_DESCRIPTOR* get() const { - if (desc_ == nullptr) { - PADDLE_THROW(platform::errors::PreconditionNotMet( - "DFTI DESCRIPTOR has not been initialized.")); - } - return desc_.get(); + DFTI_DESCRIPTOR* raw_desc = desc_.get(); + PADDLE_ENFORCE_NOT_NULL(raw_desc, + platform::errors::PreconditionNotMet( + "DFTI DESCRIPTOR has not been initialized.")); + return raw_desc; } private: @@ -421,7 +420,9 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype, return DFTI_DOUBLE; default: PADDLE_THROW(platform::errors::InvalidArgument( - "Input data type should be FP32, FP64, COMPLEX64 or COMPLEX128.")); + "Invalid input datatype (%s), input data type should be FP32, " + "FP64, COMPLEX64 or COMPLEX128.", + framework::DataTypeToString(in_dtype))); } }(); @@ -430,35 +431,27 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype, const DFTI_CONFIG_VALUE domain = (fft_type == FFTTransformType::C2C) ? DFTI_COMPLEX : DFTI_REAL; - // const bool complex_input = framework::IsComplexType(in_dtype); - // const bool complex_output = framework::IsComplexType(out_dtype); - // const DFTI_CONFIG_VALUE domain = [&] { - // if (forward) { - // return complex_input ? DFTI_COMPLEX : DFTI_REAL; - // } else { - // return complex_output ? DFTI_COMPLEX : DFTI_REAL; - // } - // }(); - DftiDescriptor descriptor; std::vector fft_sizes(signal_sizes.cbegin(), signal_sizes.cend()); const MKL_LONG signal_ndim = fft_sizes.size() - 1; descriptor.init(precision, domain, signal_ndim, fft_sizes.data() + 1); // placement inplace or not inplace - MKL_DFTI_CHECK( - DftiSetValue(descriptor.get(), DFTI_PLACEMENT, DFTI_NOT_INPLACE)); + MKL_DFTI_CHECK(platform::dynload::DftiSetValue( + descriptor.get(), DFTI_PLACEMENT, DFTI_NOT_INPLACE)); // number of transformations const MKL_LONG batch_size = fft_sizes[0]; - MKL_DFTI_CHECK( - DftiSetValue(descriptor.get(), DFTI_NUMBER_OF_TRANSFORMS, batch_size)); + MKL_DFTI_CHECK(platform::dynload::DftiSetValue( + descriptor.get(), DFTI_NUMBER_OF_TRANSFORMS, batch_size)); // input & output distance const MKL_LONG idist = in_strides[0]; const MKL_LONG odist = out_strides[0]; - MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_INPUT_DISTANCE, idist)); - MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_OUTPUT_DISTANCE, odist)); + MKL_DFTI_CHECK(platform::dynload::DftiSetValue(descriptor.get(), + DFTI_INPUT_DISTANCE, idist)); + MKL_DFTI_CHECK(platform::dynload::DftiSetValue(descriptor.get(), + DFTI_OUTPUT_DISTANCE, odist)); // input & output stride std::vector mkl_in_stride(1 + signal_ndim, 0); @@ -467,15 +460,15 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype, mkl_in_stride[i] = in_strides[i]; mkl_out_stride[i] = out_strides[i]; } - MKL_DFTI_CHECK( - DftiSetValue(descriptor.get(), DFTI_INPUT_STRIDES, mkl_in_stride.data())); - MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_OUTPUT_STRIDES, - mkl_out_stride.data())); + MKL_DFTI_CHECK(platform::dynload::DftiSetValue( + descriptor.get(), DFTI_INPUT_STRIDES, mkl_in_stride.data())); + MKL_DFTI_CHECK(platform::dynload::DftiSetValue( + descriptor.get(), DFTI_OUTPUT_STRIDES, mkl_out_stride.data())); // conjugate even storage if (!(fft_type == FFTTransformType::C2C)) { - MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_CONJUGATE_EVEN_STORAGE, - DFTI_COMPLEX_COMPLEX)); + MKL_DFTI_CHECK(platform::dynload::DftiSetValue( + descriptor.get(), DFTI_CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX)); } MKL_LONG signal_numel = @@ -496,11 +489,12 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype, return DFTI_BACKWARD_SCALE; } }(); - MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), scale_direction, scale)); + MKL_DFTI_CHECK(platform::dynload::DftiSetValue(descriptor.get(), + scale_direction, scale)); } // commit the descriptor - MKL_DFTI_CHECK(DftiCommitDescriptor(descriptor.get())); + MKL_DFTI_CHECK(platform::dynload::DftiCommitDescriptor(descriptor.get())); return descriptor; } @@ -592,15 +586,16 @@ void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out, collapsed_input.numel(), collapsed_input_conj.data()); for_range(functor); - MKL_DFTI_CHECK(DftiComputeBackward(desc.get(), - collapsed_input_conj.data(), - collapsed_output.data())); + MKL_DFTI_CHECK(platform::dynload::DftiComputeBackward( + desc.get(), collapsed_input_conj.data(), + collapsed_output.data())); } else if (fft_type == FFTTransformType::R2C && !forward) { framework::Tensor collapsed_output_conj(collapsed_output.type()); collapsed_output_conj.mutable_data(collapsed_output.dims(), ctx.GetPlace()); - MKL_DFTI_CHECK(DftiComputeForward(desc.get(), collapsed_input.data(), - collapsed_output_conj.data())); + MKL_DFTI_CHECK(platform::dynload::DftiComputeForward( + desc.get(), collapsed_input.data(), + collapsed_output_conj.data())); // conjugate the output platform::ForRange for_range(ctx, collapsed_output.numel()); math::ConjFunctor functor(collapsed_output_conj.data(), @@ -609,13 +604,13 @@ void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out, for_range(functor); } else { if (forward) { - MKL_DFTI_CHECK(DftiComputeForward(desc.get(), - collapsed_input.data(), - collapsed_output.data())); + MKL_DFTI_CHECK(platform::dynload::DftiComputeForward( + desc.get(), collapsed_input.data(), + collapsed_output.data())); } else { - MKL_DFTI_CHECK(DftiComputeBackward(desc.get(), - collapsed_input.data(), - collapsed_output.data())); + MKL_DFTI_CHECK(platform::dynload::DftiComputeBackward( + desc.get(), collapsed_input.data(), + collapsed_output.data())); } } diff --git a/paddle/fluid/platform/dynload/CMakeLists.txt b/paddle/fluid/platform/dynload/CMakeLists.txt index c0d4b349a9..8c64aad46c 100644 --- a/paddle/fluid/platform/dynload/CMakeLists.txt +++ b/paddle/fluid/platform/dynload/CMakeLists.txt @@ -49,3 +49,9 @@ endif() cc_library(dynload_lapack SRCS lapack.cc DEPS dynamic_loader) add_dependencies(dynload_lapack extern_lapack) # TODO(TJ): add iomp, mkldnn? + +if (MKL_FOUND AND WITH_ONEMKL) + message("ONEMKL INCLUDE directory is ${MKL_INCLUDE}") + cc_library(dynload_mklrt SRCS mklrt.cc DEPS dynamic_loader) + target_include_directories(dynload_mklrt PRIVATE ${MKL_INCLUDE}) +endif() diff --git a/paddle/fluid/platform/dynload/dynamic_loader.cc b/paddle/fluid/platform/dynload/dynamic_loader.cc index a83f085f7d..0c5c47e38f 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.cc +++ b/paddle/fluid/platform/dynload/dynamic_loader.cc @@ -53,6 +53,12 @@ DEFINE_string(mklml_dir, "", "Specify path for loading libmklml_intel.so."); DEFINE_string(lapack_dir, "", "Specify path for loading liblapack.so."); +DEFINE_string(mkl_dir, "", + "Specify path for loading libmkl_rt.so. " + "For insrance, /opt/intel/oneapi/mkl/latest/lib/intel64/." + "If default, " + "dlopen will search mkl from LD_LIBRARY_PATH"); + DEFINE_string(op_dir, "", "Specify path for loading user-defined op library."); #ifdef PADDLE_WITH_HIP @@ -518,6 +524,16 @@ void* GetCUFFTDsoHandle() { #endif } +void* GetMKLRTDsoHandle() { +#if defined(__APPLE__) || defined(__OSX__) + return GetDsoHandleFromSearchPath(FLAGS_mkl_dir, "libmkl_rt.dylib"); +#elif defined(_WIN32) + return GetDsoHandleFromSearchPath(FLAGS_mkl_dir, "mkl_rt.dll"); +#else + return GetDsoHandleFromSearchPath(FLAGS_mkl_dir, "libmkl_rt.so"); +#endif +} + } // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/dynload/dynamic_loader.h b/paddle/fluid/platform/dynload/dynamic_loader.h index 82c36d9e22..6260efdf71 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.h +++ b/paddle/fluid/platform/dynload/dynamic_loader.h @@ -43,6 +43,7 @@ void* GetLAPACKDsoHandle(); void* GetOpDsoHandle(const std::string& dso_name); void* GetNvtxDsoHandle(); void* GetCUFFTDsoHandle(); +void* GetMKLRTDsoHandle(); void SetPaddleLibPath(const std::string&); } // namespace dynload diff --git a/paddle/fluid/platform/dynload/mklrt.cc b/paddle/fluid/platform/dynload/mklrt.cc new file mode 100644 index 0000000000..45fad15fb5 --- /dev/null +++ b/paddle/fluid/platform/dynload/mklrt.cc @@ -0,0 +1,51 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/dynload/mklrt.h" + +namespace paddle { +namespace platform { +namespace dynload { + +std::once_flag mklrt_dso_flag; +void* mklrt_dso_handle = nullptr; + +#define DEFINE_WRAP(__name) DynLoad__##__name __name + +MKLDFTI_ROUTINE_EACH(DEFINE_WRAP); + +DFTI_EXTERN MKL_LONG DftiCreateDescriptorX(DFTI_DESCRIPTOR_HANDLE* desc, + enum DFTI_CONFIG_VALUE prec, + enum DFTI_CONFIG_VALUE domain, + MKL_LONG dim, MKL_LONG* sizes) { + if (prec == DFTI_SINGLE) { + if (dim == 1) { + return DftiCreateDescriptor_s_1d(desc, domain, sizes[0]); + } else { + return DftiCreateDescriptor_s_md(desc, domain, dim, sizes); + } + } else if (prec == DFTI_DOUBLE) { + if (dim == 1) { + return DftiCreateDescriptor_d_1d(desc, domain, sizes[0]); + } else { + return DftiCreateDescriptor_d_md(desc, domain, dim, sizes); + } + } else { + return DftiCreateDescriptor(desc, prec, domain, dim, sizes); + } +} + +} // namespace dynload +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/dynload/mklrt.h b/paddle/fluid/platform/dynload/mklrt.h new file mode 100644 index 0000000000..423cd4d0a2 --- /dev/null +++ b/paddle/fluid/platform/dynload/mklrt.h @@ -0,0 +1,80 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include // NOLINT + +#include "paddle/fluid/platform/dynload/dynamic_loader.h" +#include "paddle/fluid/platform/port.h" + +namespace paddle { +namespace platform { +namespace dynload { + +extern std::once_flag mklrt_dso_flag; +extern void* mklrt_dso_handle; + +/** + * The following macro definition can generate structs + * (for each function) to dynamic load mkldfti routine + * via operator overloading. + */ +#define DYNAMIC_LOAD_MKLRT_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ + using mklrtFunc = decltype(&::__name); \ + std::call_once(mklrt_dso_flag, []() { \ + mklrt_dso_handle = paddle::platform::dynload::GetMKLRTDsoHandle(); \ + }); \ + static void* p_##__name = dlsym(mklrt_dso_handle, #__name); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ + }; \ + extern DynLoad__##__name __name + +// mkl_dfti.h has a macro that shadows the function with the same name +// un-defeine this macro so as to export that function +#undef DftiCreateDescriptor + +#define MKLDFTI_ROUTINE_EACH(__macro) \ + __macro(DftiCreateDescriptor); \ + __macro(DftiCreateDescriptor_s_1d); \ + __macro(DftiCreateDescriptor_d_1d); \ + __macro(DftiCreateDescriptor_s_md); \ + __macro(DftiCreateDescriptor_d_md); \ + __macro(DftiSetValue); \ + __macro(DftiGetValue); \ + __macro(DftiCommitDescriptor); \ + __macro(DftiComputeForward); \ + __macro(DftiComputeBackward); \ + __macro(DftiFreeDescriptor); \ + __macro(DftiErrorClass); \ + __macro(DftiErrorMessage); + +MKLDFTI_ROUTINE_EACH(DYNAMIC_LOAD_MKLRT_WRAP) + +#undef DYNAMIC_LOAD_MKLRT_WRAP + +// define another function to avoid naming conflict +DFTI_EXTERN MKL_LONG DftiCreateDescriptorX(DFTI_DESCRIPTOR_HANDLE* desc, + enum DFTI_CONFIG_VALUE prec, + enum DFTI_CONFIG_VALUE domain, + MKL_LONG dim, MKL_LONG* sizes); + +} // namespace dynload +} // namespace platform +} // namespace paddle -- GitLab