spectral_op.cu 20.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at
   http://www.apache.org/licenses/LICENSE-2.0
   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */
#include <functional>
#include <list>
#include <memory>
#include <mutex>
#include <numeric>
#include <sstream>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <vector>

#include "paddle/fluid/operators/conj_op.h"
23
#include "paddle/fluid/operators/spectral_helper.h"
24 25
#include "paddle/fluid/operators/spectral_op.h"
#include "paddle/fluid/operators/transpose_op.h"
26
#include "paddle/fluid/platform/enforce.h"
27 28 29 30 31 32

namespace paddle {
namespace operators {

namespace {

33 34 35 36 37 38 39
// Calculates the normalization constant
double fft_normalization_scale(FFTNormMode normalization,
                               const std::vector<int64_t>& sizes,
                               const std::vector<int64_t>& dims) {
  // auto norm = static_cast<fft_norm_mode>(normalization);
  if (normalization == FFTNormMode::none) {
    return static_cast<double>(1.0);
40 41
  }

42 43 44 45 46 47 48 49
  int64_t signal_numel = 1;
  for (auto dim : dims) {
    signal_numel *= sizes[dim];
  }
  const double scale_denom = (normalization == FFTNormMode::by_sqrt_n)
                                 ? std::sqrt(signal_numel)
                                 : static_cast<double>(signal_numel);
  return static_cast<double>(1.0 / scale_denom);
50 51
}

52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
template <typename DeviceContext, typename T>
void exec_normalization(const DeviceContext& ctx, const Tensor* in, Tensor* out,
                        FFTNormMode normalization,
                        const std::vector<int64_t>& sizes,
                        const std::vector<int64_t>& axes) {
  double scale = fft_normalization_scale(normalization, sizes, axes);
  if (scale != 1.0) {
    auto eigen_out = framework::EigenVector<T>::Flatten(*out);
    auto eigen_in = framework::EigenVector<T>::Flatten(*in);
    auto dev = ctx.eigen_device();
    EigenScale<Eigen::GpuDevice, T>::Eval(*dev, eigen_out, eigen_in,
                                          static_cast<T>(scale),
                                          static_cast<T>(0), false);
  } else {
    framework::TensorCopy(*in, ctx.GetPlace(), out);
67
  }
68
}
69

70
#if defined(PADDLE_WITH_CUDA)
F
Feiyu Chan 已提交
71 72 73
FFTConfigKey create_fft_configkey(const framework::Tensor& input,
                                  const framework::Tensor& output,
                                  int signal_ndim) {
74 75 76 77 78 79 80
  // Create the transform plan (either from cache or locally)
  const auto value_type = framework::IsComplexType(input.type())
                              ? framework::ToRealType(input.type())
                              : input.type();
  auto fft_type = GetFFTTransformType(input.type(), output.type());
  // signal sizes
  std::vector<int64_t> signal_size(signal_ndim + 1);
81

82 83 84 85 86
  signal_size[0] = input.dims()[0];
  for (int64_t i = 1; i <= signal_ndim; ++i) {
    auto in_size = input.dims()[i];
    auto out_size = output.dims()[i];
    signal_size[i] = std::max(in_size, out_size);
87
  }
F
Feiyu Chan 已提交
88 89 90 91
  FFTConfigKey key(framework::vectorize(input.dims()),
                   framework::vectorize(output.dims()), signal_size, fft_type,
                   value_type);
  return key;
92
}
93

94
// Execute a pre-planned transform
F
Feiyu Chan 已提交
95
static void exec_cufft_plan_raw(const FFTConfig& config, void* in_data,
96 97
                                void* out_data, bool forward) {
  auto& plan = config.plan();
98

99 100 101
  PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cufftXtExec(
      plan, in_data, out_data, forward ? CUFFT_FORWARD : CUFFT_INVERSE));
}
102

103
template <typename DeviceContext, typename Ti, typename To>
F
Feiyu Chan 已提交
104
void exec_cufft_plan(const DeviceContext& ctx, const FFTConfig& config,
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
                     framework::Tensor* input, framework::Tensor* output,
                     bool forward) {
  // execute transform plan
  auto fft_type = config.transform_type();
  if (fft_type == FFTTransformType::C2R && forward) {
    forward = false;
    framework::Tensor input_conj(input->type());
    input_conj.mutable_data<Ti>(input->dims(), ctx.GetPlace());
    platform::ForRange<DeviceContext> for_range(ctx, input->numel());
    math::ConjFunctor<Ti> functor(input->data<Ti>(), input->numel(),
                                  input_conj.data<Ti>());
    for_range(functor);
    exec_cufft_plan_raw(config, input_conj.data<void>(), output->data<void>(),
                        forward);
  } else if (fft_type == FFTTransformType::R2C && !forward) {
    forward = true;
    framework::Tensor out_conj(output->type());
    out_conj.mutable_data<To>(output->dims(), ctx.GetPlace());
    exec_cufft_plan_raw(config, input->data<void>(), out_conj.data<void>(),
                        forward);

    platform::ForRange<DeviceContext> for_range(ctx, output->numel());
    math::ConjFunctor<To> functor(out_conj.data<To>(), output->numel(),
                                  output->data<To>());
    for_range(functor);
  } else {
    exec_cufft_plan_raw(config, input->data<void>(), output->data<void>(),
                        forward);
133
  }
134
}
135

136
#elif defined(PADDLE_WITH_HIP)
137

F
Feiyu Chan 已提交
138
FFTConfigKey create_fft_configkey(const framework::Tensor& input,
139 140 141 142 143 144 145 146 147
                                  const framework::Tensor& output,
                                  int signal_ndim) {
  // Create the transform plan (either from cache or locally)
  const auto value_type = framework::IsComplexType(input.type())
                              ? framework::ToRealType(input.type())
                              : input.type();
  auto fft_type = GetFFTTransformType(input.type(), output.type());
  // signal sizes
  std::vector<int64_t> signal_size(signal_ndim + 1);
148

149 150 151 152 153 154
  signal_size[0] = input.dims()[0];
  for (int64_t i = 1; i <= signal_ndim; ++i) {
    auto in_size = input.dims()[i];
    auto out_size = output.dims()[i];
    signal_size[i] = std::max(in_size, out_size);
  }
F
Feiyu Chan 已提交
155 156 157 158
  FFTConfigKey key(framework::vectorize(input.dims()),
                   framework::vectorize(output.dims()), signal_size, fft_type,
                   value_type);
  return key;
159
}
160 161

// Execute a pre-planned transform
F
Feiyu Chan 已提交
162
static void exec_hipfft_plan_raw(const FFTConfig& config, void* in_data,
163
                                 void* out_data, bool forward) {
164
  auto& plan = config.plan();
165

166 167 168 169
  auto value_type = config.data_type();
  if (value_type == framework::proto::VarType::FP32) {
    switch (config.transform_type()) {
      case FFTTransformType::C2C: {
170 171 172 173
        PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftExecC2C(
            plan, static_cast<hipfftComplex*>(in_data),
            static_cast<hipfftComplex*>(out_data),
            forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD));
174 175 176
        return;
      }
      case FFTTransformType::R2C: {
177 178 179
        PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftExecR2C(
            plan, static_cast<hipfftReal*>(in_data),
            static_cast<hipfftComplex*>(out_data)));
180 181 182
        return;
      }
      case FFTTransformType::C2R: {
183 184 185
        PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftExecC2R(
            plan, static_cast<hipfftComplex*>(in_data),
            static_cast<hipfftReal*>(out_data)));
186 187 188 189 190 191
        return;
      }
    }
  } else if (value_type == framework::proto::VarType::FP64) {
    switch (config.transform_type()) {
      case FFTTransformType::C2C: {
192 193 194 195
        PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftExecZ2Z(
            plan, static_cast<hipfftDoubleComplex*>(in_data),
            static_cast<hipfftDoubleComplex*>(out_data),
            forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD));
196 197 198
        return;
      }
      case FFTTransformType::R2C: {
199 200 201
        PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftExecD2Z(
            plan, static_cast<hipfftDoubleReal*>(in_data),
            static_cast<hipfftDoubleComplex*>(out_data)));
202 203 204
        return;
      }
      case FFTTransformType::C2R: {
205 206 207
        PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftExecZ2D(
            plan, static_cast<hipfftDoubleComplex*>(in_data),
            static_cast<hipfftDoubleReal*>(out_data)));
