complex_functors.h 12.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
16 17 18 19
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES
#endif
#include <cmath>
20 21
#include <type_traits>

22
#include "paddle/phi/common/complex.h"
23
#include "paddle/phi/common/type_traits.h"
24
#include "paddle/phi/core/hostdevice.h"
25

26
namespace phi {
27
namespace funcs {
28 29 30 31 32 33 34 35

template <typename T, typename RealT>
using Complex = typename std::enable_if<!std::is_same<T, RealT>::value>::type;

// There are no NoComplex cases now, implement later if needed
template <typename T, typename RealT>
using NoComplex = typename std::enable_if<std::is_same<T, RealT>::value>::type;

36
template <typename T>
37
using EnableComplex = typename std::enable_if<
38 39
    std::is_same<T, phi::dtype::complex<float>>::value ||
    std::is_same<T, phi::dtype::complex<double>>::value>::type;
40 41 42

template <typename T>
using DisableComplex = typename std::enable_if<
43 44
    !std::is_same<T, phi::dtype::complex<float>>::value &&
    !std::is_same<T, phi::dtype::complex<double>>::value>::type;
45

46 47 48 49
template <typename T, typename Enable = void>
struct RealFunctor;

template <typename T>
50
struct RealFunctor<T, Complex<T, dtype::Real<T>>> {
51
 public:
52
  RealFunctor(const T* input, dtype::Real<T>* output, int64_t numel)
53 54 55 56 57 58 59 60
      : input_(input), output_(output), numel_(numel) {}

  HOSTDEVICE void operator()(int64_t idx) const {
    output_[idx] = input_[idx].real;
  }

 private:
  const T* input_;
61
  dtype::Real<T>* output_;
62 63 64 65 66 67 68
  int64_t numel_;
};

template <typename T, typename Enable = void>
struct ImagFunctor;

template <typename T>
69 70
struct ImagFunctor<T, Complex<T, dtype::Real<T>>> {
  ImagFunctor(const T* input, dtype::Real<T>* output, int64_t numel)
71 72 73 74 75 76 77
      : input_(input), output_(output), numel_(numel) {}

  HOSTDEVICE void operator()(int64_t idx) const {
    output_[idx] = input_[idx].imag;
  }

  const T* input_;
78
  dtype::Real<T>* output_;
79 80 81
  int64_t numel_;
};

82 83 84 85
template <typename T, typename Enable = void>
struct AbsFunctor;

template <typename T>
86 87
struct AbsFunctor<T, Complex<T, dtype::Real<T>>> {
  AbsFunctor(const T* input, dtype::Real<T>* output, int64_t numel)
88 89 90 91 92 93 94
      : input_(input), output_(output), numel_(numel) {}

  HOSTDEVICE void operator()(int64_t idx) const {
    output_[idx] = abs(input_[idx]);
  }

  const T* input_;
95
  dtype::Real<T>* output_;
96 97 98 99
  int64_t numel_;
};

template <typename T>
100
struct AbsFunctor<T, NoComplex<T, dtype::Real<T>>> {
101 102 103 104
  AbsFunctor(const T* input, T* output, int64_t numel)
      : input_(input), output_(output), numel_(numel) {}

  HOSTDEVICE void operator()(int64_t idx) const {
105
    output_[idx] = std::abs(input_[idx]);
106 107 108 109 110 111 112 113 114
  }

  const T* input_;
  T* output_;
  int64_t numel_;
};

template <typename T>
struct AbsGradFunctor {
115 116 117 118
  AbsGradFunctor(const dtype::Real<T>* dout,
                 const T* x,
                 T* output,
                 int64_t numel)
119 120 121 122 123 124
      : dout_(dout), x_(x), output_(output), numel_(numel) {}

