spectral_op.cu 21.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at
   http://www.apache.org/licenses/LICENSE-2.0
   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */
#include <functional>
#include <list>
#include <memory>
#include <mutex>
#include <numeric>
#include <sstream>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <vector>

#include "paddle/fluid/operators/conj_op.h"
23
#include "paddle/fluid/operators/spectral_helper.h"
24 25
#include "paddle/fluid/operators/spectral_op.h"
#include "paddle/fluid/operators/transpose_op.h"
26
#include "paddle/fluid/platform/enforce.h"
27
#include "paddle/phi/kernels/funcs/complex_functors.h"
28 29 30 31 32 33

namespace paddle {
namespace operators {

namespace {

34 35 36 37 38 39 40
// Calculates the normalization constant
double fft_normalization_scale(FFTNormMode normalization,
                               const std::vector<int64_t>& sizes,
                               const std::vector<int64_t>& dims) {
  // auto norm = static_cast<fft_norm_mode>(normalization);
  if (normalization == FFTNormMode::none) {
    return static_cast<double>(1.0);
41 42
  }

43 44 45 46 47 48 49 50
  int64_t signal_numel = 1;
  for (auto dim : dims) {
    signal_numel *= sizes[dim];
  }
  const double scale_denom = (normalization == FFTNormMode::by_sqrt_n)
                                 ? std::sqrt(signal_numel)
                                 : static_cast<double>(signal_numel);
  return static_cast<double>(1.0 / scale_denom);
51 52
}

53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
template <typename DeviceContext, typename T>
void exec_normalization(const DeviceContext& ctx, const Tensor* in, Tensor* out,
                        FFTNormMode normalization,
                        const std::vector<int64_t>& sizes,
                        const std::vector<int64_t>& axes) {
  double scale = fft_normalization_scale(normalization, sizes, axes);
  if (scale != 1.0) {
    auto eigen_out = framework::EigenVector<T>::Flatten(*out);
    auto eigen_in = framework::EigenVector<T>::Flatten(*in);
    auto dev = ctx.eigen_device();
    EigenScale<Eigen::GpuDevice, T>::Eval(*dev, eigen_out, eigen_in,
                                          static_cast<T>(scale),
                                          static_cast<T>(0), false);
  } else {
    framework::TensorCopy(*in, ctx.GetPlace(), out);
68
  }
69
}
70

71
#if defined(PADDLE_WITH_CUDA)
F
Feiyu Chan 已提交
72 73 74
FFTConfigKey create_fft_configkey(const framework::Tensor& input,
                                  const framework::Tensor& output,
                                  int signal_ndim) {
75
  // Create the transform plan (either from cache or locally)
76 77 78 79 80 81 82
  const auto value_type =
      framework::IsComplexType(framework::TransToProtoVarType(input.dtype()))
          ? framework::ToRealType(framework::TransToProtoVarType(input.dtype()))
          : framework::TransToProtoVarType(input.dtype());
  auto fft_type =
      GetFFTTransformType(framework::TransToProtoVarType(input.dtype()),
                          framework::TransToProtoVarType(output.dtype()));
83 84
  // signal sizes
  std::vector<int64_t> signal_size(signal_ndim + 1);
85

86 87 88 89 90
  signal_size[0] = input.dims()[0];
  for (int64_t i = 1; i <= signal_ndim; ++i) {
    auto in_size = input.dims()[i];
    auto out_size = output.dims()[i];
    signal_size[i] = std::max(in_size, out_size);
91
  }
92 93
  FFTConfigKey key(phi::vectorize(input.dims()), phi::vectorize(output.dims()),
                   signal_size, fft_type, value_type);
F
Feiyu Chan 已提交
94
  return key;
95
}
96

97
// Execute a pre-planned transform
F
Feiyu Chan 已提交
98
static void exec_cufft_plan_raw(const FFTConfig& config, void* in_data,
99 100
                                void* out_data, bool forward) {
  auto& plan = config.plan();
101

102
  PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cufftXtExec(
103 104
      plan, in_data, out_data, forward ? CUFFT_FORWARD : CUFFT_INVERSE));
}
105

106
template <typename DeviceContext, typename Ti, typename To>
F
Feiyu Chan 已提交
107
void exec_cufft_plan(const DeviceContext& ctx, const FFTConfig& config,
108 109 110 111 112 113 114 115 116
                     framework::Tensor* input, framework::Tensor* output,
                     bool forward) {
  // execute transform plan
  auto fft_type = config.transform_type();
  if (fft_type == FFTTransformType::C2R && forward) {
    forward = false;
    framework::Tensor input_conj(input->type());
    input_conj.mutable_data<Ti>(input->dims(), ctx.GetPlace());
    platform::ForRange<DeviceContext> for_range(ctx, input->numel());
117 118
    phi::funcs::ConjFunctor<Ti> functor(input->data<Ti>(), input->numel(),
                                        input_conj.data<Ti>());
119
    for_range(functor);
120
    exec_cufft_plan_raw(config, input_conj.data(), output->data(), forward);
121 122 123 124
  } else if (fft_type == FFTTransformType::R2C && !forward) {
    forward = true;
    framework::Tensor out_conj(output->type());
    out_conj.mutable_data<To>(output->dims(), ctx.GetPlace());
125
    exec_cufft_plan_raw(config, input->data(), out_conj.data(), forward);
126 127

    platform::ForRange<DeviceContext> for_range(ctx, output->numel());
128 129
    phi::funcs::ConjFunctor<To> functor(out_conj.data<To>(), output->numel(),
                                        output->data<To>());
130 131
    for_range(functor);
  } else {
132
    exec_cufft_plan_raw(config, input->data(), output->data(), forward);
133
  }
134
}
135

136
#elif defined(PADDLE_WITH_HIP)
137

F
Feiyu Chan 已提交
138
FFTConfigKey create_fft_configkey(const framework::Tensor& input,
139 140 141
                                  const framework::Tensor& output,
                                  int signal_ndim) {
  // Create the transform plan (either from cache or locally)
142 143 144 145 146 147 148
  const auto value_type =
      framework::IsComplexType(framework::TransToProtoVarType(input.dtype()))
          ? framework::ToRealType(framework::TransToProtoVarType(input.dtype()))
          : framework::TransToProtoVarType(input.dtype());
  auto fft_type =
      GetFFTTransformType(framework::TransToProtoVarType(input.dtype()),
                          framework::TransToProtoVarType(output.type()));
149 150
  // signal sizes
  std::vector<int64_t> signal_size(signal_ndim + 1);
151

152 153 154 155 156 157
  signal_size[0] = input.dims()[0];
  for (int64_t i = 1; i <= signal_ndim; ++i) {
    auto in_size = input.dims()[i];
    auto out_size = output.dims()[i];
    signal_size[i] = std::max(in_size, out_size);
  }
158 159
  FFTConfigKey key(phi::vectorize(input.dims()), phi::vectorize(output.dims()),
                   signal_size, fft_type, value_type);
F
Feiyu Chan 已提交
160
  return key;
161
}
162 163

// Execute a pre-planned transform
F
Feiyu Chan 已提交
164
static void exec_hipfft_plan_raw(const FFTConfig& config, void* in_data,
165
                                 void* out_data, bool forward) {
166
  auto& plan = config.plan();
167

168 169 170 171
  auto value_type = config.data_type();
  if (value_type == framework::proto::VarType::FP32) {
    switch (config.transform_type()) {
      case FFTTransformType::C2C: {
172
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftExecC2C(
173 174 175
            plan, static_cast<hipfftComplex*>(in_data),
            static_cast<hipfftComplex*>(out_data),
            forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD));
176 177 178
        return;
      }
      case FFTTransformType::R2C: {
179
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftExecR2C(
180 181
            plan, static_cast<hipfftReal*>(in_data),
            static_cast<hipfftComplex*>(out_data)));
182 183 184
        return;
      }
      case FFTTransformType::C2R: {
185
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftExecC2R(
186 187
            plan, static_cast<hipfftComplex*>(in_data),
            static_cast<hipfftReal*>(out_data)));
188 189 190 191 192 193
        return;
      }
    }
  } else if (value_type == framework::proto::VarType::FP64) {
    switch (config.transform_type()) {
      case FFTTransformType::C2C: {
194
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftExecZ2Z(
195 196 197
            plan, static_cast<hipfftDoubleComplex*>(in_data),
            static_cast<hipfftDoubleComplex*>(out_data),
            forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD));
198 199 200
        return;
      }
      case FFTTransformType::R2C: {
201
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftExecD2Z(
202 203
            plan, static_cast<hipfftDoubleReal*>(in_data),
            static_cast<hipfftDoubleComplex*>(out_data)));
204 205 206
        return;
      }
      case FFTTransformType::C2R: {
207
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftExecZ2D(
208 209
            plan, static_cast<hipfftDoubleComplex*>(in_data),
            static_cast<hipfftDoubleReal*>(out_data)));
