spectral_op.cu 21.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at
   http://www.apache.org/licenses/LICENSE-2.0
   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */
#include <functional>
#include <list>
#include <memory>
#include <mutex>
#include <numeric>
#include <sstream>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <vector>

#include "paddle/fluid/operators/conj_op.h"
23
#include "paddle/fluid/operators/spectral_helper.h"
24 25
#include "paddle/fluid/operators/spectral_op.h"
#include "paddle/fluid/operators/transpose_op.h"
26
#include "paddle/fluid/platform/enforce.h"
27
#include "paddle/pten/kernels/funcs/complex_functors.h"
28 29 30 31 32 33

namespace paddle {
namespace operators {

namespace {

34 35 36 37 38 39 40
// Calculates the normalization constant
double fft_normalization_scale(FFTNormMode normalization,
                               const std::vector<int64_t>& sizes,
                               const std::vector<int64_t>& dims) {
  // auto norm = static_cast<fft_norm_mode>(normalization);
  if (normalization == FFTNormMode::none) {
    return static_cast<double>(1.0);
41 42
  }

43 44 45 46 47 48 49 50
  int64_t signal_numel = 1;
  for (auto dim : dims) {
    signal_numel *= sizes[dim];
  }
  const double scale_denom = (normalization == FFTNormMode::by_sqrt_n)
                                 ? std::sqrt(signal_numel)
                                 : static_cast<double>(signal_numel);
  return static_cast<double>(1.0 / scale_denom);
51 52
}

53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
template <typename DeviceContext, typename T>
void exec_normalization(const DeviceContext& ctx, const Tensor* in, Tensor* out,
                        FFTNormMode normalization,
                        const std::vector<int64_t>& sizes,
                        const std::vector<int64_t>& axes) {
  double scale = fft_normalization_scale(normalization, sizes, axes);
  if (scale != 1.0) {
    auto eigen_out = framework::EigenVector<T>::Flatten(*out);
    auto eigen_in = framework::EigenVector<T>::Flatten(*in);
    auto dev = ctx.eigen_device();
    EigenScale<Eigen::GpuDevice, T>::Eval(*dev, eigen_out, eigen_in,
                                          static_cast<T>(scale),
                                          static_cast<T>(0), false);
  } else {
    framework::TensorCopy(*in, ctx.GetPlace(), out);
68
  }
69
}
70

71
#if defined(PADDLE_WITH_CUDA)
F
Feiyu Chan 已提交
72 73 74
FFTConfigKey create_fft_configkey(const framework::Tensor& input,
                                  const framework::Tensor& output,
                                  int signal_ndim) {
75
  // Create the transform plan (either from cache or locally)
76 77 78 79 80 81 82
  const auto value_type =
      framework::IsComplexType(framework::TransToProtoVarType(input.dtype()))
          ? framework::ToRealType(framework::TransToProtoVarType(input.dtype()))
          : framework::TransToProtoVarType(input.dtype());
  auto fft_type =
      GetFFTTransformType(framework::TransToProtoVarType(input.dtype()),
                          framework::TransToProtoVarType(output.dtype()));
83 84
  // signal sizes
  std::vector<int64_t> signal_size(signal_ndim + 1);
85

86 87 88 89 90
  signal_size[0] = input.dims()[0];
  for (int64_t i = 1; i <= signal_ndim; ++i) {
    auto in_size = input.dims()[i];
    auto out_size = output.dims()[i];
    signal_size[i] = std::max(in_size, out_size);
91
  }
F
Feiyu Chan 已提交
92 93 94 95
  FFTConfigKey key(framework::vectorize(input.dims()),
                   framework::vectorize(output.dims()), signal_size, fft_type,
                   value_type);
  return key;
96
}
97

98
// Execute a pre-planned transform
F
Feiyu Chan 已提交
99
static void exec_cufft_plan_raw(const FFTConfig& config, void* in_data,
100 101
                                void* out_data, bool forward) {
  auto& plan = config.plan();
102

103
  PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cufftXtExec(
104 105
      plan, in_data, out_data, forward ? CUFFT_FORWARD : CUFFT_INVERSE));
}
106

107
template <typename DeviceContext, typename Ti, typename To>
F
Feiyu Chan 已提交
108
void exec_cufft_plan(const DeviceContext& ctx, const FFTConfig& config,
109 110 111 112 113 114 115 116 117
                     framework::Tensor* input, framework::Tensor* output,
                     bool forward) {
  // execute transform plan
  auto fft_type = config.transform_type();
  if (fft_type == FFTTransformType::C2R && forward) {
    forward = false;
    framework::Tensor input_conj(input->type());
    input_conj.mutable_data<Ti>(input->dims(), ctx.GetPlace());
    platform::ForRange<DeviceContext> for_range(ctx, input->numel());
118 119
    pten::funcs::ConjFunctor<Ti> functor(input->data<Ti>(), input->numel(),
                                         input_conj.data<Ti>());
120
    for_range(functor);
121
    exec_cufft_plan_raw(config, input_conj.data(), output->data(), forward);
122 123 124 125
  } else if (fft_type == FFTTransformType::R2C && !forward) {
    forward = true;
    framework::Tensor out_conj(output->type());
    out_conj.mutable_data<To>(output->dims(), ctx.GetPlace());
126
    exec_cufft_plan_raw(config, input->data(), out_conj.data(), forward);
127 128

    platform::ForRange<DeviceContext> for_range(ctx, output->numel());
129 130
    pten::funcs::ConjFunctor<To> functor(out_conj.data<To>(), output->numel(),
                                         output->data<To>());
131 132
    for_range(functor);
  } else {
133
    exec_cufft_plan_raw(config, input->data(), output->data(), forward);
134
  }
135
}
136

137
#elif defined(PADDLE_WITH_HIP)
138

F
Feiyu Chan 已提交
139
FFTConfigKey create_fft_configkey(const framework::Tensor& input,
140 141 142
                                  const framework::Tensor& output,
                                  int signal_ndim) {
  // Create the transform plan (either from cache or locally)
143 144 145 146 147 148 149
  const auto value_type =
      framework::IsComplexType(framework::TransToProtoVarType(input.dtype()))
          ? framework::ToRealType(framework::TransToProtoVarType(input.dtype()))
          : framework::TransToProtoVarType(input.dtype());
  auto fft_type =
      GetFFTTransformType(framework::TransToProtoVarType(input.dtype()),
                          framework::TransToProtoVarType(output.type()));
150 151
  // signal sizes
  std::vector<int64_t> signal_size(signal_ndim + 1);
152

153 154 155 156 157 158
  signal_size[0] = input.dims()[0];
  for (int64_t i = 1; i <= signal_ndim; ++i) {
    auto in_size = input.dims()[i];
    auto out_size = output.dims()[i];
    signal_size[i] = std::max(in_size, out_size);
  }
F
Feiyu Chan 已提交
159 160 161 162
  FFTConfigKey key(framework::vectorize(input.dims()),
                   framework::vectorize(output.dims()), signal_size, fft_type,
                   value_type);
  return key;
163
}
164 165

// Execute a pre-planned transform
F
Feiyu Chan 已提交
166
static void exec_hipfft_plan_raw(const FFTConfig& config, void* in_data,
167
                                 void* out_data, bool forward) {
168
  auto& plan = config.plan();
169

170 171 172 173
  auto value_type = config.data_type();
  if (value_type == framework::proto::VarType::FP32) {
    switch (config.transform_type()) {
      case FFTTransformType::C2C: {
174
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftExecC2C(
175 176 177
            plan, static_cast<hipfftComplex*>(in_data),
            static_cast<hipfftComplex*>(out_data),
            forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD));
178 179 180
        return;
      }
      case FFTTransformType::R2C: {
181
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftExecR2C(
182 183
            plan, static_cast<hipfftReal*>(in_data),
            static_cast<hipfftComplex*>(out_data)));
184 185 186
        return;
      }
      case FFTTransformType::C2R: {
187
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftExecC2R(
188 189
            plan, static_cast<hipfftComplex*>(in_data),
            static_cast<hipfftReal*>(out_data)));
190 191 192 193 194 195
        return;
      }
    }
  } else if (value_type == framework::proto::VarType::FP64) {
    switch (config.transform_type()) {
      case FFTTransformType::C2C: {
196
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftExecZ2Z(
197 198 199
            plan, static_cast<hipfftDoubleComplex*>(in_data),
            static_cast<hipfftDoubleComplex*>(out_data),
            forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD));
200 201 202
        return;
      }
      case FFTTransformType::R2C: {
203
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftExecD2Z(
204 205
            plan, static_cast<hipfftDoubleReal*>(in_data),
            static_cast<hipfftDoubleComplex*>(out_data)));
206 207 208
        return;
      }
      case FFTTransformType::C2R: {
209
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftExecZ2D(
210 211
            plan, static_cast<hipfftDoubleComplex*>(in_data),
            static_cast<hipfftDoubleReal*>(out_data)));
