spectral_op.cu 20.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at
   http://www.apache.org/licenses/LICENSE-2.0
   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */
#include <functional>
#include <list>
#include <memory>
#include <mutex>
#include <numeric>
#include <sstream>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <vector>

#include "paddle/fluid/operators/conj_op.h"
C
chentianyu03 已提交
23
#include "paddle/fluid/operators/math/complex_functors.h"
24
#include "paddle/fluid/operators/spectral_helper.h"
25 26
#include "paddle/fluid/operators/spectral_op.h"
#include "paddle/fluid/operators/transpose_op.h"
27
#include "paddle/fluid/platform/enforce.h"
28 29 30 31 32 33

namespace paddle {
namespace operators {

namespace {

34 35 36 37 38 39 40
// Calculates the normalization constant
double fft_normalization_scale(FFTNormMode normalization,
                               const std::vector<int64_t>& sizes,
                               const std::vector<int64_t>& dims) {
  // auto norm = static_cast<fft_norm_mode>(normalization);
  if (normalization == FFTNormMode::none) {
    return static_cast<double>(1.0);
41 42
  }

43 44 45 46 47 48 49 50
  int64_t signal_numel = 1;
  for (auto dim : dims) {
    signal_numel *= sizes[dim];
  }
  const double scale_denom = (normalization == FFTNormMode::by_sqrt_n)
                                 ? std::sqrt(signal_numel)
                                 : static_cast<double>(signal_numel);
  return static_cast<double>(1.0 / scale_denom);
51 52
}

53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
template <typename DeviceContext, typename T>
void exec_normalization(const DeviceContext& ctx, const Tensor* in, Tensor* out,
                        FFTNormMode normalization,
                        const std::vector<int64_t>& sizes,
                        const std::vector<int64_t>& axes) {
  double scale = fft_normalization_scale(normalization, sizes, axes);
  if (scale != 1.0) {
    auto eigen_out = framework::EigenVector<T>::Flatten(*out);
    auto eigen_in = framework::EigenVector<T>::Flatten(*in);
    auto dev = ctx.eigen_device();
    EigenScale<Eigen::GpuDevice, T>::Eval(*dev, eigen_out, eigen_in,
                                          static_cast<T>(scale),
                                          static_cast<T>(0), false);
  } else {
    framework::TensorCopy(*in, ctx.GetPlace(), out);
68
  }
69
}
70

71
#if defined(PADDLE_WITH_CUDA)
F
Feiyu Chan 已提交
72 73 74
FFTConfigKey create_fft_configkey(const framework::Tensor& input,
                                  const framework::Tensor& output,
                                  int signal_ndim) {
75 76 77 78 79 80 81
  // Create the transform plan (either from cache or locally)
  const auto value_type = framework::IsComplexType(input.type())
                              ? framework::ToRealType(input.type())
                              : input.type();
  auto fft_type = GetFFTTransformType(input.type(), output.type());
  // signal sizes
  std::vector<int64_t> signal_size(signal_ndim + 1);
82

83 84 85 86 87
  signal_size[0] = input.dims()[0];
  for (int64_t i = 1; i <= signal_ndim; ++i) {
    auto in_size = input.dims()[i];
    auto out_size = output.dims()[i];
    signal_size[i] = std::max(in_size, out_size);
88
  }
F
Feiyu Chan 已提交
89 90 91 92
  FFTConfigKey key(framework::vectorize(input.dims()),
                   framework::vectorize(output.dims()), signal_size, fft_type,
                   value_type);
  return key;
93
}
94

95
// Execute a pre-planned transform
F
Feiyu Chan 已提交
96
static void exec_cufft_plan_raw(const FFTConfig& config, void* in_data,
97 98
                                void* out_data, bool forward) {
  auto& plan = config.plan();
99

100
  PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cufftXtExec(
101 102
      plan, in_data, out_data, forward ? CUFFT_FORWARD : CUFFT_INVERSE));
}
103

104
template <typename DeviceContext, typename Ti, typename To>
F
Feiyu Chan 已提交
105
void exec_cufft_plan(const DeviceContext& ctx, const FFTConfig& config,
106 107 108 109 110 111 112 113 114 115 116 117
                     framework::Tensor* input, framework::Tensor* output,
                     bool forward) {
  // execute transform plan
  auto fft_type = config.transform_type();
  if (fft_type == FFTTransformType::C2R && forward) {
    forward = false;
    framework::Tensor input_conj(input->type());
    input_conj.mutable_data<Ti>(input->dims(), ctx.GetPlace());
    platform::ForRange<DeviceContext> for_range(ctx, input->numel());
    math::ConjFunctor<Ti> functor(input->data<Ti>(), input->numel(),
                                  input_conj.data<Ti>());
    for_range(functor);
118
    exec_cufft_plan_raw(config, input_conj.data(), output->data(), forward);
119 120 121 122
  } else if (fft_type == FFTTransformType::R2C && !forward) {
    forward = true;
    framework::Tensor out_conj(output->type());
    out_conj.mutable_data<To>(output->dims(), ctx.GetPlace());
123
    exec_cufft_plan_raw(config, input->data(), out_conj.data(), forward);
124 125 126 127 128 129

    platform::ForRange<DeviceContext> for_range(ctx, output->numel());
    math::ConjFunctor<To> functor(out_conj.data<To>(), output->numel(),
                                  output->data<To>());
    for_range(functor);
  } else {
130
    exec_cufft_plan_raw(config, input->data(), output->data(), forward);
131
  }
132
}
133

134
#elif defined(PADDLE_WITH_HIP)
135

F
Feiyu Chan 已提交
136
FFTConfigKey create_fft_configkey(const framework::Tensor& input,
137 138 139 140 141 142 143 144 145
                                  const framework::Tensor& output,
                                  int signal_ndim) {
  // Create the transform plan (either from cache or locally)
  const auto value_type = framework::IsComplexType(input.type())
                              ? framework::ToRealType(input.type())
                              : input.type();
  auto fft_type = GetFFTTransformType(input.type(), output.type());
  // signal sizes
  std::vector<int64_t> signal_size(signal_ndim + 1);
146

147 148 149 150 151 152
  signal_size[0] = input.dims()[0];
  for (int64_t i = 1; i <= signal_ndim; ++i) {
    auto in_size = input.dims()[i];
    auto out_size = output.dims()[i];
    signal_size[i] = std::max(in_size, out_size);
  }
F
Feiyu Chan 已提交
153 154 155 156
  FFTConfigKey key(framework::vectorize(input.dims()),
                   framework::vectorize(output.dims()), signal_size, fft_type,
                   value_type);
  return key;
157
}
158 159

// Execute a pre-planned transform
F
Feiyu Chan 已提交
160
static void exec_hipfft_plan_raw(const FFTConfig& config, void* in_data,
161
                                 void* out_data, bool forward) {
162
  auto& plan = config.plan();
163

164 165 166 167
  auto value_type = config.data_type();
  if (value_type == framework::proto::VarType::FP32) {
    switch (config.transform_type()) {
      case FFTTransformType::C2C: {
168
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftExecC2C(
169 170 171
            plan, static_cast<hipfftComplex*>(in_data),
            static_cast<hipfftComplex*>(out_data),
            forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD));
172 173 174
        return;
      }
      case FFTTransformType::R2C: {
175
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftExecR2C(
176 177
            plan, static_cast<hipfftReal*>(in_data),
            static_cast<hipfftComplex*>(out_data)));
178 179 180
        return;
      }
      case FFTTransformType::C2R: {
181
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftExecC2R(
182 183
            plan, static_cast<hipfftComplex*>(in_data),
            static_cast<hipfftReal*>(out_data)));
184 185 186 187 188 189
        return;
      }
    }
  } else if (value_type == framework::proto::VarType::FP64) {
    switch (config.transform_type()) {
      case FFTTransformType::C2C: {
190
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftExecZ2Z(
191 192 193
            plan, static_cast<hipfftDoubleComplex*>(in_data),
            static_cast<hipfftDoubleComplex*>(out_data),
            forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD));
194 195 196
        return;
      }
      case FFTTransformType::R2C: {
197
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftExecD2Z(
198 199
            plan, static_cast<hipfftDoubleReal*>(in_data),
            static_cast<hipfftDoubleComplex*>(out_data)));
200 201 202
        return;
      }
      case FFTTransformType::C2R: {
203
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftExecZ2D(
204 205
            plan, static_cast<hipfftDoubleComplex*>(in_data),
            static_cast<hipfftDoubleReal*>(out_data)));
