未验证 提交 f45e6cf6 编写于 作者: F Feiyu Chan 提交者: GitHub

dynamic load mkl as a fft backend when it is avaialble and requested (#36414)

上级 b3f02c57
......@@ -102,10 +102,21 @@ else()
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
endif()
if (WITH_GPU AND (NOT WITH_ROCM))
if (MKL_FOUND AND WITH_ONEMKL)
op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda dynload_mklrt ${OP_HEADER_DEPS})
target_include_directories(spectral_op PRIVATE ${MKL_INCLUDE})
else()
op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda ${OP_HEADER_DEPS})
endif()
else()
if (MKL_FOUND AND WITH_ONEMKL)
op_library(spectral_op SRCS spectral_op.cc DEPS dynload_mklrt ${OP_HEADER_DEPS})
target_include_directories(spectral_op PRIVATE ${MKL_INCLUDE})
else()
op_library(spectral_op SRCS spectral_op.cc DEPS ${OP_HEADER_DEPS})
endif()
endif()
op_library(lstm_op DEPS ${OP_HEADER_DEPS} lstm_compute)
......
......@@ -27,7 +27,7 @@
#include "paddle/fluid/platform/complex.h"
#if defined(PADDLE_WITH_ONEMKL)
#include <mkl_dfti.h>
#include "paddle/fluid/platform/dynload/mklrt.h"
#elif defined(PADDLE_WITH_POCKETFFT)
#include "extern_pocketfft/pocketfft_hdronly.h"
#endif
......@@ -357,46 +357,45 @@ FFTNormMode get_norm_from_string(const std::string& norm, bool forward) {
// FFT Functors
#if defined(PADDLE_WITH_ONEMKL)
#define MKL_DFTI_CHECK(expr) \
do { \
MKL_LONG status = (expr); \
if (!platform::dynload::DftiErrorClass(status, DFTI_NO_ERROR)) \
PADDLE_THROW(platform::errors::External( \
platform::dynload::DftiErrorMessage(status))); \
} while (0);
namespace {
static inline void MKL_DFTI_CHECK(MKL_INT status) {
if (status && !DftiErrorClass(status, DFTI_NO_ERROR)) {
PADDLE_THROW(platform::errors::External(DftiErrorMessage(status)));
}
}
struct DftiDescriptorDeleter {
void operator()(DFTI_DESCRIPTOR_HANDLE handle) {
if (handle != nullptr) {
MKL_DFTI_CHECK(DftiFreeDescriptor(&handle));
MKL_DFTI_CHECK(platform::dynload::DftiFreeDescriptor(&handle));
}
}
};
// A RAII wrapper for MKL_DESCRIPTOR*
class DftiDescriptor {
public:
void init(DFTI_CONFIG_VALUE precision, DFTI_CONFIG_VALUE signal_type,
MKL_LONG signal_ndim, MKL_LONG* sizes) {
if (desc_ != nullptr) {
PADDLE_THROW(platform::errors::AlreadyExists(
"DFT DESCRIPTOR can only be initialized once."));
}
PADDLE_ENFORCE_EQ(desc_.get(), nullptr,
platform::errors::AlreadyExists(
"DftiDescriptor has already been initialized."));
DFTI_DESCRIPTOR* raw_desc;
if (signal_ndim == 1) {
MKL_DFTI_CHECK(
DftiCreateDescriptor(&raw_desc, precision, signal_type, 1, sizes[0]));
} else {
MKL_DFTI_CHECK(DftiCreateDescriptor(&raw_desc, precision, signal_type,
signal_ndim, sizes));
}
MKL_DFTI_CHECK(platform::dynload::DftiCreateDescriptorX(
&raw_desc, precision, signal_type, signal_ndim, sizes));
desc_.reset(raw_desc);
}
DFTI_DESCRIPTOR* get() const {
if (desc_ == nullptr) {
PADDLE_THROW(platform::errors::PreconditionNotMet(
DFTI_DESCRIPTOR* raw_desc = desc_.get();
PADDLE_ENFORCE_NOT_NULL(raw_desc,
platform::errors::PreconditionNotMet(
"DFTI DESCRIPTOR has not been initialized."));
}
return desc_.get();
return raw_desc;
}
private:
......@@ -421,7 +420,9 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
return DFTI_DOUBLE;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Input data type should be FP32, FP64, COMPLEX64 or COMPLEX128."));
"Invalid input datatype (%s), input data type should be FP32, "
"FP64, COMPLEX64 or COMPLEX128.",
framework::DataTypeToString(in_dtype)));
}
}();
......@@ -430,35 +431,27 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
const DFTI_CONFIG_VALUE domain =
(fft_type == FFTTransformType::C2C) ? DFTI_COMPLEX : DFTI_REAL;
// const bool complex_input = framework::IsComplexType(in_dtype);
// const bool complex_output = framework::IsComplexType(out_dtype);
// const DFTI_CONFIG_VALUE domain = [&] {
// if (forward) {
// return complex_input ? DFTI_COMPLEX : DFTI_REAL;
// } else {
// return complex_output ? DFTI_COMPLEX : DFTI_REAL;
// }
// }();
DftiDescriptor descriptor;
std::vector<MKL_LONG> fft_sizes(signal_sizes.cbegin(), signal_sizes.cend());
const MKL_LONG signal_ndim = fft_sizes.size() - 1;
descriptor.init(precision, domain, signal_ndim, fft_sizes.data() + 1);
// placement inplace or not inplace
MKL_DFTI_CHECK(
DftiSetValue(descriptor.get(), DFTI_PLACEMENT, DFTI_NOT_INPLACE));
MKL_DFTI_CHECK(platform::dynload::DftiSetValue(
descriptor.get(), DFTI_PLACEMENT, DFTI_NOT_INPLACE));
// number of transformations
const MKL_LONG batch_size = fft_sizes[0];
MKL_DFTI_CHECK(
DftiSetValue(descriptor.get(), DFTI_NUMBER_OF_TRANSFORMS, batch_size));
MKL_DFTI_CHECK(platform::dynload::DftiSetValue(
descriptor.get(), DFTI_NUMBER_OF_TRANSFORMS, batch_size));
// input & output distance
const MKL_LONG idist = in_strides[0];
const MKL_LONG odist = out_strides[0];
MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_INPUT_DISTANCE, idist));
MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_OUTPUT_DISTANCE, odist));
MKL_DFTI_CHECK(platform::dynload::DftiSetValue(descriptor.get(),
DFTI_INPUT_DISTANCE, idist));
MKL_DFTI_CHECK(platform::dynload::DftiSetValue(descriptor.get(),
DFTI_OUTPUT_DISTANCE, odist));
// input & output stride
std::vector<MKL_LONG> mkl_in_stride(1 + signal_ndim, 0);
......@@ -467,15 +460,15 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
mkl_in_stride[i] = in_strides[i];
mkl_out_stride[i] = out_strides[i];
}
MKL_DFTI_CHECK(
DftiSetValue(descriptor.get(), DFTI_INPUT_STRIDES, mkl_in_stride.data()));
MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_OUTPUT_STRIDES,
mkl_out_stride.data()));
MKL_DFTI_CHECK(platform::dynload::DftiSetValue(
descriptor.get(), DFTI_INPUT_STRIDES, mkl_in_stride.data()));
MKL_DFTI_CHECK(platform::dynload::DftiSetValue(
descriptor.get(), DFTI_OUTPUT_STRIDES, mkl_out_stride.data()));
// conjugate even storage
if (!(fft_type == FFTTransformType::C2C)) {
MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_CONJUGATE_EVEN_STORAGE,
DFTI_COMPLEX_COMPLEX));
MKL_DFTI_CHECK(platform::dynload::DftiSetValue(
descriptor.get(), DFTI_CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX));
}
MKL_LONG signal_numel =
......@@ -496,11 +489,12 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
return DFTI_BACKWARD_SCALE;
}
}();
MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), scale_direction, scale));
MKL_DFTI_CHECK(platform::dynload::DftiSetValue(descriptor.get(),
scale_direction, scale));
}
// commit the descriptor
MKL_DFTI_CHECK(DftiCommitDescriptor(descriptor.get()));
MKL_DFTI_CHECK(platform::dynload::DftiCommitDescriptor(descriptor.get()));
return descriptor;
}
......@@ -592,14 +586,15 @@ void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out,
collapsed_input.numel(),
collapsed_input_conj.data<Ti>());
for_range(functor);
MKL_DFTI_CHECK(DftiComputeBackward(desc.get(),
collapsed_input_conj.data<void>(),
MKL_DFTI_CHECK(platform::dynload::DftiComputeBackward(
desc.get(), collapsed_input_conj.data<void>(),
collapsed_output.data<void>()));
} else if (fft_type == FFTTransformType::R2C && !forward) {
framework::Tensor collapsed_output_conj(collapsed_output.type());
collapsed_output_conj.mutable_data<To>(collapsed_output.dims(),
ctx.GetPlace());
MKL_DFTI_CHECK(DftiComputeForward(desc.get(), collapsed_input.data<void>(),
MKL_DFTI_CHECK(platform::dynload::DftiComputeForward(
desc.get(), collapsed_input.data<void>(),
collapsed_output_conj.data<void>()));
// conjugate the output
platform::ForRange<DeviceContext> for_range(ctx, collapsed_output.numel());
......@@ -609,12 +604,12 @@ void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out,
for_range(functor);
} else {
if (forward) {
MKL_DFTI_CHECK(DftiComputeForward(desc.get(),
collapsed_input.data<void>(),
MKL_DFTI_CHECK(platform::dynload::DftiComputeForward(
desc.get(), collapsed_input.data<void>(),
collapsed_output.data<void>()));
} else {
MKL_DFTI_CHECK(DftiComputeBackward(desc.get(),
collapsed_input.data<void>(),
MKL_DFTI_CHECK(platform::dynload::DftiComputeBackward(
desc.get(), collapsed_input.data<void>(),
collapsed_output.data<void>()));
}
}
......
......@@ -49,3 +49,9 @@ endif()
cc_library(dynload_lapack SRCS lapack.cc DEPS dynamic_loader)
add_dependencies(dynload_lapack extern_lapack)
# TODO(TJ): add iomp, mkldnn?
if (MKL_FOUND AND WITH_ONEMKL)
message("ONEMKL INCLUDE directory is ${MKL_INCLUDE}")
cc_library(dynload_mklrt SRCS mklrt.cc DEPS dynamic_loader)
target_include_directories(dynload_mklrt PRIVATE ${MKL_INCLUDE})
endif()
......@@ -53,6 +53,12 @@ DEFINE_string(mklml_dir, "", "Specify path for loading libmklml_intel.so.");
DEFINE_string(lapack_dir, "", "Specify path for loading liblapack.so.");
DEFINE_string(mkl_dir, "",
"Specify path for loading libmkl_rt.so. "
"For insrance, /opt/intel/oneapi/mkl/latest/lib/intel64/."
"If default, "
"dlopen will search mkl from LD_LIBRARY_PATH");
DEFINE_string(op_dir, "", "Specify path for loading user-defined op library.");
#ifdef PADDLE_WITH_HIP
......@@ -518,6 +524,16 @@ void* GetCUFFTDsoHandle() {
#endif
}
void* GetMKLRTDsoHandle() {
#if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(FLAGS_mkl_dir, "libmkl_rt.dylib");
#elif defined(_WIN32)
return GetDsoHandleFromSearchPath(FLAGS_mkl_dir, "mkl_rt.dll");
#else
return GetDsoHandleFromSearchPath(FLAGS_mkl_dir, "libmkl_rt.so");
#endif
}
} // namespace dynload
} // namespace platform
} // namespace paddle
......@@ -43,6 +43,7 @@ void* GetLAPACKDsoHandle();
void* GetOpDsoHandle(const std::string& dso_name);
void* GetNvtxDsoHandle();
void* GetCUFFTDsoHandle();
void* GetMKLRTDsoHandle();
void SetPaddleLibPath(const std::string&);
} // namespace dynload
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/dynload/mklrt.h"
namespace paddle {
namespace platform {
namespace dynload {
std::once_flag mklrt_dso_flag;
void* mklrt_dso_handle = nullptr;
#define DEFINE_WRAP(__name) DynLoad__##__name __name
MKLDFTI_ROUTINE_EACH(DEFINE_WRAP);
DFTI_EXTERN MKL_LONG DftiCreateDescriptorX(DFTI_DESCRIPTOR_HANDLE* desc,
enum DFTI_CONFIG_VALUE prec,
enum DFTI_CONFIG_VALUE domain,
MKL_LONG dim, MKL_LONG* sizes) {
if (prec == DFTI_SINGLE) {
if (dim == 1) {
return DftiCreateDescriptor_s_1d(desc, domain, sizes[0]);
} else {
return DftiCreateDescriptor_s_md(desc, domain, dim, sizes);
}
} else if (prec == DFTI_DOUBLE) {
if (dim == 1) {
return DftiCreateDescriptor_d_1d(desc, domain, sizes[0]);
} else {
return DftiCreateDescriptor_d_md(desc, domain, dim, sizes);
}
} else {
return DftiCreateDescriptor(desc, prec, domain, dim, sizes);
}
}
} // namespace dynload
} // namespace platform
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <mkl_dfti.h>
#include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/platform/port.h"
namespace paddle {
namespace platform {
namespace dynload {
extern std::once_flag mklrt_dso_flag;
extern void* mklrt_dso_handle;
/**
* The following macro definition can generate structs
* (for each function) to dynamic load mkldfti routine
* via operator overloading.
*/
#define DYNAMIC_LOAD_MKLRT_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
using mklrtFunc = decltype(&::__name); \
std::call_once(mklrt_dso_flag, []() { \
mklrt_dso_handle = paddle::platform::dynload::GetMKLRTDsoHandle(); \
}); \
static void* p_##__name = dlsym(mklrt_dso_handle, #__name); \
return reinterpret_cast<mklrtFunc>(p_##__name)(args...); \
} \
}; \
extern DynLoad__##__name __name
// mkl_dfti.h has a macro that shadows the function with the same name
// un-defeine this macro so as to export that function
#undef DftiCreateDescriptor
#define MKLDFTI_ROUTINE_EACH(__macro) \
__macro(DftiCreateDescriptor); \
__macro(DftiCreateDescriptor_s_1d); \
__macro(DftiCreateDescriptor_d_1d); \
__macro(DftiCreateDescriptor_s_md); \
__macro(DftiCreateDescriptor_d_md); \
__macro(DftiSetValue); \
__macro(DftiGetValue); \
__macro(DftiCommitDescriptor); \
__macro(DftiComputeForward); \
__macro(DftiComputeBackward); \
__macro(DftiFreeDescriptor); \
__macro(DftiErrorClass); \
__macro(DftiErrorMessage);
MKLDFTI_ROUTINE_EACH(DYNAMIC_LOAD_MKLRT_WRAP)
#undef DYNAMIC_LOAD_MKLRT_WRAP
// define another function to avoid naming conflict
DFTI_EXTERN MKL_LONG DftiCreateDescriptorX(DFTI_DESCRIPTOR_HANDLE* desc,
enum DFTI_CONFIG_VALUE prec,
enum DFTI_CONFIG_VALUE domain,
MKL_LONG dim, MKL_LONG* sizes);
} // namespace dynload
} // namespace platform
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册