212 213 214 215 216 217 218 219
        return;
      }
    }
  }
  PADDLE_THROW(platform::errors::InvalidArgument(
      "hipFFT only support transforms of type float32 and float64"));
}

220
template <typename DeviceContext, typename Ti, typename To>
F
Feiyu Chan 已提交
221
void exec_hipfft_plan(const DeviceContext& ctx, const FFTConfig& config,
222 223 224 225 226 227 228 229
                      framework::Tensor* input, framework::Tensor* output,
                      bool forward) {
  auto fft_type = config.transform_type();
  if (fft_type == FFTTransformType::C2R && forward) {
    forward = false;
    framework::Tensor input_conj(input->type());
    input_conj.mutable_data<Ti>(input->dims(), ctx.GetPlace());
    platform::ForRange<DeviceContext> for_range(ctx, input->numel());
230 231
    pten::funcs::ConjFunctor<Ti> functor(input->data<Ti>(), input->numel(),
                                         input_conj.data<Ti>());
232
    for_range(functor);
233
    exec_hipfft_plan_raw(config, input_conj.data(), output->data(), forward);
234 235 236 237
  } else if (fft_type == FFTTransformType::R2C && !forward) {
    forward = true;
    framework::Tensor out_conj(output->type());
    out_conj.mutable_data<To>(output->dims(), ctx.GetPlace());
238
    exec_hipfft_plan_raw(config, input->data(), out_conj.data(), forward);
239 240

    platform::ForRange<DeviceContext> for_range(ctx, output->numel());
241 242
    pten::funcs::ConjFunctor<To> functor(out_conj.data<To>(), output->numel(),
                                         output->data<To>());
243 244
    for_range(functor);
  } else {
245
    exec_hipfft_plan_raw(config, input->data(), output->data(), forward);
246 247 248 249 250
  }
}