206 207 208 209 210 211 212 213
        return;
      }
    }
  }
  PADDLE_THROW(platform::errors::InvalidArgument(
      "hipFFT only support transforms of type float32 and float64"));
}

214
template <typename DeviceContext, typename Ti, typename To>
F
Feiyu Chan 已提交
215
void exec_hipfft_plan(const DeviceContext& ctx, const FFTConfig& config,
216 217 218 219 220 221 222 223 224 225 226
                      framework::Tensor* input, framework::Tensor* output,
                      bool forward) {
  auto fft_type = config.transform_type();
  if (fft_type == FFTTransformType::C2R && forward) {
    forward = false;
    framework::Tensor input_conj(input->type());
    input_conj.mutable_data<Ti>(input->dims(), ctx.GetPlace());
    platform::ForRange<DeviceContext> for_range(ctx, input->numel());
    math::ConjFunctor<Ti> functor(input->data<Ti>(), input->numel(),
                                  input_conj.data<Ti>());
    for_range(functor);
227
    exec_hipfft_plan_raw(config, input_conj.data(), output->data(), forward);
228 229 230 231
  } else if (fft_type == FFTTransformType::R2C && !forward) {
    forward = true;
    framework::Tensor out_conj(output->type());
    out_conj.mutable_data<To>(output->dims(), ctx.GetPlace());
232
    exec_hipfft_plan_raw(config, input->data(), out_conj.data(), forward);
233 234 235 236 237 238

    platform::ForRange<DeviceContext> for_range(ctx, output->numel());
    math::ConjFunctor<To> functor(out_conj.data<To>(), output->numel(),
                                  output->data<To>());
    for_range(functor);
  } else {
239
    exec_hipfft_plan_raw(config, input->data(), output->data(), forward);
240 241 242 243 244
  }
}

