diff --git a/paddle/fluid/operators/matrix_power_op.cc b/paddle/fluid/operators/matrix_power_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..c65af3129f3646163925be95b27b9fec25207f8c --- /dev/null +++ b/paddle/fluid/operators/matrix_power_op.cc @@ -0,0 +1,131 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/matrix_power_op.h" + +namespace paddle { +namespace operators { + +class MatrixPowerOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "matrix_power"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "matrix_power"); + auto dims = ctx->GetInputDim("X"); + auto n_dim = dims.size(); + PADDLE_ENFORCE_GE(n_dim, 2, + platform::errors::InvalidArgument( + "The Input(X) should have at least 2 dimensions. But " + "received a %d dimension tensor.", + n_dim)); + PADDLE_ENFORCE_EQ(dims[n_dim - 2], dims[n_dim - 1], + platform::errors::InvalidArgument( + "The inner-most 2 dimensions of Input(X) all should " + "be square matrices " + "But received X's shape[-2] = %d and shape[-1] = %d.", + dims[n_dim - 2], dims[n_dim - 1])); + ctx->SetOutputDim("Out", dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +class MatrixPowerOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput( + "X", + "(Tensor), The input tensor of matrix_power op. Its shape should be " + "[*, M, M] where * is zero or more batch dimensions, and matrices " + "on the inner-most 2 dimensions all should be square matrices."); + AddOutput("Out", + "(Tensor), The output tensor of matrix_power op. It has the same " + "shape as the input."); + AddAttr("n", "(int), The exponent used to calculate the power of X."); + AddComment(R"DOC( +Matrix Power Operator. + +Computes the n-th power of a square matrix or a batch of square matrices. + +)DOC"); + } +}; + +class MatrixPowerOpInferVarType + : public framework::PassInDtypeAndVarTypeToOutput { + protected: + std::unordered_map& GetInputOutputWithSameType() + const override { + static std::unordered_map u_map{ + {"X", /*->*/ "Out"}}; + return u_map; + } +}; + +class MatrixPowerGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* context) const override { + OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "matrix_power_grad"); + OP_INOUT_CHECK(context->HasInput("Out"), "Input", "Out", + "matrix_power_grad"); + OP_INOUT_CHECK(context->HasInput(framework::GradVarName("Out")), "Input", + "Out@GRAD", "matrix_power_grad"); + auto x_dims = context->GetInputDim("X"); + auto x_grad_name = framework::GradVarName("X"); + if (context->HasOutput(x_grad_name)) { + context->SetOutputDim(x_grad_name, x_dims); + } + } +}; + +template +class MatrixPowerGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType(this->ForwardOpType() + "_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput("Out", this->Output("Out")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(matrix_power, ops::MatrixPowerOp, ops::MatrixPowerOpMaker, + ops::MatrixPowerOpInferVarType, + ops::MatrixPowerGradOpMaker, + ops::MatrixPowerGradOpMaker); + +REGISTER_OPERATOR(matrix_power_grad, ops::MatrixPowerGradOp); + +REGISTER_OP_CPU_KERNEL( + matrix_power, + ops::MatrixPowerKernel, + ops::MatrixPowerKernel); + +REGISTER_OP_CPU_KERNEL( + matrix_power_grad, + ops::MatrixPowerGradKernel, + ops::MatrixPowerGradKernel); diff --git a/paddle/fluid/operators/matrix_power_op.cu b/paddle/fluid/operators/matrix_power_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..d972e9499dc88444e2addc9c9082d9e6fd496e08 --- /dev/null +++ b/paddle/fluid/operators/matrix_power_op.cu @@ -0,0 +1,27 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/matrix_power_op.h" + +namespace ops = paddle::operators; +namespace plf = paddle::platform; + +REGISTER_OP_CUDA_KERNEL(matrix_power, + ops::MatrixPowerKernel, + ops::MatrixPowerKernel); + +REGISTER_OP_CUDA_KERNEL( + matrix_power_grad, + ops::MatrixPowerGradKernel, + ops::MatrixPowerGradKernel); diff --git a/paddle/fluid/operators/matrix_power_op.h b/paddle/fluid/operators/matrix_power_op.h new file mode 100644 index 0000000000000000000000000000000000000000..6c4b8860bf8c6692183f350d1be4017029d90c9b --- /dev/null +++ b/paddle/fluid/operators/matrix_power_op.h @@ -0,0 +1,277 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/matrix_inverse.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +struct IdentityMatrixFunctor { + IdentityMatrixFunctor(const int m, T* output) : m_(m), output_(output) {} + + HOSTDEVICE void operator()(size_t index) const { + const int row = index / m_ % m_; + const int col = index % m_; + output_[index] = col == row ? static_cast(1) : static_cast(0); + } + + const int m_; + T* output_; +}; + +template +void MatrixPowerFunction(const Tensor* X, const int n, Tensor* Out, + const paddle::framework::ExecutionContext& ctx) { + const auto& x_dims = X->dims(); + const int x_ndim = x_dims.size(); + T* out_data = Out->mutable_data(ctx.GetPlace()); + + auto& dev_ctx = ctx.template device_context(); + platform::ForRange for_range(dev_ctx, X->numel()); + + if (n == 0) { + // Out = Identity Matrix + IdentityMatrixFunctor functor(x_dims[x_ndim - 1], out_data); + for_range(functor); + return; + } + + auto blas = math::GetBlas(dev_ctx); + + Tensor new_x = ctx.AllocateTmpTensor(X->dims(), dev_ctx); + int new_n = n; + if (n > 0) { + // newX = X + framework::TensorCopy(*X, ctx.GetPlace(), dev_ctx, &new_x); + } else { + // newX = X^{-1}, n = -n + math::MatrixInverseFunctor mat_inv; + mat_inv(dev_ctx, *X, &new_x); + new_n = -n; + } + + if (new_n == 1) { + framework::TensorCopy(new_x, ctx.GetPlace(), dev_ctx, Out); + return; + } + + auto no_trans_desc = math::CreateMatrixDescriptor(x_dims, 0, false); + + if (new_n == 2) { + // Out = newX * newX + Out->mutable_data(ctx.GetPlace()); + blas.MatMul(new_x, no_trans_desc, new_x, no_trans_desc, static_cast(1), + Out, static_cast(0)); + return; + } else if (new_n == 3) { + // Out = (newX * newX) * newX + // Note: C[i] matrices in MatMul must not overlap, i.e. the individual + // gemm operations must be computable independently; otherwise, + // undefined behavior is expected. + Tensor temp = ctx.AllocateTmpTensor(X->dims(), dev_ctx); + blas.MatMul(new_x, no_trans_desc, new_x, no_trans_desc, static_cast(1), + &temp, static_cast(0)); + blas.MatMul(temp, no_trans_desc, new_x, no_trans_desc, static_cast(1), + Out, static_cast(0)); + return; + } else if (new_n == 4) { + // Out = (newX * newX) * (newX * newX) + Tensor temp = ctx.AllocateTmpTensor(X->dims(), dev_ctx); + blas.MatMul(new_x, no_trans_desc, new_x, no_trans_desc, static_cast(1), + &temp, static_cast(0)); + blas.MatMul(temp, no_trans_desc, temp, no_trans_desc, static_cast(1), + Out, static_cast(0)); + return; + } + + // Calculate Out = newX^{n} for abs(n) > 4 with time complexity as O(logN) + int bit = 0; + Tensor z = Tensor(X->type()); + bool out_inited = false; + Tensor temp_out = ctx.AllocateTmpTensor(X->dims(), dev_ctx); + Tensor temp_z = ctx.AllocateTmpTensor(X->dims(), dev_ctx); + while (new_n > 0) { + bit = new_n & 0x1; + new_n >>= 1; + if (z.IsInitialized()) { + blas.MatMul(z, no_trans_desc, z, no_trans_desc, static_cast(1), + &temp_z, static_cast(0)); + framework::TensorCopy(temp_z, ctx.GetPlace(), dev_ctx, &z); + } else { + z = ctx.AllocateTmpTensor(X->dims(), dev_ctx); + framework::TensorCopy(new_x, ctx.GetPlace(), dev_ctx, &z); + } + if (bit == 1) { + if (out_inited == true) { + blas.MatMul(*Out, no_trans_desc, z, no_trans_desc, static_cast(1), + &temp_out, static_cast(0)); + framework::TensorCopy(temp_out, ctx.GetPlace(), dev_ctx, Out); + } else { + framework::TensorCopy(z, ctx.GetPlace(), dev_ctx, Out); + out_inited = true; + } + } + } + return; +} + +template +class MatrixPowerKernel : public framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext& ctx) const override { + const Tensor* X = ctx.Input("X"); + Tensor* Out = ctx.Output("Out"); + int n = ctx.Attr("n"); + + const auto& x_dims = X->dims(); + const int x_ndim = x_dims.size(); + PADDLE_ENFORCE_EQ( + x_dims[x_ndim - 2], x_dims[x_ndim - 1], + platform::errors::InvalidArgument( + "The inner-most 2 dimensions of Input(X) should be equal." + "X's shape[-2] = %d and shape[-1] = %d.", + x_dims[x_ndim - 2], x_dims[x_ndim - 1])); + + MatrixPowerFunction(X, n, Out, ctx); + } +}; + +template +void MatrixPowerGradFunction(const Tensor* X, const Tensor* Out, + const Tensor* dOut, const int n, Tensor* dX, + const paddle::framework::ExecutionContext& ctx) { + dX->mutable_data(ctx.GetPlace()); + const auto& x_dims = X->dims(); + + auto& dev_ctx = ctx.template device_context(); + auto blas = math::GetBlas(dev_ctx); + + if (n == 0) { + // \nabla X = O + math::SetConstant zero; + zero(dev_ctx, dX, static_cast(0)); + return; + } else if (n == 1) { + // \nabla X = \nabla Out + framework::TensorCopy(*dOut, ctx.GetPlace(), dev_ctx, dX); + return; + } + + auto trans_desc = math::CreateMatrixDescriptor(x_dims, 0, true); + auto no_trans_desc = math::CreateMatrixDescriptor(x_dims, 0, false); + + if (n == -1) { + // \nabla X = Out^{T} * \nabla Out * Out^{T} + Tensor temp_dx = + ctx.AllocateTmpTensor(X->dims(), dev_ctx); + blas.MatMul(*Out, trans_desc, *dOut, no_trans_desc, static_cast(-1), + &temp_dx, static_cast(0)); + blas.MatMul(temp_dx, no_trans_desc, *Out, trans_desc, static_cast(1), dX, + static_cast(0)); + return; + } + + Tensor new_x = ctx.AllocateTmpTensor(X->dims(), dev_ctx); + int new_n = n; + if (n > 0) { + // newX = X + framework::TensorCopy(*X, ctx.GetPlace(), dev_ctx, &new_x); + } else { + // newX = X^{-1}, n = -n + math::MatrixInverseFunctor mat_inv; + mat_inv(dev_ctx, *X, &new_x); + new_n = -n; + } + + // Use chain rule blow to compute \nabla newX^{n} + // First, Get newX^{0}, newX^{1}, ..., newX^{n - 1}, + // Note that newX^{0} can be omitted + std::vector> tensor_list(new_n - 1); + tensor_list[0] = std::make_shared(new_x); + int index = 1; + while (index < new_n - 1) { + tensor_list[index] = std::make_shared( + ctx.AllocateTmpTensor(X->dims(), dev_ctx)); + blas.MatMul(*tensor_list[index - 1], no_trans_desc, new_x, no_trans_desc, + static_cast(1), tensor_list[index].get(), static_cast(0)); + index++; + } + + // Second, \nabla newX = \sum_{i = 0}^{n - 1} (newX^{T}^{i} + // * \nabla Out + // * (newX^{T}^{n - i - 1}) + Tensor dx_new = ctx.AllocateTmpTensor(X->dims(), dev_ctx); + blas.MatMul(*tensor_list[new_n - 2], trans_desc, *dOut, no_trans_desc, + static_cast(1), &dx_new, static_cast(0)); + Tensor da_an_minus1 = + ctx.AllocateTmpTensor(X->dims(), dev_ctx); + blas.MatMul(*dOut, no_trans_desc, *tensor_list[new_n - 2], trans_desc, + static_cast(1), &da_an_minus1, static_cast(0)); + blas.AXPY(X->numel(), static_cast(1), da_an_minus1.data(), + dx_new.data()); + int start = 0; + while (start < new_n - 2) { + Tensor a_da = ctx.AllocateTmpTensor(X->dims(), dev_ctx); + Tensor a_da_a = ctx.AllocateTmpTensor(X->dims(), dev_ctx); + blas.MatMul(*tensor_list[start], trans_desc, *dOut, no_trans_desc, + static_cast(1), &a_da, static_cast(0)); + blas.MatMul(a_da, no_trans_desc, *tensor_list[new_n - 3 - start], + trans_desc, static_cast(1), &a_da_a, static_cast(0)); + blas.AXPY(X->numel(), static_cast(1), a_da_a.data(), + dx_new.data()); + start++; + } + + if (n > 0) { + // \nabla X = \nabla newX + framework::TensorCopy(dx_new, ctx.GetPlace(), dev_ctx, dX); + } else { + // \nabla X = newX^{T} * \nabla newX * newX^{T} + Tensor temp_dx = + ctx.AllocateTmpTensor(X->dims(), dev_ctx); + blas.MatMul(new_x, trans_desc, dx_new, no_trans_desc, static_cast(-1), + &temp_dx, static_cast(0)); + blas.MatMul(temp_dx, no_trans_desc, new_x, trans_desc, static_cast(1), + dX, static_cast(0)); + } + return; +} + +template +class MatrixPowerGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const Tensor* X = ctx.Input("X"); + const Tensor* Out = ctx.Input("Out"); + const Tensor* dOut = ctx.Input(framework::GradVarName("Out")); + const int n = ctx.Attr("n"); + Tensor* dX = ctx.Output(framework::GradVarName("X")); + + MatrixPowerGradFunction(X, Out, dOut, n, dX, ctx); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 27a414e092802d48544e05678996d5f24587fae7..1c38d5197986669398b23b7516d88f9ff6dafa61 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -99,6 +99,7 @@ from .tensor.linalg import cholesky # noqa: F401 from .tensor.linalg import bmm # noqa: F401 from .tensor.linalg import histogram # noqa: F401 from .tensor.linalg import mv # noqa: F401 +from .tensor.linalg import matrix_power # noqa: F401 from .tensor.logic import equal # noqa: F401 from .tensor.logic import greater_equal # noqa: F401 from .tensor.logic import greater_than # noqa: F401 @@ -491,6 +492,7 @@ __all__ = [ # noqa 'stack', 'sqrt', 'cholesky', + 'matrix_power', 'randperm', 'linspace', 'reshape', diff --git a/python/paddle/fluid/tests/unittests/test_matrix_power_op.py b/python/paddle/fluid/tests/unittests/test_matrix_power_op.py new file mode 100644 index 0000000000000000000000000000000000000000..96823f49d2f08b094c997af4f81a2e725ab85efb --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_matrix_power_op.py @@ -0,0 +1,353 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle.fluid as fluid +import paddle.fluid.core as core +import paddle +from op_test import OpTest + +paddle.enable_static() + + +class TestMatrixPowerOp(OpTest): + def config(self): + self.matrix_shape = [10, 10] + self.dtype = "float64" + self.n = 0 + + def setUp(self): + self.op_type = "matrix_power" + self.config() + + np.random.seed(123) + mat = np.random.random(self.matrix_shape).astype(self.dtype) + powered_mat = np.linalg.matrix_power(mat, self.n) + + self.inputs = {"X": mat} + self.outputs = {"Out": powered_mat} + self.attrs = {"n": self.n} + + def test_check_output(self): + self.check_output() + + def test_grad(self): + self.check_grad( + ["X"], "Out", numeric_grad_delta=1e-5, max_relative_error=1e-7) + + +class TestMatrixPowerOpN1(TestMatrixPowerOp): + def config(self): + self.matrix_shape = [10, 10] + self.dtype = "float64" + self.n = 1 + + +class TestMatrixPowerOpN2(TestMatrixPowerOp): + def config(self): + self.matrix_shape = [10, 10] + self.dtype = "float64" + self.n = 2 + + +class TestMatrixPowerOpN3(TestMatrixPowerOp): + def config(self): + self.matrix_shape = [10, 10] + self.dtype = "float64" + self.n = 3 + + +class TestMatrixPowerOpN4(TestMatrixPowerOp): + def config(self): + self.matrix_shape = [10, 10] + self.dtype = "float64" + self.n = 4 + + +class TestMatrixPowerOpN5(TestMatrixPowerOp): + def config(self): + self.matrix_shape = [10, 10] + self.dtype = "float64" + self.n = 5 + + +class TestMatrixPowerOpN6(TestMatrixPowerOp): + def config(self): + self.matrix_shape = [10, 10] + self.dtype = "float64" + self.n = 6 + + +class TestMatrixPowerOpN10(TestMatrixPowerOp): + def config(self): + self.matrix_shape = [10, 10] + self.dtype = "float64" + self.n = 10 + + +class TestMatrixPowerOpNMinus(TestMatrixPowerOp): + def config(self): + self.matrix_shape = [10, 10] + self.dtype = "float64" + self.n = -1 + + def test_grad(self): + self.check_grad( + ["X"], "Out", numeric_grad_delta=1e-5, max_relative_error=1e-6) + + +class TestMatrixPowerOpNMinus2(TestMatrixPowerOpNMinus): + def config(self): + self.matrix_shape = [10, 10] + self.dtype = "float64" + self.n = -2 + + +class TestMatrixPowerOpNMinus3(TestMatrixPowerOpNMinus): + def config(self): + self.matrix_shape = [10, 10] + self.dtype = "float64" + self.n = -3 + + +class TestMatrixPowerOpNMinus4(TestMatrixPowerOpNMinus): + def config(self): + self.matrix_shape = [10, 10] + self.dtype = "float64" + self.n = -4 + + +class TestMatrixPowerOpNMinus5(TestMatrixPowerOpNMinus): + def config(self): + self.matrix_shape = [10, 10] + self.dtype = "float64" + self.n = -5 + + +class TestMatrixPowerOpNMinus6(TestMatrixPowerOpNMinus): + def config(self): + self.matrix_shape = [10, 10] + self.dtype = "float64" + self.n = -6 + + +class TestMatrixPowerOpNMinus10(TestMatrixPowerOp): + def config(self): + self.matrix_shape = [10, 10] + self.dtype = "float64" + self.n = -10 + + def test_grad(self): + self.check_grad( + ["X"], "Out", numeric_grad_delta=1e-5, max_relative_error=1e-6) + + +class TestMatrixPowerOpBatched1(TestMatrixPowerOp): + def config(self): + self.matrix_shape = [8, 4, 4] + self.dtype = "float64" + self.n = 5 + + +class TestMatrixPowerOpBatched2(TestMatrixPowerOp): + def config(self): + self.matrix_shape = [2, 6, 4, 4] + self.dtype = "float64" + self.n = 4 + + +class TestMatrixPowerOpBatched3(TestMatrixPowerOp): + def config(self): + self.matrix_shape = [2, 6, 4, 4] + self.dtype = "float64" + self.n = 0 + + +class TestMatrixPowerOpBatchedLong(TestMatrixPowerOp): + def config(self): + self.matrix_shape = [1, 2, 3, 4, 4, 3, 3] + self.dtype = "float64" + self.n = 3 + + +class TestMatrixPowerOpLarge1(TestMatrixPowerOp): + def config(self): + self.matrix_shape = [32, 32] + self.dtype = "float64" + self.n = 3 + + +class TestMatrixPowerOpLarge2(TestMatrixPowerOp): + def config(self): + self.matrix_shape = [10, 10] + self.dtype = "float64" + self.n = 32 + + +class TestMatrixPowerOpFP32(TestMatrixPowerOp): + def config(self): + self.matrix_shape = [10, 10] + self.dtype = "float32" + self.n = 2 + + def test_grad(self): + self.check_grad(["X"], "Out", max_relative_error=1e-2) + + +class TestMatrixPowerOpBatchedFP32(TestMatrixPowerOpFP32): + def config(self): + self.matrix_shape = [2, 8, 4, 4] + self.dtype = "float32" + self.n = 2 + + +class TestMatrixPowerOpLarge1FP32(TestMatrixPowerOpFP32): + def config(self): + self.matrix_shape = [32, 32] + self.dtype = "float32" + self.n = 2 + + +class TestMatrixPowerOpLarge2FP32(TestMatrixPowerOpFP32): + def config(self): + self.matrix_shape = [10, 10] + self.dtype = "float32" + self.n = 32 + + +class TestMatrixPowerOpFP32Minus(TestMatrixPowerOpFP32): + def config(self): + self.matrix_shape = [10, 10] + self.dtype = "float32" + self.n = -1 + + +class TestMatrixPowerAPI(unittest.TestCase): + def setUp(self): + np.random.seed(123) + self.places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(fluid.CUDAPlace(0)) + + def check_static_result(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + input_x = fluid.data(name="input_x", shape=[4, 4], dtype="float64") + result = paddle.linalg.matrix_power(x=input_x, n=-2) + input_np = np.random.random([4, 4]).astype("float64") + result_np = np.linalg.matrix_power(input_np, -2) + + exe = fluid.Executor(place) + fetches = exe.run(fluid.default_main_program(), + feed={"input_x": input_np}, + fetch_list=[result]) + self.assertTrue( + np.allclose(fetches[0], np.linalg.matrix_power(input_np, -2))) + + def test_static(self): + for place in self.places: + self.check_static_result(place=place) + + def test_dygraph(self): + for place in self.places: + with fluid.dygraph.guard(place): + input_np = np.random.random([4, 4]).astype("float64") + input = paddle.to_tensor(input_np) + result = paddle.linalg.matrix_power(input, -2) + self.assertTrue( + np.allclose(result.numpy(), + np.linalg.matrix_power(input_np, -2))) + + +class TestMatrixPowerAPIError(unittest.TestCase): + def test_errors(self): + input_np = np.random.random([4, 4]).astype("float64") + + # input must be Variable. + self.assertRaises(TypeError, paddle.linalg.matrix_power, input_np) + + # n must be int + for n in [2.0, '2', -2.0]: + input = fluid.data( + name="input_float32", shape=[4, 4], dtype='float32') + self.assertRaises(TypeError, paddle.linalg.matrix_power, input, n) + + # The data type of input must be float32 or float64. + for dtype in ["bool", "int32", "int64", "float16"]: + input = fluid.data(name="input_" + dtype, shape=[4, 4], dtype=dtype) + self.assertRaises(TypeError, paddle.linalg.matrix_power, input, 2) + + # When out is set, the data type must be the same as input. + input = fluid.data(name="input_1", shape=[4, 4], dtype="float32") + out = fluid.data(name="output", shape=[4, 4], dtype="float64") + self.assertRaises(TypeError, paddle.linalg.matrix_power, input, 2, out) + + # The number of dimensions of input must be >= 2. + input = fluid.data(name="input_2", shape=[4], dtype="float32") + self.assertRaises(ValueError, paddle.linalg.matrix_power, input, 2) + + # The inner-most 2 dimensions of input should be equal to each other + input = fluid.data(name="input_3", shape=[4, 5], dtype="float32") + self.assertRaises(ValueError, paddle.linalg.matrix_power, input, 2) + + +class TestMatrixPowerSingularAPI(unittest.TestCase): + def setUp(self): + self.places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(fluid.CUDAPlace(0)) + + def check_static_result(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + input = fluid.data(name="input", shape=[4, 4], dtype="float64") + result = paddle.linalg.matrix_power(x=input, n=-2) + + input_np = np.zeros([4, 4]).astype("float64") + + exe = fluid.Executor(place) + try: + fetches = exe.run(fluid.default_main_program(), + feed={"input": input_np}, + fetch_list=[result]) + except RuntimeError as ex: + print("The mat is singular") + pass + except ValueError as ex: + print("The mat is singular") + pass + + def test_static(self): + paddle.enable_static() + for place in self.places: + self.check_static_result(place=place) + paddle.disable_static() + + def test_dygraph(self): + for place in self.places: + with fluid.dygraph.guard(place): + input_np = np.ones([4, 4]).astype("float64") + input = fluid.dygraph.to_variable(input_np) + try: + result = paddle.linalg.matrix_power(input, -2) + except RuntimeError as ex: + print("The mat is singular") + pass + except ValueError as ex: + print("The mat is singular") + pass + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py b/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py index c771531b7b61be7933b5355204c532b847b13dc5..929a9696d1c12dc8a638844ae0a9739dcabe9dee 100644 --- a/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py @@ -46,6 +46,7 @@ NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST = [ 'cudnn_lstm', \ 'rnn', \ 'lgamma', \ + 'matrix_power', \ ] NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST = ['bilinear_interp',\ diff --git a/python/paddle/linalg.py b/python/paddle/linalg.py index 5cef01d18aca48a6e777f7b85da304351426f495..ec6b7aa9e3d8212d26ea3536f9be594bb1a6f629 100644 --- a/python/paddle/linalg.py +++ b/python/paddle/linalg.py @@ -14,10 +14,12 @@ from .tensor.linalg import cholesky # noqa: F401 from .tensor.linalg import norm # noqa: F401 +from .tensor.linalg import matrix_power # noqa: F401 from .tensor import inverse as inv # noqa: F401 __all__ = [ 'cholesky', #noqa 'norm', - 'inv' + 'inv', + 'matrix_power' ] diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index bcb508d11922fc9613953a90ef03133b133cb689..cc20e98006fec40e960b1ef93da37bed871ce476 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -44,6 +44,7 @@ from .linalg import cholesky # noqa: F401 from .linalg import bmm # noqa: F401 from .linalg import histogram # noqa: F401 from .linalg import mv # noqa: F401 +from .linalg import matrix_power # noqa: F401 from .logic import equal # noqa: F401 from .logic import greater_equal # noqa: F401 from .logic import greater_than # noqa: F401 @@ -220,6 +221,7 @@ tensor_method_func = [ #noqa 'bmm', 'histogram', 'mv', + 'matrix_power', 'abs', 'acos', 'all', diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index a1610581b67c033c5af46b88e270614a4b8cc1b7..74d9876cddd5cbdd47c686f496511e1025639cef 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -941,3 +941,73 @@ def mv(x, vec, name=None): type='mv', inputs={'X': x, 'Vec': vec}, outputs={'Out': out}) return out + + +def matrix_power(x, n, name=None): + r""" + Computes the n-th power of a square matrix or a batch of square matrices. + + Let :math:`X` be a sqaure matrix or a batch of square matrices, :math:`n` be + an exponent, the equation should be: + + .. math:: + Out = X ^ {n} + + Specifically, + + - If `n > 0`, it returns the matrix or a batch of matrices raised to the power + of `n`. + + - If `n = 0`, it returns the identity matrix or a batch of identity matrices. + + - If `n < 0`, it returns the inverse of each matrix (if invertible) raised to + the power of `abs(n)`. + + Args: + x (Tensor): A square matrix or a batch of square matrices to be raised + to power `n`. Its shape should be `[*, M, M]`, where `*` is zero or + more batch dimensions. Its data type should be float32 or float64. + n (int): The exponent. It can be any positive, negative integer or zero. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: The n-th power of the matrix (or the batch of matrices) `x`. Its + data type should be the same as that of `x`. + + Examples: + .. code-block:: python + + import paddle + + x = paddle.to_tensor([[1, 2, 3], + [1, 4, 9], + [1, 8, 27]], dtype='float64') + print(paddle.matrix_power(x, 2)) + # [[6. , 34. , 102.], + # [14. , 90. , 282.], + # [36. , 250., 804.]] + + print(paddle.matrix_power(x, 0)) + # [[1., 0., 0.], + # [0., 1., 0.], + # [0., 0., 1.]] + + print(paddle.matrix_power(x, -2)) + # [[ 12.91666667, -12.75000000, 2.83333333 ], + # [-7.66666667 , 8. , -1.83333333 ], + # [ 1.80555556 , -1.91666667 , 0.44444444 ]] + """ + if in_dygraph_mode(): + return core.ops.matrix_power(x, "n", n) + + check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'matrix_power') + check_type(n, 'n', int, 'matrix_power') + helper = LayerHelper('matrix_power', **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='matrix_power', + inputs={'X': x}, + outputs={'Out': out}, + attrs={'n': n}) + return out