210 211 212 213 214 215 216 217
        return;
      }
    }
  }
  PADDLE_THROW(platform::errors::InvalidArgument(
      "hipFFT only support transforms of type float32 and float64"));
}

218
template <typename DeviceContext, typename Ti, typename To>
F
Feiyu Chan 已提交
219
void exec_hipfft_plan(const DeviceContext& ctx, const FFTConfig& config,
220 221 222 223 224 225 226 227
                      framework::Tensor* input, framework::Tensor* output,
                      bool forward) {
  auto fft_type = config.transform_type();
  if (fft_type == FFTTransformType::C2R && forward) {
    forward = false;
    framework::Tensor input_conj(input->type());
    input_conj.mutable_data<Ti>(input->dims(), ctx.GetPlace());
    platform::ForRange<DeviceContext> for_range(ctx, input->numel());
228 229
    phi::funcs::ConjFunctor<Ti> functor(input->data<Ti>(), input->numel(),
                                        input_conj.data<Ti>());
230
    for_range(functor);
231
    exec_hipfft_plan_raw(config, input_conj.data(), output->data(), forward);
232 233 234 235
  } else if (fft_type == FFTTransformType::R2C && !forward) {
    forward = true;
    framework::Tensor out_conj(output->type());
    out_conj.mutable_data<To>(output->dims(), ctx.GetPlace());
236
    exec_hipfft_plan_raw(config, input->data(), out_conj.data(), forward);
237 238

    platform::ForRange<DeviceContext> for_range(ctx, output->numel());
239 240
    phi::funcs::ConjFunctor<To> functor(out_conj.data<To>(), output->numel(),
                                        output->data<To>());
241 242
    for_range(functor);
  } else {
243
    exec_hipfft_plan_raw(config, input->data(), output->data(), forward);
244 245 246 247 248
  }
}

