未验证 提交 687902fc 编写于 作者: F Feiyu Chan 提交者: GitHub

[phi] update code for mkl based fft (#39889)

上级 584844ec
......@@ -25,9 +25,10 @@
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#if defined(PADDLE_WITH_ONEMKL)
#include "paddle/fluid/platform/dynload/mklrt.h"
#include "paddle/phi/backends/dynload/mklrt.h"
#elif defined(PADDLE_WITH_POCKETFFT)
#include "extern_pocketfft/pocketfft_hdronly.h"
#endif
......@@ -357,12 +358,12 @@ FFTNormMode get_norm_from_string(const std::string& norm, bool forward) {
// FFT Functors
#if defined(PADDLE_WITH_ONEMKL)
#define MKL_DFTI_CHECK(expr) \
do { \
MKL_LONG status = (expr); \
if (!platform::dynload::DftiErrorClass(status, DFTI_NO_ERROR)) \
PADDLE_THROW(platform::errors::External( \
platform::dynload::DftiErrorMessage(status))); \
#define MKL_DFTI_CHECK(expr) \
do { \
MKL_LONG status = (expr); \
if (!phi::dynload::DftiErrorClass(status, DFTI_NO_ERROR)) \
PADDLE_THROW( \
platform::errors::External(phi::dynload::DftiErrorMessage(status))); \
} while (0);
namespace {
......@@ -370,7 +371,7 @@ namespace {
struct DftiDescriptorDeleter {
void operator()(DFTI_DESCRIPTOR_HANDLE handle) {
if (handle != nullptr) {
MKL_DFTI_CHECK(platform::dynload::DftiFreeDescriptor(&handle));
MKL_DFTI_CHECK(phi::dynload::DftiFreeDescriptor(&handle));
}
}
};
......@@ -385,7 +386,7 @@ class DftiDescriptor {
"DftiDescriptor has already been initialized."));
DFTI_DESCRIPTOR* raw_desc;
MKL_DFTI_CHECK(platform::dynload::DftiCreateDescriptorX(
MKL_DFTI_CHECK(phi::dynload::DftiCreateDescriptorX(
&raw_desc, precision, signal_type, signal_ndim, sizes));
desc_.reset(raw_desc);
}
......@@ -437,21 +438,21 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
descriptor.init(precision, domain, signal_ndim, fft_sizes.data() + 1);
// placement inplace or not inplace
MKL_DFTI_CHECK(platform::dynload::DftiSetValue(
descriptor.get(), DFTI_PLACEMENT, DFTI_NOT_INPLACE));
MKL_DFTI_CHECK(phi::dynload::DftiSetValue(descriptor.get(), DFTI_PLACEMENT,
DFTI_NOT_INPLACE));
// number of transformations
const MKL_LONG batch_size = fft_sizes[0];
MKL_DFTI_CHECK(platform::dynload::DftiSetValue(
MKL_DFTI_CHECK(phi::dynload::DftiSetValue(
descriptor.get(), DFTI_NUMBER_OF_TRANSFORMS, batch_size));
// input & output distance
const MKL_LONG idist = in_strides[0];
const MKL_LONG odist = out_strides[0];
MKL_DFTI_CHECK(platform::dynload::DftiSetValue(descriptor.get(),
DFTI_INPUT_DISTANCE, idist));
MKL_DFTI_CHECK(platform::dynload::DftiSetValue(descriptor.get(),
DFTI_OUTPUT_DISTANCE, odist));
MKL_DFTI_CHECK(
phi::dynload::DftiSetValue(descriptor.get(), DFTI_INPUT_DISTANCE, idist));
MKL_DFTI_CHECK(phi::dynload::DftiSetValue(descriptor.get(),
DFTI_OUTPUT_DISTANCE, odist));
// input & output stride
std::vector<MKL_LONG> mkl_in_stride(1 + signal_ndim, 0);
......@@ -460,14 +461,14 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
mkl_in_stride[i] = in_strides[i];
mkl_out_stride[i] = out_strides[i];
}
MKL_DFTI_CHECK(platform::dynload::DftiSetValue(
MKL_DFTI_CHECK(phi::dynload::DftiSetValue(
descriptor.get(), DFTI_INPUT_STRIDES, mkl_in_stride.data()));
MKL_DFTI_CHECK(platform::dynload::DftiSetValue(
MKL_DFTI_CHECK(phi::dynload::DftiSetValue(
descriptor.get(), DFTI_OUTPUT_STRIDES, mkl_out_stride.data()));
// conjugate even storage
if (!(fft_type == FFTTransformType::C2C)) {
MKL_DFTI_CHECK(platform::dynload::DftiSetValue(
MKL_DFTI_CHECK(phi::dynload::DftiSetValue(
descriptor.get(), DFTI_CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX));
}
......@@ -489,12 +490,12 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
return DFTI_BACKWARD_SCALE;
}
}();
MKL_DFTI_CHECK(platform::dynload::DftiSetValue(descriptor.get(),
scale_direction, scale));
MKL_DFTI_CHECK(
phi::dynload::DftiSetValue(descriptor.get(), scale_direction, scale));
}
// commit the descriptor
MKL_DFTI_CHECK(platform::dynload::DftiCommitDescriptor(descriptor.get()));
MKL_DFTI_CHECK(phi::dynload::DftiCommitDescriptor(descriptor.get()));
return descriptor;
}
......@@ -575,39 +576,39 @@ void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out,
framework::TransToProtoVarType(out->dtype()), input_stride,
output_stride, signal_sizes, normalization, forward);
const FFTTransformType fft_type = GetFFTTransformType(x->type(), out->type());
const FFTTransformType fft_type =
GetFFTTransformType(framework::TransToProtoVarType(x->dtype()),
framework::TransToProtoVarType(out->type()));
if (fft_type == FFTTransformType::C2R && forward) {
framework::Tensor collapsed_input_conj(
framework::TransToProtoVarType(collapsed_input.dtype()));
framework::Tensor collapsed_input_conj(collapsed_input.dtype());
collapsed_input_conj.mutable_data<Ti>(collapsed_input.dims(),
ctx.GetPlace());
// conjugate the input
platform::ForRange<DeviceContext> for_range(ctx, collapsed_input.numel());
math::ConjFunctor<Ti> functor(collapsed_input.data<Ti>(),
collapsed_input.numel(),
collapsed_input_conj.data<Ti>());
phi::funcs::ConjFunctor<Ti> functor(collapsed_input.data<Ti>(),
collapsed_input.numel(),
collapsed_input_conj.data<Ti>());
for_range(functor);
MKL_DFTI_CHECK(platform::dynload::DftiComputeBackward(
MKL_DFTI_CHECK(phi::dynload::DftiComputeBackward(
desc.get(), collapsed_input_conj.data(), collapsed_output.data()));
} else if (fft_type == FFTTransformType::R2C && !forward) {
framework::Tensor collapsed_output_conj(
framework::TransToProtoVarType(collapsed_output.dtype()));
framework::Tensor collapsed_output_conj(collapsed_output.dtype());
collapsed_output_conj.mutable_data<To>(collapsed_output.dims(),
ctx.GetPlace());
MKL_DFTI_CHECK(platform::dynload::DftiComputeForward(
MKL_DFTI_CHECK(phi::dynload::DftiComputeForward(
desc.get(), collapsed_input.data(), collapsed_output_conj.data()));
// conjugate the output
platform::ForRange<DeviceContext> for_range(ctx, collapsed_output.numel());
math::ConjFunctor<To> functor(collapsed_output_conj.data<To>(),
collapsed_output.numel(),
collapsed_output.data<To>());
phi::funcs::ConjFunctor<To> functor(collapsed_output_conj.data<To>(),
collapsed_output.numel(),
collapsed_output.data<To>());
for_range(functor);
} else {
if (forward) {
MKL_DFTI_CHECK(platform::dynload::DftiComputeForward(
MKL_DFTI_CHECK(phi::dynload::DftiComputeForward(
desc.get(), collapsed_input.data(), collapsed_output.data()));
} else {
MKL_DFTI_CHECK(platform::dynload::DftiComputeBackward(
MKL_DFTI_CHECK(phi::dynload::DftiComputeBackward(
desc.get(), collapsed_input.data(), collapsed_output.data()));
}
}
......
......@@ -17,7 +17,8 @@ limitations under the License. */
#include <mkl_dfti.h>
#include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/phi/backends/dynload/dynamic_loader.h"
#include "paddle/phi/backends/dynload/mklrt.h"
#include "paddle/phi/backends/dynload/port.h"
namespace paddle {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册