#endif

251 252 253 254 255 256 257 258 259
// Execute a general unnormalized fft operation (can be c2c, onesided r2c or
// onesided c2r)
template <typename DeviceContext, typename Ti, typename To>
void exec_fft(const DeviceContext& ctx, const Tensor* X, Tensor* out,
              const std::vector<int64_t>& dim, bool forward) {
  const auto x_dims = framework::vectorize(X->dims());
  const int64_t ndim = static_cast<int64_t>(X->dims().size());
  auto tensor_place = ctx.GetPlace();

260
  // make a dim permutation
261 262 263 264 265 266 267 268 269 270 271 272
  std::vector<int> dim_permute(ndim);
  std::iota(dim_permute.begin(), dim_permute.end(), int{0});
  std::vector<bool> is_transformed_dim(ndim);
  for (const auto& d : dim) {
    is_transformed_dim[d] = true;
  }
  auto batch_end =
      std::partition(dim_permute.begin(), dim_permute.end(),
                     [&](int64_t d) { return !is_transformed_dim[d]; });
  std::sort(dim_permute.begin(), batch_end);
  std::copy(dim.cbegin(), dim.cend(), batch_end);

273 274 275 276 277 278 279
  // transpose input according to dim permutation
  auto transposed_input_shape = X->dims().transpose(dim_permute);
  framework::Tensor transposed_input;
  transposed_input.Resize(transposed_input_shape);
  transposed_input.mutable_data<Ti>(tensor_place);
  TransCompute<DeviceContext, Ti>(ndim, ctx, *X, &transposed_input,
                                  dim_permute);
280 281

  // Reshape batch dimensions into a single dimension
282 283 284 285 286
  const int64_t signal_ndim = static_cast<int64_t>(dim.size());
  std::vector<int64_t> collapsed_input_shape(signal_ndim + 1);

  auto transposed_input_shape_ = framework::vectorize(transposed_input_shape);
  const int64_t batch_dims = ndim - signal_ndim;
287
  auto batch_size =
288 289
      std::accumulate(transposed_input_shape_.begin(),
                      transposed_input_shape_.begin() + batch_dims,
290
                      static_cast<int>(1), std::multiplies<int>());
291
  collapsed_input_shape[0] = batch_size;
292

293 294
  std::copy(transposed_input_shape_.begin() + batch_dims,
            transposed_input_shape_.end(), collapsed_input_shape.begin() + 1);
295

296 297 298 299 300 301 302
  framework::Tensor& collapsed_input = transposed_input;
  collapsed_input.Resize(framework::make_ddim(collapsed_input_shape));

  // make a collpased output
  const auto out_dims = framework::vectorize(out->dims());
  std::vector<int64_t> collapsed_output_shape(1 + signal_ndim);
  collapsed_output_shape[0] = batch_size;
303
  for (size_t i = 0; i < dim.size(); ++i) {
304
    collapsed_output_shape[i + 1] = out_dims[dim[i]];
305
  }
306 307 308 309
  framework::Tensor collapsed_output;
  collapsed_output.Resize(framework::make_ddim(collapsed_output_shape));
  collapsed_output.mutable_data<To>(tensor_place);

F
Feiyu Chan 已提交
310 311
  FFTConfig* config = nullptr;

312
#if defined(PADDLE_WITH_CUDA)
F
Feiyu Chan 已提交
313
  std::unique_ptr<FFTConfig> config_ = nullptr;
314
  // create plan
F
Feiyu Chan 已提交
315 316
  FFTConfigKey key =
      create_fft_configkey(collapsed_input, collapsed_output, signal_ndim);
317 318 319 320 321 322
  bool using_cache = false;
#if !defined(CUFFT_VERSION) || (CUFFT_VERSION < 10200)
  using_cache = true;
#endif

  if (using_cache) {
F
Feiyu Chan 已提交
323 324 325 326 327 328 329 330 331 332 333 334
    const int64_t device_id = static_cast<int64_t>(
        reinterpret_cast<const platform::CUDAPlace*>(&collapsed_input.place())
            ->GetDeviceId());
    FFTConfigCache& plan_cache = get_fft_plan_cache(device_id);
    std::unique_lock<std::mutex> guard(plan_cache.mutex, std::defer_lock);
    guard.lock();
    config = &(plan_cache.lookup(key));
  } else {
    config_ = std::make_unique<FFTConfig>(key);
    config = config_.get();
  }

335
  // prepare cufft for execution
336
  PADDLE_ENFORCE_GPU_SUCCESS(
F
Feiyu Chan 已提交
337
      platform::dynload::cufftSetStream(config->plan(), ctx.stream()));
338
  framework::Tensor workspace_tensor;
F
Feiyu Chan 已提交
339
  workspace_tensor.mutable_data<To>(tensor_place, config->workspace_size());
340
  PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cufftSetWorkArea(
F
Feiyu Chan 已提交
341
      config->plan(), workspace_tensor.data<To>()));
