未验证 提交 0764fda2 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Unify complex type trait and fix real imag bug (#40036)

* unify complex type trait and fix real imag bug

* add unittest for type tratis
上级 b4d931e8
...@@ -36,8 +36,8 @@ class AngleKernel : public framework::OpKernel<T> { ...@@ -36,8 +36,8 @@ class AngleKernel : public framework::OpKernel<T> {
auto numel = x->numel(); auto numel = x->numel();
auto* x_data = x->data<T>(); auto* x_data = x->data<T>();
auto* out_data = out->mutable_data<phi::funcs::Real<T>>( auto* out_data = out->mutable_data<phi::dtype::Real<T>>(
context.GetPlace(), size_t(x->numel() * sizeof(phi::funcs::Real<T>))); context.GetPlace(), size_t(x->numel() * sizeof(phi::dtype::Real<T>)));
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, numel); platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
...@@ -57,7 +57,7 @@ class AngleGradKernel : public framework::OpKernel<T> { ...@@ -57,7 +57,7 @@ class AngleGradKernel : public framework::OpKernel<T> {
ctx.Output<framework::Tensor>(framework::GradVarName("X")); ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto numel = d_out->numel(); auto numel = d_out->numel();
auto* dout_data = d_out->data<phi::funcs::Real<T>>(); auto* dout_data = d_out->data<phi::dtype::Real<T>>();
auto* x_data = x->data<T>(); auto* x_data = x->data<T>();
auto* dx_data = d_x->mutable_data<T>( auto* dx_data = d_x->mutable_data<T>(
ctx.GetPlace(), static_cast<size_t>(numel * sizeof(T))); ctx.GetPlace(), static_cast<size_t>(numel * sizeof(T)));
......
...@@ -87,19 +87,19 @@ void LapackEig(Tensor* input, Tensor* values, Tensor* vectors, int info, ...@@ -87,19 +87,19 @@ void LapackEig(Tensor* input, Tensor* values, Tensor* vectors, int info,
int values_stride = values->dims()[values->dims().size() - 1]; int values_stride = values->dims()[values->dims().size() - 1];
Tensor rwork; Tensor rwork;
phi::funcs::Real<T>* rwork_data = nullptr; phi::dtype::Real<T>* rwork_data = nullptr;
rwork.Resize(phi::make_ddim({lda * 2})); rwork.Resize(phi::make_ddim({lda * 2}));
rwork_data = rwork.mutable_data<phi::funcs::Real<T>>(context.GetPlace()); rwork_data = rwork.mutable_data<phi::dtype::Real<T>>(context.GetPlace());
// call lapackEig once to compute the size of work; // call lapackEig once to compute the size of work;
T computed_work_size; T computed_work_size;
phi::funcs::lapackEig<T, phi::funcs::Real<T>>( phi::funcs::lapackEig<T, phi::dtype::Real<T>>(
jobvl, jobvr, order, input_data, lda, values_data, lvector_data, ldvl, jobvl, jobvr, order, input_data, lda, values_data, lvector_data, ldvl,
rvector_data, ldvr, &computed_work_size, lwork, rwork_data, &info); rvector_data, ldvr, &computed_work_size, lwork, rwork_data, &info);
lwork = std::max<int>( lwork = std::max<int>(
1, static_cast<int>(phi::funcs::Real<T>(computed_work_size))); 1, static_cast<int>(phi::dtype::Real<T>(computed_work_size)));
Tensor work; Tensor work;
work.Resize(phi::make_ddim({lwork})); work.Resize(phi::make_ddim({lwork}));
T* work_data = work.mutable_data<T>(context.GetPlace()); T* work_data = work.mutable_data<T>(context.GetPlace());
...@@ -109,7 +109,7 @@ void LapackEig(Tensor* input, Tensor* values, Tensor* vectors, int info, ...@@ -109,7 +109,7 @@ void LapackEig(Tensor* input, Tensor* values, Tensor* vectors, int info,
T* current_values = &values_data[i * values_stride]; T* current_values = &values_data[i * values_stride];
T* current_rvectors = &rvector_data[i * matrix_stride]; T* current_rvectors = &rvector_data[i * matrix_stride];
phi::funcs::lapackEig<T, phi::funcs::Real<T>>( phi::funcs::lapackEig<T, phi::dtype::Real<T>>(
jobvl, jobvr, order, current_matrix, lda, current_values, lvector_data, jobvl, jobvr, order, current_matrix, lda, current_values, lvector_data,
ldvl, current_rvectors, ldvr, work_data, lwork, rwork_data, &info); ldvl, current_rvectors, ldvr, work_data, lwork, rwork_data, &info);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -207,23 +207,23 @@ class EigKernel : public framework::OpKernel<T> { ...@@ -207,23 +207,23 @@ class EigKernel : public framework::OpKernel<T> {
origin_dim.push_back(last_item * 2); origin_dim.push_back(last_item * 2);
framework::DDim big_dim = phi::make_ddim(origin_dim); framework::DDim big_dim = phi::make_ddim(origin_dim);
real_values.mutable_data<phi::funcs::Real<T>>(big_dim, real_values.mutable_data<phi::dtype::Real<T>>(big_dim,
context.GetPlace()); context.GetPlace());
real_vectors.mutable_data<phi::funcs::Real<T>>(x->dims(), real_vectors.mutable_data<phi::dtype::Real<T>>(x->dims(),
context.GetPlace()); context.GetPlace());
ApplyEigKernel<DeviceContext, phi::funcs::Real<T>>( ApplyEigKernel<DeviceContext, phi::dtype::Real<T>>(
*x, &real_values, &real_vectors, context); *x, &real_values, &real_vectors, context);
auto dito = math::DeviceIndependenceTensorOperations< auto dito = math::DeviceIndependenceTensorOperations<
DeviceContext, phi::funcs::Real<T>, Tout>(context); DeviceContext, phi::dtype::Real<T>, Tout>(context);
// 1. extract real part & imag part from real_values // 1. extract real part & imag part from real_values
Tensor real_part = dito.Slice(real_values, {-1}, {0}, {order}); Tensor real_part = dito.Slice(real_values, {-1}, {0}, {order});
Tensor imag_part = dito.Slice(real_values, {-1}, {order}, {order * 2}); Tensor imag_part = dito.Slice(real_values, {-1}, {order}, {order * 2});
// 2. construct complex values // 2. construct complex values
auto* real_part_data = real_part.data<phi::funcs::Real<T>>(); auto* real_part_data = real_part.data<phi::dtype::Real<T>>();
auto* imag_part_data = imag_part.data<phi::funcs::Real<T>>(); auto* imag_part_data = imag_part.data<phi::dtype::Real<T>>();
int out_values_numel = out_values->numel(); int out_values_numel = out_values->numel();
platform::ForRange<DeviceContext> for_range( platform::ForRange<DeviceContext> for_range(
context.template device_context<DeviceContext>(), out_values_numel); context.template device_context<DeviceContext>(), out_values_numel);
...@@ -236,7 +236,7 @@ class EigKernel : public framework::OpKernel<T> { ...@@ -236,7 +236,7 @@ class EigKernel : public framework::OpKernel<T> {
Tensor real_vector_trans = dito.Transpose(real_vectors); Tensor real_vector_trans = dito.Transpose(real_vectors);
Tensor out_vectors_trans; Tensor out_vectors_trans;
out_vectors_trans.mutable_data<Tout>(x->dims(), context.GetPlace()); out_vectors_trans.mutable_data<Tout>(x->dims(), context.GetPlace());
ConstructComplexVectors<phi::funcs::Real<T>, Tout>( ConstructComplexVectors<phi::dtype::Real<T>, Tout>(
&out_vectors_trans, *out_values, real_vector_trans, context, &out_vectors_trans, *out_values, real_vector_trans, context,
batch_count, order); batch_count, order);
TransposeTwoAxis<DeviceContext, Tout>(out_vectors_trans, out_vectors, TransposeTwoAxis<DeviceContext, Tout>(out_vectors_trans, out_vectors,
...@@ -272,7 +272,7 @@ void ComputeBackwardForComplexInput( ...@@ -272,7 +272,7 @@ void ComputeBackwardForComplexInput(
// turn diag_unsqueezed into complex // turn diag_unsqueezed into complex
auto numel = diag_unsqueezed.numel(); auto numel = diag_unsqueezed.numel();
Tensor diag_unsqueezed_complex; Tensor diag_unsqueezed_complex;
auto* data_diag_un = diag_unsqueezed.data<phi::funcs::Real<Tout>>(); auto* data_diag_un = diag_unsqueezed.data<phi::dtype::Real<Tout>>();
auto* data_diag_un_com = diag_unsqueezed_complex.mutable_data<Tout>( auto* data_diag_un_com = diag_unsqueezed_complex.mutable_data<Tout>(
diag_unsqueezed.dims(), context.GetPlace(), diag_unsqueezed.dims(), context.GetPlace(),
static_cast<size_t>(numel * sizeof(Tout))); static_cast<size_t>(numel * sizeof(Tout)));
......
...@@ -40,7 +40,7 @@ template <typename DeviceContext, typename T> ...@@ -40,7 +40,7 @@ template <typename DeviceContext, typename T>
class EighGradKernel : public framework::OpKernel<T> { class EighGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using ValueType = phi::funcs::Real<T>; using ValueType = phi::dtype::Real<T>;
auto& x_grad = *ctx.Output<framework::Tensor>(framework::GradVarName("X")); auto& x_grad = *ctx.Output<framework::Tensor>(framework::GradVarName("X"));
x_grad.mutable_data<T>(ctx.GetPlace()); x_grad.mutable_data<T>(ctx.GetPlace());
auto& output_w = *ctx.Input<Tensor>("Eigenvalues"); auto& output_w = *ctx.Input<Tensor>("Eigenvalues");
......
...@@ -48,7 +48,7 @@ struct PaddleComplex< ...@@ -48,7 +48,7 @@ struct PaddleComplex<
template <typename T> template <typename T>
using PaddleCType = typename PaddleComplex<T>::type; using PaddleCType = typename PaddleComplex<T>::type;
template <typename T> template <typename T>
using Real = typename phi::funcs::Real<T>; using Real = typename phi::dtype::Real<T>;
static void SpiltBatchSquareMatrix(const Tensor& input, static void SpiltBatchSquareMatrix(const Tensor& input,
std::vector<Tensor>* output) { std::vector<Tensor>* output) {
...@@ -144,7 +144,7 @@ LapackEigvals(const framework::ExecutionContext& ctx, const Tensor& input, ...@@ -144,7 +144,7 @@ LapackEigvals(const framework::ExecutionContext& ctx, const Tensor& input,
required_work_mem, work_mem)); required_work_mem, work_mem));
int64_t rwork_mem = rwork->memory_size(); int64_t rwork_mem = rwork->memory_size();
int64_t required_rwork_mem = (n_dim << 1) * sizeof(phi::funcs::Real<T>); int64_t required_rwork_mem = (n_dim << 1) * sizeof(phi::dtype::Real<T>);
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
rwork_mem, required_rwork_mem, rwork_mem, required_rwork_mem,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -154,11 +154,11 @@ LapackEigvals(const framework::ExecutionContext& ctx, const Tensor& input, ...@@ -154,11 +154,11 @@ LapackEigvals(const framework::ExecutionContext& ctx, const Tensor& input,
required_rwork_mem, rwork_mem)); required_rwork_mem, rwork_mem));
int info = 0; int info = 0;
phi::funcs::lapackEig<T, phi::funcs::Real<T>>( phi::funcs::lapackEig<T, phi::dtype::Real<T>>(
'N', 'N', static_cast<int>(n_dim), a.template data<T>(), 'N', 'N', static_cast<int>(n_dim), a.template data<T>(),
static_cast<int>(n_dim), output->template data<T>(), NULL, 1, NULL, 1, static_cast<int>(n_dim), output->template data<T>(), NULL, 1, NULL, 1,
work->template data<T>(), static_cast<int>(work_mem / sizeof(T)), work->template data<T>(), static_cast<int>(work_mem / sizeof(T)),
rwork->template data<phi::funcs::Real<T>>(), &info); rwork->template data<phi::dtype::Real<T>>(), &info);
std::string name = "framework::platform::dynload::cgeev_"; std::string name = "framework::platform::dynload::cgeev_";
if (framework::TransToProtoVarType(input.dtype()) == if (framework::TransToProtoVarType(input.dtype()) ==
...@@ -188,10 +188,10 @@ class EigvalsKernel : public framework::OpKernel<T> { ...@@ -188,10 +188,10 @@ class EigvalsKernel : public framework::OpKernel<T> {
// query workspace size // query workspace size
T qwork; T qwork;
int info; int info;
phi::funcs::lapackEig<T, phi::funcs::Real<T>>( phi::funcs::lapackEig<T, phi::dtype::Real<T>>(
'N', 'N', static_cast<int>(n_dim), input_matrices[0].template data<T>(), 'N', 'N', static_cast<int>(n_dim), input_matrices[0].template data<T>(),
static_cast<int>(n_dim), NULL, NULL, 1, NULL, 1, &qwork, -1, static_cast<int>(n_dim), NULL, NULL, 1, NULL, 1, &qwork, -1,
static_cast<Real<T>*>(NULL), &info); static_cast<phi::dtype::Real<T>*>(NULL), &info);
int64_t lwork = static_cast<int64_t>(qwork); int64_t lwork = static_cast<int64_t>(qwork);
Tensor work, rwork; Tensor work, rwork;
...@@ -208,7 +208,7 @@ class EigvalsKernel : public framework::OpKernel<T> { ...@@ -208,7 +208,7 @@ class EigvalsKernel : public framework::OpKernel<T> {
} }
if (framework::IsComplexType( if (framework::IsComplexType(
framework::TransToProtoVarType(input->dtype()))) { framework::TransToProtoVarType(input->dtype()))) {
rwork.mutable_data<phi::funcs::Real<T>>(phi::make_ddim({n_dim << 1}), rwork.mutable_data<phi::dtype::Real<T>>(phi::make_ddim({n_dim << 1}),
ctx.GetPlace()); ctx.GetPlace());
} }
......
...@@ -83,7 +83,7 @@ DECLARE_INPLACE_OP_INFERER(ImagGradOpInplaceInferer, ...@@ -83,7 +83,7 @@ DECLARE_INPLACE_OP_INFERER(ImagGradOpInplaceInferer,
} // namespace paddle } // namespace paddle
DELCARE_INFER_SHAPE_FUNCTOR(imag, ImagInferShapeFunctor, DELCARE_INFER_SHAPE_FUNCTOR(imag, ImagInferShapeFunctor,
PT_INFER_META(phi::UnchangedInferMeta)); PT_INFER_META(phi::RealAndImagInferMeta));
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -46,7 +46,7 @@ template <typename DeviceContext, typename T> ...@@ -46,7 +46,7 @@ template <typename DeviceContext, typename T>
class LstsqCPUKernel : public framework::OpKernel<T> { class LstsqCPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
using ValueType = phi::funcs::Real<T>; using ValueType = phi::dtype::Real<T>;
const Tensor& x = *context.Input<Tensor>("X"); const Tensor& x = *context.Input<Tensor>("X");
auto y = context.Input<Tensor>("Y"); auto y = context.Input<Tensor>("Y");
...@@ -169,7 +169,7 @@ class LstsqCPUKernel : public framework::OpKernel<T> { ...@@ -169,7 +169,7 @@ class LstsqCPUKernel : public framework::OpKernel<T> {
&rank_32, &wkopt, lwork, &rwkopt, &info); &rank_32, &wkopt, lwork, &rwkopt, &info);
} }
lwork = std::max<int>(1, static_cast<int>(phi::funcs::Real<T>(wkopt))); lwork = std::max<int>(1, static_cast<int>(phi::dtype::Real<T>(wkopt)));
Tensor work; Tensor work;
work.Resize(phi::make_ddim({lwork})); work.Resize(phi::make_ddim({lwork}));
T* work_data = work.mutable_data<T>(context.GetPlace()); T* work_data = work.mutable_data<T>(context.GetPlace());
......
...@@ -63,7 +63,7 @@ struct MatrixEighFunctor<platform::CPUDeviceContext, T> { ...@@ -63,7 +63,7 @@ struct MatrixEighFunctor<platform::CPUDeviceContext, T> {
void operator()(const framework::ExecutionContext &ctx, const Tensor &input, void operator()(const framework::ExecutionContext &ctx, const Tensor &input,
Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower, Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower,
bool has_vectors) { bool has_vectors) {
using ValueType = phi::funcs::Real<T>; using ValueType = phi::dtype::Real<T>;
auto *out_value = eigen_values->mutable_data<ValueType>(ctx.GetPlace()); auto *out_value = eigen_values->mutable_data<ValueType>(ctx.GetPlace());
auto dito = auto dito =
...@@ -123,7 +123,7 @@ struct MatrixEighFunctor<platform::CPUDeviceContext, T> { ...@@ -123,7 +123,7 @@ struct MatrixEighFunctor<platform::CPUDeviceContext, T> {
for (auto i = 0; i < batch_size; i++) { for (auto i = 0; i < batch_size; i++) {
auto *value_data = out_value + i * values_stride; auto *value_data = out_value + i * values_stride;
auto *input_data = input_vector + i * vector_stride; auto *input_data = input_vector + i * vector_stride;
phi::funcs::lapackEigh<T, phi::funcs::Real<T>>( phi::funcs::lapackEigh<T, phi::dtype::Real<T>>(
jobz, uplo, n, input_data, lda, value_data, work_data, lwork, jobz, uplo, n, input_data, lda, value_data, work_data, lwork,
rwork_data, lrwork, iwork_data, liwork, &info); rwork_data, lrwork, iwork_data, liwork, &info);
CheckEighResult(i, info); CheckEighResult(i, info);
...@@ -151,7 +151,7 @@ struct MatrixEighFunctor<platform::CUDADeviceContext, T> { ...@@ -151,7 +151,7 @@ struct MatrixEighFunctor<platform::CUDADeviceContext, T> {
void operator()(const framework::ExecutionContext &ctx, const Tensor &input, void operator()(const framework::ExecutionContext &ctx, const Tensor &input,
Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower, Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower,
bool has_vectors) { bool has_vectors) {
using ValueType = phi::funcs::Real<T>; using ValueType = phi::dtype::Real<T>;
auto *out_value = eigen_values->mutable_data<ValueType>(ctx.GetPlace()); auto *out_value = eigen_values->mutable_data<ValueType>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
...@@ -233,7 +233,7 @@ struct MatrixEighFunctor<platform::CUDADeviceContext, T> { ...@@ -233,7 +233,7 @@ struct MatrixEighFunctor<platform::CUDADeviceContext, T> {
} }
} }
using ValueType = phi::funcs::Real<T>; using ValueType = phi::dtype::Real<T>;
inline void EvdBuffer(cusolverDnHandle_t handle, cusolverEigMode_t jobz, inline void EvdBuffer(cusolverDnHandle_t handle, cusolverEigMode_t jobz,
cublasFillMode_t uplo, int n, const T *A, int lda, cublasFillMode_t uplo, int n, const T *A, int lda,
const ValueType *W, int *lwork) const; const ValueType *W, int *lwork) const;
......
...@@ -115,7 +115,7 @@ static __global__ void InclusiveScanInnerDimCUDAKernel(const T *x, T *y, ...@@ -115,7 +115,7 @@ static __global__ void InclusiveScanInnerDimCUDAKernel(const T *x, T *y,
size_t num_rows, size_t num_rows,
size_t row_size, T init, size_t row_size, T init,
BinaryOp op) { BinaryOp op) {
using RealT = phi::funcs::Real<T>; using RealT = phi::dtype::Real<T>;
constexpr auto kSharedBufferSize = constexpr auto kSharedBufferSize =
framework::IsComplex<T>::value ? 4 * kThreadNumX : 2 * kThreadNumX; framework::IsComplex<T>::value ? 4 * kThreadNumX : 2 * kThreadNumX;
__shared__ RealT sbuf[kThreadNumY][kSharedBufferSize]; __shared__ RealT sbuf[kThreadNumY][kSharedBufferSize];
......
...@@ -56,13 +56,13 @@ class QrGPUKernel : public framework::OpKernel<T> { ...@@ -56,13 +56,13 @@ class QrGPUKernel : public framework::OpKernel<T> {
int tau_stride = min_mn; int tau_stride = min_mn;
if (compute_q) { if (compute_q) {
q.mutable_data<phi::funcs::Real<T>>( q.mutable_data<phi::dtype::Real<T>>(
context.GetPlace(), context.GetPlace(),
size_t(batch_size * m * k * sizeof(phi::funcs::Real<T>))); size_t(batch_size * m * k * sizeof(phi::dtype::Real<T>)));
} }
r.mutable_data<phi::funcs::Real<T>>( r.mutable_data<phi::dtype::Real<T>>(
context.GetPlace(), context.GetPlace(),
size_t(batch_size * k * n * sizeof(phi::funcs::Real<T>))); size_t(batch_size * k * n * sizeof(phi::dtype::Real<T>)));
auto dito = auto dito =
math::DeviceIndependenceTensorOperations<platform::CUDADeviceContext, math::DeviceIndependenceTensorOperations<platform::CUDADeviceContext,
...@@ -71,9 +71,9 @@ class QrGPUKernel : public framework::OpKernel<T> { ...@@ -71,9 +71,9 @@ class QrGPUKernel : public framework::OpKernel<T> {
// Note: allocate temporary tensors because of lacking in-place operatios. // Note: allocate temporary tensors because of lacking in-place operatios.
// Prepare qr // Prepare qr
Tensor qr; Tensor qr;
qr.mutable_data<phi::funcs::Real<T>>( qr.mutable_data<phi::dtype::Real<T>>(
context.GetPlace(), context.GetPlace(),
size_t(batch_size * m * n * sizeof(phi::funcs::Real<T>))); size_t(batch_size * m * n * sizeof(phi::dtype::Real<T>)));
// BatchedGeqrf performs computation in-place and 'qr' must be a copy of // BatchedGeqrf performs computation in-place and 'qr' must be a copy of
// input // input
paddle::framework::TensorCopy(x, context.GetPlace(), &qr); paddle::framework::TensorCopy(x, context.GetPlace(), &qr);
...@@ -126,7 +126,7 @@ class QrGPUKernel : public framework::OpKernel<T> { ...@@ -126,7 +126,7 @@ class QrGPUKernel : public framework::OpKernel<T> {
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
memory::Copy(dev_ctx.GetPlace(), (new_qr_data + i * new_qr_stride), memory::Copy(dev_ctx.GetPlace(), (new_qr_data + i * new_qr_stride),
dev_ctx.GetPlace(), (qr_data + i * qr_stride), dev_ctx.GetPlace(), (qr_data + i * qr_stride),
qr_stride * sizeof(phi::funcs::Real<T>), qr_stride * sizeof(phi::dtype::Real<T>),
dev_ctx.stream()); dev_ctx.stream());
} }
BatchedOrgqr<platform::CUDADeviceContext, T>( BatchedOrgqr<platform::CUDADeviceContext, T>(
......
...@@ -74,19 +74,19 @@ class QrCPUKernel : public framework::OpKernel<T> { ...@@ -74,19 +74,19 @@ class QrCPUKernel : public framework::OpKernel<T> {
int q_stride = m * k; int q_stride = m * k;
int r_stride = k * n; int r_stride = k * n;
auto* x_data = x.data<phi::funcs::Real<T>>(); auto* x_data = x.data<phi::dtype::Real<T>>();
T* q_data = nullptr; T* q_data = nullptr;
if (compute_q) { if (compute_q) {
q_data = q.mutable_data<phi::funcs::Real<T>>( q_data = q.mutable_data<phi::dtype::Real<T>>(
context.GetPlace(), context.GetPlace(),
size_t(batch_size * m * k * sizeof(phi::funcs::Real<T>))); size_t(batch_size * m * k * sizeof(phi::dtype::Real<T>)));
memset(q_data, 0, memset(q_data, 0,
size_t(batch_size * m * k * sizeof(phi::funcs::Real<T>))); size_t(batch_size * m * k * sizeof(phi::dtype::Real<T>)));
} }
auto* r_data = r.mutable_data<phi::funcs::Real<T>>( auto* r_data = r.mutable_data<phi::dtype::Real<T>>(
context.GetPlace(), context.GetPlace(),
size_t(batch_size * k * n * sizeof(phi::funcs::Real<T>))); size_t(batch_size * k * n * sizeof(phi::dtype::Real<T>)));
memset(r_data, 0, size_t(batch_size * k * n * sizeof(phi::funcs::Real<T>))); memset(r_data, 0, size_t(batch_size * k * n * sizeof(phi::dtype::Real<T>)));
// Implement QR by calling Eigen // Implement QR by calling Eigen
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
...@@ -142,7 +142,7 @@ class QrGradKernel : public framework::OpKernel<T> { ...@@ -142,7 +142,7 @@ class QrGradKernel : public framework::OpKernel<T> {
// Use a different name dA instead of dX // Use a different name dA instead of dX
framework::Tensor& dA = framework::Tensor& dA =
*ctx.Output<framework::Tensor>(framework::GradVarName("X")); *ctx.Output<framework::Tensor>(framework::GradVarName("X"));
dA.mutable_data<phi::funcs::Real<T>>(ctx.GetPlace()); dA.mutable_data<phi::dtype::Real<T>>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
phi::funcs::SetConstant<DeviceContext, T>()(dev_ctx, &dA, T(0)); phi::funcs::SetConstant<DeviceContext, T>()(dev_ctx, &dA, T(0));
...@@ -224,7 +224,7 @@ class QrGradKernel : public framework::OpKernel<T> { ...@@ -224,7 +224,7 @@ class QrGradKernel : public framework::OpKernel<T> {
} else { } else {
// If m < n for input matrices A, we partition A = [X|Y] and R = [U|V] // If m < n for input matrices A, we partition A = [X|Y] and R = [U|V]
// Calculate dX and dY individually and concatenate them to get dA // Calculate dX and dY individually and concatenate them to get dA
dA.mutable_data<phi::funcs::Real<T>>(ctx.GetPlace()); dA.mutable_data<phi::dtype::Real<T>>(ctx.GetPlace());
auto Y = dito.Slice(A, {-1}, {m}, {n}); auto Y = dito.Slice(A, {-1}, {m}, {n});
auto U = dito.Slice(R, {-1}, {0}, {m}); auto U = dito.Slice(R, {-1}, {0}, {m});
......
...@@ -83,7 +83,7 @@ DECLARE_INPLACE_OP_INFERER(RealGradOpInplaceInferer, ...@@ -83,7 +83,7 @@ DECLARE_INPLACE_OP_INFERER(RealGradOpInplaceInferer,
} // namespace paddle } // namespace paddle
DELCARE_INFER_SHAPE_FUNCTOR(real, RealInferShapeFunctor, DELCARE_INFER_SHAPE_FUNCTOR(real, RealInferShapeFunctor,
PT_INFER_META(phi::UnchangedInferMeta)); PT_INFER_META(phi::RealAndImagInferMeta));
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -105,7 +105,7 @@ struct RealMulComplexFunctor { ...@@ -105,7 +105,7 @@ struct RealMulComplexFunctor {
"The image part of y must to be 0" "The image part of y must to be 0"
"but got [%d]", "but got [%d]",
y.imag)); y.imag));
return platform::complex<phi::funcs::Real<T>>(x.real * y.real, return platform::complex<phi::dtype::Real<T>>(x.real * y.real,
x.imag * y.real); x.imag * y.real);
} }
}; };
...@@ -391,11 +391,11 @@ struct DeviceIndependenceTensorOperations { ...@@ -391,11 +391,11 @@ struct DeviceIndependenceTensorOperations {
// batch_diag for CPU only // batch_diag for CPU only
Tensor BatchDiag(const Tensor& x, int batch) { Tensor BatchDiag(const Tensor& x, int batch) {
Tensor out; Tensor out;
auto* x_data = x.data<phi::funcs::Real<T>>(); auto* x_data = x.data<phi::dtype::Real<T>>();
auto numel = x.numel(); auto numel = x.numel();
auto* out_data = out.mutable_data<phi::funcs::Real<T>>( auto* out_data = out.mutable_data<phi::dtype::Real<T>>(
x.dims(), context.GetPlace(), x.dims(), context.GetPlace(),
static_cast<size_t>(numel * sizeof(phi::funcs::Real<T>))); static_cast<size_t>(numel * sizeof(phi::dtype::Real<T>)));
auto x_dims = x.dims(); auto x_dims = x.dims();
int num_dims = x_dims.size(); int num_dims = x_dims.size();
...@@ -661,9 +661,9 @@ struct DeviceIndependenceTensorOperations { ...@@ -661,9 +661,9 @@ struct DeviceIndependenceTensorOperations {
Tensor Real(const Tensor& x) { Tensor Real(const Tensor& x) {
Tensor out; Tensor out;
auto numel = x.numel(); auto numel = x.numel();
auto* out_data = out.mutable_data<phi::funcs::Real<T>>( auto* out_data = out.mutable_data<phi::dtype::Real<T>>(
x.dims(), context.GetPlace(), x.dims(), context.GetPlace(),
static_cast<size_t>(numel * sizeof(phi::funcs::Real<T>))); static_cast<size_t>(numel * sizeof(phi::dtype::Real<T>)));
auto* x_data = x.data<T>(); auto* x_data = x.data<T>();
auto for_range = GetForRange(numel); auto for_range = GetForRange(numel);
phi::funcs::RealFunctor<T> functor(x_data, out_data, numel); phi::funcs::RealFunctor<T> functor(x_data, out_data, numel);
......
...@@ -46,14 +46,14 @@ class SvdCPUKernel : public framework::OpKernel<T> { ...@@ -46,14 +46,14 @@ class SvdCPUKernel : public framework::OpKernel<T> {
int col_u = full ? rows : k; int col_u = full ? rows : k;
int col_v = full ? cols : k; int col_v = full ? cols : k;
int batches = numel / (rows * cols); int batches = numel / (rows * cols);
auto* U_out = U->mutable_data<phi::funcs::Real<T>>( auto* U_out = U->mutable_data<phi::dtype::Real<T>>(
context.GetPlace(), context.GetPlace(),
size_t(batches * rows * col_u * sizeof(phi::funcs::Real<T>))); size_t(batches * rows * col_u * sizeof(phi::dtype::Real<T>)));
auto* VH_out = VH->mutable_data<phi::funcs::Real<T>>( auto* VH_out = VH->mutable_data<phi::dtype::Real<T>>(
context.GetPlace(), context.GetPlace(),
size_t(batches * col_v * cols * sizeof(phi::funcs::Real<T>))); size_t(batches * col_v * cols * sizeof(phi::dtype::Real<T>)));
auto* S_out = S->mutable_data<phi::funcs::Real<T>>( auto* S_out = S->mutable_data<phi::dtype::Real<T>>(
context.GetPlace(), size_t(batches * k * sizeof(phi::funcs::Real<T>))); context.GetPlace(), size_t(batches * k * sizeof(phi::dtype::Real<T>)));
/*SVD Use the Eigen Library*/ /*SVD Use the Eigen Library*/
math::BatchSvd<T>(x_data, U_out, VH_out, S_out, rows, cols, batches, full); math::BatchSvd<T>(x_data, U_out, VH_out, S_out, rows, cols, batches, full);
} }
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/data_type.h"
namespace phi {
namespace dtype {
template <bool B, typename T>
struct cond {
static constexpr bool value = B;
using type = T;
};
template <bool B, typename TrueF, typename FalseF>
struct eval_if {
using type = typename TrueF::type;
};
template <typename TrueF, typename FalseF>
struct eval_if<false, TrueF, FalseF> {
using type = typename FalseF::type;
};
template <bool B, typename T, typename F>
using eval_if_t = typename eval_if<B, T, F>::type;
template <typename Head, typename... Tail>
struct select {
using type = eval_if_t<Head::value, Head, select<Tail...>>;
};
template <typename T>
struct select<T> {
using type = T;
};
template <bool B, typename T>
struct select<cond<B, T>> {
// last one had better be true!
static_assert(B, "No match select type!");
using type = T;
};
template <typename Head, typename... Tail>
using select_t = typename select<Head, Tail...>::type;
// runtime real and complex type conversion
template <typename T>
using Real = select_t<cond<std::is_same<T, complex<float>>::value, float>,
cond<std::is_same<T, complex<double>>::value, double>,
T>;
template <typename T>
using Complex = select_t<cond<std::is_same<T, float>::value, complex<float>>,
cond<std::is_same<T, double>::value, complex<double>>,
T>;
inline DataType ToReal(DataType dtype) {
switch (dtype) {
case phi::DataType::COMPLEX64:
return phi::DataType::FLOAT32;
case phi::DataType::COMPLEX128:
return phi::DataType::FLOAT64;
default:
return dtype;
}
}
inline DataType ToComplex(DataType dtype) {
switch (dtype) {
case phi::DataType::FLOAT32:
return phi::DataType::COMPLEX64;
case phi::DataType::FLOAT64:
return phi::DataType::COMPLEX128;
default:
return dtype;
}
}
} // namespace dtype
} // namespace phi
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <set> #include <set>
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/kernels/funcs/unfold_functor.h" #include "paddle/phi/kernels/funcs/unfold_functor.h"
...@@ -51,6 +52,12 @@ void UnchangedInferMetaCheckAxis(const MetaTensor& x, ...@@ -51,6 +52,12 @@ void UnchangedInferMetaCheckAxis(const MetaTensor& x,
out->share_meta(x); out->share_meta(x);
} }
void RealAndImagInferMeta(const MetaTensor& x, MetaTensor* out) {
out->set_dims(x.dims());
out->set_dtype(dtype::ToReal(x.dtype()));
out->set_layout(x.layout());
}
void FlattenInferMeta(const MetaTensor& x, void FlattenInferMeta(const MetaTensor& x,
int start_axis, int start_axis,
int stop_axis, int stop_axis,
......
...@@ -39,6 +39,8 @@ void UnchangedInferMetaCheckAxis(const MetaTensor& x, ...@@ -39,6 +39,8 @@ void UnchangedInferMetaCheckAxis(const MetaTensor& x,
int axis, int axis,
MetaTensor* out); MetaTensor* out);
void RealAndImagInferMeta(const MetaTensor& x, MetaTensor* out);
void FlattenInferMeta(const MetaTensor& x, void FlattenInferMeta(const MetaTensor& x,
int start_axis, int start_axis,
int stop_axis, int stop_axis,
......
...@@ -25,9 +25,9 @@ template <typename T, typename Context> ...@@ -25,9 +25,9 @@ template <typename T, typename Context>
void AbsKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) { void AbsKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
auto numel = x.numel(); auto numel = x.numel();
auto* x_data = x.data<T>(); auto* x_data = x.data<T>();
ctx.template Alloc<phi::funcs::Real<T>>( ctx.template Alloc<phi::dtype::Real<T>>(
out, size_t(x.numel() * sizeof(phi::funcs::Real<T>))); out, size_t(x.numel() * sizeof(phi::dtype::Real<T>)));
auto* out_data = out->data<phi::funcs::Real<T>>(); auto* out_data = out->data<phi::dtype::Real<T>>();
phi::funcs::ForRange<Context> for_range(ctx, numel); phi::funcs::ForRange<Context> for_range(ctx, numel);
phi::funcs::AbsFunctor<T> functor(x_data, out_data, numel); phi::funcs::AbsFunctor<T> functor(x_data, out_data, numel);
......
...@@ -37,11 +37,15 @@ PD_REGISTER_KERNEL(real, ...@@ -37,11 +37,15 @@ PD_REGISTER_KERNEL(real,
ALL_LAYOUT, ALL_LAYOUT,
phi::RealKernel, phi::RealKernel,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
PD_REGISTER_KERNEL(imag, PD_REGISTER_KERNEL(imag,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::ImagKernel, phi::ImagKernel,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
...@@ -20,56 +20,12 @@ limitations under the License. */ ...@@ -20,56 +20,12 @@ limitations under the License. */
#include <type_traits> #include <type_traits>
#include "paddle/phi/common/complex.h" #include "paddle/phi/common/complex.h"
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/hostdevice.h"
namespace phi { namespace phi {
namespace funcs { namespace funcs {
template <bool B, typename T>
struct cond {
static constexpr bool value = B;
using type = T;
};
template <bool B, typename TrueF, typename FalseF>
struct eval_if {
using type = typename TrueF::type;
};
template <typename TrueF, typename FalseF>
struct eval_if<false, TrueF, FalseF> {
using type = typename FalseF::type;
};
template <bool B, typename T, typename F>
using eval_if_t = typename eval_if<B, T, F>::type;
template <typename Head, typename... Tail>
struct select {
using type = eval_if_t<Head::value, Head, select<Tail...>>;
};
template <typename T>
struct select<T> {
using type = T;
};
template <bool B, typename T>
struct select<cond<B, T>> {
// last one had better be true!
static_assert(B, "No match select type!");
using type = T;
};
template <typename Head, typename... Tail>
using select_t = typename select<Head, Tail...>::type;
template <typename T>
using Real =
select_t<cond<std::is_same<T, phi::dtype::complex<float>>::value, float>,
cond<std::is_same<T, phi::dtype::complex<double>>::value, double>,
T>;
template <typename T, typename RealT> template <typename T, typename RealT>
using Complex = typename std::enable_if<!std::is_same<T, RealT>::value>::type; using Complex = typename std::enable_if<!std::is_same<T, RealT>::value>::type;
...@@ -91,9 +47,9 @@ template <typename T, typename Enable = void> ...@@ -91,9 +47,9 @@ template <typename T, typename Enable = void>
struct RealFunctor; struct RealFunctor;
template <typename T> template <typename T>
struct RealFunctor<T, Complex<T, Real<T>>> { struct RealFunctor<T, Complex<T, dtype::Real<T>>> {
public: public:
RealFunctor(const T* input, Real<T>* output, int64_t numel) RealFunctor(const T* input, dtype::Real<T>* output, int64_t numel)
: input_(input), output_(output), numel_(numel) {} : input_(input), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const { HOSTDEVICE void operator()(int64_t idx) const {
...@@ -102,7 +58,7 @@ struct RealFunctor<T, Complex<T, Real<T>>> { ...@@ -102,7 +58,7 @@ struct RealFunctor<T, Complex<T, Real<T>>> {
private: private:
const T* input_; const T* input_;
Real<T>* output_; dtype::Real<T>* output_;
int64_t numel_; int64_t numel_;
}; };
...@@ -110,8 +66,8 @@ template <typename T, typename Enable = void> ...@@ -110,8 +66,8 @@ template <typename T, typename Enable = void>
struct ImagFunctor; struct ImagFunctor;
template <typename T> template <typename T>
struct ImagFunctor<T, Complex<T, Real<T>>> { struct ImagFunctor<T, Complex<T, dtype::Real<T>>> {
ImagFunctor(const T* input, Real<T>* output, int64_t numel) ImagFunctor(const T* input, dtype::Real<T>* output, int64_t numel)
: input_(input), output_(output), numel_(numel) {} : input_(input), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const { HOSTDEVICE void operator()(int64_t idx) const {
...@@ -119,7 +75,7 @@ struct ImagFunctor<T, Complex<T, Real<T>>> { ...@@ -119,7 +75,7 @@ struct ImagFunctor<T, Complex<T, Real<T>>> {
} }
const T* input_; const T* input_;
Real<T>* output_; dtype::Real<T>* output_;
int64_t numel_; int64_t numel_;
}; };
...@@ -127,8 +83,8 @@ template <typename T, typename Enable = void> ...@@ -127,8 +83,8 @@ template <typename T, typename Enable = void>
struct AbsFunctor; struct AbsFunctor;
template <typename T> template <typename T>
struct AbsFunctor<T, Complex<T, Real<T>>> { struct AbsFunctor<T, Complex<T, dtype::Real<T>>> {
AbsFunctor(const T* input, Real<T>* output, int64_t numel) AbsFunctor(const T* input, dtype::Real<T>* output, int64_t numel)
: input_(input), output_(output), numel_(numel) {} : input_(input), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const { HOSTDEVICE void operator()(int64_t idx) const {
...@@ -136,12 +92,12 @@ struct AbsFunctor<T, Complex<T, Real<T>>> { ...@@ -136,12 +92,12 @@ struct AbsFunctor<T, Complex<T, Real<T>>> {
} }
const T* input_; const T* input_;
Real<T>* output_; dtype::Real<T>* output_;
int64_t numel_; int64_t numel_;
}; };
template <typename T> template <typename T>
struct AbsFunctor<T, NoComplex<T, Real<T>>> { struct AbsFunctor<T, NoComplex<T, dtype::Real<T>>> {
AbsFunctor(const T* input, T* output, int64_t numel) AbsFunctor(const T* input, T* output, int64_t numel)
: input_(input), output_(output), numel_(numel) {} : input_(input), output_(output), numel_(numel) {}
...@@ -203,7 +159,10 @@ struct AbsGradCUDAFunctor<phi::dtype::complex<double>> { ...@@ -203,7 +159,10 @@ struct AbsGradCUDAFunctor<phi::dtype::complex<double>> {
template <typename T> template <typename T>
struct AbsGradFunctor { struct AbsGradFunctor {
AbsGradFunctor(const Real<T>* dout, const T* x, T* output, int64_t numel) AbsGradFunctor(const dtype::Real<T>* dout,
const T* x,
T* output,
int64_t numel)
: dout_(dout), x_(x), output_(output), numel_(numel) {} : dout_(dout), x_(x), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const { HOSTDEVICE void operator()(int64_t idx) const {
...@@ -214,7 +173,7 @@ struct AbsGradFunctor { ...@@ -214,7 +173,7 @@ struct AbsGradFunctor {
} }
} }
const Real<T>* dout_; const dtype::Real<T>* dout_;
const T* x_; const T* x_;
T* output_; T* output_;
int64_t numel_; int64_t numel_;
...@@ -334,8 +293,8 @@ template <typename T, typename Enable = void> ...@@ -334,8 +293,8 @@ template <typename T, typename Enable = void>
struct RealToComplexFunctor; struct RealToComplexFunctor;
template <typename T> template <typename T>
struct RealToComplexFunctor<T, Complex<T, Real<T>>> { struct RealToComplexFunctor<T, Complex<T, dtype::Real<T>>> {
RealToComplexFunctor(const Real<T>* input, T* output, int64_t numel) RealToComplexFunctor(const dtype::Real<T>* input, T* output, int64_t numel)
: input_(input), output_(output), numel_(numel) {} : input_(input), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const { HOSTDEVICE void operator()(int64_t idx) const {
...@@ -343,7 +302,7 @@ struct RealToComplexFunctor<T, Complex<T, Real<T>>> { ...@@ -343,7 +302,7 @@ struct RealToComplexFunctor<T, Complex<T, Real<T>>> {
output_[idx].imag = 0; output_[idx].imag = 0;
} }
const Real<T>* input_; const dtype::Real<T>* input_;
T* output_; T* output_;
int64_t numel_; int64_t numel_;
}; };
...@@ -352,8 +311,8 @@ template <typename T, typename Enable = void> ...@@ -352,8 +311,8 @@ template <typename T, typename Enable = void>
struct ImagToComplexFunctor; struct ImagToComplexFunctor;
template <typename T> template <typename T>
struct ImagToComplexFunctor<T, Complex<T, Real<T>>> { struct ImagToComplexFunctor<T, Complex<T, dtype::Real<T>>> {
ImagToComplexFunctor(const Real<T>* input, T* output, int64_t numel) ImagToComplexFunctor(const dtype::Real<T>* input, T* output, int64_t numel)
: input_(input), output_(output), numel_(numel) {} : input_(input), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const { HOSTDEVICE void operator()(int64_t idx) const {
...@@ -361,7 +320,7 @@ struct ImagToComplexFunctor<T, Complex<T, Real<T>>> { ...@@ -361,7 +320,7 @@ struct ImagToComplexFunctor<T, Complex<T, Real<T>>> {
output_[idx].imag = input_[idx]; output_[idx].imag = input_[idx];
} }
const Real<T>* input_; const dtype::Real<T>* input_;
T* output_; T* output_;
int64_t numel_; int64_t numel_;
}; };
...@@ -370,9 +329,9 @@ template <typename T, typename Enable = void> ...@@ -370,9 +329,9 @@ template <typename T, typename Enable = void>
struct RealImagToComplexFunctor; struct RealImagToComplexFunctor;
template <typename T> template <typename T>
struct RealImagToComplexFunctor<T, Complex<T, Real<T>>> { struct RealImagToComplexFunctor<T, Complex<T, dtype::Real<T>>> {
RealImagToComplexFunctor(const Real<T>* input_real, RealImagToComplexFunctor(const dtype::Real<T>* input_real,
const Real<T>* input_imag, const dtype::Real<T>* input_imag,
T* output, T* output,
int64_t numel) int64_t numel)
: input_real_(input_real), : input_real_(input_real),
...@@ -385,8 +344,8 @@ struct RealImagToComplexFunctor<T, Complex<T, Real<T>>> { ...@@ -385,8 +344,8 @@ struct RealImagToComplexFunctor<T, Complex<T, Real<T>>> {
output_[idx].imag = input_imag_[idx]; output_[idx].imag = input_imag_[idx];
} }
const Real<T>* input_real_; const dtype::Real<T>* input_real_;
const Real<T>* input_imag_; const dtype::Real<T>* input_imag_;
T* output_; T* output_;
int64_t numel_; int64_t numel_;
}; };
...@@ -423,8 +382,8 @@ struct AngleFunctor; ...@@ -423,8 +382,8 @@ struct AngleFunctor;
// angel function for complex // angel function for complex
template <typename T> template <typename T>
struct AngleFunctor<T, phi::funcs::Complex<T, phi::funcs::Real<T>>> { struct AngleFunctor<T, phi::funcs::Complex<T, dtype::Real<T>>> {
AngleFunctor(const T* input, phi::funcs::Real<T>* output, int64_t numel) AngleFunctor(const T* input, dtype::Real<T>* output, int64_t numel)
: input_(input), output_(output), numel_(numel) {} : input_(input), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const { HOSTDEVICE void operator()(int64_t idx) const {
...@@ -432,13 +391,13 @@ struct AngleFunctor<T, phi::funcs::Complex<T, phi::funcs::Real<T>>> { ...@@ -432,13 +391,13 @@ struct AngleFunctor<T, phi::funcs::Complex<T, phi::funcs::Real<T>>> {
} }
const T* input_; const T* input_;
phi::funcs::Real<T>* output_; dtype::Real<T>* output_;
int64_t numel_; int64_t numel_;
}; };
// angel function for real // angel function for real
template <typename T> template <typename T>
struct AngleFunctor<T, phi::funcs::NoComplex<T, phi::funcs::Real<T>>> { struct AngleFunctor<T, phi::funcs::NoComplex<T, dtype::Real<T>>> {
AngleFunctor(const T* input, T* output, int64_t numel) AngleFunctor(const T* input, T* output, int64_t numel)
: input_(input), output_(output), numel_(numel) {} : input_(input), output_(output), numel_(numel) {}
...@@ -456,25 +415,22 @@ struct AngleGradFunctor; ...@@ -456,25 +415,22 @@ struct AngleGradFunctor;
// angle grad for complex // angle grad for complex
template <typename T> template <typename T>
struct AngleGradFunctor<T, phi::funcs::Complex<T, phi::funcs::Real<T>>> { struct AngleGradFunctor<T, phi::funcs::Complex<T, dtype::Real<T>>> {
AngleGradFunctor(const phi::funcs::Real<T>* dout, AngleGradFunctor(const dtype::Real<T>* dout, const T* x, T* dx, int64_t numel)
const T* x,
T* dx,
int64_t numel)
: dout_(dout), x_(x), dx_(dx), numel_(numel) {} : dout_(dout), x_(x), dx_(dx), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const { HOSTDEVICE void operator()(int64_t idx) const {
if (x_[idx] == T(0)) { if (x_[idx] == T(0)) {
dx_[idx] = T(0); dx_[idx] = T(0);
} else { } else {
const phi::funcs::Real<T> r_square = const phi::dtype::Real<T> r_square =
x_[idx].real * x_[idx].real + x_[idx].imag * x_[idx].imag; x_[idx].real * x_[idx].real + x_[idx].imag * x_[idx].imag;
dx_[idx] = T(-dout_[idx] * x_[idx].imag / r_square, dx_[idx] = T(-dout_[idx] * x_[idx].imag / r_square,
dout_[idx] * x_[idx].real / r_square); dout_[idx] * x_[idx].real / r_square);
} }
} }
const phi::funcs::Real<T>* dout_; const phi::dtype::Real<T>* dout_;
const T* x_; const T* x_;
T* dx_; T* dx_;
int64_t numel_; int64_t numel_;
...@@ -482,16 +438,13 @@ struct AngleGradFunctor<T, phi::funcs::Complex<T, phi::funcs::Real<T>>> { ...@@ -482,16 +438,13 @@ struct AngleGradFunctor<T, phi::funcs::Complex<T, phi::funcs::Real<T>>> {
// angle grad for real // angle grad for real
template <typename T> template <typename T>
struct AngleGradFunctor<T, phi::funcs::NoComplex<T, phi::funcs::Real<T>>> { struct AngleGradFunctor<T, phi::funcs::NoComplex<T, dtype::Real<T>>> {
AngleGradFunctor(const phi::funcs::Real<T>* dout, AngleGradFunctor(const dtype::Real<T>* dout, const T* x, T* dx, int64_t numel)
const T* x,
T* dx,
int64_t numel)
: dout_(dout), x_(x), dx_(dx), numel_(numel) {} : dout_(dout), x_(x), dx_(dx), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const { dx_[idx] = 0; } HOSTDEVICE void operator()(int64_t idx) const { dx_[idx] = 0; }
const phi::funcs::Real<T>* dout_; const dtype::Real<T>* dout_;
const T* x_; const T* x_;
T* dx_; T* dx_;
int64_t numel_; int64_t numel_;
......
...@@ -27,14 +27,14 @@ template <typename T, typename Enable = void> ...@@ -27,14 +27,14 @@ template <typename T, typename Enable = void>
struct CudaAbsFunctor; struct CudaAbsFunctor;
template <typename T> template <typename T>
struct CudaAbsFunctor<T, phi::funcs::Complex<T, phi::funcs::Real<T>>> { struct CudaAbsFunctor<T, phi::funcs::Complex<T, phi::dtype::Real<T>>> {
__device__ __forceinline__ phi::funcs::Real<T> operator()(const T x) const { __device__ __forceinline__ phi::dtype::Real<T> operator()(const T x) const {
return abs(x); return abs(x);
} }
}; };
template <typename T> template <typename T>
struct CudaAbsFunctor<T, phi::funcs::NoComplex<T, phi::funcs::Real<T>>> { struct CudaAbsFunctor<T, phi::funcs::NoComplex<T, phi::dtype::Real<T>>> {
__device__ __forceinline__ T operator()(const T x) const { __device__ __forceinline__ T operator()(const T x) const {
return std::abs(x); return std::abs(x);
} }
...@@ -42,12 +42,12 @@ struct CudaAbsFunctor<T, phi::funcs::NoComplex<T, phi::funcs::Real<T>>> { ...@@ -42,12 +42,12 @@ struct CudaAbsFunctor<T, phi::funcs::NoComplex<T, phi::funcs::Real<T>>> {
template <typename T, typename Context> template <typename T, typename Context>
void AbsKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) { void AbsKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
ctx.template Alloc<phi::funcs::Real<T>>(out); ctx.template Alloc<phi::dtype::Real<T>>(out);
std::vector<const DenseTensor*> ins = {&x}; std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out}; std::vector<DenseTensor*> outs = {out};
auto functor = CudaAbsFunctor<T>(); auto functor = CudaAbsFunctor<T>();
funcs::ElementwiseKernel<phi::funcs::Real<T>>(ctx, ins, &outs, functor); funcs::ElementwiseKernel<phi::dtype::Real<T>>(ctx, ins, &outs, functor);
} }
} // namespace phi } // namespace phi
......
...@@ -38,11 +38,15 @@ PD_REGISTER_KERNEL(real, ...@@ -38,11 +38,15 @@ PD_REGISTER_KERNEL(real,
ALL_LAYOUT, ALL_LAYOUT,
phi::RealKernel, phi::RealKernel,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
PD_REGISTER_KERNEL(imag, PD_REGISTER_KERNEL(imag,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::ImagKernel, phi::ImagKernel,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
...@@ -47,7 +47,7 @@ void AbsGradKernel(const Context& ctx, ...@@ -47,7 +47,7 @@ void AbsGradKernel(const Context& ctx,
const DenseTensor& dout, const DenseTensor& dout,
DenseTensor* dx) { DenseTensor* dx) {
auto numel = dout.numel(); auto numel = dout.numel();
auto* dout_data = dout.data<phi::funcs::Real<T>>(); auto* dout_data = dout.data<phi::dtype::Real<T>>();
auto* x_data = x.data<T>(); auto* x_data = x.data<T>();
ctx.template Alloc<T>(dx, static_cast<size_t>(numel * sizeof(T))); ctx.template Alloc<T>(dx, static_cast<size_t>(numel * sizeof(T)));
......
...@@ -24,7 +24,7 @@ void RealGradKernel(const Context& dev_ctx, ...@@ -24,7 +24,7 @@ void RealGradKernel(const Context& dev_ctx,
const DenseTensor& dout, const DenseTensor& dout,
DenseTensor* dx) { DenseTensor* dx) {
auto numel = dout.numel(); auto numel = dout.numel();
auto* dout_data = dout.data<phi::funcs::Real<T>>(); auto* dout_data = dout.data<phi::dtype::Real<T>>();
auto* dx_data = auto* dx_data =
dev_ctx.template Alloc<T>(dx, static_cast<size_t>(numel * sizeof(T))); dev_ctx.template Alloc<T>(dx, static_cast<size_t>(numel * sizeof(T)));
...@@ -38,7 +38,7 @@ void ImagGradKernel(const Context& dev_ctx, ...@@ -38,7 +38,7 @@ void ImagGradKernel(const Context& dev_ctx,
const DenseTensor& dout, const DenseTensor& dout,
DenseTensor* dx) { DenseTensor* dx) {
auto numel = dout.numel(); auto numel = dout.numel();
auto* dout_data = dout.data<phi::funcs::Real<T>>(); auto* dout_data = dout.data<phi::dtype::Real<T>>();
auto* dx_data = auto* dx_data =
dev_ctx.template Alloc<T>(dx, static_cast<size_t>(numel * sizeof(T))); dev_ctx.template Alloc<T>(dx, static_cast<size_t>(numel * sizeof(T)));
......
...@@ -39,8 +39,8 @@ void RealKernel(const Context& dev_ctx, ...@@ -39,8 +39,8 @@ void RealKernel(const Context& dev_ctx,
DenseTensor* out) { DenseTensor* out) {
auto numel = x.numel(); auto numel = x.numel();
auto* x_data = x.data<T>(); auto* x_data = x.data<T>();
auto* out_data = dev_ctx.template Alloc<phi::funcs::Real<T>>( auto* out_data = dev_ctx.template Alloc<phi::dtype::Real<T>>(
out, static_cast<size_t>(numel * sizeof(phi::funcs::Real<T>))); out, static_cast<size_t>(numel * sizeof(phi::dtype::Real<T>)));
phi::funcs::ForRange<Context> for_range(dev_ctx, numel); phi::funcs::ForRange<Context> for_range(dev_ctx, numel);
phi::funcs::RealFunctor<T> functor(x_data, out_data, numel); phi::funcs::RealFunctor<T> functor(x_data, out_data, numel);
...@@ -53,8 +53,8 @@ void ImagKernel(const Context& dev_ctx, ...@@ -53,8 +53,8 @@ void ImagKernel(const Context& dev_ctx,
DenseTensor* out) { DenseTensor* out) {
auto numel = x.numel(); auto numel = x.numel();
auto* x_data = x.data<T>(); auto* x_data = x.data<T>();
auto* out_data = dev_ctx.template Alloc<phi::funcs::Real<T>>( auto* out_data = dev_ctx.template Alloc<phi::dtype::Real<T>>(
out, static_cast<size_t>(numel * sizeof(phi::funcs::Real<T>))); out, static_cast<size_t>(numel * sizeof(phi::dtype::Real<T>)));
phi::funcs::ForRange<Context> for_range(dev_ctx, numel); phi::funcs::ForRange<Context> for_range(dev_ctx, numel);
phi::funcs::ImagFunctor<T> functor(x_data, out_data, numel); phi::funcs::ImagFunctor<T> functor(x_data, out_data, numel);
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/phi/api/ext/exception.h" #include "paddle/phi/api/ext/exception.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/type_traits.h"
namespace phi { namespace phi {
namespace tests { namespace tests {
...@@ -71,5 +72,20 @@ TEST(DataType, OStream) { ...@@ -71,5 +72,20 @@ TEST(DataType, OStream) {
} }
} }
TEST(TypeTraits, Complex) {
EXPECT_EQ(phi::dtype::ToReal(phi::DataType::COMPLEX64),
phi::DataType::FLOAT32);
EXPECT_EQ(phi::dtype::ToReal(phi::DataType::COMPLEX128),
phi::DataType::FLOAT64);
EXPECT_EQ(phi::dtype::ToReal(phi::DataType::FLOAT32), phi::DataType::FLOAT32);
EXPECT_EQ(phi::dtype::ToComplex(phi::DataType::FLOAT32),
phi::DataType::COMPLEX64);
EXPECT_EQ(phi::dtype::ToComplex(phi::DataType::FLOAT64),
phi::DataType::COMPLEX128);
EXPECT_EQ(phi::dtype::ToComplex(phi::DataType::COMPLEX64),
phi::DataType::COMPLEX64);
}
} // namespace tests } // namespace tests
} // namespace phi } // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册