208 209 210 211 212 213 214 215
        return;
      }
    }
  }
  PADDLE_THROW(platform::errors::InvalidArgument(
      "hipFFT only support transforms of type float32 and float64"));
}

216
template <typename DeviceContext, typename Ti, typename To>
F
Feiyu Chan 已提交
217
void exec_hipfft_plan(const DeviceContext& ctx, const FFTConfig& config,
218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249
                      framework::Tensor* input, framework::Tensor* output,
                      bool forward) {
  auto fft_type = config.transform_type();
  if (fft_type == FFTTransformType::C2R && forward) {
    forward = false;
    framework::Tensor input_conj(input->type());
    input_conj.mutable_data<Ti>(input->dims(), ctx.GetPlace());
    platform::ForRange<DeviceContext> for_range(ctx, input->numel());
    math::ConjFunctor<Ti> functor(input->data<Ti>(), input->numel(),
                                  input_conj.data<Ti>());
    for_range(functor);
    exec_hipfft_plan_raw(config, input_conj.data<void>(), output->data<void>(),
                         forward);
  } else if (fft_type == FFTTransformType::R2C && !forward) {
    forward = true;
    framework::Tensor out_conj(output->type());
    out_conj.mutable_data<To>(output->dims(), ctx.GetPlace());
    exec_hipfft_plan_raw(config, input->data<void>(), out_conj.data<void>(),
                         forward);

    platform::ForRange<DeviceContext> for_range(ctx, output->numel());
    math::ConjFunctor<To> functor(out_conj.data<To>(), output->numel(),
                                  output->data<To>());
    for_range(functor);
  } else {
    exec_hipfft_plan_raw(config, input->data<void>(), output->data<void>(),
                         forward);
  }
}