  HOSTDEVICE void operator()(int64_t idx) const {
    if (x_[idx] == T(0)) {
      output_[idx] = T(0);
    } else {
125
      output_[idx] = T(dout_[idx]) * (x_[idx] / T(std::abs(x_[idx])));
126 127 128
    }
  }

129
  const dtype::Real<T>* dout_;
130 131 132 133 134
  const T* x_;
  T* output_;
  int64_t numel_;
};

135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
template <>
struct AbsGradFunctor<phi::dtype::bfloat16> {
  AbsGradFunctor(const dtype::Real<phi::dtype::bfloat16>* dout,
                 const phi::dtype::bfloat16* x,
                 phi::dtype::bfloat16* output,
                 int64_t numel)
      : dout_(dout), x_(x), output_(output), numel_(numel) {}

  HOSTDEVICE void operator()(int64_t idx) const {
    if (x_[idx] == static_cast<phi::dtype::bfloat16>(0)) {
      output_[idx] = static_cast<phi::dtype::bfloat16>(0);
    } else {
      output_[idx] = dout_[idx] * (x_[idx] / (abs(x_[idx])));
    }
  }

  const dtype::Real<phi::dtype::bfloat16>* dout_;
  const phi::dtype::bfloat16* x_;
  phi::dtype::bfloat16* output_;
  int64_t numel_;
};

157
template <>
158
struct AbsGradFunctor<phi::dtype::complex<float>> {
159
  AbsGradFunctor(const float* dout,
160 161
                 const phi::dtype::complex<float>* x,
                 phi::dtype::complex<float>* output,
162
                 int64_t numel)
163 164 165
      : dout_(dout), x_(x), output_(output), numel_(numel) {}

  HOSTDEVICE void operator()(int64_t idx) const {
166 167
    if (x_[idx] == phi::dtype::complex<float>(0)) {
      output_[idx] = phi::dtype::complex<float>(0);
168
    } else {
169 170
      output_[idx] = phi::dtype::complex<float>(dout_[idx]) *
                     (x_[idx] / phi::dtype::complex<float>(abs(x_[idx])));
171 172 173 174
    }
  }

  const float* dout_;
175 176
  const phi::dtype::complex<float>* x_;
  phi::dtype::complex<float>* output_;
177 178 179 180
  int64_t numel_;
};

template <>
181
struct AbsGradFunctor<phi::dtype::complex<double>> {
182
  AbsGradFunctor(const double* dout,
183 184
                 const phi::dtype::complex<double>* x,
                 phi::dtype::complex<double>* output,
185
                 int64_t numel)
186 187 188
      : dout_(dout), x_(x), output_(output), numel_(numel) {}

  HOSTDEVICE void operator()(int64_t idx) const {
189 190
    if (x_[idx] == phi::dtype::complex<double>(0)) {
      output_[idx] = phi::dtype::complex<double>(0);
191
    } else {
192 193
      output_[idx] = phi::dtype::complex<double>(dout_[idx]) *
                     (x_[idx] / phi::dtype::complex<double>(abs(x_[idx])));
194 195 196 197
    }
  }

  const double* dout_;
198 199
  const phi::dtype::complex<double>* x_;
  phi::dtype::complex<double>* output_;
200 201 202
  int64_t numel_;
};

203 204 205 206 207 208 209 210 211
template <typename T>
struct AbsGradGradFunctor {
  AbsGradGradFunctor(const T* ddx, const T* x, T* output, int64_t numel)
      : ddx_(ddx), x_(x), output_(output), numel_(numel) {}

  HOSTDEVICE void operator()(int64_t idx) const {
    if (x_[idx] == T(0)) {
      output_[idx] = T(0);
    } else {
212
      output_[idx] = T(ddx_[idx]) * x_[idx] / T(std::abs(x_[idx]));
213 214 215 216 217 218 219 220 221
    }
  }

  const T* ddx_;
  const T* x_;
  T* output_;
  int64_t numel_;
};

222
template <>
223 224 225 226
struct AbsGradGradFunctor<phi::dtype::complex<double>> {
  AbsGradGradFunctor(const phi::dtype::complex<double>* ddx,
                     const phi::dtype::complex<double>* x,
                     phi::dtype::complex<double>* output,
227
                     int64_t numel)
228 229 230
      : ddx_(ddx), x_(x), output_(output), numel_(numel) {}

