未验证 提交 0764fda2 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Unify complex type trait and fix real imag bug (#40036)

* unify complex type trait and fix real imag bug

* add unittest for type tratis
上级 b4d931e8
......@@ -36,8 +36,8 @@ class AngleKernel : public framework::OpKernel<T> {
auto numel = x->numel();
auto* x_data = x->data<T>();
auto* out_data = out->mutable_data<phi::funcs::Real<T>>(
context.GetPlace(), size_t(x->numel() * sizeof(phi::funcs::Real<T>)));
auto* out_data = out->mutable_data<phi::dtype::Real<T>>(
context.GetPlace(), size_t(x->numel() * sizeof(phi::dtype::Real<T>)));
auto& dev_ctx = context.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
......@@ -57,7 +57,7 @@ class AngleGradKernel : public framework::OpKernel<T> {
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto numel = d_out->numel();
auto* dout_data = d_out->data<phi::funcs::Real<T>>();
auto* dout_data = d_out->data<phi::dtype::Real<T>>();
auto* x_data = x->data<T>();
auto* dx_data = d_x->mutable_data<T>(
ctx.GetPlace(), static_cast<size_t>(numel * sizeof(T)));
......
......@@ -87,19 +87,19 @@ void LapackEig(Tensor* input, Tensor* values, Tensor* vectors, int info,
int values_stride = values->dims()[values->dims().size() - 1];
Tensor rwork;
phi::funcs::Real<T>* rwork_data = nullptr;
phi::dtype::Real<T>* rwork_data = nullptr;
rwork.Resize(phi::make_ddim({lda * 2}));
rwork_data = rwork.mutable_data<phi::funcs::Real<T>>(context.GetPlace());
rwork_data = rwork.mutable_data<phi::dtype::Real<T>>(context.GetPlace());
// call lapackEig once to compute the size of work;
T computed_work_size;
phi::funcs::lapackEig<T, phi::funcs::Real<T>>(
phi::funcs::lapackEig<T, phi::dtype::Real<T>>(
jobvl, jobvr, order, input_data, lda, values_data, lvector_data, ldvl,
rvector_data, ldvr, &computed_work_size, lwork, rwork_data, &info);
lwork = std::max<int>(
1, static_cast<int>(phi::funcs::Real<T>(computed_work_size)));
1, static_cast<int>(phi::dtype::Real<T>(computed_work_size)));
Tensor work;
work.Resize(phi::make_ddim({lwork}));
T* work_data = work.mutable_data<T>(context.GetPlace());
......@@ -109,7 +109,7 @@ void LapackEig(Tensor* input, Tensor* values, Tensor* vectors, int info,
T* current_values = &values_data[i * values_stride];
T* current_rvectors = &rvector_data[i * matrix_stride];
phi::funcs::lapackEig<T, phi::funcs::Real<T>>(
phi::funcs::lapackEig<T, phi::dtype::Real<T>>(
jobvl, jobvr, order, current_matrix, lda, current_values, lvector_data,
ldvl, current_rvectors, ldvr, work_data, lwork, rwork_data, &info);
PADDLE_ENFORCE_EQ(
......@@ -207,23 +207,23 @@ class EigKernel : public framework::OpKernel<T> {
origin_dim.push_back(last_item * 2);
framework::DDim big_dim = phi::make_ddim(origin_dim);
real_values.mutable_data<phi::funcs::Real<T>>(big_dim,
real_values.mutable_data<phi::dtype::Real<T>>(big_dim,
context.GetPlace());
real_vectors.mutable_data<phi::funcs::Real<T>>(x->dims(),
real_vectors.mutable_data<phi::dtype::Real<T>>(x->dims(),
context.GetPlace());
ApplyEigKernel<DeviceContext, phi::funcs::Real<T>>(
ApplyEigKernel<DeviceContext, phi::dtype::Real<T>>(
*x, &real_values, &real_vectors, context);
auto dito = math::DeviceIndependenceTensorOperations<
DeviceContext, phi::funcs::Real<T>, Tout>(context);
DeviceContext, phi::dtype::Real<T>, Tout>(context);
// 1. extract real part & imag part from real_values
Tensor real_part = dito.Slice(real_values, {-1}, {0}, {order});
Tensor imag_part = dito.Slice(real_values, {-1}, {order}, {order * 2});
// 2. construct complex values
auto* real_part_data = real_part.data<phi::funcs::Real<T>>();
auto* imag_part_data = imag_part.data<phi::funcs::Real<T>>();
auto* real_part_data = real_part.data<phi::dtype::Real<T>>();
auto* imag_part_data = imag_part.data<phi::dtype::Real<T>>();
int out_values_numel = out_values->numel();
platform::ForRange<DeviceContext> for_range(
context.template device_context<DeviceContext>(), out_values_numel);
......@@ -236,7 +236,7 @@ class EigKernel : public framework::OpKernel<T> {
Tensor real_vector_trans = dito.Transpose(real_vectors);
Tensor out_vectors_trans;
out_vectors_trans.mutable_data<Tout>(x->dims(), context.GetPlace());
ConstructComplexVectors<phi::funcs::Real<T>, Tout>(
ConstructComplexVectors<phi::dtype::Real<T>, Tout>(
&out_vectors_trans, *out_values, real_vector_trans, context,
batch_count, order);
TransposeTwoAxis<DeviceContext, Tout>(out_vectors_trans, out_vectors,
......@@ -272,7 +272,7 @@ void ComputeBackwardForComplexInput(
// turn diag_unsqueezed into complex
auto numel = diag_unsqueezed.numel();
Tensor diag_unsqueezed_complex;
auto* data_diag_un = diag_unsqueezed.data<phi::funcs::Real<Tout>>();
auto* data_diag_un = diag_unsqueezed.data<phi::dtype::Real<Tout>>();
auto* data_diag_un_com = diag_unsqueezed_complex.mutable_data<Tout>(
diag_unsqueezed.dims(), context.GetPlace(),
static_cast<size_t>(numel * sizeof(Tout)));
......
......@@ -40,7 +40,7 @@ template <typename DeviceContext, typename T>
class EighGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using ValueType = phi::funcs::Real<T>;
using ValueType = phi::dtype::Real<T>;
auto& x_grad = *ctx.Output<framework::Tensor>(framework::GradVarName("X"));
x_grad.mutable_data<T>(ctx.GetPlace());
auto& output_w = *ctx.Input<Tensor>("Eigenvalues");
......
......@@ -48,7 +48,7 @@ struct PaddleComplex<
template <typename T>
using PaddleCType = typename PaddleComplex<T>::type;
template <typename T>
using Real = typename phi::funcs::Real<T>;
using Real = typename phi::dtype::Real<T>;
static void SpiltBatchSquareMatrix(const Tensor& input,
std::vector<Tensor>* output) {
......@@ -144,7 +144,7 @@ LapackEigvals(const framework::ExecutionContext& ctx, const Tensor& input,
required_work_mem, work_mem));
int64_t rwork_mem = rwork->memory_size();
int64_t required_rwork_mem = (n_dim << 1) * sizeof(phi::funcs::Real<T>);
int64_t required_rwork_mem = (n_dim << 1) * sizeof(phi::dtype::Real<T>);
PADDLE_ENFORCE_GE(
rwork_mem, required_rwork_mem,
platform::errors::InvalidArgument(
......@@ -154,11 +154,11 @@ LapackEigvals(const framework::ExecutionContext& ctx, const Tensor& input,
required_rwork_mem, rwork_mem));
int info = 0;
phi::funcs::lapackEig<T, phi::funcs::Real<T>>(
phi::funcs::lapackEig<T, phi::dtype::Real<T>>(
'N', 'N', static_cast<int>(n_dim), a.template data<T>(),
static_cast<int>(n_dim), output->template data<T>(), NULL, 1, NULL, 1,
work->template data<T>(), static_cast<int>(work_mem / sizeof(T)),
rwork->template data<phi::funcs::Real<T>>(), &info);
rwork->template data<phi::dtype::Real<T>>(), &info);
std::string name = "framework::platform::dynload::cgeev_";
if (framework::TransToProtoVarType(input.dtype()) ==
......@@ -188,10 +188,10 @@ class EigvalsKernel : public framework::OpKernel<T> {
// query workspace size
T qwork;
int info;
phi::funcs::lapackEig<T, phi::funcs::Real<T>>(
phi::funcs::lapackEig<T, phi::dtype::Real<T>>(
'N', 'N', static_cast<int>(n_dim), input_matrices[0].template data<T>(),
static_cast<int>(n_dim), NULL, NULL, 1, NULL, 1, &qwork, -1,
static_cast<Real<T>*>(NULL), &info);
static_cast<phi::dtype::Real<T>*>(NULL), &info);
int64_t lwork = static_cast<int64_t>(qwork);
Tensor work, rwork;
......@@ -208,7 +208,7 @@ class EigvalsKernel : public framework::OpKernel<T> {
}
if (framework::IsComplexType(
framework::TransToProtoVarType(input->dtype()))) {
rwork.mutable_data<phi::funcs::Real<T>>(phi::make_ddim({n_dim << 1}),
rwork.mutable_data<phi::dtype::Real<T>>(phi::make_ddim({n_dim << 1}),
ctx.GetPlace());
}
......
......@@ -83,7 +83,7 @@ DECLARE_INPLACE_OP_INFERER(ImagGradOpInplaceInferer,
} // namespace paddle
DELCARE_INFER_SHAPE_FUNCTOR(imag, ImagInferShapeFunctor,
PT_INFER_META(phi::UnchangedInferMeta));
PT_INFER_META(phi::RealAndImagInferMeta));
namespace ops = paddle::operators;
......
......@@ -46,7 +46,7 @@ template <typename DeviceContext, typename T>
class LstsqCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
using ValueType = phi::funcs::Real<T>;
using ValueType = phi::dtype::Real<T>;
const Tensor& x = *context.Input<Tensor>("X");
auto y = context.Input<Tensor>("Y");
......@@ -169,7 +169,7 @@ class LstsqCPUKernel : public framework::OpKernel<T> {
&rank_32, &wkopt, lwork, &rwkopt, &info);
}
lwork = std::max<int>(1, static_cast<int>(phi::funcs::Real<T>(wkopt)));
lwork = std::max<int>(1, static_cast<int>(phi::dtype::Real<T>(wkopt)));
Tensor work;
work.Resize(phi::make_ddim({lwork}));
T* work_data = work.mutable_data<T>(context.GetPlace());
......
......@@ -63,7 +63,7 @@ struct MatrixEighFunctor<platform::CPUDeviceContext, T> {
void operator()(const framework::ExecutionContext &ctx, const Tensor &input,
Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower,
bool has_vectors) {
using ValueType = phi::funcs::Real<T>;
using ValueType = phi::dtype::Real<T>;
auto *out_value = eigen_values->mutable_data<ValueType>(ctx.GetPlace());
auto dito =
......@@ -123,7 +123,7 @@ struct MatrixEighFunctor<platform::CPUDeviceContext, T> {
for (auto i = 0; i < batch_size; i++) {
auto *value_data = out_value + i * values_stride;
auto *input_data = input_vector + i * vector_stride;
phi::funcs::lapackEigh<T, phi::funcs::Real<T>>(
phi::funcs::lapackEigh<T, phi::dtype::Real<T>>(
jobz, uplo, n, input_data, lda, value_data, work_data, lwork,
rwork_data, lrwork, iwork_data, liwork, &info);
CheckEighResult(i, info);
......@@ -151,7 +151,7 @@ struct MatrixEighFunctor<platform::CUDADeviceContext, T> {
void operator()(const framework::ExecutionContext &ctx, const Tensor &input,
Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower,
bool has_vectors) {
using ValueType = phi::funcs::Real<T>;
using ValueType = phi::dtype::Real<T>;
auto *out_value = eigen_values->mutable_data<ValueType>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
......@@ -233,7 +233,7 @@ struct MatrixEighFunctor<platform::CUDADeviceContext, T> {
}
}
using ValueType = phi::funcs::Real<T>;
using ValueType = phi::dtype::Real<T>;
inline void EvdBuffer(cusolverDnHandle_t handle, cusolverEigMode_t jobz,
cublasFillMode_t uplo, int n, const T *A, int lda,
const ValueType *W, int *lwork) const;
......
......@@ -115,7 +115,7 @@ static __global__ void InclusiveScanInnerDimCUDAKernel(const T *x, T *y,
size_t num_rows,
size_t row_size, T init,
BinaryOp op) {
using RealT = phi::funcs::Real<T>;
using RealT = phi::dtype::Real<T>;
constexpr auto kSharedBufferSize =
framework::IsComplex<T>::value ? 4 * kThreadNumX : 2 * kThreadNumX;
__shared__ RealT sbuf[kThreadNumY][kSharedBufferSize];
......
......@@ -56,13 +56,13 @@ class QrGPUKernel : public framework::OpKernel<T> {
int tau_stride = min_mn;
if (compute_q) {
q.mutable_data<phi::funcs::Real<T>>(
q.mutable_data<phi::dtype::Real<T>>(
context.GetPlace(),
size_t(batch_size * m * k * sizeof(phi::funcs::Real<T>)));
size_t(batch_size * m * k * sizeof(phi::dtype::Real<T>)));
}
r.mutable_data<phi::funcs::Real<T>>(
r.mutable_data<phi::dtype::Real<T>>(
context.GetPlace(),
size_t(batch_size * k * n * sizeof(phi::funcs::Real<T>)));
size_t(batch_size * k * n * sizeof(phi::dtype::Real<T>)));
auto dito =
math::DeviceIndependenceTensorOperations<platform::CUDADeviceContext,
......@@ -71,9 +71,9 @@ class QrGPUKernel : public framework::OpKernel<T> {
// Note: allocate temporary tensors because of lacking in-place operatios.
// Prepare qr
Tensor qr;
qr.mutable_data<phi::funcs::Real<T>>(
qr.mutable_data<phi::dtype::Real<T>>(
context.GetPlace(),
size_t(batch_size * m * n * sizeof(phi::funcs::Real<T>)));
size_t(batch_size * m * n * sizeof(phi::dtype::Real<T>)));
// BatchedGeqrf performs computation in-place and 'qr' must be a copy of
// input
paddle::framework::TensorCopy(x, context.GetPlace(), &qr);
......@@ -126,7 +126,7 @@ class QrGPUKernel : public framework::OpKernel<T> {
for (int i = 0; i < batch_size; ++i) {
memory::Copy(dev_ctx.GetPlace(), (new_qr_data + i * new_qr_stride),
dev_ctx.GetPlace(), (qr_data + i * qr_stride),
qr_stride * sizeof(phi::funcs::Real<T>),
qr_stride * sizeof(phi::dtype::Real<T>),
dev_ctx.stream());
}
BatchedOrgqr<platform::CUDADeviceContext, T>(
......
......@@ -74,19 +74,19 @@ class QrCPUKernel : public framework::OpKernel<T> {
int q_stride = m * k;
int r_stride = k * n;
auto* x_data = x.data<phi::funcs::Real<T>>();
auto* x_data = x.data<phi::dtype::Real<T>>();
T* q_data = nullptr;
if (compute_q) {
q_data = q.mutable_data<phi::funcs::Real<T>>(
q_data = q.mutable_data<phi::dtype::Real<T>>(
context.GetPlace(),
size_t(batch_size * m * k * sizeof(phi::funcs::Real<T>)));
size_t(batch_size * m * k * sizeof(phi::dtype::Real<T>)));
memset(q_data, 0,
size_t(batch_size * m * k * sizeof(phi::funcs::Real<T>)));
size_t(batch_size * m * k * sizeof(phi::dtype::Real<T>)));
}
auto* r_data = r.mutable_data<phi::funcs::Real<T>>(
auto* r_data = r.mutable_data<phi::dtype::Real<T>>(
context.GetPlace(),
size_t(batch_size * k * n * sizeof(phi::funcs::Real<T>)));
memset(r_data, 0, size_t(batch_size * k * n * sizeof(phi::funcs::Real<T>)));
size_t(batch_size * k * n * sizeof(phi::dtype::Real<T>)));
memset(r_data, 0, size_t(batch_size * k * n * sizeof(phi::dtype::Real<T>)));
// Implement QR by calling Eigen
for (int i = 0; i < batch_size; ++i) {
......@@ -142,7 +142,7 @@ class QrGradKernel : public framework::OpKernel<T> {
// Use a different name dA instead of dX
framework::Tensor& dA =
*ctx.Output<framework::Tensor>(framework::GradVarName("X"));
dA.mutable_data<phi::funcs::Real<T>>(ctx.GetPlace());
dA.mutable_data<phi::dtype::Real<T>>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
phi::funcs::SetConstant<DeviceContext, T>()(dev_ctx, &dA, T(0));
......@@ -224,7 +224,7 @@ class QrGradKernel : public framework::OpKernel<T> {
} else {
// If m < n for input matrices A, we partition A = [X|Y] and R = [U|V]
// Calculate dX and dY individually and concatenate them to get dA
dA.mutable_data<phi::funcs::Real<T>>(ctx.GetPlace());
dA.mutable_data<phi::dtype::Real<T>>(ctx.GetPlace());
auto Y = dito.Slice(A, {-1}, {m}, {n});
auto U = dito.Slice(R, {-1}, {0}, {m});
......
......@@ -83,7 +83,7 @@ DECLARE_INPLACE_OP_INFERER(RealGradOpInplaceInferer,
} // namespace paddle
DELCARE_INFER_SHAPE_FUNCTOR(real, RealInferShapeFunctor,
PT_INFER_META(phi::UnchangedInferMeta));
PT_INFER_META(phi::RealAndImagInferMeta));
namespace ops = paddle::operators;
......
......@@ -105,7 +105,7 @@ struct RealMulComplexFunctor {
"The image part of y must to be 0"
"but got [%d]",
y.imag));
return platform::complex<phi::funcs::Real<T>>(x.real * y.real,
return platform::complex<phi::dtype::Real<T>>(x.real * y.real,
x.imag * y.real);
}
};
......@@ -391,11 +391,11 @@ struct DeviceIndependenceTensorOperations {
// batch_diag for CPU only
Tensor BatchDiag(const Tensor& x, int batch) {
Tensor out;
auto* x_data = x.data<phi::funcs::Real<T>>();
auto* x_data = x.data<phi::dtype::Real<T>>();
auto numel = x.numel();
auto* out_data = out.mutable_data<phi::funcs::Real<T>>(
auto* out_data = out.mutable_data<phi::dtype::Real<T>>(
x.dims(), context.GetPlace(),
static_cast<size_t>(numel * sizeof(phi::funcs::Real<T>)));
static_cast<size_t>(numel * sizeof(phi::dtype::Real<T>)));
auto x_dims = x.dims();
int num_dims = x_dims.size();
......@@ -661,9 +661,9 @@ struct DeviceIndependenceTensorOperations {
Tensor Real(const Tensor& x) {
Tensor out;
auto numel = x.numel();
auto* out_data = out.mutable_data<phi::funcs::Real<T>>(
auto* out_data = out.mutable_data<phi::dtype::Real<T>>(
x.dims(), context.GetPlace(),
static_cast<size_t>(numel * sizeof(phi::funcs::Real<T>)));
static_cast<size_t>(numel * sizeof(phi::dtype::Real<T>)));
auto* x_data = x.data<T>();
auto for_range = GetForRange(numel);
phi::funcs::RealFunctor<T> functor(x_data, out_data, numel);
......
......@@ -46,14 +46,14 @@ class SvdCPUKernel : public framework::OpKernel<T> {
int col_u = full ? rows : k;
int col_v = full ? cols : k;
int batches = numel / (rows * cols);
auto* U_out = U->mutable_data<phi::funcs::Real<T>>(
auto* U_out = U->mutable_data<phi::dtype::Real<T>>(
context.GetPlace(),
size_t(batches * rows * col_u * sizeof(phi::funcs::Real<T>)));
auto* VH_out = VH->mutable_data<phi::funcs::Real<T>>(
size_t(batches * rows * col_u * sizeof(phi::dtype::Real<T>)));
auto* VH_out = VH->mutable_data<phi::dtype::Real<T>>(
context.GetPlace(),
size_t(batches * col_v * cols * sizeof(phi::funcs::Real<T>)));
auto* S_out = S->mutable_data<phi::funcs::Real<T>>(
context.GetPlace(), size_t(batches * k * sizeof(phi::funcs::Real<T>)));
size_t(batches * col_v * cols * sizeof(phi::dtype::Real<T>)));
auto* S_out = S->mutable_data<phi::dtype::Real<T>>(
context.GetPlace(), size_t(batches * k * sizeof(phi::dtype::Real<T>)));
/*SVD Use the Eigen Library*/
math::BatchSvd<T>(x_data, U_out, VH_out, S_out, rows, cols, batches, full);
}
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/data_type.h"
namespace phi {
namespace dtype {
template <bool B, typename T>
struct cond {
static constexpr bool value = B;
using type = T;
};
template <bool B, typename TrueF, typename FalseF>
struct eval_if {
using type = typename TrueF::type;
};
template <typename TrueF, typename FalseF>
struct eval_if<false, TrueF, FalseF> {
using type = typename FalseF::type;
};
template <bool B, typename T, typename F>
using eval_if_t = typename eval_if<B, T, F>::type;
template <typename Head, typename... Tail>
struct select {
using type = eval_if_t<Head::value, Head, select<Tail...>>;
};
template <typename T>
struct select<T> {
using type = T;
};
template <bool B, typename T>
struct select<cond<B, T>> {
// last one had better be true!
static_assert(B, "No match select type!");
using type = T;
};
template <typename Head, typename... Tail>
using select_t = typename select<Head, Tail...>::type;
// runtime real and complex type conversion
template <typename T>
using Real = select_t<cond<std::is_same<T, complex<float>>::value, float>,
cond<std::is_same<T, complex<double>>::value, double>,
T>;
template <typename T>
using Complex = select_t<cond<std::is_same<T, float>::value, complex<float>>,
cond<std::is_same<T, double>::value, complex<double>>,
T>;
inline DataType ToReal(DataType dtype) {
switch (dtype) {
case phi::DataType::COMPLEX64:
return phi::DataType::FLOAT32;
case phi::DataType::COMPLEX128:
return phi::DataType::FLOAT64;
default:
return dtype;
}
}
inline DataType ToComplex(DataType dtype) {
switch (dtype) {
case phi::DataType::FLOAT32:
return phi::DataType::COMPLEX64;
case phi::DataType::FLOAT64:
return phi::DataType::COMPLEX128;
default:
return dtype;
}
}
} // namespace dtype
} // namespace phi
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <algorithm>
#include <set>
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/kernels/funcs/unfold_functor.h"
......@@ -51,6 +52,12 @@ void UnchangedInferMetaCheckAxis(const MetaTensor& x,
out->share_meta(x);
}
void RealAndImagInferMeta(const MetaTensor& x, MetaTensor* out) {
out->set_dims(x.dims());
out->set_dtype(dtype::ToReal(x.dtype()));
out->set_layout(x.layout());
}
void FlattenInferMeta(const MetaTensor& x,
int start_axis,
int stop_axis,
......
......@@ -39,6 +39,8 @@ void UnchangedInferMetaCheckAxis(const MetaTensor& x,
int axis,
MetaTensor* out);
void RealAndImagInferMeta(const MetaTensor& x, MetaTensor* out);
void FlattenInferMeta(const MetaTensor& x,
int start_axis,
int stop_axis,
......
......@@ -25,9 +25,9 @@ template <typename T, typename Context>
void AbsKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
auto numel = x.numel();
auto* x_data = x.data<T>();
ctx.template Alloc<phi::funcs::Real<T>>(
out, size_t(x.numel() * sizeof(phi::funcs::Real<T>)));
auto* out_data = out->data<phi::funcs::Real<T>>();
ctx.template Alloc<phi::dtype::Real<T>>(
out, size_t(x.numel() * sizeof(phi::dtype::Real<T>)));
auto* out_data = out->data<phi::dtype::Real<T>>();
phi::funcs::ForRange<Context> for_range(ctx, numel);
phi::funcs::AbsFunctor<T> functor(x_data, out_data, numel);
......
......@@ -37,11 +37,15 @@ PD_REGISTER_KERNEL(real,
ALL_LAYOUT,
phi::RealKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
PD_REGISTER_KERNEL(imag,
CPU,
ALL_LAYOUT,
phi::ImagKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
......@@ -20,56 +20,12 @@ limitations under the License. */
#include <type_traits>
#include "paddle/phi/common/complex.h"
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/hostdevice.h"
namespace phi {
namespace funcs {
template <bool B, typename T>
struct cond {
static constexpr bool value = B;
using type = T;
};
template <bool B, typename TrueF, typename FalseF>
struct eval_if {
using type = typename TrueF::type;
};
template <typename TrueF, typename FalseF>
struct eval_if<false, TrueF, FalseF> {
using type = typename FalseF::type;
};
template <bool B, typename T, typename F>
using eval_if_t = typename eval_if<B, T, F>::type;
template <typename Head, typename... Tail>
struct select {
using type = eval_if_t<Head::value, Head, select<Tail...>>;
};
template <typename T>
struct select<T> {
using type = T;
};
template <bool B, typename T>
struct select<cond<B, T>> {
// last one had better be true!
static_assert(B, "No match select type!");
using type = T;
};
template <typename Head, typename... Tail>
using select_t = typename select<Head, Tail...>::type;
template <typename T>
using Real =
select_t<cond<std::is_same<T, phi::dtype::complex<float>>::value, float>,
cond<std::is_same<T, phi::dtype::complex<double>>::value, double>,
T>;
template <typename T, typename RealT>
using Complex = typename std::enable_if<!std::is_same<T, RealT>::value>::type;
......@@ -91,9 +47,9 @@ template <typename T, typename Enable = void>
struct RealFunctor;
template <typename T>
struct RealFunctor<T, Complex<T, Real<T>>> {
struct RealFunctor<T, Complex<T, dtype::Real<T>>> {
public:
RealFunctor(const T* input, Real<T>* output, int64_t numel)
RealFunctor(const T* input, dtype::Real<T>* output, int64_t numel)
: input_(input), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
......@@ -102,7 +58,7 @@ struct RealFunctor<T, Complex<T, Real<T>>> {
private:
const T* input_;
Real<T>* output_;
dtype::Real<T>* output_;
int64_t numel_;
};
......@@ -110,8 +66,8 @@ template <typename T, typename Enable = void>
struct ImagFunctor;
template <typename T>
struct ImagFunctor<T, Complex<T, Real<T>>> {
ImagFunctor(const T* input, Real<T>* output, int64_t numel)
struct ImagFunctor<T, Complex<T, dtype::Real<T>>> {
ImagFunctor(const T* input, dtype::Real<T>* output, int64_t numel)
: input_(input), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
......@@ -119,7 +75,7 @@ struct ImagFunctor<T, Complex<T, Real<T>>> {
}
const T* input_;
Real<T>* output_;
dtype::Real<T>* output_;
int64_t numel_;
};
......@@ -127,8 +83,8 @@ template <typename T, typename Enable = void>
struct AbsFunctor;
template <typename T>
struct AbsFunctor<T, Complex<T, Real<T>>> {
AbsFunctor(const T* input, Real<T>* output, int64_t numel)
struct AbsFunctor<T, Complex<T, dtype::Real<T>>> {
AbsFunctor(const T* input, dtype::Real<T>* output, int64_t numel)
: input_(input), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
......@@ -136,12 +92,12 @@ struct AbsFunctor<T, Complex<T, Real<T>>> {
}
const T* input_;
Real<T>* output_;
dtype::Real<T>* output_;
int64_t numel_;
};
template <typename T>
struct AbsFunctor<T, NoComplex<T, Real<T>>> {
struct AbsFunctor<T, NoComplex<T, dtype::Real<T>>> {
AbsFunctor(const T* input, T* output, int64_t numel)
: input_(input), output_(output), numel_(numel) {}
......@@ -203,7 +159,10 @@ struct AbsGradCUDAFunctor<phi::dtype::complex<double>> {
template <typename T>
struct AbsGradFunctor {
AbsGradFunctor(const Real<T>* dout, const T* x, T* output, int64_t numel)
AbsGradFunctor(const dtype::Real<T>* dout,
const T* x,
T* output,
int64_t numel)
: dout_(dout), x_(x), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
......@@ -214,7 +173,7 @@ struct AbsGradFunctor {
}
}
const Real<T>* dout_;
const dtype::Real<T>* dout_;
const T* x_;
T* output_;
int64_t numel_;
......@@ -334,8 +293,8 @@ template <typename T, typename Enable = void>
struct RealToComplexFunctor;
template <typename T>
struct RealToComplexFunctor<T, Complex<T, Real<T>>> {
RealToComplexFunctor(const Real<T>* input, T* output, int64_t numel)
struct RealToComplexFunctor<T, Complex<T, dtype::Real<T>>> {
RealToComplexFunctor(const dtype::Real<T>* input, T* output, int64_t numel)
: input_(input), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
......@@ -343,7 +302,7 @@ struct RealToComplexFunctor<T, Complex<T, Real<T>>> {
output_[idx].imag = 0;
}
const Real<T>* input_;
const dtype::Real<T>* input_;
T* output_;
int64_t numel_;
};
......@@ -352,8 +311,8 @@ template <typename T, typename Enable = void>
struct ImagToComplexFunctor;
template <typename T>
struct ImagToComplexFunctor<T, Complex<T, Real<T>>> {
ImagToComplexFunctor(const Real<T>* input, T* output, int64_t numel)
struct ImagToComplexFunctor<T, Complex<T, dtype::Real<T>>> {
ImagToComplexFunctor(const dtype::Real<T>* input, T* output, int64_t numel)
: input_(input), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
......@@ -361,7 +320,7 @@ struct ImagToComplexFunctor<T, Complex<T, Real<T>>> {
output_[idx].imag = input_[idx];
}
const Real<T>* input_;
const dtype::Real<T>* input_;
T* output_;
int64_t numel_;
};
......@@ -370,9 +329,9 @@ template <typename T, typename Enable = void>
struct RealImagToComplexFunctor;
template <typename T>
struct RealImagToComplexFunctor<T, Complex<T, Real<T>>> {
RealImagToComplexFunctor(const Real<T>* input_real,
const Real<T>* input_imag,
struct RealImagToComplexFunctor<T, Complex<T, dtype::Real<T>>> {
RealImagToComplexFunctor(const dtype::Real<T>* input_real,
const dtype::Real<T>* input_imag,
T* output,
int64_t numel)
: input_real_(input_real),
......@@ -385,8 +344,8 @@ struct RealImagToComplexFunctor<T, Complex<T, Real<T>>> {
output_[idx].imag = input_imag_[idx];
}
const Real<T>* input_real_;
const Real<T>* input_imag_;
const dtype::Real<T>* input_real_;
const dtype::Real<T>* input_imag_;
T* output_;
int64_t numel_;
};
......@@ -423,8 +382,8 @@ struct AngleFunctor;
// angel function for complex
template <typename T>
struct AngleFunctor<T, phi::funcs::Complex<T, phi::funcs::Real<T>>> {
AngleFunctor(const T* input, phi::funcs::Real<T>* output, int64_t numel)
struct AngleFunctor<T, phi::funcs::Complex<T, dtype::Real<T>>> {
AngleFunctor(const T* input, dtype::Real<T>* output, int64_t numel)
: input_(input), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
......@@ -432,13 +391,13 @@ struct AngleFunctor<T, phi::funcs::Complex<T, phi::funcs::Real<T>>> {
}
const T* input_;
phi::funcs::Real<T>* output_;
dtype::Real<T>* output_;
int64_t numel_;
};
// angel function for real
template <typename T>
struct AngleFunctor<T, phi::funcs::NoComplex<T, phi::funcs::Real<T>>> {
struct AngleFunctor<T, phi::funcs::NoComplex<T, dtype::Real<T>>> {
AngleFunctor(const T* input, T* output, int64_t numel)
: input_(input), output_(output), numel_(numel) {}
......@@ -456,25 +415,22 @@ struct AngleGradFunctor;
// angle grad for complex
template <typename T>
struct AngleGradFunctor<T, phi::funcs::Complex<T, phi::funcs::Real<T>>> {
AngleGradFunctor(const phi::funcs::Real<T>* dout,
const T* x,
T* dx,
int64_t numel)
struct AngleGradFunctor<T, phi::funcs::Complex<T, dtype::Real<T>>> {
AngleGradFunctor(const dtype::Real<T>* dout, const T* x, T* dx, int64_t numel)
: dout_(dout), x_(x), dx_(dx), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
if (x_[idx] == T(0)) {
dx_[idx] = T(0);
} else {
const phi::funcs::Real<T> r_square =
const phi::dtype::Real<T> r_square =
x_[idx].real * x_[idx].real + x_[idx].imag * x_[idx].imag;
dx_[idx] = T(-dout_[idx] * x_[idx].imag / r_square,
dout_[idx] * x_[idx].real / r_square);
}
}
const phi::funcs::Real<T>* dout_;
const phi::dtype::Real<T>* dout_;
const T* x_;
T* dx_;
int64_t numel_;
......@@ -482,16 +438,13 @@ struct AngleGradFunctor<T, phi::funcs::Complex<T, phi::funcs::Real<T>>> {
// angle grad for real
template <typename T>
struct AngleGradFunctor<T, phi::funcs::NoComplex<T, phi::funcs::Real<T>>> {
AngleGradFunctor(const phi::funcs::Real<T>* dout,
const T* x,
T* dx,
int64_t numel)
struct AngleGradFunctor<T, phi::funcs::NoComplex<T, dtype::Real<T>>> {
AngleGradFunctor(const dtype::Real<T>* dout, const T* x, T* dx, int64_t numel)
: dout_(dout), x_(x), dx_(dx), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const { dx_[idx] = 0; }
const phi::funcs::Real<T>* dout_;
const dtype::Real<T>* dout_;
const T* x_;
T* dx_;
int64_t numel_;
......
......@@ -27,14 +27,14 @@ template <typename T, typename Enable = void>
struct CudaAbsFunctor;
template <typename T>
struct CudaAbsFunctor<T, phi::funcs::Complex<T, phi::funcs::Real<T>>> {
__device__ __forceinline__ phi::funcs::Real<T> operator()(const T x) const {
struct CudaAbsFunctor<T, phi::funcs::Complex<T, phi::dtype::Real<T>>> {
__device__ __forceinline__ phi::dtype::Real<T> operator()(const T x) const {
return abs(x);
}
};
template <typename T>
struct CudaAbsFunctor<T, phi::funcs::NoComplex<T, phi::funcs::Real<T>>> {
struct CudaAbsFunctor<T, phi::funcs::NoComplex<T, phi::dtype::Real<T>>> {
__device__ __forceinline__ T operator()(const T x) const {
return std::abs(x);
}
......@@ -42,12 +42,12 @@ struct CudaAbsFunctor<T, phi::funcs::NoComplex<T, phi::funcs::Real<T>>> {
template <typename T, typename Context>
void AbsKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
ctx.template Alloc<phi::funcs::Real<T>>(out);
ctx.template Alloc<phi::dtype::Real<T>>(out);
std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out};
auto functor = CudaAbsFunctor<T>();
funcs::ElementwiseKernel<phi::funcs::Real<T>>(ctx, ins, &outs, functor);
funcs::ElementwiseKernel<phi::dtype::Real<T>>(ctx, ins, &outs, functor);
}
} // namespace phi
......
......@@ -38,11 +38,15 @@ PD_REGISTER_KERNEL(real,
ALL_LAYOUT,
phi::RealKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
PD_REGISTER_KERNEL(imag,
GPU,
ALL_LAYOUT,
phi::ImagKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
......@@ -47,7 +47,7 @@ void AbsGradKernel(const Context& ctx,
const DenseTensor& dout,
DenseTensor* dx) {
auto numel = dout.numel();
auto* dout_data = dout.data<phi::funcs::Real<T>>();
auto* dout_data = dout.data<phi::dtype::Real<T>>();
auto* x_data = x.data<T>();
ctx.template Alloc<T>(dx, static_cast<size_t>(numel * sizeof(T)));
......
......@@ -24,7 +24,7 @@ void RealGradKernel(const Context& dev_ctx,
const DenseTensor& dout,
DenseTensor* dx) {
auto numel = dout.numel();
auto* dout_data = dout.data<phi::funcs::Real<T>>();
auto* dout_data = dout.data<phi::dtype::Real<T>>();
auto* dx_data =
dev_ctx.template Alloc<T>(dx, static_cast<size_t>(numel * sizeof(T)));
......@@ -38,7 +38,7 @@ void ImagGradKernel(const Context& dev_ctx,
const DenseTensor& dout,
DenseTensor* dx) {
auto numel = dout.numel();
auto* dout_data = dout.data<phi::funcs::Real<T>>();
auto* dout_data = dout.data<phi::dtype::Real<T>>();
auto* dx_data =
dev_ctx.template Alloc<T>(dx, static_cast<size_t>(numel * sizeof(T)));
......
......@@ -39,8 +39,8 @@ void RealKernel(const Context& dev_ctx,
DenseTensor* out) {
auto numel = x.numel();
auto* x_data = x.data<T>();
auto* out_data = dev_ctx.template Alloc<phi::funcs::Real<T>>(
out, static_cast<size_t>(numel * sizeof(phi::funcs::Real<T>)));
auto* out_data = dev_ctx.template Alloc<phi::dtype::Real<T>>(
out, static_cast<size_t>(numel * sizeof(phi::dtype::Real<T>)));
phi::funcs::ForRange<Context> for_range(dev_ctx, numel);
phi::funcs::RealFunctor<T> functor(x_data, out_data, numel);
......@@ -53,8 +53,8 @@ void ImagKernel(const Context& dev_ctx,
DenseTensor* out) {
auto numel = x.numel();
auto* x_data = x.data<T>();
auto* out_data = dev_ctx.template Alloc<phi::funcs::Real<T>>(
out, static_cast<size_t>(numel * sizeof(phi::funcs::Real<T>)));
auto* out_data = dev_ctx.template Alloc<phi::dtype::Real<T>>(
out, static_cast<size_t>(numel * sizeof(phi::dtype::Real<T>)));
phi::funcs::ForRange<Context> for_range(dev_ctx, numel);
phi::funcs::ImagFunctor<T> functor(x_data, out_data, numel);
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/phi/api/ext/exception.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/type_traits.h"
namespace phi {
namespace tests {
......@@ -71,5 +72,20 @@ TEST(DataType, OStream) {
}
}
TEST(TypeTraits, Complex) {
EXPECT_EQ(phi::dtype::ToReal(phi::DataType::COMPLEX64),
phi::DataType::FLOAT32);
EXPECT_EQ(phi::dtype::ToReal(phi::DataType::COMPLEX128),
phi::DataType::FLOAT64);
EXPECT_EQ(phi::dtype::ToReal(phi::DataType::FLOAT32), phi::DataType::FLOAT32);
EXPECT_EQ(phi::dtype::ToComplex(phi::DataType::FLOAT32),
phi::DataType::COMPLEX64);
EXPECT_EQ(phi::dtype::ToComplex(phi::DataType::FLOAT64),
phi::DataType::COMPLEX128);
EXPECT_EQ(phi::dtype::ToComplex(phi::DataType::COMPLEX64),
phi::DataType::COMPLEX64);
}
} // namespace tests
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册