From fd36ede6d89c1d5397e6b351e020ffbbad0ed6a7 Mon Sep 17 00:00:00 2001 From: Liu-xiandong <85323580+Liu-xiandong@users.noreply.github.com> Date: Mon, 7 Mar 2022 12:15:34 +0800 Subject: [PATCH] [phi] move multi_dot OP (#40038) * [phi] move multi_dot OP * fix the segment bug * fix bug * delete useless comment * fix CI bug --- paddle/fluid/operators/multi_dot_op.cc | 397 --------------- .../phi/kernels/cpu/multi_dot_grad_kernel.cc | 22 + paddle/phi/kernels/cpu/multi_dot_kernel.cc | 22 + .../phi/kernels/gpu/multi_dot_grad_kernel.cu | 30 ++ paddle/phi/kernels/gpu/multi_dot_kernel.cu | 25 + .../phi/kernels/impl/multi_dot_kernel_impl.h | 456 ++++++++++++++++++ paddle/phi/kernels/multi_dot_grad_kernel.h | 27 ++ paddle/phi/kernels/multi_dot_kernel.h | 26 + paddle/phi/ops/compat/multi_dot_sig.cc | 27 ++ 9 files changed, 635 insertions(+), 397 deletions(-) create mode 100644 paddle/phi/kernels/cpu/multi_dot_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/multi_dot_kernel.cc create mode 100644 paddle/phi/kernels/gpu/multi_dot_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/multi_dot_kernel.cu create mode 100644 paddle/phi/kernels/impl/multi_dot_kernel_impl.h create mode 100644 paddle/phi/kernels/multi_dot_grad_kernel.h create mode 100644 paddle/phi/kernels/multi_dot_kernel.h create mode 100644 paddle/phi/ops/compat/multi_dot_sig.cc diff --git a/paddle/fluid/operators/multi_dot_op.cc b/paddle/fluid/operators/multi_dot_op.cc index fe4609b3ad9..b309e1b87ef 100644 --- a/paddle/fluid/operators/multi_dot_op.cc +++ b/paddle/fluid/operators/multi_dot_op.cc @@ -87,135 +87,6 @@ inline framework::DDim ComputeAndCheckShape( return out_dim; } -template -inline framework::Tensor MatMul(const framework::ExecutionContext& ctx, - const framework::Tensor& matrix_a, - const framework::Tensor& matrix_b, - const framework::DDim& a_dim, - const framework::DDim& b_dim) { - auto place = ctx.GetPlace(); - auto blas = phi::funcs::GetBlas(ctx); - - framework::Tensor matrix_c; - framework::DDim c_dim = phi::make_ddim({a_dim[0], b_dim[1]}); - matrix_c.Resize(c_dim); - matrix_c.mutable_data(place); - - auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a_dim, 0, false); - auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b_dim, 0, false); - const T alpha = static_cast(1.0); - blas.MatMul(matrix_a, mat_dim_a, matrix_b, mat_dim_b, alpha, &matrix_c, T(0)); - return matrix_c; -} - -/** - * @brief Recursively calculate matrix multiplication according to the optimal - * order - * Let k = order[i,j], then ins[i...j] = ins[i...k] * ins[k+1 ...j] - * - * @param - * ins: the input tensors - * ins_dims: the shape of ins after reshape - * order: the optimal order - * i: the left of sub chain - * j: the righe of sub chain - * save_result: set true by backward - * results: save the intermediate result during backward - */ -template -inline framework::Tensor MatChainMul( - const framework::ExecutionContext& ctx, - const std::vector& ins, - const std::vector& ins_dims, - const std::vector& order, const uint64_t i, const uint64_t j, - const bool save_result, std::vector* results) { - if (i == j) { - return *ins[i]; - } - - const auto A = MatChainMul(ctx, ins, ins_dims, order, i, - order[i * ins.size() + j], - save_result, results); - framework::DDim a_dim = A.dims(); - if (i == order[i * ins.size() + j]) { - a_dim = ins_dims[i]; - } - - const auto B = MatChainMul(ctx, ins, ins_dims, order, - order[i * ins.size() + j] + 1, j, - save_result, results); - framework::DDim b_dim = B.dims(); - if (j == order[i * ins.size() + j] + 1) { - b_dim = ins_dims[j]; - } - - auto result = MatMul(ctx, A, B, a_dim, b_dim); - if (save_result) { - (*results)[i * ins.size() + j] = result; - } - return result; -} - -/** - * @brief get the optimal order - */ -std::vector GetOrder(const std::vector& ins, - const std::vector& ins_dims) { - auto n = ins.size(); - // p: save the ins shape, the ins[i] shape is (p[i], p[i+1]) - std::vector p(n + 1); - for (uint64_t i = 0; i < n; i++) { - p[i] = ins_dims[i][0]; - } - p[n] = ins_dims[n - 1][1]; - - // m[i, j]: save the lowest cost for multiplying ins[i...j] - std::vector m(n * n, 0); - // define ins[i...j] means multiplying matrices from ins[i] to ins[j] - // order[i, j] = k, this means that ins[i...k] and ins[k...j] fist and then - // multiply the resulting matrices is the optimal order for ins[i...j] - std::vector order(n * n); - for (uint64_t l = 1; l < n; l++) { - for (uint64_t i = 0; i < n - l; i++) { - auto j = i + l; - m[i * n + j] = 0xffffffff; - for (uint64_t k = i; k < j; k++) { - uint64_t q = - m[i * n + k] + m[(k + 1) * n + j] + p[i] * p[k + 1] * p[j + 1]; - if (q < m[i * n + j]) { - m[i * n + j] = q; - order[i * n + j] = k; - } - } - } - } - return order; -} - -template -static inline framework::Tensor MultiDotMatChainOrder( - const framework::ExecutionContext& ctx, - const std::vector& ins, - const std::vector& ins_dims, const bool save_result, - std::vector* results) { - auto order = GetOrder(ins, ins_dims); - return MatChainMul(ctx, ins, ins_dims, order, 0, - ins.size() - 1, save_result, results); -} - -inline void GetDims(const std::vector& ins, - std::vector* ins_dims) { - const auto n = ins.size(); - for (size_t i = 0; i < n; i++) { - (*ins_dims)[i] = ins[i]->dims(); - if (i == 0 && (*ins_dims)[i].size() == 1) { - (*ins_dims)[i] = phi::make_ddim({1, (*ins_dims)[i][0]}); - } else if (i == n - 1 && (*ins_dims)[i].size() == 1) { - (*ins_dims)[i] = phi::make_ddim({(*ins_dims)[i][0], 1}); - } - } -} - class MultiDotOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -252,78 +123,6 @@ class MultiDotOp : public framework::OperatorWithKernel { } }; -/** - * 1. there are only 2 matrices: direct matrix multiplication A*B - * 2. there are only 3 matrices: calculate the cost of (A*B)*C and A*(B*C), - * choose the least cost order for calculation - * 3. more than 3 matrices: call MultiDotMatChainOrder - */ -template -class MultiDotKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto ins = ctx.MultiInput("X"); - auto* out = ctx.Output("Out"); - - auto place = ctx.GetPlace(); - out->mutable_data(place); - - auto blas = phi::funcs::GetBlas(ctx); - - auto n = ins.size(); - std::vector ins_dims(n); - GetDims(ins, &ins_dims); - - const T scale = static_cast(1.0); - if (n == 2) { - auto mat_dim_a = - phi::funcs::CreateMatrixDescriptor(ins_dims[0], 0, false); - auto mat_dim_b = - phi::funcs::CreateMatrixDescriptor(ins_dims[1], 0, false); - blas.MatMul(*ins[0], mat_dim_a, *ins[1], mat_dim_b, scale, out, T(0)); - } else if (n == 3) { - const auto Ma = ins_dims[0][0]; - const auto Ka = ins_dims[0][1]; - const auto Nb = ins_dims[1][1]; - const auto Nc = ins_dims[2][1]; - const uint64_t cost1 = Ma * Nb * (Ka + Nc); - const uint64_t cost2 = Ka * Nc * (Nb + Ma); - auto mat_dim_a = - phi::funcs::CreateMatrixDescriptor(ins_dims[0], 0, false); - auto mat_dim_b = - phi::funcs::CreateMatrixDescriptor(ins_dims[1], 0, false); - auto mat_dim_c = - phi::funcs::CreateMatrixDescriptor(ins_dims[2], 0, false); - if (cost1 < cost2) { - framework::Tensor tmp_out; - tmp_out.mutable_data(place, Ma * Nb * sizeof(T)); - framework::DDim tmp_dim = phi::make_ddim({Ma, Nb}); - blas.MatMul(*ins[0], mat_dim_a, *ins[1], mat_dim_b, scale, &tmp_out, - T(0)); - auto mat_dim_tmp = - phi::funcs::CreateMatrixDescriptor(tmp_dim, 0, false); - blas.MatMul(tmp_out, mat_dim_tmp, *ins[2], mat_dim_c, scale, out, T(0)); - } else { - framework::Tensor tmp_out; - tmp_out.mutable_data(place, Ka * Nc * sizeof(T)); - framework::DDim tmp_dim = phi::make_ddim({Ka, Nc}); - blas.MatMul(*ins[1], mat_dim_b, *ins[2], mat_dim_c, scale, &tmp_out, - T(0)); - auto mat_dim_tmp = - phi::funcs::CreateMatrixDescriptor(tmp_dim, 0, false); - blas.MatMul(*ins[0], mat_dim_a, tmp_out, mat_dim_tmp, scale, out, T(0)); - } - } else { - std::vector results; - const auto tmp = MultiDotMatChainOrder( - ctx, ins, ins_dims, false, &results); - auto out_dim = out->dims(); - *out = tmp; - out->Resize(out_dim); - } - } -}; - class MultiDotOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -341,180 +140,6 @@ class MultiDotOpGrad : public framework::OperatorWithKernel { } }; -template -class MultiDotGradKernel : public framework::OpKernel { - public: - /** - * @brief calculate dA and dB - * dA = dout * transpose(B) - * dB = transpose(A) * dout - */ - void CalcGrad(const framework::ExecutionContext& ctx, - const framework::Tensor& dout, const framework::Tensor& A, - const framework::Tensor& B, const framework::DDim& dout_dim, - const framework::DDim& a_dim, const framework::DDim& b_dim, - framework::Tensor* dA, framework::Tensor* dB) const { - auto mat_dim_dout = phi::funcs::CreateMatrixDescriptor(dout_dim, 0, false); - auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a_dim, 0, true); - auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b_dim, 0, true); - T alpha = static_cast(1.0); - auto blas = phi::funcs::GetBlas(ctx); - blas.MatMul(A, mat_dim_a, dout, mat_dim_dout, alpha, dB, T(0)); - blas.MatMul(dout, mat_dim_dout, B, mat_dim_b, alpha, dA, T(0)); - } - - /** - * @brief calculate multi matrix multiplication grad by a chain order - * @param - * dout: the grad of multi matrix multiplication out - * dx: the out grad of inputs - * ins: the input tensors - * ins_dims: the shape of ins after reshape - * order: the optimal order - * i: the left of sub chain - * j: the righe of sub chain - * results: the intermediate result of farward - */ - void MatChainMulGrad(const framework::ExecutionContext& ctx, - const framework::Tensor& dout, - std::vector* dx, - const std::vector& ins, - const framework::DDim& dout_dim, - const std::vector& ins_dims, - const std::vector& order, const uint64_t i, - const uint64_t j, - const std::vector& results) const { - if (i == j) { - *((*dx)[i]) = dout; - return; - } - - const auto n = ins.size(); - const auto right = order[i * n + j]; - const auto left = order[i * n + j] + 1; - // get the multi result of left sub chain - const auto* A = &results[i * n + right]; - framework::DDim a_dim = A->dims(); - if (i == right) { - A = ins[i]; - a_dim = ins_dims[i]; - } - // get the multi result of right sub chain - const auto* B = &results[left * n + j]; - framework::DDim b_dim = B->dims(); - if (left == j) { - B = ins[j]; - b_dim = ins_dims[j]; - } - framework::Tensor dA, dB; - dA.Resize({dout_dim[0], b_dim[0]}); - dB.Resize({a_dim[1], dout_dim[1]}); - dA.mutable_data(ctx.GetPlace()); - dB.mutable_data(ctx.GetPlace()); - - CalcGrad(ctx, dout, *A, *B, dout_dim, a_dim, b_dim, &dA, &dB); - MatChainMulGrad(ctx, dA, dx, ins, dA.dims(), ins_dims, order, i, right, - results); - MatChainMulGrad(ctx, dB, dx, ins, dB.dims(), ins_dims, order, left, j, - results); - } - - void MultiDotGradMatChainOrder( - const framework::ExecutionContext& ctx, const framework::Tensor& dout, - const std::vector& ins, - const framework::DDim& dout_dim, - const std::vector& ins_dims, - std::vector* dx) const { - auto order = GetOrder(ins, ins_dims); - auto n = ins.size(); - std::vector results(n * n); - MatChainMul(ctx, ins, ins_dims, order, 0, n - 1, true, - &results); - MatChainMulGrad(ctx, dout, dx, ins, dout_dim, ins_dims, order, 0, n - 1, - results); - } - - void Compute(const framework::ExecutionContext& ctx) const { - auto ins = ctx.MultiInput("X"); - auto dout = *ctx.Input(framework::GradVarName("Out")); - auto dx = ctx.MultiOutput(framework::GradVarName("X")); - - auto blas = phi::funcs::GetBlas(ctx); - auto place = ctx.GetPlace(); - - const auto n = ins.size(); - for (size_t i = 0; i < n; i++) { - dx[i]->mutable_data(place); - } - - std::vector ins_dims(n); - GetDims(ins, &ins_dims); - - framework::DDim dout_dim = dout.dims(); - if (ins[0]->dims().size() == 1 && ins[n - 1]->dims().size() == 1) { - dout_dim = phi::make_ddim({1, 1}); - } else if (ins[0]->dims().size() == 1) { - if (dout_dim.size() == 1) { - dout_dim = phi::make_ddim({1, dout_dim[0]}); - } - } else if (ins[n - 1]->dims().size() == 1) { - if (dout_dim.size() == 1) { - dout_dim = phi::make_ddim({dout_dim[0], 1}); - } - } - - T alpha = static_cast(1); - auto mat_dim_dout = phi::funcs::CreateMatrixDescriptor(dout_dim, 0, false); - if (n == 2) { - CalcGrad(ctx, dout, *ins[0], *ins[1], dout_dim, ins_dims[0], ins_dims[1], - dx[0], dx[1]); - } else if (n == 3) { - const auto Ma = ins_dims[0][0]; - const auto Ka = ins_dims[0][1]; - const auto Nb = ins_dims[1][1]; - const auto Nc = ins_dims[2][1]; - const uint64_t cost1 = Ma * Nb * (Ka + Nc); - const uint64_t cost2 = Ka * Nc * (Nb + Ma); - auto mat_dim_a = - phi::funcs::CreateMatrixDescriptor(ins_dims[0], 0, false); - auto mat_dim_b = - phi::funcs::CreateMatrixDescriptor(ins_dims[1], 0, false); - auto mat_dim_c = - phi::funcs::CreateMatrixDescriptor(ins_dims[2], 0, false); - if (cost1 < cost2) { - framework::Tensor tmp_out, tmp_dout; - tmp_out.Resize({Ma, Nb}); - tmp_out.mutable_data(place); - tmp_dout.Resize({mat_dim_dout.height_, Nb}); - tmp_dout.mutable_data(place); - blas.MatMul(*ins[0], mat_dim_a, *ins[1], mat_dim_b, alpha, &tmp_out, - T(0)); - CalcGrad(ctx, dout, tmp_out, *ins[2], dout_dim, tmp_out.dims(), - ins_dims[2], &tmp_dout, dx[2]); - CalcGrad(ctx, tmp_dout, *ins[0], *ins[1], tmp_dout.dims(), ins_dims[0], - ins_dims[1], dx[0], dx[1]); - } else { - framework::Tensor tmp_out, tmp_dout; - tmp_out.Resize({Ka, Nc}); - tmp_out.mutable_data(place); - tmp_dout.Resize({Ka, mat_dim_dout.width_}); - tmp_dout.mutable_data(place); - blas.MatMul(*ins[1], mat_dim_b, *ins[2], mat_dim_c, alpha, &tmp_out, - T(0)); - CalcGrad(ctx, dout, *ins[0], tmp_out, dout_dim, ins_dims[0], - tmp_dout.dims(), dx[0], &tmp_dout); - CalcGrad(ctx, tmp_dout, *ins[1], *ins[2], tmp_dout.dims(), ins_dims[1], - ins_dims[2], dx[1], dx[2]); - } - } else { - MultiDotGradMatChainOrder(ctx, dout, ins, dout_dim, ins_dims, &dx); - if (ins[n - 1]->dims().size() == 1) { - dx[n - 1]->Resize({dx[n - 1]->dims()[0]}); - } - } - } -}; - template class MultiDotOpGradMaker : public framework::SingleGradOpMaker { public: @@ -552,25 +177,3 @@ REGISTER_OPERATOR(multi_dot, ops::MultiDotOp, ops::MultiDotOpMaker, REGISTER_OPERATOR(multi_dot_grad, ops::MultiDotOpGrad, ops::MultiDotOpDoubleGradMaker, ops::MultiDotOpDoubleGradMaker); - -REGISTER_OP_CPU_KERNEL( - multi_dot, ops::MultiDotKernel, - ops::MultiDotKernel); -REGISTER_OP_CPU_KERNEL( - multi_dot_grad, - ops::MultiDotGradKernel, - ops::MultiDotGradKernel); - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -REGISTER_OP_CUDA_KERNEL( - multi_dot, ops::MultiDotKernel, - ops::MultiDotKernel, - ops::MultiDotKernel); -REGISTER_OP_CUDA_KERNEL( - multi_dot_grad, - ops::MultiDotGradKernel, - ops::MultiDotGradKernel, - ops::MultiDotGradKernel); -#endif diff --git a/paddle/phi/kernels/cpu/multi_dot_grad_kernel.cc b/paddle/phi/kernels/cpu/multi_dot_grad_kernel.cc new file mode 100644 index 00000000000..2cd75404be8 --- /dev/null +++ b/paddle/phi/kernels/cpu/multi_dot_grad_kernel.cc @@ -0,0 +1,22 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/multi_dot_grad_kernel.h" +#include "paddle/phi/kernels/impl/multi_dot_kernel_impl.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL( + multi_dot_grad, CPU, ALL_LAYOUT, phi::MultiDotGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/multi_dot_kernel.cc b/paddle/phi/kernels/cpu/multi_dot_kernel.cc new file mode 100644 index 00000000000..a4249a98e46 --- /dev/null +++ b/paddle/phi/kernels/cpu/multi_dot_kernel.cc @@ -0,0 +1,22 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/multi_dot_kernel.h" +#include "paddle/phi/kernels/impl/multi_dot_kernel_impl.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL( + multi_dot, CPU, ALL_LAYOUT, phi::MultiDotKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/multi_dot_grad_kernel.cu b/paddle/phi/kernels/gpu/multi_dot_grad_kernel.cu new file mode 100644 index 00000000000..6761d945e95 --- /dev/null +++ b/paddle/phi/kernels/gpu/multi_dot_grad_kernel.cu @@ -0,0 +1,30 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/impl/multi_dot_kernel_impl.h" +#include "paddle/phi/kernels/multi_dot_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" + +using float16 = phi::dtype::float16; + +PD_REGISTER_KERNEL(multi_dot_grad, + GPU, + ALL_LAYOUT, + phi::MultiDotGradKernel, + float, + double, + float16) {} diff --git a/paddle/phi/kernels/gpu/multi_dot_kernel.cu b/paddle/phi/kernels/gpu/multi_dot_kernel.cu new file mode 100644 index 00000000000..60b1fce5ddd --- /dev/null +++ b/paddle/phi/kernels/gpu/multi_dot_kernel.cu @@ -0,0 +1,25 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/impl/multi_dot_kernel_impl.h" +#include "paddle/phi/kernels/multi_dot_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" + +using float16 = phi::dtype::float16; + +PD_REGISTER_KERNEL( + multi_dot, GPU, ALL_LAYOUT, phi::MultiDotKernel, float, double, float16) {} diff --git a/paddle/phi/kernels/impl/multi_dot_kernel_impl.h b/paddle/phi/kernels/impl/multi_dot_kernel_impl.h new file mode 100644 index 00000000000..0833e94fe2c --- /dev/null +++ b/paddle/phi/kernels/impl/multi_dot_kernel_impl.h @@ -0,0 +1,456 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" + +namespace phi { + +template +inline DenseTensor MatMul(const Context& ctx, + const DenseTensor& matrix_a, + const DenseTensor& matrix_b, + const phi::DDim& a_dim, + const phi::DDim& b_dim) { + auto blas = phi::funcs::GetBlas(ctx); + + DenseTensor matrix_c; + phi::DDim c_dim = phi::make_ddim({a_dim[0], b_dim[1]}); + matrix_c.Resize(c_dim); + ctx.template Alloc(&matrix_c); + + auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a_dim, 0, false); + auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b_dim, 0, false); + const T alpha = static_cast(1.0); + blas.MatMul(matrix_a.data(), + mat_dim_a, + matrix_b.data(), + mat_dim_b, + alpha, + matrix_c.data(), + T(0)); + return matrix_c; +} + +/** + * @brief Recursively calculate matrix multiplication according to the optimal + * order + * Let k = order[i,j], then ins[i...j] = ins[i...k] * ins[k+1 ...j] + * + * @param + * ins: the input tensors + * ins_dims: the shape of ins after reshape + * order: the optimal order + * i: the left of sub chain + * j: the righe of sub chain + * save_result: set true by backward + * results: save the intermediate result during backward + */ +template +inline DenseTensor MatChainMul(const Context& ctx, + const std::vector& ins, + const std::vector& ins_dims, + const std::vector& order, + const uint64_t i, + const uint64_t j, + const bool save_result, + std::vector* results) { + if (i == j) { + return *ins[i]; + } + + const auto A = MatChainMul(ctx, + ins, + ins_dims, + order, + i, + order[i * ins.size() + j], + save_result, + results); + phi::DDim a_dim = A.dims(); + if (i == order[i * ins.size() + j]) { + a_dim = ins_dims[i]; + } + + const auto B = MatChainMul(ctx, + ins, + ins_dims, + order, + order[i * ins.size() + j] + 1, + j, + save_result, + results); + phi::DDim b_dim = B.dims(); + if (j == order[i * ins.size() + j] + 1) { + b_dim = ins_dims[j]; + } + + auto result = MatMul(ctx, A, B, a_dim, b_dim); + if (save_result) { + (*results)[i * ins.size() + j] = result; + } + return result; +} + +/** + * @brief get the optimal order + */ +template +std::vector GetOrder(const std::vector& ins, + const std::vector& ins_dims) { + auto n = ins.size(); + // p: save the ins shape, the ins[i] shape is (p[i], p[i+1]) + std::vector p(n + 1); + for (uint64_t i = 0; i < n; i++) { + p[i] = ins_dims[i][0]; + } + p[n] = ins_dims[n - 1][1]; + + // m[i, j]: save the lowest cost for multiplying ins[i...j] + std::vector m(n * n, 0); + // define ins[i...j] means multiplying matrices from ins[i] to ins[j] + // order[i, j] = k, this means that ins[i...k] and ins[k...j] fist and then + // multiply the resulting matrices is the optimal order for ins[i...j] + std::vector order(n * n); + for (uint64_t l = 1; l < n; l++) { + for (uint64_t i = 0; i < n - l; i++) { + auto j = i + l; + m[i * n + j] = 0xffffffff; + for (uint64_t k = i; k < j; k++) { + uint64_t q = + m[i * n + k] + m[(k + 1) * n + j] + p[i] * p[k + 1] * p[j + 1]; + if (q < m[i * n + j]) { + m[i * n + j] = q; + order[i * n + j] = k; + } + } + } + } + return order; +} + +template +static inline DenseTensor MultiDotMatChainOrder( + const Context& ctx, + const std::vector& ins, + const std::vector& ins_dims, + const bool save_result, + std::vector* results) { + auto order = GetOrder(ins, ins_dims); + return MatChainMul( + ctx, ins, ins_dims, order, 0, ins.size() - 1, save_result, results); +} + +template +inline void GetDims(const std::vector& ins, + std::vector* ins_dims) { + const auto n = ins.size(); + for (size_t i = 0; i < n; i++) { + (*ins_dims)[i] = ins[i]->dims(); + if (i == 0 && (*ins_dims)[i].size() == 1) { + (*ins_dims)[i] = phi::make_ddim({1, (*ins_dims)[i][0]}); + } else if (i == n - 1 && (*ins_dims)[i].size() == 1) { + (*ins_dims)[i] = phi::make_ddim({(*ins_dims)[i][0], 1}); + } + } +} + +template +void MultiDotKernel(const Context& ctx, + const std::vector& x, + DenseTensor* out) { + auto ins = x; + ctx.template Alloc(out); + + auto blas = phi::funcs::GetBlas(ctx); + + auto n = ins.size(); + std::vector ins_dims(n); + GetDims(ins, &ins_dims); + + const T scale = static_cast(1.0); + if (n == 2) { + auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(ins_dims[0], 0, false); + auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(ins_dims[1], 0, false); + blas.MatMul(*ins[0], mat_dim_a, *ins[1], mat_dim_b, scale, out, T(0)); + } else if (n == 3) { + const auto Ma = ins_dims[0][0]; + const auto Ka = ins_dims[0][1]; + const auto Nb = ins_dims[1][1]; + const auto Nc = ins_dims[2][1]; + const uint64_t cost1 = Ma * Nb * (Ka + Nc); + const uint64_t cost2 = Ka * Nc * (Nb + Ma); + auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(ins_dims[0], 0, false); + auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(ins_dims[1], 0, false); + auto mat_dim_c = phi::funcs::CreateMatrixDescriptor(ins_dims[2], 0, false); + if (cost1 < cost2) { + DenseTensor tmp_out; + phi::DDim tmp_dim = phi::make_ddim({Ma, Nb}); + tmp_out.Resize(tmp_dim); + ctx.template Alloc(&tmp_out); + blas.MatMul( + *ins[0], mat_dim_a, *ins[1], mat_dim_b, scale, &tmp_out, T(0)); + auto mat_dim_tmp = phi::funcs::CreateMatrixDescriptor(tmp_dim, 0, false); + blas.MatMul(tmp_out, mat_dim_tmp, *ins[2], mat_dim_c, scale, out, T(0)); + } else { + DenseTensor tmp_out; + phi::DDim tmp_dim = phi::make_ddim({Ka, Nc}); + tmp_out.Resize(tmp_dim); + ctx.template Alloc(&tmp_out); + std::cout << tmp_out << std::endl; + blas.MatMul( + *ins[1], mat_dim_b, *ins[2], mat_dim_c, scale, &tmp_out, T(0)); + auto mat_dim_tmp = phi::funcs::CreateMatrixDescriptor(tmp_dim, 0, false); + blas.MatMul(*ins[0], mat_dim_a, tmp_out, mat_dim_tmp, scale, out, T(0)); + } + } else { + std::vector results; + const auto tmp = + MultiDotMatChainOrder(ctx, ins, ins_dims, false, &results); + auto out_dim = out->dims(); + *out = tmp; + out->Resize(out_dim); + } +} + +/** + * @brief calculate dA and dB + * dA = dout * transpose(B) + * dB = transpose(A) * dout + */ +template +void CalcGrad(const Context& ctx, + const DenseTensor& dout, + const DenseTensor& A, + const DenseTensor& B, + const phi::DDim& dout_dim, + const phi::DDim& a_dim, + const phi::DDim& b_dim, + DenseTensor* dA, + DenseTensor* dB) { + auto mat_dim_dout = phi::funcs::CreateMatrixDescriptor(dout_dim, 0, false); + auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a_dim, 0, true); + auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b_dim, 0, true); + T alpha = static_cast(1.0); + auto blas = phi::funcs::GetBlas(ctx); + blas.MatMul(A, mat_dim_a, dout, mat_dim_dout, alpha, dB, T(0)); + blas.MatMul(dout, mat_dim_dout, B, mat_dim_b, alpha, dA, T(0)); +} + +/** + * @brief calculate multi matrix multiplication grad by a chain order + * @param + * dout: the grad of multi matrix multiplication out + * dx: the out grad of inputs + * ins: the input tensors + * ins_dims: the shape of ins after reshape + * order: the optimal order + * i: the left of sub chain + * j: the righe of sub chain + * results: the intermediate result of farward + */ +template +void MatChainMulGrad(const Context& ctx, + const DenseTensor& dout, + std::vector* dx, + const std::vector& ins, + const phi::DDim& dout_dim, + const std::vector& ins_dims, + const std::vector& order, + const uint64_t i, + const uint64_t j, + const std::vector& results) { + if (i == j) { + *((*dx)[i]) = dout; + return; + } + + const auto n = ins.size(); + const auto right = order[i * n + j]; + const auto left = order[i * n + j] + 1; + // get the multi result of left sub chain + const auto* A = &results[i * n + right]; + phi::DDim a_dim = A->dims(); + if (i == right) { + A = ins[i]; + a_dim = ins_dims[i]; + } + // get the multi result of right sub chain + const auto* B = &results[left * n + j]; + phi::DDim b_dim = B->dims(); + if (left == j) { + B = ins[j]; + b_dim = ins_dims[j]; + } + DenseTensor dA, dB; + dA.Resize({dout_dim[0], b_dim[0]}); + dB.Resize({a_dim[1], dout_dim[1]}); + ctx.template Alloc(&dA); + ctx.template Alloc(&dB); + + CalcGrad(ctx, dout, *A, *B, dout_dim, a_dim, b_dim, &dA, &dB); + MatChainMulGrad( + ctx, dA, dx, ins, dA.dims(), ins_dims, order, i, right, results); + MatChainMulGrad( + ctx, dB, dx, ins, dB.dims(), ins_dims, order, left, j, results); +} + +template +void MultiDotGradMatChainOrder(const Context& ctx, + const DenseTensor& dout, + const std::vector& ins, + const phi::DDim& dout_dim, + const std::vector& ins_dims, + std::vector* dx) { + auto order = GetOrder(ins, ins_dims); + auto n = ins.size(); + std::vector results(n * n); + MatChainMul(ctx, ins, ins_dims, order, 0, n - 1, true, &results); + MatChainMulGrad( + ctx, dout, dx, ins, dout_dim, ins_dims, order, 0, n - 1, results); +} + +template +void MultiDotGradKernel(const Context& ctx, + const DenseTensor& out_grad, + const std::vector& x, + std::vector x_grad) { + auto ins = x; + auto dout = out_grad; + auto dx = x_grad; + + auto blas = phi::funcs::GetBlas(ctx); + + const auto n = ins.size(); + for (size_t i = 0; i < n; i++) { + ctx.template Alloc(dx[i]); + } + + std::vector ins_dims(n); + GetDims(ins, &ins_dims); + + phi::DDim dout_dim = dout.dims(); + if (ins[0]->dims().size() == 1 && ins[n - 1]->dims().size() == 1) { + dout_dim = phi::make_ddim({1, 1}); + } else if (ins[0]->dims().size() == 1) { + if (dout_dim.size() == 1) { + dout_dim = phi::make_ddim({1, dout_dim[0]}); + } + } else if (ins[n - 1]->dims().size() == 1) { + if (dout_dim.size() == 1) { + dout_dim = phi::make_ddim({dout_dim[0], 1}); + } + } + + T alpha = static_cast(1); + auto mat_dim_dout = phi::funcs::CreateMatrixDescriptor(dout_dim, 0, false); + if (n == 2) { + CalcGrad(ctx, + dout, + *ins[0], + *ins[1], + dout_dim, + ins_dims[0], + ins_dims[1], + dx[0], + dx[1]); + } else if (n == 3) { + const auto Ma = ins_dims[0][0]; + const auto Ka = ins_dims[0][1]; + const auto Nb = ins_dims[1][1]; + const auto Nc = ins_dims[2][1]; + const uint64_t cost1 = Ma * Nb * (Ka + Nc); + const uint64_t cost2 = Ka * Nc * (Nb + Ma); + auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(ins_dims[0], 0, false); + auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(ins_dims[1], 0, false); + auto mat_dim_c = phi::funcs::CreateMatrixDescriptor(ins_dims[2], 0, false); + if (cost1 < cost2) { + DenseTensor tmp_out, tmp_dout; + tmp_out.Resize({Ma, Nb}); + ctx.template Alloc(&tmp_out); + tmp_dout.Resize({mat_dim_dout.height_, Nb}); + ctx.template Alloc(&tmp_dout); + blas.MatMul( + *ins[0], mat_dim_a, *ins[1], mat_dim_b, alpha, &tmp_out, T(0)); + CalcGrad(ctx, + dout, + tmp_out, + *ins[2], + dout_dim, + tmp_out.dims(), + ins_dims[2], + &tmp_dout, + dx[2]); + CalcGrad(ctx, + tmp_dout, + *ins[0], + *ins[1], + tmp_dout.dims(), + ins_dims[0], + ins_dims[1], + dx[0], + dx[1]); + } else { + DenseTensor tmp_out, tmp_dout; + tmp_out.Resize({Ka, Nc}); + ctx.template Alloc(&tmp_out); + tmp_dout.Resize({Ka, mat_dim_dout.width_}); + ctx.template Alloc(&tmp_dout); + blas.MatMul( + *ins[1], mat_dim_b, *ins[2], mat_dim_c, alpha, &tmp_out, T(0)); + CalcGrad(ctx, + dout, + *ins[0], + tmp_out, + dout_dim, + ins_dims[0], + tmp_dout.dims(), + dx[0], + &tmp_dout); + CalcGrad(ctx, + tmp_dout, + *ins[1], + *ins[2], + tmp_dout.dims(), + ins_dims[1], + ins_dims[2], + dx[1], + dx[2]); + } + } else { + MultiDotGradMatChainOrder( + ctx, dout, ins, dout_dim, ins_dims, &dx); + if (ins[n - 1]->dims().size() == 1) { + dx[n - 1]->Resize({dx[n - 1]->dims()[0]}); + } + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/multi_dot_grad_kernel.h b/paddle/phi/kernels/multi_dot_grad_kernel.h new file mode 100644 index 00000000000..e6d8ecd744e --- /dev/null +++ b/paddle/phi/kernels/multi_dot_grad_kernel.h @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void MultiDotGradKernel(const Context& ctx, + const DenseTensor& out_grad, + const std::vector& x, + std::vector x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/multi_dot_kernel.h b/paddle/phi/kernels/multi_dot_kernel.h new file mode 100644 index 00000000000..09866e8dde5 --- /dev/null +++ b/paddle/phi/kernels/multi_dot_kernel.h @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void MultiDotKernel(const Context& ctx, + const std::vector& x, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/multi_dot_sig.cc b/paddle/phi/ops/compat/multi_dot_sig.cc new file mode 100644 index 00000000000..598cbd980f3 --- /dev/null +++ b/paddle/phi/ops/compat/multi_dot_sig.cc @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature MultiDotGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "multi_dot_grad", {GradVarName("Out"), "X"}, {}, {GradVarName("X")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(multi_dot_grad, phi::MultiDotGradOpArgumentMapping); -- GitLab