diff --git a/paddle/fluid/operators/determinant_op.cc b/paddle/fluid/operators/determinant_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..379a401cde62e870f87c3ad8862fe9ccb18bd1d7 --- /dev/null +++ b/paddle/fluid/operators/determinant_op.cc @@ -0,0 +1,191 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/determinant_op.h" + +namespace paddle { +namespace operators { + +class DeterminantOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "determinant"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "determinant"); + } +}; + +class DeterminantOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Input", "(Tensor) The input tensor of determinant."); + AddOutput("Out", + "(Tensor) The output Tensor containing the determinant" + "value of a square matrix or batches of square matrices "); + + AddComment(R"DOC( +Determinant Operator.)DOC"); + } +}; + +class DeterminantGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", + "DeterminantGradOp"); + OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "DeterminantGradOp"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Input")), "Output", + framework::GradVarName("Input"), "DeterminantGradOp"); + + ctx->SetOutputDim(framework::GradVarName("Input"), + ctx->GetInputDim("Input")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); + } +}; + +template +class DeterminantGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr grad_op) const override { + grad_op->SetType("determinant_grad"); + grad_op->SetInput("Input", this->Input("Input")); + grad_op->SetInput("Out", this->Output("Out")); + grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + grad_op->SetOutput(framework::GradVarName("Input"), + this->InputGrad("Input")); + grad_op->SetAttrMap(this->Attrs()); + } +}; + +DECLARE_NO_NEED_BUFFER_VARS_INFERER(DeterminantGradNoNeedBufferVarsInferer, + "Input"); + +class SlogDeterminantOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "determinant"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "determinant"); + } +}; + +class SlogDeterminantOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Input", "(Tensor) The input tensor of SlogDeterminant."); + AddOutput("Out", + "(Tensor) The output tensor containing the sign of the" + "determinant and the natural logarithm" + "of the absolute value of determinant,"); + + AddComment(R"DOC( +SlogDeterminant Operator.)DOC"); + } +}; + +class SlogDeterminantGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", + "SlogDeterminantGradOp"); + OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", + "SlogDeterminantGradOp"); + + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Input")), "Output", + framework::GradVarName("Input"), "SlogDeterminantGradOp"); + + ctx->SetOutputDim(framework::GradVarName("Input"), + ctx->GetInputDim("Input")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); + } +}; + +template +class SlogDeterminantGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr grad_op) const override { + grad_op->SetType("slogdeterminant_grad"); + grad_op->SetInput("Input", this->Input("Input")); + grad_op->SetInput("Out", this->Output("Out")); + grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + grad_op->SetOutput(framework::GradVarName("Input"), + this->InputGrad("Input")); + grad_op->SetAttrMap(this->Attrs()); + } +}; + +DECLARE_NO_NEED_BUFFER_VARS_INFERER(SlogDeterminantGradNoNeedBufferVarsInferer, + "Input"); + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OPERATOR(determinant, ops::DeterminantOp, ops::DeterminantOpMaker, + ops::DeterminantGradOpMaker, + ops::DeterminantGradOpMaker); + +REGISTER_OPERATOR(determinant_grad, ops::DeterminantGradOp) + +REGISTER_OP_CPU_KERNEL(determinant, + ops::DeterminantKernel, + ops::DeterminantKernel); + +REGISTER_OP_CPU_KERNEL( + determinant_grad, ops::DeterminantGradKernel, + ops::DeterminantGradKernel); + +REGISTER_OPERATOR(slogdeterminant, ops::SlogDeterminantOp, + ops::SlogDeterminantOpMaker, + ops::SlogDeterminantGradOpMaker, + ops::SlogDeterminantGradOpMaker); + +REGISTER_OPERATOR(slogdeterminant_grad, + ops::DeterminantGradOp) // reuse det grad op + +REGISTER_OP_CPU_KERNEL( + slogdeterminant, ops::SlogDeterminantKernel, + ops::SlogDeterminantKernel); + +REGISTER_OP_CPU_KERNEL( + slogdeterminant_grad, + ops::DeterminantGradKernel, + ops::DeterminantGradKernel); diff --git a/paddle/fluid/operators/determinant_op.cu b/paddle/fluid/operators/determinant_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..f17d94d805228b6e4a296554df0ccce5e7573b84 --- /dev/null +++ b/paddle/fluid/operators/determinant_op.cu @@ -0,0 +1,72 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/determinant_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" + +namespace paddle { +namespace operators { + +using platform::PADDLE_CUDA_NUM_THREADS; +using Tensor = framework::Tensor; + +template +__global__ void DeterminantGrad(const size_t numel, T* out) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid < numel) { + out[tid] = static_cast(1); + } +} + +template +class DeterminantGradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const auto* dout = context.Input(framework::GradVarName("Out")); + const T* dout_data = dout->data(); + auto dout_dim = vectorize(dout->dims()); + + auto* dx = context.Output(framework::GradVarName("Input")); + T* dx_data = dx->mutable_data(context.GetPlace()); + + int64_t numel = dx->numel(); + for (int64_t idx = 0; idx < numel; idx++) { + dx_data[idx] = static_cast(1); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL( + determinant, ops::DeterminantKernel, + ops::DeterminantKernel); + +REGISTER_OP_CUDA_KERNEL( + determinant_grad, + ops::DeterminantGradKernel, + ops::DeterminantGradKernel); + +REGISTER_OP_CUDA_KERNEL( + slogdeterminant, ops::SlogDeterminantKernel, + ops::SlogDeterminantKernel); + +REGISTER_OP_CUDA_KERNEL( + slogdeterminant_grad, + ops::SlogDeterminantGradKernel, + ops::SlogDeterminantGradKernel); diff --git a/paddle/fluid/operators/determinant_op.h b/paddle/fluid/operators/determinant_op.h new file mode 100644 index 0000000000000000000000000000000000000000..ead1262d9fe06042b59d973002ecc6f2775dfe43 --- /dev/null +++ b/paddle/fluid/operators/determinant_op.h @@ -0,0 +1,206 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +T sign(T val) { + return static_cast(T(0) < val) - (val < T(0)); +} + +template +class EigenMatrix {}; + +template <> +class EigenMatrix { + public: + using MatrixType = Eigen::MatrixXf; +}; + +template <> +class EigenMatrix { + public: + using MatrixType = Eigen::MatrixXd; +}; + +inline int64_t GetBatchCount(const framework::DDim dims) { + int64_t batch_count = 1; + auto dim_size = dims.size(); + PADDLE_ENFORCE_GT(dim_size, 2, + platform::errors::InvalidArgument( + "To get the number of batch square matrices, " + "the size of dimension should greater than 2.", + dim_size)); + + // Cumulative multiplying each dimension until the last 2 to get the batch + // count, + // for example a tensor with shape [3,3,3,3], the batch count of matrices is + // 9. + for (int64_t i = 0; i < dims.size() - 2; i++) { + batch_count *= dims[i]; + } + + return batch_count; +} + +template +struct DeterminantFunctor { + void operator()(const Tensor& input, const framework::ExecutionContext ctx, + int64_t rank, int64_t batch_count, Tensor* output) { + std::vector input_vec; + std::vector output_vec; + framework::TensorToVector(input, ctx.device_context(), &input_vec); + for (int64_t i = 0; i < batch_count; ++i) { // maybe can be parallel + auto begin_iter = input_vec.begin() + i * rank * rank; + auto end_iter = input_vec.begin() + (i + 1) * rank * rank; + std::vector sub_vec(begin_iter, + end_iter); // get every square matrix data + Eigen::MatrixXf matrix(rank, rank); + for (int64_t i = 0; i < rank; ++i) { + for (int64_t j = 0; j < rank; ++j) { + matrix(i, j) = sub_vec[rank * i + j]; + } + } + output_vec.push_back(matrix.determinant()); + } + framework::TensorFromVector(output_vec, output); + } +}; +template +class DeterminantKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* input = context.Input("Input"); + auto input_dim = vectorize(input->dims()); + auto input_dim_size = input_dim.size(); + auto* output = context.Output("Out"); + + auto batch_count = GetBatchCount(input->dims()); + VLOG(2) << "input dim:" << input->dims(); + PADDLE_ENFORCE_GE( + input_dim_size, 2, + platform::errors::InvalidArgument( + "the input matrix dimension size should greater than 2.")); + PADDLE_ENFORCE_EQ(input_dim[input_dim_size - 1], + input_dim[input_dim_size - 2], + platform::errors::InvalidArgument( + "the input matrix should be square matrix.")); + auto rank = input_dim[input_dim_size - 1]; // square matrix length + DeterminantFunctor()(*input, context, rank, batch_count, output); + if (input_dim_size > 2) { + auto output_dims = + framework::slice_ddim(input->dims(), 0, input_dim_size - 2); + output->Resize(output_dims); + } + VLOG(2) << "output dim:" << output->dims(); + } +}; + +template +class DeterminantGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Not support DeterminantGrad at this time.")); + } +}; + +template +struct SlogDeterminantFunctor { + void operator()(const Tensor& input, const framework::ExecutionContext ctx, + int rank, int batch_count, Tensor* output) { + std::vector input_vec; + std::vector sign_vec; + std::vector log_vec; + std::vector output_vec; + framework::TensorToVector(input, ctx.device_context(), &input_vec); + for (int i = 0; i < batch_count; ++i) { // maybe can be parallel + auto begin_iter = input_vec.begin() + i * rank * rank; + auto end_iter = input_vec.begin() + (i + 1) * rank * rank; + std::vector sub_vec(begin_iter, + end_iter); // get every square matrix data + typename EigenMatrix::MatrixType matrix(rank, rank); + for (int i = 0; i < rank; ++i) { + for (int j = 0; j < rank; ++j) { + matrix(i, j) = sub_vec[rank * i + j]; + } + } + VLOG(2) << "det value: " << matrix.determinant(); + VLOG(2) << "matrix val: " << matrix; + auto det_val = matrix.determinant(); + sign_vec.push_back(sign(det_val)); + det_val >= 0 + ? log_vec.push_back(std::log(det_val)) + : log_vec.push_back(std::log(std::abs( + det_val))); // for computing log value of a negative value. + } + // merge sign_vec and log_vec as final output_vec + output_vec.insert(output_vec.end(), sign_vec.begin(), sign_vec.end()); + output_vec.insert(output_vec.end(), log_vec.begin(), log_vec.end()); + framework::TensorFromVector(output_vec, output); + } +}; + +template +class SlogDeterminantKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* input = context.Input("Input"); + auto input_dim = vectorize(input->dims()); + auto input_dim_size = input_dim.size(); + auto* output = context.Output("Out"); + + auto batch_count = GetBatchCount(input->dims()); + VLOG(2) << "input dim:" << input->dims(); + PADDLE_ENFORCE_GE( + input_dim_size, 2, + platform::errors::InvalidArgument( + "the input matrix dimension size should greater than 2.")); + PADDLE_ENFORCE_EQ(input_dim[input_dim_size - 1], + input_dim[input_dim_size - 2], + platform::errors::InvalidArgument( + "the input matrix should be square matrix.")); + auto rank = input_dim[input_dim_size - 1]; // square matrix length + SlogDeterminantFunctor()(*input, context, rank, batch_count, output); + std::vector output_dim_vec(input_dim.begin(), input_dim.end() - 2); + output_dim_vec.insert(output_dim_vec.begin(), + 2); // make the output dims as same as numpy + auto output_dims = framework::make_ddim(output_dim_vec); + output->Resize(output_dims); + VLOG(2) << "output dim:" << output->dims(); + } +}; + +template +class SlogDeterminantGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Not support SlogDeterminantGrad at this time.")); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 6bd58ee558f0beb02e8f4a89ae8d9bd953235958..60e6c954c161e195066989b93e75156fbaff3651 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -101,6 +101,8 @@ from .tensor.linalg import cholesky # noqa: F401 from .tensor.linalg import bmm # noqa: F401 from .tensor.linalg import histogram # noqa: F401 from .tensor.linalg import mv # noqa: F401 +from .tensor.linalg import det # noqa: F401 +from .tensor.linalg import slogdet # noqa: F401 from .tensor.linalg import multi_dot # noqa: F401 from .tensor.linalg import matrix_power # noqa: F401 from .tensor.linalg import svd # noqa: F401 diff --git a/python/paddle/fluid/tests/unittests/test_determinant_op.py b/python/paddle/fluid/tests/unittests/test_determinant_op.py new file mode 100644 index 0000000000000000000000000000000000000000..c19d44eb030cfbcf935a9cb968dff6450f6c1e08 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_determinant_op.py @@ -0,0 +1,155 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest, skip_check_grad_ci +import paddle +import paddle.nn.functional as F +import paddle.fluid as fluid +import paddle.fluid.core as core +import paddle.tensor as tensor + +paddle.enable_static() + + +@skip_check_grad_ci(reason="determinant grad is in progress.") +class TestDeterminantOp(OpTest): + def setUp(self): + self.init_data() + self.op_type = "determinant" + self.outputs = {'Out': self.target} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + pass + + def init_data(self): + np.random.seed(0) + self.case = np.random.rand(3, 3, 3, 3, 3).astype('float64') + self.inputs = {'Input': self.case} + self.target = np.linalg.det(self.case) + + +class TestDeterminantOpCase1(TestDeterminantOp): + def init_data(self): + np.random.seed(0) + self.case = np.random.rand(3, 3, 3, 3).astype(np.float32) + self.inputs = {'Input': self.case} + self.target = np.linalg.det(self.case) + + def test_check_grad(self): + pass + + +class TestDeterminantOpCase2(TestDeterminantOp): + def init_data(self): + np.random.seed(0) + self.case = np.random.rand(4, 2, 4, 4).astype('float64') + self.inputs = {'Input': self.case} + self.target = np.linalg.det(self.case) + + def test_check_grad(self): + pass + + +class TestDeterminantAPI(unittest.TestCase): + def setUp(self): + self.shape = [3, 3, 3, 3] + np.random.seed(0) + self.x = np.random.rand(3, 3, 3, 3).astype(np.float32) + self.place = paddle.CPUPlace() + + def test_api_static(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data('X', self.shape) + out = paddle.linalg.det(x) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x}, fetch_list=[out]) + out_ref = np.linalg.det(self.x) + + for out in res: + self.assertEqual(np.allclose(out, out_ref, rtol=1e-03), True) + + def test_api_dygraph(self): + paddle.disable_static(self.place) + x_tensor = paddle.to_tensor(self.x) + out = paddle.linalg.det(x_tensor) + out_ref = np.linalg.det(self.x) + self.assertEqual(np.allclose(out.numpy(), out_ref, rtol=1e-03), True) + paddle.enable_static() + + +@skip_check_grad_ci(reason="slogdeterminant grad is in progress.") +class TestSlogDeterminantOp(OpTest): + def setUp(self): + self.op_type = "slogdeterminant" + self.init_data() + self.outputs = {'Out': self.target} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + pass + + def init_data(self): + np.random.seed(0) + self.case = np.random.rand(3, 3, 3, 3).astype('float64') + self.inputs = {'Input': self.case} + self.target = np.array(np.linalg.slogdet(self.case)) + + +class TestSlogDeterminantOpCase1(TestSlogDeterminantOp): + def init_data(self): + np.random.seed(0) + self.case = np.random.rand(2, 2, 5, 5).astype(np.float32) + self.inputs = {'Input': self.case} + self.target = np.array(np.linalg.slogdet(self.case)) + + +class TestSlogDeterminantAPI(unittest.TestCase): + def setUp(self): + self.shape = [3, 3, 3, 3] + np.random.seed(0) + self.x = np.random.rand(3, 3, 3, 3).astype(np.float32) + self.place = paddle.CPUPlace() + + def test_api_static(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data('X', self.shape) + out = paddle.linalg.slogdet(x) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x}, fetch_list=[out]) + out_ref = np.array(np.linalg.slogdet(self.x)) + for out in res: + self.assertEqual(np.allclose(out, out_ref, rtol=1e-03), True) + + def test_api_dygraph(self): + paddle.disable_static(self.place) + x_tensor = paddle.to_tensor(self.x) + out = paddle.linalg.slogdet(x_tensor) + out_ref = np.array(np.linalg.slogdet(self.x)) + self.assertEqual(np.allclose(out.numpy(), out_ref, rtol=1e-03), True) + paddle.enable_static() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/linalg.py b/python/paddle/linalg.py index 361bee09def514bd3a6ed55276552b5ecc93be17..f12cafd3421d61439b21c7d10cc6707bf6d5465a 100644 --- a/python/paddle/linalg.py +++ b/python/paddle/linalg.py @@ -22,6 +22,8 @@ from .tensor.linalg import multi_dot # noqa: F401 from .tensor.linalg import matrix_rank from .tensor.linalg import svd from .tensor.linalg import eigh # noqa: F401 +from .tensor.linalg import det +from .tensor.linalg import slogdet from .tensor.linalg import pinv __all__ = [ @@ -34,6 +36,8 @@ __all__ = [ 'matrix_rank', 'svd', 'matrix_power', + 'det', + 'slogdet', 'eigh', 'pinv' ] diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index c7862f61894e5a8b24c5371e5f17e5331e1ca165..fbe6bd1697dbd434b9e367ef0e5abdb99176ad0a 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -14,7 +14,7 @@ import numpy as np from ..fluid.layer_helper import LayerHelper -from ..fluid.data_feeder import check_variable_and_dtype, check_type +from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype from ..fluid.framework import in_dygraph_mode, _varbase_creator, Variable from ..fluid.layers import transpose, cast # noqa: F401 @@ -1351,6 +1351,109 @@ def mv(x, vec, name=None): return out +def det(x): + """ + Calculates determinant value of a square matrix or batches of square matrices. + Args: + x (Tensor): input (Tensor): the input matrix of size `(n, n)` or the batch of matrices of size + `(*, n, n)` where `*` is one or more batch dimensions. + Returns: + y (Tensor):the determinant value of a square matrix or batches of square matrices. + + Example: + .. code-block:: python + + import paddle + + x = paddle.randn([3,3,3]) + + A = paddle.det(x) + + print(A) + + # [ 0.02547996, 2.52317095, -6.15900707]) + + + """ + if in_dygraph_mode(): + return core.ops.determinant(x) + + check_dtype(x.dtype, 'Input', ['float32', 'float64'], 'det') + + input_shape = list(x.shape) + assert len(input_shape) >= 2, \ + "The x must be at least 2-dimensional, " \ + "but received Input x's dimensional: %s.\n" % \ + len(input_shape) + + assert (input_shape[-1] == input_shape[-2]), \ + "Expect squared input," \ + "but received %s by %s matrix.\n" \ + %(input_shape[-2], input_shape[-1]) \ + + helper = LayerHelper('determinant', **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type='determinant', inputs={'Input': [x]}, outputs={'Out': [out]}) + return out + + +def slogdet(x): + """ + Calculates the sign and natural logarithm of the absolute value of a square matrix's or batches square matrices' determinant. + The determinant can be computed with ``sign * exp(logabsdet) + + Supports input of float, double + + Note that for matrices that have zero determinant, this returns ``(0, -inf)`` + Args: + x (Tensor): the batch of matrices of size :math:`(*, n, n)` + where math:`*` is one or more batch dimensions. + + Returns: + y (Tensor): A tensor containing the sign of the determinant and the natural logarithm + of the absolute value of determinant, respectively. + + Example: + .. code-block:: python + + import paddle + + x = paddle.randn([3,3,3]) + + A = paddle.slogdet(x) + + print(A) + + # [[ 1. , 1. , -1. ], + # [-0.98610914, -0.43010661, -0.10872950]]) + + """ + if in_dygraph_mode(): + return core.ops.slogdeterminant(x) + + check_dtype(x.dtype, 'Input', ['float32', 'float64'], 'slogdet') + + input_shape = list(x.shape) + assert len(input_shape) >= 2, \ + "The x must be at least 2-dimensional, " \ + "but received Input x's dimensional: %s.\n" % \ + len(input_shape) + + assert (input_shape[-1] == input_shape[-2]), \ + "Expect squared input," \ + "but received %s by %s matrix.\n" \ + %(input_shape[-2], input_shape[-1]) \ + + helper = LayerHelper('slogdeterminant', **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type='slogdeterminant', inputs={'Input': [x]}, outputs={'Out': [out]}) + return out + + def svd(x, full_matrices=False, name=None): r""" Computes the singular value decomposition of one matrix or a batch of regular matrices.