diff --git a/paddle/fluid/operators/determinant_op.cc b/paddle/fluid/operators/determinant_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..98247fbc862bbc199316a4d4c8971d0f4a159544 --- /dev/null +++ b/paddle/fluid/operators/determinant_op.cc @@ -0,0 +1,194 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/determinant_op.h" + +namespace paddle { +namespace operators { + +class DeterminantOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "determinant"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "determinant"); + } +}; + +class DeterminantOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Input", "(Tensor) The input tensor of determinant."); + AddOutput("Out", + "(Tensor) The output Tensor containing the determinant" + "value of a square matrix or batches of square matrices "); + + AddComment(R"DOC( +Determinant Operator.)DOC"); + } +}; + +class DeterminantGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", + "DeterminantGradOp"); + OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "DeterminantGradOp"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + framework::GradVarName("Out"), "DeterminantGradOp"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Input")), "Output", + framework::GradVarName("Input"), "DeterminantGradOp"); + + ctx->SetOutputDim(framework::GradVarName("Input"), + ctx->GetInputDim("Input")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); + } +}; + +template +class DeterminantGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr grad_op) const override { + grad_op->SetType("determinant_grad"); + grad_op->SetInput("Input", this->Input("Input")); + grad_op->SetInput("Out", this->Output("Out")); + grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + grad_op->SetOutput(framework::GradVarName("Input"), + this->InputGrad("Input")); + grad_op->SetAttrMap(this->Attrs()); + } +}; + +DECLARE_NO_NEED_BUFFER_VARS_INFERER(DeterminantGradNoNeedBufferVarsInferer, + "Input"); + +class SlogDeterminantOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "determinant"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "determinant"); + } +}; + +class SlogDeterminantOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Input", "(Tensor) The input tensor of SlogDeterminant."); + AddOutput("Out", + "(Tensor) The output tensor containing the sign of the" + "determinant and the natural logarithm" + "of the absolute value of determinant,"); + + AddComment(R"DOC( +SlogDeterminant Operator.)DOC"); + } +}; + +class SlogDeterminantGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", + "SlogDeterminantGradOp"); + OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", + "SlogDeterminantGradOp"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + framework::GradVarName("Out"), "SlogDeterminantGradOp"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Input")), "Output", + framework::GradVarName("Input"), "SlogDeterminantGradOp"); + + ctx->SetOutputDim(framework::GradVarName("Input"), + ctx->GetInputDim("Input")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); + } +}; + +template +class SlogDeterminantGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr grad_op) const override { + grad_op->SetType("slogdeterminant_grad"); + grad_op->SetInput("Input", this->Input("Input")); + grad_op->SetInput("Out", this->Output("Out")); + grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + grad_op->SetOutput(framework::GradVarName("Input"), + this->InputGrad("Input")); + grad_op->SetAttrMap(this->Attrs()); + } +}; + +DECLARE_NO_NEED_BUFFER_VARS_INFERER(SlogDeterminantGradNoNeedBufferVarsInferer, + "Input"); + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OPERATOR(determinant, ops::DeterminantOp, ops::DeterminantOpMaker, + ops::DeterminantGradOpMaker, + ops::DeterminantGradOpMaker); + +REGISTER_OPERATOR(determinant_grad, ops::DeterminantGradOp) + +REGISTER_OP_CPU_KERNEL(determinant, + ops::DeterminantKernel, + ops::DeterminantKernel); + +REGISTER_OP_CPU_KERNEL( + determinant_grad, ops::DeterminantGradKernel, + ops::DeterminantGradKernel); + +REGISTER_OPERATOR(slogdeterminant, ops::SlogDeterminantOp, + ops::SlogDeterminantOpMaker, + ops::SlogDeterminantGradOpMaker, + ops::SlogDeterminantGradOpMaker); + +REGISTER_OPERATOR(slogdeterminant_grad, + ops::SlogDeterminantGradOp) // reuse det grad op + +REGISTER_OP_CPU_KERNEL( + slogdeterminant, ops::SlogDeterminantKernel, + ops::SlogDeterminantKernel); + +REGISTER_OP_CPU_KERNEL( + slogdeterminant_grad, + ops::SlogDeterminantGradKernel, + ops::SlogDeterminantGradKernel); diff --git a/paddle/fluid/operators/determinant_op.cu b/paddle/fluid/operators/determinant_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..d19d4c3d093860c1f603e4d752063b7a858c0460 --- /dev/null +++ b/paddle/fluid/operators/determinant_op.cu @@ -0,0 +1,36 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/determinant_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL( + determinant, ops::DeterminantKernel, + ops::DeterminantKernel); + +REGISTER_OP_CUDA_KERNEL( + determinant_grad, + ops::DeterminantGradKernel, + ops::DeterminantGradKernel); + +REGISTER_OP_CUDA_KERNEL( + slogdeterminant, ops::SlogDeterminantKernel, + ops::SlogDeterminantKernel); + +REGISTER_OP_CUDA_KERNEL( + slogdeterminant_grad, + ops::SlogDeterminantGradKernel, + ops::SlogDeterminantGradKernel); diff --git a/paddle/fluid/operators/determinant_op.h b/paddle/fluid/operators/determinant_op.h new file mode 100644 index 0000000000000000000000000000000000000000..4c17869fb5d2a582b0124c859a4d87971a103114 --- /dev/null +++ b/paddle/fluid/operators/determinant_op.h @@ -0,0 +1,436 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/complex_functors.h" +#include "paddle/fluid/operators/math/matrix_inverse.h" +#include "paddle/fluid/operators/svd_helper.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +T sign(T val) { + return static_cast(T(0) < val) - (val < T(0)); +} + +template +class EigenMatrix {}; + +template <> +class EigenMatrix { + public: + using MatrixType = Eigen::MatrixXf; +}; + +template <> +class EigenMatrix { + public: + using MatrixType = Eigen::MatrixXd; +}; + +inline int64_t GetBatchCount(const framework::DDim dims) { + int64_t batch_count = 1; + auto dim_size = dims.size(); + PADDLE_ENFORCE_GE( + dim_size, 2, + platform::errors::InvalidArgument( + "the input matrix dimension size should greater than 2.")); + + // Cumulative multiplying each dimension until the last 2 to get the batch + // count, + // for example a tensor with shape [3,3,3,3], the batch count of matrices is + // 9. + for (int64_t i = 0; i < dims.size() - 2; i++) { + batch_count *= dims[i]; + } + + return batch_count; +} + +template +struct DeterminantFunctor { + void operator()(const Tensor& input, const framework::ExecutionContext ctx, + int64_t rank, int64_t batch_count, Tensor* output) { + std::vector input_vec; + std::vector output_vec; + framework::TensorToVector(input, ctx.device_context(), &input_vec); + for (int64_t i = 0; i < batch_count; ++i) { // maybe can be parallel + auto begin_iter = input_vec.begin() + i * rank * rank; + auto end_iter = input_vec.begin() + (i + 1) * rank * rank; + std::vector sub_vec(begin_iter, + end_iter); // get every square matrix data + typename EigenMatrix::MatrixType matrix(rank, rank); + for (int64_t i = 0; i < rank; ++i) { + for (int64_t j = 0; j < rank; ++j) { + matrix(i, j) = sub_vec[rank * i + j]; + } + } + output_vec.push_back(matrix.determinant()); + } + framework::TensorFromVector(output_vec, output); + } +}; +template +class DeterminantKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* input = context.Input("Input"); + auto input_dim = vectorize(input->dims()); + auto input_dim_size = input_dim.size(); + auto* output = context.Output("Out"); + + auto batch_count = GetBatchCount(input->dims()); + VLOG(2) << "input dim:" << input->dims(); + PADDLE_ENFORCE_GE( + input_dim_size, 2, + platform::errors::InvalidArgument( + "the input matrix dimension size should greater than 2.")); + PADDLE_ENFORCE_EQ(input_dim[input_dim_size - 1], + input_dim[input_dim_size - 2], + platform::errors::InvalidArgument( + "the input matrix should be square matrix.")); + auto rank = input_dim[input_dim_size - 1]; // square matrix length + DeterminantFunctor()(*input, context, rank, batch_count, output); + auto output_dims = + framework::slice_ddim(input->dims(), 0, input_dim_size - 2); + if (input_dim_size > 2) { + output->Resize(output_dims); + } else { + // when input is a two-dimension matrix, The det value is a number. + output->Resize({1}); + } + VLOG(2) << "output dim:" << output->dims(); + } +}; + +template +struct FoundZeroFunctor { + FoundZeroFunctor(const T* x, int64_t numel, bool* res) + : x_(x), numel_(numel), res_(res) {} + HOSTDEVICE void operator()(size_t idx) const { + if (*res_ || idx >= static_cast(numel_)) { + // founded zero number + return; + } + *res_ = (x_[idx] == static_cast(0)); + } + const T* x_; + int64_t numel_; + bool* res_; +}; + +template +inline bool CheckMatrixInvertible(const framework::ExecutionContext& ctx, + const framework::Tensor* det) { + auto& dev_ctx = ctx.template device_context(); + auto numel = det->numel(); + + framework::Tensor dev_tensor; + auto* data = dev_tensor.mutable_data({1}, ctx.GetPlace()); + + // set false + math::SetConstant zero; + zero(dev_ctx, &dev_tensor, false); + + // find whether zero + platform::ForRange for_range(dev_ctx, numel); + FoundZeroFunctor functor(det->data(), numel, data); + for_range(functor); + + // copy to host + dev_ctx.Wait(); + framework::Tensor cpu_tensor; + framework::TensorCopy(dev_tensor, platform::CPUPlace(), &cpu_tensor); + + // if founded zero, the matrix is not invertible + // else the matrix is invertible + auto* res = cpu_tensor.data(); + return !(*res); +} + +template +class DeterminantGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto& dev_ctx = context.template device_context(); + const auto* input = context.Input("Input"); + const auto* det = context.Input("Out"); + const auto* grad = + context.Input(framework::GradVarName("Out")); + auto* ddet = + context.Output(framework::GradVarName("Input")); + + auto input_dims_size = input->dims().size(); + if (input_dims_size > 2) { + PADDLE_ENFORCE_EQ( + grad->dims().size() + 2, input_dims_size, + platform::errors::InvalidArgument( + "The grad tensor of det dims size should 2 less than" + " input tensor's, but here differ %d", + input_dims_size - grad->dims().size())); + } else if (input_dims_size == 2) { + // input dims size 2 and grad dims size 1 is possible + PADDLE_ENFORCE_EQ( + grad->dims().size(), 1, + platform::errors::InvalidArgument( + "The grad tensor of det dims size should 2 less than" + " input tensor's, but here differ %d", + input_dims_size - grad->dims().size())); + } else { + // checked in forward, pass + } + + // Check Whether the matrix is invertible + // (matrix A not invertible) == (det(A)=0) + if (!CheckMatrixInvertible(context, det)) { + // The matrix is not invertible + VLOG(3) << "The input matrix not invertible!"; + ddet->Resize(input->dims()); + ddet->mutable_data(context.GetPlace()); + math::SetConstant zero; + zero(dev_ctx, ddet, static_cast(0.0f)); + return; + } + + // The matrix is invertible + // let |A| = Determinant(A) + // Ref to https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf + // we set d|A| = unsqueeze(dA * |A|, [-1, -2]) * inverse(A).transpose(-2, + // -1) + + math::DeviceIndependenceTensorOperations helper(context); + + // First: inverse(A) + framework::Tensor inverse_A; + // A must be square matrices! + inverse_A.Resize(input->dims()); + inverse_A.mutable_data(context.GetPlace()); + + math::MatrixInverseFunctor mat_inv; + mat_inv(dev_ctx, *input, &inverse_A); + + VLOG(3) << "inverse(A) dims: " << inverse_A.dims(); + + // Second: inverse(A).transpose(-2, -1) + framework::Tensor transpose_inverse_A = helper.Transpose(inverse_A); + VLOG(3) << "(dA * |A|).transpose(-2, -1) dims: " + << transpose_inverse_A.dims(); + + // Third: dA * |A| + auto mul_dA_detA = helper.Mul(*grad, *det); + VLOG(3) << "dA * |A| dims: " << mul_dA_detA.dims(); + + // Fourth: unsqueeze(dA * |A|, [-1, -2]) + auto unsqueeze1 = helper.Unsqueeze(mul_dA_detA, -1); + auto unsqueeze2 = helper.Unsqueeze(unsqueeze1, -2); + VLOG(3) << "unsqueezed(dA * |A|) dims: " << unsqueeze2.dims(); + + // Finally: unsqueeze(dA * |A|) * inverse(A) + auto res = helper.Mul(unsqueeze2, transpose_inverse_A); + + VLOG(3) << "unsqueeze(dA * |A|) * inverse(A) dims: " << res.dims(); + + framework::TensorCopy(res, context.GetPlace(), ddet); + + ddet->Resize(input->dims()); + VLOG(3) << "d|A| dims: " << ddet->dims(); + } +}; + +template +struct SlogDeterminantFunctor { + void operator()(const Tensor& input, const framework::ExecutionContext ctx, + int64_t rank, int64_t batch_count, Tensor* output) { + std::vector input_vec; + std::vector sign_vec; + std::vector log_vec; + std::vector output_vec; + framework::TensorToVector(input, ctx.device_context(), &input_vec); + for (int64_t i = 0; i < batch_count; ++i) { // maybe can be parallel + auto begin_iter = input_vec.begin() + i * rank * rank; + auto end_iter = input_vec.begin() + (i + 1) * rank * rank; + std::vector sub_vec(begin_iter, + end_iter); // get every square matrix data + typename EigenMatrix::MatrixType matrix(rank, rank); + for (int64_t i = 0; i < rank; ++i) { + for (int64_t j = 0; j < rank; ++j) { + matrix(i, j) = sub_vec[rank * i + j]; + } + } + VLOG(2) << "det value: " << matrix.determinant(); + VLOG(2) << "matrix val: " << matrix; + auto det_val = matrix.determinant(); + sign_vec.push_back(sign(det_val)); + det_val >= 0 + ? log_vec.push_back(std::log(det_val)) + : log_vec.push_back(std::log(std::abs( + det_val))); // for computing log value of a negative value. + } + // merge sign_vec and log_vec as final output_vec + output_vec.insert(output_vec.end(), sign_vec.begin(), sign_vec.end()); + output_vec.insert(output_vec.end(), log_vec.begin(), log_vec.end()); + framework::TensorFromVector(output_vec, output); + } +}; + +template +class SlogDeterminantKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* input = context.Input("Input"); + auto input_dim = vectorize(input->dims()); + auto input_dim_size = input_dim.size(); + auto* output = context.Output("Out"); + + auto batch_count = GetBatchCount(input->dims()); + VLOG(2) << "input dim:" << input->dims(); + PADDLE_ENFORCE_GE( + input_dim_size, 2, + platform::errors::InvalidArgument( + "the input matrix dimension size should greater than 2.")); + PADDLE_ENFORCE_EQ(input_dim[input_dim_size - 1], + input_dim[input_dim_size - 2], + platform::errors::InvalidArgument( + "the input matrix should be square matrix.")); + auto rank = input_dim[input_dim_size - 1]; // square matrix length + SlogDeterminantFunctor()(*input, context, rank, batch_count, output); + std::vector output_dim_vec(input_dim.begin(), input_dim.end() - 2); + if (input_dim.size() == static_cast(2)) { + // when input is a two-dimension matrix, The det value is a number. + output_dim_vec = {1}; + } + output_dim_vec.insert(output_dim_vec.begin(), + 2); // make the output dims as same as numpy + auto output_dims = framework::make_ddim(output_dim_vec); + output->Resize(output_dims); + VLOG(2) << "output dim:" << output->dims(); + } +}; + +template +class SlogDeterminantGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto& dev_ctx = context.template device_context(); + const auto* input = context.Input("Input"); + const auto* slogdet = context.Input("Out"); + const auto* grad = + context.Input(framework::GradVarName("Out")); + auto* dslogdet = + context.Output(framework::GradVarName("Input")); + + PADDLE_ENFORCE_EQ(grad->dims()[0], 2, + platform::errors::InvalidArgument( + "The grad tensor of SlogDet should contain two" + " grad: sign and absslogdet, but here %ld.", + grad->dims()[0])); + if (input->dims().size() > 2) { + PADDLE_ENFORCE_EQ( + grad->dims().size() + 1, input->dims().size(), + platform::errors::InvalidArgument( + "The grad tensor of slogdet dims size should 1 less than" + " input tensor's, but here differ %d", + input->dims().size() - grad->dims().size())); + } + + // Check Whether the matrix is invertible + // (matrix A not invertible) == (absslogdet(A)=0) + auto slogdet_vec = slogdet->Split(1, 0); + auto absslogdet_val = slogdet_vec[0]; + if (!CheckMatrixInvertible(context, &absslogdet_val)) { + // The matrix is not invertible + VLOG(3) << "The input matrix not invertible!"; + dslogdet->Resize(input->dims()); + dslogdet->mutable_data(context.GetPlace()); + math::SetConstant zero; + zero(dev_ctx, dslogdet, std::numeric_limits::quiet_NaN()); + return; + } + + // The matrix is invertible + // let sl|A| = SlogDeterminant(A) + // Ref to https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf + // we set dsl|A| = unsqueeze(dslA, [-1, -2]) * + // inverse(A).conj().transpose(-2, -1) + + math::DeviceIndependenceTensorOperations helper(context); + + // First: inverse(A) + framework::Tensor inverse_A; + // A must be square matrices! + inverse_A.Resize(input->dims()); + inverse_A.mutable_data(context.GetPlace()); + + math::MatrixInverseFunctor mat_inv; + mat_inv(dev_ctx, *input, &inverse_A); + + VLOG(3) << "inverse(A) dims: " << inverse_A.dims(); + + // Second: inverse(A).conj() + framework::Tensor conj_inverse_A; + conj_inverse_A.Resize(inverse_A.dims()); + auto numel = input->numel(); + auto* conj_data = conj_inverse_A.mutable_data(context.GetPlace(), + size_t(numel * sizeof(T))); + + platform::ForRange for_range(dev_ctx, numel); + math::ConjFunctor functor(inverse_A.data(), numel, conj_data); + for_range(functor); + + VLOG(3) << "inverse(A).conj() dims: " << conj_inverse_A.dims(); + + // Third: inverse(A).conj().transpose(-2, -1) + framework::Tensor transpose_inverse_A = helper.Transpose(conj_inverse_A); + VLOG(3) << "inverse(A).conj().transpose(-2, -1) dims: " + << transpose_inverse_A.dims(); + + // Fourth: split grad value to [sign_grad, absslogdet_grad] + auto grad_vec = grad->Split(1, 0); + auto det_grad = grad_vec[1]; + + // remmove useless first dimension + int det_grad_size = det_grad.dims().size(); + std::vector det_grad_vec; + for (int i = 1; i < det_grad_size; ++i) { + det_grad_vec.emplace_back(det_grad.dims()[i]); + } + det_grad.Resize(det_grad.dims().reshape(det_grad_vec)); + + // Fifth: unsqueeze(dslA, [-1, -2]) + auto unsqueeze1 = helper.Unsqueeze(det_grad, -1); + auto unsqueeze2 = helper.Unsqueeze(unsqueeze1, -2); + VLOG(3) << "unsqueezed(dslA, [-1, -2]) dims: " << unsqueeze2.dims(); + + // Finally: unsqueeze(dslA) * inverse(A) + auto res = helper.Mul(unsqueeze2, transpose_inverse_A); + VLOG(3) << "unsqueeze(dslA) * inverse(A) dims: " << res.dims(); + + framework::TensorCopy(res, context.GetPlace(), dslogdet); + dslogdet->Resize(input->dims()); + VLOG(3) << "dsl|A| dims: " << dslogdet->dims(); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index dba9ea892fa7a572581dfbe4b9a0673becf504e7..e09138ef09409906ef31b09c72db404549ac371e 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -101,6 +101,8 @@ from .tensor.linalg import cholesky # noqa: F401 from .tensor.linalg import bmm # noqa: F401 from .tensor.linalg import histogram # noqa: F401 from .tensor.linalg import mv # noqa: F401 +from .tensor.linalg import det # noqa: F401 +from .tensor.linalg import slogdet # noqa: F401 from .tensor.linalg import multi_dot # noqa: F401 from .tensor.linalg import matrix_power # noqa: F401 from .tensor.linalg import svd # noqa: F401 diff --git a/python/paddle/fluid/tests/unittests/test_determinant_op.py b/python/paddle/fluid/tests/unittests/test_determinant_op.py new file mode 100644 index 0000000000000000000000000000000000000000..f8110bffa2f7137cd88d7fa8294b59d74f6d3e71 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_determinant_op.py @@ -0,0 +1,149 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle +import paddle.nn.functional as F +import paddle.fluid as fluid +import paddle.fluid.core as core +import paddle.tensor as tensor + +paddle.enable_static() + + +class TestDeterminantOp(OpTest): + def setUp(self): + self.init_data() + self.op_type = "determinant" + self.outputs = {'Out': self.target} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['Input'], ['Out']) + + def init_data(self): + np.random.seed(0) + self.case = np.random.rand(3, 3, 3, 5, 5).astype('float64') + self.inputs = {'Input': self.case} + self.target = np.linalg.det(self.case) + + +class TestDeterminantOpCase1(TestDeterminantOp): + def init_data(self): + np.random.seed(0) + self.case = np.random.rand(10, 10).astype('float32') + self.inputs = {'Input': self.case} + self.target = np.linalg.det(self.case) + + +class TestDeterminantOpCase2(TestDeterminantOp): + def init_data(self): + np.random.seed(0) + # not invertible matrix + self.case = np.ones([4, 2, 4, 4]).astype('float64') + self.inputs = {'Input': self.case} + self.target = np.linalg.det(self.case) + + +class TestDeterminantAPI(unittest.TestCase): + def setUp(self): + np.random.seed(0) + self.shape = [3, 3, 5, 5] + self.x = np.random.random(self.shape).astype(np.float32) + self.place = paddle.CPUPlace() + + def test_api_static(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data('X', self.shape) + out = paddle.linalg.det(x) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x}, fetch_list=[out]) + out_ref = np.linalg.det(self.x) + + for out in res: + self.assertEqual(np.allclose(out, out_ref, rtol=1e-03), True) + + def test_api_dygraph(self): + paddle.disable_static(self.place) + x_tensor = paddle.to_tensor(self.x) + out = paddle.linalg.det(x_tensor) + out_ref = np.linalg.det(self.x) + self.assertEqual(np.allclose(out.numpy(), out_ref, rtol=1e-03), True) + paddle.enable_static() + + +class TestSlogDeterminantOp(OpTest): + def setUp(self): + self.op_type = "slogdeterminant" + self.init_data() + self.outputs = {'Out': self.target} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + # the slog det's grad value is always huge + self.check_grad(['Input'], ['Out'], max_relative_error=0.1) + + def init_data(self): + np.random.seed(0) + self.case = np.random.rand(4, 5, 5).astype('float64') + self.inputs = {'Input': self.case} + self.target = np.array(np.linalg.slogdet(self.case)) + + +class TestSlogDeterminantOpCase1(TestSlogDeterminantOp): + def init_data(self): + np.random.seed(0) + self.case = np.random.rand(2, 2, 5, 5).astype(np.float32) + self.inputs = {'Input': self.case} + self.target = np.array(np.linalg.slogdet(self.case)) + + +class TestSlogDeterminantAPI(unittest.TestCase): + def setUp(self): + np.random.seed(0) + self.shape = [3, 3, 5, 5] + self.x = np.random.random(self.shape).astype(np.float32) + self.place = paddle.CPUPlace() + + def test_api_static(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data('X', self.shape) + out = paddle.linalg.slogdet(x) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x}, fetch_list=[out]) + out_ref = np.array(np.linalg.slogdet(self.x)) + for out in res: + self.assertEqual(np.allclose(out, out_ref, rtol=1e-03), True) + + def test_api_dygraph(self): + paddle.disable_static(self.place) + x_tensor = paddle.to_tensor(self.x) + out = paddle.linalg.slogdet(x_tensor) + out_ref = np.array(np.linalg.slogdet(self.x)) + self.assertEqual(np.allclose(out.numpy(), out_ref, rtol=1e-03), True) + paddle.enable_static() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/linalg.py b/python/paddle/linalg.py index f3bac8f147ab857b1d3fb60449551ea76e878b22..d57d9a4bdb6780c1eed2f8a65fc71bddc45c1c82 100644 --- a/python/paddle/linalg.py +++ b/python/paddle/linalg.py @@ -23,6 +23,8 @@ from .tensor.linalg import multi_dot # noqa: F401 from .tensor.linalg import matrix_rank from .tensor.linalg import svd from .tensor.linalg import eigh # noqa: F401 +from .tensor.linalg import det +from .tensor.linalg import slogdet from .tensor.linalg import pinv __all__ = [ @@ -35,6 +37,8 @@ __all__ = [ 'matrix_rank', 'svd', 'matrix_power', + 'det', + 'slogdet', 'eigh', 'pinv', 'solve' diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index ce97224ae1157d5accb9f8cc1dec5cac2aedfe17..b9fb0e7c563e708f26e43b112fcf20dc359db939 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -14,7 +14,7 @@ import numpy as np from ..fluid.layer_helper import LayerHelper -from ..fluid.data_feeder import check_variable_and_dtype, check_type +from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype from ..fluid.framework import in_dygraph_mode, _varbase_creator, Variable from ..fluid.layers import transpose, cast # noqa: F401 @@ -1351,6 +1351,109 @@ def mv(x, vec, name=None): return out +def det(x): + """ + Calculates determinant value of a square matrix or batches of square matrices. + Args: + x (Tensor): input (Tensor): the input matrix of size `(n, n)` or the batch of matrices of size + `(*, n, n)` where `*` is one or more batch dimensions. + Returns: + y (Tensor):the determinant value of a square matrix or batches of square matrices. + + Example: + .. code-block:: python + + import paddle + + x = paddle.randn([3,3,3]) + + A = paddle.det(x) + + print(A) + + # [ 0.02547996, 2.52317095, -6.15900707]) + + + """ + if in_dygraph_mode(): + return core.ops.determinant(x) + + check_dtype(x.dtype, 'Input', ['float32', 'float64'], 'det') + + input_shape = list(x.shape) + assert len(input_shape) >= 2, \ + "The x must be at least 2-dimensional, " \ + "but received Input x's dimensional: %s.\n" % \ + len(input_shape) + + assert (input_shape[-1] == input_shape[-2]), \ + "Expect squared input," \ + "but received %s by %s matrix.\n" \ + %(input_shape[-2], input_shape[-1]) \ + + helper = LayerHelper('determinant', **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type='determinant', inputs={'Input': [x]}, outputs={'Out': [out]}) + return out + + +def slogdet(x): + """ + Calculates the sign and natural logarithm of the absolute value of a square matrix's or batches square matrices' determinant. + The determinant can be computed with ``sign * exp(logabsdet) + + Supports input of float, double + + Note that for matrices that have zero determinant, this returns ``(0, -inf)`` + Args: + x (Tensor): the batch of matrices of size :math:`(*, n, n)` + where math:`*` is one or more batch dimensions. + + Returns: + y (Tensor): A tensor containing the sign of the determinant and the natural logarithm + of the absolute value of determinant, respectively. + + Example: + .. code-block:: python + + import paddle + + x = paddle.randn([3,3,3]) + + A = paddle.slogdet(x) + + print(A) + + # [[ 1. , 1. , -1. ], + # [-0.98610914, -0.43010661, -0.10872950]]) + + """ + if in_dygraph_mode(): + return core.ops.slogdeterminant(x) + + check_dtype(x.dtype, 'Input', ['float32', 'float64'], 'slogdet') + + input_shape = list(x.shape) + assert len(input_shape) >= 2, \ + "The x must be at least 2-dimensional, " \ + "but received Input x's dimensional: %s.\n" % \ + len(input_shape) + + assert (input_shape[-1] == input_shape[-2]), \ + "Expect squared input," \ + "but received %s by %s matrix.\n" \ + %(input_shape[-2], input_shape[-1]) \ + + helper = LayerHelper('slogdeterminant', **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type='slogdeterminant', inputs={'Input': [x]}, outputs={'Out': [out]}) + return out + + def svd(x, full_matrices=False, name=None): r""" Computes the singular value decomposition of one matrix or a batch of regular matrices.