342
  // execute transform plan
F
Feiyu Chan 已提交
343
  exec_cufft_plan<DeviceContext, Ti, To>(ctx, *config, &collapsed_input,
344
                                         &collapsed_output, forward);
345

346 347
#elif defined(PADDLE_WITH_HIP)
  // create plan
F
Feiyu Chan 已提交
348 349 350 351 352 353 354 355 356 357
  FFTConfigKey key =
      create_fft_configkey(collapsed_input, collapsed_output, signal_ndim);
  const int64_t device_id = static_cast<int64_t>(
      reinterpret_cast<const platform::CUDAPlace*>(&collapsed_input.place())
          ->GetDeviceId());
  FFTConfigCache& plan_cache = get_fft_plan_cache(device_id);
  std::unique_lock<std::mutex> guard(plan_cache.mutex, std::defer_lock);
  guard.lock();
  config = &(plan_cache.lookup(key));

358
  // prepare cufft for execution
359
  PADDLE_ENFORCE_GPU_SUCCESS(
F
Feiyu Chan 已提交
360
      platform::dynload::hipfftSetStream(config->plan(), ctx.stream()));
361
  framework::Tensor workspace_tensor;
F
Feiyu Chan 已提交
362
  workspace_tensor.mutable_data<To>(tensor_place, config->workspace_size());
