From 6bd2762cbd9170fffa93b30815b46f542c99ae2b Mon Sep 17 00:00:00 2001 From: Jiabin Yang <360788950@qq.com> Date: Wed, 14 Sep 2022 17:19:31 +0800 Subject: [PATCH] [PHI] Support bmm and bmm_grad in xpu (#45887) * support bmm and bmm_grad in xpu * add error removal * test=kunlun * refactor code for better structure * test=kunlun * add fp16 kernel for bmm * test=kunlun --- paddle/fluid/eager/utils.h | 2 +- paddle/fluid/operators/bmm_op_xpu.cc | 226 ---------------------- paddle/phi/kernels/xpu/bmm_grad_kernel.cc | 107 ++++++++++ paddle/phi/kernels/xpu/bmm_kernel.cc | 80 ++++++++ paddle/phi/kernels/xpu/bmm_xpu_utils.h | 64 ++++++ 5 files changed, 252 insertions(+), 227 deletions(-) delete mode 100644 paddle/fluid/operators/bmm_op_xpu.cc create mode 100644 paddle/phi/kernels/xpu/bmm_grad_kernel.cc create mode 100644 paddle/phi/kernels/xpu/bmm_kernel.cc create mode 100644 paddle/phi/kernels/xpu/bmm_xpu_utils.h diff --git a/paddle/fluid/eager/utils.h b/paddle/fluid/eager/utils.h index a42b118771..e82d8d03a0 100644 --- a/paddle/fluid/eager/utils.h +++ b/paddle/fluid/eager/utils.h @@ -223,7 +223,7 @@ class EagerUtils { const std::vector& out_var, std::vector* result); - // end Intermidate needed + // end Intermidate needed. static void CheckAndRetainGrad(const paddle::experimental::Tensor& tensor); static void CheckAndRetainGrad( diff --git a/paddle/fluid/operators/bmm_op_xpu.cc b/paddle/fluid/operators/bmm_op_xpu.cc deleted file mode 100644 index f6e1d0227c..0000000000 --- a/paddle/fluid/operators/bmm_op_xpu.cc +++ /dev/null @@ -1,226 +0,0 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifdef PADDLE_WITH_XPU - -#include -#include - -#include "paddle/fluid/operators/matmul_v2_op.h" -#include "paddle/fluid/operators/xpu_api_wrapper.h" -#include "paddle/fluid/platform/device/device_wrapper.h" - -namespace paddle { -namespace operators { - -template -static void MatMulXPUFunction(const Tensor* x, - const Tensor* y, - Tensor* out, - bool trans_x, - bool trans_y, - const paddle::framework::ExecutionContext& ctx) { - using XPUType = typename XPUTypeTrait::Type; - const auto& x_dims = x->dims(); - const auto& y_dims = y->dims(); - auto& dev_ctx = - ctx.template device_context(); - - auto mat_dim_a = phi::funcs::CreateMatrixDescriptor( - RowMatrixFromVector(x_dims), 0, trans_x); - auto mat_dim_b = phi::funcs::CreateMatrixDescriptor( - ColumnMatrixFromVector(y_dims), 0, trans_y); - - T* data_c = out->data(); - int m = mat_dim_a.height_; - int n = mat_dim_b.width_; - int k = mat_dim_a.width_; - int batch_size = mat_dim_a.batch_size_; - // batch matmul - int r = xpu::fc_batched( - dev_ctx.x_context(), // Context* ctx, - batch_size, // int batch_size, - mat_dim_a.trans_, // bool x_trans, - mat_dim_b.trans_, // bool w_trans, - m, // int m, - n, // int n, - k, // int k, - 1.0, // float alpha, - reinterpret_cast(x->data()), // const TX* x, - mat_dim_a.stride_, // int stride_a, - reinterpret_cast(y->data()), // const TW* w, - mat_dim_b.stride_, // int stride_b, - 0.0, // float beta, - reinterpret_cast(data_c), // TY* y, - m * n, // int stride_c, - nullptr, // const float* x_maxptr, - nullptr); // const float* w_maxptr - - PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_batched"); -} - -template -class BmmXPUKernel : public framework::OpKernel { - public: - void Compute(const paddle::framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* out = ctx.Output("Out"); - out->mutable_data(ctx.GetPlace()); - - if (x->numel() == 0 || y->numel() == 0) { - return; - } - bool trans_x = false; - bool trans_y = false; - - auto x_dims = x->dims(); - auto y_dims = y->dims(); - - PADDLE_ENFORCE_EQ(x_dims.size(), - 3, - platform::errors::InvalidArgument( - "Input(X) of BmmOp must be 3-dimensional in BmmOp, " - "but received X's shape: [%s].", - x_dims)); - PADDLE_ENFORCE_EQ(y_dims.size(), - 3, - platform::errors::InvalidArgument( - "Input(Y) of BmmOp must be 3-dimensional in BmmOp, " - "but received Y's shape: [%s].", - y_dims)); - PADDLE_ENFORCE_EQ( - x_dims[0], - y_dims[0], - platform::errors::InvalidArgument( - "Input(X) and Input(Y) must have the same batch size in BmmOp, " - "but received X's batch size: [%s]," - "Y's batch size [%s]", - x_dims[0], - y_dims[0])); - PADDLE_ENFORCE_EQ( - x_dims[2], - y_dims[1], - platform::errors::InvalidArgument( - "Input(X)'s width must be equal with Input(Y)'s height in BmmOp," - "but receive X's width: [%s]," - "Y's height: [%s].", - x_dims[2], - y_dims[1])); - - if (std::is_same::value) { - MatMulXPUFunction(x, y, out, trans_x, trans_y, ctx); - } else { - if (std::getenv("XPU_PADDLE_FC_INT32") != nullptr) { - MatMulXPUFunction(x, y, out, trans_x, trans_y, ctx); - } else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) { - MatMulXPUFunction(x, y, out, trans_x, trans_y, ctx); - } else { - MatMulXPUFunction(x, y, out, trans_x, trans_y, ctx); - } - } - } -}; - -template -class BmmXPUGradKernel : public framework::OpKernel { - public: - void MatMul(const framework::ExecutionContext& ctx, - const framework::Tensor& a, - bool trans_a, - const framework::Tensor& b, - bool trans_b, - framework::Tensor* out) const { - out->mutable_data(ctx.GetPlace()); - if (std::is_same::value) { - MatMulXPUFunction(&a, &b, out, trans_a, trans_b, ctx); - } else { - if (std::getenv("XPU_PADDLE_FC_INT32") != nullptr) { - MatMulXPUFunction(&a, &b, out, trans_a, trans_b, ctx); - } else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) { - MatMulXPUFunction(&a, &b, out, trans_a, trans_b, ctx); - } else { - MatMulXPUFunction(&a, &b, out, trans_a, trans_b, ctx); - } - } - } - - void CalcInputGrad(const framework::ExecutionContext& context, - const framework::Tensor& a, - bool trans_a, - const framework::Tensor& b, - bool trans_b, - framework::Tensor* out) const { - if (out == nullptr) return; - MatMul(context, a, trans_a, b, trans_b, out); - } - - void Compute(const framework::ExecutionContext& context) const override { - auto x = *context.Input("X"); - auto y = *context.Input("Y"); - auto dout = - *context.Input(framework::GradVarName("Out")); - auto* dx = context.Output(framework::GradVarName("X")); - auto* dy = context.Output(framework::GradVarName("Y")); - ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, false, false); - - framework::DDim dx_dims; - if (dx) { - dx_dims = dx->dims(); - if (dx_dims != x.dims()) { - dx->Resize(x.dims()); - } - } - - framework::DDim dy_dims; - if (dy) { - dy_dims = dy->dims(); - if (dy_dims != y.dims()) { - dy->Resize(y.dims()); - } - } - - CalcInputGrad(context, dout, false, y, true, dx); - CalcInputGrad(context, x, true, dout, false, dy); - - // CalcInputGrad(context, dout, false, false, y, true, false, dx); - // CalcInputGrad(context, x, true, true, dout, false, true, dy); - - if (dx) { - if (dx_dims != x.dims()) { - dx->Resize(dx_dims); - } - } - - if (dy) { - if (dy_dims != y.dims()) { - dy->Resize(dy_dims); - } - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -REGISTER_OP_XPU_KERNEL(bmm, - ops::BmmXPUKernel, - ops::BmmXPUKernel); -REGISTER_OP_XPU_KERNEL(bmm_grad, - ops::BmmXPUGradKernel, - ops::BmmXPUGradKernel); - -#endif diff --git a/paddle/phi/kernels/xpu/bmm_grad_kernel.cc b/paddle/phi/kernels/xpu/bmm_grad_kernel.cc new file mode 100644 index 0000000000..246da888d2 --- /dev/null +++ b/paddle/phi/kernels/xpu/bmm_grad_kernel.cc @@ -0,0 +1,107 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/bmm_grad_kernel.h" + +#include "paddle/phi/kernels/xpu/bmm_xpu_utils.h" + +namespace phi { + +template +void MatMul(const Context& dev_ctx, + const DenseTensor& a, + bool trans_a, + const DenseTensor& b, + bool trans_b, + DenseTensor* out) { + dev_ctx.template Alloc(out); + xpu::Context* xpu_ctx = dev_ctx.x_context(); + if (std::is_same::value) { + MatMulXPUFunction(a, b, out, trans_a, trans_b, xpu_ctx); + } else { + if (std::getenv("XPU_PADDLE_FC_INT32") != nullptr) { + MatMulXPUFunction(a, b, out, trans_a, trans_b, xpu_ctx); + } else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) { + MatMulXPUFunction(a, b, out, trans_a, trans_b, xpu_ctx); + } else { + MatMulXPUFunction(a, b, out, trans_a, trans_b, xpu_ctx); + } + } +} + +template +void CalcInputGrad(const Context& dev_ctx, + const DenseTensor& a, + bool trans_a, + const DenseTensor& b, + bool trans_b, + DenseTensor* out) { + if (out == nullptr) return; + MatMul(dev_ctx, a, trans_a, b, trans_b, out); +} + +template +void BmmGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out_grad, + DenseTensor* x_grad, + DenseTensor* y_grad) { + DenseTensor x_help = x; + DenseTensor y_help = y; + DenseTensor out_grad_help = out_grad; + ReshapeXYOutIntoMatrixSequence( + &x_help, &y_help, &out_grad_help, false, false); + + phi::DDim dx_dims; + if (x_grad) { + dx_dims = x_grad->dims(); + if (dx_dims != x_help.dims()) { + x_grad->Resize(x_help.dims()); + } + } + + phi::DDim dy_dims; + if (y_grad) { + dy_dims = y_grad->dims(); + if (dy_dims != y_help.dims()) { + y_grad->Resize(y_help.dims()); + } + } + + CalcInputGrad( + dev_ctx, out_grad_help, false, y_help, true, x_grad); + CalcInputGrad( + dev_ctx, x_help, true, out_grad_help, false, y_grad); + + if (x_grad) { + if (dx_dims != x_help.dims()) { + x_grad->Resize(dx_dims); + } + } + if (y_grad) { + if (dy_dims != y_help.dims()) { + y_grad->Resize(dy_dims); + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(bmm_grad, + XPU, + ALL_LAYOUT, + phi::BmmGradKernel, + float, + paddle::platform::float16) {} diff --git a/paddle/phi/kernels/xpu/bmm_kernel.cc b/paddle/phi/kernels/xpu/bmm_kernel.cc new file mode 100644 index 0000000000..b75383bbaa --- /dev/null +++ b/paddle/phi/kernels/xpu/bmm_kernel.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/bmm_kernel.h" +#include "paddle/phi/kernels/xpu/bmm_xpu_utils.h" +namespace phi { +template +void BmmKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + dev_ctx.template Alloc(out); + if (x.numel() == 0 || y.numel() == 0) { + return; + } + bool trans_x = false; + bool trans_y = false; + + auto x_dims = x.dims(); + auto y_dims = y.dims(); + + PADDLE_ENFORCE_EQ(x_dims.size(), + 3, + phi::errors::InvalidArgument( + "Input(X) of BmmOp must be 3-dimensional in BmmOp, " + "but received X's shape: [%s]", + x_dims)); + PADDLE_ENFORCE_EQ(y_dims.size(), + 3, + phi::errors::InvalidArgument( + "Input(Y) of BmmOp must be 3-dimensional in BmmOp, " + "but received Y's shape: [%s].", + y_dims)); + PADDLE_ENFORCE_EQ( + x_dims[0], + y_dims[0], + phi::errors::InvalidArgument( + "Input(X) and Input(Y) must have the same batch size in BmmOp, " + "but received X's batch size: [%s]," + "Y's batch size [%s]", + x_dims[0], + y_dims[0])); + PADDLE_ENFORCE_EQ( + x_dims[2], + y_dims[1], + phi::errors::InvalidArgument( + "Input(X)'s width must be equal with Input(Y)'s height in BmmOp," + "but receive X's width: [%s]," + "Y's height: [%s].", + x_dims[2], + y_dims[1])); + + xpu::Context* xpu_ctx = dev_ctx.x_context(); + if (std::is_same::value) { + MatMulXPUFunction(x, y, out, trans_x, trans_y, xpu_ctx); + } else { + if (std::getenv("XPU_PADDLE_FC_INT32") != nullptr) { + MatMulXPUFunction(x, y, out, trans_x, trans_y, xpu_ctx); + } else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) { + MatMulXPUFunction(x, y, out, trans_x, trans_y, xpu_ctx); + } else { + MatMulXPUFunction(x, y, out, trans_x, trans_y, xpu_ctx); + } + } +} +} // namespace phi + +PD_REGISTER_KERNEL( + bmm, XPU, ALL_LAYOUT, phi::BmmKernel, float, paddle::platform::float16) {} diff --git a/paddle/phi/kernels/xpu/bmm_xpu_utils.h b/paddle/phi/kernels/xpu/bmm_xpu_utils.h new file mode 100644 index 0000000000..f0ac5c7e14 --- /dev/null +++ b/paddle/phi/kernels/xpu/bmm_xpu_utils.h @@ -0,0 +1,64 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/matmul_grad_kernel_impl.h" +#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h" + +namespace phi { +template +static void MatMulXPUFunction(const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out, + bool trans_x, + bool trans_y, + xpu::Context* xpu_ctx) { + using XPUType = typename XPUTypeTrait::Type; + const auto& x_dims = x.dims(); + const auto& y_dims = y.dims(); + + auto mat_dim_a = phi::funcs::CreateMatrixDescriptor( + RowMatrixFromVector(x_dims), 0, trans_x); + auto mat_dim_b = phi::funcs::CreateMatrixDescriptor( + ColumnMatrixFromVector(y_dims), 0, trans_y); + + T* data_c = out->data(); + int m = mat_dim_a.height_; + int n = mat_dim_b.width_; + int k = mat_dim_a.width_; + int batch_size = mat_dim_a.batch_size_; + // batch matmul + int r = xpu::fc_batched( + xpu_ctx, // Context* ctx, + batch_size, // int batch_size, + mat_dim_a.trans_, // bool x_trans, + mat_dim_b.trans_, // bool w_trans, + m, // int m, + n, // int n, + k, // int k, + 1.0, // float alpha, + reinterpret_cast(x.data()), // const TX* x, + mat_dim_a.stride_, // int stride_a, + reinterpret_cast(y.data()), // const TW* w, + mat_dim_b.stride_, // int stride_b, + 0.0, // float beta, + reinterpret_cast(data_c), // TY* y, + m * n, // int stride_c, + nullptr, // const float* x_maxptr, + nullptr); // const float* w_maxptr + + PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_batched"); +} +} // namespace phi -- GitLab