Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
f45e6cf6
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f45e6cf6
编写于
10月 15, 2021
作者:
F
Feiyu Chan
提交者:
GitHub
10月 15, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
dynamic load mkl as a fft backend when it is avaialble and requested (#36414)
上级
b3f02c57
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
221 addition
and
61 deletion
+221
-61
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+13
-2
paddle/fluid/operators/spectral_op.cc
paddle/fluid/operators/spectral_op.cc
+54
-59
paddle/fluid/platform/dynload/CMakeLists.txt
paddle/fluid/platform/dynload/CMakeLists.txt
+6
-0
paddle/fluid/platform/dynload/dynamic_loader.cc
paddle/fluid/platform/dynload/dynamic_loader.cc
+16
-0
paddle/fluid/platform/dynload/dynamic_loader.h
paddle/fluid/platform/dynload/dynamic_loader.h
+1
-0
paddle/fluid/platform/dynload/mklrt.cc
paddle/fluid/platform/dynload/mklrt.cc
+51
-0
paddle/fluid/platform/dynload/mklrt.h
paddle/fluid/platform/dynload/mklrt.h
+80
-0
未找到文件。
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
f45e6cf6
...
@@ -102,10 +102,21 @@ else()
...
@@ -102,10 +102,21 @@ else()
op_library
(
warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale
)
op_library
(
warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale
)
endif
()
endif
()
if
(
WITH_GPU
AND
(
NOT WITH_ROCM
))
if
(
WITH_GPU
AND
(
NOT WITH_ROCM
))
if
(
MKL_FOUND AND WITH_ONEMKL
)
op_library
(
spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda dynload_mklrt
${
OP_HEADER_DEPS
}
)
target_include_directories
(
spectral_op PRIVATE
${
MKL_INCLUDE
}
)
else
()
op_library
(
spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda
${
OP_HEADER_DEPS
}
)
op_library
(
spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda
${
OP_HEADER_DEPS
}
)
endif
()
else
()
else
()
if
(
MKL_FOUND AND WITH_ONEMKL
)
op_library
(
spectral_op SRCS spectral_op.cc DEPS dynload_mklrt
${
OP_HEADER_DEPS
}
)
target_include_directories
(
spectral_op PRIVATE
${
MKL_INCLUDE
}
)
else
()
op_library
(
spectral_op SRCS spectral_op.cc DEPS
${
OP_HEADER_DEPS
}
)
op_library
(
spectral_op SRCS spectral_op.cc DEPS
${
OP_HEADER_DEPS
}
)
endif
()
endif
()
endif
()
op_library
(
lstm_op DEPS
${
OP_HEADER_DEPS
}
lstm_compute
)
op_library
(
lstm_op DEPS
${
OP_HEADER_DEPS
}
lstm_compute
)
...
...
paddle/fluid/operators/spectral_op.cc
浏览文件 @
f45e6cf6
...
@@ -27,7 +27,7 @@
...
@@ -27,7 +27,7 @@
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex.h"
#if defined(PADDLE_WITH_ONEMKL)
#if defined(PADDLE_WITH_ONEMKL)
#include
<mkl_dfti.h>
#include
"paddle/fluid/platform/dynload/mklrt.h"
#elif defined(PADDLE_WITH_POCKETFFT)
#elif defined(PADDLE_WITH_POCKETFFT)
#include "extern_pocketfft/pocketfft_hdronly.h"
#include "extern_pocketfft/pocketfft_hdronly.h"
#endif
#endif
...
@@ -357,46 +357,45 @@ FFTNormMode get_norm_from_string(const std::string& norm, bool forward) {
...
@@ -357,46 +357,45 @@ FFTNormMode get_norm_from_string(const std::string& norm, bool forward) {
// FFT Functors
// FFT Functors
#if defined(PADDLE_WITH_ONEMKL)
#if defined(PADDLE_WITH_ONEMKL)
#define MKL_DFTI_CHECK(expr) \
do { \
MKL_LONG status = (expr); \
if (!platform::dynload::DftiErrorClass(status, DFTI_NO_ERROR)) \
PADDLE_THROW(platform::errors::External( \
platform::dynload::DftiErrorMessage(status))); \
} while (0);
namespace
{
namespace
{
static
inline
void
MKL_DFTI_CHECK
(
MKL_INT
status
)
{
if
(
status
&&
!
DftiErrorClass
(
status
,
DFTI_NO_ERROR
))
{
PADDLE_THROW
(
platform
::
errors
::
External
(
DftiErrorMessage
(
status
)));
}
}
struct
DftiDescriptorDeleter
{
struct
DftiDescriptorDeleter
{
void
operator
()(
DFTI_DESCRIPTOR_HANDLE
handle
)
{
void
operator
()(
DFTI_DESCRIPTOR_HANDLE
handle
)
{
if
(
handle
!=
nullptr
)
{
if
(
handle
!=
nullptr
)
{
MKL_DFTI_CHECK
(
DftiFreeDescriptor
(
&
handle
));
MKL_DFTI_CHECK
(
platform
::
dynload
::
DftiFreeDescriptor
(
&
handle
));
}
}
}
}
};
};
// A RAII wrapper for MKL_DESCRIPTOR*
class
DftiDescriptor
{
class
DftiDescriptor
{
public:
public:
void
init
(
DFTI_CONFIG_VALUE
precision
,
DFTI_CONFIG_VALUE
signal_type
,
void
init
(
DFTI_CONFIG_VALUE
precision
,
DFTI_CONFIG_VALUE
signal_type
,
MKL_LONG
signal_ndim
,
MKL_LONG
*
sizes
)
{
MKL_LONG
signal_ndim
,
MKL_LONG
*
sizes
)
{
if
(
desc_
!=
nullptr
)
{
PADDLE_ENFORCE_EQ
(
desc_
.
get
(),
nullptr
,
PADDLE_THROW
(
platform
::
errors
::
AlreadyExists
(
platform
::
errors
::
AlreadyExists
(
"DFT DESCRIPTOR can only be initialized once
."
));
"DftiDescriptor has already been initialized
."
));
}
DFTI_DESCRIPTOR
*
raw_desc
;
DFTI_DESCRIPTOR
*
raw_desc
;
if
(
signal_ndim
==
1
)
{
MKL_DFTI_CHECK
(
platform
::
dynload
::
DftiCreateDescriptorX
(
MKL_DFTI_CHECK
(
&
raw_desc
,
precision
,
signal_type
,
signal_ndim
,
sizes
));
DftiCreateDescriptor
(
&
raw_desc
,
precision
,
signal_type
,
1
,
sizes
[
0
]));
}
else
{
MKL_DFTI_CHECK
(
DftiCreateDescriptor
(
&
raw_desc
,
precision
,
signal_type
,
signal_ndim
,
sizes
));
}
desc_
.
reset
(
raw_desc
);
desc_
.
reset
(
raw_desc
);
}
}
DFTI_DESCRIPTOR
*
get
()
const
{
DFTI_DESCRIPTOR
*
get
()
const
{
if
(
desc_
==
nullptr
)
{
DFTI_DESCRIPTOR
*
raw_desc
=
desc_
.
get
();
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
PADDLE_ENFORCE_NOT_NULL
(
raw_desc
,
platform
::
errors
::
PreconditionNotMet
(
"DFTI DESCRIPTOR has not been initialized."
));
"DFTI DESCRIPTOR has not been initialized."
));
}
return
raw_desc
;
return
desc_
.
get
();
}
}
private:
private:
...
@@ -421,7 +420,9 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
...
@@ -421,7 +420,9 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
return
DFTI_DOUBLE
;
return
DFTI_DOUBLE
;
default:
default:
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Input data type should be FP32, FP64, COMPLEX64 or COMPLEX128."
));
"Invalid input datatype (%s), input data type should be FP32, "
"FP64, COMPLEX64 or COMPLEX128."
,
framework
::
DataTypeToString
(
in_dtype
)));
}
}
}();
}();
...
@@ -430,35 +431,27 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
...
@@ -430,35 +431,27 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
const
DFTI_CONFIG_VALUE
domain
=
const
DFTI_CONFIG_VALUE
domain
=
(
fft_type
==
FFTTransformType
::
C2C
)
?
DFTI_COMPLEX
:
DFTI_REAL
;
(
fft_type
==
FFTTransformType
::
C2C
)
?
DFTI_COMPLEX
:
DFTI_REAL
;
// const bool complex_input = framework::IsComplexType(in_dtype);
// const bool complex_output = framework::IsComplexType(out_dtype);
// const DFTI_CONFIG_VALUE domain = [&] {
// if (forward) {
// return complex_input ? DFTI_COMPLEX : DFTI_REAL;
// } else {
// return complex_output ? DFTI_COMPLEX : DFTI_REAL;
// }
// }();
DftiDescriptor
descriptor
;
DftiDescriptor
descriptor
;
std
::
vector
<
MKL_LONG
>
fft_sizes
(
signal_sizes
.
cbegin
(),
signal_sizes
.
cend
());
std
::
vector
<
MKL_LONG
>
fft_sizes
(
signal_sizes
.
cbegin
(),
signal_sizes
.
cend
());
const
MKL_LONG
signal_ndim
=
fft_sizes
.
size
()
-
1
;
const
MKL_LONG
signal_ndim
=
fft_sizes
.
size
()
-
1
;
descriptor
.
init
(
precision
,
domain
,
signal_ndim
,
fft_sizes
.
data
()
+
1
);
descriptor
.
init
(
precision
,
domain
,
signal_ndim
,
fft_sizes
.
data
()
+
1
);
// placement inplace or not inplace
// placement inplace or not inplace
MKL_DFTI_CHECK
(
MKL_DFTI_CHECK
(
platform
::
dynload
::
DftiSetValue
(
DftiSetValue
(
descriptor
.
get
(),
DFTI_PLACEMENT
,
DFTI_NOT_INPLACE
));
descriptor
.
get
(),
DFTI_PLACEMENT
,
DFTI_NOT_INPLACE
));
// number of transformations
// number of transformations
const
MKL_LONG
batch_size
=
fft_sizes
[
0
];
const
MKL_LONG
batch_size
=
fft_sizes
[
0
];
MKL_DFTI_CHECK
(
MKL_DFTI_CHECK
(
platform
::
dynload
::
DftiSetValue
(
DftiSetValue
(
descriptor
.
get
(),
DFTI_NUMBER_OF_TRANSFORMS
,
batch_size
));
descriptor
.
get
(),
DFTI_NUMBER_OF_TRANSFORMS
,
batch_size
));
// input & output distance
// input & output distance
const
MKL_LONG
idist
=
in_strides
[
0
];
const
MKL_LONG
idist
=
in_strides
[
0
];
const
MKL_LONG
odist
=
out_strides
[
0
];
const
MKL_LONG
odist
=
out_strides
[
0
];
MKL_DFTI_CHECK
(
DftiSetValue
(
descriptor
.
get
(),
DFTI_INPUT_DISTANCE
,
idist
));
MKL_DFTI_CHECK
(
platform
::
dynload
::
DftiSetValue
(
descriptor
.
get
(),
MKL_DFTI_CHECK
(
DftiSetValue
(
descriptor
.
get
(),
DFTI_OUTPUT_DISTANCE
,
odist
));
DFTI_INPUT_DISTANCE
,
idist
));
MKL_DFTI_CHECK
(
platform
::
dynload
::
DftiSetValue
(
descriptor
.
get
(),
DFTI_OUTPUT_DISTANCE
,
odist
));
// input & output stride
// input & output stride
std
::
vector
<
MKL_LONG
>
mkl_in_stride
(
1
+
signal_ndim
,
0
);
std
::
vector
<
MKL_LONG
>
mkl_in_stride
(
1
+
signal_ndim
,
0
);
...
@@ -467,15 +460,15 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
...
@@ -467,15 +460,15 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
mkl_in_stride
[
i
]
=
in_strides
[
i
];
mkl_in_stride
[
i
]
=
in_strides
[
i
];
mkl_out_stride
[
i
]
=
out_strides
[
i
];
mkl_out_stride
[
i
]
=
out_strides
[
i
];
}
}
MKL_DFTI_CHECK
(
MKL_DFTI_CHECK
(
platform
::
dynload
::
DftiSetValue
(
DftiSetValue
(
descriptor
.
get
(),
DFTI_INPUT_STRIDES
,
mkl_in_stride
.
data
()));
descriptor
.
get
(),
DFTI_INPUT_STRIDES
,
mkl_in_stride
.
data
()));
MKL_DFTI_CHECK
(
DftiSetValue
(
descriptor
.
get
(),
DFTI_OUTPUT_STRIDES
,
MKL_DFTI_CHECK
(
platform
::
dynload
::
DftiSetValue
(
mkl_out_stride
.
data
()));
descriptor
.
get
(),
DFTI_OUTPUT_STRIDES
,
mkl_out_stride
.
data
()));
// conjugate even storage
// conjugate even storage
if
(
!
(
fft_type
==
FFTTransformType
::
C2C
))
{
if
(
!
(
fft_type
==
FFTTransformType
::
C2C
))
{
MKL_DFTI_CHECK
(
DftiSetValue
(
descriptor
.
get
(),
DFTI_CONJUGATE_EVEN_STORAGE
,
MKL_DFTI_CHECK
(
platform
::
dynload
::
DftiSetValue
(
DFTI_COMPLEX_COMPLEX
));
descriptor
.
get
(),
DFTI_CONJUGATE_EVEN_STORAGE
,
DFTI_COMPLEX_COMPLEX
));
}
}
MKL_LONG
signal_numel
=
MKL_LONG
signal_numel
=
...
@@ -496,11 +489,12 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
...
@@ -496,11 +489,12 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
return
DFTI_BACKWARD_SCALE
;
return
DFTI_BACKWARD_SCALE
;
}
}
}();
}();
MKL_DFTI_CHECK
(
DftiSetValue
(
descriptor
.
get
(),
scale_direction
,
scale
));
MKL_DFTI_CHECK
(
platform
::
dynload
::
DftiSetValue
(
descriptor
.
get
(),
scale_direction
,
scale
));
}
}
// commit the descriptor
// commit the descriptor
MKL_DFTI_CHECK
(
DftiCommitDescriptor
(
descriptor
.
get
()));
MKL_DFTI_CHECK
(
platform
::
dynload
::
DftiCommitDescriptor
(
descriptor
.
get
()));
return
descriptor
;
return
descriptor
;
}
}
...
@@ -592,14 +586,15 @@ void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out,
...
@@ -592,14 +586,15 @@ void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out,
collapsed_input
.
numel
(),
collapsed_input
.
numel
(),
collapsed_input_conj
.
data
<
Ti
>
());
collapsed_input_conj
.
data
<
Ti
>
());
for_range
(
functor
);
for_range
(
functor
);
MKL_DFTI_CHECK
(
DftiComputeBackward
(
desc
.
get
(),
MKL_DFTI_CHECK
(
platform
::
dynload
::
DftiComputeBackward
(
collapsed_input_conj
.
data
<
void
>
(),
desc
.
get
(),
collapsed_input_conj
.
data
<
void
>
(),
collapsed_output
.
data
<
void
>
()));
collapsed_output
.
data
<
void
>
()));
}
else
if
(
fft_type
==
FFTTransformType
::
R2C
&&
!
forward
)
{
}
else
if
(
fft_type
==
FFTTransformType
::
R2C
&&
!
forward
)
{
framework
::
Tensor
collapsed_output_conj
(
collapsed_output
.
type
());
framework
::
Tensor
collapsed_output_conj
(
collapsed_output
.
type
());
collapsed_output_conj
.
mutable_data
<
To
>
(
collapsed_output
.
dims
(),
collapsed_output_conj
.
mutable_data
<
To
>
(
collapsed_output
.
dims
(),
ctx
.
GetPlace
());
ctx
.
GetPlace
());
MKL_DFTI_CHECK
(
DftiComputeForward
(
desc
.
get
(),
collapsed_input
.
data
<
void
>
(),
MKL_DFTI_CHECK
(
platform
::
dynload
::
DftiComputeForward
(
desc
.
get
(),
collapsed_input
.
data
<
void
>
(),
collapsed_output_conj
.
data
<
void
>
()));
collapsed_output_conj
.
data
<
void
>
()));
// conjugate the output
// conjugate the output
platform
::
ForRange
<
DeviceContext
>
for_range
(
ctx
,
collapsed_output
.
numel
());
platform
::
ForRange
<
DeviceContext
>
for_range
(
ctx
,
collapsed_output
.
numel
());
...
@@ -609,12 +604,12 @@ void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out,
...
@@ -609,12 +604,12 @@ void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out,
for_range
(
functor
);
for_range
(
functor
);
}
else
{
}
else
{
if
(
forward
)
{
if
(
forward
)
{
MKL_DFTI_CHECK
(
DftiComputeForward
(
desc
.
get
(),
MKL_DFTI_CHECK
(
platform
::
dynload
::
DftiComputeForward
(
collapsed_input
.
data
<
void
>
(),
desc
.
get
(),
collapsed_input
.
data
<
void
>
(),
collapsed_output
.
data
<
void
>
()));
collapsed_output
.
data
<
void
>
()));
}
else
{
}
else
{
MKL_DFTI_CHECK
(
DftiComputeBackward
(
desc
.
get
(),
MKL_DFTI_CHECK
(
platform
::
dynload
::
DftiComputeBackward
(
collapsed_input
.
data
<
void
>
(),
desc
.
get
(),
collapsed_input
.
data
<
void
>
(),
collapsed_output
.
data
<
void
>
()));
collapsed_output
.
data
<
void
>
()));
}
}
}
}
...
...
paddle/fluid/platform/dynload/CMakeLists.txt
浏览文件 @
f45e6cf6
...
@@ -49,3 +49,9 @@ endif()
...
@@ -49,3 +49,9 @@ endif()
cc_library
(
dynload_lapack SRCS lapack.cc DEPS dynamic_loader
)
cc_library
(
dynload_lapack SRCS lapack.cc DEPS dynamic_loader
)
add_dependencies
(
dynload_lapack extern_lapack
)
add_dependencies
(
dynload_lapack extern_lapack
)
# TODO(TJ): add iomp, mkldnn?
# TODO(TJ): add iomp, mkldnn?
if
(
MKL_FOUND AND WITH_ONEMKL
)
message
(
"ONEMKL INCLUDE directory is
${
MKL_INCLUDE
}
"
)
cc_library
(
dynload_mklrt SRCS mklrt.cc DEPS dynamic_loader
)
target_include_directories
(
dynload_mklrt PRIVATE
${
MKL_INCLUDE
}
)
endif
()
paddle/fluid/platform/dynload/dynamic_loader.cc
浏览文件 @
f45e6cf6
...
@@ -53,6 +53,12 @@ DEFINE_string(mklml_dir, "", "Specify path for loading libmklml_intel.so.");
...
@@ -53,6 +53,12 @@ DEFINE_string(mklml_dir, "", "Specify path for loading libmklml_intel.so.");
DEFINE_string
(
lapack_dir
,
""
,
"Specify path for loading liblapack.so."
);
DEFINE_string
(
lapack_dir
,
""
,
"Specify path for loading liblapack.so."
);
DEFINE_string
(
mkl_dir
,
""
,
"Specify path for loading libmkl_rt.so. "
"For insrance, /opt/intel/oneapi/mkl/latest/lib/intel64/."
"If default, "
"dlopen will search mkl from LD_LIBRARY_PATH"
);
DEFINE_string
(
op_dir
,
""
,
"Specify path for loading user-defined op library."
);
DEFINE_string
(
op_dir
,
""
,
"Specify path for loading user-defined op library."
);
#ifdef PADDLE_WITH_HIP
#ifdef PADDLE_WITH_HIP
...
@@ -518,6 +524,16 @@ void* GetCUFFTDsoHandle() {
...
@@ -518,6 +524,16 @@ void* GetCUFFTDsoHandle() {
#endif
#endif
}
}
void
*
GetMKLRTDsoHandle
()
{
#if defined(__APPLE__) || defined(__OSX__)
return
GetDsoHandleFromSearchPath
(
FLAGS_mkl_dir
,
"libmkl_rt.dylib"
);
#elif defined(_WIN32)
return
GetDsoHandleFromSearchPath
(
FLAGS_mkl_dir
,
"mkl_rt.dll"
);
#else
return
GetDsoHandleFromSearchPath
(
FLAGS_mkl_dir
,
"libmkl_rt.so"
);
#endif
}
}
// namespace dynload
}
// namespace dynload
}
// namespace platform
}
// namespace platform
}
// namespace paddle
}
// namespace paddle
paddle/fluid/platform/dynload/dynamic_loader.h
浏览文件 @
f45e6cf6
...
@@ -43,6 +43,7 @@ void* GetLAPACKDsoHandle();
...
@@ -43,6 +43,7 @@ void* GetLAPACKDsoHandle();
void
*
GetOpDsoHandle
(
const
std
::
string
&
dso_name
);
void
*
GetOpDsoHandle
(
const
std
::
string
&
dso_name
);
void
*
GetNvtxDsoHandle
();
void
*
GetNvtxDsoHandle
();
void
*
GetCUFFTDsoHandle
();
void
*
GetCUFFTDsoHandle
();
void
*
GetMKLRTDsoHandle
();
void
SetPaddleLibPath
(
const
std
::
string
&
);
void
SetPaddleLibPath
(
const
std
::
string
&
);
}
// namespace dynload
}
// namespace dynload
...
...
paddle/fluid/platform/dynload/mklrt.cc
0 → 100644
浏览文件 @
f45e6cf6
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/dynload/mklrt.h"
namespace
paddle
{
namespace
platform
{
namespace
dynload
{
std
::
once_flag
mklrt_dso_flag
;
void
*
mklrt_dso_handle
=
nullptr
;
#define DEFINE_WRAP(__name) DynLoad__##__name __name
MKLDFTI_ROUTINE_EACH
(
DEFINE_WRAP
);
DFTI_EXTERN
MKL_LONG
DftiCreateDescriptorX
(
DFTI_DESCRIPTOR_HANDLE
*
desc
,
enum
DFTI_CONFIG_VALUE
prec
,
enum
DFTI_CONFIG_VALUE
domain
,
MKL_LONG
dim
,
MKL_LONG
*
sizes
)
{
if
(
prec
==
DFTI_SINGLE
)
{
if
(
dim
==
1
)
{
return
DftiCreateDescriptor_s_1d
(
desc
,
domain
,
sizes
[
0
]);
}
else
{
return
DftiCreateDescriptor_s_md
(
desc
,
domain
,
dim
,
sizes
);
}
}
else
if
(
prec
==
DFTI_DOUBLE
)
{
if
(
dim
==
1
)
{
return
DftiCreateDescriptor_d_1d
(
desc
,
domain
,
sizes
[
0
]);
}
else
{
return
DftiCreateDescriptor_d_md
(
desc
,
domain
,
dim
,
sizes
);
}
}
else
{
return
DftiCreateDescriptor
(
desc
,
prec
,
domain
,
dim
,
sizes
);
}
}
}
// namespace dynload
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/dynload/mklrt.h
0 → 100644
浏览文件 @
f45e6cf6
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <mkl_dfti.h>
#include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/platform/port.h"
namespace
paddle
{
namespace
platform
{
namespace
dynload
{
extern
std
::
once_flag
mklrt_dso_flag
;
extern
void
*
mklrt_dso_handle
;
/**
* The following macro definition can generate structs
* (for each function) to dynamic load mkldfti routine
* via operator overloading.
*/
#define DYNAMIC_LOAD_MKLRT_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
using mklrtFunc = decltype(&::__name); \
std::call_once(mklrt_dso_flag, []() { \
mklrt_dso_handle = paddle::platform::dynload::GetMKLRTDsoHandle(); \
}); \
static void* p_##__name = dlsym(mklrt_dso_handle, #__name); \
return reinterpret_cast<mklrtFunc>(p_##__name)(args...); \
} \
}; \
extern DynLoad__##__name __name
// mkl_dfti.h has a macro that shadows the function with the same name
// un-defeine this macro so as to export that function
#undef DftiCreateDescriptor
#define MKLDFTI_ROUTINE_EACH(__macro) \
__macro(DftiCreateDescriptor); \
__macro(DftiCreateDescriptor_s_1d); \
__macro(DftiCreateDescriptor_d_1d); \
__macro(DftiCreateDescriptor_s_md); \
__macro(DftiCreateDescriptor_d_md); \
__macro(DftiSetValue); \
__macro(DftiGetValue); \
__macro(DftiCommitDescriptor); \
__macro(DftiComputeForward); \
__macro(DftiComputeBackward); \
__macro(DftiFreeDescriptor); \
__macro(DftiErrorClass); \
__macro(DftiErrorMessage);
MKLDFTI_ROUTINE_EACH
(
DYNAMIC_LOAD_MKLRT_WRAP
)
#undef DYNAMIC_LOAD_MKLRT_WRAP
// define another function to avoid naming conflict
DFTI_EXTERN
MKL_LONG
DftiCreateDescriptorX
(
DFTI_DESCRIPTOR_HANDLE
*
desc
,
enum
DFTI_CONFIG_VALUE
prec
,
enum
DFTI_CONFIG_VALUE
domain
,
MKL_LONG
dim
,
MKL_LONG
*
sizes
);
}
// namespace dynload
}
// namespace platform
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录