// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "Eigen/Dense" #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; using DDim = framework::DDim; template struct PaddleComplex { using Type = paddle::platform::complex; }; template <> struct PaddleComplex> { using Type = paddle::platform::complex; }; template <> struct PaddleComplex> { using Type = paddle::platform::complex; }; template struct StdComplex { using Type = std::complex; }; template <> struct StdComplex> { using Type = std::complex; }; template <> struct StdComplex> { using Type = std::complex; }; template using PaddleCType = typename PaddleComplex::Type; template using StdCType = typename StdComplex::Type; template using EigenMatrixPaddle = Eigen::Matrix; template using EigenVectorPaddle = Eigen::Matrix, Eigen::Dynamic, 1>; template using EigenMatrixStd = Eigen::Matrix, Eigen::Dynamic, Eigen::Dynamic>; template using EigenVectorStd = Eigen::Matrix, Eigen::Dynamic, 1>; static void SpiltBatchSquareMatrix(const Tensor &input, std::vector *output) { DDim input_dims = input.dims(); int last_dim = input_dims.size() - 1; int n_dim = input_dims[last_dim]; DDim flattened_input_dims, flattened_output_dims; if (input_dims.size() > 2) { flattened_input_dims = flatten_to_3d(input_dims, last_dim - 1, last_dim); } else { flattened_input_dims = framework::make_ddim({1, n_dim, n_dim}); } Tensor flattened_input; flattened_input.ShareDataWith(input); flattened_input.Resize(flattened_input_dims); (*output) = flattened_input.Split(1, 0); } template class EigvalsKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { const Tensor *input = ctx.Input("X"); Tensor *output = ctx.Output("Out"); auto input_type = input->type(); auto output_type = framework::IsComplexType(input_type) ? input_type : framework::ToComplexType(input_type); output->mutable_data(ctx.GetPlace(), output_type); std::vector input_matrices; SpiltBatchSquareMatrix(*input, /*->*/ &input_matrices); int n_dim = input_matrices[0].dims()[1]; int n_batch = input_matrices.size(); DDim output_dims = output->dims(); output->Resize(framework::make_ddim({n_batch, n_dim})); std::vector output_vectors = output->Split(1, 0); Eigen::Map> input_emp(NULL, n_dim, n_dim); Eigen::Map> output_evp(NULL, n_dim); EigenMatrixStd input_ems; EigenVectorStd output_evs; for (int i = 0; i < n_batch; ++i) { new (&input_emp) Eigen::Map>( input_matrices[i].data(), n_dim, n_dim); new (&output_evp) Eigen::Map>( output_vectors[i].data>(), n_dim); input_ems = input_emp.template cast>(); output_evs = input_ems.eigenvalues(); output_evp = output_evs.template cast>(); } output->Resize(output_dims); } }; } // namespace operators } // namespace paddle