未验证 提交 5fa44c34 编写于 作者: C chentianyu03 提交者: GitHub

modify Ops to complex template (#33041)

* modify conj, real, imag OP to complex template

* replace with complex template to dot Op

* replace with complex template to Abs Op

* add support for complex64 and complex128
上级 86ea8dce
......@@ -164,9 +164,9 @@ REGISTER_OP_CPU_KERNEL(
ops::AbsKernel<paddle::platform::CPUDeviceContext, int>,
ops::AbsKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::AbsKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::AbsKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
abs_grad, ops::AbsGradKernel<paddle::platform::CPUDeviceContext, float>,
......@@ -174,9 +174,9 @@ REGISTER_OP_CPU_KERNEL(
ops::AbsGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::AbsGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::AbsGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::AbsGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
abs_grad_grad,
......@@ -187,6 +187,6 @@ REGISTER_OP_CPU_KERNEL(
ops::AbsDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>,
ops::AbsDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::AbsDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
......@@ -70,8 +70,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::AbsKernel<plat::CUDADeviceContext, int>,
ops::AbsKernel<plat::CUDADeviceContext, int64_t>,
ops::AbsKernel<plat::CUDADeviceContext, plat::float16>,
ops::AbsKernel<plat::CUDADeviceContext, plat::complex64>,
ops::AbsKernel<plat::CUDADeviceContext, plat::complex128>);
ops::AbsKernel<plat::CUDADeviceContext, plat::complex<float>>,
ops::AbsKernel<plat::CUDADeviceContext, plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
abs_grad, ops::AbsGradKernel<plat::CUDADeviceContext, float>,
......@@ -79,8 +79,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::AbsGradKernel<plat::CUDADeviceContext, int>,
ops::AbsGradKernel<plat::CUDADeviceContext, int64_t>,
ops::AbsGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::AbsGradKernel<plat::CUDADeviceContext, plat::complex64>,
ops::AbsGradKernel<plat::CUDADeviceContext, plat::complex128>);
ops::AbsGradKernel<plat::CUDADeviceContext, plat::complex<float>>,
ops::AbsGradKernel<plat::CUDADeviceContext, plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
abs_grad_grad, ops::AbsDoubleGradKernel<plat::CUDADeviceContext, float>,
......@@ -88,5 +88,5 @@ REGISTER_OP_CUDA_KERNEL(
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, int>,
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, int64_t>,
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, plat::complex64>,
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, plat::complex128>);
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, plat::complex<float>>,
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, plat::complex<double>>);
......@@ -78,9 +78,9 @@ REGISTER_OPERATOR(conj, ops::ConjOp, ops::ConjOpMaker,
REGISTER_OP_CPU_KERNEL(
conj, ops::ConjKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ConjKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>,
paddle::platform::complex<double>>,
ops::ConjKernel<paddle::platform::CPUDeviceContext, float>,
ops::ConjKernel<paddle::platform::CPUDeviceContext, double>,
ops::ConjKernel<paddle::platform::CPUDeviceContext, int>,
......
......@@ -13,15 +13,14 @@
// limitations under the License.
#include "paddle/fluid/operators/conj_op.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
conj, ops::ConjKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ConjKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>,
paddle::platform::complex<double>>,
ops::ConjKernel<paddle::platform::CUDADeviceContext, float>,
ops::ConjKernel<paddle::platform::CUDADeviceContext, double>,
ops::ConjKernel<paddle::platform::CUDADeviceContext, int>,
......
......@@ -33,7 +33,7 @@ class DotOp : public framework::OperatorWithKernel {
"Output(Out) of DotOp should not be null."));
auto x_dims = ctx->GetInputDim("X");
auto x_rank = (size_t)x_dims.size();
auto x_rank = static_cast<size_t>(x_dims.size());
PADDLE_ENFORCE_EQ(true, 1 == x_rank || 2 == x_rank,
platform::errors::PreconditionNotMet(
"ShapeError: The dimensions of input tensor X (%s) "
......@@ -154,15 +154,15 @@ REGISTER_OP_CPU_KERNEL(
ops::DotKernel<paddle::platform::CPUDeviceContext, int>,
ops::DotKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::DotKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::DotKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
dot_grad, ops::DotGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
......@@ -22,12 +22,14 @@ REGISTER_OP_CUDA_KERNEL(
ops::DotKernel<plat::CUDADeviceContext, double>,
ops::DotKernel<plat::CUDADeviceContext, int>,
ops::DotKernel<plat::CUDADeviceContext, int64_t>,
ops::DotKernel<plat::CUDADeviceContext, paddle::platform::complex64>,
ops::DotKernel<plat::CUDADeviceContext, paddle::platform::complex128>);
REGISTER_OP_CUDA_KERNEL(
dot_grad, ops::DotGradKernel<plat::CUDADeviceContext, float>,
ops::DotGradKernel<plat::CUDADeviceContext, double>,
ops::DotGradKernel<plat::CUDADeviceContext, int>,
ops::DotGradKernel<plat::CUDADeviceContext, int64_t>,
ops::DotGradKernel<plat::CUDADeviceContext, paddle::platform::complex64>,
ops::DotGradKernel<plat::CUDADeviceContext, paddle::platform::complex128>);
ops::DotKernel<plat::CUDADeviceContext, paddle::platform::complex<float>>,
ops::DotKernel<plat::CUDADeviceContext, paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(dot_grad,
ops::DotGradKernel<plat::CUDADeviceContext, float>,
ops::DotGradKernel<plat::CUDADeviceContext, double>,
ops::DotGradKernel<plat::CUDADeviceContext, int>,
ops::DotGradKernel<plat::CUDADeviceContext, int64_t>,
ops::DotGradKernel<plat::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::DotGradKernel<plat::CUDADeviceContext,
paddle::platform::complex<double>>);
......@@ -96,11 +96,11 @@ REGISTER_OPERATOR(imag, ops::ImagOp, ops::ImagOpMaker,
REGISTER_OPERATOR(imag_grad, ops::ImagGradOp);
REGISTER_OP_CPU_KERNEL(imag, ops::ImagKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ImagKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(imag_grad,
ops::ImagGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ImagGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
......@@ -18,11 +18,11 @@ namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(imag,
ops::ImagKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ImagKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(imag_grad,
ops::ImagGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ImagGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
......@@ -16,8 +16,7 @@ limitations under the License. */
#include <type_traits>
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace paddle {
......@@ -66,7 +65,10 @@ using select_t = typename select<Head, Tail...>::type;
template <typename T>
using Real =
select_t<cond<std::is_same<T, platform::complex64>::value, float>,
cond<std::is_same<T, platform::complex128>::value, double>, T>;
cond<std::is_same<T, platform::complex128>::value, double>,
cond<std::is_same<T, platform::complex<float>>::value, float>,
cond<std::is_same<T, platform::complex<double>>::value, double>,
T>;
template <typename T, typename RealT>
using Complex = typename std::enable_if<!std::is_same<T, RealT>::value>::type;
......@@ -76,14 +78,18 @@ template <typename T, typename RealT>
using NoComplex = typename std::enable_if<std::is_same<T, RealT>::value>::type;
template <typename T>
using EnableComplex =
typename std::enable_if<std::is_same<T, platform::complex64>::value ||
std::is_same<T, platform::complex128>::value>::type;
using EnableComplex = typename std::enable_if<
std::is_same<T, platform::complex64>::value ||
std::is_same<T, platform::complex128>::value ||
std::is_same<T, platform::complex<float>>::value ||
std::is_same<T, platform::complex<double>>::value>::type;
template <typename T>
using DisableComplex = typename std::enable_if<
!std::is_same<T, platform::complex64>::value &&
!std::is_same<T, platform::complex128>::value>::type;
!std::is_same<T, platform::complex128>::value &&
!std::is_same<T, platform::complex<float>>::value &&
!std::is_same<T, platform::complex<double>>::value>::type;
template <typename T, typename Enable = void>
struct RealFunctor;
......@@ -173,44 +179,45 @@ struct AbsGradFunctor {
};
template <>
struct AbsGradFunctor<paddle::platform::complex64> {
AbsGradFunctor(const float* dout, const paddle::platform::complex64* x,
paddle::platform::complex64* output, int64_t numel)
struct AbsGradFunctor<paddle::platform::complex<float>> {
AbsGradFunctor(const float* dout, const paddle::platform::complex<float>* x,
paddle::platform::complex<float>* output, int64_t numel)
: dout_(dout), x_(x), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
if (x_[idx] == paddle::platform::complex64(0)) {
output_[idx] = paddle::platform::complex64(0);
if (x_[idx] == paddle::platform::complex<float>(0)) {
output_[idx] = paddle::platform::complex<float>(0);
} else {
output_[idx] = paddle::platform::complex64(dout_[idx]) *
(x_[idx] / paddle::platform::complex64(abs(x_[idx])));
output_[idx] = paddle::platform::complex<float>(dout_[idx]) *
(x_[idx] / paddle::platform::complex<float>(abs(x_[idx])));
}
}
const float* dout_;
const paddle::platform::complex64* x_;
paddle::platform::complex64* output_;
const paddle::platform::complex<float>* x_;
paddle::platform::complex<float>* output_;
int64_t numel_;
};
template <>
struct AbsGradFunctor<paddle::platform::complex128> {
AbsGradFunctor(const double* dout, const paddle::platform::complex128* x,
paddle::platform::complex128* output, int64_t numel)
struct AbsGradFunctor<paddle::platform::complex<double>> {
AbsGradFunctor(const double* dout, const paddle::platform::complex<double>* x,
paddle::platform::complex<double>* output, int64_t numel)
: dout_(dout), x_(x), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
if (x_[idx] == paddle::platform::complex128(0)) {
output_[idx] = paddle::platform::complex128(0);
if (x_[idx] == paddle::platform::complex<double>(0)) {
output_[idx] = paddle::platform::complex<double>(0);
} else {
output_[idx] = paddle::platform::complex128(dout_[idx]) *
(x_[idx] / paddle::platform::complex128(abs(x_[idx])));
output_[idx] =
paddle::platform::complex<double>(dout_[idx]) *
(x_[idx] / paddle::platform::complex<double>(abs(x_[idx])));
}
}
const double* dout_;
const paddle::platform::complex128* x_;
paddle::platform::complex128* output_;
const paddle::platform::complex<double>* x_;
paddle::platform::complex<double>* output_;
int64_t numel_;
};
......@@ -234,46 +241,46 @@ struct AbsGradGradFunctor {
};
template <>
struct AbsGradGradFunctor<paddle::platform::complex128> {
AbsGradGradFunctor(const paddle::platform::complex128* ddx,
const paddle::platform::complex128* x,
paddle::platform::complex128* output, int64_t numel)
struct AbsGradGradFunctor<paddle::platform::complex<double>> {
AbsGradGradFunctor(const paddle::platform::complex<double>* ddx,
const paddle::platform::complex<double>* x,
paddle::platform::complex<double>* output, int64_t numel)
: ddx_(ddx), x_(x), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
if (x_[idx] == paddle::platform::complex128(0)) {
output_[idx] = paddle::platform::complex128(0);
if (x_[idx] == paddle::platform::complex<double>(0)) {
output_[idx] = paddle::platform::complex<double>(0);
} else {
output_[idx] = paddle::platform::complex128(ddx_[idx]) * x_[idx] /
paddle::platform::complex128(abs(x_[idx]));
output_[idx] = paddle::platform::complex<double>(ddx_[idx]) * x_[idx] /
paddle::platform::complex<double>(abs(x_[idx]));
}
}
const paddle::platform::complex128* ddx_;
const paddle::platform::complex128* x_;
paddle::platform::complex128* output_;
const paddle::platform::complex<double>* ddx_;
const paddle::platform::complex<double>* x_;
paddle::platform::complex<double>* output_;
int64_t numel_;
};
template <>
struct AbsGradGradFunctor<paddle::platform::complex64> {
AbsGradGradFunctor(const paddle::platform::complex64* ddx,
const paddle::platform::complex64* x,
paddle::platform::complex64* output, int64_t numel)
struct AbsGradGradFunctor<paddle::platform::complex<float>> {
AbsGradGradFunctor(const paddle::platform::complex<float>* ddx,
const paddle::platform::complex<float>* x,
paddle::platform::complex<float>* output, int64_t numel)
: ddx_(ddx), x_(x), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
if (x_[idx] == paddle::platform::complex64(0)) {
output_[idx] = paddle::platform::complex64(0);
if (x_[idx] == paddle::platform::complex<float>(0)) {
output_[idx] = paddle::platform::complex<float>(0);
} else {
output_[idx] = paddle::platform::complex64(ddx_[idx]) * x_[idx] /
paddle::platform::complex64(abs(x_[idx]));
output_[idx] = paddle::platform::complex<float>(ddx_[idx]) * x_[idx] /
paddle::platform::complex<float>(abs(x_[idx]));
}
}
const paddle::platform::complex64* ddx_;
const paddle::platform::complex64* x_;
paddle::platform::complex64* output_;
const paddle::platform::complex<float>* ddx_;
const paddle::platform::complex<float>* x_;
paddle::platform::complex<float>* output_;
int64_t numel_;
};
template <typename T, typename Enable = void>
......
......@@ -95,11 +95,11 @@ REGISTER_OPERATOR(real, ops::RealOp, ops::RealOpMaker,
REGISTER_OPERATOR(real_grad, ops::RealGradOp);
REGISTER_OP_CPU_KERNEL(real, ops::RealKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::RealKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(real_grad,
ops::RealGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::RealGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
......@@ -18,11 +18,11 @@ namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(real,
ops::RealKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::RealKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(real_grad,
ops::RealGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::RealGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册