#endif

250 251 252 253 254 255 256 257 258
// Execute a general unnormalized fft operation (can be c2c, onesided r2c or
// onesided c2r)
template <typename DeviceContext, typename Ti, typename To>
void exec_fft(const DeviceContext& ctx, const Tensor* X, Tensor* out,
              const std::vector<int64_t>& dim, bool forward) {
  const auto x_dims = framework::vectorize(X->dims());
  const int64_t ndim = static_cast<int64_t>(X->dims().size());
  auto tensor_place = ctx.GetPlace();

259
  // make a dim permutation
260 261 262 263 264 265 266 267 268 269 270 271
  std::vector<int> dim_permute(ndim);
  std::iota(dim_permute.begin(), dim_permute.end(), int{0});
  std::vector<bool> is_transformed_dim(ndim);
  for (const auto& d : dim) {
    is_transformed_dim[d] = true;
  }
  auto batch_end =
      std::partition(dim_permute.begin(), dim_permute.end(),
                     [&](int64_t d) { return !is_transformed_dim[d]; });
  std::sort(dim_permute.begin(), batch_end);
  std::copy(dim.cbegin(), dim.cend(), batch_end);

272 273 274 275 276 277 278
  // transpose input according to dim permutation
  auto transposed_input_shape = X->dims().transpose(dim_permute);
  framework::Tensor transposed_input;
  transposed_input.Resize(transposed_input_shape);
  transposed_input.mutable_data<Ti>(tensor_place);
  TransCompute<DeviceContext, Ti>(ndim, ctx, *X, &transposed_input,
                                  dim_permute);
279 280

  // Reshape batch dimensions into a single dimension
281 282 283 284 285
  const int64_t signal_ndim = static_cast<int64_t>(dim.size());
  std::vector<int64_t> collapsed_input_shape(signal_ndim + 1);

  auto transposed_input_shape_ = framework::vectorize(transposed_input_shape);
  const int64_t batch_dims = ndim - signal_ndim;
286
  auto batch_size =
287 288
      std::accumulate(transposed_input_shape_.begin(),
                      transposed_input_shape_.begin() + batch_dims,
289
                      static_cast<int>(1), std::multiplies<int>());
290
  collapsed_input_shape[0] = batch_size;
291

292 293
  std::copy(transposed_input_shape_.begin() + batch_dims,
            transposed_input_shape_.end(), collapsed_input_shape.begin() + 1);
294

295 296 297 298 299 300 301
  framework::Tensor& collapsed_input = transposed_input;
  collapsed_input.Resize(framework::make_ddim(collapsed_input_shape));

  // make a collpased output
  const auto out_dims = framework::vectorize(out->dims());
  std::vector<int64_t> collapsed_output_shape(1 + signal_ndim);
  collapsed_output_shape[0] = batch_size;
302
  for (size_t i = 0; i < dim.size(); ++i) {
303
    collapsed_output_shape[i + 1] = out_dims[dim[i]];
304
  }
305 306 307 308
  framework::Tensor collapsed_output;
  collapsed_output.Resize(framework::make_ddim(collapsed_output_shape));
  collapsed_output.mutable_data<To>(tensor_place);

F
Feiyu Chan 已提交
309 310
  FFTConfig* config = nullptr;

311
#if defined(PADDLE_WITH_CUDA)
F
Feiyu Chan 已提交
312
  std::unique_ptr<FFTConfig> config_ = nullptr;
313
  // create plan
F
Feiyu Chan 已提交
314 315
  FFTConfigKey key =
      create_fft_configkey(collapsed_input, collapsed_output, signal_ndim);
316 317 318 319 320 321
  bool using_cache = false;
#if !defined(CUFFT_VERSION) || (CUFFT_VERSION < 10200)
  using_cache = true;
#endif

  if (using_cache) {
F
Feiyu Chan 已提交
322 323 324 325 326 327 328 329 330 331 332 333
    const int64_t device_id = static_cast<int64_t>(
        reinterpret_cast<const platform::CUDAPlace*>(&collapsed_input.place())
            ->GetDeviceId());
    FFTConfigCache& plan_cache = get_fft_plan_cache(device_id);
    std::unique_lock<std::mutex> guard(plan_cache.mutex, std::defer_lock);
    guard.lock();
    config = &(plan_cache.lookup(key));
  } else {
    config_ = std::make_unique<FFTConfig>(key);
    config = config_.get();
  }

334
  // prepare cufft for execution
335
  PADDLE_ENFORCE_CUDA_SUCCESS(
F
Feiyu Chan 已提交
336
      platform::dynload::cufftSetStream(config->plan(), ctx.stream()));
337
  framework::Tensor workspace_tensor;
F
Feiyu Chan 已提交
338
  workspace_tensor.mutable_data<To>(tensor_place, config->workspace_size());
339
  PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cufftSetWorkArea(
F
Feiyu Chan 已提交
340
      config->plan(), workspace_tensor.data<To>()));