#endif

245 246 247 248 249 250 251 252 253
// Execute a general unnormalized fft operation (can be c2c, onesided r2c or
// onesided c2r)
template <typename DeviceContext, typename Ti, typename To>
void exec_fft(const DeviceContext& ctx, const Tensor* X, Tensor* out,
              const std::vector<int64_t>& dim, bool forward) {
  const auto x_dims = framework::vectorize(X->dims());
  const int64_t ndim = static_cast<int64_t>(X->dims().size());
  auto tensor_place = ctx.GetPlace();

254
  // make a dim permutation
255 256 257 258 259 260 261 262 263 264 265 266
  std::vector<int> dim_permute(ndim);
  std::iota(dim_permute.begin(), dim_permute.end(), int{0});
  std::vector<bool> is_transformed_dim(ndim);
  for (const auto& d : dim) {
    is_transformed_dim[d] = true;
  }
  auto batch_end =
      std::partition(dim_permute.begin(), dim_permute.end(),
                     [&](int64_t d) { return !is_transformed_dim[d]; });
  std::sort(dim_permute.begin(), batch_end);
  std::copy(dim.cbegin(), dim.cend(), batch_end);

267 268 269 270 271 272 273
  // transpose input according to dim permutation
  auto transposed_input_shape = X->dims().transpose(dim_permute);
  framework::Tensor transposed_input;
  transposed_input.Resize(transposed_input_shape);
  transposed_input.mutable_data<Ti>(tensor_place);
  TransCompute<DeviceContext, Ti>(ndim, ctx, *X, &transposed_input,
                                  dim_permute);
274 275

  // Reshape batch dimensions into a single dimension
276 277 278 279 280
  const int64_t signal_ndim = static_cast<int64_t>(dim.size());
  std::vector<int64_t> collapsed_input_shape(signal_ndim + 1);

  auto transposed_input_shape_ = framework::vectorize(transposed_input_shape);
  const int64_t batch_dims = ndim - signal_ndim;
281
  auto batch_size =
282 283
      std::accumulate(transposed_input_shape_.begin(),
                      transposed_input_shape_.begin() + batch_dims,
284
                      static_cast<int>(1), std::multiplies<int>());
285
  collapsed_input_shape[0] = batch_size;
286

287 288
  std::copy(transposed_input_shape_.begin() + batch_dims,
            transposed_input_shape_.end(), collapsed_input_shape.begin() + 1);
289

290 291 292 293 294 295 296
  framework::Tensor& collapsed_input = transposed_input;
  collapsed_input.Resize(framework::make_ddim(collapsed_input_shape));

  // make a collpased output
  const auto out_dims = framework::vectorize(out->dims());
  std::vector<int64_t> collapsed_output_shape(1 + signal_ndim);
  collapsed_output_shape[0] = batch_size;
297
  for (size_t i = 0; i < dim.size(); ++i) {
298
    collapsed_output_shape[i + 1] = out_dims[dim[i]];
299
  }
300 301 302 303
  framework::Tensor collapsed_output;
  collapsed_output.Resize(framework::make_ddim(collapsed_output_shape));
  collapsed_output.mutable_data<To>(tensor_place);

F
Feiyu Chan 已提交
304 305
  FFTConfig* config = nullptr;

306
#if defined(PADDLE_WITH_CUDA)
F
Feiyu Chan 已提交
307
  std::unique_ptr<FFTConfig> config_ = nullptr;
308
  // create plan
F
Feiyu Chan 已提交
309 310
  FFTConfigKey key =
      create_fft_configkey(collapsed_input, collapsed_output, signal_ndim);
311 312 313 314 315 316
  bool using_cache = false;
#if !defined(CUFFT_VERSION) || (CUFFT_VERSION < 10200)
  using_cache = true;
#endif

  if (using_cache) {
F
Feiyu Chan 已提交
317 318 319 320 321 322 323 324 325 326 327 328
    const int64_t device_id = static_cast<int64_t>(
        reinterpret_cast<const platform::CUDAPlace*>(&collapsed_input.place())
            ->GetDeviceId());
    FFTConfigCache& plan_cache = get_fft_plan_cache(device_id);
    std::unique_lock<std::mutex> guard(plan_cache.mutex, std::defer_lock);
    guard.lock();
    config = &(plan_cache.lookup(key));
  } else {
    config_ = std::make_unique<FFTConfig>(key);
    config = config_.get();
  }

329
  // prepare cufft for execution
330
  PADDLE_ENFORCE_GPU_SUCCESS(
F
Feiyu Chan 已提交
331
      platform::dynload::cufftSetStream(config->plan(), ctx.stream()));
332
  framework::Tensor workspace_tensor;
F
Feiyu Chan 已提交
333
  workspace_tensor.mutable_data<To>(tensor_place, config->workspace_size());
334
  PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cufftSetWorkArea(
F
Feiyu Chan 已提交
335
      config->plan(), workspace_tensor.data<To>()));