#endif

249 250 251 252 253
// Execute a general unnormalized fft operation (can be c2c, onesided r2c or
// onesided c2r)
template <typename DeviceContext, typename Ti, typename To>
void exec_fft(const DeviceContext& ctx, const Tensor* X, Tensor* out,
              const std::vector<int64_t>& dim, bool forward) {
254
  const auto x_dims = phi::vectorize(X->dims());
255 256 257
  const int64_t ndim = static_cast<int64_t>(X->dims().size());
  auto tensor_place = ctx.GetPlace();

258
  // make a dim permutation
259 260 261 262 263 264 265 266 267 268 269 270
  std::vector<int> dim_permute(ndim);
  std::iota(dim_permute.begin(), dim_permute.end(), int{0});
  std::vector<bool> is_transformed_dim(ndim);
  for (const auto& d : dim) {
    is_transformed_dim[d] = true;
  }
  auto batch_end =
      std::partition(dim_permute.begin(), dim_permute.end(),
                     [&](int64_t d) { return !is_transformed_dim[d]; });
  std::sort(dim_permute.begin(), batch_end);
  std::copy(dim.cbegin(), dim.cend(), batch_end);

271 272 273 274 275 276 277
  // transpose input according to dim permutation
  auto transposed_input_shape = X->dims().transpose(dim_permute);
  framework::Tensor transposed_input;
  transposed_input.Resize(transposed_input_shape);
  transposed_input.mutable_data<Ti>(tensor_place);
  TransCompute<DeviceContext, Ti>(ndim, ctx, *X, &transposed_input,
                                  dim_permute);
278 279

  // Reshape batch dimensions into a single dimension
280 281 282
  const int64_t signal_ndim = static_cast<int64_t>(dim.size());
  std::vector<int64_t> collapsed_input_shape(signal_ndim + 1);

283
  auto transposed_input_shape_ = phi::vectorize(transposed_input_shape);
284
  const int64_t batch_dims = ndim - signal_ndim;
285
  auto batch_size =
286 287
      std::accumulate(transposed_input_shape_.begin(),
                      transposed_input_shape_.begin() + batch_dims,
288
                      static_cast<int>(1), std::multiplies<int>());
289
  collapsed_input_shape[0] = batch_size;
290

291 292
  std::copy(transposed_input_shape_.begin() + batch_dims,
            transposed_input_shape_.end(), collapsed_input_shape.begin() + 1);
293

294
  framework::Tensor& collapsed_input = transposed_input;
295
  collapsed_input.Resize(phi::make_ddim(collapsed_input_shape));
296 297

  // make a collpased output
298
  const auto out_dims = phi::vectorize(out->dims());
299 300
  std::vector<int64_t> collapsed_output_shape(1 + signal_ndim);
  collapsed_output_shape[0] = batch_size;
301
  for (size_t i = 0; i < dim.size(); ++i) {
302
    collapsed_output_shape[i + 1] = out_dims[dim[i]];
303
  }
304
  framework::Tensor collapsed_output;
305
  collapsed_output.Resize(phi::make_ddim(collapsed_output_shape));
306 307
  collapsed_output.mutable_data<To>(tensor_place);

F
Feiyu Chan 已提交
308 309
  FFTConfig* config = nullptr;

310
#if defined(PADDLE_WITH_CUDA)
F
Feiyu Chan 已提交
311
  std::unique_ptr<FFTConfig> config_ = nullptr;
312
  // create plan
F
Feiyu Chan 已提交
313 314
  FFTConfigKey key =
      create_fft_configkey(collapsed_input, collapsed_output, signal_ndim);
315 316 317 318 319 320
  bool using_cache = false;
#if !defined(CUFFT_VERSION) || (CUFFT_VERSION < 10200)
  using_cache = true;
#endif

  if (using_cache) {
F
Feiyu Chan 已提交
321 322 323 324 325 326 327 328 329 330 331 332
    const int64_t device_id = static_cast<int64_t>(
        reinterpret_cast<const platform::CUDAPlace*>(&collapsed_input.place())
            ->GetDeviceId());
    FFTConfigCache& plan_cache = get_fft_plan_cache(device_id);
    std::unique_lock<std::mutex> guard(plan_cache.mutex, std::defer_lock);
    guard.lock();
    config = &(plan_cache.lookup(key));
  } else {
    config_ = std::make_unique<FFTConfig>(key);
    config = config_.get();
  }

333
  // prepare cufft for execution
334
  PADDLE_ENFORCE_GPU_SUCCESS(
F
Feiyu Chan 已提交
335
      platform::dynload::cufftSetStream(config->plan(), ctx.stream()));
336
  framework::Tensor workspace_tensor;
F
Feiyu Chan 已提交
337
  workspace_tensor.mutable_data<To>(tensor_place, config->workspace_size());
338
  PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cufftSetWorkArea(
F
Feiyu Chan 已提交
339
      config->plan(), workspace_tensor.data<To>()));