  HOSTDEVICE void operator()(int64_t idx) const {
231 232
    if (x_[idx] == phi::dtype::complex<double>(0)) {
      output_[idx] = phi::dtype::complex<double>(0);
233
    } else {
234 235
      output_[idx] = phi::dtype::complex<double>(ddx_[idx]) * x_[idx] /
                     phi::dtype::complex<double>(abs(x_[idx]));
236 237 238
    }
  }

239 240 241
  const phi::dtype::complex<double>* ddx_;
  const phi::dtype::complex<double>* x_;
  phi::dtype::complex<double>* output_;
242 243 244 245
  int64_t numel_;
};

template <>
246 247 248 249
struct AbsGradGradFunctor<phi::dtype::complex<float>> {
  AbsGradGradFunctor(const phi::dtype::complex<float>* ddx,
                     const phi::dtype::complex<float>* x,
                     phi::dtype::complex<float>* output,
250
                     int64_t numel)
251 252 253
      : ddx_(ddx), x_(x), output_(output), numel_(numel) {}

  HOSTDEVICE void operator()(int64_t idx) const {
254 255
    if (x_[idx] == phi::dtype::complex<float>(0)) {
      output_[idx] = phi::dtype::complex<float>(0);
256
    } else {
257 258
      output_[idx] = phi::dtype::complex<float>(ddx_[idx]) * x_[idx] /
                     phi::dtype::complex<float>(abs(x_[idx]));
259 260 261
    }
  }

262 263 264
  const phi::dtype::complex<float>* ddx_;
  const phi::dtype::complex<float>* x_;
  phi::dtype::complex<float>* output_;
265 266
  int64_t numel_;
};
267 268 269 270
template <typename T, typename Enable = void>
struct RealToComplexFunctor;

template <typename T>
271 272
struct RealToComplexFunctor<T, Complex<T, dtype::Real<T>>> {
  RealToComplexFunctor(const dtype::Real<T>* input, T* output, int64_t numel)
273 274 275 276 277 278 279
      : input_(input), output_(output), numel_(numel) {}

  HOSTDEVICE void operator()(int64_t idx) const {
    output_[idx].real = input_[idx];
    output_[idx].imag = 0;
  }

280
  const dtype::Real<T>* input_;
281 282 283 284 285 286 287 288
  T* output_;
  int64_t numel_;
};

template <typename T, typename Enable = void>
struct ImagToComplexFunctor;

template <typename T>
289 290
struct ImagToComplexFunctor<T, Complex<T, dtype::Real<T>>> {
  ImagToComplexFunctor(const dtype::Real<T>* input, T* output, int64_t numel)
291 292 293 294 295 296 297
      : input_(input), output_(output), numel_(numel) {}

  HOSTDEVICE void operator()(int64_t idx) const {
    output_[idx].real = 0;
    output_[idx].imag = input_[idx];
  }

298
  const dtype::Real<T>* input_;
299 300 301 302
  T* output_;
  int64_t numel_;
};

303 304 305 306
template <typename T, typename Enable = void>
struct RealImagToComplexFunctor;

template <typename T>
307 308 309
struct RealImagToComplexFunctor<T, Complex<T, dtype::Real<T>>> {
  RealImagToComplexFunctor(const dtype::Real<T>* input_real,
                           const dtype::Real<T>* input_imag,
310 311
                           T* output,
                           int64_t numel)
312 313 314 315 316 317 318 319 320 321
      : input_real_(input_real),
        input_imag_(input_imag),
        output_(output),
        numel_(numel) {}