336
  // execute transform plan
F
Feiyu Chan 已提交
337
  exec_cufft_plan<DeviceContext, Ti, To>(ctx, *config, &collapsed_input,
338
                                         &collapsed_output, forward);
339

340 341
#elif defined(PADDLE_WITH_HIP)
  // create plan
F
Feiyu Chan 已提交
342 343 344 345 346 347 348 349 350 351
  FFTConfigKey key =
      create_fft_configkey(collapsed_input, collapsed_output, signal_ndim);
  const int64_t device_id = static_cast<int64_t>(
      reinterpret_cast<const platform::CUDAPlace*>(&collapsed_input.place())
          ->GetDeviceId());
  FFTConfigCache& plan_cache = get_fft_plan_cache(device_id);
  std::unique_lock<std::mutex> guard(plan_cache.mutex, std::defer_lock);
  guard.lock();
  config = &(plan_cache.lookup(key));

352
  // prepare cufft for execution
353
  PADDLE_ENFORCE_GPU_SUCCESS(
F
Feiyu Chan 已提交
354
      platform::dynload::hipfftSetStream(config->plan(), ctx.stream()));
355
  framework::Tensor workspace_tensor;
F
Feiyu Chan 已提交
356
  workspace_tensor.mutable_data<To>(tensor_place, config->workspace_size());