363
  PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftSetWorkArea(
F
Feiyu Chan 已提交
364
      config->plan(), workspace_tensor.data<To>()));
365
  // execute transform plan
F
Feiyu Chan 已提交
366
  exec_hipfft_plan<DeviceContext, Ti, To>(ctx, *config, &collapsed_input,
367 368
                                          &collapsed_output, forward);
#endif
369 370

  // Inverting output by reshape and transpose to original batch and dimension
371
  auto transposed_out_shape = out->dims().transpose(dim_permute);
372

373 374
  collapsed_output.Resize(transposed_out_shape);
  auto& transposed_output = collapsed_output;
375

376 377 378
  std::vector<int> reverse_dim_permute(ndim);
  for (size_t i = 0; i < ndim; i++) {
    reverse_dim_permute[dim_permute[i]] = i;
379 380
  }

381 382
  TransCompute<DeviceContext, To>(ndim, ctx, transposed_output, out,
                                  reverse_dim_permute);
383
}
384

385 386 387 388
}  // anonymous namespace

// Use the optimized path to perform single R2C or C2R if transformation dim is
// supported by cuFFT
F
Feiyu Chan 已提交
389
bool use_optimized_fft_path(const std::vector<int64_t>& axes) {
390 391
  // For performance reason, when axes starts with (0, 1), do not use the
  // optimized path.
F
Feiyu Chan 已提交
392
  if (axes.size() > kMaxFFTNdim ||
393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421
      (axes.size() >= 2 && axes[0] == 0 && axes[1] == 1)) {
    return false;
  } else {
    return true;
  }
}

template <typename Ti, typename To>
struct FFTC2CFunctor<platform::CUDADeviceContext, Ti, To> {
  void operator()(const platform::CUDADeviceContext& ctx, const Tensor* X,
                  Tensor* out, const std::vector<int64_t>& axes,
                  FFTNormMode normalization, bool forward) {
    if (axes.empty()) {
      framework::TensorCopy(*X, ctx.GetPlace(), out);
      return;
    }

    framework::Tensor* p_out = out;
    std::vector<int64_t> out_dims = framework::vectorize(X->dims());
    std::vector<int64_t> working_axes(axes.begin(), axes.end());
    std::vector<int64_t> first_dims;
    size_t max_dims;
    framework::Tensor working_tensor;
    working_tensor.mutable_data<Ti>(X->dims(), ctx.GetPlace());
    framework::Tensor* p_working_tensor = &working_tensor;
    framework::TensorCopy(*X, ctx.GetPlace(), &working_tensor);

    while (true) {
      max_dims =
F
Feiyu Chan 已提交
422
          std::min(static_cast<size_t>(kMaxFFTNdim), working_axes.size());
423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448
      first_dims.assign(working_axes.end() - max_dims, working_axes.end());

      exec_fft<platform::CUDADeviceContext, Ti, To>(ctx, p_working_tensor,
                                                    p_out, first_dims, forward);
      working_axes.resize(working_axes.size() - max_dims);
      first_dims.clear();

      if (working_axes.empty()) {
        break;
      }

      std::swap(p_out, p_working_tensor);
    }
    exec_normalization<platform::CUDADeviceContext, To>(
        ctx, p_out, out, normalization, out_dims, axes);
  }
};

