// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/platform/for_range.h" #include "paddle/pten/kernels/funcs/complex_functors.h" #include "paddle/pten/kernels/funcs/lapack/lapack_function.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; using DDim = framework::DDim; template struct PaddleComplex; template struct PaddleComplex< T, typename std::enable_if::value>::type> { using type = paddle::platform::complex; }; template struct PaddleComplex< T, typename std::enable_if< std::is_same>::value || std::is_same>::value>::type> { using type = T; }; template using PaddleCType = typename PaddleComplex::type; template using Real = typename pten::funcs::Real; static void SpiltBatchSquareMatrix(const Tensor& input, std::vector* output) { DDim input_dims = input.dims(); int last_dim = input_dims.size() - 1; int n_dim = input_dims[last_dim]; DDim flattened_input_dims, flattened_output_dims; if (input_dims.size() > 2) { flattened_input_dims = flatten_to_3d(input_dims, last_dim - 1, last_dim); } else { flattened_input_dims = framework::make_ddim({1, n_dim, n_dim}); } Tensor flattened_input; flattened_input.ShareDataWith(input); flattened_input.Resize(flattened_input_dims); (*output) = flattened_input.Split(1, 0); } static void CheckLapackEigResult(const int info, const std::string& name) { PADDLE_ENFORCE_LE(info, 0, platform::errors::PreconditionNotMet( "The QR algorithm failed to compute all the " "eigenvalues in function %s.", name.c_str())); PADDLE_ENFORCE_GE( info, 0, platform::errors::InvalidArgument( "The %d-th argument has an illegal value in function %s.", -info, name.c_str())); } template static typename std::enable_if::value>::type LapackEigvals(const framework::ExecutionContext& ctx, const Tensor& input, Tensor* output, Tensor* work, Tensor* rwork /*unused*/) { Tensor a; // will be overwritten when lapackEig exit framework::TensorCopy(input, input.place(), &a); Tensor w; int64_t n_dim = input.dims()[1]; auto* w_data = w.mutable_data(framework::make_ddim({n_dim << 1}), ctx.GetPlace()); int64_t work_mem = work->memory_size(); int64_t required_work_mem = 3 * n_dim * sizeof(T); PADDLE_ENFORCE_GE( work_mem, 3 * n_dim * sizeof(T), platform::errors::InvalidArgument( "The memory size of the work tensor in LapackEigvals function " "should be at least %" PRId64 " bytes, " "but received work\'s memory size = %" PRId64 " bytes.", required_work_mem, work_mem)); int info = 0; pten::funcs::lapackEig('N', 'N', static_cast(n_dim), a.template data(), static_cast(n_dim), w_data, NULL, 1, NULL, 1, work->template data(), static_cast(work_mem / sizeof(T)), static_cast(NULL), &info); std::string name = "framework::platform::dynload::dgeev_"; if (framework::TransToProtoVarType(input.dtype()) == framework::proto::VarType::FP64) { name = "framework::platform::dynload::sgeev_"; } CheckLapackEigResult(info, name); platform::ForRange for_range( ctx.template device_context(), n_dim); pten::funcs::RealImagToComplexFunctor> functor( w_data, w_data + n_dim, output->template data>(), n_dim); for_range(functor); } template typename std::enable_if>::value || std::is_same>::value>::type LapackEigvals(const framework::ExecutionContext& ctx, const Tensor& input, Tensor* output, Tensor* work, Tensor* rwork) { Tensor a; // will be overwritten when lapackEig exit framework::TensorCopy(input, input.place(), &a); int64_t work_mem = work->memory_size(); int64_t n_dim = input.dims()[1]; int64_t required_work_mem = 3 * n_dim * sizeof(T); PADDLE_ENFORCE_GE( work_mem, 3 * n_dim * sizeof(T), platform::errors::InvalidArgument( "The memory size of the work tensor in LapackEigvals function " "should be at least %" PRId64 " bytes, " "but received work\'s memory size = %" PRId64 " bytes.", required_work_mem, work_mem)); int64_t rwork_mem = rwork->memory_size(); int64_t required_rwork_mem = (n_dim << 1) * sizeof(pten::funcs::Real); PADDLE_ENFORCE_GE( rwork_mem, required_rwork_mem, platform::errors::InvalidArgument( "The memory size of the rwork tensor in LapackEigvals function " "should be at least %" PRId64 " bytes, " "but received rwork\'s memory size = %" PRId64 " bytes.", required_rwork_mem, rwork_mem)); int info = 0; pten::funcs::lapackEig>( 'N', 'N', static_cast(n_dim), a.template data(), static_cast(n_dim), output->template data(), NULL, 1, NULL, 1, work->template data(), static_cast(work_mem / sizeof(T)), rwork->template data>(), &info); std::string name = "framework::platform::dynload::cgeev_"; if (framework::TransToProtoVarType(input.dtype()) == framework::proto::VarType::COMPLEX64) { name = "framework::platform::dynload::zgeev_"; } CheckLapackEigResult(info, name); } template class EigvalsKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { const Tensor* input = ctx.Input("X"); Tensor* output = ctx.Output("Out"); output->mutable_data>(ctx.GetPlace()); std::vector input_matrices; SpiltBatchSquareMatrix(*input, /*->*/ &input_matrices); int64_t n_dim = input_matrices[0].dims()[1]; int64_t n_batch = input_matrices.size(); DDim output_dims = output->dims(); output->Resize(framework::make_ddim({n_batch, n_dim})); std::vector output_vectors = output->Split(1, 0); // query workspace size T qwork; int info; pten::funcs::lapackEig>( 'N', 'N', static_cast(n_dim), input_matrices[0].template data(), static_cast(n_dim), NULL, NULL, 1, NULL, 1, &qwork, -1, static_cast*>(NULL), &info); int64_t lwork = static_cast(qwork); Tensor work, rwork; try { work.mutable_data(framework::make_ddim({lwork}), ctx.GetPlace()); } catch (memory::allocation::BadAlloc&) { LOG(WARNING) << "Failed to allocate Lapack workspace with the optimal " << "memory size = " << lwork * sizeof(T) << " bytes, " << "try reallocating a smaller workspace with the minimum " << "required size = " << 3 * n_dim * sizeof(T) << " bytes, " << "this may lead to bad performance."; lwork = 3 * n_dim; work.mutable_data(framework::make_ddim({lwork}), ctx.GetPlace()); } if (framework::IsComplexType( framework::TransToProtoVarType(input->dtype()))) { rwork.mutable_data>( framework::make_ddim({n_dim << 1}), ctx.GetPlace()); } for (int64_t i = 0; i < n_batch; ++i) { LapackEigvals(ctx, input_matrices[i], &output_vectors[i], &work, &rwork); } output->Resize(output_dims); } }; } // namespace operators } // namespace paddle