341
  // execute transform plan
F
Feiyu Chan 已提交
342
  exec_cufft_plan<DeviceContext, Ti, To>(ctx, *config, &collapsed_input,
343
                                         &collapsed_output, forward);
344

345 346
#elif defined(PADDLE_WITH_HIP)
  // create plan
F
Feiyu Chan 已提交
347 348 349 350 351 352 353 354 355 356
  FFTConfigKey key =
      create_fft_configkey(collapsed_input, collapsed_output, signal_ndim);
  const int64_t device_id = static_cast<int64_t>(
      reinterpret_cast<const platform::CUDAPlace*>(&collapsed_input.place())
          ->GetDeviceId());
  FFTConfigCache& plan_cache = get_fft_plan_cache(device_id);
  std::unique_lock<std::mutex> guard(plan_cache.mutex, std::defer_lock);
  guard.lock();
  config = &(plan_cache.lookup(key));

357 358
  // prepare cufft for execution
  PADDLE_ENFORCE_CUDA_SUCCESS(
F
Feiyu Chan 已提交
359
      platform::dynload::hipfftSetStream(config->plan(), ctx.stream()));
360
  framework::Tensor workspace_tensor;
F
Feiyu Chan 已提交
361
  workspace_tensor.mutable_data<To>(tensor_place, config->workspace_size());