340
  // execute transform plan
F
Feiyu Chan 已提交
341
  exec_cufft_plan<DeviceContext, Ti, To>(ctx, *config, &collapsed_input,
342
                                         &collapsed_output, forward);
343

344 345
#elif defined(PADDLE_WITH_HIP)
  // create plan
F
Feiyu Chan 已提交
346 347 348 349 350 351 352 353 354 355
  FFTConfigKey key =
      create_fft_configkey(collapsed_input, collapsed_output, signal_ndim);
  const int64_t device_id = static_cast<int64_t>(
      reinterpret_cast<const platform::CUDAPlace*>(&collapsed_input.place())
          ->GetDeviceId());
  FFTConfigCache& plan_cache = get_fft_plan_cache(device_id);
  std::unique_lock<std::mutex> guard(plan_cache.mutex, std::defer_lock);
  guard.lock();
  config = &(plan_cache.lookup(key));

356
  // prepare cufft for execution
357
  PADDLE_ENFORCE_GPU_SUCCESS(
F
Feiyu Chan 已提交
358
      platform::dynload::hipfftSetStream(config->plan(), ctx.stream()));
359
  framework::Tensor workspace_tensor;
F
Feiyu Chan 已提交
360
  workspace_tensor.mutable_data<To>(tensor_place, config->workspace_size());
