From 0764fda25bb016bf143fc0a3aa93a3fb56b0cd73 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Wed, 2 Mar 2022 15:07:34 +0800 Subject: [PATCH] [Phi] Unify complex type trait and fix real imag bug (#40036) * unify complex type trait and fix real imag bug * add unittest for type tratis --- paddle/fluid/operators/angle_op.h | 6 +- paddle/fluid/operators/eig_op.h | 26 ++-- paddle/fluid/operators/eigh_op.h | 2 +- paddle/fluid/operators/eigvals_op.h | 14 +- paddle/fluid/operators/imag_op.cc | 2 +- paddle/fluid/operators/lstsq_op.h | 4 +- .../operators/math/eigen_values_vectors.h | 8 +- paddle/fluid/operators/math/inclusive_scan.h | 2 +- paddle/fluid/operators/qr_op.cu | 14 +- paddle/fluid/operators/qr_op.h | 18 +-- paddle/fluid/operators/real_op.cc | 2 +- paddle/fluid/operators/svd_helper.h | 12 +- paddle/fluid/operators/svd_op.h | 12 +- paddle/phi/common/type_traits.h | 96 ++++++++++++++ paddle/phi/infermeta/unary.cc | 7 + paddle/phi/infermeta/unary.h | 2 + paddle/phi/kernels/cpu/abs_kernel.cc | 6 +- paddle/phi/kernels/cpu/complex_kernel.cc | 8 +- paddle/phi/kernels/funcs/complex_functors.h | 123 ++++++------------ paddle/phi/kernels/gpu/abs_kernel.cu | 10 +- paddle/phi/kernels/gpu/complex_kernel.cu | 8 +- .../phi/kernels/impl/abs_grad_kernel_impl.h | 2 +- .../kernels/impl/complex_grad_kernel_impl.h | 4 +- paddle/phi/kernels/impl/complex_kernel_impl.h | 8 +- paddle/phi/tests/common/test_data_type.cc | 16 +++ 25 files changed, 247 insertions(+), 165 deletions(-) create mode 100644 paddle/phi/common/type_traits.h diff --git a/paddle/fluid/operators/angle_op.h b/paddle/fluid/operators/angle_op.h index db5a3ea2961..116a8053db3 100644 --- a/paddle/fluid/operators/angle_op.h +++ b/paddle/fluid/operators/angle_op.h @@ -36,8 +36,8 @@ class AngleKernel : public framework::OpKernel { auto numel = x->numel(); auto* x_data = x->data(); - auto* out_data = out->mutable_data>( - context.GetPlace(), size_t(x->numel() * sizeof(phi::funcs::Real))); + auto* out_data = out->mutable_data>( + context.GetPlace(), size_t(x->numel() * sizeof(phi::dtype::Real))); auto& dev_ctx = context.template device_context(); platform::ForRange for_range(dev_ctx, numel); @@ -57,7 +57,7 @@ class AngleGradKernel : public framework::OpKernel { ctx.Output(framework::GradVarName("X")); auto numel = d_out->numel(); - auto* dout_data = d_out->data>(); + auto* dout_data = d_out->data>(); auto* x_data = x->data(); auto* dx_data = d_x->mutable_data( ctx.GetPlace(), static_cast(numel * sizeof(T))); diff --git a/paddle/fluid/operators/eig_op.h b/paddle/fluid/operators/eig_op.h index 03b25c6705a..e9c6c1eb7ec 100644 --- a/paddle/fluid/operators/eig_op.h +++ b/paddle/fluid/operators/eig_op.h @@ -87,19 +87,19 @@ void LapackEig(Tensor* input, Tensor* values, Tensor* vectors, int info, int values_stride = values->dims()[values->dims().size() - 1]; Tensor rwork; - phi::funcs::Real* rwork_data = nullptr; + phi::dtype::Real* rwork_data = nullptr; rwork.Resize(phi::make_ddim({lda * 2})); - rwork_data = rwork.mutable_data>(context.GetPlace()); + rwork_data = rwork.mutable_data>(context.GetPlace()); // call lapackEig once to compute the size of work; T computed_work_size; - phi::funcs::lapackEig>( + phi::funcs::lapackEig>( jobvl, jobvr, order, input_data, lda, values_data, lvector_data, ldvl, rvector_data, ldvr, &computed_work_size, lwork, rwork_data, &info); lwork = std::max( - 1, static_cast(phi::funcs::Real(computed_work_size))); + 1, static_cast(phi::dtype::Real(computed_work_size))); Tensor work; work.Resize(phi::make_ddim({lwork})); T* work_data = work.mutable_data(context.GetPlace()); @@ -109,7 +109,7 @@ void LapackEig(Tensor* input, Tensor* values, Tensor* vectors, int info, T* current_values = &values_data[i * values_stride]; T* current_rvectors = &rvector_data[i * matrix_stride]; - phi::funcs::lapackEig>( + phi::funcs::lapackEig>( jobvl, jobvr, order, current_matrix, lda, current_values, lvector_data, ldvl, current_rvectors, ldvr, work_data, lwork, rwork_data, &info); PADDLE_ENFORCE_EQ( @@ -207,23 +207,23 @@ class EigKernel : public framework::OpKernel { origin_dim.push_back(last_item * 2); framework::DDim big_dim = phi::make_ddim(origin_dim); - real_values.mutable_data>(big_dim, + real_values.mutable_data>(big_dim, context.GetPlace()); - real_vectors.mutable_data>(x->dims(), + real_vectors.mutable_data>(x->dims(), context.GetPlace()); - ApplyEigKernel>( + ApplyEigKernel>( *x, &real_values, &real_vectors, context); auto dito = math::DeviceIndependenceTensorOperations< - DeviceContext, phi::funcs::Real, Tout>(context); + DeviceContext, phi::dtype::Real, Tout>(context); // 1. extract real part & imag part from real_values Tensor real_part = dito.Slice(real_values, {-1}, {0}, {order}); Tensor imag_part = dito.Slice(real_values, {-1}, {order}, {order * 2}); // 2. construct complex values - auto* real_part_data = real_part.data>(); - auto* imag_part_data = imag_part.data>(); + auto* real_part_data = real_part.data>(); + auto* imag_part_data = imag_part.data>(); int out_values_numel = out_values->numel(); platform::ForRange for_range( context.template device_context(), out_values_numel); @@ -236,7 +236,7 @@ class EigKernel : public framework::OpKernel { Tensor real_vector_trans = dito.Transpose(real_vectors); Tensor out_vectors_trans; out_vectors_trans.mutable_data(x->dims(), context.GetPlace()); - ConstructComplexVectors, Tout>( + ConstructComplexVectors, Tout>( &out_vectors_trans, *out_values, real_vector_trans, context, batch_count, order); TransposeTwoAxis(out_vectors_trans, out_vectors, @@ -272,7 +272,7 @@ void ComputeBackwardForComplexInput( // turn diag_unsqueezed into complex auto numel = diag_unsqueezed.numel(); Tensor diag_unsqueezed_complex; - auto* data_diag_un = diag_unsqueezed.data>(); + auto* data_diag_un = diag_unsqueezed.data>(); auto* data_diag_un_com = diag_unsqueezed_complex.mutable_data( diag_unsqueezed.dims(), context.GetPlace(), static_cast(numel * sizeof(Tout))); diff --git a/paddle/fluid/operators/eigh_op.h b/paddle/fluid/operators/eigh_op.h index 294794877b3..5279ec75093 100644 --- a/paddle/fluid/operators/eigh_op.h +++ b/paddle/fluid/operators/eigh_op.h @@ -40,7 +40,7 @@ template class EighGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - using ValueType = phi::funcs::Real; + using ValueType = phi::dtype::Real; auto& x_grad = *ctx.Output(framework::GradVarName("X")); x_grad.mutable_data(ctx.GetPlace()); auto& output_w = *ctx.Input("Eigenvalues"); diff --git a/paddle/fluid/operators/eigvals_op.h b/paddle/fluid/operators/eigvals_op.h index 59eabfb29b9..4627acc0d07 100644 --- a/paddle/fluid/operators/eigvals_op.h +++ b/paddle/fluid/operators/eigvals_op.h @@ -48,7 +48,7 @@ struct PaddleComplex< template using PaddleCType = typename PaddleComplex::type; template -using Real = typename phi::funcs::Real; +using Real = typename phi::dtype::Real; static void SpiltBatchSquareMatrix(const Tensor& input, std::vector* output) { @@ -144,7 +144,7 @@ LapackEigvals(const framework::ExecutionContext& ctx, const Tensor& input, required_work_mem, work_mem)); int64_t rwork_mem = rwork->memory_size(); - int64_t required_rwork_mem = (n_dim << 1) * sizeof(phi::funcs::Real); + int64_t required_rwork_mem = (n_dim << 1) * sizeof(phi::dtype::Real); PADDLE_ENFORCE_GE( rwork_mem, required_rwork_mem, platform::errors::InvalidArgument( @@ -154,11 +154,11 @@ LapackEigvals(const framework::ExecutionContext& ctx, const Tensor& input, required_rwork_mem, rwork_mem)); int info = 0; - phi::funcs::lapackEig>( + phi::funcs::lapackEig>( 'N', 'N', static_cast(n_dim), a.template data(), static_cast(n_dim), output->template data(), NULL, 1, NULL, 1, work->template data(), static_cast(work_mem / sizeof(T)), - rwork->template data>(), &info); + rwork->template data>(), &info); std::string name = "framework::platform::dynload::cgeev_"; if (framework::TransToProtoVarType(input.dtype()) == @@ -188,10 +188,10 @@ class EigvalsKernel : public framework::OpKernel { // query workspace size T qwork; int info; - phi::funcs::lapackEig>( + phi::funcs::lapackEig>( 'N', 'N', static_cast(n_dim), input_matrices[0].template data(), static_cast(n_dim), NULL, NULL, 1, NULL, 1, &qwork, -1, - static_cast*>(NULL), &info); + static_cast*>(NULL), &info); int64_t lwork = static_cast(qwork); Tensor work, rwork; @@ -208,7 +208,7 @@ class EigvalsKernel : public framework::OpKernel { } if (framework::IsComplexType( framework::TransToProtoVarType(input->dtype()))) { - rwork.mutable_data>(phi::make_ddim({n_dim << 1}), + rwork.mutable_data>(phi::make_ddim({n_dim << 1}), ctx.GetPlace()); } diff --git a/paddle/fluid/operators/imag_op.cc b/paddle/fluid/operators/imag_op.cc index 33b68d68992..567a69f383d 100644 --- a/paddle/fluid/operators/imag_op.cc +++ b/paddle/fluid/operators/imag_op.cc @@ -83,7 +83,7 @@ DECLARE_INPLACE_OP_INFERER(ImagGradOpInplaceInferer, } // namespace paddle DELCARE_INFER_SHAPE_FUNCTOR(imag, ImagInferShapeFunctor, - PT_INFER_META(phi::UnchangedInferMeta)); + PT_INFER_META(phi::RealAndImagInferMeta)); namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/lstsq_op.h b/paddle/fluid/operators/lstsq_op.h index a4c3d1c81fb..3cbbc62e7be 100644 --- a/paddle/fluid/operators/lstsq_op.h +++ b/paddle/fluid/operators/lstsq_op.h @@ -46,7 +46,7 @@ template class LstsqCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - using ValueType = phi::funcs::Real; + using ValueType = phi::dtype::Real; const Tensor& x = *context.Input("X"); auto y = context.Input("Y"); @@ -169,7 +169,7 @@ class LstsqCPUKernel : public framework::OpKernel { &rank_32, &wkopt, lwork, &rwkopt, &info); } - lwork = std::max(1, static_cast(phi::funcs::Real(wkopt))); + lwork = std::max(1, static_cast(phi::dtype::Real(wkopt))); Tensor work; work.Resize(phi::make_ddim({lwork})); T* work_data = work.mutable_data(context.GetPlace()); diff --git a/paddle/fluid/operators/math/eigen_values_vectors.h b/paddle/fluid/operators/math/eigen_values_vectors.h index 9b6ebf73d9b..1ade2190bb9 100644 --- a/paddle/fluid/operators/math/eigen_values_vectors.h +++ b/paddle/fluid/operators/math/eigen_values_vectors.h @@ -63,7 +63,7 @@ struct MatrixEighFunctor { void operator()(const framework::ExecutionContext &ctx, const Tensor &input, Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower, bool has_vectors) { - using ValueType = phi::funcs::Real; + using ValueType = phi::dtype::Real; auto *out_value = eigen_values->mutable_data(ctx.GetPlace()); auto dito = @@ -123,7 +123,7 @@ struct MatrixEighFunctor { for (auto i = 0; i < batch_size; i++) { auto *value_data = out_value + i * values_stride; auto *input_data = input_vector + i * vector_stride; - phi::funcs::lapackEigh>( + phi::funcs::lapackEigh>( jobz, uplo, n, input_data, lda, value_data, work_data, lwork, rwork_data, lrwork, iwork_data, liwork, &info); CheckEighResult(i, info); @@ -151,7 +151,7 @@ struct MatrixEighFunctor { void operator()(const framework::ExecutionContext &ctx, const Tensor &input, Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower, bool has_vectors) { - using ValueType = phi::funcs::Real; + using ValueType = phi::dtype::Real; auto *out_value = eigen_values->mutable_data(ctx.GetPlace()); auto &dev_ctx = ctx.template device_context(); @@ -233,7 +233,7 @@ struct MatrixEighFunctor { } } - using ValueType = phi::funcs::Real; + using ValueType = phi::dtype::Real; inline void EvdBuffer(cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, const T *A, int lda, const ValueType *W, int *lwork) const; diff --git a/paddle/fluid/operators/math/inclusive_scan.h b/paddle/fluid/operators/math/inclusive_scan.h index 38692a64611..9994ccc10cb 100644 --- a/paddle/fluid/operators/math/inclusive_scan.h +++ b/paddle/fluid/operators/math/inclusive_scan.h @@ -115,7 +115,7 @@ static __global__ void InclusiveScanInnerDimCUDAKernel(const T *x, T *y, size_t num_rows, size_t row_size, T init, BinaryOp op) { - using RealT = phi::funcs::Real; + using RealT = phi::dtype::Real; constexpr auto kSharedBufferSize = framework::IsComplex::value ? 4 * kThreadNumX : 2 * kThreadNumX; __shared__ RealT sbuf[kThreadNumY][kSharedBufferSize]; diff --git a/paddle/fluid/operators/qr_op.cu b/paddle/fluid/operators/qr_op.cu index 5e841a097fe..a57a8d5cf8b 100644 --- a/paddle/fluid/operators/qr_op.cu +++ b/paddle/fluid/operators/qr_op.cu @@ -56,13 +56,13 @@ class QrGPUKernel : public framework::OpKernel { int tau_stride = min_mn; if (compute_q) { - q.mutable_data>( + q.mutable_data>( context.GetPlace(), - size_t(batch_size * m * k * sizeof(phi::funcs::Real))); + size_t(batch_size * m * k * sizeof(phi::dtype::Real))); } - r.mutable_data>( + r.mutable_data>( context.GetPlace(), - size_t(batch_size * k * n * sizeof(phi::funcs::Real))); + size_t(batch_size * k * n * sizeof(phi::dtype::Real))); auto dito = math::DeviceIndependenceTensorOperations { // Note: allocate temporary tensors because of lacking in-place operatios. // Prepare qr Tensor qr; - qr.mutable_data>( + qr.mutable_data>( context.GetPlace(), - size_t(batch_size * m * n * sizeof(phi::funcs::Real))); + size_t(batch_size * m * n * sizeof(phi::dtype::Real))); // BatchedGeqrf performs computation in-place and 'qr' must be a copy of // input paddle::framework::TensorCopy(x, context.GetPlace(), &qr); @@ -126,7 +126,7 @@ class QrGPUKernel : public framework::OpKernel { for (int i = 0; i < batch_size; ++i) { memory::Copy(dev_ctx.GetPlace(), (new_qr_data + i * new_qr_stride), dev_ctx.GetPlace(), (qr_data + i * qr_stride), - qr_stride * sizeof(phi::funcs::Real), + qr_stride * sizeof(phi::dtype::Real), dev_ctx.stream()); } BatchedOrgqr( diff --git a/paddle/fluid/operators/qr_op.h b/paddle/fluid/operators/qr_op.h index cef9371fea0..f09a07e96cd 100644 --- a/paddle/fluid/operators/qr_op.h +++ b/paddle/fluid/operators/qr_op.h @@ -74,19 +74,19 @@ class QrCPUKernel : public framework::OpKernel { int q_stride = m * k; int r_stride = k * n; - auto* x_data = x.data>(); + auto* x_data = x.data>(); T* q_data = nullptr; if (compute_q) { - q_data = q.mutable_data>( + q_data = q.mutable_data>( context.GetPlace(), - size_t(batch_size * m * k * sizeof(phi::funcs::Real))); + size_t(batch_size * m * k * sizeof(phi::dtype::Real))); memset(q_data, 0, - size_t(batch_size * m * k * sizeof(phi::funcs::Real))); + size_t(batch_size * m * k * sizeof(phi::dtype::Real))); } - auto* r_data = r.mutable_data>( + auto* r_data = r.mutable_data>( context.GetPlace(), - size_t(batch_size * k * n * sizeof(phi::funcs::Real))); - memset(r_data, 0, size_t(batch_size * k * n * sizeof(phi::funcs::Real))); + size_t(batch_size * k * n * sizeof(phi::dtype::Real))); + memset(r_data, 0, size_t(batch_size * k * n * sizeof(phi::dtype::Real))); // Implement QR by calling Eigen for (int i = 0; i < batch_size; ++i) { @@ -142,7 +142,7 @@ class QrGradKernel : public framework::OpKernel { // Use a different name dA instead of dX framework::Tensor& dA = *ctx.Output(framework::GradVarName("X")); - dA.mutable_data>(ctx.GetPlace()); + dA.mutable_data>(ctx.GetPlace()); auto& dev_ctx = ctx.template device_context(); phi::funcs::SetConstant()(dev_ctx, &dA, T(0)); @@ -224,7 +224,7 @@ class QrGradKernel : public framework::OpKernel { } else { // If m < n for input matrices A, we partition A = [X|Y] and R = [U|V] // Calculate dX and dY individually and concatenate them to get dA - dA.mutable_data>(ctx.GetPlace()); + dA.mutable_data>(ctx.GetPlace()); auto Y = dito.Slice(A, {-1}, {m}, {n}); auto U = dito.Slice(R, {-1}, {0}, {m}); diff --git a/paddle/fluid/operators/real_op.cc b/paddle/fluid/operators/real_op.cc index 1f3691978b5..28a8484f539 100644 --- a/paddle/fluid/operators/real_op.cc +++ b/paddle/fluid/operators/real_op.cc @@ -83,7 +83,7 @@ DECLARE_INPLACE_OP_INFERER(RealGradOpInplaceInferer, } // namespace paddle DELCARE_INFER_SHAPE_FUNCTOR(real, RealInferShapeFunctor, - PT_INFER_META(phi::UnchangedInferMeta)); + PT_INFER_META(phi::RealAndImagInferMeta)); namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/svd_helper.h b/paddle/fluid/operators/svd_helper.h index bcb3ee44f04..166f49999d5 100644 --- a/paddle/fluid/operators/svd_helper.h +++ b/paddle/fluid/operators/svd_helper.h @@ -105,7 +105,7 @@ struct RealMulComplexFunctor { "The image part of y must to be 0" "but got [%d]", y.imag)); - return platform::complex>(x.real * y.real, + return platform::complex>(x.real * y.real, x.imag * y.real); } }; @@ -391,11 +391,11 @@ struct DeviceIndependenceTensorOperations { // batch_diag for CPU only Tensor BatchDiag(const Tensor& x, int batch) { Tensor out; - auto* x_data = x.data>(); + auto* x_data = x.data>(); auto numel = x.numel(); - auto* out_data = out.mutable_data>( + auto* out_data = out.mutable_data>( x.dims(), context.GetPlace(), - static_cast(numel * sizeof(phi::funcs::Real))); + static_cast(numel * sizeof(phi::dtype::Real))); auto x_dims = x.dims(); int num_dims = x_dims.size(); @@ -661,9 +661,9 @@ struct DeviceIndependenceTensorOperations { Tensor Real(const Tensor& x) { Tensor out; auto numel = x.numel(); - auto* out_data = out.mutable_data>( + auto* out_data = out.mutable_data>( x.dims(), context.GetPlace(), - static_cast(numel * sizeof(phi::funcs::Real))); + static_cast(numel * sizeof(phi::dtype::Real))); auto* x_data = x.data(); auto for_range = GetForRange(numel); phi::funcs::RealFunctor functor(x_data, out_data, numel); diff --git a/paddle/fluid/operators/svd_op.h b/paddle/fluid/operators/svd_op.h index f5e451ac705..42a847206a3 100644 --- a/paddle/fluid/operators/svd_op.h +++ b/paddle/fluid/operators/svd_op.h @@ -46,14 +46,14 @@ class SvdCPUKernel : public framework::OpKernel { int col_u = full ? rows : k; int col_v = full ? cols : k; int batches = numel / (rows * cols); - auto* U_out = U->mutable_data>( + auto* U_out = U->mutable_data>( context.GetPlace(), - size_t(batches * rows * col_u * sizeof(phi::funcs::Real))); - auto* VH_out = VH->mutable_data>( + size_t(batches * rows * col_u * sizeof(phi::dtype::Real))); + auto* VH_out = VH->mutable_data>( context.GetPlace(), - size_t(batches * col_v * cols * sizeof(phi::funcs::Real))); - auto* S_out = S->mutable_data>( - context.GetPlace(), size_t(batches * k * sizeof(phi::funcs::Real))); + size_t(batches * col_v * cols * sizeof(phi::dtype::Real))); + auto* S_out = S->mutable_data>( + context.GetPlace(), size_t(batches * k * sizeof(phi::dtype::Real))); /*SVD Use the Eigen Library*/ math::BatchSvd(x_data, U_out, VH_out, S_out, rows, cols, batches, full); } diff --git a/paddle/phi/common/type_traits.h b/paddle/phi/common/type_traits.h new file mode 100644 index 00000000000..ef894eee468 --- /dev/null +++ b/paddle/phi/common/type_traits.h @@ -0,0 +1,96 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/data_type.h" + +namespace phi { +namespace dtype { + +template +struct cond { + static constexpr bool value = B; + using type = T; +}; + +template +struct eval_if { + using type = typename TrueF::type; +}; + +template +struct eval_if { + using type = typename FalseF::type; +}; + +template +using eval_if_t = typename eval_if::type; + +template +struct select { + using type = eval_if_t>; +}; + +template +struct select { + using type = T; +}; + +template +struct select> { + // last one had better be true! + static_assert(B, "No match select type!"); + using type = T; +}; + +template +using select_t = typename select::type; + +// runtime real and complex type conversion + +template +using Real = select_t>::value, float>, + cond>::value, double>, + T>; + +template +using Complex = select_t::value, complex>, + cond::value, complex>, + T>; + +inline DataType ToReal(DataType dtype) { + switch (dtype) { + case phi::DataType::COMPLEX64: + return phi::DataType::FLOAT32; + case phi::DataType::COMPLEX128: + return phi::DataType::FLOAT64; + default: + return dtype; + } +} + +inline DataType ToComplex(DataType dtype) { + switch (dtype) { + case phi::DataType::FLOAT32: + return phi::DataType::COMPLEX64; + case phi::DataType::FLOAT64: + return phi::DataType::COMPLEX128; + default: + return dtype; + } +} + +} // namespace dtype +} // namespace phi diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 983e0162264..fbd9259a83f 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/type_traits.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/kernels/funcs/unfold_functor.h" @@ -51,6 +52,12 @@ void UnchangedInferMetaCheckAxis(const MetaTensor& x, out->share_meta(x); } +void RealAndImagInferMeta(const MetaTensor& x, MetaTensor* out) { + out->set_dims(x.dims()); + out->set_dtype(dtype::ToReal(x.dtype())); + out->set_layout(x.layout()); +} + void FlattenInferMeta(const MetaTensor& x, int start_axis, int stop_axis, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index a2d779e0f70..3c0628981af 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -39,6 +39,8 @@ void UnchangedInferMetaCheckAxis(const MetaTensor& x, int axis, MetaTensor* out); +void RealAndImagInferMeta(const MetaTensor& x, MetaTensor* out); + void FlattenInferMeta(const MetaTensor& x, int start_axis, int stop_axis, diff --git a/paddle/phi/kernels/cpu/abs_kernel.cc b/paddle/phi/kernels/cpu/abs_kernel.cc index efe7d090405..9f89fc27a71 100644 --- a/paddle/phi/kernels/cpu/abs_kernel.cc +++ b/paddle/phi/kernels/cpu/abs_kernel.cc @@ -25,9 +25,9 @@ template void AbsKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) { auto numel = x.numel(); auto* x_data = x.data(); - ctx.template Alloc>( - out, size_t(x.numel() * sizeof(phi::funcs::Real))); - auto* out_data = out->data>(); + ctx.template Alloc>( + out, size_t(x.numel() * sizeof(phi::dtype::Real))); + auto* out_data = out->data>(); phi::funcs::ForRange for_range(ctx, numel); phi::funcs::AbsFunctor functor(x_data, out_data, numel); diff --git a/paddle/phi/kernels/cpu/complex_kernel.cc b/paddle/phi/kernels/cpu/complex_kernel.cc index 801502e1673..859d5a84527 100644 --- a/paddle/phi/kernels/cpu/complex_kernel.cc +++ b/paddle/phi/kernels/cpu/complex_kernel.cc @@ -37,11 +37,15 @@ PD_REGISTER_KERNEL(real, ALL_LAYOUT, phi::RealKernel, phi::dtype::complex, - phi::dtype::complex) {} + phi::dtype::complex) { + kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); +} PD_REGISTER_KERNEL(imag, CPU, ALL_LAYOUT, phi::ImagKernel, phi::dtype::complex, - phi::dtype::complex) {} + phi::dtype::complex) { + kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); +} diff --git a/paddle/phi/kernels/funcs/complex_functors.h b/paddle/phi/kernels/funcs/complex_functors.h index 86dbdd099ec..8b292cb5dc5 100644 --- a/paddle/phi/kernels/funcs/complex_functors.h +++ b/paddle/phi/kernels/funcs/complex_functors.h @@ -20,56 +20,12 @@ limitations under the License. */ #include #include "paddle/phi/common/complex.h" +#include "paddle/phi/common/type_traits.h" #include "paddle/phi/core/hostdevice.h" namespace phi { namespace funcs { -template -struct cond { - static constexpr bool value = B; - using type = T; -}; - -template -struct eval_if { - using type = typename TrueF::type; -}; - -template -struct eval_if { - using type = typename FalseF::type; -}; - -template -using eval_if_t = typename eval_if::type; - -template -struct select { - using type = eval_if_t>; -}; - -template -struct select { - using type = T; -}; - -template -struct select> { - // last one had better be true! - static_assert(B, "No match select type!"); - using type = T; -}; - -template -using select_t = typename select::type; - -template -using Real = - select_t>::value, float>, - cond>::value, double>, - T>; - template using Complex = typename std::enable_if::value>::type; @@ -91,9 +47,9 @@ template struct RealFunctor; template -struct RealFunctor>> { +struct RealFunctor>> { public: - RealFunctor(const T* input, Real* output, int64_t numel) + RealFunctor(const T* input, dtype::Real* output, int64_t numel) : input_(input), output_(output), numel_(numel) {} HOSTDEVICE void operator()(int64_t idx) const { @@ -102,7 +58,7 @@ struct RealFunctor>> { private: const T* input_; - Real* output_; + dtype::Real* output_; int64_t numel_; }; @@ -110,8 +66,8 @@ template struct ImagFunctor; template -struct ImagFunctor>> { - ImagFunctor(const T* input, Real* output, int64_t numel) +struct ImagFunctor>> { + ImagFunctor(const T* input, dtype::Real* output, int64_t numel) : input_(input), output_(output), numel_(numel) {} HOSTDEVICE void operator()(int64_t idx) const { @@ -119,7 +75,7 @@ struct ImagFunctor>> { } const T* input_; - Real* output_; + dtype::Real* output_; int64_t numel_; }; @@ -127,8 +83,8 @@ template struct AbsFunctor; template -struct AbsFunctor>> { - AbsFunctor(const T* input, Real* output, int64_t numel) +struct AbsFunctor>> { + AbsFunctor(const T* input, dtype::Real* output, int64_t numel) : input_(input), output_(output), numel_(numel) {} HOSTDEVICE void operator()(int64_t idx) const { @@ -136,12 +92,12 @@ struct AbsFunctor>> { } const T* input_; - Real* output_; + dtype::Real* output_; int64_t numel_; }; template -struct AbsFunctor>> { +struct AbsFunctor>> { AbsFunctor(const T* input, T* output, int64_t numel) : input_(input), output_(output), numel_(numel) {} @@ -203,7 +159,10 @@ struct AbsGradCUDAFunctor> { template struct AbsGradFunctor { - AbsGradFunctor(const Real* dout, const T* x, T* output, int64_t numel) + AbsGradFunctor(const dtype::Real* dout, + const T* x, + T* output, + int64_t numel) : dout_(dout), x_(x), output_(output), numel_(numel) {} HOSTDEVICE void operator()(int64_t idx) const { @@ -214,7 +173,7 @@ struct AbsGradFunctor { } } - const Real* dout_; + const dtype::Real* dout_; const T* x_; T* output_; int64_t numel_; @@ -334,8 +293,8 @@ template struct RealToComplexFunctor; template -struct RealToComplexFunctor>> { - RealToComplexFunctor(const Real* input, T* output, int64_t numel) +struct RealToComplexFunctor>> { + RealToComplexFunctor(const dtype::Real* input, T* output, int64_t numel) : input_(input), output_(output), numel_(numel) {} HOSTDEVICE void operator()(int64_t idx) const { @@ -343,7 +302,7 @@ struct RealToComplexFunctor>> { output_[idx].imag = 0; } - const Real* input_; + const dtype::Real* input_; T* output_; int64_t numel_; }; @@ -352,8 +311,8 @@ template struct ImagToComplexFunctor; template -struct ImagToComplexFunctor>> { - ImagToComplexFunctor(const Real* input, T* output, int64_t numel) +struct ImagToComplexFunctor>> { + ImagToComplexFunctor(const dtype::Real* input, T* output, int64_t numel) : input_(input), output_(output), numel_(numel) {} HOSTDEVICE void operator()(int64_t idx) const { @@ -361,7 +320,7 @@ struct ImagToComplexFunctor>> { output_[idx].imag = input_[idx]; } - const Real* input_; + const dtype::Real* input_; T* output_; int64_t numel_; }; @@ -370,9 +329,9 @@ template struct RealImagToComplexFunctor; template -struct RealImagToComplexFunctor>> { - RealImagToComplexFunctor(const Real* input_real, - const Real* input_imag, +struct RealImagToComplexFunctor>> { + RealImagToComplexFunctor(const dtype::Real* input_real, + const dtype::Real* input_imag, T* output, int64_t numel) : input_real_(input_real), @@ -385,8 +344,8 @@ struct RealImagToComplexFunctor>> { output_[idx].imag = input_imag_[idx]; } - const Real* input_real_; - const Real* input_imag_; + const dtype::Real* input_real_; + const dtype::Real* input_imag_; T* output_; int64_t numel_; }; @@ -423,8 +382,8 @@ struct AngleFunctor; // angel function for complex template -struct AngleFunctor>> { - AngleFunctor(const T* input, phi::funcs::Real* output, int64_t numel) +struct AngleFunctor>> { + AngleFunctor(const T* input, dtype::Real* output, int64_t numel) : input_(input), output_(output), numel_(numel) {} HOSTDEVICE void operator()(int64_t idx) const { @@ -432,13 +391,13 @@ struct AngleFunctor>> { } const T* input_; - phi::funcs::Real* output_; + dtype::Real* output_; int64_t numel_; }; // angel function for real template -struct AngleFunctor>> { +struct AngleFunctor>> { AngleFunctor(const T* input, T* output, int64_t numel) : input_(input), output_(output), numel_(numel) {} @@ -456,25 +415,22 @@ struct AngleGradFunctor; // angle grad for complex template -struct AngleGradFunctor>> { - AngleGradFunctor(const phi::funcs::Real* dout, - const T* x, - T* dx, - int64_t numel) +struct AngleGradFunctor>> { + AngleGradFunctor(const dtype::Real* dout, const T* x, T* dx, int64_t numel) : dout_(dout), x_(x), dx_(dx), numel_(numel) {} HOSTDEVICE void operator()(int64_t idx) const { if (x_[idx] == T(0)) { dx_[idx] = T(0); } else { - const phi::funcs::Real r_square = + const phi::dtype::Real r_square = x_[idx].real * x_[idx].real + x_[idx].imag * x_[idx].imag; dx_[idx] = T(-dout_[idx] * x_[idx].imag / r_square, dout_[idx] * x_[idx].real / r_square); } } - const phi::funcs::Real* dout_; + const phi::dtype::Real* dout_; const T* x_; T* dx_; int64_t numel_; @@ -482,16 +438,13 @@ struct AngleGradFunctor>> { // angle grad for real template -struct AngleGradFunctor>> { - AngleGradFunctor(const phi::funcs::Real* dout, - const T* x, - T* dx, - int64_t numel) +struct AngleGradFunctor>> { + AngleGradFunctor(const dtype::Real* dout, const T* x, T* dx, int64_t numel) : dout_(dout), x_(x), dx_(dx), numel_(numel) {} HOSTDEVICE void operator()(int64_t idx) const { dx_[idx] = 0; } - const phi::funcs::Real* dout_; + const dtype::Real* dout_; const T* x_; T* dx_; int64_t numel_; diff --git a/paddle/phi/kernels/gpu/abs_kernel.cu b/paddle/phi/kernels/gpu/abs_kernel.cu index e122e6b1e9c..5c424316a83 100644 --- a/paddle/phi/kernels/gpu/abs_kernel.cu +++ b/paddle/phi/kernels/gpu/abs_kernel.cu @@ -27,14 +27,14 @@ template struct CudaAbsFunctor; template -struct CudaAbsFunctor>> { - __device__ __forceinline__ phi::funcs::Real operator()(const T x) const { +struct CudaAbsFunctor>> { + __device__ __forceinline__ phi::dtype::Real operator()(const T x) const { return abs(x); } }; template -struct CudaAbsFunctor>> { +struct CudaAbsFunctor>> { __device__ __forceinline__ T operator()(const T x) const { return std::abs(x); } @@ -42,12 +42,12 @@ struct CudaAbsFunctor>> { template void AbsKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) { - ctx.template Alloc>(out); + ctx.template Alloc>(out); std::vector ins = {&x}; std::vector outs = {out}; auto functor = CudaAbsFunctor(); - funcs::ElementwiseKernel>(ctx, ins, &outs, functor); + funcs::ElementwiseKernel>(ctx, ins, &outs, functor); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/complex_kernel.cu b/paddle/phi/kernels/gpu/complex_kernel.cu index d0b086718a4..e03e079581a 100644 --- a/paddle/phi/kernels/gpu/complex_kernel.cu +++ b/paddle/phi/kernels/gpu/complex_kernel.cu @@ -38,11 +38,15 @@ PD_REGISTER_KERNEL(real, ALL_LAYOUT, phi::RealKernel, phi::dtype::complex, - phi::dtype::complex) {} + phi::dtype::complex) { + kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); +} PD_REGISTER_KERNEL(imag, GPU, ALL_LAYOUT, phi::ImagKernel, phi::dtype::complex, - phi::dtype::complex) {} + phi::dtype::complex) { + kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); +} diff --git a/paddle/phi/kernels/impl/abs_grad_kernel_impl.h b/paddle/phi/kernels/impl/abs_grad_kernel_impl.h index 78c25200bbd..9dad40b57c9 100644 --- a/paddle/phi/kernels/impl/abs_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/abs_grad_kernel_impl.h @@ -47,7 +47,7 @@ void AbsGradKernel(const Context& ctx, const DenseTensor& dout, DenseTensor* dx) { auto numel = dout.numel(); - auto* dout_data = dout.data>(); + auto* dout_data = dout.data>(); auto* x_data = x.data(); ctx.template Alloc(dx, static_cast(numel * sizeof(T))); diff --git a/paddle/phi/kernels/impl/complex_grad_kernel_impl.h b/paddle/phi/kernels/impl/complex_grad_kernel_impl.h index a10481284b1..03896a2353d 100644 --- a/paddle/phi/kernels/impl/complex_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/complex_grad_kernel_impl.h @@ -24,7 +24,7 @@ void RealGradKernel(const Context& dev_ctx, const DenseTensor& dout, DenseTensor* dx) { auto numel = dout.numel(); - auto* dout_data = dout.data>(); + auto* dout_data = dout.data>(); auto* dx_data = dev_ctx.template Alloc(dx, static_cast(numel * sizeof(T))); @@ -38,7 +38,7 @@ void ImagGradKernel(const Context& dev_ctx, const DenseTensor& dout, DenseTensor* dx) { auto numel = dout.numel(); - auto* dout_data = dout.data>(); + auto* dout_data = dout.data>(); auto* dx_data = dev_ctx.template Alloc(dx, static_cast(numel * sizeof(T))); diff --git a/paddle/phi/kernels/impl/complex_kernel_impl.h b/paddle/phi/kernels/impl/complex_kernel_impl.h index ff5cf86ed2e..72b13288339 100644 --- a/paddle/phi/kernels/impl/complex_kernel_impl.h +++ b/paddle/phi/kernels/impl/complex_kernel_impl.h @@ -39,8 +39,8 @@ void RealKernel(const Context& dev_ctx, DenseTensor* out) { auto numel = x.numel(); auto* x_data = x.data(); - auto* out_data = dev_ctx.template Alloc>( - out, static_cast(numel * sizeof(phi::funcs::Real))); + auto* out_data = dev_ctx.template Alloc>( + out, static_cast(numel * sizeof(phi::dtype::Real))); phi::funcs::ForRange for_range(dev_ctx, numel); phi::funcs::RealFunctor functor(x_data, out_data, numel); @@ -53,8 +53,8 @@ void ImagKernel(const Context& dev_ctx, DenseTensor* out) { auto numel = x.numel(); auto* x_data = x.data(); - auto* out_data = dev_ctx.template Alloc>( - out, static_cast(numel * sizeof(phi::funcs::Real))); + auto* out_data = dev_ctx.template Alloc>( + out, static_cast(numel * sizeof(phi::dtype::Real))); phi::funcs::ForRange for_range(dev_ctx, numel); phi::funcs::ImagFunctor functor(x_data, out_data, numel); diff --git a/paddle/phi/tests/common/test_data_type.cc b/paddle/phi/tests/common/test_data_type.cc index c962c68b4d5..5a1b41d796d 100644 --- a/paddle/phi/tests/common/test_data_type.cc +++ b/paddle/phi/tests/common/test_data_type.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/phi/api/ext/exception.h" #include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/type_traits.h" namespace phi { namespace tests { @@ -71,5 +72,20 @@ TEST(DataType, OStream) { } } +TEST(TypeTraits, Complex) { + EXPECT_EQ(phi::dtype::ToReal(phi::DataType::COMPLEX64), + phi::DataType::FLOAT32); + EXPECT_EQ(phi::dtype::ToReal(phi::DataType::COMPLEX128), + phi::DataType::FLOAT64); + EXPECT_EQ(phi::dtype::ToReal(phi::DataType::FLOAT32), phi::DataType::FLOAT32); + + EXPECT_EQ(phi::dtype::ToComplex(phi::DataType::FLOAT32), + phi::DataType::COMPLEX64); + EXPECT_EQ(phi::dtype::ToComplex(phi::DataType::FLOAT64), + phi::DataType::COMPLEX128); + EXPECT_EQ(phi::dtype::ToComplex(phi::DataType::COMPLEX64), + phi::DataType::COMPLEX64); +} + } // namespace tests } // namespace phi -- GitLab