  HOSTDEVICE void operator()(int64_t idx) const {
    output_[idx].real = input_real_[idx];
    output_[idx].imag = input_imag_[idx];
  }

322 323
  const dtype::Real<T>* input_real_;
  const dtype::Real<T>* input_imag_;
324 325 326 327
  T* output_;
  int64_t numel_;
};

C
chentianyu03 已提交
328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354
template <typename T, typename Enable = void>
struct ConjFunctor;

template <typename T>
struct ConjFunctor<T, EnableComplex<T>> {
  ConjFunctor(const T* input, int64_t numel, T* output)
      : input_(input), numel_(numel), output_(output) {}

  HOSTDEVICE void operator()(size_t idx) const {
    output_[idx] = T(input_[idx].real, -input_[idx].imag);
  }
  const T* input_;
  int64_t numel_;
  T* output_;
};

template <typename T>
struct ConjFunctor<T, DisableComplex<T>> {
  ConjFunctor(const T* input, int64_t numel, T* output)
      : input_(input), numel_(numel), output_(output) {}

  HOSTDEVICE void operator()(size_t idx) const { output_[idx] = input_[idx]; }
  const T* input_;
  int64_t numel_;
  T* output_;
};

355 356 357 358 359
template <typename T, typename Enable = void>
struct AngleFunctor;

// angel function for complex
template <typename T>
360 361
struct AngleFunctor<T, phi::funcs::Complex<T, dtype::Real<T>>> {
  AngleFunctor(const T* input, dtype::Real<T>* output, int64_t numel)
362 363 364 365 366 367 368
      : input_(input), output_(output), numel_(numel) {}

  HOSTDEVICE void operator()(int64_t idx) const {
    output_[idx] = arg(input_[idx]);
  }

  const T* input_;
369
  dtype::Real<T>* output_;
370 371 372 373 374
  int64_t numel_;
};

// angel function for real
template <typename T>
375
struct AngleFunctor<T, phi::funcs::NoComplex<T, dtype::Real<T>>> {
376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392
  AngleFunctor(const T* input, T* output, int64_t numel)
      : input_(input), output_(output), numel_(numel) {}

  HOSTDEVICE void operator()(int64_t idx) const {
    output_[idx] = input_[idx] < static_cast<T>(0) ? M_PI : 0;
  }

  const T* input_;
  T* output_;
  int64_t numel_;
};

template <typename T, typename Enable = void>
struct AngleGradFunctor;

// angle grad for complex
template <typename T>
393 394
struct AngleGradFunctor<T, phi::funcs::Complex<T, dtype::Real<T>>> {
  AngleGradFunctor(const dtype::Real<T>* dout, const T* x, T* dx, int64_t numel)
395 396 397 398 399 400
      : dout_(dout), x_(x), dx_(dx), numel_(numel) {}

  HOSTDEVICE void operator()(int64_t idx) const {
    if (x_[idx] == T(0)) {
      dx_[idx] = T(0);
    } else {
401
      const phi::dtype::Real<T> r_square =
402 403 404 405 406 407
          x_[idx].real * x_[idx].real + x_[idx].imag * x_[idx].imag;
      dx_[idx] = T(-dout_[idx] * x_[idx].imag / r_square,
                   dout_[idx] * x_[idx].real / r_square);
    }
  }

408
  const phi::dtype::Real<T>* dout_;
409 410 411 412 413 414 415
  const T* x_;
  T* dx_;
  int64_t numel_;
};

// angle grad for real
template <typename T>
416 417
struct AngleGradFunctor<T, phi::funcs::NoComplex<T, dtype::Real<T>>> {
  AngleGradFunctor(const dtype::Real<T>* dout, const T* x, T* dx, int64_t numel)
418 419 420 421
      : dout_(dout), x_(x), dx_(dx), numel_(numel) {}

  HOSTDEVICE void operator()(int64_t idx) const { dx_[idx] = 0; }

422
  const dtype::Real<T>* dout_;
423 424 425 426 427 428
  const T* x_;
  T* dx_;
  int64_t numel_;
};

}  // namespace funcs
429
}  // namespace phi