361
  PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::hipfftSetWorkArea(
F
Feiyu Chan 已提交
362
      config->plan(), workspace_tensor.data<To>()));
363
  // execute transform plan
F
Feiyu Chan 已提交
364
  exec_hipfft_plan<DeviceContext, Ti, To>(ctx, *config, &collapsed_input,
365 366
                                          &collapsed_output, forward);
#endif
367 368

  // Inverting output by reshape and transpose to original batch and dimension
369
  auto transposed_out_shape = out->dims().transpose(dim_permute);
370

371 372
  collapsed_output.Resize(transposed_out_shape);
  auto& transposed_output = collapsed_output;
373

374 375 376
  std::vector<int> reverse_dim_permute(ndim);
  for (size_t i = 0; i < ndim; i++) {
    reverse_dim_permute[dim_permute[i]] = i;
377 378
  }

379 380
  TransCompute<DeviceContext, To>(ndim, ctx, transposed_output, out,
                                  reverse_dim_permute);
381
}
382

383 384 385 386
}  // anonymous namespace

// Use the optimized path to perform single R2C or C2R if transformation dim is
// supported by cuFFT
F
Feiyu Chan 已提交
387
bool use_optimized_fft_path(const std::vector<int64_t>& axes) {
388 389
  // For performance reason, when axes starts with (0, 1), do not use the
  // optimized path.
F
Feiyu Chan 已提交
390
  if (axes.size() > kMaxFFTNdim ||
391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408
      (axes.size() >= 2 && axes[0] == 0 && axes[1] == 1)) {
    return false;
  } else {
    return true;
  }
}

template <typename Ti, typename To>
struct FFTC2CFunctor<platform::CUDADeviceContext, Ti, To> {
  void operator()(const platform::CUDADeviceContext& ctx, const Tensor* X,
                  Tensor* out, const std::vector<int64_t>& axes,
                  FFTNormMode normalization, bool forward) {
    if (axes.empty()) {
      framework::TensorCopy(*X, ctx.GetPlace(), out);
      return;
    }

    framework::Tensor* p_out = out;
409
    std::vector<int64_t> out_dims = phi::vectorize(X->dims());
410 411 412 413 414 415 416 417 418 419
    std::vector<int64_t> working_axes(axes.begin(), axes.end());
    std::vector<int64_t> first_dims;
    size_t max_dims;
    framework::Tensor working_tensor;
    working_tensor.mutable_data<Ti>(X->dims(), ctx.GetPlace());
    framework::Tensor* p_working_tensor = &working_tensor;
    framework::TensorCopy(*X, ctx.GetPlace(), &working_tensor);

    while (true) {
      max_dims =
F
Feiyu Chan 已提交
420
          std::min(static_cast<size_t>(kMaxFFTNdim), working_axes.size());
421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443
      first_dims.assign(working_axes.end() - max_dims, working_axes.end());

      exec_fft<platform::CUDADeviceContext, Ti, To>(ctx, p_working_tensor,
                                                    p_out, first_dims, forward);
      working_axes.resize(working_axes.size() - max_dims);
      first_dims.clear();

      if (working_axes.empty()) {
        break;
      }

      std::swap(p_out, p_working_tensor);
    }
    exec_normalization<platform::CUDADeviceContext, To>(
        ctx, p_out, out, normalization, out_dims, axes);
  }
};