357
  PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftSetWorkArea(
F
Feiyu Chan 已提交
358
      config->plan(), workspace_tensor.data<To>()));
359
  // execute transform plan
F
Feiyu Chan 已提交
360
  exec_hipfft_plan<DeviceContext, Ti, To>(ctx, *config, &collapsed_input,
361 362
                                          &collapsed_output, forward);
#endif
363 364

  // Inverting output by reshape and transpose to original batch and dimension
365
  auto transposed_out_shape = out->dims().transpose(dim_permute);
366

367 368
  collapsed_output.Resize(transposed_out_shape);
  auto& transposed_output = collapsed_output;
369

370 371 372
  std::vector<int> reverse_dim_permute(ndim);
  for (size_t i = 0; i < ndim; i++) {
    reverse_dim_permute[dim_permute[i]] = i;
373 374
  }

375 376
  TransCompute<DeviceContext, To>(ndim, ctx, transposed_output, out,
                                  reverse_dim_permute);
377
}
378

379 380 381 382
}  // anonymous namespace

// Use the optimized path to perform single R2C or C2R if transformation dim is
// supported by cuFFT
F
Feiyu Chan 已提交
383
bool use_optimized_fft_path(const std::vector<int64_t>& axes) {
384 385
  // For performance reason, when axes starts with (0, 1), do not use the
  // optimized path.
F
Feiyu Chan 已提交
386
  if (axes.size() > kMaxFFTNdim ||
387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415
      (axes.size() >= 2 && axes[0] == 0 && axes[1] == 1)) {
    return false;
  } else {
    return true;
  }
}

template <typename Ti, typename To>
struct FFTC2CFunctor<platform::CUDADeviceContext, Ti, To> {
  void operator()(const platform::CUDADeviceContext& ctx, const Tensor* X,
                  Tensor* out, const std::vector<int64_t>& axes,
                  FFTNormMode normalization, bool forward) {
    if (axes.empty()) {
      framework::TensorCopy(*X, ctx.GetPlace(), out);
      return;
    }

    framework::Tensor* p_out = out;
    std::vector<int64_t> out_dims = framework::vectorize(X->dims());
    std::vector<int64_t> working_axes(axes.begin(), axes.end());
    std::vector<int64_t> first_dims;
    size_t max_dims;
    framework::Tensor working_tensor;
    working_tensor.mutable_data<Ti>(X->dims(), ctx.GetPlace());
    framework::Tensor* p_working_tensor = &working_tensor;
    framework::TensorCopy(*X, ctx.GetPlace(), &working_tensor);

    while (true) {
      max_dims =
F
Feiyu Chan 已提交
416
          std::min(static_cast<size_t>(kMaxFFTNdim), working_axes.size());
417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442
      first_dims.assign(working_axes.end() - max_dims, working_axes.end());

      exec_fft<platform::CUDADeviceContext, Ti, To>(ctx, p_working_tensor,
                                                    p_out, first_dims, forward);
      working_axes.resize(working_axes.size() - max_dims);
      first_dims.clear();

      if (working_axes.empty()) {
        break;
      }

      std::swap(p_out, p_working_tensor);
    }
    exec_normalization<platform::CUDADeviceContext, To>(
        ctx, p_out, out, normalization, out_dims, axes);
  }
};

