Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
687902fc
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
687902fc
编写于
2月 25, 2022
作者:
F
Feiyu Chan
提交者:
GitHub
2月 25, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[phi] update code for mkl based fft (#39889)
上级
584844ec
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
40 addition
and
38 deletion
+40
-38
paddle/fluid/operators/spectral_op.cc
paddle/fluid/operators/spectral_op.cc
+38
-37
paddle/fluid/platform/dynload/mklrt.h
paddle/fluid/platform/dynload/mklrt.h
+2
-1
未找到文件。
paddle/fluid/operators/spectral_op.cc
浏览文件 @
687902fc
...
@@ -25,9 +25,10 @@
...
@@ -25,9 +25,10 @@
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#if defined(PADDLE_WITH_ONEMKL)
#if defined(PADDLE_WITH_ONEMKL)
#include "paddle/
fluid/platform
/dynload/mklrt.h"
#include "paddle/
phi/backends
/dynload/mklrt.h"
#elif defined(PADDLE_WITH_POCKETFFT)
#elif defined(PADDLE_WITH_POCKETFFT)
#include "extern_pocketfft/pocketfft_hdronly.h"
#include "extern_pocketfft/pocketfft_hdronly.h"
#endif
#endif
...
@@ -360,9 +361,9 @@ FFTNormMode get_norm_from_string(const std::string& norm, bool forward) {
...
@@ -360,9 +361,9 @@ FFTNormMode get_norm_from_string(const std::string& norm, bool forward) {
#define MKL_DFTI_CHECK(expr) \
#define MKL_DFTI_CHECK(expr) \
do { \
do { \
MKL_LONG status = (expr); \
MKL_LONG status = (expr); \
if (!p
latform::dynload::DftiErrorClass(status, DFTI_NO_ERROR))
\
if (!p
hi::dynload::DftiErrorClass(status, DFTI_NO_ERROR))
\
PADDLE_THROW(
platform::errors::External(
\
PADDLE_THROW(
\
platform::
dynload::DftiErrorMessage(status)));
\
platform::
errors::External(phi::dynload::DftiErrorMessage(status)));
\
} while (0);
} while (0);
namespace
{
namespace
{
...
@@ -370,7 +371,7 @@ namespace {
...
@@ -370,7 +371,7 @@ namespace {
struct
DftiDescriptorDeleter
{
struct
DftiDescriptorDeleter
{
void
operator
()(
DFTI_DESCRIPTOR_HANDLE
handle
)
{
void
operator
()(
DFTI_DESCRIPTOR_HANDLE
handle
)
{
if
(
handle
!=
nullptr
)
{
if
(
handle
!=
nullptr
)
{
MKL_DFTI_CHECK
(
p
latform
::
dynload
::
DftiFreeDescriptor
(
&
handle
));
MKL_DFTI_CHECK
(
p
hi
::
dynload
::
DftiFreeDescriptor
(
&
handle
));
}
}
}
}
};
};
...
@@ -385,7 +386,7 @@ class DftiDescriptor {
...
@@ -385,7 +386,7 @@ class DftiDescriptor {
"DftiDescriptor has already been initialized."
));
"DftiDescriptor has already been initialized."
));
DFTI_DESCRIPTOR
*
raw_desc
;
DFTI_DESCRIPTOR
*
raw_desc
;
MKL_DFTI_CHECK
(
p
latform
::
dynload
::
DftiCreateDescriptorX
(
MKL_DFTI_CHECK
(
p
hi
::
dynload
::
DftiCreateDescriptorX
(
&
raw_desc
,
precision
,
signal_type
,
signal_ndim
,
sizes
));
&
raw_desc
,
precision
,
signal_type
,
signal_ndim
,
sizes
));
desc_
.
reset
(
raw_desc
);
desc_
.
reset
(
raw_desc
);
}
}
...
@@ -437,20 +438,20 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
...
@@ -437,20 +438,20 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
descriptor
.
init
(
precision
,
domain
,
signal_ndim
,
fft_sizes
.
data
()
+
1
);
descriptor
.
init
(
precision
,
domain
,
signal_ndim
,
fft_sizes
.
data
()
+
1
);
// placement inplace or not inplace
// placement inplace or not inplace
MKL_DFTI_CHECK
(
p
latform
::
dynload
::
DftiSetValue
(
MKL_DFTI_CHECK
(
p
hi
::
dynload
::
DftiSetValue
(
descriptor
.
get
(),
DFTI_PLACEMENT
,
descriptor
.
get
(),
DFTI_PLACEMENT
,
DFTI_NOT_INPLACE
));
DFTI_NOT_INPLACE
));
// number of transformations
// number of transformations
const
MKL_LONG
batch_size
=
fft_sizes
[
0
];
const
MKL_LONG
batch_size
=
fft_sizes
[
0
];
MKL_DFTI_CHECK
(
p
latform
::
dynload
::
DftiSetValue
(
MKL_DFTI_CHECK
(
p
hi
::
dynload
::
DftiSetValue
(
descriptor
.
get
(),
DFTI_NUMBER_OF_TRANSFORMS
,
batch_size
));
descriptor
.
get
(),
DFTI_NUMBER_OF_TRANSFORMS
,
batch_size
));
// input & output distance
// input & output distance
const
MKL_LONG
idist
=
in_strides
[
0
];
const
MKL_LONG
idist
=
in_strides
[
0
];
const
MKL_LONG
odist
=
out_strides
[
0
];
const
MKL_LONG
odist
=
out_strides
[
0
];
MKL_DFTI_CHECK
(
platform
::
dynload
::
DftiSetValue
(
descriptor
.
get
(),
MKL_DFTI_CHECK
(
DFTI_INPUT_DISTANCE
,
idist
));
phi
::
dynload
::
DftiSetValue
(
descriptor
.
get
(),
DFTI_INPUT_DISTANCE
,
idist
));
MKL_DFTI_CHECK
(
p
latform
::
dynload
::
DftiSetValue
(
descriptor
.
get
(),
MKL_DFTI_CHECK
(
p
hi
::
dynload
::
DftiSetValue
(
descriptor
.
get
(),
DFTI_OUTPUT_DISTANCE
,
odist
));
DFTI_OUTPUT_DISTANCE
,
odist
));
// input & output stride
// input & output stride
...
@@ -460,14 +461,14 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
...
@@ -460,14 +461,14 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
mkl_in_stride
[
i
]
=
in_strides
[
i
];
mkl_in_stride
[
i
]
=
in_strides
[
i
];
mkl_out_stride
[
i
]
=
out_strides
[
i
];
mkl_out_stride
[
i
]
=
out_strides
[
i
];
}
}
MKL_DFTI_CHECK
(
p
latform
::
dynload
::
DftiSetValue
(
MKL_DFTI_CHECK
(
p
hi
::
dynload
::
DftiSetValue
(
descriptor
.
get
(),
DFTI_INPUT_STRIDES
,
mkl_in_stride
.
data
()));
descriptor
.
get
(),
DFTI_INPUT_STRIDES
,
mkl_in_stride
.
data
()));
MKL_DFTI_CHECK
(
p
latform
::
dynload
::
DftiSetValue
(
MKL_DFTI_CHECK
(
p
hi
::
dynload
::
DftiSetValue
(
descriptor
.
get
(),
DFTI_OUTPUT_STRIDES
,
mkl_out_stride
.
data
()));
descriptor
.
get
(),
DFTI_OUTPUT_STRIDES
,
mkl_out_stride
.
data
()));
// conjugate even storage
// conjugate even storage
if
(
!
(
fft_type
==
FFTTransformType
::
C2C
))
{
if
(
!
(
fft_type
==
FFTTransformType
::
C2C
))
{
MKL_DFTI_CHECK
(
p
latform
::
dynload
::
DftiSetValue
(
MKL_DFTI_CHECK
(
p
hi
::
dynload
::
DftiSetValue
(
descriptor
.
get
(),
DFTI_CONJUGATE_EVEN_STORAGE
,
DFTI_COMPLEX_COMPLEX
));
descriptor
.
get
(),
DFTI_CONJUGATE_EVEN_STORAGE
,
DFTI_COMPLEX_COMPLEX
));
}
}
...
@@ -489,12 +490,12 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
...
@@ -489,12 +490,12 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
return
DFTI_BACKWARD_SCALE
;
return
DFTI_BACKWARD_SCALE
;
}
}
}();
}();
MKL_DFTI_CHECK
(
platform
::
dynload
::
DftiSetValue
(
descriptor
.
get
(),
MKL_DFTI_CHECK
(
scale_direction
,
scale
));
phi
::
dynload
::
DftiSetValue
(
descriptor
.
get
(),
scale_direction
,
scale
));
}
}
// commit the descriptor
// commit the descriptor
MKL_DFTI_CHECK
(
p
latform
::
dynload
::
DftiCommitDescriptor
(
descriptor
.
get
()));
MKL_DFTI_CHECK
(
p
hi
::
dynload
::
DftiCommitDescriptor
(
descriptor
.
get
()));
return
descriptor
;
return
descriptor
;
}
}
...
@@ -575,39 +576,39 @@ void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out,
...
@@ -575,39 +576,39 @@ void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out,
framework
::
TransToProtoVarType
(
out
->
dtype
()),
input_stride
,
framework
::
TransToProtoVarType
(
out
->
dtype
()),
input_stride
,
output_stride
,
signal_sizes
,
normalization
,
forward
);
output_stride
,
signal_sizes
,
normalization
,
forward
);
const
FFTTransformType
fft_type
=
GetFFTTransformType
(
x
->
type
(),
out
->
type
());
const
FFTTransformType
fft_type
=
GetFFTTransformType
(
framework
::
TransToProtoVarType
(
x
->
dtype
()),
framework
::
TransToProtoVarType
(
out
->
type
()));
if
(
fft_type
==
FFTTransformType
::
C2R
&&
forward
)
{
if
(
fft_type
==
FFTTransformType
::
C2R
&&
forward
)
{
framework
::
Tensor
collapsed_input_conj
(
framework
::
Tensor
collapsed_input_conj
(
collapsed_input
.
dtype
());
framework
::
TransToProtoVarType
(
collapsed_input
.
dtype
()));
collapsed_input_conj
.
mutable_data
<
Ti
>
(
collapsed_input
.
dims
(),
collapsed_input_conj
.
mutable_data
<
Ti
>
(
collapsed_input
.
dims
(),
ctx
.
GetPlace
());
ctx
.
GetPlace
());
// conjugate the input
// conjugate the input
platform
::
ForRange
<
DeviceContext
>
for_range
(
ctx
,
collapsed_input
.
numel
());
platform
::
ForRange
<
DeviceContext
>
for_range
(
ctx
,
collapsed_input
.
numel
());
math
::
ConjFunctor
<
Ti
>
functor
(
collapsed_input
.
data
<
Ti
>
(),
phi
::
funcs
::
ConjFunctor
<
Ti
>
functor
(
collapsed_input
.
data
<
Ti
>
(),
collapsed_input
.
numel
(),
collapsed_input
.
numel
(),
collapsed_input_conj
.
data
<
Ti
>
());
collapsed_input_conj
.
data
<
Ti
>
());
for_range
(
functor
);
for_range
(
functor
);
MKL_DFTI_CHECK
(
p
latform
::
dynload
::
DftiComputeBackward
(
MKL_DFTI_CHECK
(
p
hi
::
dynload
::
DftiComputeBackward
(
desc
.
get
(),
collapsed_input_conj
.
data
(),
collapsed_output
.
data
()));
desc
.
get
(),
collapsed_input_conj
.
data
(),
collapsed_output
.
data
()));
}
else
if
(
fft_type
==
FFTTransformType
::
R2C
&&
!
forward
)
{
}
else
if
(
fft_type
==
FFTTransformType
::
R2C
&&
!
forward
)
{
framework
::
Tensor
collapsed_output_conj
(
framework
::
Tensor
collapsed_output_conj
(
collapsed_output
.
dtype
());
framework
::
TransToProtoVarType
(
collapsed_output
.
dtype
()));
collapsed_output_conj
.
mutable_data
<
To
>
(
collapsed_output
.
dims
(),
collapsed_output_conj
.
mutable_data
<
To
>
(
collapsed_output
.
dims
(),
ctx
.
GetPlace
());
ctx
.
GetPlace
());
MKL_DFTI_CHECK
(
p
latform
::
dynload
::
DftiComputeForward
(
MKL_DFTI_CHECK
(
p
hi
::
dynload
::
DftiComputeForward
(
desc
.
get
(),
collapsed_input
.
data
(),
collapsed_output_conj
.
data
()));
desc
.
get
(),
collapsed_input
.
data
(),
collapsed_output_conj
.
data
()));
// conjugate the output
// conjugate the output
platform
::
ForRange
<
DeviceContext
>
for_range
(
ctx
,
collapsed_output
.
numel
());
platform
::
ForRange
<
DeviceContext
>
for_range
(
ctx
,
collapsed_output
.
numel
());
math
::
ConjFunctor
<
To
>
functor
(
collapsed_output_conj
.
data
<
To
>
(),
phi
::
funcs
::
ConjFunctor
<
To
>
functor
(
collapsed_output_conj
.
data
<
To
>
(),
collapsed_output
.
numel
(),
collapsed_output
.
numel
(),
collapsed_output
.
data
<
To
>
());
collapsed_output
.
data
<
To
>
());
for_range
(
functor
);
for_range
(
functor
);
}
else
{
}
else
{
if
(
forward
)
{
if
(
forward
)
{
MKL_DFTI_CHECK
(
p
latform
::
dynload
::
DftiComputeForward
(
MKL_DFTI_CHECK
(
p
hi
::
dynload
::
DftiComputeForward
(
desc
.
get
(),
collapsed_input
.
data
(),
collapsed_output
.
data
()));
desc
.
get
(),
collapsed_input
.
data
(),
collapsed_output
.
data
()));
}
else
{
}
else
{
MKL_DFTI_CHECK
(
p
latform
::
dynload
::
DftiComputeBackward
(
MKL_DFTI_CHECK
(
p
hi
::
dynload
::
DftiComputeBackward
(
desc
.
get
(),
collapsed_input
.
data
(),
collapsed_output
.
data
()));
desc
.
get
(),
collapsed_input
.
data
(),
collapsed_output
.
data
()));
}
}
}
}
...
...
paddle/fluid/platform/dynload/mklrt.h
浏览文件 @
687902fc
...
@@ -17,7 +17,8 @@ limitations under the License. */
...
@@ -17,7 +17,8 @@ limitations under the License. */
#include <mkl_dfti.h>
#include <mkl_dfti.h>
#include <mutex> // NOLINT
#include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/phi/backends/dynload/dynamic_loader.h"
#include "paddle/phi/backends/dynload/mklrt.h"
#include "paddle/phi/backends/dynload/port.h"
#include "paddle/phi/backends/dynload/port.h"
namespace
paddle
{
namespace
paddle
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录