template <typename Ti, typename To>
struct FFTC2RFunctor<platform::CUDADeviceContext, Ti, To> {
  void operator()(const platform::CUDADeviceContext& ctx, const Tensor* X,
                  Tensor* out, const std::vector<int64_t>& axes,
                  FFTNormMode normalization, bool forward) {
444 445
    std::vector<int64_t> in_dims = phi::vectorize(X->dims());
    std::vector<int64_t> out_dims = phi::vectorize(out->dims());
446

F
Feiyu Chan 已提交
447
    if (use_optimized_fft_path(axes)) {
448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477
      framework::Tensor x_copy(X->type());
      x_copy.mutable_data<Ti>(X->dims(), ctx.GetPlace());
      framework::TensorCopy(*X, ctx.GetPlace(), &x_copy);
      exec_fft<platform::CUDADeviceContext, Ti, To>(ctx, &x_copy, out, axes,
                                                    forward);
    } else {
      framework::Tensor temp_tensor;
      temp_tensor.mutable_data<Ti>(X->dims(), ctx.GetPlace());
      const std::vector<int64_t> dims(axes.begin(), axes.end() - 1);

      FFTC2CFunctor<platform::CUDADeviceContext, Ti, Ti> c2c_functor;
      c2c_functor(ctx, X, &temp_tensor, dims, FFTNormMode::none, forward);

      exec_fft<platform::CUDADeviceContext, Ti, To>(ctx, &temp_tensor, out,
                                                    {axes.back()}, forward);
    }
    exec_normalization<platform::CUDADeviceContext, To>(
        ctx, out, out, normalization, out_dims, axes);
  }
};

// n dimension real to complex FFT use cufft lib
template <typename Ti, typename To>
struct FFTR2CFunctor<platform::CUDADeviceContext, Ti, To> {
  void operator()(const platform::CUDADeviceContext& ctx, const Tensor* X,
                  Tensor* out, const std::vector<int64_t>& axes,
                  FFTNormMode normalization, bool forward) {
    // Step1: R2C transform on the last dimension
    framework::Tensor* r2c_out = out;
    const std::vector<int64_t> last_dim{axes.back()};
478
    std::vector<int64_t> out_dims = phi::vectorize(out->dims());
479 480 481 482 483 484 485 486 487 488 489 490 491
    exec_fft<platform::CUDADeviceContext, Ti, To>(ctx, X, r2c_out, last_dim,
                                                  forward);

    // Step2: C2C transform on the remaining dimension
    framework::Tensor c2c_out;
    if (axes.size() > 1) {
      c2c_out.mutable_data<To>(out->dims(), ctx.GetPlace());
      std::vector<int64_t> remain_dim(axes.begin(), axes.end() - 1);
      FFTC2CFunctor<platform::CUDADeviceContext, To, To> fft_c2c_func;
      fft_c2c_func(ctx, r2c_out, &c2c_out, remain_dim, FFTNormMode::none,
                   forward);
    }

492
    const auto in_sizes = phi::vectorize(X->dims());
493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528
    framework::Tensor* norm_tensor = axes.size() > 1 ? &c2c_out : r2c_out;
    exec_normalization<platform::CUDADeviceContext, To>(
        ctx, norm_tensor, out, normalization, in_sizes, axes);
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
    fft_c2c, ops::FFTC2CKernel<paddle::platform::CUDADeviceContext, float>,
    ops::FFTC2CKernel<paddle::platform::CUDADeviceContext, double>);

REGISTER_OP_CUDA_KERNEL(
    fft_c2c_grad,
    ops::FFTC2CGradKernel<paddle::platform::CUDADeviceContext, float>,
    ops::FFTC2CGradKernel<paddle::platform::CUDADeviceContext, double>);

REGISTER_OP_CUDA_KERNEL(
    fft_c2r, ops::FFTC2RKernel<paddle::platform::CUDADeviceContext, float>,
    ops::FFTC2RKernel<paddle::platform::CUDADeviceContext, double>);

REGISTER_OP_CUDA_KERNEL(
    fft_c2r_grad,
    ops::FFTC2RGradKernel<paddle::platform::CUDADeviceContext, float>,
    ops::FFTC2RGradKernel<paddle::platform::CUDADeviceContext, double>);

REGISTER_OP_CUDA_KERNEL(
    fft_r2c, ops::FFTR2CKernel<paddle::platform::CUDADeviceContext, float>,
    ops::FFTR2CKernel<paddle::platform::CUDADeviceContext, double>);

REGISTER_OP_CUDA_KERNEL(
    fft_r2c_grad,
    ops::FFTR2CGradKernel<paddle::platform::CUDADeviceContext, float>,
    ops::FFTR2CGradKernel<paddle::platform::CUDADeviceContext, double>);