template <typename Ti, typename To>
struct FFTC2RFunctor<platform::CUDADeviceContext, Ti, To> {
  void operator()(const platform::CUDADeviceContext& ctx, const Tensor* X,
                  Tensor* out, const std::vector<int64_t>& axes,
                  FFTNormMode normalization, bool forward) {
    std::vector<int64_t> in_dims = framework::vectorize(X->dims());
    std::vector<int64_t> out_dims = framework::vectorize(out->dims());

F
Feiyu Chan 已提交
443
    if (use_optimized_fft_path(axes)) {
444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524
      framework::Tensor x_copy(X->type());
      x_copy.mutable_data<Ti>(X->dims(), ctx.GetPlace());
      framework::TensorCopy(*X, ctx.GetPlace(), &x_copy);
      exec_fft<platform::CUDADeviceContext, Ti, To>(ctx, &x_copy, out, axes,
                                                    forward);
    } else {
      framework::Tensor temp_tensor;
      temp_tensor.mutable_data<Ti>(X->dims(), ctx.GetPlace());
      const std::vector<int64_t> dims(axes.begin(), axes.end() - 1);

      FFTC2CFunctor<platform::CUDADeviceContext, Ti, Ti> c2c_functor;
      c2c_functor(ctx, X, &temp_tensor, dims, FFTNormMode::none, forward);

      exec_fft<platform::CUDADeviceContext, Ti, To>(ctx, &temp_tensor, out,
                                                    {axes.back()}, forward);
    }
    exec_normalization<platform::CUDADeviceContext, To>(
        ctx, out, out, normalization, out_dims, axes);
  }
};

// n dimension real to complex FFT use cufft lib
template <typename Ti, typename To>
struct FFTR2CFunctor<platform::CUDADeviceContext, Ti, To> {
  void operator()(const platform::CUDADeviceContext& ctx, const Tensor* X,
                  Tensor* out, const std::vector<int64_t>& axes,
                  FFTNormMode normalization, bool forward) {
    // Step1: R2C transform on the last dimension
    framework::Tensor* r2c_out = out;
    const std::vector<int64_t> last_dim{axes.back()};
    std::vector<int64_t> out_dims = framework::vectorize(out->dims());
    exec_fft<platform::CUDADeviceContext, Ti, To>(ctx, X, r2c_out, last_dim,
                                                  forward);

    // Step2: C2C transform on the remaining dimension
    framework::Tensor c2c_out;
    if (axes.size() > 1) {
      c2c_out.mutable_data<To>(out->dims(), ctx.GetPlace());
      std::vector<int64_t> remain_dim(axes.begin(), axes.end() - 1);
      FFTC2CFunctor<platform::CUDADeviceContext, To, To> fft_c2c_func;
      fft_c2c_func(ctx, r2c_out, &c2c_out, remain_dim, FFTNormMode::none,
                   forward);
    }

    const auto in_sizes = framework::vectorize(X->dims());
    framework::Tensor* norm_tensor = axes.size() > 1 ? &c2c_out : r2c_out;
    exec_normalization<platform::CUDADeviceContext, To>(
        ctx, norm_tensor, out, normalization, in_sizes, axes);
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
    fft_c2c, ops::FFTC2CKernel<paddle::platform::CUDADeviceContext, float>,
    ops::FFTC2CKernel<paddle::platform::CUDADeviceContext, double>);

REGISTER_OP_CUDA_KERNEL(
    fft_c2c_grad,
    ops::FFTC2CGradKernel<paddle::platform::CUDADeviceContext, float>,
    ops::FFTC2CGradKernel<paddle::platform::CUDADeviceContext, double>);

REGISTER_OP_CUDA_KERNEL(
    fft_c2r, ops::FFTC2RKernel<paddle::platform::CUDADeviceContext, float>,
    ops::FFTC2RKernel<paddle::platform::CUDADeviceContext, double>);

REGISTER_OP_CUDA_KERNEL(
    fft_c2r_grad,
    ops::FFTC2RGradKernel<paddle::platform::CUDADeviceContext, float>,
    ops::FFTC2RGradKernel<paddle::platform::CUDADeviceContext, double>);

REGISTER_OP_CUDA_KERNEL(
    fft_r2c, ops::FFTR2CKernel<paddle::platform::CUDADeviceContext, float>,
    ops::FFTR2CKernel<paddle::platform::CUDADeviceContext, double>);

REGISTER_OP_CUDA_KERNEL(
    fft_r2c_grad,
    ops::FFTR2CGradKernel<paddle::platform::CUDADeviceContext, float>,
    ops::FFTR2CGradKernel<paddle::platform::CUDADeviceContext, double>);