template <typename Ti, typename To>
struct FFTC2RFunctor<platform::CUDADeviceContext, Ti, To> {
  void operator()(const platform::CUDADeviceContext& ctx, const Tensor* X,
                  Tensor* out, const std::vector<int64_t>& axes,
                  FFTNormMode normalization, bool forward) {
    std::vector<int64_t> in_dims = framework::vectorize(X->dims());
    std::vector<int64_t> out_dims = framework::vectorize(out->dims());

F
Feiyu Chan 已提交
449
    if (use_optimized_fft_path(axes)) {
450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530
      framework::Tensor x_copy(X->type());
      x_copy.mutable_data<Ti>(X->dims(), ctx.GetPlace());
      framework::TensorCopy(*X, ctx.GetPlace(), &x_copy);
      exec_fft<platform::CUDADeviceContext, Ti, To>(ctx, &x_copy, out, axes,
                                                    forward);
    } else {
      framework::Tensor temp_tensor;
      temp_tensor.mutable_data<Ti>(X->dims(), ctx.GetPlace());
      const std::vector<int64_t> dims(axes.begin(), axes.end() - 1);

      FFTC2CFunctor<platform::CUDADeviceContext, Ti, Ti> c2c_functor;
      c2c_functor(ctx, X, &temp_tensor, dims, FFTNormMode::none, forward);

      exec_fft<platform::CUDADeviceContext, Ti, To>(ctx, &temp_tensor, out,
                                                    {axes.back()}, forward);
    }
    exec_normalization<platform::CUDADeviceContext, To>(
        ctx, out, out, normalization, out_dims, axes);
  }
};

// n dimension real to complex FFT use cufft lib
template <typename Ti, typename To>
struct FFTR2CFunctor<platform::CUDADeviceContext, Ti, To> {
  void operator()(const platform::CUDADeviceContext& ctx, const Tensor* X,
                  Tensor* out, const std::vector<int64_t>& axes,
                  FFTNormMode normalization, bool forward) {
    // Step1: R2C transform on the last dimension
    framework::Tensor* r2c_out = out;
    const std::vector<int64_t> last_dim{axes.back()};
    std::vector<int64_t> out_dims = framework::vectorize(out->dims());
    exec_fft<platform::CUDADeviceContext, Ti, To>(ctx, X, r2c_out, last_dim,
                                                  forward);

    // Step2: C2C transform on the remaining dimension
    framework::Tensor c2c_out;
    if (axes.size() > 1) {
      c2c_out.mutable_data<To>(out->dims(), ctx.GetPlace());
      std::vector<int64_t> remain_dim(axes.begin(), axes.end() - 1);
      FFTC2CFunctor<platform::CUDADeviceContext, To, To> fft_c2c_func;
      fft_c2c_func(ctx, r2c_out, &c2c_out, remain_dim, FFTNormMode::none,
                   forward);
    }

    const auto in_sizes = framework::vectorize(X->dims());
    framework::Tensor* norm_tensor = axes.size() > 1 ? &c2c_out : r2c_out;
    exec_normalization<platform::CUDADeviceContext, To>(
        ctx, norm_tensor, out, normalization, in_sizes, axes);
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
    fft_c2c, ops::FFTC2CKernel<paddle::platform::CUDADeviceContext, float>,
    ops::FFTC2CKernel<paddle::platform::CUDADeviceContext, double>);

REGISTER_OP_CUDA_KERNEL(
    fft_c2c_grad,
    ops::FFTC2CGradKernel<paddle::platform::CUDADeviceContext, float>,
    ops::FFTC2CGradKernel<paddle::platform::CUDADeviceContext, double>);

REGISTER_OP_CUDA_KERNEL(
    fft_c2r, ops::FFTC2RKernel<paddle::platform::CUDADeviceContext, float>,
    ops::FFTC2RKernel<paddle::platform::CUDADeviceContext, double>);

REGISTER_OP_CUDA_KERNEL(
    fft_c2r_grad,
    ops::FFTC2RGradKernel<paddle::platform::CUDADeviceContext, float>,
    ops::FFTC2RGradKernel<paddle::platform::CUDADeviceContext, double>);

REGISTER_OP_CUDA_KERNEL(
    fft_r2c, ops::FFTR2CKernel<paddle::platform::CUDADeviceContext, float>,
    ops::FFTR2CKernel<paddle::platform::CUDADeviceContext, double>);

REGISTER_OP_CUDA_KERNEL(
    fft_r2c_grad,
    ops::FFTR2CGradKernel<paddle::platform::CUDADeviceContext, float>,
    ops::FFTR2CGradKernel<paddle::platform::CUDADeviceContext, double>);