未验证 提交 5fa44c34 编写于 作者: C chentianyu03 提交者: GitHub

modify Ops to complex template (#33041)

* modify conj, real, imag OP to complex template

* replace with complex template to dot Op

* replace with complex template to Abs Op

* add support for complex64 and complex128
上级 86ea8dce
...@@ -164,9 +164,9 @@ REGISTER_OP_CPU_KERNEL( ...@@ -164,9 +164,9 @@ REGISTER_OP_CPU_KERNEL(
ops::AbsKernel<paddle::platform::CPUDeviceContext, int>, ops::AbsKernel<paddle::platform::CPUDeviceContext, int>,
ops::AbsKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::AbsKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::AbsKernel<paddle::platform::CPUDeviceContext, ops::AbsKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::AbsKernel<paddle::platform::CPUDeviceContext, ops::AbsKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
abs_grad, ops::AbsGradKernel<paddle::platform::CPUDeviceContext, float>, abs_grad, ops::AbsGradKernel<paddle::platform::CPUDeviceContext, float>,
...@@ -174,9 +174,9 @@ REGISTER_OP_CPU_KERNEL( ...@@ -174,9 +174,9 @@ REGISTER_OP_CPU_KERNEL(
ops::AbsGradKernel<paddle::platform::CPUDeviceContext, int>, ops::AbsGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::AbsGradKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::AbsGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::AbsGradKernel<paddle::platform::CPUDeviceContext, ops::AbsGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::AbsGradKernel<paddle::platform::CPUDeviceContext, ops::AbsGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
abs_grad_grad, abs_grad_grad,
...@@ -187,6 +187,6 @@ REGISTER_OP_CPU_KERNEL( ...@@ -187,6 +187,6 @@ REGISTER_OP_CPU_KERNEL(
ops::AbsDoubleGradKernel<paddle::platform::CPUDeviceContext, ops::AbsDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>, paddle::platform::float16>,
ops::AbsDoubleGradKernel<paddle::platform::CPUDeviceContext, ops::AbsDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::AbsDoubleGradKernel<paddle::platform::CPUDeviceContext, ops::AbsDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
...@@ -70,8 +70,8 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -70,8 +70,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::AbsKernel<plat::CUDADeviceContext, int>, ops::AbsKernel<plat::CUDADeviceContext, int>,
ops::AbsKernel<plat::CUDADeviceContext, int64_t>, ops::AbsKernel<plat::CUDADeviceContext, int64_t>,
ops::AbsKernel<plat::CUDADeviceContext, plat::float16>, ops::AbsKernel<plat::CUDADeviceContext, plat::float16>,
ops::AbsKernel<plat::CUDADeviceContext, plat::complex64>, ops::AbsKernel<plat::CUDADeviceContext, plat::complex<float>>,
ops::AbsKernel<plat::CUDADeviceContext, plat::complex128>); ops::AbsKernel<plat::CUDADeviceContext, plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
abs_grad, ops::AbsGradKernel<plat::CUDADeviceContext, float>, abs_grad, ops::AbsGradKernel<plat::CUDADeviceContext, float>,
...@@ -79,8 +79,8 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -79,8 +79,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::AbsGradKernel<plat::CUDADeviceContext, int>, ops::AbsGradKernel<plat::CUDADeviceContext, int>,
ops::AbsGradKernel<plat::CUDADeviceContext, int64_t>, ops::AbsGradKernel<plat::CUDADeviceContext, int64_t>,
ops::AbsGradKernel<plat::CUDADeviceContext, plat::float16>, ops::AbsGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::AbsGradKernel<plat::CUDADeviceContext, plat::complex64>, ops::AbsGradKernel<plat::CUDADeviceContext, plat::complex<float>>,
ops::AbsGradKernel<plat::CUDADeviceContext, plat::complex128>); ops::AbsGradKernel<plat::CUDADeviceContext, plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
abs_grad_grad, ops::AbsDoubleGradKernel<plat::CUDADeviceContext, float>, abs_grad_grad, ops::AbsDoubleGradKernel<plat::CUDADeviceContext, float>,
...@@ -88,5 +88,5 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -88,5 +88,5 @@ REGISTER_OP_CUDA_KERNEL(
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, int>, ops::AbsDoubleGradKernel<plat::CUDADeviceContext, int>,
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, int64_t>, ops::AbsDoubleGradKernel<plat::CUDADeviceContext, int64_t>,
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, plat::float16>, ops::AbsDoubleGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, plat::complex64>, ops::AbsDoubleGradKernel<plat::CUDADeviceContext, plat::complex<float>>,
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, plat::complex128>); ops::AbsDoubleGradKernel<plat::CUDADeviceContext, plat::complex<double>>);
...@@ -78,9 +78,9 @@ REGISTER_OPERATOR(conj, ops::ConjOp, ops::ConjOpMaker, ...@@ -78,9 +78,9 @@ REGISTER_OPERATOR(conj, ops::ConjOp, ops::ConjOpMaker,
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conj, ops::ConjKernel<paddle::platform::CPUDeviceContext, conj, ops::ConjKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::ConjKernel<paddle::platform::CPUDeviceContext, ops::ConjKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>, paddle::platform::complex<double>>,
ops::ConjKernel<paddle::platform::CPUDeviceContext, float>, ops::ConjKernel<paddle::platform::CPUDeviceContext, float>,
ops::ConjKernel<paddle::platform::CPUDeviceContext, double>, ops::ConjKernel<paddle::platform::CPUDeviceContext, double>,
ops::ConjKernel<paddle::platform::CPUDeviceContext, int>, ops::ConjKernel<paddle::platform::CPUDeviceContext, int>,
......
...@@ -13,15 +13,14 @@ ...@@ -13,15 +13,14 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/conj_op.h" #include "paddle/fluid/operators/conj_op.h"
#include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex64.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
conj, ops::ConjKernel<paddle::platform::CUDADeviceContext, conj, ops::ConjKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::ConjKernel<paddle::platform::CUDADeviceContext, ops::ConjKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>, paddle::platform::complex<double>>,
ops::ConjKernel<paddle::platform::CUDADeviceContext, float>, ops::ConjKernel<paddle::platform::CUDADeviceContext, float>,
ops::ConjKernel<paddle::platform::CUDADeviceContext, double>, ops::ConjKernel<paddle::platform::CUDADeviceContext, double>,
ops::ConjKernel<paddle::platform::CUDADeviceContext, int>, ops::ConjKernel<paddle::platform::CUDADeviceContext, int>,
......
...@@ -33,7 +33,7 @@ class DotOp : public framework::OperatorWithKernel { ...@@ -33,7 +33,7 @@ class DotOp : public framework::OperatorWithKernel {
"Output(Out) of DotOp should not be null.")); "Output(Out) of DotOp should not be null."));
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto x_rank = (size_t)x_dims.size(); auto x_rank = static_cast<size_t>(x_dims.size());
PADDLE_ENFORCE_EQ(true, 1 == x_rank || 2 == x_rank, PADDLE_ENFORCE_EQ(true, 1 == x_rank || 2 == x_rank,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"ShapeError: The dimensions of input tensor X (%s) " "ShapeError: The dimensions of input tensor X (%s) "
...@@ -154,15 +154,15 @@ REGISTER_OP_CPU_KERNEL( ...@@ -154,15 +154,15 @@ REGISTER_OP_CPU_KERNEL(
ops::DotKernel<paddle::platform::CPUDeviceContext, int>, ops::DotKernel<paddle::platform::CPUDeviceContext, int>,
ops::DotKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::DotKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::DotKernel<paddle::platform::CPUDeviceContext, ops::DotKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::DotKernel<paddle::platform::CPUDeviceContext, ops::DotKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
dot_grad, ops::DotGradKernel<paddle::platform::CPUDeviceContext, float>, dot_grad, ops::DotGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext, double>, ops::DotGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext, int>, ops::DotGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::DotGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext, ops::DotGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext, ops::DotGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
...@@ -22,12 +22,14 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -22,12 +22,14 @@ REGISTER_OP_CUDA_KERNEL(
ops::DotKernel<plat::CUDADeviceContext, double>, ops::DotKernel<plat::CUDADeviceContext, double>,
ops::DotKernel<plat::CUDADeviceContext, int>, ops::DotKernel<plat::CUDADeviceContext, int>,
ops::DotKernel<plat::CUDADeviceContext, int64_t>, ops::DotKernel<plat::CUDADeviceContext, int64_t>,
ops::DotKernel<plat::CUDADeviceContext, paddle::platform::complex64>, ops::DotKernel<plat::CUDADeviceContext, paddle::platform::complex<float>>,
ops::DotKernel<plat::CUDADeviceContext, paddle::platform::complex128>); ops::DotKernel<plat::CUDADeviceContext, paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(dot_grad,
dot_grad, ops::DotGradKernel<plat::CUDADeviceContext, float>, ops::DotGradKernel<plat::CUDADeviceContext, float>,
ops::DotGradKernel<plat::CUDADeviceContext, double>, ops::DotGradKernel<plat::CUDADeviceContext, double>,
ops::DotGradKernel<plat::CUDADeviceContext, int>, ops::DotGradKernel<plat::CUDADeviceContext, int>,
ops::DotGradKernel<plat::CUDADeviceContext, int64_t>, ops::DotGradKernel<plat::CUDADeviceContext, int64_t>,
ops::DotGradKernel<plat::CUDADeviceContext, paddle::platform::complex64>, ops::DotGradKernel<plat::CUDADeviceContext,
ops::DotGradKernel<plat::CUDADeviceContext, paddle::platform::complex128>); paddle::platform::complex<float>>,
ops::DotGradKernel<plat::CUDADeviceContext,
paddle::platform::complex<double>>);
...@@ -96,11 +96,11 @@ REGISTER_OPERATOR(imag, ops::ImagOp, ops::ImagOpMaker, ...@@ -96,11 +96,11 @@ REGISTER_OPERATOR(imag, ops::ImagOp, ops::ImagOpMaker,
REGISTER_OPERATOR(imag_grad, ops::ImagGradOp); REGISTER_OPERATOR(imag_grad, ops::ImagGradOp);
REGISTER_OP_CPU_KERNEL(imag, ops::ImagKernel<paddle::platform::CPUDeviceContext, REGISTER_OP_CPU_KERNEL(imag, ops::ImagKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::ImagKernel<paddle::platform::CPUDeviceContext, ops::ImagKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(imag_grad, REGISTER_OP_CPU_KERNEL(imag_grad,
ops::ImagGradKernel<paddle::platform::CPUDeviceContext, ops::ImagGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::ImagGradKernel<paddle::platform::CPUDeviceContext, ops::ImagGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
...@@ -18,11 +18,11 @@ namespace ops = paddle::operators; ...@@ -18,11 +18,11 @@ namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(imag, REGISTER_OP_CUDA_KERNEL(imag,
ops::ImagKernel<paddle::platform::CUDADeviceContext, ops::ImagKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::ImagKernel<paddle::platform::CUDADeviceContext, ops::ImagKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(imag_grad, REGISTER_OP_CUDA_KERNEL(imag_grad,
ops::ImagGradKernel<paddle::platform::CUDADeviceContext, ops::ImagGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::ImagGradKernel<paddle::platform::CUDADeviceContext, ops::ImagGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
...@@ -16,8 +16,7 @@ limitations under the License. */ ...@@ -16,8 +16,7 @@ limitations under the License. */
#include <type_traits> #include <type_traits>
#include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/hostdevice.h" #include "paddle/fluid/platform/hostdevice.h"
namespace paddle { namespace paddle {
...@@ -66,7 +65,10 @@ using select_t = typename select<Head, Tail...>::type; ...@@ -66,7 +65,10 @@ using select_t = typename select<Head, Tail...>::type;
template <typename T> template <typename T>
using Real = using Real =
select_t<cond<std::is_same<T, platform::complex64>::value, float>, select_t<cond<std::is_same<T, platform::complex64>::value, float>,
cond<std::is_same<T, platform::complex128>::value, double>, T>; cond<std::is_same<T, platform::complex128>::value, double>,
cond<std::is_same<T, platform::complex<float>>::value, float>,
cond<std::is_same<T, platform::complex<double>>::value, double>,
T>;
template <typename T, typename RealT> template <typename T, typename RealT>
using Complex = typename std::enable_if<!std::is_same<T, RealT>::value>::type; using Complex = typename std::enable_if<!std::is_same<T, RealT>::value>::type;
...@@ -76,14 +78,18 @@ template <typename T, typename RealT> ...@@ -76,14 +78,18 @@ template <typename T, typename RealT>
using NoComplex = typename std::enable_if<std::is_same<T, RealT>::value>::type; using NoComplex = typename std::enable_if<std::is_same<T, RealT>::value>::type;
template <typename T> template <typename T>
using EnableComplex = using EnableComplex = typename std::enable_if<
typename std::enable_if<std::is_same<T, platform::complex64>::value || std::is_same<T, platform::complex64>::value ||
std::is_same<T, platform::complex128>::value>::type; std::is_same<T, platform::complex128>::value ||
std::is_same<T, platform::complex<float>>::value ||
std::is_same<T, platform::complex<double>>::value>::type;
template <typename T> template <typename T>
using DisableComplex = typename std::enable_if< using DisableComplex = typename std::enable_if<
!std::is_same<T, platform::complex64>::value && !std::is_same<T, platform::complex64>::value &&
!std::is_same<T, platform::complex128>::value>::type; !std::is_same<T, platform::complex128>::value &&
!std::is_same<T, platform::complex<float>>::value &&
!std::is_same<T, platform::complex<double>>::value>::type;
template <typename T, typename Enable = void> template <typename T, typename Enable = void>
struct RealFunctor; struct RealFunctor;
...@@ -173,44 +179,45 @@ struct AbsGradFunctor { ...@@ -173,44 +179,45 @@ struct AbsGradFunctor {
}; };
template <> template <>
struct AbsGradFunctor<paddle::platform::complex64> { struct AbsGradFunctor<paddle::platform::complex<float>> {
AbsGradFunctor(const float* dout, const paddle::platform::complex64* x, AbsGradFunctor(const float* dout, const paddle::platform::complex<float>* x,
paddle::platform::complex64* output, int64_t numel) paddle::platform::complex<float>* output, int64_t numel)
: dout_(dout), x_(x), output_(output), numel_(numel) {} : dout_(dout), x_(x), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const { HOSTDEVICE void operator()(int64_t idx) const {
if (x_[idx] == paddle::platform::complex64(0)) { if (x_[idx] == paddle::platform::complex<float>(0)) {
output_[idx] = paddle::platform::complex64(0); output_[idx] = paddle::platform::complex<float>(0);
} else { } else {
output_[idx] = paddle::platform::complex64(dout_[idx]) * output_[idx] = paddle::platform::complex<float>(dout_[idx]) *
(x_[idx] / paddle::platform::complex64(abs(x_[idx]))); (x_[idx] / paddle::platform::complex<float>(abs(x_[idx])));
} }
} }
const float* dout_; const float* dout_;
const paddle::platform::complex64* x_; const paddle::platform::complex<float>* x_;
paddle::platform::complex64* output_; paddle::platform::complex<float>* output_;
int64_t numel_; int64_t numel_;
}; };
template <> template <>
struct AbsGradFunctor<paddle::platform::complex128> { struct AbsGradFunctor<paddle::platform::complex<double>> {
AbsGradFunctor(const double* dout, const paddle::platform::complex128* x, AbsGradFunctor(const double* dout, const paddle::platform::complex<double>* x,
paddle::platform::complex128* output, int64_t numel) paddle::platform::complex<double>* output, int64_t numel)
: dout_(dout), x_(x), output_(output), numel_(numel) {} : dout_(dout), x_(x), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const { HOSTDEVICE void operator()(int64_t idx) const {
if (x_[idx] == paddle::platform::complex128(0)) { if (x_[idx] == paddle::platform::complex<double>(0)) {
output_[idx] = paddle::platform::complex128(0); output_[idx] = paddle::platform::complex<double>(0);
} else { } else {
output_[idx] = paddle::platform::complex128(dout_[idx]) * output_[idx] =
(x_[idx] / paddle::platform::complex128(abs(x_[idx]))); paddle::platform::complex<double>(dout_[idx]) *
(x_[idx] / paddle::platform::complex<double>(abs(x_[idx])));
} }
} }
const double* dout_; const double* dout_;
const paddle::platform::complex128* x_; const paddle::platform::complex<double>* x_;
paddle::platform::complex128* output_; paddle::platform::complex<double>* output_;
int64_t numel_; int64_t numel_;
}; };
...@@ -234,46 +241,46 @@ struct AbsGradGradFunctor { ...@@ -234,46 +241,46 @@ struct AbsGradGradFunctor {
}; };
template <> template <>
struct AbsGradGradFunctor<paddle::platform::complex128> { struct AbsGradGradFunctor<paddle::platform::complex<double>> {
AbsGradGradFunctor(const paddle::platform::complex128* ddx, AbsGradGradFunctor(const paddle::platform::complex<double>* ddx,
const paddle::platform::complex128* x, const paddle::platform::complex<double>* x,
paddle::platform::complex128* output, int64_t numel) paddle::platform::complex<double>* output, int64_t numel)
: ddx_(ddx), x_(x), output_(output), numel_(numel) {} : ddx_(ddx), x_(x), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const { HOSTDEVICE void operator()(int64_t idx) const {
if (x_[idx] == paddle::platform::complex128(0)) { if (x_[idx] == paddle::platform::complex<double>(0)) {
output_[idx] = paddle::platform::complex128(0); output_[idx] = paddle::platform::complex<double>(0);
} else { } else {
output_[idx] = paddle::platform::complex128(ddx_[idx]) * x_[idx] / output_[idx] = paddle::platform::complex<double>(ddx_[idx]) * x_[idx] /
paddle::platform::complex128(abs(x_[idx])); paddle::platform::complex<double>(abs(x_[idx]));
} }
} }
const paddle::platform::complex128* ddx_; const paddle::platform::complex<double>* ddx_;
const paddle::platform::complex128* x_; const paddle::platform::complex<double>* x_;
paddle::platform::complex128* output_; paddle::platform::complex<double>* output_;
int64_t numel_; int64_t numel_;
}; };
template <> template <>
struct AbsGradGradFunctor<paddle::platform::complex64> { struct AbsGradGradFunctor<paddle::platform::complex<float>> {
AbsGradGradFunctor(const paddle::platform::complex64* ddx, AbsGradGradFunctor(const paddle::platform::complex<float>* ddx,
const paddle::platform::complex64* x, const paddle::platform::complex<float>* x,
paddle::platform::complex64* output, int64_t numel) paddle::platform::complex<float>* output, int64_t numel)
: ddx_(ddx), x_(x), output_(output), numel_(numel) {} : ddx_(ddx), x_(x), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const { HOSTDEVICE void operator()(int64_t idx) const {
if (x_[idx] == paddle::platform::complex64(0)) { if (x_[idx] == paddle::platform::complex<float>(0)) {
output_[idx] = paddle::platform::complex64(0); output_[idx] = paddle::platform::complex<float>(0);
} else { } else {
output_[idx] = paddle::platform::complex64(ddx_[idx]) * x_[idx] / output_[idx] = paddle::platform::complex<float>(ddx_[idx]) * x_[idx] /
paddle::platform::complex64(abs(x_[idx])); paddle::platform::complex<float>(abs(x_[idx]));
} }
} }
const paddle::platform::complex64* ddx_; const paddle::platform::complex<float>* ddx_;
const paddle::platform::complex64* x_; const paddle::platform::complex<float>* x_;
paddle::platform::complex64* output_; paddle::platform::complex<float>* output_;
int64_t numel_; int64_t numel_;
}; };
template <typename T, typename Enable = void> template <typename T, typename Enable = void>
......
...@@ -95,11 +95,11 @@ REGISTER_OPERATOR(real, ops::RealOp, ops::RealOpMaker, ...@@ -95,11 +95,11 @@ REGISTER_OPERATOR(real, ops::RealOp, ops::RealOpMaker,
REGISTER_OPERATOR(real_grad, ops::RealGradOp); REGISTER_OPERATOR(real_grad, ops::RealGradOp);
REGISTER_OP_CPU_KERNEL(real, ops::RealKernel<paddle::platform::CPUDeviceContext, REGISTER_OP_CPU_KERNEL(real, ops::RealKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::RealKernel<paddle::platform::CPUDeviceContext, ops::RealKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(real_grad, REGISTER_OP_CPU_KERNEL(real_grad,
ops::RealGradKernel<paddle::platform::CPUDeviceContext, ops::RealGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::RealGradKernel<paddle::platform::CPUDeviceContext, ops::RealGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
...@@ -18,11 +18,11 @@ namespace ops = paddle::operators; ...@@ -18,11 +18,11 @@ namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(real, REGISTER_OP_CUDA_KERNEL(real,
ops::RealKernel<paddle::platform::CUDADeviceContext, ops::RealKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::RealKernel<paddle::platform::CUDADeviceContext, ops::RealKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(real_grad, REGISTER_OP_CUDA_KERNEL(real_grad,
ops::RealGradKernel<paddle::platform::CUDADeviceContext, ops::RealGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::RealGradKernel<paddle::platform::CUDADeviceContext, ops::RealGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册