From 6c399d945eb95e69a3a6b6fd585c48021fbe95ee Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Thu, 27 May 2021 18:02:26 +0800 Subject: [PATCH] Modify Ops from complex64/128 to complex types. (#33133) * modify kron OP to complex template types * modify reshape, slice, trace, transpose OPs to complex template types * modify to complex template types in eigen slice files * change to complex template types for pad.cc and pac.cu * format code style --- paddle/fluid/operators/eigen/pad.cc | 7 +- paddle/fluid/operators/eigen/pad.cu | 7 +- paddle/fluid/operators/eigen/slice.cc | 4 -- paddle/fluid/operators/eigen/slice.cu | 7 +- paddle/fluid/operators/kron_op.cc | 11 ++- paddle/fluid/operators/kron_op.cu | 11 ++- paddle/fluid/operators/kron_op.h | 92 +++++--------------------- paddle/fluid/operators/reshape_op.cc | 34 +++++----- paddle/fluid/operators/slice_op.cc | 16 ++--- paddle/fluid/operators/trace_op.cc | 8 +-- paddle/fluid/operators/trace_op.cu | 8 +-- paddle/fluid/operators/transpose_op.cc | 16 ++--- paddle/fluid/operators/transpose_op.cu | 16 ++--- 13 files changed, 85 insertions(+), 152 deletions(-) diff --git a/paddle/fluid/operators/eigen/pad.cc b/paddle/fluid/operators/eigen/pad.cc index 72668bca9af..421c9eaf5cd 100644 --- a/paddle/fluid/operators/eigen/pad.cc +++ b/paddle/fluid/operators/eigen/pad.cc @@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/eigen/eigen_function.h" -#include "paddle/fluid/platform/complex128.h" -#include "paddle/fluid/platform/complex64.h" +#include "paddle/fluid/platform/complex.h" namespace paddle { namespace operators { @@ -56,8 +55,8 @@ INSTANTIATION(EigenPad, int); INSTANTIATION(EigenPad, int64_t); INSTANTIATION(EigenPad, float); INSTANTIATION(EigenPad, double); -INSTANTIATION(EigenPad, platform::complex64); -INSTANTIATION(EigenPad, platform::complex128); +INSTANTIATION(EigenPad, platform::complex); +INSTANTIATION(EigenPad, platform::complex); #undef INSTANTIATION } // namespace operators diff --git a/paddle/fluid/operators/eigen/pad.cu b/paddle/fluid/operators/eigen/pad.cu index 1c936f886a3..ee7d0429105 100644 --- a/paddle/fluid/operators/eigen/pad.cu +++ b/paddle/fluid/operators/eigen/pad.cu @@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/eigen/eigen_function.h" -#include "paddle/fluid/platform/complex128.h" -#include "paddle/fluid/platform/complex64.h" +#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/float16.h" namespace paddle { @@ -58,8 +57,8 @@ INSTANTIATION(EigenPad, int64_t); INSTANTIATION(EigenPad, float); INSTANTIATION(EigenPad, double); INSTANTIATION(EigenPad, platform::float16); -INSTANTIATION(EigenPad, platform::complex64); -INSTANTIATION(EigenPad, platform::complex128); +INSTANTIATION(EigenPad, platform::complex); +INSTANTIATION(EigenPad, platform::complex); #undef INSTANTIATION } // namespace operators diff --git a/paddle/fluid/operators/eigen/slice.cc b/paddle/fluid/operators/eigen/slice.cc index 240b4249ff1..2579b5f07eb 100644 --- a/paddle/fluid/operators/eigen/slice.cc +++ b/paddle/fluid/operators/eigen/slice.cc @@ -14,8 +14,6 @@ limitations under the License. */ #include "paddle/fluid/operators/eigen/eigen_function.h" #include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/complex.h" -#include "paddle/fluid/platform/complex128.h" -#include "paddle/fluid/platform/complex64.h" #include "paddle/fluid/platform/float16.h" namespace paddle { @@ -69,8 +67,6 @@ INSTANTIATION(EigenSlice, float); INSTANTIATION(EigenSlice, double); INSTANTIATION(EigenSlice, platform::float16); INSTANTIATION(EigenSlice, platform::bfloat16); -INSTANTIATION(EigenSlice, platform::complex64); -INSTANTIATION(EigenSlice, platform::complex128); INSTANTIATION(EigenSlice, platform::complex); INSTANTIATION(EigenSlice, platform::complex); #undef INSTANTIATION diff --git a/paddle/fluid/operators/eigen/slice.cu b/paddle/fluid/operators/eigen/slice.cu index 91c4a29f4ae..f059508394f 100644 --- a/paddle/fluid/operators/eigen/slice.cu +++ b/paddle/fluid/operators/eigen/slice.cu @@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/eigen/eigen_function.h" -#include "paddle/fluid/platform/complex128.h" -#include "paddle/fluid/platform/complex64.h" +#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/float16.h" namespace paddle { @@ -58,8 +57,8 @@ INSTANTIATION(EigenSlice, int64_t); INSTANTIATION(EigenSlice, float); INSTANTIATION(EigenSlice, double); INSTANTIATION(EigenSlice, platform::float16); -INSTANTIATION(EigenSlice, platform::complex64); -INSTANTIATION(EigenSlice, platform::complex128); +INSTANTIATION(EigenSlice, platform::complex); +INSTANTIATION(EigenSlice, platform::complex); #undef INSTANTIATION } // namespace operators diff --git a/paddle/fluid/operators/kron_op.cc b/paddle/fluid/operators/kron_op.cc index dab9948edc3..308330313a9 100644 --- a/paddle/fluid/operators/kron_op.cc +++ b/paddle/fluid/operators/kron_op.cc @@ -18,8 +18,7 @@ limitations under the License. */ #include #include "paddle/fluid/operators/kron_op.h" -#include "paddle/fluid/platform/complex128.h" -#include "paddle/fluid/platform/complex64.h" +#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/float16.h" namespace paddle { @@ -185,9 +184,9 @@ REGISTER_OP_CPU_KERNEL( ops::KronKernel, ops::KronKernel, ops::KronKernel, + paddle::platform::complex>, ops::KronKernel); + paddle::platform::complex>); REGISTER_OPERATOR(kron_grad, ops::KronGradOp); REGISTER_OP_CPU_KERNEL( @@ -198,6 +197,6 @@ REGISTER_OP_CPU_KERNEL( ops::KronGradKernel, ops::KronGradKernel, ops::KronGradKernel, + paddle::platform::complex>, ops::KronGradKernel); + paddle::platform::complex>); diff --git a/paddle/fluid/operators/kron_op.cu b/paddle/fluid/operators/kron_op.cu index a348cb2e175..e5124e65007 100644 --- a/paddle/fluid/operators/kron_op.cu +++ b/paddle/fluid/operators/kron_op.cu @@ -13,8 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/kron_op.h" -#include "paddle/fluid/platform/complex128.h" -#include "paddle/fluid/platform/complex64.h" +#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; @@ -26,9 +25,9 @@ REGISTER_OP_CUDA_KERNEL( ops::KronKernel, ops::KronKernel, ops::KronKernel, + paddle::platform::complex>, ops::KronKernel); + paddle::platform::complex>); REGISTER_OP_CUDA_KERNEL( kron_grad, ops::KronGradKernel, @@ -38,6 +37,6 @@ REGISTER_OP_CUDA_KERNEL( ops::KronGradKernel, ops::KronGradKernel, ops::KronGradKernel, + paddle::platform::complex>, ops::KronGradKernel); + paddle::platform::complex>); diff --git a/paddle/fluid/operators/kron_op.h b/paddle/fluid/operators/kron_op.h index 6815fd460fa..6c3bad4e1bd 100644 --- a/paddle/fluid/operators/kron_op.h +++ b/paddle/fluid/operators/kron_op.h @@ -26,9 +26,6 @@ limitations under the License. */ namespace paddle { namespace operators { -using complex64 = paddle::platform::complex64; -using complex128 = paddle::platform::complex128; - // Process an element in the output, used with a parallel-for template struct KronElemFunctor { @@ -175,72 +172,13 @@ struct KronGradElemFunctor { const int ndims_; }; -template <> -struct KronGradElemFunctor { - KronGradElemFunctor(const complex64* dout, const complex64* A, - const complex64* B, complex64* dout_a, complex64* dout_b, - const int64_t* stride_dout, const int64_t* stride_a, - const int64_t* stride_b, const int64_t* shape_b, - const int64_t numel_a, const int64_t numel_b, - const int ndims) - : dout_(dout), - A_(A), - B_(B), - dout_a_(dout_a), - dout_b_(dout_b), - stride_dout_(stride_dout), - stride_a_(stride_a), - stride_b_(stride_b), - shape_b_(shape_b), - numel_a_(numel_a), - numel_b_(numel_b), - ndims_(ndims) {} - - HOSTDEVICE void operator()(int64_t idx) { - int64_t index = idx; - int64_t index_a = 0; - int64_t index_b = 0; - for (int i = 0; i < ndims_; i++) { - auto pos_i = index / stride_dout_[i]; - index = index % stride_dout_[i]; - auto pos_ai = pos_i / shape_b_[i]; - auto pos_bi = pos_i % shape_b_[i]; - index_a += stride_a_[i] * pos_ai; - index_b += stride_b_[i] * pos_bi; - } - - if (dout_a_) { - size_t index_out_a = index_a * numel_b_ + index_b; - dout_a_[index_out_a] = - dout_[idx] * complex64(B_[index_b].real, -B_[index_b].imag); - } - if (dout_b_) { - size_t index_out_b = index_b * numel_a_ + index_a; - dout_b_[index_out_b] = - dout_[idx] * complex64(A_[index_a].real, -A_[index_a].imag); - } - } - - private: - const complex64* dout_; - const complex64* A_; - const complex64* B_; - complex64* dout_a_; - complex64* dout_b_; - const int64_t* stride_dout_; - const int64_t* stride_a_; - const int64_t* stride_b_; - const int64_t* shape_b_; - const int64_t numel_a_; - const int64_t numel_b_; - const int ndims_; -}; - -template <> -struct KronGradElemFunctor { - KronGradElemFunctor(const complex128* dout, const complex128* A, - const complex128* B, complex128* dout_a, - complex128* dout_b, const int64_t* stride_dout, +template +struct KronGradElemFunctor> { + KronGradElemFunctor(const platform::complex* dout, + const platform::complex* A, + const platform::complex* B, + platform::complex* dout_a, + platform::complex* dout_b, const int64_t* stride_dout, const int64_t* stride_a, const int64_t* stride_b, const int64_t* shape_b, const int64_t numel_a, const int64_t numel_b, const int ndims) @@ -273,21 +211,23 @@ struct KronGradElemFunctor { if (dout_a_) { size_t index_out_a = index_a * numel_b_ + index_b; dout_a_[index_out_a] = - dout_[idx] * complex128(B_[index_b].real, -B_[index_b].imag); + dout_[idx] * + platform::complex(B_[index_b].real, -B_[index_b].imag); } if (dout_b_) { size_t index_out_b = index_b * numel_a_ + index_a; dout_b_[index_out_b] = - dout_[idx] * complex128(A_[index_a].real, -A_[index_a].imag); + dout_[idx] * + platform::complex(A_[index_a].real, -A_[index_a].imag); } } private: - const complex128* dout_; - const complex128* A_; - const complex128* B_; - complex128* dout_a_; - complex128* dout_b_; + const platform::complex* dout_; + const platform::complex* A_; + const platform::complex* B_; + platform::complex* dout_a_; + platform::complex* dout_b_; const int64_t* stride_dout_; const int64_t* stride_a_; const int64_t* stride_b_; diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index e119a21caa2..717029cb8f1 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -613,23 +613,24 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR( reshape2, float, ops::ReshapeKernel, double, ops::ReshapeKernel, int8_t, ops::ReshapeKernel, uint8_t, ops::ReshapeKernel, int, ops::ReshapeKernel, int64_t, ops::ReshapeKernel, bool, ops::ReshapeKernel, - paddle::platform::bfloat16, ops::ReshapeKernel, paddle::platform::complex64, - ops::ReshapeKernel, paddle::platform::complex128, ops::ReshapeKernel); + paddle::platform::bfloat16, ops::ReshapeKernel, + paddle::platform::complex, ops::ReshapeKernel, + paddle::platform::complex, ops::ReshapeKernel); REGISTER_OP_CPU_KERNEL_FUNCTOR( reshape2_grad, float, ops::ReshapeGradKernel, double, ops::ReshapeGradKernel, int, ops::ReshapeGradKernel, uint8_t, ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel, bool, ops::ReshapeGradKernel, paddle::platform::bfloat16, ops::ReshapeGradKernel, - paddle::platform::complex64, ops::ReshapeGradKernel, - paddle::platform::complex128, ops::ReshapeGradKernel); + paddle::platform::complex, ops::ReshapeGradKernel, + paddle::platform::complex, ops::ReshapeGradKernel); REGISTER_OP_CPU_KERNEL_FUNCTOR( reshape2_grad_grad, float, ops::ReshapeDoubleGradKernel, double, ops::ReshapeDoubleGradKernel, int, ops::ReshapeDoubleGradKernel, uint8_t, ops::ReshapeDoubleGradKernel, int64_t, ops::ReshapeDoubleGradKernel, bool, ops::ReshapeDoubleGradKernel, paddle::platform::bfloat16, - ops::ReshapeDoubleGradKernel, paddle::platform::complex64, - ops::ReshapeDoubleGradKernel, paddle::platform::complex128, + ops::ReshapeDoubleGradKernel, paddle::platform::complex, + ops::ReshapeDoubleGradKernel, paddle::platform::complex, ops::ReshapeDoubleGradKernel); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) @@ -650,22 +651,23 @@ REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double, uint8_t, ops::ReshapeKernel, int64_t, ops::ReshapeKernel, plat::float16, ops::ReshapeKernel, bool, ops::ReshapeKernel, - plat::complex64, ops::ReshapeKernel, - plat::complex128, ops::ReshapeKernel); + plat::complex, ops::ReshapeKernel, + plat::complex, ops::ReshapeKernel); REGISTER_OP_CUDA_KERNEL_FUNCTOR( reshape2_grad, float, ops::ReshapeGradKernel, double, ops::ReshapeGradKernel, int, ops::ReshapeGradKernel, uint8_t, ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel, plat::float16, - ops::ReshapeGradKernel, bool, ops::ReshapeGradKernel, plat::complex64, - ops::ReshapeGradKernel, plat::complex128, ops::ReshapeGradKernel); + ops::ReshapeGradKernel, bool, ops::ReshapeGradKernel, plat::complex, + ops::ReshapeGradKernel, plat::complex, ops::ReshapeGradKernel); REGISTER_OP_CUDA_KERNEL_FUNCTOR( reshape2_grad_grad, float, ops::ReshapeDoubleGradKernel, double, ops::ReshapeDoubleGradKernel, int, ops::ReshapeDoubleGradKernel, uint8_t, ops::ReshapeDoubleGradKernel, int64_t, ops::ReshapeDoubleGradKernel, plat::float16, ops::ReshapeDoubleGradKernel, bool, - ops::ReshapeDoubleGradKernel, plat::complex64, ops::ReshapeDoubleGradKernel, - plat::complex128, ops::ReshapeDoubleGradKernel); + ops::ReshapeDoubleGradKernel, plat::complex, + ops::ReshapeDoubleGradKernel, plat::complex, + ops::ReshapeDoubleGradKernel); #endif #ifdef PADDLE_WITH_XPU @@ -673,14 +675,14 @@ REGISTER_OP_XPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double, ops::ReshapeKernel, int, ops::ReshapeKernel, int64_t, ops::ReshapeKernel, plat::float16, ops::ReshapeKernel, bool, ops::ReshapeKernel, - plat::complex64, ops::ReshapeKernel, - plat::complex128, ops::ReshapeKernel); + plat::complex, ops::ReshapeKernel, + plat::complex, ops::ReshapeKernel); REGISTER_OP_XPU_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel, double, ops::ReshapeGradKernel, int, ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel, plat::float16, ops::ReshapeGradKernel, bool, - ops::ReshapeGradKernel, plat::complex64, - ops::ReshapeGradKernel, plat::complex128, + ops::ReshapeGradKernel, plat::complex, + ops::ReshapeGradKernel, plat::complex, ops::ReshapeGradKernel); #endif diff --git a/paddle/fluid/operators/slice_op.cc b/paddle/fluid/operators/slice_op.cc index c37fd679bed..b5298979721 100644 --- a/paddle/fluid/operators/slice_op.cc +++ b/paddle/fluid/operators/slice_op.cc @@ -436,9 +436,9 @@ REGISTER_OP_CPU_KERNEL( ops::SliceKernel, ops::SliceKernel, ops::SliceKernel, + paddle::platform::complex>, ops::SliceKernel); + paddle::platform::complex>); REGISTER_OP_CPU_KERNEL( slice_grad, ops::SliceGradKernel, @@ -446,9 +446,9 @@ REGISTER_OP_CPU_KERNEL( ops::SliceGradKernel, ops::SliceGradKernel, ops::SliceGradKernel, + paddle::platform::complex>, ops::SliceGradKernel); + paddle::platform::complex>); REGISTER_OP_CUDA_KERNEL( slice, ops::SliceKernel, @@ -458,9 +458,9 @@ REGISTER_OP_CUDA_KERNEL( ops::SliceKernel, ops::SliceKernel, + paddle::platform::complex>, ops::SliceKernel); + paddle::platform::complex>); REGISTER_OP_CUDA_KERNEL( slice_grad, @@ -471,6 +471,6 @@ REGISTER_OP_CUDA_KERNEL( ops::SliceGradKernel, ops::SliceGradKernel, + paddle::platform::complex>, ops::SliceGradKernel); + paddle::platform::complex>); diff --git a/paddle/fluid/operators/trace_op.cc b/paddle/fluid/operators/trace_op.cc index 623d4c7fc23..de71a089b69 100644 --- a/paddle/fluid/operators/trace_op.cc +++ b/paddle/fluid/operators/trace_op.cc @@ -167,18 +167,18 @@ REGISTER_OP_CPU_KERNEL( ops::TraceKernel, ops::TraceKernel, ops::TraceKernel, + paddle::platform::complex>, ops::TraceKernel); + paddle::platform::complex>); REGISTER_OP_CPU_KERNEL( trace_grad, ops::TraceGradKernel, ops::TraceGradKernel, ops::TraceGradKernel, ops::TraceGradKernel, ops::TraceGradKernel, + paddle::platform::complex>, ops::TraceGradKernel); + paddle::platform::complex>); /* ========================== register checkpoint ===========================*/ REGISTER_OP_VERSION(trace) diff --git a/paddle/fluid/operators/trace_op.cu b/paddle/fluid/operators/trace_op.cu index 2c2745018be..6798521c8f7 100644 --- a/paddle/fluid/operators/trace_op.cu +++ b/paddle/fluid/operators/trace_op.cu @@ -64,9 +64,9 @@ REGISTER_OP_CUDA_KERNEL( ops::TraceCUDAKernel, ops::TraceCUDAKernel, ops::TraceCUDAKernel, + paddle::platform::complex>, ops::TraceCUDAKernel); + paddle::platform::complex>); REGISTER_OP_CUDA_KERNEL( trace_grad, ops::TraceGradKernel, ops::TraceGradKernel, @@ -75,6 +75,6 @@ REGISTER_OP_CUDA_KERNEL( ops::TraceGradKernel, ops::TraceGradKernel, ops::TraceGradKernel, + paddle::platform::complex>, ops::TraceGradKernel); + paddle::platform::complex>); diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index 465970451f5..95b2c13ff6c 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -341,17 +341,17 @@ REGISTER_OP_CPU_KERNEL( transpose, ops::TransposeKernel, ops::TransposeKernel, ops::TransposeKernel, + paddle::platform::complex>, ops::TransposeKernel); + paddle::platform::complex>); REGISTER_OP_CPU_KERNEL( transpose_grad, ops::TransposeGradKernel, ops::TransposeGradKernel, ops::TransposeGradKernel, + paddle::platform::complex>, ops::TransposeGradKernel); + paddle::platform::complex>); REGISTER_OPERATOR(transpose2, ops::Transpose2Op, ops::Transpose2OpMaker, ops::Transpose2GradMaker, @@ -366,9 +366,9 @@ REGISTER_OP_CPU_KERNEL( ops::TransposeKernel, ops::TransposeKernel, ops::TransposeKernel, + paddle::platform::complex>, ops::TransposeKernel); + paddle::platform::complex>); REGISTER_OP_CPU_KERNEL( transpose2_grad, ops::TransposeGradKernel, @@ -376,6 +376,6 @@ REGISTER_OP_CPU_KERNEL( ops::TransposeGradKernel, ops::TransposeGradKernel, ops::TransposeGradKernel, + paddle::platform::complex>, ops::TransposeGradKernel); + paddle::platform::complex>); diff --git a/paddle/fluid/operators/transpose_op.cu b/paddle/fluid/operators/transpose_op.cu index afeb22bd6fa..a462bbb4834 100644 --- a/paddle/fluid/operators/transpose_op.cu +++ b/paddle/fluid/operators/transpose_op.cu @@ -732,9 +732,9 @@ REGISTER_OP_CUDA_KERNEL( ops::TransposeGPUKernel, ops::TransposeGPUKernel, ops::TransposeGPUKernel, + paddle::platform::complex>, ops::TransposeGPUKernel); + paddle::platform::complex>); REGISTER_OP_CUDA_KERNEL( transpose_grad, ops::TransposeGradGPUKernel, @@ -742,9 +742,9 @@ REGISTER_OP_CUDA_KERNEL( ops::TransposeGradGPUKernel, ops::TransposeGradGPUKernel, + paddle::platform::complex>, ops::TransposeGradGPUKernel); + paddle::platform::complex>); REGISTER_OP_CUDA_KERNEL( transpose2, @@ -754,9 +754,9 @@ REGISTER_OP_CUDA_KERNEL( ops::TransposeGPUKernel, ops::TransposeGPUKernel, ops::TransposeGPUKernel, + paddle::platform::complex>, ops::TransposeGPUKernel); + paddle::platform::complex>); REGISTER_OP_CUDA_KERNEL( transpose2_grad, ops::TransposeGradGPUKernel, @@ -766,6 +766,6 @@ REGISTER_OP_CUDA_KERNEL( ops::TransposeGradGPUKernel, ops::TransposeGradGPUKernel, + paddle::platform::complex>, ops::TransposeGradGPUKernel); + paddle::platform::complex>); -- GitLab