362
  PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftSetWorkArea(
F
Feiyu Chan 已提交
363
      config->plan(), workspace_tensor.data<To>()));
364
  // execute transform plan
F
Feiyu Chan 已提交
365
  exec_hipfft_plan<DeviceContext, Ti, To>(ctx, *config, &collapsed_input,
366 367
                                          &collapsed_output, forward);
#endif
368 369

  // Inverting output by reshape and transpose to original batch and dimension
370
  auto transposed_out_shape = out->dims().transpose(dim_permute);
371

372 373
  collapsed_output.Resize(transposed_out_shape);
  auto& transposed_output = collapsed_output;
374

375 376 377
  std::vector<int> reverse_dim_permute(ndim);
  for (size_t i = 0; i < ndim; i++) {
    reverse_dim_permute[dim_permute[i]] = i;
378 379
  }

380 381
  TransCompute<DeviceContext, To>(ndim, ctx, transposed_output, out,
                                  reverse_dim_permute);
382
}
383

384 385 386 387
}  // anonymous namespace

// Use the optimized path to perform single R2C or C2R if transformation dim is
// supported by cuFFT
F
Feiyu Chan 已提交
388
bool use_optimized_fft_path(const std::vector<int64_t>& axes) {
389 390
  // For performance reason, when axes starts with (0, 1), do not use the
  // optimized path.
F
Feiyu Chan 已提交
391
  if (axes.size() > kMaxFFTNdim ||
392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420
      (axes.size() >= 2 && axes[0] == 0 && axes[1] == 1)) {
    return false;
  } else {
    return true;
  }
}

template <typename Ti, typename To>
struct FFTC2CFunctor<platform::CUDADeviceContext, Ti, To> {
  void operator()(const platform::CUDADeviceContext& ctx, const Tensor* X,
                  Tensor* out, const std::vector<int64_t>& axes,
                  FFTNormMode normalization, bool forward) {
    if (axes.empty()) {
      framework::TensorCopy(*X, ctx.GetPlace(), out);
      return;
    }

    framework::Tensor* p_out = out;
    std::vector<int64_t> out_dims = framework::vectorize(X->dims());
    std::vector<int64_t> working_axes(axes.begin(), axes.end());
    std::vector<int64_t> first_dims;
    size_t max_dims;
    framework::Tensor working_tensor;
    working_tensor.mutable_data<Ti>(X->dims(), ctx.GetPlace());
    framework::Tensor* p_working_tensor = &working_tensor;
    framework::TensorCopy(*X, ctx.GetPlace(), &working_tensor);

    while (true) {
      max_dims =
F
Feiyu Chan 已提交
421
          std::min(static_cast<size_t>(kMaxFFTNdim), working_axes.size());
422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447
      first_dims.assign(working_axes.end() - max_dims, working_axes.end());

      exec_fft<platform::CUDADeviceContext, Ti, To>(ctx, p_working_tensor,
                                                    p_out, first_dims, forward);
      working_axes.resize(working_axes.size() - max_dims);
      first_dims.clear();

      if (working_axes.empty()) {
        break;
      }

      std::swap(p_out, p_working_tensor);
    }
    exec_normalization<platform::CUDADeviceContext, To>(
        ctx, p_out, out, normalization, out_dims, axes);
  }
};

