From c9f7cff03d9b4093f785b6824265769e4634fc35 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 16 Sep 2021 09:21:43 +0800 Subject: [PATCH] Add a new op: paddle.linalg.multi_dot (#35224) --- paddle/fluid/operators/multi_dot_op.cc | 567 ++++++++++++++++++ python/paddle/__init__.py | 1 + .../tests/unittests/test_multi_dot_op.py | 263 ++++++++ .../white_list/check_shape_white_list.py | 1 + python/paddle/linalg.py | 2 + python/paddle/tensor/__init__.py | 2 + python/paddle/tensor/linalg.py | 106 +++- 7 files changed, 929 insertions(+), 13 deletions(-) create mode 100644 paddle/fluid/operators/multi_dot_op.cc create mode 100644 python/paddle/fluid/tests/unittests/test_multi_dot_op.py diff --git a/paddle/fluid/operators/multi_dot_op.cc b/paddle/fluid/operators/multi_dot_op.cc new file mode 100644 index 0000000000..2d06170d34 --- /dev/null +++ b/paddle/fluid/operators/multi_dot_op.cc @@ -0,0 +1,567 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/strided_memcpy.h" +#include "paddle/fluid/operators/utils.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; + +/** + * @brief compute the output shape and check the input shape valid or not + */ +inline framework::DDim ComputeAndCheckShape( + const bool is_runtime, const std::vector& inputs_dims) { + const size_t n = inputs_dims.size(); + auto first_dim = inputs_dims[0]; + + bool is_vector = false; + framework::DDim out_dim; + + PADDLE_ENFORCE_LT( + first_dim.size(), static_cast(3), + platform::errors::InvalidArgument( + "multi_dot: the first input tensor must be 1D or 2D but got[%d]!", + static_cast(first_dim.size()))); + + // If the first tensor is 1D of size n view it as a row vector (1, n) + if (first_dim.size() == 1) { + first_dim = framework::make_ddim({1, static_cast(first_dim[0])}); + is_vector = true; + } + + auto last_dim = inputs_dims[n - 1]; + PADDLE_ENFORCE_LT( + last_dim.size(), static_cast(3), + platform::errors::InvalidArgument( + "the last input tensor of multi_dot must be 1D or 2D but got[%d]!", + static_cast(first_dim.size()))); + + // If the last tensor is 1D of size n view it as a column vector (n, 1) + if (last_dim.size() == 1) { + last_dim = framework::make_ddim({static_cast(last_dim[0]), 1}); + out_dim = is_vector ? framework::make_ddim({1}) + : framework::make_ddim({first_dim[0]}); + } else { + out_dim = is_vector ? framework::make_ddim({last_dim[1]}) + : framework::make_ddim({first_dim[0], last_dim[1]}); + } + + auto width = first_dim[1]; + for (size_t i = 1; i < n - 1; i++) { + PADDLE_ENFORCE_EQ(inputs_dims[i].size(), static_cast(2), + platform::errors::InvalidArgument( + "the input tensor of multi_dot op must be 2D.")); + + const auto& tmp_dim = inputs_dims[i]; + PADDLE_ENFORCE_EQ( + tmp_dim[0], width, + platform::errors::InvalidArgument( + "the input matrix does not meet the multiplication requirements.")); + width = tmp_dim[1]; + } + + PADDLE_ENFORCE_EQ( + last_dim[0], width, + platform::errors::InvalidArgument( + "the input matrix does not meet the multiplication requirements.")); + + return out_dim; +} + +template +inline framework::Tensor MatMul(const framework::ExecutionContext& ctx, + const framework::Tensor& matrix_a, + const framework::Tensor& matrix_b, + const framework::DDim& a_dim, + const framework::DDim& b_dim) { + auto place = ctx.GetPlace(); + auto blas = math::GetBlas(ctx); + + framework::Tensor matrix_c; + framework::DDim c_dim = framework::make_ddim({a_dim[0], b_dim[1]}); + matrix_c.Resize(c_dim); + matrix_c.mutable_data(place); + + auto mat_dim_a = math::CreateMatrixDescriptor(a_dim, 0, false); + auto mat_dim_b = math::CreateMatrixDescriptor(b_dim, 0, false); + const T alpha = static_cast(1.0); + blas.MatMul(matrix_a, mat_dim_a, matrix_b, mat_dim_b, alpha, &matrix_c, T(0)); + return matrix_c; +} + +/** + * @brief Recursively calculate matrix multiplication according to the optimal + * order + * Let k = order[i,j], then ins[i...j] = ins[i...k] * ins[k+1 ...j] + * + * @param + * ins: the input tensors + * ins_dims: the shape of ins after reshape + * order: the optimal order + * i: the left of sub chain + * j: the righe of sub chain + * save_result: set true by backward + * results: save the intermediate result during backward + */ +template +inline framework::Tensor MatChainMul( + const framework::ExecutionContext& ctx, + const std::vector& ins, + const std::vector& ins_dims, + const std::vector& order, const uint64_t i, const uint64_t j, + const bool save_result, std::vector* results) { + if (i == j) { + return *ins[i]; + } + + const auto A = MatChainMul(ctx, ins, ins_dims, order, i, + order[i * ins.size() + j], + save_result, results); + framework::DDim a_dim = A.dims(); + if (i == order[i * ins.size() + j]) { + a_dim = ins_dims[i]; + } + + const auto B = MatChainMul(ctx, ins, ins_dims, order, + order[i * ins.size() + j] + 1, j, + save_result, results); + framework::DDim b_dim = B.dims(); + if (j == order[i * ins.size() + j] + 1) { + b_dim = ins_dims[j]; + } + + auto result = MatMul(ctx, A, B, a_dim, b_dim); + if (save_result) { + (*results)[i * ins.size() + j] = result; + } + return result; +} + +/** + * @brief get the optimal order + */ +std::vector GetOrder(const std::vector& ins, + const std::vector& ins_dims) { + auto n = ins.size(); + // p: save the ins shape, the ins[i] shape is (p[i], p[i+1]) + std::vector p(n + 1); + for (uint64_t i = 0; i < n; i++) { + p[i] = ins_dims[i][0]; + } + p[n] = ins_dims[n - 1][1]; + + // m[i, j]: save the lowest cost for multiplying ins[i...j] + std::vector m(n * n, 0); + // define ins[i...j] means multiplying matrices from ins[i] to ins[j] + // order[i, j] = k, this means that ins[i...k] and ins[k...j] fist and then + // multiply the resulting matrices is the optimal order for ins[i...j] + std::vector order(n * n); + for (uint64_t l = 1; l < n; l++) { + for (uint64_t i = 0; i < n - l; i++) { + auto j = i + l; + m[i * n + j] = 0xffffffff; + for (uint64_t k = i; k < j; k++) { + uint64_t q = + m[i * n + k] + m[(k + 1) * n + j] + p[i] * p[k + 1] * p[j + 1]; + if (q < m[i * n + j]) { + m[i * n + j] = q; + order[i * n + j] = k; + } + } + } + } + return order; +} + +template +static inline framework::Tensor MultiDotMatChainOrder( + const framework::ExecutionContext& ctx, + const std::vector& ins, + const std::vector& ins_dims, const bool save_result, + std::vector* results) { + auto order = GetOrder(ins, ins_dims); + return MatChainMul(ctx, ins, ins_dims, order, 0, + ins.size() - 1, save_result, results); +} + +inline void GetDims(const std::vector& ins, + std::vector* ins_dims) { + const auto n = ins.size(); + for (size_t i = 0; i < n; i++) { + (*ins_dims)[i] = ins[i]->dims(); + if (i == 0 && (*ins_dims)[i].size() == 1) { + (*ins_dims)[i] = framework::make_ddim({1, (*ins_dims)[i][0]}); + } else if (i == n - 1 && (*ins_dims)[i].size() == 1) { + (*ins_dims)[i] = framework::make_ddim({(*ins_dims)[i][0], 1}); + } + } +} + +class MultiDotOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input tensors of multi_dot operator.").AsDuplicable(); + AddOutput("Out", "The output tensor of multi_dot operator"); + AddComment(R"DOC( +Compute the dot product of two or more arrays in a single function call, while automatically selecting the fastest evaluation order. + +multi_dot chains MatMul and uses optimal parenthesization of the matrices [1] [2]. Depending on the shapes of the matrices, this can speed up the multiplication a lot. + +If the first argument is 1-D it is treated as a row vector. If the last argument is 1-D it is treated as a column vector. The other arguments must be 2-D. + )DOC"); + } +}; + +class MultiDotOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", "multi_dot"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "multi_dot"); + + auto inputs_dims = ctx->GetInputsDim("X"); + + const size_t inputs_num = inputs_dims.size(); + PADDLE_ENFORCE_GT( + inputs_num, static_cast(1), + platform::errors::InvalidArgument( + "The number of input tensors in multi_dot op should > 1.")); + auto out_dims = ComputeAndCheckShape(ctx->IsRuntime(), inputs_dims); + ctx->SetOutputDim("Out", out_dims); + ctx->ShareLoD("X", "Out"); + } +}; + +/** + * 1. there are only 2 matrices: direct matrix multiplication A*B + * 2. there are only 3 matrices: calculate the cost of (A*B)*C and A*(B*C), + * choose the least cost order for calculation + * 3. more than 3 matrices: call MultiDotMatChainOrder + */ +template +class MultiDotKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto ins = ctx.MultiInput("X"); + auto* out = ctx.Output("Out"); + + auto place = ctx.GetPlace(); + out->mutable_data(place); + + auto blas = math::GetBlas(ctx); + + auto n = ins.size(); + std::vector ins_dims(n); + GetDims(ins, &ins_dims); + + const T scale = static_cast(1.0); + if (n == 2) { + auto mat_dim_a = math::CreateMatrixDescriptor(ins_dims[0], 0, false); + auto mat_dim_b = math::CreateMatrixDescriptor(ins_dims[1], 0, false); + blas.MatMul(*ins[0], mat_dim_a, *ins[1], mat_dim_b, scale, out, T(0)); + } else if (n == 3) { + const auto Ma = ins_dims[0][0]; + const auto Ka = ins_dims[0][1]; + const auto Nb = ins_dims[1][1]; + const auto Nc = ins_dims[2][1]; + const uint64_t cost1 = Ma * Nb * (Ka + Nc); + const uint64_t cost2 = Ka * Nc * (Nb + Ma); + auto mat_dim_a = math::CreateMatrixDescriptor(ins_dims[0], 0, false); + auto mat_dim_b = math::CreateMatrixDescriptor(ins_dims[1], 0, false); + auto mat_dim_c = math::CreateMatrixDescriptor(ins_dims[2], 0, false); + if (cost1 < cost2) { + framework::Tensor tmp_out; + tmp_out.mutable_data(place, Ma * Nb * sizeof(T)); + framework::DDim tmp_dim = framework::make_ddim({Ma, Nb}); + blas.MatMul(*ins[0], mat_dim_a, *ins[1], mat_dim_b, scale, &tmp_out, + T(0)); + auto mat_dim_tmp = math::CreateMatrixDescriptor(tmp_dim, 0, false); + blas.MatMul(tmp_out, mat_dim_tmp, *ins[2], mat_dim_c, scale, out, T(0)); + } else { + framework::Tensor tmp_out; + tmp_out.mutable_data(place, Ka * Nc * sizeof(T)); + framework::DDim tmp_dim = framework::make_ddim({Ka, Nc}); + blas.MatMul(*ins[1], mat_dim_b, *ins[2], mat_dim_c, scale, &tmp_out, + T(0)); + auto mat_dim_tmp = math::CreateMatrixDescriptor(tmp_dim, 0, false); + blas.MatMul(*ins[0], mat_dim_a, tmp_out, mat_dim_tmp, scale, out, T(0)); + } + } else { + std::vector results; + const auto tmp = MultiDotMatChainOrder( + ctx, ins, ins_dims, false, &results); + auto out_dim = out->dims(); + *out = tmp; + out->Resize(out_dim); + } + } +}; + +class MultiDotOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", "multi_dot"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "Out@GRAD", "multi_dot"); + + auto in_x = "X"; + auto out_x_g_n = framework::GradVarName(in_x); + auto ins_dims = ctx->GetInputsDim(in_x); + ctx->SetOutputsDim(out_x_g_n, ins_dims); + ctx->ShareAllLoD(in_x, out_x_g_n); + } +}; + +template +class MultiDotGradKernel : public framework::OpKernel { + public: + /** + * @brief calculate dA and dB + * dA = dout * transpose(B) + * dB = transpose(A) * dout + */ + void CalcGrad(const framework::ExecutionContext& ctx, + const framework::Tensor& dout, const framework::Tensor& A, + const framework::Tensor& B, const framework::DDim& dout_dim, + const framework::DDim& a_dim, const framework::DDim& b_dim, + framework::Tensor* dA, framework::Tensor* dB) const { + auto mat_dim_dout = math::CreateMatrixDescriptor(dout_dim, 0, false); + auto mat_dim_a = math::CreateMatrixDescriptor(a_dim, 0, true); + auto mat_dim_b = math::CreateMatrixDescriptor(b_dim, 0, true); + T alpha = static_cast(1.0); + auto blas = math::GetBlas(ctx); + blas.MatMul(A, mat_dim_a, dout, mat_dim_dout, alpha, dB, T(0)); + blas.MatMul(dout, mat_dim_dout, B, mat_dim_b, alpha, dA, T(0)); + } + + /** + * @brief calculate multi matrix multiplication grad by a chain order + * @param + * dout: the grad of multi matrix multiplication out + * dx: the out grad of inputs + * ins: the input tensors + * ins_dims: the shape of ins after reshape + * order: the optimal order + * i: the left of sub chain + * j: the righe of sub chain + * results: the intermediate result of farward + */ + void MatChainMulGrad(const framework::ExecutionContext& ctx, + const framework::Tensor& dout, + std::vector* dx, + const std::vector& ins, + const framework::DDim& dout_dim, + const std::vector& ins_dims, + const std::vector& order, const uint64_t i, + const uint64_t j, + const std::vector& results) const { + if (i == j) { + *((*dx)[i]) = dout; + return; + } + + const auto n = ins.size(); + const auto right = order[i * n + j]; + const auto left = order[i * n + j] + 1; + // get the multi result of left sub chain + const auto* A = &results[i * n + right]; + framework::DDim a_dim = A->dims(); + if (i == right) { + A = ins[i]; + a_dim = ins_dims[i]; + } + // get the multi result of right sub chain + const auto* B = &results[left * n + j]; + framework::DDim b_dim = B->dims(); + if (left == j) { + B = ins[j]; + b_dim = ins_dims[j]; + } + framework::Tensor dA, dB; + dA.Resize({dout_dim[0], b_dim[0]}); + dB.Resize({a_dim[1], dout_dim[1]}); + dA.mutable_data(ctx.GetPlace()); + dB.mutable_data(ctx.GetPlace()); + + CalcGrad(ctx, dout, *A, *B, dout_dim, a_dim, b_dim, &dA, &dB); + MatChainMulGrad(ctx, dA, dx, ins, dA.dims(), ins_dims, order, i, right, + results); + MatChainMulGrad(ctx, dB, dx, ins, dB.dims(), ins_dims, order, left, j, + results); + } + + void MultiDotGradMatChainOrder( + const framework::ExecutionContext& ctx, const framework::Tensor& dout, + const std::vector& ins, + const framework::DDim& dout_dim, + const std::vector& ins_dims, + std::vector* dx) const { + auto order = GetOrder(ins, ins_dims); + auto n = ins.size(); + std::vector results(n * n); + MatChainMul(ctx, ins, ins_dims, order, 0, n - 1, true, + &results); + MatChainMulGrad(ctx, dout, dx, ins, dout_dim, ins_dims, order, 0, n - 1, + results); + } + + void Compute(const framework::ExecutionContext& ctx) const { + auto ins = ctx.MultiInput("X"); + auto dout = *ctx.Input(framework::GradVarName("Out")); + auto dx = ctx.MultiOutput(framework::GradVarName("X")); + + auto blas = math::GetBlas(ctx); + auto place = ctx.GetPlace(); + + const auto n = ins.size(); + for (size_t i = 0; i < n; i++) { + dx[i]->mutable_data(place); + } + + std::vector ins_dims(n); + GetDims(ins, &ins_dims); + + framework::DDim dout_dim = dout.dims(); + if (ins[0]->dims().size() == 1 && ins[n - 1]->dims().size() == 1) { + dout_dim = framework::make_ddim({1, 1}); + } else if (ins[0]->dims().size() == 1) { + if (dout_dim.size() == 1) { + dout_dim = framework::make_ddim({1, dout_dim[0]}); + } + } else if (ins[n - 1]->dims().size() == 1) { + if (dout_dim.size() == 1) { + dout_dim = framework::make_ddim({dout_dim[0], 1}); + } + } + + T alpha = static_cast(1); + auto mat_dim_dout = math::CreateMatrixDescriptor(dout_dim, 0, false); + if (n == 2) { + CalcGrad(ctx, dout, *ins[0], *ins[1], dout_dim, ins_dims[0], ins_dims[1], + dx[0], dx[1]); + } else if (n == 3) { + const auto Ma = ins_dims[0][0]; + const auto Ka = ins_dims[0][1]; + const auto Nb = ins_dims[1][1]; + const auto Nc = ins_dims[2][1]; + const uint64_t cost1 = Ma * Nb * (Ka + Nc); + const uint64_t cost2 = Ka * Nc * (Nb + Ma); + auto mat_dim_a = math::CreateMatrixDescriptor(ins_dims[0], 0, false); + auto mat_dim_b = math::CreateMatrixDescriptor(ins_dims[1], 0, false); + auto mat_dim_c = math::CreateMatrixDescriptor(ins_dims[2], 0, false); + if (cost1 < cost2) { + framework::Tensor tmp_out, tmp_dout; + tmp_out.Resize({Ma, Nb}); + tmp_out.mutable_data(place); + tmp_dout.Resize({mat_dim_dout.height_, Nb}); + tmp_dout.mutable_data(place); + blas.MatMul(*ins[0], mat_dim_a, *ins[1], mat_dim_b, alpha, &tmp_out, + T(0)); + CalcGrad(ctx, dout, tmp_out, *ins[2], dout_dim, tmp_out.dims(), + ins_dims[2], &tmp_dout, dx[2]); + CalcGrad(ctx, tmp_dout, *ins[0], *ins[1], tmp_dout.dims(), ins_dims[0], + ins_dims[1], dx[0], dx[1]); + } else { + framework::Tensor tmp_out, tmp_dout; + tmp_out.Resize({Ka, Nc}); + tmp_out.mutable_data(place); + tmp_dout.Resize({Ka, mat_dim_dout.width_}); + tmp_dout.mutable_data(place); + blas.MatMul(*ins[1], mat_dim_b, *ins[2], mat_dim_c, alpha, &tmp_out, + T(0)); + CalcGrad(ctx, dout, *ins[0], tmp_out, dout_dim, ins_dims[0], + tmp_dout.dims(), dx[0], &tmp_dout); + CalcGrad(ctx, tmp_dout, *ins[1], *ins[2], tmp_dout.dims(), ins_dims[1], + ins_dims[2], dx[1], dx[2]); + } + } else { + MultiDotGradMatChainOrder(ctx, dout, ins, dout_dim, ins_dims, &dx); + if (ins[n - 1]->dims().size() == 1) { + dx[n - 1]->Resize({dx[n - 1]->dims()[0]}); + } + } + } +}; + +template +class MultiDotOpGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("multi_dot_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X", false)); + } +}; +template +class MultiDotOpDoubleGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr grad_op) const override { + grad_op->SetType("multi_dot"); + grad_op->SetInput("X", this->Input(("X"))); + grad_op->SetInput("DOut", this->Input(framework::GradVarName("Out"))); + grad_op->SetOutput("DDx", this->OutputGrad(framework::GradVarName("X"))); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(multi_dot, ops::MultiDotOp, ops::MultiDotOpMaker, + ops::MultiDotOpGradMaker, + ops::MultiDotOpGradMaker); +REGISTER_OPERATOR(multi_dot_grad, ops::MultiDotOpGrad, + ops::MultiDotOpDoubleGradMaker, + ops::MultiDotOpDoubleGradMaker); + +REGISTER_OP_CPU_KERNEL( + multi_dot, ops::MultiDotKernel, + ops::MultiDotKernel); +REGISTER_OP_CPU_KERNEL( + multi_dot_grad, + ops::MultiDotGradKernel, + ops::MultiDotGradKernel); + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +REGISTER_OP_CUDA_KERNEL( + multi_dot, ops::MultiDotKernel, + ops::MultiDotKernel, + ops::MultiDotKernel); +REGISTER_OP_CUDA_KERNEL( + multi_dot_grad, + ops::MultiDotGradKernel, + ops::MultiDotGradKernel, + ops::MultiDotGradKernel); +#endif diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index c0f8a1b19b..6d482785e9 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -99,6 +99,7 @@ from .tensor.linalg import cholesky # noqa: F401 from .tensor.linalg import bmm # noqa: F401 from .tensor.linalg import histogram # noqa: F401 from .tensor.linalg import mv # noqa: F401 +from .tensor.linalg import multi_dot # noqa: F401 from .tensor.linalg import matrix_power # noqa: F401 from .tensor.logic import equal # noqa: F401 from .tensor.logic import greater_equal # noqa: F401 diff --git a/python/paddle/fluid/tests/unittests/test_multi_dot_op.py b/python/paddle/fluid/tests/unittests/test_multi_dot_op.py new file mode 100644 index 0000000000..97047b1ae0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_multi_dot_op.py @@ -0,0 +1,263 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest, skip_check_grad_ci +from numpy.linalg import multi_dot +from op_test import OpTest +import paddle + +paddle.enable_static() + + +#the unittest of multi_dot +#compare the result of paddle multi_dot and numpy multi_dot +class TestMultiDotOp(OpTest): + def setUp(self): + self.op_type = "multi_dot" + self.dtype = self.get_dtype() + self.get_inputs_and_outputs() + + def get_dtype(self): + return "float64" + + def get_inputs_and_outputs(self): + self.A = np.random.random((2, 8)).astype(self.dtype) + self.B = np.random.random((8, 4)).astype(self.dtype) + self.inputs = {'X': [('x0', self.A), ('x1', self.B)]} + self.outputs = {'Out': multi_dot([self.A, self.B])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['x0'], 'Out') + self.check_grad(['x1'], 'Out') + + +#(A*B)*C +class TestMultiDotOp3Mat(TestMultiDotOp): + def get_inputs_and_outputs(self): + self.A = np.random.random((2, 10)).astype(self.dtype) + self.B = np.random.random((10, 4)).astype(self.dtype) + self.C = np.random.random((4, 3)).astype(self.dtype) + self.inputs = {'X': [('x0', self.A), ('x1', self.B), ('x2', self.C)]} + self.outputs = {'Out': multi_dot([self.A, self.B, self.C])} + + def test_check_grad(self): + self.check_grad(['x0'], 'Out') + self.check_grad(['x1'], 'Out') + self.check_grad(['x2'], 'Out') + + +#A*(B*C) +class TestMultiDotOp3Mat2(TestMultiDotOp): + def get_inputs_and_outputs(self): + self.A = np.random.random((3, 4)).astype(self.dtype) + self.B = np.random.random((4, 8)).astype(self.dtype) + self.C = np.random.random((8, 2)).astype(self.dtype) + self.inputs = {'X': [('x0', self.A), ('x1', self.B), ('x2', self.C)]} + self.outputs = {'Out': multi_dot([self.A, self.B, self.C])} + + def test_check_grad(self): + self.check_grad(['x0'], 'Out') + self.check_grad(['x1'], 'Out') + self.check_grad(['x2'], 'Out') + + +class TestMultiDotOp4Mat(TestMultiDotOp): + def get_inputs_and_outputs(self): + self.A = np.random.random((8, 6)).astype(self.dtype) + self.B = np.random.random((6, 3)).astype(self.dtype) + self.C = np.random.random((3, 4)).astype(self.dtype) + self.D = np.random.random((4, 5)).astype(self.dtype) + self.inputs = { + 'X': + [('x0', self.A), ('x1', self.B), ('x2', self.C), ('x3', self.D)] + } + self.outputs = {'Out': multi_dot([self.A, self.B, self.C, self.D])} + + def test_check_grad(self): + self.check_grad(['x0'], 'Out') + self.check_grad(['x1'], 'Out') + self.check_grad(['x2'], 'Out') + self.check_grad(['x3'], 'Out') + + +class TestMultiDotOpFirst1D(TestMultiDotOp): + def get_inputs_and_outputs(self): + self.A = np.random.random((4)).astype(self.dtype) + self.B = np.random.random((4, 3)).astype(self.dtype) + self.inputs = {'X': [('x0', self.A), ('x1', self.B)]} + self.outputs = {'Out': multi_dot([self.A, self.B])} + + +class TestMultiDotOp3MatFirst1D(TestMultiDotOp3Mat): + def get_inputs_and_outputs(self): + self.A = np.random.random((4)).astype(self.dtype) + self.B = np.random.random((4, 3)).astype(self.dtype) + self.C = np.random.random((3, 3)).astype(self.dtype) + self.inputs = {'X': [('x0', self.A), ('x1', self.B), ('x2', self.C)]} + self.outputs = {'Out': multi_dot([self.A, self.B, self.C])} + + +class TestMultiDotOp4MatFirst1D(TestMultiDotOp4Mat): + def get_inputs_and_outputs(self): + self.A = np.random.random((4)).astype(self.dtype) + self.B = np.random.random((4, 3)).astype(self.dtype) + self.C = np.random.random((3, 4)).astype(self.dtype) + self.D = np.random.random((4, 5)).astype(self.dtype) + self.inputs = { + 'X': + [('x0', self.A), ('x1', self.B), ('x2', self.C), ('x3', self.D)] + } + self.outputs = {'Out': multi_dot([self.A, self.B, self.C, self.D])} + + +class TestMultiDotOpLast1D(TestMultiDotOp): + def get_inputs_and_outputs(self): + self.A = np.random.random((3, 6)).astype(self.dtype) + self.B = np.random.random((6)).astype(self.dtype) + self.inputs = {'X': [('x0', self.A), ('x1', self.B)]} + self.outputs = {'Out': multi_dot([self.A, self.B])} + + +class TestMultiDotOp3MatLast1D(TestMultiDotOp3Mat): + def get_inputs_and_outputs(self): + self.A = np.random.random((2, 4)).astype(self.dtype) + self.B = np.random.random((4, 3)).astype(self.dtype) + self.C = np.random.random((3)).astype(self.dtype) + self.inputs = {'X': [('x0', self.A), ('x1', self.B), ('x2', self.C)]} + self.outputs = {'Out': multi_dot([self.A, self.B, self.C])} + + def test_check_grad(self): + self.check_grad(['x0'], 'Out') + self.check_grad(['x1'], 'Out') + self.check_grad(['x2'], 'Out') + + +class TestMultiDotOp4MatLast1D(TestMultiDotOp4Mat): + def get_inputs_and_outputs(self): + self.A = np.random.random((2, 3)).astype(self.dtype) + self.B = np.random.random((3, 2)).astype(self.dtype) + self.C = np.random.random((2, 3)).astype(self.dtype) + self.D = np.random.random((3)).astype(self.dtype) + self.inputs = { + 'X': + [('x0', self.A), ('x1', self.B), ('x2', self.C), ('x3', self.D)] + } + self.outputs = {'Out': multi_dot([self.A, self.B, self.C, self.D])} + + +class TestMultiDotOpFirstAndLast1D(TestMultiDotOp): + def get_inputs_and_outputs(self): + self.A = np.random.random((4, )).astype(self.dtype) + self.B = np.random.random((4)).astype(self.dtype) + self.inputs = {'X': [('x0', self.A), ('x1', self.B)]} + self.outputs = {'Out': multi_dot([self.A, self.B])} + + +class TestMultiDotOp3MatFirstAndLast1D(TestMultiDotOp3Mat): + def get_inputs_and_outputs(self): + self.A = np.random.random((6, )).astype(self.dtype) + self.B = np.random.random((6, 4)).astype(self.dtype) + self.C = np.random.random((4)).astype(self.dtype) + self.inputs = {'X': [('x0', self.A), ('x1', self.B), ('x2', self.C)]} + self.outputs = {'Out': multi_dot([self.A, self.B, self.C])} + + +class TestMultiDotOp4MatFirstAndLast1D(TestMultiDotOp4Mat): + def get_inputs_and_outputs(self): + self.A = np.random.random((3, )).astype(self.dtype) + self.B = np.random.random((3, 4)).astype(self.dtype) + self.C = np.random.random((4, 2)).astype(self.dtype) + self.D = np.random.random((2)).astype(self.dtype) + self.inputs = { + 'X': + [('x0', self.A), ('x1', self.B), ('x2', self.C), ('x3', self.D)] + } + self.outputs = {'Out': multi_dot([self.A, self.B, self.C, self.D])} + + +#####python API test####### +class TestMultiDotOpError(unittest.TestCase): + def test_errors(self): + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + # The inputs type of multi_dot must be list matrix. + input1 = 12 + self.assertRaises(TypeError, paddle.multi_dot, [input1, input1]) + + # The inputs dtype of multi_dot must be float64, float64 or float16. + input2 = paddle.static.data( + name='input2', shape=[10, 10], dtype="int32") + self.assertRaises(TypeError, paddle.multi_dot, [input2, input2]) + + # the number of tensor must be larger than 1 + x0 = paddle.static.data(name='x0', shape=[3, 2], dtype="float64") + self.assertRaises(ValueError, paddle.multi_dot, [x0]) + + #the first tensor must be 1D or 2D + x1 = paddle.static.data(name='x1', shape=[3, 2, 3], dtype="float64") + x2 = paddle.static.data(name='x2', shape=[3, 2], dtype="float64") + self.assertRaises(ValueError, paddle.multi_dot, [x1, x2]) + + #the last tensor must be 1D or 2D + x3 = paddle.static.data(name='x3', shape=[3, 2], dtype="float64") + x4 = paddle.static.data(name='x4', shape=[3, 2, 2], dtype="float64") + self.assertRaises(ValueError, paddle.multi_dot, [x3, x4]) + + #the tensor must be 2D, except first and last tensor + x5 = paddle.static.data(name='x5', shape=[3, 2], dtype="float64") + x6 = paddle.static.data(name='x6', shape=[2], dtype="float64") + x7 = paddle.static.data(name='x7', shape=[2, 2], dtype="float64") + self.assertRaises(ValueError, paddle.multi_dot, [x5, x6, x7]) + + +class APITestMultiDot(unittest.TestCase): + def test_out(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x0 = paddle.static.data(name='x0', shape=[3, 2], dtype="float64") + x1 = paddle.static.data(name='x1', shape=[2, 3], dtype='float64') + result = paddle.multi_dot([x0, x1]) + exe = paddle.static.Executor(paddle.CPUPlace()) + data1 = np.random.rand(3, 2).astype("float64") + data2 = np.random.rand(2, 3).astype("float64") + np_res = exe.run(feed={'x0': data1, + 'x1': data2}, + fetch_list=[result]) + expected_result = np.linalg.multi_dot([data1, data2]) + + self.assertTrue( + np.allclose( + np_res, expected_result, atol=1e-5), + "two value is\ + {}\n{}, check diff!".format(np_res, expected_result)) + + def test_dygraph_without_out(self): + paddle.disable_static() + device = paddle.CPUPlace() + input_array1 = np.random.rand(3, 4).astype("float64") + input_array2 = np.random.rand(4, 3).astype("float64") + data1 = paddle.to_tensor(input_array1) + data2 = paddle.to_tensor(input_array2) + out = paddle.multi_dot([data1, data2]) + expected_result = np.linalg.multi_dot([input_array1, input_array2]) + self.assertTrue(np.allclose(expected_result, out.numpy())) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/white_list/check_shape_white_list.py b/python/paddle/fluid/tests/unittests/white_list/check_shape_white_list.py index 15f28d94c7..626ea6c2ae 100644 --- a/python/paddle/fluid/tests/unittests/white_list/check_shape_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/check_shape_white_list.py @@ -28,4 +28,5 @@ NEED_TO_FIX_OP_LIST = [ 'cvm', 'cudnn_lstm', 'rnn', + 'multi_dot', ] diff --git a/python/paddle/linalg.py b/python/paddle/linalg.py index eabb017a0f..27dc2595bf 100644 --- a/python/paddle/linalg.py +++ b/python/paddle/linalg.py @@ -16,6 +16,7 @@ from .tensor.linalg import cholesky # noqa: F401 from .tensor.linalg import norm # noqa: F401 from .tensor.linalg import matrix_power # noqa: F401 from .tensor import inverse as inv # noqa: F401 +from .tensor.linalg import multi_dot # noqa: F401 from .tensor.linalg import matrix_rank from .tensor.linalg import svd @@ -23,6 +24,7 @@ __all__ = [ 'cholesky', #noqa 'norm', 'inv', + 'multi_dot', 'matrix_rank', 'svd', 'matrix_power' diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 19624cf6b8..fce4b764a8 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -45,6 +45,8 @@ from .linalg import bmm # noqa: F401 from .linalg import histogram # noqa: F401 from .linalg import mv # noqa: F401 from .linalg import matrix_power # noqa: F401 +from .linalg import multi_dot # noqa: F401 +from .linalg import svd # noqa: F401 from .logic import equal # noqa: F401 from .logic import greater_equal # noqa: F401 from .logic import greater_than # noqa: F401 diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index b50643471e..96a3610b18 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -789,25 +789,25 @@ def matrix_rank(x, tol=None, hermitian=False, name=None): r""" Computes the rank of a matrix. - The rank of a matrix is the number of singular values that are greater than the specified tol threshold when hermitian=False, + The rank of a matrix is the number of singular values that are greater than the specified tol threshold when hermitian=False, or the number of eigenvalues in absolute value that are greater than the specified tol threshold when hermitian=True. Args: - x (Tensor): The input tensor. - Its shape should be [..., m, n], where ... is zero or more batch dimensions. If x is a batch of matrices then the output - has the same batch dimensions. The data type of x should be float32 or float64. - tol (float,Tensor,optional): the tolerance value. Default: None. - If tol is not specified, and sigma is the largest singular value (or eigenvalue in absolute value), and eps is the - epsilon value for the dtype of x, then tol is computed with formula tol=sigma * max(m,n) * eps. Note that if x is + x (Tensor): The input tensor. + Its shape should be [..., m, n], where ... is zero or more batch dimensions. If x is a batch of matrices then the output + has the same batch dimensions. The data type of x should be float32 or float64. + tol (float,Tensor,optional): the tolerance value. Default: None. + If tol is not specified, and sigma is the largest singular value (or eigenvalue in absolute value), and eps is the + epsilon value for the dtype of x, then tol is computed with formula tol=sigma * max(m,n) * eps. Note that if x is a batch of matrices, tol is computed this way for every batch. hermitian (bool,optional): indicates whether x is Hermitian. Default: False. - When hermitian=True, x is assumed to be Hermitian, but x is not checked inside the function. Instead, We just use the + When hermitian=True, x is assumed to be Hermitian, but x is not checked inside the function. Instead, We just use the lower triangular of the matrix to compute. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: Tensor: Rank of tensor x. - + Examples: .. code-block:: python @@ -824,7 +824,7 @@ def matrix_rank(x, tol=None, hermitian=False, name=None): # d = [[1, 1, 1, 1], # [1, 1, 1, 1], # [1, 1, 1, 1]] - + """ if in_dygraph_mode(): @@ -1112,12 +1112,12 @@ def matrix_power(x, n, name=None): .. math:: Out = X ^ {n} - + Specifically, - If `n > 0`, it returns the matrix or a batch of matrices raised to the power of `n`. - + - If `n = 0`, it returns the identity matrix or a batch of identity matrices. - If `n < 0`, it returns the inverse of each matrix (if invertible) raised to @@ -1128,7 +1128,7 @@ def matrix_power(x, n, name=None): to power `n`. Its shape should be `[*, M, M]`, where `*` is zero or more batch dimensions. Its data type should be float32 or float64. n (int): The exponent. It can be any positive, negative integer or zero. - name (str, optional): Name for the operation (optional, default is None). + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -1171,3 +1171,83 @@ def matrix_power(x, n, name=None): outputs={'Out': out}, attrs={'n': n}) return out + + +def multi_dot(x, name=None): + """ + Multi_dot is an operator that calculates multiple matrix multiplications. + + Supports inputs of float, double and float16 dtypes. This function does not + support batched inputs. + + The input tensor in [x] must be 2-D except for the first and last can be 1-D. + If the first tensor is a 1-D vector of shape(n, ) it is treated as row vector + of shape(1, n), similarly if the last tensor is a 1D vector of shape(n, ), it + is treated as a column vector of shape(n, 1). + + If the first and last tensor are 2-D matrix, then the output is also 2-D matrix, + otherwise the output is a 1-D vector. + + Multi_dot will select the lowest cost multiplication order for calculation. The + cost of multiplying two matrices with shapes (a, b) and (b, c) is a * b * c. + Given matrices A, B, C with shapes (20, 5), (5, 100), (100, 10) respectively, + we can calculate the cost of different multiplication orders as follows: + - Cost((AB)C) = 20x5x100 + 20x100x10 = 30000 + - Cost(A(BC)) = 5x100x10 + 20x5x10 = 6000 + + In this case, multiplying B and C first, then multiply A, which is 5 times faster + than sequential calculation. + + Args: + x ([Tensor]): The input tensors which is a list Tensor. + name(str|None): A name for this layer(optional). If set None, the layer + will be named automatically. + + Returns: + Tensor: The output Tensor. + + + Examples: + + .. code-block:: python + + import paddle + import numpy as np + + # A * B + A_data = np.random.random([3, 4]).astype(np.float32) + B_data = np.random.random([4, 5]).astype(np.float32) + A = paddle.to_tensor(A_data) + B = paddle.to_tensor(B_data) + out = paddle.multi_dot([A, B]) + print(out.numpy().shape) + # [3, 5] + + # A * B * C + A_data = np.random.random([10, 5]).astype(np.float32) + B_data = np.random.random([5, 8]).astype(np.float32) + C_data = np.random.random([8, 7]).astype(np.float32) + A = paddle.to_tensor(A_data) + B = paddle.to_tensor(B_data) + C = paddle.to_tensor(C_data) + out = paddle.multi_dot([A, B, C]) + print(out.numpy().shape) + # [10, 7] + + """ + if in_dygraph_mode(): + return _C_ops.multi_dot(x) + + check_type(x, 'x', (list, tuple), 'multi_dot') + for id, item in enumerate(x): + check_variable_and_dtype(item, 'x[' + str(id) + ']', + ['float16', 'float32', 'float64'], 'multi_dot') + if item.dtype != x[0].dtype: + raise TypeError( + "All the Tensors in the input must have the same data type.") + + helper = LayerHelper('multi_dot', **locals()) + dtype = helper.input_dtype(input_param_name='x') + out = helper.create_variable_for_type_inference(dtype) + helper.append_op(type='multi_dot', inputs={"X": x}, outputs={"Out": out}) + return out -- GitLab