template <typename Ti, typename To>
struct FFTC2RFunctor<platform::CUDADeviceContext, Ti, To> {
  void operator()(const platform::CUDADeviceContext& ctx, const Tensor* X,
                  Tensor* out, const std::vector<int64_t>& axes,
                  FFTNormMode normalization, bool forward) {
    std::vector<int64_t> in_dims = framework::vectorize(X->dims());
    std::vector<int64_t> out_dims = framework::vectorize(out->dims());

F
Feiyu Chan 已提交
448
    if (use_optimized_fft_path(axes)) {
449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529
      framework::Tensor x_copy(X->type());
      x_copy.mutable_data<Ti>(X->dims(), ctx.GetPlace());
      framework::TensorCopy(*X, ctx.GetPlace(), &x_copy);
      exec_fft<platform::CUDADeviceContext, Ti, To>(ctx, &x_copy, out, axes,
                                                    forward);
    } else {
      framework::Tensor temp_tensor;
      temp_tensor.mutable_data<Ti>(X->dims(), ctx.GetPlace());
      const std::vector<int64_t> dims(axes.begin(), axes.end() - 1);

      FFTC2CFunctor<platform::CUDADeviceContext, Ti, Ti> c2c_functor;
      c2c_functor(ctx, X, &temp_tensor, dims, FFTNormMode::none, forward);

      exec_fft<platform::CUDADeviceContext, Ti, To>(ctx, &temp_tensor, out,
                                                    {axes.back()}, forward);
    }
    exec_normalization<platform::CUDADeviceContext, To>(
        ctx, out, out, normalization, out_dims, axes);
  }
};

// n dimension real to complex FFT use cufft lib
template <typename Ti, typename To>
struct FFTR2CFunctor<platform::CUDADeviceContext, Ti, To> {
  void operator()(const platform::CUDADeviceContext& ctx, const Tensor* X,
                  Tensor* out, const std::vector<int64_t>& axes,
                  FFTNormMode normalization, bool forward) {
    // Step1: R2C transform on the last dimension
    framework::Tensor* r2c_out = out;
    const std::vector<int64_t> last_dim{axes.back()};
    std::vector<int64_t> out_dims = framework::vectorize(out->dims());
    exec_fft<platform::CUDADeviceContext, Ti, To>(ctx, X, r2c_out, last_dim,
                                                  forward);

    // Step2: C2C transform on the remaining dimension
    framework::Tensor c2c_out;
    if (axes.size() > 1) {
      c2c_out.mutable_data<To>(out->dims(), ctx.GetPlace());
      std::vector<int64_t> remain_dim(axes.begin(), axes.end() - 1);
      FFTC2CFunctor<platform::CUDADeviceContext, To, To> fft_c2c_func;
      fft_c2c_func(ctx, r2c_out, &c2c_out, remain_dim, FFTNormMode::none,
                   forward);
    }

    const auto in_sizes = framework::vectorize(X->dims());
    framework::Tensor* norm_tensor = axes.size() > 1 ? &c2c_out : r2c_out;
    exec_normalization<platform::CUDADeviceContext, To>(
        ctx, norm_tensor, out, normalization, in_sizes, axes);
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
    fft_c2c, ops::FFTC2CKernel<paddle::platform::CUDADeviceContext, float>,
    ops::FFTC2CKernel<paddle::platform::CUDADeviceContext, double>);

REGISTER_OP_CUDA_KERNEL(
    fft_c2c_grad,
    ops::FFTC2CGradKernel<paddle::platform::CUDADeviceContext, float>,
    ops::FFTC2CGradKernel<paddle::platform::CUDADeviceContext, double>);

REGISTER_OP_CUDA_KERNEL(
    fft_c2r, ops::FFTC2RKernel<paddle::platform::CUDADeviceContext, float>,
    ops::FFTC2RKernel<paddle::platform::CUDADeviceContext, double>);

REGISTER_OP_CUDA_KERNEL(
    fft_c2r_grad,
    ops::FFTC2RGradKernel<paddle::platform::CUDADeviceContext, float>,
    ops::FFTC2RGradKernel<paddle::platform::CUDADeviceContext, double>);

REGISTER_OP_CUDA_KERNEL(
    fft_r2c, ops::FFTR2CKernel<paddle::platform::CUDADeviceContext, float>,
    ops::FFTR2CKernel<paddle::platform::CUDADeviceContext, double>);

REGISTER_OP_CUDA_KERNEL(
    fft_r2c_grad,
    ops::FFTR2CGradKernel<paddle::platform::CUDADeviceContext, float>,
    ops::FFTR2CGradKernel<paddle::platform::CUDADeviceContext, double>);